Add multimodal support
This commit is contained in:
@@ -1,9 +1,12 @@
|
|||||||
|
import base64
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import AsyncIterable, Iterable
|
from typing import AsyncIterable, Iterable
|
||||||
|
|
||||||
import boto3
|
import boto3
|
||||||
|
import requests
|
||||||
import tiktoken
|
import tiktoken
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
from api.models.base import BaseChatModel, BaseEmbeddingsModel
|
from api.models.base import BaseChatModel, BaseEmbeddingsModel
|
||||||
from api.schema import (
|
from api.schema import (
|
||||||
@@ -20,7 +23,7 @@ from api.schema import (
|
|||||||
EmbeddingsRequest,
|
EmbeddingsRequest,
|
||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
EmbeddingsUsage,
|
EmbeddingsUsage,
|
||||||
Embedding,
|
Embedding, TextContent,
|
||||||
)
|
)
|
||||||
from api.setting import DEBUG, AWS_REGION
|
from api.setting import DEBUG, AWS_REGION
|
||||||
|
|
||||||
@@ -147,6 +150,44 @@ def get_model(model_id: str) -> BedrockModel:
|
|||||||
class ClaudeModel(BedrockModel):
|
class ClaudeModel(BedrockModel):
|
||||||
anthropic_version = "bedrock-2023-05-31"
|
anthropic_version = "bedrock-2023-05-31"
|
||||||
|
|
||||||
|
def _get_base64_image(self, image_url: str):
|
||||||
|
# Send a request to the image URL
|
||||||
|
response = requests.get(image_url)
|
||||||
|
# Check if the request was successful
|
||||||
|
if response.status_code == 200:
|
||||||
|
# Get the image content
|
||||||
|
image_content = response.content
|
||||||
|
# Encode the image content as base64
|
||||||
|
base64_image = base64.b64encode(image_content)
|
||||||
|
return base64_image.decode('utf-8')
|
||||||
|
else:
|
||||||
|
raise HTTPException(status_code=500, detail="Unable to access the image url")
|
||||||
|
|
||||||
|
def _parse_messages(self, messages: list[ChatRequestMessage]) -> list[dict]:
|
||||||
|
# Refer to: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
|
||||||
|
converted_messages = []
|
||||||
|
for msg in messages:
|
||||||
|
if isinstance(msg.content, str):
|
||||||
|
converted_messages.append({"role": msg.role, "content": msg.content})
|
||||||
|
continue
|
||||||
|
|
||||||
|
content_parts = []
|
||||||
|
for part in msg.content:
|
||||||
|
if isinstance(part, TextContent):
|
||||||
|
content_parts.append(part.model_dump())
|
||||||
|
else:
|
||||||
|
content_parts.append({
|
||||||
|
"type": "image",
|
||||||
|
"source": {
|
||||||
|
"type": "base64",
|
||||||
|
"media_type": "image/jpeg",
|
||||||
|
"data": self._get_base64_image(part.image_url.url)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
converted_messages.append({"role": msg.role, "content": content_parts})
|
||||||
|
return converted_messages
|
||||||
|
|
||||||
def _parse_args(self, chat_request: ChatRequest) -> dict:
|
def _parse_args(self, chat_request: ChatRequest) -> dict:
|
||||||
args = {
|
args = {
|
||||||
"anthropic_version": self.anthropic_version,
|
"anthropic_version": self.anthropic_version,
|
||||||
@@ -154,18 +195,11 @@ class ClaudeModel(BedrockModel):
|
|||||||
"top_p": chat_request.top_p,
|
"top_p": chat_request.top_p,
|
||||||
"temperature": chat_request.temperature,
|
"temperature": chat_request.temperature,
|
||||||
}
|
}
|
||||||
|
start = 0
|
||||||
if chat_request.messages[0].role == "system":
|
if chat_request.messages[0].role == "system":
|
||||||
args["system"] = chat_request.messages[0].content
|
args["system"] = chat_request.messages[0].content
|
||||||
args["messages"] = [
|
start = 1
|
||||||
{"role": msg.role, "content": msg.content}
|
args["messages"] = self._parse_messages(chat_request.messages[start:])
|
||||||
for msg in chat_request.messages[1:]
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
args["messages"] = [
|
|
||||||
{"role": msg.role, "content": msg.content}
|
|
||||||
for msg in chat_request.messages
|
|
||||||
]
|
|
||||||
|
|
||||||
return args
|
return args
|
||||||
|
|
||||||
def chat(self, chat_request: ChatRequest) -> ChatResponse:
|
def chat(self, chat_request: ChatRequest) -> ChatResponse:
|
||||||
@@ -243,6 +277,8 @@ class Llama2Model(BedrockModel):
|
|||||||
# TODO: Add validation
|
# TODO: Add validation
|
||||||
for i in range(start, len(messages)):
|
for i in range(start, len(messages)):
|
||||||
msg = messages[i]
|
msg = messages[i]
|
||||||
|
if not isinstance(msg.content, str):
|
||||||
|
raise HTTPException(status_code=400, detail="Content must be a string for Llama 2 model")
|
||||||
if msg.role == "user":
|
if msg.role == "user":
|
||||||
if end_turn:
|
if end_turn:
|
||||||
prompt += bos_token + "[INST] "
|
prompt += bos_token + "[INST] "
|
||||||
@@ -325,6 +361,8 @@ class MistralModel(BedrockModel):
|
|||||||
# TODO: Add validation
|
# TODO: Add validation
|
||||||
for i in range(start, len(messages)):
|
for i in range(start, len(messages)):
|
||||||
msg = messages[i]
|
msg = messages[i]
|
||||||
|
if not isinstance(msg.content, str):
|
||||||
|
raise HTTPException(status_code=400, detail="Content must be a string for Mistral/Mixtral model")
|
||||||
if msg.role == "user":
|
if msg.role == "user":
|
||||||
if end_turn:
|
if end_turn:
|
||||||
prompt += bos_token + "[INST] "
|
prompt += bos_token + "[INST] "
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ router = APIRouter(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/completions", response_model=ChatResponse | ChatStreamResponse)
|
@router.post("/completions", response_model=ChatResponse | ChatStreamResponse, response_model_exclude_none=True)
|
||||||
async def chat_completions(
|
async def chat_completions(
|
||||||
chat_request: Annotated[
|
chat_request: Annotated[
|
||||||
ChatRequest,
|
ChatRequest,
|
||||||
|
|||||||
@@ -16,10 +16,36 @@ class Models(BaseModel):
|
|||||||
data: list[Model] = []
|
data: list[Model] = []
|
||||||
|
|
||||||
|
|
||||||
|
class TextContent(BaseModel):
|
||||||
|
type: Literal["text"] = "text"
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
class ImageUrl(BaseModel):
|
||||||
|
url: str
|
||||||
|
detail: str | None = "auto"
|
||||||
|
|
||||||
|
|
||||||
|
class ImageContent(BaseModel):
|
||||||
|
type: Literal["image_url"] = "image"
|
||||||
|
image_url: ImageUrl
|
||||||
|
|
||||||
|
|
||||||
class ChatRequestMessage(BaseModel):
|
class ChatRequestMessage(BaseModel):
|
||||||
name: str | None = None
|
name: str | None = None
|
||||||
role: Literal["user", "assistant", "system"]
|
role: Literal["user", "assistant", "system"]
|
||||||
content: str
|
content: str | list[TextContent | ImageContent]
|
||||||
|
|
||||||
|
|
||||||
|
class Function(BaseModel):
|
||||||
|
name: str
|
||||||
|
description: str | None = None
|
||||||
|
parameters: object
|
||||||
|
|
||||||
|
|
||||||
|
class Tool(BaseModel):
|
||||||
|
type: Literal["function"] = "function"
|
||||||
|
function: Function
|
||||||
|
|
||||||
|
|
||||||
class ChatRequest(BaseModel):
|
class ChatRequest(BaseModel):
|
||||||
@@ -33,6 +59,8 @@ class ChatRequest(BaseModel):
|
|||||||
user: str | None = None # Not used
|
user: str | None = None # Not used
|
||||||
max_tokens: int | None = 2048
|
max_tokens: int | None = 2048
|
||||||
n: int | None = 1 # Not used
|
n: int | None = 1 # Not used
|
||||||
|
tools: list[Tool] | None = None
|
||||||
|
tool_choice: str | object = "auto"
|
||||||
|
|
||||||
|
|
||||||
class Usage(BaseModel):
|
class Usage(BaseModel):
|
||||||
@@ -41,10 +69,22 @@ class Usage(BaseModel):
|
|||||||
total_tokens: int
|
total_tokens: int
|
||||||
|
|
||||||
|
|
||||||
|
class ResponseFunction(BaseModel):
|
||||||
|
name: str
|
||||||
|
arguments: str
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCall(BaseModel):
|
||||||
|
id: str
|
||||||
|
type: Literal["function"] = "function"
|
||||||
|
function: ResponseFunction
|
||||||
|
|
||||||
|
|
||||||
class ChatResponseMessage(BaseModel):
|
class ChatResponseMessage(BaseModel):
|
||||||
# tool_calls
|
# tool_calls
|
||||||
role: Literal["assistant"] | None = None
|
role: Literal["assistant"] | None = None
|
||||||
content: str | None = None
|
content: str | None = None
|
||||||
|
tool_calls: list[ToolCall] | None = None
|
||||||
|
|
||||||
|
|
||||||
class BaseChoice(BaseModel):
|
class BaseChoice(BaseModel):
|
||||||
|
|||||||
Reference in New Issue
Block a user