Clean up code
This commit is contained in:
@@ -74,6 +74,7 @@ class BedrockModel(BaseChatModel, ABC):
|
|||||||
if DEBUG:
|
if DEBUG:
|
||||||
logger.info("Invoke Bedrock Model: " + model_id)
|
logger.info("Invoke Bedrock Model: " + model_id)
|
||||||
logger.info("Bedrock request body: " + body)
|
logger.info("Bedrock request body: " + body)
|
||||||
|
try:
|
||||||
if with_stream:
|
if with_stream:
|
||||||
return bedrock_runtime.invoke_model_with_response_stream(
|
return bedrock_runtime.invoke_model_with_response_stream(
|
||||||
body=body,
|
body=body,
|
||||||
@@ -87,6 +88,13 @@ class BedrockModel(BaseChatModel, ABC):
|
|||||||
accept=self.accept,
|
accept=self.accept,
|
||||||
contentType=self.content_type,
|
contentType=self.content_type,
|
||||||
)
|
)
|
||||||
|
except bedrock_runtime.exceptions.ValidationException as e:
|
||||||
|
print("Validation Exception")
|
||||||
|
print(e)
|
||||||
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e)
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def merge_message(messages: list[dict]) -> list[dict]:
|
def merge_message(messages: list[dict]) -> list[dict]:
|
||||||
@@ -185,8 +193,8 @@ class ClaudeModel(BedrockModel):
|
|||||||
{tools}
|
{tools}
|
||||||
|
|
||||||
Please think if you need to use a tool or not for user's question, you must:
|
Please think if you need to use a tool or not for user's question, you must:
|
||||||
1. Respond Y or N inside a <Tool></Tool> xml tag first to indicate that.
|
1. Respond Y or N within <tool></tool> tags first to indicate that.
|
||||||
2. If a tool is needed, MUST respond a JSON object matching the following schema inside a <Func></Func> xml tag:
|
2. If a tool is needed, MUST respond a JSON object matching the following schema within <function></function> tags:
|
||||||
{{"name": $TOOL_NAME, "arguments": {{"$PARAMETER_NAME": "$PARAMETER_VALUE", ...}}}}
|
{{"name": $TOOL_NAME, "arguments": {{"$PARAMETER_NAME": "$PARAMETER_VALUE", ...}}}}
|
||||||
3. If no tools is needed, respond with normal text."""
|
3. If no tools is needed, respond with normal text."""
|
||||||
|
|
||||||
@@ -201,6 +209,7 @@ Please think if you need to use a tool or not for user's question, you must:
|
|||||||
converted_messages = []
|
converted_messages = []
|
||||||
for message in chat_request.messages:
|
for message in chat_request.messages:
|
||||||
if message.role == "system":
|
if message.role == "system":
|
||||||
|
assert isinstance(message.content, str)
|
||||||
system_prompt += message.content + "\n"
|
system_prompt += message.content + "\n"
|
||||||
elif message.role == "user" and not isinstance(message.content, str):
|
elif message.role == "user" and not isinstance(message.content, str):
|
||||||
converted_messages.append(
|
converted_messages.append(
|
||||||
@@ -243,9 +252,9 @@ Please think if you need to use a tool or not for user's question, you must:
|
|||||||
system_prompt += self.tool_prompt.format(tools=tools_str)
|
system_prompt += self.tool_prompt.format(tools=tools_str)
|
||||||
converted_messages.append({
|
converted_messages.append({
|
||||||
'role': 'assistant',
|
'role': 'assistant',
|
||||||
'content': '<Tool>'
|
'content': '<tool>'
|
||||||
})
|
})
|
||||||
args["stop_sequences"] = ['</Func>']
|
args["stop_sequences"] = ['</function>']
|
||||||
args["messages"] = self.merge_message(converted_messages)
|
args["messages"] = self.merge_message(converted_messages)
|
||||||
if system_prompt:
|
if system_prompt:
|
||||||
if DEBUG:
|
if DEBUG:
|
||||||
@@ -267,10 +276,10 @@ Please think if you need to use a tool or not for user's question, you must:
|
|||||||
|
|
||||||
tools = None
|
tools = None
|
||||||
if chat_request.tools:
|
if chat_request.tools:
|
||||||
if message.startswith("Y</Tool>"):
|
if message.startswith("Y</tool>"):
|
||||||
tools = self._parse_tool_message(message)
|
tools = self._parse_tool_message(message)
|
||||||
message = None
|
message = None
|
||||||
elif message.startswith("N</Tool>"):
|
elif message.startswith("N</tool>"):
|
||||||
message = message[8:].lstrip("\n")
|
message = message[8:].lstrip("\n")
|
||||||
return self._create_response(
|
return self._create_response(
|
||||||
model=chat_request.model,
|
model=chat_request.model,
|
||||||
@@ -283,6 +292,8 @@ Please think if you need to use a tool or not for user's question, you must:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
|
def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
|
||||||
|
if DEBUG:
|
||||||
|
logger.info("Raw request: " + chat_request.model_dump_json())
|
||||||
response = self._invoke_model(
|
response = self._invoke_model(
|
||||||
args=self._parse_args(chat_request),
|
args=self._parse_args(chat_request),
|
||||||
model_id=chat_request.model,
|
model_id=chat_request.model,
|
||||||
@@ -321,7 +332,7 @@ Please think if you need to use a tool or not for user's question, you must:
|
|||||||
tool_message += chunk_message
|
tool_message += chunk_message
|
||||||
continue
|
continue
|
||||||
if index < 3:
|
if index < 3:
|
||||||
# Ignore the N</Tool>, which is 3 tokens
|
# Ignore the N</tool>, which is 3 tokens
|
||||||
index += 1
|
index += 1
|
||||||
continue
|
continue
|
||||||
if first_token:
|
if first_token:
|
||||||
@@ -350,7 +361,7 @@ Please think if you need to use a tool or not for user's question, you must:
|
|||||||
if DEBUG:
|
if DEBUG:
|
||||||
logger.info("Tool message: " + tool_message.replace("\n", " "))
|
logger.info("Tool message: " + tool_message.replace("\n", " "))
|
||||||
try:
|
try:
|
||||||
tool_messages = tool_message[tool_message.rindex("<Func>") + 6:]
|
tool_messages = tool_message[tool_message.rindex("<function>") + len("<function>"):]
|
||||||
function = json.loads(tool_messages.replace("\n", " "))
|
function = json.loads(tool_messages.replace("\n", " "))
|
||||||
args = json.dumps(function.get("arguments", {}))
|
args = json.dumps(function.get("arguments", {}))
|
||||||
function = ResponseFunction(
|
function = ResponseFunction(
|
||||||
@@ -365,7 +376,7 @@ Please think if you need to use a tool or not for user's question, you must:
|
|||||||
]
|
]
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Failed to parse tool response")
|
logger.error("Failed to parse tool response" + str(e))
|
||||||
raise HTTPException(status_code=500, detail="Failed to parse tool response")
|
raise HTTPException(status_code=500, detail="Failed to parse tool response")
|
||||||
|
|
||||||
def _get_base64_image(self, image_url: str) -> tuple[str, str]:
|
def _get_base64_image(self, image_url: str) -> tuple[str, str]:
|
||||||
@@ -617,12 +628,20 @@ class BedrockEmbeddingsModel(BaseEmbeddingsModel, ABC):
|
|||||||
if DEBUG:
|
if DEBUG:
|
||||||
logger.info("Invoke Bedrock Model: " + model_id)
|
logger.info("Invoke Bedrock Model: " + model_id)
|
||||||
logger.info("Bedrock request body: " + body)
|
logger.info("Bedrock request body: " + body)
|
||||||
|
try:
|
||||||
return bedrock_runtime.invoke_model(
|
return bedrock_runtime.invoke_model(
|
||||||
body=body,
|
body=body,
|
||||||
modelId=model_id,
|
modelId=model_id,
|
||||||
accept=self.accept,
|
accept=self.accept,
|
||||||
contentType=self.content_type,
|
contentType=self.content_type,
|
||||||
)
|
)
|
||||||
|
except bedrock_runtime.exceptions.ValidationException as e:
|
||||||
|
print("Validation Exception")
|
||||||
|
print(e)
|
||||||
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e)
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
def _create_response(
|
def _create_response(
|
||||||
self,
|
self,
|
||||||
@@ -739,16 +758,19 @@ def get_model(model_id: str) -> BedrockModel:
|
|||||||
model_name = SUPPORTED_BEDROCK_MODELS.get(model_id, "")
|
model_name = SUPPORTED_BEDROCK_MODELS.get(model_id, "")
|
||||||
if DEBUG:
|
if DEBUG:
|
||||||
logger.info("model name is " + model_name)
|
logger.info("model name is " + model_name)
|
||||||
if model_name in ["Claude Instant", "Claude", "Claude 3 Sonnet", "Claude 3 Haiku", "Claude 3 Opus"]:
|
# Not using start_with here in case of complex scenarios.
|
||||||
|
# The downside is to change this everytime for a new model supported.
|
||||||
|
match model_name:
|
||||||
|
case "Claude Instant" | "Claude" | "Claude 3 Sonnet" | "Claude 3 Haiku" | "Claude 3 Opus":
|
||||||
return ClaudeModel()
|
return ClaudeModel()
|
||||||
elif model_name in ["Llama 2 Chat 13B", "Llama 2 Chat 70B"]:
|
case "Llama 2 Chat 13B" | "Llama 2 Chat 70B":
|
||||||
return Llama2Model()
|
return Llama2Model()
|
||||||
elif model_name in ["Mistral 7B Instruct", "Mixtral 8x7B Instruct", "Mistral Large"]:
|
case "Mistral 7B Instruct" | "Mixtral 8x7B Instruct" | "Mistral Large":
|
||||||
return MistralModel()
|
return MistralModel()
|
||||||
else:
|
case _:
|
||||||
logger.error("Unsupported model id " + model_id)
|
logger.error("Unsupported model id " + model_id)
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=400,
|
||||||
detail="Unsupported model id " + model_id,
|
detail="Unsupported model id " + model_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -757,13 +779,12 @@ def get_embeddings_model(model_id: str) -> BedrockEmbeddingsModel:
|
|||||||
model_name = SUPPORTED_BEDROCK_EMBEDDING_MODELS.get(model_id, "")
|
model_name = SUPPORTED_BEDROCK_EMBEDDING_MODELS.get(model_id, "")
|
||||||
if DEBUG:
|
if DEBUG:
|
||||||
logger.info("model name is " + model_name)
|
logger.info("model name is " + model_name)
|
||||||
if model_name in ["Cohere Embed Multilingual", "Cohere Embed English"]:
|
match model_name:
|
||||||
|
case "Cohere Embed Multilingual" | "Cohere Embed English":
|
||||||
return CohereEmbeddingsModel()
|
return CohereEmbeddingsModel()
|
||||||
elif model_name in ["Titan Embeddings G1 - Text", "Titan Multimodal Embeddings G1"]:
|
case _:
|
||||||
return TitanEmbeddingsModel()
|
|
||||||
else:
|
|
||||||
logger.error("Unsupported model id " + model_id)
|
logger.error("Unsupported model id " + model_id)
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=400,
|
||||||
detail="Unsupported model id " + model_id,
|
detail="Unsupported embedding model id " + model_id,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,15 +1,13 @@
|
|||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Body, HTTPException
|
from fastapi import APIRouter, Depends, Body
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
from api.auth import api_key_auth
|
from api.auth import api_key_auth
|
||||||
from api.models import get_model, SUPPORTED_BEDROCK_MODELS
|
from api.models import get_model
|
||||||
from api.schema import ChatRequest, ChatResponse, ChatStreamResponse
|
from api.schema import ChatRequest, ChatResponse, ChatStreamResponse
|
||||||
from api.setting import DEFAULT_MODEL
|
from api.setting import DEFAULT_MODEL
|
||||||
|
|
||||||
router = APIRouter()
|
|
||||||
|
|
||||||
router = APIRouter(
|
router = APIRouter(
|
||||||
prefix="/chat",
|
prefix="/chat",
|
||||||
dependencies=[Depends(api_key_auth)],
|
dependencies=[Depends(api_key_auth)],
|
||||||
@@ -36,15 +34,11 @@ async def chat_completions(
|
|||||||
):
|
):
|
||||||
if chat_request.model.lower().startswith("gpt-"):
|
if chat_request.model.lower().startswith("gpt-"):
|
||||||
chat_request.model = DEFAULT_MODEL
|
chat_request.model = DEFAULT_MODEL
|
||||||
if chat_request.model not in SUPPORTED_BEDROCK_MODELS.keys():
|
|
||||||
raise HTTPException(status_code=400, detail="Unsupported Model Id " + chat_request.model)
|
|
||||||
try:
|
|
||||||
model = get_model(chat_request.model)
|
|
||||||
|
|
||||||
|
# Exception will be raised if model not supported.
|
||||||
|
model = get_model(chat_request.model)
|
||||||
if chat_request.stream:
|
if chat_request.stream:
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
content=model.chat_stream(chat_request), media_type="text/event-stream"
|
content=model.chat_stream(chat_request), media_type="text/event-stream"
|
||||||
)
|
)
|
||||||
return model.chat(chat_request)
|
return model.chat(chat_request)
|
||||||
except ValueError as e:
|
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
|
||||||
|
|||||||
@@ -1,14 +1,12 @@
|
|||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Body, HTTPException
|
from fastapi import APIRouter, Depends, Body
|
||||||
|
|
||||||
from api.auth import api_key_auth
|
from api.auth import api_key_auth
|
||||||
from api.models import get_embeddings_model, SUPPORTED_BEDROCK_EMBEDDING_MODELS
|
from api.models import get_embeddings_model
|
||||||
from api.schema import EmbeddingsRequest, EmbeddingsResponse
|
from api.schema import EmbeddingsRequest, EmbeddingsResponse
|
||||||
from api.setting import DEFAULT_EMBEDDING_MODEL
|
from api.setting import DEFAULT_EMBEDDING_MODEL
|
||||||
|
|
||||||
router = APIRouter()
|
|
||||||
|
|
||||||
router = APIRouter(
|
router = APIRouter(
|
||||||
prefix="/embeddings",
|
prefix="/embeddings",
|
||||||
dependencies=[Depends(api_key_auth)],
|
dependencies=[Depends(api_key_auth)],
|
||||||
@@ -33,10 +31,6 @@ async def embeddings(
|
|||||||
):
|
):
|
||||||
if embeddings_request.model.lower().startswith("text-embedding-"):
|
if embeddings_request.model.lower().startswith("text-embedding-"):
|
||||||
embeddings_request.model = DEFAULT_EMBEDDING_MODEL
|
embeddings_request.model = DEFAULT_EMBEDDING_MODEL
|
||||||
if embeddings_request.model not in SUPPORTED_BEDROCK_EMBEDDING_MODELS.keys():
|
# Exception will be raised if model not supported.
|
||||||
raise HTTPException(status_code=400, detail="Unsupported Model Id " + embeddings_request.model)
|
|
||||||
try:
|
|
||||||
model = get_embeddings_model(embeddings_request.model)
|
model = get_embeddings_model(embeddings_request.model)
|
||||||
return model.embed(embeddings_request)
|
return model.embed(embeddings_request)
|
||||||
except ValueError as e:
|
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
|
||||||
|
|||||||
@@ -6,8 +6,6 @@ from api.auth import api_key_auth
|
|||||||
from api.models import SUPPORTED_BEDROCK_MODELS, SUPPORTED_BEDROCK_EMBEDDING_MODELS
|
from api.models import SUPPORTED_BEDROCK_MODELS, SUPPORTED_BEDROCK_EMBEDDING_MODELS
|
||||||
from api.schema import Models, Model
|
from api.schema import Models, Model
|
||||||
|
|
||||||
router = APIRouter()
|
|
||||||
|
|
||||||
router = APIRouter(
|
router = APIRouter(
|
||||||
prefix="/models",
|
prefix="/models",
|
||||||
dependencies=[Depends(api_key_auth)],
|
dependencies=[Depends(api_key_auth)],
|
||||||
|
|||||||
Reference in New Issue
Block a user