Add multimodal support

This commit is contained in:
Aiden Dai
2024-04-02 13:10:15 +08:00
parent 31ae10a275
commit e49a579a41
3 changed files with 91 additions and 13 deletions

View File

@@ -1,9 +1,12 @@
import base64
import json
import logging
from typing import AsyncIterable, Iterable
import boto3
import requests
import tiktoken
from fastapi import HTTPException
from api.models.base import BaseChatModel, BaseEmbeddingsModel
from api.schema import (
@@ -20,7 +23,7 @@ from api.schema import (
EmbeddingsRequest,
EmbeddingsResponse,
EmbeddingsUsage,
Embedding,
Embedding, TextContent,
)
from api.setting import DEBUG, AWS_REGION
@@ -147,6 +150,44 @@ def get_model(model_id: str) -> BedrockModel:
class ClaudeModel(BedrockModel):
anthropic_version = "bedrock-2023-05-31"
def _get_base64_image(self, image_url: str):
# Send a request to the image URL
response = requests.get(image_url)
# Check if the request was successful
if response.status_code == 200:
# Get the image content
image_content = response.content
# Encode the image content as base64
base64_image = base64.b64encode(image_content)
return base64_image.decode('utf-8')
else:
raise HTTPException(status_code=500, detail="Unable to access the image url")
def _parse_messages(self, messages: list[ChatRequestMessage]) -> list[dict]:
# Refer to: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
converted_messages = []
for msg in messages:
if isinstance(msg.content, str):
converted_messages.append({"role": msg.role, "content": msg.content})
continue
content_parts = []
for part in msg.content:
if isinstance(part, TextContent):
content_parts.append(part.model_dump())
else:
content_parts.append({
"type": "image",
"source": {
"type": "base64",
"media_type": "image/jpeg",
"data": self._get_base64_image(part.image_url.url)
}
})
converted_messages.append({"role": msg.role, "content": content_parts})
return converted_messages
def _parse_args(self, chat_request: ChatRequest) -> dict:
args = {
"anthropic_version": self.anthropic_version,
@@ -154,18 +195,11 @@ class ClaudeModel(BedrockModel):
"top_p": chat_request.top_p,
"temperature": chat_request.temperature,
}
start = 0
if chat_request.messages[0].role == "system":
args["system"] = chat_request.messages[0].content
args["messages"] = [
{"role": msg.role, "content": msg.content}
for msg in chat_request.messages[1:]
]
else:
args["messages"] = [
{"role": msg.role, "content": msg.content}
for msg in chat_request.messages
]
start = 1
args["messages"] = self._parse_messages(chat_request.messages[start:])
return args
def chat(self, chat_request: ChatRequest) -> ChatResponse:
@@ -243,6 +277,8 @@ class Llama2Model(BedrockModel):
# TODO: Add validation
for i in range(start, len(messages)):
msg = messages[i]
if not isinstance(msg.content, str):
raise HTTPException(status_code=400, detail="Content must be a string for Llama 2 model")
if msg.role == "user":
if end_turn:
prompt += bos_token + "[INST] "
@@ -325,6 +361,8 @@ class MistralModel(BedrockModel):
# TODO: Add validation
for i in range(start, len(messages)):
msg = messages[i]
if not isinstance(msg.content, str):
raise HTTPException(status_code=400, detail="Content must be a string for Mistral/Mixtral model")
if msg.role == "user":
if end_turn:
prompt += bos_token + "[INST] "