feat: add support to include application inference profiles as models (#131)
--------- Co-authored-by: Mengxin Zhu <843303+zxkane@users.noreply.github.com>
This commit is contained in:
42
README.md
42
README.md
@@ -26,6 +26,7 @@ If you find this GitHub repository useful, please consider giving it a free star
|
|||||||
- [x] Support Embedding API
|
- [x] Support Embedding API
|
||||||
- [x] Support Multimodal API
|
- [x] Support Multimodal API
|
||||||
- [x] Support Cross-Region Inference
|
- [x] Support Cross-Region Inference
|
||||||
|
- [x] Support Application Inference Profiles (**new**)
|
||||||
- [x] Support Reasoning (**new**)
|
- [x] Support Reasoning (**new**)
|
||||||
|
|
||||||
Please check [Usage Guide](./docs/Usage.md) for more details about how to use the new APIs.
|
Please check [Usage Guide](./docs/Usage.md) for more details about how to use the new APIs.
|
||||||
@@ -148,7 +149,48 @@ print(completion.choices[0].message.content)
|
|||||||
|
|
||||||
Please check [Usage Guide](./docs/Usage.md) for more details about how to use embedding API, multimodal API and tool call.
|
Please check [Usage Guide](./docs/Usage.md) for more details about how to use embedding API, multimodal API and tool call.
|
||||||
|
|
||||||
|
### Application Inference Profiles
|
||||||
|
|
||||||
|
This proxy now supports **Application Inference Profiles**, which allow you to track usage and costs for your model invocations. You can use application inference profiles created in your AWS account for cost tracking and monitoring purposes.
|
||||||
|
|
||||||
|
**Using Application Inference Profiles:**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Use an application inference profile ARN as the model ID
|
||||||
|
curl $OPENAI_BASE_URL/chat/completions \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-H "Authorization: Bearer $OPENAI_API_KEY" \
|
||||||
|
-d '{
|
||||||
|
"model": "arn:aws:bedrock:us-west-2:123456789012:application-inference-profile/your-profile-id",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hello!"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
**SDK Usage with Application Inference Profiles:**
|
||||||
|
|
||||||
|
```python
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
client = OpenAI()
|
||||||
|
completion = client.chat.completions.create(
|
||||||
|
model="arn:aws:bedrock:us-west-2:123456789012:application-inference-profile/your-profile-id",
|
||||||
|
messages=[{"role": "user", "content": "Hello!"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
print(completion.choices[0].message.content)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Benefits of Application Inference Profiles:**
|
||||||
|
- **Cost Tracking**: Track usage and costs for specific applications or use cases
|
||||||
|
- **Usage Monitoring**: Monitor model invocation metrics through CloudWatch
|
||||||
|
- **Tag-based Cost Allocation**: Use AWS cost allocation tags for detailed billing analysis
|
||||||
|
|
||||||
|
For more information about creating and managing application inference profiles, see the [Amazon Bedrock User Guide](https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-create.html).
|
||||||
|
|
||||||
## Other Examples
|
## Other Examples
|
||||||
|
|
||||||
|
|||||||
@@ -151,6 +151,7 @@ Resources:
|
|||||||
Resource:
|
Resource:
|
||||||
- arn:aws:bedrock:*::foundation-model/*
|
- arn:aws:bedrock:*::foundation-model/*
|
||||||
- arn:aws:bedrock:*:*:inference-profile/*
|
- arn:aws:bedrock:*:*:inference-profile/*
|
||||||
|
- arn:aws:bedrock:*:*:application-inference-profile/*
|
||||||
- Action:
|
- Action:
|
||||||
- secretsmanager:GetSecretValue
|
- secretsmanager:GetSecretValue
|
||||||
- secretsmanager:DescribeSecret
|
- secretsmanager:DescribeSecret
|
||||||
@@ -185,6 +186,7 @@ Resources:
|
|||||||
Ref: DefaultModelId
|
Ref: DefaultModelId
|
||||||
DEFAULT_EMBEDDING_MODEL: cohere.embed-multilingual-v3
|
DEFAULT_EMBEDDING_MODEL: cohere.embed-multilingual-v3
|
||||||
ENABLE_CROSS_REGION_INFERENCE: "true"
|
ENABLE_CROSS_REGION_INFERENCE: "true"
|
||||||
|
ENABLE_APPLICATION_INFERENCE_PROFILES: "true"
|
||||||
MemorySize: 1024
|
MemorySize: 1024
|
||||||
PackageType: Image
|
PackageType: Image
|
||||||
Role:
|
Role:
|
||||||
|
|||||||
@@ -193,6 +193,7 @@ Resources:
|
|||||||
Resource:
|
Resource:
|
||||||
- arn:aws:bedrock:*::foundation-model/*
|
- arn:aws:bedrock:*::foundation-model/*
|
||||||
- arn:aws:bedrock:*:*:inference-profile/*
|
- arn:aws:bedrock:*:*:inference-profile/*
|
||||||
|
- arn:aws:bedrock:*:*:application-inference-profile/*
|
||||||
Version: "2012-10-17"
|
Version: "2012-10-17"
|
||||||
PolicyName: ProxyTaskRoleDefaultPolicy933321B8
|
PolicyName: ProxyTaskRoleDefaultPolicy933321B8
|
||||||
Roles:
|
Roles:
|
||||||
@@ -222,6 +223,8 @@ Resources:
|
|||||||
Value: cohere.embed-multilingual-v3
|
Value: cohere.embed-multilingual-v3
|
||||||
- Name: ENABLE_CROSS_REGION_INFERENCE
|
- Name: ENABLE_CROSS_REGION_INFERENCE
|
||||||
Value: "true"
|
Value: "true"
|
||||||
|
- Name: ENABLE_APPLICATION_INFERENCE_PROFILES
|
||||||
|
Value: "true"
|
||||||
Essential: true
|
Essential: true
|
||||||
Image:
|
Image:
|
||||||
Fn::Join:
|
Fn::Join:
|
||||||
|
|||||||
@@ -38,7 +38,13 @@ from api.schema import (
|
|||||||
Usage,
|
Usage,
|
||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
from api.setting import AWS_REGION, DEBUG, DEFAULT_MODEL, ENABLE_CROSS_REGION_INFERENCE
|
from api.setting import (
|
||||||
|
AWS_REGION,
|
||||||
|
DEBUG,
|
||||||
|
DEFAULT_MODEL,
|
||||||
|
ENABLE_CROSS_REGION_INFERENCE,
|
||||||
|
ENABLE_APPLICATION_INFERENCE_PROFILES,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -83,15 +89,40 @@ def list_bedrock_models() -> dict:
|
|||||||
Returns a model list combines:
|
Returns a model list combines:
|
||||||
- ON_DEMAND models.
|
- ON_DEMAND models.
|
||||||
- Cross-Region Inference Profiles (if enabled via Env)
|
- Cross-Region Inference Profiles (if enabled via Env)
|
||||||
|
- Application Inference Profiles (if enabled via Env)
|
||||||
"""
|
"""
|
||||||
model_list = {}
|
model_list = {}
|
||||||
try:
|
try:
|
||||||
profile_list = []
|
profile_list = []
|
||||||
|
app_profile_dict = {}
|
||||||
|
|
||||||
if ENABLE_CROSS_REGION_INFERENCE:
|
if ENABLE_CROSS_REGION_INFERENCE:
|
||||||
# List system defined inference profile IDs
|
# List system defined inference profile IDs
|
||||||
response = bedrock_client.list_inference_profiles(maxResults=1000, typeEquals="SYSTEM_DEFINED")
|
response = bedrock_client.list_inference_profiles(maxResults=1000, typeEquals="SYSTEM_DEFINED")
|
||||||
profile_list = [p["inferenceProfileId"] for p in response["inferenceProfileSummaries"]]
|
profile_list = [p["inferenceProfileId"] for p in response["inferenceProfileSummaries"]]
|
||||||
|
|
||||||
|
if ENABLE_APPLICATION_INFERENCE_PROFILES:
|
||||||
|
# List application defined inference profile IDs and create mapping
|
||||||
|
response = bedrock_client.list_inference_profiles(maxResults=1000, typeEquals="APPLICATION")
|
||||||
|
|
||||||
|
for profile in response["inferenceProfileSummaries"]:
|
||||||
|
try:
|
||||||
|
profile_arn = profile.get("inferenceProfileArn")
|
||||||
|
if not profile_arn:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Process all models in the profile
|
||||||
|
models = profile.get("models", [])
|
||||||
|
for model in models:
|
||||||
|
model_arn = model.get("modelArn", "")
|
||||||
|
if model_arn:
|
||||||
|
model_id = model_arn.split('/')[-1] if '/' in model_arn else model_arn
|
||||||
|
if model_id:
|
||||||
|
app_profile_dict[model_id] = profile_arn
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error processing application profile: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
# List foundation models, only cares about text outputs here.
|
# List foundation models, only cares about text outputs here.
|
||||||
response = bedrock_client.list_foundation_models(byOutputModality="TEXT")
|
response = bedrock_client.list_foundation_models(byOutputModality="TEXT")
|
||||||
|
|
||||||
@@ -115,6 +146,10 @@ def list_bedrock_models() -> dict:
|
|||||||
if profile_id in profile_list:
|
if profile_id in profile_list:
|
||||||
model_list[profile_id] = {"modalities": input_modalities}
|
model_list[profile_id] = {"modalities": input_modalities}
|
||||||
|
|
||||||
|
# Add application inference profiles
|
||||||
|
if model_id in app_profile_dict:
|
||||||
|
model_list[app_profile_dict[model_id]] = {"modalities": input_modalities}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unable to list models: {str(e)}")
|
logger.error(f"Unable to list models: {str(e)}")
|
||||||
|
|
||||||
@@ -162,7 +197,9 @@ class BedrockModel(BaseChatModel):
|
|||||||
try:
|
try:
|
||||||
if stream:
|
if stream:
|
||||||
# Run the blocking boto3 call in a thread pool
|
# Run the blocking boto3 call in a thread pool
|
||||||
response = await run_in_threadpool(bedrock_runtime.converse_stream, **args)
|
response = await run_in_threadpool(
|
||||||
|
bedrock_runtime.converse_stream, **args
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Run the blocking boto3 call in a thread pool
|
# Run the blocking boto3 call in a thread pool
|
||||||
response = await run_in_threadpool(bedrock_runtime.converse, **args)
|
response = await run_in_threadpool(bedrock_runtime.converse, **args)
|
||||||
@@ -274,7 +311,9 @@ class BedrockModel(BaseChatModel):
|
|||||||
messages.append(
|
messages.append(
|
||||||
{
|
{
|
||||||
"role": message.role,
|
"role": message.role,
|
||||||
"content": self._parse_content_parts(message, chat_request.model),
|
"content": self._parse_content_parts(
|
||||||
|
message, chat_request.model
|
||||||
|
),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
elif isinstance(message, AssistantMessage):
|
elif isinstance(message, AssistantMessage):
|
||||||
@@ -283,7 +322,9 @@ class BedrockModel(BaseChatModel):
|
|||||||
messages.append(
|
messages.append(
|
||||||
{
|
{
|
||||||
"role": message.role,
|
"role": message.role,
|
||||||
"content": self._parse_content_parts(message, chat_request.model),
|
"content": self._parse_content_parts(
|
||||||
|
message, chat_request.model
|
||||||
|
),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
if message.tool_calls:
|
if message.tool_calls:
|
||||||
@@ -363,7 +404,9 @@ class BedrockModel(BaseChatModel):
|
|||||||
# If the next role is different from the previous message, add the previous role's messages to the list
|
# If the next role is different from the previous message, add the previous role's messages to the list
|
||||||
if next_role != current_role:
|
if next_role != current_role:
|
||||||
if current_content:
|
if current_content:
|
||||||
reformatted_messages.append({"role": current_role, "content": current_content})
|
reformatted_messages.append(
|
||||||
|
{"role": current_role, "content": current_content}
|
||||||
|
)
|
||||||
# Switch to the new role
|
# Switch to the new role
|
||||||
current_role = next_role
|
current_role = next_role
|
||||||
current_content = []
|
current_content = []
|
||||||
@@ -376,7 +419,9 @@ class BedrockModel(BaseChatModel):
|
|||||||
|
|
||||||
# Add the last role's messages to the list
|
# Add the last role's messages to the list
|
||||||
if current_content:
|
if current_content:
|
||||||
reformatted_messages.append({"role": current_role, "content": current_content})
|
reformatted_messages.append(
|
||||||
|
{"role": current_role, "content": current_content}
|
||||||
|
)
|
||||||
|
|
||||||
return reformatted_messages
|
return reformatted_messages
|
||||||
|
|
||||||
@@ -414,9 +459,13 @@ class BedrockModel(BaseChatModel):
|
|||||||
# Use max_completion_tokens if provided.
|
# Use max_completion_tokens if provided.
|
||||||
|
|
||||||
max_tokens = (
|
max_tokens = (
|
||||||
chat_request.max_completion_tokens if chat_request.max_completion_tokens else chat_request.max_tokens
|
chat_request.max_completion_tokens
|
||||||
|
if chat_request.max_completion_tokens
|
||||||
|
else chat_request.max_tokens
|
||||||
|
)
|
||||||
|
budget_tokens = self._calc_budget_tokens(
|
||||||
|
max_tokens, chat_request.reasoning_effort
|
||||||
)
|
)
|
||||||
budget_tokens = self._calc_budget_tokens(max_tokens, chat_request.reasoning_effort)
|
|
||||||
inference_config["maxTokens"] = max_tokens
|
inference_config["maxTokens"] = max_tokens
|
||||||
# unset topP - Not supported
|
# unset topP - Not supported
|
||||||
inference_config.pop("topP")
|
inference_config.pop("topP")
|
||||||
@@ -428,7 +477,9 @@ class BedrockModel(BaseChatModel):
|
|||||||
if chat_request.tools:
|
if chat_request.tools:
|
||||||
tool_config = {"tools": [self._convert_tool_spec(t.function) for t in chat_request.tools]}
|
tool_config = {"tools": [self._convert_tool_spec(t.function) for t in chat_request.tools]}
|
||||||
|
|
||||||
if chat_request.tool_choice and not chat_request.model.startswith("meta.llama3-1-"):
|
if chat_request.tool_choice and not chat_request.model.startswith(
|
||||||
|
"meta.llama3-1-"
|
||||||
|
):
|
||||||
if isinstance(chat_request.tool_choice, str):
|
if isinstance(chat_request.tool_choice, str):
|
||||||
# auto (default) is mapped to {"auto" : {}}
|
# auto (default) is mapped to {"auto" : {}}
|
||||||
# required is mapped to {"any" : {}}
|
# required is mapped to {"any" : {}}
|
||||||
@@ -477,11 +528,15 @@ class BedrockModel(BaseChatModel):
|
|||||||
message.content = ""
|
message.content = ""
|
||||||
for c in content:
|
for c in content:
|
||||||
if "reasoningContent" in c:
|
if "reasoningContent" in c:
|
||||||
message.reasoning_content = c["reasoningContent"]["reasoningText"].get("text", "")
|
message.reasoning_content = c["reasoningContent"][
|
||||||
|
"reasoningText"
|
||||||
|
].get("text", "")
|
||||||
elif "text" in c:
|
elif "text" in c:
|
||||||
message.content = c["text"]
|
message.content = c["text"]
|
||||||
else:
|
else:
|
||||||
logger.warning("Unknown tag in message content " + ",".join(c.keys()))
|
logger.warning(
|
||||||
|
"Unknown tag in message content " + ",".join(c.keys())
|
||||||
|
)
|
||||||
|
|
||||||
response = ChatResponse(
|
response = ChatResponse(
|
||||||
id=message_id,
|
id=message_id,
|
||||||
@@ -505,7 +560,9 @@ class BedrockModel(BaseChatModel):
|
|||||||
response.created = int(time.time())
|
response.created = int(time.time())
|
||||||
return response
|
return response
|
||||||
|
|
||||||
def _create_response_stream(self, model_id: str, message_id: str, chunk: dict) -> ChatStreamResponse | None:
|
def _create_response_stream(
|
||||||
|
self, model_id: str, message_id: str, chunk: dict
|
||||||
|
) -> ChatStreamResponse | None:
|
||||||
"""Parsing the Bedrock stream response chunk.
|
"""Parsing the Bedrock stream response chunk.
|
||||||
|
|
||||||
Ref: https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#message-inference-examples
|
Ref: https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#message-inference-examples
|
||||||
@@ -627,7 +684,9 @@ class BedrockModel(BaseChatModel):
|
|||||||
image_content = response.content
|
image_content = response.content
|
||||||
return image_content, content_type
|
return image_content, content_type
|
||||||
else:
|
else:
|
||||||
raise HTTPException(status_code=500, detail="Unable to access the image url")
|
raise HTTPException(
|
||||||
|
status_code=500, detail="Unable to access the image url"
|
||||||
|
)
|
||||||
|
|
||||||
def _parse_content_parts(
|
def _parse_content_parts(
|
||||||
self,
|
self,
|
||||||
@@ -687,7 +746,9 @@ class BedrockModel(BaseChatModel):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
def _calc_budget_tokens(self, max_tokens: int, reasoning_effort: Literal["low", "medium", "high"]) -> int:
|
def _calc_budget_tokens(
|
||||||
|
self, max_tokens: int, reasoning_effort: Literal["low", "medium", "high"]
|
||||||
|
) -> int:
|
||||||
# Helper function to calculate budget_tokens based on the max_tokens.
|
# Helper function to calculate budget_tokens based on the max_tokens.
|
||||||
# Ratio for efforts: Low - 30%, medium - 60%, High: Max token - 1
|
# Ratio for efforts: Low - 30%, medium - 60%, High: Max token - 1
|
||||||
# Note that The minimum budget_tokens is 1,024 tokens so far.
|
# Note that The minimum budget_tokens is 1,024 tokens so far.
|
||||||
@@ -718,7 +779,9 @@ class BedrockModel(BaseChatModel):
|
|||||||
"complete": "stop",
|
"complete": "stop",
|
||||||
"content_filtered": "content_filter",
|
"content_filtered": "content_filter",
|
||||||
}
|
}
|
||||||
return finish_reason_mapping.get(finish_reason.lower(), finish_reason.lower())
|
return finish_reason_mapping.get(
|
||||||
|
finish_reason.lower(), finish_reason.lower()
|
||||||
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@@ -809,7 +872,9 @@ class CohereEmbeddingsModel(BedrockEmbeddingsModel):
|
|||||||
return args
|
return args
|
||||||
|
|
||||||
def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse:
|
def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse:
|
||||||
response = self._invoke_model(args=self._parse_args(embeddings_request), model_id=embeddings_request.model)
|
response = self._invoke_model(
|
||||||
|
args=self._parse_args(embeddings_request), model_id=embeddings_request.model
|
||||||
|
)
|
||||||
response_body = json.loads(response.get("body").read())
|
response_body = json.loads(response.get("body").read())
|
||||||
if DEBUG:
|
if DEBUG:
|
||||||
logger.info("Bedrock response body: " + str(response_body))
|
logger.info("Bedrock response body: " + str(response_body))
|
||||||
@@ -825,10 +890,15 @@ class TitanEmbeddingsModel(BedrockEmbeddingsModel):
|
|||||||
def _parse_args(self, embeddings_request: EmbeddingsRequest) -> dict:
|
def _parse_args(self, embeddings_request: EmbeddingsRequest) -> dict:
|
||||||
if isinstance(embeddings_request.input, str):
|
if isinstance(embeddings_request.input, str):
|
||||||
input_text = embeddings_request.input
|
input_text = embeddings_request.input
|
||||||
elif isinstance(embeddings_request.input, list) and len(embeddings_request.input) == 1:
|
elif (
|
||||||
|
isinstance(embeddings_request.input, list)
|
||||||
|
and len(embeddings_request.input) == 1
|
||||||
|
):
|
||||||
input_text = embeddings_request.input[0]
|
input_text = embeddings_request.input[0]
|
||||||
else:
|
else:
|
||||||
raise ValueError("Amazon Titan Embeddings models support only single strings as input.")
|
raise ValueError(
|
||||||
|
"Amazon Titan Embeddings models support only single strings as input."
|
||||||
|
)
|
||||||
args = {
|
args = {
|
||||||
"inputText": input_text,
|
"inputText": input_text,
|
||||||
# Note: inputImage is not supported!
|
# Note: inputImage is not supported!
|
||||||
@@ -842,7 +912,9 @@ class TitanEmbeddingsModel(BedrockEmbeddingsModel):
|
|||||||
return args
|
return args
|
||||||
|
|
||||||
def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse:
|
def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse:
|
||||||
response = self._invoke_model(args=self._parse_args(embeddings_request), model_id=embeddings_request.model)
|
response = self._invoke_model(
|
||||||
|
args=self._parse_args(embeddings_request), model_id=embeddings_request.model
|
||||||
|
)
|
||||||
response_body = json.loads(response.get("body").read())
|
response_body = json.loads(response.get("body").read())
|
||||||
if DEBUG:
|
if DEBUG:
|
||||||
logger.info("Bedrock response body: " + str(response_body))
|
logger.info("Bedrock response body: " + str(response_body))
|
||||||
|
|||||||
@@ -16,3 +16,4 @@ AWS_REGION = os.environ.get("AWS_REGION", "us-west-2")
|
|||||||
DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "anthropic.claude-3-sonnet-20240229-v1:0")
|
DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "anthropic.claude-3-sonnet-20240229-v1:0")
|
||||||
DEFAULT_EMBEDDING_MODEL = os.environ.get("DEFAULT_EMBEDDING_MODEL", "cohere.embed-multilingual-v3")
|
DEFAULT_EMBEDDING_MODEL = os.environ.get("DEFAULT_EMBEDDING_MODEL", "cohere.embed-multilingual-v3")
|
||||||
ENABLE_CROSS_REGION_INFERENCE = os.environ.get("ENABLE_CROSS_REGION_INFERENCE", "true").lower() != "false"
|
ENABLE_CROSS_REGION_INFERENCE = os.environ.get("ENABLE_CROSS_REGION_INFERENCE", "true").lower() != "false"
|
||||||
|
ENABLE_APPLICATION_INFERENCE_PROFILES = os.environ.get("ENABLE_APPLICATION_INFERENCE_PROFILES", "true").lower() != "false"
|
||||||
|
|||||||
Reference in New Issue
Block a user