Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding custom chat template / special token support #54

Merged
merged 9 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions src/instructlab/training/chat_templates/ibm_generic_tmpl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# First Party
from instructlab.training.tokenizer_utils import SpecialTokens

SPECIAL_TOKENS = SpecialTokens(
system="<|system|>",
user="<|user|>",
assistant="<|assistant|>",
eos="<|endoftext|>",
pad="<|pad|>",
)

CHAT_TEMPLATE = (
"{% for message in messages %}"
"{% if message['role'] == 'pretraining' %}"
"{{'<|endoftext|>' + message['content'] + '<|endoftext|>'}}"
"{% elif message['role'] == 'system' %}"
"{{'<|system|>'+ '\n' + message['content'] + '\n'}}"
"{% elif message['role'] == 'user' %}"
"{{'<|user|>' + '\n' + message['content'] + '\n'}}"
"{% elif message['role'] == 'assistant' %}"
"{{'<|assistant|>' + '\n' + message['content'] + '<|endoftext|>' + ('' if loop.last else '\n')}}"
"{% endif %}"
"{% endfor %}"
)
22 changes: 22 additions & 0 deletions src/instructlab/training/chat_templates/mistral_tmpl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# First Party
from instructlab.training.tokenizer_utils import SpecialTokens

SPECIAL_TOKENS = SpecialTokens(
bos="<s>",
eos="</s>",
user="[INST]",
assistant="[/INST]",
)

CHAT_TEMPLATE = (
"{{ '<s>' }}"
"{% for message in messages %}"
"{% if message['role'] == 'pretraining' %}"
"{{ message['content'] + '</s>' }}"
"{% elif message['role'] == 'user' %}"
"{{ '[INST] ' + message['content'] + ' [/INST]' }}"
"{% elif message['role'] == 'assistant' %}"
"{{ message['content'] + '</s>'}}"
"{% endif %}"
"{% endfor %}"
)
7 changes: 7 additions & 0 deletions src/instructlab/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

# Standard
from enum import Enum
import os

# Third Party
from pydantic import BaseModel, ConfigDict, Field
Expand Down Expand Up @@ -42,6 +43,7 @@ class DataProcessArgs(BaseModel):
data_output_path: str
max_seq_len: int # defines the max sequence length of a sample
model_path: str # either a HF model name or path to HF model
chat_tmpl_path: str

# disable the protected namespace for the model_config field
model_config = ConfigDict(protected_namespaces=())
Expand Down Expand Up @@ -100,6 +102,11 @@ class TrainingArgs(BaseModel):
# Either the name of a HuggingFace model or a path to a model saved in HuggingFace format.
model_path: str

# Specify the chat template / special tokens for training (default is ibm-generic template/tokens)
chat_tmpl_path: str = os.path.join(
os.path.dirname(__file__), "chat_templates/ibm_generic_tmpl.py"
)

# this field specifies the filepath to the training dataset before processing
data_path: str
ckpt_output_dir: str
Expand Down
44 changes: 23 additions & 21 deletions src/instructlab/training/data_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from pathlib import Path
from typing import List
import logging
import os

# Third Party
from datasets import load_dataset
Expand All @@ -10,12 +11,8 @@

# First Party
from instructlab.training.config import DataProcessArgs
from instructlab.training.tokenizer_utils import (
SPECIAL_TOKENS,
get_sp_token,
setup_tokenizer,
)
from instructlab.training.utils import log_rank_0, setup_logger
from instructlab.training.tokenizer_utils import get_sp_token, setup_tokenizer
from instructlab.training.utils import log_rank_0, retrieve_chat_template, setup_logger


def check_valid_sample(
Expand All @@ -37,18 +34,6 @@ def check_valid_sample(
if not any(token in whole_sentence_tk for token in special_tokens):
return True

# first token should be system_token
if whole_sentence_tk[0] != system_tk:
print("\033[91mfirst token is not a system_token\033[0m")
log_rank_0(tokenizer.decode(whole_sentence_tk), to_print=True)
return False

# check there's only one system_token
if (np.array(whole_sentence_tk) == system_tk).sum() != 1:
print("\033[91mthere are more than one system_token\033[0m")
log_rank_0(tokenizer.decode(whole_sentence_tk), to_print=True)
return False

whole_sentence_tk = np.array(whole_sentence_tk)
user_token_index = (whole_sentence_tk == user_tk).nonzero()[0]
assistant_token_index = (whole_sentence_tk == assistant_tk).nonzero()[0]
Expand Down Expand Up @@ -121,7 +106,11 @@ def unmask_only_assistant_responses(
whole_sentence = chosen_token["input_ids"][:sentence_legth].clone()

# pre-training mode
if system_tk not in whole_sentence:
if not (
system_tk in whole_sentence
or user_token in whole_sentence
or assist_token in whole_sentence
):
return labels

labels[:sentence_legth] = -100
Expand Down Expand Up @@ -204,11 +193,15 @@ def remove_pretrain_system_messages(example: dict):


def main(args: DataProcessArgs):
tokenizer = setup_tokenizer(args.model_path)
CHAT_TEMPLATE, SPECIAL_TOKENS = retrieve_chat_template(args.chat_tmpl_path)
tokenizer = setup_tokenizer(args.model_path, SPECIAL_TOKENS, CHAT_TEMPLATE)

eos_tk = get_sp_token(tokenizer, SPECIAL_TOKENS.eos)
pad_tk = get_sp_token(tokenizer, SPECIAL_TOKENS.pad)
system_tk = get_sp_token(tokenizer, SPECIAL_TOKENS.system)
if SPECIAL_TOKENS.system:
system_tk = get_sp_token(tokenizer, SPECIAL_TOKENS.system)
else:
system_tk = None
user_tk = get_sp_token(tokenizer, SPECIAL_TOKENS.user)
assistant_tk = get_sp_token(tokenizer, SPECIAL_TOKENS.assistant)
log_rank_0(
Expand Down Expand Up @@ -309,13 +302,22 @@ def main(args: DataProcessArgs):
parser.add_argument(
"--model_name_or_path", type=str, required=True, help="Model name or path"
)
parser.add_argument(
"--chat-tmpl-path",
type=str,
default=os.path.join(
os.path.dirname(__file__), "chat_templates/ibm_generic_tmpl.py"
),
help="Path to desired chat template and special tokens, defaults to IBM generic.",
)
args = parser.parse_args()
setup_logger(args.logging_level)
data_process_args = DataProcessArgs(
data_output_path=args.data_output_path,
data_path=args.data_path,
max_seq_len=args.max_seq_len,
model_path=args.model_name_or_path,
chat_tmpl_path=args.chat_tmpl_path,
)
main(data_process_args)

Expand Down
13 changes: 12 additions & 1 deletion src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
patch_target_module,
prepare_peft_model,
prepare_universal_checkpoint_from_latest,
retrieve_chat_template,
save_hf_format_ds,
save_model_ds_native,
set_random_seed,
Expand Down Expand Up @@ -438,7 +439,8 @@ def main(args):
print(f"\033[38;5;120m{yaml.dump(vars(args), sort_keys=False)}\033[0m")

setup_logger(args.log_level)
tokenizer = setup_tokenizer(args.model_name_or_path)
CHAT_TEMPLATE, SPECIAL_TOKENS = retrieve_chat_template(args.chat_tmpl_path)
tokenizer = setup_tokenizer(args.model_name_or_path, SPECIAL_TOKENS, CHAT_TEMPLATE)
# device = torch.device("cuda", args.local_rank)

#### distributed init #####
Expand Down Expand Up @@ -522,6 +524,7 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs):
model_path=train_args.model_path,
data_path=train_args.data_path,
max_seq_len=train_args.max_seq_len,
chat_tmpl_path=train_args.chat_tmpl_path,
)
)

Expand All @@ -546,6 +549,7 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs):
f"--log_level=INFO",
f"--max_batch_len={train_args.max_batch_len}",
f"--seed={train_args.random_seed}",
f"--chat-tmpl-path={train_args.chat_tmpl_path}",
]

if train_args.mock_data:
Expand Down Expand Up @@ -644,6 +648,13 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs):
help="Offload optimizer to CPU when using DeepSpeed. This configures it to use ZeRO stage 2.",
)
parser.add_argument("--NEFTune_alpha", type=float, default=None)
parser.add_argument(
"--chat-tmpl-path",
type=str,
default=os.path.join(
os.path.dirname(__file__), "chat_templates/ibm_generic_tmpl.py"
),
)
args = parser.parse_args()
set_random_seed(args.seed)
main(args)
Expand Down
48 changes: 18 additions & 30 deletions src/instructlab/training/tokenizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,46 +10,34 @@

@dataclass
class SpecialTokens:
system: str = field(default="<|system|>")
system: str = field(default=None)
user: str = field(default="<|user|>")
assistant: str = field(default="<|assistant|>")
eos: str = field(default="<|endoftext|>")
pad: str = field(default="<|pad|>")
pad: str = field(default=None)
bos: str = field(default="<|begginingoftext|>")


SPECIAL_TOKENS = SpecialTokens()

CHAT_TEMPLATE = (
"{% for message in messages %}"
"{% if message['role'] == 'pretraining' %}"
"{{'<|endoftext|>' + message['content'] + '<|endoftext|>'}}"
"{% elif message['role'] == 'system' %}"
"{{'<|system|>'+ '\n' + message['content'] + '\n'}}"
"{% elif message['role'] == 'user' %}"
"{{'<|user|>' + '\n' + message['content'] + '\n'}}"
"{% elif message['role'] == 'assistant' %}"
"{{'<|assistant|>' + '\n' + message['content'] + '<|endoftext|>' + ('' if loop.last else '\n')}}"
"{% endif %}"
"{% endfor %}"
)


def setup_tokenizer(
model_name_or_path, SPECIAL_TOKENS=SPECIAL_TOKENS, CHAT_TEMPLATE=CHAT_TEMPLATE
):
def setup_tokenizer(model_name_or_path, SPECIAL_TOKENS, CHAT_TEMPLATE):
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, fast_tokenizer=True)
tokenizer.add_special_tokens(
{"eos_token": SPECIAL_TOKENS.eos, "pad_token": SPECIAL_TOKENS.pad}
)

if not SPECIAL_TOKENS.pad:
SPECIAL_TOKENS.pad = SPECIAL_TOKENS.eos
tokenizer.add_special_tokens(
{
"additional_special_tokens": [
SPECIAL_TOKENS.system,
SPECIAL_TOKENS.user,
SPECIAL_TOKENS.assistant,
]
"bos_token": SPECIAL_TOKENS.bos,
"eos_token": SPECIAL_TOKENS.eos,
"pad_token": SPECIAL_TOKENS.pad,
}
)

if SPECIAL_TOKENS.system:
add_token_list = [SPECIAL_TOKENS.system]
else:
add_token_list = []
add_token_list.extend([SPECIAL_TOKENS.user, SPECIAL_TOKENS.assistant])

tokenizer.add_special_tokens({"additional_special_tokens": add_token_list})
if getattr(tokenizer, "add_bos_token", False) or getattr(
tokenizer, "add_eos_token", False
):
Expand Down
13 changes: 13 additions & 0 deletions src/instructlab/training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,19 @@
import torch.nn.functional as F


def retrieve_chat_template(chat_tmpl_path):
try:
spec = importlib.util.spec_from_file_location("spcl_chat_tmpl", chat_tmpl_path)
module = importlib.util.module_from_spec(spec)
sys.modules["spcl_chat_tmpl"] = module
spec.loader.exec_module(module)
SPECIAL_TOKENS = module.SPECIAL_TOKENS
CHAT_TEMPLATE = module.CHAT_TEMPLATE
except:

Check warning on line 35 in src/instructlab/training/utils.py

View workflow job for this annotation

GitHub Actions / lint

W0702: No exception type(s) specified (bare-except)
sys.exit(f"Invalid chat template path: {chat_tmpl_path}")
return CHAT_TEMPLATE, SPECIAL_TOKENS


def add_noisy_embeddings(model, noise_alpha=None):
if not noise_alpha:
return model
Expand Down
Loading