From f1440602ced87fc3f1fa4f4d815bbcd4b94a3d14 Mon Sep 17 00:00:00 2001 From: Aiden Dai Date: Wed, 3 Apr 2024 11:10:19 +0800 Subject: [PATCH] Add Tool call support --- src/api/models/bedrock.py | 329 ++++++++++++++++++++++++++------------ src/api/schema.py | 47 ++++-- src/requirements.txt | 4 +- 3 files changed, 264 insertions(+), 116 deletions(-) diff --git a/src/api/models/bedrock.py b/src/api/models/bedrock.py index f5bd865..c86d79b 100644 --- a/src/api/models/bedrock.py +++ b/src/api/models/bedrock.py @@ -1,6 +1,7 @@ import base64 import json import logging +from abc import ABC from typing import AsyncIterable, Iterable import boto3 @@ -13,17 +14,21 @@ from api.schema import ( # Chat ChatResponse, ChatRequest, - ChatRequestMessage, Choice, ChatResponseMessage, Usage, ChatStreamResponse, ChoiceDelta, + ImageContent, + TextContent, + ResponseFunction, + ToolCall, + Tool, # Embeddings EmbeddingsRequest, EmbeddingsResponse, EmbeddingsUsage, - Embedding, TextContent, + Embedding, ) from api.setting import DEBUG, AWS_REGION @@ -58,7 +63,7 @@ ENCODER = tiktoken.get_encoding("cl100k_base") # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html -class BedrockModel(BaseChatModel): +class BedrockModel(BaseChatModel, ABC): accept = "application/json" content_type = "application/json" @@ -86,17 +91,30 @@ class BedrockModel(BaseChatModel): model: str, message: str, message_id: str, + tools_message: str | None = None, input_tokens: int = 0, output_tokens: int = 0, ) -> ChatResponse: - choice = Choice( - index=0, - message=ChatResponseMessage( - role="assistant", - content=message, - ), - finish_reason="stop", - ) + if tools_message: + # For tool response, the content is empty + tools = self._parse_tools_response(tools_message) + choice = Choice( + index=0, + message=ChatResponseMessage( + role="assistant", + tool_calls=tools, + ), + finish_reason="stop", + ) + else: + choice = Choice( + index=0, + message=ChatResponseMessage( + role="assistant", + content=message, + ), + finish_reason="stop", + ) response = ChatResponse( id=message_id, model=model, @@ -132,25 +150,31 @@ class BedrockModel(BaseChatModel): return response -def get_model(model_id: str) -> BedrockModel: - model_name = SUPPORTED_BEDROCK_MODELS.get(model_id, "") - if DEBUG: - logger.info("model name is " + model_name) - if model_name in ["Claude Instant", "Claude", "Claude 3 Sonnet", "Claude 3 Haiku"]: - return ClaudeModel() - elif model_name in ["Llama 2 Chat 13B", "Llama 2 Chat 70B"]: - return Llama2Model() - elif model_name in ["Mistral 7B Instruct", "Mixtral 8x7B Instruct"]: - return MistralModel() - else: - logger.error("Unsupported model id " + model_id) - raise ValueError("Invalid model ID") - - class ClaudeModel(BedrockModel): anthropic_version = "bedrock-2023-05-31" - def _get_base64_image(self, image_url: str): + def _parse_tools_response(self, tools_messages: str) -> list[ToolCall]: + """Parse the tools response + + Example tool message like: + \n{\n "name": "get_current_weather",\n "arguments": {\n "location": "Shanghai"... }\n}\n + """ + function = json.loads( + tools_messages.replace("\n", " ").encode("unicode_escape") + ) + + args = json.dumps(function.get("arguments", {})) + function = ResponseFunction( + name=function["name"], arguments=args.replace("\\\\n", "\\n") + ) + return [ + ToolCall( + id="0", + function=function, + ) + ] + + def _get_base64_image(self, image_url: str) -> str: # Send a request to the image URL response = requests.get(image_url) # Check if the request was successful @@ -159,34 +183,44 @@ class ClaudeModel(BedrockModel): image_content = response.content # Encode the image content as base64 base64_image = base64.b64encode(image_content) - return base64_image.decode('utf-8') + return base64_image.decode("utf-8") else: - raise HTTPException(status_code=500, detail="Unable to access the image url") + raise HTTPException( + status_code=500, detail="Unable to access the image url" + ) - def _parse_messages(self, messages: list[ChatRequestMessage]) -> list[dict]: - # Refer to: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html - converted_messages = [] - for msg in messages: - if isinstance(msg.content, str): - converted_messages.append({"role": msg.role, "content": msg.content}) - continue - - content_parts = [] - for part in msg.content: - if isinstance(part, TextContent): - content_parts.append(part.model_dump()) - else: - content_parts.append({ + def _parse_content_parts( + self, content: list[TextContent | ImageContent] + ) -> list[dict]: + # See: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html + content_parts = [] + for part in content: + if isinstance(part, TextContent): + content_parts.append(part.model_dump()) + else: + content_parts.append( + { "type": "image", "source": { "type": "base64", "media_type": "image/jpeg", - "data": self._get_base64_image(part.image_url.url) - } - }) + "data": self._get_base64_image(part.image_url.url), + }, + } + ) + return content_parts - converted_messages.append({"role": msg.role, "content": content_parts}) - return converted_messages + def _create_tool_prompt(self, tools: list[Tool]) -> str: + tool_prompt = "\nYou have access to the following tools:\n" + tool_prompt += json.dumps( + [tool.function.model_dump() for tool in tools], indent=2 + ) + tool_prompt += ( + "\nIf you need to use one of the above tools, " + "only respond with a JSON object matching the following schema inside a xml tag: \n" + '{"name": $TOOL_NAME, "arguments": {"$PARAMETER_NAME": "$PARAMETER_VALUE", ...}\n' + ) + return tool_prompt def _parse_args(self, chat_request: ChatRequest) -> dict: args = { @@ -195,25 +229,73 @@ class ClaudeModel(BedrockModel): "top_p": chat_request.top_p, "temperature": chat_request.temperature, } - start = 0 - if chat_request.messages[0].role == "system": - args["system"] = chat_request.messages[0].content - start = 1 - args["messages"] = self._parse_messages(chat_request.messages[start:]) + system_prompt = "" + converted_messages = [] + for message in chat_request.messages: + if message.role == "system": + system_prompt += message.content + "\n" + elif message.role == "user" and not isinstance(message.content, str): + converted_messages.append( + { + "role": message.role, + "content": self._parse_content_parts(message.content), + } + ) + elif message.role == "assistant" and not message.content: + # if content is empty + # create the content using the tool call info. + tool_content = "Should use {} tool with args: {}".format( + message.tool_calls[0].function.name, + message.tool_calls[0].function.arguments, + ) + converted_messages.append( + {"role": message.role, "content": tool_content} + ) + elif message.role == "tool": + # Since bedrock does not support tool role + # Convert the tool message to a user message. + converted_messages.append( + { + "role": "user", + "content": "The result of the tool call is " + message.content, + } + ) + else: + converted_messages.append( + {"role": message.role, "content": message.content} + ) + + if chat_request.tools: + system_prompt += self._create_tool_prompt(chat_request.tools) + + args["messages"] = converted_messages + if system_prompt: + if DEBUG: + logger.info("System Prompt: " + system_prompt) + args["system"] = system_prompt.replace("\n", "") return args def chat(self, chat_request: ChatRequest) -> ChatResponse: + if DEBUG: + logger.info("Raw request: " + chat_request.model_dump_json()) response = self._invoke_model( args=self._parse_args(chat_request), model_id=chat_request.model ) response_body = json.loads(response.get("body").read()) if DEBUG: logger.info("Bedrock response body: " + str(response_body)) + message = response_body["content"][0]["text"] + tools_message = None + start = message.find("") + end = message.find("") + if start != -1 and end != -1: + tools_message = message[start + 6: end] return self._create_response( model=chat_request.model, message=response_body["content"][0]["text"], message_id=response_body["id"], + tools_message=tools_message, input_tokens=response_body["usage"]["input_tokens"], output_tokens=response_body["usage"]["output_tokens"], ) @@ -256,7 +338,7 @@ class ClaudeModel(BedrockModel): class Llama2Model(BedrockModel): - def _convert_prompt(self, messages: list[ChatRequestMessage]) -> str: + def _convert_prompt(self, chat_request: ChatRequest) -> str: """Create a prompt message follow below example: [INST] <>\n{your_system_message}\n<>\n\n{user_message_1} [/INST] {model_reply_1} @@ -264,21 +346,26 @@ class Llama2Model(BedrockModel): """ if DEBUG: logger.info("Convert below messages to prompt for Llama 2: ") - for msg in messages: + for msg in chat_request.messages: logger.info(msg.model_dump_json()) bos_token = "" eos_token = "" - prompt = bos_token + "[INST] " - start = 0 + prompt = "" end_turn = False - if messages[0].role == "system": - prompt += "<>\n" + messages[0].content + "\n<>\n\n" - start = 1 - # TODO: Add validation - for i in range(start, len(messages)): - msg = messages[i] + system_prompt = "" + for msg in chat_request.messages: + if msg.role == "system": + system_prompt += "\n" + msg.content + "\n" + continue + if msg.role == "tool": + raise HTTPException( + status_code=500, + detail="Tool prompt is not supported for Llama 2 model", + ) if not isinstance(msg.content, str): - raise HTTPException(status_code=400, detail="Content must be a string for Llama 2 model") + raise HTTPException( + status_code=400, detail="Content must be a string for Llama 2 model" + ) if msg.role == "user": if end_turn: prompt += bos_token + "[INST] " @@ -287,12 +374,16 @@ class Llama2Model(BedrockModel): else: prompt += msg.content + eos_token end_turn = True + + if system_prompt: + system_prompt = "<>" + system_prompt + "<>" + prompt = bos_token + "[INST] " + system_prompt + prompt if DEBUG: logger.info("Converted prompt: " + prompt.replace("\n", "\\n")) return prompt def _parse_args(self, chat_request: ChatRequest) -> dict: - prompt = self._convert_prompt(chat_request.messages) + prompt = self._convert_prompt(chat_request) return { "prompt": prompt, "max_gen_len": chat_request.max_tokens, @@ -340,29 +431,36 @@ class Llama2Model(BedrockModel): class MistralModel(BedrockModel): - def _convert_prompt(self, messages: list[ChatRequestMessage]) -> str: + def _convert_prompt(self, chat_request: ChatRequest) -> str: """Create a prompt message follow below example: [INST] {your_system_message}\n{user_message_1} [/INST] {model_reply_1} [INST] {user_message_2} [/INST] """ + # TODO: maybe reuse the Llama 2 one. if DEBUG: - logger.info("Convert below messages to prompt for Llama 2: ") - for msg in messages: + logger.info("Convert below messages to prompt for Mistral/Mixtral model: ") + for msg in chat_request.messages: logger.info(msg.model_dump_json()) bos_token = "" eos_token = "" - prompt = bos_token + "[INST] " - start = 0 + prompt = "" end_turn = False - if messages[0].role == "system": - prompt += messages[0].content + "\n" - start = 1 - # TODO: Add validation - for i in range(start, len(messages)): - msg = messages[i] + system_prompt = "" + for msg in chat_request.messages: + if msg.role == "system": + system_prompt += "\n" + msg.content + "\n" + continue + if msg.role == "tool": + raise HTTPException( + status_code=500, + detail="Tool prompt is not supported for Mistral/Mixtral model", + ) if not isinstance(msg.content, str): - raise HTTPException(status_code=400, detail="Content must be a string for Mistral/Mixtral model") + raise HTTPException( + status_code=400, + detail="Content must be a string for Mistral/Mixtral model", + ) if msg.role == "user": if end_turn: prompt += bos_token + "[INST] " @@ -371,12 +469,14 @@ class MistralModel(BedrockModel): else: prompt += msg.content + eos_token end_turn = True + + prompt = bos_token + "[INST] " + system_prompt + prompt if DEBUG: logger.info("Converted prompt: " + prompt.replace("\n", "\\n")) return prompt def _parse_args(self, chat_request: ChatRequest) -> dict: - prompt = self._convert_prompt(chat_request.messages) + prompt = self._convert_prompt(chat_request) return { "prompt": prompt, "max_tokens": chat_request.max_tokens, @@ -422,7 +522,7 @@ class MistralModel(BedrockModel): yield self._stream_response_to_bytes(response) -class BedrockEmbeddingsModel(BaseEmbeddingsModel): +class BedrockEmbeddingsModel(BaseEmbeddingsModel, ABC): accept = "application/json" content_type = "application/json" @@ -446,10 +546,8 @@ class BedrockEmbeddingsModel(BaseEmbeddingsModel): output_tokens: int = 0, ) -> EmbeddingsResponse: data = [ - Embedding( - index=i, - embedding=embedding - ) for i, embedding in enumerate(embeddings) + Embedding(index=i, embedding=embedding) + for i, embedding in enumerate(embeddings) ] response = EmbeddingsResponse( data=data, @@ -465,19 +563,6 @@ class BedrockEmbeddingsModel(BaseEmbeddingsModel): return response -def get_embeddings_model(model_id: str) -> BedrockEmbeddingsModel: - model_name = SUPPORTED_BEDROCK_EMBEDDING_MODELS.get(model_id, "") - if DEBUG: - logger.info("model name is " + model_name) - if model_name in ["Cohere Embed Multilingual", "Cohere Embed English"]: - return CohereEmbeddingsModel() - elif model_name in ["Titan Embeddings G1 - Text", "Titan Multimodal Embeddings G1"]: - return TitanEmbeddingsModel() - else: - logger.error("Unsupported model id " + model_id) - raise ValueError("Invalid model ID") - - class CohereEmbeddingsModel(BedrockEmbeddingsModel): def _parse_args(self, embeddings_request: EmbeddingsRequest) -> dict: @@ -528,17 +613,25 @@ class TitanEmbeddingsModel(BedrockEmbeddingsModel): def _parse_args(self, embeddings_request: EmbeddingsRequest) -> dict: if isinstance(embeddings_request.input, str): input_text = embeddings_request.input - elif isinstance(embeddings_request.input, list) and len(embeddings_request.input) == 1: + elif ( + isinstance(embeddings_request.input, list) + and len(embeddings_request.input) == 1 + ): input_text = embeddings_request.input[0] else: - raise ValueError("Amazon Titan Embeddings models support only single strings as input.") + raise ValueError( + "Amazon Titan Embeddings models support only single strings as input." + ) args = { "inputText": input_text, # Note: inputImage is not supported! } if embeddings_request.model == "amazon.titan-embed-image-v1": - args["embeddingConfig"] = embeddings_request.embedding_config if embeddings_request.embedding_config else { - "outputEmbeddingLength": 1024} + args["embeddingConfig"] = ( + embeddings_request.embedding_config + if embeddings_request.embedding_config + else {"outputEmbeddingLength": 1024} + ) return args def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse: @@ -552,5 +645,39 @@ class TitanEmbeddingsModel(BedrockEmbeddingsModel): return self._create_response( embeddings=[response_body["embedding"]], model=embeddings_request.model, - input_tokens=response_body["inputTextTokenCount"] + input_tokens=response_body["inputTextTokenCount"], + ) + + +def get_model(model_id: str) -> BedrockModel: + model_name = SUPPORTED_BEDROCK_MODELS.get(model_id, "") + if DEBUG: + logger.info("model name is " + model_name) + if model_name in ["Claude Instant", "Claude", "Claude 3 Sonnet", "Claude 3 Haiku"]: + return ClaudeModel() + elif model_name in ["Llama 2 Chat 13B", "Llama 2 Chat 70B"]: + return Llama2Model() + elif model_name in ["Mistral 7B Instruct", "Mixtral 8x7B Instruct"]: + return MistralModel() + else: + logger.error("Unsupported model id " + model_id) + raise HTTPException( + status_code=500, + detail="Unsupported model id " + model_id, + ) + + +def get_embeddings_model(model_id: str) -> BedrockEmbeddingsModel: + model_name = SUPPORTED_BEDROCK_EMBEDDING_MODELS.get(model_id, "") + if DEBUG: + logger.info("model name is " + model_name) + if model_name in ["Cohere Embed Multilingual", "Cohere Embed English"]: + return CohereEmbeddingsModel() + elif model_name in ["Titan Embeddings G1 - Text", "Titan Multimodal Embeddings G1"]: + return TitanEmbeddingsModel() + else: + logger.error("Unsupported model id " + model_id) + raise HTTPException( + status_code=500, + detail="Unsupported model id " + model_id, ) diff --git a/src/api/schema.py b/src/api/schema.py index aa87a58..321dad8 100644 --- a/src/api/schema.py +++ b/src/api/schema.py @@ -16,6 +16,17 @@ class Models(BaseModel): data: list[Model] = [] +class ResponseFunction(BaseModel): + name: str + arguments: str + + +class ToolCall(BaseModel): + id: str + type: Literal["function"] = "function" + function: ResponseFunction + + class TextContent(BaseModel): type: Literal["text"] = "text" text: str @@ -31,12 +42,31 @@ class ImageContent(BaseModel): image_url: ImageUrl -class ChatRequestMessage(BaseModel): +class SystemMessage(BaseModel): name: str | None = None - role: Literal["user", "assistant", "system"] + role: Literal["system"] = "system" + content: str + + +class UserMessage(BaseModel): + name: str | None = None + role: Literal["user"] = "user" content: str | list[TextContent | ImageContent] +class AssistantMessage(BaseModel): + name: str | None = None + role: Literal["assistant"] = "assistant" + content: str | None + tool_calls: list[ToolCall] | None = None + + +class ToolMessage(BaseModel): + role: Literal["tool"] = "tool" + content: str + tool_call_id: str + + class Function(BaseModel): name: str description: str | None = None @@ -49,7 +79,7 @@ class Tool(BaseModel): class ChatRequest(BaseModel): - messages: list[ChatRequestMessage] + messages: list[SystemMessage | UserMessage | AssistantMessage | ToolMessage] model: str frequency_penalty: float | None = Field(default=0.0, le=2.0, ge=-2.0) # Not used presence_penalty: float | None = Field(default=0.0, le=2.0, ge=-2.0) # Not used @@ -69,17 +99,6 @@ class Usage(BaseModel): total_tokens: int -class ResponseFunction(BaseModel): - name: str - arguments: str - - -class ToolCall(BaseModel): - id: str - type: Literal["function"] = "function" - function: ResponseFunction - - class ChatResponseMessage(BaseModel): # tool_calls role: Literal["assistant"] | None = None diff --git a/src/requirements.txt b/src/requirements.txt index 322918e..49019b8 100644 --- a/src/requirements.txt +++ b/src/requirements.txt @@ -1,4 +1,6 @@ fastapi==0.110.0 pydantic==2.6.3 uvicorn==0.27.0.post1 -mangum==0.17.0 \ No newline at end of file +mangum==0.17.0 +tiktoken==0.6.0 +requests==2.31.0 \ No newline at end of file