Support of reasoning
This commit is contained in:
@@ -210,7 +210,6 @@ class BedrockModel(BaseChatModel):
|
|||||||
"""Default implementation for Chat Stream API"""
|
"""Default implementation for Chat Stream API"""
|
||||||
response = self._invoke_bedrock(chat_request, stream=True)
|
response = self._invoke_bedrock(chat_request, stream=True)
|
||||||
message_id = self.generate_message_id()
|
message_id = self.generate_message_id()
|
||||||
|
|
||||||
stream = response.get("stream")
|
stream = response.get("stream")
|
||||||
for chunk in stream:
|
for chunk in stream:
|
||||||
stream_response = self._create_response_stream(
|
stream_response = self._create_response_stream(
|
||||||
@@ -398,7 +397,6 @@ class BedrockModel(BaseChatModel):
|
|||||||
|
|
||||||
Ref: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html
|
Ref: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html
|
||||||
"""
|
"""
|
||||||
|
|
||||||
messages = self._parse_messages(chat_request)
|
messages = self._parse_messages(chat_request)
|
||||||
system_prompts = self._parse_system_prompts(chat_request)
|
system_prompts = self._parse_system_prompts(chat_request)
|
||||||
|
|
||||||
@@ -424,14 +422,17 @@ class BedrockModel(BaseChatModel):
|
|||||||
if chat_request.reasoning_effort:
|
if chat_request.reasoning_effort:
|
||||||
# From OpenAI api, the max_token is not supported in reasoning mode
|
# From OpenAI api, the max_token is not supported in reasoning mode
|
||||||
# Use max_completion_tokens if provided.
|
# Use max_completion_tokens if provided.
|
||||||
|
|
||||||
max_tokens = chat_request.max_completion_tokens if chat_request.max_completion_tokens else chat_request.max_tokens
|
max_tokens = chat_request.max_completion_tokens if chat_request.max_completion_tokens else chat_request.max_tokens
|
||||||
|
budget_tokens = self._calc_budget_tokens(max_tokens, chat_request.reasoning_effort)
|
||||||
inference_config["maxTokens"] = max_tokens
|
inference_config["maxTokens"] = max_tokens
|
||||||
# unset topP - Not supported
|
# unset topP - Not supported
|
||||||
inference_config.pop("topP")
|
inference_config.pop("topP")
|
||||||
|
|
||||||
args["additionalModelRequestFields"] = {
|
args["additionalModelRequestFields"] = {
|
||||||
"reasoning_config": {
|
"reasoning_config": {
|
||||||
"type": "enabled",
|
"type": "enabled",
|
||||||
"budget_tokens": max_tokens - 1
|
"budget_tokens": budget_tokens
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
# add tool config
|
# add tool config
|
||||||
@@ -493,7 +494,7 @@ class BedrockModel(BaseChatModel):
|
|||||||
for c in content:
|
for c in content:
|
||||||
if "reasoningContent" in c:
|
if "reasoningContent" in c:
|
||||||
message.reasoning_content = c["reasoningContent"]["reasoningText"].get("text", "")
|
message.reasoning_content = c["reasoningContent"]["reasoningText"].get("text", "")
|
||||||
if "text" in c:
|
elif "text" in c:
|
||||||
message.content = c["text"]
|
message.content = c["text"]
|
||||||
else:
|
else:
|
||||||
logger.warning("Unknown tag in message content " + ",".join(c.keys()))
|
logger.warning("Unknown tag in message content " + ",".join(c.keys()))
|
||||||
@@ -564,6 +565,12 @@ class BedrockModel(BaseChatModel):
|
|||||||
message = ChatResponseMessage(
|
message = ChatResponseMessage(
|
||||||
content=delta["text"],
|
content=delta["text"],
|
||||||
)
|
)
|
||||||
|
elif "reasoningContent" in delta:
|
||||||
|
# ignore "signature" in the delta.
|
||||||
|
if "text" in delta["reasoningContent"]:
|
||||||
|
message = ChatResponseMessage(
|
||||||
|
reasoning_content=delta["reasoningContent"]["text"],
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# tool use
|
# tool use
|
||||||
index = chunk["contentBlockDelta"]["contentBlockIndex"] - 1
|
index = chunk["contentBlockDelta"]["contentBlockIndex"] - 1
|
||||||
@@ -701,6 +708,18 @@ class BedrockModel(BaseChatModel):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def _calc_budget_tokens(self, max_tokens: int, reasoning_effort: Literal["low", "medium", "high"]) -> int:
|
||||||
|
# Helper function to calculate budget_tokens based on the max_tokens.
|
||||||
|
# Ratio for efforts: Low - 30%, medium - 60%, High: Max token - 1
|
||||||
|
# Note that The minimum budget_tokens is 1,024 tokens so far.
|
||||||
|
# But it may be changed for different models in the future.
|
||||||
|
if reasoning_effort == "low":
|
||||||
|
return int(max_tokens * 0.3)
|
||||||
|
elif reasoning_effort == "medium":
|
||||||
|
return int(max_tokens * 0.6)
|
||||||
|
else:
|
||||||
|
return max_tokens - 1
|
||||||
|
|
||||||
def _convert_finish_reason(self, finish_reason: str | None) -> str | None:
|
def _convert_finish_reason(self, finish_reason: str | None) -> str | None:
|
||||||
"""
|
"""
|
||||||
Below is a list of finish reason according to OpenAI doc:
|
Below is a list of finish reason according to OpenAI doc:
|
||||||
|
|||||||
Reference in New Issue
Block a user