Update api response

This commit is contained in:
Aiden Dai
2024-06-07 10:58:44 +08:00
parent 5f84cef13a
commit 6ef7641a0d
2 changed files with 35 additions and 3 deletions

View File

@@ -2,6 +2,7 @@ import base64
import json import json
import logging import logging
import re import re
import time
from abc import ABC from abc import ABC
from typing import AsyncIterable, Iterable, Literal from typing import AsyncIterable, Iterable, Literal
@@ -403,12 +404,14 @@ class BedrockModel(BaseChatModel):
message.tool_calls = [ message.tool_calls = [
ToolCall( ToolCall(
id=tool["toolUseId"], id=tool["toolUseId"],
type="function",
function=ResponseFunction( function=ResponseFunction(
name=tool["name"], name=tool["name"],
arguments=json.dumps(tool["input"]), arguments=json.dumps(tool["input"]),
), ),
) )
] ]
message.content = None
else: else:
message.content = content[0]["text"] message.content = content[0]["text"]
@@ -417,8 +420,10 @@ class BedrockModel(BaseChatModel):
model=model, model=model,
choices=[ choices=[
Choice( Choice(
index=0,
message=message, message=message,
finish_reason=finish_reason, finish_reason=self._convert_finish_reason(finish_reason),
logprobs=None,
) )
], ],
usage=Usage( usage=Usage(
@@ -427,6 +432,9 @@ class BedrockModel(BaseChatModel):
total_tokens=input_tokens + output_tokens, total_tokens=input_tokens + output_tokens,
), ),
) )
response.system_fingerprint = "fp"
response.object = "chat.completion"
response.created = int(time.time())
return response return response
def _create_response_stream( def _create_response_stream(
@@ -455,6 +463,7 @@ class BedrockModel(BaseChatModel):
message = ChatResponseMessage( message = ChatResponseMessage(
tool_calls=[ tool_calls=[
ToolCall( ToolCall(
type="function",
id=delta["toolUse"]["toolUseId"], id=delta["toolUse"]["toolUseId"],
function=ResponseFunction( function=ResponseFunction(
name=delta["toolUse"]["name"], name=delta["toolUse"]["name"],
@@ -475,6 +484,7 @@ class BedrockModel(BaseChatModel):
message = ChatResponseMessage( message = ChatResponseMessage(
tool_calls=[ tool_calls=[
ToolCall( ToolCall(
type="function",
function=ResponseFunction( function=ResponseFunction(
arguments=delta["toolUse"]["input"], arguments=delta["toolUse"]["input"],
) )
@@ -509,7 +519,7 @@ class BedrockModel(BaseChatModel):
index=0, index=0,
delta=message, delta=message,
logprobs=None, logprobs=None,
finish_reason=finish_reason, finish_reason=self._convert_finish_reason(finish_reason),
) )
], ],
usage=usage, usage=usage,
@@ -615,6 +625,28 @@ class BedrockModel(BaseChatModel):
} }
} }
def _convert_finish_reason(self, finish_reason: str | None) -> str | None:
"""
Below is a list of finish reason according to OpenAI doc:
- stop: if the model hit a natural stop point or a provided stop sequence,
- length: if the maximum number of tokens specified in the request was reached,
- content_filter: if content was omitted due to a flag from our content filters,
- tool_calls: if the model called a tool
"""
if finish_reason:
finish_reason_mapping = {
"tool_use": "tool_calls",
"finished": "stop",
"end_turn": "stop",
"max_tokens": "length",
"stop_sequence": "stop",
"complete": "stop",
"content_filtered": "content_filter"
}
return finish_reason_mapping.get(finish_reason.lower(), finish_reason.lower())
return None
class BedrockEmbeddingsModel(BaseEmbeddingsModel, ABC): class BedrockEmbeddingsModel(BaseEmbeddingsModel, ABC):
accept = "application/json" accept = "application/json"

View File

@@ -15,7 +15,7 @@ router = APIRouter(
) )
@router.post("/completions", response_model=ChatResponse | ChatStreamResponse, response_model_exclude_none=True) @router.post("/completions", response_model=ChatResponse | ChatStreamResponse, response_model_exclude_unset=True)
async def chat_completions( async def chat_completions(
chat_request: Annotated[ chat_request: Annotated[
ChatRequest, ChatRequest,