Add Llama 3 support

This commit is contained in:
Aiden Dai
2024-04-23 10:30:57 +08:00
parent 7416f9a4e2
commit 0512a9b8cc
2 changed files with 44 additions and 6 deletions

View File

@@ -48,6 +48,8 @@ SUPPORTED_BEDROCK_MODELS = {
"anthropic.claude-3-haiku-20240307-v1:0": "Claude 3 Haiku",
"meta.llama2-13b-chat-v1": "Llama 2 Chat 13B",
"meta.llama2-70b-chat-v1": "Llama 2 Chat 70B",
"meta.llama3-8b-instruct-v1:0": "Llama 3 8B Instruct",
"meta.llama3-70b-instruct-v1:0": "Llama 3 70B Instruct",
"mistral.mistral-7b-instruct-v0:2": "Mistral 7B Instruct",
"mistral.mixtral-8x7b-instruct-v0:1": "Mixtral 8x7B Instruct",
"mistral.mistral-large-2402-v1:0": "Mistral Large",
@@ -433,10 +435,40 @@ Please think if you need to use a tool or not for user's question, you must:
return content_parts
class Llama2Model(BedrockModel):
class LlamaModel(BedrockModel):
def _convert_prompt(self, chat_request: ChatRequest) -> str:
"""Create a prompt message follow below example:
@staticmethod
def create_llama3_prompt(chat_request: ChatRequest) -> str:
"""Create a prompt message for Llama 3 following below example:
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
{{ system_prompt }}<|eot_id|><|start_header_id|>user<|end_header_id|>
{{ user_message_1 }}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
{{ model_answer_1 }}<|eot_id|><|start_header_id|>user<|end_header_id|>
{{ user_message_2 }}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
"""
if DEBUG:
logger.info("Convert below messages to prompt for Llama 3: ")
for msg in chat_request.messages:
logger.info(msg.model_dump_json())
bos_token = "<|begin_of_text|>"
prompt_lines = []
for msg in chat_request.messages:
prompt_lines.append(f"<|start_header_id|>{msg.role}<|end_header_id|>\n\n{msg.content}<|eot_id|>")
prompt_lines.append(f"<|start_header_id|>assistant<|end_header_id|>\n\n")
prompt = bos_token + "".join(prompt_lines)
if DEBUG:
logger.info("Converted prompt: " + prompt.replace("\n", "\\n"))
return prompt
@staticmethod
def create_llama2_prompt(chat_request: ChatRequest) -> str:
"""Create a prompt message for Llama 2 following below example:
<s>[INST] <<SYS>>\n{your_system_message}\n<</SYS>>\n\n{user_message_1} [/INST] {model_reply_1}</s>
<s>[INST] {user_message_2} [/INST]
@@ -480,7 +512,11 @@ class Llama2Model(BedrockModel):
return prompt
def _parse_args(self, chat_request: ChatRequest) -> dict:
prompt = self._convert_prompt(chat_request)
if chat_request.model.startswith("meta.llama2"):
prompt = self.create_llama2_prompt(chat_request)
else:
prompt = self.create_llama3_prompt(chat_request)
# Currently, there is no way to set stop sequence for Llama 3 models.
return {
"prompt": prompt,
"max_gen_len": chat_request.max_tokens,
@@ -763,8 +799,8 @@ def get_model(model_id: str) -> BedrockModel:
match model_name:
case "Claude Instant" | "Claude" | "Claude 3 Sonnet" | "Claude 3 Haiku" | "Claude 3 Opus":
return ClaudeModel()
case "Llama 2 Chat 13B" | "Llama 2 Chat 70B":
return Llama2Model()
case "Llama 2 Chat 13B" | "Llama 2 Chat 70B" | "Llama 3 8B Instruct" | "Llama 3 70B Instruct":
return LlamaModel()
case "Mistral 7B Instruct" | "Mixtral 8x7B Instruct" | "Mistral Large":
return MistralModel()
case _: