From 31ae10a2758f7f6249531145c85e7a5f5dda24bd Mon Sep 17 00:00:00 2001 From: Aiden Dai Date: Tue, 2 Apr 2024 10:39:55 +0800 Subject: [PATCH] Update embedding API --- src/api/models/bedrock.py | 25 ++++++++++++++++++++++--- src/api/routers/embeddings.py | 1 - src/api/schema.py | 4 ++-- 3 files changed, 24 insertions(+), 6 deletions(-) diff --git a/src/api/models/bedrock.py b/src/api/models/bedrock.py index a880a1a..7f6035c 100644 --- a/src/api/models/bedrock.py +++ b/src/api/models/bedrock.py @@ -1,8 +1,9 @@ import json import logging -from typing import AsyncIterable +from typing import AsyncIterable, Iterable import boto3 +import tiktoken from api.models.base import BaseChatModel, BaseEmbeddingsModel from api.schema import ( @@ -45,10 +46,13 @@ SUPPORTED_BEDROCK_MODELS = { SUPPORTED_BEDROCK_EMBEDDING_MODELS = { "cohere.embed-multilingual-v3": "Cohere Embed Multilingual", "cohere.embed-english-v3": "Cohere Embed English", - "amazon.titan-embed-text-v1": "Titan Embeddings G1 - Text", - "amazon.titan-embed-image-v1": "Titan Multimodal Embeddings G1" + # Disable Titan embedding. + # "amazon.titan-embed-text-v1": "Titan Embeddings G1 - Text", + # "amazon.titan-embed-image-v1": "Titan Multimodal Embeddings G1" } +ENCODER = tiktoken.get_encoding("cl100k_base") + # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html class BedrockModel(BaseChatModel): @@ -439,10 +443,25 @@ def get_embeddings_model(model_id: str) -> BedrockEmbeddingsModel: class CohereEmbeddingsModel(BedrockEmbeddingsModel): def _parse_args(self, embeddings_request: EmbeddingsRequest) -> dict: + texts = [] if isinstance(embeddings_request.input, str): texts = [embeddings_request.input] elif isinstance(embeddings_request.input, list): texts = embeddings_request.input + elif isinstance(embeddings_request.input, Iterable): + # For encoded input + # The workaround is to use tiktoken to decode to get the original text. + encodings = [] + for inner in embeddings_request.input: + if isinstance(inner, int): + # Iterable[int] + encodings.append(inner) + else: + # Iterable[Iterable[int]] + text = ENCODER.decode(list(inner)) + texts.append(text) + if encodings: + texts.append(ENCODER.decode(encodings)) # Maximum of 2048 characters args = { diff --git a/src/api/routers/embeddings.py b/src/api/routers/embeddings.py index f283553..e823aa7 100644 --- a/src/api/routers/embeddings.py +++ b/src/api/routers/embeddings.py @@ -37,7 +37,6 @@ async def embeddings( raise HTTPException(status_code=400, detail="Unsupported Model Id " + embeddings_request.model) try: model = get_embeddings_model(embeddings_request.model) - # TODO: Check type of input return model.embed(embeddings_request) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) diff --git a/src/api/schema.py b/src/api/schema.py index d31da54..5791441 100644 --- a/src/api/schema.py +++ b/src/api/schema.py @@ -66,7 +66,7 @@ class BaseChatResponse(BaseModel): id: str created: int = Field(default_factory=lambda: int(time.time())) model: str - system_fingerprint: str = "fp_e97c09dd4e26" + system_fingerprint: str = "fp" class ChatResponse(BaseChatResponse): @@ -81,7 +81,7 @@ class ChatStreamResponse(BaseChatResponse): class EmbeddingsRequest(BaseModel): - input: str | list[str] | Iterable[int] | Iterable[Iterable[int]] + input: str | list[str] | Iterable[int | Iterable[int]] model: str encoding_format: Literal["float", "base64"] = "float" # not used. dimensions: int | None = None # not used.