diff --git a/.flake8 b/.flake8 deleted file mode 100644 index 364a711..0000000 --- a/.flake8 +++ /dev/null @@ -1,19 +0,0 @@ -[flake8] -max-line-length = 120 -ignore = - E203,W191,W503 -exclude = - build - .git - __pycache__ - .tox - venv - .venv - .venv-test - tmp* - deployment - cdk.out - node_modules - -max-complexity = 10 -require-code = True \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..6198a25 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,8 @@ +- repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.9.10 + hooks: + # Run the linter. + - id: ruff + # Run the formatter. + - id: ruff-format \ No newline at end of file diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 0000000..dd77d11 --- /dev/null +++ b/ruff.toml @@ -0,0 +1,23 @@ +line-length = 120 +indent-width = 4 +target-version = "py312" + +exclude = [ + ".venv", + ".vscode", + "test/*" +] + +[lint] +select = ["E", "F"] +ignore = [ + "E501", + "B008", + "C901", + "F401", + "W191", +] + +[format] +# use double quotes for strings. +quote-style = "double" \ No newline at end of file diff --git a/src/api/auth.py b/src/api/auth.py index 22e711f..1a64653 100644 --- a/src/api/auth.py +++ b/src/api/auth.py @@ -16,9 +16,7 @@ if api_key_param: # For backward compatibility. # Please now use secrets manager instead. ssm = boto3.client("ssm") - api_key = ssm.get_parameter(Name=api_key_param, WithDecryption=True)["Parameter"][ - "Value" - ] + api_key = ssm.get_parameter(Name=api_key_param, WithDecryption=True)["Parameter"]["Value"] elif api_key_secret_arn: sm = boto3.client("secretsmanager") try: @@ -26,11 +24,9 @@ elif api_key_secret_arn: if "SecretString" in response: secret = json.loads(response["SecretString"]) api_key = secret["api_key"] - except ClientError as e: - raise RuntimeError( - "Unable to retrieve API KEY, please ensure the secret ARN is correct" - ) - except KeyError as e: + except ClientError: + raise RuntimeError("Unable to retrieve API KEY, please ensure the secret ARN is correct") + except KeyError: raise RuntimeError('Please ensure the secret contains a "api_key" field') elif api_key_env: api_key = api_key_env @@ -45,6 +41,4 @@ def api_key_auth( credentials: Annotated[HTTPAuthorizationCredentials, Depends(security)], ): if credentials.credentials != api_key: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API Key" - ) + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API Key") diff --git a/src/api/models/base.py b/src/api/models/base.py index 9d9db7f..659fdff 100644 --- a/src/api/models/base.py +++ b/src/api/models/base.py @@ -43,9 +43,7 @@ class BaseChatModel(ABC): return "chatcmpl-" + str(uuid.uuid4())[:8] @staticmethod - def stream_response_to_bytes( - response: ChatStreamResponse | None = None - ) -> bytes: + def stream_response_to_bytes(response: ChatStreamResponse | None = None) -> bytes: if response: # to populate other fields when using exclude_unset=True response.system_fingerprint = "fp" diff --git a/src/api/models/bedrock.py b/src/api/models/bedrock.py index a9ecd99..7a35157 100644 --- a/src/api/models/bedrock.py +++ b/src/api/models/bedrock.py @@ -36,7 +36,6 @@ from api.schema import ( EmbeddingsResponse, EmbeddingsUsage, Embedding, - ) from api.setting import DEBUG, AWS_REGION, ENABLE_CROSS_REGION_INFERENCE, DEFAULT_MODEL @@ -50,15 +49,15 @@ bedrock_runtime = boto3.client( config=config, ) bedrock_client = boto3.client( - service_name='bedrock', + service_name="bedrock", region_name=AWS_REGION, config=config, ) def get_inference_region_prefix(): - if AWS_REGION.startswith('ap-'): - return 'apac' + if AWS_REGION.startswith("ap-"): + return "apac" return AWS_REGION[:2] @@ -88,49 +87,38 @@ def list_bedrock_models() -> dict: profile_list = [] if ENABLE_CROSS_REGION_INFERENCE: # List system defined inference profile IDs - response = bedrock_client.list_inference_profiles( - maxResults=1000, - typeEquals='SYSTEM_DEFINED' - ) - profile_list = [p['inferenceProfileId'] for p in response['inferenceProfileSummaries']] + response = bedrock_client.list_inference_profiles(maxResults=1000, typeEquals="SYSTEM_DEFINED") + profile_list = [p["inferenceProfileId"] for p in response["inferenceProfileSummaries"]] # List foundation models, only cares about text outputs here. - response = bedrock_client.list_foundation_models( - byOutputModality='TEXT' - ) + response = bedrock_client.list_foundation_models(byOutputModality="TEXT") - for model in response['modelSummaries']: - model_id = model.get('modelId', 'N/A') - stream_supported = model.get('responseStreamingSupported', True) - status = model['modelLifecycle'].get('status', 'ACTIVE') + for model in response["modelSummaries"]: + model_id = model.get("modelId", "N/A") + stream_supported = model.get("responseStreamingSupported", True) + status = model["modelLifecycle"].get("status", "ACTIVE") # currently, use this to filter out rerank models and legacy models if not stream_supported or status not in ["ACTIVE", "LEGACY"]: continue - inference_types = model.get('inferenceTypesSupported', []) - input_modalities = model['inputModalities'] + inference_types = model.get("inferenceTypesSupported", []) + input_modalities = model["inputModalities"] # Add on-demand model list - if 'ON_DEMAND' in inference_types: - model_list[model_id] = { - 'modalities': input_modalities - } + if "ON_DEMAND" in inference_types: + model_list[model_id] = {"modalities": input_modalities} # Add cross-region inference model list. - profile_id = cr_inference_prefix + '.' + model_id + profile_id = cr_inference_prefix + "." + model_id if profile_id in profile_list: - model_list[profile_id] = { - 'modalities': input_modalities - } + model_list[profile_id] = {"modalities": input_modalities} except Exception as e: logger.error(f"Unable to list models: {str(e)}") if not model_list: # In case stack not updated. - model_list[DEFAULT_MODEL] = { - 'modalities': ["TEXT", "IMAGE"] - } + model_list[DEFAULT_MODEL] = {"modalities": ["TEXT", "IMAGE"]} return model_list @@ -140,7 +128,6 @@ bedrock_model_list = list_bedrock_models() class BedrockModel(BaseChatModel): - def list_models(self) -> list[str]: """Always refresh the latest model list""" global bedrock_model_list @@ -224,10 +211,7 @@ class BedrockModel(BaseChatModel): logger.info("Proxy response :" + stream_response.model_dump_json()) if stream_response.choices: yield self.stream_response_to_bytes(stream_response) - elif ( - chat_request.stream_options - and chat_request.stream_options.include_usage - ): + elif chat_request.stream_options and chat_request.stream_options.include_usage: # An empty choices for Usage as per OpenAI doc below: # if you set stream_options: {"include_usage": true}. # an additional chunk will be streamed before the data: [DONE] message. @@ -277,9 +261,7 @@ class BedrockModel(BaseChatModel): messages.append( { "role": message.role, - "content": self._parse_content_parts( - message, chat_request.model - ), + "content": self._parse_content_parts(message, chat_request.model), } ) elif isinstance(message, AssistantMessage): @@ -288,9 +270,7 @@ class BedrockModel(BaseChatModel): messages.append( { "role": message.role, - "content": self._parse_content_parts( - message, chat_request.model - ), + "content": self._parse_content_parts(message, chat_request.model), } ) if message.tool_calls: @@ -305,7 +285,7 @@ class BedrockModel(BaseChatModel): "toolUse": { "toolUseId": tool_call.id, "name": tool_call.function.name, - "input": tool_input + "input": tool_input, } } ], @@ -335,7 +315,7 @@ class BedrockModel(BaseChatModel): return self._reframe_multi_payloard(messages) def _reframe_multi_payloard(self, messages: list) -> list: - """ Receive messages and reformat them to comply with the Claude format + """Receive messages and reformat them to comply with the Claude format With OpenAI format requests, it's not a problem to repeatedly receive messages from the same role, but with Claude format requests, you cannot repeatedly receive messages from the same role. @@ -364,16 +344,13 @@ class BedrockModel(BaseChatModel): # Search through the list of messages and combine messages from the same role into one list for message in messages: - next_role = message['role'] - next_content = message['content'] + next_role = message["role"] + next_content = message["content"] # If the next role is different from the previous message, add the previous role's messages to the list if next_role != current_role: if current_content: - reformatted_messages.append({ - "role": current_role, - "content": current_content - }) + reformatted_messages.append({"role": current_role, "content": current_content}) # Switch to the new role current_role = next_role current_content = [] @@ -386,10 +363,7 @@ class BedrockModel(BaseChatModel): # Add the last role's messages to the list if current_content: - reformatted_messages.append({ - "role": current_role, - "content": current_content - }) + reformatted_messages.append({"role": current_role, "content": current_content}) return reformatted_messages @@ -426,25 +400,20 @@ class BedrockModel(BaseChatModel): # From OpenAI api, the max_token is not supported in reasoning mode # Use max_completion_tokens if provided. - max_tokens = chat_request.max_completion_tokens if chat_request.max_completion_tokens else chat_request.max_tokens + max_tokens = ( + chat_request.max_completion_tokens if chat_request.max_completion_tokens else chat_request.max_tokens + ) budget_tokens = self._calc_budget_tokens(max_tokens, chat_request.reasoning_effort) inference_config["maxTokens"] = max_tokens # unset topP - Not supported inference_config.pop("topP") args["additionalModelRequestFields"] = { - "reasoning_config": { - "type": "enabled", - "budget_tokens": budget_tokens - } + "reasoning_config": {"type": "enabled", "budget_tokens": budget_tokens} } # add tool config if chat_request.tools: - args["toolConfig"] = { - "tools": [ - self._convert_tool_spec(t.function) for t in chat_request.tools - ] - } + args["toolConfig"] = {"tools": [self._convert_tool_spec(t.function) for t in chat_request.tools]} if chat_request.tool_choice and not chat_request.model.startswith("meta.llama3-1-"): if isinstance(chat_request.tool_choice, str): @@ -458,19 +427,19 @@ class BedrockModel(BaseChatModel): # Specific tool to use assert "function" in chat_request.tool_choice args["toolConfig"]["toolChoice"] = { - "tool": {"name": chat_request.tool_choice["function"].get("name", "")}} + "tool": {"name": chat_request.tool_choice["function"].get("name", "")} + } return args def _create_response( - self, - model: str, - message_id: str, - content: list[dict] = None, - finish_reason: str | None = None, - input_tokens: int = 0, - output_tokens: int = 0, + self, + model: str, + message_id: str, + content: list[dict] = None, + finish_reason: str | None = None, + input_tokens: int = 0, + output_tokens: int = 0, ) -> ChatResponse: - message = ChatResponseMessage( role="assistant", ) @@ -524,9 +493,7 @@ class BedrockModel(BaseChatModel): response.created = int(time.time()) return response - def _create_response_stream( - self, model_id: str, message_id: str, chunk: dict - ) -> ChatStreamResponse | None: + def _create_response_stream(self, model_id: str, message_id: str, chunk: dict) -> ChatStreamResponse | None: """Parsing the Bedrock stream response chunk. Ref: https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#message-inference-examples @@ -583,7 +550,7 @@ class BedrockModel(BaseChatModel): index=index, function=ResponseFunction( arguments=delta["toolUse"]["input"], - ) + ), ) ] ) @@ -641,7 +608,6 @@ class BedrockModel(BaseChatModel): response = requests.get(image_url) # Check if the request was successful if response.status_code == 200: - content_type = response.headers.get("Content-Type") if not content_type.startswith("image"): content_type = "image/jpeg" @@ -649,14 +615,12 @@ class BedrockModel(BaseChatModel): image_content = response.content return image_content, content_type else: - raise HTTPException( - status_code=500, detail="Unable to access the image url" - ) + raise HTTPException(status_code=500, detail="Unable to access the image url") def _parse_content_parts( - self, - message: UserMessage, - model_id: str, + self, + message: UserMessage, + model_id: str, ) -> list[dict]: if isinstance(message.content, str): return [ @@ -695,7 +659,7 @@ class BedrockModel(BaseChatModel): @staticmethod def is_supported_modality(model_id: str, modality: str = "IMAGE") -> bool: model = bedrock_model_list.get(model_id) - modalities = model.get('modalities', []) + modalities = model.get("modalities", []) if modality in modalities: return True return False @@ -740,7 +704,7 @@ class BedrockModel(BaseChatModel): "max_tokens": "length", "stop_sequence": "stop", "complete": "stop", - "content_filtered": "content_filter" + "content_filtered": "content_filter", } return finish_reason_mapping.get(finish_reason.lower(), finish_reason.lower()) return None @@ -773,12 +737,12 @@ class BedrockEmbeddingsModel(BaseEmbeddingsModel, ABC): raise HTTPException(status_code=500, detail=str(e)) def _create_response( - self, - embeddings: list[float], - model: str, - input_tokens: int = 0, - output_tokens: int = 0, - encoding_format: Literal["float", "base64"] = "float", + self, + embeddings: list[float], + model: str, + input_tokens: int = 0, + output_tokens: int = 0, + encoding_format: Literal["float", "base64"] = "float", ) -> EmbeddingsResponse: data = [] for i, embedding in enumerate(embeddings): @@ -803,7 +767,6 @@ class BedrockEmbeddingsModel(BaseEmbeddingsModel, ABC): class CohereEmbeddingsModel(BedrockEmbeddingsModel): - def _parse_args(self, embeddings_request: EmbeddingsRequest) -> dict: texts = [] if isinstance(embeddings_request.input, str): @@ -834,9 +797,7 @@ class CohereEmbeddingsModel(BedrockEmbeddingsModel): return args def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse: - response = self._invoke_model( - args=self._parse_args(embeddings_request), model_id=embeddings_request.model - ) + response = self._invoke_model(args=self._parse_args(embeddings_request), model_id=embeddings_request.model) response_body = json.loads(response.get("body").read()) if DEBUG: logger.info("Bedrock response body: " + str(response_body)) @@ -849,19 +810,13 @@ class CohereEmbeddingsModel(BedrockEmbeddingsModel): class TitanEmbeddingsModel(BedrockEmbeddingsModel): - def _parse_args(self, embeddings_request: EmbeddingsRequest) -> dict: if isinstance(embeddings_request.input, str): input_text = embeddings_request.input - elif ( - isinstance(embeddings_request.input, list) - and len(embeddings_request.input) == 1 - ): + elif isinstance(embeddings_request.input, list) and len(embeddings_request.input) == 1: input_text = embeddings_request.input[0] else: - raise ValueError( - "Amazon Titan Embeddings models support only single strings as input." - ) + raise ValueError("Amazon Titan Embeddings models support only single strings as input.") args = { "inputText": input_text, # Note: inputImage is not supported! @@ -875,9 +830,7 @@ class TitanEmbeddingsModel(BedrockEmbeddingsModel): return args def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse: - response = self._invoke_model( - args=self._parse_args(embeddings_request), model_id=embeddings_request.model - ) + response = self._invoke_model(args=self._parse_args(embeddings_request), model_id=embeddings_request.model) response_body = json.loads(response.get("body").read()) if DEBUG: logger.info("Bedrock response body: " + str(response_body)) diff --git a/src/api/routers/chat.py b/src/api/routers/chat.py index 1e48a48..aa200bf 100644 --- a/src/api/routers/chat.py +++ b/src/api/routers/chat.py @@ -17,20 +17,20 @@ router = APIRouter( @router.post("/completions", response_model=ChatResponse | ChatStreamResponse, response_model_exclude_unset=True) async def chat_completions( - chat_request: Annotated[ - ChatRequest, - Body( - examples=[ - { - "model": "anthropic.claude-3-sonnet-20240229-v1:0", - "messages": [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "Hello!"}, - ], - } - ], - ), - ] + chat_request: Annotated[ + ChatRequest, + Body( + examples=[ + { + "model": "anthropic.claude-3-sonnet-20240229-v1:0", + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello!"}, + ], + } + ], + ), + ], ): if chat_request.model.lower().startswith("gpt-"): chat_request.model = DEFAULT_MODEL @@ -39,7 +39,5 @@ async def chat_completions( model = BedrockModel() model.validate(chat_request) if chat_request.stream: - return StreamingResponse( - content=model.chat_stream(chat_request), media_type="text/event-stream" - ) + return StreamingResponse(content=model.chat_stream(chat_request), media_type="text/event-stream") return model.chat(chat_request) diff --git a/src/api/routers/embeddings.py b/src/api/routers/embeddings.py index e5cde31..1ccf627 100644 --- a/src/api/routers/embeddings.py +++ b/src/api/routers/embeddings.py @@ -15,19 +15,17 @@ router = APIRouter( @router.post("", response_model=EmbeddingsResponse) async def embeddings( - embeddings_request: Annotated[ - EmbeddingsRequest, - Body( - examples=[ - { - "model": "cohere.embed-multilingual-v3", - "input": [ - "Your text string goes here" - ], - } - ], - ), - ] + embeddings_request: Annotated[ + EmbeddingsRequest, + Body( + examples=[ + { + "model": "cohere.embed-multilingual-v3", + "input": ["Your text string goes here"], + } + ], + ), + ], ): if embeddings_request.model.lower().startswith("text-embedding-"): embeddings_request.model = DEFAULT_EMBEDDING_MODEL diff --git a/src/api/routers/model.py b/src/api/routers/model.py index ce5e8a1..71054fb 100644 --- a/src/api/routers/model.py +++ b/src/api/routers/model.py @@ -22,9 +22,7 @@ async def validate_model_id(model_id: str): @router.get("", response_model=Models) async def list_models(): - model_list = [ - Model(id=model_id) for model_id in chat_model.list_models() - ] + model_list = [Model(id=model_id) for model_id in chat_model.list_models()] return Models(data=model_list) @@ -33,10 +31,10 @@ async def list_models(): response_model=Model, ) async def get_model( - model_id: Annotated[ - str, - Path(description="Model ID", example="anthropic.claude-3-sonnet-20240229-v1:0"), - ] + model_id: Annotated[ + str, + Path(description="Model ID", example="anthropic.claude-3-sonnet-20240229-v1:0"), + ], ): await validate_model_id(model_id) return Model(id=model_id) diff --git a/src/api/setting.py b/src/api/setting.py index 2bce1dc..e090300 100644 --- a/src/api/setting.py +++ b/src/api/setting.py @@ -13,10 +13,6 @@ Use OpenAI-Compatible RESTful APIs for Amazon Bedrock models. DEBUG = os.environ.get("DEBUG", "false").lower() != "false" AWS_REGION = os.environ.get("AWS_REGION", "us-west-2") -DEFAULT_MODEL = os.environ.get( - "DEFAULT_MODEL", "anthropic.claude-3-sonnet-20240229-v1:0" -) -DEFAULT_EMBEDDING_MODEL = os.environ.get( - "DEFAULT_EMBEDDING_MODEL", "cohere.embed-multilingual-v3" -) +DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "anthropic.claude-3-sonnet-20240229-v1:0") +DEFAULT_EMBEDDING_MODEL = os.environ.get("DEFAULT_EMBEDDING_MODEL", "cohere.embed-multilingual-v3") ENABLE_CROSS_REGION_INFERENCE = os.environ.get("ENABLE_CROSS_REGION_INFERENCE", "true").lower() != "false"