diff --git a/src/api/models/bedrock.py b/src/api/models/bedrock.py
index 9e83460..f131179 100644
--- a/src/api/models/bedrock.py
+++ b/src/api/models/bedrock.py
@@ -24,7 +24,6 @@ from api.schema import (
TextContent,
ResponseFunction,
ToolCall,
- Tool,
# Embeddings
EmbeddingsRequest,
EmbeddingsResponse,
@@ -88,39 +87,60 @@ class BedrockModel(BaseChatModel, ABC):
contentType=self.content_type,
)
+ @staticmethod
+ def merge_message(messages: list[dict]) -> list[dict]:
+ """Merge the request messages with the same role as previous message
+ """
+ merged_messages = []
+ prev_role = None
+ merged_content = ""
+
+ for message in messages:
+ role = message["role"]
+ content = message["content"]
+
+ if role != prev_role or isinstance(content, list):
+ if prev_role:
+ merged_messages.append({"role": prev_role, "content": merged_content})
+ if isinstance(content, str):
+ merged_content = content
+ prev_role = role
+ else:
+ merged_messages.append({"role": role, "content": content})
+ prev_role = None
+ merged_content = ""
+ else:
+ if content == merged_content:
+ # ignore duplicates
+ continue
+ merged_content += "\n" + content
+
+ if merged_content:
+ merged_messages.append({"role": prev_role, "content": merged_content})
+ return merged_messages
+
+ @staticmethod
def _create_response(
- self,
model: str,
- message: str,
message_id: str,
- tools_message: str | None = None,
+ message: str | None = None,
+ finish_reason: str | None = None,
+ tools: list[ToolCall] | None = None,
input_tokens: int = 0,
output_tokens: int = 0,
) -> ChatResponse:
- if tools_message:
- # For tool response, the content is empty
- tools = self._parse_tools_response(tools_message)
- choice = Choice(
+ response = ChatResponse(
+ id=message_id,
+ model=model,
+ choices=[Choice(
index=0,
message=ChatResponseMessage(
role="assistant",
tool_calls=tools,
- ),
- finish_reason="stop",
- )
- else:
- choice = Choice(
- index=0,
- message=ChatResponseMessage(
- role="assistant",
content=message,
),
- finish_reason="stop",
- )
- response = ChatResponse(
- id=message_id,
- model=model,
- choices=[choice],
+ finish_reason="tool_calls" if tools else finish_reason,
+ )],
usage=Usage(
prompt_tokens=input_tokens,
completion_tokens=output_tokens,
@@ -131,21 +151,26 @@ class BedrockModel(BaseChatModel, ABC):
logger.info("Proxy response :" + response.model_dump_json())
return response
+ @staticmethod
def _create_response_stream(
- self, model: str, message_id: str, chunk_message: str, finish_reason: str | None
+ model: str,
+ message_id: str,
+ chunk_message: str | None = None,
+ finish_reason: str | None = None,
+ tools: list[ToolCall] | None = None,
) -> ChatStreamResponse:
- choice = ChoiceDelta(
- index=0,
- delta=ChatResponseMessage(
- role="assistant",
- content=chunk_message,
- ),
- finish_reason=finish_reason,
- )
response = ChatStreamResponse(
id=message_id,
model=model,
- choices=[choice],
+ choices=[ChoiceDelta(
+ index=0,
+ delta=ChatResponseMessage(
+ role="assistant",
+ tool_calls=tools,
+ content=chunk_message,
+ ),
+ finish_reason="tool_calls" if tools else finish_reason,
+ )],
)
if DEBUG:
logger.info("Proxy response :" + response.model_dump_json())
@@ -154,27 +179,193 @@ class BedrockModel(BaseChatModel, ABC):
class ClaudeModel(BedrockModel):
anthropic_version = "bedrock-2023-05-31"
+ # follow these instructions for tool uses:
+ tool_prompt = """You have access to the following tools:
+{tools}
- def _parse_tools_response(self, tools_messages: str) -> list[ToolCall]:
- """Parse the tools response
+Please think if you need to use a tool or not for user's question, you must:
+1. Respond Y or N inside a xml tag first to indicate that.
+2. If a tool is needed, MUST respond a JSON object matching the following schema inside a xml tag:
+ {{"name": $TOOL_NAME, "arguments": {{"$PARAMETER_NAME": "$PARAMETER_VALUE", ...}}}}
+3. If no tools is needed, respond with normal text."""
- Example tool message like:
- \n{\n "name": "get_current_weather",\n "arguments": {\n "location": "Shanghai"... }\n}\n
- """
- function = json.loads(
- tools_messages.replace("\n", " ").encode("unicode_escape")
- )
+ def _parse_args(self, chat_request: ChatRequest) -> dict:
+ args = {
+ "anthropic_version": self.anthropic_version,
+ "max_tokens": chat_request.max_tokens,
+ "top_p": chat_request.top_p,
+ "temperature": chat_request.temperature,
+ }
+ system_prompt = ""
+ converted_messages = []
+ for message in chat_request.messages:
+ if message.role == "system":
+ system_prompt += message.content + "\n"
+ elif message.role == "user" and not isinstance(message.content, str):
+ converted_messages.append(
+ {
+ "role": message.role,
+ "content": self._parse_content_parts(message.content),
+ }
+ )
+ elif message.role == "assistant" and not message.content:
+ # if content is empty
+ # create the content using the tool call info.
+ tool_content = "[Tool use for `{}` with id `{}` with the following `input`]\n{}".format(
+ message.tool_calls[0].function.name,
+ message.tool_calls[0].id,
+ message.tool_calls[0].function.arguments,
+ )
+ converted_messages.append(
+ {"role": message.role, "content": tool_content}
+ )
+ elif message.role == "tool":
+ # Since bedrock does not support tool role
+ # Convert the tool message to a user message.
+ converted_messages.append(
+ {
+ "role": "user",
+ "content": "[Tool result with matching id `{}` of `{}`] ".format(
+ message.tool_call_id,
+ message.content),
+ }
+ )
+ else:
+ converted_messages.append(
+ {"role": message.role, "content": message.content}
+ )
- args = json.dumps(function.get("arguments", {}))
- function = ResponseFunction(
- name=function["name"], arguments=args.replace("\\\\n", "\\n")
- )
- return [
- ToolCall(
- id="0",
- function=function,
+ if chat_request.tools:
+ tools_str = json.dumps(
+ [tool.function.model_dump() for tool in chat_request.tools]
)
- ]
+ system_prompt += self.tool_prompt.format(tools=tools_str)
+ converted_messages.append({
+ 'role': 'assistant',
+ 'content': ''
+ })
+ args["stop_sequences"] = ['']
+ args["messages"] = self.merge_message(converted_messages)
+ if system_prompt:
+ if DEBUG:
+ logger.info("System Prompt: " + system_prompt)
+ args["system"] = system_prompt
+
+ return args
+
+ def chat(self, chat_request: ChatRequest) -> ChatResponse:
+ if DEBUG:
+ logger.info("Raw request: " + chat_request.model_dump_json())
+ response = self._invoke_model(
+ args=self._parse_args(chat_request), model_id=chat_request.model
+ )
+ response_body = json.loads(response.get("body").read())
+ if DEBUG:
+ logger.info("Bedrock response body: " + str(response_body))
+ message = response_body["content"][0]["text"]
+
+ tools = None
+ if chat_request.tools:
+ if message.startswith("Y"):
+ tools = self._parse_tool_message(message)
+ message = None
+ elif message.startswith("N"):
+ message = message[8:].lstrip("\n")
+ return self._create_response(
+ model=chat_request.model,
+ message_id=response_body["id"],
+ message=message,
+ tools=tools,
+ finish_reason=response_body["stop_reason"],
+ input_tokens=response_body["usage"]["input_tokens"],
+ output_tokens=response_body["usage"]["output_tokens"],
+ )
+
+ def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
+ response = self._invoke_model(
+ args=self._parse_args(chat_request),
+ model_id=chat_request.model,
+ with_stream=True,
+ )
+ msg_id = ""
+ chunk_id = 0
+ tool_message = ""
+ first_token = True
+ index = 0
+ for event in response.get("body"):
+ if DEBUG:
+ logger.info("Bedrock response chunk: " + str(event))
+ chunk = json.loads(event["chunk"]["bytes"])
+ chunk_id += 1
+
+ if chunk["type"] == "message_start":
+ msg_id = chunk["message"]["id"]
+ continue
+
+ if chunk["type"] == "message_delta":
+ chunk_message = ""
+ finish_reason = chunk["delta"]["stop_reason"]
+
+ elif chunk["type"] == "content_block_delta":
+ chunk_message = chunk["delta"]["text"]
+ finish_reason = None
+ if chat_request.tools:
+ # Check first token
+ if not tool_message and chunk_message == "Y":
+ tool_message = "Y"
+ continue
+ if tool_message:
+ # Buffer all chunk message
+ # in order to extract tool call info
+ tool_message += chunk_message
+ continue
+ if index < 3:
+ # Ignore the N, which is 3 tokens
+ index += 1
+ continue
+ if first_token:
+ chunk_message = chunk_message.lstrip("\n")
+ first_token = False
+ else:
+ continue
+ response = self._create_response_stream(
+ model=chat_request.model,
+ message_id=msg_id,
+ chunk_message=chunk_message,
+ finish_reason=finish_reason,
+ )
+
+ yield self._stream_response_to_bytes(response)
+ if chat_request.tools and tool_message:
+ tools = self._parse_tool_message(tool_message)
+ response = self._create_response_stream(
+ model=chat_request.model,
+ message_id=msg_id,
+ tools=tools,
+ )
+ yield self._stream_response_to_bytes(response)
+
+ def _parse_tool_message(self, tool_message: str) -> list[ToolCall]:
+ if DEBUG:
+ logger.info("Tool message: " + tool_message.replace("\n", " "))
+ try:
+ tool_messages = tool_message[tool_message.rindex("") + 6:]
+ function = json.loads(tool_messages.replace("\n", " "))
+ args = json.dumps(function.get("arguments", {}))
+ function = ResponseFunction(
+ name=function["name"], arguments=args
+ )
+
+ return [
+ ToolCall(
+ # id="0",
+ function=function,
+ )
+ ]
+
+ except Exception as e:
+ logger.error("Failed to parse tool response")
+ raise HTTPException(status_code=500, detail="Failed to parse tool response")
def _get_base64_image(self, image_url: str) -> tuple[str, str]:
"""Try to get the base64 data from an image url.
@@ -229,131 +420,6 @@ class ClaudeModel(BedrockModel):
)
return content_parts
- def _create_tool_prompt(self, tools: list[Tool]) -> str:
- tool_prompt = "\nYou have access to the following tools:\n"
- tool_prompt += json.dumps(
- [tool.function.model_dump() for tool in tools], indent=2
- )
- tool_prompt += (
- "\nIf you need to use one of the above tools, "
- "only respond with a JSON object matching the following schema inside a xml tag: \n"
- '{"name": $TOOL_NAME, "arguments": {"$PARAMETER_NAME": "$PARAMETER_VALUE", ...}\n'
- )
- return tool_prompt
-
- def _parse_args(self, chat_request: ChatRequest) -> dict:
- args = {
- "anthropic_version": self.anthropic_version,
- "max_tokens": chat_request.max_tokens,
- "top_p": chat_request.top_p,
- "temperature": chat_request.temperature,
- }
- system_prompt = ""
- converted_messages = []
- for message in chat_request.messages:
- if message.role == "system":
- system_prompt += message.content + "\n"
- elif message.role == "user" and not isinstance(message.content, str):
- converted_messages.append(
- {
- "role": message.role,
- "content": self._parse_content_parts(message.content),
- }
- )
- elif message.role == "assistant" and not message.content:
- # if content is empty
- # create the content using the tool call info.
- tool_content = "Should use {} tool with args: {}".format(
- message.tool_calls[0].function.name,
- message.tool_calls[0].function.arguments,
- )
- converted_messages.append(
- {"role": message.role, "content": tool_content}
- )
- elif message.role == "tool":
- # Since bedrock does not support tool role
- # Convert the tool message to a user message.
- converted_messages.append(
- {
- "role": "user",
- "content": "The result of the tool call is " + message.content,
- }
- )
- else:
- converted_messages.append(
- {"role": message.role, "content": message.content}
- )
-
- if chat_request.tools:
- system_prompt += self._create_tool_prompt(chat_request.tools)
-
- args["messages"] = converted_messages
- if system_prompt:
- if DEBUG:
- logger.info("System Prompt: " + system_prompt)
- args["system"] = system_prompt.replace("\n", "")
- return args
-
- def chat(self, chat_request: ChatRequest) -> ChatResponse:
- if DEBUG:
- logger.info("Raw request: " + chat_request.model_dump_json())
- response = self._invoke_model(
- args=self._parse_args(chat_request), model_id=chat_request.model
- )
- response_body = json.loads(response.get("body").read())
- if DEBUG:
- logger.info("Bedrock response body: " + str(response_body))
- message = response_body["content"][0]["text"]
-
- tools_message = None
- start = message.find("")
- end = message.find("")
- if start != -1 and end != -1:
- tools_message = message[start + 6: end]
- return self._create_response(
- model=chat_request.model,
- message=response_body["content"][0]["text"],
- message_id=response_body["id"],
- tools_message=tools_message,
- input_tokens=response_body["usage"]["input_tokens"],
- output_tokens=response_body["usage"]["output_tokens"],
- )
-
- def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
- response = self._invoke_model(
- args=self._parse_args(chat_request),
- model_id=chat_request.model,
- with_stream=True,
- )
- msg_id = ""
- chunk_id = 0
- for event in response.get("body"):
- if DEBUG:
- logger.info("Bedrock response chunk: " + str(event))
- chunk = json.loads(event["chunk"]["bytes"])
- chunk_id += 1
- if chunk["type"] == "message_start":
- msg_id = chunk["message"]["id"]
- continue
-
- if chunk["type"] == "message_delta":
- chunk_message = ""
- finish_reason = "stop"
-
- elif chunk["type"] == "content_block_delta":
- chunk_message = chunk["delta"]["text"]
- finish_reason = None
- else:
- continue
- response = self._create_response_stream(
- model=chat_request.model,
- message_id=msg_id,
- chunk_message=chunk_message,
- finish_reason=finish_reason,
- )
-
- yield self._stream_response_to_bytes(response)
-
class Llama2Model(BedrockModel):
diff --git a/src/api/schema.py b/src/api/schema.py
index 321dad8..39b9baf 100644
--- a/src/api/schema.py
+++ b/src/api/schema.py
@@ -1,4 +1,5 @@
import time
+import uuid
from typing import Literal, Iterable
from pydantic import BaseModel, Field
@@ -22,7 +23,7 @@ class ResponseFunction(BaseModel):
class ToolCall(BaseModel):
- id: str
+ id: str = Field(default_factory=lambda: str(uuid.uuid4())[:8])
type: Literal["function"] = "function"
function: ResponseFunction