apply ruff linter
This commit is contained in:
@@ -16,9 +16,7 @@ if api_key_param:
|
||||
# For backward compatibility.
|
||||
# Please now use secrets manager instead.
|
||||
ssm = boto3.client("ssm")
|
||||
api_key = ssm.get_parameter(Name=api_key_param, WithDecryption=True)["Parameter"][
|
||||
"Value"
|
||||
]
|
||||
api_key = ssm.get_parameter(Name=api_key_param, WithDecryption=True)["Parameter"]["Value"]
|
||||
elif api_key_secret_arn:
|
||||
sm = boto3.client("secretsmanager")
|
||||
try:
|
||||
@@ -26,11 +24,9 @@ elif api_key_secret_arn:
|
||||
if "SecretString" in response:
|
||||
secret = json.loads(response["SecretString"])
|
||||
api_key = secret["api_key"]
|
||||
except ClientError as e:
|
||||
raise RuntimeError(
|
||||
"Unable to retrieve API KEY, please ensure the secret ARN is correct"
|
||||
)
|
||||
except KeyError as e:
|
||||
except ClientError:
|
||||
raise RuntimeError("Unable to retrieve API KEY, please ensure the secret ARN is correct")
|
||||
except KeyError:
|
||||
raise RuntimeError('Please ensure the secret contains a "api_key" field')
|
||||
elif api_key_env:
|
||||
api_key = api_key_env
|
||||
@@ -45,6 +41,4 @@ def api_key_auth(
|
||||
credentials: Annotated[HTTPAuthorizationCredentials, Depends(security)],
|
||||
):
|
||||
if credentials.credentials != api_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API Key"
|
||||
)
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API Key")
|
||||
|
||||
@@ -43,9 +43,7 @@ class BaseChatModel(ABC):
|
||||
return "chatcmpl-" + str(uuid.uuid4())[:8]
|
||||
|
||||
@staticmethod
|
||||
def stream_response_to_bytes(
|
||||
response: ChatStreamResponse | None = None
|
||||
) -> bytes:
|
||||
def stream_response_to_bytes(response: ChatStreamResponse | None = None) -> bytes:
|
||||
if response:
|
||||
# to populate other fields when using exclude_unset=True
|
||||
response.system_fingerprint = "fp"
|
||||
|
||||
@@ -36,7 +36,6 @@ from api.schema import (
|
||||
EmbeddingsResponse,
|
||||
EmbeddingsUsage,
|
||||
Embedding,
|
||||
|
||||
)
|
||||
from api.setting import DEBUG, AWS_REGION, ENABLE_CROSS_REGION_INFERENCE, DEFAULT_MODEL
|
||||
|
||||
@@ -50,15 +49,15 @@ bedrock_runtime = boto3.client(
|
||||
config=config,
|
||||
)
|
||||
bedrock_client = boto3.client(
|
||||
service_name='bedrock',
|
||||
service_name="bedrock",
|
||||
region_name=AWS_REGION,
|
||||
config=config,
|
||||
)
|
||||
|
||||
|
||||
def get_inference_region_prefix():
|
||||
if AWS_REGION.startswith('ap-'):
|
||||
return 'apac'
|
||||
if AWS_REGION.startswith("ap-"):
|
||||
return "apac"
|
||||
return AWS_REGION[:2]
|
||||
|
||||
|
||||
@@ -88,49 +87,38 @@ def list_bedrock_models() -> dict:
|
||||
profile_list = []
|
||||
if ENABLE_CROSS_REGION_INFERENCE:
|
||||
# List system defined inference profile IDs
|
||||
response = bedrock_client.list_inference_profiles(
|
||||
maxResults=1000,
|
||||
typeEquals='SYSTEM_DEFINED'
|
||||
)
|
||||
profile_list = [p['inferenceProfileId'] for p in response['inferenceProfileSummaries']]
|
||||
response = bedrock_client.list_inference_profiles(maxResults=1000, typeEquals="SYSTEM_DEFINED")
|
||||
profile_list = [p["inferenceProfileId"] for p in response["inferenceProfileSummaries"]]
|
||||
|
||||
# List foundation models, only cares about text outputs here.
|
||||
response = bedrock_client.list_foundation_models(
|
||||
byOutputModality='TEXT'
|
||||
)
|
||||
response = bedrock_client.list_foundation_models(byOutputModality="TEXT")
|
||||
|
||||
for model in response['modelSummaries']:
|
||||
model_id = model.get('modelId', 'N/A')
|
||||
stream_supported = model.get('responseStreamingSupported', True)
|
||||
status = model['modelLifecycle'].get('status', 'ACTIVE')
|
||||
for model in response["modelSummaries"]:
|
||||
model_id = model.get("modelId", "N/A")
|
||||
stream_supported = model.get("responseStreamingSupported", True)
|
||||
status = model["modelLifecycle"].get("status", "ACTIVE")
|
||||
|
||||
# currently, use this to filter out rerank models and legacy models
|
||||
if not stream_supported or status not in ["ACTIVE", "LEGACY"]:
|
||||
continue
|
||||
|
||||
inference_types = model.get('inferenceTypesSupported', [])
|
||||
input_modalities = model['inputModalities']
|
||||
inference_types = model.get("inferenceTypesSupported", [])
|
||||
input_modalities = model["inputModalities"]
|
||||
# Add on-demand model list
|
||||
if 'ON_DEMAND' in inference_types:
|
||||
model_list[model_id] = {
|
||||
'modalities': input_modalities
|
||||
}
|
||||
if "ON_DEMAND" in inference_types:
|
||||
model_list[model_id] = {"modalities": input_modalities}
|
||||
|
||||
# Add cross-region inference model list.
|
||||
profile_id = cr_inference_prefix + '.' + model_id
|
||||
profile_id = cr_inference_prefix + "." + model_id
|
||||
if profile_id in profile_list:
|
||||
model_list[profile_id] = {
|
||||
'modalities': input_modalities
|
||||
}
|
||||
model_list[profile_id] = {"modalities": input_modalities}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unable to list models: {str(e)}")
|
||||
|
||||
if not model_list:
|
||||
# In case stack not updated.
|
||||
model_list[DEFAULT_MODEL] = {
|
||||
'modalities': ["TEXT", "IMAGE"]
|
||||
}
|
||||
model_list[DEFAULT_MODEL] = {"modalities": ["TEXT", "IMAGE"]}
|
||||
|
||||
return model_list
|
||||
|
||||
@@ -140,7 +128,6 @@ bedrock_model_list = list_bedrock_models()
|
||||
|
||||
|
||||
class BedrockModel(BaseChatModel):
|
||||
|
||||
def list_models(self) -> list[str]:
|
||||
"""Always refresh the latest model list"""
|
||||
global bedrock_model_list
|
||||
@@ -224,10 +211,7 @@ class BedrockModel(BaseChatModel):
|
||||
logger.info("Proxy response :" + stream_response.model_dump_json())
|
||||
if stream_response.choices:
|
||||
yield self.stream_response_to_bytes(stream_response)
|
||||
elif (
|
||||
chat_request.stream_options
|
||||
and chat_request.stream_options.include_usage
|
||||
):
|
||||
elif chat_request.stream_options and chat_request.stream_options.include_usage:
|
||||
# An empty choices for Usage as per OpenAI doc below:
|
||||
# if you set stream_options: {"include_usage": true}.
|
||||
# an additional chunk will be streamed before the data: [DONE] message.
|
||||
@@ -277,9 +261,7 @@ class BedrockModel(BaseChatModel):
|
||||
messages.append(
|
||||
{
|
||||
"role": message.role,
|
||||
"content": self._parse_content_parts(
|
||||
message, chat_request.model
|
||||
),
|
||||
"content": self._parse_content_parts(message, chat_request.model),
|
||||
}
|
||||
)
|
||||
elif isinstance(message, AssistantMessage):
|
||||
@@ -288,9 +270,7 @@ class BedrockModel(BaseChatModel):
|
||||
messages.append(
|
||||
{
|
||||
"role": message.role,
|
||||
"content": self._parse_content_parts(
|
||||
message, chat_request.model
|
||||
),
|
||||
"content": self._parse_content_parts(message, chat_request.model),
|
||||
}
|
||||
)
|
||||
if message.tool_calls:
|
||||
@@ -305,7 +285,7 @@ class BedrockModel(BaseChatModel):
|
||||
"toolUse": {
|
||||
"toolUseId": tool_call.id,
|
||||
"name": tool_call.function.name,
|
||||
"input": tool_input
|
||||
"input": tool_input,
|
||||
}
|
||||
}
|
||||
],
|
||||
@@ -335,7 +315,7 @@ class BedrockModel(BaseChatModel):
|
||||
return self._reframe_multi_payloard(messages)
|
||||
|
||||
def _reframe_multi_payloard(self, messages: list) -> list:
|
||||
""" Receive messages and reformat them to comply with the Claude format
|
||||
"""Receive messages and reformat them to comply with the Claude format
|
||||
|
||||
With OpenAI format requests, it's not a problem to repeatedly receive messages from the same role, but
|
||||
with Claude format requests, you cannot repeatedly receive messages from the same role.
|
||||
@@ -364,16 +344,13 @@ class BedrockModel(BaseChatModel):
|
||||
|
||||
# Search through the list of messages and combine messages from the same role into one list
|
||||
for message in messages:
|
||||
next_role = message['role']
|
||||
next_content = message['content']
|
||||
next_role = message["role"]
|
||||
next_content = message["content"]
|
||||
|
||||
# If the next role is different from the previous message, add the previous role's messages to the list
|
||||
if next_role != current_role:
|
||||
if current_content:
|
||||
reformatted_messages.append({
|
||||
"role": current_role,
|
||||
"content": current_content
|
||||
})
|
||||
reformatted_messages.append({"role": current_role, "content": current_content})
|
||||
# Switch to the new role
|
||||
current_role = next_role
|
||||
current_content = []
|
||||
@@ -386,10 +363,7 @@ class BedrockModel(BaseChatModel):
|
||||
|
||||
# Add the last role's messages to the list
|
||||
if current_content:
|
||||
reformatted_messages.append({
|
||||
"role": current_role,
|
||||
"content": current_content
|
||||
})
|
||||
reformatted_messages.append({"role": current_role, "content": current_content})
|
||||
|
||||
return reformatted_messages
|
||||
|
||||
@@ -426,25 +400,20 @@ class BedrockModel(BaseChatModel):
|
||||
# From OpenAI api, the max_token is not supported in reasoning mode
|
||||
# Use max_completion_tokens if provided.
|
||||
|
||||
max_tokens = chat_request.max_completion_tokens if chat_request.max_completion_tokens else chat_request.max_tokens
|
||||
max_tokens = (
|
||||
chat_request.max_completion_tokens if chat_request.max_completion_tokens else chat_request.max_tokens
|
||||
)
|
||||
budget_tokens = self._calc_budget_tokens(max_tokens, chat_request.reasoning_effort)
|
||||
inference_config["maxTokens"] = max_tokens
|
||||
# unset topP - Not supported
|
||||
inference_config.pop("topP")
|
||||
|
||||
args["additionalModelRequestFields"] = {
|
||||
"reasoning_config": {
|
||||
"type": "enabled",
|
||||
"budget_tokens": budget_tokens
|
||||
}
|
||||
"reasoning_config": {"type": "enabled", "budget_tokens": budget_tokens}
|
||||
}
|
||||
# add tool config
|
||||
if chat_request.tools:
|
||||
args["toolConfig"] = {
|
||||
"tools": [
|
||||
self._convert_tool_spec(t.function) for t in chat_request.tools
|
||||
]
|
||||
}
|
||||
args["toolConfig"] = {"tools": [self._convert_tool_spec(t.function) for t in chat_request.tools]}
|
||||
|
||||
if chat_request.tool_choice and not chat_request.model.startswith("meta.llama3-1-"):
|
||||
if isinstance(chat_request.tool_choice, str):
|
||||
@@ -458,19 +427,19 @@ class BedrockModel(BaseChatModel):
|
||||
# Specific tool to use
|
||||
assert "function" in chat_request.tool_choice
|
||||
args["toolConfig"]["toolChoice"] = {
|
||||
"tool": {"name": chat_request.tool_choice["function"].get("name", "")}}
|
||||
"tool": {"name": chat_request.tool_choice["function"].get("name", "")}
|
||||
}
|
||||
return args
|
||||
|
||||
def _create_response(
|
||||
self,
|
||||
model: str,
|
||||
message_id: str,
|
||||
content: list[dict] = None,
|
||||
finish_reason: str | None = None,
|
||||
input_tokens: int = 0,
|
||||
output_tokens: int = 0,
|
||||
self,
|
||||
model: str,
|
||||
message_id: str,
|
||||
content: list[dict] = None,
|
||||
finish_reason: str | None = None,
|
||||
input_tokens: int = 0,
|
||||
output_tokens: int = 0,
|
||||
) -> ChatResponse:
|
||||
|
||||
message = ChatResponseMessage(
|
||||
role="assistant",
|
||||
)
|
||||
@@ -524,9 +493,7 @@ class BedrockModel(BaseChatModel):
|
||||
response.created = int(time.time())
|
||||
return response
|
||||
|
||||
def _create_response_stream(
|
||||
self, model_id: str, message_id: str, chunk: dict
|
||||
) -> ChatStreamResponse | None:
|
||||
def _create_response_stream(self, model_id: str, message_id: str, chunk: dict) -> ChatStreamResponse | None:
|
||||
"""Parsing the Bedrock stream response chunk.
|
||||
|
||||
Ref: https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#message-inference-examples
|
||||
@@ -583,7 +550,7 @@ class BedrockModel(BaseChatModel):
|
||||
index=index,
|
||||
function=ResponseFunction(
|
||||
arguments=delta["toolUse"]["input"],
|
||||
)
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
@@ -641,7 +608,6 @@ class BedrockModel(BaseChatModel):
|
||||
response = requests.get(image_url)
|
||||
# Check if the request was successful
|
||||
if response.status_code == 200:
|
||||
|
||||
content_type = response.headers.get("Content-Type")
|
||||
if not content_type.startswith("image"):
|
||||
content_type = "image/jpeg"
|
||||
@@ -649,14 +615,12 @@ class BedrockModel(BaseChatModel):
|
||||
image_content = response.content
|
||||
return image_content, content_type
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Unable to access the image url"
|
||||
)
|
||||
raise HTTPException(status_code=500, detail="Unable to access the image url")
|
||||
|
||||
def _parse_content_parts(
|
||||
self,
|
||||
message: UserMessage,
|
||||
model_id: str,
|
||||
self,
|
||||
message: UserMessage,
|
||||
model_id: str,
|
||||
) -> list[dict]:
|
||||
if isinstance(message.content, str):
|
||||
return [
|
||||
@@ -695,7 +659,7 @@ class BedrockModel(BaseChatModel):
|
||||
@staticmethod
|
||||
def is_supported_modality(model_id: str, modality: str = "IMAGE") -> bool:
|
||||
model = bedrock_model_list.get(model_id)
|
||||
modalities = model.get('modalities', [])
|
||||
modalities = model.get("modalities", [])
|
||||
if modality in modalities:
|
||||
return True
|
||||
return False
|
||||
@@ -740,7 +704,7 @@ class BedrockModel(BaseChatModel):
|
||||
"max_tokens": "length",
|
||||
"stop_sequence": "stop",
|
||||
"complete": "stop",
|
||||
"content_filtered": "content_filter"
|
||||
"content_filtered": "content_filter",
|
||||
}
|
||||
return finish_reason_mapping.get(finish_reason.lower(), finish_reason.lower())
|
||||
return None
|
||||
@@ -773,12 +737,12 @@ class BedrockEmbeddingsModel(BaseEmbeddingsModel, ABC):
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
def _create_response(
|
||||
self,
|
||||
embeddings: list[float],
|
||||
model: str,
|
||||
input_tokens: int = 0,
|
||||
output_tokens: int = 0,
|
||||
encoding_format: Literal["float", "base64"] = "float",
|
||||
self,
|
||||
embeddings: list[float],
|
||||
model: str,
|
||||
input_tokens: int = 0,
|
||||
output_tokens: int = 0,
|
||||
encoding_format: Literal["float", "base64"] = "float",
|
||||
) -> EmbeddingsResponse:
|
||||
data = []
|
||||
for i, embedding in enumerate(embeddings):
|
||||
@@ -803,7 +767,6 @@ class BedrockEmbeddingsModel(BaseEmbeddingsModel, ABC):
|
||||
|
||||
|
||||
class CohereEmbeddingsModel(BedrockEmbeddingsModel):
|
||||
|
||||
def _parse_args(self, embeddings_request: EmbeddingsRequest) -> dict:
|
||||
texts = []
|
||||
if isinstance(embeddings_request.input, str):
|
||||
@@ -834,9 +797,7 @@ class CohereEmbeddingsModel(BedrockEmbeddingsModel):
|
||||
return args
|
||||
|
||||
def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse:
|
||||
response = self._invoke_model(
|
||||
args=self._parse_args(embeddings_request), model_id=embeddings_request.model
|
||||
)
|
||||
response = self._invoke_model(args=self._parse_args(embeddings_request), model_id=embeddings_request.model)
|
||||
response_body = json.loads(response.get("body").read())
|
||||
if DEBUG:
|
||||
logger.info("Bedrock response body: " + str(response_body))
|
||||
@@ -849,19 +810,13 @@ class CohereEmbeddingsModel(BedrockEmbeddingsModel):
|
||||
|
||||
|
||||
class TitanEmbeddingsModel(BedrockEmbeddingsModel):
|
||||
|
||||
def _parse_args(self, embeddings_request: EmbeddingsRequest) -> dict:
|
||||
if isinstance(embeddings_request.input, str):
|
||||
input_text = embeddings_request.input
|
||||
elif (
|
||||
isinstance(embeddings_request.input, list)
|
||||
and len(embeddings_request.input) == 1
|
||||
):
|
||||
elif isinstance(embeddings_request.input, list) and len(embeddings_request.input) == 1:
|
||||
input_text = embeddings_request.input[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
"Amazon Titan Embeddings models support only single strings as input."
|
||||
)
|
||||
raise ValueError("Amazon Titan Embeddings models support only single strings as input.")
|
||||
args = {
|
||||
"inputText": input_text,
|
||||
# Note: inputImage is not supported!
|
||||
@@ -875,9 +830,7 @@ class TitanEmbeddingsModel(BedrockEmbeddingsModel):
|
||||
return args
|
||||
|
||||
def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse:
|
||||
response = self._invoke_model(
|
||||
args=self._parse_args(embeddings_request), model_id=embeddings_request.model
|
||||
)
|
||||
response = self._invoke_model(args=self._parse_args(embeddings_request), model_id=embeddings_request.model)
|
||||
response_body = json.loads(response.get("body").read())
|
||||
if DEBUG:
|
||||
logger.info("Bedrock response body: " + str(response_body))
|
||||
|
||||
@@ -17,20 +17,20 @@ router = APIRouter(
|
||||
|
||||
@router.post("/completions", response_model=ChatResponse | ChatStreamResponse, response_model_exclude_unset=True)
|
||||
async def chat_completions(
|
||||
chat_request: Annotated[
|
||||
ChatRequest,
|
||||
Body(
|
||||
examples=[
|
||||
{
|
||||
"model": "anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello!"},
|
||||
],
|
||||
}
|
||||
],
|
||||
),
|
||||
]
|
||||
chat_request: Annotated[
|
||||
ChatRequest,
|
||||
Body(
|
||||
examples=[
|
||||
{
|
||||
"model": "anthropic.claude-3-sonnet-20240229-v1:0",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello!"},
|
||||
],
|
||||
}
|
||||
],
|
||||
),
|
||||
],
|
||||
):
|
||||
if chat_request.model.lower().startswith("gpt-"):
|
||||
chat_request.model = DEFAULT_MODEL
|
||||
@@ -39,7 +39,5 @@ async def chat_completions(
|
||||
model = BedrockModel()
|
||||
model.validate(chat_request)
|
||||
if chat_request.stream:
|
||||
return StreamingResponse(
|
||||
content=model.chat_stream(chat_request), media_type="text/event-stream"
|
||||
)
|
||||
return StreamingResponse(content=model.chat_stream(chat_request), media_type="text/event-stream")
|
||||
return model.chat(chat_request)
|
||||
|
||||
@@ -15,19 +15,17 @@ router = APIRouter(
|
||||
|
||||
@router.post("", response_model=EmbeddingsResponse)
|
||||
async def embeddings(
|
||||
embeddings_request: Annotated[
|
||||
EmbeddingsRequest,
|
||||
Body(
|
||||
examples=[
|
||||
{
|
||||
"model": "cohere.embed-multilingual-v3",
|
||||
"input": [
|
||||
"Your text string goes here"
|
||||
],
|
||||
}
|
||||
],
|
||||
),
|
||||
]
|
||||
embeddings_request: Annotated[
|
||||
EmbeddingsRequest,
|
||||
Body(
|
||||
examples=[
|
||||
{
|
||||
"model": "cohere.embed-multilingual-v3",
|
||||
"input": ["Your text string goes here"],
|
||||
}
|
||||
],
|
||||
),
|
||||
],
|
||||
):
|
||||
if embeddings_request.model.lower().startswith("text-embedding-"):
|
||||
embeddings_request.model = DEFAULT_EMBEDDING_MODEL
|
||||
|
||||
@@ -22,9 +22,7 @@ async def validate_model_id(model_id: str):
|
||||
|
||||
@router.get("", response_model=Models)
|
||||
async def list_models():
|
||||
model_list = [
|
||||
Model(id=model_id) for model_id in chat_model.list_models()
|
||||
]
|
||||
model_list = [Model(id=model_id) for model_id in chat_model.list_models()]
|
||||
return Models(data=model_list)
|
||||
|
||||
|
||||
@@ -33,10 +31,10 @@ async def list_models():
|
||||
response_model=Model,
|
||||
)
|
||||
async def get_model(
|
||||
model_id: Annotated[
|
||||
str,
|
||||
Path(description="Model ID", example="anthropic.claude-3-sonnet-20240229-v1:0"),
|
||||
]
|
||||
model_id: Annotated[
|
||||
str,
|
||||
Path(description="Model ID", example="anthropic.claude-3-sonnet-20240229-v1:0"),
|
||||
],
|
||||
):
|
||||
await validate_model_id(model_id)
|
||||
return Model(id=model_id)
|
||||
|
||||
@@ -13,10 +13,6 @@ Use OpenAI-Compatible RESTful APIs for Amazon Bedrock models.
|
||||
|
||||
DEBUG = os.environ.get("DEBUG", "false").lower() != "false"
|
||||
AWS_REGION = os.environ.get("AWS_REGION", "us-west-2")
|
||||
DEFAULT_MODEL = os.environ.get(
|
||||
"DEFAULT_MODEL", "anthropic.claude-3-sonnet-20240229-v1:0"
|
||||
)
|
||||
DEFAULT_EMBEDDING_MODEL = os.environ.get(
|
||||
"DEFAULT_EMBEDDING_MODEL", "cohere.embed-multilingual-v3"
|
||||
)
|
||||
DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "anthropic.claude-3-sonnet-20240229-v1:0")
|
||||
DEFAULT_EMBEDDING_MODEL = os.environ.get("DEFAULT_EMBEDDING_MODEL", "cohere.embed-multilingual-v3")
|
||||
ENABLE_CROSS_REGION_INFERENCE = os.environ.get("ENABLE_CROSS_REGION_INFERENCE", "true").lower() != "false"
|
||||
|
||||
Reference in New Issue
Block a user