diff --git a/README.md b/README.md index 9851ced55..0d4e2939f 100644 --- a/README.md +++ b/README.md @@ -679,7 +679,9 @@ We support profiling with Nsight Systems, the PyTorch Profiler, and PyTorch Memo ## Nsight Systems Profiling -To use the Nsight Systems profiling, set config options `profile`, `profile_step_start`, and `profile_step_stop`. Launch training with: +To use the Nsight Systems profiling, set config options `profile`, `profile_step_start`, and `profile_step_stop` (see [here](https://github.com/EleutherAI/gpt-neox/blob/main/configs/neox_arguments.md) for argument usage, and [here](https://github.com/EleutherAI/gpt-neox/blob/main/configs/prof.yml) for a sample config). + +To populate nsys metrics, launch training with: ``` nsys profile -s none -t nvtx,cuda -o --force-overwrite true \ @@ -689,22 +691,22 @@ $TRAIN_PATH/train.py --conf_dir configs The generated output file can then by viewed with the Nsight Systems GUI: -![Alt text](images/nsight_profiling.png) +![nsight-prof](images/nsight_profiling.png) ## PyTorch Profiling -To use the built-in PyTorch profiler, set config options `profile`, `profile_step_start`, and `profile_step_stop`. +To use the built-in PyTorch profiler, set config options `profile`, `profile_step_start`, and `profile_step_stop` (see [here](https://github.com/EleutherAI/gpt-neox/blob/main/configs/neox_arguments.md) for argument usage, and [here](https://github.com/EleutherAI/gpt-neox/blob/main/configs/prof.yml) for a sample config). The PyTorch profiler will save traces to your `tensorboard` log directory. You can view these traces within TensorBoard by following the steps [here](https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html). -![Alt text](images/pytorch_profiling.png) +![torch-prof](images/pytorch_profiling.png) ## PyTorch Memory Profiling -To use PyTorch Memory Profiling, set config options `memory_profiling` and `memory_profiling_path`. +To use PyTorch Memory Profiling, set config options `memory_profiling` and `memory_profiling_path` (see [here](https://github.com/EleutherAI/gpt-neox/blob/main/configs/neox_arguments.md) for argument usage, and [here](https://github.com/EleutherAI/gpt-neox/blob/main/configs/prof.yml) for a sample config). -![Alt text](images/memory_profiling.png) +![mem-prof](images/memory_profiling.png) View the generated profile with the [memory_viz.py](https://github.com/pytorch/pytorch/blob/main/torch/cuda/_memory_viz.py) script. Run with: diff --git a/configs/README.md b/configs/README.md index 71a09ebea..ac20ed89b 100644 --- a/configs/README.md +++ b/configs/README.md @@ -124,6 +124,8 @@ These can be set to any integer between `0` and `num_gpus`, and `num_gpus` must # this should provide some speedup but takes a while to build, set to true if desired "scaled_upper_triang_masked_softmax_fusion": false, "train_iters": 320000, + # alternatively, use train_epochs to automatically determine the number of training iterations + #"train_epochs": 1, ``` An example of some basic settings used to configure your model's architecture and number of training steps. diff --git a/configs/neox_arguments.md b/configs/neox_arguments.md index efceee6e1..84f5ed97b 100644 --- a/configs/neox_arguments.md +++ b/configs/neox_arguments.md @@ -14,14 +14,19 @@ LR Scheduler Arguments Learning rate decay function. Choose from 'constant', 'linear', 'cosine', 'exponential'. - - **lr_decay_iters**: int Default = None - Number of iterations to decay learning rate over, If None defaults to --train-iters + Number of iterations to decay learning rate over. If None, defaults to + --train-iters or the equivalent inferred value from train_epochs. + +- **lr_decay_fraction**: float + Default = None + Effective fraction of training over which to decay lr. Overrides lr_decay_iters. + Useful when specifying train_epochs. - **min_lr**: float @@ -111,7 +116,7 @@ Logging Arguments - **git_hash**: str - Default = 217b4c5 + Default = 62c9738a current git hash of repository @@ -133,6 +138,54 @@ Logging Arguments +- **use_comet**: bool + + Default = None + + Flag indicating if comet is to be used. + + + +- **comet_workspace**: Optional + + Default = None + + Comet workspace name, if not configured Comet Experiments will be created in the user configured default workspace. + + + +- **comet_project**: Optional + + Default = None + + Comet project name, if not configured Comet Experiments will be created in the Uncategorized Experiments project. + + + +- **comet_experiment_name**: Optional + + Default = None + + Custom name for the Comet experiment. If not provided, a random name is used. + + + +- **comet_tags**: Optional + + Default = None + + List of tags to attach to the created Comet Experiment. + + + +- **comet_others**: Optional + + Default = None + + Custom metadata to attach to the created Comet Experiment. + + + - **log_interval**: int Default = 100 @@ -282,9 +335,23 @@ Model Arguments Default = None - Transformer intermediate size. Currently only used for "mlp_type": "llama". + Transformer intermediate size. Default = 4h - If not passed, will be set to a reasonable default. + + +- **mlp_multiple_of**: int + + Default = 1 + + force mlp size to be a multiple of this value + + + +- **expansion_factor**: float + + Default = None + + Transformer intermediate size. Default = 4 @@ -352,6 +419,14 @@ Model Arguments +- **rmsnorm_fusion**: bool + + Default = False + + Use fused RMS norm kernel (if `norm` is `rmsnorm`). + + + - **use_qk_layernorm**: bool Default = False @@ -498,11 +573,19 @@ Model Arguments -- **activation**: typing.Literal['gelu', 'geglu', 'relu', 'softsign', 'swish', 'mish', 'silu'] +- **activation**: typing.Literal['gelu', 'geglu', 'relu', 'softsign', 'swish', 'mish', 'silu', 'reglu', 'swiglu', 'bilinear', 'glu'] Default = gelu - Activation function to use - choose from ["gelu", "geglu", "relu", "softsign", "swish", "mish", "silu"] + Activation function to use - choose from ["gelu", "geglu", "relu", "softsign", "swish", "mish", "silu", "reglu", "swiglu", "bilinear", "glu"] + + + +- **use_flashattn_swiglu**: bool + + Default = False + + Use flash attention's version of swiglu @@ -683,13 +766,11 @@ Model Arguments -- **mlp_type**: str +- **use_bias_in_mlp**: bool - Default = regular + Default = True - Types: - regular: Megatron implementation - llama: LLaMA MLP (SiLU-gated MLP) + If false, mlps will not have bias terms @@ -764,6 +845,29 @@ Model Arguments +- **dim_att**: int + + Default = None + + Total dimension of the attention mechanism for RWKV. If not set, defaults to hidden_size. + + + +- **head_size**: int + + Default = None + + Size of each attention head for RWKV. Calculated as dim_att // num_attention_heads. + + + +- **ffn_dim**: int + + Default = None + + Dimension of the feed-forward network for RWKV. If not set, calculated based on hidden_size and expansion_factor. + + ## NeoXArgsOptimizer Optimizer Arguments @@ -1094,7 +1198,15 @@ Text Generation arguments Default = None How to generate text/sample the model. - Options: `unconditional`, `input-file`, `interactive` + Options: `unconditional`, `input-file`, `interactive`, `precompute` + + + +- **precompute_model_name**: str + + Default = None + + Model name to use for saving precomputed logprobs @@ -1381,11 +1493,19 @@ Training Arguments -- **label_data_paths**: list +- **train_label_data_paths**: list + + Default = None + + List of paths to train label datasets (not shifted by 1 yet!). + + + +- **train_reward_data_paths**: list Default = None - List of paths to label datasets (not shifted by 1 yet!). + List of paths to train reward datasets @@ -1397,6 +1517,22 @@ Training Arguments +- **test_label_data_paths**: list + + Default = None + + List of paths to test label datasets (not shifted by 1 yet!). + + + +- **test_reward_data_paths**: list + + Default = None + + List of paths to test reward datasets + + + - **valid_data_paths**: list Default = None @@ -1405,6 +1541,118 @@ Training Arguments +- **valid_label_data_paths**: list + + Default = None + + List of paths to validation label datasets (not shifted by 1 yet!). + + + +- **valid_reward_data_paths**: list + + Default = None + + List of paths to validation reward datasets + + + +- **pos_train_data_paths**: list + + Default = None + + + + + +- **neg_train_data_paths**: list + + Default = None + + List of paths to positive and negative training datasets. + + + +- **pos_train_label_data_paths**: list + + Default = None + + + + + +- **neg_train_label_data_paths**: list + + Default = None + + List of paths to positive and negative training label datasets (not shifted by 1 yet!). + + + +- **pos_valid_data_paths**: list + + Default = None + + + + + +- **neg_valid_data_paths**: list + + Default = None + + List of paths to positive and negative validation datasets. + + + +- **pos_valid_label_data_paths**: list + + Default = None + + + + + +- **neg_valid_label_data_paths**: list + + Default = None + + List of paths to positive and negative validation label datasets (not shifted by 1 yet!). + + + +- **pos_test_data_paths**: list + + Default = None + + + + + +- **neg_test_data_paths**: list + + Default = None + + List of paths to positive and negative test datasets. + + + +- **pos_test_label_data_paths**: list + + Default = None + + + + + +- **neg_test_label_data_paths**: list + + Default = None + + List of paths to positive and negative test label datasets (not shifted by 1 yet!). + + + - **train_data_weights**: list Default = None @@ -1472,6 +1720,99 @@ Training Arguments +- **pack_impl**: typing.Literal['packed', 'pack_until_overflow', 'unpacked'] + + Default = packed + + Packing implementation, can be one of "packed", "pack_until_overflow", or "unpacked". + + warning: pack_until_overflow is very naive and will likely have issues with pretraining scale datasets + + + +- **dataset_impl**: typing.Literal['gpt2', 'pairwise'] + + Default = gpt2 + + Dataset implementation, can be one of "gpt2" or "pairwise" + + + +- **train_impl**: typing.Literal['normal', 'dpo', 'rm', 'kto'] + + Default = normal + + Training implementation, can be one of "normal", "dpo", "kto", or "rm" + + + +- **dpo_fp32**: bool + + Default = True + + Whether to cast logits to fp32 for DPO loss calculation. + + + +- **dpo_reference_free**: bool + + Default = False + + Whether to use reference-free DPO. + + + +- **dpo_beta**: float + + Default = 0.1 + + Beta value for DPO + + + +- **kto_fp32**: bool + + Default = True + + Whether to cast logits to fp32 for KTO loss calculation. + + + +- **kto_desirable_weight**: float + + Default = 1.0 + + Weight for desirable loss in KTO. Might help if you have unbalanced desirable and undesirable classes. + + + +- **kto_undesirable_weight**: float + + Default = 1.0 + + Weight for undesirable loss in KTO. Might help if you have unbalanced desirable and undesirable classes. + + + +- **kto_beta**: float + + Default = 0.1 + + Beta value for KTO + + + +- **allow_chopped**: bool + + Default = True + + WARNING: if your packing impl is packed, this is ignored. + + Allow chopped samples in the dataset. + (e.g if your sequence length is 1024 and you have a sample of length 1026, it will be chopped to 1024) + + + - **mmap_warmup**: bool Default = False @@ -1618,6 +1959,15 @@ Training Arguments +- **train_epochs**: int + + Default = None + + Number of epochs to run for training. Do not specify both train_epochs and train_iters. + Not currently compatible with data reweighing, pairwise datasets, and packing other than 'packed' + + + - **eval_iters**: int Default = 100 diff --git a/configs/prof.yml b/configs/prof.yml new file mode 100644 index 000000000..c2f2ee118 --- /dev/null +++ b/configs/prof.yml @@ -0,0 +1,17 @@ +# Sample profiling config +{ + # Turns on nsys and pytorch profiling + "profile": true, + + # pytorch profiler options + "profile_step_start": 10, + "profile_step_stop": 12, + + # pytorch memory profiler options + "memory_profiling": true, + "memory_profiling_path": tensorboard, + + + # All trace files (pytorch, nsys, tensorboard, etc) will be written here + "tensorboard_dir": "tensorboard", +} diff --git a/megatron/data/data_utils.py b/megatron/data/data_utils.py index 335bda061..c08b60151 100644 --- a/megatron/data/data_utils.py +++ b/megatron/data/data_utils.py @@ -16,7 +16,7 @@ import torch import numpy as np from typing import List, Tuple -from itertools import zip_longest +from itertools import zip_longest, cycle from functools import partial from megatron import mpu, print_rank_0 @@ -62,6 +62,7 @@ def build_the_dataset( dataset_impl, allow_chopped, num_samples, + num_epochs, seq_length, seed, skip_warmup, @@ -141,6 +142,7 @@ def build_the_dataset( documents, indexed_dataset, num_samples, + num_epochs, seq_length, seed, pack_impl=pack_impl, @@ -179,6 +181,7 @@ def build_train_valid_test_datasets( allow_chopped, splits_string, train_valid_test_num_samples, + train_valid_test_epochs, seq_length, seed, skip_warmup, @@ -219,6 +222,7 @@ def build_dataset(index, name): documents, indexed_dataset, train_valid_test_num_samples[index], + train_valid_test_epochs[index], seq_length, seed, pack_impl=pack_impl, @@ -268,12 +272,15 @@ def get_normalized_weights_and_num_samples( weight_sum = sum(weights) assert weight_sum > 0.0 weights = [weight / weight_sum for weight in weights] - # Add 0.5% (the 1.005 factor) so in case the blending dataset does - # not uniformly distribute the number of samples, we still have - # samples left to feed to the network. - weighted_num_samples = [] - for weight in weights: - weighted_num_samples.append(int(math.ceil(num_samples * weight * 1.005))) + if num_samples is not None: + # Add 0.5% (the 1.005 factor) so in case the blending dataset does + # not uniformly distribute the number of samples, we still have + # samples left to feed to the network. + weighted_num_samples = [] + for weight in weights: + weighted_num_samples.append(int(math.ceil(num_samples * weight * 1.005))) + else: + weighted_num_samples = [None for _ in weights] return weights, weighted_num_samples @@ -282,9 +289,9 @@ def build_weighted_datasets( train_num_samples, valid_num_samples, test_num_samples, - train_weights, - valid_weights, - test_weights, + train_epochs, + valid_epochs, + test_epochs, build_index_mappings=True, ): # build individual datasets @@ -367,6 +374,7 @@ def build_weighted_datasets( pack_impl=neox_args.pack_impl, allow_chopped=neox_args.allow_chopped, num_samples=train_num_samples[i], + num_epochs=train_epochs, seq_length=neox_args.seq_length, seed=neox_args.seed, skip_warmup=(not neox_args.mmap_warmup), @@ -391,6 +399,7 @@ def build_weighted_datasets( pack_impl=neox_args.pack_impl, allow_chopped=neox_args.allow_chopped, num_samples=valid_num_samples[i], + num_epochs=valid_epochs, seq_length=neox_args.seq_length, seed=neox_args.seed, skip_warmup=(not neox_args.mmap_warmup), @@ -415,6 +424,7 @@ def build_weighted_datasets( pack_impl=neox_args.pack_impl, allow_chopped=neox_args.allow_chopped, num_samples=test_num_samples[i], + num_epochs=test_epochs, seq_length=neox_args.seq_length, seed=neox_args.seed, skip_warmup=(not neox_args.mmap_warmup), @@ -469,9 +479,44 @@ def weights_by_num_docs(l: list, alpha=0.3): return weights -def build_train_valid_test_data_iterators(neox_args): +def validate_train_epochs(neox_args): + """Check for unsupported neox_args when using train_epochs instead of train_iters""" + if neox_args.train_epochs is None: + return + + if neox_args.train_epochs and neox_args.train_iters: + raise ValueError( + "Cannot specify both train epochs and train iters simultaneously" + ) + + if neox_args.pack_impl != "packed": + raise ValueError( + "Packing implementations other than 'packed' are currently unsupported with train_epochs" + ) + + if neox_args.weight_by_num_documents: + raise ValueError( + "Weighting by number of documents is currently unsupported with train_epochs" + ) + + if neox_args.train_data_weights and ( + not all(weight == 1.0 for weight in neox_args.train_data_weights) + ): + raise ValueError( + "train_data_weights != None is currently unsupported with train_epochs" + ) + + if neox_args.dataset_impl != "gpt2": + raise ValueError( + "non gpt2 datasets are not currently unsupported with train_epochs" + ) + + +def build_train_valid_test_data_loaders(neox_args): """XXX""" + validate_train_epochs(neox_args) + (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None) print_rank_0("> building train, validation, and test datasets ...") @@ -489,14 +534,21 @@ def build_train_valid_test_data_iterators(neox_args): # Data loader only on rank 0 of each model parallel group. if mpu.get_model_parallel_rank() == 0 and pipe_load: # Number of train/valid/test samples. - train_iters = neox_args.train_iters - eval_iters = (train_iters // neox_args.eval_interval + 1) * neox_args.eval_iters - test_iters = neox_args.eval_iters - train_val_test_num_samples = [ - train_iters * neox_args.train_batch_size, - eval_iters * neox_args.train_batch_size, - test_iters * neox_args.train_batch_size, - ] + if neox_args.train_iters is not None: + train_iters = neox_args.train_iters + eval_iters = ( + train_iters // neox_args.eval_interval + 1 + ) * neox_args.eval_iters + test_iters = neox_args.eval_iters + train_val_test_num_samples = [ + train_iters * neox_args.train_batch_size, + eval_iters * neox_args.train_batch_size, + test_iters * neox_args.train_batch_size, + ] + train_val_test_epochs = [None, None, None] + elif neox_args.train_epochs is not None: + train_val_test_num_samples = [None, None, None] + train_val_test_epochs = [1, 1, 1] if (neox_args.train_data_paths) or (neox_args.pos_train_data_paths): # when individual train / valid / test data paths are provided @@ -517,9 +569,9 @@ def build_train_valid_test_data_iterators(neox_args): train_num_samples, valid_num_samples, test_num_samples, - train_weights, - valid_weights, - test_weights, + train_val_test_epochs[0], + train_val_test_epochs[1], + train_val_test_epochs[2], build_index_mappings=not neox_args.weight_by_num_documents, ) @@ -565,9 +617,9 @@ def build_train_valid_test_data_iterators(neox_args): train_num_samples, valid_num_samples, test_num_samples, - train_weights, - valid_weights, - test_weights, + train_val_test_epochs[0], + train_val_test_epochs[1], + train_val_test_epochs[2], ) if train_datasets: @@ -585,6 +637,7 @@ def build_train_valid_test_data_iterators(neox_args): data_impl=neox_args.data_impl, splits_string=neox_args.split, train_valid_test_num_samples=train_val_test_num_samples, + train_valid_test_epochs=train_val_test_epochs, seq_length=neox_args.seq_length, seed=neox_args.seed, skip_warmup=(not neox_args.mmap_warmup), @@ -598,9 +651,15 @@ def build_train_valid_test_data_iterators(neox_args): test_dataloader = make_data_loader(test_ds, neox_args=neox_args) # Flags to know if we need to do training/validation/testing. - do_train = train_dataloader is not None and neox_args.train_iters > 0 - do_valid = valid_dataloader is not None and neox_args.eval_iters > 0 - do_test = test_dataloader is not None and neox_args.eval_iters > 0 + if neox_args.train_epochs: + do_train = train_dataloader is not None + do_valid = valid_dataloader is not None + do_test = test_dataloader is not None + else: + do_train = train_dataloader is not None and neox_args.train_iters > 0 + do_valid = valid_dataloader is not None and neox_args.eval_iters > 0 + do_test = test_dataloader is not None and neox_args.eval_iters > 0 + # Need to broadcast num_tokens and num_type_tokens. flags = torch.cuda.LongTensor([int(do_train), int(do_valid), int(do_test)]) else: @@ -620,6 +679,19 @@ def build_train_valid_test_data_iterators(neox_args): neox_args.do_train = flags[0].item() neox_args.do_valid = flags[1].item() neox_args.do_test = flags[2].item() + data_loaders = { + "train": train_dataloader, + "valid": valid_dataloader, + "test": test_dataloader, + } + return data_loaders + + +def shift_and_wrap_data_loaders(neox_args, data_loaders, loop=True): + """Shift start iteration and wrap data_loaders in iterators""" + train_dataloader = data_loaders["train"] + valid_dataloader = data_loaders["valid"] + test_dataloader = data_loaders["test"] # Shift the start iterations. if train_dataloader is not None: @@ -645,19 +717,34 @@ def build_train_valid_test_data_iterators(neox_args): ) ) + def loop_iterator(data_loader): + while True: + for x in data_loader: + yield x + data_loader.start_iter = 0 + # Build iterators. if train_dataloader is not None: - train_data_iterator = iter(train_dataloader) + if loop: + train_data_iterator = cycle(train_dataloader) + else: + train_data_iterator = iter(train_dataloader) else: train_data_iterator = None if valid_dataloader is not None: - valid_data_iterator = iter(valid_dataloader) + if loop: + valid_data_iterator = cycle(valid_dataloader) + else: + valid_data_iterator = iter(valid_dataloader) else: valid_data_iterator = None if test_dataloader is not None: - test_data_iterator = iter(test_dataloader) + if loop: + test_data_iterator = cycle(test_dataloader) + else: + test_data_iterator = iter(test_dataloader) else: test_data_iterator = None diff --git a/megatron/data/gpt2_dataset.py b/megatron/data/gpt2_dataset.py index c4729cc3e..73c21bebd 100644 --- a/megatron/data/gpt2_dataset.py +++ b/megatron/data/gpt2_dataset.py @@ -34,6 +34,7 @@ def __init__( documents, indexed_dataset, num_samples, + num_epochs, seq_length, seed, pack_impl="packed", @@ -70,6 +71,7 @@ def __init__( self.indexed_dataset.sizes, self.label_dataset, num_samples, + num_epochs, seq_length, seed, self.pack_impl, @@ -203,6 +205,7 @@ def _build_index_mappings( sizes, label_dataset, num_samples, + num_epochs, seq_length, seed, packing_impl, @@ -217,7 +220,8 @@ def _build_index_mappings( """ # Number of tokens in each epoch and number of required epochs. tokens_per_epoch = _num_tokens(documents, sizes) - num_epochs = _num_epochs(tokens_per_epoch, seq_length, num_samples) + if not num_epochs: + num_epochs = _num_epochs(tokens_per_epoch, seq_length, num_samples) # rng state np_rng = np.random.RandomState(seed=seed) diff --git a/megatron/data/samplers.py b/megatron/data/samplers.py index a9428e41c..5206636d8 100644 --- a/megatron/data/samplers.py +++ b/megatron/data/samplers.py @@ -100,7 +100,11 @@ class DistributedBatchSampler(data.sampler.BatchSampler): specifying True will result in the following samples for each gpu: GPU0: [0,2,4,6] GPU1: [1,3,5,7] specifying False will result in the following samples: - GPU0: [0,1,2,3] GPU1: [4,5,6,7]""" + GPU0: [0,1,2,3] GPU1: [4,5,6,7] + + The `infinite_loop` parameter allows the sampler to yield batches indefinitely, + restarting from the beginning of the dataset when all samples have been iterated over. + """ def __init__( self, diff --git a/megatron/logging.py b/megatron/logging.py index 05945fdda..af8a41fe5 100644 --- a/megatron/logging.py +++ b/megatron/logging.py @@ -23,6 +23,7 @@ from megatron import mpu, print_rank_0 from megatron.utils import report_memory +import math class Tee: @@ -106,6 +107,38 @@ def get_flops(neox_args, iter_time_s) -> float: + 18 * hidden_size * hidden_size * num_layers / num_heads ) ) + elif "mamba" in neox_args.attention_config: + # from https://github.com/Zyphra/zcookbook/blob/main/calc/calc_mamba_flops.py + if neox_args.expansion_factor: + d_inner = neox_args.hidden_size * neox_args.expansion_factor + elif neox_args.intermediate_size: + d_inner = neox_args.intermediate_size + else: + d_inner = neox_args.hidden_size * 2 # default expansion factor + d_state = 16 # TODO make d_state an arg. Currently hardcoded in neox mamba definition and here + conv_dimension = 4 # TODO make conv_dimension an arg. Currently hardcoded in neox mamba definition and here + dt_rank = math.ceil(neox_args.hidden_size / 16) + ssm_flops = ( + ckpt_activations_factor + * d_inner + * seq_len + * batch_size + * (11 * d_state + 4 * dt_rank + 1) + ) + mamba_projectors_flops = ( + ckpt_activations_factor * seq_len * batch_size * 6 * d_inner * hidden_size + ) + mamba_conv_flops = ( + ckpt_activations_factor + * seq_len + * batch_size + * 2 + * d_inner + * conv_dimension + ) + mamba_flops = ssm_flops + mamba_projectors_flops + mamba_conv_flops + embedding_flops = 6 * seq_len * batch_size * hidden_size * vocab_size + flops_per_iteration = mamba_flops * num_layers + embedding_flops else: flops_per_iteration = ( 24 diff --git a/megatron/neox_arguments/arguments.py b/megatron/neox_arguments/arguments.py index 9811cba72..4c1f39bbe 100644 --- a/megatron/neox_arguments/arguments.py +++ b/megatron/neox_arguments/arguments.py @@ -952,30 +952,6 @@ def calculate_derived(self): } ) - # derive steps where checkpoint should be saved - if self.checkpoint_factor or self.extra_save_iters: - if self.extra_save_iters: - save_iters = set(self.extra_save_iters) - else: - save_iters = set() - - step = self.checkpoint_factor # don't save step 0 or 1 - while step < self.train_iters: - save_iters.add(step) - if self.checkpoint_scale == "log": - step *= self.checkpoint_factor - elif self.checkpoint_scale == "linear": - step += self.checkpoint_factor - - save_iters = list(save_iters) - save_iters.sort() - - self.update_values( - { - "save_iters": save_iters, - } - ) - # derive precision fp16_conflict = "DeepSpeed fp16 field was set but precision conflicts" if self.fp16 and self.fp16.get("enabled", False): @@ -1065,6 +1041,10 @@ def calculate_derived(self): ) if self.optimizer_type.lower() == "onebitadam": + assert ( + self.train_iters is not None + ), "OneBitAdam requires train_iters to be specified" + # onebitadam needs to instantiated by deepspeed, and so we need to pass deepspeed scheduler args # for all other optimizers, the scheduling is handled by megatron self.scheduler = { diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index f8fdb9410..d596daffe 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -21,7 +21,7 @@ from template import NeoXArgsTemplate try: - from typing import List, Literal, Union, Optional + from typing import List, Literal, Union, Optional, Any except ImportError: from typing_extensions import List, Literal, Union, Optional @@ -46,7 +46,7 @@ def get_git_commit_hash(): try: git_hash = subprocess.check_output(["git", "describe", "--always"]).strip() git_hash = git_hash.decode() - except subprocess.CalledProcessError: + except (subprocess.CalledProcessError, FileNotFoundError): git_hash = None return git_hash @@ -504,6 +504,21 @@ class NeoXArgsModel(NeoXArgsTemplate): Parameter controlling whether the output layer is parallelized over the hidden dim (row) or the vocab dim (column) """ + dim_att: int = None + """ + Total dimension of the attention mechanism for RWKV. If not set, defaults to hidden_size. + """ + + head_size: int = None + """ + Size of each attention head for RWKV. Calculated as dim_att // num_attention_heads. + """ + + ffn_dim: int = None + """ + Dimension of the feed-forward network for RWKV. If not set, calculated based on hidden_size and expansion_factor. + """ + @dataclass class NeoXArgsOptimizer(NeoXArgsTemplate): @@ -576,7 +591,13 @@ class NeoXArgsLRScheduler(NeoXArgsTemplate): lr_decay_iters: int = None """ - Number of iterations to decay learning rate over, If None defaults to --train-iters + Number of iterations to decay learning rate over, If None defaults to + --train-iters or the equivalent inferred valued from train_epochs. + """ + + lr_decay_fraction: float = None + """ + Effective fraction of training over which to decay lr, overrides lr_decay_iters, useful when specifying train_epochs """ min_lr: float = 0.0 @@ -670,7 +691,7 @@ class NeoXArgsLogging(NeoXArgsTemplate): Custom metadata to attach to the created Comet Experiment. """ - comet_experiment = None + comet_experiment: Any = None """ Initialized comet experiment object used to log data """ @@ -729,8 +750,8 @@ class NeoXArgsLogging(NeoXArgsTemplate): profile: bool = False """ - Enable nsys profiling. When using this option, - nsys options should be specified in commandline. + Enable nsys and pytorch profiling. When using this option with nsys, + nsys options should be directly specified in commandline. An example nsys commandline is ``` nsys profile -s none -t nvtx,cuda -o @@ -855,11 +876,6 @@ class NeoXArgsOther(NeoXArgsTemplate): Set during training """ - save_iters: list = None - """ - Set during training - """ - global_num_gpus: int = None """ Set during launching @@ -1090,6 +1106,13 @@ class NeoXArgsTraining(NeoXArgsTemplate): Weight for undesirable loss in KTO. Might help if you have unbalanced desirable and undesirable classes. """ + z_loss: float = 0.0 + """ + Z-loss parameter, only implemented for RM training currently. + https://arxiv.org/pdf/2204.02311 + https://arxiv.org/pdf/2309.10305 + """ + kto_beta: float = 0.1 """ Beta value for KTO @@ -1144,7 +1167,7 @@ class NeoXArgsTraining(NeoXArgsTemplate): while "log" implies that the number of steps between each checkpoint will be multiplied by `checkpoint-factor` at each step, starting from step 1. """ - checkpoint_factor: int = None + checkpoint_factor: Union[int, float] = None """ Acts as a multiplier on either the "log" or "linear" checkpoint spacing. @@ -1198,6 +1221,12 @@ class NeoXArgsTraining(NeoXArgsTemplate): Number of iterations to run for training. """ + train_epochs: int = None + """ + Number of epochs to run for training. Do not specify both train_epochs and train_iters. + Not currently compatible with data reweighing, pairwise datasets, and packing other than 'packed' + """ + eval_iters: int = 100 """ Number of iterations to run for evaluation validation/test for. diff --git a/megatron/training.py b/megatron/training.py index 3e72197ef..f553cc832 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -50,7 +50,10 @@ ) from megatron.mpu.mappings import gather_from_model_parallel_region from megatron.checkpointing import load_checkpoint, save_checkpoint -from megatron.data.data_utils import build_train_valid_test_data_iterators +from megatron.data.data_utils import ( + build_train_valid_test_data_loaders, + shift_and_wrap_data_loaders, +) from megatron.initialize import initialize_megatron from megatron.learning_rates import AnnealingLR from megatron.logging import tb_wandb_log, training_log @@ -114,7 +117,7 @@ def _plot_data(df, activation, graph_name_prefix): else: for activation in activation_list: _plot_data(df, activation, graph_name_prefix) - + return 0 @@ -170,14 +173,54 @@ def gen(): return 0 +def update_iterations(neox_args, data_loaders): + """ + Compute the number of train iterations if not specified and num_epochs, updates the neox_args object. + Note that if len(train_dataloader) % gradient_accumulation_steps != 0, this will configure neox + to do as many iterations as possible while ensuring that each example is seen *at most* train_epochs + times. + """ + if (not neox_args.do_train) or (neox_args.train_iters is not None): + pass + elif neox_args.train_iters is None and neox_args.train_epochs is None: + print_rank_0( + "ERROR:Failed to specify either train_epochs or train_iters in config file" + ) + else: + global_rank = torch.distributed.get_rank() + + if global_rank == 0: + train_dataloader = data_loaders["train"] + train_epochs = neox_args.train_epochs + gradient_accumulation_steps = neox_args.gradient_accumulation_steps + + train_dataloader_len = len(train_dataloader) + train_iterations = ( + train_dataloader_len * train_epochs + ) // gradient_accumulation_steps + + train_iters_tensor = torch.cuda.LongTensor([train_iterations]) + else: + train_iters_tensor = torch.cuda.LongTensor([0]) + + torch.distributed.broadcast(train_iters_tensor, src=0) + + neox_args.train_iters = train_iters_tensor[0].item() + + print_rank_0( + f"Training for a total of {neox_args.train_iters} iterations, corresponding to {neox_args.train_epochs} epochs." + ) + + def pretrain(neox_args): """Main training program. This function will run the following in the order provided: 1) initialize Megatron. - 2) setup model, optimizer and lr schedule - 3) call train_val_test_data_provider to get train/val/test datasets. - 4) train the model. + 2) get train/val/test datasets. + 3) setup model, optimizer and lr schedule. + 4) configure data loading + 5) train the model. Arguments: neox_args: an instance of NeoXArgs containing the configuration for pretrain @@ -194,21 +237,26 @@ def pretrain(neox_args): # Initialize and get arguments, timers, and Tensorboard writer. initialize_megatron(neox_args=neox_args) + # Normally we initialize the model first, but if we're running a coord check we build datasets first. if neox_args.coord_check: print_rank_0("---- Do Coord Check ----") # Data stuff neox_args.iteration = 0 - timers("train/valid/test data iterators").start() - ( - train_data_iterator, - valid_data_iterator, - test_data_iterator, - ) = build_train_valid_test_data_iterators(neox_args=neox_args) - timers("train/valid/test data iterators").stop() + # Create data loaders + timers("train/valid/test data loaders").start() + data_loaders = build_train_valid_test_data_loaders(neox_args=neox_args) + update_iterations(neox_args=neox_args, data_loaders=data_loaders) + timers("train/valid/test data loaders").stop() - coord_check(neox_args, timers, train_data_iterator) + coord_check(neox_args, timers, data_loaders["train"]) sys.exit() + # Create data loaders + timers("train/valid/test data loaders").start() + data_loaders = build_train_valid_test_data_loaders(neox_args=neox_args) + update_iterations(neox_args=neox_args, data_loaders=data_loaders) + timers("train/valid/test data loaders").stop() + # Model, optimizer, and learning rate. timers("model and optimizer").start() model, optimizer, lr_scheduler, reference_model = setup_model_and_optimizer( @@ -216,23 +264,34 @@ def pretrain(neox_args): ) timers("model and optimizer").stop() - # Data stuff. + # Make and configure iterators timers("train/valid/test data iterators").start() ( train_data_iterator, valid_data_iterator, test_data_iterator, - ) = build_train_valid_test_data_iterators(neox_args=neox_args) + ) = shift_and_wrap_data_loaders(neox_args=neox_args, data_loaders=data_loaders) timers("train/valid/test data iterators").stop() # Print setup timing. print_rank_0("done with setups ...") - timers.log(["model and optimizer", "train/valid/test data iterators"]) + timers.log( + [ + "train/valid/test data loaders", + "model and optimizer", + "train/valid/test data iterators", + ] + ) print_rank_0("training ...") iteration = neox_args.iteration # edge case: save step 0 checkpoint if requested and we're starting from step 0 - if neox_args.save and 0 in neox_args.save_iters and iteration == 0: + if ( + neox_args.save + and neox_args.extra_save_iters + and 0 in neox_args.extra_save_iters + and iteration == 0 + ): save_checkpoint( neox_args=neox_args, iteration=iteration, @@ -333,7 +392,7 @@ def get_batch(neox_args, data_iterator): # Items and their type. if neox_args.train_impl in ["normal", "kto"]: keys = ["text", "label"] if neox_args.train_label_data_paths else ["text"] - elif neox_args.train_impl == "dpo": + elif neox_args.train_impl in ["dpo", "rm"]: keys = ( [["pos", "pos_label"], ["neg", "neg_label"]] if neox_args.pos_train_label_data_paths @@ -355,6 +414,9 @@ def get_batch(neox_args, data_iterator): datatype=datatype, ) elif neox_args.train_impl == "kto": + assert ( + neox_args.train_micro_batch_size_per_gpu > 1 + ), "For KTO training, the train_micro_batch_size_per_gpu must be greater than 1." tup = _get_batch( neox_args=neox_args, tokenizer=neox_args.tokenizer, @@ -408,6 +470,13 @@ def get_batch(neox_args, data_iterator): def get_batch_pipe(data, neox_args, curr_scheduler=None): """A modification of get_batch() to work with the latest batch instead of an iterator.""" + + assert neox_args.train_impl not in [ + "kto", + "dpo", + "rm", + ], "Pipeline parallel is currently unsupported when using any of kto, dpo, rm. Set pipe_parallel_size to 0" + # Items and their type. keys = ["text", "label"] if neox_args.train_label_data_paths else ["text"] datatype = torch.int64 @@ -525,7 +594,7 @@ def forward_step( return model.eval_batch(data_iterator, return_logits=return_logits) # Get the batch. - if neox_args.memory_profiling and neox_args.it: + if neox_args.memory_profiling and neox_args.iteration: torch.cuda.nvtx.range_push(f"Get batch") if timers is not None: timers("batch generator").start() @@ -584,6 +653,32 @@ def forward_step( else: moe_loss = 0.0 loss = main_loss + moe_loss + elif neox_args.train_impl == "rm": + maybe_tuple = model((tokens, position_ids, attention_mask), neox_args=neox_args) + if type(maybe_tuple) is tuple: + outputs, _ = maybe_tuple + else: + outputs = maybe_tuple + pos, neg = torch.chunk(outputs, 2, 0) + pos_loss_mask, neg_loss_mask = torch.chunk(loss_mask, 2, 0) + # We assume that each pos, neg pair occur in the same order + # e.g. second nonzero pos is the corresponding second nonzero neg + # and that there are also an equal number of pos and neg in each sequence. + pos_indx = pos_loss_mask.nonzero() + neg_indx = neg_loss_mask.nonzero() + # indx[:, 0] is the batch index, indx[:, 1] is the token index, we only care about the token index. + pos_indx = pos_indx[:, 1].unsqueeze(1) + neg_indx = neg_indx[:, 1].unsqueeze(1) + pos = torch.gather(pos.squeeze(), dim=1, index=pos_indx) + neg = torch.gather(neg.squeeze(), dim=1, index=neg_indx) + with torch.no_grad(): + metrics["pos_values"] = pos.clone().detach().mean() + metrics["neg_values"] = neg.clone().detach().mean() + metrics["margin"] = (pos - neg).clone().detach().mean() + metrics["accuracy"] = ((pos - neg) > 0).clone().detach().float().mean() + loss = (-F.logsigmoid(pos - neg).mean()) + ( + (neox_args.z_loss * (pos**2 + neg**2)).mean() + ) elif neox_args.train_impl == "dpo": # Based on https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90 with torch.inference_mode(): @@ -790,7 +885,7 @@ def get_model(neox_args, use_cache=False): model = GPT2ModelPipe( neox_args=neox_args, num_tokentypes=0, - parallel_output=True, + parallel_output=True if neox_args.train_impl != "rm" else False, topology=mpu.get_topology(), use_cache=use_cache, ) @@ -994,6 +1089,8 @@ def get_learning_rate_scheduler(optimizer, neox_args): # Add linear learning rate scheduler. if neox_args.lr_decay_iters is not None: num_iters = neox_args.lr_decay_iters + elif neox_args.lr_decay_fraction is not None: + num_iters = math.floor(neox_args.train_iters * neox_args.lr_decay_fraction) else: num_iters = neox_args.train_iters num_iters = max(1, num_iters) @@ -1278,6 +1375,29 @@ def train_step_pipe(neox_args, timers, model, data_iterator): return loss_dict +def is_save_iter(neox_args, iteration): + if neox_args.extra_save_iters and iteration in neox_args.extra_save_iters: + return True + + if neox_args.checkpoint_factor: + if neox_args.checkpoint_scale == "linear": + assert float( + neox_args.checkpoint_factor + ).is_integer(), "checkpoint_factor must be a whole number when using linear checkpoint_scale" + return iteration % neox_args.checkpoint_factor == 0 + elif neox_args.checkpoint_scale == "log": + # Check if iteration is a power of checkpoint_factor + assert neox_args.checkpoint_factor > 1 + power = 1 + while power < iteration + 1: + if int(power) == iteration: + return True + power *= neox_args.checkpoint_factor + return False + + return False + + def train( neox_args, timers, @@ -1374,7 +1494,7 @@ def train( ) # Checkpointing - if neox_args.save and iteration in neox_args.save_iters: + if neox_args.save and is_save_iter(neox_args, iteration): save_checkpoint( neox_args=neox_args, iteration=iteration, diff --git a/post-training/README.md b/post-training/README.md index e6be7d931..930ad0e31 100644 --- a/post-training/README.md +++ b/post-training/README.md @@ -34,12 +34,18 @@ python tools/datasets/preprocess_data_with_chat_template.py --input data/pairwis ## SFT data ```bash -python post-training/llama_dpo_data.py python tools/datasets/preprocess_data_with_chat_template.py --input data/sft/llama3_sft_train_filtered.jsonl --output-prefix data/sft/llama3_train --tokenizer-path checkpoints/neox_converted/llama3-8b-instruct/tokenizer --jsonl-keys messages python tools/datasets/preprocess_data_with_chat_template.py --input data/sft/llama3_sft_test_filtered.jsonl --output-prefix data/sft/llama3_test --tokenizer-path checkpoints/neox_converted/llama3-8b-instruct/tokenizer --jsonl-keys messages python tools/datasets/preprocess_data_with_chat_template.py --input data/sft/llama3_sft_train_filtered.jsonl --output-prefix data/sft/llama3_val --tokenizer-path checkpoints/neox_converted/llama3-8b-instruct/tokenizer --jsonl-keys messages ``` +## KTO data +```bash +python tools/datasets/preprocess_data_with_chat_template.py --input data/kto/llama3_sft_train_filtered.jsonl --output-prefix data/kto/llama3_train --tokenizer-path checkpoints/neox_converted/llama3-8b-instruct/tokenizer --jsonl-keys messages --reward-key reward +python tools/datasets/preprocess_data_with_chat_template.py --input data/kto/llama3_sft_test_filtered.jsonl --output-prefix data/kto/llama3_test --tokenizer-path checkpoints/neox_converted/llama3-8b-instruct/tokenizer --jsonl-keys messages --reward-key reward +python tools/datasets/preprocess_data_with_chat_template.py --input data/kto/llama3_sft_train_filtered.jsonl --output-prefix data/kto/llama3_val --tokenizer-path checkpoints/neox_converted/llama3-8b-instruct/tokenizer --jsonl-keys messages --reward-key reward +``` + ## Converting back to hf ```bash diff --git a/post-training/configs/llama3-8b-kto.yml b/post-training/configs/llama3-8b-kto.yml new file mode 100644 index 000000000..e819d37cb --- /dev/null +++ b/post-training/configs/llama3-8b-kto.yml @@ -0,0 +1,120 @@ +{ + "pipe_parallel_size": 0, + "model_parallel_size": 4, + "make_vocab_size_divisible_by": 1, + + # model settings + "num_layers": 32, + "hidden_size": 4096, + "num_attention_heads": 32, + "num_kv_heads": 8, + # llama3 supports more than this but this is just for testing. + "seq_length": 1024, + "max_position_embeddings": 1024, + "pos_emb": "rotary", + "rotary_pct": 1, + "rotary_emb_base": 500000, + "rope_fusion": true, + "no_weight_tying": true, + "gpt_j_residual": false, + "output_layer_parallelism": "column", + "norm": "rmsnorm", + "rms_norm_epsilon": 1.0e-5, + + "attention_config": [[["flash"], 32]], + + "scaled_upper_triang_masked_softmax_fusion": true, + "bias_gelu_fusion": false, + "use_bias_in_norms": false, + "use_bias_in_attn_linear": false, + "use_bias_in_mlp": false, + "use_flashattn_swiglu": true, + "activation": "swiglu", + "intermediate_size": 14336, + "mlp_multiple_of": 14336, + + "optimizer": { + "type": "Adam", + "params": { + "lr": 0.00001, + "betas": [0.9, 0.95], + "eps": 1.0e-8 + } + }, + "min_lr": 0.000001, + + "zero_optimization": { + "stage": 1, + "allgather_partitions": true, + "allgather_bucket_size": 1260000000, + "overlap_comm": true, + "reduce_scatter": true, + "reduce_bucket_size": 1260000000, + "contiguous_gradients": true, + "cpu_offload": false + }, + + + "train_impl": "kto", + "kto_fp32": true, + "kto_beta": 0.1, + "allow_chopped": false, + "train_label_data_paths": [ "data/kto/llama3_train_messages_label_document" ], + "test_label_data_paths": [ "data/kto/llama3_test_messages_label_document" ], + "valid_label_data_paths": [ "data/kto/llama3_train_messages_label_document" ], + "train_data_paths": [ "data/kto/llama3_train_messages_document" ], + "test_data_paths": [ "data/kto/llama3_test_messages_document" ], + "valid_data_paths": [ "data/kto/llama3_train_messages_document" ], + "train_reward_data_paths": [ "data/kto/llama3_train_messages_reward_document" ], + "test_reward_data_paths": [ "data/kto/llama3_test_messages_reward_document" ], + "valid_reward_data_paths": [ "data/kto/llama3_train_messages_reward_document" ], + + "train_micro_batch_size_per_gpu": 32, + "gradient_accumulation_steps": 2, + "data_impl": "mmap", + "pack_impl": "unpacked", + "num_workers": 1, + + "checkpoint_activations": true, + "checkpoint_num_layers": 1, + "partition_activations": true, + "synchronize_each_layer": true, + + "gradient_clipping": 1.0, + "weight_decay": 0.1, + "hidden_dropout": 0, + "attention_dropout": 0, + + "precision": "bfloat16", + "fp32_allreduce": true, + "bf16": { + "enabled": true + }, + "data_types": { + "grad_accum_dtype": "fp32" + }, + + "train_iters": 477, + "lr_decay_iters": 477, + "distributed_backend": "nccl", + "lr_decay_style": "cosine", + "warmup": 0.1, + "checkpoint_factor": 1000, + "eval_interval": 100, + "eval_iters": 10, + + "log_interval": 1, + "steps_per_print": 1, + "wall_clock_breakdown": true, + + + "save": "checkpoints/kto/llama3/llama3-8b-instruct", + #"load": "", # once run is started, to restart from intermediate ckpt use "load" = "save" + "load": "checkpoints/neox_converted/llama3-8b-instruct", + "vocab-file": "checkpoints/neox_converted/llama3-8b-instruct/tokenizer/tokenizer.json", + "use_wandb": true, + "wandb_group": "llama3-8b-instruct", + "wandb_project": "ultrafeedback-kto", + "finetune": true, # set to false once resuming from intermediate finetuning step + "tokenizer_type": "HFTokenizer", +} diff --git a/post-training/llama_data.py b/post-training/llama_data.py index 5eef8c2d3..eab6ac9f1 100644 --- a/post-training/llama_data.py +++ b/post-training/llama_data.py @@ -22,7 +22,6 @@ ) as f: writer = jsonlines.Writer(f) for item in raw_datasets[split]: - # add empty system messages item["chosen"] = item["chosen"] item["rejected"] = item["rejected"] writer.write(item) @@ -33,6 +32,18 @@ ) as f: writer = jsonlines.Writer(f) for item in raw_datasets[split]: - # add empty system messages item["messages"] = item["chosen"] writer.write(item) +os.makedirs(os.path.join("data", "kto"), exist_ok=True) +for split in ["train", "test"]: + with open( + os.path.join("data", "kto", f"llama3_kto_{split}_filtered.jsonl"), "w" + ) as f: + writer = jsonlines.Writer(f) + for item in raw_datasets[split]: + item["messages"] = item["chosen"] + item["reward"] = 1 + writer.write(item) + item["messages"] = item["rejected"] + item["reward"] = -1 + writer.write(item) diff --git a/tests/neox_args/test_neoxargs_usage.py b/tests/neox_args/test_neoxargs_usage.py index 176151c2a..5f8ba7bd2 100644 --- a/tests/neox_args/test_neoxargs_usage.py +++ b/tests/neox_args/test_neoxargs_usage.py @@ -66,7 +66,9 @@ def test_neoxargs_usage(): # find args matches matches = list( - re.findall(r"(?<=args\.).{2,}?(?=[\s\n(){}+-/*;:,=,[,\]])", file_contents) + re.findall( + r"(?<=neox_args\.).{2,}?(?=[\s\n(){}+-/*;:,=,[,\]])", file_contents + ) ) if len(matches) == 0: continue diff --git a/tests/unit/test_format_conversion_scripts.py b/tests/unit/test_format_conversion_scripts.py index e0801434c..6935e480a 100644 --- a/tests/unit/test_format_conversion_scripts.py +++ b/tests/unit/test_format_conversion_scripts.py @@ -4,8 +4,12 @@ from megatron.neox_arguments.neox_args import NeoXArgsTokenizer +@pytest.mark.skip( + reason="Conversion test is skipped until we fix the CUDA + torch multiprocessing issue." +) def test_gpt_neox_to_huggingface(monkeypatch, tmpdir, tmp_path): # Generate random GPT-NEOX model, check we can convert to hf format + model_dir = str(tmpdir) input_args = ["train.py", "tests/config/test_setup.yml"] deepspeed_main_args = simulate_deepy_env(monkeypatch, input_args) diff --git a/tools/datasets/preprocess_data_with_chat_template.py b/tools/datasets/preprocess_data_with_chat_template.py index 767b42afc..ee2b983b6 100644 --- a/tools/datasets/preprocess_data_with_chat_template.py +++ b/tools/datasets/preprocess_data_with_chat_template.py @@ -408,6 +408,8 @@ def main(): for key in update_keys: builders[key].finalize(output_idx_files[key]) builders[key + "_label"].finalize(output_idx_files[key + "_label"]) + if args.reward_key is not None: + builders[key + "_reward"].finalize(output_idx_files[key + "_reward"]) if __name__ == "__main__":