diff --git a/src/api/models/bedrock.py b/src/api/models/bedrock.py index 7f6035c..f5bd865 100644 --- a/src/api/models/bedrock.py +++ b/src/api/models/bedrock.py @@ -1,9 +1,12 @@ +import base64 import json import logging from typing import AsyncIterable, Iterable import boto3 +import requests import tiktoken +from fastapi import HTTPException from api.models.base import BaseChatModel, BaseEmbeddingsModel from api.schema import ( @@ -20,7 +23,7 @@ from api.schema import ( EmbeddingsRequest, EmbeddingsResponse, EmbeddingsUsage, - Embedding, + Embedding, TextContent, ) from api.setting import DEBUG, AWS_REGION @@ -147,6 +150,44 @@ def get_model(model_id: str) -> BedrockModel: class ClaudeModel(BedrockModel): anthropic_version = "bedrock-2023-05-31" + def _get_base64_image(self, image_url: str): + # Send a request to the image URL + response = requests.get(image_url) + # Check if the request was successful + if response.status_code == 200: + # Get the image content + image_content = response.content + # Encode the image content as base64 + base64_image = base64.b64encode(image_content) + return base64_image.decode('utf-8') + else: + raise HTTPException(status_code=500, detail="Unable to access the image url") + + def _parse_messages(self, messages: list[ChatRequestMessage]) -> list[dict]: + # Refer to: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html + converted_messages = [] + for msg in messages: + if isinstance(msg.content, str): + converted_messages.append({"role": msg.role, "content": msg.content}) + continue + + content_parts = [] + for part in msg.content: + if isinstance(part, TextContent): + content_parts.append(part.model_dump()) + else: + content_parts.append({ + "type": "image", + "source": { + "type": "base64", + "media_type": "image/jpeg", + "data": self._get_base64_image(part.image_url.url) + } + }) + + converted_messages.append({"role": msg.role, "content": content_parts}) + return converted_messages + def _parse_args(self, chat_request: ChatRequest) -> dict: args = { "anthropic_version": self.anthropic_version, @@ -154,18 +195,11 @@ class ClaudeModel(BedrockModel): "top_p": chat_request.top_p, "temperature": chat_request.temperature, } + start = 0 if chat_request.messages[0].role == "system": args["system"] = chat_request.messages[0].content - args["messages"] = [ - {"role": msg.role, "content": msg.content} - for msg in chat_request.messages[1:] - ] - else: - args["messages"] = [ - {"role": msg.role, "content": msg.content} - for msg in chat_request.messages - ] - + start = 1 + args["messages"] = self._parse_messages(chat_request.messages[start:]) return args def chat(self, chat_request: ChatRequest) -> ChatResponse: @@ -243,6 +277,8 @@ class Llama2Model(BedrockModel): # TODO: Add validation for i in range(start, len(messages)): msg = messages[i] + if not isinstance(msg.content, str): + raise HTTPException(status_code=400, detail="Content must be a string for Llama 2 model") if msg.role == "user": if end_turn: prompt += bos_token + "[INST] " @@ -325,6 +361,8 @@ class MistralModel(BedrockModel): # TODO: Add validation for i in range(start, len(messages)): msg = messages[i] + if not isinstance(msg.content, str): + raise HTTPException(status_code=400, detail="Content must be a string for Mistral/Mixtral model") if msg.role == "user": if end_turn: prompt += bos_token + "[INST] " diff --git a/src/api/routers/chat.py b/src/api/routers/chat.py index efc3ced..3e6d69d 100644 --- a/src/api/routers/chat.py +++ b/src/api/routers/chat.py @@ -17,7 +17,7 @@ router = APIRouter( ) -@router.post("/completions", response_model=ChatResponse | ChatStreamResponse) +@router.post("/completions", response_model=ChatResponse | ChatStreamResponse, response_model_exclude_none=True) async def chat_completions( chat_request: Annotated[ ChatRequest, diff --git a/src/api/schema.py b/src/api/schema.py index 5791441..aa87a58 100644 --- a/src/api/schema.py +++ b/src/api/schema.py @@ -16,10 +16,36 @@ class Models(BaseModel): data: list[Model] = [] +class TextContent(BaseModel): + type: Literal["text"] = "text" + text: str + + +class ImageUrl(BaseModel): + url: str + detail: str | None = "auto" + + +class ImageContent(BaseModel): + type: Literal["image_url"] = "image" + image_url: ImageUrl + + class ChatRequestMessage(BaseModel): name: str | None = None role: Literal["user", "assistant", "system"] - content: str + content: str | list[TextContent | ImageContent] + + +class Function(BaseModel): + name: str + description: str | None = None + parameters: object + + +class Tool(BaseModel): + type: Literal["function"] = "function" + function: Function class ChatRequest(BaseModel): @@ -33,6 +59,8 @@ class ChatRequest(BaseModel): user: str | None = None # Not used max_tokens: int | None = 2048 n: int | None = 1 # Not used + tools: list[Tool] | None = None + tool_choice: str | object = "auto" class Usage(BaseModel): @@ -41,10 +69,22 @@ class Usage(BaseModel): total_tokens: int +class ResponseFunction(BaseModel): + name: str + arguments: str + + +class ToolCall(BaseModel): + id: str + type: Literal["function"] = "function" + function: ResponseFunction + + class ChatResponseMessage(BaseModel): # tool_calls role: Literal["assistant"] | None = None content: str | None = None + tool_calls: list[ToolCall] | None = None class BaseChoice(BaseModel):