|
|
|
|
@@ -2,7 +2,7 @@ import base64
|
|
|
|
|
import json
|
|
|
|
|
import logging
|
|
|
|
|
import re
|
|
|
|
|
from abc import ABC
|
|
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
|
from typing import AsyncIterable, Iterable, Literal
|
|
|
|
|
|
|
|
|
|
import boto3
|
|
|
|
|
@@ -54,6 +54,8 @@ SUPPORTED_BEDROCK_MODELS = {
|
|
|
|
|
"mistral.mistral-7b-instruct-v0:2": "Mistral 7B Instruct",
|
|
|
|
|
"mistral.mixtral-8x7b-instruct-v0:1": "Mixtral 8x7B Instruct",
|
|
|
|
|
"mistral.mistral-large-2402-v1:0": "Mistral Large",
|
|
|
|
|
"cohere.command-r-v1:0": "Command R",
|
|
|
|
|
"cohere.command-r-plus-v1:0": "Command R+",
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SUPPORTED_BEDROCK_EMBEDDING_MODELS = {
|
|
|
|
|
@@ -72,28 +74,157 @@ class BedrockModel(BaseChatModel, ABC):
|
|
|
|
|
accept = "application/json"
|
|
|
|
|
content_type = "application/json"
|
|
|
|
|
|
|
|
|
|
def _invoke_model(self, args: dict, model_id: str, with_stream: bool = False):
|
|
|
|
|
body = json.dumps(args)
|
|
|
|
|
# Default field name to get the response message
|
|
|
|
|
text_field_name = "text"
|
|
|
|
|
|
|
|
|
|
# Default field name to get the response finish reason
|
|
|
|
|
finish_reason_field_name = "finish_reason"
|
|
|
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
|
def compose_request_body(self, chat_request: ChatRequest) -> str:
|
|
|
|
|
"""Since the request body to Bedrock varies,
|
|
|
|
|
each model should implement this to compose the request body.
|
|
|
|
|
|
|
|
|
|
:param chat_request:
|
|
|
|
|
:return: request body as a string
|
|
|
|
|
"""
|
|
|
|
|
raise NotImplementedError()
|
|
|
|
|
|
|
|
|
|
def chat(self, chat_request: ChatRequest) -> ChatResponse:
|
|
|
|
|
"""Default implementation for Chat API."""
|
|
|
|
|
if DEBUG:
|
|
|
|
|
logger.info("Raw request: " + chat_request.model_dump_json())
|
|
|
|
|
request_body = self.compose_request_body(chat_request)
|
|
|
|
|
|
|
|
|
|
if DEBUG:
|
|
|
|
|
logger.info("Bedrock request: " + request_body)
|
|
|
|
|
|
|
|
|
|
response = self.invoke_model(
|
|
|
|
|
request_body=request_body,
|
|
|
|
|
model_id=chat_request.model,
|
|
|
|
|
)
|
|
|
|
|
message_id = self.generate_message_id()
|
|
|
|
|
return self.parse_response(chat_request, response, message_id)
|
|
|
|
|
|
|
|
|
|
def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
|
|
|
|
|
"""Default implementation for Chat Stream API"""
|
|
|
|
|
if DEBUG:
|
|
|
|
|
logger.info("Raw request: " + chat_request.model_dump_json())
|
|
|
|
|
request_body = self.compose_request_body(chat_request)
|
|
|
|
|
response = self.invoke_model(
|
|
|
|
|
request_body=request_body,
|
|
|
|
|
model_id=chat_request.model,
|
|
|
|
|
with_stream=True,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
message_id = self.generate_message_id()
|
|
|
|
|
for stream_response in self.parse_stream_response(
|
|
|
|
|
chat_request, response, message_id
|
|
|
|
|
):
|
|
|
|
|
if stream_response.choices:
|
|
|
|
|
yield self.stream_response_to_bytes(stream_response)
|
|
|
|
|
elif (
|
|
|
|
|
chat_request.stream_options
|
|
|
|
|
and chat_request.stream_options.include_usage
|
|
|
|
|
):
|
|
|
|
|
# An empty choices for Usage as per OpenAI doc below:
|
|
|
|
|
# if you set stream_options: {"include_usage": true}.
|
|
|
|
|
# an additional chunk will be streamed before the data: [DONE] message.
|
|
|
|
|
# The usage field on this chunk shows the token usage statistics for the entire request,
|
|
|
|
|
# and the choices field will always be an empty array.
|
|
|
|
|
# All other chunks will also include a usage field, but with a null value.
|
|
|
|
|
yield self.stream_response_to_bytes(stream_response)
|
|
|
|
|
# return an [DONE] message at the end.
|
|
|
|
|
yield self.stream_response_to_bytes()
|
|
|
|
|
|
|
|
|
|
def get_message_text(self, response_body: dict) -> str | None:
|
|
|
|
|
"""Default func to get the response message.
|
|
|
|
|
|
|
|
|
|
Ideally, only the field name should be changed."""
|
|
|
|
|
return response_body.get(self.text_field_name)
|
|
|
|
|
|
|
|
|
|
def get_message_finish_reason(self, response_body: dict) -> str | None:
|
|
|
|
|
"""Default func to get the finish message.
|
|
|
|
|
|
|
|
|
|
Ideally, only the field name should be changed."""
|
|
|
|
|
return response_body.get(self.finish_reason_field_name)
|
|
|
|
|
|
|
|
|
|
def get_message_usage(self, response_body: dict) -> tuple[int, int]:
|
|
|
|
|
"""Default func to get the finish message.
|
|
|
|
|
|
|
|
|
|
Can be overridden in the detail model for complex cases."""
|
|
|
|
|
input_tokens = int(response_body.get("prompt_token_count", "0"))
|
|
|
|
|
output_tokens = int(response_body.get("generation_token_count", "0"))
|
|
|
|
|
return input_tokens, output_tokens
|
|
|
|
|
|
|
|
|
|
def parse_response(
|
|
|
|
|
self, chat_request: ChatRequest, service_response: dict, message_id: str
|
|
|
|
|
) -> ChatResponse:
|
|
|
|
|
response_body = json.loads(service_response.get("body").read())
|
|
|
|
|
if DEBUG:
|
|
|
|
|
logger.info("Bedrock response body: " + str(response_body))
|
|
|
|
|
|
|
|
|
|
input_tokens, output_tokens = self.get_message_usage(response_body)
|
|
|
|
|
return self.create_response(
|
|
|
|
|
model=chat_request.model,
|
|
|
|
|
message_id=message_id,
|
|
|
|
|
message=self.get_message_text(response_body),
|
|
|
|
|
finish_reason=self.get_message_finish_reason(response_body),
|
|
|
|
|
input_tokens=input_tokens,
|
|
|
|
|
output_tokens=output_tokens,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def parse_stream_response(
|
|
|
|
|
self, chat_request: ChatRequest, service_response: dict, message_id: str
|
|
|
|
|
) -> Iterable[ChatStreamResponse]:
|
|
|
|
|
|
|
|
|
|
chunk_id = 0
|
|
|
|
|
for event in service_response.get("body"):
|
|
|
|
|
if DEBUG:
|
|
|
|
|
logger.info("Bedrock response chunk: " + str(event))
|
|
|
|
|
chunk = json.loads(event["chunk"]["bytes"])
|
|
|
|
|
chunk_id += 1
|
|
|
|
|
|
|
|
|
|
response = self.create_response_stream(
|
|
|
|
|
model=chat_request.model,
|
|
|
|
|
message_id=message_id,
|
|
|
|
|
chunk_message=self.get_message_text(chunk),
|
|
|
|
|
finish_reason=self.get_message_finish_reason(chunk),
|
|
|
|
|
)
|
|
|
|
|
yield response
|
|
|
|
|
# Get the usage for streaming response anyway.
|
|
|
|
|
if "amazon-bedrock-invocationMetrics" in chunk:
|
|
|
|
|
yield self.create_response_stream(
|
|
|
|
|
model=chat_request.model,
|
|
|
|
|
message_id=message_id,
|
|
|
|
|
input_tokens=chunk["amazon-bedrock-invocationMetrics"][
|
|
|
|
|
"inputTokenCount"
|
|
|
|
|
],
|
|
|
|
|
output_tokens=chunk["amazon-bedrock-invocationMetrics"][
|
|
|
|
|
"outputTokenCount"
|
|
|
|
|
],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def invoke_model(self, request_body: str, model_id: str, with_stream: bool = False):
|
|
|
|
|
if DEBUG:
|
|
|
|
|
logger.info("Invoke Bedrock Model: " + model_id)
|
|
|
|
|
logger.info("Bedrock request body: " + body)
|
|
|
|
|
logger.info("Bedrock request body: " + request_body)
|
|
|
|
|
try:
|
|
|
|
|
if with_stream:
|
|
|
|
|
return bedrock_runtime.invoke_model_with_response_stream(
|
|
|
|
|
body=body,
|
|
|
|
|
body=request_body,
|
|
|
|
|
modelId=model_id,
|
|
|
|
|
accept=self.accept,
|
|
|
|
|
contentType=self.content_type,
|
|
|
|
|
)
|
|
|
|
|
return bedrock_runtime.invoke_model(
|
|
|
|
|
body=body,
|
|
|
|
|
body=request_body,
|
|
|
|
|
modelId=model_id,
|
|
|
|
|
accept=self.accept,
|
|
|
|
|
contentType=self.content_type,
|
|
|
|
|
)
|
|
|
|
|
except bedrock_runtime.exceptions.ValidationException as e:
|
|
|
|
|
print("Validation Exception")
|
|
|
|
|
print(e)
|
|
|
|
|
logger.error("Validation Error: " + str(e))
|
|
|
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(e)
|
|
|
|
|
@@ -101,8 +232,7 @@ class BedrockModel(BaseChatModel, ABC):
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def merge_message(messages: list[dict]) -> list[dict]:
|
|
|
|
|
"""Merge the request messages with the same role as previous message
|
|
|
|
|
"""
|
|
|
|
|
"""Merge the request messages with the same role as previous message"""
|
|
|
|
|
merged_messages = []
|
|
|
|
|
prev_role = None
|
|
|
|
|
merged_content = ""
|
|
|
|
|
@@ -110,10 +240,11 @@ class BedrockModel(BaseChatModel, ABC):
|
|
|
|
|
for message in messages:
|
|
|
|
|
role = message["role"]
|
|
|
|
|
content = message["content"]
|
|
|
|
|
|
|
|
|
|
if role != prev_role or isinstance(content, list):
|
|
|
|
|
if prev_role:
|
|
|
|
|
merged_messages.append({"role": prev_role, "content": merged_content})
|
|
|
|
|
merged_messages.append(
|
|
|
|
|
{"role": prev_role, "content": merged_content}
|
|
|
|
|
)
|
|
|
|
|
if isinstance(content, str):
|
|
|
|
|
merged_content = content
|
|
|
|
|
prev_role = role
|
|
|
|
|
@@ -132,7 +263,7 @@ class BedrockModel(BaseChatModel, ABC):
|
|
|
|
|
return merged_messages
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _create_response(
|
|
|
|
|
def create_response(
|
|
|
|
|
model: str,
|
|
|
|
|
message_id: str,
|
|
|
|
|
message: str | None = None,
|
|
|
|
|
@@ -144,15 +275,17 @@ class BedrockModel(BaseChatModel, ABC):
|
|
|
|
|
response = ChatResponse(
|
|
|
|
|
id=message_id,
|
|
|
|
|
model=model,
|
|
|
|
|
choices=[Choice(
|
|
|
|
|
index=0,
|
|
|
|
|
message=ChatResponseMessage(
|
|
|
|
|
role="assistant",
|
|
|
|
|
tool_calls=tools,
|
|
|
|
|
content=message,
|
|
|
|
|
),
|
|
|
|
|
finish_reason="tool_calls" if tools else finish_reason,
|
|
|
|
|
)],
|
|
|
|
|
choices=[
|
|
|
|
|
Choice(
|
|
|
|
|
index=0,
|
|
|
|
|
message=ChatResponseMessage(
|
|
|
|
|
role="assistant",
|
|
|
|
|
tool_calls=tools,
|
|
|
|
|
content=message,
|
|
|
|
|
),
|
|
|
|
|
finish_reason="tool_calls" if tools else finish_reason,
|
|
|
|
|
)
|
|
|
|
|
],
|
|
|
|
|
usage=Usage(
|
|
|
|
|
prompt_tokens=input_tokens,
|
|
|
|
|
completion_tokens=output_tokens,
|
|
|
|
|
@@ -164,26 +297,42 @@ class BedrockModel(BaseChatModel, ABC):
|
|
|
|
|
return response
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _create_response_stream(
|
|
|
|
|
def create_response_stream(
|
|
|
|
|
model: str,
|
|
|
|
|
message_id: str,
|
|
|
|
|
chunk_message: str | None = None,
|
|
|
|
|
finish_reason: str | None = None,
|
|
|
|
|
tools: list[ToolCall] | None = None,
|
|
|
|
|
input_tokens: int = 0,
|
|
|
|
|
output_tokens: int = 0,
|
|
|
|
|
) -> ChatStreamResponse:
|
|
|
|
|
response = ChatStreamResponse(
|
|
|
|
|
id=message_id,
|
|
|
|
|
model=model,
|
|
|
|
|
choices=[ChoiceDelta(
|
|
|
|
|
index=0,
|
|
|
|
|
delta=ChatResponseMessage(
|
|
|
|
|
role="assistant",
|
|
|
|
|
tool_calls=tools,
|
|
|
|
|
content=chunk_message,
|
|
|
|
|
if chunk_message or finish_reason or tools:
|
|
|
|
|
response = ChatStreamResponse(
|
|
|
|
|
id=message_id,
|
|
|
|
|
model=model,
|
|
|
|
|
choices=[
|
|
|
|
|
ChoiceDelta(
|
|
|
|
|
index=0,
|
|
|
|
|
delta=ChatResponseMessage(
|
|
|
|
|
role="assistant",
|
|
|
|
|
tool_calls=tools,
|
|
|
|
|
content=chunk_message,
|
|
|
|
|
),
|
|
|
|
|
finish_reason=finish_reason,
|
|
|
|
|
)
|
|
|
|
|
],
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
response = ChatStreamResponse(
|
|
|
|
|
id=message_id,
|
|
|
|
|
model=model,
|
|
|
|
|
choices=[],
|
|
|
|
|
usage=Usage(
|
|
|
|
|
prompt_tokens=input_tokens,
|
|
|
|
|
completion_tokens=output_tokens,
|
|
|
|
|
total_tokens=input_tokens + output_tokens,
|
|
|
|
|
),
|
|
|
|
|
finish_reason="tool_calls" if tools else finish_reason,
|
|
|
|
|
)],
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
if DEBUG:
|
|
|
|
|
logger.info("Proxy response :" + response.model_dump_json())
|
|
|
|
|
return response
|
|
|
|
|
@@ -201,7 +350,7 @@ Please think if you need to use a tool or not for user's question, you must:
|
|
|
|
|
{{"name": $TOOL_NAME, "arguments": {{"$PARAMETER_NAME": "$PARAMETER_VALUE", ...}}}}
|
|
|
|
|
3. If no tools is needed, respond with normal text."""
|
|
|
|
|
|
|
|
|
|
def _parse_args(self, chat_request: ChatRequest) -> dict:
|
|
|
|
|
def compose_request_body(self, chat_request: ChatRequest) -> str:
|
|
|
|
|
args = {
|
|
|
|
|
"anthropic_version": self.anthropic_version,
|
|
|
|
|
"max_tokens": chat_request.max_tokens,
|
|
|
|
|
@@ -239,8 +388,8 @@ Please think if you need to use a tool or not for user's question, you must:
|
|
|
|
|
{
|
|
|
|
|
"role": "user",
|
|
|
|
|
"content": "[Tool result with matching id `{}` of `{}`] ".format(
|
|
|
|
|
message.tool_call_id,
|
|
|
|
|
message.content),
|
|
|
|
|
message.tool_call_id, message.content
|
|
|
|
|
),
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
@@ -253,74 +402,86 @@ Please think if you need to use a tool or not for user's question, you must:
|
|
|
|
|
[tool.function.model_dump() for tool in chat_request.tools]
|
|
|
|
|
)
|
|
|
|
|
system_prompt += self.tool_prompt.format(tools=tools_str)
|
|
|
|
|
converted_messages.append({
|
|
|
|
|
'role': 'assistant',
|
|
|
|
|
'content': '<tool>'
|
|
|
|
|
})
|
|
|
|
|
args["stop_sequences"] = ['</function>']
|
|
|
|
|
converted_messages.append({"role": "assistant", "content": "<tool>"})
|
|
|
|
|
args["stop_sequences"] = ["</function>"]
|
|
|
|
|
args["messages"] = self.merge_message(converted_messages)
|
|
|
|
|
if system_prompt:
|
|
|
|
|
if DEBUG:
|
|
|
|
|
logger.info("System Prompt: " + system_prompt)
|
|
|
|
|
args["system"] = system_prompt
|
|
|
|
|
|
|
|
|
|
return args
|
|
|
|
|
return json.dumps(args)
|
|
|
|
|
|
|
|
|
|
def chat(self, chat_request: ChatRequest) -> ChatResponse:
|
|
|
|
|
if DEBUG:
|
|
|
|
|
logger.info("Raw request: " + chat_request.model_dump_json())
|
|
|
|
|
response = self._invoke_model(
|
|
|
|
|
args=self._parse_args(chat_request), model_id=chat_request.model
|
|
|
|
|
)
|
|
|
|
|
response_body = json.loads(response.get("body").read())
|
|
|
|
|
def parse_response(
|
|
|
|
|
self, chat_request: ChatRequest, service_response: dict, message_id: str
|
|
|
|
|
) -> ChatResponse:
|
|
|
|
|
response_body = json.loads(service_response.get("body").read())
|
|
|
|
|
if DEBUG:
|
|
|
|
|
logger.info("Bedrock response body: " + str(response_body))
|
|
|
|
|
message = response_body["content"][0]["text"]
|
|
|
|
|
|
|
|
|
|
finish_reason = response_body["stop_reason"]
|
|
|
|
|
tools = None
|
|
|
|
|
if chat_request.tools:
|
|
|
|
|
if message.startswith("Y</tool>"):
|
|
|
|
|
tools = self._parse_tool_message(message)
|
|
|
|
|
message = None
|
|
|
|
|
finish_reason = "tool_calls"
|
|
|
|
|
elif message.startswith("N</tool>"):
|
|
|
|
|
message = message[8:].lstrip("\n")
|
|
|
|
|
return self._create_response(
|
|
|
|
|
return self.create_response(
|
|
|
|
|
model=chat_request.model,
|
|
|
|
|
message_id=response_body["id"],
|
|
|
|
|
message_id=message_id,
|
|
|
|
|
message=message,
|
|
|
|
|
tools=tools,
|
|
|
|
|
finish_reason=response_body["stop_reason"],
|
|
|
|
|
finish_reason=finish_reason,
|
|
|
|
|
input_tokens=response_body["usage"]["input_tokens"],
|
|
|
|
|
output_tokens=response_body["usage"]["output_tokens"],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
|
|
|
|
|
if DEBUG:
|
|
|
|
|
logger.info("Raw request: " + chat_request.model_dump_json())
|
|
|
|
|
response = self._invoke_model(
|
|
|
|
|
args=self._parse_args(chat_request),
|
|
|
|
|
model_id=chat_request.model,
|
|
|
|
|
with_stream=True,
|
|
|
|
|
)
|
|
|
|
|
msg_id = ""
|
|
|
|
|
def parse_stream_response(
|
|
|
|
|
self, chat_request: ChatRequest, service_response: dict, message_id: str
|
|
|
|
|
) -> Iterable[ChatStreamResponse]:
|
|
|
|
|
|
|
|
|
|
chunk_id = 0
|
|
|
|
|
tool_message = ""
|
|
|
|
|
first_token = True
|
|
|
|
|
index = 0
|
|
|
|
|
for event in response.get("body"):
|
|
|
|
|
for event in service_response.get("body"):
|
|
|
|
|
if DEBUG:
|
|
|
|
|
logger.info("Bedrock response chunk: " + str(event))
|
|
|
|
|
chunk = json.loads(event["chunk"]["bytes"])
|
|
|
|
|
chunk_id += 1
|
|
|
|
|
|
|
|
|
|
if chunk["type"] == "message_start":
|
|
|
|
|
msg_id = chunk["message"]["id"]
|
|
|
|
|
continue
|
|
|
|
|
if chunk["type"] == "message_stop":
|
|
|
|
|
# Get the usage for streaming response anyway.
|
|
|
|
|
if "amazon-bedrock-invocationMetrics" in chunk:
|
|
|
|
|
yield self.create_response_stream(
|
|
|
|
|
model=chat_request.model,
|
|
|
|
|
message_id=message_id,
|
|
|
|
|
input_tokens=chunk["amazon-bedrock-invocationMetrics"][
|
|
|
|
|
"inputTokenCount"
|
|
|
|
|
],
|
|
|
|
|
output_tokens=chunk["amazon-bedrock-invocationMetrics"][
|
|
|
|
|
"outputTokenCount"
|
|
|
|
|
],
|
|
|
|
|
)
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
if chunk["type"] == "message_delta":
|
|
|
|
|
elif chunk["type"] == "message_delta":
|
|
|
|
|
chunk_message = ""
|
|
|
|
|
finish_reason = chunk["delta"]["stop_reason"]
|
|
|
|
|
|
|
|
|
|
# Send tool message first if any.
|
|
|
|
|
if chat_request.tools and tool_message:
|
|
|
|
|
tools = self._parse_tool_message(tool_message)
|
|
|
|
|
finish_reason = "tool_calls"
|
|
|
|
|
response = self.create_response_stream(
|
|
|
|
|
model=chat_request.model,
|
|
|
|
|
message_id=message_id,
|
|
|
|
|
tools=tools,
|
|
|
|
|
)
|
|
|
|
|
yield response
|
|
|
|
|
|
|
|
|
|
elif chunk["type"] == "content_block_delta":
|
|
|
|
|
chunk_message = chunk["delta"]["text"]
|
|
|
|
|
finish_reason = None
|
|
|
|
|
@@ -343,22 +504,14 @@ Please think if you need to use a tool or not for user's question, you must:
|
|
|
|
|
first_token = False
|
|
|
|
|
else:
|
|
|
|
|
continue
|
|
|
|
|
response = self._create_response_stream(
|
|
|
|
|
response = self.create_response_stream(
|
|
|
|
|
model=chat_request.model,
|
|
|
|
|
message_id=msg_id,
|
|
|
|
|
message_id=message_id,
|
|
|
|
|
chunk_message=chunk_message,
|
|
|
|
|
finish_reason=finish_reason,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
yield self._stream_response_to_bytes(response)
|
|
|
|
|
if chat_request.tools and tool_message:
|
|
|
|
|
tools = self._parse_tool_message(tool_message)
|
|
|
|
|
response = self._create_response_stream(
|
|
|
|
|
model=chat_request.model,
|
|
|
|
|
message_id=msg_id,
|
|
|
|
|
tools=tools,
|
|
|
|
|
)
|
|
|
|
|
yield self._stream_response_to_bytes(response)
|
|
|
|
|
yield response
|
|
|
|
|
|
|
|
|
|
def _parse_tool_message(self, tool_message: str) -> list[ToolCall]:
|
|
|
|
|
if DEBUG:
|
|
|
|
|
@@ -367,9 +520,7 @@ Please think if you need to use a tool or not for user's question, you must:
|
|
|
|
|
tool_messages = tool_message[tool_message.rindex("<function>") + len("<function>"):]
|
|
|
|
|
function = json.loads(tool_messages.replace("\n", " "))
|
|
|
|
|
args = json.dumps(function.get("arguments", {}))
|
|
|
|
|
function = ResponseFunction(
|
|
|
|
|
name=function["name"], arguments=args
|
|
|
|
|
)
|
|
|
|
|
function = ResponseFunction(name=function["name"], arguments=args)
|
|
|
|
|
|
|
|
|
|
return [
|
|
|
|
|
ToolCall(
|
|
|
|
|
@@ -400,7 +551,7 @@ Please think if you need to use a tool or not for user's question, you must:
|
|
|
|
|
# Check if the request was successful
|
|
|
|
|
if response.status_code == 200:
|
|
|
|
|
|
|
|
|
|
content_type = response.headers.get('Content-Type')
|
|
|
|
|
content_type = response.headers.get("Content-Type")
|
|
|
|
|
if not content_type.startswith("image"):
|
|
|
|
|
content_type = "image/jpeg"
|
|
|
|
|
# Get the image content
|
|
|
|
|
@@ -437,6 +588,8 @@ Please think if you need to use a tool or not for user's question, you must:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LlamaModel(BedrockModel):
|
|
|
|
|
text_field_name = "generation"
|
|
|
|
|
finish_reason_field_name = "stop_reason"
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def create_llama3_prompt(chat_request: ChatRequest) -> str:
|
|
|
|
|
@@ -460,7 +613,9 @@ class LlamaModel(BedrockModel):
|
|
|
|
|
|
|
|
|
|
prompt_lines = []
|
|
|
|
|
for msg in chat_request.messages:
|
|
|
|
|
prompt_lines.append(f"<|start_header_id|>{msg.role}<|end_header_id|>\n\n{msg.content}<|eot_id|>")
|
|
|
|
|
prompt_lines.append(
|
|
|
|
|
f"<|start_header_id|>{msg.role}<|end_header_id|>\n\n{msg.content}<|eot_id|>"
|
|
|
|
|
)
|
|
|
|
|
prompt_lines.append(f"<|start_header_id|>assistant<|end_header_id|>\n\n")
|
|
|
|
|
prompt = bos_token + "".join(prompt_lines)
|
|
|
|
|
if DEBUG:
|
|
|
|
|
@@ -512,59 +667,24 @@ class LlamaModel(BedrockModel):
|
|
|
|
|
logger.info("Converted prompt: " + prompt.replace("\n", "\\n"))
|
|
|
|
|
return prompt
|
|
|
|
|
|
|
|
|
|
def _parse_args(self, chat_request: ChatRequest) -> dict:
|
|
|
|
|
def compose_request_body(self, chat_request: ChatRequest) -> str:
|
|
|
|
|
if chat_request.model.startswith("meta.llama2"):
|
|
|
|
|
prompt = self.create_llama2_prompt(chat_request)
|
|
|
|
|
else:
|
|
|
|
|
prompt = self.create_llama3_prompt(chat_request)
|
|
|
|
|
# Currently, there is no way to set stop sequence for Llama 3 models.
|
|
|
|
|
return {
|
|
|
|
|
args = {
|
|
|
|
|
"prompt": prompt,
|
|
|
|
|
"max_gen_len": chat_request.max_tokens,
|
|
|
|
|
"temperature": chat_request.temperature,
|
|
|
|
|
"top_p": chat_request.top_p,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
def chat(self, chat_request: ChatRequest) -> ChatResponse:
|
|
|
|
|
response = self._invoke_model(
|
|
|
|
|
args=self._parse_args(chat_request), model_id=chat_request.model
|
|
|
|
|
)
|
|
|
|
|
response_body = json.loads(response.get("body").read())
|
|
|
|
|
if DEBUG:
|
|
|
|
|
logger.info("Bedrock response body: " + str(response_body))
|
|
|
|
|
message_id = self._generate_message_id()
|
|
|
|
|
|
|
|
|
|
return self._create_response(
|
|
|
|
|
model=chat_request.model,
|
|
|
|
|
message=response_body["generation"],
|
|
|
|
|
message_id=message_id,
|
|
|
|
|
input_tokens=response_body["prompt_token_count"],
|
|
|
|
|
output_tokens=response_body["generation_token_count"],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
|
|
|
|
|
response = self._invoke_model(
|
|
|
|
|
args=self._parse_args(chat_request),
|
|
|
|
|
model_id=chat_request.model,
|
|
|
|
|
with_stream=True,
|
|
|
|
|
)
|
|
|
|
|
msg_id = ""
|
|
|
|
|
chunk_id = 0
|
|
|
|
|
for event in response.get("body"):
|
|
|
|
|
if DEBUG:
|
|
|
|
|
logger.info("Bedrock response chunk: " + str(event))
|
|
|
|
|
chunk = json.loads(event["chunk"]["bytes"])
|
|
|
|
|
chunk_id += 1
|
|
|
|
|
response = self._create_response_stream(
|
|
|
|
|
model=chat_request.model,
|
|
|
|
|
message_id=msg_id,
|
|
|
|
|
chunk_message=chunk["generation"],
|
|
|
|
|
finish_reason=chunk["stop_reason"],
|
|
|
|
|
)
|
|
|
|
|
yield self._stream_response_to_bytes(response)
|
|
|
|
|
return json.dumps(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MistralModel(BedrockModel):
|
|
|
|
|
text_field_name = "text"
|
|
|
|
|
finish_reason_field_name = "stop_reason"
|
|
|
|
|
|
|
|
|
|
def _convert_prompt(self, chat_request: ChatRequest) -> str:
|
|
|
|
|
"""Create a prompt message follow below example:
|
|
|
|
|
|
|
|
|
|
@@ -609,51 +729,54 @@ class MistralModel(BedrockModel):
|
|
|
|
|
logger.info("Converted prompt: " + prompt.replace("\n", "\\n"))
|
|
|
|
|
return prompt
|
|
|
|
|
|
|
|
|
|
def _parse_args(self, chat_request: ChatRequest) -> dict:
|
|
|
|
|
def compose_request_body(self, chat_request: ChatRequest) -> str:
|
|
|
|
|
prompt = self._convert_prompt(chat_request)
|
|
|
|
|
return {
|
|
|
|
|
args = {
|
|
|
|
|
"prompt": prompt,
|
|
|
|
|
"max_tokens": chat_request.max_tokens,
|
|
|
|
|
"temperature": chat_request.temperature,
|
|
|
|
|
"top_p": chat_request.top_p,
|
|
|
|
|
}
|
|
|
|
|
return json.dumps(args)
|
|
|
|
|
|
|
|
|
|
def chat(self, chat_request: ChatRequest) -> ChatResponse:
|
|
|
|
|
def get_message_text(self, response_body: dict) -> str | None:
|
|
|
|
|
return super().get_message_text(response_body["outputs"][0])
|
|
|
|
|
|
|
|
|
|
response = self._invoke_model(
|
|
|
|
|
args=self._parse_args(chat_request), model_id=chat_request.model
|
|
|
|
|
)
|
|
|
|
|
response_body = json.loads(response.get("body").read())
|
|
|
|
|
if DEBUG:
|
|
|
|
|
logger.info("Bedrock response body: " + str(response_body))
|
|
|
|
|
message_id = self._generate_message_id()
|
|
|
|
|
def get_message_finish_reason(self, response_body: dict) -> str | None:
|
|
|
|
|
return super().get_message_finish_reason(response_body["outputs"][0])
|
|
|
|
|
|
|
|
|
|
return self._create_response(
|
|
|
|
|
model=chat_request.model,
|
|
|
|
|
message=response_body["outputs"][0]["text"],
|
|
|
|
|
message_id=message_id,
|
|
|
|
|
)
|
|
|
|
|
def get_message_usage(self, response_body: dict) -> tuple[int, int]:
|
|
|
|
|
# Mistral/Mixtral does not provide info about usage
|
|
|
|
|
return 0, 0
|
|
|
|
|
|
|
|
|
|
def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
|
|
|
|
|
response = self._invoke_model(
|
|
|
|
|
args=self._parse_args(chat_request),
|
|
|
|
|
model_id=chat_request.model,
|
|
|
|
|
with_stream=True,
|
|
|
|
|
)
|
|
|
|
|
msg_id = ""
|
|
|
|
|
chunk_id = 0
|
|
|
|
|
for event in response.get("body"):
|
|
|
|
|
if DEBUG:
|
|
|
|
|
logger.info("Bedrock response chunk: " + str(event))
|
|
|
|
|
chunk = json.loads(event["chunk"]["bytes"])
|
|
|
|
|
chunk_id += 1
|
|
|
|
|
response = self._create_response_stream(
|
|
|
|
|
model=chat_request.model,
|
|
|
|
|
message_id=msg_id,
|
|
|
|
|
chunk_message=chunk["outputs"][0]["text"],
|
|
|
|
|
finish_reason=chunk["outputs"][0]["stop_reason"],
|
|
|
|
|
|
|
|
|
|
class CohereCommandModel(BedrockModel):
|
|
|
|
|
|
|
|
|
|
def _parse_message(self, message) -> dict:
|
|
|
|
|
if message.role not in ["user", "assistant"]:
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=400, detail="Only user or assistant message is supported"
|
|
|
|
|
)
|
|
|
|
|
yield self._stream_response_to_bytes(response)
|
|
|
|
|
return {
|
|
|
|
|
"role": "USER" if message.role == "user" else "CHATBOT",
|
|
|
|
|
"message": message.content,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
def compose_request_body(self, chat_request: ChatRequest) -> str:
|
|
|
|
|
messages = chat_request.messages
|
|
|
|
|
if messages[-1].role != "user":
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=400, detail="Last message should be a valid user message"
|
|
|
|
|
)
|
|
|
|
|
chat_history = [self._parse_message(msg) for msg in messages[:-1]]
|
|
|
|
|
args = {
|
|
|
|
|
"message": messages[-1].content,
|
|
|
|
|
"chat_history": chat_history,
|
|
|
|
|
"max_tokens": chat_request.max_tokens,
|
|
|
|
|
"temperature": chat_request.temperature,
|
|
|
|
|
"p": chat_request.top_p,
|
|
|
|
|
}
|
|
|
|
|
return json.dumps(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BedrockEmbeddingsModel(BaseEmbeddingsModel, ABC):
|
|
|
|
|
@@ -673,8 +796,7 @@ class BedrockEmbeddingsModel(BaseEmbeddingsModel, ABC):
|
|
|
|
|
contentType=self.content_type,
|
|
|
|
|
)
|
|
|
|
|
except bedrock_runtime.exceptions.ValidationException as e:
|
|
|
|
|
print("Validation Exception")
|
|
|
|
|
print(e)
|
|
|
|
|
logger.error("Validation Error: " + str(e))
|
|
|
|
|
raise HTTPException(status_code=400, detail=str(e))
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(e)
|
|
|
|
|
@@ -705,7 +827,6 @@ class BedrockEmbeddingsModel(BaseEmbeddingsModel, ABC):
|
|
|
|
|
total_tokens=input_tokens + output_tokens,
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if DEBUG:
|
|
|
|
|
logger.info("Proxy response :" + response.model_dump_json())
|
|
|
|
|
return response
|
|
|
|
|
@@ -799,24 +920,22 @@ class TitanEmbeddingsModel(BedrockEmbeddingsModel):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_model(model_id: str) -> BedrockModel:
|
|
|
|
|
model_name = SUPPORTED_BEDROCK_MODELS.get(model_id, "")
|
|
|
|
|
if DEBUG:
|
|
|
|
|
logger.info("model name is " + model_name)
|
|
|
|
|
# Not using start_with here in case of complex scenarios.
|
|
|
|
|
# The downside is to change this everytime for a new model supported.
|
|
|
|
|
match model_name:
|
|
|
|
|
case "Claude Instant" | "Claude" | "Claude 3 Sonnet" | "Claude 3 Haiku" | "Claude 3 Opus":
|
|
|
|
|
return ClaudeModel()
|
|
|
|
|
case "Llama 2 Chat 13B" | "Llama 2 Chat 70B" | "Llama 3 8B Instruct" | "Llama 3 70B Instruct":
|
|
|
|
|
return LlamaModel()
|
|
|
|
|
case "Mistral 7B Instruct" | "Mixtral 8x7B Instruct" | "Mistral Large":
|
|
|
|
|
return MistralModel()
|
|
|
|
|
case _:
|
|
|
|
|
logger.error("Unsupported model id " + model_id)
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=400,
|
|
|
|
|
detail="Unsupported model id " + model_id,
|
|
|
|
|
)
|
|
|
|
|
logger.info("model id is " + model_id)
|
|
|
|
|
if model_id not in SUPPORTED_BEDROCK_MODELS.keys():
|
|
|
|
|
logger.error("Unsupported model id " + model_id)
|
|
|
|
|
raise HTTPException(
|
|
|
|
|
status_code=400,
|
|
|
|
|
detail="Unsupported model id " + model_id,
|
|
|
|
|
)
|
|
|
|
|
if model_id.startswith("anthropic.claude"):
|
|
|
|
|
return ClaudeModel()
|
|
|
|
|
elif model_id.startswith("meta.llama"):
|
|
|
|
|
return LlamaModel()
|
|
|
|
|
elif model_id.startswith("mistral.mistral") or model_id.startswith("mistral.mixtral"):
|
|
|
|
|
return MistralModel()
|
|
|
|
|
elif model_id.startswith("cohere.command-r"):
|
|
|
|
|
return CohereCommandModel()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_embeddings_model(model_id: str) -> BedrockEmbeddingsModel:
|
|
|
|
|
|