Update embedding API

This commit is contained in:
Aiden Dai
2024-04-02 09:30:18 +08:00
parent a06703d92b
commit beed3b04b7
5 changed files with 86 additions and 83 deletions

View File

@@ -1,11 +1,10 @@
import json
import logging
import uuid
from abc import ABC, abstractmethod
from typing import AsyncIterable
import boto3
from api.models.base import BaseChatModel, BaseEmbeddingsModel
from api.schema import (
# Chat
ChatResponse,
@@ -50,28 +49,6 @@ SUPPORTED_BEDROCK_EMBEDDING_MODELS = {
"amazon.titan-embed-image-v1": "Titan Multimodal Embeddings G1"
}
class BaseChatModel(ABC):
"""Represent a basic chat model
Currently, only Bedrock model is supported, but may be used for SageMaker models if needed.
"""
@abstractmethod
def chat(self, chat_request: ChatRequest) -> ChatResponse:
"""Handle a basic chat completion requests."""
pass
@abstractmethod
def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
"""Handle a basic chat completion requests with stream response."""
pass
def _generate_message_id(self) -> str:
return "chatcmpl-" + str(uuid.uuid4())[:8]
def _stream_response_to_bytes(self, response: ChatStreamResponse) -> bytes:
return "data: {}\n\n".format(response.model_dump_json()).encode("utf-8")
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html
class BedrockModel(BaseChatModel):
@@ -98,12 +75,12 @@ class BedrockModel(BaseChatModel):
)
def _create_response(
self,
model: str,
message: str,
message_id: str,
input_tokens: int = 0,
output_tokens: int = 0,
self,
model: str,
message: str,
message_id: str,
input_tokens: int = 0,
output_tokens: int = 0,
) -> ChatResponse:
choice = Choice(
index=0,
@@ -128,7 +105,7 @@ class BedrockModel(BaseChatModel):
return response
def _create_response_stream(
self, model: str, message_id: str, chunk_message: str, finish_reason: str | None
self, model: str, message_id: str, chunk_message: str, finish_reason: str | None
) -> ChatStreamResponse:
choice = ChoiceDelta(
index=0,
@@ -403,37 +380,15 @@ class MistralModel(BedrockModel):
yield self._stream_response_to_bytes(response)
class BaseEmbeddingsModel(ABC):
"""Represents a basic embeddings model.
Currently, only Bedrock-provided models are supported, but it may be used for SageMaker models if needed.
"""
@abstractmethod
def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse:
"""Handle a basic embeddings request."""
pass
def _generate_message_id(self) -> str:
return "embeddings-" + str(uuid.uuid4())[:8]
class BedrockEmbeddingsModel(BaseEmbeddingsModel):
accept = "application/json"
content_type = "application/json"
def _invoke_model(self, args: dict, model_id: str, with_stream: bool = False):
def _invoke_model(self, args: dict, model_id: str):
body = json.dumps(args)
if DEBUG:
logger.info("Invoke Bedrock Model: " + model_id)
logger.info("Bedrock request body: " + body)
if with_stream:
return bedrock_runtime.invoke_model_with_response_stream(
body=body,
modelId=model_id,
accept=self.accept,
contentType=self.content_type,
)
return bedrock_runtime.invoke_model(
body=body,
modelId=model_id,
@@ -442,18 +397,18 @@ class BedrockEmbeddingsModel(BaseEmbeddingsModel):
)
def _create_response(
self,
embeddings: list[float],
model: str,
input_tokens: int = 0,
output_tokens: int = 0,
self,
embeddings: list[float],
model: str,
input_tokens: int = 0,
output_tokens: int = 0,
) -> EmbeddingsResponse:
data = [
Embedding(
index=i,
embedding=embedding
) for i, embedding in enumerate(embeddings)
]
index=i,
embedding=embedding
) for i, embedding in enumerate(embeddings)
]
response = EmbeddingsResponse(
data=data,
model=model,
@@ -462,6 +417,7 @@ class BedrockEmbeddingsModel(BaseEmbeddingsModel):
total_tokens=input_tokens + output_tokens,
),
)
if DEBUG:
logger.info("Proxy response :" + response.model_dump_json())
return response
@@ -487,10 +443,12 @@ class CohereEmbeddingsModel(BedrockEmbeddingsModel):
texts = [embeddings_request.input]
elif isinstance(embeddings_request.input, list):
texts = embeddings_request.input
# Maximum of 2048 characters
args = {
"texts": texts,
"input_type": embeddings_request.input_type if embeddings_request.input_type else "search_document",
"truncate": embeddings_request.truncate if embeddings_request.truncate else "NONE",
"input_type": "search_document",
"truncate": "END", # "NONE|START|END"
}
return args
@@ -522,7 +480,8 @@ class TitanEmbeddingsModel(BedrockEmbeddingsModel):
# Note: inputImage is not supported!
}
if embeddings_request.model == "amazon.titan-embed-image-v1":
args["embeddingConfig"] = embeddings_request.embedding_config if embeddings_request.embedding_config else {"outputEmbeddingLength": 1024}
args["embeddingConfig"] = embeddings_request.embedding_config if embeddings_request.embedding_config else {
"outputEmbeddingLength": 1024}
return args
def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse: