Add Llama 3 support
This commit is contained in:
@@ -48,6 +48,8 @@ SUPPORTED_BEDROCK_MODELS = {
|
||||
"anthropic.claude-3-haiku-20240307-v1:0": "Claude 3 Haiku",
|
||||
"meta.llama2-13b-chat-v1": "Llama 2 Chat 13B",
|
||||
"meta.llama2-70b-chat-v1": "Llama 2 Chat 70B",
|
||||
"meta.llama3-8b-instruct-v1:0": "Llama 3 8B Instruct",
|
||||
"meta.llama3-70b-instruct-v1:0": "Llama 3 70B Instruct",
|
||||
"mistral.mistral-7b-instruct-v0:2": "Mistral 7B Instruct",
|
||||
"mistral.mixtral-8x7b-instruct-v0:1": "Mixtral 8x7B Instruct",
|
||||
"mistral.mistral-large-2402-v1:0": "Mistral Large",
|
||||
@@ -433,10 +435,40 @@ Please think if you need to use a tool or not for user's question, you must:
|
||||
return content_parts
|
||||
|
||||
|
||||
class Llama2Model(BedrockModel):
|
||||
class LlamaModel(BedrockModel):
|
||||
|
||||
def _convert_prompt(self, chat_request: ChatRequest) -> str:
|
||||
"""Create a prompt message follow below example:
|
||||
@staticmethod
|
||||
def create_llama3_prompt(chat_request: ChatRequest) -> str:
|
||||
"""Create a prompt message for Llama 3 following below example:
|
||||
|
||||
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
||||
|
||||
{{ system_prompt }}<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
{{ user_message_1 }}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
|
||||
{{ model_answer_1 }}<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||
|
||||
{{ user_message_2 }}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||
"""
|
||||
if DEBUG:
|
||||
logger.info("Convert below messages to prompt for Llama 3: ")
|
||||
for msg in chat_request.messages:
|
||||
logger.info(msg.model_dump_json())
|
||||
bos_token = "<|begin_of_text|>"
|
||||
|
||||
prompt_lines = []
|
||||
for msg in chat_request.messages:
|
||||
prompt_lines.append(f"<|start_header_id|>{msg.role}<|end_header_id|>\n\n{msg.content}<|eot_id|>")
|
||||
prompt_lines.append(f"<|start_header_id|>assistant<|end_header_id|>\n\n")
|
||||
prompt = bos_token + "".join(prompt_lines)
|
||||
if DEBUG:
|
||||
logger.info("Converted prompt: " + prompt.replace("\n", "\\n"))
|
||||
return prompt
|
||||
|
||||
@staticmethod
|
||||
def create_llama2_prompt(chat_request: ChatRequest) -> str:
|
||||
"""Create a prompt message for Llama 2 following below example:
|
||||
|
||||
<s>[INST] <<SYS>>\n{your_system_message}\n<</SYS>>\n\n{user_message_1} [/INST] {model_reply_1}</s>
|
||||
<s>[INST] {user_message_2} [/INST]
|
||||
@@ -480,7 +512,11 @@ class Llama2Model(BedrockModel):
|
||||
return prompt
|
||||
|
||||
def _parse_args(self, chat_request: ChatRequest) -> dict:
|
||||
prompt = self._convert_prompt(chat_request)
|
||||
if chat_request.model.startswith("meta.llama2"):
|
||||
prompt = self.create_llama2_prompt(chat_request)
|
||||
else:
|
||||
prompt = self.create_llama3_prompt(chat_request)
|
||||
# Currently, there is no way to set stop sequence for Llama 3 models.
|
||||
return {
|
||||
"prompt": prompt,
|
||||
"max_gen_len": chat_request.max_tokens,
|
||||
@@ -763,8 +799,8 @@ def get_model(model_id: str) -> BedrockModel:
|
||||
match model_name:
|
||||
case "Claude Instant" | "Claude" | "Claude 3 Sonnet" | "Claude 3 Haiku" | "Claude 3 Opus":
|
||||
return ClaudeModel()
|
||||
case "Llama 2 Chat 13B" | "Llama 2 Chat 70B":
|
||||
return Llama2Model()
|
||||
case "Llama 2 Chat 13B" | "Llama 2 Chat 70B" | "Llama 3 8B Instruct" | "Llama 3 70B Instruct":
|
||||
return LlamaModel()
|
||||
case "Mistral 7B Instruct" | "Mixtral 8x7B Instruct" | "Mistral Large":
|
||||
return MistralModel()
|
||||
case _:
|
||||
|
||||
@@ -21,6 +21,8 @@ List of Amazon Bedrock models currently supported:
|
||||
- anthropic.claude-3-haiku-20240307-v1:0
|
||||
- meta.llama2-13b-chat-v1
|
||||
- meta.llama2-70b-chat-v1
|
||||
- meta.llama3-8b-instruct-v1:0
|
||||
- meta.llama3-70b-instruct-v1:0
|
||||
- mistral.mistral-7b-instruct-v0:2
|
||||
- mistral.mixtral-8x7b-instruct-v0:1
|
||||
- mistral.mistral-large-2402-v1:0
|
||||
|
||||
Reference in New Issue
Block a user