models: fix Application Inference Profiles mapping (#175)

* models: fix Application Inference Profiles mapping to include all profiles per model_id; switch to defaultdict(set) and emit all AIPs

* Fix rebase issue

---------

Co-authored-by: Jeremy Brockett <313937+jbrockett@users.noreply.github.com>
This commit is contained in:
jbrockett
2025-08-14 03:21:14 -04:00
committed by GitHub
parent a2110ff648
commit 911dfe26d6

View File

@@ -4,6 +4,7 @@ import logging
import re import re
import time import time
from abc import ABC from abc import ABC
from collections import defaultdict
from typing import AsyncIterable, Iterable, Literal from typing import AsyncIterable, Iterable, Literal
import boto3 import boto3
@@ -103,7 +104,8 @@ def list_bedrock_models() -> dict:
model_list = {} model_list = {}
try: try:
profile_list = [] profile_list = []
app_profile_dict = {} # Map foundation model_id -> set of application inference profile ARNs
app_profiles_by_model = defaultdict(set)
if ENABLE_CROSS_REGION_INFERENCE: if ENABLE_CROSS_REGION_INFERENCE:
# List system defined inference profile IDs # List system defined inference profile IDs
@@ -128,7 +130,7 @@ def list_bedrock_models() -> dict:
if model_arn: if model_arn:
model_id = model_arn.split('/')[-1] if '/' in model_arn else model_arn model_id = model_arn.split('/')[-1] if '/' in model_arn else model_arn
if model_id: if model_id:
app_profile_dict[model_id] = profile_arn app_profiles_by_model[model_id].add(profile_arn)
except Exception as e: except Exception as e:
logger.warning(f"Error processing application profile: {e}") logger.warning(f"Error processing application profile: {e}")
continue continue
@@ -156,9 +158,10 @@ def list_bedrock_models() -> dict:
if profile_id in profile_list: if profile_id in profile_list:
model_list[profile_id] = {"modalities": input_modalities} model_list[profile_id] = {"modalities": input_modalities}
# Add application inference profiles # Add application inference profiles (emit all profiles for this model)
if model_id in app_profile_dict: if model_id in app_profiles_by_model:
model_list[app_profile_dict[model_id]] = {"modalities": input_modalities} for profile_arn in app_profiles_by_model[model_id]:
model_list[profile_arn] = {"modalities": input_modalities}
except Exception as e: except Exception as e:
logger.error(f"Unable to list models: {str(e)}") logger.error(f"Unable to list models: {str(e)}")