diff --git a/megatron/data/data_utils.py b/megatron/data/data_utils.py index bc5754cdb..7e4dbdb37 100644 --- a/megatron/data/data_utils.py +++ b/megatron/data/data_utils.py @@ -55,6 +55,8 @@ def build_the_dataset( data_prefix, name, data_impl, + pack_impl, + allow_chopped, num_samples, seq_length, seed, @@ -83,6 +85,8 @@ def build_the_dataset( num_samples, seq_length, seed, + pack_impl=pack_impl, + allow_chopped=allow_chopped, build_index_mappings=build_index_mappings, label_dataset=label_dataset, ) @@ -93,6 +97,8 @@ def build_train_valid_test_datasets( data_prefix, use_shared_fs, data_impl, + pack_impl, + allow_chopped, splits_string, train_valid_test_num_samples, seq_length, @@ -138,6 +144,8 @@ def build_dataset(index, name): train_valid_test_num_samples[index], seq_length, seed, + pack_impl=pack_impl, + allow_chopped=allow_chopped, use_shared_fs=use_shared_fs, ) return dataset @@ -204,12 +212,25 @@ def build_weighted_datasets( ): # build individual datasets train_datasets, valid_datasets, test_datasets = [], [], [] - for i, (train_path, label_path, valid_path, test_path) in enumerate( + for i, ( + train_path, + train_label_path, + valid_path, + valid_label_path, + test_path, + test_label_path, + ) in enumerate( zip_longest( neox_args.train_data_paths, - neox_args.label_data_paths if neox_args.label_data_paths else [], + neox_args.train_label_data_paths + if neox_args.train_label_data_paths + else [], neox_args.valid_data_paths, + neox_args.valid_label_data_paths + if neox_args.valid_label_data_paths + else [], neox_args.test_data_paths, + neox_args.test_label_data_paths if neox_args.test_label_data_paths else [], ) ): if train_path: @@ -218,12 +239,14 @@ def build_weighted_datasets( data_prefix=train_path, name=f"train_{i}", data_impl=neox_args.data_impl, + pack_impl=neox_args.pack_impl, + allow_chopped=neox_args.allow_chopped, num_samples=train_num_samples[i], seq_length=neox_args.seq_length, seed=neox_args.seed, skip_warmup=(not neox_args.mmap_warmup), build_index_mappings=build_index_mappings, - label_prefix=label_path, + label_prefix=train_label_path, ) ) @@ -233,11 +256,14 @@ def build_weighted_datasets( data_prefix=valid_path, name=f"valid_{i}", data_impl=neox_args.data_impl, + pack_impl=neox_args.pack_impl, + allow_chopped=neox_args.allow_chopped, num_samples=valid_num_samples[i], seq_length=neox_args.seq_length, seed=neox_args.seed, skip_warmup=(not neox_args.mmap_warmup), build_index_mappings=build_index_mappings, + label_prefix=valid_label_path, ) ) @@ -247,11 +273,14 @@ def build_weighted_datasets( data_prefix=test_path, name=f"test_{i}", data_impl=neox_args.data_impl, + pack_impl=neox_args.pack_impl, + allow_chopped=neox_args.allow_chopped, num_samples=test_num_samples[i], seq_length=neox_args.seq_length, seed=neox_args.seed, skip_warmup=(not neox_args.mmap_warmup), build_index_mappings=build_index_mappings, + label_prefix=test_label_path, ) ) return train_datasets, valid_datasets, test_datasets @@ -414,6 +443,8 @@ def build_train_valid_test_data_iterators(neox_args): seq_length=neox_args.seq_length, seed=neox_args.seed, skip_warmup=(not neox_args.mmap_warmup), + pack_impl=neox_args.pack_impl, + allow_chopped=neox_args.allow_chopped, ) # Build dataloders. diff --git a/megatron/data/gpt2_dataset.py b/megatron/data/gpt2_dataset.py index 75e601fda..edba57df2 100644 --- a/megatron/data/gpt2_dataset.py +++ b/megatron/data/gpt2_dataset.py @@ -36,14 +36,19 @@ def __init__( num_samples, seq_length, seed, + pack_impl="packed", + allow_chopped=True, build_index_mappings=True, use_shared_fs=True, label_dataset=None, ): self.name = name + self.pack_impl = pack_impl + self.allow_chopped = allow_chopped self.indexed_dataset = indexed_dataset self.label_dataset = label_dataset + self.seq_length = seq_length # Checks assert np.min(documents) >= 0 @@ -56,10 +61,13 @@ def __init__( data_prefix, documents, self.indexed_dataset.sizes, + self.label_dataset, num_samples, seq_length, seed, + self.pack_impl, use_shared_fs=use_shared_fs, + allow_chopped=self.allow_chopped, ) self.shuffle_idx_len = self.shuffle_idx.shape[0] - 1 self.sample_idx_len = self.sample_idx.shape[0] - 1 @@ -113,8 +121,38 @@ def __getitem__(self, idx): samples.append(np.concatenate(sample_list)) if len(datasets) == 1: + if len(samples[0]) < (self.seq_length + 1): + # Pad with -100s so the masking function can ignore these. + samples[0] = np.pad( + samples[0], + (0, (self.seq_length + 1) - len(samples[0])), + mode="constant", + constant_values=-100, + ) + elif len(samples[0]) > (self.seq_length + 1): + # Check for overflow and truncate. + samples[0] = samples[0][: (self.seq_length + 1)] return {"text": np.array(samples[0], dtype=np.int64)} else: + if len(samples[0]) < (self.seq_length + 1): + # Pad with 0s, can use any number since it's masked. + samples[0] = np.pad( + samples[0], + (0, (self.seq_length + 1) - len(samples[0])), + mode="constant", + constant_values=0, + ) + # pad with -100s so we can mask it out + samples[1] = np.pad( + samples[1], + (0, (self.seq_length + 1) - len(samples[1])), + mode="constant", + constant_values=-100, + ) + elif len(samples[0]) > (self.seq_length + 1): + # Check for overflow and truncate. + samples[0] = samples[0][: (self.seq_length + 1)] + samples[1] = samples[1][: (self.seq_length + 1)] return { "text": np.array(samples[0], dtype=np.int64), "label": np.array(samples[1], dtype=np.int64), @@ -132,10 +170,13 @@ def _build_index_mappings( data_prefix, documents, sizes, + label_dataset, num_samples, seq_length, seed, + packing_impl, use_shared_fs=True, + allow_chopped=True, ): """Build doc-idx, sample-idx, and shuffle-idx. doc-idx: is an array (ordered) of documents to be used in training. @@ -155,6 +196,9 @@ def _build_index_mappings( _filename += "_{}ns".format(num_samples) _filename += "_{}sl".format(seq_length) _filename += "_{}s".format(seed) + _filename += "_{}pi".format(packing_impl) + if allow_chopped: + _filename += "_ac" doc_idx_filename = _filename + "_doc_idx.npy" sample_idx_filename = _filename + "_sample_idx.npy" shuffle_idx_filename = _filename + "_shuffle_idx.npy" @@ -177,44 +221,116 @@ def _build_index_mappings( ) # doc-idx. start_time = time.time() - doc_idx = _build_doc_idx(documents, num_epochs, np_rng) - np.save(doc_idx_filename, doc_idx, allow_pickle=True) - print_rank_0( - " > elapsed time to build and save doc-idx mapping " - "(seconds): {:4f}".format(time.time() - start_time) - ) - # sample-idx. - start_time = time.time() - # Use C++ implementation for speed. - from megatron.data import helpers - - assert doc_idx.dtype == np.int32 - assert sizes.dtype == np.int32 - - num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length - if 2 * (num_samples + 1) < np.iinfo(np.int32).max: - sample_idx = helpers.build_sample_idx_int32( - sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch + if packing_impl == "packed": + doc_idx = _build_doc_idx(documents, num_epochs, np_rng) + np.save(doc_idx_filename, doc_idx, allow_pickle=True) + print_rank_0( + " > elapsed time to build and save doc-idx mapping " + "(seconds): {:4f}".format(time.time() - start_time) ) - else: - sample_idx = helpers.build_sample_idx_int64( - sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch + # sample-idx. + start_time = time.time() + # Use C++ implementation for speed. + from megatron.data import helpers + + assert doc_idx.dtype == np.int32 + assert sizes.dtype == np.int32 + + num_samples = (num_epochs * tokens_per_epoch - 1) / seq_length + if 2 * (num_samples + 1) < np.iinfo(np.int32).max: + sample_idx = helpers.build_sample_idx_int32( + sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch + ) + else: + sample_idx = helpers.build_sample_idx_int64( + sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch + ) + np.save(sample_idx_filename, sample_idx, allow_pickle=True) + print_rank_0( + " > elapsed time to build and save sample-idx mapping " + "(seconds): {:4f}".format(time.time() - start_time) ) - np.save(sample_idx_filename, sample_idx, allow_pickle=True) - print_rank_0( - " > elapsed time to build and save sample-idx mapping " - "(seconds): {:4f}".format(time.time() - start_time) - ) - # shuffle-idx. - start_time = time.time() - # -1 is due to data structure used to retrieve the index: - # sample i --> [sample_idx[i], sample_idx[i+1]) - shuffle_idx = _build_shuffle_idx(sample_idx.shape[0] - 1, np_rng) - np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True) - print_rank_0( - " > elapsed time to build and save shuffle-idx mapping" - " (seconds): {:4f}".format(time.time() - start_time) - ) + # shuffle-idx. + start_time = time.time() + # -1 is due to data structure used to retrieve the index: + # sample i --> [sample_idx[i], sample_idx[i+1]) + shuffle_idx = _build_shuffle_idx(sample_idx.shape[0] - 1, np_rng) + np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True) + print_rank_0( + " > elapsed time to build and save shuffle-idx mapping" + " (seconds): {:4f}".format(time.time() - start_time) + ) + elif packing_impl == "pack_until_overflow": + # Naively pack data until it overflows, then roll it over to a new one instead. + shuffle_idx = np.arange(num_samples) # Shuffle index around epochs + np_rng.shuffle(shuffle_idx) + sample_idx = [] + doc_idx = [] + # Iterate over files until we have enough samples. + temp_shuffle_idx = np.arange(len(documents)) + np_rng.shuffle(temp_shuffle_idx) + running_length = 0 + curr_shuffle_idx = 0 + while len(sample_idx) < num_samples: + if not allow_chopped: + # +1 since we shift left/right by 1 + if sizes[temp_shuffle_idx[curr_shuffle_idx]] > seq_length + 1: + curr_shuffle_idx += 1 + continue + # First, check if we need to skip this item... + if label_dataset is not None: + if np.all( + label_dataset.get(temp_shuffle_idx[curr_shuffle_idx])[ + : seq_length + 1 + ] + == -100 + ): + curr_shuffle_idx += 1 + continue + doc_length = sizes[temp_shuffle_idx[curr_shuffle_idx]] + if running_length == 0: + sample_idx.append(np.array([len(doc_idx), 0])) + doc_idx.append(temp_shuffle_idx[curr_shuffle_idx]) + running_length += doc_length + else: + if running_length + doc_length > (seq_length + 1): + running_length = doc_length + sample_idx.append(np.array([len(doc_idx), 0])) + else: + running_length += doc_length + doc_idx.append(temp_shuffle_idx[curr_shuffle_idx]) + curr_shuffle_idx += 1 + if curr_shuffle_idx == len(documents): + curr_shuffle_idx = 0 + np_rng.shuffle(temp_shuffle_idx) + sample_idx.append(np.array([len(doc_idx), 0])) + np.save(doc_idx_filename, doc_idx, allow_pickle=True) + np.save(sample_idx_filename, sample_idx, allow_pickle=True) + np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True) + elif packing_impl == "unpacked": + # Unpacked data, one sample per document. + shuffle_idx = np.arange(num_samples) # Shuffle index around epochs + np_rng.shuffle(shuffle_idx) + sample_idx = np.zeros((num_samples + 1, 2), dtype=np.int64) + sample_idx[:, 0] = np.array([i for i in range(num_samples + 1)]) + sample_idx[:, 1] = 0 + doc_idx = list() + doc_i = 0 + while len(doc_idx) <= num_samples: + if not allow_chopped: + # +1 since we shift left/right by 1 + if sizes[doc_i] > seq_length + 1: + doc_i = (doc_i + 1) % len(documents) + continue + # Just in case we have bad data in the loop... + if np.all(label_dataset.get(doc_i)[:seq_length] == -100): + doc_i = (doc_i + 1) % len(documents) + continue + doc_idx.append(doc_i) + doc_i = (doc_i + 1) % len(documents) + np.save(doc_idx_filename, doc_idx, allow_pickle=True) + np.save(sample_idx_filename, sample_idx, allow_pickle=True) + np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True) # This should be a barrier but nccl barrier assumes # device_index=rank which is not the case for model diff --git a/megatron/neox_arguments/arguments.py b/megatron/neox_arguments/arguments.py index fb26fb4aa..327639454 100644 --- a/megatron/neox_arguments/arguments.py +++ b/megatron/neox_arguments/arguments.py @@ -1121,10 +1121,8 @@ def calculate_derived(self): if self.test_data_paths and (self.test_data_weights is None): self.test_data_weights = [1.0] * len(self.test_data_paths) - if self.label_data_paths: - err_str = ( - "Must use `label_data_paths` with `train_data_paths`, not `data_path`" - ) + if self.train_label_data_paths: + err_str = "Must use `train_label_data_paths` with `train_data_paths`, not `data_path`" assert self.train_data_paths and not self.data_path, err_str # if a sample input file is provided, default text_gen_type type to input-file diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index 7993f785f..dd51c7778 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -855,9 +855,9 @@ class NeoXArgsTraining(NeoXArgsTemplate): List of paths to train datasets. """ - label_data_paths: list = None + train_label_data_paths: list = None """ - List of paths to label datasets (not shifted by 1 yet!). + List of paths to train label datasets (not shifted by 1 yet!). """ test_data_paths: list = None @@ -865,11 +865,21 @@ class NeoXArgsTraining(NeoXArgsTemplate): List of paths to test datasets. """ + test_label_data_paths: list = None + """ + List of paths to test label datasets (not shifted by 1 yet!). + """ + valid_data_paths: list = None """ List of paths to validation datasets. """ + valid_label_data_paths: list = None + """ + List of paths to validation label datasets (not shifted by 1 yet!). + """ + train_data_weights: list = None """ List of 'weights' that decide how often to sample from each training dataset when blending datasets. If None, defaults to equal weighting. @@ -919,6 +929,21 @@ class NeoXArgsTraining(NeoXArgsTemplate): Implementation of indexed datasets, can be one of "infer", "cached", or "mmap" """ + pack_impl: Literal["packed", "pack_until_overflow", "unpacked"] = "packed" + """ + Packing implementation, can be one of "packed", "pack_until_overflow", or "unpacked". + + warning: pack_until_overflow is very naive and will likely have issues with pretraining scale datasets + """ + + allow_chopped: bool = True + """ + WARNING: if your packing impl is packed, this is ignored. + + Allow chopped samples in the dataset. + (e.g if your sequence length is 1024 and you have a sample of length 1026, it will be chopped to 1024) + """ + mmap_warmup: bool = False """ Warm up mmap files. diff --git a/megatron/training.py b/megatron/training.py index ce59b242a..fc3d9e129 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -278,16 +278,19 @@ def pretrain(neox_args): def _get_batch(neox_args, tokenizer, keys, data, datatype): """Support function for get_batch / get_batch pipe (to avoid code repetition)""" data_b = mpu.broadcast_data(keys, data, datatype) - + token_key = keys[0] + label_key = keys[1] if len(keys) > 1 else None # Unpack. - tokens_ = data_b["text"].long() + tokens_ = data_b[token_key].long() if "label" in data_b: + label_mask = (data_b[label_key].long() >= 0)[:, 1:].contiguous() labels = torch.where( - data_b["label"].long() >= 0, - data_b["label"].long(), + data_b[label_key].long() >= 0, + data_b[label_key].long(), torch.zeros_like(data_b["label"].long()), )[:, 1:].contiguous() else: + label_mask = (tokens_.long() >= 0)[:, 1:].contiguous() labels = tokens_[:, 1:].contiguous() tokens = tokens_[:, :-1].contiguous() @@ -298,9 +301,9 @@ def _get_batch(neox_args, tokenizer, keys, data, datatype): eod_mask_loss=neox_args.eod_mask_loss, sliding_window_width=neox_args.sliding_window_width, ) - # If `label` is present, any token < 0 (e.g., -100, the default for torch) skips the loss computation - if "label" in data_b: - loss_mask = (data_b["label"][:, 1:] >= 0).to(loss_mask.dtype) + + # combine loss masks from get_ltor_masks_and_position_ids with loss masks from data + loss_mask = label_mask.to(loss_mask.dtype) * loss_mask return tokens, labels, loss_mask, attention_mask, position_ids @@ -308,7 +311,7 @@ def get_batch(neox_args, data_iterator): """Generate a batch""" # Items and their type. - keys = ["text", "label"] if neox_args.label_data_paths else ["text"] + keys = ["text", "label"] if neox_args.train_label_data_paths else ["text"] datatype = torch.int64 # Broadcast data. @@ -328,7 +331,7 @@ def get_batch(neox_args, data_iterator): def get_batch_pipe(data, neox_args, curr_scheduler=None): """A modification of get_batch() to work with the latest batch instead of an iterator.""" # Items and their type. - keys = ["text", "label"] if neox_args.label_data_paths else ["text"] + keys = ["text", "label"] if neox_args.train_label_data_paths else ["text"] datatype = torch.int64 tokens, labels, loss_mask, attention_mask, position_ids = _get_batch(