From 5f513b59938e6a0e431a5fb0061d4831b8dcc1a9 Mon Sep 17 00:00:00 2001 From: Mustafa Eyceoz Date: Wed, 19 Jun 2024 14:52:24 -0400 Subject: [PATCH] First pass chat template (backend only) --- .../chat_templates/ibm_generic_tmpl.py | 23 +++++++++++ .../training/chat_templates/mistral_tmpl.py | 23 +++++++++++ src/instructlab/training/data_process.py | 14 +------ src/instructlab/training/tokenizer_utils.py | 40 ++++++++----------- 4 files changed, 64 insertions(+), 36 deletions(-) create mode 100644 src/instructlab/training/chat_templates/ibm_generic_tmpl.py create mode 100644 src/instructlab/training/chat_templates/mistral_tmpl.py diff --git a/src/instructlab/training/chat_templates/ibm_generic_tmpl.py b/src/instructlab/training/chat_templates/ibm_generic_tmpl.py new file mode 100644 index 00000000..ba54e68d --- /dev/null +++ b/src/instructlab/training/chat_templates/ibm_generic_tmpl.py @@ -0,0 +1,23 @@ +from tokenizer_utils import SpecialTokens + +SPECIAL_TOKENS = SpecialTokens( + system="<|system|>", + user="<|user|>", + assistant="<|assistant|>", + eos="<|endoftext|>", + pad="<|pad|>" +) + +CHAT_TEMPLATE = ( + "{% for message in messages %}" + "{% if message['role'] == 'pretraining' %}" + "{{'<|endoftext|>' + message['content'] + '<|endoftext|>'}}" + "{% elif message['role'] == 'system' %}" + "{{'<|system|>'+ '\n' + message['content'] + '\n'}}" + "{% elif message['role'] == 'user' %}" + "{{'<|user|>' + '\n' + message['content'] + '\n'}}" + "{% elif message['role'] == 'assistant' %}" + "{{'<|assistant|>' + '\n' + message['content'] + '<|endoftext|>' + ('' if loop.last else '\n')}}" + "{% endif %}" + "{% endfor %}" +) \ No newline at end of file diff --git a/src/instructlab/training/chat_templates/mistral_tmpl.py b/src/instructlab/training/chat_templates/mistral_tmpl.py new file mode 100644 index 00000000..ba85e93d --- /dev/null +++ b/src/instructlab/training/chat_templates/mistral_tmpl.py @@ -0,0 +1,23 @@ +from tokenizer_utils import SpecialTokens + +SPECIAL_TOKENS = SpecialTokens( + bos="", + eos="", + user="[INST]", + assistant="[/INST]", + + +) + +CHAT_TEMPLATE = ( + "{{ '' }}" + "{% for message in messages %}" + "{% if message['role'] == 'pretraining' %}" + "{{ message['content'] + '' }}" + "{% elif message['role'] == 'user' %}" + "{{ '[INST] ' + message['content'] + ' [/INST]' }}" + "{% elif message['role'] == 'assistant' %}" + "{{ message['content'] + ''}}" + "{% endif %}" + "{% endfor %}" +) diff --git a/src/instructlab/training/data_process.py b/src/instructlab/training/data_process.py index f8f43535..9ea51a38 100644 --- a/src/instructlab/training/data_process.py +++ b/src/instructlab/training/data_process.py @@ -37,18 +37,6 @@ def check_valid_sample( if not any(token in whole_sentence_tk for token in special_tokens): return True - # first token should be system_token - if whole_sentence_tk[0] != system_tk: - print("\033[91mfirst token is not a system_token\033[0m") - log_rank_0(tokenizer.decode(whole_sentence_tk), to_print=True) - return False - - # check there's only one system_token - if (np.array(whole_sentence_tk) == system_tk).sum() != 1: - print("\033[91mthere are more than one system_token\033[0m") - log_rank_0(tokenizer.decode(whole_sentence_tk), to_print=True) - return False - whole_sentence_tk = np.array(whole_sentence_tk) user_token_index = (whole_sentence_tk == user_tk).nonzero()[0] assistant_token_index = (whole_sentence_tk == assistant_tk).nonzero()[0] @@ -121,7 +109,7 @@ def unmask_only_assistant_responses( whole_sentence = chosen_token["input_ids"][:sentence_legth].clone() # pre-training mode - if system_tk not in whole_sentence: + if not (system_tk in whole_sentence or user_token in whole_sentence or assist_token in whole_sentence): return labels labels[:sentence_legth] = -100 diff --git a/src/instructlab/training/tokenizer_utils.py b/src/instructlab/training/tokenizer_utils.py index 7eff0e69..c7e195ac 100644 --- a/src/instructlab/training/tokenizer_utils.py +++ b/src/instructlab/training/tokenizer_utils.py @@ -10,44 +10,38 @@ @dataclass class SpecialTokens: - system: str = field(default="<|system|>") + system: str = field(default=None) user: str = field(default="<|user|>") assistant: str = field(default="<|assistant|>") eos: str = field(default="<|endoftext|>") - pad: str = field(default="<|pad|>") + pad: str = field(default=None) + bos: str = field(default="<|begginingoftext|>") -SPECIAL_TOKENS = SpecialTokens() - -CHAT_TEMPLATE = ( - "{% for message in messages %}" - "{% if message['role'] == 'pretraining' %}" - "{{'<|endoftext|>' + message['content'] + '<|endoftext|>'}}" - "{% elif message['role'] == 'system' %}" - "{{'<|system|>'+ '\n' + message['content'] + '\n'}}" - "{% elif message['role'] == 'user' %}" - "{{'<|user|>' + '\n' + message['content'] + '\n'}}" - "{% elif message['role'] == 'assistant' %}" - "{{'<|assistant|>' + '\n' + message['content'] + '<|endoftext|>' + ('' if loop.last else '\n')}}" - "{% endif %}" - "{% endfor %}" -) +#TODO: Replace with specified template path +from instructlab.training.chat_templates.ibm_generic_tmpl import SPECIAL_TOKENS, CHAT_TEMPLATE def setup_tokenizer( model_name_or_path, SPECIAL_TOKENS=SPECIAL_TOKENS, CHAT_TEMPLATE=CHAT_TEMPLATE ): tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, fast_tokenizer=True) + + if not SPECIAL_TOKENS.pad: + SPECIAL_TOKENS.pad = SPECIAL_TOKENS.eos tokenizer.add_special_tokens( - {"eos_token": SPECIAL_TOKENS.eos, "pad_token": SPECIAL_TOKENS.pad} + {"bos_token": SPECIAL_TOKENS.bos, "eos_token": SPECIAL_TOKENS.eos, "pad_token": SPECIAL_TOKENS.pad} ) + + if SPECIAL_TOKENS.system: + add_token_list = [SPECIAL_TOKENS.system] + else: + add_token_list = [] + add_token_list.extend([SPECIAL_TOKENS.user, SPECIAL_TOKENS.assistant]) + tokenizer.add_special_tokens( { - "additional_special_tokens": [ - SPECIAL_TOKENS.system, - SPECIAL_TOKENS.user, - SPECIAL_TOKENS.assistant, - ] + "additional_special_tokens": add_token_list } ) if getattr(tokenizer, "add_bos_token", False) or getattr(