Add multimodal support
This commit is contained in:
@@ -1,9 +1,12 @@
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
from typing import AsyncIterable, Iterable
|
||||
|
||||
import boto3
|
||||
import requests
|
||||
import tiktoken
|
||||
from fastapi import HTTPException
|
||||
|
||||
from api.models.base import BaseChatModel, BaseEmbeddingsModel
|
||||
from api.schema import (
|
||||
@@ -20,7 +23,7 @@ from api.schema import (
|
||||
EmbeddingsRequest,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingsUsage,
|
||||
Embedding,
|
||||
Embedding, TextContent,
|
||||
)
|
||||
from api.setting import DEBUG, AWS_REGION
|
||||
|
||||
@@ -147,6 +150,44 @@ def get_model(model_id: str) -> BedrockModel:
|
||||
class ClaudeModel(BedrockModel):
|
||||
anthropic_version = "bedrock-2023-05-31"
|
||||
|
||||
def _get_base64_image(self, image_url: str):
|
||||
# Send a request to the image URL
|
||||
response = requests.get(image_url)
|
||||
# Check if the request was successful
|
||||
if response.status_code == 200:
|
||||
# Get the image content
|
||||
image_content = response.content
|
||||
# Encode the image content as base64
|
||||
base64_image = base64.b64encode(image_content)
|
||||
return base64_image.decode('utf-8')
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Unable to access the image url")
|
||||
|
||||
def _parse_messages(self, messages: list[ChatRequestMessage]) -> list[dict]:
|
||||
# Refer to: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
|
||||
converted_messages = []
|
||||
for msg in messages:
|
||||
if isinstance(msg.content, str):
|
||||
converted_messages.append({"role": msg.role, "content": msg.content})
|
||||
continue
|
||||
|
||||
content_parts = []
|
||||
for part in msg.content:
|
||||
if isinstance(part, TextContent):
|
||||
content_parts.append(part.model_dump())
|
||||
else:
|
||||
content_parts.append({
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/jpeg",
|
||||
"data": self._get_base64_image(part.image_url.url)
|
||||
}
|
||||
})
|
||||
|
||||
converted_messages.append({"role": msg.role, "content": content_parts})
|
||||
return converted_messages
|
||||
|
||||
def _parse_args(self, chat_request: ChatRequest) -> dict:
|
||||
args = {
|
||||
"anthropic_version": self.anthropic_version,
|
||||
@@ -154,18 +195,11 @@ class ClaudeModel(BedrockModel):
|
||||
"top_p": chat_request.top_p,
|
||||
"temperature": chat_request.temperature,
|
||||
}
|
||||
start = 0
|
||||
if chat_request.messages[0].role == "system":
|
||||
args["system"] = chat_request.messages[0].content
|
||||
args["messages"] = [
|
||||
{"role": msg.role, "content": msg.content}
|
||||
for msg in chat_request.messages[1:]
|
||||
]
|
||||
else:
|
||||
args["messages"] = [
|
||||
{"role": msg.role, "content": msg.content}
|
||||
for msg in chat_request.messages
|
||||
]
|
||||
|
||||
start = 1
|
||||
args["messages"] = self._parse_messages(chat_request.messages[start:])
|
||||
return args
|
||||
|
||||
def chat(self, chat_request: ChatRequest) -> ChatResponse:
|
||||
@@ -243,6 +277,8 @@ class Llama2Model(BedrockModel):
|
||||
# TODO: Add validation
|
||||
for i in range(start, len(messages)):
|
||||
msg = messages[i]
|
||||
if not isinstance(msg.content, str):
|
||||
raise HTTPException(status_code=400, detail="Content must be a string for Llama 2 model")
|
||||
if msg.role == "user":
|
||||
if end_turn:
|
||||
prompt += bos_token + "[INST] "
|
||||
@@ -325,6 +361,8 @@ class MistralModel(BedrockModel):
|
||||
# TODO: Add validation
|
||||
for i in range(start, len(messages)):
|
||||
msg = messages[i]
|
||||
if not isinstance(msg.content, str):
|
||||
raise HTTPException(status_code=400, detail="Content must be a string for Mistral/Mixtral model")
|
||||
if msg.role == "user":
|
||||
if end_turn:
|
||||
prompt += bos_token + "[INST] "
|
||||
|
||||
@@ -17,7 +17,7 @@ router = APIRouter(
|
||||
)
|
||||
|
||||
|
||||
@router.post("/completions", response_model=ChatResponse | ChatStreamResponse)
|
||||
@router.post("/completions", response_model=ChatResponse | ChatStreamResponse, response_model_exclude_none=True)
|
||||
async def chat_completions(
|
||||
chat_request: Annotated[
|
||||
ChatRequest,
|
||||
|
||||
@@ -16,10 +16,36 @@ class Models(BaseModel):
|
||||
data: list[Model] = []
|
||||
|
||||
|
||||
class TextContent(BaseModel):
|
||||
type: Literal["text"] = "text"
|
||||
text: str
|
||||
|
||||
|
||||
class ImageUrl(BaseModel):
|
||||
url: str
|
||||
detail: str | None = "auto"
|
||||
|
||||
|
||||
class ImageContent(BaseModel):
|
||||
type: Literal["image_url"] = "image"
|
||||
image_url: ImageUrl
|
||||
|
||||
|
||||
class ChatRequestMessage(BaseModel):
|
||||
name: str | None = None
|
||||
role: Literal["user", "assistant", "system"]
|
||||
content: str
|
||||
content: str | list[TextContent | ImageContent]
|
||||
|
||||
|
||||
class Function(BaseModel):
|
||||
name: str
|
||||
description: str | None = None
|
||||
parameters: object
|
||||
|
||||
|
||||
class Tool(BaseModel):
|
||||
type: Literal["function"] = "function"
|
||||
function: Function
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
@@ -33,6 +59,8 @@ class ChatRequest(BaseModel):
|
||||
user: str | None = None # Not used
|
||||
max_tokens: int | None = 2048
|
||||
n: int | None = 1 # Not used
|
||||
tools: list[Tool] | None = None
|
||||
tool_choice: str | object = "auto"
|
||||
|
||||
|
||||
class Usage(BaseModel):
|
||||
@@ -41,10 +69,22 @@ class Usage(BaseModel):
|
||||
total_tokens: int
|
||||
|
||||
|
||||
class ResponseFunction(BaseModel):
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
id: str
|
||||
type: Literal["function"] = "function"
|
||||
function: ResponseFunction
|
||||
|
||||
|
||||
class ChatResponseMessage(BaseModel):
|
||||
# tool_calls
|
||||
role: Literal["assistant"] | None = None
|
||||
content: str | None = None
|
||||
tool_calls: list[ToolCall] | None = None
|
||||
|
||||
|
||||
class BaseChoice(BaseModel):
|
||||
|
||||
Reference in New Issue
Block a user