models: fix Application Inference Profiles mapping (#175)
* models: fix Application Inference Profiles mapping to include all profiles per model_id; switch to defaultdict(set) and emit all AIPs * Fix rebase issue --------- Co-authored-by: Jeremy Brockett <313937+jbrockett@users.noreply.github.com>
This commit is contained in:
@@ -4,6 +4,7 @@ import logging
|
|||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
|
from collections import defaultdict
|
||||||
from typing import AsyncIterable, Iterable, Literal
|
from typing import AsyncIterable, Iterable, Literal
|
||||||
|
|
||||||
import boto3
|
import boto3
|
||||||
@@ -103,7 +104,8 @@ def list_bedrock_models() -> dict:
|
|||||||
model_list = {}
|
model_list = {}
|
||||||
try:
|
try:
|
||||||
profile_list = []
|
profile_list = []
|
||||||
app_profile_dict = {}
|
# Map foundation model_id -> set of application inference profile ARNs
|
||||||
|
app_profiles_by_model = defaultdict(set)
|
||||||
|
|
||||||
if ENABLE_CROSS_REGION_INFERENCE:
|
if ENABLE_CROSS_REGION_INFERENCE:
|
||||||
# List system defined inference profile IDs
|
# List system defined inference profile IDs
|
||||||
@@ -128,7 +130,7 @@ def list_bedrock_models() -> dict:
|
|||||||
if model_arn:
|
if model_arn:
|
||||||
model_id = model_arn.split('/')[-1] if '/' in model_arn else model_arn
|
model_id = model_arn.split('/')[-1] if '/' in model_arn else model_arn
|
||||||
if model_id:
|
if model_id:
|
||||||
app_profile_dict[model_id] = profile_arn
|
app_profiles_by_model[model_id].add(profile_arn)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Error processing application profile: {e}")
|
logger.warning(f"Error processing application profile: {e}")
|
||||||
continue
|
continue
|
||||||
@@ -156,9 +158,10 @@ def list_bedrock_models() -> dict:
|
|||||||
if profile_id in profile_list:
|
if profile_id in profile_list:
|
||||||
model_list[profile_id] = {"modalities": input_modalities}
|
model_list[profile_id] = {"modalities": input_modalities}
|
||||||
|
|
||||||
# Add application inference profiles
|
# Add application inference profiles (emit all profiles for this model)
|
||||||
if model_id in app_profile_dict:
|
if model_id in app_profiles_by_model:
|
||||||
model_list[app_profile_dict[model_id]] = {"modalities": input_modalities}
|
for profile_arn in app_profiles_by_model[model_id]:
|
||||||
|
model_list[profile_arn] = {"modalities": input_modalities}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Unable to list models: {str(e)}")
|
logger.error(f"Unable to list models: {str(e)}")
|
||||||
|
|||||||
Reference in New Issue
Block a user