import base64 import json import logging import re import time from abc import ABC from typing import AsyncIterable, Iterable, Literal import boto3 import numpy as np import requests import tiktoken from botocore.config import Config from fastapi import HTTPException from api.models.base import BaseChatModel, BaseEmbeddingsModel from api.schema import ( # Chat ChatResponse, ChatRequest, Choice, ChatResponseMessage, Usage, ChatStreamResponse, ImageContent, TextContent, ToolCall, ChoiceDelta, UserMessage, AssistantMessage, ToolMessage, Function, ResponseFunction, # Embeddings EmbeddingsRequest, EmbeddingsResponse, EmbeddingsUsage, Embedding, ) from api.setting import DEBUG, AWS_REGION, ENABLE_CROSS_REGION_INFERENCE, DEFAULT_MODEL logger = logging.getLogger(__name__) config = Config(connect_timeout=60, read_timeout=120, retries={"max_attempts": 1}) bedrock_runtime = boto3.client( service_name="bedrock-runtime", region_name=AWS_REGION, config=config, ) bedrock_client = boto3.client( service_name='bedrock', region_name=AWS_REGION, config=config, ) def get_inference_region_prefix(): if AWS_REGION.startswith('ap-'): return 'apac' return AWS_REGION[:2] # https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html cr_inference_prefix = get_inference_region_prefix() SUPPORTED_BEDROCK_EMBEDDING_MODELS = { "cohere.embed-multilingual-v3": "Cohere Embed Multilingual", "cohere.embed-english-v3": "Cohere Embed English", # Disable Titan embedding. # "amazon.titan-embed-text-v1": "Titan Embeddings G1 - Text", # "amazon.titan-embed-image-v1": "Titan Multimodal Embeddings G1" } ENCODER = tiktoken.get_encoding("cl100k_base") def list_bedrock_models() -> dict: """Automatically getting a list of supported models. Returns a model list combines: - ON_DEMAND models. - Cross-Region Inference Profiles (if enabled via Env) """ model_list = {} try: profile_list = [] if ENABLE_CROSS_REGION_INFERENCE: # List system defined inference profile IDs response = bedrock_client.list_inference_profiles( maxResults=1000, typeEquals='SYSTEM_DEFINED' ) profile_list = [p['inferenceProfileId'] for p in response['inferenceProfileSummaries']] # List foundation models, only cares about text outputs here. response = bedrock_client.list_foundation_models( byOutputModality='TEXT' ) for model in response['modelSummaries']: model_id = model.get('modelId', 'N/A') stream_supported = model.get('responseStreamingSupported', True) status = model['modelLifecycle'].get('status', 'ACTIVE') # currently, use this to filter out rerank models and legacy models if not stream_supported or status not in ["ACTIVE", "LEGACY"]: continue inference_types = model.get('inferenceTypesSupported', []) input_modalities = model['inputModalities'] # Add on-demand model list if 'ON_DEMAND' in inference_types: model_list[model_id] = { 'modalities': input_modalities } # Add cross-region inference model list. profile_id = cr_inference_prefix + '.' + model_id if profile_id in profile_list: model_list[profile_id] = { 'modalities': input_modalities } except Exception as e: logger.error(f"Unable to list models: {str(e)}") if not model_list: # In case stack not updated. model_list[DEFAULT_MODEL] = { 'modalities': ["TEXT", "IMAGE"] } return model_list # Initialize the model list. bedrock_model_list = list_bedrock_models() class BedrockModel(BaseChatModel): def list_models(self) -> list[str]: """Always refresh the latest model list""" global bedrock_model_list bedrock_model_list = list_bedrock_models() return list(bedrock_model_list.keys()) def validate(self, chat_request: ChatRequest): """Perform basic validation on requests""" error = "" # check if model is supported if chat_request.model not in bedrock_model_list.keys(): error = f"Unsupported model {chat_request.model}, please use models API to get a list of supported models" if error: raise HTTPException( status_code=400, detail=error, ) def _invoke_bedrock(self, chat_request: ChatRequest, stream=False): """Common logic for invoke bedrock models""" if DEBUG: logger.info("Raw request: " + chat_request.model_dump_json()) # convert OpenAI chat request to Bedrock SDK request args = self._parse_request(chat_request) if DEBUG: logger.info("Bedrock request: " + json.dumps(str(args))) try: if stream: response = bedrock_runtime.converse_stream(**args) else: response = bedrock_runtime.converse(**args) except bedrock_runtime.exceptions.ValidationException as e: logger.error("Validation Error: " + str(e)) raise HTTPException(status_code=400, detail=str(e)) except Exception as e: logger.error(e) raise HTTPException(status_code=500, detail=str(e)) return response def chat(self, chat_request: ChatRequest) -> ChatResponse: """Default implementation for Chat API.""" message_id = self.generate_message_id() response = self._invoke_bedrock(chat_request) output_message = response["output"]["message"] input_tokens = response["usage"]["inputTokens"] output_tokens = response["usage"]["outputTokens"] finish_reason = response["stopReason"] chat_response = self._create_response( model=chat_request.model, message_id=message_id, content=output_message["content"], finish_reason=finish_reason, input_tokens=input_tokens, output_tokens=output_tokens, ) if DEBUG: logger.info("Proxy response :" + chat_response.model_dump_json()) return chat_response def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]: """Default implementation for Chat Stream API""" response = self._invoke_bedrock(chat_request, stream=True) message_id = self.generate_message_id() stream = response.get("stream") for chunk in stream: stream_response = self._create_response_stream( model_id=chat_request.model, message_id=message_id, chunk=chunk ) if not stream_response: continue if DEBUG: logger.info("Proxy response :" + stream_response.model_dump_json()) if stream_response.choices: yield self.stream_response_to_bytes(stream_response) elif ( chat_request.stream_options and chat_request.stream_options.include_usage ): # An empty choices for Usage as per OpenAI doc below: # if you set stream_options: {"include_usage": true}. # an additional chunk will be streamed before the data: [DONE] message. # The usage field on this chunk shows the token usage statistics for the entire request, # and the choices field will always be an empty array. # All other chunks will also include a usage field, but with a null value. yield self.stream_response_to_bytes(stream_response) # return an [DONE] message at the end. yield self.stream_response_to_bytes() def _parse_system_prompts(self, chat_request: ChatRequest) -> list[dict[str, str]]: """Create system prompts. Note that not all models support system prompts. example output: [{"text" : system_prompt}] See example: https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#message-inference-examples """ system_prompts = [] for message in chat_request.messages: if message.role != "system": # ignore system messages here continue assert isinstance(message.content, str) system_prompts.append({"text": message.content}) return system_prompts def _parse_messages(self, chat_request: ChatRequest) -> list[dict]: """ Converse API only support user and assistant messages. example output: [{ "role": "user", "content": [{"text": input_text}] }] See example: https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#message-inference-examples """ messages = [] for message in chat_request.messages: if isinstance(message, UserMessage): messages.append( { "role": message.role, "content": self._parse_content_parts( message, chat_request.model ), } ) elif isinstance(message, AssistantMessage): if message.content: # Text message messages.append( { "role": message.role, "content": self._parse_content_parts( message, chat_request.model ), } ) else: # Tool use message tool_input = json.loads(message.tool_calls[0].function.arguments) messages.append( { "role": message.role, "content": [ { "toolUse": { "toolUseId": message.tool_calls[0].id, "name": message.tool_calls[0].function.name, "input": tool_input } } ], } ) elif isinstance(message, ToolMessage): # Bedrock does not support tool role, # Add toolResult to content # https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolResultBlock.html messages.append( { "role": "user", "content": [ { "toolResult": { "toolUseId": message.tool_call_id, "content": [{"text": message.content}], } } ], } ) else: # ignore others, such as system messages continue return self._reframe_multi_payloard(messages) def _reframe_multi_payloard(self, messages: list) -> list: """ Receive messages and reformat them to comply with the Claude format With OpenAI format requests, it's not a problem to repeatedly receive messages from the same role, but with Claude format requests, you cannot repeatedly receive messages from the same role. This method searches through the OpenAI format messages in order and reformats them to the Claude format. ``` openai_format_messages=[ {"role": "user", "content": "Hello"}, {"role": "user", "content": "Who are you?"}, ] bedrock_format_messages=[ { "role": "user", "content": [ {"text": "Hello"}, {"text": "Who are you?"} ] }, ] """ reformatted_messages = [] current_role = None current_content = [] # Search through the list of messages and combine messages from the same role into one list for message in messages: next_role = message['role'] next_content = message['content'] # If the next role is different from the previous message, add the previous role's messages to the list if next_role != current_role: if current_content: reformatted_messages.append({ "role": current_role, "content": current_content }) # Switch to the new role current_role = next_role current_content = [] # Add the message content to current_content if isinstance(next_content, str): current_content.append({"text": next_content}) elif isinstance(next_content, list): current_content.extend(next_content) # Add the last role's messages to the list if current_content: reformatted_messages.append({ "role": current_role, "content": current_content }) return reformatted_messages def _parse_request(self, chat_request: ChatRequest) -> dict: """Create default converse request body. Also perform validations to tool call etc. Ref: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html """ messages = self._parse_messages(chat_request) system_prompts = self._parse_system_prompts(chat_request) # Base inference parameters. inference_config = { "temperature": chat_request.temperature, "maxTokens": chat_request.max_tokens, "topP": chat_request.top_p, } if chat_request.stop is not None: stop = chat_request.stop if isinstance(stop, str): stop = [stop] inference_config["stopSequences"] = stop args = { "modelId": chat_request.model, "messages": messages, "system": system_prompts, "inferenceConfig": inference_config, } if chat_request.reasoning_effort: # From OpenAI api, the max_token is not supported in reasoning mode # Use max_completion_tokens if provided. max_tokens = chat_request.max_completion_tokens if chat_request.max_completion_tokens else chat_request.max_tokens inference_config["maxTokens"] = max_tokens # unset topP - Not supported inference_config.pop("topP") args["additionalModelRequestFields"] = { "reasoning_config": { "type": "enabled", "budget_tokens": max_tokens - 1 } } # add tool config if chat_request.tools: args["toolConfig"] = { "tools": [ self._convert_tool_spec(t.function) for t in chat_request.tools ] } if chat_request.tool_choice and not chat_request.model.startswith("meta.llama3-1-"): if isinstance(chat_request.tool_choice, str): # auto (default) is mapped to {"auto" : {}} # required is mapped to {"any" : {}} if chat_request.tool_choice == "required": args["toolConfig"]["toolChoice"] = {"any": {}} else: args["toolConfig"]["toolChoice"] = {"auto": {}} else: # Specific tool to use assert "function" in chat_request.tool_choice args["toolConfig"]["toolChoice"] = { "tool": {"name": chat_request.tool_choice["function"].get("name", "")}} return args def _create_response( self, model: str, message_id: str, content: list[dict] = None, finish_reason: str | None = None, input_tokens: int = 0, output_tokens: int = 0, ) -> ChatResponse: message = ChatResponseMessage( role="assistant", ) if finish_reason == "tool_use": # https://docs.aws.amazon.com/bedrock/latest/userguide/tool-use.html#tool-use-examples tool_calls = [] for part in content: if "toolUse" in part: tool = part["toolUse"] tool_calls.append( ToolCall( id=tool["toolUseId"], type="function", function=ResponseFunction( name=tool["name"], arguments=json.dumps(tool["input"]), ), ) ) message.tool_calls = tool_calls message.content = None else: message.content = "" for c in content: if "reasoningContent" in c: message.reasoning_content = c["reasoningContent"]["reasoningText"].get("text", "") if "text" in c: message.content = c["text"] else: logger.warning("Unknown tag in message content " + ",".join(c.keys())) response = ChatResponse( id=message_id, model=model, choices=[ Choice( index=0, message=message, finish_reason=self._convert_finish_reason(finish_reason), logprobs=None, ) ], usage=Usage( prompt_tokens=input_tokens, completion_tokens=output_tokens, total_tokens=input_tokens + output_tokens, ), ) response.system_fingerprint = "fp" response.object = "chat.completion" response.created = int(time.time()) return response def _create_response_stream( self, model_id: str, message_id: str, chunk: dict ) -> ChatStreamResponse | None: """Parsing the Bedrock stream response chunk. Ref: https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#message-inference-examples """ if DEBUG: logger.info("Bedrock response chunk: " + str(chunk)) finish_reason = None message = None usage = None if "messageStart" in chunk: message = ChatResponseMessage( role=chunk["messageStart"]["role"], content="", ) if "contentBlockStart" in chunk: # tool call start delta = chunk["contentBlockStart"]["start"] if "toolUse" in delta: # first index is content index = chunk["contentBlockStart"]["contentBlockIndex"] - 1 message = ChatResponseMessage( tool_calls=[ ToolCall( index=index, type="function", id=delta["toolUse"]["toolUseId"], function=ResponseFunction( name=delta["toolUse"]["name"], arguments="", ), ) ] ) if "contentBlockDelta" in chunk: delta = chunk["contentBlockDelta"]["delta"] if "text" in delta: # stream content message = ChatResponseMessage( content=delta["text"], ) else: # tool use index = chunk["contentBlockDelta"]["contentBlockIndex"] - 1 message = ChatResponseMessage( tool_calls=[ ToolCall( index=index, function=ResponseFunction( arguments=delta["toolUse"]["input"], ) ) ] ) if "messageStop" in chunk: message = ChatResponseMessage() finish_reason = chunk["messageStop"]["stopReason"] if "metadata" in chunk: # usage information in metadata. metadata = chunk["metadata"] if "usage" in metadata: # token usage return ChatStreamResponse( id=message_id, model=model_id, choices=[], usage=Usage( prompt_tokens=metadata["usage"]["inputTokens"], completion_tokens=metadata["usage"]["outputTokens"], total_tokens=metadata["usage"]["totalTokens"], ), ) if message: return ChatStreamResponse( id=message_id, model=model_id, choices=[ ChoiceDelta( index=0, delta=message, logprobs=None, finish_reason=self._convert_finish_reason(finish_reason), ) ], usage=usage, ) return None def _parse_image(self, image_url: str) -> tuple[bytes, str]: """Try to get the raw data from an image url. Ref: https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ImageSource.html returns a tuple of (Image Data, Content Type) """ pattern = r"^data:(image/[a-z]*);base64,\s*" content_type = re.search(pattern, image_url) # if already base64 encoded. # Only supports 'image/jpeg', 'image/png', 'image/gif' or 'image/webp' if content_type: image_data = re.sub(pattern, "", image_url) return base64.b64decode(image_data), content_type.group(1) # Send a request to the image URL response = requests.get(image_url) # Check if the request was successful if response.status_code == 200: content_type = response.headers.get("Content-Type") if not content_type.startswith("image"): content_type = "image/jpeg" # Get the image content image_content = response.content return image_content, content_type else: raise HTTPException( status_code=500, detail="Unable to access the image url" ) def _parse_content_parts( self, message: UserMessage, model_id: str, ) -> list[dict]: if isinstance(message.content, str): return [ { "text": message.content, } ] content_parts = [] for part in message.content: if isinstance(part, TextContent): content_parts.append( { "text": part.text, } ) elif isinstance(part, ImageContent): if not self.is_supported_modality(model_id, modality="IMAGE"): raise HTTPException( status_code=400, detail=f"Multimodal message is currently not supported by {model_id}", ) image_data, content_type = self._parse_image(part.image_url.url) content_parts.append( { "image": { "format": content_type[6:], # image/ "source": {"bytes": image_data}, }, } ) else: # Ignore.. continue return content_parts @staticmethod def is_supported_modality(model_id: str, modality: str = "IMAGE") -> bool: model = bedrock_model_list.get(model_id) modalities = model.get('modalities', []) if modality in modalities: return True return False def _convert_tool_spec(self, func: Function) -> dict: return { "toolSpec": { "name": func.name, "description": func.description, "inputSchema": { "json": func.parameters, }, } } def _convert_finish_reason(self, finish_reason: str | None) -> str | None: """ Below is a list of finish reason according to OpenAI doc: - stop: if the model hit a natural stop point or a provided stop sequence, - length: if the maximum number of tokens specified in the request was reached, - content_filter: if content was omitted due to a flag from our content filters, - tool_calls: if the model called a tool """ if finish_reason: finish_reason_mapping = { "tool_use": "tool_calls", "finished": "stop", "end_turn": "stop", "max_tokens": "length", "stop_sequence": "stop", "complete": "stop", "content_filtered": "content_filter" } return finish_reason_mapping.get(finish_reason.lower(), finish_reason.lower()) return None class BedrockEmbeddingsModel(BaseEmbeddingsModel, ABC): accept = "application/json" content_type = "application/json" def _invoke_model(self, args: dict, model_id: str): body = json.dumps(args) if DEBUG: logger.info("Invoke Bedrock Model: " + model_id) logger.info("Bedrock request body: " + body) try: return bedrock_runtime.invoke_model( body=body, modelId=model_id, accept=self.accept, contentType=self.content_type, ) except bedrock_runtime.exceptions.ValidationException as e: logger.error("Validation Error: " + str(e)) raise HTTPException(status_code=400, detail=str(e)) except Exception as e: logger.error(e) raise HTTPException(status_code=500, detail=str(e)) def _create_response( self, embeddings: list[float], model: str, input_tokens: int = 0, output_tokens: int = 0, encoding_format: Literal["float", "base64"] = "float", ) -> EmbeddingsResponse: data = [] for i, embedding in enumerate(embeddings): if encoding_format == "base64": arr = np.array(embedding, dtype=np.float32) arr_bytes = arr.tobytes() encoded_embedding = base64.b64encode(arr_bytes) data.append(Embedding(index=i, embedding=encoded_embedding)) else: data.append(Embedding(index=i, embedding=embedding)) response = EmbeddingsResponse( data=data, model=model, usage=EmbeddingsUsage( prompt_tokens=input_tokens, total_tokens=input_tokens + output_tokens, ), ) if DEBUG: logger.info("Proxy response :" + response.model_dump_json()) return response class CohereEmbeddingsModel(BedrockEmbeddingsModel): def _parse_args(self, embeddings_request: EmbeddingsRequest) -> dict: texts = [] if isinstance(embeddings_request.input, str): texts = [embeddings_request.input] elif isinstance(embeddings_request.input, list): texts = embeddings_request.input elif isinstance(embeddings_request.input, Iterable): # For encoded input # The workaround is to use tiktoken to decode to get the original text. encodings = [] for inner in embeddings_request.input: if isinstance(inner, int): # Iterable[int] encodings.append(inner) else: # Iterable[Iterable[int]] text = ENCODER.decode(list(inner)) texts.append(text) if encodings: texts.append(ENCODER.decode(encodings)) # Maximum of 2048 characters args = { "texts": texts, "input_type": "search_document", "truncate": "END", # "NONE|START|END" } return args def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse: response = self._invoke_model( args=self._parse_args(embeddings_request), model_id=embeddings_request.model ) response_body = json.loads(response.get("body").read()) if DEBUG: logger.info("Bedrock response body: " + str(response_body)) return self._create_response( embeddings=response_body["embeddings"], model=embeddings_request.model, encoding_format=embeddings_request.encoding_format, ) class TitanEmbeddingsModel(BedrockEmbeddingsModel): def _parse_args(self, embeddings_request: EmbeddingsRequest) -> dict: if isinstance(embeddings_request.input, str): input_text = embeddings_request.input elif ( isinstance(embeddings_request.input, list) and len(embeddings_request.input) == 1 ): input_text = embeddings_request.input[0] else: raise ValueError( "Amazon Titan Embeddings models support only single strings as input." ) args = { "inputText": input_text, # Note: inputImage is not supported! } if embeddings_request.model == "amazon.titan-embed-image-v1": args["embeddingConfig"] = ( embeddings_request.embedding_config if embeddings_request.embedding_config else {"outputEmbeddingLength": 1024} ) return args def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse: response = self._invoke_model( args=self._parse_args(embeddings_request), model_id=embeddings_request.model ) response_body = json.loads(response.get("body").read()) if DEBUG: logger.info("Bedrock response body: " + str(response_body)) return self._create_response( embeddings=[response_body["embedding"]], model=embeddings_request.model, input_tokens=response_body["inputTextTokenCount"], ) def get_embeddings_model(model_id: str) -> BedrockEmbeddingsModel: model_name = SUPPORTED_BEDROCK_EMBEDDING_MODELS.get(model_id, "") if DEBUG: logger.info("model name is " + model_name) match model_name: case "Cohere Embed Multilingual" | "Cohere Embed English": return CohereEmbeddingsModel() case _: logger.error("Unsupported model id " + model_id) raise HTTPException( status_code=400, detail="Unsupported embedding model id " + model_id, )