Update api response
This commit is contained in:
@@ -2,6 +2,7 @@ import base64
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from abc import ABC
|
||||
from typing import AsyncIterable, Iterable, Literal
|
||||
|
||||
@@ -403,12 +404,14 @@ class BedrockModel(BaseChatModel):
|
||||
message.tool_calls = [
|
||||
ToolCall(
|
||||
id=tool["toolUseId"],
|
||||
type="function",
|
||||
function=ResponseFunction(
|
||||
name=tool["name"],
|
||||
arguments=json.dumps(tool["input"]),
|
||||
),
|
||||
)
|
||||
]
|
||||
message.content = None
|
||||
else:
|
||||
message.content = content[0]["text"]
|
||||
|
||||
@@ -417,8 +420,10 @@ class BedrockModel(BaseChatModel):
|
||||
model=model,
|
||||
choices=[
|
||||
Choice(
|
||||
index=0,
|
||||
message=message,
|
||||
finish_reason=finish_reason,
|
||||
finish_reason=self._convert_finish_reason(finish_reason),
|
||||
logprobs=None,
|
||||
)
|
||||
],
|
||||
usage=Usage(
|
||||
@@ -427,6 +432,9 @@ class BedrockModel(BaseChatModel):
|
||||
total_tokens=input_tokens + output_tokens,
|
||||
),
|
||||
)
|
||||
response.system_fingerprint = "fp"
|
||||
response.object = "chat.completion"
|
||||
response.created = int(time.time())
|
||||
return response
|
||||
|
||||
def _create_response_stream(
|
||||
@@ -455,6 +463,7 @@ class BedrockModel(BaseChatModel):
|
||||
message = ChatResponseMessage(
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
type="function",
|
||||
id=delta["toolUse"]["toolUseId"],
|
||||
function=ResponseFunction(
|
||||
name=delta["toolUse"]["name"],
|
||||
@@ -475,6 +484,7 @@ class BedrockModel(BaseChatModel):
|
||||
message = ChatResponseMessage(
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
type="function",
|
||||
function=ResponseFunction(
|
||||
arguments=delta["toolUse"]["input"],
|
||||
)
|
||||
@@ -509,7 +519,7 @@ class BedrockModel(BaseChatModel):
|
||||
index=0,
|
||||
delta=message,
|
||||
logprobs=None,
|
||||
finish_reason=finish_reason,
|
||||
finish_reason=self._convert_finish_reason(finish_reason),
|
||||
)
|
||||
],
|
||||
usage=usage,
|
||||
@@ -615,6 +625,28 @@ class BedrockModel(BaseChatModel):
|
||||
}
|
||||
}
|
||||
|
||||
def _convert_finish_reason(self, finish_reason: str | None) -> str | None:
|
||||
"""
|
||||
Below is a list of finish reason according to OpenAI doc:
|
||||
|
||||
- stop: if the model hit a natural stop point or a provided stop sequence,
|
||||
- length: if the maximum number of tokens specified in the request was reached,
|
||||
- content_filter: if content was omitted due to a flag from our content filters,
|
||||
- tool_calls: if the model called a tool
|
||||
"""
|
||||
if finish_reason:
|
||||
finish_reason_mapping = {
|
||||
"tool_use": "tool_calls",
|
||||
"finished": "stop",
|
||||
"end_turn": "stop",
|
||||
"max_tokens": "length",
|
||||
"stop_sequence": "stop",
|
||||
"complete": "stop",
|
||||
"content_filtered": "content_filter"
|
||||
}
|
||||
return finish_reason_mapping.get(finish_reason.lower(), finish_reason.lower())
|
||||
return None
|
||||
|
||||
|
||||
class BedrockEmbeddingsModel(BaseEmbeddingsModel, ABC):
|
||||
accept = "application/json"
|
||||
|
||||
@@ -15,7 +15,7 @@ router = APIRouter(
|
||||
)
|
||||
|
||||
|
||||
@router.post("/completions", response_model=ChatResponse | ChatStreamResponse, response_model_exclude_none=True)
|
||||
@router.post("/completions", response_model=ChatResponse | ChatStreamResponse, response_model_exclude_unset=True)
|
||||
async def chat_completions(
|
||||
chat_request: Annotated[
|
||||
ChatRequest,
|
||||
|
||||
Reference in New Issue
Block a user