Skip to content

Commit

Permalink
Adding custom chat template / special token support (#54)
Browse files Browse the repository at this point in the history
To support more than just the currently hard-coded IBM generic special tokens / chat template.

---------

Signed-off-by: Mustafa Eyceoz <[email protected]>
  • Loading branch information
Maxusmusti authored Jun 21, 2024
1 parent 4ff2b4d commit 08647d7
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 52 deletions.
24 changes: 24 additions & 0 deletions src/instructlab/training/chat_templates/ibm_generic_tmpl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# First Party
from instructlab.training.tokenizer_utils import SpecialTokens

SPECIAL_TOKENS = SpecialTokens(
system="<|system|>",
user="<|user|>",
assistant="<|assistant|>",
eos="<|endoftext|>",
pad="<|pad|>",
)

CHAT_TEMPLATE = (
"{% for message in messages %}"
"{% if message['role'] == 'pretraining' %}"
"{{'<|endoftext|>' + message['content'] + '<|endoftext|>'}}"
"{% elif message['role'] == 'system' %}"
"{{'<|system|>'+ '\n' + message['content'] + '\n'}}"
"{% elif message['role'] == 'user' %}"
"{{'<|user|>' + '\n' + message['content'] + '\n'}}"
"{% elif message['role'] == 'assistant' %}"
"{{'<|assistant|>' + '\n' + message['content'] + '<|endoftext|>' + ('' if loop.last else '\n')}}"
"{% endif %}"
"{% endfor %}"
)
22 changes: 22 additions & 0 deletions src/instructlab/training/chat_templates/mistral_tmpl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# First Party
from instructlab.training.tokenizer_utils import SpecialTokens

SPECIAL_TOKENS = SpecialTokens(
bos="<s>",
eos="</s>",
user="[INST]",
assistant="[/INST]",
)

CHAT_TEMPLATE = (
"{{ '<s>' }}"
"{% for message in messages %}"
"{% if message['role'] == 'pretraining' %}"
"{{ message['content'] + '</s>' }}"
"{% elif message['role'] == 'user' %}"
"{{ '[INST] ' + message['content'] + ' [/INST]' }}"
"{% elif message['role'] == 'assistant' %}"
"{{ message['content'] + '</s>'}}"
"{% endif %}"
"{% endfor %}"
)
7 changes: 7 additions & 0 deletions src/instructlab/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

# Standard
from enum import Enum
import os

# Third Party
from pydantic import BaseModel, ConfigDict, Field
Expand Down Expand Up @@ -42,6 +43,7 @@ class DataProcessArgs(BaseModel):
data_output_path: str
max_seq_len: int # defines the max sequence length of a sample
model_path: str # either a HF model name or path to HF model
chat_tmpl_path: str

# disable the protected namespace for the model_config field
model_config = ConfigDict(protected_namespaces=())
Expand Down Expand Up @@ -100,6 +102,11 @@ class TrainingArgs(BaseModel):
# Either the name of a HuggingFace model or a path to a model saved in HuggingFace format.
model_path: str

# Specify the chat template / special tokens for training (default is ibm-generic template/tokens)
chat_tmpl_path: str = os.path.join(
os.path.dirname(__file__), "chat_templates/ibm_generic_tmpl.py"
)

# this field specifies the filepath to the training dataset before processing
data_path: str
ckpt_output_dir: str
Expand Down
44 changes: 23 additions & 21 deletions src/instructlab/training/data_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pathlib import Path
from typing import List
import logging
import os

# Third Party
from datasets import load_dataset
Expand All @@ -10,12 +11,8 @@

# First Party
from instructlab.training.config import DataProcessArgs
from instructlab.training.tokenizer_utils import (
SPECIAL_TOKENS,
get_sp_token,
setup_tokenizer,
)
from instructlab.training.utils import log_rank_0, setup_logger
from instructlab.training.tokenizer_utils import get_sp_token, setup_tokenizer
from instructlab.training.utils import log_rank_0, retrieve_chat_template, setup_logger


def check_valid_sample(
Expand All @@ -37,18 +34,6 @@ def check_valid_sample(
if not any(token in whole_sentence_tk for token in special_tokens):
return True

# first token should be system_token
if whole_sentence_tk[0] != system_tk:
print("\033[91mfirst token is not a system_token\033[0m")
log_rank_0(tokenizer.decode(whole_sentence_tk), to_print=True)
return False

# check there's only one system_token
if (np.array(whole_sentence_tk) == system_tk).sum() != 1:
print("\033[91mthere are more than one system_token\033[0m")
log_rank_0(tokenizer.decode(whole_sentence_tk), to_print=True)
return False

whole_sentence_tk = np.array(whole_sentence_tk)
user_token_index = (whole_sentence_tk == user_tk).nonzero()[0]
assistant_token_index = (whole_sentence_tk == assistant_tk).nonzero()[0]
Expand Down Expand Up @@ -121,7 +106,11 @@ def unmask_only_assistant_responses(
whole_sentence = chosen_token["input_ids"][:sentence_legth].clone()

# pre-training mode
if system_tk not in whole_sentence:
if not (
system_tk in whole_sentence
or user_token in whole_sentence
or assist_token in whole_sentence
):
return labels

labels[:sentence_legth] = -100
Expand Down Expand Up @@ -204,11 +193,15 @@ def remove_pretrain_system_messages(example: dict):


def main(args: DataProcessArgs):
tokenizer = setup_tokenizer(args.model_path)
CHAT_TEMPLATE, SPECIAL_TOKENS = retrieve_chat_template(args.chat_tmpl_path)
tokenizer = setup_tokenizer(args.model_path, SPECIAL_TOKENS, CHAT_TEMPLATE)

eos_tk = get_sp_token(tokenizer, SPECIAL_TOKENS.eos)
pad_tk = get_sp_token(tokenizer, SPECIAL_TOKENS.pad)
system_tk = get_sp_token(tokenizer, SPECIAL_TOKENS.system)
if SPECIAL_TOKENS.system:
system_tk = get_sp_token(tokenizer, SPECIAL_TOKENS.system)
else:
system_tk = None
user_tk = get_sp_token(tokenizer, SPECIAL_TOKENS.user)
assistant_tk = get_sp_token(tokenizer, SPECIAL_TOKENS.assistant)
log_rank_0(
Expand Down Expand Up @@ -309,13 +302,22 @@ def main(args: DataProcessArgs):
parser.add_argument(
"--model_name_or_path", type=str, required=True, help="Model name or path"
)
parser.add_argument(
"--chat-tmpl-path",
type=str,
default=os.path.join(
os.path.dirname(__file__), "chat_templates/ibm_generic_tmpl.py"
),
help="Path to desired chat template and special tokens, defaults to IBM generic.",
)
args = parser.parse_args()
setup_logger(args.logging_level)
data_process_args = DataProcessArgs(
data_output_path=args.data_output_path,
data_path=args.data_path,
max_seq_len=args.max_seq_len,
model_path=args.model_name_or_path,
chat_tmpl_path=args.chat_tmpl_path,
)
main(data_process_args)

Expand Down
13 changes: 12 additions & 1 deletion src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
patch_target_module,
prepare_peft_model,
prepare_universal_checkpoint_from_latest,
retrieve_chat_template,
save_hf_format_ds,
save_model_ds_native,
set_random_seed,
Expand Down Expand Up @@ -438,7 +439,8 @@ def main(args):
print(f"\033[38;5;120m{yaml.dump(vars(args), sort_keys=False)}\033[0m")

setup_logger(args.log_level)
tokenizer = setup_tokenizer(args.model_name_or_path)
CHAT_TEMPLATE, SPECIAL_TOKENS = retrieve_chat_template(args.chat_tmpl_path)
tokenizer = setup_tokenizer(args.model_name_or_path, SPECIAL_TOKENS, CHAT_TEMPLATE)
# device = torch.device("cuda", args.local_rank)

#### distributed init #####
Expand Down Expand Up @@ -522,6 +524,7 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs):
model_path=train_args.model_path,
data_path=train_args.data_path,
max_seq_len=train_args.max_seq_len,
chat_tmpl_path=train_args.chat_tmpl_path,
)
)

Expand All @@ -546,6 +549,7 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs):
f"--log_level=INFO",
f"--max_batch_len={train_args.max_batch_len}",
f"--seed={train_args.random_seed}",
f"--chat-tmpl-path={train_args.chat_tmpl_path}",
]

if train_args.mock_data:
Expand Down Expand Up @@ -644,6 +648,13 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs):
help="Offload optimizer to CPU when using DeepSpeed. This configures it to use ZeRO stage 2.",
)
parser.add_argument("--NEFTune_alpha", type=float, default=None)
parser.add_argument(
"--chat-tmpl-path",
type=str,
default=os.path.join(
os.path.dirname(__file__), "chat_templates/ibm_generic_tmpl.py"
),
)
args = parser.parse_args()
set_random_seed(args.seed)
main(args)
Expand Down
48 changes: 18 additions & 30 deletions src/instructlab/training/tokenizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,46 +10,34 @@

@dataclass
class SpecialTokens:
system: str = field(default="<|system|>")
system: str = field(default=None)
user: str = field(default="<|user|>")
assistant: str = field(default="<|assistant|>")
eos: str = field(default="<|endoftext|>")
pad: str = field(default="<|pad|>")
pad: str = field(default=None)
bos: str = field(default="<|begginingoftext|>")


SPECIAL_TOKENS = SpecialTokens()

CHAT_TEMPLATE = (
"{% for message in messages %}"
"{% if message['role'] == 'pretraining' %}"
"{{'<|endoftext|>' + message['content'] + '<|endoftext|>'}}"
"{% elif message['role'] == 'system' %}"
"{{'<|system|>'+ '\n' + message['content'] + '\n'}}"
"{% elif message['role'] == 'user' %}"
"{{'<|user|>' + '\n' + message['content'] + '\n'}}"
"{% elif message['role'] == 'assistant' %}"
"{{'<|assistant|>' + '\n' + message['content'] + '<|endoftext|>' + ('' if loop.last else '\n')}}"
"{% endif %}"
"{% endfor %}"
)


def setup_tokenizer(
model_name_or_path, SPECIAL_TOKENS=SPECIAL_TOKENS, CHAT_TEMPLATE=CHAT_TEMPLATE
):
def setup_tokenizer(model_name_or_path, SPECIAL_TOKENS, CHAT_TEMPLATE):
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, fast_tokenizer=True)
tokenizer.add_special_tokens(
{"eos_token": SPECIAL_TOKENS.eos, "pad_token": SPECIAL_TOKENS.pad}
)

if not SPECIAL_TOKENS.pad:
SPECIAL_TOKENS.pad = SPECIAL_TOKENS.eos
tokenizer.add_special_tokens(
{
"additional_special_tokens": [
SPECIAL_TOKENS.system,
SPECIAL_TOKENS.user,
SPECIAL_TOKENS.assistant,
]
"bos_token": SPECIAL_TOKENS.bos,
"eos_token": SPECIAL_TOKENS.eos,
"pad_token": SPECIAL_TOKENS.pad,
}
)

if SPECIAL_TOKENS.system:
add_token_list = [SPECIAL_TOKENS.system]
else:
add_token_list = []
add_token_list.extend([SPECIAL_TOKENS.user, SPECIAL_TOKENS.assistant])

tokenizer.add_special_tokens({"additional_special_tokens": add_token_list})
if getattr(tokenizer, "add_bos_token", False) or getattr(
tokenizer, "add_eos_token", False
):
Expand Down
13 changes: 13 additions & 0 deletions src/instructlab/training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,19 @@
import torch.nn.functional as F


def retrieve_chat_template(chat_tmpl_path):
try:
spec = importlib.util.spec_from_file_location("spcl_chat_tmpl", chat_tmpl_path)
module = importlib.util.module_from_spec(spec)
sys.modules["spcl_chat_tmpl"] = module
spec.loader.exec_module(module)
SPECIAL_TOKENS = module.SPECIAL_TOKENS
CHAT_TEMPLATE = module.CHAT_TEMPLATE
except:

Check warning on line 35 in src/instructlab/training/utils.py

View workflow job for this annotation

GitHub Actions / lint

W0702: No exception type(s) specified (bare-except)
sys.exit(f"Invalid chat template path: {chat_tmpl_path}")
return CHAT_TEMPLATE, SPECIAL_TOKENS


def add_noisy_embeddings(model, noise_alpha=None):
if not noise_alpha:
return model
Expand Down

0 comments on commit 08647d7

Please sign in to comment.