Refactor model implementation
This commit is contained in:
@@ -29,11 +29,17 @@ class BaseChatModel(ABC):
|
||||
"""Handle a basic chat completion requests with stream response."""
|
||||
pass
|
||||
|
||||
def _generate_message_id(self) -> str:
|
||||
@staticmethod
|
||||
def generate_message_id() -> str:
|
||||
return "chatcmpl-" + str(uuid.uuid4())[:8]
|
||||
|
||||
def _stream_response_to_bytes(self, response: ChatStreamResponse) -> bytes:
|
||||
return "data: {}\n\n".format(response.model_dump_json()).encode("utf-8")
|
||||
@staticmethod
|
||||
def stream_response_to_bytes(
|
||||
response: ChatStreamResponse | None = None
|
||||
) -> bytes:
|
||||
if response:
|
||||
return "data: {}\n\n".format(response.model_dump_json()).encode("utf-8")
|
||||
return "data: [DONE]\n\n".encode("utf-8")
|
||||
|
||||
|
||||
class BaseEmbeddingsModel(ABC):
|
||||
@@ -46,6 +52,3 @@ class BaseEmbeddingsModel(ABC):
|
||||
def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse:
|
||||
"""Handle a basic embeddings request."""
|
||||
pass
|
||||
|
||||
def _generate_message_id(self) -> str:
|
||||
return "embeddings-" + str(uuid.uuid4())[:8]
|
||||
|
||||
@@ -2,7 +2,7 @@ import base64
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from abc import ABC
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import AsyncIterable, Iterable, Literal
|
||||
|
||||
import boto3
|
||||
@@ -54,6 +54,8 @@ SUPPORTED_BEDROCK_MODELS = {
|
||||
"mistral.mistral-7b-instruct-v0:2": "Mistral 7B Instruct",
|
||||
"mistral.mixtral-8x7b-instruct-v0:1": "Mixtral 8x7B Instruct",
|
||||
"mistral.mistral-large-2402-v1:0": "Mistral Large",
|
||||
"cohere.command-r-v1:0": "Command R",
|
||||
"cohere.command-r-plus-v1:0": "Command R+",
|
||||
}
|
||||
|
||||
SUPPORTED_BEDROCK_EMBEDDING_MODELS = {
|
||||
@@ -72,28 +74,157 @@ class BedrockModel(BaseChatModel, ABC):
|
||||
accept = "application/json"
|
||||
content_type = "application/json"
|
||||
|
||||
def _invoke_model(self, args: dict, model_id: str, with_stream: bool = False):
|
||||
body = json.dumps(args)
|
||||
# Default field name to get the response message
|
||||
text_field_name = "text"
|
||||
|
||||
# Default field name to get the response finish reason
|
||||
finish_reason_field_name = "finish_reason"
|
||||
|
||||
@abstractmethod
|
||||
def compose_request_body(self, chat_request: ChatRequest) -> str:
|
||||
"""Since the request body to Bedrock varies,
|
||||
each model should implement this to compose the request body.
|
||||
|
||||
:param chat_request:
|
||||
:return: request body as a string
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def chat(self, chat_request: ChatRequest) -> ChatResponse:
|
||||
"""Default implementation for Chat API."""
|
||||
if DEBUG:
|
||||
logger.info("Raw request: " + chat_request.model_dump_json())
|
||||
request_body = self.compose_request_body(chat_request)
|
||||
|
||||
if DEBUG:
|
||||
logger.info("Bedrock request: " + request_body)
|
||||
|
||||
response = self.invoke_model(
|
||||
request_body=request_body,
|
||||
model_id=chat_request.model,
|
||||
)
|
||||
message_id = self.generate_message_id()
|
||||
return self.parse_response(chat_request, response, message_id)
|
||||
|
||||
def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
|
||||
"""Default implementation for Chat Stream API"""
|
||||
if DEBUG:
|
||||
logger.info("Raw request: " + chat_request.model_dump_json())
|
||||
request_body = self.compose_request_body(chat_request)
|
||||
response = self.invoke_model(
|
||||
request_body=request_body,
|
||||
model_id=chat_request.model,
|
||||
with_stream=True,
|
||||
)
|
||||
|
||||
message_id = self.generate_message_id()
|
||||
for stream_response in self.parse_stream_response(
|
||||
chat_request, response, message_id
|
||||
):
|
||||
if stream_response.choices:
|
||||
yield self.stream_response_to_bytes(stream_response)
|
||||
elif (
|
||||
chat_request.stream_options
|
||||
and chat_request.stream_options.include_usage
|
||||
):
|
||||
# An empty choices for Usage as per OpenAI doc below:
|
||||
# if you set stream_options: {"include_usage": true}.
|
||||
# an additional chunk will be streamed before the data: [DONE] message.
|
||||
# The usage field on this chunk shows the token usage statistics for the entire request,
|
||||
# and the choices field will always be an empty array.
|
||||
# All other chunks will also include a usage field, but with a null value.
|
||||
yield self.stream_response_to_bytes(stream_response)
|
||||
# return an [DONE] message at the end.
|
||||
yield self.stream_response_to_bytes()
|
||||
|
||||
def get_message_text(self, response_body: dict) -> str | None:
|
||||
"""Default func to get the response message.
|
||||
|
||||
Ideally, only the field name should be changed."""
|
||||
return response_body.get(self.text_field_name)
|
||||
|
||||
def get_message_finish_reason(self, response_body: dict) -> str | None:
|
||||
"""Default func to get the finish message.
|
||||
|
||||
Ideally, only the field name should be changed."""
|
||||
return response_body.get(self.finish_reason_field_name)
|
||||
|
||||
def get_message_usage(self, response_body: dict) -> tuple[int, int]:
|
||||
"""Default func to get the finish message.
|
||||
|
||||
Can be overridden in the detail model for complex cases."""
|
||||
input_tokens = int(response_body.get("prompt_token_count", "0"))
|
||||
output_tokens = int(response_body.get("generation_token_count", "0"))
|
||||
return input_tokens, output_tokens
|
||||
|
||||
def parse_response(
|
||||
self, chat_request: ChatRequest, service_response: dict, message_id: str
|
||||
) -> ChatResponse:
|
||||
response_body = json.loads(service_response.get("body").read())
|
||||
if DEBUG:
|
||||
logger.info("Bedrock response body: " + str(response_body))
|
||||
|
||||
input_tokens, output_tokens = self.get_message_usage(response_body)
|
||||
return self.create_response(
|
||||
model=chat_request.model,
|
||||
message_id=message_id,
|
||||
message=self.get_message_text(response_body),
|
||||
finish_reason=self.get_message_finish_reason(response_body),
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
)
|
||||
|
||||
def parse_stream_response(
|
||||
self, chat_request: ChatRequest, service_response: dict, message_id: str
|
||||
) -> Iterable[ChatStreamResponse]:
|
||||
|
||||
chunk_id = 0
|
||||
for event in service_response.get("body"):
|
||||
if DEBUG:
|
||||
logger.info("Bedrock response chunk: " + str(event))
|
||||
chunk = json.loads(event["chunk"]["bytes"])
|
||||
chunk_id += 1
|
||||
|
||||
response = self.create_response_stream(
|
||||
model=chat_request.model,
|
||||
message_id=message_id,
|
||||
chunk_message=self.get_message_text(chunk),
|
||||
finish_reason=self.get_message_finish_reason(chunk),
|
||||
)
|
||||
yield response
|
||||
# Get the usage for streaming response anyway.
|
||||
if "amazon-bedrock-invocationMetrics" in chunk:
|
||||
yield self.create_response_stream(
|
||||
model=chat_request.model,
|
||||
message_id=message_id,
|
||||
input_tokens=chunk["amazon-bedrock-invocationMetrics"][
|
||||
"inputTokenCount"
|
||||
],
|
||||
output_tokens=chunk["amazon-bedrock-invocationMetrics"][
|
||||
"outputTokenCount"
|
||||
],
|
||||
)
|
||||
|
||||
def invoke_model(self, request_body: str, model_id: str, with_stream: bool = False):
|
||||
if DEBUG:
|
||||
logger.info("Invoke Bedrock Model: " + model_id)
|
||||
logger.info("Bedrock request body: " + body)
|
||||
logger.info("Bedrock request body: " + request_body)
|
||||
try:
|
||||
if with_stream:
|
||||
return bedrock_runtime.invoke_model_with_response_stream(
|
||||
body=body,
|
||||
body=request_body,
|
||||
modelId=model_id,
|
||||
accept=self.accept,
|
||||
contentType=self.content_type,
|
||||
)
|
||||
return bedrock_runtime.invoke_model(
|
||||
body=body,
|
||||
body=request_body,
|
||||
modelId=model_id,
|
||||
accept=self.accept,
|
||||
contentType=self.content_type,
|
||||
)
|
||||
except bedrock_runtime.exceptions.ValidationException as e:
|
||||
print("Validation Exception")
|
||||
print(e)
|
||||
logger.error("Validation Error: " + str(e))
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
@@ -101,8 +232,7 @@ class BedrockModel(BaseChatModel, ABC):
|
||||
|
||||
@staticmethod
|
||||
def merge_message(messages: list[dict]) -> list[dict]:
|
||||
"""Merge the request messages with the same role as previous message
|
||||
"""
|
||||
"""Merge the request messages with the same role as previous message"""
|
||||
merged_messages = []
|
||||
prev_role = None
|
||||
merged_content = ""
|
||||
@@ -110,10 +240,11 @@ class BedrockModel(BaseChatModel, ABC):
|
||||
for message in messages:
|
||||
role = message["role"]
|
||||
content = message["content"]
|
||||
|
||||
if role != prev_role or isinstance(content, list):
|
||||
if prev_role:
|
||||
merged_messages.append({"role": prev_role, "content": merged_content})
|
||||
merged_messages.append(
|
||||
{"role": prev_role, "content": merged_content}
|
||||
)
|
||||
if isinstance(content, str):
|
||||
merged_content = content
|
||||
prev_role = role
|
||||
@@ -132,7 +263,7 @@ class BedrockModel(BaseChatModel, ABC):
|
||||
return merged_messages
|
||||
|
||||
@staticmethod
|
||||
def _create_response(
|
||||
def create_response(
|
||||
model: str,
|
||||
message_id: str,
|
||||
message: str | None = None,
|
||||
@@ -144,15 +275,17 @@ class BedrockModel(BaseChatModel, ABC):
|
||||
response = ChatResponse(
|
||||
id=message_id,
|
||||
model=model,
|
||||
choices=[Choice(
|
||||
index=0,
|
||||
message=ChatResponseMessage(
|
||||
role="assistant",
|
||||
tool_calls=tools,
|
||||
content=message,
|
||||
),
|
||||
finish_reason="tool_calls" if tools else finish_reason,
|
||||
)],
|
||||
choices=[
|
||||
Choice(
|
||||
index=0,
|
||||
message=ChatResponseMessage(
|
||||
role="assistant",
|
||||
tool_calls=tools,
|
||||
content=message,
|
||||
),
|
||||
finish_reason="tool_calls" if tools else finish_reason,
|
||||
)
|
||||
],
|
||||
usage=Usage(
|
||||
prompt_tokens=input_tokens,
|
||||
completion_tokens=output_tokens,
|
||||
@@ -164,26 +297,42 @@ class BedrockModel(BaseChatModel, ABC):
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
def _create_response_stream(
|
||||
def create_response_stream(
|
||||
model: str,
|
||||
message_id: str,
|
||||
chunk_message: str | None = None,
|
||||
finish_reason: str | None = None,
|
||||
tools: list[ToolCall] | None = None,
|
||||
input_tokens: int = 0,
|
||||
output_tokens: int = 0,
|
||||
) -> ChatStreamResponse:
|
||||
response = ChatStreamResponse(
|
||||
id=message_id,
|
||||
model=model,
|
||||
choices=[ChoiceDelta(
|
||||
index=0,
|
||||
delta=ChatResponseMessage(
|
||||
role="assistant",
|
||||
tool_calls=tools,
|
||||
content=chunk_message,
|
||||
if chunk_message or finish_reason or tools:
|
||||
response = ChatStreamResponse(
|
||||
id=message_id,
|
||||
model=model,
|
||||
choices=[
|
||||
ChoiceDelta(
|
||||
index=0,
|
||||
delta=ChatResponseMessage(
|
||||
role="assistant",
|
||||
tool_calls=tools,
|
||||
content=chunk_message,
|
||||
),
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
],
|
||||
)
|
||||
else:
|
||||
response = ChatStreamResponse(
|
||||
id=message_id,
|
||||
model=model,
|
||||
choices=[],
|
||||
usage=Usage(
|
||||
prompt_tokens=input_tokens,
|
||||
completion_tokens=output_tokens,
|
||||
total_tokens=input_tokens + output_tokens,
|
||||
),
|
||||
finish_reason="tool_calls" if tools else finish_reason,
|
||||
)],
|
||||
)
|
||||
)
|
||||
if DEBUG:
|
||||
logger.info("Proxy response :" + response.model_dump_json())
|
||||
return response
|
||||
@@ -201,7 +350,7 @@ Please think if you need to use a tool or not for user's question, you must:
|
||||
{{"name": $TOOL_NAME, "arguments": {{"$PARAMETER_NAME": "$PARAMETER_VALUE", ...}}}}
|
||||
3. If no tools is needed, respond with normal text."""
|
||||
|
||||
def _parse_args(self, chat_request: ChatRequest) -> dict:
|
||||
def compose_request_body(self, chat_request: ChatRequest) -> str:
|
||||
args = {
|
||||
"anthropic_version": self.anthropic_version,
|
||||
"max_tokens": chat_request.max_tokens,
|
||||
@@ -239,8 +388,8 @@ Please think if you need to use a tool or not for user's question, you must:
|
||||
{
|
||||
"role": "user",
|
||||
"content": "[Tool result with matching id `{}` of `{}`] ".format(
|
||||
message.tool_call_id,
|
||||
message.content),
|
||||
message.tool_call_id, message.content
|
||||
),
|
||||
}
|
||||
)
|
||||
else:
|
||||
@@ -253,74 +402,86 @@ Please think if you need to use a tool or not for user's question, you must:
|
||||
[tool.function.model_dump() for tool in chat_request.tools]
|
||||
)
|
||||
system_prompt += self.tool_prompt.format(tools=tools_str)
|
||||
converted_messages.append({
|
||||
'role': 'assistant',
|
||||
'content': '<tool>'
|
||||
})
|
||||
args["stop_sequences"] = ['</function>']
|
||||
converted_messages.append({"role": "assistant", "content": "<tool>"})
|
||||
args["stop_sequences"] = ["</function>"]
|
||||
args["messages"] = self.merge_message(converted_messages)
|
||||
if system_prompt:
|
||||
if DEBUG:
|
||||
logger.info("System Prompt: " + system_prompt)
|
||||
args["system"] = system_prompt
|
||||
|
||||
return args
|
||||
return json.dumps(args)
|
||||
|
||||
def chat(self, chat_request: ChatRequest) -> ChatResponse:
|
||||
if DEBUG:
|
||||
logger.info("Raw request: " + chat_request.model_dump_json())
|
||||
response = self._invoke_model(
|
||||
args=self._parse_args(chat_request), model_id=chat_request.model
|
||||
)
|
||||
response_body = json.loads(response.get("body").read())
|
||||
def parse_response(
|
||||
self, chat_request: ChatRequest, service_response: dict, message_id: str
|
||||
) -> ChatResponse:
|
||||
response_body = json.loads(service_response.get("body").read())
|
||||
if DEBUG:
|
||||
logger.info("Bedrock response body: " + str(response_body))
|
||||
message = response_body["content"][0]["text"]
|
||||
|
||||
finish_reason = response_body["stop_reason"]
|
||||
tools = None
|
||||
if chat_request.tools:
|
||||
if message.startswith("Y</tool>"):
|
||||
tools = self._parse_tool_message(message)
|
||||
message = None
|
||||
finish_reason = "tool_calls"
|
||||
elif message.startswith("N</tool>"):
|
||||
message = message[8:].lstrip("\n")
|
||||
return self._create_response(
|
||||
return self.create_response(
|
||||
model=chat_request.model,
|
||||
message_id=response_body["id"],
|
||||
message_id=message_id,
|
||||
message=message,
|
||||
tools=tools,
|
||||
finish_reason=response_body["stop_reason"],
|
||||
finish_reason=finish_reason,
|
||||
input_tokens=response_body["usage"]["input_tokens"],
|
||||
output_tokens=response_body["usage"]["output_tokens"],
|
||||
)
|
||||
|
||||
def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
|
||||
if DEBUG:
|
||||
logger.info("Raw request: " + chat_request.model_dump_json())
|
||||
response = self._invoke_model(
|
||||
args=self._parse_args(chat_request),
|
||||
model_id=chat_request.model,
|
||||
with_stream=True,
|
||||
)
|
||||
msg_id = ""
|
||||
def parse_stream_response(
|
||||
self, chat_request: ChatRequest, service_response: dict, message_id: str
|
||||
) -> Iterable[ChatStreamResponse]:
|
||||
|
||||
chunk_id = 0
|
||||
tool_message = ""
|
||||
first_token = True
|
||||
index = 0
|
||||
for event in response.get("body"):
|
||||
for event in service_response.get("body"):
|
||||
if DEBUG:
|
||||
logger.info("Bedrock response chunk: " + str(event))
|
||||
chunk = json.loads(event["chunk"]["bytes"])
|
||||
chunk_id += 1
|
||||
|
||||
if chunk["type"] == "message_start":
|
||||
msg_id = chunk["message"]["id"]
|
||||
continue
|
||||
if chunk["type"] == "message_stop":
|
||||
# Get the usage for streaming response anyway.
|
||||
if "amazon-bedrock-invocationMetrics" in chunk:
|
||||
yield self.create_response_stream(
|
||||
model=chat_request.model,
|
||||
message_id=message_id,
|
||||
input_tokens=chunk["amazon-bedrock-invocationMetrics"][
|
||||
"inputTokenCount"
|
||||
],
|
||||
output_tokens=chunk["amazon-bedrock-invocationMetrics"][
|
||||
"outputTokenCount"
|
||||
],
|
||||
)
|
||||
break
|
||||
|
||||
if chunk["type"] == "message_delta":
|
||||
elif chunk["type"] == "message_delta":
|
||||
chunk_message = ""
|
||||
finish_reason = chunk["delta"]["stop_reason"]
|
||||
|
||||
# Send tool message first if any.
|
||||
if chat_request.tools and tool_message:
|
||||
tools = self._parse_tool_message(tool_message)
|
||||
finish_reason = "tool_calls"
|
||||
response = self.create_response_stream(
|
||||
model=chat_request.model,
|
||||
message_id=message_id,
|
||||
tools=tools,
|
||||
)
|
||||
yield response
|
||||
|
||||
elif chunk["type"] == "content_block_delta":
|
||||
chunk_message = chunk["delta"]["text"]
|
||||
finish_reason = None
|
||||
@@ -343,22 +504,14 @@ Please think if you need to use a tool or not for user's question, you must:
|
||||
first_token = False
|
||||
else:
|
||||
continue
|
||||
response = self._create_response_stream(
|
||||
response = self.create_response_stream(
|
||||
model=chat_request.model,
|
||||
message_id=msg_id,
|
||||
message_id=message_id,
|
||||
chunk_message=chunk_message,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
|
||||
yield self._stream_response_to_bytes(response)
|
||||
if chat_request.tools and tool_message:
|
||||
tools = self._parse_tool_message(tool_message)
|
||||
response = self._create_response_stream(
|
||||
model=chat_request.model,
|
||||
message_id=msg_id,
|
||||
tools=tools,
|
||||
)
|
||||
yield self._stream_response_to_bytes(response)
|
||||
yield response
|
||||
|
||||
def _parse_tool_message(self, tool_message: str) -> list[ToolCall]:
|
||||
if DEBUG:
|
||||
@@ -367,9 +520,7 @@ Please think if you need to use a tool or not for user's question, you must:
|
||||
tool_messages = tool_message[tool_message.rindex("<function>") + len("<function>"):]
|
||||
function = json.loads(tool_messages.replace("\n", " "))
|
||||
args = json.dumps(function.get("arguments", {}))
|
||||
function = ResponseFunction(
|
||||
name=function["name"], arguments=args
|
||||
)
|
||||
function = ResponseFunction(name=function["name"], arguments=args)
|
||||
|
||||
return [
|
||||
ToolCall(
|
||||
@@ -400,7 +551,7 @@ Please think if you need to use a tool or not for user's question, you must:
|
||||
# Check if the request was successful
|
||||
if response.status_code == 200:
|
||||
|
||||
content_type = response.headers.get('Content-Type')
|
||||
content_type = response.headers.get("Content-Type")
|
||||
if not content_type.startswith("image"):
|
||||
content_type = "image/jpeg"
|
||||
# Get the image content
|
||||
@@ -437,6 +588,8 @@ Please think if you need to use a tool or not for user's question, you must:
|
||||
|
||||
|
||||
class LlamaModel(BedrockModel):
|
||||
text_field_name = "generation"
|
||||
finish_reason_field_name = "stop_reason"
|
||||
|
||||
@staticmethod
|
||||
def create_llama3_prompt(chat_request: ChatRequest) -> str:
|
||||
@@ -460,7 +613,9 @@ class LlamaModel(BedrockModel):
|
||||
|
||||
prompt_lines = []
|
||||
for msg in chat_request.messages:
|
||||
prompt_lines.append(f"<|start_header_id|>{msg.role}<|end_header_id|>\n\n{msg.content}<|eot_id|>")
|
||||
prompt_lines.append(
|
||||
f"<|start_header_id|>{msg.role}<|end_header_id|>\n\n{msg.content}<|eot_id|>"
|
||||
)
|
||||
prompt_lines.append(f"<|start_header_id|>assistant<|end_header_id|>\n\n")
|
||||
prompt = bos_token + "".join(prompt_lines)
|
||||
if DEBUG:
|
||||
@@ -512,59 +667,24 @@ class LlamaModel(BedrockModel):
|
||||
logger.info("Converted prompt: " + prompt.replace("\n", "\\n"))
|
||||
return prompt
|
||||
|
||||
def _parse_args(self, chat_request: ChatRequest) -> dict:
|
||||
def compose_request_body(self, chat_request: ChatRequest) -> str:
|
||||
if chat_request.model.startswith("meta.llama2"):
|
||||
prompt = self.create_llama2_prompt(chat_request)
|
||||
else:
|
||||
prompt = self.create_llama3_prompt(chat_request)
|
||||
# Currently, there is no way to set stop sequence for Llama 3 models.
|
||||
return {
|
||||
args = {
|
||||
"prompt": prompt,
|
||||
"max_gen_len": chat_request.max_tokens,
|
||||
"temperature": chat_request.temperature,
|
||||
"top_p": chat_request.top_p,
|
||||
}
|
||||
|
||||
def chat(self, chat_request: ChatRequest) -> ChatResponse:
|
||||
response = self._invoke_model(
|
||||
args=self._parse_args(chat_request), model_id=chat_request.model
|
||||
)
|
||||
response_body = json.loads(response.get("body").read())
|
||||
if DEBUG:
|
||||
logger.info("Bedrock response body: " + str(response_body))
|
||||
message_id = self._generate_message_id()
|
||||
|
||||
return self._create_response(
|
||||
model=chat_request.model,
|
||||
message=response_body["generation"],
|
||||
message_id=message_id,
|
||||
input_tokens=response_body["prompt_token_count"],
|
||||
output_tokens=response_body["generation_token_count"],
|
||||
)
|
||||
|
||||
def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
|
||||
response = self._invoke_model(
|
||||
args=self._parse_args(chat_request),
|
||||
model_id=chat_request.model,
|
||||
with_stream=True,
|
||||
)
|
||||
msg_id = ""
|
||||
chunk_id = 0
|
||||
for event in response.get("body"):
|
||||
if DEBUG:
|
||||
logger.info("Bedrock response chunk: " + str(event))
|
||||
chunk = json.loads(event["chunk"]["bytes"])
|
||||
chunk_id += 1
|
||||
response = self._create_response_stream(
|
||||
model=chat_request.model,
|
||||
message_id=msg_id,
|
||||
chunk_message=chunk["generation"],
|
||||
finish_reason=chunk["stop_reason"],
|
||||
)
|
||||
yield self._stream_response_to_bytes(response)
|
||||
return json.dumps(args)
|
||||
|
||||
|
||||
class MistralModel(BedrockModel):
|
||||
text_field_name = "text"
|
||||
finish_reason_field_name = "stop_reason"
|
||||
|
||||
def _convert_prompt(self, chat_request: ChatRequest) -> str:
|
||||
"""Create a prompt message follow below example:
|
||||
|
||||
@@ -609,51 +729,54 @@ class MistralModel(BedrockModel):
|
||||
logger.info("Converted prompt: " + prompt.replace("\n", "\\n"))
|
||||
return prompt
|
||||
|
||||
def _parse_args(self, chat_request: ChatRequest) -> dict:
|
||||
def compose_request_body(self, chat_request: ChatRequest) -> str:
|
||||
prompt = self._convert_prompt(chat_request)
|
||||
return {
|
||||
args = {
|
||||
"prompt": prompt,
|
||||
"max_tokens": chat_request.max_tokens,
|
||||
"temperature": chat_request.temperature,
|
||||
"top_p": chat_request.top_p,
|
||||
}
|
||||
return json.dumps(args)
|
||||
|
||||
def chat(self, chat_request: ChatRequest) -> ChatResponse:
|
||||
def get_message_text(self, response_body: dict) -> str | None:
|
||||
return super().get_message_text(response_body["outputs"][0])
|
||||
|
||||
response = self._invoke_model(
|
||||
args=self._parse_args(chat_request), model_id=chat_request.model
|
||||
)
|
||||
response_body = json.loads(response.get("body").read())
|
||||
if DEBUG:
|
||||
logger.info("Bedrock response body: " + str(response_body))
|
||||
message_id = self._generate_message_id()
|
||||
def get_message_finish_reason(self, response_body: dict) -> str | None:
|
||||
return super().get_message_finish_reason(response_body["outputs"][0])
|
||||
|
||||
return self._create_response(
|
||||
model=chat_request.model,
|
||||
message=response_body["outputs"][0]["text"],
|
||||
message_id=message_id,
|
||||
)
|
||||
def get_message_usage(self, response_body: dict) -> tuple[int, int]:
|
||||
# Mistral/Mixtral does not provide info about usage
|
||||
return 0, 0
|
||||
|
||||
def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
|
||||
response = self._invoke_model(
|
||||
args=self._parse_args(chat_request),
|
||||
model_id=chat_request.model,
|
||||
with_stream=True,
|
||||
)
|
||||
msg_id = ""
|
||||
chunk_id = 0
|
||||
for event in response.get("body"):
|
||||
if DEBUG:
|
||||
logger.info("Bedrock response chunk: " + str(event))
|
||||
chunk = json.loads(event["chunk"]["bytes"])
|
||||
chunk_id += 1
|
||||
response = self._create_response_stream(
|
||||
model=chat_request.model,
|
||||
message_id=msg_id,
|
||||
chunk_message=chunk["outputs"][0]["text"],
|
||||
finish_reason=chunk["outputs"][0]["stop_reason"],
|
||||
|
||||
class CohereCommandModel(BedrockModel):
|
||||
|
||||
def _parse_message(self, message) -> dict:
|
||||
if message.role not in ["user", "assistant"]:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Only user or assistant message is supported"
|
||||
)
|
||||
yield self._stream_response_to_bytes(response)
|
||||
return {
|
||||
"role": "USER" if message.role == "user" else "CHATBOT",
|
||||
"message": message.content,
|
||||
}
|
||||
|
||||
def compose_request_body(self, chat_request: ChatRequest) -> str:
|
||||
messages = chat_request.messages
|
||||
if messages[-1].role != "user":
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Last message should be a valid user message"
|
||||
)
|
||||
chat_history = [self._parse_message(msg) for msg in messages[:-1]]
|
||||
args = {
|
||||
"message": messages[-1].content,
|
||||
"chat_history": chat_history,
|
||||
"max_tokens": chat_request.max_tokens,
|
||||
"temperature": chat_request.temperature,
|
||||
"p": chat_request.top_p,
|
||||
}
|
||||
return json.dumps(args)
|
||||
|
||||
|
||||
class BedrockEmbeddingsModel(BaseEmbeddingsModel, ABC):
|
||||
@@ -673,8 +796,7 @@ class BedrockEmbeddingsModel(BaseEmbeddingsModel, ABC):
|
||||
contentType=self.content_type,
|
||||
)
|
||||
except bedrock_runtime.exceptions.ValidationException as e:
|
||||
print("Validation Exception")
|
||||
print(e)
|
||||
logger.error("Validation Error: " + str(e))
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
@@ -705,7 +827,6 @@ class BedrockEmbeddingsModel(BaseEmbeddingsModel, ABC):
|
||||
total_tokens=input_tokens + output_tokens,
|
||||
),
|
||||
)
|
||||
|
||||
if DEBUG:
|
||||
logger.info("Proxy response :" + response.model_dump_json())
|
||||
return response
|
||||
@@ -799,24 +920,22 @@ class TitanEmbeddingsModel(BedrockEmbeddingsModel):
|
||||
|
||||
|
||||
def get_model(model_id: str) -> BedrockModel:
|
||||
model_name = SUPPORTED_BEDROCK_MODELS.get(model_id, "")
|
||||
if DEBUG:
|
||||
logger.info("model name is " + model_name)
|
||||
# Not using start_with here in case of complex scenarios.
|
||||
# The downside is to change this everytime for a new model supported.
|
||||
match model_name:
|
||||
case "Claude Instant" | "Claude" | "Claude 3 Sonnet" | "Claude 3 Haiku" | "Claude 3 Opus":
|
||||
return ClaudeModel()
|
||||
case "Llama 2 Chat 13B" | "Llama 2 Chat 70B" | "Llama 3 8B Instruct" | "Llama 3 70B Instruct":
|
||||
return LlamaModel()
|
||||
case "Mistral 7B Instruct" | "Mixtral 8x7B Instruct" | "Mistral Large":
|
||||
return MistralModel()
|
||||
case _:
|
||||
logger.error("Unsupported model id " + model_id)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Unsupported model id " + model_id,
|
||||
)
|
||||
logger.info("model id is " + model_id)
|
||||
if model_id not in SUPPORTED_BEDROCK_MODELS.keys():
|
||||
logger.error("Unsupported model id " + model_id)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Unsupported model id " + model_id,
|
||||
)
|
||||
if model_id.startswith("anthropic.claude"):
|
||||
return ClaudeModel()
|
||||
elif model_id.startswith("meta.llama"):
|
||||
return LlamaModel()
|
||||
elif model_id.startswith("mistral.mistral") or model_id.startswith("mistral.mixtral"):
|
||||
return MistralModel()
|
||||
elif model_id.startswith("cohere.command-r"):
|
||||
return CohereCommandModel()
|
||||
|
||||
|
||||
def get_embeddings_model(model_id: str) -> BedrockEmbeddingsModel:
|
||||
|
||||
Reference in New Issue
Block a user