Add base64 encoded embedding support

This commit is contained in:
Aiden Dai
2024-04-26 13:46:46 +08:00
parent 27d253fddb
commit 180c199da9
3 changed files with 20 additions and 11 deletions

View File

@@ -3,9 +3,10 @@ import json
import logging
import re
from abc import ABC
from typing import AsyncIterable, Iterable
from typing import AsyncIterable, Iterable, Literal
import boto3
import numpy as np
import requests
import tiktoken
from fastapi import HTTPException
@@ -685,11 +686,17 @@ class BedrockEmbeddingsModel(BaseEmbeddingsModel, ABC):
model: str,
input_tokens: int = 0,
output_tokens: int = 0,
encoding_format: Literal["float", "base64"] = "float",
) -> EmbeddingsResponse:
data = [
Embedding(index=i, embedding=embedding)
for i, embedding in enumerate(embeddings)
]
data = []
for i, embedding in enumerate(embeddings):
if encoding_format == "base64":
arr = np.array(embedding, dtype=np.float32)
arr_bytes = arr.tobytes()
encoded_embedding = base64.b64encode(arr_bytes)
data.append(Embedding(index=i, embedding=encoded_embedding))
else:
data.append(Embedding(index=i, embedding=embedding))
response = EmbeddingsResponse(
data=data,
model=model,
@@ -746,6 +753,7 @@ class CohereEmbeddingsModel(BedrockEmbeddingsModel):
return self._create_response(
embeddings=response_body["embeddings"],
model=embeddings_request.model,
encoding_format=embeddings_request.encoding_format,
)

View File

@@ -143,14 +143,14 @@ class ChatStreamResponse(BaseChatResponse):
class EmbeddingsRequest(BaseModel):
input: str | list[str] | Iterable[int | Iterable[int]]
model: str
encoding_format: Literal["float", "base64"] = "float" # not used.
encoding_format: Literal["float", "base64"] = "float"
dimensions: int | None = None # not used.
user: str | None = None # not used.
class Embedding(BaseModel):
object: Literal["embedding"] = "embedding"
embedding: list[float]
embedding: list[float] | bytes
index: int