Update embedding API
This commit is contained in:
@@ -1,8 +1,9 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import AsyncIterable
|
from typing import AsyncIterable, Iterable
|
||||||
|
|
||||||
import boto3
|
import boto3
|
||||||
|
import tiktoken
|
||||||
|
|
||||||
from api.models.base import BaseChatModel, BaseEmbeddingsModel
|
from api.models.base import BaseChatModel, BaseEmbeddingsModel
|
||||||
from api.schema import (
|
from api.schema import (
|
||||||
@@ -45,10 +46,13 @@ SUPPORTED_BEDROCK_MODELS = {
|
|||||||
SUPPORTED_BEDROCK_EMBEDDING_MODELS = {
|
SUPPORTED_BEDROCK_EMBEDDING_MODELS = {
|
||||||
"cohere.embed-multilingual-v3": "Cohere Embed Multilingual",
|
"cohere.embed-multilingual-v3": "Cohere Embed Multilingual",
|
||||||
"cohere.embed-english-v3": "Cohere Embed English",
|
"cohere.embed-english-v3": "Cohere Embed English",
|
||||||
"amazon.titan-embed-text-v1": "Titan Embeddings G1 - Text",
|
# Disable Titan embedding.
|
||||||
"amazon.titan-embed-image-v1": "Titan Multimodal Embeddings G1"
|
# "amazon.titan-embed-text-v1": "Titan Embeddings G1 - Text",
|
||||||
|
# "amazon.titan-embed-image-v1": "Titan Multimodal Embeddings G1"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ENCODER = tiktoken.get_encoding("cl100k_base")
|
||||||
|
|
||||||
|
|
||||||
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html
|
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html
|
||||||
class BedrockModel(BaseChatModel):
|
class BedrockModel(BaseChatModel):
|
||||||
@@ -439,10 +443,25 @@ def get_embeddings_model(model_id: str) -> BedrockEmbeddingsModel:
|
|||||||
class CohereEmbeddingsModel(BedrockEmbeddingsModel):
|
class CohereEmbeddingsModel(BedrockEmbeddingsModel):
|
||||||
|
|
||||||
def _parse_args(self, embeddings_request: EmbeddingsRequest) -> dict:
|
def _parse_args(self, embeddings_request: EmbeddingsRequest) -> dict:
|
||||||
|
texts = []
|
||||||
if isinstance(embeddings_request.input, str):
|
if isinstance(embeddings_request.input, str):
|
||||||
texts = [embeddings_request.input]
|
texts = [embeddings_request.input]
|
||||||
elif isinstance(embeddings_request.input, list):
|
elif isinstance(embeddings_request.input, list):
|
||||||
texts = embeddings_request.input
|
texts = embeddings_request.input
|
||||||
|
elif isinstance(embeddings_request.input, Iterable):
|
||||||
|
# For encoded input
|
||||||
|
# The workaround is to use tiktoken to decode to get the original text.
|
||||||
|
encodings = []
|
||||||
|
for inner in embeddings_request.input:
|
||||||
|
if isinstance(inner, int):
|
||||||
|
# Iterable[int]
|
||||||
|
encodings.append(inner)
|
||||||
|
else:
|
||||||
|
# Iterable[Iterable[int]]
|
||||||
|
text = ENCODER.decode(list(inner))
|
||||||
|
texts.append(text)
|
||||||
|
if encodings:
|
||||||
|
texts.append(ENCODER.decode(encodings))
|
||||||
|
|
||||||
# Maximum of 2048 characters
|
# Maximum of 2048 characters
|
||||||
args = {
|
args = {
|
||||||
|
|||||||
@@ -37,7 +37,6 @@ async def embeddings(
|
|||||||
raise HTTPException(status_code=400, detail="Unsupported Model Id " + embeddings_request.model)
|
raise HTTPException(status_code=400, detail="Unsupported Model Id " + embeddings_request.model)
|
||||||
try:
|
try:
|
||||||
model = get_embeddings_model(embeddings_request.model)
|
model = get_embeddings_model(embeddings_request.model)
|
||||||
# TODO: Check type of input
|
|
||||||
return model.embed(embeddings_request)
|
return model.embed(embeddings_request)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ class BaseChatResponse(BaseModel):
|
|||||||
id: str
|
id: str
|
||||||
created: int = Field(default_factory=lambda: int(time.time()))
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
model: str
|
model: str
|
||||||
system_fingerprint: str = "fp_e97c09dd4e26"
|
system_fingerprint: str = "fp"
|
||||||
|
|
||||||
|
|
||||||
class ChatResponse(BaseChatResponse):
|
class ChatResponse(BaseChatResponse):
|
||||||
@@ -81,7 +81,7 @@ class ChatStreamResponse(BaseChatResponse):
|
|||||||
|
|
||||||
|
|
||||||
class EmbeddingsRequest(BaseModel):
|
class EmbeddingsRequest(BaseModel):
|
||||||
input: str | list[str] | Iterable[int] | Iterable[Iterable[int]]
|
input: str | list[str] | Iterable[int | Iterable[int]]
|
||||||
model: str
|
model: str
|
||||||
encoding_format: Literal["float", "base64"] = "float" # not used.
|
encoding_format: Literal["float", "base64"] = "float" # not used.
|
||||||
dimensions: int | None = None # not used.
|
dimensions: int | None = None # not used.
|
||||||
|
|||||||
Reference in New Issue
Block a user