Refactor to use new Converse API
This commit is contained in:
@@ -1,7 +0,0 @@
|
||||
from api.models.bedrock import (
|
||||
ClaudeModel,
|
||||
SUPPORTED_BEDROCK_MODELS,
|
||||
SUPPORTED_BEDROCK_EMBEDDING_MODELS,
|
||||
get_model,
|
||||
get_embeddings_model,
|
||||
)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import time
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import AsyncIterable
|
||||
@@ -19,6 +20,14 @@ class BaseChatModel(ABC):
|
||||
Currently, only Bedrock model is supported, but may be used for SageMaker models if needed.
|
||||
"""
|
||||
|
||||
def list_models(self) -> list[str]:
|
||||
"""Return a list of supported models"""
|
||||
return []
|
||||
|
||||
def validate(self, chat_request: ChatRequest):
|
||||
"""Validate chat completion requests."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def chat(self, chat_request: ChatRequest) -> ChatResponse:
|
||||
"""Handle a basic chat completion requests."""
|
||||
@@ -38,7 +47,11 @@ class BaseChatModel(ABC):
|
||||
response: ChatStreamResponse | None = None
|
||||
) -> bytes:
|
||||
if response:
|
||||
return "data: {}\n\n".format(response.model_dump_json()).encode("utf-8")
|
||||
# to populate other fields when using exclude_unset=True
|
||||
response.system_fingerprint = "fp"
|
||||
response.object = "chat.completion.chunk"
|
||||
response.created = int(time.time())
|
||||
return "data: {}\n\n".format(response.model_dump_json(exclude_unset=True)).encode("utf-8")
|
||||
return "data: [DONE]\n\n".encode("utf-8")
|
||||
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user