Clean up code
This commit is contained in:
@@ -74,19 +74,27 @@ class BedrockModel(BaseChatModel, ABC):
|
||||
if DEBUG:
|
||||
logger.info("Invoke Bedrock Model: " + model_id)
|
||||
logger.info("Bedrock request body: " + body)
|
||||
if with_stream:
|
||||
return bedrock_runtime.invoke_model_with_response_stream(
|
||||
try:
|
||||
if with_stream:
|
||||
return bedrock_runtime.invoke_model_with_response_stream(
|
||||
body=body,
|
||||
modelId=model_id,
|
||||
accept=self.accept,
|
||||
contentType=self.content_type,
|
||||
)
|
||||
return bedrock_runtime.invoke_model(
|
||||
body=body,
|
||||
modelId=model_id,
|
||||
accept=self.accept,
|
||||
contentType=self.content_type,
|
||||
)
|
||||
return bedrock_runtime.invoke_model(
|
||||
body=body,
|
||||
modelId=model_id,
|
||||
accept=self.accept,
|
||||
contentType=self.content_type,
|
||||
)
|
||||
except bedrock_runtime.exceptions.ValidationException as e:
|
||||
print("Validation Exception")
|
||||
print(e)
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@staticmethod
|
||||
def merge_message(messages: list[dict]) -> list[dict]:
|
||||
@@ -185,8 +193,8 @@ class ClaudeModel(BedrockModel):
|
||||
{tools}
|
||||
|
||||
Please think if you need to use a tool or not for user's question, you must:
|
||||
1. Respond Y or N inside a <Tool></Tool> xml tag first to indicate that.
|
||||
2. If a tool is needed, MUST respond a JSON object matching the following schema inside a <Func></Func> xml tag:
|
||||
1. Respond Y or N within <tool></tool> tags first to indicate that.
|
||||
2. If a tool is needed, MUST respond a JSON object matching the following schema within <function></function> tags:
|
||||
{{"name": $TOOL_NAME, "arguments": {{"$PARAMETER_NAME": "$PARAMETER_VALUE", ...}}}}
|
||||
3. If no tools is needed, respond with normal text."""
|
||||
|
||||
@@ -201,6 +209,7 @@ Please think if you need to use a tool or not for user's question, you must:
|
||||
converted_messages = []
|
||||
for message in chat_request.messages:
|
||||
if message.role == "system":
|
||||
assert isinstance(message.content, str)
|
||||
system_prompt += message.content + "\n"
|
||||
elif message.role == "user" and not isinstance(message.content, str):
|
||||
converted_messages.append(
|
||||
@@ -243,9 +252,9 @@ Please think if you need to use a tool or not for user's question, you must:
|
||||
system_prompt += self.tool_prompt.format(tools=tools_str)
|
||||
converted_messages.append({
|
||||
'role': 'assistant',
|
||||
'content': '<Tool>'
|
||||
'content': '<tool>'
|
||||
})
|
||||
args["stop_sequences"] = ['</Func>']
|
||||
args["stop_sequences"] = ['</function>']
|
||||
args["messages"] = self.merge_message(converted_messages)
|
||||
if system_prompt:
|
||||
if DEBUG:
|
||||
@@ -267,10 +276,10 @@ Please think if you need to use a tool or not for user's question, you must:
|
||||
|
||||
tools = None
|
||||
if chat_request.tools:
|
||||
if message.startswith("Y</Tool>"):
|
||||
if message.startswith("Y</tool>"):
|
||||
tools = self._parse_tool_message(message)
|
||||
message = None
|
||||
elif message.startswith("N</Tool>"):
|
||||
elif message.startswith("N</tool>"):
|
||||
message = message[8:].lstrip("\n")
|
||||
return self._create_response(
|
||||
model=chat_request.model,
|
||||
@@ -283,6 +292,8 @@ Please think if you need to use a tool or not for user's question, you must:
|
||||
)
|
||||
|
||||
def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
|
||||
if DEBUG:
|
||||
logger.info("Raw request: " + chat_request.model_dump_json())
|
||||
response = self._invoke_model(
|
||||
args=self._parse_args(chat_request),
|
||||
model_id=chat_request.model,
|
||||
@@ -321,7 +332,7 @@ Please think if you need to use a tool or not for user's question, you must:
|
||||
tool_message += chunk_message
|
||||
continue
|
||||
if index < 3:
|
||||
# Ignore the N</Tool>, which is 3 tokens
|
||||
# Ignore the N</tool>, which is 3 tokens
|
||||
index += 1
|
||||
continue
|
||||
if first_token:
|
||||
@@ -350,7 +361,7 @@ Please think if you need to use a tool or not for user's question, you must:
|
||||
if DEBUG:
|
||||
logger.info("Tool message: " + tool_message.replace("\n", " "))
|
||||
try:
|
||||
tool_messages = tool_message[tool_message.rindex("<Func>") + 6:]
|
||||
tool_messages = tool_message[tool_message.rindex("<function>") + len("<function>"):]
|
||||
function = json.loads(tool_messages.replace("\n", " "))
|
||||
args = json.dumps(function.get("arguments", {}))
|
||||
function = ResponseFunction(
|
||||
@@ -365,7 +376,7 @@ Please think if you need to use a tool or not for user's question, you must:
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to parse tool response")
|
||||
logger.error("Failed to parse tool response" + str(e))
|
||||
raise HTTPException(status_code=500, detail="Failed to parse tool response")
|
||||
|
||||
def _get_base64_image(self, image_url: str) -> tuple[str, str]:
|
||||
@@ -617,12 +628,20 @@ class BedrockEmbeddingsModel(BaseEmbeddingsModel, ABC):
|
||||
if DEBUG:
|
||||
logger.info("Invoke Bedrock Model: " + model_id)
|
||||
logger.info("Bedrock request body: " + body)
|
||||
return bedrock_runtime.invoke_model(
|
||||
body=body,
|
||||
modelId=model_id,
|
||||
accept=self.accept,
|
||||
contentType=self.content_type,
|
||||
)
|
||||
try:
|
||||
return bedrock_runtime.invoke_model(
|
||||
body=body,
|
||||
modelId=model_id,
|
||||
accept=self.accept,
|
||||
contentType=self.content_type,
|
||||
)
|
||||
except bedrock_runtime.exceptions.ValidationException as e:
|
||||
print("Validation Exception")
|
||||
print(e)
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
def _create_response(
|
||||
self,
|
||||
@@ -739,31 +758,33 @@ def get_model(model_id: str) -> BedrockModel:
|
||||
model_name = SUPPORTED_BEDROCK_MODELS.get(model_id, "")
|
||||
if DEBUG:
|
||||
logger.info("model name is " + model_name)
|
||||
if model_name in ["Claude Instant", "Claude", "Claude 3 Sonnet", "Claude 3 Haiku", "Claude 3 Opus"]:
|
||||
return ClaudeModel()
|
||||
elif model_name in ["Llama 2 Chat 13B", "Llama 2 Chat 70B"]:
|
||||
return Llama2Model()
|
||||
elif model_name in ["Mistral 7B Instruct", "Mixtral 8x7B Instruct", "Mistral Large"]:
|
||||
return MistralModel()
|
||||
else:
|
||||
logger.error("Unsupported model id " + model_id)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Unsupported model id " + model_id,
|
||||
)
|
||||
# Not using start_with here in case of complex scenarios.
|
||||
# The downside is to change this everytime for a new model supported.
|
||||
match model_name:
|
||||
case "Claude Instant" | "Claude" | "Claude 3 Sonnet" | "Claude 3 Haiku" | "Claude 3 Opus":
|
||||
return ClaudeModel()
|
||||
case "Llama 2 Chat 13B" | "Llama 2 Chat 70B":
|
||||
return Llama2Model()
|
||||
case "Mistral 7B Instruct" | "Mixtral 8x7B Instruct" | "Mistral Large":
|
||||
return MistralModel()
|
||||
case _:
|
||||
logger.error("Unsupported model id " + model_id)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Unsupported model id " + model_id,
|
||||
)
|
||||
|
||||
|
||||
def get_embeddings_model(model_id: str) -> BedrockEmbeddingsModel:
|
||||
model_name = SUPPORTED_BEDROCK_EMBEDDING_MODELS.get(model_id, "")
|
||||
if DEBUG:
|
||||
logger.info("model name is " + model_name)
|
||||
if model_name in ["Cohere Embed Multilingual", "Cohere Embed English"]:
|
||||
return CohereEmbeddingsModel()
|
||||
elif model_name in ["Titan Embeddings G1 - Text", "Titan Multimodal Embeddings G1"]:
|
||||
return TitanEmbeddingsModel()
|
||||
else:
|
||||
logger.error("Unsupported model id " + model_id)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Unsupported model id " + model_id,
|
||||
)
|
||||
match model_name:
|
||||
case "Cohere Embed Multilingual" | "Cohere Embed English":
|
||||
return CohereEmbeddingsModel()
|
||||
case _:
|
||||
logger.error("Unsupported model id " + model_id)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Unsupported embedding model id " + model_id,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user