apply ruff linter

This commit is contained in:
Aiden Dai
2025-03-13 13:50:57 +08:00
parent 33e8fcfd3b
commit f21b9a2e84
10 changed files with 128 additions and 181 deletions

View File

@@ -16,9 +16,7 @@ if api_key_param:
# For backward compatibility.
# Please now use secrets manager instead.
ssm = boto3.client("ssm")
api_key = ssm.get_parameter(Name=api_key_param, WithDecryption=True)["Parameter"][
"Value"
]
api_key = ssm.get_parameter(Name=api_key_param, WithDecryption=True)["Parameter"]["Value"]
elif api_key_secret_arn:
sm = boto3.client("secretsmanager")
try:
@@ -26,11 +24,9 @@ elif api_key_secret_arn:
if "SecretString" in response:
secret = json.loads(response["SecretString"])
api_key = secret["api_key"]
except ClientError as e:
raise RuntimeError(
"Unable to retrieve API KEY, please ensure the secret ARN is correct"
)
except KeyError as e:
except ClientError:
raise RuntimeError("Unable to retrieve API KEY, please ensure the secret ARN is correct")
except KeyError:
raise RuntimeError('Please ensure the secret contains a "api_key" field')
elif api_key_env:
api_key = api_key_env
@@ -45,6 +41,4 @@ def api_key_auth(
credentials: Annotated[HTTPAuthorizationCredentials, Depends(security)],
):
if credentials.credentials != api_key:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API Key"
)
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API Key")

View File

@@ -43,9 +43,7 @@ class BaseChatModel(ABC):
return "chatcmpl-" + str(uuid.uuid4())[:8]
@staticmethod
def stream_response_to_bytes(
response: ChatStreamResponse | None = None
) -> bytes:
def stream_response_to_bytes(response: ChatStreamResponse | None = None) -> bytes:
if response:
# to populate other fields when using exclude_unset=True
response.system_fingerprint = "fp"

View File

@@ -36,7 +36,6 @@ from api.schema import (
EmbeddingsResponse,
EmbeddingsUsage,
Embedding,
)
from api.setting import DEBUG, AWS_REGION, ENABLE_CROSS_REGION_INFERENCE, DEFAULT_MODEL
@@ -50,15 +49,15 @@ bedrock_runtime = boto3.client(
config=config,
)
bedrock_client = boto3.client(
service_name='bedrock',
service_name="bedrock",
region_name=AWS_REGION,
config=config,
)
def get_inference_region_prefix():
if AWS_REGION.startswith('ap-'):
return 'apac'
if AWS_REGION.startswith("ap-"):
return "apac"
return AWS_REGION[:2]
@@ -88,49 +87,38 @@ def list_bedrock_models() -> dict:
profile_list = []
if ENABLE_CROSS_REGION_INFERENCE:
# List system defined inference profile IDs
response = bedrock_client.list_inference_profiles(
maxResults=1000,
typeEquals='SYSTEM_DEFINED'
)
profile_list = [p['inferenceProfileId'] for p in response['inferenceProfileSummaries']]
response = bedrock_client.list_inference_profiles(maxResults=1000, typeEquals="SYSTEM_DEFINED")
profile_list = [p["inferenceProfileId"] for p in response["inferenceProfileSummaries"]]
# List foundation models, only cares about text outputs here.
response = bedrock_client.list_foundation_models(
byOutputModality='TEXT'
)
response = bedrock_client.list_foundation_models(byOutputModality="TEXT")
for model in response['modelSummaries']:
model_id = model.get('modelId', 'N/A')
stream_supported = model.get('responseStreamingSupported', True)
status = model['modelLifecycle'].get('status', 'ACTIVE')
for model in response["modelSummaries"]:
model_id = model.get("modelId", "N/A")
stream_supported = model.get("responseStreamingSupported", True)
status = model["modelLifecycle"].get("status", "ACTIVE")
# currently, use this to filter out rerank models and legacy models
if not stream_supported or status not in ["ACTIVE", "LEGACY"]:
continue
inference_types = model.get('inferenceTypesSupported', [])
input_modalities = model['inputModalities']
inference_types = model.get("inferenceTypesSupported", [])
input_modalities = model["inputModalities"]
# Add on-demand model list
if 'ON_DEMAND' in inference_types:
model_list[model_id] = {
'modalities': input_modalities
}
if "ON_DEMAND" in inference_types:
model_list[model_id] = {"modalities": input_modalities}
# Add cross-region inference model list.
profile_id = cr_inference_prefix + '.' + model_id
profile_id = cr_inference_prefix + "." + model_id
if profile_id in profile_list:
model_list[profile_id] = {
'modalities': input_modalities
}
model_list[profile_id] = {"modalities": input_modalities}
except Exception as e:
logger.error(f"Unable to list models: {str(e)}")
if not model_list:
# In case stack not updated.
model_list[DEFAULT_MODEL] = {
'modalities': ["TEXT", "IMAGE"]
}
model_list[DEFAULT_MODEL] = {"modalities": ["TEXT", "IMAGE"]}
return model_list
@@ -140,7 +128,6 @@ bedrock_model_list = list_bedrock_models()
class BedrockModel(BaseChatModel):
def list_models(self) -> list[str]:
"""Always refresh the latest model list"""
global bedrock_model_list
@@ -224,10 +211,7 @@ class BedrockModel(BaseChatModel):
logger.info("Proxy response :" + stream_response.model_dump_json())
if stream_response.choices:
yield self.stream_response_to_bytes(stream_response)
elif (
chat_request.stream_options
and chat_request.stream_options.include_usage
):
elif chat_request.stream_options and chat_request.stream_options.include_usage:
# An empty choices for Usage as per OpenAI doc below:
# if you set stream_options: {"include_usage": true}.
# an additional chunk will be streamed before the data: [DONE] message.
@@ -277,9 +261,7 @@ class BedrockModel(BaseChatModel):
messages.append(
{
"role": message.role,
"content": self._parse_content_parts(
message, chat_request.model
),
"content": self._parse_content_parts(message, chat_request.model),
}
)
elif isinstance(message, AssistantMessage):
@@ -288,9 +270,7 @@ class BedrockModel(BaseChatModel):
messages.append(
{
"role": message.role,
"content": self._parse_content_parts(
message, chat_request.model
),
"content": self._parse_content_parts(message, chat_request.model),
}
)
if message.tool_calls:
@@ -305,7 +285,7 @@ class BedrockModel(BaseChatModel):
"toolUse": {
"toolUseId": tool_call.id,
"name": tool_call.function.name,
"input": tool_input
"input": tool_input,
}
}
],
@@ -335,7 +315,7 @@ class BedrockModel(BaseChatModel):
return self._reframe_multi_payloard(messages)
def _reframe_multi_payloard(self, messages: list) -> list:
""" Receive messages and reformat them to comply with the Claude format
"""Receive messages and reformat them to comply with the Claude format
With OpenAI format requests, it's not a problem to repeatedly receive messages from the same role, but
with Claude format requests, you cannot repeatedly receive messages from the same role.
@@ -364,16 +344,13 @@ class BedrockModel(BaseChatModel):
# Search through the list of messages and combine messages from the same role into one list
for message in messages:
next_role = message['role']
next_content = message['content']
next_role = message["role"]
next_content = message["content"]
# If the next role is different from the previous message, add the previous role's messages to the list
if next_role != current_role:
if current_content:
reformatted_messages.append({
"role": current_role,
"content": current_content
})
reformatted_messages.append({"role": current_role, "content": current_content})
# Switch to the new role
current_role = next_role
current_content = []
@@ -386,10 +363,7 @@ class BedrockModel(BaseChatModel):
# Add the last role's messages to the list
if current_content:
reformatted_messages.append({
"role": current_role,
"content": current_content
})
reformatted_messages.append({"role": current_role, "content": current_content})
return reformatted_messages
@@ -426,25 +400,20 @@ class BedrockModel(BaseChatModel):
# From OpenAI api, the max_token is not supported in reasoning mode
# Use max_completion_tokens if provided.
max_tokens = chat_request.max_completion_tokens if chat_request.max_completion_tokens else chat_request.max_tokens
max_tokens = (
chat_request.max_completion_tokens if chat_request.max_completion_tokens else chat_request.max_tokens
)
budget_tokens = self._calc_budget_tokens(max_tokens, chat_request.reasoning_effort)
inference_config["maxTokens"] = max_tokens
# unset topP - Not supported
inference_config.pop("topP")
args["additionalModelRequestFields"] = {
"reasoning_config": {
"type": "enabled",
"budget_tokens": budget_tokens
}
"reasoning_config": {"type": "enabled", "budget_tokens": budget_tokens}
}
# add tool config
if chat_request.tools:
args["toolConfig"] = {
"tools": [
self._convert_tool_spec(t.function) for t in chat_request.tools
]
}
args["toolConfig"] = {"tools": [self._convert_tool_spec(t.function) for t in chat_request.tools]}
if chat_request.tool_choice and not chat_request.model.startswith("meta.llama3-1-"):
if isinstance(chat_request.tool_choice, str):
@@ -458,19 +427,19 @@ class BedrockModel(BaseChatModel):
# Specific tool to use
assert "function" in chat_request.tool_choice
args["toolConfig"]["toolChoice"] = {
"tool": {"name": chat_request.tool_choice["function"].get("name", "")}}
"tool": {"name": chat_request.tool_choice["function"].get("name", "")}
}
return args
def _create_response(
self,
model: str,
message_id: str,
content: list[dict] = None,
finish_reason: str | None = None,
input_tokens: int = 0,
output_tokens: int = 0,
self,
model: str,
message_id: str,
content: list[dict] = None,
finish_reason: str | None = None,
input_tokens: int = 0,
output_tokens: int = 0,
) -> ChatResponse:
message = ChatResponseMessage(
role="assistant",
)
@@ -524,9 +493,7 @@ class BedrockModel(BaseChatModel):
response.created = int(time.time())
return response
def _create_response_stream(
self, model_id: str, message_id: str, chunk: dict
) -> ChatStreamResponse | None:
def _create_response_stream(self, model_id: str, message_id: str, chunk: dict) -> ChatStreamResponse | None:
"""Parsing the Bedrock stream response chunk.
Ref: https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#message-inference-examples
@@ -583,7 +550,7 @@ class BedrockModel(BaseChatModel):
index=index,
function=ResponseFunction(
arguments=delta["toolUse"]["input"],
)
),
)
]
)
@@ -641,7 +608,6 @@ class BedrockModel(BaseChatModel):
response = requests.get(image_url)
# Check if the request was successful
if response.status_code == 200:
content_type = response.headers.get("Content-Type")
if not content_type.startswith("image"):
content_type = "image/jpeg"
@@ -649,14 +615,12 @@ class BedrockModel(BaseChatModel):
image_content = response.content
return image_content, content_type
else:
raise HTTPException(
status_code=500, detail="Unable to access the image url"
)
raise HTTPException(status_code=500, detail="Unable to access the image url")
def _parse_content_parts(
self,
message: UserMessage,
model_id: str,
self,
message: UserMessage,
model_id: str,
) -> list[dict]:
if isinstance(message.content, str):
return [
@@ -695,7 +659,7 @@ class BedrockModel(BaseChatModel):
@staticmethod
def is_supported_modality(model_id: str, modality: str = "IMAGE") -> bool:
model = bedrock_model_list.get(model_id)
modalities = model.get('modalities', [])
modalities = model.get("modalities", [])
if modality in modalities:
return True
return False
@@ -740,7 +704,7 @@ class BedrockModel(BaseChatModel):
"max_tokens": "length",
"stop_sequence": "stop",
"complete": "stop",
"content_filtered": "content_filter"
"content_filtered": "content_filter",
}
return finish_reason_mapping.get(finish_reason.lower(), finish_reason.lower())
return None
@@ -773,12 +737,12 @@ class BedrockEmbeddingsModel(BaseEmbeddingsModel, ABC):
raise HTTPException(status_code=500, detail=str(e))
def _create_response(
self,
embeddings: list[float],
model: str,
input_tokens: int = 0,
output_tokens: int = 0,
encoding_format: Literal["float", "base64"] = "float",
self,
embeddings: list[float],
model: str,
input_tokens: int = 0,
output_tokens: int = 0,
encoding_format: Literal["float", "base64"] = "float",
) -> EmbeddingsResponse:
data = []
for i, embedding in enumerate(embeddings):
@@ -803,7 +767,6 @@ class BedrockEmbeddingsModel(BaseEmbeddingsModel, ABC):
class CohereEmbeddingsModel(BedrockEmbeddingsModel):
def _parse_args(self, embeddings_request: EmbeddingsRequest) -> dict:
texts = []
if isinstance(embeddings_request.input, str):
@@ -834,9 +797,7 @@ class CohereEmbeddingsModel(BedrockEmbeddingsModel):
return args
def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse:
response = self._invoke_model(
args=self._parse_args(embeddings_request), model_id=embeddings_request.model
)
response = self._invoke_model(args=self._parse_args(embeddings_request), model_id=embeddings_request.model)
response_body = json.loads(response.get("body").read())
if DEBUG:
logger.info("Bedrock response body: " + str(response_body))
@@ -849,19 +810,13 @@ class CohereEmbeddingsModel(BedrockEmbeddingsModel):
class TitanEmbeddingsModel(BedrockEmbeddingsModel):
def _parse_args(self, embeddings_request: EmbeddingsRequest) -> dict:
if isinstance(embeddings_request.input, str):
input_text = embeddings_request.input
elif (
isinstance(embeddings_request.input, list)
and len(embeddings_request.input) == 1
):
elif isinstance(embeddings_request.input, list) and len(embeddings_request.input) == 1:
input_text = embeddings_request.input[0]
else:
raise ValueError(
"Amazon Titan Embeddings models support only single strings as input."
)
raise ValueError("Amazon Titan Embeddings models support only single strings as input.")
args = {
"inputText": input_text,
# Note: inputImage is not supported!
@@ -875,9 +830,7 @@ class TitanEmbeddingsModel(BedrockEmbeddingsModel):
return args
def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse:
response = self._invoke_model(
args=self._parse_args(embeddings_request), model_id=embeddings_request.model
)
response = self._invoke_model(args=self._parse_args(embeddings_request), model_id=embeddings_request.model)
response_body = json.loads(response.get("body").read())
if DEBUG:
logger.info("Bedrock response body: " + str(response_body))

View File

@@ -17,20 +17,20 @@ router = APIRouter(
@router.post("/completions", response_model=ChatResponse | ChatStreamResponse, response_model_exclude_unset=True)
async def chat_completions(
chat_request: Annotated[
ChatRequest,
Body(
examples=[
{
"model": "anthropic.claude-3-sonnet-20240229-v1:0",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello!"},
],
}
],
),
]
chat_request: Annotated[
ChatRequest,
Body(
examples=[
{
"model": "anthropic.claude-3-sonnet-20240229-v1:0",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello!"},
],
}
],
),
],
):
if chat_request.model.lower().startswith("gpt-"):
chat_request.model = DEFAULT_MODEL
@@ -39,7 +39,5 @@ async def chat_completions(
model = BedrockModel()
model.validate(chat_request)
if chat_request.stream:
return StreamingResponse(
content=model.chat_stream(chat_request), media_type="text/event-stream"
)
return StreamingResponse(content=model.chat_stream(chat_request), media_type="text/event-stream")
return model.chat(chat_request)

View File

@@ -15,19 +15,17 @@ router = APIRouter(
@router.post("", response_model=EmbeddingsResponse)
async def embeddings(
embeddings_request: Annotated[
EmbeddingsRequest,
Body(
examples=[
{
"model": "cohere.embed-multilingual-v3",
"input": [
"Your text string goes here"
],
}
],
),
]
embeddings_request: Annotated[
EmbeddingsRequest,
Body(
examples=[
{
"model": "cohere.embed-multilingual-v3",
"input": ["Your text string goes here"],
}
],
),
],
):
if embeddings_request.model.lower().startswith("text-embedding-"):
embeddings_request.model = DEFAULT_EMBEDDING_MODEL

View File

@@ -22,9 +22,7 @@ async def validate_model_id(model_id: str):
@router.get("", response_model=Models)
async def list_models():
model_list = [
Model(id=model_id) for model_id in chat_model.list_models()
]
model_list = [Model(id=model_id) for model_id in chat_model.list_models()]
return Models(data=model_list)
@@ -33,10 +31,10 @@ async def list_models():
response_model=Model,
)
async def get_model(
model_id: Annotated[
str,
Path(description="Model ID", example="anthropic.claude-3-sonnet-20240229-v1:0"),
]
model_id: Annotated[
str,
Path(description="Model ID", example="anthropic.claude-3-sonnet-20240229-v1:0"),
],
):
await validate_model_id(model_id)
return Model(id=model_id)

View File

@@ -13,10 +13,6 @@ Use OpenAI-Compatible RESTful APIs for Amazon Bedrock models.
DEBUG = os.environ.get("DEBUG", "false").lower() != "false"
AWS_REGION = os.environ.get("AWS_REGION", "us-west-2")
DEFAULT_MODEL = os.environ.get(
"DEFAULT_MODEL", "anthropic.claude-3-sonnet-20240229-v1:0"
)
DEFAULT_EMBEDDING_MODEL = os.environ.get(
"DEFAULT_EMBEDDING_MODEL", "cohere.embed-multilingual-v3"
)
DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "anthropic.claude-3-sonnet-20240229-v1:0")
DEFAULT_EMBEDDING_MODEL = os.environ.get("DEFAULT_EMBEDDING_MODEL", "cohere.embed-multilingual-v3")
ENABLE_CROSS_REGION_INFERENCE = os.environ.get("ENABLE_CROSS_REGION_INFERENCE", "true").lower() != "false"