Refactor model implementation

This commit is contained in:
Aiden Dai
2024-05-09 16:58:04 +08:00
parent 180c199da9
commit 9f6b334385
4 changed files with 316 additions and 186 deletions

View File

@@ -29,11 +29,17 @@ class BaseChatModel(ABC):
"""Handle a basic chat completion requests with stream response.""" """Handle a basic chat completion requests with stream response."""
pass pass
def _generate_message_id(self) -> str: @staticmethod
def generate_message_id() -> str:
return "chatcmpl-" + str(uuid.uuid4())[:8] return "chatcmpl-" + str(uuid.uuid4())[:8]
def _stream_response_to_bytes(self, response: ChatStreamResponse) -> bytes: @staticmethod
return "data: {}\n\n".format(response.model_dump_json()).encode("utf-8") def stream_response_to_bytes(
response: ChatStreamResponse | None = None
) -> bytes:
if response:
return "data: {}\n\n".format(response.model_dump_json()).encode("utf-8")
return "data: [DONE]\n\n".encode("utf-8")
class BaseEmbeddingsModel(ABC): class BaseEmbeddingsModel(ABC):
@@ -46,6 +52,3 @@ class BaseEmbeddingsModel(ABC):
def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse: def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse:
"""Handle a basic embeddings request.""" """Handle a basic embeddings request."""
pass pass
def _generate_message_id(self) -> str:
return "embeddings-" + str(uuid.uuid4())[:8]

View File

@@ -2,7 +2,7 @@ import base64
import json import json
import logging import logging
import re import re
from abc import ABC from abc import ABC, abstractmethod
from typing import AsyncIterable, Iterable, Literal from typing import AsyncIterable, Iterable, Literal
import boto3 import boto3
@@ -54,6 +54,8 @@ SUPPORTED_BEDROCK_MODELS = {
"mistral.mistral-7b-instruct-v0:2": "Mistral 7B Instruct", "mistral.mistral-7b-instruct-v0:2": "Mistral 7B Instruct",
"mistral.mixtral-8x7b-instruct-v0:1": "Mixtral 8x7B Instruct", "mistral.mixtral-8x7b-instruct-v0:1": "Mixtral 8x7B Instruct",
"mistral.mistral-large-2402-v1:0": "Mistral Large", "mistral.mistral-large-2402-v1:0": "Mistral Large",
"cohere.command-r-v1:0": "Command R",
"cohere.command-r-plus-v1:0": "Command R+",
} }
SUPPORTED_BEDROCK_EMBEDDING_MODELS = { SUPPORTED_BEDROCK_EMBEDDING_MODELS = {
@@ -72,28 +74,157 @@ class BedrockModel(BaseChatModel, ABC):
accept = "application/json" accept = "application/json"
content_type = "application/json" content_type = "application/json"
def _invoke_model(self, args: dict, model_id: str, with_stream: bool = False): # Default field name to get the response message
body = json.dumps(args) text_field_name = "text"
# Default field name to get the response finish reason
finish_reason_field_name = "finish_reason"
@abstractmethod
def compose_request_body(self, chat_request: ChatRequest) -> str:
"""Since the request body to Bedrock varies,
each model should implement this to compose the request body.
:param chat_request:
:return: request body as a string
"""
raise NotImplementedError()
def chat(self, chat_request: ChatRequest) -> ChatResponse:
"""Default implementation for Chat API."""
if DEBUG:
logger.info("Raw request: " + chat_request.model_dump_json())
request_body = self.compose_request_body(chat_request)
if DEBUG:
logger.info("Bedrock request: " + request_body)
response = self.invoke_model(
request_body=request_body,
model_id=chat_request.model,
)
message_id = self.generate_message_id()
return self.parse_response(chat_request, response, message_id)
def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
"""Default implementation for Chat Stream API"""
if DEBUG:
logger.info("Raw request: " + chat_request.model_dump_json())
request_body = self.compose_request_body(chat_request)
response = self.invoke_model(
request_body=request_body,
model_id=chat_request.model,
with_stream=True,
)
message_id = self.generate_message_id()
for stream_response in self.parse_stream_response(
chat_request, response, message_id
):
if stream_response.choices:
yield self.stream_response_to_bytes(stream_response)
elif (
chat_request.stream_options
and chat_request.stream_options.include_usage
):
# An empty choices for Usage as per OpenAI doc below:
# if you set stream_options: {"include_usage": true}.
# an additional chunk will be streamed before the data: [DONE] message.
# The usage field on this chunk shows the token usage statistics for the entire request,
# and the choices field will always be an empty array.
# All other chunks will also include a usage field, but with a null value.
yield self.stream_response_to_bytes(stream_response)
# return an [DONE] message at the end.
yield self.stream_response_to_bytes()
def get_message_text(self, response_body: dict) -> str | None:
"""Default func to get the response message.
Ideally, only the field name should be changed."""
return response_body.get(self.text_field_name)
def get_message_finish_reason(self, response_body: dict) -> str | None:
"""Default func to get the finish message.
Ideally, only the field name should be changed."""
return response_body.get(self.finish_reason_field_name)
def get_message_usage(self, response_body: dict) -> tuple[int, int]:
"""Default func to get the finish message.
Can be overridden in the detail model for complex cases."""
input_tokens = int(response_body.get("prompt_token_count", "0"))
output_tokens = int(response_body.get("generation_token_count", "0"))
return input_tokens, output_tokens
def parse_response(
self, chat_request: ChatRequest, service_response: dict, message_id: str
) -> ChatResponse:
response_body = json.loads(service_response.get("body").read())
if DEBUG:
logger.info("Bedrock response body: " + str(response_body))
input_tokens, output_tokens = self.get_message_usage(response_body)
return self.create_response(
model=chat_request.model,
message_id=message_id,
message=self.get_message_text(response_body),
finish_reason=self.get_message_finish_reason(response_body),
input_tokens=input_tokens,
output_tokens=output_tokens,
)
def parse_stream_response(
self, chat_request: ChatRequest, service_response: dict, message_id: str
) -> Iterable[ChatStreamResponse]:
chunk_id = 0
for event in service_response.get("body"):
if DEBUG:
logger.info("Bedrock response chunk: " + str(event))
chunk = json.loads(event["chunk"]["bytes"])
chunk_id += 1
response = self.create_response_stream(
model=chat_request.model,
message_id=message_id,
chunk_message=self.get_message_text(chunk),
finish_reason=self.get_message_finish_reason(chunk),
)
yield response
# Get the usage for streaming response anyway.
if "amazon-bedrock-invocationMetrics" in chunk:
yield self.create_response_stream(
model=chat_request.model,
message_id=message_id,
input_tokens=chunk["amazon-bedrock-invocationMetrics"][
"inputTokenCount"
],
output_tokens=chunk["amazon-bedrock-invocationMetrics"][
"outputTokenCount"
],
)
def invoke_model(self, request_body: str, model_id: str, with_stream: bool = False):
if DEBUG: if DEBUG:
logger.info("Invoke Bedrock Model: " + model_id) logger.info("Invoke Bedrock Model: " + model_id)
logger.info("Bedrock request body: " + body) logger.info("Bedrock request body: " + request_body)
try: try:
if with_stream: if with_stream:
return bedrock_runtime.invoke_model_with_response_stream( return bedrock_runtime.invoke_model_with_response_stream(
body=body, body=request_body,
modelId=model_id, modelId=model_id,
accept=self.accept, accept=self.accept,
contentType=self.content_type, contentType=self.content_type,
) )
return bedrock_runtime.invoke_model( return bedrock_runtime.invoke_model(
body=body, body=request_body,
modelId=model_id, modelId=model_id,
accept=self.accept, accept=self.accept,
contentType=self.content_type, contentType=self.content_type,
) )
except bedrock_runtime.exceptions.ValidationException as e: except bedrock_runtime.exceptions.ValidationException as e:
print("Validation Exception") logger.error("Validation Error: " + str(e))
print(e)
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(e))
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
@@ -101,8 +232,7 @@ class BedrockModel(BaseChatModel, ABC):
@staticmethod @staticmethod
def merge_message(messages: list[dict]) -> list[dict]: def merge_message(messages: list[dict]) -> list[dict]:
"""Merge the request messages with the same role as previous message """Merge the request messages with the same role as previous message"""
"""
merged_messages = [] merged_messages = []
prev_role = None prev_role = None
merged_content = "" merged_content = ""
@@ -110,10 +240,11 @@ class BedrockModel(BaseChatModel, ABC):
for message in messages: for message in messages:
role = message["role"] role = message["role"]
content = message["content"] content = message["content"]
if role != prev_role or isinstance(content, list): if role != prev_role or isinstance(content, list):
if prev_role: if prev_role:
merged_messages.append({"role": prev_role, "content": merged_content}) merged_messages.append(
{"role": prev_role, "content": merged_content}
)
if isinstance(content, str): if isinstance(content, str):
merged_content = content merged_content = content
prev_role = role prev_role = role
@@ -132,7 +263,7 @@ class BedrockModel(BaseChatModel, ABC):
return merged_messages return merged_messages
@staticmethod @staticmethod
def _create_response( def create_response(
model: str, model: str,
message_id: str, message_id: str,
message: str | None = None, message: str | None = None,
@@ -144,15 +275,17 @@ class BedrockModel(BaseChatModel, ABC):
response = ChatResponse( response = ChatResponse(
id=message_id, id=message_id,
model=model, model=model,
choices=[Choice( choices=[
index=0, Choice(
message=ChatResponseMessage( index=0,
role="assistant", message=ChatResponseMessage(
tool_calls=tools, role="assistant",
content=message, tool_calls=tools,
), content=message,
finish_reason="tool_calls" if tools else finish_reason, ),
)], finish_reason="tool_calls" if tools else finish_reason,
)
],
usage=Usage( usage=Usage(
prompt_tokens=input_tokens, prompt_tokens=input_tokens,
completion_tokens=output_tokens, completion_tokens=output_tokens,
@@ -164,26 +297,42 @@ class BedrockModel(BaseChatModel, ABC):
return response return response
@staticmethod @staticmethod
def _create_response_stream( def create_response_stream(
model: str, model: str,
message_id: str, message_id: str,
chunk_message: str | None = None, chunk_message: str | None = None,
finish_reason: str | None = None, finish_reason: str | None = None,
tools: list[ToolCall] | None = None, tools: list[ToolCall] | None = None,
input_tokens: int = 0,
output_tokens: int = 0,
) -> ChatStreamResponse: ) -> ChatStreamResponse:
response = ChatStreamResponse( if chunk_message or finish_reason or tools:
id=message_id, response = ChatStreamResponse(
model=model, id=message_id,
choices=[ChoiceDelta( model=model,
index=0, choices=[
delta=ChatResponseMessage( ChoiceDelta(
role="assistant", index=0,
tool_calls=tools, delta=ChatResponseMessage(
content=chunk_message, role="assistant",
tool_calls=tools,
content=chunk_message,
),
finish_reason=finish_reason,
)
],
)
else:
response = ChatStreamResponse(
id=message_id,
model=model,
choices=[],
usage=Usage(
prompt_tokens=input_tokens,
completion_tokens=output_tokens,
total_tokens=input_tokens + output_tokens,
), ),
finish_reason="tool_calls" if tools else finish_reason, )
)],
)
if DEBUG: if DEBUG:
logger.info("Proxy response :" + response.model_dump_json()) logger.info("Proxy response :" + response.model_dump_json())
return response return response
@@ -201,7 +350,7 @@ Please think if you need to use a tool or not for user's question, you must:
{{"name": $TOOL_NAME, "arguments": {{"$PARAMETER_NAME": "$PARAMETER_VALUE", ...}}}} {{"name": $TOOL_NAME, "arguments": {{"$PARAMETER_NAME": "$PARAMETER_VALUE", ...}}}}
3. If no tools is needed, respond with normal text.""" 3. If no tools is needed, respond with normal text."""
def _parse_args(self, chat_request: ChatRequest) -> dict: def compose_request_body(self, chat_request: ChatRequest) -> str:
args = { args = {
"anthropic_version": self.anthropic_version, "anthropic_version": self.anthropic_version,
"max_tokens": chat_request.max_tokens, "max_tokens": chat_request.max_tokens,
@@ -239,8 +388,8 @@ Please think if you need to use a tool or not for user's question, you must:
{ {
"role": "user", "role": "user",
"content": "[Tool result with matching id `{}` of `{}`] ".format( "content": "[Tool result with matching id `{}` of `{}`] ".format(
message.tool_call_id, message.tool_call_id, message.content
message.content), ),
} }
) )
else: else:
@@ -253,74 +402,86 @@ Please think if you need to use a tool or not for user's question, you must:
[tool.function.model_dump() for tool in chat_request.tools] [tool.function.model_dump() for tool in chat_request.tools]
) )
system_prompt += self.tool_prompt.format(tools=tools_str) system_prompt += self.tool_prompt.format(tools=tools_str)
converted_messages.append({ converted_messages.append({"role": "assistant", "content": "<tool>"})
'role': 'assistant', args["stop_sequences"] = ["</function>"]
'content': '<tool>'
})
args["stop_sequences"] = ['</function>']
args["messages"] = self.merge_message(converted_messages) args["messages"] = self.merge_message(converted_messages)
if system_prompt: if system_prompt:
if DEBUG: if DEBUG:
logger.info("System Prompt: " + system_prompt) logger.info("System Prompt: " + system_prompt)
args["system"] = system_prompt args["system"] = system_prompt
return args return json.dumps(args)
def chat(self, chat_request: ChatRequest) -> ChatResponse: def parse_response(
if DEBUG: self, chat_request: ChatRequest, service_response: dict, message_id: str
logger.info("Raw request: " + chat_request.model_dump_json()) ) -> ChatResponse:
response = self._invoke_model( response_body = json.loads(service_response.get("body").read())
args=self._parse_args(chat_request), model_id=chat_request.model
)
response_body = json.loads(response.get("body").read())
if DEBUG: if DEBUG:
logger.info("Bedrock response body: " + str(response_body)) logger.info("Bedrock response body: " + str(response_body))
message = response_body["content"][0]["text"] message = response_body["content"][0]["text"]
finish_reason = response_body["stop_reason"]
tools = None tools = None
if chat_request.tools: if chat_request.tools:
if message.startswith("Y</tool>"): if message.startswith("Y</tool>"):
tools = self._parse_tool_message(message) tools = self._parse_tool_message(message)
message = None message = None
finish_reason = "tool_calls"
elif message.startswith("N</tool>"): elif message.startswith("N</tool>"):
message = message[8:].lstrip("\n") message = message[8:].lstrip("\n")
return self._create_response( return self.create_response(
model=chat_request.model, model=chat_request.model,
message_id=response_body["id"], message_id=message_id,
message=message, message=message,
tools=tools, tools=tools,
finish_reason=response_body["stop_reason"], finish_reason=finish_reason,
input_tokens=response_body["usage"]["input_tokens"], input_tokens=response_body["usage"]["input_tokens"],
output_tokens=response_body["usage"]["output_tokens"], output_tokens=response_body["usage"]["output_tokens"],
) )
def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]: def parse_stream_response(
if DEBUG: self, chat_request: ChatRequest, service_response: dict, message_id: str
logger.info("Raw request: " + chat_request.model_dump_json()) ) -> Iterable[ChatStreamResponse]:
response = self._invoke_model(
args=self._parse_args(chat_request),
model_id=chat_request.model,
with_stream=True,
)
msg_id = ""
chunk_id = 0 chunk_id = 0
tool_message = "" tool_message = ""
first_token = True first_token = True
index = 0 index = 0
for event in response.get("body"): for event in service_response.get("body"):
if DEBUG: if DEBUG:
logger.info("Bedrock response chunk: " + str(event)) logger.info("Bedrock response chunk: " + str(event))
chunk = json.loads(event["chunk"]["bytes"]) chunk = json.loads(event["chunk"]["bytes"])
chunk_id += 1 chunk_id += 1
if chunk["type"] == "message_start": if chunk["type"] == "message_stop":
msg_id = chunk["message"]["id"] # Get the usage for streaming response anyway.
continue if "amazon-bedrock-invocationMetrics" in chunk:
yield self.create_response_stream(
model=chat_request.model,
message_id=message_id,
input_tokens=chunk["amazon-bedrock-invocationMetrics"][
"inputTokenCount"
],
output_tokens=chunk["amazon-bedrock-invocationMetrics"][
"outputTokenCount"
],
)
break
if chunk["type"] == "message_delta": elif chunk["type"] == "message_delta":
chunk_message = "" chunk_message = ""
finish_reason = chunk["delta"]["stop_reason"] finish_reason = chunk["delta"]["stop_reason"]
# Send tool message first if any.
if chat_request.tools and tool_message:
tools = self._parse_tool_message(tool_message)
finish_reason = "tool_calls"
response = self.create_response_stream(
model=chat_request.model,
message_id=message_id,
tools=tools,
)
yield response
elif chunk["type"] == "content_block_delta": elif chunk["type"] == "content_block_delta":
chunk_message = chunk["delta"]["text"] chunk_message = chunk["delta"]["text"]
finish_reason = None finish_reason = None
@@ -343,22 +504,14 @@ Please think if you need to use a tool or not for user's question, you must:
first_token = False first_token = False
else: else:
continue continue
response = self._create_response_stream( response = self.create_response_stream(
model=chat_request.model, model=chat_request.model,
message_id=msg_id, message_id=message_id,
chunk_message=chunk_message, chunk_message=chunk_message,
finish_reason=finish_reason, finish_reason=finish_reason,
) )
yield self._stream_response_to_bytes(response) yield response
if chat_request.tools and tool_message:
tools = self._parse_tool_message(tool_message)
response = self._create_response_stream(
model=chat_request.model,
message_id=msg_id,
tools=tools,
)
yield self._stream_response_to_bytes(response)
def _parse_tool_message(self, tool_message: str) -> list[ToolCall]: def _parse_tool_message(self, tool_message: str) -> list[ToolCall]:
if DEBUG: if DEBUG:
@@ -367,9 +520,7 @@ Please think if you need to use a tool or not for user's question, you must:
tool_messages = tool_message[tool_message.rindex("<function>") + len("<function>"):] tool_messages = tool_message[tool_message.rindex("<function>") + len("<function>"):]
function = json.loads(tool_messages.replace("\n", " ")) function = json.loads(tool_messages.replace("\n", " "))
args = json.dumps(function.get("arguments", {})) args = json.dumps(function.get("arguments", {}))
function = ResponseFunction( function = ResponseFunction(name=function["name"], arguments=args)
name=function["name"], arguments=args
)
return [ return [
ToolCall( ToolCall(
@@ -400,7 +551,7 @@ Please think if you need to use a tool or not for user's question, you must:
# Check if the request was successful # Check if the request was successful
if response.status_code == 200: if response.status_code == 200:
content_type = response.headers.get('Content-Type') content_type = response.headers.get("Content-Type")
if not content_type.startswith("image"): if not content_type.startswith("image"):
content_type = "image/jpeg" content_type = "image/jpeg"
# Get the image content # Get the image content
@@ -437,6 +588,8 @@ Please think if you need to use a tool or not for user's question, you must:
class LlamaModel(BedrockModel): class LlamaModel(BedrockModel):
text_field_name = "generation"
finish_reason_field_name = "stop_reason"
@staticmethod @staticmethod
def create_llama3_prompt(chat_request: ChatRequest) -> str: def create_llama3_prompt(chat_request: ChatRequest) -> str:
@@ -460,7 +613,9 @@ class LlamaModel(BedrockModel):
prompt_lines = [] prompt_lines = []
for msg in chat_request.messages: for msg in chat_request.messages:
prompt_lines.append(f"<|start_header_id|>{msg.role}<|end_header_id|>\n\n{msg.content}<|eot_id|>") prompt_lines.append(
f"<|start_header_id|>{msg.role}<|end_header_id|>\n\n{msg.content}<|eot_id|>"
)
prompt_lines.append(f"<|start_header_id|>assistant<|end_header_id|>\n\n") prompt_lines.append(f"<|start_header_id|>assistant<|end_header_id|>\n\n")
prompt = bos_token + "".join(prompt_lines) prompt = bos_token + "".join(prompt_lines)
if DEBUG: if DEBUG:
@@ -512,59 +667,24 @@ class LlamaModel(BedrockModel):
logger.info("Converted prompt: " + prompt.replace("\n", "\\n")) logger.info("Converted prompt: " + prompt.replace("\n", "\\n"))
return prompt return prompt
def _parse_args(self, chat_request: ChatRequest) -> dict: def compose_request_body(self, chat_request: ChatRequest) -> str:
if chat_request.model.startswith("meta.llama2"): if chat_request.model.startswith("meta.llama2"):
prompt = self.create_llama2_prompt(chat_request) prompt = self.create_llama2_prompt(chat_request)
else: else:
prompt = self.create_llama3_prompt(chat_request) prompt = self.create_llama3_prompt(chat_request)
# Currently, there is no way to set stop sequence for Llama 3 models. args = {
return {
"prompt": prompt, "prompt": prompt,
"max_gen_len": chat_request.max_tokens, "max_gen_len": chat_request.max_tokens,
"temperature": chat_request.temperature, "temperature": chat_request.temperature,
"top_p": chat_request.top_p, "top_p": chat_request.top_p,
} }
return json.dumps(args)
def chat(self, chat_request: ChatRequest) -> ChatResponse:
response = self._invoke_model(
args=self._parse_args(chat_request), model_id=chat_request.model
)
response_body = json.loads(response.get("body").read())
if DEBUG:
logger.info("Bedrock response body: " + str(response_body))
message_id = self._generate_message_id()
return self._create_response(
model=chat_request.model,
message=response_body["generation"],
message_id=message_id,
input_tokens=response_body["prompt_token_count"],
output_tokens=response_body["generation_token_count"],
)
def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
response = self._invoke_model(
args=self._parse_args(chat_request),
model_id=chat_request.model,
with_stream=True,
)
msg_id = ""
chunk_id = 0
for event in response.get("body"):
if DEBUG:
logger.info("Bedrock response chunk: " + str(event))
chunk = json.loads(event["chunk"]["bytes"])
chunk_id += 1
response = self._create_response_stream(
model=chat_request.model,
message_id=msg_id,
chunk_message=chunk["generation"],
finish_reason=chunk["stop_reason"],
)
yield self._stream_response_to_bytes(response)
class MistralModel(BedrockModel): class MistralModel(BedrockModel):
text_field_name = "text"
finish_reason_field_name = "stop_reason"
def _convert_prompt(self, chat_request: ChatRequest) -> str: def _convert_prompt(self, chat_request: ChatRequest) -> str:
"""Create a prompt message follow below example: """Create a prompt message follow below example:
@@ -609,51 +729,54 @@ class MistralModel(BedrockModel):
logger.info("Converted prompt: " + prompt.replace("\n", "\\n")) logger.info("Converted prompt: " + prompt.replace("\n", "\\n"))
return prompt return prompt
def _parse_args(self, chat_request: ChatRequest) -> dict: def compose_request_body(self, chat_request: ChatRequest) -> str:
prompt = self._convert_prompt(chat_request) prompt = self._convert_prompt(chat_request)
return { args = {
"prompt": prompt, "prompt": prompt,
"max_tokens": chat_request.max_tokens, "max_tokens": chat_request.max_tokens,
"temperature": chat_request.temperature, "temperature": chat_request.temperature,
"top_p": chat_request.top_p, "top_p": chat_request.top_p,
} }
return json.dumps(args)
def chat(self, chat_request: ChatRequest) -> ChatResponse: def get_message_text(self, response_body: dict) -> str | None:
return super().get_message_text(response_body["outputs"][0])
response = self._invoke_model( def get_message_finish_reason(self, response_body: dict) -> str | None:
args=self._parse_args(chat_request), model_id=chat_request.model return super().get_message_finish_reason(response_body["outputs"][0])
)
response_body = json.loads(response.get("body").read())
if DEBUG:
logger.info("Bedrock response body: " + str(response_body))
message_id = self._generate_message_id()
return self._create_response( def get_message_usage(self, response_body: dict) -> tuple[int, int]:
model=chat_request.model, # Mistral/Mixtral does not provide info about usage
message=response_body["outputs"][0]["text"], return 0, 0
message_id=message_id,
)
def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
response = self._invoke_model( class CohereCommandModel(BedrockModel):
args=self._parse_args(chat_request),
model_id=chat_request.model, def _parse_message(self, message) -> dict:
with_stream=True, if message.role not in ["user", "assistant"]:
) raise HTTPException(
msg_id = "" status_code=400, detail="Only user or assistant message is supported"
chunk_id = 0
for event in response.get("body"):
if DEBUG:
logger.info("Bedrock response chunk: " + str(event))
chunk = json.loads(event["chunk"]["bytes"])
chunk_id += 1
response = self._create_response_stream(
model=chat_request.model,
message_id=msg_id,
chunk_message=chunk["outputs"][0]["text"],
finish_reason=chunk["outputs"][0]["stop_reason"],
) )
yield self._stream_response_to_bytes(response) return {
"role": "USER" if message.role == "user" else "CHATBOT",
"message": message.content,
}
def compose_request_body(self, chat_request: ChatRequest) -> str:
messages = chat_request.messages
if messages[-1].role != "user":
raise HTTPException(
status_code=400, detail="Last message should be a valid user message"
)
chat_history = [self._parse_message(msg) for msg in messages[:-1]]
args = {
"message": messages[-1].content,
"chat_history": chat_history,
"max_tokens": chat_request.max_tokens,
"temperature": chat_request.temperature,
"p": chat_request.top_p,
}
return json.dumps(args)
class BedrockEmbeddingsModel(BaseEmbeddingsModel, ABC): class BedrockEmbeddingsModel(BaseEmbeddingsModel, ABC):
@@ -673,8 +796,7 @@ class BedrockEmbeddingsModel(BaseEmbeddingsModel, ABC):
contentType=self.content_type, contentType=self.content_type,
) )
except bedrock_runtime.exceptions.ValidationException as e: except bedrock_runtime.exceptions.ValidationException as e:
print("Validation Exception") logger.error("Validation Error: " + str(e))
print(e)
raise HTTPException(status_code=400, detail=str(e)) raise HTTPException(status_code=400, detail=str(e))
except Exception as e: except Exception as e:
logger.error(e) logger.error(e)
@@ -705,7 +827,6 @@ class BedrockEmbeddingsModel(BaseEmbeddingsModel, ABC):
total_tokens=input_tokens + output_tokens, total_tokens=input_tokens + output_tokens,
), ),
) )
if DEBUG: if DEBUG:
logger.info("Proxy response :" + response.model_dump_json()) logger.info("Proxy response :" + response.model_dump_json())
return response return response
@@ -799,24 +920,22 @@ class TitanEmbeddingsModel(BedrockEmbeddingsModel):
def get_model(model_id: str) -> BedrockModel: def get_model(model_id: str) -> BedrockModel:
model_name = SUPPORTED_BEDROCK_MODELS.get(model_id, "")
if DEBUG: if DEBUG:
logger.info("model name is " + model_name) logger.info("model id is " + model_id)
# Not using start_with here in case of complex scenarios. if model_id not in SUPPORTED_BEDROCK_MODELS.keys():
# The downside is to change this everytime for a new model supported. logger.error("Unsupported model id " + model_id)
match model_name: raise HTTPException(
case "Claude Instant" | "Claude" | "Claude 3 Sonnet" | "Claude 3 Haiku" | "Claude 3 Opus": status_code=400,
return ClaudeModel() detail="Unsupported model id " + model_id,
case "Llama 2 Chat 13B" | "Llama 2 Chat 70B" | "Llama 3 8B Instruct" | "Llama 3 70B Instruct": )
return LlamaModel() if model_id.startswith("anthropic.claude"):
case "Mistral 7B Instruct" | "Mixtral 8x7B Instruct" | "Mistral Large": return ClaudeModel()
return MistralModel() elif model_id.startswith("meta.llama"):
case _: return LlamaModel()
logger.error("Unsupported model id " + model_id) elif model_id.startswith("mistral.mistral") or model_id.startswith("mistral.mixtral"):
raise HTTPException( return MistralModel()
status_code=400, elif model_id.startswith("cohere.command-r"):
detail="Unsupported model id " + model_id, return CohereCommandModel()
)
def get_embeddings_model(model_id: str) -> BedrockEmbeddingsModel: def get_embeddings_model(model_id: str) -> BedrockEmbeddingsModel:

View File

@@ -79,12 +79,17 @@ class Tool(BaseModel):
function: Function function: Function
class StreamOptions(BaseModel):
include_usage: bool = True
class ChatRequest(BaseModel): class ChatRequest(BaseModel):
messages: list[SystemMessage | UserMessage | AssistantMessage | ToolMessage] messages: list[SystemMessage | UserMessage | AssistantMessage | ToolMessage]
model: str model: str
frequency_penalty: float | None = Field(default=0.0, le=2.0, ge=-2.0) # Not used frequency_penalty: float | None = Field(default=0.0, le=2.0, ge=-2.0) # Not used
presence_penalty: float | None = Field(default=0.0, le=2.0, ge=-2.0) # Not used presence_penalty: float | None = Field(default=0.0, le=2.0, ge=-2.0) # Not used
stream: bool | None = False stream: bool | None = False
stream_options: StreamOptions | None = None
temperature: float | None = Field(default=1.0, le=2.0, ge=0.0) temperature: float | None = Field(default=1.0, le=2.0, ge=0.0)
top_p: float | None = Field(default=1.0, le=1.0, ge=0.0) top_p: float | None = Field(default=1.0, le=1.0, ge=0.0)
user: str | None = None # Not used user: str | None = None # Not used
@@ -138,6 +143,7 @@ class ChatResponse(BaseChatResponse):
class ChatStreamResponse(BaseChatResponse): class ChatStreamResponse(BaseChatResponse):
choices: list[ChoiceDelta] choices: list[ChoiceDelta]
object: Literal["chat.completion.chunk"] = "chat.completion.chunk" object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
usage: Usage | None = None
class EmbeddingsRequest(BaseModel): class EmbeddingsRequest(BaseModel):

View File

@@ -26,6 +26,8 @@ List of Amazon Bedrock models currently supported:
- mistral.mistral-7b-instruct-v0:2 - mistral.mistral-7b-instruct-v0:2
- mistral.mixtral-8x7b-instruct-v0:1 - mistral.mixtral-8x7b-instruct-v0:1
- mistral.mistral-large-2402-v1:0 - mistral.mistral-large-2402-v1:0
- cohere.command-r-v1:0
- cohere.command-r-plus-v1:0
# Embeddings # Embeddings
- cohere.embed-multilingual-v3 - cohere.embed-multilingual-v3