diff --git a/src/instructlab/training/chat_templates/ibm_generic_tmpl.py b/src/instructlab/training/chat_templates/ibm_generic_tmpl.py new file mode 100644 index 00000000..87bfdb0a --- /dev/null +++ b/src/instructlab/training/chat_templates/ibm_generic_tmpl.py @@ -0,0 +1,24 @@ +# First Party +from instructlab.training.tokenizer_utils import SpecialTokens + +SPECIAL_TOKENS = SpecialTokens( + system="<|system|>", + user="<|user|>", + assistant="<|assistant|>", + eos="<|endoftext|>", + pad="<|pad|>", +) + +CHAT_TEMPLATE = ( + "{% for message in messages %}" + "{% if message['role'] == 'pretraining' %}" + "{{'<|endoftext|>' + message['content'] + '<|endoftext|>'}}" + "{% elif message['role'] == 'system' %}" + "{{'<|system|>'+ '\n' + message['content'] + '\n'}}" + "{% elif message['role'] == 'user' %}" + "{{'<|user|>' + '\n' + message['content'] + '\n'}}" + "{% elif message['role'] == 'assistant' %}" + "{{'<|assistant|>' + '\n' + message['content'] + '<|endoftext|>' + ('' if loop.last else '\n')}}" + "{% endif %}" + "{% endfor %}" +) diff --git a/src/instructlab/training/chat_templates/mistral_tmpl.py b/src/instructlab/training/chat_templates/mistral_tmpl.py new file mode 100644 index 00000000..965823f2 --- /dev/null +++ b/src/instructlab/training/chat_templates/mistral_tmpl.py @@ -0,0 +1,22 @@ +# First Party +from instructlab.training.tokenizer_utils import SpecialTokens + +SPECIAL_TOKENS = SpecialTokens( + bos="", + eos="", + user="[INST]", + assistant="[/INST]", +) + +CHAT_TEMPLATE = ( + "{{ '' }}" + "{% for message in messages %}" + "{% if message['role'] == 'pretraining' %}" + "{{ message['content'] + '' }}" + "{% elif message['role'] == 'user' %}" + "{{ '[INST] ' + message['content'] + ' [/INST]' }}" + "{% elif message['role'] == 'assistant' %}" + "{{ message['content'] + ''}}" + "{% endif %}" + "{% endfor %}" +) diff --git a/src/instructlab/training/config.py b/src/instructlab/training/config.py index c8733a45..83c7a1f8 100644 --- a/src/instructlab/training/config.py +++ b/src/instructlab/training/config.py @@ -4,6 +4,7 @@ # Standard from enum import Enum +import os # Third Party from pydantic import BaseModel, ConfigDict, Field @@ -42,6 +43,7 @@ class DataProcessArgs(BaseModel): data_output_path: str max_seq_len: int # defines the max sequence length of a sample model_path: str # either a HF model name or path to HF model + chat_tmpl_path: str # disable the protected namespace for the model_config field model_config = ConfigDict(protected_namespaces=()) @@ -100,6 +102,11 @@ class TrainingArgs(BaseModel): # Either the name of a HuggingFace model or a path to a model saved in HuggingFace format. model_path: str + # Specify the chat template / special tokens for training (default is ibm-generic template/tokens) + chat_tmpl_path: str = os.path.join( + os.path.dirname(__file__), "chat_templates/ibm_generic_tmpl.py" + ) + # this field specifies the filepath to the training dataset before processing data_path: str ckpt_output_dir: str diff --git a/src/instructlab/training/data_process.py b/src/instructlab/training/data_process.py index f8f43535..9301d185 100644 --- a/src/instructlab/training/data_process.py +++ b/src/instructlab/training/data_process.py @@ -2,6 +2,7 @@ from pathlib import Path from typing import List import logging +import os # Third Party from datasets import load_dataset @@ -10,12 +11,8 @@ # First Party from instructlab.training.config import DataProcessArgs -from instructlab.training.tokenizer_utils import ( - SPECIAL_TOKENS, - get_sp_token, - setup_tokenizer, -) -from instructlab.training.utils import log_rank_0, setup_logger +from instructlab.training.tokenizer_utils import get_sp_token, setup_tokenizer +from instructlab.training.utils import log_rank_0, retrieve_chat_template, setup_logger def check_valid_sample( @@ -37,18 +34,6 @@ def check_valid_sample( if not any(token in whole_sentence_tk for token in special_tokens): return True - # first token should be system_token - if whole_sentence_tk[0] != system_tk: - print("\033[91mfirst token is not a system_token\033[0m") - log_rank_0(tokenizer.decode(whole_sentence_tk), to_print=True) - return False - - # check there's only one system_token - if (np.array(whole_sentence_tk) == system_tk).sum() != 1: - print("\033[91mthere are more than one system_token\033[0m") - log_rank_0(tokenizer.decode(whole_sentence_tk), to_print=True) - return False - whole_sentence_tk = np.array(whole_sentence_tk) user_token_index = (whole_sentence_tk == user_tk).nonzero()[0] assistant_token_index = (whole_sentence_tk == assistant_tk).nonzero()[0] @@ -121,7 +106,11 @@ def unmask_only_assistant_responses( whole_sentence = chosen_token["input_ids"][:sentence_legth].clone() # pre-training mode - if system_tk not in whole_sentence: + if not ( + system_tk in whole_sentence + or user_token in whole_sentence + or assist_token in whole_sentence + ): return labels labels[:sentence_legth] = -100 @@ -204,11 +193,15 @@ def remove_pretrain_system_messages(example: dict): def main(args: DataProcessArgs): - tokenizer = setup_tokenizer(args.model_path) + CHAT_TEMPLATE, SPECIAL_TOKENS = retrieve_chat_template(args.chat_tmpl_path) + tokenizer = setup_tokenizer(args.model_path, SPECIAL_TOKENS, CHAT_TEMPLATE) eos_tk = get_sp_token(tokenizer, SPECIAL_TOKENS.eos) pad_tk = get_sp_token(tokenizer, SPECIAL_TOKENS.pad) - system_tk = get_sp_token(tokenizer, SPECIAL_TOKENS.system) + if SPECIAL_TOKENS.system: + system_tk = get_sp_token(tokenizer, SPECIAL_TOKENS.system) + else: + system_tk = None user_tk = get_sp_token(tokenizer, SPECIAL_TOKENS.user) assistant_tk = get_sp_token(tokenizer, SPECIAL_TOKENS.assistant) log_rank_0( @@ -309,6 +302,14 @@ def main(args: DataProcessArgs): parser.add_argument( "--model_name_or_path", type=str, required=True, help="Model name or path" ) + parser.add_argument( + "--chat-tmpl-path", + type=str, + default=os.path.join( + os.path.dirname(__file__), "chat_templates/ibm_generic_tmpl.py" + ), + help="Path to desired chat template and special tokens, defaults to IBM generic.", + ) args = parser.parse_args() setup_logger(args.logging_level) data_process_args = DataProcessArgs( @@ -316,6 +317,7 @@ def main(args: DataProcessArgs): data_path=args.data_path, max_seq_len=args.max_seq_len, model_path=args.model_name_or_path, + chat_tmpl_path=args.chat_tmpl_path, ) main(data_process_args) diff --git a/src/instructlab/training/main_ds.py b/src/instructlab/training/main_ds.py index eeb0c077..1eb21a68 100644 --- a/src/instructlab/training/main_ds.py +++ b/src/instructlab/training/main_ds.py @@ -37,6 +37,7 @@ patch_target_module, prepare_peft_model, prepare_universal_checkpoint_from_latest, + retrieve_chat_template, save_hf_format_ds, save_model_ds_native, set_random_seed, @@ -438,7 +439,8 @@ def main(args): print(f"\033[38;5;120m{yaml.dump(vars(args), sort_keys=False)}\033[0m") setup_logger(args.log_level) - tokenizer = setup_tokenizer(args.model_name_or_path) + CHAT_TEMPLATE, SPECIAL_TOKENS = retrieve_chat_template(args.chat_tmpl_path) + tokenizer = setup_tokenizer(args.model_name_or_path, SPECIAL_TOKENS, CHAT_TEMPLATE) # device = torch.device("cuda", args.local_rank) #### distributed init ##### @@ -522,6 +524,7 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs): model_path=train_args.model_path, data_path=train_args.data_path, max_seq_len=train_args.max_seq_len, + chat_tmpl_path=train_args.chat_tmpl_path, ) ) @@ -546,6 +549,7 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs): f"--log_level=INFO", f"--max_batch_len={train_args.max_batch_len}", f"--seed={train_args.random_seed}", + f"--chat-tmpl-path={train_args.chat_tmpl_path}", ] if train_args.mock_data: @@ -644,6 +648,13 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs): help="Offload optimizer to CPU when using DeepSpeed. This configures it to use ZeRO stage 2.", ) parser.add_argument("--NEFTune_alpha", type=float, default=None) + parser.add_argument( + "--chat-tmpl-path", + type=str, + default=os.path.join( + os.path.dirname(__file__), "chat_templates/ibm_generic_tmpl.py" + ), + ) args = parser.parse_args() set_random_seed(args.seed) main(args) diff --git a/src/instructlab/training/tokenizer_utils.py b/src/instructlab/training/tokenizer_utils.py index 7eff0e69..5c789441 100644 --- a/src/instructlab/training/tokenizer_utils.py +++ b/src/instructlab/training/tokenizer_utils.py @@ -10,46 +10,34 @@ @dataclass class SpecialTokens: - system: str = field(default="<|system|>") + system: str = field(default=None) user: str = field(default="<|user|>") assistant: str = field(default="<|assistant|>") eos: str = field(default="<|endoftext|>") - pad: str = field(default="<|pad|>") + pad: str = field(default=None) + bos: str = field(default="<|begginingoftext|>") -SPECIAL_TOKENS = SpecialTokens() - -CHAT_TEMPLATE = ( - "{% for message in messages %}" - "{% if message['role'] == 'pretraining' %}" - "{{'<|endoftext|>' + message['content'] + '<|endoftext|>'}}" - "{% elif message['role'] == 'system' %}" - "{{'<|system|>'+ '\n' + message['content'] + '\n'}}" - "{% elif message['role'] == 'user' %}" - "{{'<|user|>' + '\n' + message['content'] + '\n'}}" - "{% elif message['role'] == 'assistant' %}" - "{{'<|assistant|>' + '\n' + message['content'] + '<|endoftext|>' + ('' if loop.last else '\n')}}" - "{% endif %}" - "{% endfor %}" -) - - -def setup_tokenizer( - model_name_or_path, SPECIAL_TOKENS=SPECIAL_TOKENS, CHAT_TEMPLATE=CHAT_TEMPLATE -): +def setup_tokenizer(model_name_or_path, SPECIAL_TOKENS, CHAT_TEMPLATE): tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, fast_tokenizer=True) - tokenizer.add_special_tokens( - {"eos_token": SPECIAL_TOKENS.eos, "pad_token": SPECIAL_TOKENS.pad} - ) + + if not SPECIAL_TOKENS.pad: + SPECIAL_TOKENS.pad = SPECIAL_TOKENS.eos tokenizer.add_special_tokens( { - "additional_special_tokens": [ - SPECIAL_TOKENS.system, - SPECIAL_TOKENS.user, - SPECIAL_TOKENS.assistant, - ] + "bos_token": SPECIAL_TOKENS.bos, + "eos_token": SPECIAL_TOKENS.eos, + "pad_token": SPECIAL_TOKENS.pad, } ) + + if SPECIAL_TOKENS.system: + add_token_list = [SPECIAL_TOKENS.system] + else: + add_token_list = [] + add_token_list.extend([SPECIAL_TOKENS.user, SPECIAL_TOKENS.assistant]) + + tokenizer.add_special_tokens({"additional_special_tokens": add_token_list}) if getattr(tokenizer, "add_bos_token", False) or getattr( tokenizer, "add_eos_token", False ): diff --git a/src/instructlab/training/utils.py b/src/instructlab/training/utils.py index 6feaa548..5eccb5dc 100644 --- a/src/instructlab/training/utils.py +++ b/src/instructlab/training/utils.py @@ -24,6 +24,19 @@ import torch.nn.functional as F +def retrieve_chat_template(chat_tmpl_path): + try: + spec = importlib.util.spec_from_file_location("spcl_chat_tmpl", chat_tmpl_path) + module = importlib.util.module_from_spec(spec) + sys.modules["spcl_chat_tmpl"] = module + spec.loader.exec_module(module) + SPECIAL_TOKENS = module.SPECIAL_TOKENS + CHAT_TEMPLATE = module.CHAT_TEMPLATE + except: + sys.exit(f"Invalid chat template path: {chat_tmpl_path}") + return CHAT_TEMPLATE, SPECIAL_TOKENS + + def add_noisy_embeddings(model, noise_alpha=None): if not noise_alpha: return model