diff --git a/src/api/models/bedrock.py b/src/api/models/bedrock.py
index f5bd865..c86d79b 100644
--- a/src/api/models/bedrock.py
+++ b/src/api/models/bedrock.py
@@ -1,6 +1,7 @@
import base64
import json
import logging
+from abc import ABC
from typing import AsyncIterable, Iterable
import boto3
@@ -13,17 +14,21 @@ from api.schema import (
# Chat
ChatResponse,
ChatRequest,
- ChatRequestMessage,
Choice,
ChatResponseMessage,
Usage,
ChatStreamResponse,
ChoiceDelta,
+ ImageContent,
+ TextContent,
+ ResponseFunction,
+ ToolCall,
+ Tool,
# Embeddings
EmbeddingsRequest,
EmbeddingsResponse,
EmbeddingsUsage,
- Embedding, TextContent,
+ Embedding,
)
from api.setting import DEBUG, AWS_REGION
@@ -58,7 +63,7 @@ ENCODER = tiktoken.get_encoding("cl100k_base")
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html
-class BedrockModel(BaseChatModel):
+class BedrockModel(BaseChatModel, ABC):
accept = "application/json"
content_type = "application/json"
@@ -86,17 +91,30 @@ class BedrockModel(BaseChatModel):
model: str,
message: str,
message_id: str,
+ tools_message: str | None = None,
input_tokens: int = 0,
output_tokens: int = 0,
) -> ChatResponse:
- choice = Choice(
- index=0,
- message=ChatResponseMessage(
- role="assistant",
- content=message,
- ),
- finish_reason="stop",
- )
+ if tools_message:
+ # For tool response, the content is empty
+ tools = self._parse_tools_response(tools_message)
+ choice = Choice(
+ index=0,
+ message=ChatResponseMessage(
+ role="assistant",
+ tool_calls=tools,
+ ),
+ finish_reason="stop",
+ )
+ else:
+ choice = Choice(
+ index=0,
+ message=ChatResponseMessage(
+ role="assistant",
+ content=message,
+ ),
+ finish_reason="stop",
+ )
response = ChatResponse(
id=message_id,
model=model,
@@ -132,25 +150,31 @@ class BedrockModel(BaseChatModel):
return response
-def get_model(model_id: str) -> BedrockModel:
- model_name = SUPPORTED_BEDROCK_MODELS.get(model_id, "")
- if DEBUG:
- logger.info("model name is " + model_name)
- if model_name in ["Claude Instant", "Claude", "Claude 3 Sonnet", "Claude 3 Haiku"]:
- return ClaudeModel()
- elif model_name in ["Llama 2 Chat 13B", "Llama 2 Chat 70B"]:
- return Llama2Model()
- elif model_name in ["Mistral 7B Instruct", "Mixtral 8x7B Instruct"]:
- return MistralModel()
- else:
- logger.error("Unsupported model id " + model_id)
- raise ValueError("Invalid model ID")
-
-
class ClaudeModel(BedrockModel):
anthropic_version = "bedrock-2023-05-31"
- def _get_base64_image(self, image_url: str):
+ def _parse_tools_response(self, tools_messages: str) -> list[ToolCall]:
+ """Parse the tools response
+
+ Example tool message like:
+ \n{\n "name": "get_current_weather",\n "arguments": {\n "location": "Shanghai"... }\n}\n
+ """
+ function = json.loads(
+ tools_messages.replace("\n", " ").encode("unicode_escape")
+ )
+
+ args = json.dumps(function.get("arguments", {}))
+ function = ResponseFunction(
+ name=function["name"], arguments=args.replace("\\\\n", "\\n")
+ )
+ return [
+ ToolCall(
+ id="0",
+ function=function,
+ )
+ ]
+
+ def _get_base64_image(self, image_url: str) -> str:
# Send a request to the image URL
response = requests.get(image_url)
# Check if the request was successful
@@ -159,34 +183,44 @@ class ClaudeModel(BedrockModel):
image_content = response.content
# Encode the image content as base64
base64_image = base64.b64encode(image_content)
- return base64_image.decode('utf-8')
+ return base64_image.decode("utf-8")
else:
- raise HTTPException(status_code=500, detail="Unable to access the image url")
+ raise HTTPException(
+ status_code=500, detail="Unable to access the image url"
+ )
- def _parse_messages(self, messages: list[ChatRequestMessage]) -> list[dict]:
- # Refer to: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
- converted_messages = []
- for msg in messages:
- if isinstance(msg.content, str):
- converted_messages.append({"role": msg.role, "content": msg.content})
- continue
-
- content_parts = []
- for part in msg.content:
- if isinstance(part, TextContent):
- content_parts.append(part.model_dump())
- else:
- content_parts.append({
+ def _parse_content_parts(
+ self, content: list[TextContent | ImageContent]
+ ) -> list[dict]:
+ # See: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
+ content_parts = []
+ for part in content:
+ if isinstance(part, TextContent):
+ content_parts.append(part.model_dump())
+ else:
+ content_parts.append(
+ {
"type": "image",
"source": {
"type": "base64",
"media_type": "image/jpeg",
- "data": self._get_base64_image(part.image_url.url)
- }
- })
+ "data": self._get_base64_image(part.image_url.url),
+ },
+ }
+ )
+ return content_parts
- converted_messages.append({"role": msg.role, "content": content_parts})
- return converted_messages
+ def _create_tool_prompt(self, tools: list[Tool]) -> str:
+ tool_prompt = "\nYou have access to the following tools:\n"
+ tool_prompt += json.dumps(
+ [tool.function.model_dump() for tool in tools], indent=2
+ )
+ tool_prompt += (
+ "\nIf you need to use one of the above tools, "
+ "only respond with a JSON object matching the following schema inside a xml tag: \n"
+ '{"name": $TOOL_NAME, "arguments": {"$PARAMETER_NAME": "$PARAMETER_VALUE", ...}\n'
+ )
+ return tool_prompt
def _parse_args(self, chat_request: ChatRequest) -> dict:
args = {
@@ -195,25 +229,73 @@ class ClaudeModel(BedrockModel):
"top_p": chat_request.top_p,
"temperature": chat_request.temperature,
}
- start = 0
- if chat_request.messages[0].role == "system":
- args["system"] = chat_request.messages[0].content
- start = 1
- args["messages"] = self._parse_messages(chat_request.messages[start:])
+ system_prompt = ""
+ converted_messages = []
+ for message in chat_request.messages:
+ if message.role == "system":
+ system_prompt += message.content + "\n"
+ elif message.role == "user" and not isinstance(message.content, str):
+ converted_messages.append(
+ {
+ "role": message.role,
+ "content": self._parse_content_parts(message.content),
+ }
+ )
+ elif message.role == "assistant" and not message.content:
+ # if content is empty
+ # create the content using the tool call info.
+ tool_content = "Should use {} tool with args: {}".format(
+ message.tool_calls[0].function.name,
+ message.tool_calls[0].function.arguments,
+ )
+ converted_messages.append(
+ {"role": message.role, "content": tool_content}
+ )
+ elif message.role == "tool":
+ # Since bedrock does not support tool role
+ # Convert the tool message to a user message.
+ converted_messages.append(
+ {
+ "role": "user",
+ "content": "The result of the tool call is " + message.content,
+ }
+ )
+ else:
+ converted_messages.append(
+ {"role": message.role, "content": message.content}
+ )
+
+ if chat_request.tools:
+ system_prompt += self._create_tool_prompt(chat_request.tools)
+
+ args["messages"] = converted_messages
+ if system_prompt:
+ if DEBUG:
+ logger.info("System Prompt: " + system_prompt)
+ args["system"] = system_prompt.replace("\n", "")
return args
def chat(self, chat_request: ChatRequest) -> ChatResponse:
+ if DEBUG:
+ logger.info("Raw request: " + chat_request.model_dump_json())
response = self._invoke_model(
args=self._parse_args(chat_request), model_id=chat_request.model
)
response_body = json.loads(response.get("body").read())
if DEBUG:
logger.info("Bedrock response body: " + str(response_body))
+ message = response_body["content"][0]["text"]
+ tools_message = None
+ start = message.find("")
+ end = message.find("")
+ if start != -1 and end != -1:
+ tools_message = message[start + 6: end]
return self._create_response(
model=chat_request.model,
message=response_body["content"][0]["text"],
message_id=response_body["id"],
+ tools_message=tools_message,
input_tokens=response_body["usage"]["input_tokens"],
output_tokens=response_body["usage"]["output_tokens"],
)
@@ -256,7 +338,7 @@ class ClaudeModel(BedrockModel):
class Llama2Model(BedrockModel):
- def _convert_prompt(self, messages: list[ChatRequestMessage]) -> str:
+ def _convert_prompt(self, chat_request: ChatRequest) -> str:
"""Create a prompt message follow below example:
[INST] <>\n{your_system_message}\n<>\n\n{user_message_1} [/INST] {model_reply_1}
@@ -264,21 +346,26 @@ class Llama2Model(BedrockModel):
"""
if DEBUG:
logger.info("Convert below messages to prompt for Llama 2: ")
- for msg in messages:
+ for msg in chat_request.messages:
logger.info(msg.model_dump_json())
bos_token = ""
eos_token = ""
- prompt = bos_token + "[INST] "
- start = 0
+ prompt = ""
end_turn = False
- if messages[0].role == "system":
- prompt += "<>\n" + messages[0].content + "\n<>\n\n"
- start = 1
- # TODO: Add validation
- for i in range(start, len(messages)):
- msg = messages[i]
+ system_prompt = ""
+ for msg in chat_request.messages:
+ if msg.role == "system":
+ system_prompt += "\n" + msg.content + "\n"
+ continue
+ if msg.role == "tool":
+ raise HTTPException(
+ status_code=500,
+ detail="Tool prompt is not supported for Llama 2 model",
+ )
if not isinstance(msg.content, str):
- raise HTTPException(status_code=400, detail="Content must be a string for Llama 2 model")
+ raise HTTPException(
+ status_code=400, detail="Content must be a string for Llama 2 model"
+ )
if msg.role == "user":
if end_turn:
prompt += bos_token + "[INST] "
@@ -287,12 +374,16 @@ class Llama2Model(BedrockModel):
else:
prompt += msg.content + eos_token
end_turn = True
+
+ if system_prompt:
+ system_prompt = "<>" + system_prompt + "<>"
+ prompt = bos_token + "[INST] " + system_prompt + prompt
if DEBUG:
logger.info("Converted prompt: " + prompt.replace("\n", "\\n"))
return prompt
def _parse_args(self, chat_request: ChatRequest) -> dict:
- prompt = self._convert_prompt(chat_request.messages)
+ prompt = self._convert_prompt(chat_request)
return {
"prompt": prompt,
"max_gen_len": chat_request.max_tokens,
@@ -340,29 +431,36 @@ class Llama2Model(BedrockModel):
class MistralModel(BedrockModel):
- def _convert_prompt(self, messages: list[ChatRequestMessage]) -> str:
+ def _convert_prompt(self, chat_request: ChatRequest) -> str:
"""Create a prompt message follow below example:
[INST] {your_system_message}\n{user_message_1} [/INST] {model_reply_1}
[INST] {user_message_2} [/INST]
"""
+ # TODO: maybe reuse the Llama 2 one.
if DEBUG:
- logger.info("Convert below messages to prompt for Llama 2: ")
- for msg in messages:
+ logger.info("Convert below messages to prompt for Mistral/Mixtral model: ")
+ for msg in chat_request.messages:
logger.info(msg.model_dump_json())
bos_token = ""
eos_token = ""
- prompt = bos_token + "[INST] "
- start = 0
+ prompt = ""
end_turn = False
- if messages[0].role == "system":
- prompt += messages[0].content + "\n"
- start = 1
- # TODO: Add validation
- for i in range(start, len(messages)):
- msg = messages[i]
+ system_prompt = ""
+ for msg in chat_request.messages:
+ if msg.role == "system":
+ system_prompt += "\n" + msg.content + "\n"
+ continue
+ if msg.role == "tool":
+ raise HTTPException(
+ status_code=500,
+ detail="Tool prompt is not supported for Mistral/Mixtral model",
+ )
if not isinstance(msg.content, str):
- raise HTTPException(status_code=400, detail="Content must be a string for Mistral/Mixtral model")
+ raise HTTPException(
+ status_code=400,
+ detail="Content must be a string for Mistral/Mixtral model",
+ )
if msg.role == "user":
if end_turn:
prompt += bos_token + "[INST] "
@@ -371,12 +469,14 @@ class MistralModel(BedrockModel):
else:
prompt += msg.content + eos_token
end_turn = True
+
+ prompt = bos_token + "[INST] " + system_prompt + prompt
if DEBUG:
logger.info("Converted prompt: " + prompt.replace("\n", "\\n"))
return prompt
def _parse_args(self, chat_request: ChatRequest) -> dict:
- prompt = self._convert_prompt(chat_request.messages)
+ prompt = self._convert_prompt(chat_request)
return {
"prompt": prompt,
"max_tokens": chat_request.max_tokens,
@@ -422,7 +522,7 @@ class MistralModel(BedrockModel):
yield self._stream_response_to_bytes(response)
-class BedrockEmbeddingsModel(BaseEmbeddingsModel):
+class BedrockEmbeddingsModel(BaseEmbeddingsModel, ABC):
accept = "application/json"
content_type = "application/json"
@@ -446,10 +546,8 @@ class BedrockEmbeddingsModel(BaseEmbeddingsModel):
output_tokens: int = 0,
) -> EmbeddingsResponse:
data = [
- Embedding(
- index=i,
- embedding=embedding
- ) for i, embedding in enumerate(embeddings)
+ Embedding(index=i, embedding=embedding)
+ for i, embedding in enumerate(embeddings)
]
response = EmbeddingsResponse(
data=data,
@@ -465,19 +563,6 @@ class BedrockEmbeddingsModel(BaseEmbeddingsModel):
return response
-def get_embeddings_model(model_id: str) -> BedrockEmbeddingsModel:
- model_name = SUPPORTED_BEDROCK_EMBEDDING_MODELS.get(model_id, "")
- if DEBUG:
- logger.info("model name is " + model_name)
- if model_name in ["Cohere Embed Multilingual", "Cohere Embed English"]:
- return CohereEmbeddingsModel()
- elif model_name in ["Titan Embeddings G1 - Text", "Titan Multimodal Embeddings G1"]:
- return TitanEmbeddingsModel()
- else:
- logger.error("Unsupported model id " + model_id)
- raise ValueError("Invalid model ID")
-
-
class CohereEmbeddingsModel(BedrockEmbeddingsModel):
def _parse_args(self, embeddings_request: EmbeddingsRequest) -> dict:
@@ -528,17 +613,25 @@ class TitanEmbeddingsModel(BedrockEmbeddingsModel):
def _parse_args(self, embeddings_request: EmbeddingsRequest) -> dict:
if isinstance(embeddings_request.input, str):
input_text = embeddings_request.input
- elif isinstance(embeddings_request.input, list) and len(embeddings_request.input) == 1:
+ elif (
+ isinstance(embeddings_request.input, list)
+ and len(embeddings_request.input) == 1
+ ):
input_text = embeddings_request.input[0]
else:
- raise ValueError("Amazon Titan Embeddings models support only single strings as input.")
+ raise ValueError(
+ "Amazon Titan Embeddings models support only single strings as input."
+ )
args = {
"inputText": input_text,
# Note: inputImage is not supported!
}
if embeddings_request.model == "amazon.titan-embed-image-v1":
- args["embeddingConfig"] = embeddings_request.embedding_config if embeddings_request.embedding_config else {
- "outputEmbeddingLength": 1024}
+ args["embeddingConfig"] = (
+ embeddings_request.embedding_config
+ if embeddings_request.embedding_config
+ else {"outputEmbeddingLength": 1024}
+ )
return args
def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse:
@@ -552,5 +645,39 @@ class TitanEmbeddingsModel(BedrockEmbeddingsModel):
return self._create_response(
embeddings=[response_body["embedding"]],
model=embeddings_request.model,
- input_tokens=response_body["inputTextTokenCount"]
+ input_tokens=response_body["inputTextTokenCount"],
+ )
+
+
+def get_model(model_id: str) -> BedrockModel:
+ model_name = SUPPORTED_BEDROCK_MODELS.get(model_id, "")
+ if DEBUG:
+ logger.info("model name is " + model_name)
+ if model_name in ["Claude Instant", "Claude", "Claude 3 Sonnet", "Claude 3 Haiku"]:
+ return ClaudeModel()
+ elif model_name in ["Llama 2 Chat 13B", "Llama 2 Chat 70B"]:
+ return Llama2Model()
+ elif model_name in ["Mistral 7B Instruct", "Mixtral 8x7B Instruct"]:
+ return MistralModel()
+ else:
+ logger.error("Unsupported model id " + model_id)
+ raise HTTPException(
+ status_code=500,
+ detail="Unsupported model id " + model_id,
+ )
+
+
+def get_embeddings_model(model_id: str) -> BedrockEmbeddingsModel:
+ model_name = SUPPORTED_BEDROCK_EMBEDDING_MODELS.get(model_id, "")
+ if DEBUG:
+ logger.info("model name is " + model_name)
+ if model_name in ["Cohere Embed Multilingual", "Cohere Embed English"]:
+ return CohereEmbeddingsModel()
+ elif model_name in ["Titan Embeddings G1 - Text", "Titan Multimodal Embeddings G1"]:
+ return TitanEmbeddingsModel()
+ else:
+ logger.error("Unsupported model id " + model_id)
+ raise HTTPException(
+ status_code=500,
+ detail="Unsupported model id " + model_id,
)
diff --git a/src/api/schema.py b/src/api/schema.py
index aa87a58..321dad8 100644
--- a/src/api/schema.py
+++ b/src/api/schema.py
@@ -16,6 +16,17 @@ class Models(BaseModel):
data: list[Model] = []
+class ResponseFunction(BaseModel):
+ name: str
+ arguments: str
+
+
+class ToolCall(BaseModel):
+ id: str
+ type: Literal["function"] = "function"
+ function: ResponseFunction
+
+
class TextContent(BaseModel):
type: Literal["text"] = "text"
text: str
@@ -31,12 +42,31 @@ class ImageContent(BaseModel):
image_url: ImageUrl
-class ChatRequestMessage(BaseModel):
+class SystemMessage(BaseModel):
name: str | None = None
- role: Literal["user", "assistant", "system"]
+ role: Literal["system"] = "system"
+ content: str
+
+
+class UserMessage(BaseModel):
+ name: str | None = None
+ role: Literal["user"] = "user"
content: str | list[TextContent | ImageContent]
+class AssistantMessage(BaseModel):
+ name: str | None = None
+ role: Literal["assistant"] = "assistant"
+ content: str | None
+ tool_calls: list[ToolCall] | None = None
+
+
+class ToolMessage(BaseModel):
+ role: Literal["tool"] = "tool"
+ content: str
+ tool_call_id: str
+
+
class Function(BaseModel):
name: str
description: str | None = None
@@ -49,7 +79,7 @@ class Tool(BaseModel):
class ChatRequest(BaseModel):
- messages: list[ChatRequestMessage]
+ messages: list[SystemMessage | UserMessage | AssistantMessage | ToolMessage]
model: str
frequency_penalty: float | None = Field(default=0.0, le=2.0, ge=-2.0) # Not used
presence_penalty: float | None = Field(default=0.0, le=2.0, ge=-2.0) # Not used
@@ -69,17 +99,6 @@ class Usage(BaseModel):
total_tokens: int
-class ResponseFunction(BaseModel):
- name: str
- arguments: str
-
-
-class ToolCall(BaseModel):
- id: str
- type: Literal["function"] = "function"
- function: ResponseFunction
-
-
class ChatResponseMessage(BaseModel):
# tool_calls
role: Literal["assistant"] | None = None
diff --git a/src/requirements.txt b/src/requirements.txt
index 322918e..49019b8 100644
--- a/src/requirements.txt
+++ b/src/requirements.txt
@@ -1,4 +1,6 @@
fastapi==0.110.0
pydantic==2.6.3
uvicorn==0.27.0.post1
-mangum==0.17.0
\ No newline at end of file
+mangum==0.17.0
+tiktoken==0.6.0
+requests==2.31.0
\ No newline at end of file