diff --git a/src/api/models/bedrock.py b/src/api/models/bedrock.py index 9e83460..f131179 100644 --- a/src/api/models/bedrock.py +++ b/src/api/models/bedrock.py @@ -24,7 +24,6 @@ from api.schema import ( TextContent, ResponseFunction, ToolCall, - Tool, # Embeddings EmbeddingsRequest, EmbeddingsResponse, @@ -88,39 +87,60 @@ class BedrockModel(BaseChatModel, ABC): contentType=self.content_type, ) + @staticmethod + def merge_message(messages: list[dict]) -> list[dict]: + """Merge the request messages with the same role as previous message + """ + merged_messages = [] + prev_role = None + merged_content = "" + + for message in messages: + role = message["role"] + content = message["content"] + + if role != prev_role or isinstance(content, list): + if prev_role: + merged_messages.append({"role": prev_role, "content": merged_content}) + if isinstance(content, str): + merged_content = content + prev_role = role + else: + merged_messages.append({"role": role, "content": content}) + prev_role = None + merged_content = "" + else: + if content == merged_content: + # ignore duplicates + continue + merged_content += "\n" + content + + if merged_content: + merged_messages.append({"role": prev_role, "content": merged_content}) + return merged_messages + + @staticmethod def _create_response( - self, model: str, - message: str, message_id: str, - tools_message: str | None = None, + message: str | None = None, + finish_reason: str | None = None, + tools: list[ToolCall] | None = None, input_tokens: int = 0, output_tokens: int = 0, ) -> ChatResponse: - if tools_message: - # For tool response, the content is empty - tools = self._parse_tools_response(tools_message) - choice = Choice( + response = ChatResponse( + id=message_id, + model=model, + choices=[Choice( index=0, message=ChatResponseMessage( role="assistant", tool_calls=tools, - ), - finish_reason="stop", - ) - else: - choice = Choice( - index=0, - message=ChatResponseMessage( - role="assistant", content=message, ), - finish_reason="stop", - ) - response = ChatResponse( - id=message_id, - model=model, - choices=[choice], + finish_reason="tool_calls" if tools else finish_reason, + )], usage=Usage( prompt_tokens=input_tokens, completion_tokens=output_tokens, @@ -131,21 +151,26 @@ class BedrockModel(BaseChatModel, ABC): logger.info("Proxy response :" + response.model_dump_json()) return response + @staticmethod def _create_response_stream( - self, model: str, message_id: str, chunk_message: str, finish_reason: str | None + model: str, + message_id: str, + chunk_message: str | None = None, + finish_reason: str | None = None, + tools: list[ToolCall] | None = None, ) -> ChatStreamResponse: - choice = ChoiceDelta( - index=0, - delta=ChatResponseMessage( - role="assistant", - content=chunk_message, - ), - finish_reason=finish_reason, - ) response = ChatStreamResponse( id=message_id, model=model, - choices=[choice], + choices=[ChoiceDelta( + index=0, + delta=ChatResponseMessage( + role="assistant", + tool_calls=tools, + content=chunk_message, + ), + finish_reason="tool_calls" if tools else finish_reason, + )], ) if DEBUG: logger.info("Proxy response :" + response.model_dump_json()) @@ -154,27 +179,193 @@ class BedrockModel(BaseChatModel, ABC): class ClaudeModel(BedrockModel): anthropic_version = "bedrock-2023-05-31" + # follow these instructions for tool uses: + tool_prompt = """You have access to the following tools: +{tools} - def _parse_tools_response(self, tools_messages: str) -> list[ToolCall]: - """Parse the tools response +Please think if you need to use a tool or not for user's question, you must: +1. Respond Y or N inside a xml tag first to indicate that. +2. If a tool is needed, MUST respond a JSON object matching the following schema inside a xml tag: + {{"name": $TOOL_NAME, "arguments": {{"$PARAMETER_NAME": "$PARAMETER_VALUE", ...}}}} +3. If no tools is needed, respond with normal text.""" - Example tool message like: - \n{\n "name": "get_current_weather",\n "arguments": {\n "location": "Shanghai"... }\n}\n - """ - function = json.loads( - tools_messages.replace("\n", " ").encode("unicode_escape") - ) + def _parse_args(self, chat_request: ChatRequest) -> dict: + args = { + "anthropic_version": self.anthropic_version, + "max_tokens": chat_request.max_tokens, + "top_p": chat_request.top_p, + "temperature": chat_request.temperature, + } + system_prompt = "" + converted_messages = [] + for message in chat_request.messages: + if message.role == "system": + system_prompt += message.content + "\n" + elif message.role == "user" and not isinstance(message.content, str): + converted_messages.append( + { + "role": message.role, + "content": self._parse_content_parts(message.content), + } + ) + elif message.role == "assistant" and not message.content: + # if content is empty + # create the content using the tool call info. + tool_content = "[Tool use for `{}` with id `{}` with the following `input`]\n{}".format( + message.tool_calls[0].function.name, + message.tool_calls[0].id, + message.tool_calls[0].function.arguments, + ) + converted_messages.append( + {"role": message.role, "content": tool_content} + ) + elif message.role == "tool": + # Since bedrock does not support tool role + # Convert the tool message to a user message. + converted_messages.append( + { + "role": "user", + "content": "[Tool result with matching id `{}` of `{}`] ".format( + message.tool_call_id, + message.content), + } + ) + else: + converted_messages.append( + {"role": message.role, "content": message.content} + ) - args = json.dumps(function.get("arguments", {})) - function = ResponseFunction( - name=function["name"], arguments=args.replace("\\\\n", "\\n") - ) - return [ - ToolCall( - id="0", - function=function, + if chat_request.tools: + tools_str = json.dumps( + [tool.function.model_dump() for tool in chat_request.tools] ) - ] + system_prompt += self.tool_prompt.format(tools=tools_str) + converted_messages.append({ + 'role': 'assistant', + 'content': '' + }) + args["stop_sequences"] = [''] + args["messages"] = self.merge_message(converted_messages) + if system_prompt: + if DEBUG: + logger.info("System Prompt: " + system_prompt) + args["system"] = system_prompt + + return args + + def chat(self, chat_request: ChatRequest) -> ChatResponse: + if DEBUG: + logger.info("Raw request: " + chat_request.model_dump_json()) + response = self._invoke_model( + args=self._parse_args(chat_request), model_id=chat_request.model + ) + response_body = json.loads(response.get("body").read()) + if DEBUG: + logger.info("Bedrock response body: " + str(response_body)) + message = response_body["content"][0]["text"] + + tools = None + if chat_request.tools: + if message.startswith("Y"): + tools = self._parse_tool_message(message) + message = None + elif message.startswith("N"): + message = message[8:].lstrip("\n") + return self._create_response( + model=chat_request.model, + message_id=response_body["id"], + message=message, + tools=tools, + finish_reason=response_body["stop_reason"], + input_tokens=response_body["usage"]["input_tokens"], + output_tokens=response_body["usage"]["output_tokens"], + ) + + def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]: + response = self._invoke_model( + args=self._parse_args(chat_request), + model_id=chat_request.model, + with_stream=True, + ) + msg_id = "" + chunk_id = 0 + tool_message = "" + first_token = True + index = 0 + for event in response.get("body"): + if DEBUG: + logger.info("Bedrock response chunk: " + str(event)) + chunk = json.loads(event["chunk"]["bytes"]) + chunk_id += 1 + + if chunk["type"] == "message_start": + msg_id = chunk["message"]["id"] + continue + + if chunk["type"] == "message_delta": + chunk_message = "" + finish_reason = chunk["delta"]["stop_reason"] + + elif chunk["type"] == "content_block_delta": + chunk_message = chunk["delta"]["text"] + finish_reason = None + if chat_request.tools: + # Check first token + if not tool_message and chunk_message == "Y": + tool_message = "Y" + continue + if tool_message: + # Buffer all chunk message + # in order to extract tool call info + tool_message += chunk_message + continue + if index < 3: + # Ignore the N, which is 3 tokens + index += 1 + continue + if first_token: + chunk_message = chunk_message.lstrip("\n") + first_token = False + else: + continue + response = self._create_response_stream( + model=chat_request.model, + message_id=msg_id, + chunk_message=chunk_message, + finish_reason=finish_reason, + ) + + yield self._stream_response_to_bytes(response) + if chat_request.tools and tool_message: + tools = self._parse_tool_message(tool_message) + response = self._create_response_stream( + model=chat_request.model, + message_id=msg_id, + tools=tools, + ) + yield self._stream_response_to_bytes(response) + + def _parse_tool_message(self, tool_message: str) -> list[ToolCall]: + if DEBUG: + logger.info("Tool message: " + tool_message.replace("\n", " ")) + try: + tool_messages = tool_message[tool_message.rindex("") + 6:] + function = json.loads(tool_messages.replace("\n", " ")) + args = json.dumps(function.get("arguments", {})) + function = ResponseFunction( + name=function["name"], arguments=args + ) + + return [ + ToolCall( + # id="0", + function=function, + ) + ] + + except Exception as e: + logger.error("Failed to parse tool response") + raise HTTPException(status_code=500, detail="Failed to parse tool response") def _get_base64_image(self, image_url: str) -> tuple[str, str]: """Try to get the base64 data from an image url. @@ -229,131 +420,6 @@ class ClaudeModel(BedrockModel): ) return content_parts - def _create_tool_prompt(self, tools: list[Tool]) -> str: - tool_prompt = "\nYou have access to the following tools:\n" - tool_prompt += json.dumps( - [tool.function.model_dump() for tool in tools], indent=2 - ) - tool_prompt += ( - "\nIf you need to use one of the above tools, " - "only respond with a JSON object matching the following schema inside a xml tag: \n" - '{"name": $TOOL_NAME, "arguments": {"$PARAMETER_NAME": "$PARAMETER_VALUE", ...}\n' - ) - return tool_prompt - - def _parse_args(self, chat_request: ChatRequest) -> dict: - args = { - "anthropic_version": self.anthropic_version, - "max_tokens": chat_request.max_tokens, - "top_p": chat_request.top_p, - "temperature": chat_request.temperature, - } - system_prompt = "" - converted_messages = [] - for message in chat_request.messages: - if message.role == "system": - system_prompt += message.content + "\n" - elif message.role == "user" and not isinstance(message.content, str): - converted_messages.append( - { - "role": message.role, - "content": self._parse_content_parts(message.content), - } - ) - elif message.role == "assistant" and not message.content: - # if content is empty - # create the content using the tool call info. - tool_content = "Should use {} tool with args: {}".format( - message.tool_calls[0].function.name, - message.tool_calls[0].function.arguments, - ) - converted_messages.append( - {"role": message.role, "content": tool_content} - ) - elif message.role == "tool": - # Since bedrock does not support tool role - # Convert the tool message to a user message. - converted_messages.append( - { - "role": "user", - "content": "The result of the tool call is " + message.content, - } - ) - else: - converted_messages.append( - {"role": message.role, "content": message.content} - ) - - if chat_request.tools: - system_prompt += self._create_tool_prompt(chat_request.tools) - - args["messages"] = converted_messages - if system_prompt: - if DEBUG: - logger.info("System Prompt: " + system_prompt) - args["system"] = system_prompt.replace("\n", "") - return args - - def chat(self, chat_request: ChatRequest) -> ChatResponse: - if DEBUG: - logger.info("Raw request: " + chat_request.model_dump_json()) - response = self._invoke_model( - args=self._parse_args(chat_request), model_id=chat_request.model - ) - response_body = json.loads(response.get("body").read()) - if DEBUG: - logger.info("Bedrock response body: " + str(response_body)) - message = response_body["content"][0]["text"] - - tools_message = None - start = message.find("") - end = message.find("") - if start != -1 and end != -1: - tools_message = message[start + 6: end] - return self._create_response( - model=chat_request.model, - message=response_body["content"][0]["text"], - message_id=response_body["id"], - tools_message=tools_message, - input_tokens=response_body["usage"]["input_tokens"], - output_tokens=response_body["usage"]["output_tokens"], - ) - - def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]: - response = self._invoke_model( - args=self._parse_args(chat_request), - model_id=chat_request.model, - with_stream=True, - ) - msg_id = "" - chunk_id = 0 - for event in response.get("body"): - if DEBUG: - logger.info("Bedrock response chunk: " + str(event)) - chunk = json.loads(event["chunk"]["bytes"]) - chunk_id += 1 - if chunk["type"] == "message_start": - msg_id = chunk["message"]["id"] - continue - - if chunk["type"] == "message_delta": - chunk_message = "" - finish_reason = "stop" - - elif chunk["type"] == "content_block_delta": - chunk_message = chunk["delta"]["text"] - finish_reason = None - else: - continue - response = self._create_response_stream( - model=chat_request.model, - message_id=msg_id, - chunk_message=chunk_message, - finish_reason=finish_reason, - ) - - yield self._stream_response_to_bytes(response) - class Llama2Model(BedrockModel): diff --git a/src/api/schema.py b/src/api/schema.py index 321dad8..39b9baf 100644 --- a/src/api/schema.py +++ b/src/api/schema.py @@ -1,4 +1,5 @@ import time +import uuid from typing import Literal, Iterable from pydantic import BaseModel, Field @@ -22,7 +23,7 @@ class ResponseFunction(BaseModel): class ToolCall(BaseModel): - id: str + id: str = Field(default_factory=lambda: str(uuid.uuid4())[:8]) type: Literal["function"] = "function" function: ResponseFunction