From 911dfe26d6b2c7bfaaef46fccfbac79776d07e81 Mon Sep 17 00:00:00 2001 From: jbrockett Date: Thu, 14 Aug 2025 03:21:14 -0400 Subject: [PATCH] models: fix Application Inference Profiles mapping (#175) * models: fix Application Inference Profiles mapping to include all profiles per model_id; switch to defaultdict(set) and emit all AIPs * Fix rebase issue --------- Co-authored-by: Jeremy Brockett <313937+jbrockett@users.noreply.github.com> --- src/api/models/bedrock.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/api/models/bedrock.py b/src/api/models/bedrock.py index bf08165..374fcd1 100644 --- a/src/api/models/bedrock.py +++ b/src/api/models/bedrock.py @@ -4,6 +4,7 @@ import logging import re import time from abc import ABC +from collections import defaultdict from typing import AsyncIterable, Iterable, Literal import boto3 @@ -103,7 +104,8 @@ def list_bedrock_models() -> dict: model_list = {} try: profile_list = [] - app_profile_dict = {} + # Map foundation model_id -> set of application inference profile ARNs + app_profiles_by_model = defaultdict(set) if ENABLE_CROSS_REGION_INFERENCE: # List system defined inference profile IDs @@ -128,7 +130,7 @@ def list_bedrock_models() -> dict: if model_arn: model_id = model_arn.split('/')[-1] if '/' in model_arn else model_arn if model_id: - app_profile_dict[model_id] = profile_arn + app_profiles_by_model[model_id].add(profile_arn) except Exception as e: logger.warning(f"Error processing application profile: {e}") continue @@ -156,9 +158,10 @@ def list_bedrock_models() -> dict: if profile_id in profile_list: model_list[profile_id] = {"modalities": input_modalities} - # Add application inference profiles - if model_id in app_profile_dict: - model_list[app_profile_dict[model_id]] = {"modalities": input_modalities} + # Add application inference profiles (emit all profiles for this model) + if model_id in app_profiles_by_model: + for profile_arn in app_profiles_by_model[model_id]: + model_list[profile_arn] = {"modalities": input_modalities} except Exception as e: logger.error(f"Unable to list models: {str(e)}")