Refactor to use new Converse API
This commit is contained in:
@@ -1,7 +0,0 @@
|
||||
from api.models.bedrock import (
|
||||
ClaudeModel,
|
||||
SUPPORTED_BEDROCK_MODELS,
|
||||
SUPPORTED_BEDROCK_EMBEDDING_MODELS,
|
||||
get_model,
|
||||
get_embeddings_model,
|
||||
)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import time
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import AsyncIterable
|
||||
@@ -19,6 +20,14 @@ class BaseChatModel(ABC):
|
||||
Currently, only Bedrock model is supported, but may be used for SageMaker models if needed.
|
||||
"""
|
||||
|
||||
def list_models(self) -> list[str]:
|
||||
"""Return a list of supported models"""
|
||||
return []
|
||||
|
||||
def validate(self, chat_request: ChatRequest):
|
||||
"""Validate chat completion requests."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def chat(self, chat_request: ChatRequest) -> ChatResponse:
|
||||
"""Handle a basic chat completion requests."""
|
||||
@@ -38,7 +47,11 @@ class BaseChatModel(ABC):
|
||||
response: ChatStreamResponse | None = None
|
||||
) -> bytes:
|
||||
if response:
|
||||
return "data: {}\n\n".format(response.model_dump_json()).encode("utf-8")
|
||||
# to populate other fields when using exclude_unset=True
|
||||
response.system_fingerprint = "fp"
|
||||
response.object = "chat.completion.chunk"
|
||||
response.created = int(time.time())
|
||||
return "data: {}\n\n".format(response.model_dump_json(exclude_unset=True)).encode("utf-8")
|
||||
return "data: [DONE]\n\n".encode("utf-8")
|
||||
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,7 +4,7 @@ from fastapi import APIRouter, Depends, Body
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from api.auth import api_key_auth
|
||||
from api.models import get_model
|
||||
from api.models.bedrock import BedrockModel
|
||||
from api.schema import ChatRequest, ChatResponse, ChatStreamResponse
|
||||
from api.setting import DEFAULT_MODEL
|
||||
|
||||
@@ -36,7 +36,8 @@ async def chat_completions(
|
||||
chat_request.model = DEFAULT_MODEL
|
||||
|
||||
# Exception will be raised if model not supported.
|
||||
model = get_model(chat_request.model)
|
||||
model = BedrockModel()
|
||||
model.validate(chat_request)
|
||||
if chat_request.stream:
|
||||
return StreamingResponse(
|
||||
content=model.chat_stream(chat_request), media_type="text/event-stream"
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Annotated
|
||||
from fastapi import APIRouter, Depends, Body
|
||||
|
||||
from api.auth import api_key_auth
|
||||
from api.models import get_embeddings_model
|
||||
from api.models.bedrock import get_embeddings_model
|
||||
from api.schema import EmbeddingsRequest, EmbeddingsResponse
|
||||
from api.setting import DEFAULT_EMBEDDING_MODEL
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Annotated
|
||||
from fastapi import APIRouter, Depends, HTTPException, Path
|
||||
|
||||
from api.auth import api_key_auth
|
||||
from api.models import SUPPORTED_BEDROCK_MODELS, SUPPORTED_BEDROCK_EMBEDDING_MODELS
|
||||
from api.models.bedrock import BedrockModel
|
||||
from api.schema import Models, Model
|
||||
|
||||
router = APIRouter(
|
||||
@@ -12,16 +12,19 @@ router = APIRouter(
|
||||
# responses={404: {"description": "Not found"}},
|
||||
)
|
||||
|
||||
chat_model = BedrockModel()
|
||||
|
||||
|
||||
async def validate_model_id(model_id: str):
|
||||
if model_id not in (SUPPORTED_BEDROCK_MODELS | SUPPORTED_BEDROCK_EMBEDDING_MODELS).keys():
|
||||
if model_id not in chat_model.list_models():
|
||||
raise HTTPException(status_code=500, detail="Unsupported Model Id")
|
||||
|
||||
|
||||
@router.get("", response_model=Models)
|
||||
async def list_models():
|
||||
model_list = [Model(id=model_id) for model_id in
|
||||
(SUPPORTED_BEDROCK_MODELS | SUPPORTED_BEDROCK_EMBEDDING_MODELS).keys()]
|
||||
model_list = [
|
||||
Model(id=model_id) for model_id in chat_model.list_models()
|
||||
]
|
||||
return Models(data=model_list)
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import time
|
||||
import uuid
|
||||
from typing import Literal, Iterable
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -18,12 +17,12 @@ class Models(BaseModel):
|
||||
|
||||
|
||||
class ResponseFunction(BaseModel):
|
||||
name: str
|
||||
name: str | None = None
|
||||
arguments: str
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4())[:8])
|
||||
id: str | None = None
|
||||
type: Literal["function"] = "function"
|
||||
function: ResponseFunction
|
||||
|
||||
@@ -113,8 +112,8 @@ class ChatResponseMessage(BaseModel):
|
||||
|
||||
|
||||
class BaseChoice(BaseModel):
|
||||
index: int
|
||||
finish_reason: str | None
|
||||
index: int | None = 0
|
||||
finish_reason: str | None = None
|
||||
logprobs: dict | None = None
|
||||
|
||||
|
||||
|
||||
@@ -11,27 +11,11 @@ DESCRIPTION = """
|
||||
Use OpenAI-Compatible RESTful APIs for Amazon Bedrock models.
|
||||
|
||||
List of Amazon Bedrock models currently supported:
|
||||
|
||||
# Chat
|
||||
- anthropic.claude-instant-v1
|
||||
- anthropic.claude-v2:1
|
||||
- anthropic.claude-v2
|
||||
- anthropic.claude-3-opus-20240229-v1:0
|
||||
- anthropic.claude-3-sonnet-20240229-v1:0
|
||||
- anthropic.claude-3-haiku-20240307-v1:0
|
||||
- meta.llama2-13b-chat-v1
|
||||
- meta.llama2-70b-chat-v1
|
||||
- meta.llama3-8b-instruct-v1:0
|
||||
- meta.llama3-70b-instruct-v1:0
|
||||
- mistral.mistral-7b-instruct-v0:2
|
||||
- mistral.mixtral-8x7b-instruct-v0:1
|
||||
- mistral.mistral-large-2402-v1:0
|
||||
- cohere.command-r-v1:0
|
||||
- cohere.command-r-plus-v1:0
|
||||
|
||||
# Embeddings
|
||||
- cohere.embed-multilingual-v3
|
||||
- cohere.embed-english-v3
|
||||
- Anthropic Claude 2 / 3 (Haiku/Sonnet/Opus)
|
||||
- Meta Llama 2 / 3
|
||||
- Mistral / Mixtral
|
||||
- Cohere Command R / R+
|
||||
- Cohere Embedding
|
||||
"""
|
||||
|
||||
DEBUG = os.environ.get("DEBUG", "false").lower() != "false"
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
fastapi==0.110.2
|
||||
fastapi==0.111.0
|
||||
pydantic==2.7.1
|
||||
uvicorn==0.29.0
|
||||
mangum==0.17.0
|
||||
tiktoken==0.6.0
|
||||
requests==2.32.0
|
||||
requests==2.32.3
|
||||
numpy==1.26.4
|
||||
boto3==1.34.117
|
||||
botocore==1.34.117
|
||||
Reference in New Issue
Block a user