Update api response
This commit is contained in:
@@ -2,6 +2,7 @@ import base64
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
import time
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from typing import AsyncIterable, Iterable, Literal
|
from typing import AsyncIterable, Iterable, Literal
|
||||||
|
|
||||||
@@ -403,12 +404,14 @@ class BedrockModel(BaseChatModel):
|
|||||||
message.tool_calls = [
|
message.tool_calls = [
|
||||||
ToolCall(
|
ToolCall(
|
||||||
id=tool["toolUseId"],
|
id=tool["toolUseId"],
|
||||||
|
type="function",
|
||||||
function=ResponseFunction(
|
function=ResponseFunction(
|
||||||
name=tool["name"],
|
name=tool["name"],
|
||||||
arguments=json.dumps(tool["input"]),
|
arguments=json.dumps(tool["input"]),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
message.content = None
|
||||||
else:
|
else:
|
||||||
message.content = content[0]["text"]
|
message.content = content[0]["text"]
|
||||||
|
|
||||||
@@ -417,8 +420,10 @@ class BedrockModel(BaseChatModel):
|
|||||||
model=model,
|
model=model,
|
||||||
choices=[
|
choices=[
|
||||||
Choice(
|
Choice(
|
||||||
|
index=0,
|
||||||
message=message,
|
message=message,
|
||||||
finish_reason=finish_reason,
|
finish_reason=self._convert_finish_reason(finish_reason),
|
||||||
|
logprobs=None,
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
usage=Usage(
|
usage=Usage(
|
||||||
@@ -427,6 +432,9 @@ class BedrockModel(BaseChatModel):
|
|||||||
total_tokens=input_tokens + output_tokens,
|
total_tokens=input_tokens + output_tokens,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
response.system_fingerprint = "fp"
|
||||||
|
response.object = "chat.completion"
|
||||||
|
response.created = int(time.time())
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def _create_response_stream(
|
def _create_response_stream(
|
||||||
@@ -455,6 +463,7 @@ class BedrockModel(BaseChatModel):
|
|||||||
message = ChatResponseMessage(
|
message = ChatResponseMessage(
|
||||||
tool_calls=[
|
tool_calls=[
|
||||||
ToolCall(
|
ToolCall(
|
||||||
|
type="function",
|
||||||
id=delta["toolUse"]["toolUseId"],
|
id=delta["toolUse"]["toolUseId"],
|
||||||
function=ResponseFunction(
|
function=ResponseFunction(
|
||||||
name=delta["toolUse"]["name"],
|
name=delta["toolUse"]["name"],
|
||||||
@@ -475,6 +484,7 @@ class BedrockModel(BaseChatModel):
|
|||||||
message = ChatResponseMessage(
|
message = ChatResponseMessage(
|
||||||
tool_calls=[
|
tool_calls=[
|
||||||
ToolCall(
|
ToolCall(
|
||||||
|
type="function",
|
||||||
function=ResponseFunction(
|
function=ResponseFunction(
|
||||||
arguments=delta["toolUse"]["input"],
|
arguments=delta["toolUse"]["input"],
|
||||||
)
|
)
|
||||||
@@ -509,7 +519,7 @@ class BedrockModel(BaseChatModel):
|
|||||||
index=0,
|
index=0,
|
||||||
delta=message,
|
delta=message,
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
finish_reason=finish_reason,
|
finish_reason=self._convert_finish_reason(finish_reason),
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
usage=usage,
|
usage=usage,
|
||||||
@@ -615,6 +625,28 @@ class BedrockModel(BaseChatModel):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def _convert_finish_reason(self, finish_reason: str | None) -> str | None:
|
||||||
|
"""
|
||||||
|
Below is a list of finish reason according to OpenAI doc:
|
||||||
|
|
||||||
|
- stop: if the model hit a natural stop point or a provided stop sequence,
|
||||||
|
- length: if the maximum number of tokens specified in the request was reached,
|
||||||
|
- content_filter: if content was omitted due to a flag from our content filters,
|
||||||
|
- tool_calls: if the model called a tool
|
||||||
|
"""
|
||||||
|
if finish_reason:
|
||||||
|
finish_reason_mapping = {
|
||||||
|
"tool_use": "tool_calls",
|
||||||
|
"finished": "stop",
|
||||||
|
"end_turn": "stop",
|
||||||
|
"max_tokens": "length",
|
||||||
|
"stop_sequence": "stop",
|
||||||
|
"complete": "stop",
|
||||||
|
"content_filtered": "content_filter"
|
||||||
|
}
|
||||||
|
return finish_reason_mapping.get(finish_reason.lower(), finish_reason.lower())
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class BedrockEmbeddingsModel(BaseEmbeddingsModel, ABC):
|
class BedrockEmbeddingsModel(BaseEmbeddingsModel, ABC):
|
||||||
accept = "application/json"
|
accept = "application/json"
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ router = APIRouter(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/completions", response_model=ChatResponse | ChatStreamResponse, response_model_exclude_none=True)
|
@router.post("/completions", response_model=ChatResponse | ChatStreamResponse, response_model_exclude_unset=True)
|
||||||
async def chat_completions(
|
async def chat_completions(
|
||||||
chat_request: Annotated[
|
chat_request: Annotated[
|
||||||
ChatRequest,
|
ChatRequest,
|
||||||
|
|||||||
Reference in New Issue
Block a user