Refine tool call
This commit is contained in:
@@ -24,7 +24,6 @@ from api.schema import (
|
|||||||
TextContent,
|
TextContent,
|
||||||
ResponseFunction,
|
ResponseFunction,
|
||||||
ToolCall,
|
ToolCall,
|
||||||
Tool,
|
|
||||||
# Embeddings
|
# Embeddings
|
||||||
EmbeddingsRequest,
|
EmbeddingsRequest,
|
||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
@@ -88,39 +87,60 @@ class BedrockModel(BaseChatModel, ABC):
|
|||||||
contentType=self.content_type,
|
contentType=self.content_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def merge_message(messages: list[dict]) -> list[dict]:
|
||||||
|
"""Merge the request messages with the same role as previous message
|
||||||
|
"""
|
||||||
|
merged_messages = []
|
||||||
|
prev_role = None
|
||||||
|
merged_content = ""
|
||||||
|
|
||||||
|
for message in messages:
|
||||||
|
role = message["role"]
|
||||||
|
content = message["content"]
|
||||||
|
|
||||||
|
if role != prev_role or isinstance(content, list):
|
||||||
|
if prev_role:
|
||||||
|
merged_messages.append({"role": prev_role, "content": merged_content})
|
||||||
|
if isinstance(content, str):
|
||||||
|
merged_content = content
|
||||||
|
prev_role = role
|
||||||
|
else:
|
||||||
|
merged_messages.append({"role": role, "content": content})
|
||||||
|
prev_role = None
|
||||||
|
merged_content = ""
|
||||||
|
else:
|
||||||
|
if content == merged_content:
|
||||||
|
# ignore duplicates
|
||||||
|
continue
|
||||||
|
merged_content += "\n" + content
|
||||||
|
|
||||||
|
if merged_content:
|
||||||
|
merged_messages.append({"role": prev_role, "content": merged_content})
|
||||||
|
return merged_messages
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def _create_response(
|
def _create_response(
|
||||||
self,
|
|
||||||
model: str,
|
model: str,
|
||||||
message: str,
|
|
||||||
message_id: str,
|
message_id: str,
|
||||||
tools_message: str | None = None,
|
message: str | None = None,
|
||||||
|
finish_reason: str | None = None,
|
||||||
|
tools: list[ToolCall] | None = None,
|
||||||
input_tokens: int = 0,
|
input_tokens: int = 0,
|
||||||
output_tokens: int = 0,
|
output_tokens: int = 0,
|
||||||
) -> ChatResponse:
|
) -> ChatResponse:
|
||||||
if tools_message:
|
response = ChatResponse(
|
||||||
# For tool response, the content is empty
|
id=message_id,
|
||||||
tools = self._parse_tools_response(tools_message)
|
model=model,
|
||||||
choice = Choice(
|
choices=[Choice(
|
||||||
index=0,
|
index=0,
|
||||||
message=ChatResponseMessage(
|
message=ChatResponseMessage(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
tool_calls=tools,
|
tool_calls=tools,
|
||||||
),
|
|
||||||
finish_reason="stop",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
choice = Choice(
|
|
||||||
index=0,
|
|
||||||
message=ChatResponseMessage(
|
|
||||||
role="assistant",
|
|
||||||
content=message,
|
content=message,
|
||||||
),
|
),
|
||||||
finish_reason="stop",
|
finish_reason="tool_calls" if tools else finish_reason,
|
||||||
)
|
)],
|
||||||
response = ChatResponse(
|
|
||||||
id=message_id,
|
|
||||||
model=model,
|
|
||||||
choices=[choice],
|
|
||||||
usage=Usage(
|
usage=Usage(
|
||||||
prompt_tokens=input_tokens,
|
prompt_tokens=input_tokens,
|
||||||
completion_tokens=output_tokens,
|
completion_tokens=output_tokens,
|
||||||
@@ -131,21 +151,26 @@ class BedrockModel(BaseChatModel, ABC):
|
|||||||
logger.info("Proxy response :" + response.model_dump_json())
|
logger.info("Proxy response :" + response.model_dump_json())
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def _create_response_stream(
|
def _create_response_stream(
|
||||||
self, model: str, message_id: str, chunk_message: str, finish_reason: str | None
|
model: str,
|
||||||
|
message_id: str,
|
||||||
|
chunk_message: str | None = None,
|
||||||
|
finish_reason: str | None = None,
|
||||||
|
tools: list[ToolCall] | None = None,
|
||||||
) -> ChatStreamResponse:
|
) -> ChatStreamResponse:
|
||||||
choice = ChoiceDelta(
|
|
||||||
index=0,
|
|
||||||
delta=ChatResponseMessage(
|
|
||||||
role="assistant",
|
|
||||||
content=chunk_message,
|
|
||||||
),
|
|
||||||
finish_reason=finish_reason,
|
|
||||||
)
|
|
||||||
response = ChatStreamResponse(
|
response = ChatStreamResponse(
|
||||||
id=message_id,
|
id=message_id,
|
||||||
model=model,
|
model=model,
|
||||||
choices=[choice],
|
choices=[ChoiceDelta(
|
||||||
|
index=0,
|
||||||
|
delta=ChatResponseMessage(
|
||||||
|
role="assistant",
|
||||||
|
tool_calls=tools,
|
||||||
|
content=chunk_message,
|
||||||
|
),
|
||||||
|
finish_reason="tool_calls" if tools else finish_reason,
|
||||||
|
)],
|
||||||
)
|
)
|
||||||
if DEBUG:
|
if DEBUG:
|
||||||
logger.info("Proxy response :" + response.model_dump_json())
|
logger.info("Proxy response :" + response.model_dump_json())
|
||||||
@@ -154,27 +179,193 @@ class BedrockModel(BaseChatModel, ABC):
|
|||||||
|
|
||||||
class ClaudeModel(BedrockModel):
|
class ClaudeModel(BedrockModel):
|
||||||
anthropic_version = "bedrock-2023-05-31"
|
anthropic_version = "bedrock-2023-05-31"
|
||||||
|
# follow these instructions for tool uses:
|
||||||
|
tool_prompt = """You have access to the following tools:
|
||||||
|
{tools}
|
||||||
|
|
||||||
def _parse_tools_response(self, tools_messages: str) -> list[ToolCall]:
|
Please think if you need to use a tool or not for user's question, you must:
|
||||||
"""Parse the tools response
|
1. Respond Y or N inside a <Tool></Tool> xml tag first to indicate that.
|
||||||
|
2. If a tool is needed, MUST respond a JSON object matching the following schema inside a <Func></Func> xml tag:
|
||||||
|
{{"name": $TOOL_NAME, "arguments": {{"$PARAMETER_NAME": "$PARAMETER_VALUE", ...}}}}
|
||||||
|
3. If no tools is needed, respond with normal text."""
|
||||||
|
|
||||||
Example tool message like:
|
def _parse_args(self, chat_request: ChatRequest) -> dict:
|
||||||
\n{\n "name": "get_current_weather",\n "arguments": {\n "location": "Shanghai"... }\n}\n
|
args = {
|
||||||
"""
|
"anthropic_version": self.anthropic_version,
|
||||||
function = json.loads(
|
"max_tokens": chat_request.max_tokens,
|
||||||
tools_messages.replace("\n", " ").encode("unicode_escape")
|
"top_p": chat_request.top_p,
|
||||||
)
|
"temperature": chat_request.temperature,
|
||||||
|
}
|
||||||
|
system_prompt = ""
|
||||||
|
converted_messages = []
|
||||||
|
for message in chat_request.messages:
|
||||||
|
if message.role == "system":
|
||||||
|
system_prompt += message.content + "\n"
|
||||||
|
elif message.role == "user" and not isinstance(message.content, str):
|
||||||
|
converted_messages.append(
|
||||||
|
{
|
||||||
|
"role": message.role,
|
||||||
|
"content": self._parse_content_parts(message.content),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
elif message.role == "assistant" and not message.content:
|
||||||
|
# if content is empty
|
||||||
|
# create the content using the tool call info.
|
||||||
|
tool_content = "[Tool use for `{}` with id `{}` with the following `input`]\n{}".format(
|
||||||
|
message.tool_calls[0].function.name,
|
||||||
|
message.tool_calls[0].id,
|
||||||
|
message.tool_calls[0].function.arguments,
|
||||||
|
)
|
||||||
|
converted_messages.append(
|
||||||
|
{"role": message.role, "content": tool_content}
|
||||||
|
)
|
||||||
|
elif message.role == "tool":
|
||||||
|
# Since bedrock does not support tool role
|
||||||
|
# Convert the tool message to a user message.
|
||||||
|
converted_messages.append(
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "[Tool result with matching id `{}` of `{}`] ".format(
|
||||||
|
message.tool_call_id,
|
||||||
|
message.content),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
converted_messages.append(
|
||||||
|
{"role": message.role, "content": message.content}
|
||||||
|
)
|
||||||
|
|
||||||
args = json.dumps(function.get("arguments", {}))
|
if chat_request.tools:
|
||||||
function = ResponseFunction(
|
tools_str = json.dumps(
|
||||||
name=function["name"], arguments=args.replace("\\\\n", "\\n")
|
[tool.function.model_dump() for tool in chat_request.tools]
|
||||||
)
|
|
||||||
return [
|
|
||||||
ToolCall(
|
|
||||||
id="0",
|
|
||||||
function=function,
|
|
||||||
)
|
)
|
||||||
]
|
system_prompt += self.tool_prompt.format(tools=tools_str)
|
||||||
|
converted_messages.append({
|
||||||
|
'role': 'assistant',
|
||||||
|
'content': '<Tool>'
|
||||||
|
})
|
||||||
|
args["stop_sequences"] = ['</Func>']
|
||||||
|
args["messages"] = self.merge_message(converted_messages)
|
||||||
|
if system_prompt:
|
||||||
|
if DEBUG:
|
||||||
|
logger.info("System Prompt: " + system_prompt)
|
||||||
|
args["system"] = system_prompt
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
def chat(self, chat_request: ChatRequest) -> ChatResponse:
|
||||||
|
if DEBUG:
|
||||||
|
logger.info("Raw request: " + chat_request.model_dump_json())
|
||||||
|
response = self._invoke_model(
|
||||||
|
args=self._parse_args(chat_request), model_id=chat_request.model
|
||||||
|
)
|
||||||
|
response_body = json.loads(response.get("body").read())
|
||||||
|
if DEBUG:
|
||||||
|
logger.info("Bedrock response body: " + str(response_body))
|
||||||
|
message = response_body["content"][0]["text"]
|
||||||
|
|
||||||
|
tools = None
|
||||||
|
if chat_request.tools:
|
||||||
|
if message.startswith("Y</Tool>"):
|
||||||
|
tools = self._parse_tool_message(message)
|
||||||
|
message = None
|
||||||
|
elif message.startswith("N</Tool>"):
|
||||||
|
message = message[8:].lstrip("\n")
|
||||||
|
return self._create_response(
|
||||||
|
model=chat_request.model,
|
||||||
|
message_id=response_body["id"],
|
||||||
|
message=message,
|
||||||
|
tools=tools,
|
||||||
|
finish_reason=response_body["stop_reason"],
|
||||||
|
input_tokens=response_body["usage"]["input_tokens"],
|
||||||
|
output_tokens=response_body["usage"]["output_tokens"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
|
||||||
|
response = self._invoke_model(
|
||||||
|
args=self._parse_args(chat_request),
|
||||||
|
model_id=chat_request.model,
|
||||||
|
with_stream=True,
|
||||||
|
)
|
||||||
|
msg_id = ""
|
||||||
|
chunk_id = 0
|
||||||
|
tool_message = ""
|
||||||
|
first_token = True
|
||||||
|
index = 0
|
||||||
|
for event in response.get("body"):
|
||||||
|
if DEBUG:
|
||||||
|
logger.info("Bedrock response chunk: " + str(event))
|
||||||
|
chunk = json.loads(event["chunk"]["bytes"])
|
||||||
|
chunk_id += 1
|
||||||
|
|
||||||
|
if chunk["type"] == "message_start":
|
||||||
|
msg_id = chunk["message"]["id"]
|
||||||
|
continue
|
||||||
|
|
||||||
|
if chunk["type"] == "message_delta":
|
||||||
|
chunk_message = ""
|
||||||
|
finish_reason = chunk["delta"]["stop_reason"]
|
||||||
|
|
||||||
|
elif chunk["type"] == "content_block_delta":
|
||||||
|
chunk_message = chunk["delta"]["text"]
|
||||||
|
finish_reason = None
|
||||||
|
if chat_request.tools:
|
||||||
|
# Check first token
|
||||||
|
if not tool_message and chunk_message == "Y":
|
||||||
|
tool_message = "Y"
|
||||||
|
continue
|
||||||
|
if tool_message:
|
||||||
|
# Buffer all chunk message
|
||||||
|
# in order to extract tool call info
|
||||||
|
tool_message += chunk_message
|
||||||
|
continue
|
||||||
|
if index < 3:
|
||||||
|
# Ignore the N</Tool>, which is 3 tokens
|
||||||
|
index += 1
|
||||||
|
continue
|
||||||
|
if first_token:
|
||||||
|
chunk_message = chunk_message.lstrip("\n")
|
||||||
|
first_token = False
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
response = self._create_response_stream(
|
||||||
|
model=chat_request.model,
|
||||||
|
message_id=msg_id,
|
||||||
|
chunk_message=chunk_message,
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield self._stream_response_to_bytes(response)
|
||||||
|
if chat_request.tools and tool_message:
|
||||||
|
tools = self._parse_tool_message(tool_message)
|
||||||
|
response = self._create_response_stream(
|
||||||
|
model=chat_request.model,
|
||||||
|
message_id=msg_id,
|
||||||
|
tools=tools,
|
||||||
|
)
|
||||||
|
yield self._stream_response_to_bytes(response)
|
||||||
|
|
||||||
|
def _parse_tool_message(self, tool_message: str) -> list[ToolCall]:
|
||||||
|
if DEBUG:
|
||||||
|
logger.info("Tool message: " + tool_message.replace("\n", " "))
|
||||||
|
try:
|
||||||
|
tool_messages = tool_message[tool_message.rindex("<Func>") + 6:]
|
||||||
|
function = json.loads(tool_messages.replace("\n", " "))
|
||||||
|
args = json.dumps(function.get("arguments", {}))
|
||||||
|
function = ResponseFunction(
|
||||||
|
name=function["name"], arguments=args
|
||||||
|
)
|
||||||
|
|
||||||
|
return [
|
||||||
|
ToolCall(
|
||||||
|
# id="0",
|
||||||
|
function=function,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to parse tool response")
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to parse tool response")
|
||||||
|
|
||||||
def _get_base64_image(self, image_url: str) -> tuple[str, str]:
|
def _get_base64_image(self, image_url: str) -> tuple[str, str]:
|
||||||
"""Try to get the base64 data from an image url.
|
"""Try to get the base64 data from an image url.
|
||||||
@@ -229,131 +420,6 @@ class ClaudeModel(BedrockModel):
|
|||||||
)
|
)
|
||||||
return content_parts
|
return content_parts
|
||||||
|
|
||||||
def _create_tool_prompt(self, tools: list[Tool]) -> str:
|
|
||||||
tool_prompt = "\nYou have access to the following tools:\n"
|
|
||||||
tool_prompt += json.dumps(
|
|
||||||
[tool.function.model_dump() for tool in tools], indent=2
|
|
||||||
)
|
|
||||||
tool_prompt += (
|
|
||||||
"\nIf you need to use one of the above tools, "
|
|
||||||
"only respond with a JSON object matching the following schema inside a <tool></tool> xml tag: \n"
|
|
||||||
'{"name": $TOOL_NAME, "arguments": {"$PARAMETER_NAME": "$PARAMETER_VALUE", ...}\n'
|
|
||||||
)
|
|
||||||
return tool_prompt
|
|
||||||
|
|
||||||
def _parse_args(self, chat_request: ChatRequest) -> dict:
|
|
||||||
args = {
|
|
||||||
"anthropic_version": self.anthropic_version,
|
|
||||||
"max_tokens": chat_request.max_tokens,
|
|
||||||
"top_p": chat_request.top_p,
|
|
||||||
"temperature": chat_request.temperature,
|
|
||||||
}
|
|
||||||
system_prompt = ""
|
|
||||||
converted_messages = []
|
|
||||||
for message in chat_request.messages:
|
|
||||||
if message.role == "system":
|
|
||||||
system_prompt += message.content + "\n"
|
|
||||||
elif message.role == "user" and not isinstance(message.content, str):
|
|
||||||
converted_messages.append(
|
|
||||||
{
|
|
||||||
"role": message.role,
|
|
||||||
"content": self._parse_content_parts(message.content),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
elif message.role == "assistant" and not message.content:
|
|
||||||
# if content is empty
|
|
||||||
# create the content using the tool call info.
|
|
||||||
tool_content = "Should use {} tool with args: {}".format(
|
|
||||||
message.tool_calls[0].function.name,
|
|
||||||
message.tool_calls[0].function.arguments,
|
|
||||||
)
|
|
||||||
converted_messages.append(
|
|
||||||
{"role": message.role, "content": tool_content}
|
|
||||||
)
|
|
||||||
elif message.role == "tool":
|
|
||||||
# Since bedrock does not support tool role
|
|
||||||
# Convert the tool message to a user message.
|
|
||||||
converted_messages.append(
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": "The result of the tool call is " + message.content,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
converted_messages.append(
|
|
||||||
{"role": message.role, "content": message.content}
|
|
||||||
)
|
|
||||||
|
|
||||||
if chat_request.tools:
|
|
||||||
system_prompt += self._create_tool_prompt(chat_request.tools)
|
|
||||||
|
|
||||||
args["messages"] = converted_messages
|
|
||||||
if system_prompt:
|
|
||||||
if DEBUG:
|
|
||||||
logger.info("System Prompt: " + system_prompt)
|
|
||||||
args["system"] = system_prompt.replace("\n", "")
|
|
||||||
return args
|
|
||||||
|
|
||||||
def chat(self, chat_request: ChatRequest) -> ChatResponse:
|
|
||||||
if DEBUG:
|
|
||||||
logger.info("Raw request: " + chat_request.model_dump_json())
|
|
||||||
response = self._invoke_model(
|
|
||||||
args=self._parse_args(chat_request), model_id=chat_request.model
|
|
||||||
)
|
|
||||||
response_body = json.loads(response.get("body").read())
|
|
||||||
if DEBUG:
|
|
||||||
logger.info("Bedrock response body: " + str(response_body))
|
|
||||||
message = response_body["content"][0]["text"]
|
|
||||||
|
|
||||||
tools_message = None
|
|
||||||
start = message.find("<tool>")
|
|
||||||
end = message.find("</tool>")
|
|
||||||
if start != -1 and end != -1:
|
|
||||||
tools_message = message[start + 6: end]
|
|
||||||
return self._create_response(
|
|
||||||
model=chat_request.model,
|
|
||||||
message=response_body["content"][0]["text"],
|
|
||||||
message_id=response_body["id"],
|
|
||||||
tools_message=tools_message,
|
|
||||||
input_tokens=response_body["usage"]["input_tokens"],
|
|
||||||
output_tokens=response_body["usage"]["output_tokens"],
|
|
||||||
)
|
|
||||||
|
|
||||||
def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
|
|
||||||
response = self._invoke_model(
|
|
||||||
args=self._parse_args(chat_request),
|
|
||||||
model_id=chat_request.model,
|
|
||||||
with_stream=True,
|
|
||||||
)
|
|
||||||
msg_id = ""
|
|
||||||
chunk_id = 0
|
|
||||||
for event in response.get("body"):
|
|
||||||
if DEBUG:
|
|
||||||
logger.info("Bedrock response chunk: " + str(event))
|
|
||||||
chunk = json.loads(event["chunk"]["bytes"])
|
|
||||||
chunk_id += 1
|
|
||||||
if chunk["type"] == "message_start":
|
|
||||||
msg_id = chunk["message"]["id"]
|
|
||||||
continue
|
|
||||||
|
|
||||||
if chunk["type"] == "message_delta":
|
|
||||||
chunk_message = ""
|
|
||||||
finish_reason = "stop"
|
|
||||||
|
|
||||||
elif chunk["type"] == "content_block_delta":
|
|
||||||
chunk_message = chunk["delta"]["text"]
|
|
||||||
finish_reason = None
|
|
||||||
else:
|
|
||||||
continue
|
|
||||||
response = self._create_response_stream(
|
|
||||||
model=chat_request.model,
|
|
||||||
message_id=msg_id,
|
|
||||||
chunk_message=chunk_message,
|
|
||||||
finish_reason=finish_reason,
|
|
||||||
)
|
|
||||||
|
|
||||||
yield self._stream_response_to_bytes(response)
|
|
||||||
|
|
||||||
|
|
||||||
class Llama2Model(BedrockModel):
|
class Llama2Model(BedrockModel):
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import time
|
import time
|
||||||
|
import uuid
|
||||||
from typing import Literal, Iterable
|
from typing import Literal, Iterable
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
@@ -22,7 +23,7 @@ class ResponseFunction(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class ToolCall(BaseModel):
|
class ToolCall(BaseModel):
|
||||||
id: str
|
id: str = Field(default_factory=lambda: str(uuid.uuid4())[:8])
|
||||||
type: Literal["function"] = "function"
|
type: Literal["function"] = "function"
|
||||||
function: ResponseFunction
|
function: ResponseFunction
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user