Add Llama 3 support
This commit is contained in:
@@ -48,6 +48,8 @@ SUPPORTED_BEDROCK_MODELS = {
|
|||||||
"anthropic.claude-3-haiku-20240307-v1:0": "Claude 3 Haiku",
|
"anthropic.claude-3-haiku-20240307-v1:0": "Claude 3 Haiku",
|
||||||
"meta.llama2-13b-chat-v1": "Llama 2 Chat 13B",
|
"meta.llama2-13b-chat-v1": "Llama 2 Chat 13B",
|
||||||
"meta.llama2-70b-chat-v1": "Llama 2 Chat 70B",
|
"meta.llama2-70b-chat-v1": "Llama 2 Chat 70B",
|
||||||
|
"meta.llama3-8b-instruct-v1:0": "Llama 3 8B Instruct",
|
||||||
|
"meta.llama3-70b-instruct-v1:0": "Llama 3 70B Instruct",
|
||||||
"mistral.mistral-7b-instruct-v0:2": "Mistral 7B Instruct",
|
"mistral.mistral-7b-instruct-v0:2": "Mistral 7B Instruct",
|
||||||
"mistral.mixtral-8x7b-instruct-v0:1": "Mixtral 8x7B Instruct",
|
"mistral.mixtral-8x7b-instruct-v0:1": "Mixtral 8x7B Instruct",
|
||||||
"mistral.mistral-large-2402-v1:0": "Mistral Large",
|
"mistral.mistral-large-2402-v1:0": "Mistral Large",
|
||||||
@@ -433,10 +435,40 @@ Please think if you need to use a tool or not for user's question, you must:
|
|||||||
return content_parts
|
return content_parts
|
||||||
|
|
||||||
|
|
||||||
class Llama2Model(BedrockModel):
|
class LlamaModel(BedrockModel):
|
||||||
|
|
||||||
def _convert_prompt(self, chat_request: ChatRequest) -> str:
|
@staticmethod
|
||||||
"""Create a prompt message follow below example:
|
def create_llama3_prompt(chat_request: ChatRequest) -> str:
|
||||||
|
"""Create a prompt message for Llama 3 following below example:
|
||||||
|
|
||||||
|
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
|
||||||
|
|
||||||
|
{{ system_prompt }}<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||||
|
|
||||||
|
{{ user_message_1 }}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||||
|
|
||||||
|
{{ model_answer_1 }}<|eot_id|><|start_header_id|>user<|end_header_id|>
|
||||||
|
|
||||||
|
{{ user_message_2 }}<|eot_id|><|start_header_id|>assistant<|end_header_id|>
|
||||||
|
"""
|
||||||
|
if DEBUG:
|
||||||
|
logger.info("Convert below messages to prompt for Llama 3: ")
|
||||||
|
for msg in chat_request.messages:
|
||||||
|
logger.info(msg.model_dump_json())
|
||||||
|
bos_token = "<|begin_of_text|>"
|
||||||
|
|
||||||
|
prompt_lines = []
|
||||||
|
for msg in chat_request.messages:
|
||||||
|
prompt_lines.append(f"<|start_header_id|>{msg.role}<|end_header_id|>\n\n{msg.content}<|eot_id|>")
|
||||||
|
prompt_lines.append(f"<|start_header_id|>assistant<|end_header_id|>\n\n")
|
||||||
|
prompt = bos_token + "".join(prompt_lines)
|
||||||
|
if DEBUG:
|
||||||
|
logger.info("Converted prompt: " + prompt.replace("\n", "\\n"))
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_llama2_prompt(chat_request: ChatRequest) -> str:
|
||||||
|
"""Create a prompt message for Llama 2 following below example:
|
||||||
|
|
||||||
<s>[INST] <<SYS>>\n{your_system_message}\n<</SYS>>\n\n{user_message_1} [/INST] {model_reply_1}</s>
|
<s>[INST] <<SYS>>\n{your_system_message}\n<</SYS>>\n\n{user_message_1} [/INST] {model_reply_1}</s>
|
||||||
<s>[INST] {user_message_2} [/INST]
|
<s>[INST] {user_message_2} [/INST]
|
||||||
@@ -480,7 +512,11 @@ class Llama2Model(BedrockModel):
|
|||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
def _parse_args(self, chat_request: ChatRequest) -> dict:
|
def _parse_args(self, chat_request: ChatRequest) -> dict:
|
||||||
prompt = self._convert_prompt(chat_request)
|
if chat_request.model.startswith("meta.llama2"):
|
||||||
|
prompt = self.create_llama2_prompt(chat_request)
|
||||||
|
else:
|
||||||
|
prompt = self.create_llama3_prompt(chat_request)
|
||||||
|
# Currently, there is no way to set stop sequence for Llama 3 models.
|
||||||
return {
|
return {
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
"max_gen_len": chat_request.max_tokens,
|
"max_gen_len": chat_request.max_tokens,
|
||||||
@@ -763,8 +799,8 @@ def get_model(model_id: str) -> BedrockModel:
|
|||||||
match model_name:
|
match model_name:
|
||||||
case "Claude Instant" | "Claude" | "Claude 3 Sonnet" | "Claude 3 Haiku" | "Claude 3 Opus":
|
case "Claude Instant" | "Claude" | "Claude 3 Sonnet" | "Claude 3 Haiku" | "Claude 3 Opus":
|
||||||
return ClaudeModel()
|
return ClaudeModel()
|
||||||
case "Llama 2 Chat 13B" | "Llama 2 Chat 70B":
|
case "Llama 2 Chat 13B" | "Llama 2 Chat 70B" | "Llama 3 8B Instruct" | "Llama 3 70B Instruct":
|
||||||
return Llama2Model()
|
return LlamaModel()
|
||||||
case "Mistral 7B Instruct" | "Mixtral 8x7B Instruct" | "Mistral Large":
|
case "Mistral 7B Instruct" | "Mixtral 8x7B Instruct" | "Mistral Large":
|
||||||
return MistralModel()
|
return MistralModel()
|
||||||
case _:
|
case _:
|
||||||
|
|||||||
@@ -21,6 +21,8 @@ List of Amazon Bedrock models currently supported:
|
|||||||
- anthropic.claude-3-haiku-20240307-v1:0
|
- anthropic.claude-3-haiku-20240307-v1:0
|
||||||
- meta.llama2-13b-chat-v1
|
- meta.llama2-13b-chat-v1
|
||||||
- meta.llama2-70b-chat-v1
|
- meta.llama2-70b-chat-v1
|
||||||
|
- meta.llama3-8b-instruct-v1:0
|
||||||
|
- meta.llama3-70b-instruct-v1:0
|
||||||
- mistral.mistral-7b-instruct-v0:2
|
- mistral.mistral-7b-instruct-v0:2
|
||||||
- mistral.mixtral-8x7b-instruct-v0:1
|
- mistral.mixtral-8x7b-instruct-v0:1
|
||||||
- mistral.mistral-large-2402-v1:0
|
- mistral.mistral-large-2402-v1:0
|
||||||
|
|||||||
Reference in New Issue
Block a user