Refactor model implementation

This commit is contained in:
Aiden Dai
2024-05-09 16:58:04 +08:00
parent 180c199da9
commit 9f6b334385
4 changed files with 316 additions and 186 deletions

View File

@@ -29,11 +29,17 @@ class BaseChatModel(ABC):
"""Handle a basic chat completion requests with stream response."""
pass
def _generate_message_id(self) -> str:
@staticmethod
def generate_message_id() -> str:
return "chatcmpl-" + str(uuid.uuid4())[:8]
def _stream_response_to_bytes(self, response: ChatStreamResponse) -> bytes:
@staticmethod
def stream_response_to_bytes(
response: ChatStreamResponse | None = None
) -> bytes:
if response:
return "data: {}\n\n".format(response.model_dump_json()).encode("utf-8")
return "data: [DONE]\n\n".encode("utf-8")
class BaseEmbeddingsModel(ABC):
@@ -46,6 +52,3 @@ class BaseEmbeddingsModel(ABC):
def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse:
"""Handle a basic embeddings request."""
pass
def _generate_message_id(self) -> str:
return "embeddings-" + str(uuid.uuid4())[:8]

View File

@@ -2,7 +2,7 @@ import base64
import json
import logging
import re
from abc import ABC
from abc import ABC, abstractmethod
from typing import AsyncIterable, Iterable, Literal
import boto3
@@ -54,6 +54,8 @@ SUPPORTED_BEDROCK_MODELS = {
"mistral.mistral-7b-instruct-v0:2": "Mistral 7B Instruct",
"mistral.mixtral-8x7b-instruct-v0:1": "Mixtral 8x7B Instruct",
"mistral.mistral-large-2402-v1:0": "Mistral Large",
"cohere.command-r-v1:0": "Command R",
"cohere.command-r-plus-v1:0": "Command R+",
}
SUPPORTED_BEDROCK_EMBEDDING_MODELS = {
@@ -72,28 +74,157 @@ class BedrockModel(BaseChatModel, ABC):
accept = "application/json"
content_type = "application/json"
def _invoke_model(self, args: dict, model_id: str, with_stream: bool = False):
body = json.dumps(args)
# Default field name to get the response message
text_field_name = "text"
# Default field name to get the response finish reason
finish_reason_field_name = "finish_reason"
@abstractmethod
def compose_request_body(self, chat_request: ChatRequest) -> str:
"""Since the request body to Bedrock varies,
each model should implement this to compose the request body.
:param chat_request:
:return: request body as a string
"""
raise NotImplementedError()
def chat(self, chat_request: ChatRequest) -> ChatResponse:
"""Default implementation for Chat API."""
if DEBUG:
logger.info("Raw request: " + chat_request.model_dump_json())
request_body = self.compose_request_body(chat_request)
if DEBUG:
logger.info("Bedrock request: " + request_body)
response = self.invoke_model(
request_body=request_body,
model_id=chat_request.model,
)
message_id = self.generate_message_id()
return self.parse_response(chat_request, response, message_id)
def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
"""Default implementation for Chat Stream API"""
if DEBUG:
logger.info("Raw request: " + chat_request.model_dump_json())
request_body = self.compose_request_body(chat_request)
response = self.invoke_model(
request_body=request_body,
model_id=chat_request.model,
with_stream=True,
)
message_id = self.generate_message_id()
for stream_response in self.parse_stream_response(
chat_request, response, message_id
):
if stream_response.choices:
yield self.stream_response_to_bytes(stream_response)
elif (
chat_request.stream_options
and chat_request.stream_options.include_usage
):
# An empty choices for Usage as per OpenAI doc below:
# if you set stream_options: {"include_usage": true}.
# an additional chunk will be streamed before the data: [DONE] message.
# The usage field on this chunk shows the token usage statistics for the entire request,
# and the choices field will always be an empty array.
# All other chunks will also include a usage field, but with a null value.
yield self.stream_response_to_bytes(stream_response)
# return an [DONE] message at the end.
yield self.stream_response_to_bytes()
def get_message_text(self, response_body: dict) -> str | None:
"""Default func to get the response message.
Ideally, only the field name should be changed."""
return response_body.get(self.text_field_name)
def get_message_finish_reason(self, response_body: dict) -> str | None:
"""Default func to get the finish message.
Ideally, only the field name should be changed."""
return response_body.get(self.finish_reason_field_name)
def get_message_usage(self, response_body: dict) -> tuple[int, int]:
"""Default func to get the finish message.
Can be overridden in the detail model for complex cases."""
input_tokens = int(response_body.get("prompt_token_count", "0"))
output_tokens = int(response_body.get("generation_token_count", "0"))
return input_tokens, output_tokens
def parse_response(
self, chat_request: ChatRequest, service_response: dict, message_id: str
) -> ChatResponse:
response_body = json.loads(service_response.get("body").read())
if DEBUG:
logger.info("Bedrock response body: " + str(response_body))
input_tokens, output_tokens = self.get_message_usage(response_body)
return self.create_response(
model=chat_request.model,
message_id=message_id,
message=self.get_message_text(response_body),
finish_reason=self.get_message_finish_reason(response_body),
input_tokens=input_tokens,
output_tokens=output_tokens,
)
def parse_stream_response(
self, chat_request: ChatRequest, service_response: dict, message_id: str
) -> Iterable[ChatStreamResponse]:
chunk_id = 0
for event in service_response.get("body"):
if DEBUG:
logger.info("Bedrock response chunk: " + str(event))
chunk = json.loads(event["chunk"]["bytes"])
chunk_id += 1
response = self.create_response_stream(
model=chat_request.model,
message_id=message_id,
chunk_message=self.get_message_text(chunk),
finish_reason=self.get_message_finish_reason(chunk),
)
yield response
# Get the usage for streaming response anyway.
if "amazon-bedrock-invocationMetrics" in chunk:
yield self.create_response_stream(
model=chat_request.model,
message_id=message_id,
input_tokens=chunk["amazon-bedrock-invocationMetrics"][
"inputTokenCount"
],
output_tokens=chunk["amazon-bedrock-invocationMetrics"][
"outputTokenCount"
],
)
def invoke_model(self, request_body: str, model_id: str, with_stream: bool = False):
if DEBUG:
logger.info("Invoke Bedrock Model: " + model_id)
logger.info("Bedrock request body: " + body)
logger.info("Bedrock request body: " + request_body)
try:
if with_stream:
return bedrock_runtime.invoke_model_with_response_stream(
body=body,
body=request_body,
modelId=model_id,
accept=self.accept,
contentType=self.content_type,
)
return bedrock_runtime.invoke_model(
body=body,
body=request_body,
modelId=model_id,
accept=self.accept,
contentType=self.content_type,
)
except bedrock_runtime.exceptions.ValidationException as e:
print("Validation Exception")
print(e)
logger.error("Validation Error: " + str(e))
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(e)
@@ -101,8 +232,7 @@ class BedrockModel(BaseChatModel, ABC):
@staticmethod
def merge_message(messages: list[dict]) -> list[dict]:
"""Merge the request messages with the same role as previous message
"""
"""Merge the request messages with the same role as previous message"""
merged_messages = []
prev_role = None
merged_content = ""
@@ -110,10 +240,11 @@ class BedrockModel(BaseChatModel, ABC):
for message in messages:
role = message["role"]
content = message["content"]
if role != prev_role or isinstance(content, list):
if prev_role:
merged_messages.append({"role": prev_role, "content": merged_content})
merged_messages.append(
{"role": prev_role, "content": merged_content}
)
if isinstance(content, str):
merged_content = content
prev_role = role
@@ -132,7 +263,7 @@ class BedrockModel(BaseChatModel, ABC):
return merged_messages
@staticmethod
def _create_response(
def create_response(
model: str,
message_id: str,
message: str | None = None,
@@ -144,7 +275,8 @@ class BedrockModel(BaseChatModel, ABC):
response = ChatResponse(
id=message_id,
model=model,
choices=[Choice(
choices=[
Choice(
index=0,
message=ChatResponseMessage(
role="assistant",
@@ -152,7 +284,8 @@ class BedrockModel(BaseChatModel, ABC):
content=message,
),
finish_reason="tool_calls" if tools else finish_reason,
)],
)
],
usage=Usage(
prompt_tokens=input_tokens,
completion_tokens=output_tokens,
@@ -164,25 +297,41 @@ class BedrockModel(BaseChatModel, ABC):
return response
@staticmethod
def _create_response_stream(
def create_response_stream(
model: str,
message_id: str,
chunk_message: str | None = None,
finish_reason: str | None = None,
tools: list[ToolCall] | None = None,
input_tokens: int = 0,
output_tokens: int = 0,
) -> ChatStreamResponse:
if chunk_message or finish_reason or tools:
response = ChatStreamResponse(
id=message_id,
model=model,
choices=[ChoiceDelta(
choices=[
ChoiceDelta(
index=0,
delta=ChatResponseMessage(
role="assistant",
tool_calls=tools,
content=chunk_message,
),
finish_reason="tool_calls" if tools else finish_reason,
)],
finish_reason=finish_reason,
)
],
)
else:
response = ChatStreamResponse(
id=message_id,
model=model,
choices=[],
usage=Usage(
prompt_tokens=input_tokens,
completion_tokens=output_tokens,
total_tokens=input_tokens + output_tokens,
),
)
if DEBUG:
logger.info("Proxy response :" + response.model_dump_json())
@@ -201,7 +350,7 @@ Please think if you need to use a tool or not for user's question, you must:
{{"name": $TOOL_NAME, "arguments": {{"$PARAMETER_NAME": "$PARAMETER_VALUE", ...}}}}
3. If no tools is needed, respond with normal text."""
def _parse_args(self, chat_request: ChatRequest) -> dict:
def compose_request_body(self, chat_request: ChatRequest) -> str:
args = {
"anthropic_version": self.anthropic_version,
"max_tokens": chat_request.max_tokens,
@@ -239,8 +388,8 @@ Please think if you need to use a tool or not for user's question, you must:
{
"role": "user",
"content": "[Tool result with matching id `{}` of `{}`] ".format(
message.tool_call_id,
message.content),
message.tool_call_id, message.content
),
}
)
else:
@@ -253,74 +402,86 @@ Please think if you need to use a tool or not for user's question, you must:
[tool.function.model_dump() for tool in chat_request.tools]
)
system_prompt += self.tool_prompt.format(tools=tools_str)
converted_messages.append({
'role': 'assistant',
'content': '<tool>'
})
args["stop_sequences"] = ['</function>']
converted_messages.append({"role": "assistant", "content": "<tool>"})
args["stop_sequences"] = ["</function>"]
args["messages"] = self.merge_message(converted_messages)
if system_prompt:
if DEBUG:
logger.info("System Prompt: " + system_prompt)
args["system"] = system_prompt
return args
return json.dumps(args)
def chat(self, chat_request: ChatRequest) -> ChatResponse:
if DEBUG:
logger.info("Raw request: " + chat_request.model_dump_json())
response = self._invoke_model(
args=self._parse_args(chat_request), model_id=chat_request.model
)
response_body = json.loads(response.get("body").read())
def parse_response(
self, chat_request: ChatRequest, service_response: dict, message_id: str
) -> ChatResponse:
response_body = json.loads(service_response.get("body").read())
if DEBUG:
logger.info("Bedrock response body: " + str(response_body))
message = response_body["content"][0]["text"]
finish_reason = response_body["stop_reason"]
tools = None
if chat_request.tools:
if message.startswith("Y</tool>"):
tools = self._parse_tool_message(message)
message = None
finish_reason = "tool_calls"
elif message.startswith("N</tool>"):
message = message[8:].lstrip("\n")
return self._create_response(
return self.create_response(
model=chat_request.model,
message_id=response_body["id"],
message_id=message_id,
message=message,
tools=tools,
finish_reason=response_body["stop_reason"],
finish_reason=finish_reason,
input_tokens=response_body["usage"]["input_tokens"],
output_tokens=response_body["usage"]["output_tokens"],
)
def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
if DEBUG:
logger.info("Raw request: " + chat_request.model_dump_json())
response = self._invoke_model(
args=self._parse_args(chat_request),
model_id=chat_request.model,
with_stream=True,
)
msg_id = ""
def parse_stream_response(
self, chat_request: ChatRequest, service_response: dict, message_id: str
) -> Iterable[ChatStreamResponse]:
chunk_id = 0
tool_message = ""
first_token = True
index = 0
for event in response.get("body"):
for event in service_response.get("body"):
if DEBUG:
logger.info("Bedrock response chunk: " + str(event))
chunk = json.loads(event["chunk"]["bytes"])
chunk_id += 1
if chunk["type"] == "message_start":
msg_id = chunk["message"]["id"]
continue
if chunk["type"] == "message_stop":
# Get the usage for streaming response anyway.
if "amazon-bedrock-invocationMetrics" in chunk:
yield self.create_response_stream(
model=chat_request.model,
message_id=message_id,
input_tokens=chunk["amazon-bedrock-invocationMetrics"][
"inputTokenCount"
],
output_tokens=chunk["amazon-bedrock-invocationMetrics"][
"outputTokenCount"
],
)
break
if chunk["type"] == "message_delta":
elif chunk["type"] == "message_delta":
chunk_message = ""
finish_reason = chunk["delta"]["stop_reason"]
# Send tool message first if any.
if chat_request.tools and tool_message:
tools = self._parse_tool_message(tool_message)
finish_reason = "tool_calls"
response = self.create_response_stream(
model=chat_request.model,
message_id=message_id,
tools=tools,
)
yield response
elif chunk["type"] == "content_block_delta":
chunk_message = chunk["delta"]["text"]
finish_reason = None
@@ -343,22 +504,14 @@ Please think if you need to use a tool or not for user's question, you must:
first_token = False
else:
continue
response = self._create_response_stream(
response = self.create_response_stream(
model=chat_request.model,
message_id=msg_id,
message_id=message_id,
chunk_message=chunk_message,
finish_reason=finish_reason,
)
yield self._stream_response_to_bytes(response)
if chat_request.tools and tool_message:
tools = self._parse_tool_message(tool_message)
response = self._create_response_stream(
model=chat_request.model,
message_id=msg_id,
tools=tools,
)
yield self._stream_response_to_bytes(response)
yield response
def _parse_tool_message(self, tool_message: str) -> list[ToolCall]:
if DEBUG:
@@ -367,9 +520,7 @@ Please think if you need to use a tool or not for user's question, you must:
tool_messages = tool_message[tool_message.rindex("<function>") + len("<function>"):]
function = json.loads(tool_messages.replace("\n", " "))
args = json.dumps(function.get("arguments", {}))
function = ResponseFunction(
name=function["name"], arguments=args
)
function = ResponseFunction(name=function["name"], arguments=args)
return [
ToolCall(
@@ -400,7 +551,7 @@ Please think if you need to use a tool or not for user's question, you must:
# Check if the request was successful
if response.status_code == 200:
content_type = response.headers.get('Content-Type')
content_type = response.headers.get("Content-Type")
if not content_type.startswith("image"):
content_type = "image/jpeg"
# Get the image content
@@ -437,6 +588,8 @@ Please think if you need to use a tool or not for user's question, you must:
class LlamaModel(BedrockModel):
text_field_name = "generation"
finish_reason_field_name = "stop_reason"
@staticmethod
def create_llama3_prompt(chat_request: ChatRequest) -> str:
@@ -460,7 +613,9 @@ class LlamaModel(BedrockModel):
prompt_lines = []
for msg in chat_request.messages:
prompt_lines.append(f"<|start_header_id|>{msg.role}<|end_header_id|>\n\n{msg.content}<|eot_id|>")
prompt_lines.append(
f"<|start_header_id|>{msg.role}<|end_header_id|>\n\n{msg.content}<|eot_id|>"
)
prompt_lines.append(f"<|start_header_id|>assistant<|end_header_id|>\n\n")
prompt = bos_token + "".join(prompt_lines)
if DEBUG:
@@ -512,59 +667,24 @@ class LlamaModel(BedrockModel):
logger.info("Converted prompt: " + prompt.replace("\n", "\\n"))
return prompt
def _parse_args(self, chat_request: ChatRequest) -> dict:
def compose_request_body(self, chat_request: ChatRequest) -> str:
if chat_request.model.startswith("meta.llama2"):
prompt = self.create_llama2_prompt(chat_request)
else:
prompt = self.create_llama3_prompt(chat_request)
# Currently, there is no way to set stop sequence for Llama 3 models.
return {
args = {
"prompt": prompt,
"max_gen_len": chat_request.max_tokens,
"temperature": chat_request.temperature,
"top_p": chat_request.top_p,
}
def chat(self, chat_request: ChatRequest) -> ChatResponse:
response = self._invoke_model(
args=self._parse_args(chat_request), model_id=chat_request.model
)
response_body = json.loads(response.get("body").read())
if DEBUG:
logger.info("Bedrock response body: " + str(response_body))
message_id = self._generate_message_id()
return self._create_response(
model=chat_request.model,
message=response_body["generation"],
message_id=message_id,
input_tokens=response_body["prompt_token_count"],
output_tokens=response_body["generation_token_count"],
)
def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
response = self._invoke_model(
args=self._parse_args(chat_request),
model_id=chat_request.model,
with_stream=True,
)
msg_id = ""
chunk_id = 0
for event in response.get("body"):
if DEBUG:
logger.info("Bedrock response chunk: " + str(event))
chunk = json.loads(event["chunk"]["bytes"])
chunk_id += 1
response = self._create_response_stream(
model=chat_request.model,
message_id=msg_id,
chunk_message=chunk["generation"],
finish_reason=chunk["stop_reason"],
)
yield self._stream_response_to_bytes(response)
return json.dumps(args)
class MistralModel(BedrockModel):
text_field_name = "text"
finish_reason_field_name = "stop_reason"
def _convert_prompt(self, chat_request: ChatRequest) -> str:
"""Create a prompt message follow below example:
@@ -609,51 +729,54 @@ class MistralModel(BedrockModel):
logger.info("Converted prompt: " + prompt.replace("\n", "\\n"))
return prompt
def _parse_args(self, chat_request: ChatRequest) -> dict:
def compose_request_body(self, chat_request: ChatRequest) -> str:
prompt = self._convert_prompt(chat_request)
return {
args = {
"prompt": prompt,
"max_tokens": chat_request.max_tokens,
"temperature": chat_request.temperature,
"top_p": chat_request.top_p,
}
return json.dumps(args)
def chat(self, chat_request: ChatRequest) -> ChatResponse:
def get_message_text(self, response_body: dict) -> str | None:
return super().get_message_text(response_body["outputs"][0])
response = self._invoke_model(
args=self._parse_args(chat_request), model_id=chat_request.model
)
response_body = json.loads(response.get("body").read())
if DEBUG:
logger.info("Bedrock response body: " + str(response_body))
message_id = self._generate_message_id()
def get_message_finish_reason(self, response_body: dict) -> str | None:
return super().get_message_finish_reason(response_body["outputs"][0])
return self._create_response(
model=chat_request.model,
message=response_body["outputs"][0]["text"],
message_id=message_id,
)
def get_message_usage(self, response_body: dict) -> tuple[int, int]:
# Mistral/Mixtral does not provide info about usage
return 0, 0
def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
response = self._invoke_model(
args=self._parse_args(chat_request),
model_id=chat_request.model,
with_stream=True,
class CohereCommandModel(BedrockModel):
def _parse_message(self, message) -> dict:
if message.role not in ["user", "assistant"]:
raise HTTPException(
status_code=400, detail="Only user or assistant message is supported"
)
msg_id = ""
chunk_id = 0
for event in response.get("body"):
if DEBUG:
logger.info("Bedrock response chunk: " + str(event))
chunk = json.loads(event["chunk"]["bytes"])
chunk_id += 1
response = self._create_response_stream(
model=chat_request.model,
message_id=msg_id,
chunk_message=chunk["outputs"][0]["text"],
finish_reason=chunk["outputs"][0]["stop_reason"],
return {
"role": "USER" if message.role == "user" else "CHATBOT",
"message": message.content,
}
def compose_request_body(self, chat_request: ChatRequest) -> str:
messages = chat_request.messages
if messages[-1].role != "user":
raise HTTPException(
status_code=400, detail="Last message should be a valid user message"
)
yield self._stream_response_to_bytes(response)
chat_history = [self._parse_message(msg) for msg in messages[:-1]]
args = {
"message": messages[-1].content,
"chat_history": chat_history,
"max_tokens": chat_request.max_tokens,
"temperature": chat_request.temperature,
"p": chat_request.top_p,
}
return json.dumps(args)
class BedrockEmbeddingsModel(BaseEmbeddingsModel, ABC):
@@ -673,8 +796,7 @@ class BedrockEmbeddingsModel(BaseEmbeddingsModel, ABC):
contentType=self.content_type,
)
except bedrock_runtime.exceptions.ValidationException as e:
print("Validation Exception")
print(e)
logger.error("Validation Error: " + str(e))
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(e)
@@ -705,7 +827,6 @@ class BedrockEmbeddingsModel(BaseEmbeddingsModel, ABC):
total_tokens=input_tokens + output_tokens,
),
)
if DEBUG:
logger.info("Proxy response :" + response.model_dump_json())
return response
@@ -799,24 +920,22 @@ class TitanEmbeddingsModel(BedrockEmbeddingsModel):
def get_model(model_id: str) -> BedrockModel:
model_name = SUPPORTED_BEDROCK_MODELS.get(model_id, "")
if DEBUG:
logger.info("model name is " + model_name)
# Not using start_with here in case of complex scenarios.
# The downside is to change this everytime for a new model supported.
match model_name:
case "Claude Instant" | "Claude" | "Claude 3 Sonnet" | "Claude 3 Haiku" | "Claude 3 Opus":
return ClaudeModel()
case "Llama 2 Chat 13B" | "Llama 2 Chat 70B" | "Llama 3 8B Instruct" | "Llama 3 70B Instruct":
return LlamaModel()
case "Mistral 7B Instruct" | "Mixtral 8x7B Instruct" | "Mistral Large":
return MistralModel()
case _:
logger.info("model id is " + model_id)
if model_id not in SUPPORTED_BEDROCK_MODELS.keys():
logger.error("Unsupported model id " + model_id)
raise HTTPException(
status_code=400,
detail="Unsupported model id " + model_id,
)
if model_id.startswith("anthropic.claude"):
return ClaudeModel()
elif model_id.startswith("meta.llama"):
return LlamaModel()
elif model_id.startswith("mistral.mistral") or model_id.startswith("mistral.mixtral"):
return MistralModel()
elif model_id.startswith("cohere.command-r"):
return CohereCommandModel()
def get_embeddings_model(model_id: str) -> BedrockEmbeddingsModel:

View File

@@ -79,12 +79,17 @@ class Tool(BaseModel):
function: Function
class StreamOptions(BaseModel):
include_usage: bool = True
class ChatRequest(BaseModel):
messages: list[SystemMessage | UserMessage | AssistantMessage | ToolMessage]
model: str
frequency_penalty: float | None = Field(default=0.0, le=2.0, ge=-2.0) # Not used
presence_penalty: float | None = Field(default=0.0, le=2.0, ge=-2.0) # Not used
stream: bool | None = False
stream_options: StreamOptions | None = None
temperature: float | None = Field(default=1.0, le=2.0, ge=0.0)
top_p: float | None = Field(default=1.0, le=1.0, ge=0.0)
user: str | None = None # Not used
@@ -138,6 +143,7 @@ class ChatResponse(BaseChatResponse):
class ChatStreamResponse(BaseChatResponse):
choices: list[ChoiceDelta]
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
usage: Usage | None = None
class EmbeddingsRequest(BaseModel):

View File

@@ -26,6 +26,8 @@ List of Amazon Bedrock models currently supported:
- mistral.mistral-7b-instruct-v0:2
- mistral.mixtral-8x7b-instruct-v0:1
- mistral.mistral-large-2402-v1:0
- cohere.command-r-v1:0
- cohere.command-r-plus-v1:0
# Embeddings
- cohere.embed-multilingual-v3