Added support for Cohere Embed and Titan Embeddings models
This commit is contained in:
@@ -7,7 +7,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import PlainTextResponse
|
||||
from mangum import Mangum
|
||||
|
||||
from api.routers import model, chat
|
||||
from api.routers import model, chat, embeddings
|
||||
from api.setting import API_ROUTE_PREFIX, TITLE, DESCRIPTION, SUMMARY, VERSION
|
||||
|
||||
config = {
|
||||
@@ -33,6 +33,7 @@ app.add_middleware(
|
||||
|
||||
app.include_router(model.router, prefix=API_ROUTE_PREFIX)
|
||||
app.include_router(chat.router, prefix=API_ROUTE_PREFIX)
|
||||
app.include_router(embeddings.router, prefix=API_ROUTE_PREFIX)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
|
||||
@@ -1 +1,7 @@
|
||||
from api.models.bedrock import ClaudeModel, SUPPORTED_BEDROCK_MODELS, get_model
|
||||
from api.models.bedrock import (
|
||||
ClaudeModel,
|
||||
SUPPORTED_BEDROCK_MODELS,
|
||||
SUPPORTED_BEDROCK_EMBEDDING_MODELS,
|
||||
get_model,
|
||||
get_embeddings_model,
|
||||
)
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import AsyncIterable
|
||||
import boto3
|
||||
|
||||
from api.schema import (
|
||||
# Chat
|
||||
ChatResponse,
|
||||
ChatRequest,
|
||||
ChatRequestMessage,
|
||||
@@ -15,6 +16,11 @@ from api.schema import (
|
||||
Usage,
|
||||
ChatStreamResponse,
|
||||
ChoiceDelta,
|
||||
# Embeddings
|
||||
EmbeddingsRequest,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingsUsage,
|
||||
Embedding,
|
||||
)
|
||||
from api.setting import DEBUG, AWS_REGION
|
||||
|
||||
@@ -37,6 +43,12 @@ SUPPORTED_BEDROCK_MODELS = {
|
||||
"mistral.mixtral-8x7b-instruct-v0:1": "Mixtral 8x7B Instruct",
|
||||
}
|
||||
|
||||
SUPPORTED_BEDROCK_EMBEDDING_MODELS = {
|
||||
"cohere.embed-multilingual-v3": "Cohere Embed Multilingual",
|
||||
"cohere.embed-english-v3": "Cohere Embed English",
|
||||
"amazon.titan-embed-text-v1": "Titan Embeddings G1 - Text",
|
||||
"amazon.titan-embed-image-v1": "Titan Multimodal Embeddings G1"
|
||||
}
|
||||
|
||||
class BaseChatModel(ABC):
|
||||
"""Represent a basic chat model
|
||||
@@ -389,3 +401,136 @@ class MistralModel(BedrockModel):
|
||||
finish_reason=chunk["outputs"][0]["stop_reason"],
|
||||
)
|
||||
yield self._stream_response_to_bytes(response)
|
||||
|
||||
|
||||
class BaseEmbeddingsModel(ABC):
|
||||
"""Represents a basic embeddings model.
|
||||
|
||||
Currently, only Bedrock-provided models are supported, but it may be used for SageMaker models if needed.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse:
|
||||
"""Handle a basic embeddings request."""
|
||||
pass
|
||||
|
||||
def _generate_message_id(self) -> str:
|
||||
return "embeddings-" + str(uuid.uuid4())[:8]
|
||||
|
||||
|
||||
class BedrockEmbeddingsModel(BaseEmbeddingsModel):
|
||||
accept = "application/json"
|
||||
content_type = "application/json"
|
||||
|
||||
def _invoke_model(self, args: dict, model_id: str, with_stream: bool = False):
|
||||
body = json.dumps(args)
|
||||
if DEBUG:
|
||||
logger.info("Invoke Bedrock Model: " + model_id)
|
||||
logger.info("Bedrock request body: " + body)
|
||||
if with_stream:
|
||||
return bedrock_runtime.invoke_model_with_response_stream(
|
||||
body=body,
|
||||
modelId=model_id,
|
||||
accept=self.accept,
|
||||
contentType=self.content_type,
|
||||
)
|
||||
return bedrock_runtime.invoke_model(
|
||||
body=body,
|
||||
modelId=model_id,
|
||||
accept=self.accept,
|
||||
contentType=self.content_type,
|
||||
)
|
||||
|
||||
def _create_response(
|
||||
self,
|
||||
embeddings: list[float],
|
||||
model: str,
|
||||
input_tokens: int = 0,
|
||||
output_tokens: int = 0,
|
||||
) -> EmbeddingsResponse:
|
||||
data = [
|
||||
Embedding(
|
||||
index=i,
|
||||
embedding=embedding
|
||||
) for i, embedding in enumerate(embeddings)
|
||||
]
|
||||
response = EmbeddingsResponse(
|
||||
data=data,
|
||||
model=model,
|
||||
usage=EmbeddingsUsage(
|
||||
prompt_tokens=input_tokens,
|
||||
total_tokens=input_tokens + output_tokens,
|
||||
),
|
||||
)
|
||||
if DEBUG:
|
||||
logger.info("Proxy response :" + response.model_dump_json())
|
||||
return response
|
||||
|
||||
|
||||
def get_embeddings_model(model_id: str) -> BedrockEmbeddingsModel:
|
||||
model_name = SUPPORTED_BEDROCK_EMBEDDING_MODELS.get(model_id, "")
|
||||
if DEBUG:
|
||||
logger.info("model name is " + model_name)
|
||||
if model_name in ["Cohere Embed Multilingual", "Cohere Embed English"]:
|
||||
return CohereEmbeddingsModel()
|
||||
elif model_name in ["Titan Embeddings G1 - Text", "Titan Multimodal Embeddings G1"]:
|
||||
return TitanEmbeddingsModel()
|
||||
else:
|
||||
logger.error("Unsupported model id " + model_id)
|
||||
raise ValueError("Invalid model ID")
|
||||
|
||||
|
||||
class CohereEmbeddingsModel(BedrockEmbeddingsModel):
|
||||
|
||||
def _parse_args(self, embeddings_request: EmbeddingsRequest) -> dict:
|
||||
args = {
|
||||
"texts": embeddings_request.input,
|
||||
"input_type": embeddings_request.input_type if embeddings_request.input_type else "search_document",
|
||||
"truncate": embeddings_request.truncate if embeddings_request.truncate else "NONE",
|
||||
}
|
||||
return args
|
||||
|
||||
def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse:
|
||||
response = self._invoke_model(
|
||||
args=self._parse_args(embeddings_request), model_id=embeddings_request.model
|
||||
)
|
||||
response_body = json.loads(response.get("body").read())
|
||||
if DEBUG:
|
||||
logger.info("Bedrock response body: " + str(response_body))
|
||||
|
||||
return self._create_response(
|
||||
embeddings=response_body["embeddings"],
|
||||
model=embeddings_request.model,
|
||||
)
|
||||
|
||||
|
||||
class TitanEmbeddingsModel(BedrockEmbeddingsModel):
|
||||
|
||||
def _parse_args(self, embeddings_request: EmbeddingsRequest) -> dict:
|
||||
if isinstance(embeddings_request.input, str):
|
||||
input_text = embeddings_request.input
|
||||
elif isinstance(embeddings_request.input, list) and len(embeddings_request.input) == 1:
|
||||
input_text = embeddings_request.input[0]
|
||||
else:
|
||||
raise ValueError("Amazon Titan Embeddings models support only single strings as input.")
|
||||
args = {
|
||||
"inputText": input_text,
|
||||
# Note: inputImage is not supported!
|
||||
}
|
||||
if embeddings_request.model == "amazon.titan-embed-image-v1":
|
||||
args["embeddingConfig"] = embeddings_request.embedding_config if embeddings_request.embedding_config else {"outputEmbeddingLength": 1024}
|
||||
return args
|
||||
|
||||
def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse:
|
||||
response = self._invoke_model(
|
||||
args=self._parse_args(embeddings_request), model_id=embeddings_request.model
|
||||
)
|
||||
response_body = json.loads(response.get("body").read())
|
||||
if DEBUG:
|
||||
logger.info("Bedrock response body: " + str(response_body))
|
||||
|
||||
return self._create_response(
|
||||
embeddings=[response_body["embedding"]],
|
||||
model=embeddings_request.model,
|
||||
input_tokens=response_body["inputTextTokenCount"]
|
||||
)
|
||||
|
||||
43
src/api/routers/embeddings.py
Normal file
43
src/api/routers/embeddings.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Body, HTTPException
|
||||
|
||||
from api.auth import api_key_auth
|
||||
from api.models import get_embeddings_model, SUPPORTED_BEDROCK_EMBEDDING_MODELS
|
||||
from api.schema import EmbeddingsRequest, EmbeddingsResponse
|
||||
from api.setting import DEFAULT_EMBEDDING_MODEL
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/embeddings",
|
||||
tags=["items"],
|
||||
dependencies=[Depends(api_key_auth)],
|
||||
)
|
||||
|
||||
|
||||
@router.post("/", response_model=EmbeddingsResponse)
|
||||
async def embeddings(
|
||||
embeddings_request: Annotated[
|
||||
EmbeddingsRequest,
|
||||
Body(
|
||||
examples=[
|
||||
{
|
||||
"model": "cohere.embed-multilingual-v3",
|
||||
"input": [
|
||||
"Your text string goes here"
|
||||
],
|
||||
}
|
||||
],
|
||||
),
|
||||
]
|
||||
):
|
||||
if embeddings_request.model.lower().startswith("text-embedding-"):
|
||||
embeddings_request.model = DEFAULT_EMBEDDING_MODEL
|
||||
if embeddings_request.model not in SUPPORTED_BEDROCK_EMBEDDING_MODELS.keys():
|
||||
raise HTTPException(status_code=400, detail="Unsupported Model Id " + embeddings_request.model)
|
||||
try:
|
||||
model = get_embeddings_model(embeddings_request.model)
|
||||
return model.embed(embeddings_request)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
@@ -3,7 +3,7 @@ from typing import Annotated
|
||||
from fastapi import APIRouter, Depends, HTTPException, Path
|
||||
|
||||
from api.auth import api_key_auth
|
||||
from api.models import SUPPORTED_BEDROCK_MODELS
|
||||
from api.models import SUPPORTED_BEDROCK_MODELS, SUPPORTED_BEDROCK_EMBEDDING_MODELS
|
||||
from api.schema import Models, Model
|
||||
|
||||
router = APIRouter()
|
||||
@@ -17,13 +17,13 @@ router = APIRouter(
|
||||
|
||||
|
||||
async def validate_model_id(model_id: str):
|
||||
if model_id not in SUPPORTED_BEDROCK_MODELS.keys():
|
||||
if model_id not in (SUPPORTED_BEDROCK_MODELS | SUPPORTED_BEDROCK_EMBEDDING_MODELS).keys():
|
||||
raise HTTPException(status_code=400, detail="Unsupported Model Id")
|
||||
|
||||
|
||||
@router.get("/", response_model=Models)
|
||||
async def list_models():
|
||||
model_list = [Model(id=model_id) for model_id in SUPPORTED_BEDROCK_MODELS.keys()]
|
||||
model_list = [Model(id=model_id) for model_id in (SUPPORTED_BEDROCK_MODELS | SUPPORTED_BEDROCK_EMBEDDING_MODELS).keys()]
|
||||
return Models(data=model_list)
|
||||
|
||||
|
||||
|
||||
@@ -78,3 +78,36 @@ class ChatResponse(BaseChatResponse):
|
||||
class ChatStreamResponse(BaseChatResponse):
|
||||
choices: list[ChoiceDelta]
|
||||
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
|
||||
|
||||
|
||||
class EmbeddingsRequest(BaseModel):
|
||||
input: str | list[str]
|
||||
model: str
|
||||
# Cohere Embed
|
||||
input_type: Literal["search_document", "search_query", "classification", "clustering"] | None = None
|
||||
truncate: Literal["NONE", "LEFT", "RIGHT"] | None = None
|
||||
# Titan Embeddings
|
||||
embedding_config: dict | None = None
|
||||
|
||||
|
||||
class BaseEmbeddingsResponse(BaseModel):
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
system_fingerprint: str = "fp_e97c09dd4e26"
|
||||
|
||||
|
||||
class Embedding(BaseModel):
|
||||
object: Literal["embedding"] = "embedding"
|
||||
embedding: list[float]
|
||||
index: int
|
||||
|
||||
|
||||
class EmbeddingsUsage(BaseModel):
|
||||
prompt_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class EmbeddingsResponse(BaseEmbeddingsResponse):
|
||||
data: list[Embedding]
|
||||
object: Literal["list"] = "list"
|
||||
usage: EmbeddingsUsage
|
||||
@@ -11,6 +11,8 @@ DESCRIPTION = """
|
||||
Use OpenAI-Compatible RESTful APIs for Amazon Bedrock models.
|
||||
|
||||
List of Amazon Bedrock models currently supported:
|
||||
|
||||
# Chat
|
||||
- anthropic.claude-instant-v1
|
||||
- anthropic.claude-v2:1
|
||||
- anthropic.claude-v2
|
||||
@@ -20,8 +22,15 @@ List of Amazon Bedrock models currently supported:
|
||||
- meta.llama2-70b-chat-v1
|
||||
- mistral.mistral-7b-instruct-v0:2
|
||||
- mistral.mixtral-8x7b-instruct-v0:1
|
||||
|
||||
# Embeddings
|
||||
- cohere.embed-multilingual-v3
|
||||
- cohere.embed-english-v3
|
||||
- amazon.titan-embed-text-v1
|
||||
- amazon.titan-embed-image-v1
|
||||
"""
|
||||
|
||||
DEBUG = os.environ.get("DEBUG", "false").lower() != "false"
|
||||
AWS_REGION = os.environ.get("AWS_REGION", "us-west-2")
|
||||
DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "anthropic.claude-3-sonnet-20240229-v1:0")
|
||||
DEFAULT_EMBEDDING_MODEL = os.environ.get("DEFAULT_EMBEDDING_MODEL", "cohere.embed-multilingual-v3")
|
||||
|
||||
Reference in New Issue
Block a user