diff --git a/src/api/models/bedrock.py b/src/api/models/bedrock.py index e49c82c..db33cc9 100644 --- a/src/api/models/bedrock.py +++ b/src/api/models/bedrock.py @@ -3,9 +3,10 @@ import json import logging import re from abc import ABC -from typing import AsyncIterable, Iterable +from typing import AsyncIterable, Iterable, Literal import boto3 +import numpy as np import requests import tiktoken from fastapi import HTTPException @@ -685,11 +686,17 @@ class BedrockEmbeddingsModel(BaseEmbeddingsModel, ABC): model: str, input_tokens: int = 0, output_tokens: int = 0, + encoding_format: Literal["float", "base64"] = "float", ) -> EmbeddingsResponse: - data = [ - Embedding(index=i, embedding=embedding) - for i, embedding in enumerate(embeddings) - ] + data = [] + for i, embedding in enumerate(embeddings): + if encoding_format == "base64": + arr = np.array(embedding, dtype=np.float32) + arr_bytes = arr.tobytes() + encoded_embedding = base64.b64encode(arr_bytes) + data.append(Embedding(index=i, embedding=encoded_embedding)) + else: + data.append(Embedding(index=i, embedding=embedding)) response = EmbeddingsResponse( data=data, model=model, @@ -746,6 +753,7 @@ class CohereEmbeddingsModel(BedrockEmbeddingsModel): return self._create_response( embeddings=response_body["embeddings"], model=embeddings_request.model, + encoding_format=embeddings_request.encoding_format, ) diff --git a/src/api/schema.py b/src/api/schema.py index 39b9baf..fc4e786 100644 --- a/src/api/schema.py +++ b/src/api/schema.py @@ -143,14 +143,14 @@ class ChatStreamResponse(BaseChatResponse): class EmbeddingsRequest(BaseModel): input: str | list[str] | Iterable[int | Iterable[int]] model: str - encoding_format: Literal["float", "base64"] = "float" # not used. + encoding_format: Literal["float", "base64"] = "float" dimensions: int | None = None # not used. user: str | None = None # not used. class Embedding(BaseModel): object: Literal["embedding"] = "embedding" - embedding: list[float] + embedding: list[float] | bytes index: int diff --git a/src/requirements.txt b/src/requirements.txt index 49019b8..2b0808a 100644 --- a/src/requirements.txt +++ b/src/requirements.txt @@ -1,6 +1,7 @@ -fastapi==0.110.0 -pydantic==2.6.3 -uvicorn==0.27.0.post1 +fastapi==0.110.2 +pydantic==2.7.1 +uvicorn==0.29.0 mangum==0.17.0 tiktoken==0.6.0 -requests==2.31.0 \ No newline at end of file +requests==2.31.0 +numpy==1.26.4 \ No newline at end of file