Clean up code

This commit is contained in:
Aiden Dai
2024-04-18 16:22:06 +08:00
parent 8340be4660
commit 7416f9a4e2
4 changed files with 81 additions and 74 deletions

View File

@@ -74,19 +74,27 @@ class BedrockModel(BaseChatModel, ABC):
if DEBUG: if DEBUG:
logger.info("Invoke Bedrock Model: " + model_id) logger.info("Invoke Bedrock Model: " + model_id)
logger.info("Bedrock request body: " + body) logger.info("Bedrock request body: " + body)
if with_stream: try:
return bedrock_runtime.invoke_model_with_response_stream( if with_stream:
return bedrock_runtime.invoke_model_with_response_stream(
body=body,
modelId=model_id,
accept=self.accept,
contentType=self.content_type,
)
return bedrock_runtime.invoke_model(
body=body, body=body,
modelId=model_id, modelId=model_id,
accept=self.accept, accept=self.accept,
contentType=self.content_type, contentType=self.content_type,
) )
return bedrock_runtime.invoke_model( except bedrock_runtime.exceptions.ValidationException as e:
body=body, print("Validation Exception")
modelId=model_id, print(e)
accept=self.accept, raise HTTPException(status_code=400, detail=str(e))
contentType=self.content_type, except Exception as e:
) logger.error(e)
raise HTTPException(status_code=500, detail=str(e))
@staticmethod @staticmethod
def merge_message(messages: list[dict]) -> list[dict]: def merge_message(messages: list[dict]) -> list[dict]:
@@ -185,8 +193,8 @@ class ClaudeModel(BedrockModel):
{tools} {tools}
Please think if you need to use a tool or not for user's question, you must: Please think if you need to use a tool or not for user's question, you must:
1. Respond Y or N inside a <Tool></Tool> xml tag first to indicate that. 1. Respond Y or N within <tool></tool> tags first to indicate that.
2. If a tool is needed, MUST respond a JSON object matching the following schema inside a <Func></Func> xml tag: 2. If a tool is needed, MUST respond a JSON object matching the following schema within <function></function> tags:
{{"name": $TOOL_NAME, "arguments": {{"$PARAMETER_NAME": "$PARAMETER_VALUE", ...}}}} {{"name": $TOOL_NAME, "arguments": {{"$PARAMETER_NAME": "$PARAMETER_VALUE", ...}}}}
3. If no tools is needed, respond with normal text.""" 3. If no tools is needed, respond with normal text."""
@@ -201,6 +209,7 @@ Please think if you need to use a tool or not for user's question, you must:
converted_messages = [] converted_messages = []
for message in chat_request.messages: for message in chat_request.messages:
if message.role == "system": if message.role == "system":
assert isinstance(message.content, str)
system_prompt += message.content + "\n" system_prompt += message.content + "\n"
elif message.role == "user" and not isinstance(message.content, str): elif message.role == "user" and not isinstance(message.content, str):
converted_messages.append( converted_messages.append(
@@ -243,9 +252,9 @@ Please think if you need to use a tool or not for user's question, you must:
system_prompt += self.tool_prompt.format(tools=tools_str) system_prompt += self.tool_prompt.format(tools=tools_str)
converted_messages.append({ converted_messages.append({
'role': 'assistant', 'role': 'assistant',
'content': '<Tool>' 'content': '<tool>'
}) })
args["stop_sequences"] = ['</Func>'] args["stop_sequences"] = ['</function>']
args["messages"] = self.merge_message(converted_messages) args["messages"] = self.merge_message(converted_messages)
if system_prompt: if system_prompt:
if DEBUG: if DEBUG:
@@ -267,10 +276,10 @@ Please think if you need to use a tool or not for user's question, you must:
tools = None tools = None
if chat_request.tools: if chat_request.tools:
if message.startswith("Y</Tool>"): if message.startswith("Y</tool>"):
tools = self._parse_tool_message(message) tools = self._parse_tool_message(message)
message = None message = None
elif message.startswith("N</Tool>"): elif message.startswith("N</tool>"):
message = message[8:].lstrip("\n") message = message[8:].lstrip("\n")
return self._create_response( return self._create_response(
model=chat_request.model, model=chat_request.model,
@@ -283,6 +292,8 @@ Please think if you need to use a tool or not for user's question, you must:
) )
def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]: def chat_stream(self, chat_request: ChatRequest) -> AsyncIterable[bytes]:
if DEBUG:
logger.info("Raw request: " + chat_request.model_dump_json())
response = self._invoke_model( response = self._invoke_model(
args=self._parse_args(chat_request), args=self._parse_args(chat_request),
model_id=chat_request.model, model_id=chat_request.model,
@@ -321,7 +332,7 @@ Please think if you need to use a tool or not for user's question, you must:
tool_message += chunk_message tool_message += chunk_message
continue continue
if index < 3: if index < 3:
# Ignore the N</Tool>, which is 3 tokens # Ignore the N</tool>, which is 3 tokens
index += 1 index += 1
continue continue
if first_token: if first_token:
@@ -350,7 +361,7 @@ Please think if you need to use a tool or not for user's question, you must:
if DEBUG: if DEBUG:
logger.info("Tool message: " + tool_message.replace("\n", " ")) logger.info("Tool message: " + tool_message.replace("\n", " "))
try: try:
tool_messages = tool_message[tool_message.rindex("<Func>") + 6:] tool_messages = tool_message[tool_message.rindex("<function>") + len("<function>"):]
function = json.loads(tool_messages.replace("\n", " ")) function = json.loads(tool_messages.replace("\n", " "))
args = json.dumps(function.get("arguments", {})) args = json.dumps(function.get("arguments", {}))
function = ResponseFunction( function = ResponseFunction(
@@ -365,7 +376,7 @@ Please think if you need to use a tool or not for user's question, you must:
] ]
except Exception as e: except Exception as e:
logger.error("Failed to parse tool response") logger.error("Failed to parse tool response" + str(e))
raise HTTPException(status_code=500, detail="Failed to parse tool response") raise HTTPException(status_code=500, detail="Failed to parse tool response")
def _get_base64_image(self, image_url: str) -> tuple[str, str]: def _get_base64_image(self, image_url: str) -> tuple[str, str]:
@@ -617,12 +628,20 @@ class BedrockEmbeddingsModel(BaseEmbeddingsModel, ABC):
if DEBUG: if DEBUG:
logger.info("Invoke Bedrock Model: " + model_id) logger.info("Invoke Bedrock Model: " + model_id)
logger.info("Bedrock request body: " + body) logger.info("Bedrock request body: " + body)
return bedrock_runtime.invoke_model( try:
body=body, return bedrock_runtime.invoke_model(
modelId=model_id, body=body,
accept=self.accept, modelId=model_id,
contentType=self.content_type, accept=self.accept,
) contentType=self.content_type,
)
except bedrock_runtime.exceptions.ValidationException as e:
print("Validation Exception")
print(e)
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(e)
raise HTTPException(status_code=500, detail=str(e))
def _create_response( def _create_response(
self, self,
@@ -739,31 +758,33 @@ def get_model(model_id: str) -> BedrockModel:
model_name = SUPPORTED_BEDROCK_MODELS.get(model_id, "") model_name = SUPPORTED_BEDROCK_MODELS.get(model_id, "")
if DEBUG: if DEBUG:
logger.info("model name is " + model_name) logger.info("model name is " + model_name)
if model_name in ["Claude Instant", "Claude", "Claude 3 Sonnet", "Claude 3 Haiku", "Claude 3 Opus"]: # Not using start_with here in case of complex scenarios.
return ClaudeModel() # The downside is to change this everytime for a new model supported.
elif model_name in ["Llama 2 Chat 13B", "Llama 2 Chat 70B"]: match model_name:
return Llama2Model() case "Claude Instant" | "Claude" | "Claude 3 Sonnet" | "Claude 3 Haiku" | "Claude 3 Opus":
elif model_name in ["Mistral 7B Instruct", "Mixtral 8x7B Instruct", "Mistral Large"]: return ClaudeModel()
return MistralModel() case "Llama 2 Chat 13B" | "Llama 2 Chat 70B":
else: return Llama2Model()
logger.error("Unsupported model id " + model_id) case "Mistral 7B Instruct" | "Mixtral 8x7B Instruct" | "Mistral Large":
raise HTTPException( return MistralModel()
status_code=500, case _:
detail="Unsupported model id " + model_id, logger.error("Unsupported model id " + model_id)
) raise HTTPException(
status_code=400,
detail="Unsupported model id " + model_id,
)
def get_embeddings_model(model_id: str) -> BedrockEmbeddingsModel: def get_embeddings_model(model_id: str) -> BedrockEmbeddingsModel:
model_name = SUPPORTED_BEDROCK_EMBEDDING_MODELS.get(model_id, "") model_name = SUPPORTED_BEDROCK_EMBEDDING_MODELS.get(model_id, "")
if DEBUG: if DEBUG:
logger.info("model name is " + model_name) logger.info("model name is " + model_name)
if model_name in ["Cohere Embed Multilingual", "Cohere Embed English"]: match model_name:
return CohereEmbeddingsModel() case "Cohere Embed Multilingual" | "Cohere Embed English":
elif model_name in ["Titan Embeddings G1 - Text", "Titan Multimodal Embeddings G1"]: return CohereEmbeddingsModel()
return TitanEmbeddingsModel() case _:
else: logger.error("Unsupported model id " + model_id)
logger.error("Unsupported model id " + model_id) raise HTTPException(
raise HTTPException( status_code=400,
status_code=500, detail="Unsupported embedding model id " + model_id,
detail="Unsupported model id " + model_id, )
)

View File

@@ -1,15 +1,13 @@
from typing import Annotated from typing import Annotated
from fastapi import APIRouter, Depends, Body, HTTPException from fastapi import APIRouter, Depends, Body
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from api.auth import api_key_auth from api.auth import api_key_auth
from api.models import get_model, SUPPORTED_BEDROCK_MODELS from api.models import get_model
from api.schema import ChatRequest, ChatResponse, ChatStreamResponse from api.schema import ChatRequest, ChatResponse, ChatStreamResponse
from api.setting import DEFAULT_MODEL from api.setting import DEFAULT_MODEL
router = APIRouter()
router = APIRouter( router = APIRouter(
prefix="/chat", prefix="/chat",
dependencies=[Depends(api_key_auth)], dependencies=[Depends(api_key_auth)],
@@ -36,15 +34,11 @@ async def chat_completions(
): ):
if chat_request.model.lower().startswith("gpt-"): if chat_request.model.lower().startswith("gpt-"):
chat_request.model = DEFAULT_MODEL chat_request.model = DEFAULT_MODEL
if chat_request.model not in SUPPORTED_BEDROCK_MODELS.keys():
raise HTTPException(status_code=400, detail="Unsupported Model Id " + chat_request.model) # Exception will be raised if model not supported.
try: model = get_model(chat_request.model)
model = get_model(chat_request.model) if chat_request.stream:
return StreamingResponse(
if chat_request.stream: content=model.chat_stream(chat_request), media_type="text/event-stream"
return StreamingResponse( )
content=model.chat_stream(chat_request), media_type="text/event-stream" return model.chat(chat_request)
)
return model.chat(chat_request)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))

View File

@@ -1,14 +1,12 @@
from typing import Annotated from typing import Annotated
from fastapi import APIRouter, Depends, Body, HTTPException from fastapi import APIRouter, Depends, Body
from api.auth import api_key_auth from api.auth import api_key_auth
from api.models import get_embeddings_model, SUPPORTED_BEDROCK_EMBEDDING_MODELS from api.models import get_embeddings_model
from api.schema import EmbeddingsRequest, EmbeddingsResponse from api.schema import EmbeddingsRequest, EmbeddingsResponse
from api.setting import DEFAULT_EMBEDDING_MODEL from api.setting import DEFAULT_EMBEDDING_MODEL
router = APIRouter()
router = APIRouter( router = APIRouter(
prefix="/embeddings", prefix="/embeddings",
dependencies=[Depends(api_key_auth)], dependencies=[Depends(api_key_auth)],
@@ -33,10 +31,6 @@ async def embeddings(
): ):
if embeddings_request.model.lower().startswith("text-embedding-"): if embeddings_request.model.lower().startswith("text-embedding-"):
embeddings_request.model = DEFAULT_EMBEDDING_MODEL embeddings_request.model = DEFAULT_EMBEDDING_MODEL
if embeddings_request.model not in SUPPORTED_BEDROCK_EMBEDDING_MODELS.keys(): # Exception will be raised if model not supported.
raise HTTPException(status_code=400, detail="Unsupported Model Id " + embeddings_request.model) model = get_embeddings_model(embeddings_request.model)
try: return model.embed(embeddings_request)
model = get_embeddings_model(embeddings_request.model)
return model.embed(embeddings_request)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))

View File

@@ -6,8 +6,6 @@ from api.auth import api_key_auth
from api.models import SUPPORTED_BEDROCK_MODELS, SUPPORTED_BEDROCK_EMBEDDING_MODELS from api.models import SUPPORTED_BEDROCK_MODELS, SUPPORTED_BEDROCK_EMBEDDING_MODELS
from api.schema import Models, Model from api.schema import Models, Model
router = APIRouter()
router = APIRouter( router = APIRouter(
prefix="/models", prefix="/models",
dependencies=[Depends(api_key_auth)], dependencies=[Depends(api_key_auth)],