diff --git a/src/api/models/bedrock.py b/src/api/models/bedrock.py index c345725..28c4b69 100644 --- a/src/api/models/bedrock.py +++ b/src/api/models/bedrock.py @@ -74,19 +74,27 @@ class BedrockModel(BaseChatModel, ABC): if DEBUG: logger.info("Invoke Bedrock Model: " + model_id) logger.info("Bedrock request body: " + body) - if with_stream: - return bedrock_runtime.invoke_model_with_response_stream( + try: + if with_stream: + return bedrock_runtime.invoke_model_with_response_stream( + body=body, + modelId=model_id, + accept=self.accept, + contentType=self.content_type, + ) + return bedrock_runtime.invoke_model( body=body, modelId=model_id, accept=self.accept, contentType=self.content_type, ) - return bedrock_runtime.invoke_model( - body=body, - modelId=model_id, - accept=self.accept, - contentType=self.content_type, - ) + except bedrock_runtime.exceptions.ValidationException as e: + print("Validation Exception") + print(e) + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + logger.error(e) + raise HTTPException(status_code=500, detail=str(e)) @staticmethod def merge_message(messages: list[dict]) -> list[dict]: @@ -185,8 +193,8 @@ class ClaudeModel(BedrockModel): {tools} Please think if you need to use a tool or not for user's question, you must: -1. Respond Y or N inside a xml tag first to indicate that. -2. If a tool is needed, MUST respond a JSON object matching the following schema inside a xml tag: +1. Respond Y or N within tags first to indicate that. +2. If a tool is needed, MUST respond a JSON object matching the following schema within tags: {{"name": $TOOL_NAME, "arguments": {{"$PARAMETER_NAME": "$PARAMETER_VALUE", ...}}}} 3. If no tools is needed, respond with normal text.""" @@ -201,6 +209,7 @@ Please think if you need to use a tool or not for user's question, you must: converted_messages = [] for message in chat_request.messages: if message.role == "system": + assert isinstance(message.content, str) system_prompt += message.content + "\n" elif message.role == "user" and not isinstance(message.content, str): converted_messages.append( @@ -243,9 +252,9 @@ Please think if you need to use a tool or not for user's question, you must: system_prompt += self.tool_prompt.format(tools=tools_str) converted_messages.append({ 'role': 'assistant', - 'content': '' + 'content': '' }) - args["stop_sequences"] = [''] + args["stop_sequences"] = [''] args["messages"] = self.merge_message(converted_messages) if system_prompt: if DEBUG: @@ -267,10 +276,10 @@ Please think if you need to use a tool or not for user's question, you must: tools = None if chat_request.tools: - if message.startswith("Y"): + if message.startswith("Y"): tools = self._parse_tool_message(message) message = None - elif message.startswith("N"): + elif message.startswith("N"): message = message[8:].lstrip("\n") return self._create_response( model=chat_request.model, @@ -283,6 +292,8 @@ Please think if you need to use a tool or not for user's question, you must: ) def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]: + if DEBUG: + logger.info("Raw request: " + chat_request.model_dump_json()) response = self._invoke_model( args=self._parse_args(chat_request), model_id=chat_request.model, @@ -321,7 +332,7 @@ Please think if you need to use a tool or not for user's question, you must: tool_message += chunk_message continue if index < 3: - # Ignore the N, which is 3 tokens + # Ignore the N, which is 3 tokens index += 1 continue if first_token: @@ -350,7 +361,7 @@ Please think if you need to use a tool or not for user's question, you must: if DEBUG: logger.info("Tool message: " + tool_message.replace("\n", " ")) try: - tool_messages = tool_message[tool_message.rindex("") + 6:] + tool_messages = tool_message[tool_message.rindex("") + len(""):] function = json.loads(tool_messages.replace("\n", " ")) args = json.dumps(function.get("arguments", {})) function = ResponseFunction( @@ -365,7 +376,7 @@ Please think if you need to use a tool or not for user's question, you must: ] except Exception as e: - logger.error("Failed to parse tool response") + logger.error("Failed to parse tool response" + str(e)) raise HTTPException(status_code=500, detail="Failed to parse tool response") def _get_base64_image(self, image_url: str) -> tuple[str, str]: @@ -617,12 +628,20 @@ class BedrockEmbeddingsModel(BaseEmbeddingsModel, ABC): if DEBUG: logger.info("Invoke Bedrock Model: " + model_id) logger.info("Bedrock request body: " + body) - return bedrock_runtime.invoke_model( - body=body, - modelId=model_id, - accept=self.accept, - contentType=self.content_type, - ) + try: + return bedrock_runtime.invoke_model( + body=body, + modelId=model_id, + accept=self.accept, + contentType=self.content_type, + ) + except bedrock_runtime.exceptions.ValidationException as e: + print("Validation Exception") + print(e) + raise HTTPException(status_code=400, detail=str(e)) + except Exception as e: + logger.error(e) + raise HTTPException(status_code=500, detail=str(e)) def _create_response( self, @@ -739,31 +758,33 @@ def get_model(model_id: str) -> BedrockModel: model_name = SUPPORTED_BEDROCK_MODELS.get(model_id, "") if DEBUG: logger.info("model name is " + model_name) - if model_name in ["Claude Instant", "Claude", "Claude 3 Sonnet", "Claude 3 Haiku", "Claude 3 Opus"]: - return ClaudeModel() - elif model_name in ["Llama 2 Chat 13B", "Llama 2 Chat 70B"]: - return Llama2Model() - elif model_name in ["Mistral 7B Instruct", "Mixtral 8x7B Instruct", "Mistral Large"]: - return MistralModel() - else: - logger.error("Unsupported model id " + model_id) - raise HTTPException( - status_code=500, - detail="Unsupported model id " + model_id, - ) + # Not using start_with here in case of complex scenarios. + # The downside is to change this everytime for a new model supported. + match model_name: + case "Claude Instant" | "Claude" | "Claude 3 Sonnet" | "Claude 3 Haiku" | "Claude 3 Opus": + return ClaudeModel() + case "Llama 2 Chat 13B" | "Llama 2 Chat 70B": + return Llama2Model() + case "Mistral 7B Instruct" | "Mixtral 8x7B Instruct" | "Mistral Large": + return MistralModel() + case _: + logger.error("Unsupported model id " + model_id) + raise HTTPException( + status_code=400, + detail="Unsupported model id " + model_id, + ) def get_embeddings_model(model_id: str) -> BedrockEmbeddingsModel: model_name = SUPPORTED_BEDROCK_EMBEDDING_MODELS.get(model_id, "") if DEBUG: logger.info("model name is " + model_name) - if model_name in ["Cohere Embed Multilingual", "Cohere Embed English"]: - return CohereEmbeddingsModel() - elif model_name in ["Titan Embeddings G1 - Text", "Titan Multimodal Embeddings G1"]: - return TitanEmbeddingsModel() - else: - logger.error("Unsupported model id " + model_id) - raise HTTPException( - status_code=500, - detail="Unsupported model id " + model_id, - ) + match model_name: + case "Cohere Embed Multilingual" | "Cohere Embed English": + return CohereEmbeddingsModel() + case _: + logger.error("Unsupported model id " + model_id) + raise HTTPException( + status_code=400, + detail="Unsupported embedding model id " + model_id, + ) diff --git a/src/api/routers/chat.py b/src/api/routers/chat.py index 3e6d69d..f2758fa 100644 --- a/src/api/routers/chat.py +++ b/src/api/routers/chat.py @@ -1,15 +1,13 @@ from typing import Annotated -from fastapi import APIRouter, Depends, Body, HTTPException +from fastapi import APIRouter, Depends, Body from fastapi.responses import StreamingResponse from api.auth import api_key_auth -from api.models import get_model, SUPPORTED_BEDROCK_MODELS +from api.models import get_model from api.schema import ChatRequest, ChatResponse, ChatStreamResponse from api.setting import DEFAULT_MODEL -router = APIRouter() - router = APIRouter( prefix="/chat", dependencies=[Depends(api_key_auth)], @@ -36,15 +34,11 @@ async def chat_completions( ): if chat_request.model.lower().startswith("gpt-"): chat_request.model = DEFAULT_MODEL - if chat_request.model not in SUPPORTED_BEDROCK_MODELS.keys(): - raise HTTPException(status_code=400, detail="Unsupported Model Id " + chat_request.model) - try: - model = get_model(chat_request.model) - - if chat_request.stream: - return StreamingResponse( - content=model.chat_stream(chat_request), media_type="text/event-stream" - ) - return model.chat(chat_request) - except ValueError as e: - raise HTTPException(status_code=400, detail=str(e)) + + # Exception will be raised if model not supported. + model = get_model(chat_request.model) + if chat_request.stream: + return StreamingResponse( + content=model.chat_stream(chat_request), media_type="text/event-stream" + ) + return model.chat(chat_request) diff --git a/src/api/routers/embeddings.py b/src/api/routers/embeddings.py index 8780c79..135fc27 100644 --- a/src/api/routers/embeddings.py +++ b/src/api/routers/embeddings.py @@ -1,14 +1,12 @@ from typing import Annotated -from fastapi import APIRouter, Depends, Body, HTTPException +from fastapi import APIRouter, Depends, Body from api.auth import api_key_auth -from api.models import get_embeddings_model, SUPPORTED_BEDROCK_EMBEDDING_MODELS +from api.models import get_embeddings_model from api.schema import EmbeddingsRequest, EmbeddingsResponse from api.setting import DEFAULT_EMBEDDING_MODEL -router = APIRouter() - router = APIRouter( prefix="/embeddings", dependencies=[Depends(api_key_auth)], @@ -33,10 +31,6 @@ async def embeddings( ): if embeddings_request.model.lower().startswith("text-embedding-"): embeddings_request.model = DEFAULT_EMBEDDING_MODEL - if embeddings_request.model not in SUPPORTED_BEDROCK_EMBEDDING_MODELS.keys(): - raise HTTPException(status_code=400, detail="Unsupported Model Id " + embeddings_request.model) - try: - model = get_embeddings_model(embeddings_request.model) - return model.embed(embeddings_request) - except ValueError as e: - raise HTTPException(status_code=400, detail=str(e)) + # Exception will be raised if model not supported. + model = get_embeddings_model(embeddings_request.model) + return model.embed(embeddings_request) diff --git a/src/api/routers/model.py b/src/api/routers/model.py index 1ca5800..2640d7d 100644 --- a/src/api/routers/model.py +++ b/src/api/routers/model.py @@ -6,8 +6,6 @@ from api.auth import api_key_auth from api.models import SUPPORTED_BEDROCK_MODELS, SUPPORTED_BEDROCK_EMBEDDING_MODELS from api.schema import Models, Model -router = APIRouter() - router = APIRouter( prefix="/models", dependencies=[Depends(api_key_auth)],