Refactor to use new Converse API

This commit is contained in:
Aiden Dai
2024-06-04 14:59:40 +08:00
parent 86e3db7e09
commit 696039053d
9 changed files with 497 additions and 679 deletions

View File

@@ -1,7 +0,0 @@
from api.models.bedrock import (
ClaudeModel,
SUPPORTED_BEDROCK_MODELS,
SUPPORTED_BEDROCK_EMBEDDING_MODELS,
get_model,
get_embeddings_model,
)

View File

@@ -1,3 +1,4 @@
import time
import uuid import uuid
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import AsyncIterable from typing import AsyncIterable
@@ -19,6 +20,14 @@ class BaseChatModel(ABC):
Currently, only Bedrock model is supported, but may be used for SageMaker models if needed. Currently, only Bedrock model is supported, but may be used for SageMaker models if needed.
""" """
def list_models(self) -> list[str]:
"""Return a list of supported models"""
return []
def validate(self, chat_request: ChatRequest):
"""Validate chat completion requests."""
pass
@abstractmethod @abstractmethod
def chat(self, chat_request: ChatRequest) -> ChatResponse: def chat(self, chat_request: ChatRequest) -> ChatResponse:
"""Handle a basic chat completion requests.""" """Handle a basic chat completion requests."""
@@ -38,7 +47,11 @@ class BaseChatModel(ABC):
response: ChatStreamResponse | None = None response: ChatStreamResponse | None = None
) -> bytes: ) -> bytes:
if response: if response:
return "data: {}\n\n".format(response.model_dump_json()).encode("utf-8") # to populate other fields when using exclude_unset=True
response.system_fingerprint = "fp"
response.object = "chat.completion.chunk"
response.created = int(time.time())
return "data: {}\n\n".format(response.model_dump_json(exclude_unset=True)).encode("utf-8")
return "data: [DONE]\n\n".encode("utf-8") return "data: [DONE]\n\n".encode("utf-8")

File diff suppressed because it is too large Load Diff

View File

@@ -4,7 +4,7 @@ from fastapi import APIRouter, Depends, Body
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from api.auth import api_key_auth from api.auth import api_key_auth
from api.models import get_model from api.models.bedrock import BedrockModel
from api.schema import ChatRequest, ChatResponse, ChatStreamResponse from api.schema import ChatRequest, ChatResponse, ChatStreamResponse
from api.setting import DEFAULT_MODEL from api.setting import DEFAULT_MODEL
@@ -36,7 +36,8 @@ async def chat_completions(
chat_request.model = DEFAULT_MODEL chat_request.model = DEFAULT_MODEL
# Exception will be raised if model not supported. # Exception will be raised if model not supported.
model = get_model(chat_request.model) model = BedrockModel()
model.validate(chat_request)
if chat_request.stream: if chat_request.stream:
return StreamingResponse( return StreamingResponse(
content=model.chat_stream(chat_request), media_type="text/event-stream" content=model.chat_stream(chat_request), media_type="text/event-stream"

View File

@@ -3,7 +3,7 @@ from typing import Annotated
from fastapi import APIRouter, Depends, Body from fastapi import APIRouter, Depends, Body
from api.auth import api_key_auth from api.auth import api_key_auth
from api.models import get_embeddings_model from api.models.bedrock import get_embeddings_model
from api.schema import EmbeddingsRequest, EmbeddingsResponse from api.schema import EmbeddingsRequest, EmbeddingsResponse
from api.setting import DEFAULT_EMBEDDING_MODEL from api.setting import DEFAULT_EMBEDDING_MODEL

View File

@@ -3,7 +3,7 @@ from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, Path from fastapi import APIRouter, Depends, HTTPException, Path
from api.auth import api_key_auth from api.auth import api_key_auth
from api.models import SUPPORTED_BEDROCK_MODELS, SUPPORTED_BEDROCK_EMBEDDING_MODELS from api.models.bedrock import BedrockModel
from api.schema import Models, Model from api.schema import Models, Model
router = APIRouter( router = APIRouter(
@@ -12,16 +12,19 @@ router = APIRouter(
# responses={404: {"description": "Not found"}}, # responses={404: {"description": "Not found"}},
) )
chat_model = BedrockModel()
async def validate_model_id(model_id: str): async def validate_model_id(model_id: str):
if model_id not in (SUPPORTED_BEDROCK_MODELS | SUPPORTED_BEDROCK_EMBEDDING_MODELS).keys(): if model_id not in chat_model.list_models():
raise HTTPException(status_code=500, detail="Unsupported Model Id") raise HTTPException(status_code=500, detail="Unsupported Model Id")
@router.get("", response_model=Models) @router.get("", response_model=Models)
async def list_models(): async def list_models():
model_list = [Model(id=model_id) for model_id in model_list = [
(SUPPORTED_BEDROCK_MODELS | SUPPORTED_BEDROCK_EMBEDDING_MODELS).keys()] Model(id=model_id) for model_id in chat_model.list_models()
]
return Models(data=model_list) return Models(data=model_list)

View File

@@ -1,5 +1,4 @@
import time import time
import uuid
from typing import Literal, Iterable from typing import Literal, Iterable
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@@ -18,12 +17,12 @@ class Models(BaseModel):
class ResponseFunction(BaseModel): class ResponseFunction(BaseModel):
name: str name: str | None = None
arguments: str arguments: str
class ToolCall(BaseModel): class ToolCall(BaseModel):
id: str = Field(default_factory=lambda: str(uuid.uuid4())[:8]) id: str | None = None
type: Literal["function"] = "function" type: Literal["function"] = "function"
function: ResponseFunction function: ResponseFunction
@@ -113,8 +112,8 @@ class ChatResponseMessage(BaseModel):
class BaseChoice(BaseModel): class BaseChoice(BaseModel):
index: int index: int | None = 0
finish_reason: str | None finish_reason: str | None = None
logprobs: dict | None = None logprobs: dict | None = None

View File

@@ -11,27 +11,11 @@ DESCRIPTION = """
Use OpenAI-Compatible RESTful APIs for Amazon Bedrock models. Use OpenAI-Compatible RESTful APIs for Amazon Bedrock models.
List of Amazon Bedrock models currently supported: List of Amazon Bedrock models currently supported:
- Anthropic Claude 2 / 3 (Haiku/Sonnet/Opus)
# Chat - Meta Llama 2 / 3
- anthropic.claude-instant-v1 - Mistral / Mixtral
- anthropic.claude-v2:1 - Cohere Command R / R+
- anthropic.claude-v2 - Cohere Embedding
- anthropic.claude-3-opus-20240229-v1:0
- anthropic.claude-3-sonnet-20240229-v1:0
- anthropic.claude-3-haiku-20240307-v1:0
- meta.llama2-13b-chat-v1
- meta.llama2-70b-chat-v1
- meta.llama3-8b-instruct-v1:0
- meta.llama3-70b-instruct-v1:0
- mistral.mistral-7b-instruct-v0:2
- mistral.mixtral-8x7b-instruct-v0:1
- mistral.mistral-large-2402-v1:0
- cohere.command-r-v1:0
- cohere.command-r-plus-v1:0
# Embeddings
- cohere.embed-multilingual-v3
- cohere.embed-english-v3
""" """
DEBUG = os.environ.get("DEBUG", "false").lower() != "false" DEBUG = os.environ.get("DEBUG", "false").lower() != "false"

View File

@@ -1,7 +1,9 @@
fastapi==0.110.2 fastapi==0.111.0
pydantic==2.7.1 pydantic==2.7.1
uvicorn==0.29.0 uvicorn==0.29.0
mangum==0.17.0 mangum==0.17.0
tiktoken==0.6.0 tiktoken==0.6.0
requests==2.32.0 requests==2.32.3
numpy==1.26.4 numpy==1.26.4
boto3==1.34.117
botocore==1.34.117