From 268c5ef0f1ccee344d86461f2fe281ccfa79adbf Mon Sep 17 00:00:00 2001 From: Aiden Dai Date: Wed, 3 Apr 2024 11:20:31 +0800 Subject: [PATCH] Clean up code --- src/api/routers/model.py | 14 +++++++------- src/api/setting.py | 10 ++++++---- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/src/api/routers/model.py b/src/api/routers/model.py index 985f902..4e7cbac 100644 --- a/src/api/routers/model.py +++ b/src/api/routers/model.py @@ -10,7 +10,6 @@ router = APIRouter() router = APIRouter( prefix="/models", - tags=["items"], dependencies=[Depends(api_key_auth)], # responses={404: {"description": "Not found"}}, ) @@ -18,12 +17,13 @@ router = APIRouter( async def validate_model_id(model_id: str): if model_id not in (SUPPORTED_BEDROCK_MODELS | SUPPORTED_BEDROCK_EMBEDDING_MODELS).keys(): - raise HTTPException(status_code=400, detail="Unsupported Model Id") + raise HTTPException(status_code=500, detail="Unsupported Model Id") @router.get("/", response_model=Models) async def list_models(): - model_list = [Model(id=model_id) for model_id in (SUPPORTED_BEDROCK_MODELS | SUPPORTED_BEDROCK_EMBEDDING_MODELS).keys()] + model_list = [Model(id=model_id) for model_id in + (SUPPORTED_BEDROCK_MODELS | SUPPORTED_BEDROCK_EMBEDDING_MODELS).keys()] return Models(data=model_list) @@ -32,10 +32,10 @@ async def list_models(): response_model=Model, ) async def get_model( - model_id: Annotated[ - str, - Path(description="Model ID", example="anthropic.claude-3-sonnet-20240229-v1:0"), - ] + model_id: Annotated[ + str, + Path(description="Model ID", example="anthropic.claude-3-sonnet-20240229-v1:0"), + ] ): await validate_model_id(model_id) return Model(id=model_id) diff --git a/src/api/setting.py b/src/api/setting.py index 6a14f22..5f42f2c 100644 --- a/src/api/setting.py +++ b/src/api/setting.py @@ -26,11 +26,13 @@ List of Amazon Bedrock models currently supported: # Embeddings - cohere.embed-multilingual-v3 - cohere.embed-english-v3 -- amazon.titan-embed-text-v1 -- amazon.titan-embed-image-v1 """ DEBUG = os.environ.get("DEBUG", "false").lower() != "false" AWS_REGION = os.environ.get("AWS_REGION", "us-west-2") -DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "anthropic.claude-3-sonnet-20240229-v1:0") -DEFAULT_EMBEDDING_MODEL = os.environ.get("DEFAULT_EMBEDDING_MODEL", "cohere.embed-multilingual-v3") +DEFAULT_MODEL = os.environ.get( + "DEFAULT_MODEL", "anthropic.claude-3-sonnet-20240229-v1:0" +) +DEFAULT_EMBEDDING_MODEL = os.environ.get( + "DEFAULT_EMBEDDING_MODEL", "cohere.embed-multilingual-v3" +)