Refactor to use new Converse API

This commit is contained in:
Aiden Dai
2024-06-04 14:59:40 +08:00
parent 86e3db7e09
commit 696039053d
9 changed files with 497 additions and 679 deletions

View File

@@ -1,7 +0,0 @@
from api.models.bedrock import (
ClaudeModel,
SUPPORTED_BEDROCK_MODELS,
SUPPORTED_BEDROCK_EMBEDDING_MODELS,
get_model,
get_embeddings_model,
)

View File

@@ -1,3 +1,4 @@
import time
import uuid
from abc import ABC, abstractmethod
from typing import AsyncIterable
@@ -19,6 +20,14 @@ class BaseChatModel(ABC):
Currently, only Bedrock model is supported, but may be used for SageMaker models if needed.
"""
def list_models(self) -> list[str]:
"""Return a list of supported models"""
return []
def validate(self, chat_request: ChatRequest):
"""Validate chat completion requests."""
pass
@abstractmethod
def chat(self, chat_request: ChatRequest) -> ChatResponse:
"""Handle a basic chat completion requests."""
@@ -38,7 +47,11 @@ class BaseChatModel(ABC):
response: ChatStreamResponse | None = None
) -> bytes:
if response:
return "data: {}\n\n".format(response.model_dump_json()).encode("utf-8")
# to populate other fields when using exclude_unset=True
response.system_fingerprint = "fp"
response.object = "chat.completion.chunk"
response.created = int(time.time())
return "data: {}\n\n".format(response.model_dump_json(exclude_unset=True)).encode("utf-8")
return "data: [DONE]\n\n".encode("utf-8")

File diff suppressed because it is too large Load Diff

View File

@@ -4,7 +4,7 @@ from fastapi import APIRouter, Depends, Body
from fastapi.responses import StreamingResponse
from api.auth import api_key_auth
from api.models import get_model
from api.models.bedrock import BedrockModel
from api.schema import ChatRequest, ChatResponse, ChatStreamResponse
from api.setting import DEFAULT_MODEL
@@ -34,9 +34,10 @@ async def chat_completions(
):
if chat_request.model.lower().startswith("gpt-"):
chat_request.model = DEFAULT_MODEL
# Exception will be raised if model not supported.
model = get_model(chat_request.model)
model = BedrockModel()
model.validate(chat_request)
if chat_request.stream:
return StreamingResponse(
content=model.chat_stream(chat_request), media_type="text/event-stream"

View File

@@ -3,7 +3,7 @@ from typing import Annotated
from fastapi import APIRouter, Depends, Body
from api.auth import api_key_auth
from api.models import get_embeddings_model
from api.models.bedrock import get_embeddings_model
from api.schema import EmbeddingsRequest, EmbeddingsResponse
from api.setting import DEFAULT_EMBEDDING_MODEL

View File

@@ -3,7 +3,7 @@ from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, Path
from api.auth import api_key_auth
from api.models import SUPPORTED_BEDROCK_MODELS, SUPPORTED_BEDROCK_EMBEDDING_MODELS
from api.models.bedrock import BedrockModel
from api.schema import Models, Model
router = APIRouter(
@@ -12,16 +12,19 @@ router = APIRouter(
# responses={404: {"description": "Not found"}},
)
chat_model = BedrockModel()
async def validate_model_id(model_id: str):
if model_id not in (SUPPORTED_BEDROCK_MODELS | SUPPORTED_BEDROCK_EMBEDDING_MODELS).keys():
if model_id not in chat_model.list_models():
raise HTTPException(status_code=500, detail="Unsupported Model Id")
@router.get("", response_model=Models)
async def list_models():
model_list = [Model(id=model_id) for model_id in
(SUPPORTED_BEDROCK_MODELS | SUPPORTED_BEDROCK_EMBEDDING_MODELS).keys()]
model_list = [
Model(id=model_id) for model_id in chat_model.list_models()
]
return Models(data=model_list)

View File

@@ -1,5 +1,4 @@
import time
import uuid
from typing import Literal, Iterable
from pydantic import BaseModel, Field
@@ -18,12 +17,12 @@ class Models(BaseModel):
class ResponseFunction(BaseModel):
name: str
name: str | None = None
arguments: str
class ToolCall(BaseModel):
id: str = Field(default_factory=lambda: str(uuid.uuid4())[:8])
id: str | None = None
type: Literal["function"] = "function"
function: ResponseFunction
@@ -113,8 +112,8 @@ class ChatResponseMessage(BaseModel):
class BaseChoice(BaseModel):
index: int
finish_reason: str | None
index: int | None = 0
finish_reason: str | None = None
logprobs: dict | None = None

View File

@@ -11,27 +11,11 @@ DESCRIPTION = """
Use OpenAI-Compatible RESTful APIs for Amazon Bedrock models.
List of Amazon Bedrock models currently supported:
# Chat
- anthropic.claude-instant-v1
- anthropic.claude-v2:1
- anthropic.claude-v2
- anthropic.claude-3-opus-20240229-v1:0
- anthropic.claude-3-sonnet-20240229-v1:0
- anthropic.claude-3-haiku-20240307-v1:0
- meta.llama2-13b-chat-v1
- meta.llama2-70b-chat-v1
- meta.llama3-8b-instruct-v1:0
- meta.llama3-70b-instruct-v1:0
- mistral.mistral-7b-instruct-v0:2
- mistral.mixtral-8x7b-instruct-v0:1
- mistral.mistral-large-2402-v1:0
- cohere.command-r-v1:0
- cohere.command-r-plus-v1:0
# Embeddings
- cohere.embed-multilingual-v3
- cohere.embed-english-v3
- Anthropic Claude 2 / 3 (Haiku/Sonnet/Opus)
- Meta Llama 2 / 3
- Mistral / Mixtral
- Cohere Command R / R+
- Cohere Embedding
"""
DEBUG = os.environ.get("DEBUG", "false").lower() != "false"