diff --git a/src/api/models/bedrock.py b/src/api/models/bedrock.py index 54c709c..3e77eec 100644 --- a/src/api/models/bedrock.py +++ b/src/api/models/bedrock.py @@ -420,6 +420,19 @@ class BedrockModel(BaseChatModel): "system": system_prompts, "inferenceConfig": inference_config, } + if chat_request.reasoning_effort: + # From OpenAI api, the max_token is not supported in reasoning mode + # Use max_completion_tokens if provided. + max_tokens = chat_request.max_completion_tokens if chat_request.max_completion_tokens else chat_request.max_tokens + inference_config["maxTokens"] = max_tokens + # unset topP - Not supported + inference_config.pop("topP") + args["additionalModelRequestFields"] = { + "reasoning_config": { + "type": "enabled", + "budget_tokens": max_tokens - 1 + } + } # add tool config if chat_request.tools: args["toolConfig"] = { @@ -476,8 +489,13 @@ class BedrockModel(BaseChatModel): message.content = None else: message.content = "" - if content: - message.content = content[0]["text"] + for c in content: + if "reasoningContent" in c: + message.reasoning_content = c["reasoningContent"]["reasoningText"].get("text", "") + if "text" in c: + message.content = c["text"] + else: + logger.warning("Unknown tag in message content " + ",".join(c.keys())) response = ChatResponse( id=message_id, diff --git a/src/api/schema.py b/src/api/schema.py index b8ec75d..5d081cf 100644 --- a/src/api/schema.py +++ b/src/api/schema.py @@ -94,6 +94,8 @@ class ChatRequest(BaseModel): top_p: float | None = Field(default=1.0, le=1.0, ge=0.0) user: str | None = None # Not used max_tokens: int | None = 2048 + max_completion_tokens: int | None = None + reasoning_effort: Literal["low", "medium", "high"] | None = None n: int | None = 1 # Not used tools: list[Tool] | None = None tool_choice: str | object = "auto" @@ -111,6 +113,7 @@ class ChatResponseMessage(BaseModel): role: Literal["assistant"] | None = None content: str | None = None tool_calls: list[ToolCall] | None = None + reasoning_content: str | None = None class BaseChoice(BaseModel): diff --git a/src/requirements.txt b/src/requirements.txt index 9d2b47c..7ad8fb1 100644 --- a/src/requirements.txt +++ b/src/requirements.txt @@ -1,9 +1,9 @@ -fastapi==0.115.6 +fastapi==0.115.8 pydantic==2.7.1 uvicorn==0.29.0 mangum==0.17.0 tiktoken==0.6.0 requests==2.32.3 numpy==1.26.4 -boto3==1.35.81 -botocore==1.35.81 \ No newline at end of file +boto3==1.37.0 +botocore==1.37.0 \ No newline at end of file