Skip to content

Commit

Permalink
lint, fix templating, default version
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Sep 10, 2024
1 parent 5822e18 commit 3fc0bae
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 12 deletions.
7 changes: 5 additions & 2 deletions src/axolotl/cli/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from axolotl.common.cli import PreprocessCliArgs
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.prompt_strategies.sharegpt import register_chat_template

LOG = logging.getLogger("axolotl.cli.preprocess")


Expand All @@ -38,9 +39,11 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)

if parsed_cfg.chat_template:
register_chat_template(parsed_cfg.chat_template, parsed_cfg.default_system_message)
register_chat_template(
parsed_cfg.chat_template, parsed_cfg.default_system_message
)

if not parsed_cfg.dataset_prepared_path:
msg = (
Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
print_axolotl_text_art()
check_accelerate_default_config()
check_user_token()

if cfg.chat_template:
register_chat_template(cfg.chat_template, cfg.default_system_message)

Expand Down
21 changes: 15 additions & 6 deletions src/axolotl/prompt_strategies/sharegpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,18 @@ def register_chatml_template(system_message=None):
)


def register_llama3x_template(version, system_message=None):
extra_system_content_to_insert = "Cutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n" if version == "3.1" else ""
def register_llama3x_template(version="", system_message=None):
extra_system_content_to_insert = (
"Cutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n"
if version == "3.1"
else ""
)
register_conv_template(
Conversation(
name=f"llama{version.replace('.', '')}",
system_template="<|start_header_id|>system<|end_header_id|>\n\n{extra_system_content_to_insert}{system_message}<|eot_id|>",
system_template="<|start_header_id|>system<|end_header_id|>\n\n"
+ extra_system_content_to_insert
+ "{system_message}<|eot_id|>",
system_message=system_message,
roles=("user", "assistant"),
sep_style=SeparatorStyle.LLAMA3,
Expand All @@ -53,7 +59,7 @@ def register_llama3x_template(version, system_message=None):
stop_token_ids=[128001, 128009],
)
)


def build_loader(
tokenization_strategy_cls: Type["ShareGPTPromptTokenizingStrategy"],
Expand Down Expand Up @@ -92,18 +98,21 @@ def _load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):

return _load


def register_chat_template(chat_template, default_system_message=None):
if chat_template == "chatml" or chat_template.startswith("llama"):
if default_system_message:
LOG.info(f"Adding default system message: {default_system_message}")

if chat_template == "chatml":
register_chatml_template(default_system_message)
LOG.info(f"Using ChatML template")
LOG.info("Using ChatML template")
else:
version = ".".join(list(chat_template.split("llama"))[1])
if version not in ["3", "3.1"]:
raise ValueError(f"Invalid name for Llama 3x template: {chat_template}, only llama3 and llama31 are supported")
raise ValueError(
f"Invalid name for Llama 3x template: {chat_template}, only llama3 and llama31 are supported"
)
LOG.info(f"Using Llama {version} template")
register_llama3x_template(version, default_system_message)

Expand Down
Loading

0 comments on commit 3fc0bae

Please sign in to comment.