From beed3b04b7866356cb9c4012fc7884a88e8dff9c Mon Sep 17 00:00:00 2001 From: Aiden Dai Date: Tue, 2 Apr 2024 09:30:18 +0800 Subject: [PATCH] Update embedding API --- src/api/models/base.py | 51 ++++++++++++++++++++ src/api/models/bedrock.py | 91 ++++++++++------------------------- src/api/routers/chat.py | 1 - src/api/routers/embeddings.py | 2 +- src/api/schema.py | 24 ++++----- 5 files changed, 86 insertions(+), 83 deletions(-) create mode 100644 src/api/models/base.py diff --git a/src/api/models/base.py b/src/api/models/base.py new file mode 100644 index 0000000..3c492b4 --- /dev/null +++ b/src/api/models/base.py @@ -0,0 +1,51 @@ +import uuid +from abc import ABC, abstractmethod +from typing import AsyncIterable + +from api.schema import ( + # Chat + ChatResponse, + ChatRequest, + ChatStreamResponse, + # Embeddings + EmbeddingsRequest, + EmbeddingsResponse, +) + + +class BaseChatModel(ABC): + """Represent a basic chat model + + Currently, only Bedrock model is supported, but may be used for SageMaker models if needed. + """ + + @abstractmethod + def chat(self, chat_request: ChatRequest) -> ChatResponse: + """Handle a basic chat completion requests.""" + pass + + @abstractmethod + def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]: + """Handle a basic chat completion requests with stream response.""" + pass + + def _generate_message_id(self) -> str: + return "chatcmpl-" + str(uuid.uuid4())[:8] + + def _stream_response_to_bytes(self, response: ChatStreamResponse) -> bytes: + return "data: {}\n\n".format(response.model_dump_json()).encode("utf-8") + + +class BaseEmbeddingsModel(ABC): + """Represents a basic embeddings model. + + Currently, only Bedrock-provided models are supported, but it may be used for SageMaker models if needed. + """ + + @abstractmethod + def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse: + """Handle a basic embeddings request.""" + pass + + def _generate_message_id(self) -> str: + return "embeddings-" + str(uuid.uuid4())[:8] diff --git a/src/api/models/bedrock.py b/src/api/models/bedrock.py index 48d54cc..a880a1a 100644 --- a/src/api/models/bedrock.py +++ b/src/api/models/bedrock.py @@ -1,11 +1,10 @@ import json import logging -import uuid -from abc import ABC, abstractmethod from typing import AsyncIterable import boto3 +from api.models.base import BaseChatModel, BaseEmbeddingsModel from api.schema import ( # Chat ChatResponse, @@ -50,28 +49,6 @@ SUPPORTED_BEDROCK_EMBEDDING_MODELS = { "amazon.titan-embed-image-v1": "Titan Multimodal Embeddings G1" } -class BaseChatModel(ABC): - """Represent a basic chat model - - Currently, only Bedrock model is supported, but may be used for SageMaker models if needed. - """ - - @abstractmethod - def chat(self, chat_request: ChatRequest) -> ChatResponse: - """Handle a basic chat completion requests.""" - pass - - @abstractmethod - def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]: - """Handle a basic chat completion requests with stream response.""" - pass - - def _generate_message_id(self) -> str: - return "chatcmpl-" + str(uuid.uuid4())[:8] - - def _stream_response_to_bytes(self, response: ChatStreamResponse) -> bytes: - return "data: {}\n\n".format(response.model_dump_json()).encode("utf-8") - # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html class BedrockModel(BaseChatModel): @@ -98,12 +75,12 @@ class BedrockModel(BaseChatModel): ) def _create_response( - self, - model: str, - message: str, - message_id: str, - input_tokens: int = 0, - output_tokens: int = 0, + self, + model: str, + message: str, + message_id: str, + input_tokens: int = 0, + output_tokens: int = 0, ) -> ChatResponse: choice = Choice( index=0, @@ -128,7 +105,7 @@ class BedrockModel(BaseChatModel): return response def _create_response_stream( - self, model: str, message_id: str, chunk_message: str, finish_reason: str | None + self, model: str, message_id: str, chunk_message: str, finish_reason: str | None ) -> ChatStreamResponse: choice = ChoiceDelta( index=0, @@ -403,37 +380,15 @@ class MistralModel(BedrockModel): yield self._stream_response_to_bytes(response) -class BaseEmbeddingsModel(ABC): - """Represents a basic embeddings model. - - Currently, only Bedrock-provided models are supported, but it may be used for SageMaker models if needed. - """ - - @abstractmethod - def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse: - """Handle a basic embeddings request.""" - pass - - def _generate_message_id(self) -> str: - return "embeddings-" + str(uuid.uuid4())[:8] - - class BedrockEmbeddingsModel(BaseEmbeddingsModel): accept = "application/json" content_type = "application/json" - def _invoke_model(self, args: dict, model_id: str, with_stream: bool = False): + def _invoke_model(self, args: dict, model_id: str): body = json.dumps(args) if DEBUG: logger.info("Invoke Bedrock Model: " + model_id) logger.info("Bedrock request body: " + body) - if with_stream: - return bedrock_runtime.invoke_model_with_response_stream( - body=body, - modelId=model_id, - accept=self.accept, - contentType=self.content_type, - ) return bedrock_runtime.invoke_model( body=body, modelId=model_id, @@ -442,18 +397,18 @@ class BedrockEmbeddingsModel(BaseEmbeddingsModel): ) def _create_response( - self, - embeddings: list[float], - model: str, - input_tokens: int = 0, - output_tokens: int = 0, + self, + embeddings: list[float], + model: str, + input_tokens: int = 0, + output_tokens: int = 0, ) -> EmbeddingsResponse: data = [ Embedding( - index=i, - embedding=embedding - ) for i, embedding in enumerate(embeddings) - ] + index=i, + embedding=embedding + ) for i, embedding in enumerate(embeddings) + ] response = EmbeddingsResponse( data=data, model=model, @@ -462,6 +417,7 @@ class BedrockEmbeddingsModel(BaseEmbeddingsModel): total_tokens=input_tokens + output_tokens, ), ) + if DEBUG: logger.info("Proxy response :" + response.model_dump_json()) return response @@ -487,10 +443,12 @@ class CohereEmbeddingsModel(BedrockEmbeddingsModel): texts = [embeddings_request.input] elif isinstance(embeddings_request.input, list): texts = embeddings_request.input + + # Maximum of 2048 characters args = { "texts": texts, - "input_type": embeddings_request.input_type if embeddings_request.input_type else "search_document", - "truncate": embeddings_request.truncate if embeddings_request.truncate else "NONE", + "input_type": "search_document", + "truncate": "END", # "NONE|START|END" } return args @@ -522,7 +480,8 @@ class TitanEmbeddingsModel(BedrockEmbeddingsModel): # Note: inputImage is not supported! } if embeddings_request.model == "amazon.titan-embed-image-v1": - args["embeddingConfig"] = embeddings_request.embedding_config if embeddings_request.embedding_config else {"outputEmbeddingLength": 1024} + args["embeddingConfig"] = embeddings_request.embedding_config if embeddings_request.embedding_config else { + "outputEmbeddingLength": 1024} return args def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse: diff --git a/src/api/routers/chat.py b/src/api/routers/chat.py index d9b2fa6..efc3ced 100644 --- a/src/api/routers/chat.py +++ b/src/api/routers/chat.py @@ -12,7 +12,6 @@ router = APIRouter() router = APIRouter( prefix="/chat", - tags=["items"], dependencies=[Depends(api_key_auth)], # responses={404: {"description": "Not found"}}, ) diff --git a/src/api/routers/embeddings.py b/src/api/routers/embeddings.py index f20aea3..f283553 100644 --- a/src/api/routers/embeddings.py +++ b/src/api/routers/embeddings.py @@ -11,7 +11,6 @@ router = APIRouter() router = APIRouter( prefix="/embeddings", - tags=["items"], dependencies=[Depends(api_key_auth)], ) @@ -38,6 +37,7 @@ async def embeddings( raise HTTPException(status_code=400, detail="Unsupported Model Id " + embeddings_request.model) try: model = get_embeddings_model(embeddings_request.model) + # TODO: Check type of input return model.embed(embeddings_request) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) diff --git a/src/api/schema.py b/src/api/schema.py index 05f3d1f..d31da54 100644 --- a/src/api/schema.py +++ b/src/api/schema.py @@ -1,5 +1,5 @@ import time -from typing import Literal +from typing import Literal, Iterable from pydantic import BaseModel, Field @@ -81,18 +81,11 @@ class ChatStreamResponse(BaseChatResponse): class EmbeddingsRequest(BaseModel): - input: str | list[str] - model: str - # Cohere Embed - input_type: Literal["search_document", "search_query", "classification", "clustering"] | None = None - truncate: Literal["NONE", "LEFT", "RIGHT"] | None = None - # Titan Embeddings - embedding_config: dict | None = None - - -class BaseEmbeddingsResponse(BaseModel): - created: int = Field(default_factory=lambda: int(time.time())) + input: str | list[str] | Iterable[int] | Iterable[Iterable[int]] model: str + encoding_format: Literal["float", "base64"] = "float" # not used. + dimensions: int | None = None # not used. + user: str | None = None # not used. class Embedding(BaseModel): @@ -106,7 +99,8 @@ class EmbeddingsUsage(BaseModel): total_tokens: int -class EmbeddingsResponse(BaseEmbeddingsResponse): - data: list[Embedding] +class EmbeddingsResponse(BaseModel): object: Literal["list"] = "list" - usage: EmbeddingsUsage \ No newline at end of file + data: list[Embedding] + model: str + usage: EmbeddingsUsage