Add multimodal support
This commit is contained in:
@@ -1,9 +1,12 @@
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
from typing import AsyncIterable, Iterable
|
||||
|
||||
import boto3
|
||||
import requests
|
||||
import tiktoken
|
||||
from fastapi import HTTPException
|
||||
|
||||
from api.models.base import BaseChatModel, BaseEmbeddingsModel
|
||||
from api.schema import (
|
||||
@@ -20,7 +23,7 @@ from api.schema import (
|
||||
EmbeddingsRequest,
|
||||
EmbeddingsResponse,
|
||||
EmbeddingsUsage,
|
||||
Embedding,
|
||||
Embedding, TextContent,
|
||||
)
|
||||
from api.setting import DEBUG, AWS_REGION
|
||||
|
||||
@@ -147,6 +150,44 @@ def get_model(model_id: str) -> BedrockModel:
|
||||
class ClaudeModel(BedrockModel):
|
||||
anthropic_version = "bedrock-2023-05-31"
|
||||
|
||||
def _get_base64_image(self, image_url: str):
|
||||
# Send a request to the image URL
|
||||
response = requests.get(image_url)
|
||||
# Check if the request was successful
|
||||
if response.status_code == 200:
|
||||
# Get the image content
|
||||
image_content = response.content
|
||||
# Encode the image content as base64
|
||||
base64_image = base64.b64encode(image_content)
|
||||
return base64_image.decode('utf-8')
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Unable to access the image url")
|
||||
|
||||
def _parse_messages(self, messages: list[ChatRequestMessage]) -> list[dict]:
|
||||
# Refer to: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
|
||||
converted_messages = []
|
||||
for msg in messages:
|
||||
if isinstance(msg.content, str):
|
||||
converted_messages.append({"role": msg.role, "content": msg.content})
|
||||
continue
|
||||
|
||||
content_parts = []
|
||||
for part in msg.content:
|
||||
if isinstance(part, TextContent):
|
||||
content_parts.append(part.model_dump())
|
||||
else:
|
||||
content_parts.append({
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/jpeg",
|
||||
"data": self._get_base64_image(part.image_url.url)
|
||||
}
|
||||
})
|
||||
|
||||
converted_messages.append({"role": msg.role, "content": content_parts})
|
||||
return converted_messages
|
||||
|
||||
def _parse_args(self, chat_request: ChatRequest) -> dict:
|
||||
args = {
|
||||
"anthropic_version": self.anthropic_version,
|
||||
@@ -154,18 +195,11 @@ class ClaudeModel(BedrockModel):
|
||||
"top_p": chat_request.top_p,
|
||||
"temperature": chat_request.temperature,
|
||||
}
|
||||
start = 0
|
||||
if chat_request.messages[0].role == "system":
|
||||
args["system"] = chat_request.messages[0].content
|
||||
args["messages"] = [
|
||||
{"role": msg.role, "content": msg.content}
|
||||
for msg in chat_request.messages[1:]
|
||||
]
|
||||
else:
|
||||
args["messages"] = [
|
||||
{"role": msg.role, "content": msg.content}
|
||||
for msg in chat_request.messages
|
||||
]
|
||||
|
||||
start = 1
|
||||
args["messages"] = self._parse_messages(chat_request.messages[start:])
|
||||
return args
|
||||
|
||||
def chat(self, chat_request: ChatRequest) -> ChatResponse:
|
||||
@@ -243,6 +277,8 @@ class Llama2Model(BedrockModel):
|
||||
# TODO: Add validation
|
||||
for i in range(start, len(messages)):
|
||||
msg = messages[i]
|
||||
if not isinstance(msg.content, str):
|
||||
raise HTTPException(status_code=400, detail="Content must be a string for Llama 2 model")
|
||||
if msg.role == "user":
|
||||
if end_turn:
|
||||
prompt += bos_token + "[INST] "
|
||||
@@ -325,6 +361,8 @@ class MistralModel(BedrockModel):
|
||||
# TODO: Add validation
|
||||
for i in range(start, len(messages)):
|
||||
msg = messages[i]
|
||||
if not isinstance(msg.content, str):
|
||||
raise HTTPException(status_code=400, detail="Content must be a string for Mistral/Mixtral model")
|
||||
if msg.role == "user":
|
||||
if end_turn:
|
||||
prompt += bos_token + "[INST] "
|
||||
|
||||
Reference in New Issue
Block a user