Update embedding API

This commit is contained in:
Aiden Dai
2024-04-02 10:39:55 +08:00
parent beed3b04b7
commit 31ae10a275
3 changed files with 24 additions and 6 deletions

View File

@@ -1,8 +1,9 @@
import json import json
import logging import logging
from typing import AsyncIterable from typing import AsyncIterable, Iterable
import boto3 import boto3
import tiktoken
from api.models.base import BaseChatModel, BaseEmbeddingsModel from api.models.base import BaseChatModel, BaseEmbeddingsModel
from api.schema import ( from api.schema import (
@@ -45,10 +46,13 @@ SUPPORTED_BEDROCK_MODELS = {
SUPPORTED_BEDROCK_EMBEDDING_MODELS = { SUPPORTED_BEDROCK_EMBEDDING_MODELS = {
"cohere.embed-multilingual-v3": "Cohere Embed Multilingual", "cohere.embed-multilingual-v3": "Cohere Embed Multilingual",
"cohere.embed-english-v3": "Cohere Embed English", "cohere.embed-english-v3": "Cohere Embed English",
"amazon.titan-embed-text-v1": "Titan Embeddings G1 - Text", # Disable Titan embedding.
"amazon.titan-embed-image-v1": "Titan Multimodal Embeddings G1" # "amazon.titan-embed-text-v1": "Titan Embeddings G1 - Text",
# "amazon.titan-embed-image-v1": "Titan Multimodal Embeddings G1"
} }
ENCODER = tiktoken.get_encoding("cl100k_base")
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html
class BedrockModel(BaseChatModel): class BedrockModel(BaseChatModel):
@@ -439,10 +443,25 @@ def get_embeddings_model(model_id: str) -> BedrockEmbeddingsModel:
class CohereEmbeddingsModel(BedrockEmbeddingsModel): class CohereEmbeddingsModel(BedrockEmbeddingsModel):
def _parse_args(self, embeddings_request: EmbeddingsRequest) -> dict: def _parse_args(self, embeddings_request: EmbeddingsRequest) -> dict:
texts = []
if isinstance(embeddings_request.input, str): if isinstance(embeddings_request.input, str):
texts = [embeddings_request.input] texts = [embeddings_request.input]
elif isinstance(embeddings_request.input, list): elif isinstance(embeddings_request.input, list):
texts = embeddings_request.input texts = embeddings_request.input
elif isinstance(embeddings_request.input, Iterable):
# For encoded input
# The workaround is to use tiktoken to decode to get the original text.
encodings = []
for inner in embeddings_request.input:
if isinstance(inner, int):
# Iterable[int]
encodings.append(inner)
else:
# Iterable[Iterable[int]]
text = ENCODER.decode(list(inner))
texts.append(text)
if encodings:
texts.append(ENCODER.decode(encodings))
# Maximum of 2048 characters # Maximum of 2048 characters
args = { args = {

View File

@@ -37,7 +37,6 @@ async def embeddings(
raise HTTPException(status_code=400, detail="Unsupported Model Id " + embeddings_request.model) raise HTTPException(status_code=400, detail="Unsupported Model Id " + embeddings_request.model)
try: try:
model = get_embeddings_model(embeddings_request.model) model = get_embeddings_model(embeddings_request.model)
# TODO: Check type of input
return model.embed(embeddings_request) return model.embed(embeddings_request)
except ValueError as e: except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(e))

View File

@@ -66,7 +66,7 @@ class BaseChatResponse(BaseModel):
id: str id: str
created: int = Field(default_factory=lambda: int(time.time())) created: int = Field(default_factory=lambda: int(time.time()))
model: str model: str
system_fingerprint: str = "fp_e97c09dd4e26" system_fingerprint: str = "fp"
class ChatResponse(BaseChatResponse): class ChatResponse(BaseChatResponse):
@@ -81,7 +81,7 @@ class ChatStreamResponse(BaseChatResponse):
class EmbeddingsRequest(BaseModel): class EmbeddingsRequest(BaseModel):
input: str | list[str] | Iterable[int] | Iterable[Iterable[int]] input: str | list[str] | Iterable[int | Iterable[int]]
model: str model: str
encoding_format: Literal["float", "base64"] = "float" # not used. encoding_format: Literal["float", "base64"] = "float" # not used.
dimensions: int | None = None # not used. dimensions: int | None = None # not used.