Refactor to use new Converse API
This commit is contained in:
@@ -1,7 +0,0 @@
|
|||||||
from api.models.bedrock import (
|
|
||||||
ClaudeModel,
|
|
||||||
SUPPORTED_BEDROCK_MODELS,
|
|
||||||
SUPPORTED_BEDROCK_EMBEDDING_MODELS,
|
|
||||||
get_model,
|
|
||||||
get_embeddings_model,
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import AsyncIterable
|
from typing import AsyncIterable
|
||||||
@@ -19,6 +20,14 @@ class BaseChatModel(ABC):
|
|||||||
Currently, only Bedrock model is supported, but may be used for SageMaker models if needed.
|
Currently, only Bedrock model is supported, but may be used for SageMaker models if needed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def list_models(self) -> list[str]:
|
||||||
|
"""Return a list of supported models"""
|
||||||
|
return []
|
||||||
|
|
||||||
|
def validate(self, chat_request: ChatRequest):
|
||||||
|
"""Validate chat completion requests."""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def chat(self, chat_request: ChatRequest) -> ChatResponse:
|
def chat(self, chat_request: ChatRequest) -> ChatResponse:
|
||||||
"""Handle a basic chat completion requests."""
|
"""Handle a basic chat completion requests."""
|
||||||
@@ -38,7 +47,11 @@ class BaseChatModel(ABC):
|
|||||||
response: ChatStreamResponse | None = None
|
response: ChatStreamResponse | None = None
|
||||||
) -> bytes:
|
) -> bytes:
|
||||||
if response:
|
if response:
|
||||||
return "data: {}\n\n".format(response.model_dump_json()).encode("utf-8")
|
# to populate other fields when using exclude_unset=True
|
||||||
|
response.system_fingerprint = "fp"
|
||||||
|
response.object = "chat.completion.chunk"
|
||||||
|
response.created = int(time.time())
|
||||||
|
return "data: {}\n\n".format(response.model_dump_json(exclude_unset=True)).encode("utf-8")
|
||||||
return "data: [DONE]\n\n".encode("utf-8")
|
return "data: [DONE]\n\n".encode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -4,7 +4,7 @@ from fastapi import APIRouter, Depends, Body
|
|||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
from api.auth import api_key_auth
|
from api.auth import api_key_auth
|
||||||
from api.models import get_model
|
from api.models.bedrock import BedrockModel
|
||||||
from api.schema import ChatRequest, ChatResponse, ChatStreamResponse
|
from api.schema import ChatRequest, ChatResponse, ChatStreamResponse
|
||||||
from api.setting import DEFAULT_MODEL
|
from api.setting import DEFAULT_MODEL
|
||||||
|
|
||||||
@@ -36,7 +36,8 @@ async def chat_completions(
|
|||||||
chat_request.model = DEFAULT_MODEL
|
chat_request.model = DEFAULT_MODEL
|
||||||
|
|
||||||
# Exception will be raised if model not supported.
|
# Exception will be raised if model not supported.
|
||||||
model = get_model(chat_request.model)
|
model = BedrockModel()
|
||||||
|
model.validate(chat_request)
|
||||||
if chat_request.stream:
|
if chat_request.stream:
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
content=model.chat_stream(chat_request), media_type="text/event-stream"
|
content=model.chat_stream(chat_request), media_type="text/event-stream"
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from typing import Annotated
|
|||||||
from fastapi import APIRouter, Depends, Body
|
from fastapi import APIRouter, Depends, Body
|
||||||
|
|
||||||
from api.auth import api_key_auth
|
from api.auth import api_key_auth
|
||||||
from api.models import get_embeddings_model
|
from api.models.bedrock import get_embeddings_model
|
||||||
from api.schema import EmbeddingsRequest, EmbeddingsResponse
|
from api.schema import EmbeddingsRequest, EmbeddingsResponse
|
||||||
from api.setting import DEFAULT_EMBEDDING_MODEL
|
from api.setting import DEFAULT_EMBEDDING_MODEL
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from typing import Annotated
|
|||||||
from fastapi import APIRouter, Depends, HTTPException, Path
|
from fastapi import APIRouter, Depends, HTTPException, Path
|
||||||
|
|
||||||
from api.auth import api_key_auth
|
from api.auth import api_key_auth
|
||||||
from api.models import SUPPORTED_BEDROCK_MODELS, SUPPORTED_BEDROCK_EMBEDDING_MODELS
|
from api.models.bedrock import BedrockModel
|
||||||
from api.schema import Models, Model
|
from api.schema import Models, Model
|
||||||
|
|
||||||
router = APIRouter(
|
router = APIRouter(
|
||||||
@@ -12,16 +12,19 @@ router = APIRouter(
|
|||||||
# responses={404: {"description": "Not found"}},
|
# responses={404: {"description": "Not found"}},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
chat_model = BedrockModel()
|
||||||
|
|
||||||
|
|
||||||
async def validate_model_id(model_id: str):
|
async def validate_model_id(model_id: str):
|
||||||
if model_id not in (SUPPORTED_BEDROCK_MODELS | SUPPORTED_BEDROCK_EMBEDDING_MODELS).keys():
|
if model_id not in chat_model.list_models():
|
||||||
raise HTTPException(status_code=500, detail="Unsupported Model Id")
|
raise HTTPException(status_code=500, detail="Unsupported Model Id")
|
||||||
|
|
||||||
|
|
||||||
@router.get("", response_model=Models)
|
@router.get("", response_model=Models)
|
||||||
async def list_models():
|
async def list_models():
|
||||||
model_list = [Model(id=model_id) for model_id in
|
model_list = [
|
||||||
(SUPPORTED_BEDROCK_MODELS | SUPPORTED_BEDROCK_EMBEDDING_MODELS).keys()]
|
Model(id=model_id) for model_id in chat_model.list_models()
|
||||||
|
]
|
||||||
return Models(data=model_list)
|
return Models(data=model_list)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import time
|
import time
|
||||||
import uuid
|
|
||||||
from typing import Literal, Iterable
|
from typing import Literal, Iterable
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
@@ -18,12 +17,12 @@ class Models(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class ResponseFunction(BaseModel):
|
class ResponseFunction(BaseModel):
|
||||||
name: str
|
name: str | None = None
|
||||||
arguments: str
|
arguments: str
|
||||||
|
|
||||||
|
|
||||||
class ToolCall(BaseModel):
|
class ToolCall(BaseModel):
|
||||||
id: str = Field(default_factory=lambda: str(uuid.uuid4())[:8])
|
id: str | None = None
|
||||||
type: Literal["function"] = "function"
|
type: Literal["function"] = "function"
|
||||||
function: ResponseFunction
|
function: ResponseFunction
|
||||||
|
|
||||||
@@ -113,8 +112,8 @@ class ChatResponseMessage(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class BaseChoice(BaseModel):
|
class BaseChoice(BaseModel):
|
||||||
index: int
|
index: int | None = 0
|
||||||
finish_reason: str | None
|
finish_reason: str | None = None
|
||||||
logprobs: dict | None = None
|
logprobs: dict | None = None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -11,27 +11,11 @@ DESCRIPTION = """
|
|||||||
Use OpenAI-Compatible RESTful APIs for Amazon Bedrock models.
|
Use OpenAI-Compatible RESTful APIs for Amazon Bedrock models.
|
||||||
|
|
||||||
List of Amazon Bedrock models currently supported:
|
List of Amazon Bedrock models currently supported:
|
||||||
|
- Anthropic Claude 2 / 3 (Haiku/Sonnet/Opus)
|
||||||
# Chat
|
- Meta Llama 2 / 3
|
||||||
- anthropic.claude-instant-v1
|
- Mistral / Mixtral
|
||||||
- anthropic.claude-v2:1
|
- Cohere Command R / R+
|
||||||
- anthropic.claude-v2
|
- Cohere Embedding
|
||||||
- anthropic.claude-3-opus-20240229-v1:0
|
|
||||||
- anthropic.claude-3-sonnet-20240229-v1:0
|
|
||||||
- anthropic.claude-3-haiku-20240307-v1:0
|
|
||||||
- meta.llama2-13b-chat-v1
|
|
||||||
- meta.llama2-70b-chat-v1
|
|
||||||
- meta.llama3-8b-instruct-v1:0
|
|
||||||
- meta.llama3-70b-instruct-v1:0
|
|
||||||
- mistral.mistral-7b-instruct-v0:2
|
|
||||||
- mistral.mixtral-8x7b-instruct-v0:1
|
|
||||||
- mistral.mistral-large-2402-v1:0
|
|
||||||
- cohere.command-r-v1:0
|
|
||||||
- cohere.command-r-plus-v1:0
|
|
||||||
|
|
||||||
# Embeddings
|
|
||||||
- cohere.embed-multilingual-v3
|
|
||||||
- cohere.embed-english-v3
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
DEBUG = os.environ.get("DEBUG", "false").lower() != "false"
|
DEBUG = os.environ.get("DEBUG", "false").lower() != "false"
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
fastapi==0.110.2
|
fastapi==0.111.0
|
||||||
pydantic==2.7.1
|
pydantic==2.7.1
|
||||||
uvicorn==0.29.0
|
uvicorn==0.29.0
|
||||||
mangum==0.17.0
|
mangum==0.17.0
|
||||||
tiktoken==0.6.0
|
tiktoken==0.6.0
|
||||||
requests==2.32.0
|
requests==2.32.3
|
||||||
numpy==1.26.4
|
numpy==1.26.4
|
||||||
|
boto3==1.34.117
|
||||||
|
botocore==1.34.117
|
||||||
Reference in New Issue
Block a user