From 76a3614f1768e6f0ce161bdd7940dfcb6b16e9b0 Mon Sep 17 00:00:00 2001 From: Mengxin Zhu <843303+zxkane@users.noreply.github.com> Date: Mon, 30 Jun 2025 00:14:07 +0800 Subject: [PATCH] fix: properly handle tool_use messages in conversation --- .gitignore | 3 +- src/api/app.py | 10 +++++ src/api/models/base.py | 4 ++ src/api/models/bedrock.py | 87 ++++++++++++++++++++++++++++++++++++--- src/api/schema.py | 7 +++- 5 files changed, 103 insertions(+), 8 deletions(-) diff --git a/.gitignore b/.gitignore index d8b355e..23212b8 100644 --- a/.gitignore +++ b/.gitignore @@ -159,4 +159,5 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. .idea/ -Config \ No newline at end of file +Config +.vscode/launch.json diff --git a/src/api/app.py b/src/api/app.py index 49a0519..5ea7ae7 100644 --- a/src/api/app.py +++ b/src/api/app.py @@ -45,6 +45,16 @@ async def health(): @app.exception_handler(RequestValidationError) async def validation_exception_handler(request, exc): + logger = logging.getLogger(__name__) + + # Log essential info only - avoid sensitive data and performance overhead + logger.warning( + "Request validation failed: %s %s - %s", + request.method, + request.url.path, + str(exc).split('\n')[0] # First line only + ) + return PlainTextResponse(str(exc), status_code=400) diff --git a/src/api/models/base.py b/src/api/models/base.py index 6d45340..5e7a9cb 100644 --- a/src/api/models/base.py +++ b/src/api/models/base.py @@ -1,3 +1,4 @@ +import logging import time import uuid from abc import ABC, abstractmethod @@ -14,6 +15,8 @@ from api.schema import ( Error, ) +logger = logging.getLogger(__name__) + class BaseChatModel(ABC): """Represent a basic chat model @@ -46,6 +49,7 @@ class BaseChatModel(ABC): @staticmethod def stream_response_to_bytes(response: ChatStreamResponse | Error | None = None) -> bytes: if isinstance(response, Error): + logger.error("Stream error: %s", response.error.message if response.error else "Unknown error") data = response.model_dump_json() elif isinstance(response, ChatStreamResponse): # to populate other fields when using exclude_unset=True diff --git a/src/api/models/bedrock.py b/src/api/models/bedrock.py index d17b300..9a8fd3c 100644 --- a/src/api/models/bedrock.py +++ b/src/api/models/bedrock.py @@ -34,6 +34,7 @@ from api.schema import ( ResponseFunction, TextContent, ToolCall, + ToolContent, ToolMessage, Usage, UserMessage, @@ -48,7 +49,15 @@ from api.setting import ( logger = logging.getLogger(__name__) -config = Config(connect_timeout=60, read_timeout=120, retries={"max_attempts": 1}) +config = Config( + connect_timeout=60, # Connection timeout: 60 seconds + read_timeout=900, # Read timeout: 15 minutes (suitable for long streaming responses) + retries={ + 'max_attempts': 8, # Maximum retry attempts + 'mode': 'adaptive' # Adaptive retry mode + }, + max_pool_connections=50 # Maximum connection pool size + ) bedrock_runtime = boto3.client( service_name="bedrock-runtime", @@ -177,6 +186,7 @@ class BedrockModel(BaseChatModel): # check if model is supported if chat_request.model not in bedrock_model_list.keys(): error = f"Unsupported model {chat_request.model}, please use models API to get a list of supported models" + logger.error("Unsupported model: %s", chat_request.model) if error: raise HTTPException( @@ -204,13 +214,13 @@ class BedrockModel(BaseChatModel): # Run the blocking boto3 call in a thread pool response = await run_in_threadpool(bedrock_runtime.converse, **args) except bedrock_runtime.exceptions.ValidationException as e: - logger.error("Validation Error: " + str(e)) + logger.error("Bedrock validation error for model %s: %s", chat_request.model, str(e)) raise HTTPException(status_code=400, detail=str(e)) except bedrock_runtime.exceptions.ThrottlingException as e: - logger.error("Throttling Error: " + str(e)) + logger.warning("Bedrock throttling for model %s: %s", chat_request.model, str(e)) raise HTTPException(status_code=429, detail=str(e)) except Exception as e: - logger.error(e) + logger.error("Bedrock invocation failed for model %s: %s", chat_request.model, str(e)) raise HTTPException(status_code=500, detail=str(e)) return response @@ -270,6 +280,7 @@ class BedrockModel(BaseChatModel): # return an [DONE] message at the end. yield self.stream_response_to_bytes() except Exception as e: + logger.error("Stream error for model %s: %s", chat_request.model, str(e)) error_event = Error(error=ErrorMessage(message=str(e))) yield self.stream_response_to_bytes(error_event) @@ -317,7 +328,16 @@ class BedrockModel(BaseChatModel): } ) elif isinstance(message, AssistantMessage): - if message.content.strip(): + # Check if message has content that's not empty + has_content = False + if isinstance(message.content, str): + has_content = message.content.strip() != "" + elif isinstance(message.content, list): + has_content = len(message.content) > 0 + elif message.content is not None: + has_content = True + + if has_content: # Text message messages.append( { @@ -349,6 +369,10 @@ class BedrockModel(BaseChatModel): # Bedrock does not support tool role, # Add toolResult to content # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolResultBlock.html + + # Handle different content formats from OpenAI SDK + tool_content = self._extract_tool_content(message.content) + messages.append( { "role": "user", @@ -356,7 +380,7 @@ class BedrockModel(BaseChatModel): { "toolResult": { "toolUseId": message.tool_call_id, - "content": [{"text": message.content}], + "content": [{"text": tool_content}], } } ], @@ -368,6 +392,57 @@ class BedrockModel(BaseChatModel): continue return self._reframe_multi_payloard(messages) + def _extract_tool_content(self, content) -> str: + """Extract text content from various OpenAI SDK tool message formats. + + Handles: + - String content (legacy format) + - List of content objects (OpenAI SDK 1.91.0+) + - Nested JSON structures within text content + """ + try: + if isinstance(content, str): + return content + + if isinstance(content, list): + text_parts = [] + for i, item in enumerate(content): + if isinstance(item, dict): + # Handle dict with 'text' field + if "text" in item: + item_text = item["text"] + if isinstance(item_text, str): + # Try to parse as JSON if it looks like JSON + if item_text.strip().startswith('{') and item_text.strip().endswith('}'): + try: + parsed_json = json.loads(item_text) + # Convert JSON object to readable text + text_parts.append(json.dumps(parsed_json, indent=2)) + except json.JSONDecodeError: + # Silently fallback to original text + text_parts.append(item_text) + else: + text_parts.append(item_text) + else: + text_parts.append(str(item_text)) + else: + # Handle other dict formats - convert to JSON string + text_parts.append(json.dumps(item, indent=2)) + elif hasattr(item, 'text'): + # Handle ToolContent objects + text_parts.append(item.text) + else: + # Convert any other type to string + text_parts.append(str(item)) + return "\n".join(text_parts) + + # Fallback for any other type + return str(content) + except Exception as e: + logger.warning("Tool content extraction failed: %s", str(e)) + # Return a safe fallback + return str(content) if content is not None else "" + def _reframe_multi_payloard(self, messages: list) -> list: """Receive messages and reformat them to comply with the Claude format diff --git a/src/api/schema.py b/src/api/schema.py index df80534..b6b8c15 100644 --- a/src/api/schema.py +++ b/src/api/schema.py @@ -45,6 +45,11 @@ class ImageContent(BaseModel): image_url: ImageUrl +class ToolContent(BaseModel): + type: Literal["text"] = "text" + text: str + + class SystemMessage(BaseModel): name: str | None = None role: Literal["system"] = "system" @@ -66,7 +71,7 @@ class AssistantMessage(BaseModel): class ToolMessage(BaseModel): role: Literal["tool"] = "tool" - content: str + content: str | list[ToolContent] | list[dict] tool_call_id: str