refactor(bedrock): unify inference profile metadata handling and cleanup
- Add unified profile_metadata dictionary for both SYSTEM_DEFINED and APPLICATION inference profiles - Remove unused region prefix functions and defaultdict import - Add TEMPERATURE_TOPP_CONFLICT_MODELS set for Claude model parameter conflicts - Improve model ARN parsing and error handling in profile enumeration - Consolidate profile metadata storage to enable consistent feature detection
This commit is contained in:
@@ -4,7 +4,6 @@ import logging
|
||||
import re
|
||||
import time
|
||||
from abc import ABC
|
||||
from collections import defaultdict
|
||||
from typing import AsyncIterable, Iterable, Literal
|
||||
|
||||
import boto3
|
||||
@@ -74,16 +73,6 @@ bedrock_client = boto3.client(
|
||||
config=config,
|
||||
)
|
||||
|
||||
|
||||
def get_inference_region_prefix():
|
||||
if AWS_REGION.startswith("ap-"):
|
||||
return "apac"
|
||||
return AWS_REGION[:2]
|
||||
|
||||
|
||||
# https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html
|
||||
cr_inference_prefix = get_inference_region_prefix()
|
||||
|
||||
SUPPORTED_BEDROCK_EMBEDDING_MODELS = {
|
||||
"cohere.embed-multilingual-v3": "Cohere Embed Multilingual",
|
||||
"cohere.embed-english-v3": "Cohere Embed English",
|
||||
@@ -95,6 +84,18 @@ SUPPORTED_BEDROCK_EMBEDDING_MODELS = {
|
||||
|
||||
ENCODER = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
# Global mapping: Profile ID/ARN → Foundation Model ID
|
||||
# Handles both SYSTEM_DEFINED (cross-region) and APPLICATION profiles
|
||||
# This enables feature detection for all profile types without pattern matching
|
||||
profile_metadata = {}
|
||||
|
||||
# Models that don't support both temperature and topP simultaneously
|
||||
# When both are provided, temperature takes precedence and topP is removed
|
||||
TEMPERATURE_TOPP_CONFLICT_MODELS = {
|
||||
"claude-sonnet-4-5",
|
||||
"claude-haiku-4-5",
|
||||
}
|
||||
|
||||
|
||||
def list_bedrock_models() -> dict:
|
||||
"""Automatically getting a list of supported models.
|
||||
@@ -106,15 +107,26 @@ def list_bedrock_models() -> dict:
|
||||
"""
|
||||
model_list = {}
|
||||
try:
|
||||
profile_list = []
|
||||
# Map foundation model_id -> set of application inference profile ARNs
|
||||
app_profiles_by_model = defaultdict(set)
|
||||
|
||||
if ENABLE_CROSS_REGION_INFERENCE:
|
||||
# List system defined inference profile IDs
|
||||
# List system defined inference profile IDs and store underlying model mapping
|
||||
paginator = bedrock_client.get_paginator('list_inference_profiles')
|
||||
for page in paginator.paginate(maxResults=1000, typeEquals="SYSTEM_DEFINED"):
|
||||
profile_list.extend([p["inferenceProfileId"] for p in page["inferenceProfileSummaries"]])
|
||||
for profile in page["inferenceProfileSummaries"]:
|
||||
profile_id = profile.get("inferenceProfileId")
|
||||
if not profile_id:
|
||||
continue
|
||||
|
||||
# Extract underlying model from first model in the profile
|
||||
models = profile.get("models", [])
|
||||
if models:
|
||||
model_arn = models[0].get("modelArn", "")
|
||||
if model_arn:
|
||||
# Extract foundation model ID from ARN
|
||||
model_id = model_arn.split('/')[-1]
|
||||
profile_metadata[profile_id] = {
|
||||
"underlying_model_id": model_id,
|
||||
"profile_type": "SYSTEM_DEFINED",
|
||||
}
|
||||
|
||||
if ENABLE_APPLICATION_INFERENCE_PROFILES:
|
||||
# List application defined inference profile IDs and create mapping
|
||||
@@ -125,15 +137,28 @@ def list_bedrock_models() -> dict:
|
||||
profile_arn = profile.get("inferenceProfileArn")
|
||||
if not profile_arn:
|
||||
continue
|
||||
|
||||
|
||||
# Process all models in the profile
|
||||
models = profile.get("models", [])
|
||||
for model in models:
|
||||
model_arn = model.get("modelArn", "")
|
||||
if model_arn:
|
||||
model_id = model_arn.split('/')[-1] if '/' in model_arn else model_arn
|
||||
if model_id:
|
||||
app_profiles_by_model[model_id].add(profile_arn)
|
||||
if not models:
|
||||
logger.warning(f"Application profile {profile_arn} has no models")
|
||||
continue
|
||||
|
||||
# Take first model - all models in array are same type (regional instances)
|
||||
first_model = models[0]
|
||||
model_arn = first_model.get("modelArn", "")
|
||||
if not model_arn:
|
||||
continue
|
||||
|
||||
# Extract model ID from ARN (works for both foundation models and cross-region profiles)
|
||||
model_id = model_arn.split('/')[-1] if '/' in model_arn else model_arn
|
||||
|
||||
# Store in unified profile metadata for feature detection
|
||||
profile_metadata[profile_arn] = {
|
||||
"underlying_model_id": model_id,
|
||||
"profile_type": "APPLICATION",
|
||||
"profile_name": profile.get("inferenceProfileName", ""),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing application profile: {e}")
|
||||
continue
|
||||
@@ -156,20 +181,10 @@ def list_bedrock_models() -> dict:
|
||||
if "ON_DEMAND" in inference_types:
|
||||
model_list[model_id] = {"modalities": input_modalities}
|
||||
|
||||
# Add cross-region inference model list.
|
||||
profile_id = cr_inference_prefix + "." + model_id
|
||||
if profile_id in profile_list:
|
||||
model_list[profile_id] = {"modalities": input_modalities}
|
||||
|
||||
# Add global cross-region inference profiles
|
||||
global_profile_id = "global." + model_id
|
||||
if global_profile_id in profile_list:
|
||||
model_list[global_profile_id] = {"modalities": input_modalities}
|
||||
|
||||
# Add application inference profiles (emit all profiles for this model)
|
||||
if model_id in app_profiles_by_model:
|
||||
for profile_arn in app_profiles_by_model[model_id]:
|
||||
model_list[profile_arn] = {"modalities": input_modalities}
|
||||
# Add all inference profiles (cross-region and application) for this model
|
||||
for profile_id, metadata in profile_metadata.items():
|
||||
if metadata.get("underlying_model_id") == model_id:
|
||||
model_list[profile_id] = {"modalities": input_modalities}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unable to list models: {str(e)}")
|
||||
@@ -197,17 +212,56 @@ class BedrockModel(BaseChatModel):
|
||||
error = ""
|
||||
# check if model is supported
|
||||
if chat_request.model not in bedrock_model_list.keys():
|
||||
error = f"Unsupported model {chat_request.model}, please use models API to get a list of supported models"
|
||||
# Provide helpful error for application profiles
|
||||
if "application-inference-profile" in chat_request.model:
|
||||
error = (
|
||||
f"Application profile {chat_request.model} not found. "
|
||||
f"Available profiles can be listed via GET /models API. "
|
||||
f"Ensure ENABLE_APPLICATION_INFERENCE_PROFILES=true and "
|
||||
f"the profile exists in your AWS account."
|
||||
)
|
||||
else:
|
||||
error = f"Unsupported model {chat_request.model}, please use models API to get a list of supported models"
|
||||
logger.error("Unsupported model: %s", chat_request.model)
|
||||
|
||||
# Validate profile has resolvable underlying model
|
||||
if not error and chat_request.model in profile_metadata:
|
||||
resolved = self._resolve_to_foundation_model(chat_request.model)
|
||||
if resolved == chat_request.model:
|
||||
logger.warning(
|
||||
f"Could not resolve profile {chat_request.model} "
|
||||
f"to underlying model. Some features may not work correctly."
|
||||
)
|
||||
|
||||
if error:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=error,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _supports_prompt_caching(model_id: str) -> bool:
|
||||
def _resolve_to_foundation_model(self, model_id: str) -> str:
|
||||
"""
|
||||
Resolve any model identifier to foundation model ID for feature detection.
|
||||
|
||||
Handles:
|
||||
- Cross-region profiles (us.*, eu.*, apac.*, global.*)
|
||||
- Application profiles (arn:aws:bedrock:...:application-inference-profile/...)
|
||||
- Foundation models (pass through unchanged)
|
||||
|
||||
No pattern matching needed - just dictionary lookup.
|
||||
Unknown identifiers pass through unchanged (graceful fallback).
|
||||
|
||||
Args:
|
||||
model_id: Can be foundation model ID, cross-region profile, or app profile ARN
|
||||
|
||||
Returns:
|
||||
Foundation model ID if mapping exists, otherwise original model_id
|
||||
"""
|
||||
if model_id in profile_metadata:
|
||||
return profile_metadata[model_id]["underlying_model_id"]
|
||||
return model_id
|
||||
|
||||
def _supports_prompt_caching(self, model_id: str) -> bool:
|
||||
"""
|
||||
Check if model supports prompt caching based on model ID pattern.
|
||||
|
||||
@@ -221,10 +275,12 @@ class BedrockModel(BaseChatModel):
|
||||
Returns:
|
||||
bool: True if model supports prompt caching
|
||||
"""
|
||||
model_lower = model_id.lower()
|
||||
# Resolve profile to underlying model for feature detection
|
||||
resolved_model = self._resolve_to_foundation_model(model_id)
|
||||
model_lower = resolved_model.lower()
|
||||
|
||||
# Claude models pattern matching
|
||||
if "anthropic.claude" in model_lower or ".anthropic.claude" in model_lower:
|
||||
if "anthropic.claude" in model_lower:
|
||||
# Exclude very old models that don't support caching
|
||||
excluded_patterns = ["claude-instant", "claude-v1", "claude-v2"]
|
||||
if any(pattern in model_lower for pattern in excluded_patterns):
|
||||
@@ -232,7 +288,7 @@ class BedrockModel(BaseChatModel):
|
||||
return True
|
||||
|
||||
# Nova models pattern matching
|
||||
if "amazon.nova" in model_lower or ".amazon.nova" in model_lower:
|
||||
if "amazon.nova" in model_lower:
|
||||
return True
|
||||
|
||||
# Future providers can be added here
|
||||
@@ -240,8 +296,7 @@ class BedrockModel(BaseChatModel):
|
||||
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _get_max_cache_tokens(model_id: str) -> int | None:
|
||||
def _get_max_cache_tokens(self, model_id: str) -> int | None:
|
||||
"""
|
||||
Get maximum cacheable tokens limit for the model.
|
||||
|
||||
@@ -252,14 +307,16 @@ class BedrockModel(BaseChatModel):
|
||||
Returns:
|
||||
int | None: Max tokens, or None if unlimited
|
||||
"""
|
||||
model_lower = model_id.lower()
|
||||
# Resolve profile to underlying model for feature detection
|
||||
resolved_model = self._resolve_to_foundation_model(model_id)
|
||||
model_lower = resolved_model.lower()
|
||||
|
||||
# Nova models have 20K limit
|
||||
if "amazon.nova" in model_lower or ".amazon.nova" in model_lower:
|
||||
if "amazon.nova" in model_lower:
|
||||
return 20_000
|
||||
|
||||
# Claude: No explicit limit
|
||||
if "anthropic.claude" in model_lower or ".anthropic.claude" in model_lower:
|
||||
if "anthropic.claude" in model_lower:
|
||||
return None
|
||||
|
||||
return None
|
||||
@@ -269,6 +326,14 @@ class BedrockModel(BaseChatModel):
|
||||
if DEBUG:
|
||||
logger.info("Raw request: " + chat_request.model_dump_json())
|
||||
|
||||
# Log profile resolution for debugging
|
||||
if chat_request.model in profile_metadata:
|
||||
resolved = self._resolve_to_foundation_model(chat_request.model)
|
||||
profile_type = profile_metadata[chat_request.model].get("profile_type", "UNKNOWN")
|
||||
logger.info(
|
||||
f"Profile resolution: {chat_request.model} ({profile_type}) → {resolved}"
|
||||
)
|
||||
|
||||
# convert OpenAI chat request to Bedrock SDK request
|
||||
args = self._parse_request(chat_request)
|
||||
if DEBUG:
|
||||
@@ -667,15 +732,27 @@ class BedrockModel(BaseChatModel):
|
||||
|
||||
# Base inference parameters.
|
||||
inference_config = {
|
||||
"temperature": chat_request.temperature,
|
||||
"maxTokens": chat_request.max_tokens,
|
||||
"topP": chat_request.top_p,
|
||||
}
|
||||
|
||||
# Claude Sonnet 4.5 doesn't support both temperature and topP
|
||||
# Remove topP for this model
|
||||
if "claude-sonnet-4-5" in chat_request.model.lower():
|
||||
inference_config.pop("topP", None)
|
||||
# Only include optional parameters when specified
|
||||
if chat_request.temperature is not None:
|
||||
inference_config["temperature"] = chat_request.temperature
|
||||
if chat_request.top_p is not None:
|
||||
inference_config["topP"] = chat_request.top_p
|
||||
|
||||
# Some models (Claude Sonnet 4.5, Haiku 4.5) don't support both temperature and topP
|
||||
# When both are provided, keep temperature and remove topP
|
||||
# Resolve profile to underlying model for feature detection
|
||||
resolved_model = self._resolve_to_foundation_model(chat_request.model)
|
||||
model_lower = resolved_model.lower()
|
||||
|
||||
# Check if model is in the conflict list and both parameters are present
|
||||
if "temperature" in inference_config and "topP" in inference_config:
|
||||
if any(conflict_model in model_lower for conflict_model in TEMPERATURE_TOPP_CONFLICT_MODELS):
|
||||
inference_config.pop("topP", None)
|
||||
if DEBUG:
|
||||
logger.info(f"Removed topP for {chat_request.model} (conflicts with temperature)")
|
||||
|
||||
if chat_request.stop is not None:
|
||||
stop = chat_request.stop
|
||||
@@ -692,9 +769,11 @@ class BedrockModel(BaseChatModel):
|
||||
if chat_request.reasoning_effort:
|
||||
# reasoning_effort is supported by Claude and DeepSeek v3
|
||||
# Different models use different formats
|
||||
model_lower = chat_request.model.lower()
|
||||
# Resolve profile to underlying model for feature detection
|
||||
resolved_model = self._resolve_to_foundation_model(chat_request.model)
|
||||
model_lower = resolved_model.lower()
|
||||
|
||||
if "anthropic.claude" in model_lower or ".anthropic.claude" in model_lower:
|
||||
if "anthropic.claude" in model_lower:
|
||||
# Claude format: reasoning_config = object with budget_tokens
|
||||
max_tokens = (
|
||||
chat_request.max_completion_tokens
|
||||
|
||||
@@ -97,8 +97,8 @@ class ChatRequest(BaseModel):
|
||||
presence_penalty: float | None = Field(default=0.0, le=2.0, ge=-2.0) # Not used
|
||||
stream: bool | None = False
|
||||
stream_options: StreamOptions | None = None
|
||||
temperature: float | None = Field(default=1.0, le=2.0, ge=0.0)
|
||||
top_p: float | None = Field(default=1.0, le=1.0, ge=0.0)
|
||||
temperature: float | None = Field(default=None, le=2.0, ge=0.0)
|
||||
top_p: float | None = Field(default=None, le=1.0, ge=0.0)
|
||||
user: str | None = None # Not used
|
||||
max_tokens: int | None = 2048
|
||||
max_completion_tokens: int | None = None
|
||||
|
||||
Reference in New Issue
Block a user