Clean up code

This commit is contained in:
Aiden Dai
2024-04-18 16:22:06 +08:00
parent 8340be4660
commit 7416f9a4e2
4 changed files with 81 additions and 74 deletions

View File

@@ -74,19 +74,27 @@ class BedrockModel(BaseChatModel, ABC):
if DEBUG:
logger.info("Invoke Bedrock Model: " + model_id)
logger.info("Bedrock request body: " + body)
if with_stream:
return bedrock_runtime.invoke_model_with_response_stream(
try:
if with_stream:
return bedrock_runtime.invoke_model_with_response_stream(
body=body,
modelId=model_id,
accept=self.accept,
contentType=self.content_type,
)
return bedrock_runtime.invoke_model(
body=body,
modelId=model_id,
accept=self.accept,
contentType=self.content_type,
)
return bedrock_runtime.invoke_model(
body=body,
modelId=model_id,
accept=self.accept,
contentType=self.content_type,
)
except bedrock_runtime.exceptions.ValidationException as e:
print("Validation Exception")
print(e)
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(e)
raise HTTPException(status_code=500, detail=str(e))
@staticmethod
def merge_message(messages: list[dict]) -> list[dict]:
@@ -185,8 +193,8 @@ class ClaudeModel(BedrockModel):
{tools}
Please think if you need to use a tool or not for user's question, you must:
1. Respond Y or N inside a <Tool></Tool> xml tag first to indicate that.
2. If a tool is needed, MUST respond a JSON object matching the following schema inside a <Func></Func> xml tag:
1. Respond Y or N within <tool></tool> tags first to indicate that.
2. If a tool is needed, MUST respond a JSON object matching the following schema within <function></function> tags:
{{"name": $TOOL_NAME, "arguments": {{"$PARAMETER_NAME": "$PARAMETER_VALUE", ...}}}}
3. If no tools is needed, respond with normal text."""
@@ -201,6 +209,7 @@ Please think if you need to use a tool or not for user's question, you must:
converted_messages = []
for message in chat_request.messages:
if message.role == "system":
assert isinstance(message.content, str)
system_prompt += message.content + "\n"
elif message.role == "user" and not isinstance(message.content, str):
converted_messages.append(
@@ -243,9 +252,9 @@ Please think if you need to use a tool or not for user's question, you must:
system_prompt += self.tool_prompt.format(tools=tools_str)
converted_messages.append({
'role': 'assistant',
'content': '<Tool>'
'content': '<tool>'
})
args["stop_sequences"] = ['</Func>']
args["stop_sequences"] = ['</function>']
args["messages"] = self.merge_message(converted_messages)
if system_prompt:
if DEBUG:
@@ -267,10 +276,10 @@ Please think if you need to use a tool or not for user's question, you must:
tools = None
if chat_request.tools:
if message.startswith("Y</Tool>"):
if message.startswith("Y</tool>"):
tools = self._parse_tool_message(message)
message = None
elif message.startswith("N</Tool>"):
elif message.startswith("N</tool>"):
message = message[8:].lstrip("\n")
return self._create_response(
model=chat_request.model,
@@ -283,6 +292,8 @@ Please think if you need to use a tool or not for user's question, you must:
)
def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
if DEBUG:
logger.info("Raw request: " + chat_request.model_dump_json())
response = self._invoke_model(
args=self._parse_args(chat_request),
model_id=chat_request.model,
@@ -321,7 +332,7 @@ Please think if you need to use a tool or not for user's question, you must:
tool_message += chunk_message
continue
if index < 3:
# Ignore the N</Tool>, which is 3 tokens
# Ignore the N</tool>, which is 3 tokens
index += 1
continue
if first_token:
@@ -350,7 +361,7 @@ Please think if you need to use a tool or not for user's question, you must:
if DEBUG:
logger.info("Tool message: " + tool_message.replace("\n", " "))
try:
tool_messages = tool_message[tool_message.rindex("<Func>") + 6:]
tool_messages = tool_message[tool_message.rindex("<function>") + len("<function>"):]
function = json.loads(tool_messages.replace("\n", " "))
args = json.dumps(function.get("arguments", {}))
function = ResponseFunction(
@@ -365,7 +376,7 @@ Please think if you need to use a tool or not for user's question, you must:
]
except Exception as e:
logger.error("Failed to parse tool response")
logger.error("Failed to parse tool response" + str(e))
raise HTTPException(status_code=500, detail="Failed to parse tool response")
def _get_base64_image(self, image_url: str) -> tuple[str, str]:
@@ -617,12 +628,20 @@ class BedrockEmbeddingsModel(BaseEmbeddingsModel, ABC):
if DEBUG:
logger.info("Invoke Bedrock Model: " + model_id)
logger.info("Bedrock request body: " + body)
return bedrock_runtime.invoke_model(
body=body,
modelId=model_id,
accept=self.accept,
contentType=self.content_type,
)
try:
return bedrock_runtime.invoke_model(
body=body,
modelId=model_id,
accept=self.accept,
contentType=self.content_type,
)
except bedrock_runtime.exceptions.ValidationException as e:
print("Validation Exception")
print(e)
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(e)
raise HTTPException(status_code=500, detail=str(e))
def _create_response(
self,
@@ -739,31 +758,33 @@ def get_model(model_id: str) -> BedrockModel:
model_name = SUPPORTED_BEDROCK_MODELS.get(model_id, "")
if DEBUG:
logger.info("model name is " + model_name)
if model_name in ["Claude Instant", "Claude", "Claude 3 Sonnet", "Claude 3 Haiku", "Claude 3 Opus"]:
return ClaudeModel()
elif model_name in ["Llama 2 Chat 13B", "Llama 2 Chat 70B"]:
return Llama2Model()
elif model_name in ["Mistral 7B Instruct", "Mixtral 8x7B Instruct", "Mistral Large"]:
return MistralModel()
else:
logger.error("Unsupported model id " + model_id)
raise HTTPException(
status_code=500,
detail="Unsupported model id " + model_id,
)
# Not using start_with here in case of complex scenarios.
# The downside is to change this everytime for a new model supported.
match model_name:
case "Claude Instant" | "Claude" | "Claude 3 Sonnet" | "Claude 3 Haiku" | "Claude 3 Opus":
return ClaudeModel()
case "Llama 2 Chat 13B" | "Llama 2 Chat 70B":
return Llama2Model()
case "Mistral 7B Instruct" | "Mixtral 8x7B Instruct" | "Mistral Large":
return MistralModel()
case _:
logger.error("Unsupported model id " + model_id)
raise HTTPException(
status_code=400,
detail="Unsupported model id " + model_id,
)
def get_embeddings_model(model_id: str) -> BedrockEmbeddingsModel:
model_name = SUPPORTED_BEDROCK_EMBEDDING_MODELS.get(model_id, "")
if DEBUG:
logger.info("model name is " + model_name)
if model_name in ["Cohere Embed Multilingual", "Cohere Embed English"]:
return CohereEmbeddingsModel()
elif model_name in ["Titan Embeddings G1 - Text", "Titan Multimodal Embeddings G1"]:
return TitanEmbeddingsModel()
else:
logger.error("Unsupported model id " + model_id)
raise HTTPException(
status_code=500,
detail="Unsupported model id " + model_id,
)
match model_name:
case "Cohere Embed Multilingual" | "Cohere Embed English":
return CohereEmbeddingsModel()
case _:
logger.error("Unsupported model id " + model_id)
raise HTTPException(
status_code=400,
detail="Unsupported embedding model id " + model_id,
)