Skip to content

Commit

Permalink
lint, fix templating, default version
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Oct 14, 2024
1 parent ff8db56 commit ab41470
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 9 deletions.
4 changes: 3 additions & 1 deletion src/axolotl/cli/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
)

if parsed_cfg.chat_template:
register_chat_template(parsed_cfg.chat_template, parsed_cfg.default_system_message)
register_chat_template(
parsed_cfg.chat_template, parsed_cfg.default_system_message
)

if not parsed_cfg.dataset_prepared_path:
msg = (
Expand Down
19 changes: 14 additions & 5 deletions src/axolotl/prompt_strategies/sharegpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,18 @@ def register_chatml_template(system_message=None):
)


def register_llama3x_template(version, system_message=None):
extra_system_content_to_insert = "Cutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n" if version == "3.1" else ""
def register_llama3x_template(version="", system_message=None):
extra_system_content_to_insert = (
"Cutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n"
if version == "3.1"
else ""
)
register_conv_template(
Conversation(
name=f"llama{version.replace('.', '')}",
system_template="<|start_header_id|>system<|end_header_id|>\n\n{extra_system_content_to_insert}{system_message}<|eot_id|>",
system_template="<|start_header_id|>system<|end_header_id|>\n\n"
+ extra_system_content_to_insert
+ "{system_message}<|eot_id|>",
system_message=system_message,
roles=("user", "assistant"),
sep_style=SeparatorStyle.LLAMA3,
Expand Down Expand Up @@ -95,18 +101,21 @@ def _load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):

return _load


def register_chat_template(chat_template, default_system_message=None):
if chat_template == "chatml" or chat_template.startswith("llama"):
if default_system_message:
LOG.info(f"Adding default system message: {default_system_message}")

if chat_template == "chatml":
register_chatml_template(default_system_message)
LOG.info(f"Using ChatML template")
LOG.info("Using ChatML template")
else:
version = ".".join(list(chat_template.split("llama"))[1])
if version not in ["3", "3.1"]:
raise ValueError(f"Invalid name for Llama 3x template: {chat_template}, only llama3 and llama31 are supported")
raise ValueError(
f"Invalid name for Llama 3x template: {chat_template}, only llama3 and llama31 are supported"
)
LOG.info(f"Using Llama {version} template")
register_llama3x_template(version, default_system_message)

Expand Down
Loading

0 comments on commit ab41470

Please sign in to comment.