Refine tool call

This commit is contained in:
Aiden Dai
2024-04-11 20:53:26 +08:00
parent f11a95cc19
commit 8a9ab560f1
2 changed files with 242 additions and 175 deletions

View File

@@ -24,7 +24,6 @@ from api.schema import (
TextContent, TextContent,
ResponseFunction, ResponseFunction,
ToolCall, ToolCall,
Tool,
# Embeddings # Embeddings
EmbeddingsRequest, EmbeddingsRequest,
EmbeddingsResponse, EmbeddingsResponse,
@@ -88,39 +87,60 @@ class BedrockModel(BaseChatModel, ABC):
contentType=self.content_type, contentType=self.content_type,
) )
@staticmethod
def merge_message(messages: list[dict]) -> list[dict]:
"""Merge the request messages with the same role as previous message
"""
merged_messages = []
prev_role = None
merged_content = ""
for message in messages:
role = message["role"]
content = message["content"]
if role != prev_role or isinstance(content, list):
if prev_role:
merged_messages.append({"role": prev_role, "content": merged_content})
if isinstance(content, str):
merged_content = content
prev_role = role
else:
merged_messages.append({"role": role, "content": content})
prev_role = None
merged_content = ""
else:
if content == merged_content:
# ignore duplicates
continue
merged_content += "\n" + content
if merged_content:
merged_messages.append({"role": prev_role, "content": merged_content})
return merged_messages
@staticmethod
def _create_response( def _create_response(
self,
model: str, model: str,
message: str,
message_id: str, message_id: str,
tools_message: str | None = None, message: str | None = None,
finish_reason: str | None = None,
tools: list[ToolCall] | None = None,
input_tokens: int = 0, input_tokens: int = 0,
output_tokens: int = 0, output_tokens: int = 0,
) -> ChatResponse: ) -> ChatResponse:
if tools_message: response = ChatResponse(
# For tool response, the content is empty id=message_id,
tools = self._parse_tools_response(tools_message) model=model,
choice = Choice( choices=[Choice(
index=0, index=0,
message=ChatResponseMessage( message=ChatResponseMessage(
role="assistant", role="assistant",
tool_calls=tools, tool_calls=tools,
),
finish_reason="stop",
)
else:
choice = Choice(
index=0,
message=ChatResponseMessage(
role="assistant",
content=message, content=message,
), ),
finish_reason="stop", finish_reason="tool_calls" if tools else finish_reason,
) )],
response = ChatResponse(
id=message_id,
model=model,
choices=[choice],
usage=Usage( usage=Usage(
prompt_tokens=input_tokens, prompt_tokens=input_tokens,
completion_tokens=output_tokens, completion_tokens=output_tokens,
@@ -131,21 +151,26 @@ class BedrockModel(BaseChatModel, ABC):
logger.info("Proxy response :" + response.model_dump_json()) logger.info("Proxy response :" + response.model_dump_json())
return response return response
@staticmethod
def _create_response_stream( def _create_response_stream(
self, model: str, message_id: str, chunk_message: str, finish_reason: str | None model: str,
message_id: str,
chunk_message: str | None = None,
finish_reason: str | None = None,
tools: list[ToolCall] | None = None,
) -> ChatStreamResponse: ) -> ChatStreamResponse:
choice = ChoiceDelta(
index=0,
delta=ChatResponseMessage(
role="assistant",
content=chunk_message,
),
finish_reason=finish_reason,
)
response = ChatStreamResponse( response = ChatStreamResponse(
id=message_id, id=message_id,
model=model, model=model,
choices=[choice], choices=[ChoiceDelta(
index=0,
delta=ChatResponseMessage(
role="assistant",
tool_calls=tools,
content=chunk_message,
),
finish_reason="tool_calls" if tools else finish_reason,
)],
) )
if DEBUG: if DEBUG:
logger.info("Proxy response :" + response.model_dump_json()) logger.info("Proxy response :" + response.model_dump_json())
@@ -154,28 +179,194 @@ class BedrockModel(BaseChatModel, ABC):
class ClaudeModel(BedrockModel): class ClaudeModel(BedrockModel):
anthropic_version = "bedrock-2023-05-31" anthropic_version = "bedrock-2023-05-31"
# follow these instructions for tool uses:
tool_prompt = """You have access to the following tools:
{tools}
def _parse_tools_response(self, tools_messages: str) -> list[ToolCall]: Please think if you need to use a tool or not for user's question, you must:
"""Parse the tools response 1. Respond Y or N inside a <Tool></Tool> xml tag first to indicate that.
2. If a tool is needed, MUST respond a JSON object matching the following schema inside a <Func></Func> xml tag:
{{"name": $TOOL_NAME, "arguments": {{"$PARAMETER_NAME": "$PARAMETER_VALUE", ...}}}}
3. If no tools is needed, respond with normal text."""
Example tool message like: def _parse_args(self, chat_request: ChatRequest) -> dict:
\n{\n "name": "get_current_weather",\n "arguments": {\n "location": "Shanghai"... }\n}\n args = {
""" "anthropic_version": self.anthropic_version,
function = json.loads( "max_tokens": chat_request.max_tokens,
tools_messages.replace("\n", " ").encode("unicode_escape") "top_p": chat_request.top_p,
"temperature": chat_request.temperature,
}
system_prompt = ""
converted_messages = []
for message in chat_request.messages:
if message.role == "system":
system_prompt += message.content + "\n"
elif message.role == "user" and not isinstance(message.content, str):
converted_messages.append(
{
"role": message.role,
"content": self._parse_content_parts(message.content),
}
)
elif message.role == "assistant" and not message.content:
# if content is empty
# create the content using the tool call info.
tool_content = "[Tool use for `{}` with id `{}` with the following `input`]\n{}".format(
message.tool_calls[0].function.name,
message.tool_calls[0].id,
message.tool_calls[0].function.arguments,
)
converted_messages.append(
{"role": message.role, "content": tool_content}
)
elif message.role == "tool":
# Since bedrock does not support tool role
# Convert the tool message to a user message.
converted_messages.append(
{
"role": "user",
"content": "[Tool result with matching id `{}` of `{}`] ".format(
message.tool_call_id,
message.content),
}
)
else:
converted_messages.append(
{"role": message.role, "content": message.content}
) )
if chat_request.tools:
tools_str = json.dumps(
[tool.function.model_dump() for tool in chat_request.tools]
)
system_prompt += self.tool_prompt.format(tools=tools_str)
converted_messages.append({
'role': 'assistant',
'content': '<Tool>'
})
args["stop_sequences"] = ['</Func>']
args["messages"] = self.merge_message(converted_messages)
if system_prompt:
if DEBUG:
logger.info("System Prompt: " + system_prompt)
args["system"] = system_prompt
return args
def chat(self, chat_request: ChatRequest) -> ChatResponse:
if DEBUG:
logger.info("Raw request: " + chat_request.model_dump_json())
response = self._invoke_model(
args=self._parse_args(chat_request), model_id=chat_request.model
)
response_body = json.loads(response.get("body").read())
if DEBUG:
logger.info("Bedrock response body: " + str(response_body))
message = response_body["content"][0]["text"]
tools = None
if chat_request.tools:
if message.startswith("Y</Tool>"):
tools = self._parse_tool_message(message)
message = None
elif message.startswith("N</Tool>"):
message = message[8:].lstrip("\n")
return self._create_response(
model=chat_request.model,
message_id=response_body["id"],
message=message,
tools=tools,
finish_reason=response_body["stop_reason"],
input_tokens=response_body["usage"]["input_tokens"],
output_tokens=response_body["usage"]["output_tokens"],
)
def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
response = self._invoke_model(
args=self._parse_args(chat_request),
model_id=chat_request.model,
with_stream=True,
)
msg_id = ""
chunk_id = 0
tool_message = ""
first_token = True
index = 0
for event in response.get("body"):
if DEBUG:
logger.info("Bedrock response chunk: " + str(event))
chunk = json.loads(event["chunk"]["bytes"])
chunk_id += 1
if chunk["type"] == "message_start":
msg_id = chunk["message"]["id"]
continue
if chunk["type"] == "message_delta":
chunk_message = ""
finish_reason = chunk["delta"]["stop_reason"]
elif chunk["type"] == "content_block_delta":
chunk_message = chunk["delta"]["text"]
finish_reason = None
if chat_request.tools:
# Check first token
if not tool_message and chunk_message == "Y":
tool_message = "Y"
continue
if tool_message:
# Buffer all chunk message
# in order to extract tool call info
tool_message += chunk_message
continue
if index < 3:
# Ignore the N</Tool>, which is 3 tokens
index += 1
continue
if first_token:
chunk_message = chunk_message.lstrip("\n")
first_token = False
else:
continue
response = self._create_response_stream(
model=chat_request.model,
message_id=msg_id,
chunk_message=chunk_message,
finish_reason=finish_reason,
)
yield self._stream_response_to_bytes(response)
if chat_request.tools and tool_message:
tools = self._parse_tool_message(tool_message)
response = self._create_response_stream(
model=chat_request.model,
message_id=msg_id,
tools=tools,
)
yield self._stream_response_to_bytes(response)
def _parse_tool_message(self, tool_message: str) -> list[ToolCall]:
if DEBUG:
logger.info("Tool message: " + tool_message.replace("\n", " "))
try:
tool_messages = tool_message[tool_message.rindex("<Func>") + 6:]
function = json.loads(tool_messages.replace("\n", " "))
args = json.dumps(function.get("arguments", {})) args = json.dumps(function.get("arguments", {}))
function = ResponseFunction( function = ResponseFunction(
name=function["name"], arguments=args.replace("\\\\n", "\\n") name=function["name"], arguments=args
) )
return [ return [
ToolCall( ToolCall(
id="0", # id="0",
function=function, function=function,
) )
] ]
except Exception as e:
logger.error("Failed to parse tool response")
raise HTTPException(status_code=500, detail="Failed to parse tool response")
def _get_base64_image(self, image_url: str) -> tuple[str, str]: def _get_base64_image(self, image_url: str) -> tuple[str, str]:
"""Try to get the base64 data from an image url. """Try to get the base64 data from an image url.
@@ -229,131 +420,6 @@ class ClaudeModel(BedrockModel):
) )
return content_parts return content_parts
def _create_tool_prompt(self, tools: list[Tool]) -> str:
tool_prompt = "\nYou have access to the following tools:\n"
tool_prompt += json.dumps(
[tool.function.model_dump() for tool in tools], indent=2
)
tool_prompt += (
"\nIf you need to use one of the above tools, "
"only respond with a JSON object matching the following schema inside a <tool></tool> xml tag: \n"
'{"name": $TOOL_NAME, "arguments": {"$PARAMETER_NAME": "$PARAMETER_VALUE", ...}\n'
)
return tool_prompt
def _parse_args(self, chat_request: ChatRequest) -> dict:
args = {
"anthropic_version": self.anthropic_version,
"max_tokens": chat_request.max_tokens,
"top_p": chat_request.top_p,
"temperature": chat_request.temperature,
}
system_prompt = ""
converted_messages = []
for message in chat_request.messages:
if message.role == "system":
system_prompt += message.content + "\n"
elif message.role == "user" and not isinstance(message.content, str):
converted_messages.append(
{
"role": message.role,
"content": self._parse_content_parts(message.content),
}
)
elif message.role == "assistant" and not message.content:
# if content is empty
# create the content using the tool call info.
tool_content = "Should use {} tool with args: {}".format(
message.tool_calls[0].function.name,
message.tool_calls[0].function.arguments,
)
converted_messages.append(
{"role": message.role, "content": tool_content}
)
elif message.role == "tool":
# Since bedrock does not support tool role
# Convert the tool message to a user message.
converted_messages.append(
{
"role": "user",
"content": "The result of the tool call is " + message.content,
}
)
else:
converted_messages.append(
{"role": message.role, "content": message.content}
)
if chat_request.tools:
system_prompt += self._create_tool_prompt(chat_request.tools)
args["messages"] = converted_messages
if system_prompt:
if DEBUG:
logger.info("System Prompt: " + system_prompt)
args["system"] = system_prompt.replace("\n", "")
return args
def chat(self, chat_request: ChatRequest) -> ChatResponse:
if DEBUG:
logger.info("Raw request: " + chat_request.model_dump_json())
response = self._invoke_model(
args=self._parse_args(chat_request), model_id=chat_request.model
)
response_body = json.loads(response.get("body").read())
if DEBUG:
logger.info("Bedrock response body: " + str(response_body))
message = response_body["content"][0]["text"]
tools_message = None
start = message.find("<tool>")
end = message.find("</tool>")
if start != -1 and end != -1:
tools_message = message[start + 6: end]
return self._create_response(
model=chat_request.model,
message=response_body["content"][0]["text"],
message_id=response_body["id"],
tools_message=tools_message,
input_tokens=response_body["usage"]["input_tokens"],
output_tokens=response_body["usage"]["output_tokens"],
)
def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
response = self._invoke_model(
args=self._parse_args(chat_request),
model_id=chat_request.model,
with_stream=True,
)
msg_id = ""
chunk_id = 0
for event in response.get("body"):
if DEBUG:
logger.info("Bedrock response chunk: " + str(event))
chunk = json.loads(event["chunk"]["bytes"])
chunk_id += 1
if chunk["type"] == "message_start":
msg_id = chunk["message"]["id"]
continue
if chunk["type"] == "message_delta":
chunk_message = ""
finish_reason = "stop"
elif chunk["type"] == "content_block_delta":
chunk_message = chunk["delta"]["text"]
finish_reason = None
else:
continue
response = self._create_response_stream(
model=chat_request.model,
message_id=msg_id,
chunk_message=chunk_message,
finish_reason=finish_reason,
)
yield self._stream_response_to_bytes(response)
class Llama2Model(BedrockModel): class Llama2Model(BedrockModel):

View File

@@ -1,4 +1,5 @@
import time import time
import uuid
from typing import Literal, Iterable from typing import Literal, Iterable
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@@ -22,7 +23,7 @@ class ResponseFunction(BaseModel):
class ToolCall(BaseModel): class ToolCall(BaseModel):
id: str id: str = Field(default_factory=lambda: str(uuid.uuid4())[:8])
type: Literal["function"] = "function" type: Literal["function"] = "function"
function: ResponseFunction function: ResponseFunction