Added support for Cohere Embed and Titan Embeddings models
This commit is contained in:
@@ -7,6 +7,7 @@ from typing import AsyncIterable
|
||||
import boto3
|
||||
|
||||
from api.schema import (
|
||||
# Chat
|
||||
ChatResponse,
|
||||
ChatRequest,
|
||||
ChatRequestMessage,
|
||||
@@ -15,6 +16,11 @@ from api.schema import (
|
||||
Usage,
|
||||
ChatStreamResponse,
|
||||
ChoiceDelta,
|
||||
# Embeddings
|
||||
EmbeddingsRequest,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingsUsage,
|
||||
Embedding,
|
||||
)
|
||||
from api.setting import DEBUG, AWS_REGION
|
||||
|
||||
@@ -37,6 +43,12 @@ SUPPORTED_BEDROCK_MODELS = {
|
||||
"mistral.mixtral-8x7b-instruct-v0:1": "Mixtral 8x7B Instruct",
|
||||
}
|
||||
|
||||
SUPPORTED_BEDROCK_EMBEDDING_MODELS = {
|
||||
"cohere.embed-multilingual-v3": "Cohere Embed Multilingual",
|
||||
"cohere.embed-english-v3": "Cohere Embed English",
|
||||
"amazon.titan-embed-text-v1": "Titan Embeddings G1 - Text",
|
||||
"amazon.titan-embed-image-v1": "Titan Multimodal Embeddings G1"
|
||||
}
|
||||
|
||||
class BaseChatModel(ABC):
|
||||
"""Represent a basic chat model
|
||||
@@ -389,3 +401,136 @@ class MistralModel(BedrockModel):
|
||||
finish_reason=chunk["outputs"][0]["stop_reason"],
|
||||
)
|
||||
yield self._stream_response_to_bytes(response)
|
||||
|
||||
|
||||
class BaseEmbeddingsModel(ABC):
|
||||
"""Represents a basic embeddings model.
|
||||
|
||||
Currently, only Bedrock-provided models are supported, but it may be used for SageMaker models if needed.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse:
|
||||
"""Handle a basic embeddings request."""
|
||||
pass
|
||||
|
||||
def _generate_message_id(self) -> str:
|
||||
return "embeddings-" + str(uuid.uuid4())[:8]
|
||||
|
||||
|
||||
class BedrockEmbeddingsModel(BaseEmbeddingsModel):
|
||||
accept = "application/json"
|
||||
content_type = "application/json"
|
||||
|
||||
def _invoke_model(self, args: dict, model_id: str, with_stream: bool = False):
|
||||
body = json.dumps(args)
|
||||
if DEBUG:
|
||||
logger.info("Invoke Bedrock Model: " + model_id)
|
||||
logger.info("Bedrock request body: " + body)
|
||||
if with_stream:
|
||||
return bedrock_runtime.invoke_model_with_response_stream(
|
||||
body=body,
|
||||
modelId=model_id,
|
||||
accept=self.accept,
|
||||
contentType=self.content_type,
|
||||
)
|
||||
return bedrock_runtime.invoke_model(
|
||||
body=body,
|
||||
modelId=model_id,
|
||||
accept=self.accept,
|
||||
contentType=self.content_type,
|
||||
)
|
||||
|
||||
def _create_response(
|
||||
self,
|
||||
embeddings: list[float],
|
||||
model: str,
|
||||
input_tokens: int = 0,
|
||||
output_tokens: int = 0,
|
||||
) -> EmbeddingsResponse:
|
||||
data = [
|
||||
Embedding(
|
||||
index=i,
|
||||
embedding=embedding
|
||||
) for i, embedding in enumerate(embeddings)
|
||||
]
|
||||
response = EmbeddingsResponse(
|
||||
data=data,
|
||||
model=model,
|
||||
usage=EmbeddingsUsage(
|
||||
prompt_tokens=input_tokens,
|
||||
total_tokens=input_tokens + output_tokens,
|
||||
),
|
||||
)
|
||||
if DEBUG:
|
||||
logger.info("Proxy response :" + response.model_dump_json())
|
||||
return response
|
||||
|
||||
|
||||
def get_embeddings_model(model_id: str) -> BedrockEmbeddingsModel:
|
||||
model_name = SUPPORTED_BEDROCK_EMBEDDING_MODELS.get(model_id, "")
|
||||
if DEBUG:
|
||||
logger.info("model name is " + model_name)
|
||||
if model_name in ["Cohere Embed Multilingual", "Cohere Embed English"]:
|
||||
return CohereEmbeddingsModel()
|
||||
elif model_name in ["Titan Embeddings G1 - Text", "Titan Multimodal Embeddings G1"]:
|
||||
return TitanEmbeddingsModel()
|
||||
else:
|
||||
logger.error("Unsupported model id " + model_id)
|
||||
raise ValueError("Invalid model ID")
|
||||
|
||||
|
||||
class CohereEmbeddingsModel(BedrockEmbeddingsModel):
|
||||
|
||||
def _parse_args(self, embeddings_request: EmbeddingsRequest) -> dict:
|
||||
args = {
|
||||
"texts": embeddings_request.input,
|
||||
"input_type": embeddings_request.input_type if embeddings_request.input_type else "search_document",
|
||||
"truncate": embeddings_request.truncate if embeddings_request.truncate else "NONE",
|
||||
}
|
||||
return args
|
||||
|
||||
def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse:
|
||||
response = self._invoke_model(
|
||||
args=self._parse_args(embeddings_request), model_id=embeddings_request.model
|
||||
)
|
||||
response_body = json.loads(response.get("body").read())
|
||||
if DEBUG:
|
||||
logger.info("Bedrock response body: " + str(response_body))
|
||||
|
||||
return self._create_response(
|
||||
embeddings=response_body["embeddings"],
|
||||
model=embeddings_request.model,
|
||||
)
|
||||
|
||||
|
||||
class TitanEmbeddingsModel(BedrockEmbeddingsModel):
|
||||
|
||||
def _parse_args(self, embeddings_request: EmbeddingsRequest) -> dict:
|
||||
if isinstance(embeddings_request.input, str):
|
||||
input_text = embeddings_request.input
|
||||
elif isinstance(embeddings_request.input, list) and len(embeddings_request.input) == 1:
|
||||
input_text = embeddings_request.input[0]
|
||||
else:
|
||||
raise ValueError("Amazon Titan Embeddings models support only single strings as input.")
|
||||
args = {
|
||||
"inputText": input_text,
|
||||
# Note: inputImage is not supported!
|
||||
}
|
||||
if embeddings_request.model == "amazon.titan-embed-image-v1":
|
||||
args["embeddingConfig"] = embeddings_request.embedding_config if embeddings_request.embedding_config else {"outputEmbeddingLength": 1024}
|
||||
return args
|
||||
|
||||
def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse:
|
||||
response = self._invoke_model(
|
||||
args=self._parse_args(embeddings_request), model_id=embeddings_request.model
|
||||
)
|
||||
response_body = json.loads(response.get("body").read())
|
||||
if DEBUG:
|
||||
logger.info("Bedrock response body: " + str(response_body))
|
||||
|
||||
return self._create_response(
|
||||
embeddings=[response_body["embedding"]],
|
||||
model=embeddings_request.model,
|
||||
input_tokens=response_body["inputTextTokenCount"]
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user