feat: add support to include application inference profiles as models (#131)

---------

Co-authored-by: Mengxin Zhu <843303+zxkane@users.noreply.github.com>
This commit is contained in:
Gagan M
2025-06-23 20:19:27 +05:30
committed by GitHub
parent dd191d7cd9
commit 01836087b1
5 changed files with 139 additions and 19 deletions

View File

@@ -26,6 +26,7 @@ If you find this GitHub repository useful, please consider giving it a free star
- [x] Support Embedding API - [x] Support Embedding API
- [x] Support Multimodal API - [x] Support Multimodal API
- [x] Support Cross-Region Inference - [x] Support Cross-Region Inference
- [x] Support Application Inference Profiles (**new**)
- [x] Support Reasoning (**new**) - [x] Support Reasoning (**new**)
Please check [Usage Guide](./docs/Usage.md) for more details about how to use the new APIs. Please check [Usage Guide](./docs/Usage.md) for more details about how to use the new APIs.
@@ -148,7 +149,48 @@ print(completion.choices[0].message.content)
Please check [Usage Guide](./docs/Usage.md) for more details about how to use embedding API, multimodal API and tool call. Please check [Usage Guide](./docs/Usage.md) for more details about how to use embedding API, multimodal API and tool call.
### Application Inference Profiles
This proxy now supports **Application Inference Profiles**, which allow you to track usage and costs for your model invocations. You can use application inference profiles created in your AWS account for cost tracking and monitoring purposes.
**Using Application Inference Profiles:**
```bash
# Use an application inference profile ARN as the model ID
curl $OPENAI_BASE_URL/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer $OPENAI_API_KEY" \
-d '{
"model": "arn:aws:bedrock:us-west-2:123456789012:application-inference-profile/your-profile-id",
"messages": [
{
"role": "user",
"content": "Hello!"
}
]
}'
```
**SDK Usage with Application Inference Profiles:**
```python
from openai import OpenAI
client = OpenAI()
completion = client.chat.completions.create(
model="arn:aws:bedrock:us-west-2:123456789012:application-inference-profile/your-profile-id",
messages=[{"role": "user", "content": "Hello!"}],
)
print(completion.choices[0].message.content)
```
**Benefits of Application Inference Profiles:**
- **Cost Tracking**: Track usage and costs for specific applications or use cases
- **Usage Monitoring**: Monitor model invocation metrics through CloudWatch
- **Tag-based Cost Allocation**: Use AWS cost allocation tags for detailed billing analysis
For more information about creating and managing application inference profiles, see the [Amazon Bedrock User Guide](https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-create.html).
## Other Examples ## Other Examples

View File

@@ -151,6 +151,7 @@ Resources:
Resource: Resource:
- arn:aws:bedrock:*::foundation-model/* - arn:aws:bedrock:*::foundation-model/*
- arn:aws:bedrock:*:*:inference-profile/* - arn:aws:bedrock:*:*:inference-profile/*
- arn:aws:bedrock:*:*:application-inference-profile/*
- Action: - Action:
- secretsmanager:GetSecretValue - secretsmanager:GetSecretValue
- secretsmanager:DescribeSecret - secretsmanager:DescribeSecret
@@ -185,6 +186,7 @@ Resources:
Ref: DefaultModelId Ref: DefaultModelId
DEFAULT_EMBEDDING_MODEL: cohere.embed-multilingual-v3 DEFAULT_EMBEDDING_MODEL: cohere.embed-multilingual-v3
ENABLE_CROSS_REGION_INFERENCE: "true" ENABLE_CROSS_REGION_INFERENCE: "true"
ENABLE_APPLICATION_INFERENCE_PROFILES: "true"
MemorySize: 1024 MemorySize: 1024
PackageType: Image PackageType: Image
Role: Role:

View File

@@ -193,6 +193,7 @@ Resources:
Resource: Resource:
- arn:aws:bedrock:*::foundation-model/* - arn:aws:bedrock:*::foundation-model/*
- arn:aws:bedrock:*:*:inference-profile/* - arn:aws:bedrock:*:*:inference-profile/*
- arn:aws:bedrock:*:*:application-inference-profile/*
Version: "2012-10-17" Version: "2012-10-17"
PolicyName: ProxyTaskRoleDefaultPolicy933321B8 PolicyName: ProxyTaskRoleDefaultPolicy933321B8
Roles: Roles:
@@ -222,6 +223,8 @@ Resources:
Value: cohere.embed-multilingual-v3 Value: cohere.embed-multilingual-v3
- Name: ENABLE_CROSS_REGION_INFERENCE - Name: ENABLE_CROSS_REGION_INFERENCE
Value: "true" Value: "true"
- Name: ENABLE_APPLICATION_INFERENCE_PROFILES
Value: "true"
Essential: true Essential: true
Image: Image:
Fn::Join: Fn::Join:

View File

@@ -38,7 +38,13 @@ from api.schema import (
Usage, Usage,
UserMessage, UserMessage,
) )
from api.setting import AWS_REGION, DEBUG, DEFAULT_MODEL, ENABLE_CROSS_REGION_INFERENCE from api.setting import (
AWS_REGION,
DEBUG,
DEFAULT_MODEL,
ENABLE_CROSS_REGION_INFERENCE,
ENABLE_APPLICATION_INFERENCE_PROFILES,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -83,15 +89,40 @@ def list_bedrock_models() -> dict:
Returns a model list combines: Returns a model list combines:
- ON_DEMAND models. - ON_DEMAND models.
- Cross-Region Inference Profiles (if enabled via Env) - Cross-Region Inference Profiles (if enabled via Env)
- Application Inference Profiles (if enabled via Env)
""" """
model_list = {} model_list = {}
try: try:
profile_list = [] profile_list = []
app_profile_dict = {}
if ENABLE_CROSS_REGION_INFERENCE: if ENABLE_CROSS_REGION_INFERENCE:
# List system defined inference profile IDs # List system defined inference profile IDs
response = bedrock_client.list_inference_profiles(maxResults=1000, typeEquals="SYSTEM_DEFINED") response = bedrock_client.list_inference_profiles(maxResults=1000, typeEquals="SYSTEM_DEFINED")
profile_list = [p["inferenceProfileId"] for p in response["inferenceProfileSummaries"]] profile_list = [p["inferenceProfileId"] for p in response["inferenceProfileSummaries"]]
if ENABLE_APPLICATION_INFERENCE_PROFILES:
# List application defined inference profile IDs and create mapping
response = bedrock_client.list_inference_profiles(maxResults=1000, typeEquals="APPLICATION")
for profile in response["inferenceProfileSummaries"]:
try:
profile_arn = profile.get("inferenceProfileArn")
if not profile_arn:
continue
# Process all models in the profile
models = profile.get("models", [])
for model in models:
model_arn = model.get("modelArn", "")
if model_arn:
model_id = model_arn.split('/')[-1] if '/' in model_arn else model_arn
if model_id:
app_profile_dict[model_id] = profile_arn
except Exception as e:
logger.warning(f"Error processing application profile: {e}")
continue
# List foundation models, only cares about text outputs here. # List foundation models, only cares about text outputs here.
response = bedrock_client.list_foundation_models(byOutputModality="TEXT") response = bedrock_client.list_foundation_models(byOutputModality="TEXT")
@@ -115,6 +146,10 @@ def list_bedrock_models() -> dict:
if profile_id in profile_list: if profile_id in profile_list:
model_list[profile_id] = {"modalities": input_modalities} model_list[profile_id] = {"modalities": input_modalities}
# Add application inference profiles
if model_id in app_profile_dict:
model_list[app_profile_dict[model_id]] = {"modalities": input_modalities}
except Exception as e: except Exception as e:
logger.error(f"Unable to list models: {str(e)}") logger.error(f"Unable to list models: {str(e)}")
@@ -162,7 +197,9 @@ class BedrockModel(BaseChatModel):
try: try:
if stream: if stream:
# Run the blocking boto3 call in a thread pool # Run the blocking boto3 call in a thread pool
response = await run_in_threadpool(bedrock_runtime.converse_stream, **args) response = await run_in_threadpool(
bedrock_runtime.converse_stream, **args
)
else: else:
# Run the blocking boto3 call in a thread pool # Run the blocking boto3 call in a thread pool
response = await run_in_threadpool(bedrock_runtime.converse, **args) response = await run_in_threadpool(bedrock_runtime.converse, **args)
@@ -274,7 +311,9 @@ class BedrockModel(BaseChatModel):
messages.append( messages.append(
{ {
"role": message.role, "role": message.role,
"content": self._parse_content_parts(message, chat_request.model), "content": self._parse_content_parts(
message, chat_request.model
),
} }
) )
elif isinstance(message, AssistantMessage): elif isinstance(message, AssistantMessage):
@@ -283,7 +322,9 @@ class BedrockModel(BaseChatModel):
messages.append( messages.append(
{ {
"role": message.role, "role": message.role,
"content": self._parse_content_parts(message, chat_request.model), "content": self._parse_content_parts(
message, chat_request.model
),
} }
) )
if message.tool_calls: if message.tool_calls:
@@ -363,7 +404,9 @@ class BedrockModel(BaseChatModel):
# If the next role is different from the previous message, add the previous role's messages to the list # If the next role is different from the previous message, add the previous role's messages to the list
if next_role != current_role: if next_role != current_role:
if current_content: if current_content:
reformatted_messages.append({"role": current_role, "content": current_content}) reformatted_messages.append(
{"role": current_role, "content": current_content}
)
# Switch to the new role # Switch to the new role
current_role = next_role current_role = next_role
current_content = [] current_content = []
@@ -376,7 +419,9 @@ class BedrockModel(BaseChatModel):
# Add the last role's messages to the list # Add the last role's messages to the list
if current_content: if current_content:
reformatted_messages.append({"role": current_role, "content": current_content}) reformatted_messages.append(
{"role": current_role, "content": current_content}
)
return reformatted_messages return reformatted_messages
@@ -414,9 +459,13 @@ class BedrockModel(BaseChatModel):
# Use max_completion_tokens if provided. # Use max_completion_tokens if provided.
max_tokens = ( max_tokens = (
chat_request.max_completion_tokens if chat_request.max_completion_tokens else chat_request.max_tokens chat_request.max_completion_tokens
if chat_request.max_completion_tokens
else chat_request.max_tokens
)
budget_tokens = self._calc_budget_tokens(
max_tokens, chat_request.reasoning_effort
) )
budget_tokens = self._calc_budget_tokens(max_tokens, chat_request.reasoning_effort)
inference_config["maxTokens"] = max_tokens inference_config["maxTokens"] = max_tokens
# unset topP - Not supported # unset topP - Not supported
inference_config.pop("topP") inference_config.pop("topP")
@@ -428,7 +477,9 @@ class BedrockModel(BaseChatModel):
if chat_request.tools: if chat_request.tools:
tool_config = {"tools": [self._convert_tool_spec(t.function) for t in chat_request.tools]} tool_config = {"tools": [self._convert_tool_spec(t.function) for t in chat_request.tools]}
if chat_request.tool_choice and not chat_request.model.startswith("meta.llama3-1-"): if chat_request.tool_choice and not chat_request.model.startswith(
"meta.llama3-1-"
):
if isinstance(chat_request.tool_choice, str): if isinstance(chat_request.tool_choice, str):
# auto (default) is mapped to {"auto" : {}} # auto (default) is mapped to {"auto" : {}}
# required is mapped to {"any" : {}} # required is mapped to {"any" : {}}
@@ -477,11 +528,15 @@ class BedrockModel(BaseChatModel):
message.content = "" message.content = ""
for c in content: for c in content:
if "reasoningContent" in c: if "reasoningContent" in c:
message.reasoning_content = c["reasoningContent"]["reasoningText"].get("text", "") message.reasoning_content = c["reasoningContent"][
"reasoningText"
].get("text", "")
elif "text" in c: elif "text" in c:
message.content = c["text"] message.content = c["text"]
else: else:
logger.warning("Unknown tag in message content " + ",".join(c.keys())) logger.warning(
"Unknown tag in message content " + ",".join(c.keys())
)
response = ChatResponse( response = ChatResponse(
id=message_id, id=message_id,
@@ -505,7 +560,9 @@ class BedrockModel(BaseChatModel):
response.created = int(time.time()) response.created = int(time.time())
return response return response
def _create_response_stream(self, model_id: str, message_id: str, chunk: dict) -> ChatStreamResponse | None: def _create_response_stream(
self, model_id: str, message_id: str, chunk: dict
) -> ChatStreamResponse | None:
"""Parsing the Bedrock stream response chunk. """Parsing the Bedrock stream response chunk.
Ref: https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#message-inference-examples Ref: https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#message-inference-examples
@@ -627,7 +684,9 @@ class BedrockModel(BaseChatModel):
image_content = response.content image_content = response.content
return image_content, content_type return image_content, content_type
else: else:
raise HTTPException(status_code=500, detail="Unable to access the image url") raise HTTPException(
status_code=500, detail="Unable to access the image url"
)
def _parse_content_parts( def _parse_content_parts(
self, self,
@@ -687,7 +746,9 @@ class BedrockModel(BaseChatModel):
} }
} }
def _calc_budget_tokens(self, max_tokens: int, reasoning_effort: Literal["low", "medium", "high"]) -> int: def _calc_budget_tokens(
self, max_tokens: int, reasoning_effort: Literal["low", "medium", "high"]
) -> int:
# Helper function to calculate budget_tokens based on the max_tokens. # Helper function to calculate budget_tokens based on the max_tokens.
# Ratio for efforts: Low - 30%, medium - 60%, High: Max token - 1 # Ratio for efforts: Low - 30%, medium - 60%, High: Max token - 1
# Note that The minimum budget_tokens is 1,024 tokens so far. # Note that The minimum budget_tokens is 1,024 tokens so far.
@@ -718,7 +779,9 @@ class BedrockModel(BaseChatModel):
"complete": "stop", "complete": "stop",
"content_filtered": "content_filter", "content_filtered": "content_filter",
} }
return finish_reason_mapping.get(finish_reason.lower(), finish_reason.lower()) return finish_reason_mapping.get(
finish_reason.lower(), finish_reason.lower()
)
return None return None
@@ -809,7 +872,9 @@ class CohereEmbeddingsModel(BedrockEmbeddingsModel):
return args return args
def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse: def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse:
response = self._invoke_model(args=self._parse_args(embeddings_request), model_id=embeddings_request.model) response = self._invoke_model(
args=self._parse_args(embeddings_request), model_id=embeddings_request.model
)
response_body = json.loads(response.get("body").read()) response_body = json.loads(response.get("body").read())
if DEBUG: if DEBUG:
logger.info("Bedrock response body: " + str(response_body)) logger.info("Bedrock response body: " + str(response_body))
@@ -825,10 +890,15 @@ class TitanEmbeddingsModel(BedrockEmbeddingsModel):
def _parse_args(self, embeddings_request: EmbeddingsRequest) -> dict: def _parse_args(self, embeddings_request: EmbeddingsRequest) -> dict:
if isinstance(embeddings_request.input, str): if isinstance(embeddings_request.input, str):
input_text = embeddings_request.input input_text = embeddings_request.input
elif isinstance(embeddings_request.input, list) and len(embeddings_request.input) == 1: elif (
isinstance(embeddings_request.input, list)
and len(embeddings_request.input) == 1
):
input_text = embeddings_request.input[0] input_text = embeddings_request.input[0]
else: else:
raise ValueError("Amazon Titan Embeddings models support only single strings as input.") raise ValueError(
"Amazon Titan Embeddings models support only single strings as input."
)
args = { args = {
"inputText": input_text, "inputText": input_text,
# Note: inputImage is not supported! # Note: inputImage is not supported!
@@ -842,7 +912,9 @@ class TitanEmbeddingsModel(BedrockEmbeddingsModel):
return args return args
def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse: def embed(self, embeddings_request: EmbeddingsRequest) -> EmbeddingsResponse:
response = self._invoke_model(args=self._parse_args(embeddings_request), model_id=embeddings_request.model) response = self._invoke_model(
args=self._parse_args(embeddings_request), model_id=embeddings_request.model
)
response_body = json.loads(response.get("body").read()) response_body = json.loads(response.get("body").read())
if DEBUG: if DEBUG:
logger.info("Bedrock response body: " + str(response_body)) logger.info("Bedrock response body: " + str(response_body))

View File

@@ -16,3 +16,4 @@ AWS_REGION = os.environ.get("AWS_REGION", "us-west-2")
DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "anthropic.claude-3-sonnet-20240229-v1:0") DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "anthropic.claude-3-sonnet-20240229-v1:0")
DEFAULT_EMBEDDING_MODEL = os.environ.get("DEFAULT_EMBEDDING_MODEL", "cohere.embed-multilingual-v3") DEFAULT_EMBEDDING_MODEL = os.environ.get("DEFAULT_EMBEDDING_MODEL", "cohere.embed-multilingual-v3")
ENABLE_CROSS_REGION_INFERENCE = os.environ.get("ENABLE_CROSS_REGION_INFERENCE", "true").lower() != "false" ENABLE_CROSS_REGION_INFERENCE = os.environ.get("ENABLE_CROSS_REGION_INFERENCE", "true").lower() != "false"
ENABLE_APPLICATION_INFERENCE_PROFILES = os.environ.get("ENABLE_APPLICATION_INFERENCE_PROFILES", "true").lower() != "false"