diff --git a/src/api/models/__init__.py b/src/api/models/__init__.py index eb89239..e69de29 100644 --- a/src/api/models/__init__.py +++ b/src/api/models/__init__.py @@ -1,7 +0,0 @@ -from api.models.bedrock import ( - ClaudeModel, - SUPPORTED_BEDROCK_MODELS, - SUPPORTED_BEDROCK_EMBEDDING_MODELS, - get_model, - get_embeddings_model, -) diff --git a/src/api/models/base.py b/src/api/models/base.py index 9dc74dd..9d9db7f 100644 --- a/src/api/models/base.py +++ b/src/api/models/base.py @@ -1,3 +1,4 @@ +import time import uuid from abc import ABC, abstractmethod from typing import AsyncIterable @@ -19,6 +20,14 @@ class BaseChatModel(ABC): Currently, only Bedrock model is supported, but may be used for SageMaker models if needed. """ + def list_models(self) -> list[str]: + """Return a list of supported models""" + return [] + + def validate(self, chat_request: ChatRequest): + """Validate chat completion requests.""" + pass + @abstractmethod def chat(self, chat_request: ChatRequest) -> ChatResponse: """Handle a basic chat completion requests.""" @@ -38,7 +47,11 @@ class BaseChatModel(ABC): response: ChatStreamResponse | None = None ) -> bytes: if response: - return "data: {}\n\n".format(response.model_dump_json()).encode("utf-8") + # to populate other fields when using exclude_unset=True + response.system_fingerprint = "fp" + response.object = "chat.completion.chunk" + response.created = int(time.time()) + return "data: {}\n\n".format(response.model_dump_json(exclude_unset=True)).encode("utf-8") return "data: [DONE]\n\n".encode("utf-8") diff --git a/src/api/models/bedrock.py b/src/api/models/bedrock.py index 11f4a90..da0f188 100644 --- a/src/api/models/bedrock.py +++ b/src/api/models/bedrock.py @@ -2,7 +2,7 @@ import base64 import json import logging import re -from abc import ABC, abstractmethod +from abc import ABC from typing import AsyncIterable, Iterable, Literal import boto3 @@ -20,11 +20,15 @@ from api.schema import ( ChatResponseMessage, Usage, ChatStreamResponse, - ChoiceDelta, ImageContent, TextContent, - ResponseFunction, ToolCall, + ChoiceDelta, + UserMessage, + AssistantMessage, + ToolMessage, + Function, + ResponseFunction, # Embeddings EmbeddingsRequest, EmbeddingsResponse, @@ -40,24 +44,6 @@ bedrock_runtime = boto3.client( region_name=AWS_REGION, ) -SUPPORTED_BEDROCK_MODELS = { - "anthropic.claude-instant-v1": "Claude Instant", - "anthropic.claude-v2:1": "Claude", - "anthropic.claude-v2": "Claude", - "anthropic.claude-3-sonnet-20240229-v1:0": "Claude 3 Sonnet", - "anthropic.claude-3-opus-20240229-v1:0": "Claude 3 Opus", - "anthropic.claude-3-haiku-20240307-v1:0": "Claude 3 Haiku", - "meta.llama2-13b-chat-v1": "Llama 2 Chat 13B", - "meta.llama2-70b-chat-v1": "Llama 2 Chat 70B", - "meta.llama3-8b-instruct-v1:0": "Llama 3 8B Instruct", - "meta.llama3-70b-instruct-v1:0": "Llama 3 70B Instruct", - "mistral.mistral-7b-instruct-v0:2": "Mistral 7B Instruct", - "mistral.mixtral-8x7b-instruct-v0:1": "Mixtral 8x7B Instruct", - "mistral.mistral-large-2402-v1:0": "Mistral Large", - "cohere.command-r-v1:0": "Command R", - "cohere.command-r-plus-v1:0": "Command R+", -} - SUPPORTED_BEDROCK_EMBEDDING_MODELS = { "cohere.embed-multilingual-v3": "Cohere Embed Multilingual", "cohere.embed-english-v3": "Cohere Embed English", @@ -69,58 +55,199 @@ SUPPORTED_BEDROCK_EMBEDDING_MODELS = { ENCODER = tiktoken.get_encoding("cl100k_base") -# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html -class BedrockModel(BaseChatModel, ABC): - accept = "application/json" - content_type = "application/json" +class BedrockModel(BaseChatModel): + # https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#conversation-inference-supported-models-features + _supported_models = { + "amazon.titan-text-premier-v1:0": { + "system": True, + "multimodal": False, + "tool_call": False, + "stream_tool_call": False, + }, + "anthropic.claude-instant-v1": { + "system": True, + "multimodal": False, + "tool_call": False, + "stream_tool_call": False, + }, + "anthropic.claude-v2:1": { + "system": True, + "multimodal": False, + "tool_call": False, + "stream_tool_call": False, + }, + "anthropic.claude-v2": { + "system": True, + "multimodal": False, + "tool_call": False, + "stream_tool_call": False, + }, + "anthropic.claude-3-sonnet-20240229-v1:0": { + "system": True, + "multimodal": True, + "tool_call": True, + "stream_tool_call": True, + }, + "anthropic.claude-3-opus-20240229-v1:0": { + "system": True, + "multimodal": True, + "tool_call": True, + "stream_tool_call": True, + }, + "anthropic.claude-3-haiku-20240307-v1:0": { + "system": True, + "multimodal": True, + "tool_call": True, + "stream_tool_call": True, + }, + "meta.llama2-13b-chat-v1": { + "system": True, + "multimodal": False, + "tool_call": False, + "stream_tool_call": False, + }, + "meta.llama2-70b-chat-v1": { + "system": True, + "multimodal": False, + "tool_call": False, + "stream_tool_call": False, + }, + "meta.llama3-8b-instruct-v1:0": { + "system": True, + "multimodal": False, + "tool_call": False, + "stream_tool_call": False, + }, + "meta.llama3-70b-instruct-v1:0": { + "system": True, + "multimodal": False, + "tool_call": False, + "stream_tool_call": False, + }, + "mistral.mistral-7b-instruct-v0:2": { + "system": False, + "multimodal": False, + "tool_call": False, + "stream_tool_call": False, + }, + "mistral.mixtral-8x7b-instruct-v0:1": { + "system": False, + "multimodal": False, + "tool_call": False, + "stream_tool_call": False, + }, + "mistral.mistral-small-2402-v1:0": { + "system": True, + "multimodal": False, + "tool_call": False, + "stream_tool_call": False, + }, + "mistral.mistral-large-2402-v1:0": { + "system": True, + "multimodal": False, + "tool_call": True, + "stream_tool_call": False, + }, + "cohere.command-r-v1:0": { + "system": True, + "multimodal": False, + "tool_call": True, + "stream_tool_call": False, + }, + "cohere.command-r-plus-v1:0": { + "system": True, + "multimodal": False, + "tool_call": True, + "stream_tool_call": False, + }, + } - # Default field name to get the response message - text_field_name = "text" + def list_models(self) -> list[str]: + return list(self._supported_models.keys()) - # Default field name to get the response finish reason - finish_reason_field_name = "finish_reason" + def validate(self, chat_request: ChatRequest): + """Perform basic validation on requests""" + error = "" + # check if model is supported + if chat_request.model not in self._supported_models.keys(): + error = f"Unsupported model {chat_request.model}, please use models API to get a list of supported models" - @abstractmethod - def compose_request_body(self, chat_request: ChatRequest) -> str: - """Since the request body to Bedrock varies, - each model should implement this to compose the request body. + # check if tool call is supported + elif chat_request.tools and not self._is_tool_call_supported(chat_request.model, stream=chat_request.stream): + tool_call_info = "Tool call with streaming" if chat_request.stream else "Tool call" + error = f"{tool_call_info} is currently not supported by {chat_request.model}" - :param chat_request: - :return: request body as a string - """ - raise NotImplementedError() + # check if system prompt is supported + # nice to have an error rather than ignore it. + elif not self._is_system_prompt_supported(chat_request.model): + error = f"System message is currently not supported by {chat_request.model}" + + if error: + raise HTTPException( + status_code=400, + detail=error, + ) + + def _invoke_bedrock(self, chat_request: ChatRequest, stream=False): + """Common logic for invoke bedrock models""" + if DEBUG: + logger.info("Raw request: " + chat_request.model_dump_json()) + + # convert OpenAI chat request to Bedrock SDK request + args = self._parse_request(chat_request) + if DEBUG: + logger.info("Bedrock request: " + json.dumps(args)) + + try: + if stream: + response = bedrock_runtime.converse_stream(**args) + else: + response = bedrock_runtime.converse(**args) + except bedrock_runtime.exceptions.ValidationException as e: + logger.error("Validation Error: " + str(e)) + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + logger.error(e) + raise HTTPException(status_code=500, detail=str(e)) + return response def chat(self, chat_request: ChatRequest) -> ChatResponse: """Default implementation for Chat API.""" - if DEBUG: - logger.info("Raw request: " + chat_request.model_dump_json()) - request_body = self.compose_request_body(chat_request) - if DEBUG: - logger.info("Bedrock request: " + request_body) - - response = self.invoke_model( - request_body=request_body, - model_id=chat_request.model, - ) message_id = self.generate_message_id() - return self.parse_response(chat_request, response, message_id) + response = self._invoke_bedrock(chat_request) + + output_message = response["output"]["message"] + input_tokens = response["usage"]["inputTokens"] + output_tokens = response["usage"]["outputTokens"] + finish_reason = response["stopReason"] + + chat_response = self._create_response( + model=chat_request.model, + message_id=message_id, + content=output_message["content"], + finish_reason=finish_reason, + input_tokens=input_tokens, + output_tokens=output_tokens, + ) + if DEBUG: + logger.info("Proxy response :" + chat_response.model_dump_json()) + return chat_response def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]: """Default implementation for Chat Stream API""" - if DEBUG: - logger.info("Raw request: " + chat_request.model_dump_json()) - request_body = self.compose_request_body(chat_request) - response = self.invoke_model( - request_body=request_body, - model_id=chat_request.model, - with_stream=True, - ) - + response = self._invoke_bedrock(chat_request, stream=True) message_id = self.generate_message_id() - for stream_response in self.parse_stream_response( - chat_request, response, message_id - ): + + stream = response.get("stream") + for chunk in stream: + stream_response = self._create_response_stream( + model_id=chat_request.model, message_id=message_id, chunk=chunk + ) + if not stream_response: + continue + if DEBUG: + logger.info("Proxy response :" + stream_response.model_dump_json()) if stream_response.choices: yield self.stream_response_to_bytes(stream_response) elif ( @@ -134,156 +261,169 @@ class BedrockModel(BaseChatModel, ABC): # and the choices field will always be an empty array. # All other chunks will also include a usage field, but with a null value. yield self.stream_response_to_bytes(stream_response) + # return an [DONE] message at the end. yield self.stream_response_to_bytes() - def get_message_text(self, response_body: dict) -> str | None: - """Default func to get the response message. + def _parse_system_prompts(self, chat_request: ChatRequest) -> list[dict[str, str]]: + """Create system prompts. + Note that not all models support system prompts. - Ideally, only the field name should be changed.""" - return response_body.get(self.text_field_name) + example output: [{"text" : system_prompt}] - def get_message_finish_reason(self, response_body: dict) -> str | None: - """Default func to get the finish message. + See example: + https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#message-inference-examples + """ - Ideally, only the field name should be changed.""" - return response_body.get(self.finish_reason_field_name) + system_prompts = [] + for message in chat_request.messages: + if message.role != "system": + # ignore system messages here + continue + assert isinstance(message.content, str) + system_prompts.append({"text": message.content}) - def get_message_usage(self, response_body: dict) -> tuple[int, int]: - """Default func to get the finish message. + return system_prompts - Can be overridden in the detail model for complex cases.""" - input_tokens = int(response_body.get("prompt_token_count", "0")) - output_tokens = int(response_body.get("generation_token_count", "0")) - return input_tokens, output_tokens + def _parse_messages(self, chat_request: ChatRequest) -> list[dict]: + """ + Converse API only support user and assistant messages. - def parse_response( - self, chat_request: ChatRequest, service_response: dict, message_id: str - ) -> ChatResponse: - response_body = json.loads(service_response.get("body").read()) - if DEBUG: - logger.info("Bedrock response body: " + str(response_body)) + example output: [{ + "role": "user", + "content": [{"text": input_text}] + }] - input_tokens, output_tokens = self.get_message_usage(response_body) - return self.create_response( - model=chat_request.model, - message_id=message_id, - message=self.get_message_text(response_body), - finish_reason=self.get_message_finish_reason(response_body), - input_tokens=input_tokens, - output_tokens=output_tokens, - ) - - def parse_stream_response( - self, chat_request: ChatRequest, service_response: dict, message_id: str - ) -> Iterable[ChatStreamResponse]: - - chunk_id = 0 - for event in service_response.get("body"): - if DEBUG: - logger.info("Bedrock response chunk: " + str(event)) - chunk = json.loads(event["chunk"]["bytes"]) - chunk_id += 1 - - response = self.create_response_stream( - model=chat_request.model, - message_id=message_id, - chunk_message=self.get_message_text(chunk), - finish_reason=self.get_message_finish_reason(chunk), - ) - yield response - # Get the usage for streaming response anyway. - if "amazon-bedrock-invocationMetrics" in chunk: - yield self.create_response_stream( - model=chat_request.model, - message_id=message_id, - input_tokens=chunk["amazon-bedrock-invocationMetrics"][ - "inputTokenCount" - ], - output_tokens=chunk["amazon-bedrock-invocationMetrics"][ - "outputTokenCount" - ], + See example: + https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#message-inference-examples + """ + messages = [] + for message in chat_request.messages: + if isinstance(message, UserMessage): + messages.append( + { + "role": message.role, + "content": self._parse_content_parts( + message, chat_request.model + ), + } ) - - def invoke_model(self, request_body: str, model_id: str, with_stream: bool = False): - if DEBUG: - logger.info("Invoke Bedrock Model: " + model_id) - logger.info("Bedrock request body: " + request_body) - try: - if with_stream: - return bedrock_runtime.invoke_model_with_response_stream( - body=request_body, - modelId=model_id, - accept=self.accept, - contentType=self.content_type, - ) - return bedrock_runtime.invoke_model( - body=request_body, - modelId=model_id, - accept=self.accept, - contentType=self.content_type, - ) - except bedrock_runtime.exceptions.ValidationException as e: - logger.error("Validation Error: " + str(e)) - raise HTTPException(status_code=400, detail=str(e)) - except Exception as e: - logger.error(e) - raise HTTPException(status_code=500, detail=str(e)) - - @staticmethod - def merge_message(messages: list[dict]) -> list[dict]: - """Merge the request messages with the same role as previous message""" - merged_messages = [] - prev_role = None - merged_content = "" - - for message in messages: - role = message["role"] - content = message["content"] - if role != prev_role or isinstance(content, list): - if prev_role: - merged_messages.append( - {"role": prev_role, "content": merged_content} + elif isinstance(message, AssistantMessage): + if message.content: + # Text message + messages.append( + {"role": message.role, "content": [{"text": message.content}]} ) - if isinstance(content, str): - merged_content = content - prev_role = role else: - merged_messages.append({"role": role, "content": content}) - prev_role = None - merged_content = "" + # Tool use message + tool_input = json.loads(message.tool_calls[0].function.arguments) + messages.append( + { + "role": message.role, + "content": [ + { + "toolUse": { + "toolUseId": message.tool_calls[0].id, + "name": message.tool_calls[0].function.name, + "input": tool_input + } + } + ], + } + ) + elif isinstance(message, ToolMessage): + # Bedrock does not support tool role, + # Add toolResult to content + # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolResultBlock.html + messages.append( + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": message.tool_call_id, + "content": [{"text": message.content}], + } + } + ], + } + ) + else: - if content == merged_content: - # ignore duplicates - continue - merged_content += "\n" + content + # ignore others, such as system messages + continue + return messages - if merged_content: - merged_messages.append({"role": prev_role, "content": merged_content}) - return merged_messages + def _parse_request(self, chat_request: ChatRequest) -> dict: + """Create default converse request body. - @staticmethod - def create_response( + Also perform validations to tool call etc. + + Ref: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html + """ + + messages = self._parse_messages(chat_request) + system_prompts = self._parse_system_prompts(chat_request) + + # Base inference parameters. + inference_config = { + "temperature": chat_request.temperature, + "maxTokens": chat_request.max_tokens, + "topP": chat_request.top_p, + } + + args = { + "modelId": chat_request.model, + "messages": messages, + "system": system_prompts, + "inferenceConfig": inference_config, + } + # add tool config + if chat_request.tools: + args["toolConfig"] = { + "tools": [ + self._convert_tool_spec(t.function) for t in chat_request.tools + ] + } + return args + + def _create_response( + self, model: str, message_id: str, - message: str | None = None, + content: list[dict] = None, finish_reason: str | None = None, - tools: list[ToolCall] | None = None, input_tokens: int = 0, output_tokens: int = 0, ) -> ChatResponse: + + message = ChatResponseMessage( + role="assistant", + ) + if finish_reason == "tool_use": + # https://docs.aws.amazon.com/bedrock/latest/userguide/tool-use.html#tool-use-examples + for part in content: + if "toolUse" in part: + tool = part["toolUse"] + message.tool_calls = [ + ToolCall( + id=tool["toolUseId"], + function=ResponseFunction( + name=tool["name"], + arguments=json.dumps(tool["input"]), + ), + ) + ] + else: + message.content = content[0]["text"] + response = ChatResponse( id=message_id, model=model, choices=[ Choice( - index=0, - message=ChatResponseMessage( - role="assistant", - tool_calls=tools, - content=message, - ), - finish_reason="tool_calls" if tools else finish_reason, + message=message, + finish_reason=finish_reason, ) ], usage=Usage( @@ -292,250 +432,100 @@ class BedrockModel(BaseChatModel, ABC): total_tokens=input_tokens + output_tokens, ), ) - if DEBUG: - logger.info("Proxy response :" + response.model_dump_json()) return response - @staticmethod - def create_response_stream( - model: str, - message_id: str, - chunk_message: str | None = None, - finish_reason: str | None = None, - tools: list[ToolCall] | None = None, - input_tokens: int = 0, - output_tokens: int = 0, - ) -> ChatStreamResponse: - if chunk_message or finish_reason or tools: - response = ChatStreamResponse( + def _create_response_stream( + self, model_id: str, message_id: str, chunk: dict + ) -> ChatStreamResponse | None: + """Parsing the Bedrock stream response chunk. + + Ref: https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#message-inference-examples + """ + if DEBUG: + logger.info("Bedrock response chunk: " + str(chunk)) + + finish_reason = None + message = None + usage = None + + if "messageStart" in chunk: + message = ChatResponseMessage( + role=chunk["messageStart"]["role"], + content="", + ) + if "contentBlockStart" in chunk: + # tool call start + delta = chunk["contentBlockStart"]["start"] + if "toolUse" in delta: + message = ChatResponseMessage( + tool_calls=[ + ToolCall( + id=delta["toolUse"]["toolUseId"], + function=ResponseFunction( + name=delta["toolUse"]["name"], + arguments="", + ), + ) + ] + ) + if "contentBlockDelta" in chunk: + delta = chunk["contentBlockDelta"]["delta"] + if "text" in delta: + # stream content + message = ChatResponseMessage( + content=delta["text"], + ) + else: + # tool use + message = ChatResponseMessage( + tool_calls=[ + ToolCall( + function=ResponseFunction( + arguments=delta["toolUse"]["input"], + ) + ) + ] + ) + if "messageStop" in chunk: + message = ChatResponseMessage() + finish_reason = chunk["messageStop"]["stopReason"] + + if "metadata" in chunk: + # usage information in metadata. + metadata = chunk["metadata"] + if "usage" in metadata: + # token usage + return ChatStreamResponse( + id=message_id, + model=model_id, + choices=[], + usage=Usage( + prompt_tokens=metadata["usage"]["inputTokens"], + completion_tokens=metadata["usage"]["outputTokens"], + total_tokens=metadata["usage"]["totalTokens"], + ), + ) + if message: + return ChatStreamResponse( id=message_id, - model=model, + model=model_id, choices=[ ChoiceDelta( index=0, - delta=ChatResponseMessage( - role="assistant", - tool_calls=tools, - content=chunk_message, - ), + delta=message, + logprobs=None, finish_reason=finish_reason, ) ], - ) - else: - response = ChatStreamResponse( - id=message_id, - model=model, - choices=[], - usage=Usage( - prompt_tokens=input_tokens, - completion_tokens=output_tokens, - total_tokens=input_tokens + output_tokens, - ), - ) - if DEBUG: - logger.info("Proxy response :" + response.model_dump_json()) - return response - - -class ClaudeModel(BedrockModel): - anthropic_version = "bedrock-2023-05-31" - # follow these instructions for tool uses: - tool_prompt = """You have access to the following tools: -{tools} - -Please think if you need to use a tool or not for user's question, you must: -1. Respond Y or N within tags first to indicate that. -2. If a tool is needed, MUST respond a JSON object matching the following schema within tags: - {{"name": $TOOL_NAME, "arguments": {{"$PARAMETER_NAME": "$PARAMETER_VALUE", ...}}}} -3. If no tools is needed, respond with normal text.""" - - def compose_request_body(self, chat_request: ChatRequest) -> str: - args = { - "anthropic_version": self.anthropic_version, - "max_tokens": chat_request.max_tokens, - "top_p": chat_request.top_p, - "temperature": chat_request.temperature, - } - system_prompt = "" - converted_messages = [] - for message in chat_request.messages: - if message.role == "system": - assert isinstance(message.content, str) - system_prompt += message.content + "\n" - elif message.role == "user" and not isinstance(message.content, str): - converted_messages.append( - { - "role": message.role, - "content": self._parse_content_parts(message.content), - } - ) - elif message.role == "assistant" and not message.content: - # if content is empty - # create the content using the tool call info. - tool_content = "[Tool use for `{}` with id `{}` with the following `input`]\n{}".format( - message.tool_calls[0].function.name, - message.tool_calls[0].id, - message.tool_calls[0].function.arguments, - ) - converted_messages.append( - {"role": message.role, "content": tool_content} - ) - elif message.role == "tool": - # Since bedrock does not support tool role - # Convert the tool message to a user message. - converted_messages.append( - { - "role": "user", - "content": "[Tool result with matching id `{}` of `{}`] ".format( - message.tool_call_id, message.content - ), - } - ) - else: - converted_messages.append( - {"role": message.role, "content": message.content} - ) - - if chat_request.tools: - tools_str = json.dumps( - [tool.function.model_dump() for tool in chat_request.tools] - ) - system_prompt += self.tool_prompt.format(tools=tools_str) - converted_messages.append({"role": "assistant", "content": ""}) - args["stop_sequences"] = [""] - args["messages"] = self.merge_message(converted_messages) - if system_prompt: - if DEBUG: - logger.info("System Prompt: " + system_prompt) - args["system"] = system_prompt - - return json.dumps(args) - - def parse_response( - self, chat_request: ChatRequest, service_response: dict, message_id: str - ) -> ChatResponse: - response_body = json.loads(service_response.get("body").read()) - if DEBUG: - logger.info("Bedrock response body: " + str(response_body)) - message = response_body["content"][0]["text"] - finish_reason = response_body["stop_reason"] - tools = None - if chat_request.tools: - if message.startswith("Y"): - tools = self._parse_tool_message(message) - message = None - finish_reason = "tool_calls" - elif message.startswith("N"): - message = message[8:].lstrip("\n") - return self.create_response( - model=chat_request.model, - message_id=message_id, - message=message, - tools=tools, - finish_reason=finish_reason, - input_tokens=response_body["usage"]["input_tokens"], - output_tokens=response_body["usage"]["output_tokens"], - ) - - def parse_stream_response( - self, chat_request: ChatRequest, service_response: dict, message_id: str - ) -> Iterable[ChatStreamResponse]: - - chunk_id = 0 - tool_message = "" - first_token = True - index = 0 - for event in service_response.get("body"): - if DEBUG: - logger.info("Bedrock response chunk: " + str(event)) - chunk = json.loads(event["chunk"]["bytes"]) - chunk_id += 1 - - if chunk["type"] == "message_stop": - # Get the usage for streaming response anyway. - if "amazon-bedrock-invocationMetrics" in chunk: - yield self.create_response_stream( - model=chat_request.model, - message_id=message_id, - input_tokens=chunk["amazon-bedrock-invocationMetrics"][ - "inputTokenCount" - ], - output_tokens=chunk["amazon-bedrock-invocationMetrics"][ - "outputTokenCount" - ], - ) - break - - elif chunk["type"] == "message_delta": - chunk_message = "" - finish_reason = chunk["delta"]["stop_reason"] - - # Send tool message first if any. - if chat_request.tools and tool_message: - tools = self._parse_tool_message(tool_message) - finish_reason = "tool_calls" - response = self.create_response_stream( - model=chat_request.model, - message_id=message_id, - tools=tools, - ) - yield response - - elif chunk["type"] == "content_block_delta": - chunk_message = chunk["delta"]["text"] - finish_reason = None - if chat_request.tools: - # Check first token - if not tool_message and chunk_message == "Y": - tool_message = "Y" - continue - if tool_message: - # Buffer all chunk message - # in order to extract tool call info - tool_message += chunk_message - continue - if index < 3: - # Ignore the N, which is 3 tokens - index += 1 - continue - if first_token: - chunk_message = chunk_message.lstrip("\n") - first_token = False - else: - continue - response = self.create_response_stream( - model=chat_request.model, - message_id=message_id, - chunk_message=chunk_message, - finish_reason=finish_reason, + usage=usage, ) - yield response + return None - def _parse_tool_message(self, tool_message: str) -> list[ToolCall]: - if DEBUG: - logger.info("Tool message: " + tool_message.replace("\n", " ")) - try: - tool_messages = tool_message[tool_message.rindex("") + len(""):] - function = json.loads(tool_messages.replace("\n", " ")) - args = json.dumps(function.get("arguments", {})) - function = ResponseFunction(name=function["name"], arguments=args) - - return [ - ToolCall( - # id="0", - function=function, - ) - ] - - except Exception as e: - logger.error("Failed to parse tool response" + str(e)) - raise HTTPException(status_code=500, detail="Failed to parse tool response") - - def _get_base64_image(self, image_url: str) -> tuple[str, str]: - """Try to get the base64 data from an image url. + def _parse_image(self, image_url: str) -> tuple[bytes, str]: + """Try to get the raw data from an image url. + Ref: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ImageSource.html returns a tuple of (Image Data, Content Type) """ pattern = r"^data:(image/[a-z]*);base64,\s*" @@ -544,7 +534,7 @@ Please think if you need to use a tool or not for user's question, you must: # Only supports 'image/jpeg', 'image/png', 'image/gif' or 'image/webp' if content_type: image_data = re.sub(pattern, "", image_url) - return image_data, content_type.group(1) + return base64.b64decode(image_data), content_type.group(1) # Send a request to the image URL response = requests.get(image_url) @@ -556,228 +546,80 @@ Please think if you need to use a tool or not for user's question, you must: content_type = "image/jpeg" # Get the image content image_content = response.content - # Encode the image content as base64 - base64_image = base64.b64encode(image_content) - return base64_image.decode("utf-8"), content_type + return image_content, content_type else: raise HTTPException( status_code=500, detail="Unable to access the image url" ) def _parse_content_parts( - self, content: list[TextContent | ImageContent] + self, + message: UserMessage, + model_id: str, ) -> list[dict]: - # See: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html + if isinstance(message.content, str): + return [ + { + "text": message.content, + } + ] content_parts = [] - for part in content: + for part in message.content: if isinstance(part, TextContent): - content_parts.append(part.model_dump()) - else: - image_data, content_type = self._get_base64_image(part.image_url.url) content_parts.append( { - "type": "image", - "source": { - "type": "base64", - "media_type": content_type, - "data": image_data, + "text": part.text, + } + ) + elif isinstance(part, ImageContent): + if not self._is_multimodal_supported(model_id): + raise HTTPException( + status_code=400, + detail=f"Multimodal message is currently not supported by {model_id}", + ) + image_data, content_type = self._parse_image(part.image_url.url) + content_parts.append( + { + "image": { + "format": content_type[6:], # image/ + "source": {"bytes": image_data}, }, } ) + else: + # Ignore.. + continue return content_parts + def _is_tool_call_supported(self, model_id: str, stream: bool = False) -> bool: + feature = self._supported_models.get(model_id) + if not feature: + return False + return feature["stream_tool_call"] if stream else feature["tool_call"] -class LlamaModel(BedrockModel): - text_field_name = "generation" - finish_reason_field_name = "stop_reason" + def _is_multimodal_supported(self, model_id: str) -> bool: + feature = self._supported_models.get(model_id) + if not feature: + return False + return feature["multimodal"] - @staticmethod - def create_llama3_prompt(chat_request: ChatRequest) -> str: - """Create a prompt message for Llama 3 following below example: + def _is_system_prompt_supported(self, model_id: str) -> bool: + feature = self._supported_models.get(model_id) + if not feature: + return False + return feature["system"] - <|begin_of_text|><|start_header_id|>system<|end_header_id|> - - {{ system_prompt }}<|eot_id|><|start_header_id|>user<|end_header_id|> - - {{ user_message_1 }}<|eot_id|><|start_header_id|>assistant<|end_header_id|> - - {{ model_answer_1 }}<|eot_id|><|start_header_id|>user<|end_header_id|> - - {{ user_message_2 }}<|eot_id|><|start_header_id|>assistant<|end_header_id|> - """ - if DEBUG: - logger.info("Convert below messages to prompt for Llama 3: ") - for msg in chat_request.messages: - logger.info(msg.model_dump_json()) - bos_token = "<|begin_of_text|>" - - prompt_lines = [] - for msg in chat_request.messages: - prompt_lines.append( - f"<|start_header_id|>{msg.role}<|end_header_id|>\n\n{msg.content}<|eot_id|>" - ) - prompt_lines.append(f"<|start_header_id|>assistant<|end_header_id|>\n\n") - prompt = bos_token + "".join(prompt_lines) - if DEBUG: - logger.info("Converted prompt: " + prompt.replace("\n", "\\n")) - return prompt - - @staticmethod - def create_llama2_prompt(chat_request: ChatRequest) -> str: - """Create a prompt message for Llama 2 following below example: - - [INST] <>\n{your_system_message}\n<>\n\n{user_message_1} [/INST] {model_reply_1} - [INST] {user_message_2} [/INST] - """ - if DEBUG: - logger.info("Convert below messages to prompt for Llama 2: ") - for msg in chat_request.messages: - logger.info(msg.model_dump_json()) - bos_token = "" - eos_token = "" - prompt = "" - end_turn = False - system_prompt = "" - for msg in chat_request.messages: - if msg.role == "system": - system_prompt += "\n" + msg.content + "\n" - continue - if msg.role == "tool": - raise HTTPException( - status_code=500, - detail="Tool prompt is not supported for Llama 2 model", - ) - if not isinstance(msg.content, str): - raise HTTPException( - status_code=400, detail="Content must be a string for Llama 2 model" - ) - if msg.role == "user": - if end_turn: - prompt += bos_token + "[INST] " - prompt += msg.content + " [/INST] " - end_turn = False - else: - prompt += msg.content + eos_token - end_turn = True - - if system_prompt: - system_prompt = "<>" + system_prompt + "<>" - prompt = bos_token + "[INST] " + system_prompt + prompt - if DEBUG: - logger.info("Converted prompt: " + prompt.replace("\n", "\\n")) - return prompt - - def compose_request_body(self, chat_request: ChatRequest) -> str: - if chat_request.model.startswith("meta.llama2"): - prompt = self.create_llama2_prompt(chat_request) - else: - prompt = self.create_llama3_prompt(chat_request) - args = { - "prompt": prompt, - "max_gen_len": chat_request.max_tokens, - "temperature": chat_request.temperature, - "top_p": chat_request.top_p, - } - return json.dumps(args) - - -class MistralModel(BedrockModel): - text_field_name = "text" - finish_reason_field_name = "stop_reason" - - def _convert_prompt(self, chat_request: ChatRequest) -> str: - """Create a prompt message follow below example: - - [INST] {your_system_message}\n{user_message_1} [/INST] {model_reply_1} - [INST] {user_message_2} [/INST] - """ - # TODO: maybe reuse the Llama 2 one. - if DEBUG: - logger.info("Convert below messages to prompt for Mistral/Mixtral model: ") - for msg in chat_request.messages: - logger.info(msg.model_dump_json()) - bos_token = "" - eos_token = "" - prompt = "" - end_turn = False - system_prompt = "" - for msg in chat_request.messages: - if msg.role == "system": - system_prompt += "\n" + msg.content + "\n" - continue - if msg.role == "tool": - raise HTTPException( - status_code=500, - detail="Tool prompt is not supported for Mistral/Mixtral model", - ) - if not isinstance(msg.content, str): - raise HTTPException( - status_code=400, - detail="Content must be a string for Mistral/Mixtral model", - ) - if msg.role == "user": - if end_turn: - prompt += bos_token + "[INST] " - prompt += msg.content + " [/INST] " - end_turn = False - else: - prompt += msg.content + eos_token - end_turn = True - - prompt = bos_token + "[INST] " + system_prompt + prompt - if DEBUG: - logger.info("Converted prompt: " + prompt.replace("\n", "\\n")) - return prompt - - def compose_request_body(self, chat_request: ChatRequest) -> str: - prompt = self._convert_prompt(chat_request) - args = { - "prompt": prompt, - "max_tokens": chat_request.max_tokens, - "temperature": chat_request.temperature, - "top_p": chat_request.top_p, - } - return json.dumps(args) - - def get_message_text(self, response_body: dict) -> str | None: - return super().get_message_text(response_body["outputs"][0]) - - def get_message_finish_reason(self, response_body: dict) -> str | None: - return super().get_message_finish_reason(response_body["outputs"][0]) - - def get_message_usage(self, response_body: dict) -> tuple[int, int]: - # Mistral/Mixtral does not provide info about usage - return 0, 0 - - -class CohereCommandModel(BedrockModel): - - def _parse_message(self, message) -> dict: - if message.role not in ["user", "assistant"]: - raise HTTPException( - status_code=400, detail="Only user or assistant message is supported" - ) + def _convert_tool_spec(self, func: Function) -> dict: return { - "role": "USER" if message.role == "user" else "CHATBOT", - "message": message.content, + "toolSpec": { + "name": func.name, + "description": func.description, + "inputSchema": { + "json": func.parameters, + }, + } } - def compose_request_body(self, chat_request: ChatRequest) -> str: - messages = chat_request.messages - if messages[-1].role != "user": - raise HTTPException( - status_code=400, detail="Last message should be a valid user message" - ) - chat_history = [self._parse_message(msg) for msg in messages[:-1]] - args = { - "message": messages[-1].content, - "chat_history": chat_history, - "max_tokens": chat_request.max_tokens, - "temperature": chat_request.temperature, - "p": chat_request.top_p, - } - return json.dumps(args) - class BedrockEmbeddingsModel(BaseEmbeddingsModel, ABC): accept = "application/json" @@ -919,25 +761,6 @@ class TitanEmbeddingsModel(BedrockEmbeddingsModel): ) -def get_model(model_id: str) -> BedrockModel: - if DEBUG: - logger.info("model id is " + model_id) - if model_id not in SUPPORTED_BEDROCK_MODELS.keys(): - logger.error("Unsupported model id " + model_id) - raise HTTPException( - status_code=400, - detail="Unsupported model id " + model_id, - ) - if model_id.startswith("anthropic.claude"): - return ClaudeModel() - elif model_id.startswith("meta.llama"): - return LlamaModel() - elif model_id.startswith("mistral.mistral") or model_id.startswith("mistral.mixtral"): - return MistralModel() - elif model_id.startswith("cohere.command-r"): - return CohereCommandModel() - - def get_embeddings_model(model_id: str) -> BedrockEmbeddingsModel: model_name = SUPPORTED_BEDROCK_EMBEDDING_MODELS.get(model_id, "") if DEBUG: diff --git a/src/api/routers/chat.py b/src/api/routers/chat.py index f2758fa..58ea73f 100644 --- a/src/api/routers/chat.py +++ b/src/api/routers/chat.py @@ -4,7 +4,7 @@ from fastapi import APIRouter, Depends, Body from fastapi.responses import StreamingResponse from api.auth import api_key_auth -from api.models import get_model +from api.models.bedrock import BedrockModel from api.schema import ChatRequest, ChatResponse, ChatStreamResponse from api.setting import DEFAULT_MODEL @@ -34,9 +34,10 @@ async def chat_completions( ): if chat_request.model.lower().startswith("gpt-"): chat_request.model = DEFAULT_MODEL - + # Exception will be raised if model not supported. - model = get_model(chat_request.model) + model = BedrockModel() + model.validate(chat_request) if chat_request.stream: return StreamingResponse( content=model.chat_stream(chat_request), media_type="text/event-stream" diff --git a/src/api/routers/embeddings.py b/src/api/routers/embeddings.py index 135fc27..e5cde31 100644 --- a/src/api/routers/embeddings.py +++ b/src/api/routers/embeddings.py @@ -3,7 +3,7 @@ from typing import Annotated from fastapi import APIRouter, Depends, Body from api.auth import api_key_auth -from api.models import get_embeddings_model +from api.models.bedrock import get_embeddings_model from api.schema import EmbeddingsRequest, EmbeddingsResponse from api.setting import DEFAULT_EMBEDDING_MODEL diff --git a/src/api/routers/model.py b/src/api/routers/model.py index 2640d7d..ce5e8a1 100644 --- a/src/api/routers/model.py +++ b/src/api/routers/model.py @@ -3,7 +3,7 @@ from typing import Annotated from fastapi import APIRouter, Depends, HTTPException, Path from api.auth import api_key_auth -from api.models import SUPPORTED_BEDROCK_MODELS, SUPPORTED_BEDROCK_EMBEDDING_MODELS +from api.models.bedrock import BedrockModel from api.schema import Models, Model router = APIRouter( @@ -12,16 +12,19 @@ router = APIRouter( # responses={404: {"description": "Not found"}}, ) +chat_model = BedrockModel() + async def validate_model_id(model_id: str): - if model_id not in (SUPPORTED_BEDROCK_MODELS | SUPPORTED_BEDROCK_EMBEDDING_MODELS).keys(): + if model_id not in chat_model.list_models(): raise HTTPException(status_code=500, detail="Unsupported Model Id") @router.get("", response_model=Models) async def list_models(): - model_list = [Model(id=model_id) for model_id in - (SUPPORTED_BEDROCK_MODELS | SUPPORTED_BEDROCK_EMBEDDING_MODELS).keys()] + model_list = [ + Model(id=model_id) for model_id in chat_model.list_models() + ] return Models(data=model_list) diff --git a/src/api/schema.py b/src/api/schema.py index da133d5..2b2f2fb 100644 --- a/src/api/schema.py +++ b/src/api/schema.py @@ -1,5 +1,4 @@ import time -import uuid from typing import Literal, Iterable from pydantic import BaseModel, Field @@ -18,12 +17,12 @@ class Models(BaseModel): class ResponseFunction(BaseModel): - name: str + name: str | None = None arguments: str class ToolCall(BaseModel): - id: str = Field(default_factory=lambda: str(uuid.uuid4())[:8]) + id: str | None = None type: Literal["function"] = "function" function: ResponseFunction @@ -113,8 +112,8 @@ class ChatResponseMessage(BaseModel): class BaseChoice(BaseModel): - index: int - finish_reason: str | None + index: int | None = 0 + finish_reason: str | None = None logprobs: dict | None = None diff --git a/src/api/setting.py b/src/api/setting.py index 56e99c4..9543e20 100644 --- a/src/api/setting.py +++ b/src/api/setting.py @@ -11,27 +11,11 @@ DESCRIPTION = """ Use OpenAI-Compatible RESTful APIs for Amazon Bedrock models. List of Amazon Bedrock models currently supported: - -# Chat -- anthropic.claude-instant-v1 -- anthropic.claude-v2:1 -- anthropic.claude-v2 -- anthropic.claude-3-opus-20240229-v1:0 -- anthropic.claude-3-sonnet-20240229-v1:0 -- anthropic.claude-3-haiku-20240307-v1:0 -- meta.llama2-13b-chat-v1 -- meta.llama2-70b-chat-v1 -- meta.llama3-8b-instruct-v1:0 -- meta.llama3-70b-instruct-v1:0 -- mistral.mistral-7b-instruct-v0:2 -- mistral.mixtral-8x7b-instruct-v0:1 -- mistral.mistral-large-2402-v1:0 -- cohere.command-r-v1:0 -- cohere.command-r-plus-v1:0 - -# Embeddings -- cohere.embed-multilingual-v3 -- cohere.embed-english-v3 +- Anthropic Claude 2 / 3 (Haiku/Sonnet/Opus) +- Meta Llama 2 / 3 +- Mistral / Mixtral +- Cohere Command R / R+ +- Cohere Embedding """ DEBUG = os.environ.get("DEBUG", "false").lower() != "false" diff --git a/src/requirements.txt b/src/requirements.txt index 4dc944e..3af1653 100644 --- a/src/requirements.txt +++ b/src/requirements.txt @@ -1,7 +1,9 @@ -fastapi==0.110.2 +fastapi==0.111.0 pydantic==2.7.1 uvicorn==0.29.0 mangum==0.17.0 tiktoken==0.6.0 -requests==2.32.0 -numpy==1.26.4 \ No newline at end of file +requests==2.32.3 +numpy==1.26.4 +boto3==1.34.117 +botocore==1.34.117 \ No newline at end of file