fix: properly handle tool_use messages in conversation
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -160,3 +160,4 @@ cython_debug/
|
|||||||
.idea/
|
.idea/
|
||||||
|
|
||||||
Config
|
Config
|
||||||
|
.vscode/launch.json
|
||||||
|
|||||||
@@ -45,6 +45,16 @@ async def health():
|
|||||||
|
|
||||||
@app.exception_handler(RequestValidationError)
|
@app.exception_handler(RequestValidationError)
|
||||||
async def validation_exception_handler(request, exc):
|
async def validation_exception_handler(request, exc):
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Log essential info only - avoid sensitive data and performance overhead
|
||||||
|
logger.warning(
|
||||||
|
"Request validation failed: %s %s - %s",
|
||||||
|
request.method,
|
||||||
|
request.url.path,
|
||||||
|
str(exc).split('\n')[0] # First line only
|
||||||
|
)
|
||||||
|
|
||||||
return PlainTextResponse(str(exc), status_code=400)
|
return PlainTextResponse(str(exc), status_code=400)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import logging
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
@@ -14,6 +15,8 @@ from api.schema import (
|
|||||||
Error,
|
Error,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class BaseChatModel(ABC):
|
class BaseChatModel(ABC):
|
||||||
"""Represent a basic chat model
|
"""Represent a basic chat model
|
||||||
@@ -46,6 +49,7 @@ class BaseChatModel(ABC):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def stream_response_to_bytes(response: ChatStreamResponse | Error | None = None) -> bytes:
|
def stream_response_to_bytes(response: ChatStreamResponse | Error | None = None) -> bytes:
|
||||||
if isinstance(response, Error):
|
if isinstance(response, Error):
|
||||||
|
logger.error("Stream error: %s", response.error.message if response.error else "Unknown error")
|
||||||
data = response.model_dump_json()
|
data = response.model_dump_json()
|
||||||
elif isinstance(response, ChatStreamResponse):
|
elif isinstance(response, ChatStreamResponse):
|
||||||
# to populate other fields when using exclude_unset=True
|
# to populate other fields when using exclude_unset=True
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ from api.schema import (
|
|||||||
ResponseFunction,
|
ResponseFunction,
|
||||||
TextContent,
|
TextContent,
|
||||||
ToolCall,
|
ToolCall,
|
||||||
|
ToolContent,
|
||||||
ToolMessage,
|
ToolMessage,
|
||||||
Usage,
|
Usage,
|
||||||
UserMessage,
|
UserMessage,
|
||||||
@@ -48,7 +49,15 @@ from api.setting import (
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
config = Config(connect_timeout=60, read_timeout=120, retries={"max_attempts": 1})
|
config = Config(
|
||||||
|
connect_timeout=60, # Connection timeout: 60 seconds
|
||||||
|
read_timeout=900, # Read timeout: 15 minutes (suitable for long streaming responses)
|
||||||
|
retries={
|
||||||
|
'max_attempts': 8, # Maximum retry attempts
|
||||||
|
'mode': 'adaptive' # Adaptive retry mode
|
||||||
|
},
|
||||||
|
max_pool_connections=50 # Maximum connection pool size
|
||||||
|
)
|
||||||
|
|
||||||
bedrock_runtime = boto3.client(
|
bedrock_runtime = boto3.client(
|
||||||
service_name="bedrock-runtime",
|
service_name="bedrock-runtime",
|
||||||
@@ -177,6 +186,7 @@ class BedrockModel(BaseChatModel):
|
|||||||
# check if model is supported
|
# check if model is supported
|
||||||
if chat_request.model not in bedrock_model_list.keys():
|
if chat_request.model not in bedrock_model_list.keys():
|
||||||
error = f"Unsupported model {chat_request.model}, please use models API to get a list of supported models"
|
error = f"Unsupported model {chat_request.model}, please use models API to get a list of supported models"
|
||||||
|
logger.error("Unsupported model: %s", chat_request.model)
|
||||||
|
|
||||||
if error:
|
if error:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
@@ -204,13 +214,13 @@ class BedrockModel(BaseChatModel):
|
|||||||
# Run the blocking boto3 call in a thread pool
|
# Run the blocking boto3 call in a thread pool
|
||||||
response = await run_in_threadpool(bedrock_runtime.converse, **args)
|
response = await run_in_threadpool(bedrock_runtime.converse, **args)
|
||||||
except bedrock_runtime.exceptions.ValidationException as e:
|
except bedrock_runtime.exceptions.ValidationException as e:
|
||||||
logger.error("Validation Error: " + str(e))
|
logger.error("Bedrock validation error for model %s: %s", chat_request.model, str(e))
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
except bedrock_runtime.exceptions.ThrottlingException as e:
|
except bedrock_runtime.exceptions.ThrottlingException as e:
|
||||||
logger.error("Throttling Error: " + str(e))
|
logger.warning("Bedrock throttling for model %s: %s", chat_request.model, str(e))
|
||||||
raise HTTPException(status_code=429, detail=str(e))
|
raise HTTPException(status_code=429, detail=str(e))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(e)
|
logger.error("Bedrock invocation failed for model %s: %s", chat_request.model, str(e))
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
return response
|
return response
|
||||||
|
|
||||||
@@ -270,6 +280,7 @@ class BedrockModel(BaseChatModel):
|
|||||||
# return an [DONE] message at the end.
|
# return an [DONE] message at the end.
|
||||||
yield self.stream_response_to_bytes()
|
yield self.stream_response_to_bytes()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
logger.error("Stream error for model %s: %s", chat_request.model, str(e))
|
||||||
error_event = Error(error=ErrorMessage(message=str(e)))
|
error_event = Error(error=ErrorMessage(message=str(e)))
|
||||||
yield self.stream_response_to_bytes(error_event)
|
yield self.stream_response_to_bytes(error_event)
|
||||||
|
|
||||||
@@ -317,7 +328,16 @@ class BedrockModel(BaseChatModel):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
elif isinstance(message, AssistantMessage):
|
elif isinstance(message, AssistantMessage):
|
||||||
if message.content.strip():
|
# Check if message has content that's not empty
|
||||||
|
has_content = False
|
||||||
|
if isinstance(message.content, str):
|
||||||
|
has_content = message.content.strip() != ""
|
||||||
|
elif isinstance(message.content, list):
|
||||||
|
has_content = len(message.content) > 0
|
||||||
|
elif message.content is not None:
|
||||||
|
has_content = True
|
||||||
|
|
||||||
|
if has_content:
|
||||||
# Text message
|
# Text message
|
||||||
messages.append(
|
messages.append(
|
||||||
{
|
{
|
||||||
@@ -349,6 +369,10 @@ class BedrockModel(BaseChatModel):
|
|||||||
# Bedrock does not support tool role,
|
# Bedrock does not support tool role,
|
||||||
# Add toolResult to content
|
# Add toolResult to content
|
||||||
# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolResultBlock.html
|
# https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolResultBlock.html
|
||||||
|
|
||||||
|
# Handle different content formats from OpenAI SDK
|
||||||
|
tool_content = self._extract_tool_content(message.content)
|
||||||
|
|
||||||
messages.append(
|
messages.append(
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
@@ -356,7 +380,7 @@ class BedrockModel(BaseChatModel):
|
|||||||
{
|
{
|
||||||
"toolResult": {
|
"toolResult": {
|
||||||
"toolUseId": message.tool_call_id,
|
"toolUseId": message.tool_call_id,
|
||||||
"content": [{"text": message.content}],
|
"content": [{"text": tool_content}],
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@@ -368,6 +392,57 @@ class BedrockModel(BaseChatModel):
|
|||||||
continue
|
continue
|
||||||
return self._reframe_multi_payloard(messages)
|
return self._reframe_multi_payloard(messages)
|
||||||
|
|
||||||
|
def _extract_tool_content(self, content) -> str:
|
||||||
|
"""Extract text content from various OpenAI SDK tool message formats.
|
||||||
|
|
||||||
|
Handles:
|
||||||
|
- String content (legacy format)
|
||||||
|
- List of content objects (OpenAI SDK 1.91.0+)
|
||||||
|
- Nested JSON structures within text content
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
|
||||||
|
if isinstance(content, list):
|
||||||
|
text_parts = []
|
||||||
|
for i, item in enumerate(content):
|
||||||
|
if isinstance(item, dict):
|
||||||
|
# Handle dict with 'text' field
|
||||||
|
if "text" in item:
|
||||||
|
item_text = item["text"]
|
||||||
|
if isinstance(item_text, str):
|
||||||
|
# Try to parse as JSON if it looks like JSON
|
||||||
|
if item_text.strip().startswith('{') and item_text.strip().endswith('}'):
|
||||||
|
try:
|
||||||
|
parsed_json = json.loads(item_text)
|
||||||
|
# Convert JSON object to readable text
|
||||||
|
text_parts.append(json.dumps(parsed_json, indent=2))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# Silently fallback to original text
|
||||||
|
text_parts.append(item_text)
|
||||||
|
else:
|
||||||
|
text_parts.append(item_text)
|
||||||
|
else:
|
||||||
|
text_parts.append(str(item_text))
|
||||||
|
else:
|
||||||
|
# Handle other dict formats - convert to JSON string
|
||||||
|
text_parts.append(json.dumps(item, indent=2))
|
||||||
|
elif hasattr(item, 'text'):
|
||||||
|
# Handle ToolContent objects
|
||||||
|
text_parts.append(item.text)
|
||||||
|
else:
|
||||||
|
# Convert any other type to string
|
||||||
|
text_parts.append(str(item))
|
||||||
|
return "\n".join(text_parts)
|
||||||
|
|
||||||
|
# Fallback for any other type
|
||||||
|
return str(content)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Tool content extraction failed: %s", str(e))
|
||||||
|
# Return a safe fallback
|
||||||
|
return str(content) if content is not None else ""
|
||||||
|
|
||||||
def _reframe_multi_payloard(self, messages: list) -> list:
|
def _reframe_multi_payloard(self, messages: list) -> list:
|
||||||
"""Receive messages and reformat them to comply with the Claude format
|
"""Receive messages and reformat them to comply with the Claude format
|
||||||
|
|
||||||
|
|||||||
@@ -45,6 +45,11 @@ class ImageContent(BaseModel):
|
|||||||
image_url: ImageUrl
|
image_url: ImageUrl
|
||||||
|
|
||||||
|
|
||||||
|
class ToolContent(BaseModel):
|
||||||
|
type: Literal["text"] = "text"
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
class SystemMessage(BaseModel):
|
class SystemMessage(BaseModel):
|
||||||
name: str | None = None
|
name: str | None = None
|
||||||
role: Literal["system"] = "system"
|
role: Literal["system"] = "system"
|
||||||
@@ -66,7 +71,7 @@ class AssistantMessage(BaseModel):
|
|||||||
|
|
||||||
class ToolMessage(BaseModel):
|
class ToolMessage(BaseModel):
|
||||||
role: Literal["tool"] = "tool"
|
role: Literal["tool"] = "tool"
|
||||||
content: str
|
content: str | list[ToolContent] | list[dict]
|
||||||
tool_call_id: str
|
tool_call_id: str
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user