performance improvement
This commit is contained in:
@@ -29,12 +29,12 @@ class BaseChatModel(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def chat(self, chat_request: ChatRequest) -> ChatResponse:
|
||||
async def chat(self, chat_request: ChatRequest) -> ChatResponse:
|
||||
"""Handle a basic chat completion requests."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
|
||||
async def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
|
||||
"""Handle a basic chat completion requests with stream response."""
|
||||
pass
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ import requests
|
||||
import tiktoken
|
||||
from botocore.config import Config
|
||||
from fastapi import HTTPException
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
|
||||
from api.models.base import BaseChatModel, BaseEmbeddingsModel
|
||||
from api.schema import (
|
||||
@@ -145,7 +146,7 @@ class BedrockModel(BaseChatModel):
|
||||
detail=error,
|
||||
)
|
||||
|
||||
def _invoke_bedrock(self, chat_request: ChatRequest, stream=False):
|
||||
async def _invoke_bedrock(self, chat_request: ChatRequest, stream=False):
|
||||
"""Common logic for invoke bedrock models"""
|
||||
if DEBUG:
|
||||
logger.info("Raw request: " + chat_request.model_dump_json())
|
||||
@@ -157,9 +158,11 @@ class BedrockModel(BaseChatModel):
|
||||
|
||||
try:
|
||||
if stream:
|
||||
response = bedrock_runtime.converse_stream(**args)
|
||||
# Run the blocking boto3 call in a thread pool
|
||||
response = await run_in_threadpool(bedrock_runtime.converse_stream, **args)
|
||||
else:
|
||||
response = bedrock_runtime.converse(**args)
|
||||
# Run the blocking boto3 call in a thread pool
|
||||
response = await run_in_threadpool(bedrock_runtime.converse, **args)
|
||||
except bedrock_runtime.exceptions.ValidationException as e:
|
||||
logger.error("Validation Error: " + str(e))
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
@@ -171,11 +174,11 @@ class BedrockModel(BaseChatModel):
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
return response
|
||||
|
||||
def chat(self, chat_request: ChatRequest) -> ChatResponse:
|
||||
async def chat(self, chat_request: ChatRequest) -> ChatResponse:
|
||||
"""Default implementation for Chat API."""
|
||||
|
||||
message_id = self.generate_message_id()
|
||||
response = self._invoke_bedrock(chat_request)
|
||||
response = await self._invoke_bedrock(chat_request)
|
||||
|
||||
output_message = response["output"]["message"]
|
||||
input_tokens = response["usage"]["inputTokens"]
|
||||
@@ -194,9 +197,9 @@ class BedrockModel(BaseChatModel):
|
||||
logger.info("Proxy response :" + chat_response.model_dump_json())
|
||||
return chat_response
|
||||
|
||||
def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
|
||||
async def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
|
||||
"""Default implementation for Chat Stream API"""
|
||||
response = self._invoke_bedrock(chat_request, stream=True)
|
||||
response = await self._invoke_bedrock(chat_request, stream=True)
|
||||
message_id = self.generate_message_id()
|
||||
stream = response.get("stream")
|
||||
for chunk in stream:
|
||||
|
||||
@@ -40,4 +40,4 @@ async def chat_completions(
|
||||
model.validate(chat_request)
|
||||
if chat_request.stream:
|
||||
return StreamingResponse(content=model.chat_stream(chat_request), media_type="text/event-stream")
|
||||
return model.chat(chat_request)
|
||||
return await model.chat(chat_request)
|
||||
|
||||
Reference in New Issue
Block a user