Added support for Cohere Embed and Titan Embeddings models

This commit is contained in:
Joao Galego
2024-03-27 17:27:23 +00:00
parent a26b8e6833
commit b24a43f6f4
7 changed files with 242 additions and 5 deletions

View File

@@ -7,7 +7,7 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import PlainTextResponse from fastapi.responses import PlainTextResponse
from mangum import Mangum from mangum import Mangum
from api.routers import model, chat from api.routers import model, chat, embeddings
from api.setting import API_ROUTE_PREFIX, TITLE, DESCRIPTION, SUMMARY, VERSION from api.setting import API_ROUTE_PREFIX, TITLE, DESCRIPTION, SUMMARY, VERSION
config = { config = {
@@ -33,6 +33,7 @@ app.add_middleware(
app.include_router(model.router, prefix=API_ROUTE_PREFIX) app.include_router(model.router, prefix=API_ROUTE_PREFIX)
app.include_router(chat.router, prefix=API_ROUTE_PREFIX) app.include_router(chat.router, prefix=API_ROUTE_PREFIX)
app.include_router(embeddings.router, prefix=API_ROUTE_PREFIX)
@app.get("/health") @app.get("/health")

View File

@@ -1 +1,7 @@
from api.models.bedrock import ClaudeModel, SUPPORTED_BEDROCK_MODELS, get_model from api.models.bedrock import (
ClaudeModel,
SUPPORTED_BEDROCK_MODELS,
SUPPORTED_BEDROCK_EMBEDDING_MODELS,
get_model,
get_embeddings_model,
)

View File

@@ -7,6 +7,7 @@ from typing import AsyncIterable
import boto3 import boto3
from api.schema import ( from api.schema import (
# Chat
ChatResponse, ChatResponse,
ChatRequest, ChatRequest,
ChatRequestMessage, ChatRequestMessage,
@@ -15,6 +16,11 @@ from api.schema import (
Usage, Usage,
ChatStreamResponse, ChatStreamResponse,
ChoiceDelta, ChoiceDelta,
# Embeddings
EmbeddingsRequest,
EmbeddingsResponse,
EmbeddingsUsage,
Embedding,
) )
from api.setting import DEBUG, AWS_REGION from api.setting import DEBUG, AWS_REGION
@@ -37,6 +43,12 @@ SUPPORTED_BEDROCK_MODELS = {
"mistral.mixtral-8x7b-instruct-v0:1": "Mixtral 8x7B Instruct", "mistral.mixtral-8x7b-instruct-v0:1": "Mixtral 8x7B Instruct",
} }
SUPPORTED_BEDROCK_EMBEDDING_MODELS = {
"cohere.embed-multilingual-v3": "Cohere Embed Multilingual",
"cohere.embed-english-v3": "Cohere Embed English",
"amazon.titan-embed-text-v1": "Titan Embeddings G1 - Text",
"amazon.titan-embed-image-v1": "Titan Multimodal Embeddings G1"
}
class BaseChatModel(ABC): class BaseChatModel(ABC):
"""Represent a basic chat model """Represent a basic chat model
@@ -389,3 +401,136 @@ class MistralModel(BedrockModel):
finish_reason=chunk["outputs"][0]["stop_reason"], finish_reason=chunk["outputs"][0]["stop_reason"],
) )
yield self._stream_response_to_bytes(response) yield self._stream_response_to_bytes(response)
class BaseEmbeddingsModel(ABC):
"""Represents a basic embeddings model.
Currently, only Bedrock-provided models are supported, but it may be used for SageMaker models if needed.
"""
@abstractmethod
def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse:
"""Handle a basic embeddings request."""
pass
def _generate_message_id(self) -> str:
return "embeddings-" + str(uuid.uuid4())[:8]
class BedrockEmbeddingsModel(BaseEmbeddingsModel):
accept = "application/json"
content_type = "application/json"
def _invoke_model(self, args: dict, model_id: str, with_stream: bool = False):
body = json.dumps(args)
if DEBUG:
logger.info("Invoke Bedrock Model: " + model_id)
logger.info("Bedrock request body: " + body)
if with_stream:
return bedrock_runtime.invoke_model_with_response_stream(
body=body,
modelId=model_id,
accept=self.accept,
contentType=self.content_type,
)
return bedrock_runtime.invoke_model(
body=body,
modelId=model_id,
accept=self.accept,
contentType=self.content_type,
)
def _create_response(
self,
embeddings: list[float],
model: str,
input_tokens: int = 0,
output_tokens: int = 0,
) -> EmbeddingsResponse:
data = [
Embedding(
index=i,
embedding=embedding
) for i, embedding in enumerate(embeddings)
]
response = EmbeddingsResponse(
data=data,
model=model,
usage=EmbeddingsUsage(
prompt_tokens=input_tokens,
total_tokens=input_tokens + output_tokens,
),
)
if DEBUG:
logger.info("Proxy response :" + response.model_dump_json())
return response
def get_embeddings_model(model_id: str) -> BedrockEmbeddingsModel:
model_name = SUPPORTED_BEDROCK_EMBEDDING_MODELS.get(model_id, "")
if DEBUG:
logger.info("model name is " + model_name)
if model_name in ["Cohere Embed Multilingual", "Cohere Embed English"]:
return CohereEmbeddingsModel()
elif model_name in ["Titan Embeddings G1 - Text", "Titan Multimodal Embeddings G1"]:
return TitanEmbeddingsModel()
else:
logger.error("Unsupported model id " + model_id)
raise ValueError("Invalid model ID")
class CohereEmbeddingsModel(BedrockEmbeddingsModel):
def _parse_args(self, embeddings_request: EmbeddingsRequest) -> dict:
args = {
"texts": embeddings_request.input,
"input_type": embeddings_request.input_type if embeddings_request.input_type else "search_document",
"truncate": embeddings_request.truncate if embeddings_request.truncate else "NONE",
}
return args
def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse:
response = self._invoke_model(
args=self._parse_args(embeddings_request), model_id=embeddings_request.model
)
response_body = json.loads(response.get("body").read())
if DEBUG:
logger.info("Bedrock response body: " + str(response_body))
return self._create_response(
embeddings=response_body["embeddings"],
model=embeddings_request.model,
)
class TitanEmbeddingsModel(BedrockEmbeddingsModel):
def _parse_args(self, embeddings_request: EmbeddingsRequest) -> dict:
if isinstance(embeddings_request.input, str):
input_text = embeddings_request.input
elif isinstance(embeddings_request.input, list) and len(embeddings_request.input) == 1:
input_text = embeddings_request.input[0]
else:
raise ValueError("Amazon Titan Embeddings models support only single strings as input.")
args = {
"inputText": input_text,
# Note: inputImage is not supported!
}
if embeddings_request.model == "amazon.titan-embed-image-v1":
args["embeddingConfig"] = embeddings_request.embedding_config if embeddings_request.embedding_config else {"outputEmbeddingLength": 1024}
return args
def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse:
response = self._invoke_model(
args=self._parse_args(embeddings_request), model_id=embeddings_request.model
)
response_body = json.loads(response.get("body").read())
if DEBUG:
logger.info("Bedrock response body: " + str(response_body))
return self._create_response(
embeddings=[response_body["embedding"]],
model=embeddings_request.model,
input_tokens=response_body["inputTextTokenCount"]
)

View File

@@ -0,0 +1,43 @@
from typing import Annotated
from fastapi import APIRouter, Depends, Body, HTTPException
from api.auth import api_key_auth
from api.models import get_embeddings_model, SUPPORTED_BEDROCK_EMBEDDING_MODELS
from api.schema import EmbeddingsRequest, EmbeddingsResponse
from api.setting import DEFAULT_EMBEDDING_MODEL
router = APIRouter()
router = APIRouter(
prefix="/embeddings",
tags=["items"],
dependencies=[Depends(api_key_auth)],
)
@router.post("/", response_model=EmbeddingsResponse)
async def embeddings(
embeddings_request: Annotated[
EmbeddingsRequest,
Body(
examples=[
{
"model": "cohere.embed-multilingual-v3",
"input": [
"Your text string goes here"
],
}
],
),
]
):
if embeddings_request.model.lower().startswith("text-embedding-"):
embeddings_request.model = DEFAULT_EMBEDDING_MODEL
if embeddings_request.model not in SUPPORTED_BEDROCK_EMBEDDING_MODELS.keys():
raise HTTPException(status_code=400, detail="Unsupported Model Id " + embeddings_request.model)
try:
model = get_embeddings_model(embeddings_request.model)
return model.embed(embeddings_request)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))

View File

@@ -3,7 +3,7 @@ from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, Path from fastapi import APIRouter, Depends, HTTPException, Path
from api.auth import api_key_auth from api.auth import api_key_auth
from api.models import SUPPORTED_BEDROCK_MODELS from api.models import SUPPORTED_BEDROCK_MODELS, SUPPORTED_BEDROCK_EMBEDDING_MODELS
from api.schema import Models, Model from api.schema import Models, Model
router = APIRouter() router = APIRouter()
@@ -17,13 +17,13 @@ router = APIRouter(
async def validate_model_id(model_id: str): async def validate_model_id(model_id: str):
if model_id not in SUPPORTED_BEDROCK_MODELS.keys(): if model_id not in (SUPPORTED_BEDROCK_MODELS | SUPPORTED_BEDROCK_EMBEDDING_MODELS).keys():
raise HTTPException(status_code=400, detail="Unsupported Model Id") raise HTTPException(status_code=400, detail="Unsupported Model Id")
@router.get("/", response_model=Models) @router.get("/", response_model=Models)
async def list_models(): async def list_models():
model_list = [Model(id=model_id) for model_id in SUPPORTED_BEDROCK_MODELS.keys()] model_list = [Model(id=model_id) for model_id in (SUPPORTED_BEDROCK_MODELS | SUPPORTED_BEDROCK_EMBEDDING_MODELS).keys()]
return Models(data=model_list) return Models(data=model_list)

View File

@@ -78,3 +78,36 @@ class ChatResponse(BaseChatResponse):
class ChatStreamResponse(BaseChatResponse): class ChatStreamResponse(BaseChatResponse):
choices: list[ChoiceDelta] choices: list[ChoiceDelta]
object: Literal["chat.completion.chunk"] = "chat.completion.chunk" object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
class EmbeddingsRequest(BaseModel):
input: str | list[str]
model: str
# Cohere Embed
input_type: Literal["search_document", "search_query", "classification", "clustering"] | None = None
truncate: Literal["NONE", "LEFT", "RIGHT"] | None = None
# Titan Embeddings
embedding_config: dict | None = None
class BaseEmbeddingsResponse(BaseModel):
created: int = Field(default_factory=lambda: int(time.time()))
model: str
system_fingerprint: str = "fp_e97c09dd4e26"
class Embedding(BaseModel):
object: Literal["embedding"] = "embedding"
embedding: list[float]
index: int
class EmbeddingsUsage(BaseModel):
prompt_tokens: int
total_tokens: int
class EmbeddingsResponse(BaseEmbeddingsResponse):
data: list[Embedding]
object: Literal["list"] = "list"
usage: EmbeddingsUsage

View File

@@ -11,6 +11,8 @@ DESCRIPTION = """
Use OpenAI-Compatible RESTful APIs for Amazon Bedrock models. Use OpenAI-Compatible RESTful APIs for Amazon Bedrock models.
List of Amazon Bedrock models currently supported: List of Amazon Bedrock models currently supported:
# Chat
- anthropic.claude-instant-v1 - anthropic.claude-instant-v1
- anthropic.claude-v2:1 - anthropic.claude-v2:1
- anthropic.claude-v2 - anthropic.claude-v2
@@ -20,8 +22,15 @@ List of Amazon Bedrock models currently supported:
- meta.llama2-70b-chat-v1 - meta.llama2-70b-chat-v1
- mistral.mistral-7b-instruct-v0:2 - mistral.mistral-7b-instruct-v0:2
- mistral.mixtral-8x7b-instruct-v0:1 - mistral.mixtral-8x7b-instruct-v0:1
# Embeddings
- cohere.embed-multilingual-v3
- cohere.embed-english-v3
- amazon.titan-embed-text-v1
- amazon.titan-embed-image-v1
""" """
DEBUG = os.environ.get("DEBUG", "false").lower() != "false" DEBUG = os.environ.get("DEBUG", "false").lower() != "false"
AWS_REGION = os.environ.get("AWS_REGION", "us-west-2") AWS_REGION = os.environ.get("AWS_REGION", "us-west-2")
DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "anthropic.claude-3-sonnet-20240229-v1:0") DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "anthropic.claude-3-sonnet-20240229-v1:0")
DEFAULT_EMBEDDING_MODEL = os.environ.get("DEFAULT_EMBEDDING_MODEL", "cohere.embed-multilingual-v3")