Update embedding API

This commit is contained in:
Aiden Dai
2024-04-02 09:30:18 +08:00
parent a06703d92b
commit beed3b04b7
5 changed files with 86 additions and 83 deletions

51
src/api/models/base.py Normal file
View File

@@ -0,0 +1,51 @@
import uuid
from abc import ABC, abstractmethod
from typing import AsyncIterable
from api.schema import (
# Chat
ChatResponse,
ChatRequest,
ChatStreamResponse,
# Embeddings
EmbeddingsRequest,
EmbeddingsResponse,
)
class BaseChatModel(ABC):
"""Represent a basic chat model
Currently, only Bedrock model is supported, but may be used for SageMaker models if needed.
"""
@abstractmethod
def chat(self, chat_request: ChatRequest) -> ChatResponse:
"""Handle a basic chat completion requests."""
pass
@abstractmethod
def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
"""Handle a basic chat completion requests with stream response."""
pass
def _generate_message_id(self) -> str:
return "chatcmpl-" + str(uuid.uuid4())[:8]
def _stream_response_to_bytes(self, response: ChatStreamResponse) -> bytes:
return "data: {}\n\n".format(response.model_dump_json()).encode("utf-8")
class BaseEmbeddingsModel(ABC):
"""Represents a basic embeddings model.
Currently, only Bedrock-provided models are supported, but it may be used for SageMaker models if needed.
"""
@abstractmethod
def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse:
"""Handle a basic embeddings request."""
pass
def _generate_message_id(self) -> str:
return "embeddings-" + str(uuid.uuid4())[:8]

View File

@@ -1,11 +1,10 @@
import json
import logging
import uuid
from abc import ABC, abstractmethod
from typing import AsyncIterable
import boto3
from api.models.base import BaseChatModel, BaseEmbeddingsModel
from api.schema import (
# Chat
ChatResponse,
@@ -50,28 +49,6 @@ SUPPORTED_BEDROCK_EMBEDDING_MODELS = {
"amazon.titan-embed-image-v1": "Titan Multimodal Embeddings G1"
}
class BaseChatModel(ABC):
"""Represent a basic chat model
Currently, only Bedrock model is supported, but may be used for SageMaker models if needed.
"""
@abstractmethod
def chat(self, chat_request: ChatRequest) -> ChatResponse:
"""Handle a basic chat completion requests."""
pass
@abstractmethod
def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
"""Handle a basic chat completion requests with stream response."""
pass
def _generate_message_id(self) -> str:
return "chatcmpl-" + str(uuid.uuid4())[:8]
def _stream_response_to_bytes(self, response: ChatStreamResponse) -> bytes:
return "data: {}\n\n".format(response.model_dump_json()).encode("utf-8")
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html
class BedrockModel(BaseChatModel):
@@ -98,12 +75,12 @@ class BedrockModel(BaseChatModel):
)
def _create_response(
self,
model: str,
message: str,
message_id: str,
input_tokens: int = 0,
output_tokens: int = 0,
self,
model: str,
message: str,
message_id: str,
input_tokens: int = 0,
output_tokens: int = 0,
) -> ChatResponse:
choice = Choice(
index=0,
@@ -128,7 +105,7 @@ class BedrockModel(BaseChatModel):
return response
def _create_response_stream(
self, model: str, message_id: str, chunk_message: str, finish_reason: str | None
self, model: str, message_id: str, chunk_message: str, finish_reason: str | None
) -> ChatStreamResponse:
choice = ChoiceDelta(
index=0,
@@ -403,37 +380,15 @@ class MistralModel(BedrockModel):
yield self._stream_response_to_bytes(response)
class BaseEmbeddingsModel(ABC):
"""Represents a basic embeddings model.
Currently, only Bedrock-provided models are supported, but it may be used for SageMaker models if needed.
"""
@abstractmethod
def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse:
"""Handle a basic embeddings request."""
pass
def _generate_message_id(self) -> str:
return "embeddings-" + str(uuid.uuid4())[:8]
class BedrockEmbeddingsModel(BaseEmbeddingsModel):
accept = "application/json"
content_type = "application/json"
def _invoke_model(self, args: dict, model_id: str, with_stream: bool = False):
def _invoke_model(self, args: dict, model_id: str):
body = json.dumps(args)
if DEBUG:
logger.info("Invoke Bedrock Model: " + model_id)
logger.info("Bedrock request body: " + body)
if with_stream:
return bedrock_runtime.invoke_model_with_response_stream(
body=body,
modelId=model_id,
accept=self.accept,
contentType=self.content_type,
)
return bedrock_runtime.invoke_model(
body=body,
modelId=model_id,
@@ -442,18 +397,18 @@ class BedrockEmbeddingsModel(BaseEmbeddingsModel):
)
def _create_response(
self,
embeddings: list[float],
model: str,
input_tokens: int = 0,
output_tokens: int = 0,
self,
embeddings: list[float],
model: str,
input_tokens: int = 0,
output_tokens: int = 0,
) -> EmbeddingsResponse:
data = [
Embedding(
index=i,
embedding=embedding
) for i, embedding in enumerate(embeddings)
]
index=i,
embedding=embedding
) for i, embedding in enumerate(embeddings)
]
response = EmbeddingsResponse(
data=data,
model=model,
@@ -462,6 +417,7 @@ class BedrockEmbeddingsModel(BaseEmbeddingsModel):
total_tokens=input_tokens + output_tokens,
),
)
if DEBUG:
logger.info("Proxy response :" + response.model_dump_json())
return response
@@ -487,10 +443,12 @@ class CohereEmbeddingsModel(BedrockEmbeddingsModel):
texts = [embeddings_request.input]
elif isinstance(embeddings_request.input, list):
texts = embeddings_request.input
# Maximum of 2048 characters
args = {
"texts": texts,
"input_type": embeddings_request.input_type if embeddings_request.input_type else "search_document",
"truncate": embeddings_request.truncate if embeddings_request.truncate else "NONE",
"input_type": "search_document",
"truncate": "END", # "NONE|START|END"
}
return args
@@ -522,7 +480,8 @@ class TitanEmbeddingsModel(BedrockEmbeddingsModel):
# Note: inputImage is not supported!
}
if embeddings_request.model == "amazon.titan-embed-image-v1":
args["embeddingConfig"] = embeddings_request.embedding_config if embeddings_request.embedding_config else {"outputEmbeddingLength": 1024}
args["embeddingConfig"] = embeddings_request.embedding_config if embeddings_request.embedding_config else {
"outputEmbeddingLength": 1024}
return args
def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse:

View File

@@ -12,7 +12,6 @@ router = APIRouter()
router = APIRouter(
prefix="/chat",
tags=["items"],
dependencies=[Depends(api_key_auth)],
# responses={404: {"description": "Not found"}},
)

View File

@@ -11,7 +11,6 @@ router = APIRouter()
router = APIRouter(
prefix="/embeddings",
tags=["items"],
dependencies=[Depends(api_key_auth)],
)
@@ -38,6 +37,7 @@ async def embeddings(
raise HTTPException(status_code=400, detail="Unsupported Model Id " + embeddings_request.model)
try:
model = get_embeddings_model(embeddings_request.model)
# TODO: Check type of input
return model.embed(embeddings_request)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))

View File

@@ -1,5 +1,5 @@
import time
from typing import Literal
from typing import Literal, Iterable
from pydantic import BaseModel, Field
@@ -81,18 +81,11 @@ class ChatStreamResponse(BaseChatResponse):
class EmbeddingsRequest(BaseModel):
input: str | list[str]
model: str
# Cohere Embed
input_type: Literal["search_document", "search_query", "classification", "clustering"] | None = None
truncate: Literal["NONE", "LEFT", "RIGHT"] | None = None
# Titan Embeddings
embedding_config: dict | None = None
class BaseEmbeddingsResponse(BaseModel):
created: int = Field(default_factory=lambda: int(time.time()))
input: str | list[str] | Iterable[int] | Iterable[Iterable[int]]
model: str
encoding_format: Literal["float", "base64"] = "float" # not used.
dimensions: int | None = None # not used.
user: str | None = None # not used.
class Embedding(BaseModel):
@@ -106,7 +99,8 @@ class EmbeddingsUsage(BaseModel):
total_tokens: int
class EmbeddingsResponse(BaseEmbeddingsResponse):
data: list[Embedding]
class EmbeddingsResponse(BaseModel):
object: Literal["list"] = "list"
usage: EmbeddingsUsage
data: list[Embedding]
model: str
usage: EmbeddingsUsage