diff --git a/src/api/models/bedrock.py b/src/api/models/bedrock.py index 28c4b69..e49c82c 100644 --- a/src/api/models/bedrock.py +++ b/src/api/models/bedrock.py @@ -48,6 +48,8 @@ SUPPORTED_BEDROCK_MODELS = { "anthropic.claude-3-haiku-20240307-v1:0": "Claude 3 Haiku", "meta.llama2-13b-chat-v1": "Llama 2 Chat 13B", "meta.llama2-70b-chat-v1": "Llama 2 Chat 70B", + "meta.llama3-8b-instruct-v1:0": "Llama 3 8B Instruct", + "meta.llama3-70b-instruct-v1:0": "Llama 3 70B Instruct", "mistral.mistral-7b-instruct-v0:2": "Mistral 7B Instruct", "mistral.mixtral-8x7b-instruct-v0:1": "Mixtral 8x7B Instruct", "mistral.mistral-large-2402-v1:0": "Mistral Large", @@ -433,10 +435,40 @@ Please think if you need to use a tool or not for user's question, you must: return content_parts -class Llama2Model(BedrockModel): +class LlamaModel(BedrockModel): - def _convert_prompt(self, chat_request: ChatRequest) -> str: - """Create a prompt message follow below example: + @staticmethod + def create_llama3_prompt(chat_request: ChatRequest) -> str: + """Create a prompt message for Llama 3 following below example: + + <|begin_of_text|><|start_header_id|>system<|end_header_id|> + + {{ system_prompt }}<|eot_id|><|start_header_id|>user<|end_header_id|> + + {{ user_message_1 }}<|eot_id|><|start_header_id|>assistant<|end_header_id|> + + {{ model_answer_1 }}<|eot_id|><|start_header_id|>user<|end_header_id|> + + {{ user_message_2 }}<|eot_id|><|start_header_id|>assistant<|end_header_id|> + """ + if DEBUG: + logger.info("Convert below messages to prompt for Llama 3: ") + for msg in chat_request.messages: + logger.info(msg.model_dump_json()) + bos_token = "<|begin_of_text|>" + + prompt_lines = [] + for msg in chat_request.messages: + prompt_lines.append(f"<|start_header_id|>{msg.role}<|end_header_id|>\n\n{msg.content}<|eot_id|>") + prompt_lines.append(f"<|start_header_id|>assistant<|end_header_id|>\n\n") + prompt = bos_token + "".join(prompt_lines) + if DEBUG: + logger.info("Converted prompt: " + prompt.replace("\n", "\\n")) + return prompt + + @staticmethod + def create_llama2_prompt(chat_request: ChatRequest) -> str: + """Create a prompt message for Llama 2 following below example: [INST] <>\n{your_system_message}\n<>\n\n{user_message_1} [/INST] {model_reply_1} [INST] {user_message_2} [/INST] @@ -480,7 +512,11 @@ class Llama2Model(BedrockModel): return prompt def _parse_args(self, chat_request: ChatRequest) -> dict: - prompt = self._convert_prompt(chat_request) + if chat_request.model.startswith("meta.llama2"): + prompt = self.create_llama2_prompt(chat_request) + else: + prompt = self.create_llama3_prompt(chat_request) + # Currently, there is no way to set stop sequence for Llama 3 models. return { "prompt": prompt, "max_gen_len": chat_request.max_tokens, @@ -763,8 +799,8 @@ def get_model(model_id: str) -> BedrockModel: match model_name: case "Claude Instant" | "Claude" | "Claude 3 Sonnet" | "Claude 3 Haiku" | "Claude 3 Opus": return ClaudeModel() - case "Llama 2 Chat 13B" | "Llama 2 Chat 70B": - return Llama2Model() + case "Llama 2 Chat 13B" | "Llama 2 Chat 70B" | "Llama 3 8B Instruct" | "Llama 3 70B Instruct": + return LlamaModel() case "Mistral 7B Instruct" | "Mixtral 8x7B Instruct" | "Mistral Large": return MistralModel() case _: diff --git a/src/api/setting.py b/src/api/setting.py index 7924245..bf056d8 100644 --- a/src/api/setting.py +++ b/src/api/setting.py @@ -21,6 +21,8 @@ List of Amazon Bedrock models currently supported: - anthropic.claude-3-haiku-20240307-v1:0 - meta.llama2-13b-chat-v1 - meta.llama2-70b-chat-v1 +- meta.llama3-8b-instruct-v1:0 +- meta.llama3-70b-instruct-v1:0 - mistral.mistral-7b-instruct-v0:2 - mistral.mixtral-8x7b-instruct-v0:1 - mistral.mistral-large-2402-v1:0