diff --git a/src/api/models/__init__.py b/src/api/models/__init__.py
index eb89239..e69de29 100644
--- a/src/api/models/__init__.py
+++ b/src/api/models/__init__.py
@@ -1,7 +0,0 @@
-from api.models.bedrock import (
- ClaudeModel,
- SUPPORTED_BEDROCK_MODELS,
- SUPPORTED_BEDROCK_EMBEDDING_MODELS,
- get_model,
- get_embeddings_model,
-)
diff --git a/src/api/models/base.py b/src/api/models/base.py
index 9dc74dd..9d9db7f 100644
--- a/src/api/models/base.py
+++ b/src/api/models/base.py
@@ -1,3 +1,4 @@
+import time
import uuid
from abc import ABC, abstractmethod
from typing import AsyncIterable
@@ -19,6 +20,14 @@ class BaseChatModel(ABC):
Currently, only Bedrock model is supported, but may be used for SageMaker models if needed.
"""
+ def list_models(self) -> list[str]:
+ """Return a list of supported models"""
+ return []
+
+ def validate(self, chat_request: ChatRequest):
+ """Validate chat completion requests."""
+ pass
+
@abstractmethod
def chat(self, chat_request: ChatRequest) -> ChatResponse:
"""Handle a basic chat completion requests."""
@@ -38,7 +47,11 @@ class BaseChatModel(ABC):
response: ChatStreamResponse | None = None
) -> bytes:
if response:
- return "data: {}\n\n".format(response.model_dump_json()).encode("utf-8")
+ # to populate other fields when using exclude_unset=True
+ response.system_fingerprint = "fp"
+ response.object = "chat.completion.chunk"
+ response.created = int(time.time())
+ return "data: {}\n\n".format(response.model_dump_json(exclude_unset=True)).encode("utf-8")
return "data: [DONE]\n\n".encode("utf-8")
diff --git a/src/api/models/bedrock.py b/src/api/models/bedrock.py
index 11f4a90..da0f188 100644
--- a/src/api/models/bedrock.py
+++ b/src/api/models/bedrock.py
@@ -2,7 +2,7 @@ import base64
import json
import logging
import re
-from abc import ABC, abstractmethod
+from abc import ABC
from typing import AsyncIterable, Iterable, Literal
import boto3
@@ -20,11 +20,15 @@ from api.schema import (
ChatResponseMessage,
Usage,
ChatStreamResponse,
- ChoiceDelta,
ImageContent,
TextContent,
- ResponseFunction,
ToolCall,
+ ChoiceDelta,
+ UserMessage,
+ AssistantMessage,
+ ToolMessage,
+ Function,
+ ResponseFunction,
# Embeddings
EmbeddingsRequest,
EmbeddingsResponse,
@@ -40,24 +44,6 @@ bedrock_runtime = boto3.client(
region_name=AWS_REGION,
)
-SUPPORTED_BEDROCK_MODELS = {
- "anthropic.claude-instant-v1": "Claude Instant",
- "anthropic.claude-v2:1": "Claude",
- "anthropic.claude-v2": "Claude",
- "anthropic.claude-3-sonnet-20240229-v1:0": "Claude 3 Sonnet",
- "anthropic.claude-3-opus-20240229-v1:0": "Claude 3 Opus",
- "anthropic.claude-3-haiku-20240307-v1:0": "Claude 3 Haiku",
- "meta.llama2-13b-chat-v1": "Llama 2 Chat 13B",
- "meta.llama2-70b-chat-v1": "Llama 2 Chat 70B",
- "meta.llama3-8b-instruct-v1:0": "Llama 3 8B Instruct",
- "meta.llama3-70b-instruct-v1:0": "Llama 3 70B Instruct",
- "mistral.mistral-7b-instruct-v0:2": "Mistral 7B Instruct",
- "mistral.mixtral-8x7b-instruct-v0:1": "Mixtral 8x7B Instruct",
- "mistral.mistral-large-2402-v1:0": "Mistral Large",
- "cohere.command-r-v1:0": "Command R",
- "cohere.command-r-plus-v1:0": "Command R+",
-}
-
SUPPORTED_BEDROCK_EMBEDDING_MODELS = {
"cohere.embed-multilingual-v3": "Cohere Embed Multilingual",
"cohere.embed-english-v3": "Cohere Embed English",
@@ -69,58 +55,199 @@ SUPPORTED_BEDROCK_EMBEDDING_MODELS = {
ENCODER = tiktoken.get_encoding("cl100k_base")
-# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html
-class BedrockModel(BaseChatModel, ABC):
- accept = "application/json"
- content_type = "application/json"
+class BedrockModel(BaseChatModel):
+ # https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#conversation-inference-supported-models-features
+ _supported_models = {
+ "amazon.titan-text-premier-v1:0": {
+ "system": True,
+ "multimodal": False,
+ "tool_call": False,
+ "stream_tool_call": False,
+ },
+ "anthropic.claude-instant-v1": {
+ "system": True,
+ "multimodal": False,
+ "tool_call": False,
+ "stream_tool_call": False,
+ },
+ "anthropic.claude-v2:1": {
+ "system": True,
+ "multimodal": False,
+ "tool_call": False,
+ "stream_tool_call": False,
+ },
+ "anthropic.claude-v2": {
+ "system": True,
+ "multimodal": False,
+ "tool_call": False,
+ "stream_tool_call": False,
+ },
+ "anthropic.claude-3-sonnet-20240229-v1:0": {
+ "system": True,
+ "multimodal": True,
+ "tool_call": True,
+ "stream_tool_call": True,
+ },
+ "anthropic.claude-3-opus-20240229-v1:0": {
+ "system": True,
+ "multimodal": True,
+ "tool_call": True,
+ "stream_tool_call": True,
+ },
+ "anthropic.claude-3-haiku-20240307-v1:0": {
+ "system": True,
+ "multimodal": True,
+ "tool_call": True,
+ "stream_tool_call": True,
+ },
+ "meta.llama2-13b-chat-v1": {
+ "system": True,
+ "multimodal": False,
+ "tool_call": False,
+ "stream_tool_call": False,
+ },
+ "meta.llama2-70b-chat-v1": {
+ "system": True,
+ "multimodal": False,
+ "tool_call": False,
+ "stream_tool_call": False,
+ },
+ "meta.llama3-8b-instruct-v1:0": {
+ "system": True,
+ "multimodal": False,
+ "tool_call": False,
+ "stream_tool_call": False,
+ },
+ "meta.llama3-70b-instruct-v1:0": {
+ "system": True,
+ "multimodal": False,
+ "tool_call": False,
+ "stream_tool_call": False,
+ },
+ "mistral.mistral-7b-instruct-v0:2": {
+ "system": False,
+ "multimodal": False,
+ "tool_call": False,
+ "stream_tool_call": False,
+ },
+ "mistral.mixtral-8x7b-instruct-v0:1": {
+ "system": False,
+ "multimodal": False,
+ "tool_call": False,
+ "stream_tool_call": False,
+ },
+ "mistral.mistral-small-2402-v1:0": {
+ "system": True,
+ "multimodal": False,
+ "tool_call": False,
+ "stream_tool_call": False,
+ },
+ "mistral.mistral-large-2402-v1:0": {
+ "system": True,
+ "multimodal": False,
+ "tool_call": True,
+ "stream_tool_call": False,
+ },
+ "cohere.command-r-v1:0": {
+ "system": True,
+ "multimodal": False,
+ "tool_call": True,
+ "stream_tool_call": False,
+ },
+ "cohere.command-r-plus-v1:0": {
+ "system": True,
+ "multimodal": False,
+ "tool_call": True,
+ "stream_tool_call": False,
+ },
+ }
- # Default field name to get the response message
- text_field_name = "text"
+ def list_models(self) -> list[str]:
+ return list(self._supported_models.keys())
- # Default field name to get the response finish reason
- finish_reason_field_name = "finish_reason"
+ def validate(self, chat_request: ChatRequest):
+ """Perform basic validation on requests"""
+ error = ""
+ # check if model is supported
+ if chat_request.model not in self._supported_models.keys():
+ error = f"Unsupported model {chat_request.model}, please use models API to get a list of supported models"
- @abstractmethod
- def compose_request_body(self, chat_request: ChatRequest) -> str:
- """Since the request body to Bedrock varies,
- each model should implement this to compose the request body.
+ # check if tool call is supported
+ elif chat_request.tools and not self._is_tool_call_supported(chat_request.model, stream=chat_request.stream):
+ tool_call_info = "Tool call with streaming" if chat_request.stream else "Tool call"
+ error = f"{tool_call_info} is currently not supported by {chat_request.model}"
- :param chat_request:
- :return: request body as a string
- """
- raise NotImplementedError()
+ # check if system prompt is supported
+ # nice to have an error rather than ignore it.
+ elif not self._is_system_prompt_supported(chat_request.model):
+ error = f"System message is currently not supported by {chat_request.model}"
+
+ if error:
+ raise HTTPException(
+ status_code=400,
+ detail=error,
+ )
+
+ def _invoke_bedrock(self, chat_request: ChatRequest, stream=False):
+ """Common logic for invoke bedrock models"""
+ if DEBUG:
+ logger.info("Raw request: " + chat_request.model_dump_json())
+
+ # convert OpenAI chat request to Bedrock SDK request
+ args = self._parse_request(chat_request)
+ if DEBUG:
+ logger.info("Bedrock request: " + json.dumps(args))
+
+ try:
+ if stream:
+ response = bedrock_runtime.converse_stream(**args)
+ else:
+ response = bedrock_runtime.converse(**args)
+ except bedrock_runtime.exceptions.ValidationException as e:
+ logger.error("Validation Error: " + str(e))
+ raise HTTPException(status_code=400, detail=str(e))
+ except Exception as e:
+ logger.error(e)
+ raise HTTPException(status_code=500, detail=str(e))
+ return response
def chat(self, chat_request: ChatRequest) -> ChatResponse:
"""Default implementation for Chat API."""
- if DEBUG:
- logger.info("Raw request: " + chat_request.model_dump_json())
- request_body = self.compose_request_body(chat_request)
- if DEBUG:
- logger.info("Bedrock request: " + request_body)
-
- response = self.invoke_model(
- request_body=request_body,
- model_id=chat_request.model,
- )
message_id = self.generate_message_id()
- return self.parse_response(chat_request, response, message_id)
+ response = self._invoke_bedrock(chat_request)
+
+ output_message = response["output"]["message"]
+ input_tokens = response["usage"]["inputTokens"]
+ output_tokens = response["usage"]["outputTokens"]
+ finish_reason = response["stopReason"]
+
+ chat_response = self._create_response(
+ model=chat_request.model,
+ message_id=message_id,
+ content=output_message["content"],
+ finish_reason=finish_reason,
+ input_tokens=input_tokens,
+ output_tokens=output_tokens,
+ )
+ if DEBUG:
+ logger.info("Proxy response :" + chat_response.model_dump_json())
+ return chat_response
def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
"""Default implementation for Chat Stream API"""
- if DEBUG:
- logger.info("Raw request: " + chat_request.model_dump_json())
- request_body = self.compose_request_body(chat_request)
- response = self.invoke_model(
- request_body=request_body,
- model_id=chat_request.model,
- with_stream=True,
- )
-
+ response = self._invoke_bedrock(chat_request, stream=True)
message_id = self.generate_message_id()
- for stream_response in self.parse_stream_response(
- chat_request, response, message_id
- ):
+
+ stream = response.get("stream")
+ for chunk in stream:
+ stream_response = self._create_response_stream(
+ model_id=chat_request.model, message_id=message_id, chunk=chunk
+ )
+ if not stream_response:
+ continue
+ if DEBUG:
+ logger.info("Proxy response :" + stream_response.model_dump_json())
if stream_response.choices:
yield self.stream_response_to_bytes(stream_response)
elif (
@@ -134,156 +261,169 @@ class BedrockModel(BaseChatModel, ABC):
# and the choices field will always be an empty array.
# All other chunks will also include a usage field, but with a null value.
yield self.stream_response_to_bytes(stream_response)
+
# return an [DONE] message at the end.
yield self.stream_response_to_bytes()
- def get_message_text(self, response_body: dict) -> str | None:
- """Default func to get the response message.
+ def _parse_system_prompts(self, chat_request: ChatRequest) -> list[dict[str, str]]:
+ """Create system prompts.
+ Note that not all models support system prompts.
- Ideally, only the field name should be changed."""
- return response_body.get(self.text_field_name)
+ example output: [{"text" : system_prompt}]
- def get_message_finish_reason(self, response_body: dict) -> str | None:
- """Default func to get the finish message.
+ See example:
+ https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#message-inference-examples
+ """
- Ideally, only the field name should be changed."""
- return response_body.get(self.finish_reason_field_name)
+ system_prompts = []
+ for message in chat_request.messages:
+ if message.role != "system":
+ # ignore system messages here
+ continue
+ assert isinstance(message.content, str)
+ system_prompts.append({"text": message.content})
- def get_message_usage(self, response_body: dict) -> tuple[int, int]:
- """Default func to get the finish message.
+ return system_prompts
- Can be overridden in the detail model for complex cases."""
- input_tokens = int(response_body.get("prompt_token_count", "0"))
- output_tokens = int(response_body.get("generation_token_count", "0"))
- return input_tokens, output_tokens
+ def _parse_messages(self, chat_request: ChatRequest) -> list[dict]:
+ """
+ Converse API only support user and assistant messages.
- def parse_response(
- self, chat_request: ChatRequest, service_response: dict, message_id: str
- ) -> ChatResponse:
- response_body = json.loads(service_response.get("body").read())
- if DEBUG:
- logger.info("Bedrock response body: " + str(response_body))
+ example output: [{
+ "role": "user",
+ "content": [{"text": input_text}]
+ }]
- input_tokens, output_tokens = self.get_message_usage(response_body)
- return self.create_response(
- model=chat_request.model,
- message_id=message_id,
- message=self.get_message_text(response_body),
- finish_reason=self.get_message_finish_reason(response_body),
- input_tokens=input_tokens,
- output_tokens=output_tokens,
- )
-
- def parse_stream_response(
- self, chat_request: ChatRequest, service_response: dict, message_id: str
- ) -> Iterable[ChatStreamResponse]:
-
- chunk_id = 0
- for event in service_response.get("body"):
- if DEBUG:
- logger.info("Bedrock response chunk: " + str(event))
- chunk = json.loads(event["chunk"]["bytes"])
- chunk_id += 1
-
- response = self.create_response_stream(
- model=chat_request.model,
- message_id=message_id,
- chunk_message=self.get_message_text(chunk),
- finish_reason=self.get_message_finish_reason(chunk),
- )
- yield response
- # Get the usage for streaming response anyway.
- if "amazon-bedrock-invocationMetrics" in chunk:
- yield self.create_response_stream(
- model=chat_request.model,
- message_id=message_id,
- input_tokens=chunk["amazon-bedrock-invocationMetrics"][
- "inputTokenCount"
- ],
- output_tokens=chunk["amazon-bedrock-invocationMetrics"][
- "outputTokenCount"
- ],
+ See example:
+ https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#message-inference-examples
+ """
+ messages = []
+ for message in chat_request.messages:
+ if isinstance(message, UserMessage):
+ messages.append(
+ {
+ "role": message.role,
+ "content": self._parse_content_parts(
+ message, chat_request.model
+ ),
+ }
)
-
- def invoke_model(self, request_body: str, model_id: str, with_stream: bool = False):
- if DEBUG:
- logger.info("Invoke Bedrock Model: " + model_id)
- logger.info("Bedrock request body: " + request_body)
- try:
- if with_stream:
- return bedrock_runtime.invoke_model_with_response_stream(
- body=request_body,
- modelId=model_id,
- accept=self.accept,
- contentType=self.content_type,
- )
- return bedrock_runtime.invoke_model(
- body=request_body,
- modelId=model_id,
- accept=self.accept,
- contentType=self.content_type,
- )
- except bedrock_runtime.exceptions.ValidationException as e:
- logger.error("Validation Error: " + str(e))
- raise HTTPException(status_code=400, detail=str(e))
- except Exception as e:
- logger.error(e)
- raise HTTPException(status_code=500, detail=str(e))
-
- @staticmethod
- def merge_message(messages: list[dict]) -> list[dict]:
- """Merge the request messages with the same role as previous message"""
- merged_messages = []
- prev_role = None
- merged_content = ""
-
- for message in messages:
- role = message["role"]
- content = message["content"]
- if role != prev_role or isinstance(content, list):
- if prev_role:
- merged_messages.append(
- {"role": prev_role, "content": merged_content}
+ elif isinstance(message, AssistantMessage):
+ if message.content:
+ # Text message
+ messages.append(
+ {"role": message.role, "content": [{"text": message.content}]}
)
- if isinstance(content, str):
- merged_content = content
- prev_role = role
else:
- merged_messages.append({"role": role, "content": content})
- prev_role = None
- merged_content = ""
+ # Tool use message
+ tool_input = json.loads(message.tool_calls[0].function.arguments)
+ messages.append(
+ {
+ "role": message.role,
+ "content": [
+ {
+ "toolUse": {
+ "toolUseId": message.tool_calls[0].id,
+ "name": message.tool_calls[0].function.name,
+ "input": tool_input
+ }
+ }
+ ],
+ }
+ )
+ elif isinstance(message, ToolMessage):
+ # Bedrock does not support tool role,
+ # Add toolResult to content
+ # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolResultBlock.html
+ messages.append(
+ {
+ "role": "user",
+ "content": [
+ {
+ "toolResult": {
+ "toolUseId": message.tool_call_id,
+ "content": [{"text": message.content}],
+ }
+ }
+ ],
+ }
+ )
+
else:
- if content == merged_content:
- # ignore duplicates
- continue
- merged_content += "\n" + content
+ # ignore others, such as system messages
+ continue
+ return messages
- if merged_content:
- merged_messages.append({"role": prev_role, "content": merged_content})
- return merged_messages
+ def _parse_request(self, chat_request: ChatRequest) -> dict:
+ """Create default converse request body.
- @staticmethod
- def create_response(
+ Also perform validations to tool call etc.
+
+ Ref: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html
+ """
+
+ messages = self._parse_messages(chat_request)
+ system_prompts = self._parse_system_prompts(chat_request)
+
+ # Base inference parameters.
+ inference_config = {
+ "temperature": chat_request.temperature,
+ "maxTokens": chat_request.max_tokens,
+ "topP": chat_request.top_p,
+ }
+
+ args = {
+ "modelId": chat_request.model,
+ "messages": messages,
+ "system": system_prompts,
+ "inferenceConfig": inference_config,
+ }
+ # add tool config
+ if chat_request.tools:
+ args["toolConfig"] = {
+ "tools": [
+ self._convert_tool_spec(t.function) for t in chat_request.tools
+ ]
+ }
+ return args
+
+ def _create_response(
+ self,
model: str,
message_id: str,
- message: str | None = None,
+ content: list[dict] = None,
finish_reason: str | None = None,
- tools: list[ToolCall] | None = None,
input_tokens: int = 0,
output_tokens: int = 0,
) -> ChatResponse:
+
+ message = ChatResponseMessage(
+ role="assistant",
+ )
+ if finish_reason == "tool_use":
+ # https://docs.aws.amazon.com/bedrock/latest/userguide/tool-use.html#tool-use-examples
+ for part in content:
+ if "toolUse" in part:
+ tool = part["toolUse"]
+ message.tool_calls = [
+ ToolCall(
+ id=tool["toolUseId"],
+ function=ResponseFunction(
+ name=tool["name"],
+ arguments=json.dumps(tool["input"]),
+ ),
+ )
+ ]
+ else:
+ message.content = content[0]["text"]
+
response = ChatResponse(
id=message_id,
model=model,
choices=[
Choice(
- index=0,
- message=ChatResponseMessage(
- role="assistant",
- tool_calls=tools,
- content=message,
- ),
- finish_reason="tool_calls" if tools else finish_reason,
+ message=message,
+ finish_reason=finish_reason,
)
],
usage=Usage(
@@ -292,250 +432,100 @@ class BedrockModel(BaseChatModel, ABC):
total_tokens=input_tokens + output_tokens,
),
)
- if DEBUG:
- logger.info("Proxy response :" + response.model_dump_json())
return response
- @staticmethod
- def create_response_stream(
- model: str,
- message_id: str,
- chunk_message: str | None = None,
- finish_reason: str | None = None,
- tools: list[ToolCall] | None = None,
- input_tokens: int = 0,
- output_tokens: int = 0,
- ) -> ChatStreamResponse:
- if chunk_message or finish_reason or tools:
- response = ChatStreamResponse(
+ def _create_response_stream(
+ self, model_id: str, message_id: str, chunk: dict
+ ) -> ChatStreamResponse | None:
+ """Parsing the Bedrock stream response chunk.
+
+ Ref: https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#message-inference-examples
+ """
+ if DEBUG:
+ logger.info("Bedrock response chunk: " + str(chunk))
+
+ finish_reason = None
+ message = None
+ usage = None
+
+ if "messageStart" in chunk:
+ message = ChatResponseMessage(
+ role=chunk["messageStart"]["role"],
+ content="",
+ )
+ if "contentBlockStart" in chunk:
+ # tool call start
+ delta = chunk["contentBlockStart"]["start"]
+ if "toolUse" in delta:
+ message = ChatResponseMessage(
+ tool_calls=[
+ ToolCall(
+ id=delta["toolUse"]["toolUseId"],
+ function=ResponseFunction(
+ name=delta["toolUse"]["name"],
+ arguments="",
+ ),
+ )
+ ]
+ )
+ if "contentBlockDelta" in chunk:
+ delta = chunk["contentBlockDelta"]["delta"]
+ if "text" in delta:
+ # stream content
+ message = ChatResponseMessage(
+ content=delta["text"],
+ )
+ else:
+ # tool use
+ message = ChatResponseMessage(
+ tool_calls=[
+ ToolCall(
+ function=ResponseFunction(
+ arguments=delta["toolUse"]["input"],
+ )
+ )
+ ]
+ )
+ if "messageStop" in chunk:
+ message = ChatResponseMessage()
+ finish_reason = chunk["messageStop"]["stopReason"]
+
+ if "metadata" in chunk:
+ # usage information in metadata.
+ metadata = chunk["metadata"]
+ if "usage" in metadata:
+ # token usage
+ return ChatStreamResponse(
+ id=message_id,
+ model=model_id,
+ choices=[],
+ usage=Usage(
+ prompt_tokens=metadata["usage"]["inputTokens"],
+ completion_tokens=metadata["usage"]["outputTokens"],
+ total_tokens=metadata["usage"]["totalTokens"],
+ ),
+ )
+ if message:
+ return ChatStreamResponse(
id=message_id,
- model=model,
+ model=model_id,
choices=[
ChoiceDelta(
index=0,
- delta=ChatResponseMessage(
- role="assistant",
- tool_calls=tools,
- content=chunk_message,
- ),
+ delta=message,
+ logprobs=None,
finish_reason=finish_reason,
)
],
- )
- else:
- response = ChatStreamResponse(
- id=message_id,
- model=model,
- choices=[],
- usage=Usage(
- prompt_tokens=input_tokens,
- completion_tokens=output_tokens,
- total_tokens=input_tokens + output_tokens,
- ),
- )
- if DEBUG:
- logger.info("Proxy response :" + response.model_dump_json())
- return response
-
-
-class ClaudeModel(BedrockModel):
- anthropic_version = "bedrock-2023-05-31"
- # follow these instructions for tool uses:
- tool_prompt = """You have access to the following tools:
-{tools}
-
-Please think if you need to use a tool or not for user's question, you must:
-1. Respond Y or N within tags first to indicate that.
-2. If a tool is needed, MUST respond a JSON object matching the following schema within tags:
- {{"name": $TOOL_NAME, "arguments": {{"$PARAMETER_NAME": "$PARAMETER_VALUE", ...}}}}
-3. If no tools is needed, respond with normal text."""
-
- def compose_request_body(self, chat_request: ChatRequest) -> str:
- args = {
- "anthropic_version": self.anthropic_version,
- "max_tokens": chat_request.max_tokens,
- "top_p": chat_request.top_p,
- "temperature": chat_request.temperature,
- }
- system_prompt = ""
- converted_messages = []
- for message in chat_request.messages:
- if message.role == "system":
- assert isinstance(message.content, str)
- system_prompt += message.content + "\n"
- elif message.role == "user" and not isinstance(message.content, str):
- converted_messages.append(
- {
- "role": message.role,
- "content": self._parse_content_parts(message.content),
- }
- )
- elif message.role == "assistant" and not message.content:
- # if content is empty
- # create the content using the tool call info.
- tool_content = "[Tool use for `{}` with id `{}` with the following `input`]\n{}".format(
- message.tool_calls[0].function.name,
- message.tool_calls[0].id,
- message.tool_calls[0].function.arguments,
- )
- converted_messages.append(
- {"role": message.role, "content": tool_content}
- )
- elif message.role == "tool":
- # Since bedrock does not support tool role
- # Convert the tool message to a user message.
- converted_messages.append(
- {
- "role": "user",
- "content": "[Tool result with matching id `{}` of `{}`] ".format(
- message.tool_call_id, message.content
- ),
- }
- )
- else:
- converted_messages.append(
- {"role": message.role, "content": message.content}
- )
-
- if chat_request.tools:
- tools_str = json.dumps(
- [tool.function.model_dump() for tool in chat_request.tools]
- )
- system_prompt += self.tool_prompt.format(tools=tools_str)
- converted_messages.append({"role": "assistant", "content": ""})
- args["stop_sequences"] = [""]
- args["messages"] = self.merge_message(converted_messages)
- if system_prompt:
- if DEBUG:
- logger.info("System Prompt: " + system_prompt)
- args["system"] = system_prompt
-
- return json.dumps(args)
-
- def parse_response(
- self, chat_request: ChatRequest, service_response: dict, message_id: str
- ) -> ChatResponse:
- response_body = json.loads(service_response.get("body").read())
- if DEBUG:
- logger.info("Bedrock response body: " + str(response_body))
- message = response_body["content"][0]["text"]
- finish_reason = response_body["stop_reason"]
- tools = None
- if chat_request.tools:
- if message.startswith("Y"):
- tools = self._parse_tool_message(message)
- message = None
- finish_reason = "tool_calls"
- elif message.startswith("N"):
- message = message[8:].lstrip("\n")
- return self.create_response(
- model=chat_request.model,
- message_id=message_id,
- message=message,
- tools=tools,
- finish_reason=finish_reason,
- input_tokens=response_body["usage"]["input_tokens"],
- output_tokens=response_body["usage"]["output_tokens"],
- )
-
- def parse_stream_response(
- self, chat_request: ChatRequest, service_response: dict, message_id: str
- ) -> Iterable[ChatStreamResponse]:
-
- chunk_id = 0
- tool_message = ""
- first_token = True
- index = 0
- for event in service_response.get("body"):
- if DEBUG:
- logger.info("Bedrock response chunk: " + str(event))
- chunk = json.loads(event["chunk"]["bytes"])
- chunk_id += 1
-
- if chunk["type"] == "message_stop":
- # Get the usage for streaming response anyway.
- if "amazon-bedrock-invocationMetrics" in chunk:
- yield self.create_response_stream(
- model=chat_request.model,
- message_id=message_id,
- input_tokens=chunk["amazon-bedrock-invocationMetrics"][
- "inputTokenCount"
- ],
- output_tokens=chunk["amazon-bedrock-invocationMetrics"][
- "outputTokenCount"
- ],
- )
- break
-
- elif chunk["type"] == "message_delta":
- chunk_message = ""
- finish_reason = chunk["delta"]["stop_reason"]
-
- # Send tool message first if any.
- if chat_request.tools and tool_message:
- tools = self._parse_tool_message(tool_message)
- finish_reason = "tool_calls"
- response = self.create_response_stream(
- model=chat_request.model,
- message_id=message_id,
- tools=tools,
- )
- yield response
-
- elif chunk["type"] == "content_block_delta":
- chunk_message = chunk["delta"]["text"]
- finish_reason = None
- if chat_request.tools:
- # Check first token
- if not tool_message and chunk_message == "Y":
- tool_message = "Y"
- continue
- if tool_message:
- # Buffer all chunk message
- # in order to extract tool call info
- tool_message += chunk_message
- continue
- if index < 3:
- # Ignore the N, which is 3 tokens
- index += 1
- continue
- if first_token:
- chunk_message = chunk_message.lstrip("\n")
- first_token = False
- else:
- continue
- response = self.create_response_stream(
- model=chat_request.model,
- message_id=message_id,
- chunk_message=chunk_message,
- finish_reason=finish_reason,
+ usage=usage,
)
- yield response
+ return None
- def _parse_tool_message(self, tool_message: str) -> list[ToolCall]:
- if DEBUG:
- logger.info("Tool message: " + tool_message.replace("\n", " "))
- try:
- tool_messages = tool_message[tool_message.rindex("") + len(""):]
- function = json.loads(tool_messages.replace("\n", " "))
- args = json.dumps(function.get("arguments", {}))
- function = ResponseFunction(name=function["name"], arguments=args)
-
- return [
- ToolCall(
- # id="0",
- function=function,
- )
- ]
-
- except Exception as e:
- logger.error("Failed to parse tool response" + str(e))
- raise HTTPException(status_code=500, detail="Failed to parse tool response")
-
- def _get_base64_image(self, image_url: str) -> tuple[str, str]:
- """Try to get the base64 data from an image url.
+ def _parse_image(self, image_url: str) -> tuple[bytes, str]:
+ """Try to get the raw data from an image url.
+ Ref: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ImageSource.html
returns a tuple of (Image Data, Content Type)
"""
pattern = r"^data:(image/[a-z]*);base64,\s*"
@@ -544,7 +534,7 @@ Please think if you need to use a tool or not for user's question, you must:
# Only supports 'image/jpeg', 'image/png', 'image/gif' or 'image/webp'
if content_type:
image_data = re.sub(pattern, "", image_url)
- return image_data, content_type.group(1)
+ return base64.b64decode(image_data), content_type.group(1)
# Send a request to the image URL
response = requests.get(image_url)
@@ -556,228 +546,80 @@ Please think if you need to use a tool or not for user's question, you must:
content_type = "image/jpeg"
# Get the image content
image_content = response.content
- # Encode the image content as base64
- base64_image = base64.b64encode(image_content)
- return base64_image.decode("utf-8"), content_type
+ return image_content, content_type
else:
raise HTTPException(
status_code=500, detail="Unable to access the image url"
)
def _parse_content_parts(
- self, content: list[TextContent | ImageContent]
+ self,
+ message: UserMessage,
+ model_id: str,
) -> list[dict]:
- # See: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
+ if isinstance(message.content, str):
+ return [
+ {
+ "text": message.content,
+ }
+ ]
content_parts = []
- for part in content:
+ for part in message.content:
if isinstance(part, TextContent):
- content_parts.append(part.model_dump())
- else:
- image_data, content_type = self._get_base64_image(part.image_url.url)
content_parts.append(
{
- "type": "image",
- "source": {
- "type": "base64",
- "media_type": content_type,
- "data": image_data,
+ "text": part.text,
+ }
+ )
+ elif isinstance(part, ImageContent):
+ if not self._is_multimodal_supported(model_id):
+ raise HTTPException(
+ status_code=400,
+ detail=f"Multimodal message is currently not supported by {model_id}",
+ )
+ image_data, content_type = self._parse_image(part.image_url.url)
+ content_parts.append(
+ {
+ "image": {
+ "format": content_type[6:], # image/
+ "source": {"bytes": image_data},
},
}
)
+ else:
+ # Ignore..
+ continue
return content_parts
+ def _is_tool_call_supported(self, model_id: str, stream: bool = False) -> bool:
+ feature = self._supported_models.get(model_id)
+ if not feature:
+ return False
+ return feature["stream_tool_call"] if stream else feature["tool_call"]
-class LlamaModel(BedrockModel):
- text_field_name = "generation"
- finish_reason_field_name = "stop_reason"
+ def _is_multimodal_supported(self, model_id: str) -> bool:
+ feature = self._supported_models.get(model_id)
+ if not feature:
+ return False
+ return feature["multimodal"]
- @staticmethod
- def create_llama3_prompt(chat_request: ChatRequest) -> str:
- """Create a prompt message for Llama 3 following below example:
+ def _is_system_prompt_supported(self, model_id: str) -> bool:
+ feature = self._supported_models.get(model_id)
+ if not feature:
+ return False
+ return feature["system"]
- <|begin_of_text|><|start_header_id|>system<|end_header_id|>
-
- {{ system_prompt }}<|eot_id|><|start_header_id|>user<|end_header_id|>
-
- {{ user_message_1 }}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
-
- {{ model_answer_1 }}<|eot_id|><|start_header_id|>user<|end_header_id|>
-
- {{ user_message_2 }}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
- """
- if DEBUG:
- logger.info("Convert below messages to prompt for Llama 3: ")
- for msg in chat_request.messages:
- logger.info(msg.model_dump_json())
- bos_token = "<|begin_of_text|>"
-
- prompt_lines = []
- for msg in chat_request.messages:
- prompt_lines.append(
- f"<|start_header_id|>{msg.role}<|end_header_id|>\n\n{msg.content}<|eot_id|>"
- )
- prompt_lines.append(f"<|start_header_id|>assistant<|end_header_id|>\n\n")
- prompt = bos_token + "".join(prompt_lines)
- if DEBUG:
- logger.info("Converted prompt: " + prompt.replace("\n", "\\n"))
- return prompt
-
- @staticmethod
- def create_llama2_prompt(chat_request: ChatRequest) -> str:
- """Create a prompt message for Llama 2 following below example:
-
- [INST] <>\n{your_system_message}\n<>\n\n{user_message_1} [/INST] {model_reply_1}
- [INST] {user_message_2} [/INST]
- """
- if DEBUG:
- logger.info("Convert below messages to prompt for Llama 2: ")
- for msg in chat_request.messages:
- logger.info(msg.model_dump_json())
- bos_token = ""
- eos_token = ""
- prompt = ""
- end_turn = False
- system_prompt = ""
- for msg in chat_request.messages:
- if msg.role == "system":
- system_prompt += "\n" + msg.content + "\n"
- continue
- if msg.role == "tool":
- raise HTTPException(
- status_code=500,
- detail="Tool prompt is not supported for Llama 2 model",
- )
- if not isinstance(msg.content, str):
- raise HTTPException(
- status_code=400, detail="Content must be a string for Llama 2 model"
- )
- if msg.role == "user":
- if end_turn:
- prompt += bos_token + "[INST] "
- prompt += msg.content + " [/INST] "
- end_turn = False
- else:
- prompt += msg.content + eos_token
- end_turn = True
-
- if system_prompt:
- system_prompt = "<>" + system_prompt + "<>"
- prompt = bos_token + "[INST] " + system_prompt + prompt
- if DEBUG:
- logger.info("Converted prompt: " + prompt.replace("\n", "\\n"))
- return prompt
-
- def compose_request_body(self, chat_request: ChatRequest) -> str:
- if chat_request.model.startswith("meta.llama2"):
- prompt = self.create_llama2_prompt(chat_request)
- else:
- prompt = self.create_llama3_prompt(chat_request)
- args = {
- "prompt": prompt,
- "max_gen_len": chat_request.max_tokens,
- "temperature": chat_request.temperature,
- "top_p": chat_request.top_p,
- }
- return json.dumps(args)
-
-
-class MistralModel(BedrockModel):
- text_field_name = "text"
- finish_reason_field_name = "stop_reason"
-
- def _convert_prompt(self, chat_request: ChatRequest) -> str:
- """Create a prompt message follow below example:
-
- [INST] {your_system_message}\n{user_message_1} [/INST] {model_reply_1}
- [INST] {user_message_2} [/INST]
- """
- # TODO: maybe reuse the Llama 2 one.
- if DEBUG:
- logger.info("Convert below messages to prompt for Mistral/Mixtral model: ")
- for msg in chat_request.messages:
- logger.info(msg.model_dump_json())
- bos_token = ""
- eos_token = ""
- prompt = ""
- end_turn = False
- system_prompt = ""
- for msg in chat_request.messages:
- if msg.role == "system":
- system_prompt += "\n" + msg.content + "\n"
- continue
- if msg.role == "tool":
- raise HTTPException(
- status_code=500,
- detail="Tool prompt is not supported for Mistral/Mixtral model",
- )
- if not isinstance(msg.content, str):
- raise HTTPException(
- status_code=400,
- detail="Content must be a string for Mistral/Mixtral model",
- )
- if msg.role == "user":
- if end_turn:
- prompt += bos_token + "[INST] "
- prompt += msg.content + " [/INST] "
- end_turn = False
- else:
- prompt += msg.content + eos_token
- end_turn = True
-
- prompt = bos_token + "[INST] " + system_prompt + prompt
- if DEBUG:
- logger.info("Converted prompt: " + prompt.replace("\n", "\\n"))
- return prompt
-
- def compose_request_body(self, chat_request: ChatRequest) -> str:
- prompt = self._convert_prompt(chat_request)
- args = {
- "prompt": prompt,
- "max_tokens": chat_request.max_tokens,
- "temperature": chat_request.temperature,
- "top_p": chat_request.top_p,
- }
- return json.dumps(args)
-
- def get_message_text(self, response_body: dict) -> str | None:
- return super().get_message_text(response_body["outputs"][0])
-
- def get_message_finish_reason(self, response_body: dict) -> str | None:
- return super().get_message_finish_reason(response_body["outputs"][0])
-
- def get_message_usage(self, response_body: dict) -> tuple[int, int]:
- # Mistral/Mixtral does not provide info about usage
- return 0, 0
-
-
-class CohereCommandModel(BedrockModel):
-
- def _parse_message(self, message) -> dict:
- if message.role not in ["user", "assistant"]:
- raise HTTPException(
- status_code=400, detail="Only user or assistant message is supported"
- )
+ def _convert_tool_spec(self, func: Function) -> dict:
return {
- "role": "USER" if message.role == "user" else "CHATBOT",
- "message": message.content,
+ "toolSpec": {
+ "name": func.name,
+ "description": func.description,
+ "inputSchema": {
+ "json": func.parameters,
+ },
+ }
}
- def compose_request_body(self, chat_request: ChatRequest) -> str:
- messages = chat_request.messages
- if messages[-1].role != "user":
- raise HTTPException(
- status_code=400, detail="Last message should be a valid user message"
- )
- chat_history = [self._parse_message(msg) for msg in messages[:-1]]
- args = {
- "message": messages[-1].content,
- "chat_history": chat_history,
- "max_tokens": chat_request.max_tokens,
- "temperature": chat_request.temperature,
- "p": chat_request.top_p,
- }
- return json.dumps(args)
-
class BedrockEmbeddingsModel(BaseEmbeddingsModel, ABC):
accept = "application/json"
@@ -919,25 +761,6 @@ class TitanEmbeddingsModel(BedrockEmbeddingsModel):
)
-def get_model(model_id: str) -> BedrockModel:
- if DEBUG:
- logger.info("model id is " + model_id)
- if model_id not in SUPPORTED_BEDROCK_MODELS.keys():
- logger.error("Unsupported model id " + model_id)
- raise HTTPException(
- status_code=400,
- detail="Unsupported model id " + model_id,
- )
- if model_id.startswith("anthropic.claude"):
- return ClaudeModel()
- elif model_id.startswith("meta.llama"):
- return LlamaModel()
- elif model_id.startswith("mistral.mistral") or model_id.startswith("mistral.mixtral"):
- return MistralModel()
- elif model_id.startswith("cohere.command-r"):
- return CohereCommandModel()
-
-
def get_embeddings_model(model_id: str) -> BedrockEmbeddingsModel:
model_name = SUPPORTED_BEDROCK_EMBEDDING_MODELS.get(model_id, "")
if DEBUG:
diff --git a/src/api/routers/chat.py b/src/api/routers/chat.py
index f2758fa..58ea73f 100644
--- a/src/api/routers/chat.py
+++ b/src/api/routers/chat.py
@@ -4,7 +4,7 @@ from fastapi import APIRouter, Depends, Body
from fastapi.responses import StreamingResponse
from api.auth import api_key_auth
-from api.models import get_model
+from api.models.bedrock import BedrockModel
from api.schema import ChatRequest, ChatResponse, ChatStreamResponse
from api.setting import DEFAULT_MODEL
@@ -34,9 +34,10 @@ async def chat_completions(
):
if chat_request.model.lower().startswith("gpt-"):
chat_request.model = DEFAULT_MODEL
-
+
# Exception will be raised if model not supported.
- model = get_model(chat_request.model)
+ model = BedrockModel()
+ model.validate(chat_request)
if chat_request.stream:
return StreamingResponse(
content=model.chat_stream(chat_request), media_type="text/event-stream"
diff --git a/src/api/routers/embeddings.py b/src/api/routers/embeddings.py
index 135fc27..e5cde31 100644
--- a/src/api/routers/embeddings.py
+++ b/src/api/routers/embeddings.py
@@ -3,7 +3,7 @@ from typing import Annotated
from fastapi import APIRouter, Depends, Body
from api.auth import api_key_auth
-from api.models import get_embeddings_model
+from api.models.bedrock import get_embeddings_model
from api.schema import EmbeddingsRequest, EmbeddingsResponse
from api.setting import DEFAULT_EMBEDDING_MODEL
diff --git a/src/api/routers/model.py b/src/api/routers/model.py
index 2640d7d..ce5e8a1 100644
--- a/src/api/routers/model.py
+++ b/src/api/routers/model.py
@@ -3,7 +3,7 @@ from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, Path
from api.auth import api_key_auth
-from api.models import SUPPORTED_BEDROCK_MODELS, SUPPORTED_BEDROCK_EMBEDDING_MODELS
+from api.models.bedrock import BedrockModel
from api.schema import Models, Model
router = APIRouter(
@@ -12,16 +12,19 @@ router = APIRouter(
# responses={404: {"description": "Not found"}},
)
+chat_model = BedrockModel()
+
async def validate_model_id(model_id: str):
- if model_id not in (SUPPORTED_BEDROCK_MODELS | SUPPORTED_BEDROCK_EMBEDDING_MODELS).keys():
+ if model_id not in chat_model.list_models():
raise HTTPException(status_code=500, detail="Unsupported Model Id")
@router.get("", response_model=Models)
async def list_models():
- model_list = [Model(id=model_id) for model_id in
- (SUPPORTED_BEDROCK_MODELS | SUPPORTED_BEDROCK_EMBEDDING_MODELS).keys()]
+ model_list = [
+ Model(id=model_id) for model_id in chat_model.list_models()
+ ]
return Models(data=model_list)
diff --git a/src/api/schema.py b/src/api/schema.py
index da133d5..2b2f2fb 100644
--- a/src/api/schema.py
+++ b/src/api/schema.py
@@ -1,5 +1,4 @@
import time
-import uuid
from typing import Literal, Iterable
from pydantic import BaseModel, Field
@@ -18,12 +17,12 @@ class Models(BaseModel):
class ResponseFunction(BaseModel):
- name: str
+ name: str | None = None
arguments: str
class ToolCall(BaseModel):
- id: str = Field(default_factory=lambda: str(uuid.uuid4())[:8])
+ id: str | None = None
type: Literal["function"] = "function"
function: ResponseFunction
@@ -113,8 +112,8 @@ class ChatResponseMessage(BaseModel):
class BaseChoice(BaseModel):
- index: int
- finish_reason: str | None
+ index: int | None = 0
+ finish_reason: str | None = None
logprobs: dict | None = None
diff --git a/src/api/setting.py b/src/api/setting.py
index 56e99c4..9543e20 100644
--- a/src/api/setting.py
+++ b/src/api/setting.py
@@ -11,27 +11,11 @@ DESCRIPTION = """
Use OpenAI-Compatible RESTful APIs for Amazon Bedrock models.
List of Amazon Bedrock models currently supported:
-
-# Chat
-- anthropic.claude-instant-v1
-- anthropic.claude-v2:1
-- anthropic.claude-v2
-- anthropic.claude-3-opus-20240229-v1:0
-- anthropic.claude-3-sonnet-20240229-v1:0
-- anthropic.claude-3-haiku-20240307-v1:0
-- meta.llama2-13b-chat-v1
-- meta.llama2-70b-chat-v1
-- meta.llama3-8b-instruct-v1:0
-- meta.llama3-70b-instruct-v1:0
-- mistral.mistral-7b-instruct-v0:2
-- mistral.mixtral-8x7b-instruct-v0:1
-- mistral.mistral-large-2402-v1:0
-- cohere.command-r-v1:0
-- cohere.command-r-plus-v1:0
-
-# Embeddings
-- cohere.embed-multilingual-v3
-- cohere.embed-english-v3
+- Anthropic Claude 2 / 3 (Haiku/Sonnet/Opus)
+- Meta Llama 2 / 3
+- Mistral / Mixtral
+- Cohere Command R / R+
+- Cohere Embedding
"""
DEBUG = os.environ.get("DEBUG", "false").lower() != "false"
diff --git a/src/requirements.txt b/src/requirements.txt
index 4dc944e..3af1653 100644
--- a/src/requirements.txt
+++ b/src/requirements.txt
@@ -1,7 +1,9 @@
-fastapi==0.110.2
+fastapi==0.111.0
pydantic==2.7.1
uvicorn==0.29.0
mangum==0.17.0
tiktoken==0.6.0
-requests==2.32.0
-numpy==1.26.4
\ No newline at end of file
+requests==2.32.3
+numpy==1.26.4
+boto3==1.34.117
+botocore==1.34.117
\ No newline at end of file