diff --git a/src/api/app.py b/src/api/app.py index cafdf70..7fa2f01 100644 --- a/src/api/app.py +++ b/src/api/app.py @@ -7,7 +7,7 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import PlainTextResponse from mangum import Mangum -from api.routers import model, chat +from api.routers import model, chat, embeddings from api.setting import API_ROUTE_PREFIX, TITLE, DESCRIPTION, SUMMARY, VERSION config = { @@ -33,6 +33,7 @@ app.add_middleware( app.include_router(model.router, prefix=API_ROUTE_PREFIX) app.include_router(chat.router, prefix=API_ROUTE_PREFIX) +app.include_router(embeddings.router, prefix=API_ROUTE_PREFIX) @app.get("/health") diff --git a/src/api/models/__init__.py b/src/api/models/__init__.py index 6cb3b00..eb89239 100644 --- a/src/api/models/__init__.py +++ b/src/api/models/__init__.py @@ -1 +1,7 @@ -from api.models.bedrock import ClaudeModel, SUPPORTED_BEDROCK_MODELS, get_model +from api.models.bedrock import ( + ClaudeModel, + SUPPORTED_BEDROCK_MODELS, + SUPPORTED_BEDROCK_EMBEDDING_MODELS, + get_model, + get_embeddings_model, +) diff --git a/src/api/models/bedrock.py b/src/api/models/bedrock.py index 6bdd132..75f4f3f 100644 --- a/src/api/models/bedrock.py +++ b/src/api/models/bedrock.py @@ -7,6 +7,7 @@ from typing import AsyncIterable import boto3 from api.schema import ( + # Chat ChatResponse, ChatRequest, ChatRequestMessage, @@ -15,6 +16,11 @@ from api.schema import ( Usage, ChatStreamResponse, ChoiceDelta, + # Embeddings + EmbeddingsRequest, + EmbeddingsResponse, + EmbeddingsUsage, + Embedding, ) from api.setting import DEBUG, AWS_REGION @@ -37,6 +43,12 @@ SUPPORTED_BEDROCK_MODELS = { "mistral.mixtral-8x7b-instruct-v0:1": "Mixtral 8x7B Instruct", } +SUPPORTED_BEDROCK_EMBEDDING_MODELS = { + "cohere.embed-multilingual-v3": "Cohere Embed Multilingual", + "cohere.embed-english-v3": "Cohere Embed English", + "amazon.titan-embed-text-v1": "Titan Embeddings G1 - Text", + "amazon.titan-embed-image-v1": "Titan Multimodal Embeddings G1" +} class BaseChatModel(ABC): """Represent a basic chat model @@ -389,3 +401,136 @@ class MistralModel(BedrockModel): finish_reason=chunk["outputs"][0]["stop_reason"], ) yield self._stream_response_to_bytes(response) + + +class BaseEmbeddingsModel(ABC): + """Represents a basic embeddings model. + + Currently, only Bedrock-provided models are supported, but it may be used for SageMaker models if needed. + """ + + @abstractmethod + def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse: + """Handle a basic embeddings request.""" + pass + + def _generate_message_id(self) -> str: + return "embeddings-" + str(uuid.uuid4())[:8] + + +class BedrockEmbeddingsModel(BaseEmbeddingsModel): + accept = "application/json" + content_type = "application/json" + + def _invoke_model(self, args: dict, model_id: str, with_stream: bool = False): + body = json.dumps(args) + if DEBUG: + logger.info("Invoke Bedrock Model: " + model_id) + logger.info("Bedrock request body: " + body) + if with_stream: + return bedrock_runtime.invoke_model_with_response_stream( + body=body, + modelId=model_id, + accept=self.accept, + contentType=self.content_type, + ) + return bedrock_runtime.invoke_model( + body=body, + modelId=model_id, + accept=self.accept, + contentType=self.content_type, + ) + + def _create_response( + self, + embeddings: list[float], + model: str, + input_tokens: int = 0, + output_tokens: int = 0, + ) -> EmbeddingsResponse: + data = [ + Embedding( + index=i, + embedding=embedding + ) for i, embedding in enumerate(embeddings) + ] + response = EmbeddingsResponse( + data=data, + model=model, + usage=EmbeddingsUsage( + prompt_tokens=input_tokens, + total_tokens=input_tokens + output_tokens, + ), + ) + if DEBUG: + logger.info("Proxy response :" + response.model_dump_json()) + return response + + +def get_embeddings_model(model_id: str) -> BedrockEmbeddingsModel: + model_name = SUPPORTED_BEDROCK_EMBEDDING_MODELS.get(model_id, "") + if DEBUG: + logger.info("model name is " + model_name) + if model_name in ["Cohere Embed Multilingual", "Cohere Embed English"]: + return CohereEmbeddingsModel() + elif model_name in ["Titan Embeddings G1 - Text", "Titan Multimodal Embeddings G1"]: + return TitanEmbeddingsModel() + else: + logger.error("Unsupported model id " + model_id) + raise ValueError("Invalid model ID") + + +class CohereEmbeddingsModel(BedrockEmbeddingsModel): + + def _parse_args(self, embeddings_request: EmbeddingsRequest) -> dict: + args = { + "texts": embeddings_request.input, + "input_type": embeddings_request.input_type if embeddings_request.input_type else "search_document", + "truncate": embeddings_request.truncate if embeddings_request.truncate else "NONE", + } + return args + + def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse: + response = self._invoke_model( + args=self._parse_args(embeddings_request), model_id=embeddings_request.model + ) + response_body = json.loads(response.get("body").read()) + if DEBUG: + logger.info("Bedrock response body: " + str(response_body)) + + return self._create_response( + embeddings=response_body["embeddings"], + model=embeddings_request.model, + ) + + +class TitanEmbeddingsModel(BedrockEmbeddingsModel): + + def _parse_args(self, embeddings_request: EmbeddingsRequest) -> dict: + if isinstance(embeddings_request.input, str): + input_text = embeddings_request.input + elif isinstance(embeddings_request.input, list) and len(embeddings_request.input) == 1: + input_text = embeddings_request.input[0] + else: + raise ValueError("Amazon Titan Embeddings models support only single strings as input.") + args = { + "inputText": input_text, + # Note: inputImage is not supported! + } + if embeddings_request.model == "amazon.titan-embed-image-v1": + args["embeddingConfig"] = embeddings_request.embedding_config if embeddings_request.embedding_config else {"outputEmbeddingLength": 1024} + return args + + def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse: + response = self._invoke_model( + args=self._parse_args(embeddings_request), model_id=embeddings_request.model + ) + response_body = json.loads(response.get("body").read()) + if DEBUG: + logger.info("Bedrock response body: " + str(response_body)) + + return self._create_response( + embeddings=[response_body["embedding"]], + model=embeddings_request.model, + input_tokens=response_body["inputTextTokenCount"] + ) diff --git a/src/api/routers/embeddings.py b/src/api/routers/embeddings.py new file mode 100644 index 0000000..f20aea3 --- /dev/null +++ b/src/api/routers/embeddings.py @@ -0,0 +1,43 @@ +from typing import Annotated + +from fastapi import APIRouter, Depends, Body, HTTPException + +from api.auth import api_key_auth +from api.models import get_embeddings_model, SUPPORTED_BEDROCK_EMBEDDING_MODELS +from api.schema import EmbeddingsRequest, EmbeddingsResponse +from api.setting import DEFAULT_EMBEDDING_MODEL + +router = APIRouter() + +router = APIRouter( + prefix="/embeddings", + tags=["items"], + dependencies=[Depends(api_key_auth)], +) + + +@router.post("/", response_model=EmbeddingsResponse) +async def embeddings( + embeddings_request: Annotated[ + EmbeddingsRequest, + Body( + examples=[ + { + "model": "cohere.embed-multilingual-v3", + "input": [ + "Your text string goes here" + ], + } + ], + ), + ] +): + if embeddings_request.model.lower().startswith("text-embedding-"): + embeddings_request.model = DEFAULT_EMBEDDING_MODEL + if embeddings_request.model not in SUPPORTED_BEDROCK_EMBEDDING_MODELS.keys(): + raise HTTPException(status_code=400, detail="Unsupported Model Id " + embeddings_request.model) + try: + model = get_embeddings_model(embeddings_request.model) + return model.embed(embeddings_request) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) diff --git a/src/api/routers/model.py b/src/api/routers/model.py index 4d10f98..985f902 100644 --- a/src/api/routers/model.py +++ b/src/api/routers/model.py @@ -3,7 +3,7 @@ from typing import Annotated from fastapi import APIRouter, Depends, HTTPException, Path from api.auth import api_key_auth -from api.models import SUPPORTED_BEDROCK_MODELS +from api.models import SUPPORTED_BEDROCK_MODELS, SUPPORTED_BEDROCK_EMBEDDING_MODELS from api.schema import Models, Model router = APIRouter() @@ -17,13 +17,13 @@ router = APIRouter( async def validate_model_id(model_id: str): - if model_id not in SUPPORTED_BEDROCK_MODELS.keys(): + if model_id not in (SUPPORTED_BEDROCK_MODELS | SUPPORTED_BEDROCK_EMBEDDING_MODELS).keys(): raise HTTPException(status_code=400, detail="Unsupported Model Id") @router.get("/", response_model=Models) async def list_models(): - model_list = [Model(id=model_id) for model_id in SUPPORTED_BEDROCK_MODELS.keys()] + model_list = [Model(id=model_id) for model_id in (SUPPORTED_BEDROCK_MODELS | SUPPORTED_BEDROCK_EMBEDDING_MODELS).keys()] return Models(data=model_list) diff --git a/src/api/schema.py b/src/api/schema.py index 53732c8..17969c7 100644 --- a/src/api/schema.py +++ b/src/api/schema.py @@ -78,3 +78,36 @@ class ChatResponse(BaseChatResponse): class ChatStreamResponse(BaseChatResponse): choices: list[ChoiceDelta] object: Literal["chat.completion.chunk"] = "chat.completion.chunk" + + +class EmbeddingsRequest(BaseModel): + input: str | list[str] + model: str + # Cohere Embed + input_type: Literal["search_document", "search_query", "classification", "clustering"] | None = None + truncate: Literal["NONE", "LEFT", "RIGHT"] | None = None + # Titan Embeddings + embedding_config: dict | None = None + + +class BaseEmbeddingsResponse(BaseModel): + created: int = Field(default_factory=lambda: int(time.time())) + model: str + system_fingerprint: str = "fp_e97c09dd4e26" + + +class Embedding(BaseModel): + object: Literal["embedding"] = "embedding" + embedding: list[float] + index: int + + +class EmbeddingsUsage(BaseModel): + prompt_tokens: int + total_tokens: int + + +class EmbeddingsResponse(BaseEmbeddingsResponse): + data: list[Embedding] + object: Literal["list"] = "list" + usage: EmbeddingsUsage \ No newline at end of file diff --git a/src/api/setting.py b/src/api/setting.py index f183aa2..6a14f22 100644 --- a/src/api/setting.py +++ b/src/api/setting.py @@ -11,6 +11,8 @@ DESCRIPTION = """ Use OpenAI-Compatible RESTful APIs for Amazon Bedrock models. List of Amazon Bedrock models currently supported: + +# Chat - anthropic.claude-instant-v1 - anthropic.claude-v2:1 - anthropic.claude-v2 @@ -20,8 +22,15 @@ List of Amazon Bedrock models currently supported: - meta.llama2-70b-chat-v1 - mistral.mistral-7b-instruct-v0:2 - mistral.mixtral-8x7b-instruct-v0:1 + +# Embeddings +- cohere.embed-multilingual-v3 +- cohere.embed-english-v3 +- amazon.titan-embed-text-v1 +- amazon.titan-embed-image-v1 """ DEBUG = os.environ.get("DEBUG", "false").lower() != "false" AWS_REGION = os.environ.get("AWS_REGION", "us-west-2") DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "anthropic.claude-3-sonnet-20240229-v1:0") +DEFAULT_EMBEDDING_MODEL = os.environ.get("DEFAULT_EMBEDDING_MODEL", "cohere.embed-multilingual-v3")