From b3509ee0f0233017783e53ddc43629683d4c61a5 Mon Sep 17 00:00:00 2001 From: Aiden Dai Date: Tue, 11 Jun 2024 16:58:26 +0800 Subject: [PATCH] Support multiple tool calls --- src/api/models/bedrock.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/api/models/bedrock.py b/src/api/models/bedrock.py index 6063ffb..63da5c9 100644 --- a/src/api/models/bedrock.py +++ b/src/api/models/bedrock.py @@ -398,10 +398,11 @@ class BedrockModel(BaseChatModel): ) if finish_reason == "tool_use": # https://docs.aws.amazon.com/bedrock/latest/userguide/tool-use.html#tool-use-examples + tool_calls = [] for part in content: if "toolUse" in part: tool = part["toolUse"] - message.tool_calls = [ + tool_calls.append( ToolCall( id=tool["toolUseId"], type="function", @@ -410,8 +411,9 @@ class BedrockModel(BaseChatModel): arguments=json.dumps(tool["input"]), ), ) - ] - message.content = None + ) + message.tool_calls = tool_calls + message.content = None else: message.content = content[0]["text"] @@ -450,7 +452,6 @@ class BedrockModel(BaseChatModel): finish_reason = None message = None usage = None - if "messageStart" in chunk: message = ChatResponseMessage( role=chunk["messageStart"]["role"], @@ -460,10 +461,12 @@ class BedrockModel(BaseChatModel): # tool call start delta = chunk["contentBlockStart"]["start"] if "toolUse" in delta: + # first index is content + index = chunk["contentBlockStart"]["contentBlockIndex"] - 1 message = ChatResponseMessage( tool_calls=[ ToolCall( - index=0, + index=index, type="function", id=delta["toolUse"]["toolUseId"], function=ResponseFunction( @@ -482,10 +485,11 @@ class BedrockModel(BaseChatModel): ) else: # tool use + index = chunk["contentBlockDelta"]["contentBlockIndex"] - 1 message = ChatResponseMessage( tool_calls=[ ToolCall( - index=0, + index=index, function=ResponseFunction( arguments=delta["toolUse"]["input"], )