Add multimodal support

This commit is contained in:
Aiden Dai
2024-04-02 13:10:15 +08:00
parent 31ae10a275
commit e49a579a41
3 changed files with 91 additions and 13 deletions

View File

@@ -1,9 +1,12 @@
import base64
import json import json
import logging import logging
from typing import AsyncIterable, Iterable from typing import AsyncIterable, Iterable
import boto3 import boto3
import requests
import tiktoken import tiktoken
from fastapi import HTTPException
from api.models.base import BaseChatModel, BaseEmbeddingsModel from api.models.base import BaseChatModel, BaseEmbeddingsModel
from api.schema import ( from api.schema import (
@@ -20,7 +23,7 @@ from api.schema import (
EmbeddingsRequest, EmbeddingsRequest,
EmbeddingsResponse, EmbeddingsResponse,
EmbeddingsUsage, EmbeddingsUsage,
Embedding, Embedding, TextContent,
) )
from api.setting import DEBUG, AWS_REGION from api.setting import DEBUG, AWS_REGION
@@ -147,6 +150,44 @@ def get_model(model_id: str) -> BedrockModel:
class ClaudeModel(BedrockModel): class ClaudeModel(BedrockModel):
anthropic_version = "bedrock-2023-05-31" anthropic_version = "bedrock-2023-05-31"
def _get_base64_image(self, image_url: str):
# Send a request to the image URL
response = requests.get(image_url)
# Check if the request was successful
if response.status_code == 200:
# Get the image content
image_content = response.content
# Encode the image content as base64
base64_image = base64.b64encode(image_content)
return base64_image.decode('utf-8')
else:
raise HTTPException(status_code=500, detail="Unable to access the image url")
def _parse_messages(self, messages: list[ChatRequestMessage]) -> list[dict]:
# Refer to: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
converted_messages = []
for msg in messages:
if isinstance(msg.content, str):
converted_messages.append({"role": msg.role, "content": msg.content})
continue
content_parts = []
for part in msg.content:
if isinstance(part, TextContent):
content_parts.append(part.model_dump())
else:
content_parts.append({
"type": "image",
"source": {
"type": "base64",
"media_type": "image/jpeg",
"data": self._get_base64_image(part.image_url.url)
}
})
converted_messages.append({"role": msg.role, "content": content_parts})
return converted_messages
def _parse_args(self, chat_request: ChatRequest) -> dict: def _parse_args(self, chat_request: ChatRequest) -> dict:
args = { args = {
"anthropic_version": self.anthropic_version, "anthropic_version": self.anthropic_version,
@@ -154,18 +195,11 @@ class ClaudeModel(BedrockModel):
"top_p": chat_request.top_p, "top_p": chat_request.top_p,
"temperature": chat_request.temperature, "temperature": chat_request.temperature,
} }
start = 0
if chat_request.messages[0].role == "system": if chat_request.messages[0].role == "system":
args["system"] = chat_request.messages[0].content args["system"] = chat_request.messages[0].content
args["messages"] = [ start = 1
{"role": msg.role, "content": msg.content} args["messages"] = self._parse_messages(chat_request.messages[start:])
for msg in chat_request.messages[1:]
]
else:
args["messages"] = [
{"role": msg.role, "content": msg.content}
for msg in chat_request.messages
]
return args return args
def chat(self, chat_request: ChatRequest) -> ChatResponse: def chat(self, chat_request: ChatRequest) -> ChatResponse:
@@ -243,6 +277,8 @@ class Llama2Model(BedrockModel):
# TODO: Add validation # TODO: Add validation
for i in range(start, len(messages)): for i in range(start, len(messages)):
msg = messages[i] msg = messages[i]
if not isinstance(msg.content, str):
raise HTTPException(status_code=400, detail="Content must be a string for Llama 2 model")
if msg.role == "user": if msg.role == "user":
if end_turn: if end_turn:
prompt += bos_token + "[INST] " prompt += bos_token + "[INST] "
@@ -325,6 +361,8 @@ class MistralModel(BedrockModel):
# TODO: Add validation # TODO: Add validation
for i in range(start, len(messages)): for i in range(start, len(messages)):
msg = messages[i] msg = messages[i]
if not isinstance(msg.content, str):
raise HTTPException(status_code=400, detail="Content must be a string for Mistral/Mixtral model")
if msg.role == "user": if msg.role == "user":
if end_turn: if end_turn:
prompt += bos_token + "[INST] " prompt += bos_token + "[INST] "

View File

@@ -17,7 +17,7 @@ router = APIRouter(
) )
@router.post("/completions", response_model=ChatResponse | ChatStreamResponse) @router.post("/completions", response_model=ChatResponse | ChatStreamResponse, response_model_exclude_none=True)
async def chat_completions( async def chat_completions(
chat_request: Annotated[ chat_request: Annotated[
ChatRequest, ChatRequest,

View File

@@ -16,10 +16,36 @@ class Models(BaseModel):
data: list[Model] = [] data: list[Model] = []
class TextContent(BaseModel):
type: Literal["text"] = "text"
text: str
class ImageUrl(BaseModel):
url: str
detail: str | None = "auto"
class ImageContent(BaseModel):
type: Literal["image_url"] = "image"
image_url: ImageUrl
class ChatRequestMessage(BaseModel): class ChatRequestMessage(BaseModel):
name: str | None = None name: str | None = None
role: Literal["user", "assistant", "system"] role: Literal["user", "assistant", "system"]
content: str content: str | list[TextContent | ImageContent]
class Function(BaseModel):
name: str
description: str | None = None
parameters: object
class Tool(BaseModel):
type: Literal["function"] = "function"
function: Function
class ChatRequest(BaseModel): class ChatRequest(BaseModel):
@@ -33,6 +59,8 @@ class ChatRequest(BaseModel):
user: str | None = None # Not used user: str | None = None # Not used
max_tokens: int | None = 2048 max_tokens: int | None = 2048
n: int | None = 1 # Not used n: int | None = 1 # Not used
tools: list[Tool] | None = None
tool_choice: str | object = "auto"
class Usage(BaseModel): class Usage(BaseModel):
@@ -41,10 +69,22 @@ class Usage(BaseModel):
total_tokens: int total_tokens: int
class ResponseFunction(BaseModel):
name: str
arguments: str
class ToolCall(BaseModel):
id: str
type: Literal["function"] = "function"
function: ResponseFunction
class ChatResponseMessage(BaseModel): class ChatResponseMessage(BaseModel):
# tool_calls # tool_calls
role: Literal["assistant"] | None = None role: Literal["assistant"] | None = None
content: str | None = None content: str | None = None
tool_calls: list[ToolCall] | None = None
class BaseChoice(BaseModel): class BaseChoice(BaseModel):