diff --git a/src/api/models/bedrock.py b/src/api/models/bedrock.py index dff1c5b..ad68ce7 100644 --- a/src/api/models/bedrock.py +++ b/src/api/models/bedrock.py @@ -38,7 +38,7 @@ from api.schema import ( Embedding, ) -from api.setting import DEBUG, AWS_REGION, ENABLE_CROSS_REGION_INFERENCE +from api.setting import DEBUG, AWS_REGION, ENABLE_CROSS_REGION_INFERENCE, DEFAULT_MODEL logger = logging.getLogger(__name__) @@ -126,6 +126,12 @@ def list_bedrock_models() -> dict: except Exception as e: logger.error(f"Unable to list models: {str(e)}") + if not model_list: + # In case stack not updated. + model_list[DEFAULT_MODEL] = { + 'modalities': ["TEXT", "IMAGE"] + } + return model_list