diff --git a/src/api/models/bedrock.py b/src/api/models/bedrock.py
index c345725..28c4b69 100644
--- a/src/api/models/bedrock.py
+++ b/src/api/models/bedrock.py
@@ -74,19 +74,27 @@ class BedrockModel(BaseChatModel, ABC):
if DEBUG:
logger.info("Invoke Bedrock Model: " + model_id)
logger.info("Bedrock request body: " + body)
- if with_stream:
- return bedrock_runtime.invoke_model_with_response_stream(
+ try:
+ if with_stream:
+ return bedrock_runtime.invoke_model_with_response_stream(
+ body=body,
+ modelId=model_id,
+ accept=self.accept,
+ contentType=self.content_type,
+ )
+ return bedrock_runtime.invoke_model(
body=body,
modelId=model_id,
accept=self.accept,
contentType=self.content_type,
)
- return bedrock_runtime.invoke_model(
- body=body,
- modelId=model_id,
- accept=self.accept,
- contentType=self.content_type,
- )
+ except bedrock_runtime.exceptions.ValidationException as e:
+ print("Validation Exception")
+ print(e)
+ raise HTTPException(status_code=400, detail=str(e))
+ except Exception as e:
+ logger.error(e)
+ raise HTTPException(status_code=500, detail=str(e))
@staticmethod
def merge_message(messages: list[dict]) -> list[dict]:
@@ -185,8 +193,8 @@ class ClaudeModel(BedrockModel):
{tools}
Please think if you need to use a tool or not for user's question, you must:
-1. Respond Y or N inside a xml tag first to indicate that.
-2. If a tool is needed, MUST respond a JSON object matching the following schema inside a xml tag:
+1. Respond Y or N within tags first to indicate that.
+2. If a tool is needed, MUST respond a JSON object matching the following schema within tags:
{{"name": $TOOL_NAME, "arguments": {{"$PARAMETER_NAME": "$PARAMETER_VALUE", ...}}}}
3. If no tools is needed, respond with normal text."""
@@ -201,6 +209,7 @@ Please think if you need to use a tool or not for user's question, you must:
converted_messages = []
for message in chat_request.messages:
if message.role == "system":
+ assert isinstance(message.content, str)
system_prompt += message.content + "\n"
elif message.role == "user" and not isinstance(message.content, str):
converted_messages.append(
@@ -243,9 +252,9 @@ Please think if you need to use a tool or not for user's question, you must:
system_prompt += self.tool_prompt.format(tools=tools_str)
converted_messages.append({
'role': 'assistant',
- 'content': ''
+ 'content': ''
})
- args["stop_sequences"] = ['']
+ args["stop_sequences"] = ['']
args["messages"] = self.merge_message(converted_messages)
if system_prompt:
if DEBUG:
@@ -267,10 +276,10 @@ Please think if you need to use a tool or not for user's question, you must:
tools = None
if chat_request.tools:
- if message.startswith("Y"):
+ if message.startswith("Y"):
tools = self._parse_tool_message(message)
message = None
- elif message.startswith("N"):
+ elif message.startswith("N"):
message = message[8:].lstrip("\n")
return self._create_response(
model=chat_request.model,
@@ -283,6 +292,8 @@ Please think if you need to use a tool or not for user's question, you must:
)
def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
+ if DEBUG:
+ logger.info("Raw request: " + chat_request.model_dump_json())
response = self._invoke_model(
args=self._parse_args(chat_request),
model_id=chat_request.model,
@@ -321,7 +332,7 @@ Please think if you need to use a tool or not for user's question, you must:
tool_message += chunk_message
continue
if index < 3:
- # Ignore the N, which is 3 tokens
+ # Ignore the N, which is 3 tokens
index += 1
continue
if first_token:
@@ -350,7 +361,7 @@ Please think if you need to use a tool or not for user's question, you must:
if DEBUG:
logger.info("Tool message: " + tool_message.replace("\n", " "))
try:
- tool_messages = tool_message[tool_message.rindex("") + 6:]
+ tool_messages = tool_message[tool_message.rindex("") + len(""):]
function = json.loads(tool_messages.replace("\n", " "))
args = json.dumps(function.get("arguments", {}))
function = ResponseFunction(
@@ -365,7 +376,7 @@ Please think if you need to use a tool or not for user's question, you must:
]
except Exception as e:
- logger.error("Failed to parse tool response")
+ logger.error("Failed to parse tool response" + str(e))
raise HTTPException(status_code=500, detail="Failed to parse tool response")
def _get_base64_image(self, image_url: str) -> tuple[str, str]:
@@ -617,12 +628,20 @@ class BedrockEmbeddingsModel(BaseEmbeddingsModel, ABC):
if DEBUG:
logger.info("Invoke Bedrock Model: " + model_id)
logger.info("Bedrock request body: " + body)
- return bedrock_runtime.invoke_model(
- body=body,
- modelId=model_id,
- accept=self.accept,
- contentType=self.content_type,
- )
+ try:
+ return bedrock_runtime.invoke_model(
+ body=body,
+ modelId=model_id,
+ accept=self.accept,
+ contentType=self.content_type,
+ )
+ except bedrock_runtime.exceptions.ValidationException as e:
+ print("Validation Exception")
+ print(e)
+ raise HTTPException(status_code=400, detail=str(e))
+ except Exception as e:
+ logger.error(e)
+ raise HTTPException(status_code=500, detail=str(e))
def _create_response(
self,
@@ -739,31 +758,33 @@ def get_model(model_id: str) -> BedrockModel:
model_name = SUPPORTED_BEDROCK_MODELS.get(model_id, "")
if DEBUG:
logger.info("model name is " + model_name)
- if model_name in ["Claude Instant", "Claude", "Claude 3 Sonnet", "Claude 3 Haiku", "Claude 3 Opus"]:
- return ClaudeModel()
- elif model_name in ["Llama 2 Chat 13B", "Llama 2 Chat 70B"]:
- return Llama2Model()
- elif model_name in ["Mistral 7B Instruct", "Mixtral 8x7B Instruct", "Mistral Large"]:
- return MistralModel()
- else:
- logger.error("Unsupported model id " + model_id)
- raise HTTPException(
- status_code=500,
- detail="Unsupported model id " + model_id,
- )
+ # Not using start_with here in case of complex scenarios.
+ # The downside is to change this everytime for a new model supported.
+ match model_name:
+ case "Claude Instant" | "Claude" | "Claude 3 Sonnet" | "Claude 3 Haiku" | "Claude 3 Opus":
+ return ClaudeModel()
+ case "Llama 2 Chat 13B" | "Llama 2 Chat 70B":
+ return Llama2Model()
+ case "Mistral 7B Instruct" | "Mixtral 8x7B Instruct" | "Mistral Large":
+ return MistralModel()
+ case _:
+ logger.error("Unsupported model id " + model_id)
+ raise HTTPException(
+ status_code=400,
+ detail="Unsupported model id " + model_id,
+ )
def get_embeddings_model(model_id: str) -> BedrockEmbeddingsModel:
model_name = SUPPORTED_BEDROCK_EMBEDDING_MODELS.get(model_id, "")
if DEBUG:
logger.info("model name is " + model_name)
- if model_name in ["Cohere Embed Multilingual", "Cohere Embed English"]:
- return CohereEmbeddingsModel()
- elif model_name in ["Titan Embeddings G1 - Text", "Titan Multimodal Embeddings G1"]:
- return TitanEmbeddingsModel()
- else:
- logger.error("Unsupported model id " + model_id)
- raise HTTPException(
- status_code=500,
- detail="Unsupported model id " + model_id,
- )
+ match model_name:
+ case "Cohere Embed Multilingual" | "Cohere Embed English":
+ return CohereEmbeddingsModel()
+ case _:
+ logger.error("Unsupported model id " + model_id)
+ raise HTTPException(
+ status_code=400,
+ detail="Unsupported embedding model id " + model_id,
+ )
diff --git a/src/api/routers/chat.py b/src/api/routers/chat.py
index 3e6d69d..f2758fa 100644
--- a/src/api/routers/chat.py
+++ b/src/api/routers/chat.py
@@ -1,15 +1,13 @@
from typing import Annotated
-from fastapi import APIRouter, Depends, Body, HTTPException
+from fastapi import APIRouter, Depends, Body
from fastapi.responses import StreamingResponse
from api.auth import api_key_auth
-from api.models import get_model, SUPPORTED_BEDROCK_MODELS
+from api.models import get_model
from api.schema import ChatRequest, ChatResponse, ChatStreamResponse
from api.setting import DEFAULT_MODEL
-router = APIRouter()
-
router = APIRouter(
prefix="/chat",
dependencies=[Depends(api_key_auth)],
@@ -36,15 +34,11 @@ async def chat_completions(
):
if chat_request.model.lower().startswith("gpt-"):
chat_request.model = DEFAULT_MODEL
- if chat_request.model not in SUPPORTED_BEDROCK_MODELS.keys():
- raise HTTPException(status_code=400, detail="Unsupported Model Id " + chat_request.model)
- try:
- model = get_model(chat_request.model)
-
- if chat_request.stream:
- return StreamingResponse(
- content=model.chat_stream(chat_request), media_type="text/event-stream"
- )
- return model.chat(chat_request)
- except ValueError as e:
- raise HTTPException(status_code=400, detail=str(e))
+
+ # Exception will be raised if model not supported.
+ model = get_model(chat_request.model)
+ if chat_request.stream:
+ return StreamingResponse(
+ content=model.chat_stream(chat_request), media_type="text/event-stream"
+ )
+ return model.chat(chat_request)
diff --git a/src/api/routers/embeddings.py b/src/api/routers/embeddings.py
index 8780c79..135fc27 100644
--- a/src/api/routers/embeddings.py
+++ b/src/api/routers/embeddings.py
@@ -1,14 +1,12 @@
from typing import Annotated
-from fastapi import APIRouter, Depends, Body, HTTPException
+from fastapi import APIRouter, Depends, Body
from api.auth import api_key_auth
-from api.models import get_embeddings_model, SUPPORTED_BEDROCK_EMBEDDING_MODELS
+from api.models import get_embeddings_model
from api.schema import EmbeddingsRequest, EmbeddingsResponse
from api.setting import DEFAULT_EMBEDDING_MODEL
-router = APIRouter()
-
router = APIRouter(
prefix="/embeddings",
dependencies=[Depends(api_key_auth)],
@@ -33,10 +31,6 @@ async def embeddings(
):
if embeddings_request.model.lower().startswith("text-embedding-"):
embeddings_request.model = DEFAULT_EMBEDDING_MODEL
- if embeddings_request.model not in SUPPORTED_BEDROCK_EMBEDDING_MODELS.keys():
- raise HTTPException(status_code=400, detail="Unsupported Model Id " + embeddings_request.model)
- try:
- model = get_embeddings_model(embeddings_request.model)
- return model.embed(embeddings_request)
- except ValueError as e:
- raise HTTPException(status_code=400, detail=str(e))
+ # Exception will be raised if model not supported.
+ model = get_embeddings_model(embeddings_request.model)
+ return model.embed(embeddings_request)
diff --git a/src/api/routers/model.py b/src/api/routers/model.py
index 1ca5800..2640d7d 100644
--- a/src/api/routers/model.py
+++ b/src/api/routers/model.py
@@ -6,8 +6,6 @@ from api.auth import api_key_auth
from api.models import SUPPORTED_BEDROCK_MODELS, SUPPORTED_BEDROCK_EMBEDDING_MODELS
from api.schema import Models, Model
-router = APIRouter()
-
router = APIRouter(
prefix="/models",
dependencies=[Depends(api_key_auth)],