Update embedding API
This commit is contained in:
51
src/api/models/base.py
Normal file
51
src/api/models/base.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
import uuid
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import AsyncIterable
|
||||||
|
|
||||||
|
from api.schema import (
|
||||||
|
# Chat
|
||||||
|
ChatResponse,
|
||||||
|
ChatRequest,
|
||||||
|
ChatStreamResponse,
|
||||||
|
# Embeddings
|
||||||
|
EmbeddingsRequest,
|
||||||
|
EmbeddingsResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseChatModel(ABC):
|
||||||
|
"""Represent a basic chat model
|
||||||
|
|
||||||
|
Currently, only Bedrock model is supported, but may be used for SageMaker models if needed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def chat(self, chat_request: ChatRequest) -> ChatResponse:
|
||||||
|
"""Handle a basic chat completion requests."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
|
||||||
|
"""Handle a basic chat completion requests with stream response."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _generate_message_id(self) -> str:
|
||||||
|
return "chatcmpl-" + str(uuid.uuid4())[:8]
|
||||||
|
|
||||||
|
def _stream_response_to_bytes(self, response: ChatStreamResponse) -> bytes:
|
||||||
|
return "data: {}\n\n".format(response.model_dump_json()).encode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
class BaseEmbeddingsModel(ABC):
|
||||||
|
"""Represents a basic embeddings model.
|
||||||
|
|
||||||
|
Currently, only Bedrock-provided models are supported, but it may be used for SageMaker models if needed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse:
|
||||||
|
"""Handle a basic embeddings request."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _generate_message_id(self) -> str:
|
||||||
|
return "embeddings-" + str(uuid.uuid4())[:8]
|
||||||
@@ -1,11 +1,10 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import AsyncIterable
|
from typing import AsyncIterable
|
||||||
|
|
||||||
import boto3
|
import boto3
|
||||||
|
|
||||||
|
from api.models.base import BaseChatModel, BaseEmbeddingsModel
|
||||||
from api.schema import (
|
from api.schema import (
|
||||||
# Chat
|
# Chat
|
||||||
ChatResponse,
|
ChatResponse,
|
||||||
@@ -50,28 +49,6 @@ SUPPORTED_BEDROCK_EMBEDDING_MODELS = {
|
|||||||
"amazon.titan-embed-image-v1": "Titan Multimodal Embeddings G1"
|
"amazon.titan-embed-image-v1": "Titan Multimodal Embeddings G1"
|
||||||
}
|
}
|
||||||
|
|
||||||
class BaseChatModel(ABC):
|
|
||||||
"""Represent a basic chat model
|
|
||||||
|
|
||||||
Currently, only Bedrock model is supported, but may be used for SageMaker models if needed.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def chat(self, chat_request: ChatRequest) -> ChatResponse:
|
|
||||||
"""Handle a basic chat completion requests."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
|
|
||||||
"""Handle a basic chat completion requests with stream response."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _generate_message_id(self) -> str:
|
|
||||||
return "chatcmpl-" + str(uuid.uuid4())[:8]
|
|
||||||
|
|
||||||
def _stream_response_to_bytes(self, response: ChatStreamResponse) -> bytes:
|
|
||||||
return "data: {}\n\n".format(response.model_dump_json()).encode("utf-8")
|
|
||||||
|
|
||||||
|
|
||||||
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html
|
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html
|
||||||
class BedrockModel(BaseChatModel):
|
class BedrockModel(BaseChatModel):
|
||||||
@@ -403,37 +380,15 @@ class MistralModel(BedrockModel):
|
|||||||
yield self._stream_response_to_bytes(response)
|
yield self._stream_response_to_bytes(response)
|
||||||
|
|
||||||
|
|
||||||
class BaseEmbeddingsModel(ABC):
|
|
||||||
"""Represents a basic embeddings model.
|
|
||||||
|
|
||||||
Currently, only Bedrock-provided models are supported, but it may be used for SageMaker models if needed.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse:
|
|
||||||
"""Handle a basic embeddings request."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _generate_message_id(self) -> str:
|
|
||||||
return "embeddings-" + str(uuid.uuid4())[:8]
|
|
||||||
|
|
||||||
|
|
||||||
class BedrockEmbeddingsModel(BaseEmbeddingsModel):
|
class BedrockEmbeddingsModel(BaseEmbeddingsModel):
|
||||||
accept = "application/json"
|
accept = "application/json"
|
||||||
content_type = "application/json"
|
content_type = "application/json"
|
||||||
|
|
||||||
def _invoke_model(self, args: dict, model_id: str, with_stream: bool = False):
|
def _invoke_model(self, args: dict, model_id: str):
|
||||||
body = json.dumps(args)
|
body = json.dumps(args)
|
||||||
if DEBUG:
|
if DEBUG:
|
||||||
logger.info("Invoke Bedrock Model: " + model_id)
|
logger.info("Invoke Bedrock Model: " + model_id)
|
||||||
logger.info("Bedrock request body: " + body)
|
logger.info("Bedrock request body: " + body)
|
||||||
if with_stream:
|
|
||||||
return bedrock_runtime.invoke_model_with_response_stream(
|
|
||||||
body=body,
|
|
||||||
modelId=model_id,
|
|
||||||
accept=self.accept,
|
|
||||||
contentType=self.content_type,
|
|
||||||
)
|
|
||||||
return bedrock_runtime.invoke_model(
|
return bedrock_runtime.invoke_model(
|
||||||
body=body,
|
body=body,
|
||||||
modelId=model_id,
|
modelId=model_id,
|
||||||
@@ -462,6 +417,7 @@ class BedrockEmbeddingsModel(BaseEmbeddingsModel):
|
|||||||
total_tokens=input_tokens + output_tokens,
|
total_tokens=input_tokens + output_tokens,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
if DEBUG:
|
if DEBUG:
|
||||||
logger.info("Proxy response :" + response.model_dump_json())
|
logger.info("Proxy response :" + response.model_dump_json())
|
||||||
return response
|
return response
|
||||||
@@ -487,10 +443,12 @@ class CohereEmbeddingsModel(BedrockEmbeddingsModel):
|
|||||||
texts = [embeddings_request.input]
|
texts = [embeddings_request.input]
|
||||||
elif isinstance(embeddings_request.input, list):
|
elif isinstance(embeddings_request.input, list):
|
||||||
texts = embeddings_request.input
|
texts = embeddings_request.input
|
||||||
|
|
||||||
|
# Maximum of 2048 characters
|
||||||
args = {
|
args = {
|
||||||
"texts": texts,
|
"texts": texts,
|
||||||
"input_type": embeddings_request.input_type if embeddings_request.input_type else "search_document",
|
"input_type": "search_document",
|
||||||
"truncate": embeddings_request.truncate if embeddings_request.truncate else "NONE",
|
"truncate": "END", # "NONE|START|END"
|
||||||
}
|
}
|
||||||
return args
|
return args
|
||||||
|
|
||||||
@@ -522,7 +480,8 @@ class TitanEmbeddingsModel(BedrockEmbeddingsModel):
|
|||||||
# Note: inputImage is not supported!
|
# Note: inputImage is not supported!
|
||||||
}
|
}
|
||||||
if embeddings_request.model == "amazon.titan-embed-image-v1":
|
if embeddings_request.model == "amazon.titan-embed-image-v1":
|
||||||
args["embeddingConfig"] = embeddings_request.embedding_config if embeddings_request.embedding_config else {"outputEmbeddingLength": 1024}
|
args["embeddingConfig"] = embeddings_request.embedding_config if embeddings_request.embedding_config else {
|
||||||
|
"outputEmbeddingLength": 1024}
|
||||||
return args
|
return args
|
||||||
|
|
||||||
def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse:
|
def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse:
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ router = APIRouter()
|
|||||||
|
|
||||||
router = APIRouter(
|
router = APIRouter(
|
||||||
prefix="/chat",
|
prefix="/chat",
|
||||||
tags=["items"],
|
|
||||||
dependencies=[Depends(api_key_auth)],
|
dependencies=[Depends(api_key_auth)],
|
||||||
# responses={404: {"description": "Not found"}},
|
# responses={404: {"description": "Not found"}},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ router = APIRouter()
|
|||||||
|
|
||||||
router = APIRouter(
|
router = APIRouter(
|
||||||
prefix="/embeddings",
|
prefix="/embeddings",
|
||||||
tags=["items"],
|
|
||||||
dependencies=[Depends(api_key_auth)],
|
dependencies=[Depends(api_key_auth)],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -38,6 +37,7 @@ async def embeddings(
|
|||||||
raise HTTPException(status_code=400, detail="Unsupported Model Id " + embeddings_request.model)
|
raise HTTPException(status_code=400, detail="Unsupported Model Id " + embeddings_request.model)
|
||||||
try:
|
try:
|
||||||
model = get_embeddings_model(embeddings_request.model)
|
model = get_embeddings_model(embeddings_request.model)
|
||||||
|
# TODO: Check type of input
|
||||||
return model.embed(embeddings_request)
|
return model.embed(embeddings_request)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import time
|
import time
|
||||||
from typing import Literal
|
from typing import Literal, Iterable
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
@@ -81,18 +81,11 @@ class ChatStreamResponse(BaseChatResponse):
|
|||||||
|
|
||||||
|
|
||||||
class EmbeddingsRequest(BaseModel):
|
class EmbeddingsRequest(BaseModel):
|
||||||
input: str | list[str]
|
input: str | list[str] | Iterable[int] | Iterable[Iterable[int]]
|
||||||
model: str
|
|
||||||
# Cohere Embed
|
|
||||||
input_type: Literal["search_document", "search_query", "classification", "clustering"] | None = None
|
|
||||||
truncate: Literal["NONE", "LEFT", "RIGHT"] | None = None
|
|
||||||
# Titan Embeddings
|
|
||||||
embedding_config: dict | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class BaseEmbeddingsResponse(BaseModel):
|
|
||||||
created: int = Field(default_factory=lambda: int(time.time()))
|
|
||||||
model: str
|
model: str
|
||||||
|
encoding_format: Literal["float", "base64"] = "float" # not used.
|
||||||
|
dimensions: int | None = None # not used.
|
||||||
|
user: str | None = None # not used.
|
||||||
|
|
||||||
|
|
||||||
class Embedding(BaseModel):
|
class Embedding(BaseModel):
|
||||||
@@ -106,7 +99,8 @@ class EmbeddingsUsage(BaseModel):
|
|||||||
total_tokens: int
|
total_tokens: int
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingsResponse(BaseEmbeddingsResponse):
|
class EmbeddingsResponse(BaseModel):
|
||||||
data: list[Embedding]
|
|
||||||
object: Literal["list"] = "list"
|
object: Literal["list"] = "list"
|
||||||
|
data: list[Embedding]
|
||||||
|
model: str
|
||||||
usage: EmbeddingsUsage
|
usage: EmbeddingsUsage
|
||||||
Reference in New Issue
Block a user