import base64 import json import logging import re from abc import ABC from typing import AsyncIterable, Iterable import boto3 import requests import tiktoken from fastapi import HTTPException from api.models.base import BaseChatModel, BaseEmbeddingsModel from api.schema import ( # Chat ChatResponse, ChatRequest, Choice, ChatResponseMessage, Usage, ChatStreamResponse, ChoiceDelta, ImageContent, TextContent, ResponseFunction, ToolCall, Tool, # Embeddings EmbeddingsRequest, EmbeddingsResponse, EmbeddingsUsage, Embedding, ) from api.setting import DEBUG, AWS_REGION logger = logging.getLogger(__name__) bedrock_runtime = boto3.client( service_name="bedrock-runtime", region_name=AWS_REGION, ) SUPPORTED_BEDROCK_MODELS = { "anthropic.claude-instant-v1": "Claude Instant", "anthropic.claude-v2:1": "Claude", "anthropic.claude-v2": "Claude", "anthropic.claude-3-sonnet-20240229-v1:0": "Claude 3 Sonnet", "anthropic.claude-3-haiku-20240307-v1:0": "Claude 3 Haiku", "meta.llama2-13b-chat-v1": "Llama 2 Chat 13B", "meta.llama2-70b-chat-v1": "Llama 2 Chat 70B", "mistral.mistral-7b-instruct-v0:2": "Mistral 7B Instruct", "mistral.mixtral-8x7b-instruct-v0:1": "Mixtral 8x7B Instruct", "mistral.mistral-large-2402-v1:0": "Mistral Large", } SUPPORTED_BEDROCK_EMBEDDING_MODELS = { "cohere.embed-multilingual-v3": "Cohere Embed Multilingual", "cohere.embed-english-v3": "Cohere Embed English", # Disable Titan embedding. # "amazon.titan-embed-text-v1": "Titan Embeddings G1 - Text", # "amazon.titan-embed-image-v1": "Titan Multimodal Embeddings G1" } ENCODER = tiktoken.get_encoding("cl100k_base") # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html class BedrockModel(BaseChatModel, ABC): accept = "application/json" content_type = "application/json" def _invoke_model(self, args: dict, model_id: str, with_stream: bool = False): body = json.dumps(args) if DEBUG: logger.info("Invoke Bedrock Model: " + model_id) logger.info("Bedrock request body: " + body) if with_stream: return bedrock_runtime.invoke_model_with_response_stream( body=body, modelId=model_id, accept=self.accept, contentType=self.content_type, ) return bedrock_runtime.invoke_model( body=body, modelId=model_id, accept=self.accept, contentType=self.content_type, ) def _create_response( self, model: str, message: str, message_id: str, tools_message: str | None = None, input_tokens: int = 0, output_tokens: int = 0, ) -> ChatResponse: if tools_message: # For tool response, the content is empty tools = self._parse_tools_response(tools_message) choice = Choice( index=0, message=ChatResponseMessage( role="assistant", tool_calls=tools, ), finish_reason="stop", ) else: choice = Choice( index=0, message=ChatResponseMessage( role="assistant", content=message, ), finish_reason="stop", ) response = ChatResponse( id=message_id, model=model, choices=[choice], usage=Usage( prompt_tokens=input_tokens, completion_tokens=output_tokens, total_tokens=input_tokens + output_tokens, ), ) if DEBUG: logger.info("Proxy response :" + response.model_dump_json()) return response def _create_response_stream( self, model: str, message_id: str, chunk_message: str, finish_reason: str | None ) -> ChatStreamResponse: choice = ChoiceDelta( index=0, delta=ChatResponseMessage( role="assistant", content=chunk_message, ), finish_reason=finish_reason, ) response = ChatStreamResponse( id=message_id, model=model, choices=[choice], ) if DEBUG: logger.info("Proxy response :" + response.model_dump_json()) return response class ClaudeModel(BedrockModel): anthropic_version = "bedrock-2023-05-31" def _parse_tools_response(self, tools_messages: str) -> list[ToolCall]: """Parse the tools response Example tool message like: \n{\n "name": "get_current_weather",\n "arguments": {\n "location": "Shanghai"... }\n}\n """ function = json.loads( tools_messages.replace("\n", " ").encode("unicode_escape") ) args = json.dumps(function.get("arguments", {})) function = ResponseFunction( name=function["name"], arguments=args.replace("\\\\n", "\\n") ) return [ ToolCall( id="0", function=function, ) ] def _get_base64_image(self, image_url: str) -> tuple[str, str]: """Try to get the base64 data from an image url. returns a tuple of (Image Data, Content Type) """ pattern = r"^data:(image/[a-z]*);base64,\s*" content_type = re.search(pattern, image_url) # if already base64 encoded. # Only supports 'image/jpeg', 'image/png', 'image/gif' or 'image/webp' if content_type: image_data = re.sub(pattern, "", image_url) return image_data, content_type.group(1) # Send a request to the image URL response = requests.get(image_url) content_type = response.headers.get('Content-Type') if not content_type.startswith("image"): content_type = "image/jpeg" # Check if the request was successful if response.status_code == 200: # Get the image content image_content = response.content # Encode the image content as base64 base64_image = base64.b64encode(image_content) return base64_image.decode("utf-8"), content_type else: raise HTTPException( status_code=500, detail="Unable to access the image url" ) def _parse_content_parts( self, content: list[TextContent | ImageContent] ) -> list[dict]: # See: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html content_parts = [] for part in content: if isinstance(part, TextContent): content_parts.append(part.model_dump()) else: image_data, content_type = self._get_base64_image(part.image_url.url) content_parts.append( { "type": "image", "source": { "type": "base64", "media_type": content_type, "data": image_data, }, } ) return content_parts def _create_tool_prompt(self, tools: list[Tool]) -> str: tool_prompt = "\nYou have access to the following tools:\n" tool_prompt += json.dumps( [tool.function.model_dump() for tool in tools], indent=2 ) tool_prompt += ( "\nIf you need to use one of the above tools, " "only respond with a JSON object matching the following schema inside a xml tag: \n" '{"name": $TOOL_NAME, "arguments": {"$PARAMETER_NAME": "$PARAMETER_VALUE", ...}\n' ) return tool_prompt def _parse_args(self, chat_request: ChatRequest) -> dict: args = { "anthropic_version": self.anthropic_version, "max_tokens": chat_request.max_tokens, "top_p": chat_request.top_p, "temperature": chat_request.temperature, } system_prompt = "" converted_messages = [] for message in chat_request.messages: if message.role == "system": system_prompt += message.content + "\n" elif message.role == "user" and not isinstance(message.content, str): converted_messages.append( { "role": message.role, "content": self._parse_content_parts(message.content), } ) elif message.role == "assistant" and not message.content: # if content is empty # create the content using the tool call info. tool_content = "Should use {} tool with args: {}".format( message.tool_calls[0].function.name, message.tool_calls[0].function.arguments, ) converted_messages.append( {"role": message.role, "content": tool_content} ) elif message.role == "tool": # Since bedrock does not support tool role # Convert the tool message to a user message. converted_messages.append( { "role": "user", "content": "The result of the tool call is " + message.content, } ) else: converted_messages.append( {"role": message.role, "content": message.content} ) if chat_request.tools: system_prompt += self._create_tool_prompt(chat_request.tools) args["messages"] = converted_messages if system_prompt: if DEBUG: logger.info("System Prompt: " + system_prompt) args["system"] = system_prompt.replace("\n", "") return args def chat(self, chat_request: ChatRequest) -> ChatResponse: if DEBUG: logger.info("Raw request: " + chat_request.model_dump_json()) response = self._invoke_model( args=self._parse_args(chat_request), model_id=chat_request.model ) response_body = json.loads(response.get("body").read()) if DEBUG: logger.info("Bedrock response body: " + str(response_body)) message = response_body["content"][0]["text"] tools_message = None start = message.find("") end = message.find("") if start != -1 and end != -1: tools_message = message[start + 6: end] return self._create_response( model=chat_request.model, message=response_body["content"][0]["text"], message_id=response_body["id"], tools_message=tools_message, input_tokens=response_body["usage"]["input_tokens"], output_tokens=response_body["usage"]["output_tokens"], ) def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]: response = self._invoke_model( args=self._parse_args(chat_request), model_id=chat_request.model, with_stream=True, ) msg_id = "" chunk_id = 0 for event in response.get("body"): if DEBUG: logger.info("Bedrock response chunk: " + str(event)) chunk = json.loads(event["chunk"]["bytes"]) chunk_id += 1 if chunk["type"] == "message_start": msg_id = chunk["message"]["id"] continue if chunk["type"] == "message_delta": chunk_message = "" finish_reason = "stop" elif chunk["type"] == "content_block_delta": chunk_message = chunk["delta"]["text"] finish_reason = None else: continue response = self._create_response_stream( model=chat_request.model, message_id=msg_id, chunk_message=chunk_message, finish_reason=finish_reason, ) yield self._stream_response_to_bytes(response) class Llama2Model(BedrockModel): def _convert_prompt(self, chat_request: ChatRequest) -> str: """Create a prompt message follow below example: [INST] <>\n{your_system_message}\n<>\n\n{user_message_1} [/INST] {model_reply_1} [INST] {user_message_2} [/INST] """ if DEBUG: logger.info("Convert below messages to prompt for Llama 2: ") for msg in chat_request.messages: logger.info(msg.model_dump_json()) bos_token = "" eos_token = "" prompt = "" end_turn = False system_prompt = "" for msg in chat_request.messages: if msg.role == "system": system_prompt += "\n" + msg.content + "\n" continue if msg.role == "tool": raise HTTPException( status_code=500, detail="Tool prompt is not supported for Llama 2 model", ) if not isinstance(msg.content, str): raise HTTPException( status_code=400, detail="Content must be a string for Llama 2 model" ) if msg.role == "user": if end_turn: prompt += bos_token + "[INST] " prompt += msg.content + " [/INST] " end_turn = False else: prompt += msg.content + eos_token end_turn = True if system_prompt: system_prompt = "<>" + system_prompt + "<>" prompt = bos_token + "[INST] " + system_prompt + prompt if DEBUG: logger.info("Converted prompt: " + prompt.replace("\n", "\\n")) return prompt def _parse_args(self, chat_request: ChatRequest) -> dict: prompt = self._convert_prompt(chat_request) return { "prompt": prompt, "max_gen_len": chat_request.max_tokens, "temperature": chat_request.temperature, "top_p": chat_request.top_p, } def chat(self, chat_request: ChatRequest) -> ChatResponse: response = self._invoke_model( args=self._parse_args(chat_request), model_id=chat_request.model ) response_body = json.loads(response.get("body").read()) if DEBUG: logger.info("Bedrock response body: " + str(response_body)) message_id = self._generate_message_id() return self._create_response( model=chat_request.model, message=response_body["generation"], message_id=message_id, input_tokens=response_body["prompt_token_count"], output_tokens=response_body["generation_token_count"], ) def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]: response = self._invoke_model( args=self._parse_args(chat_request), model_id=chat_request.model, with_stream=True, ) msg_id = "" chunk_id = 0 for event in response.get("body"): if DEBUG: logger.info("Bedrock response chunk: " + str(event)) chunk = json.loads(event["chunk"]["bytes"]) chunk_id += 1 response = self._create_response_stream( model=chat_request.model, message_id=msg_id, chunk_message=chunk["generation"], finish_reason=chunk["stop_reason"], ) yield self._stream_response_to_bytes(response) class MistralModel(BedrockModel): def _convert_prompt(self, chat_request: ChatRequest) -> str: """Create a prompt message follow below example: [INST] {your_system_message}\n{user_message_1} [/INST] {model_reply_1} [INST] {user_message_2} [/INST] """ # TODO: maybe reuse the Llama 2 one. if DEBUG: logger.info("Convert below messages to prompt for Mistral/Mixtral model: ") for msg in chat_request.messages: logger.info(msg.model_dump_json()) bos_token = "" eos_token = "" prompt = "" end_turn = False system_prompt = "" for msg in chat_request.messages: if msg.role == "system": system_prompt += "\n" + msg.content + "\n" continue if msg.role == "tool": raise HTTPException( status_code=500, detail="Tool prompt is not supported for Mistral/Mixtral model", ) if not isinstance(msg.content, str): raise HTTPException( status_code=400, detail="Content must be a string for Mistral/Mixtral model", ) if msg.role == "user": if end_turn: prompt += bos_token + "[INST] " prompt += msg.content + " [/INST] " end_turn = False else: prompt += msg.content + eos_token end_turn = True prompt = bos_token + "[INST] " + system_prompt + prompt if DEBUG: logger.info("Converted prompt: " + prompt.replace("\n", "\\n")) return prompt def _parse_args(self, chat_request: ChatRequest) -> dict: prompt = self._convert_prompt(chat_request) return { "prompt": prompt, "max_tokens": chat_request.max_tokens, "temperature": chat_request.temperature, "top_p": chat_request.top_p, } def chat(self, chat_request: ChatRequest) -> ChatResponse: response = self._invoke_model( args=self._parse_args(chat_request), model_id=chat_request.model ) response_body = json.loads(response.get("body").read()) if DEBUG: logger.info("Bedrock response body: " + str(response_body)) message_id = self._generate_message_id() return self._create_response( model=chat_request.model, message=response_body["outputs"][0]["text"], message_id=message_id, ) def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]: response = self._invoke_model( args=self._parse_args(chat_request), model_id=chat_request.model, with_stream=True, ) msg_id = "" chunk_id = 0 for event in response.get("body"): if DEBUG: logger.info("Bedrock response chunk: " + str(event)) chunk = json.loads(event["chunk"]["bytes"]) chunk_id += 1 response = self._create_response_stream( model=chat_request.model, message_id=msg_id, chunk_message=chunk["outputs"][0]["text"], finish_reason=chunk["outputs"][0]["stop_reason"], ) yield self._stream_response_to_bytes(response) class BedrockEmbeddingsModel(BaseEmbeddingsModel, ABC): accept = "application/json" content_type = "application/json" def _invoke_model(self, args: dict, model_id: str): body = json.dumps(args) if DEBUG: logger.info("Invoke Bedrock Model: " + model_id) logger.info("Bedrock request body: " + body) return bedrock_runtime.invoke_model( body=body, modelId=model_id, accept=self.accept, contentType=self.content_type, ) def _create_response( self, embeddings: list[float], model: str, input_tokens: int = 0, output_tokens: int = 0, ) -> EmbeddingsResponse: data = [ Embedding(index=i, embedding=embedding) for i, embedding in enumerate(embeddings) ] response = EmbeddingsResponse( data=data, model=model, usage=EmbeddingsUsage( prompt_tokens=input_tokens, total_tokens=input_tokens + output_tokens, ), ) if DEBUG: logger.info("Proxy response :" + response.model_dump_json()) return response class CohereEmbeddingsModel(BedrockEmbeddingsModel): def _parse_args(self, embeddings_request: EmbeddingsRequest) -> dict: texts = [] if isinstance(embeddings_request.input, str): texts = [embeddings_request.input] elif isinstance(embeddings_request.input, list): texts = embeddings_request.input elif isinstance(embeddings_request.input, Iterable): # For encoded input # The workaround is to use tiktoken to decode to get the original text. encodings = [] for inner in embeddings_request.input: if isinstance(inner, int): # Iterable[int] encodings.append(inner) else: # Iterable[Iterable[int]] text = ENCODER.decode(list(inner)) texts.append(text) if encodings: texts.append(ENCODER.decode(encodings)) # Maximum of 2048 characters args = { "texts": texts, "input_type": "search_document", "truncate": "END", # "NONE|START|END" } return args def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse: response = self._invoke_model( args=self._parse_args(embeddings_request), model_id=embeddings_request.model ) response_body = json.loads(response.get("body").read()) if DEBUG: logger.info("Bedrock response body: " + str(response_body)) return self._create_response( embeddings=response_body["embeddings"], model=embeddings_request.model, ) class TitanEmbeddingsModel(BedrockEmbeddingsModel): def _parse_args(self, embeddings_request: EmbeddingsRequest) -> dict: if isinstance(embeddings_request.input, str): input_text = embeddings_request.input elif ( isinstance(embeddings_request.input, list) and len(embeddings_request.input) == 1 ): input_text = embeddings_request.input[0] else: raise ValueError( "Amazon Titan Embeddings models support only single strings as input." ) args = { "inputText": input_text, # Note: inputImage is not supported! } if embeddings_request.model == "amazon.titan-embed-image-v1": args["embeddingConfig"] = ( embeddings_request.embedding_config if embeddings_request.embedding_config else {"outputEmbeddingLength": 1024} ) return args def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse: response = self._invoke_model( args=self._parse_args(embeddings_request), model_id=embeddings_request.model ) response_body = json.loads(response.get("body").read()) if DEBUG: logger.info("Bedrock response body: " + str(response_body)) return self._create_response( embeddings=[response_body["embedding"]], model=embeddings_request.model, input_tokens=response_body["inputTextTokenCount"], ) def get_model(model_id: str) -> BedrockModel: model_name = SUPPORTED_BEDROCK_MODELS.get(model_id, "") if DEBUG: logger.info("model name is " + model_name) if model_name in ["Claude Instant", "Claude", "Claude 3 Sonnet", "Claude 3 Haiku"]: return ClaudeModel() elif model_name in ["Llama 2 Chat 13B", "Llama 2 Chat 70B"]: return Llama2Model() elif model_name in ["Mistral 7B Instruct", "Mixtral 8x7B Instruct", "Mistral Large"]: return MistralModel() else: logger.error("Unsupported model id " + model_id) raise HTTPException( status_code=500, detail="Unsupported model id " + model_id, ) def get_embeddings_model(model_id: str) -> BedrockEmbeddingsModel: model_name = SUPPORTED_BEDROCK_EMBEDDING_MODELS.get(model_id, "") if DEBUG: logger.info("model name is " + model_name) if model_name in ["Cohere Embed Multilingual", "Cohere Embed English"]: return CohereEmbeddingsModel() elif model_name in ["Titan Embeddings G1 - Text", "Titan Multimodal Embeddings G1"]: return TitanEmbeddingsModel() else: logger.error("Unsupported model id " + model_id) raise HTTPException( status_code=500, detail="Unsupported model id " + model_id, )