Support multiple tool calls

This commit is contained in:
Aiden Dai
2024-06-11 16:58:26 +08:00
parent 56786f9e32
commit b3509ee0f0

View File

@@ -398,10 +398,11 @@ class BedrockModel(BaseChatModel):
) )
if finish_reason == "tool_use": if finish_reason == "tool_use":
# https://docs.aws.amazon.com/bedrock/latest/userguide/tool-use.html#tool-use-examples # https://docs.aws.amazon.com/bedrock/latest/userguide/tool-use.html#tool-use-examples
tool_calls = []
for part in content: for part in content:
if "toolUse" in part: if "toolUse" in part:
tool = part["toolUse"] tool = part["toolUse"]
message.tool_calls = [ tool_calls.append(
ToolCall( ToolCall(
id=tool["toolUseId"], id=tool["toolUseId"],
type="function", type="function",
@@ -410,8 +411,9 @@ class BedrockModel(BaseChatModel):
arguments=json.dumps(tool["input"]), arguments=json.dumps(tool["input"]),
), ),
) )
] )
message.content = None message.tool_calls = tool_calls
message.content = None
else: else:
message.content = content[0]["text"] message.content = content[0]["text"]
@@ -450,7 +452,6 @@ class BedrockModel(BaseChatModel):
finish_reason = None finish_reason = None
message = None message = None
usage = None usage = None
if "messageStart" in chunk: if "messageStart" in chunk:
message = ChatResponseMessage( message = ChatResponseMessage(
role=chunk["messageStart"]["role"], role=chunk["messageStart"]["role"],
@@ -460,10 +461,12 @@ class BedrockModel(BaseChatModel):
# tool call start # tool call start
delta = chunk["contentBlockStart"]["start"] delta = chunk["contentBlockStart"]["start"]
if "toolUse" in delta: if "toolUse" in delta:
# first index is content
index = chunk["contentBlockStart"]["contentBlockIndex"] - 1
message = ChatResponseMessage( message = ChatResponseMessage(
tool_calls=[ tool_calls=[
ToolCall( ToolCall(
index=0, index=index,
type="function", type="function",
id=delta["toolUse"]["toolUseId"], id=delta["toolUse"]["toolUseId"],
function=ResponseFunction( function=ResponseFunction(
@@ -482,10 +485,11 @@ class BedrockModel(BaseChatModel):
) )
else: else:
# tool use # tool use
index = chunk["contentBlockDelta"]["contentBlockIndex"] - 1
message = ChatResponseMessage( message = ChatResponseMessage(
tool_calls=[ tool_calls=[
ToolCall( ToolCall(
index=0, index=index,
function=ResponseFunction( function=ResponseFunction(
arguments=delta["toolUse"]["input"], arguments=delta["toolUse"]["input"],
) )