optimize error response in streaming
This commit is contained in:
@@ -11,6 +11,7 @@ from api.schema import (
|
||||
# Embeddings
|
||||
EmbeddingsRequest,
|
||||
EmbeddingsResponse,
|
||||
Error,
|
||||
)
|
||||
|
||||
|
||||
@@ -43,14 +44,19 @@ class BaseChatModel(ABC):
|
||||
return "chatcmpl-" + str(uuid.uuid4())[:8]
|
||||
|
||||
@staticmethod
|
||||
def stream_response_to_bytes(response: ChatStreamResponse | None = None) -> bytes:
|
||||
if response:
|
||||
def stream_response_to_bytes(response: ChatStreamResponse | Error | None = None) -> bytes:
|
||||
if isinstance(response, Error):
|
||||
data = response.model_dump_json()
|
||||
elif isinstance(response, ChatStreamResponse):
|
||||
# to populate other fields when using exclude_unset=True
|
||||
response.system_fingerprint = "fp"
|
||||
response.object = "chat.completion.chunk"
|
||||
response.created = int(time.time())
|
||||
return "data: {}\n\n".format(response.model_dump_json(exclude_unset=True)).encode("utf-8")
|
||||
return "data: [DONE]\n\n".encode("utf-8")
|
||||
data = response.model_dump_json(exclude_unset=True)
|
||||
else:
|
||||
data = "[DONE]"
|
||||
|
||||
return f"data: {data}\n\n".encode("utf-8")
|
||||
|
||||
|
||||
class BaseEmbeddingsModel(ABC):
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
@@ -28,6 +27,8 @@ from api.schema import (
|
||||
EmbeddingsRequest,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingsUsage,
|
||||
Error,
|
||||
ErrorMessage,
|
||||
Function,
|
||||
ImageContent,
|
||||
ResponseFunction,
|
||||
@@ -198,33 +199,41 @@ class BedrockModel(BaseChatModel):
|
||||
logger.info("Proxy response :" + chat_response.model_dump_json())
|
||||
return chat_response
|
||||
|
||||
async def _async_iterate(self, stream):
|
||||
"""Helper method to convert sync iterator to async iterator"""
|
||||
for chunk in stream:
|
||||
await run_in_threadpool(lambda: chunk)
|
||||
yield chunk
|
||||
|
||||
async def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
|
||||
"""Default implementation for Chat Stream API"""
|
||||
response = await self._invoke_bedrock(chat_request, stream=True)
|
||||
message_id = self.generate_message_id()
|
||||
stream = response.get("stream")
|
||||
for chunk in stream:
|
||||
stream_response = self._create_response_stream(
|
||||
model_id=chat_request.model, message_id=message_id, chunk=chunk
|
||||
)
|
||||
if not stream_response:
|
||||
continue
|
||||
if DEBUG:
|
||||
logger.info("Proxy response :" + stream_response.model_dump_json())
|
||||
if stream_response.choices:
|
||||
yield self.stream_response_to_bytes(stream_response)
|
||||
elif chat_request.stream_options and chat_request.stream_options.include_usage:
|
||||
# An empty choices for Usage as per OpenAI doc below:
|
||||
# if you set stream_options: {"include_usage": true}.
|
||||
# an additional chunk will be streamed before the data: [DONE] message.
|
||||
# The usage field on this chunk shows the token usage statistics for the entire request,
|
||||
# and the choices field will always be an empty array.
|
||||
# All other chunks will also include a usage field, but with a null value.
|
||||
yield self.stream_response_to_bytes(stream_response)
|
||||
await asyncio.sleep(0)
|
||||
try:
|
||||
response = await self._invoke_bedrock(chat_request, stream=True)
|
||||
message_id = self.generate_message_id()
|
||||
stream = response.get("stream")
|
||||
async for chunk in self._async_iterate(stream):
|
||||
args = {"model_id": chat_request.model, "message_id": message_id, "chunk": chunk}
|
||||
stream_response = self._create_response_stream(**args)
|
||||
if not stream_response:
|
||||
continue
|
||||
if DEBUG:
|
||||
logger.info("Proxy response :" + stream_response.model_dump_json())
|
||||
if stream_response.choices:
|
||||
yield self.stream_response_to_bytes(stream_response)
|
||||
elif chat_request.stream_options and chat_request.stream_options.include_usage:
|
||||
# An empty choices for Usage as per OpenAI doc below:
|
||||
# if you set stream_options: {"include_usage": true}.
|
||||
# an additional chunk will be streamed before the data: [DONE] message.
|
||||
# The usage field on this chunk shows the token usage statistics for the entire request,
|
||||
# and the choices field will always be an empty array.
|
||||
# All other chunks will also include a usage field, but with a null value.
|
||||
yield self.stream_response_to_bytes(stream_response)
|
||||
|
||||
# return an [DONE] message at the end.
|
||||
yield self.stream_response_to_bytes()
|
||||
# return an [DONE] message at the end.
|
||||
yield self.stream_response_to_bytes()
|
||||
except Exception as e:
|
||||
error_event = Error(error=ErrorMessage(messsage=str(e)))
|
||||
yield self.stream_response_to_bytes(error_event)
|
||||
|
||||
def _parse_system_prompts(self, chat_request: ChatRequest) -> list[dict[str, str]]:
|
||||
"""Create system prompts.
|
||||
@@ -416,29 +425,28 @@ class BedrockModel(BaseChatModel):
|
||||
}
|
||||
# add tool config
|
||||
if chat_request.tools:
|
||||
args["toolConfig"] = {"tools": [self._convert_tool_spec(t.function) for t in chat_request.tools]}
|
||||
tool_config = {"tools": [self._convert_tool_spec(t.function) for t in chat_request.tools]}
|
||||
|
||||
if chat_request.tool_choice and not chat_request.model.startswith("meta.llama3-1-"):
|
||||
if isinstance(chat_request.tool_choice, str):
|
||||
# auto (default) is mapped to {"auto" : {}}
|
||||
# required is mapped to {"any" : {}}
|
||||
if chat_request.tool_choice == "required":
|
||||
args["toolConfig"]["toolChoice"] = {"any": {}}
|
||||
tool_config["toolChoice"] = {"any": {}}
|
||||
else:
|
||||
args["toolConfig"]["toolChoice"] = {"auto": {}}
|
||||
tool_config["toolChoice"] = {"auto": {}}
|
||||
else:
|
||||
# Specific tool to use
|
||||
assert "function" in chat_request.tool_choice
|
||||
args["toolConfig"]["toolChoice"] = {
|
||||
"tool": {"name": chat_request.tool_choice["function"].get("name", "")}
|
||||
}
|
||||
tool_config["toolChoice"] = {"tool": {"name": chat_request.tool_choice["function"].get("name", "")}}
|
||||
args["toolConfig"] = tool_config
|
||||
return args
|
||||
|
||||
def _create_response(
|
||||
self,
|
||||
model: str,
|
||||
message_id: str,
|
||||
content: list[dict] = None,
|
||||
content: list[dict] | None = None,
|
||||
finish_reason: str | None = None,
|
||||
input_tokens: int = 0,
|
||||
output_tokens: int = 0,
|
||||
@@ -622,7 +630,7 @@ class BedrockModel(BaseChatModel):
|
||||
|
||||
def _parse_content_parts(
|
||||
self,
|
||||
message: UserMessage,
|
||||
message: UserMessage | AssistantMessage,
|
||||
model_id: str,
|
||||
) -> list[dict]:
|
||||
if isinstance(message.content, str):
|
||||
@@ -661,7 +669,7 @@ class BedrockModel(BaseChatModel):
|
||||
|
||||
@staticmethod
|
||||
def is_supported_modality(model_id: str, modality: str = "IMAGE") -> bool:
|
||||
model = bedrock_model_list.get(model_id)
|
||||
model = bedrock_model_list.get(model_id, {})
|
||||
modalities = model.get("modalities", [])
|
||||
if modality in modalities:
|
||||
return True
|
||||
|
||||
@@ -5,7 +5,7 @@ from fastapi.responses import StreamingResponse
|
||||
|
||||
from api.auth import api_key_auth
|
||||
from api.models.bedrock import BedrockModel
|
||||
from api.schema import ChatRequest, ChatResponse, ChatStreamResponse
|
||||
from api.schema import ChatRequest, ChatResponse, ChatStreamResponse, Error
|
||||
from api.setting import DEFAULT_MODEL
|
||||
|
||||
router = APIRouter(
|
||||
@@ -15,7 +15,9 @@ router = APIRouter(
|
||||
)
|
||||
|
||||
|
||||
@router.post("/completions", response_model=ChatResponse | ChatStreamResponse, response_model_exclude_unset=True)
|
||||
@router.post(
|
||||
"/completions", response_model=ChatResponse | ChatStreamResponse | Error, response_model_exclude_unset=True
|
||||
)
|
||||
async def chat_completions(
|
||||
chat_request: Annotated[
|
||||
ChatRequest,
|
||||
|
||||
@@ -176,3 +176,11 @@ class EmbeddingsResponse(BaseModel):
|
||||
data: list[Embedding]
|
||||
model: str
|
||||
usage: EmbeddingsUsage
|
||||
|
||||
|
||||
class ErrorMessage(BaseModel):
|
||||
messsage: str
|
||||
|
||||
|
||||
class Error(BaseModel):
|
||||
error: ErrorMessage
|
||||
|
||||
Reference in New Issue
Block a user