Add base64 encoded embedding support
This commit is contained in:
@@ -3,9 +3,10 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from typing import AsyncIterable, Iterable
|
from typing import AsyncIterable, Iterable, Literal
|
||||||
|
|
||||||
import boto3
|
import boto3
|
||||||
|
import numpy as np
|
||||||
import requests
|
import requests
|
||||||
import tiktoken
|
import tiktoken
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
@@ -685,11 +686,17 @@ class BedrockEmbeddingsModel(BaseEmbeddingsModel, ABC):
|
|||||||
model: str,
|
model: str,
|
||||||
input_tokens: int = 0,
|
input_tokens: int = 0,
|
||||||
output_tokens: int = 0,
|
output_tokens: int = 0,
|
||||||
|
encoding_format: Literal["float", "base64"] = "float",
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
data = [
|
data = []
|
||||||
Embedding(index=i, embedding=embedding)
|
for i, embedding in enumerate(embeddings):
|
||||||
for i, embedding in enumerate(embeddings)
|
if encoding_format == "base64":
|
||||||
]
|
arr = np.array(embedding, dtype=np.float32)
|
||||||
|
arr_bytes = arr.tobytes()
|
||||||
|
encoded_embedding = base64.b64encode(arr_bytes)
|
||||||
|
data.append(Embedding(index=i, embedding=encoded_embedding))
|
||||||
|
else:
|
||||||
|
data.append(Embedding(index=i, embedding=embedding))
|
||||||
response = EmbeddingsResponse(
|
response = EmbeddingsResponse(
|
||||||
data=data,
|
data=data,
|
||||||
model=model,
|
model=model,
|
||||||
@@ -746,6 +753,7 @@ class CohereEmbeddingsModel(BedrockEmbeddingsModel):
|
|||||||
return self._create_response(
|
return self._create_response(
|
||||||
embeddings=response_body["embeddings"],
|
embeddings=response_body["embeddings"],
|
||||||
model=embeddings_request.model,
|
model=embeddings_request.model,
|
||||||
|
encoding_format=embeddings_request.encoding_format,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -143,14 +143,14 @@ class ChatStreamResponse(BaseChatResponse):
|
|||||||
class EmbeddingsRequest(BaseModel):
|
class EmbeddingsRequest(BaseModel):
|
||||||
input: str | list[str] | Iterable[int | Iterable[int]]
|
input: str | list[str] | Iterable[int | Iterable[int]]
|
||||||
model: str
|
model: str
|
||||||
encoding_format: Literal["float", "base64"] = "float" # not used.
|
encoding_format: Literal["float", "base64"] = "float"
|
||||||
dimensions: int | None = None # not used.
|
dimensions: int | None = None # not used.
|
||||||
user: str | None = None # not used.
|
user: str | None = None # not used.
|
||||||
|
|
||||||
|
|
||||||
class Embedding(BaseModel):
|
class Embedding(BaseModel):
|
||||||
object: Literal["embedding"] = "embedding"
|
object: Literal["embedding"] = "embedding"
|
||||||
embedding: list[float]
|
embedding: list[float] | bytes
|
||||||
index: int
|
index: int
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
fastapi==0.110.0
|
fastapi==0.110.2
|
||||||
pydantic==2.6.3
|
pydantic==2.7.1
|
||||||
uvicorn==0.27.0.post1
|
uvicorn==0.29.0
|
||||||
mangum==0.17.0
|
mangum==0.17.0
|
||||||
tiktoken==0.6.0
|
tiktoken==0.6.0
|
||||||
requests==2.31.0
|
requests==2.31.0
|
||||||
|
numpy==1.26.4
|
||||||
Reference in New Issue
Block a user