Update embedding API

This commit is contained in:
Aiden Dai
2024-04-02 09:30:18 +08:00
parent a06703d92b
commit beed3b04b7
5 changed files with 86 additions and 83 deletions

51
src/api/models/base.py Normal file
View File

@@ -0,0 +1,51 @@
import uuid
from abc import ABC, abstractmethod
from typing import AsyncIterable
from api.schema import (
# Chat
ChatResponse,
ChatRequest,
ChatStreamResponse,
# Embeddings
EmbeddingsRequest,
EmbeddingsResponse,
)
class BaseChatModel(ABC):
"""Represent a basic chat model
Currently, only Bedrock model is supported, but may be used for SageMaker models if needed.
"""
@abstractmethod
def chat(self, chat_request: ChatRequest) -> ChatResponse:
"""Handle a basic chat completion requests."""
pass
@abstractmethod
def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
"""Handle a basic chat completion requests with stream response."""
pass
def _generate_message_id(self) -> str:
return "chatcmpl-" + str(uuid.uuid4())[:8]
def _stream_response_to_bytes(self, response: ChatStreamResponse) -> bytes:
return "data: {}\n\n".format(response.model_dump_json()).encode("utf-8")
class BaseEmbeddingsModel(ABC):
"""Represents a basic embeddings model.
Currently, only Bedrock-provided models are supported, but it may be used for SageMaker models if needed.
"""
@abstractmethod
def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse:
"""Handle a basic embeddings request."""
pass
def _generate_message_id(self) -> str:
return "embeddings-" + str(uuid.uuid4())[:8]

View File

@@ -1,11 +1,10 @@
import json import json
import logging import logging
import uuid
from abc import ABC, abstractmethod
from typing import AsyncIterable from typing import AsyncIterable
import boto3 import boto3
from api.models.base import BaseChatModel, BaseEmbeddingsModel
from api.schema import ( from api.schema import (
# Chat # Chat
ChatResponse, ChatResponse,
@@ -50,28 +49,6 @@ SUPPORTED_BEDROCK_EMBEDDING_MODELS = {
"amazon.titan-embed-image-v1": "Titan Multimodal Embeddings G1" "amazon.titan-embed-image-v1": "Titan Multimodal Embeddings G1"
} }
class BaseChatModel(ABC):
"""Represent a basic chat model
Currently, only Bedrock model is supported, but may be used for SageMaker models if needed.
"""
@abstractmethod
def chat(self, chat_request: ChatRequest) -> ChatResponse:
"""Handle a basic chat completion requests."""
pass
@abstractmethod
def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
"""Handle a basic chat completion requests with stream response."""
pass
def _generate_message_id(self) -> str:
return "chatcmpl-" + str(uuid.uuid4())[:8]
def _stream_response_to_bytes(self, response: ChatStreamResponse) -> bytes:
return "data: {}\n\n".format(response.model_dump_json()).encode("utf-8")
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html
class BedrockModel(BaseChatModel): class BedrockModel(BaseChatModel):
@@ -98,12 +75,12 @@ class BedrockModel(BaseChatModel):
) )
def _create_response( def _create_response(
self, self,
model: str, model: str,
message: str, message: str,
message_id: str, message_id: str,
input_tokens: int = 0, input_tokens: int = 0,
output_tokens: int = 0, output_tokens: int = 0,
) -> ChatResponse: ) -> ChatResponse:
choice = Choice( choice = Choice(
index=0, index=0,
@@ -128,7 +105,7 @@ class BedrockModel(BaseChatModel):
return response return response
def _create_response_stream( def _create_response_stream(
self, model: str, message_id: str, chunk_message: str, finish_reason: str | None self, model: str, message_id: str, chunk_message: str, finish_reason: str | None
) -> ChatStreamResponse: ) -> ChatStreamResponse:
choice = ChoiceDelta( choice = ChoiceDelta(
index=0, index=0,
@@ -403,37 +380,15 @@ class MistralModel(BedrockModel):
yield self._stream_response_to_bytes(response) yield self._stream_response_to_bytes(response)
class BaseEmbeddingsModel(ABC):
"""Represents a basic embeddings model.
Currently, only Bedrock-provided models are supported, but it may be used for SageMaker models if needed.
"""
@abstractmethod
def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse:
"""Handle a basic embeddings request."""
pass
def _generate_message_id(self) -> str:
return "embeddings-" + str(uuid.uuid4())[:8]
class BedrockEmbeddingsModel(BaseEmbeddingsModel): class BedrockEmbeddingsModel(BaseEmbeddingsModel):
accept = "application/json" accept = "application/json"
content_type = "application/json" content_type = "application/json"
def _invoke_model(self, args: dict, model_id: str, with_stream: bool = False): def _invoke_model(self, args: dict, model_id: str):
body = json.dumps(args) body = json.dumps(args)
if DEBUG: if DEBUG:
logger.info("Invoke Bedrock Model: " + model_id) logger.info("Invoke Bedrock Model: " + model_id)
logger.info("Bedrock request body: " + body) logger.info("Bedrock request body: " + body)
if with_stream:
return bedrock_runtime.invoke_model_with_response_stream(
body=body,
modelId=model_id,
accept=self.accept,
contentType=self.content_type,
)
return bedrock_runtime.invoke_model( return bedrock_runtime.invoke_model(
body=body, body=body,
modelId=model_id, modelId=model_id,
@@ -442,18 +397,18 @@ class BedrockEmbeddingsModel(BaseEmbeddingsModel):
) )
def _create_response( def _create_response(
self, self,
embeddings: list[float], embeddings: list[float],
model: str, model: str,
input_tokens: int = 0, input_tokens: int = 0,
output_tokens: int = 0, output_tokens: int = 0,
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
data = [ data = [
Embedding( Embedding(
index=i, index=i,
embedding=embedding embedding=embedding
) for i, embedding in enumerate(embeddings) ) for i, embedding in enumerate(embeddings)
] ]
response = EmbeddingsResponse( response = EmbeddingsResponse(
data=data, data=data,
model=model, model=model,
@@ -462,6 +417,7 @@ class BedrockEmbeddingsModel(BaseEmbeddingsModel):
total_tokens=input_tokens + output_tokens, total_tokens=input_tokens + output_tokens,
), ),
) )
if DEBUG: if DEBUG:
logger.info("Proxy response :" + response.model_dump_json()) logger.info("Proxy response :" + response.model_dump_json())
return response return response
@@ -487,10 +443,12 @@ class CohereEmbeddingsModel(BedrockEmbeddingsModel):
texts = [embeddings_request.input] texts = [embeddings_request.input]
elif isinstance(embeddings_request.input, list): elif isinstance(embeddings_request.input, list):
texts = embeddings_request.input texts = embeddings_request.input
# Maximum of 2048 characters
args = { args = {
"texts": texts, "texts": texts,
"input_type": embeddings_request.input_type if embeddings_request.input_type else "search_document", "input_type": "search_document",
"truncate": embeddings_request.truncate if embeddings_request.truncate else "NONE", "truncate": "END", # "NONE|START|END"
} }
return args return args
@@ -522,7 +480,8 @@ class TitanEmbeddingsModel(BedrockEmbeddingsModel):
# Note: inputImage is not supported! # Note: inputImage is not supported!
} }
if embeddings_request.model == "amazon.titan-embed-image-v1": if embeddings_request.model == "amazon.titan-embed-image-v1":
args["embeddingConfig"] = embeddings_request.embedding_config if embeddings_request.embedding_config else {"outputEmbeddingLength": 1024} args["embeddingConfig"] = embeddings_request.embedding_config if embeddings_request.embedding_config else {
"outputEmbeddingLength": 1024}
return args return args
def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse: def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse:

View File

@@ -12,7 +12,6 @@ router = APIRouter()
router = APIRouter( router = APIRouter(
prefix="/chat", prefix="/chat",
tags=["items"],
dependencies=[Depends(api_key_auth)], dependencies=[Depends(api_key_auth)],
# responses={404: {"description": "Not found"}}, # responses={404: {"description": "Not found"}},
) )

View File

@@ -11,7 +11,6 @@ router = APIRouter()
router = APIRouter( router = APIRouter(
prefix="/embeddings", prefix="/embeddings",
tags=["items"],
dependencies=[Depends(api_key_auth)], dependencies=[Depends(api_key_auth)],
) )
@@ -38,6 +37,7 @@ async def embeddings(
raise HTTPException(status_code=400, detail="Unsupported Model Id " + embeddings_request.model) raise HTTPException(status_code=400, detail="Unsupported Model Id " + embeddings_request.model)
try: try:
model = get_embeddings_model(embeddings_request.model) model = get_embeddings_model(embeddings_request.model)
# TODO: Check type of input
return model.embed(embeddings_request) return model.embed(embeddings_request)
except ValueError as e: except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(e))

View File

@@ -1,5 +1,5 @@
import time import time
from typing import Literal from typing import Literal, Iterable
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@@ -81,18 +81,11 @@ class ChatStreamResponse(BaseChatResponse):
class EmbeddingsRequest(BaseModel): class EmbeddingsRequest(BaseModel):
input: str | list[str] input: str | list[str] | Iterable[int] | Iterable[Iterable[int]]
model: str
# Cohere Embed
input_type: Literal["search_document", "search_query", "classification", "clustering"] | None = None
truncate: Literal["NONE", "LEFT", "RIGHT"] | None = None
# Titan Embeddings
embedding_config: dict | None = None
class BaseEmbeddingsResponse(BaseModel):
created: int = Field(default_factory=lambda: int(time.time()))
model: str model: str
encoding_format: Literal["float", "base64"] = "float" # not used.
dimensions: int | None = None # not used.
user: str | None = None # not used.
class Embedding(BaseModel): class Embedding(BaseModel):
@@ -106,7 +99,8 @@ class EmbeddingsUsage(BaseModel):
total_tokens: int total_tokens: int
class EmbeddingsResponse(BaseEmbeddingsResponse): class EmbeddingsResponse(BaseModel):
data: list[Embedding]
object: Literal["list"] = "list" object: Literal["list"] = "list"
usage: EmbeddingsUsage data: list[Embedding]
model: str
usage: EmbeddingsUsage