refactor(bedrock): unify inference profile metadata handling and cleanup

- Add unified profile_metadata dictionary for both SYSTEM_DEFINED and APPLICATION inference profiles
- Remove unused region prefix functions and defaultdict import
- Add TEMPERATURE_TOPP_CONFLICT_MODELS set for Claude model parameter conflicts
- Improve model ARN parsing and error handling in profile enumeration
- Consolidate profile metadata storage to enable consistent feature detection
This commit is contained in:
Kane Zhu
2025-10-16 15:24:02 +08:00
parent b4800c54a0
commit d86e64eed3
2 changed files with 138 additions and 59 deletions

View File

@@ -4,7 +4,6 @@ import logging
import re
import time
from abc import ABC
from collections import defaultdict
from typing import AsyncIterable, Iterable, Literal
import boto3
@@ -74,16 +73,6 @@ bedrock_client = boto3.client(
config=config,
)
def get_inference_region_prefix():
if AWS_REGION.startswith("ap-"):
return "apac"
return AWS_REGION[:2]
# https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html
cr_inference_prefix = get_inference_region_prefix()
SUPPORTED_BEDROCK_EMBEDDING_MODELS = {
"cohere.embed-multilingual-v3": "Cohere Embed Multilingual",
"cohere.embed-english-v3": "Cohere Embed English",
@@ -95,6 +84,18 @@ SUPPORTED_BEDROCK_EMBEDDING_MODELS = {
ENCODER = tiktoken.get_encoding("cl100k_base")
# Global mapping: Profile ID/ARN → Foundation Model ID
# Handles both SYSTEM_DEFINED (cross-region) and APPLICATION profiles
# This enables feature detection for all profile types without pattern matching
profile_metadata = {}
# Models that don't support both temperature and topP simultaneously
# When both are provided, temperature takes precedence and topP is removed
TEMPERATURE_TOPP_CONFLICT_MODELS = {
"claude-sonnet-4-5",
"claude-haiku-4-5",
}
def list_bedrock_models() -> dict:
"""Automatically getting a list of supported models.
@@ -106,15 +107,26 @@ def list_bedrock_models() -> dict:
"""
model_list = {}
try:
profile_list = []
# Map foundation model_id -> set of application inference profile ARNs
app_profiles_by_model = defaultdict(set)
if ENABLE_CROSS_REGION_INFERENCE:
# List system defined inference profile IDs
# List system defined inference profile IDs and store underlying model mapping
paginator = bedrock_client.get_paginator('list_inference_profiles')
for page in paginator.paginate(maxResults=1000, typeEquals="SYSTEM_DEFINED"):
profile_list.extend([p["inferenceProfileId"] for p in page["inferenceProfileSummaries"]])
for profile in page["inferenceProfileSummaries"]:
profile_id = profile.get("inferenceProfileId")
if not profile_id:
continue
# Extract underlying model from first model in the profile
models = profile.get("models", [])
if models:
model_arn = models[0].get("modelArn", "")
if model_arn:
# Extract foundation model ID from ARN
model_id = model_arn.split('/')[-1]
profile_metadata[profile_id] = {
"underlying_model_id": model_id,
"profile_type": "SYSTEM_DEFINED",
}
if ENABLE_APPLICATION_INFERENCE_PROFILES:
# List application defined inference profile IDs and create mapping
@@ -125,15 +137,28 @@ def list_bedrock_models() -> dict:
profile_arn = profile.get("inferenceProfileArn")
if not profile_arn:
continue
# Process all models in the profile
models = profile.get("models", [])
for model in models:
model_arn = model.get("modelArn", "")
if model_arn:
model_id = model_arn.split('/')[-1] if '/' in model_arn else model_arn
if model_id:
app_profiles_by_model[model_id].add(profile_arn)
if not models:
logger.warning(f"Application profile {profile_arn} has no models")
continue
# Take first model - all models in array are same type (regional instances)
first_model = models[0]
model_arn = first_model.get("modelArn", "")
if not model_arn:
continue
# Extract model ID from ARN (works for both foundation models and cross-region profiles)
model_id = model_arn.split('/')[-1] if '/' in model_arn else model_arn
# Store in unified profile metadata for feature detection
profile_metadata[profile_arn] = {
"underlying_model_id": model_id,
"profile_type": "APPLICATION",
"profile_name": profile.get("inferenceProfileName", ""),
}
except Exception as e:
logger.warning(f"Error processing application profile: {e}")
continue
@@ -156,20 +181,10 @@ def list_bedrock_models() -> dict:
if "ON_DEMAND" in inference_types:
model_list[model_id] = {"modalities": input_modalities}
# Add cross-region inference model list.
profile_id = cr_inference_prefix + "." + model_id
if profile_id in profile_list:
model_list[profile_id] = {"modalities": input_modalities}
# Add global cross-region inference profiles
global_profile_id = "global." + model_id
if global_profile_id in profile_list:
model_list[global_profile_id] = {"modalities": input_modalities}
# Add application inference profiles (emit all profiles for this model)
if model_id in app_profiles_by_model:
for profile_arn in app_profiles_by_model[model_id]:
model_list[profile_arn] = {"modalities": input_modalities}
# Add all inference profiles (cross-region and application) for this model
for profile_id, metadata in profile_metadata.items():
if metadata.get("underlying_model_id") == model_id:
model_list[profile_id] = {"modalities": input_modalities}
except Exception as e:
logger.error(f"Unable to list models: {str(e)}")
@@ -197,17 +212,56 @@ class BedrockModel(BaseChatModel):
error = ""
# check if model is supported
if chat_request.model not in bedrock_model_list.keys():
error = f"Unsupported model {chat_request.model}, please use models API to get a list of supported models"
# Provide helpful error for application profiles
if "application-inference-profile" in chat_request.model:
error = (
f"Application profile {chat_request.model} not found. "
f"Available profiles can be listed via GET /models API. "
f"Ensure ENABLE_APPLICATION_INFERENCE_PROFILES=true and "
f"the profile exists in your AWS account."
)
else:
error = f"Unsupported model {chat_request.model}, please use models API to get a list of supported models"
logger.error("Unsupported model: %s", chat_request.model)
# Validate profile has resolvable underlying model
if not error and chat_request.model in profile_metadata:
resolved = self._resolve_to_foundation_model(chat_request.model)
if resolved == chat_request.model:
logger.warning(
f"Could not resolve profile {chat_request.model} "
f"to underlying model. Some features may not work correctly."
)
if error:
raise HTTPException(
status_code=400,
detail=error,
)
@staticmethod
def _supports_prompt_caching(model_id: str) -> bool:
def _resolve_to_foundation_model(self, model_id: str) -> str:
"""
Resolve any model identifier to foundation model ID for feature detection.
Handles:
- Cross-region profiles (us.*, eu.*, apac.*, global.*)
- Application profiles (arn:aws:bedrock:...:application-inference-profile/...)
- Foundation models (pass through unchanged)
No pattern matching needed - just dictionary lookup.
Unknown identifiers pass through unchanged (graceful fallback).
Args:
model_id: Can be foundation model ID, cross-region profile, or app profile ARN
Returns:
Foundation model ID if mapping exists, otherwise original model_id
"""
if model_id in profile_metadata:
return profile_metadata[model_id]["underlying_model_id"]
return model_id
def _supports_prompt_caching(self, model_id: str) -> bool:
"""
Check if model supports prompt caching based on model ID pattern.
@@ -221,10 +275,12 @@ class BedrockModel(BaseChatModel):
Returns:
bool: True if model supports prompt caching
"""
model_lower = model_id.lower()
# Resolve profile to underlying model for feature detection
resolved_model = self._resolve_to_foundation_model(model_id)
model_lower = resolved_model.lower()
# Claude models pattern matching
if "anthropic.claude" in model_lower or ".anthropic.claude" in model_lower:
if "anthropic.claude" in model_lower:
# Exclude very old models that don't support caching
excluded_patterns = ["claude-instant", "claude-v1", "claude-v2"]
if any(pattern in model_lower for pattern in excluded_patterns):
@@ -232,7 +288,7 @@ class BedrockModel(BaseChatModel):
return True
# Nova models pattern matching
if "amazon.nova" in model_lower or ".amazon.nova" in model_lower:
if "amazon.nova" in model_lower:
return True
# Future providers can be added here
@@ -240,8 +296,7 @@ class BedrockModel(BaseChatModel):
return False
@staticmethod
def _get_max_cache_tokens(model_id: str) -> int | None:
def _get_max_cache_tokens(self, model_id: str) -> int | None:
"""
Get maximum cacheable tokens limit for the model.
@@ -252,14 +307,16 @@ class BedrockModel(BaseChatModel):
Returns:
int | None: Max tokens, or None if unlimited
"""
model_lower = model_id.lower()
# Resolve profile to underlying model for feature detection
resolved_model = self._resolve_to_foundation_model(model_id)
model_lower = resolved_model.lower()
# Nova models have 20K limit
if "amazon.nova" in model_lower or ".amazon.nova" in model_lower:
if "amazon.nova" in model_lower:
return 20_000
# Claude: No explicit limit
if "anthropic.claude" in model_lower or ".anthropic.claude" in model_lower:
if "anthropic.claude" in model_lower:
return None
return None
@@ -269,6 +326,14 @@ class BedrockModel(BaseChatModel):
if DEBUG:
logger.info("Raw request: " + chat_request.model_dump_json())
# Log profile resolution for debugging
if chat_request.model in profile_metadata:
resolved = self._resolve_to_foundation_model(chat_request.model)
profile_type = profile_metadata[chat_request.model].get("profile_type", "UNKNOWN")
logger.info(
f"Profile resolution: {chat_request.model} ({profile_type}) → {resolved}"
)
# convert OpenAI chat request to Bedrock SDK request
args = self._parse_request(chat_request)
if DEBUG:
@@ -667,15 +732,27 @@ class BedrockModel(BaseChatModel):
# Base inference parameters.
inference_config = {
"temperature": chat_request.temperature,
"maxTokens": chat_request.max_tokens,
"topP": chat_request.top_p,
}
# Claude Sonnet 4.5 doesn't support both temperature and topP
# Remove topP for this model
if "claude-sonnet-4-5" in chat_request.model.lower():
inference_config.pop("topP", None)
# Only include optional parameters when specified
if chat_request.temperature is not None:
inference_config["temperature"] = chat_request.temperature
if chat_request.top_p is not None:
inference_config["topP"] = chat_request.top_p
# Some models (Claude Sonnet 4.5, Haiku 4.5) don't support both temperature and topP
# When both are provided, keep temperature and remove topP
# Resolve profile to underlying model for feature detection
resolved_model = self._resolve_to_foundation_model(chat_request.model)
model_lower = resolved_model.lower()
# Check if model is in the conflict list and both parameters are present
if "temperature" in inference_config and "topP" in inference_config:
if any(conflict_model in model_lower for conflict_model in TEMPERATURE_TOPP_CONFLICT_MODELS):
inference_config.pop("topP", None)
if DEBUG:
logger.info(f"Removed topP for {chat_request.model} (conflicts with temperature)")
if chat_request.stop is not None:
stop = chat_request.stop
@@ -692,9 +769,11 @@ class BedrockModel(BaseChatModel):
if chat_request.reasoning_effort:
# reasoning_effort is supported by Claude and DeepSeek v3
# Different models use different formats
model_lower = chat_request.model.lower()
# Resolve profile to underlying model for feature detection
resolved_model = self._resolve_to_foundation_model(chat_request.model)
model_lower = resolved_model.lower()
if "anthropic.claude" in model_lower or ".anthropic.claude" in model_lower:
if "anthropic.claude" in model_lower:
# Claude format: reasoning_config = object with budget_tokens
max_tokens = (
chat_request.max_completion_tokens

View File

@@ -97,8 +97,8 @@ class ChatRequest(BaseModel):
presence_penalty: float | None = Field(default=0.0, le=2.0, ge=-2.0) # Not used
stream: bool | None = False
stream_options: StreamOptions | None = None
temperature: float | None = Field(default=1.0, le=2.0, ge=0.0)
top_p: float | None = Field(default=1.0, le=1.0, ge=0.0)
temperature: float | None = Field(default=None, le=2.0, ge=0.0)
top_p: float | None = Field(default=None, le=1.0, ge=0.0)
user: str | None = None # Not used
max_tokens: int | None = 2048
max_completion_tokens: int | None = None