performance improvement

This commit is contained in:
Aiden Dai
2025-03-13 18:24:08 +08:00
parent fa14ae8c05
commit 0ead770069
3 changed files with 13 additions and 10 deletions

View File

@@ -29,12 +29,12 @@ class BaseChatModel(ABC):
pass
@abstractmethod
def chat(self, chat_request: ChatRequest) -> ChatResponse:
async def chat(self, chat_request: ChatRequest) -> ChatResponse:
"""Handle a basic chat completion requests."""
pass
@abstractmethod
def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
async def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
"""Handle a basic chat completion requests with stream response."""
pass

View File

@@ -12,6 +12,7 @@ import requests
import tiktoken
from botocore.config import Config
from fastapi import HTTPException
from starlette.concurrency import run_in_threadpool
from api.models.base import BaseChatModel, BaseEmbeddingsModel
from api.schema import (
@@ -145,7 +146,7 @@ class BedrockModel(BaseChatModel):
detail=error,
)
def _invoke_bedrock(self, chat_request: ChatRequest, stream=False):
async def _invoke_bedrock(self, chat_request: ChatRequest, stream=False):
"""Common logic for invoke bedrock models"""
if DEBUG:
logger.info("Raw request: " + chat_request.model_dump_json())
@@ -157,9 +158,11 @@ class BedrockModel(BaseChatModel):
try:
if stream:
response = bedrock_runtime.converse_stream(**args)
# Run the blocking boto3 call in a thread pool
response = await run_in_threadpool(bedrock_runtime.converse_stream, **args)
else:
response = bedrock_runtime.converse(**args)
# Run the blocking boto3 call in a thread pool
response = await run_in_threadpool(bedrock_runtime.converse, **args)
except bedrock_runtime.exceptions.ValidationException as e:
logger.error("Validation Error: " + str(e))
raise HTTPException(status_code=400, detail=str(e))
@@ -171,11 +174,11 @@ class BedrockModel(BaseChatModel):
raise HTTPException(status_code=500, detail=str(e))
return response
def chat(self, chat_request: ChatRequest) -> ChatResponse:
async def chat(self, chat_request: ChatRequest) -> ChatResponse:
"""Default implementation for Chat API."""
message_id = self.generate_message_id()
response = self._invoke_bedrock(chat_request)
response = await self._invoke_bedrock(chat_request)
output_message = response["output"]["message"]
input_tokens = response["usage"]["inputTokens"]
@@ -194,9 +197,9 @@ class BedrockModel(BaseChatModel):
logger.info("Proxy response :" + chat_response.model_dump_json())
return chat_response
def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
async def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
"""Default implementation for Chat Stream API"""
response = self._invoke_bedrock(chat_request, stream=True)
response = await self._invoke_bedrock(chat_request, stream=True)
message_id = self.generate_message_id()
stream = response.get("stream")
for chunk in stream:

View File

@@ -40,4 +40,4 @@ async def chat_completions(
model.validate(chat_request)
if chat_request.stream:
return StreamingResponse(content=model.chat_stream(chat_request), media_type="text/event-stream")
return model.chat(chat_request)
return await model.chat(chat_request)