diff --git a/src/api/models/bedrock.py b/src/api/models/bedrock.py index 63da5c9..fa8e319 100644 --- a/src/api/models/bedrock.py +++ b/src/api/models/bedrock.py @@ -381,6 +381,20 @@ class BedrockModel(BaseChatModel): self._convert_tool_spec(t.function) for t in chat_request.tools ] } + + if chat_request.tool_choice: + if isinstance(chat_request.tool_choice, str): + # auto (default) is mapped to {"auto" : {}} + # required is mapped to {"any" : {}} + if chat_request.tool_choice == "required": + args["toolConfig"]["toolChoice"] = {"any": {}} + else: + args["toolConfig"]["toolChoice"] = {"auto": {}} + else: + # Specific tool to use + assert "function" in chat_request.tool_choice + args["toolConfig"]["toolChoice"] = { + "tool": {"name": chat_request.tool_choice["function"].get("name", "")}} return args def _create_response(