optimize error response in streaming

This commit is contained in:
Aiden Dai
2025-03-26 11:32:39 +08:00
parent 4f1a75b49f
commit c98e123c8f
4 changed files with 64 additions and 40 deletions

View File

@@ -11,6 +11,7 @@ from api.schema import (
# Embeddings
EmbeddingsRequest,
EmbeddingsResponse,
Error,
)
@@ -43,14 +44,19 @@ class BaseChatModel(ABC):
return "chatcmpl-" + str(uuid.uuid4())[:8]
@staticmethod
def stream_response_to_bytes(response: ChatStreamResponse | None = None) -> bytes:
if response:
def stream_response_to_bytes(response: ChatStreamResponse | Error | None = None) -> bytes:
if isinstance(response, Error):
data = response.model_dump_json()
elif isinstance(response, ChatStreamResponse):
# to populate other fields when using exclude_unset=True
response.system_fingerprint = "fp"
response.object = "chat.completion.chunk"
response.created = int(time.time())
return "data: {}\n\n".format(response.model_dump_json(exclude_unset=True)).encode("utf-8")
return "data: [DONE]\n\n".encode("utf-8")
data = response.model_dump_json(exclude_unset=True)
else:
data = "[DONE]"
return f"data: {data}\n\n".encode("utf-8")
class BaseEmbeddingsModel(ABC):

View File

@@ -1,4 +1,3 @@
import asyncio
import base64
import json
import logging
@@ -28,6 +27,8 @@ from api.schema import (
EmbeddingsRequest,
EmbeddingsResponse,
EmbeddingsUsage,
Error,
ErrorMessage,
Function,
ImageContent,
ResponseFunction,
@@ -198,33 +199,41 @@ class BedrockModel(BaseChatModel):
logger.info("Proxy response :" + chat_response.model_dump_json())
return chat_response
async def _async_iterate(self, stream):
"""Helper method to convert sync iterator to async iterator"""
for chunk in stream:
await run_in_threadpool(lambda: chunk)
yield chunk
async def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
"""Default implementation for Chat Stream API"""
response = await self._invoke_bedrock(chat_request, stream=True)
message_id = self.generate_message_id()
stream = response.get("stream")
for chunk in stream:
stream_response = self._create_response_stream(
model_id=chat_request.model, message_id=message_id, chunk=chunk
)
if not stream_response:
continue
if DEBUG:
logger.info("Proxy response :" + stream_response.model_dump_json())
if stream_response.choices:
yield self.stream_response_to_bytes(stream_response)
elif chat_request.stream_options and chat_request.stream_options.include_usage:
# An empty choices for Usage as per OpenAI doc below:
# if you set stream_options: {"include_usage": true}.
# an additional chunk will be streamed before the data: [DONE] message.
# The usage field on this chunk shows the token usage statistics for the entire request,
# and the choices field will always be an empty array.
# All other chunks will also include a usage field, but with a null value.
yield self.stream_response_to_bytes(stream_response)
await asyncio.sleep(0)
try:
response = await self._invoke_bedrock(chat_request, stream=True)
message_id = self.generate_message_id()
stream = response.get("stream")
async for chunk in self._async_iterate(stream):
args = {"model_id": chat_request.model, "message_id": message_id, "chunk": chunk}
stream_response = self._create_response_stream(**args)
if not stream_response:
continue
if DEBUG:
logger.info("Proxy response :" + stream_response.model_dump_json())
if stream_response.choices:
yield self.stream_response_to_bytes(stream_response)
elif chat_request.stream_options and chat_request.stream_options.include_usage:
# An empty choices for Usage as per OpenAI doc below:
# if you set stream_options: {"include_usage": true}.
# an additional chunk will be streamed before the data: [DONE] message.
# The usage field on this chunk shows the token usage statistics for the entire request,
# and the choices field will always be an empty array.
# All other chunks will also include a usage field, but with a null value.
yield self.stream_response_to_bytes(stream_response)
# return an [DONE] message at the end.
yield self.stream_response_to_bytes()
# return an [DONE] message at the end.
yield self.stream_response_to_bytes()
except Exception as e:
error_event = Error(error=ErrorMessage(messsage=str(e)))
yield self.stream_response_to_bytes(error_event)
def _parse_system_prompts(self, chat_request: ChatRequest) -> list[dict[str, str]]:
"""Create system prompts.
@@ -416,29 +425,28 @@ class BedrockModel(BaseChatModel):
}
# add tool config
if chat_request.tools:
args["toolConfig"] = {"tools": [self._convert_tool_spec(t.function) for t in chat_request.tools]}
tool_config = {"tools": [self._convert_tool_spec(t.function) for t in chat_request.tools]}
if chat_request.tool_choice and not chat_request.model.startswith("meta.llama3-1-"):
if isinstance(chat_request.tool_choice, str):
# auto (default) is mapped to {"auto" : {}}
# required is mapped to {"any" : {}}
if chat_request.tool_choice == "required":
args["toolConfig"]["toolChoice"] = {"any": {}}
tool_config["toolChoice"] = {"any": {}}
else:
args["toolConfig"]["toolChoice"] = {"auto": {}}
tool_config["toolChoice"] = {"auto": {}}
else:
# Specific tool to use
assert "function" in chat_request.tool_choice
args["toolConfig"]["toolChoice"] = {
"tool": {"name": chat_request.tool_choice["function"].get("name", "")}
}
tool_config["toolChoice"] = {"tool": {"name": chat_request.tool_choice["function"].get("name", "")}}
args["toolConfig"] = tool_config
return args
def _create_response(
self,
model: str,
message_id: str,
content: list[dict] = None,
content: list[dict] | None = None,
finish_reason: str | None = None,
input_tokens: int = 0,
output_tokens: int = 0,
@@ -622,7 +630,7 @@ class BedrockModel(BaseChatModel):
def _parse_content_parts(
self,
message: UserMessage,
message: UserMessage | AssistantMessage,
model_id: str,
) -> list[dict]:
if isinstance(message.content, str):
@@ -661,7 +669,7 @@ class BedrockModel(BaseChatModel):
@staticmethod
def is_supported_modality(model_id: str, modality: str = "IMAGE") -> bool:
model = bedrock_model_list.get(model_id)
model = bedrock_model_list.get(model_id, {})
modalities = model.get("modalities", [])
if modality in modalities:
return True

View File

@@ -5,7 +5,7 @@ from fastapi.responses import StreamingResponse
from api.auth import api_key_auth
from api.models.bedrock import BedrockModel
from api.schema import ChatRequest, ChatResponse, ChatStreamResponse
from api.schema import ChatRequest, ChatResponse, ChatStreamResponse, Error
from api.setting import DEFAULT_MODEL
router = APIRouter(
@@ -15,7 +15,9 @@ router = APIRouter(
)
@router.post("/completions", response_model=ChatResponse | ChatStreamResponse, response_model_exclude_unset=True)
@router.post(
"/completions", response_model=ChatResponse | ChatStreamResponse | Error, response_model_exclude_unset=True
)
async def chat_completions(
chat_request: Annotated[
ChatRequest,

View File

@@ -176,3 +176,11 @@ class EmbeddingsResponse(BaseModel):
data: list[Embedding]
model: str
usage: EmbeddingsUsage
class ErrorMessage(BaseModel):
messsage: str
class Error(BaseModel):
error: ErrorMessage