Add multimodal support

This commit is contained in:
Aiden Dai
2024-04-02 13:10:15 +08:00
parent 31ae10a275
commit e49a579a41
3 changed files with 91 additions and 13 deletions

View File

@@ -1,9 +1,12 @@
import base64
import json
import logging
from typing import AsyncIterable, Iterable
import boto3
import requests
import tiktoken
from fastapi import HTTPException
from api.models.base import BaseChatModel, BaseEmbeddingsModel
from api.schema import (
@@ -20,7 +23,7 @@ from api.schema import (
EmbeddingsRequest,
EmbeddingsResponse,
EmbeddingsUsage,
Embedding,
Embedding, TextContent,
)
from api.setting import DEBUG, AWS_REGION
@@ -147,6 +150,44 @@ def get_model(model_id: str) -> BedrockModel:
class ClaudeModel(BedrockModel):
anthropic_version = "bedrock-2023-05-31"
def _get_base64_image(self, image_url: str):
# Send a request to the image URL
response = requests.get(image_url)
# Check if the request was successful
if response.status_code == 200:
# Get the image content
image_content = response.content
# Encode the image content as base64
base64_image = base64.b64encode(image_content)
return base64_image.decode('utf-8')
else:
raise HTTPException(status_code=500, detail="Unable to access the image url")
def _parse_messages(self, messages: list[ChatRequestMessage]) -> list[dict]:
# Refer to: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
converted_messages = []
for msg in messages:
if isinstance(msg.content, str):
converted_messages.append({"role": msg.role, "content": msg.content})
continue
content_parts = []
for part in msg.content:
if isinstance(part, TextContent):
content_parts.append(part.model_dump())
else:
content_parts.append({
"type": "image",
"source": {
"type": "base64",
"media_type": "image/jpeg",
"data": self._get_base64_image(part.image_url.url)
}
})
converted_messages.append({"role": msg.role, "content": content_parts})
return converted_messages
def _parse_args(self, chat_request: ChatRequest) -> dict:
args = {
"anthropic_version": self.anthropic_version,
@@ -154,18 +195,11 @@ class ClaudeModel(BedrockModel):
"top_p": chat_request.top_p,
"temperature": chat_request.temperature,
}
start = 0
if chat_request.messages[0].role == "system":
args["system"] = chat_request.messages[0].content
args["messages"] = [
{"role": msg.role, "content": msg.content}
for msg in chat_request.messages[1:]
]
else:
args["messages"] = [
{"role": msg.role, "content": msg.content}
for msg in chat_request.messages
]
start = 1
args["messages"] = self._parse_messages(chat_request.messages[start:])
return args
def chat(self, chat_request: ChatRequest) -> ChatResponse:
@@ -243,6 +277,8 @@ class Llama2Model(BedrockModel):
# TODO: Add validation
for i in range(start, len(messages)):
msg = messages[i]
if not isinstance(msg.content, str):
raise HTTPException(status_code=400, detail="Content must be a string for Llama 2 model")
if msg.role == "user":
if end_turn:
prompt += bos_token + "[INST] "
@@ -325,6 +361,8 @@ class MistralModel(BedrockModel):
# TODO: Add validation
for i in range(start, len(messages)):
msg = messages[i]
if not isinstance(msg.content, str):
raise HTTPException(status_code=400, detail="Content must be a string for Mistral/Mixtral model")
if msg.role == "user":
if end_turn:
prompt += bos_token + "[INST] "

View File

@@ -17,7 +17,7 @@ router = APIRouter(
)
@router.post("/completions", response_model=ChatResponse | ChatStreamResponse)
@router.post("/completions", response_model=ChatResponse | ChatStreamResponse, response_model_exclude_none=True)
async def chat_completions(
chat_request: Annotated[
ChatRequest,

View File

@@ -16,10 +16,36 @@ class Models(BaseModel):
data: list[Model] = []
class TextContent(BaseModel):
type: Literal["text"] = "text"
text: str
class ImageUrl(BaseModel):
url: str
detail: str | None = "auto"
class ImageContent(BaseModel):
type: Literal["image_url"] = "image"
image_url: ImageUrl
class ChatRequestMessage(BaseModel):
name: str | None = None
role: Literal["user", "assistant", "system"]
content: str
content: str | list[TextContent | ImageContent]
class Function(BaseModel):
name: str
description: str | None = None
parameters: object
class Tool(BaseModel):
type: Literal["function"] = "function"
function: Function
class ChatRequest(BaseModel):
@@ -33,6 +59,8 @@ class ChatRequest(BaseModel):
user: str | None = None # Not used
max_tokens: int | None = 2048
n: int | None = 1 # Not used
tools: list[Tool] | None = None
tool_choice: str | object = "auto"
class Usage(BaseModel):
@@ -41,10 +69,22 @@ class Usage(BaseModel):
total_tokens: int
class ResponseFunction(BaseModel):
name: str
arguments: str
class ToolCall(BaseModel):
id: str
type: Literal["function"] = "function"
function: ResponseFunction
class ChatResponseMessage(BaseModel):
# tool_calls
role: Literal["assistant"] | None = None
content: str | None = None
tool_calls: list[ToolCall] | None = None
class BaseChoice(BaseModel):