From 6ef7641a0d4b24c41c465ea8b5b9c07cd676f46d Mon Sep 17 00:00:00 2001 From: Aiden Dai Date: Fri, 7 Jun 2024 10:58:44 +0800 Subject: [PATCH] Update api response --- src/api/models/bedrock.py | 36 ++++++++++++++++++++++++++++++++++-- src/api/routers/chat.py | 2 +- 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/src/api/models/bedrock.py b/src/api/models/bedrock.py index 0fe25d4..96cdc05 100644 --- a/src/api/models/bedrock.py +++ b/src/api/models/bedrock.py @@ -2,6 +2,7 @@ import base64 import json import logging import re +import time from abc import ABC from typing import AsyncIterable, Iterable, Literal @@ -403,12 +404,14 @@ class BedrockModel(BaseChatModel): message.tool_calls = [ ToolCall( id=tool["toolUseId"], + type="function", function=ResponseFunction( name=tool["name"], arguments=json.dumps(tool["input"]), ), ) ] + message.content = None else: message.content = content[0]["text"] @@ -417,8 +420,10 @@ class BedrockModel(BaseChatModel): model=model, choices=[ Choice( + index=0, message=message, - finish_reason=finish_reason, + finish_reason=self._convert_finish_reason(finish_reason), + logprobs=None, ) ], usage=Usage( @@ -427,6 +432,9 @@ class BedrockModel(BaseChatModel): total_tokens=input_tokens + output_tokens, ), ) + response.system_fingerprint = "fp" + response.object = "chat.completion" + response.created = int(time.time()) return response def _create_response_stream( @@ -455,6 +463,7 @@ class BedrockModel(BaseChatModel): message = ChatResponseMessage( tool_calls=[ ToolCall( + type="function", id=delta["toolUse"]["toolUseId"], function=ResponseFunction( name=delta["toolUse"]["name"], @@ -475,6 +484,7 @@ class BedrockModel(BaseChatModel): message = ChatResponseMessage( tool_calls=[ ToolCall( + type="function", function=ResponseFunction( arguments=delta["toolUse"]["input"], ) @@ -509,7 +519,7 @@ class BedrockModel(BaseChatModel): index=0, delta=message, logprobs=None, - finish_reason=finish_reason, + finish_reason=self._convert_finish_reason(finish_reason), ) ], usage=usage, @@ -615,6 +625,28 @@ class BedrockModel(BaseChatModel): } } + def _convert_finish_reason(self, finish_reason: str | None) -> str | None: + """ + Below is a list of finish reason according to OpenAI doc: + + - stop: if the model hit a natural stop point or a provided stop sequence, + - length: if the maximum number of tokens specified in the request was reached, + - content_filter: if content was omitted due to a flag from our content filters, + - tool_calls: if the model called a tool + """ + if finish_reason: + finish_reason_mapping = { + "tool_use": "tool_calls", + "finished": "stop", + "end_turn": "stop", + "max_tokens": "length", + "stop_sequence": "stop", + "complete": "stop", + "content_filtered": "content_filter" + } + return finish_reason_mapping.get(finish_reason.lower(), finish_reason.lower()) + return None + class BedrockEmbeddingsModel(BaseEmbeddingsModel, ABC): accept = "application/json" diff --git a/src/api/routers/chat.py b/src/api/routers/chat.py index 58ea73f..1e48a48 100644 --- a/src/api/routers/chat.py +++ b/src/api/routers/chat.py @@ -15,7 +15,7 @@ router = APIRouter( ) -@router.post("/completions", response_model=ChatResponse | ChatStreamResponse, response_model_exclude_none=True) +@router.post("/completions", response_model=ChatResponse | ChatStreamResponse, response_model_exclude_unset=True) async def chat_completions( chat_request: Annotated[ ChatRequest,