Update api response

This commit is contained in:
Aiden Dai
2024-06-07 10:58:44 +08:00
parent 5f84cef13a
commit 6ef7641a0d
2 changed files with 35 additions and 3 deletions

View File

@@ -2,6 +2,7 @@ import base64
import json
import logging
import re
import time
from abc import ABC
from typing import AsyncIterable, Iterable, Literal
@@ -403,12 +404,14 @@ class BedrockModel(BaseChatModel):
message.tool_calls = [
ToolCall(
id=tool["toolUseId"],
type="function",
function=ResponseFunction(
name=tool["name"],
arguments=json.dumps(tool["input"]),
),
)
]
message.content = None
else:
message.content = content[0]["text"]
@@ -417,8 +420,10 @@ class BedrockModel(BaseChatModel):
model=model,
choices=[
Choice(
index=0,
message=message,
finish_reason=finish_reason,
finish_reason=self._convert_finish_reason(finish_reason),
logprobs=None,
)
],
usage=Usage(
@@ -427,6 +432,9 @@ class BedrockModel(BaseChatModel):
total_tokens=input_tokens + output_tokens,
),
)
response.system_fingerprint = "fp"
response.object = "chat.completion"
response.created = int(time.time())
return response
def _create_response_stream(
@@ -455,6 +463,7 @@ class BedrockModel(BaseChatModel):
message = ChatResponseMessage(
tool_calls=[
ToolCall(
type="function",
id=delta["toolUse"]["toolUseId"],
function=ResponseFunction(
name=delta["toolUse"]["name"],
@@ -475,6 +484,7 @@ class BedrockModel(BaseChatModel):
message = ChatResponseMessage(
tool_calls=[
ToolCall(
type="function",
function=ResponseFunction(
arguments=delta["toolUse"]["input"],
)
@@ -509,7 +519,7 @@ class BedrockModel(BaseChatModel):
index=0,
delta=message,
logprobs=None,
finish_reason=finish_reason,
finish_reason=self._convert_finish_reason(finish_reason),
)
],
usage=usage,
@@ -615,6 +625,28 @@ class BedrockModel(BaseChatModel):
}
}
def _convert_finish_reason(self, finish_reason: str | None) -> str | None:
"""
Below is a list of finish reason according to OpenAI doc:
- stop: if the model hit a natural stop point or a provided stop sequence,
- length: if the maximum number of tokens specified in the request was reached,
- content_filter: if content was omitted due to a flag from our content filters,
- tool_calls: if the model called a tool
"""
if finish_reason:
finish_reason_mapping = {
"tool_use": "tool_calls",
"finished": "stop",
"end_turn": "stop",
"max_tokens": "length",
"stop_sequence": "stop",
"complete": "stop",
"content_filtered": "content_filter"
}
return finish_reason_mapping.get(finish_reason.lower(), finish_reason.lower())
return None
class BedrockEmbeddingsModel(BaseEmbeddingsModel, ABC):
accept = "application/json"

View File

@@ -15,7 +15,7 @@ router = APIRouter(
)
@router.post("/completions", response_model=ChatResponse | ChatStreamResponse, response_model_exclude_none=True)
@router.post("/completions", response_model=ChatResponse | ChatStreamResponse, response_model_exclude_unset=True)
async def chat_completions(
chat_request: Annotated[
ChatRequest,