Added support for Cohere Embed and Titan Embeddings models

This commit is contained in:
Joao Galego
2024-03-27 17:27:23 +00:00
parent a26b8e6833
commit b24a43f6f4
7 changed files with 242 additions and 5 deletions

View File

@@ -7,7 +7,7 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import PlainTextResponse
from mangum import Mangum
from api.routers import model, chat
from api.routers import model, chat, embeddings
from api.setting import API_ROUTE_PREFIX, TITLE, DESCRIPTION, SUMMARY, VERSION
config = {
@@ -33,6 +33,7 @@ app.add_middleware(
app.include_router(model.router, prefix=API_ROUTE_PREFIX)
app.include_router(chat.router, prefix=API_ROUTE_PREFIX)
app.include_router(embeddings.router, prefix=API_ROUTE_PREFIX)
@app.get("/health")

View File

@@ -1 +1,7 @@
from api.models.bedrock import ClaudeModel, SUPPORTED_BEDROCK_MODELS, get_model
from api.models.bedrock import (
ClaudeModel,
SUPPORTED_BEDROCK_MODELS,
SUPPORTED_BEDROCK_EMBEDDING_MODELS,
get_model,
get_embeddings_model,
)

View File

@@ -7,6 +7,7 @@ from typing import AsyncIterable
import boto3
from api.schema import (
# Chat
ChatResponse,
ChatRequest,
ChatRequestMessage,
@@ -15,6 +16,11 @@ from api.schema import (
Usage,
ChatStreamResponse,
ChoiceDelta,
# Embeddings
EmbeddingsRequest,
EmbeddingsResponse,
EmbeddingsUsage,
Embedding,
)
from api.setting import DEBUG, AWS_REGION
@@ -37,6 +43,12 @@ SUPPORTED_BEDROCK_MODELS = {
"mistral.mixtral-8x7b-instruct-v0:1": "Mixtral 8x7B Instruct",
}
SUPPORTED_BEDROCK_EMBEDDING_MODELS = {
"cohere.embed-multilingual-v3": "Cohere Embed Multilingual",
"cohere.embed-english-v3": "Cohere Embed English",
"amazon.titan-embed-text-v1": "Titan Embeddings G1 - Text",
"amazon.titan-embed-image-v1": "Titan Multimodal Embeddings G1"
}
class BaseChatModel(ABC):
"""Represent a basic chat model
@@ -389,3 +401,136 @@ class MistralModel(BedrockModel):
finish_reason=chunk["outputs"][0]["stop_reason"],
)
yield self._stream_response_to_bytes(response)
class BaseEmbeddingsModel(ABC):
"""Represents a basic embeddings model.
Currently, only Bedrock-provided models are supported, but it may be used for SageMaker models if needed.
"""
@abstractmethod
def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse:
"""Handle a basic embeddings request."""
pass
def _generate_message_id(self) -> str:
return "embeddings-" + str(uuid.uuid4())[:8]
class BedrockEmbeddingsModel(BaseEmbeddingsModel):
accept = "application/json"
content_type = "application/json"
def _invoke_model(self, args: dict, model_id: str, with_stream: bool = False):
body = json.dumps(args)
if DEBUG:
logger.info("Invoke Bedrock Model: " + model_id)
logger.info("Bedrock request body: " + body)
if with_stream:
return bedrock_runtime.invoke_model_with_response_stream(
body=body,
modelId=model_id,
accept=self.accept,
contentType=self.content_type,
)
return bedrock_runtime.invoke_model(
body=body,
modelId=model_id,
accept=self.accept,
contentType=self.content_type,
)
def _create_response(
self,
embeddings: list[float],
model: str,
input_tokens: int = 0,
output_tokens: int = 0,
) -> EmbeddingsResponse:
data = [
Embedding(
index=i,
embedding=embedding
) for i, embedding in enumerate(embeddings)
]
response = EmbeddingsResponse(
data=data,
model=model,
usage=EmbeddingsUsage(
prompt_tokens=input_tokens,
total_tokens=input_tokens + output_tokens,
),
)
if DEBUG:
logger.info("Proxy response :" + response.model_dump_json())
return response
def get_embeddings_model(model_id: str) -> BedrockEmbeddingsModel:
model_name = SUPPORTED_BEDROCK_EMBEDDING_MODELS.get(model_id, "")
if DEBUG:
logger.info("model name is " + model_name)
if model_name in ["Cohere Embed Multilingual", "Cohere Embed English"]:
return CohereEmbeddingsModel()
elif model_name in ["Titan Embeddings G1 - Text", "Titan Multimodal Embeddings G1"]:
return TitanEmbeddingsModel()
else:
logger.error("Unsupported model id " + model_id)
raise ValueError("Invalid model ID")
class CohereEmbeddingsModel(BedrockEmbeddingsModel):
def _parse_args(self, embeddings_request: EmbeddingsRequest) -> dict:
args = {
"texts": embeddings_request.input,
"input_type": embeddings_request.input_type if embeddings_request.input_type else "search_document",
"truncate": embeddings_request.truncate if embeddings_request.truncate else "NONE",
}
return args
def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse:
response = self._invoke_model(
args=self._parse_args(embeddings_request), model_id=embeddings_request.model
)
response_body = json.loads(response.get("body").read())
if DEBUG:
logger.info("Bedrock response body: " + str(response_body))
return self._create_response(
embeddings=response_body["embeddings"],
model=embeddings_request.model,
)
class TitanEmbeddingsModel(BedrockEmbeddingsModel):
def _parse_args(self, embeddings_request: EmbeddingsRequest) -> dict:
if isinstance(embeddings_request.input, str):
input_text = embeddings_request.input
elif isinstance(embeddings_request.input, list) and len(embeddings_request.input) == 1:
input_text = embeddings_request.input[0]
else:
raise ValueError("Amazon Titan Embeddings models support only single strings as input.")
args = {
"inputText": input_text,
# Note: inputImage is not supported!
}
if embeddings_request.model == "amazon.titan-embed-image-v1":
args["embeddingConfig"] = embeddings_request.embedding_config if embeddings_request.embedding_config else {"outputEmbeddingLength": 1024}
return args
def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse:
response = self._invoke_model(
args=self._parse_args(embeddings_request), model_id=embeddings_request.model
)
response_body = json.loads(response.get("body").read())
if DEBUG:
logger.info("Bedrock response body: " + str(response_body))
return self._create_response(
embeddings=[response_body["embedding"]],
model=embeddings_request.model,
input_tokens=response_body["inputTextTokenCount"]
)

View File

@@ -0,0 +1,43 @@
from typing import Annotated
from fastapi import APIRouter, Depends, Body, HTTPException
from api.auth import api_key_auth
from api.models import get_embeddings_model, SUPPORTED_BEDROCK_EMBEDDING_MODELS
from api.schema import EmbeddingsRequest, EmbeddingsResponse
from api.setting import DEFAULT_EMBEDDING_MODEL
router = APIRouter()
router = APIRouter(
prefix="/embeddings",
tags=["items"],
dependencies=[Depends(api_key_auth)],
)
@router.post("/", response_model=EmbeddingsResponse)
async def embeddings(
embeddings_request: Annotated[
EmbeddingsRequest,
Body(
examples=[
{
"model": "cohere.embed-multilingual-v3",
"input": [
"Your text string goes here"
],
}
],
),
]
):
if embeddings_request.model.lower().startswith("text-embedding-"):
embeddings_request.model = DEFAULT_EMBEDDING_MODEL
if embeddings_request.model not in SUPPORTED_BEDROCK_EMBEDDING_MODELS.keys():
raise HTTPException(status_code=400, detail="Unsupported Model Id " + embeddings_request.model)
try:
model = get_embeddings_model(embeddings_request.model)
return model.embed(embeddings_request)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))

View File

@@ -3,7 +3,7 @@ from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, Path
from api.auth import api_key_auth
from api.models import SUPPORTED_BEDROCK_MODELS
from api.models import SUPPORTED_BEDROCK_MODELS, SUPPORTED_BEDROCK_EMBEDDING_MODELS
from api.schema import Models, Model
router = APIRouter()
@@ -17,13 +17,13 @@ router = APIRouter(
async def validate_model_id(model_id: str):
if model_id not in SUPPORTED_BEDROCK_MODELS.keys():
if model_id not in (SUPPORTED_BEDROCK_MODELS | SUPPORTED_BEDROCK_EMBEDDING_MODELS).keys():
raise HTTPException(status_code=400, detail="Unsupported Model Id")
@router.get("/", response_model=Models)
async def list_models():
model_list = [Model(id=model_id) for model_id in SUPPORTED_BEDROCK_MODELS.keys()]
model_list = [Model(id=model_id) for model_id in (SUPPORTED_BEDROCK_MODELS | SUPPORTED_BEDROCK_EMBEDDING_MODELS).keys()]
return Models(data=model_list)

View File

@@ -78,3 +78,36 @@ class ChatResponse(BaseChatResponse):
class ChatStreamResponse(BaseChatResponse):
choices: list[ChoiceDelta]
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
class EmbeddingsRequest(BaseModel):
input: str | list[str]
model: str
# Cohere Embed
input_type: Literal["search_document", "search_query", "classification", "clustering"] | None = None
truncate: Literal["NONE", "LEFT", "RIGHT"] | None = None
# Titan Embeddings
embedding_config: dict | None = None
class BaseEmbeddingsResponse(BaseModel):
created: int = Field(default_factory=lambda: int(time.time()))
model: str
system_fingerprint: str = "fp_e97c09dd4e26"
class Embedding(BaseModel):
object: Literal["embedding"] = "embedding"
embedding: list[float]
index: int
class EmbeddingsUsage(BaseModel):
prompt_tokens: int
total_tokens: int
class EmbeddingsResponse(BaseEmbeddingsResponse):
data: list[Embedding]
object: Literal["list"] = "list"
usage: EmbeddingsUsage

View File

@@ -11,6 +11,8 @@ DESCRIPTION = """
Use OpenAI-Compatible RESTful APIs for Amazon Bedrock models.
List of Amazon Bedrock models currently supported:
# Chat
- anthropic.claude-instant-v1
- anthropic.claude-v2:1
- anthropic.claude-v2
@@ -20,8 +22,15 @@ List of Amazon Bedrock models currently supported:
- meta.llama2-70b-chat-v1
- mistral.mistral-7b-instruct-v0:2
- mistral.mixtral-8x7b-instruct-v0:1
# Embeddings
- cohere.embed-multilingual-v3
- cohere.embed-english-v3
- amazon.titan-embed-text-v1
- amazon.titan-embed-image-v1
"""
DEBUG = os.environ.get("DEBUG", "false").lower() != "false"
AWS_REGION = os.environ.get("AWS_REGION", "us-west-2")
DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "anthropic.claude-3-sonnet-20240229-v1:0")
DEFAULT_EMBEDDING_MODEL = os.environ.get("DEFAULT_EMBEDDING_MODEL", "cohere.embed-multilingual-v3")