Added support for Cohere Embed and Titan Embeddings models
This commit is contained in:
@@ -7,7 +7,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||||||
from fastapi.responses import PlainTextResponse
|
from fastapi.responses import PlainTextResponse
|
||||||
from mangum import Mangum
|
from mangum import Mangum
|
||||||
|
|
||||||
from api.routers import model, chat
|
from api.routers import model, chat, embeddings
|
||||||
from api.setting import API_ROUTE_PREFIX, TITLE, DESCRIPTION, SUMMARY, VERSION
|
from api.setting import API_ROUTE_PREFIX, TITLE, DESCRIPTION, SUMMARY, VERSION
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
@@ -33,6 +33,7 @@ app.add_middleware(
|
|||||||
|
|
||||||
app.include_router(model.router, prefix=API_ROUTE_PREFIX)
|
app.include_router(model.router, prefix=API_ROUTE_PREFIX)
|
||||||
app.include_router(chat.router, prefix=API_ROUTE_PREFIX)
|
app.include_router(chat.router, prefix=API_ROUTE_PREFIX)
|
||||||
|
app.include_router(embeddings.router, prefix=API_ROUTE_PREFIX)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
|
|||||||
@@ -1 +1,7 @@
|
|||||||
from api.models.bedrock import ClaudeModel, SUPPORTED_BEDROCK_MODELS, get_model
|
from api.models.bedrock import (
|
||||||
|
ClaudeModel,
|
||||||
|
SUPPORTED_BEDROCK_MODELS,
|
||||||
|
SUPPORTED_BEDROCK_EMBEDDING_MODELS,
|
||||||
|
get_model,
|
||||||
|
get_embeddings_model,
|
||||||
|
)
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from typing import AsyncIterable
|
|||||||
import boto3
|
import boto3
|
||||||
|
|
||||||
from api.schema import (
|
from api.schema import (
|
||||||
|
# Chat
|
||||||
ChatResponse,
|
ChatResponse,
|
||||||
ChatRequest,
|
ChatRequest,
|
||||||
ChatRequestMessage,
|
ChatRequestMessage,
|
||||||
@@ -15,6 +16,11 @@ from api.schema import (
|
|||||||
Usage,
|
Usage,
|
||||||
ChatStreamResponse,
|
ChatStreamResponse,
|
||||||
ChoiceDelta,
|
ChoiceDelta,
|
||||||
|
# Embeddings
|
||||||
|
EmbeddingsRequest,
|
||||||
|
EmbeddingsResponse,
|
||||||
|
EmbeddingsUsage,
|
||||||
|
Embedding,
|
||||||
)
|
)
|
||||||
from api.setting import DEBUG, AWS_REGION
|
from api.setting import DEBUG, AWS_REGION
|
||||||
|
|
||||||
@@ -37,6 +43,12 @@ SUPPORTED_BEDROCK_MODELS = {
|
|||||||
"mistral.mixtral-8x7b-instruct-v0:1": "Mixtral 8x7B Instruct",
|
"mistral.mixtral-8x7b-instruct-v0:1": "Mixtral 8x7B Instruct",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SUPPORTED_BEDROCK_EMBEDDING_MODELS = {
|
||||||
|
"cohere.embed-multilingual-v3": "Cohere Embed Multilingual",
|
||||||
|
"cohere.embed-english-v3": "Cohere Embed English",
|
||||||
|
"amazon.titan-embed-text-v1": "Titan Embeddings G1 - Text",
|
||||||
|
"amazon.titan-embed-image-v1": "Titan Multimodal Embeddings G1"
|
||||||
|
}
|
||||||
|
|
||||||
class BaseChatModel(ABC):
|
class BaseChatModel(ABC):
|
||||||
"""Represent a basic chat model
|
"""Represent a basic chat model
|
||||||
@@ -389,3 +401,136 @@ class MistralModel(BedrockModel):
|
|||||||
finish_reason=chunk["outputs"][0]["stop_reason"],
|
finish_reason=chunk["outputs"][0]["stop_reason"],
|
||||||
)
|
)
|
||||||
yield self._stream_response_to_bytes(response)
|
yield self._stream_response_to_bytes(response)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseEmbeddingsModel(ABC):
|
||||||
|
"""Represents a basic embeddings model.
|
||||||
|
|
||||||
|
Currently, only Bedrock-provided models are supported, but it may be used for SageMaker models if needed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse:
|
||||||
|
"""Handle a basic embeddings request."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _generate_message_id(self) -> str:
|
||||||
|
return "embeddings-" + str(uuid.uuid4())[:8]
|
||||||
|
|
||||||
|
|
||||||
|
class BedrockEmbeddingsModel(BaseEmbeddingsModel):
|
||||||
|
accept = "application/json"
|
||||||
|
content_type = "application/json"
|
||||||
|
|
||||||
|
def _invoke_model(self, args: dict, model_id: str, with_stream: bool = False):
|
||||||
|
body = json.dumps(args)
|
||||||
|
if DEBUG:
|
||||||
|
logger.info("Invoke Bedrock Model: " + model_id)
|
||||||
|
logger.info("Bedrock request body: " + body)
|
||||||
|
if with_stream:
|
||||||
|
return bedrock_runtime.invoke_model_with_response_stream(
|
||||||
|
body=body,
|
||||||
|
modelId=model_id,
|
||||||
|
accept=self.accept,
|
||||||
|
contentType=self.content_type,
|
||||||
|
)
|
||||||
|
return bedrock_runtime.invoke_model(
|
||||||
|
body=body,
|
||||||
|
modelId=model_id,
|
||||||
|
accept=self.accept,
|
||||||
|
contentType=self.content_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_response(
|
||||||
|
self,
|
||||||
|
embeddings: list[float],
|
||||||
|
model: str,
|
||||||
|
input_tokens: int = 0,
|
||||||
|
output_tokens: int = 0,
|
||||||
|
) -> EmbeddingsResponse:
|
||||||
|
data = [
|
||||||
|
Embedding(
|
||||||
|
index=i,
|
||||||
|
embedding=embedding
|
||||||
|
) for i, embedding in enumerate(embeddings)
|
||||||
|
]
|
||||||
|
response = EmbeddingsResponse(
|
||||||
|
data=data,
|
||||||
|
model=model,
|
||||||
|
usage=EmbeddingsUsage(
|
||||||
|
prompt_tokens=input_tokens,
|
||||||
|
total_tokens=input_tokens + output_tokens,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if DEBUG:
|
||||||
|
logger.info("Proxy response :" + response.model_dump_json())
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def get_embeddings_model(model_id: str) -> BedrockEmbeddingsModel:
|
||||||
|
model_name = SUPPORTED_BEDROCK_EMBEDDING_MODELS.get(model_id, "")
|
||||||
|
if DEBUG:
|
||||||
|
logger.info("model name is " + model_name)
|
||||||
|
if model_name in ["Cohere Embed Multilingual", "Cohere Embed English"]:
|
||||||
|
return CohereEmbeddingsModel()
|
||||||
|
elif model_name in ["Titan Embeddings G1 - Text", "Titan Multimodal Embeddings G1"]:
|
||||||
|
return TitanEmbeddingsModel()
|
||||||
|
else:
|
||||||
|
logger.error("Unsupported model id " + model_id)
|
||||||
|
raise ValueError("Invalid model ID")
|
||||||
|
|
||||||
|
|
||||||
|
class CohereEmbeddingsModel(BedrockEmbeddingsModel):
|
||||||
|
|
||||||
|
def _parse_args(self, embeddings_request: EmbeddingsRequest) -> dict:
|
||||||
|
args = {
|
||||||
|
"texts": embeddings_request.input,
|
||||||
|
"input_type": embeddings_request.input_type if embeddings_request.input_type else "search_document",
|
||||||
|
"truncate": embeddings_request.truncate if embeddings_request.truncate else "NONE",
|
||||||
|
}
|
||||||
|
return args
|
||||||
|
|
||||||
|
def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse:
|
||||||
|
response = self._invoke_model(
|
||||||
|
args=self._parse_args(embeddings_request), model_id=embeddings_request.model
|
||||||
|
)
|
||||||
|
response_body = json.loads(response.get("body").read())
|
||||||
|
if DEBUG:
|
||||||
|
logger.info("Bedrock response body: " + str(response_body))
|
||||||
|
|
||||||
|
return self._create_response(
|
||||||
|
embeddings=response_body["embeddings"],
|
||||||
|
model=embeddings_request.model,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TitanEmbeddingsModel(BedrockEmbeddingsModel):
|
||||||
|
|
||||||
|
def _parse_args(self, embeddings_request: EmbeddingsRequest) -> dict:
|
||||||
|
if isinstance(embeddings_request.input, str):
|
||||||
|
input_text = embeddings_request.input
|
||||||
|
elif isinstance(embeddings_request.input, list) and len(embeddings_request.input) == 1:
|
||||||
|
input_text = embeddings_request.input[0]
|
||||||
|
else:
|
||||||
|
raise ValueError("Amazon Titan Embeddings models support only single strings as input.")
|
||||||
|
args = {
|
||||||
|
"inputText": input_text,
|
||||||
|
# Note: inputImage is not supported!
|
||||||
|
}
|
||||||
|
if embeddings_request.model == "amazon.titan-embed-image-v1":
|
||||||
|
args["embeddingConfig"] = embeddings_request.embedding_config if embeddings_request.embedding_config else {"outputEmbeddingLength": 1024}
|
||||||
|
return args
|
||||||
|
|
||||||
|
def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse:
|
||||||
|
response = self._invoke_model(
|
||||||
|
args=self._parse_args(embeddings_request), model_id=embeddings_request.model
|
||||||
|
)
|
||||||
|
response_body = json.loads(response.get("body").read())
|
||||||
|
if DEBUG:
|
||||||
|
logger.info("Bedrock response body: " + str(response_body))
|
||||||
|
|
||||||
|
return self._create_response(
|
||||||
|
embeddings=[response_body["embedding"]],
|
||||||
|
model=embeddings_request.model,
|
||||||
|
input_tokens=response_body["inputTextTokenCount"]
|
||||||
|
)
|
||||||
|
|||||||
43
src/api/routers/embeddings.py
Normal file
43
src/api/routers/embeddings.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, Body, HTTPException
|
||||||
|
|
||||||
|
from api.auth import api_key_auth
|
||||||
|
from api.models import get_embeddings_model, SUPPORTED_BEDROCK_EMBEDDING_MODELS
|
||||||
|
from api.schema import EmbeddingsRequest, EmbeddingsResponse
|
||||||
|
from api.setting import DEFAULT_EMBEDDING_MODEL
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
router = APIRouter(
|
||||||
|
prefix="/embeddings",
|
||||||
|
tags=["items"],
|
||||||
|
dependencies=[Depends(api_key_auth)],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/", response_model=EmbeddingsResponse)
|
||||||
|
async def embeddings(
|
||||||
|
embeddings_request: Annotated[
|
||||||
|
EmbeddingsRequest,
|
||||||
|
Body(
|
||||||
|
examples=[
|
||||||
|
{
|
||||||
|
"model": "cohere.embed-multilingual-v3",
|
||||||
|
"input": [
|
||||||
|
"Your text string goes here"
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
):
|
||||||
|
if embeddings_request.model.lower().startswith("text-embedding-"):
|
||||||
|
embeddings_request.model = DEFAULT_EMBEDDING_MODEL
|
||||||
|
if embeddings_request.model not in SUPPORTED_BEDROCK_EMBEDDING_MODELS.keys():
|
||||||
|
raise HTTPException(status_code=400, detail="Unsupported Model Id " + embeddings_request.model)
|
||||||
|
try:
|
||||||
|
model = get_embeddings_model(embeddings_request.model)
|
||||||
|
return model.embed(embeddings_request)
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
@@ -3,7 +3,7 @@ from typing import Annotated
|
|||||||
from fastapi import APIRouter, Depends, HTTPException, Path
|
from fastapi import APIRouter, Depends, HTTPException, Path
|
||||||
|
|
||||||
from api.auth import api_key_auth
|
from api.auth import api_key_auth
|
||||||
from api.models import SUPPORTED_BEDROCK_MODELS
|
from api.models import SUPPORTED_BEDROCK_MODELS, SUPPORTED_BEDROCK_EMBEDDING_MODELS
|
||||||
from api.schema import Models, Model
|
from api.schema import Models, Model
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
@@ -17,13 +17,13 @@ router = APIRouter(
|
|||||||
|
|
||||||
|
|
||||||
async def validate_model_id(model_id: str):
|
async def validate_model_id(model_id: str):
|
||||||
if model_id not in SUPPORTED_BEDROCK_MODELS.keys():
|
if model_id not in (SUPPORTED_BEDROCK_MODELS | SUPPORTED_BEDROCK_EMBEDDING_MODELS).keys():
|
||||||
raise HTTPException(status_code=400, detail="Unsupported Model Id")
|
raise HTTPException(status_code=400, detail="Unsupported Model Id")
|
||||||
|
|
||||||
|
|
||||||
@router.get("/", response_model=Models)
|
@router.get("/", response_model=Models)
|
||||||
async def list_models():
|
async def list_models():
|
||||||
model_list = [Model(id=model_id) for model_id in SUPPORTED_BEDROCK_MODELS.keys()]
|
model_list = [Model(id=model_id) for model_id in (SUPPORTED_BEDROCK_MODELS | SUPPORTED_BEDROCK_EMBEDDING_MODELS).keys()]
|
||||||
return Models(data=model_list)
|
return Models(data=model_list)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -78,3 +78,36 @@ class ChatResponse(BaseChatResponse):
|
|||||||
class ChatStreamResponse(BaseChatResponse):
|
class ChatStreamResponse(BaseChatResponse):
|
||||||
choices: list[ChoiceDelta]
|
choices: list[ChoiceDelta]
|
||||||
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
|
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingsRequest(BaseModel):
|
||||||
|
input: str | list[str]
|
||||||
|
model: str
|
||||||
|
# Cohere Embed
|
||||||
|
input_type: Literal["search_document", "search_query", "classification", "clustering"] | None = None
|
||||||
|
truncate: Literal["NONE", "LEFT", "RIGHT"] | None = None
|
||||||
|
# Titan Embeddings
|
||||||
|
embedding_config: dict | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class BaseEmbeddingsResponse(BaseModel):
|
||||||
|
created: int = Field(default_factory=lambda: int(time.time()))
|
||||||
|
model: str
|
||||||
|
system_fingerprint: str = "fp_e97c09dd4e26"
|
||||||
|
|
||||||
|
|
||||||
|
class Embedding(BaseModel):
|
||||||
|
object: Literal["embedding"] = "embedding"
|
||||||
|
embedding: list[float]
|
||||||
|
index: int
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingsUsage(BaseModel):
|
||||||
|
prompt_tokens: int
|
||||||
|
total_tokens: int
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingsResponse(BaseEmbeddingsResponse):
|
||||||
|
data: list[Embedding]
|
||||||
|
object: Literal["list"] = "list"
|
||||||
|
usage: EmbeddingsUsage
|
||||||
@@ -11,6 +11,8 @@ DESCRIPTION = """
|
|||||||
Use OpenAI-Compatible RESTful APIs for Amazon Bedrock models.
|
Use OpenAI-Compatible RESTful APIs for Amazon Bedrock models.
|
||||||
|
|
||||||
List of Amazon Bedrock models currently supported:
|
List of Amazon Bedrock models currently supported:
|
||||||
|
|
||||||
|
# Chat
|
||||||
- anthropic.claude-instant-v1
|
- anthropic.claude-instant-v1
|
||||||
- anthropic.claude-v2:1
|
- anthropic.claude-v2:1
|
||||||
- anthropic.claude-v2
|
- anthropic.claude-v2
|
||||||
@@ -20,8 +22,15 @@ List of Amazon Bedrock models currently supported:
|
|||||||
- meta.llama2-70b-chat-v1
|
- meta.llama2-70b-chat-v1
|
||||||
- mistral.mistral-7b-instruct-v0:2
|
- mistral.mistral-7b-instruct-v0:2
|
||||||
- mistral.mixtral-8x7b-instruct-v0:1
|
- mistral.mixtral-8x7b-instruct-v0:1
|
||||||
|
|
||||||
|
# Embeddings
|
||||||
|
- cohere.embed-multilingual-v3
|
||||||
|
- cohere.embed-english-v3
|
||||||
|
- amazon.titan-embed-text-v1
|
||||||
|
- amazon.titan-embed-image-v1
|
||||||
"""
|
"""
|
||||||
|
|
||||||
DEBUG = os.environ.get("DEBUG", "false").lower() != "false"
|
DEBUG = os.environ.get("DEBUG", "false").lower() != "false"
|
||||||
AWS_REGION = os.environ.get("AWS_REGION", "us-west-2")
|
AWS_REGION = os.environ.get("AWS_REGION", "us-west-2")
|
||||||
DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "anthropic.claude-3-sonnet-20240229-v1:0")
|
DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "anthropic.claude-3-sonnet-20240229-v1:0")
|
||||||
|
DEFAULT_EMBEDDING_MODEL = os.environ.get("DEFAULT_EMBEDDING_MODEL", "cohere.embed-multilingual-v3")
|
||||||
|
|||||||
Reference in New Issue
Block a user