Automatically detect model list

This commit is contained in:
Aiden Dai
2024-12-16 16:01:19 +08:00
parent cb38d328aa
commit d4938a0af2
3 changed files with 108 additions and 332 deletions

View File

@@ -7,10 +7,10 @@ from abc import ABC
from typing import AsyncIterable, Iterable, Literal from typing import AsyncIterable, Iterable, Literal
import boto3 import boto3
from botocore.config import Config
import numpy as np import numpy as np
import requests import requests
import tiktoken import tiktoken
from botocore.config import Config
from fastapi import HTTPException from fastapi import HTTPException
from api.models.base import BaseChatModel, BaseEmbeddingsModel from api.models.base import BaseChatModel, BaseEmbeddingsModel
@@ -37,9 +37,8 @@ from api.schema import (
EmbeddingsUsage, EmbeddingsUsage,
Embedding, Embedding,
) )
from api.setting import DEBUG, AWS_REGION from api.setting import DEBUG, AWS_REGION, ENABLE_CROSS_REGION_INFERENCE
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -50,6 +49,21 @@ bedrock_runtime = boto3.client(
region_name=AWS_REGION, region_name=AWS_REGION,
config=config, config=config,
) )
bedrock_client = boto3.client(
service_name='bedrock',
region_name=AWS_REGION,
config=config,
)
def get_inference_region_prefix():
if AWS_REGION.startswith('ap-'):
return 'apac'
return AWS_REGION[:2]
# https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html
cr_inference_prefix = get_inference_region_prefix()
SUPPORTED_BEDROCK_EMBEDDING_MODELS = { SUPPORTED_BEDROCK_EMBEDDING_MODELS = {
"cohere.embed-multilingual-v3": "Cohere Embed Multilingual", "cohere.embed-multilingual-v3": "Cohere Embed Multilingual",
@@ -62,296 +76,78 @@ SUPPORTED_BEDROCK_EMBEDDING_MODELS = {
ENCODER = tiktoken.get_encoding("cl100k_base") ENCODER = tiktoken.get_encoding("cl100k_base")
class BedrockModel(BaseChatModel): def list_bedrock_models() -> dict:
# https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#conversation-inference-supported-models-features """Automatically getting a list of supported models.
_supported_models = {
"amazon.titan-text-premier-v1:0": { Returns a model list combines:
"system": True, - ON_DEMAND models.
"multimodal": False, - Cross-Region Inference Profiles (if enabled via Env)
"tool_call": False, """
"stream_tool_call": False, model_list = {}
}, try:
"anthropic.claude-instant-v1": { profile_list = []
"system": True, if ENABLE_CROSS_REGION_INFERENCE:
"multimodal": False, # List system defined inference profile IDs
"tool_call": False, response = bedrock_client.list_inference_profiles(
"stream_tool_call": False, maxResults=1000,
}, typeEquals='SYSTEM_DEFINED'
"anthropic.claude-v2:1": { )
"system": True, profile_list = [p['inferenceProfileId'] for p in response['inferenceProfileSummaries']]
"multimodal": False,
"tool_call": False, # List foundation models, only cares about text outputs here.
"stream_tool_call": False, response = bedrock_client.list_foundation_models(
}, byOutputModality='TEXT'
"anthropic.claude-v2": { )
"system": True,
"multimodal": False, for model in response['modelSummaries']:
"tool_call": False, model_id = model.get('modelId', 'N/A')
"stream_tool_call": False, stream_supported = model.get('responseStreamingSupported', True)
}, status = model['modelLifecycle'].get('status', 'ACTIVE')
"anthropic.claude-3-sonnet-20240229-v1:0": {
"system": True, # currently, use this to filter out rerank models and legacy models
"multimodal": True, if not stream_supported or status != "ACTIVE":
"tool_call": True, continue
"stream_tool_call": True,
}, inference_types = model.get('inferenceTypesSupported', [])
"anthropic.claude-3-opus-20240229-v1:0": { input_modalities = model['inputModalities']
"system": True, # Add on-demand model list
"multimodal": True, if 'ON_DEMAND' in inference_types:
"tool_call": True, model_list[model_id] = {
"stream_tool_call": True, 'modalities': input_modalities
},
"anthropic.claude-3-haiku-20240307-v1:0": {
"system": True,
"multimodal": True,
"tool_call": True,
"stream_tool_call": True,
},
"anthropic.claude-3-5-sonnet-20240620-v1:0": {
"system": True,
"multimodal": True,
"tool_call": True,
"stream_tool_call": True,
},
"anthropic.claude-3-5-sonnet-20241022-v2:0": {
"system": True,
"multimodal": True,
"tool_call": True,
"stream_tool_call": True,
},
"meta.llama2-13b-chat-v1": {
"system": True,
"multimodal": False,
"tool_call": False,
"stream_tool_call": False,
},
"meta.llama2-70b-chat-v1": {
"system": True,
"multimodal": False,
"tool_call": False,
"stream_tool_call": False,
},
"meta.llama3-8b-instruct-v1:0": {
"system": True,
"multimodal": False,
"tool_call": False,
"stream_tool_call": False,
},
"meta.llama3-70b-instruct-v1:0": {
"system": True,
"multimodal": False,
"tool_call": False,
"stream_tool_call": False,
},
# Llama 3.1 8b cross-region inference profile
"us.meta.llama3-1-8b-instruct-v1:0": {
"system": True,
"multimodal": False,
"tool_call": True,
"stream_tool_call": False,
},
"meta.llama3-1-8b-instruct-v1:0": {
"system": True,
"multimodal": False,
"tool_call": True,
"stream_tool_call": False,
},
# Llama 3.1 70b cross-region inference profile
"us.meta.llama3-1-70b-instruct-v1:0": {
"system": True,
"multimodal": False,
"tool_call": True,
"stream_tool_call": False,
},
"meta.llama3-1-70b-instruct-v1:0": {
"system": True,
"multimodal": False,
"tool_call": True,
"stream_tool_call": False,
},
"meta.llama3-1-405b-instruct-v1:0": {
"system": True,
"multimodal": False,
"tool_call": True,
"stream_tool_call": False,
},
# Llama 3.2 1B cross-region inference profile
"us.meta.llama3-2-1b-instruct-v1:0": {
"system": True,
"multimodal": False,
"tool_call": False,
"stream_tool_call": False,
},
# Llama 3.2 3B cross-region inference profile
"us.meta.llama3-2-3b-instruct-v1:0": {
"system": True,
"multimodal": False,
"tool_call": False,
"stream_tool_call": False,
},
# Llama 3.2 11B cross-region inference profile
"us.meta.llama3-2-11b-instruct-v1:0": {
"system": True,
"multimodal": True,
"tool_call": True,
"stream_tool_call": False,
},
# Llama 3.2 90B cross-region inference profile
"us.meta.llama3-2-90b-instruct-v1:0": {
"system": True,
"multimodal": True,
"tool_call": True,
"stream_tool_call": False,
},
"mistral.mistral-7b-instruct-v0:2": {
"system": False,
"multimodal": False,
"tool_call": False,
"stream_tool_call": False,
},
"mistral.mixtral-8x7b-instruct-v0:1": {
"system": False,
"multimodal": False,
"tool_call": False,
"stream_tool_call": False,
},
"mistral.mistral-small-2402-v1:0": {
"system": True,
"multimodal": False,
"tool_call": False,
"stream_tool_call": False,
},
"mistral.mistral-large-2402-v1:0": {
"system": True,
"multimodal": False,
"tool_call": True,
"stream_tool_call": False,
},
"mistral.mistral-large-2407-v1:0": {
"system": True,
"multimodal": False,
"tool_call": True,
"stream_tool_call": False,
},
"cohere.command-r-v1:0": {
"system": True,
"multimodal": False,
"tool_call": True,
"stream_tool_call": False,
},
"cohere.command-r-plus-v1:0": {
"system": True,
"multimodal": False,
"tool_call": True,
"stream_tool_call": False,
},
"apac.anthropic.claude-3-sonnet-20240229-v1:0": {
"system": True,
"multimodal": True,
"tool_call": True,
"stream_tool_call": True,
},
"apac.anthropic.claude-3-haiku-20240307-v1:0": {
"system": True,
"multimodal": True,
"tool_call": True,
"stream_tool_call": True,
},
"apac.anthropic.claude-3-5-sonnet-20240620-v1:0": {
"system": True,
"multimodal": True,
"tool_call": True,
"stream_tool_call": True,
},
# claude 3 Haiku cross-region inference profile
"us.anthropic.claude-3-haiku-20240307-v1:0": {
"system": True,
"multimodal": True,
"tool_call": True,
"stream_tool_call": True,
},
"eu.anthropic.claude-3-haiku-20240307-v1:0": {
"system": True,
"multimodal": True,
"tool_call": True,
"stream_tool_call": True,
},
# claude 3 Opus cross-region inference profile
"us.anthropic.claude-3-opus-20240229-v1:0": {
"system": True,
"multimodal": True,
"tool_call": True,
"stream_tool_call": True,
},
# claude 3 Sonnet cross-region inference profile
"us.anthropic.claude-3-sonnet-20240229-v1:0": {
"system": True,
"multimodal": True,
"tool_call": True,
"stream_tool_call": True,
},
"eu.anthropic.claude-3-sonnet-20240229-v1:0": {
"system": True,
"multimodal": True,
"tool_call": True,
"stream_tool_call": True,
},
# claude 3.5 Sonnet cross-region inference profile
"us.anthropic.claude-3-5-sonnet-20240620-v1:0": {
"system": True,
"multimodal": True,
"tool_call": True,
"stream_tool_call": True,
},
"eu.anthropic.claude-3-5-sonnet-20240620-v1:0": {
"system": True,
"multimodal": True,
"tool_call": True,
"stream_tool_call": True,
},
# claude 3.5 Sonnet v2 cross-region inference profile(Now only us-west-2)
"us.anthropic.claude-3-5-sonnet-20241022-v2:0": {
"system": True,
"multimodal": True,
"tool_call": True,
"stream_tool_call": True,
},
# Amazon Nova models - AWS's proprietary large language models
"us.amazon.nova-lite-v1:0": {
"system": True, # Supports system prompts for context setting
"multimodal": True, # Capable of processing both text and images
"tool_call": True,
"stream_tool_call": True,
},
"us.amazon.nova-micro-v1:0": {
"system": True, # Supports system prompts for context setting
"multimodal": False, # Text-only model, no image processing capabilities
"tool_call": True,
"stream_tool_call": True,
},
"us.amazon.nova-pro-v1:0": {
"system": True, # Supports system prompts for context setting
"multimodal": True, # Capable of processing both text and images
"tool_call": True,
"stream_tool_call": True,
},
} }
# Add cross-region inference model list.
profile_id = cr_inference_prefix + '.' + model_id
if profile_id in profile_list:
model_list[profile_id] = {
'modalities': input_modalities
}
except Exception as e:
logger.error(f"Unable to list models: {str(e)}")
return model_list
# Initialize the model list.
bedrock_model_list = list_bedrock_models()
class BedrockModel(BaseChatModel):
def list_models(self) -> list[str]: def list_models(self) -> list[str]:
return list(self._supported_models.keys()) """Always refresh the latest model list"""
global bedrock_model_list
bedrock_model_list = list_bedrock_models()
return list(bedrock_model_list.keys())
def validate(self, chat_request: ChatRequest): def validate(self, chat_request: ChatRequest):
"""Perform basic validation on requests""" """Perform basic validation on requests"""
error = "" error = ""
# check if model is supported # check if model is supported
if chat_request.model not in self._supported_models.keys(): if chat_request.model not in bedrock_model_list.keys():
error = f"Unsupported model {chat_request.model}, please use models API to get a list of supported models" error = f"Unsupported model {chat_request.model}, please use models API to get a list of supported models"
# check if tool call is supported
elif chat_request.tools and not self._is_tool_call_supported(chat_request.model, stream=chat_request.stream):
tool_call_info = "Tool call with streaming" if chat_request.stream else "Tool call"
error = f"{tool_call_info} is currently not supported by {chat_request.model}"
if error: if error:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
@@ -529,7 +325,6 @@ class BedrockModel(BaseChatModel):
continue continue
return self._reframe_multi_payloard(messages) return self._reframe_multi_payloard(messages)
def _reframe_multi_payloard(self, messages: list) -> list: def _reframe_multi_payloard(self, messages: list) -> list:
""" Receive messages and reformat them to comply with the Claude format """ Receive messages and reformat them to comply with the Claude format
@@ -540,20 +335,19 @@ This method searches through the OpenAI format messages in order and reformats t
``` ```
openai_format_messages=[ openai_format_messages=[
{"role": "user", "content": "hogehoge"}, {"role": "user", "content": "Hello"},
{"role": "user", "content": "fugafuga"}, {"role": "user", "content": "Who are you?"},
] ]
bedrock_format_messages=[ bedrock_format_messages=[
{ {
"role": "user", "role": "user",
"content": [ "content": [
{"text": "hogehoge"}, {"text": "Hello"},
{"text": "fugafuga"} {"text": "Who are you?"}
] ]
}, },
] ]
```
""" """
reformatted_messages = [] reformatted_messages = []
current_role = None current_role = None
@@ -590,7 +384,6 @@ bedrock_format_messages=[
return reformatted_messages return reformatted_messages
def _parse_request(self, chat_request: ChatRequest) -> dict: def _parse_request(self, chat_request: ChatRequest) -> dict:
"""Create default converse request body. """Create default converse request body.
@@ -839,7 +632,7 @@ bedrock_format_messages=[
} }
) )
elif isinstance(part, ImageContent): elif isinstance(part, ImageContent):
if not self._is_multimodal_supported(model_id): if not self.is_supported_modality(model_id, modality="IMAGE"):
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=f"Multimodal message is currently not supported by {model_id}", detail=f"Multimodal message is currently not supported by {model_id}",
@@ -858,23 +651,13 @@ bedrock_format_messages=[
continue continue
return content_parts return content_parts
def _is_tool_call_supported(self, model_id: str, stream: bool = False) -> bool: @staticmethod
feature = self._supported_models.get(model_id) def is_supported_modality(model_id: str, modality: str = "IMAGE") -> bool:
if not feature: model = bedrock_model_list.get(model_id)
modalities = model.get('modalities', [])
if modality in modalities:
return True
return False return False
return feature["stream_tool_call"] if stream else feature["tool_call"]
def _is_multimodal_supported(self, model_id: str) -> bool:
feature = self._supported_models.get(model_id)
if not feature:
return False
return feature["multimodal"]
def _is_system_prompt_supported(self, model_id: str) -> bool:
feature = self._supported_models.get(model_id)
if not feature:
return False
return feature["system"]
def _convert_tool_spec(self, func: Function) -> dict: def _convert_tool_spec(self, func: Function) -> dict:
return { return {

View File

@@ -9,13 +9,6 @@ SUMMARY = "OpenAI-Compatible RESTful APIs for Amazon Bedrock"
VERSION = "0.1.0" VERSION = "0.1.0"
DESCRIPTION = """ DESCRIPTION = """
Use OpenAI-Compatible RESTful APIs for Amazon Bedrock models. Use OpenAI-Compatible RESTful APIs for Amazon Bedrock models.
List of Amazon Bedrock models currently supported:
- Anthropic Claude 2 / 3 /3.5 (Haiku/Sonnet/Opus)
- Meta Llama 2 / 3
- Mistral / Mixtral
- Cohere Command R / R+
- Cohere Embedding
""" """
DEBUG = os.environ.get("DEBUG", "false").lower() != "false" DEBUG = os.environ.get("DEBUG", "false").lower() != "false"
@@ -26,3 +19,4 @@ DEFAULT_MODEL = os.environ.get(
DEFAULT_EMBEDDING_MODEL = os.environ.get( DEFAULT_EMBEDDING_MODEL = os.environ.get(
"DEFAULT_EMBEDDING_MODEL", "cohere.embed-multilingual-v3" "DEFAULT_EMBEDDING_MODEL", "cohere.embed-multilingual-v3"
) )
ENABLE_CROSS_REGION_INFERENCE = os.environ.get("ENABLE_CROSS_REGION_INFERENCE", "true").lower() != "false"

View File

@@ -1,10 +1,9 @@
fastapi==0.111.0 fastapi==0.115.6
pydantic==2.7.1 pydantic==2.7.1
uvicorn==0.29.0 uvicorn==0.29.0
mangum==0.17.0 mangum==0.17.0
tiktoken==0.6.0 tiktoken==0.6.0
requests==2.32.3 requests==2.32.3
numpy==1.26.4 numpy==1.26.4
boto3==1.35.49 boto3==1.35.81
botocore==1.35.49 botocore==1.35.81