Added support for Cohere Embed and Titan Embeddings models

This commit is contained in:
Joao Galego
2024-03-27 17:27:23 +00:00
parent a26b8e6833
commit b24a43f6f4
7 changed files with 242 additions and 5 deletions

View File

@@ -7,6 +7,7 @@ from typing import AsyncIterable
import boto3
from api.schema import (
# Chat
ChatResponse,
ChatRequest,
ChatRequestMessage,
@@ -15,6 +16,11 @@ from api.schema import (
Usage,
ChatStreamResponse,
ChoiceDelta,
# Embeddings
EmbeddingsRequest,
EmbeddingsResponse,
EmbeddingsUsage,
Embedding,
)
from api.setting import DEBUG, AWS_REGION
@@ -37,6 +43,12 @@ SUPPORTED_BEDROCK_MODELS = {
"mistral.mixtral-8x7b-instruct-v0:1": "Mixtral 8x7B Instruct",
}
SUPPORTED_BEDROCK_EMBEDDING_MODELS = {
"cohere.embed-multilingual-v3": "Cohere Embed Multilingual",
"cohere.embed-english-v3": "Cohere Embed English",
"amazon.titan-embed-text-v1": "Titan Embeddings G1 - Text",
"amazon.titan-embed-image-v1": "Titan Multimodal Embeddings G1"
}
class BaseChatModel(ABC):
"""Represent a basic chat model
@@ -389,3 +401,136 @@ class MistralModel(BedrockModel):
finish_reason=chunk["outputs"][0]["stop_reason"],
)
yield self._stream_response_to_bytes(response)
class BaseEmbeddingsModel(ABC):
"""Represents a basic embeddings model.
Currently, only Bedrock-provided models are supported, but it may be used for SageMaker models if needed.
"""
@abstractmethod
def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse:
"""Handle a basic embeddings request."""
pass
def _generate_message_id(self) -> str:
return "embeddings-" + str(uuid.uuid4())[:8]
class BedrockEmbeddingsModel(BaseEmbeddingsModel):
accept = "application/json"
content_type = "application/json"
def _invoke_model(self, args: dict, model_id: str, with_stream: bool = False):
body = json.dumps(args)
if DEBUG:
logger.info("Invoke Bedrock Model: " + model_id)
logger.info("Bedrock request body: " + body)
if with_stream:
return bedrock_runtime.invoke_model_with_response_stream(
body=body,
modelId=model_id,
accept=self.accept,
contentType=self.content_type,
)
return bedrock_runtime.invoke_model(
body=body,
modelId=model_id,
accept=self.accept,
contentType=self.content_type,
)
def _create_response(
self,
embeddings: list[float],
model: str,
input_tokens: int = 0,
output_tokens: int = 0,
) -> EmbeddingsResponse:
data = [
Embedding(
index=i,
embedding=embedding
) for i, embedding in enumerate(embeddings)
]
response = EmbeddingsResponse(
data=data,
model=model,
usage=EmbeddingsUsage(
prompt_tokens=input_tokens,
total_tokens=input_tokens + output_tokens,
),
)
if DEBUG:
logger.info("Proxy response :" + response.model_dump_json())
return response
def get_embeddings_model(model_id: str) -> BedrockEmbeddingsModel:
model_name = SUPPORTED_BEDROCK_EMBEDDING_MODELS.get(model_id, "")
if DEBUG:
logger.info("model name is " + model_name)
if model_name in ["Cohere Embed Multilingual", "Cohere Embed English"]:
return CohereEmbeddingsModel()
elif model_name in ["Titan Embeddings G1 - Text", "Titan Multimodal Embeddings G1"]:
return TitanEmbeddingsModel()
else:
logger.error("Unsupported model id " + model_id)
raise ValueError("Invalid model ID")
class CohereEmbeddingsModel(BedrockEmbeddingsModel):
def _parse_args(self, embeddings_request: EmbeddingsRequest) -> dict:
args = {
"texts": embeddings_request.input,
"input_type": embeddings_request.input_type if embeddings_request.input_type else "search_document",
"truncate": embeddings_request.truncate if embeddings_request.truncate else "NONE",
}
return args
def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse:
response = self._invoke_model(
args=self._parse_args(embeddings_request), model_id=embeddings_request.model
)
response_body = json.loads(response.get("body").read())
if DEBUG:
logger.info("Bedrock response body: " + str(response_body))
return self._create_response(
embeddings=response_body["embeddings"],
model=embeddings_request.model,
)
class TitanEmbeddingsModel(BedrockEmbeddingsModel):
def _parse_args(self, embeddings_request: EmbeddingsRequest) -> dict:
if isinstance(embeddings_request.input, str):
input_text = embeddings_request.input
elif isinstance(embeddings_request.input, list) and len(embeddings_request.input) == 1:
input_text = embeddings_request.input[0]
else:
raise ValueError("Amazon Titan Embeddings models support only single strings as input.")
args = {
"inputText": input_text,
# Note: inputImage is not supported!
}
if embeddings_request.model == "amazon.titan-embed-image-v1":
args["embeddingConfig"] = embeddings_request.embedding_config if embeddings_request.embedding_config else {"outputEmbeddingLength": 1024}
return args
def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse:
response = self._invoke_model(
args=self._parse_args(embeddings_request), model_id=embeddings_request.model
)
response_body = json.loads(response.get("body").read())
if DEBUG:
logger.info("Bedrock response body: " + str(response_body))
return self._create_response(
embeddings=[response_body["embedding"]],
model=embeddings_request.model,
input_tokens=response_body["inputTextTokenCount"]
)