Add Tool call support
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
from abc import ABC
|
||||
from typing import AsyncIterable, Iterable
|
||||
|
||||
import boto3
|
||||
@@ -13,17 +14,21 @@ from api.schema import (
|
||||
# Chat
|
||||
ChatResponse,
|
||||
ChatRequest,
|
||||
ChatRequestMessage,
|
||||
Choice,
|
||||
ChatResponseMessage,
|
||||
Usage,
|
||||
ChatStreamResponse,
|
||||
ChoiceDelta,
|
||||
ImageContent,
|
||||
TextContent,
|
||||
ResponseFunction,
|
||||
ToolCall,
|
||||
Tool,
|
||||
# Embeddings
|
||||
EmbeddingsRequest,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingsUsage,
|
||||
Embedding, TextContent,
|
||||
Embedding,
|
||||
)
|
||||
from api.setting import DEBUG, AWS_REGION
|
||||
|
||||
@@ -58,7 +63,7 @@ ENCODER = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
|
||||
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html
|
||||
class BedrockModel(BaseChatModel):
|
||||
class BedrockModel(BaseChatModel, ABC):
|
||||
accept = "application/json"
|
||||
content_type = "application/json"
|
||||
|
||||
@@ -86,17 +91,30 @@ class BedrockModel(BaseChatModel):
|
||||
model: str,
|
||||
message: str,
|
||||
message_id: str,
|
||||
tools_message: str | None = None,
|
||||
input_tokens: int = 0,
|
||||
output_tokens: int = 0,
|
||||
) -> ChatResponse:
|
||||
choice = Choice(
|
||||
index=0,
|
||||
message=ChatResponseMessage(
|
||||
role="assistant",
|
||||
content=message,
|
||||
),
|
||||
finish_reason="stop",
|
||||
)
|
||||
if tools_message:
|
||||
# For tool response, the content is empty
|
||||
tools = self._parse_tools_response(tools_message)
|
||||
choice = Choice(
|
||||
index=0,
|
||||
message=ChatResponseMessage(
|
||||
role="assistant",
|
||||
tool_calls=tools,
|
||||
),
|
||||
finish_reason="stop",
|
||||
)
|
||||
else:
|
||||
choice = Choice(
|
||||
index=0,
|
||||
message=ChatResponseMessage(
|
||||
role="assistant",
|
||||
content=message,
|
||||
),
|
||||
finish_reason="stop",
|
||||
)
|
||||
response = ChatResponse(
|
||||
id=message_id,
|
||||
model=model,
|
||||
@@ -132,25 +150,31 @@ class BedrockModel(BaseChatModel):
|
||||
return response
|
||||
|
||||
|
||||
def get_model(model_id: str) -> BedrockModel:
|
||||
model_name = SUPPORTED_BEDROCK_MODELS.get(model_id, "")
|
||||
if DEBUG:
|
||||
logger.info("model name is " + model_name)
|
||||
if model_name in ["Claude Instant", "Claude", "Claude 3 Sonnet", "Claude 3 Haiku"]:
|
||||
return ClaudeModel()
|
||||
elif model_name in ["Llama 2 Chat 13B", "Llama 2 Chat 70B"]:
|
||||
return Llama2Model()
|
||||
elif model_name in ["Mistral 7B Instruct", "Mixtral 8x7B Instruct"]:
|
||||
return MistralModel()
|
||||
else:
|
||||
logger.error("Unsupported model id " + model_id)
|
||||
raise ValueError("Invalid model ID")
|
||||
|
||||
|
||||
class ClaudeModel(BedrockModel):
|
||||
anthropic_version = "bedrock-2023-05-31"
|
||||
|
||||
def _get_base64_image(self, image_url: str):
|
||||
def _parse_tools_response(self, tools_messages: str) -> list[ToolCall]:
|
||||
"""Parse the tools response
|
||||
|
||||
Example tool message like:
|
||||
\n{\n "name": "get_current_weather",\n "arguments": {\n "location": "Shanghai"... }\n}\n
|
||||
"""
|
||||
function = json.loads(
|
||||
tools_messages.replace("\n", " ").encode("unicode_escape")
|
||||
)
|
||||
|
||||
args = json.dumps(function.get("arguments", {}))
|
||||
function = ResponseFunction(
|
||||
name=function["name"], arguments=args.replace("\\\\n", "\\n")
|
||||
)
|
||||
return [
|
||||
ToolCall(
|
||||
id="0",
|
||||
function=function,
|
||||
)
|
||||
]
|
||||
|
||||
def _get_base64_image(self, image_url: str) -> str:
|
||||
# Send a request to the image URL
|
||||
response = requests.get(image_url)
|
||||
# Check if the request was successful
|
||||
@@ -159,34 +183,44 @@ class ClaudeModel(BedrockModel):
|
||||
image_content = response.content
|
||||
# Encode the image content as base64
|
||||
base64_image = base64.b64encode(image_content)
|
||||
return base64_image.decode('utf-8')
|
||||
return base64_image.decode("utf-8")
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Unable to access the image url")
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Unable to access the image url"
|
||||
)
|
||||
|
||||
def _parse_messages(self, messages: list[ChatRequestMessage]) -> list[dict]:
|
||||
# Refer to: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
|
||||
converted_messages = []
|
||||
for msg in messages:
|
||||
if isinstance(msg.content, str):
|
||||
converted_messages.append({"role": msg.role, "content": msg.content})
|
||||
continue
|
||||
|
||||
content_parts = []
|
||||
for part in msg.content:
|
||||
if isinstance(part, TextContent):
|
||||
content_parts.append(part.model_dump())
|
||||
else:
|
||||
content_parts.append({
|
||||
def _parse_content_parts(
|
||||
self, content: list[TextContent | ImageContent]
|
||||
) -> list[dict]:
|
||||
# See: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
|
||||
content_parts = []
|
||||
for part in content:
|
||||
if isinstance(part, TextContent):
|
||||
content_parts.append(part.model_dump())
|
||||
else:
|
||||
content_parts.append(
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/jpeg",
|
||||
"data": self._get_base64_image(part.image_url.url)
|
||||
}
|
||||
})
|
||||
"data": self._get_base64_image(part.image_url.url),
|
||||
},
|
||||
}
|
||||
)
|
||||
return content_parts
|
||||
|
||||
converted_messages.append({"role": msg.role, "content": content_parts})
|
||||
return converted_messages
|
||||
def _create_tool_prompt(self, tools: list[Tool]) -> str:
|
||||
tool_prompt = "\nYou have access to the following tools:\n"
|
||||
tool_prompt += json.dumps(
|
||||
[tool.function.model_dump() for tool in tools], indent=2
|
||||
)
|
||||
tool_prompt += (
|
||||
"\nIf you need to use one of the above tools, "
|
||||
"only respond with a JSON object matching the following schema inside a <tool></tool> xml tag: \n"
|
||||
'{"name": $TOOL_NAME, "arguments": {"$PARAMETER_NAME": "$PARAMETER_VALUE", ...}\n'
|
||||
)
|
||||
return tool_prompt
|
||||
|
||||
def _parse_args(self, chat_request: ChatRequest) -> dict:
|
||||
args = {
|
||||
@@ -195,25 +229,73 @@ class ClaudeModel(BedrockModel):
|
||||
"top_p": chat_request.top_p,
|
||||
"temperature": chat_request.temperature,
|
||||
}
|
||||
start = 0
|
||||
if chat_request.messages[0].role == "system":
|
||||
args["system"] = chat_request.messages[0].content
|
||||
start = 1
|
||||
args["messages"] = self._parse_messages(chat_request.messages[start:])
|
||||
system_prompt = ""
|
||||
converted_messages = []
|
||||
for message in chat_request.messages:
|
||||
if message.role == "system":
|
||||
system_prompt += message.content + "\n"
|
||||
elif message.role == "user" and not isinstance(message.content, str):
|
||||
converted_messages.append(
|
||||
{
|
||||
"role": message.role,
|
||||
"content": self._parse_content_parts(message.content),
|
||||
}
|
||||
)
|
||||
elif message.role == "assistant" and not message.content:
|
||||
# if content is empty
|
||||
# create the content using the tool call info.
|
||||
tool_content = "Should use {} tool with args: {}".format(
|
||||
message.tool_calls[0].function.name,
|
||||
message.tool_calls[0].function.arguments,
|
||||
)
|
||||
converted_messages.append(
|
||||
{"role": message.role, "content": tool_content}
|
||||
)
|
||||
elif message.role == "tool":
|
||||
# Since bedrock does not support tool role
|
||||
# Convert the tool message to a user message.
|
||||
converted_messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": "The result of the tool call is " + message.content,
|
||||
}
|
||||
)
|
||||
else:
|
||||
converted_messages.append(
|
||||
{"role": message.role, "content": message.content}
|
||||
)
|
||||
|
||||
if chat_request.tools:
|
||||
system_prompt += self._create_tool_prompt(chat_request.tools)
|
||||
|
||||
args["messages"] = converted_messages
|
||||
if system_prompt:
|
||||
if DEBUG:
|
||||
logger.info("System Prompt: " + system_prompt)
|
||||
args["system"] = system_prompt.replace("\n", "")
|
||||
return args
|
||||
|
||||
def chat(self, chat_request: ChatRequest) -> ChatResponse:
|
||||
if DEBUG:
|
||||
logger.info("Raw request: " + chat_request.model_dump_json())
|
||||
response = self._invoke_model(
|
||||
args=self._parse_args(chat_request), model_id=chat_request.model
|
||||
)
|
||||
response_body = json.loads(response.get("body").read())
|
||||
if DEBUG:
|
||||
logger.info("Bedrock response body: " + str(response_body))
|
||||
message = response_body["content"][0]["text"]
|
||||
|
||||
tools_message = None
|
||||
start = message.find("<tool>")
|
||||
end = message.find("</tool>")
|
||||
if start != -1 and end != -1:
|
||||
tools_message = message[start + 6: end]
|
||||
return self._create_response(
|
||||
model=chat_request.model,
|
||||
message=response_body["content"][0]["text"],
|
||||
message_id=response_body["id"],
|
||||
tools_message=tools_message,
|
||||
input_tokens=response_body["usage"]["input_tokens"],
|
||||
output_tokens=response_body["usage"]["output_tokens"],
|
||||
)
|
||||
@@ -256,7 +338,7 @@ class ClaudeModel(BedrockModel):
|
||||
|
||||
class Llama2Model(BedrockModel):
|
||||
|
||||
def _convert_prompt(self, messages: list[ChatRequestMessage]) -> str:
|
||||
def _convert_prompt(self, chat_request: ChatRequest) -> str:
|
||||
"""Create a prompt message follow below example:
|
||||
|
||||
<s>[INST] <<SYS>>\n{your_system_message}\n<</SYS>>\n\n{user_message_1} [/INST] {model_reply_1}</s>
|
||||
@@ -264,21 +346,26 @@ class Llama2Model(BedrockModel):
|
||||
"""
|
||||
if DEBUG:
|
||||
logger.info("Convert below messages to prompt for Llama 2: ")
|
||||
for msg in messages:
|
||||
for msg in chat_request.messages:
|
||||
logger.info(msg.model_dump_json())
|
||||
bos_token = "<s>"
|
||||
eos_token = "</s>"
|
||||
prompt = bos_token + "[INST] "
|
||||
start = 0
|
||||
prompt = ""
|
||||
end_turn = False
|
||||
if messages[0].role == "system":
|
||||
prompt += "<<SYS>>\n" + messages[0].content + "\n<<SYS>>\n\n"
|
||||
start = 1
|
||||
# TODO: Add validation
|
||||
for i in range(start, len(messages)):
|
||||
msg = messages[i]
|
||||
system_prompt = ""
|
||||
for msg in chat_request.messages:
|
||||
if msg.role == "system":
|
||||
system_prompt += "\n" + msg.content + "\n"
|
||||
continue
|
||||
if msg.role == "tool":
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Tool prompt is not supported for Llama 2 model",
|
||||
)
|
||||
if not isinstance(msg.content, str):
|
||||
raise HTTPException(status_code=400, detail="Content must be a string for Llama 2 model")
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Content must be a string for Llama 2 model"
|
||||
)
|
||||
if msg.role == "user":
|
||||
if end_turn:
|
||||
prompt += bos_token + "[INST] "
|
||||
@@ -287,12 +374,16 @@ class Llama2Model(BedrockModel):
|
||||
else:
|
||||
prompt += msg.content + eos_token
|
||||
end_turn = True
|
||||
|
||||
if system_prompt:
|
||||
system_prompt = "<<SYS>>" + system_prompt + "<</SYS>>"
|
||||
prompt = bos_token + "[INST] " + system_prompt + prompt
|
||||
if DEBUG:
|
||||
logger.info("Converted prompt: " + prompt.replace("\n", "\\n"))
|
||||
return prompt
|
||||
|
||||
def _parse_args(self, chat_request: ChatRequest) -> dict:
|
||||
prompt = self._convert_prompt(chat_request.messages)
|
||||
prompt = self._convert_prompt(chat_request)
|
||||
return {
|
||||
"prompt": prompt,
|
||||
"max_gen_len": chat_request.max_tokens,
|
||||
@@ -340,29 +431,36 @@ class Llama2Model(BedrockModel):
|
||||
|
||||
|
||||
class MistralModel(BedrockModel):
|
||||
def _convert_prompt(self, messages: list[ChatRequestMessage]) -> str:
|
||||
def _convert_prompt(self, chat_request: ChatRequest) -> str:
|
||||
"""Create a prompt message follow below example:
|
||||
|
||||
<s>[INST] {your_system_message}\n{user_message_1} [/INST] {model_reply_1}</s>
|
||||
<s>[INST] {user_message_2} [/INST]
|
||||
"""
|
||||
# TODO: maybe reuse the Llama 2 one.
|
||||
if DEBUG:
|
||||
logger.info("Convert below messages to prompt for Llama 2: ")
|
||||
for msg in messages:
|
||||
logger.info("Convert below messages to prompt for Mistral/Mixtral model: ")
|
||||
for msg in chat_request.messages:
|
||||
logger.info(msg.model_dump_json())
|
||||
bos_token = "<s>"
|
||||
eos_token = "</s>"
|
||||
prompt = bos_token + "[INST] "
|
||||
start = 0
|
||||
prompt = ""
|
||||
end_turn = False
|
||||
if messages[0].role == "system":
|
||||
prompt += messages[0].content + "\n"
|
||||
start = 1
|
||||
# TODO: Add validation
|
||||
for i in range(start, len(messages)):
|
||||
msg = messages[i]
|
||||
system_prompt = ""
|
||||
for msg in chat_request.messages:
|
||||
if msg.role == "system":
|
||||
system_prompt += "\n" + msg.content + "\n"
|
||||
continue
|
||||
if msg.role == "tool":
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Tool prompt is not supported for Mistral/Mixtral model",
|
||||
)
|
||||
if not isinstance(msg.content, str):
|
||||
raise HTTPException(status_code=400, detail="Content must be a string for Mistral/Mixtral model")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Content must be a string for Mistral/Mixtral model",
|
||||
)
|
||||
if msg.role == "user":
|
||||
if end_turn:
|
||||
prompt += bos_token + "[INST] "
|
||||
@@ -371,12 +469,14 @@ class MistralModel(BedrockModel):
|
||||
else:
|
||||
prompt += msg.content + eos_token
|
||||
end_turn = True
|
||||
|
||||
prompt = bos_token + "[INST] " + system_prompt + prompt
|
||||
if DEBUG:
|
||||
logger.info("Converted prompt: " + prompt.replace("\n", "\\n"))
|
||||
return prompt
|
||||
|
||||
def _parse_args(self, chat_request: ChatRequest) -> dict:
|
||||
prompt = self._convert_prompt(chat_request.messages)
|
||||
prompt = self._convert_prompt(chat_request)
|
||||
return {
|
||||
"prompt": prompt,
|
||||
"max_tokens": chat_request.max_tokens,
|
||||
@@ -422,7 +522,7 @@ class MistralModel(BedrockModel):
|
||||
yield self._stream_response_to_bytes(response)
|
||||
|
||||
|
||||
class BedrockEmbeddingsModel(BaseEmbeddingsModel):
|
||||
class BedrockEmbeddingsModel(BaseEmbeddingsModel, ABC):
|
||||
accept = "application/json"
|
||||
content_type = "application/json"
|
||||
|
||||
@@ -446,10 +546,8 @@ class BedrockEmbeddingsModel(BaseEmbeddingsModel):
|
||||
output_tokens: int = 0,
|
||||
) -> EmbeddingsResponse:
|
||||
data = [
|
||||
Embedding(
|
||||
index=i,
|
||||
embedding=embedding
|
||||
) for i, embedding in enumerate(embeddings)
|
||||
Embedding(index=i, embedding=embedding)
|
||||
for i, embedding in enumerate(embeddings)
|
||||
]
|
||||
response = EmbeddingsResponse(
|
||||
data=data,
|
||||
@@ -465,19 +563,6 @@ class BedrockEmbeddingsModel(BaseEmbeddingsModel):
|
||||
return response
|
||||
|
||||
|
||||
def get_embeddings_model(model_id: str) -> BedrockEmbeddingsModel:
|
||||
model_name = SUPPORTED_BEDROCK_EMBEDDING_MODELS.get(model_id, "")
|
||||
if DEBUG:
|
||||
logger.info("model name is " + model_name)
|
||||
if model_name in ["Cohere Embed Multilingual", "Cohere Embed English"]:
|
||||
return CohereEmbeddingsModel()
|
||||
elif model_name in ["Titan Embeddings G1 - Text", "Titan Multimodal Embeddings G1"]:
|
||||
return TitanEmbeddingsModel()
|
||||
else:
|
||||
logger.error("Unsupported model id " + model_id)
|
||||
raise ValueError("Invalid model ID")
|
||||
|
||||
|
||||
class CohereEmbeddingsModel(BedrockEmbeddingsModel):
|
||||
|
||||
def _parse_args(self, embeddings_request: EmbeddingsRequest) -> dict:
|
||||
@@ -528,17 +613,25 @@ class TitanEmbeddingsModel(BedrockEmbeddingsModel):
|
||||
def _parse_args(self, embeddings_request: EmbeddingsRequest) -> dict:
|
||||
if isinstance(embeddings_request.input, str):
|
||||
input_text = embeddings_request.input
|
||||
elif isinstance(embeddings_request.input, list) and len(embeddings_request.input) == 1:
|
||||
elif (
|
||||
isinstance(embeddings_request.input, list)
|
||||
and len(embeddings_request.input) == 1
|
||||
):
|
||||
input_text = embeddings_request.input[0]
|
||||
else:
|
||||
raise ValueError("Amazon Titan Embeddings models support only single strings as input.")
|
||||
raise ValueError(
|
||||
"Amazon Titan Embeddings models support only single strings as input."
|
||||
)
|
||||
args = {
|
||||
"inputText": input_text,
|
||||
# Note: inputImage is not supported!
|
||||
}
|
||||
if embeddings_request.model == "amazon.titan-embed-image-v1":
|
||||
args["embeddingConfig"] = embeddings_request.embedding_config if embeddings_request.embedding_config else {
|
||||
"outputEmbeddingLength": 1024}
|
||||
args["embeddingConfig"] = (
|
||||
embeddings_request.embedding_config
|
||||
if embeddings_request.embedding_config
|
||||
else {"outputEmbeddingLength": 1024}
|
||||
)
|
||||
return args
|
||||
|
||||
def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse:
|
||||
@@ -552,5 +645,39 @@ class TitanEmbeddingsModel(BedrockEmbeddingsModel):
|
||||
return self._create_response(
|
||||
embeddings=[response_body["embedding"]],
|
||||
model=embeddings_request.model,
|
||||
input_tokens=response_body["inputTextTokenCount"]
|
||||
input_tokens=response_body["inputTextTokenCount"],
|
||||
)
|
||||
|
||||
|
||||
def get_model(model_id: str) -> BedrockModel:
|
||||
model_name = SUPPORTED_BEDROCK_MODELS.get(model_id, "")
|
||||
if DEBUG:
|
||||
logger.info("model name is " + model_name)
|
||||
if model_name in ["Claude Instant", "Claude", "Claude 3 Sonnet", "Claude 3 Haiku"]:
|
||||
return ClaudeModel()
|
||||
elif model_name in ["Llama 2 Chat 13B", "Llama 2 Chat 70B"]:
|
||||
return Llama2Model()
|
||||
elif model_name in ["Mistral 7B Instruct", "Mixtral 8x7B Instruct"]:
|
||||
return MistralModel()
|
||||
else:
|
||||
logger.error("Unsupported model id " + model_id)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Unsupported model id " + model_id,
|
||||
)
|
||||
|
||||
|
||||
def get_embeddings_model(model_id: str) -> BedrockEmbeddingsModel:
|
||||
model_name = SUPPORTED_BEDROCK_EMBEDDING_MODELS.get(model_id, "")
|
||||
if DEBUG:
|
||||
logger.info("model name is " + model_name)
|
||||
if model_name in ["Cohere Embed Multilingual", "Cohere Embed English"]:
|
||||
return CohereEmbeddingsModel()
|
||||
elif model_name in ["Titan Embeddings G1 - Text", "Titan Multimodal Embeddings G1"]:
|
||||
return TitanEmbeddingsModel()
|
||||
else:
|
||||
logger.error("Unsupported model id " + model_id)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Unsupported model id " + model_id,
|
||||
)
|
||||
|
||||
@@ -16,6 +16,17 @@ class Models(BaseModel):
|
||||
data: list[Model] = []
|
||||
|
||||
|
||||
class ResponseFunction(BaseModel):
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
id: str
|
||||
type: Literal["function"] = "function"
|
||||
function: ResponseFunction
|
||||
|
||||
|
||||
class TextContent(BaseModel):
|
||||
type: Literal["text"] = "text"
|
||||
text: str
|
||||
@@ -31,12 +42,31 @@ class ImageContent(BaseModel):
|
||||
image_url: ImageUrl
|
||||
|
||||
|
||||
class ChatRequestMessage(BaseModel):
|
||||
class SystemMessage(BaseModel):
|
||||
name: str | None = None
|
||||
role: Literal["user", "assistant", "system"]
|
||||
role: Literal["system"] = "system"
|
||||
content: str
|
||||
|
||||
|
||||
class UserMessage(BaseModel):
|
||||
name: str | None = None
|
||||
role: Literal["user"] = "user"
|
||||
content: str | list[TextContent | ImageContent]
|
||||
|
||||
|
||||
class AssistantMessage(BaseModel):
|
||||
name: str | None = None
|
||||
role: Literal["assistant"] = "assistant"
|
||||
content: str | None
|
||||
tool_calls: list[ToolCall] | None = None
|
||||
|
||||
|
||||
class ToolMessage(BaseModel):
|
||||
role: Literal["tool"] = "tool"
|
||||
content: str
|
||||
tool_call_id: str
|
||||
|
||||
|
||||
class Function(BaseModel):
|
||||
name: str
|
||||
description: str | None = None
|
||||
@@ -49,7 +79,7 @@ class Tool(BaseModel):
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
messages: list[ChatRequestMessage]
|
||||
messages: list[SystemMessage | UserMessage | AssistantMessage | ToolMessage]
|
||||
model: str
|
||||
frequency_penalty: float | None = Field(default=0.0, le=2.0, ge=-2.0) # Not used
|
||||
presence_penalty: float | None = Field(default=0.0, le=2.0, ge=-2.0) # Not used
|
||||
@@ -69,17 +99,6 @@ class Usage(BaseModel):
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class ResponseFunction(BaseModel):
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
id: str
|
||||
type: Literal["function"] = "function"
|
||||
function: ResponseFunction
|
||||
|
||||
|
||||
class ChatResponseMessage(BaseModel):
|
||||
# tool_calls
|
||||
role: Literal["assistant"] | None = None
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
fastapi==0.110.0
|
||||
pydantic==2.6.3
|
||||
uvicorn==0.27.0.post1
|
||||
mangum==0.17.0
|
||||
mangum==0.17.0
|
||||
tiktoken==0.6.0
|
||||
requests==2.31.0
|
||||
Reference in New Issue
Block a user