fix: properly handle tool_use messages in conversation

This commit is contained in:
Mengxin Zhu
2025-06-30 00:14:07 +08:00
parent 01836087b1
commit 76a3614f17
5 changed files with 103 additions and 8 deletions

3
.gitignore vendored
View File

@@ -159,4 +159,5 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/
Config
Config
.vscode/launch.json

View File

@@ -45,6 +45,16 @@ async def health():
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request, exc):
logger = logging.getLogger(__name__)
# Log essential info only - avoid sensitive data and performance overhead
logger.warning(
"Request validation failed: %s %s - %s",
request.method,
request.url.path,
str(exc).split('\n')[0] # First line only
)
return PlainTextResponse(str(exc), status_code=400)

View File

@@ -1,3 +1,4 @@
import logging
import time
import uuid
from abc import ABC, abstractmethod
@@ -14,6 +15,8 @@ from api.schema import (
Error,
)
logger = logging.getLogger(__name__)
class BaseChatModel(ABC):
"""Represent a basic chat model
@@ -46,6 +49,7 @@ class BaseChatModel(ABC):
@staticmethod
def stream_response_to_bytes(response: ChatStreamResponse | Error | None = None) -> bytes:
if isinstance(response, Error):
logger.error("Stream error: %s", response.error.message if response.error else "Unknown error")
data = response.model_dump_json()
elif isinstance(response, ChatStreamResponse):
# to populate other fields when using exclude_unset=True

View File

@@ -34,6 +34,7 @@ from api.schema import (
ResponseFunction,
TextContent,
ToolCall,
ToolContent,
ToolMessage,
Usage,
UserMessage,
@@ -48,7 +49,15 @@ from api.setting import (
logger = logging.getLogger(__name__)
config = Config(connect_timeout=60, read_timeout=120, retries={"max_attempts": 1})
config = Config(
connect_timeout=60, # Connection timeout: 60 seconds
read_timeout=900, # Read timeout: 15 minutes (suitable for long streaming responses)
retries={
'max_attempts': 8, # Maximum retry attempts
'mode': 'adaptive' # Adaptive retry mode
},
max_pool_connections=50 # Maximum connection pool size
)
bedrock_runtime = boto3.client(
service_name="bedrock-runtime",
@@ -177,6 +186,7 @@ class BedrockModel(BaseChatModel):
# check if model is supported
if chat_request.model not in bedrock_model_list.keys():
error = f"Unsupported model {chat_request.model}, please use models API to get a list of supported models"
logger.error("Unsupported model: %s", chat_request.model)
if error:
raise HTTPException(
@@ -204,13 +214,13 @@ class BedrockModel(BaseChatModel):
# Run the blocking boto3 call in a thread pool
response = await run_in_threadpool(bedrock_runtime.converse, **args)
except bedrock_runtime.exceptions.ValidationException as e:
logger.error("Validation Error: " + str(e))
logger.error("Bedrock validation error for model %s: %s", chat_request.model, str(e))
raise HTTPException(status_code=400, detail=str(e))
except bedrock_runtime.exceptions.ThrottlingException as e:
logger.error("Throttling Error: " + str(e))
logger.warning("Bedrock throttling for model %s: %s", chat_request.model, str(e))
raise HTTPException(status_code=429, detail=str(e))
except Exception as e:
logger.error(e)
logger.error("Bedrock invocation failed for model %s: %s", chat_request.model, str(e))
raise HTTPException(status_code=500, detail=str(e))
return response
@@ -270,6 +280,7 @@ class BedrockModel(BaseChatModel):
# return an [DONE] message at the end.
yield self.stream_response_to_bytes()
except Exception as e:
logger.error("Stream error for model %s: %s", chat_request.model, str(e))
error_event = Error(error=ErrorMessage(message=str(e)))
yield self.stream_response_to_bytes(error_event)
@@ -317,7 +328,16 @@ class BedrockModel(BaseChatModel):
}
)
elif isinstance(message, AssistantMessage):
if message.content.strip():
# Check if message has content that's not empty
has_content = False
if isinstance(message.content, str):
has_content = message.content.strip() != ""
elif isinstance(message.content, list):
has_content = len(message.content) > 0
elif message.content is not None:
has_content = True
if has_content:
# Text message
messages.append(
{
@@ -349,6 +369,10 @@ class BedrockModel(BaseChatModel):
# Bedrock does not support tool role,
# Add toolResult to content
# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolResultBlock.html
# Handle different content formats from OpenAI SDK
tool_content = self._extract_tool_content(message.content)
messages.append(
{
"role": "user",
@@ -356,7 +380,7 @@ class BedrockModel(BaseChatModel):
{
"toolResult": {
"toolUseId": message.tool_call_id,
"content": [{"text": message.content}],
"content": [{"text": tool_content}],
}
}
],
@@ -368,6 +392,57 @@ class BedrockModel(BaseChatModel):
continue
return self._reframe_multi_payloard(messages)
def _extract_tool_content(self, content) -> str:
"""Extract text content from various OpenAI SDK tool message formats.
Handles:
- String content (legacy format)
- List of content objects (OpenAI SDK 1.91.0+)
- Nested JSON structures within text content
"""
try:
if isinstance(content, str):
return content
if isinstance(content, list):
text_parts = []
for i, item in enumerate(content):
if isinstance(item, dict):
# Handle dict with 'text' field
if "text" in item:
item_text = item["text"]
if isinstance(item_text, str):
# Try to parse as JSON if it looks like JSON
if item_text.strip().startswith('{') and item_text.strip().endswith('}'):
try:
parsed_json = json.loads(item_text)
# Convert JSON object to readable text
text_parts.append(json.dumps(parsed_json, indent=2))
except json.JSONDecodeError:
# Silently fallback to original text
text_parts.append(item_text)
else:
text_parts.append(item_text)
else:
text_parts.append(str(item_text))
else:
# Handle other dict formats - convert to JSON string
text_parts.append(json.dumps(item, indent=2))
elif hasattr(item, 'text'):
# Handle ToolContent objects
text_parts.append(item.text)
else:
# Convert any other type to string
text_parts.append(str(item))
return "\n".join(text_parts)
# Fallback for any other type
return str(content)
except Exception as e:
logger.warning("Tool content extraction failed: %s", str(e))
# Return a safe fallback
return str(content) if content is not None else ""
def _reframe_multi_payloard(self, messages: list) -> list:
"""Receive messages and reformat them to comply with the Claude format

View File

@@ -45,6 +45,11 @@ class ImageContent(BaseModel):
image_url: ImageUrl
class ToolContent(BaseModel):
type: Literal["text"] = "text"
text: str
class SystemMessage(BaseModel):
name: str | None = None
role: Literal["system"] = "system"
@@ -66,7 +71,7 @@ class AssistantMessage(BaseModel):
class ToolMessage(BaseModel):
role: Literal["tool"] = "tool"
content: str
content: str | list[ToolContent] | list[dict]
tool_call_id: str