diff --git a/configs/README.md b/configs/README.md index e14274b56..3102a34d1 100644 --- a/configs/README.md +++ b/configs/README.md @@ -235,6 +235,33 @@ Additional DeepSpeed settings besides those mentioned above should be wrapped in "eval_iters": 10, ``` +However, if you want to use DPO style training you'll need to set pos/neg data paths instead of a single one, e.g. + +```yaml + "dataset_impl": "pairwise", + "train_impl": "dpo", + "pack_impl": "unpacked", + "dpo_beta": 0.1, + "dpo_fp32": true, + "pos_train_data_path": "data/enwik8/enwik8_text_pos_document", + "pos_valid_data_path": "data/enwik8/enwik8_text_pos_document", + "pos_test_data_path": "data/enwik8/enwik8_text_pos_document", + "neg_train_data_path": "data/enwik8/enwik8_text_neg_document", + "neg_valid_data_path": "data/enwik8/enwik8_text_neg_document", + "neg_test_data_path": "data/enwik8/enwik8_text_neg_document", + ## If you have labels... (likely to mask out user turns) + "pos_train_label_data_path": "data/enwik8/enwik8_text_pos_label_document", + "pos_valid_label_data_path": "data/enwik8/enwik8_text_pos_label_document", + "pos_test_label_data_path": "data/enwik8/enwik8_text_pos_label_document", + "neg_train_label_data_path": "data/enwik8/enwik8_text_neg_label_document", + "neg_valid_label_data_path": "data/enwik8/enwik8_text_neg_label_document", + "neg_test_label_data_path": "data/enwik8/enwik8_text_neg_label_document", + ## If you want to precompute the logits over your dataset... + "precompute_model_name": "gpt2", + ## Needed for the generation.py step, if precomputing + "text_gen_type": "precompute" +``` + ### LR Scheduler settings ```yaml diff --git a/generate.py b/generate.py index 743e350d0..e19ef2e0e 100755 --- a/generate.py +++ b/generate.py @@ -23,6 +23,7 @@ generate_samples_from_prompt, generate_samples_unconditional, generate_samples_interactive, + precompute_logits, ) @@ -83,6 +84,8 @@ def main(input_args=None, overwrite_values=None): top_p=neox_args.top_p, ) + elif neox_args.text_gen_type == "precompute": + precompute_logits(neox_args=neox_args, model=model) else: raise ValueError( f"`text_gen_type` either not specified or not recognised: {neox_args.text_gen_type}" diff --git a/megatron/data/data_utils.py b/megatron/data/data_utils.py index 7e4dbdb37..7c13131ad 100644 --- a/megatron/data/data_utils.py +++ b/megatron/data/data_utils.py @@ -23,6 +23,7 @@ from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset from megatron.data.blendable_dataset import BlendableDataset from megatron.data.gpt2_dataset import GPT2Dataset +from megatron.data.pairwise_dataset import PairwiseDataset from megatron.data.samplers import DistributedBatchSampler @@ -53,9 +54,12 @@ def make_data_loader(dataset, neox_args): def build_the_dataset( data_prefix, + pos_data_prefix, + neg_data_prefix, name, data_impl, pack_impl, + dataset_impl, allow_chopped, num_samples, seq_length, @@ -63,33 +67,100 @@ def build_the_dataset( skip_warmup, build_index_mappings=True, label_prefix=None, + pos_label_prefix=None, + neg_label_prefix=None, + precompute_model_name=None, ): """Build train/valid/test datasets.""" - - indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup) - if label_prefix is None: - label_dataset = None + if dataset_impl == "gpt2": + indexed_dataset = make_indexed_dataset(data_prefix, data_impl, skip_warmup) + if label_prefix is None: + label_dataset = None + else: + label_dataset = make_indexed_dataset(label_prefix, data_impl, skip_warmup) + if precompute_model_name is not None: + # If we have the name, assume it exists. If it doesn't, it will just be None which is fine. + precompute_indexed_dataset = make_indexed_dataset( + data_prefix + "_" + precompute_model_name, data_impl, skip_warmup + ) + precompute_indexed_dataset = precompute_indexed_dataset + elif dataset_impl == "pairwise": + pos_indexed_dataset = make_indexed_dataset( + pos_data_prefix, data_impl, skip_warmup + ) + neg_indexed_dataset = make_indexed_dataset( + neg_data_prefix, data_impl, skip_warmup + ) + if pos_label_prefix is None: + pos_label_dataset = None + # Also do neg here since they both must be the same + assert neg_label_prefix is None + neg_label_dataset = None + else: + pos_label_dataset = make_indexed_dataset( + pos_label_prefix, data_impl, skip_warmup + ) + # Also do neg here since they both must be the same + assert neg_label_prefix is not None + neg_label_dataset = make_indexed_dataset( + neg_label_prefix, data_impl, skip_warmup + ) + if precompute_model_name is None: + pos_ref_dataset = None + neg_ref_dataset = None + else: + pos_ref_dataset = make_indexed_dataset( + pos_data_prefix + "_" + precompute_model_name, data_impl, skip_warmup + ) + neg_ref_dataset = make_indexed_dataset( + neg_data_prefix + "_" + precompute_model_name, data_impl, skip_warmup + ) else: - label_dataset = make_indexed_dataset(label_prefix, data_impl, skip_warmup) + raise NotImplementedError(f"dataset_impl={dataset_impl} not implemented") - total_num_of_documents = indexed_dataset.sizes.shape[0] + total_num_of_documents = ( + indexed_dataset.sizes.shape[0] + if dataset_impl == "gpt2" + else pos_indexed_dataset.sizes.shape[0] + ) print_rank_0(" {}:".format(name)) print_rank_0(" no. of documents:{}".format(total_num_of_documents)) dataset = None documents = np.arange(start=0, stop=total_num_of_documents, step=1, dtype=np.int32) - dataset = GPT2Dataset( - name, - data_prefix, - documents, - indexed_dataset, - num_samples, - seq_length, - seed, - pack_impl=pack_impl, - allow_chopped=allow_chopped, - build_index_mappings=build_index_mappings, - label_dataset=label_dataset, - ) + + if dataset_impl == "gpt2": + dataset = GPT2Dataset( + name, + data_prefix, + documents, + indexed_dataset, + num_samples, + seq_length, + seed, + pack_impl=pack_impl, + allow_chopped=allow_chopped, + build_index_mappings=build_index_mappings, + label_dataset=label_dataset, + ) + elif dataset_impl == "pairwise": + dataset = PairwiseDataset( + name, + pos_data_prefix, + documents, + pos_indexed_dataset, + neg_indexed_dataset, + num_samples, + seq_length, + seed, + pack_impl=pack_impl, + allow_chopped=allow_chopped, + build_index_mappings=build_index_mappings, + pos_label_dataset=pos_label_dataset, + neg_label_dataset=neg_label_dataset, + pos_ref_dataset=pos_ref_dataset, + neg_ref_dataset=neg_ref_dataset, + ) + return dataset @@ -135,7 +206,6 @@ def build_dataset(index, name): documents = np.arange( start=splits[index], stop=splits[index + 1], step=1, dtype=np.int32 ) - dataset = GPT2Dataset( name, data_prefix, @@ -219,21 +289,57 @@ def build_weighted_datasets( valid_label_path, test_path, test_label_path, + pos_train_path, + neg_train_path, + pos_train_label_path, + neg_train_label_path, + pos_valid_path, + neg_valid_path, + pos_valid_label_path, + neg_valid_label_path, + pos_test_path, + neg_test_path, + pos_test_label_path, + neg_test_label_path, ) in enumerate( zip_longest( - neox_args.train_data_paths, + neox_args.train_data_paths if neox_args.train_data_paths else [], neox_args.train_label_data_paths if neox_args.train_label_data_paths else [], - neox_args.valid_data_paths, + neox_args.valid_data_paths if neox_args.valid_data_paths else [], neox_args.valid_label_data_paths if neox_args.valid_label_data_paths else [], - neox_args.test_data_paths, + neox_args.test_data_paths if neox_args.test_data_paths else [], neox_args.test_label_data_paths if neox_args.test_label_data_paths else [], + neox_args.pos_train_data_paths if neox_args.pos_train_data_paths else [], + neox_args.neg_train_data_paths if neox_args.neg_train_data_paths else [], + neox_args.pos_train_label_data_paths + if neox_args.pos_train_label_data_paths + else [], + neox_args.neg_train_label_data_paths + if neox_args.neg_train_label_data_paths + else [], + neox_args.pos_valid_data_paths if neox_args.pos_valid_data_paths else [], + neox_args.neg_valid_data_paths if neox_args.neg_valid_data_paths else [], + neox_args.pos_valid_label_data_paths + if neox_args.pos_valid_label_data_paths + else [], + neox_args.neg_valid_label_data_paths + if neox_args.neg_valid_label_data_paths + else [], + neox_args.pos_test_data_paths if neox_args.pos_test_data_paths else [], + neox_args.neg_test_data_paths if neox_args.neg_test_data_paths else [], + neox_args.pos_test_label_data_paths + if neox_args.pos_test_label_data_paths + else [], + neox_args.neg_test_label_data_paths + if neox_args.neg_test_label_data_paths + else [], ) ): - if train_path: + if train_path or pos_train_path: train_datasets.append( build_the_dataset( data_prefix=train_path, @@ -247,10 +353,16 @@ def build_weighted_datasets( skip_warmup=(not neox_args.mmap_warmup), build_index_mappings=build_index_mappings, label_prefix=train_label_path, + dataset_impl=neox_args.dataset_impl, + pos_data_prefix=pos_train_path, + neg_data_prefix=neg_train_path, + pos_label_prefix=pos_train_label_path, + neg_label_prefix=neg_train_label_path, + precompute_model_name=neox_args.precompute_model_name, ) ) - if valid_path: + if valid_path or pos_valid_path: valid_datasets.append( build_the_dataset( data_prefix=valid_path, @@ -264,10 +376,16 @@ def build_weighted_datasets( skip_warmup=(not neox_args.mmap_warmup), build_index_mappings=build_index_mappings, label_prefix=valid_label_path, + dataset_impl=neox_args.dataset_impl, + pos_data_prefix=pos_valid_path, + neg_data_prefix=neg_valid_path, + pos_label_prefix=pos_valid_label_path, + neg_label_prefix=neg_valid_label_path, + precompute_model_name=neox_args.precompute_model_name, ) ) - if test_path: + if test_path or pos_test_path: test_datasets.append( build_the_dataset( data_prefix=test_path, @@ -281,6 +399,12 @@ def build_weighted_datasets( skip_warmup=(not neox_args.mmap_warmup), build_index_mappings=build_index_mappings, label_prefix=test_label_path, + dataset_impl=neox_args.dataset_impl, + pos_data_prefix=pos_test_path, + neg_data_prefix=neg_test_path, + pos_label_prefix=pos_test_label_path, + neg_label_prefix=neg_test_label_path, + precompute_model_name=neox_args.precompute_model_name, ) ) return train_datasets, valid_datasets, test_datasets @@ -352,7 +476,7 @@ def build_train_valid_test_data_iterators(neox_args): test_iters * neox_args.train_batch_size, ] - if neox_args.train_data_paths: + if (neox_args.train_data_paths) or (neox_args.pos_train_data_paths): # when individual train / valid / test data paths are provided # normalize weight values and get num samples for each dataset train_weights, train_num_samples = get_normalized_weights_and_num_samples( diff --git a/megatron/data/pairwise_dataset.py b/megatron/data/pairwise_dataset.py new file mode 100644 index 000000000..e39b4d626 --- /dev/null +++ b/megatron/data/pairwise_dataset.py @@ -0,0 +1,457 @@ +# Copyright (c) 2024, EleutherAI +# This file is based on code by the authors denoted below and has been modified from its original version. +# +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pairwise style dataset.""" + +import os +import time + +import numpy as np +import torch + +from megatron import mpu, print_rank_0 + + +class PairwiseDataset(torch.utils.data.Dataset): + def __init__( + self, + name, + pos_data_prefix, # Don't need neg since it's assumed you have paired the data already. + documents, + pos_indexed_dataset, + neg_indexed_dataset, + num_samples, + seq_length, + seed, + pack_impl="unpacked", + build_index_mappings=True, + use_shared_fs=True, + pos_label_dataset=None, + pos_ref_dataset=None, + neg_label_dataset=None, + neg_ref_dataset=None, + allow_chopped=True, + ): + + self.name = name + self.pos_indexed_dataset = pos_indexed_dataset + self.pos_label_dataset = pos_label_dataset + self.pos_ref_dataset = pos_ref_dataset + self.neg_indexed_dataset = neg_indexed_dataset + self.neg_label_dataset = neg_label_dataset + self.neg_ref_dataset = neg_ref_dataset + self.pack_impl = pack_impl + self.seq_length = seq_length + # Checks + assert np.min(documents) >= 0 + assert (neg_label_dataset is not None and pos_label_dataset is not None) or ( + neg_label_dataset is None and pos_label_dataset is None + ), "Label datasets must be both None or both not None" + assert np.max(documents) < pos_indexed_dataset.sizes.shape[0] + assert pos_indexed_dataset.sizes.shape[0] == neg_indexed_dataset.sizes.shape[0] + assert ( + pack_impl != "packed" + ), "Packed implementation not supported for pairwise dataset" + + if build_index_mappings: + # Build index mappings. + self.doc_idx, self.sample_idx, self.shuffle_idx = _build_index_mappings( + self.name, + pos_data_prefix, + documents, + self.pos_indexed_dataset.sizes, + self.neg_indexed_dataset.sizes, + self.pos_label_dataset, + self.neg_label_dataset, + num_samples, + seq_length, + seed, + pack_impl, + use_shared_fs=use_shared_fs, + allow_chopped=allow_chopped, + ) + self.shuffle_idx_len = self.shuffle_idx.shape[0] - 1 + self.sample_idx_len = self.sample_idx.shape[0] - 1 + + if self.shuffle_idx_len != self.sample_idx_len - 1: + print( + f"WARNING: shuffle index length ({self.shuffle_idx_len}) is not equal to sample index length ({self.sample_idx_len})" + ) + + def __len__(self): + return min(self.shuffle_idx_len, self.sample_idx_len) + + def __getitem__(self, idx): + try: + # Get the shuffled index. + idx = self.shuffle_idx[idx] + # Start and end documents and offsets. + doc_index_f = self.sample_idx[idx][0] + doc_index_l = self.sample_idx[idx + 1][0] + offset_f = self.sample_idx[idx][1] + offset_l = self.sample_idx[idx + 1][1] + # Labels and texts are supposed to be fully in sync. + datasets = [self.pos_indexed_dataset, self.neg_indexed_dataset] + + if self.pos_label_dataset is not None: + datasets += [ + self.pos_label_dataset, + self.neg_label_dataset, + ] + if self.pos_ref_dataset is not None: + datasets += [ + self.pos_ref_dataset, + self.neg_ref_dataset, + ] + samples = [] + pos_ref_samples = [] + neg_ref_samples = [] + # If we are within the same document, just extract the chunk. + for n, dataset in enumerate(datasets): + if doc_index_f == doc_index_l: + samples.append( + dataset.get( + self.doc_idx[doc_index_f], + offset=offset_f, + length=offset_l - offset_f + 1, + ) + ) + else: + # Otherwise, get the rest of the initial document. + sample_list = [ + dataset.get(self.doc_idx[doc_index_f], offset=offset_f) + ] + # Loop over all in between documents and add the entire document. + for i in range(doc_index_f + 1, doc_index_l): + sample_list.append(dataset.get(self.doc_idx[i])) + # And finally add the relevant portion of last document. + sample_list.append( + dataset.get(self.doc_idx[doc_index_l], length=offset_l + 1) + ) + samples.append(np.concatenate(sample_list)) + for i in range(len(samples)): + if len(samples[i]) < (self.seq_length + 1): + if ((i == 2) or (i == 3)) and self.pos_label_dataset is not None: + # Labels... So pad with -100 + samples[i] = np.pad( + samples[i], + (0, (self.seq_length + 1) - len(samples[i])), + mode="constant", + constant_values=-100, + ) + else: + # Pad with 0s, can use any number since it's masked. + samples[i] = np.pad( + samples[i], + (0, (self.seq_length + 1) - len(samples[i])), + mode="constant", + constant_values=0, + ) + elif len(samples[i]) > (self.seq_length + 1): + # Check for overflow and truncate. + samples[i] = samples[i][: (self.seq_length + 1)] + ret = {} + ret["pos"] = np.array(samples[0], dtype=np.int64) + ret["neg"] = np.array(samples[1], dtype=np.int64) + if self.pos_label_dataset is not None: + ret["pos_label"] = np.array(samples[2], dtype=np.int64) + ret["neg_label"] = np.array(samples[3], dtype=np.int64) + if self.pos_ref_dataset is not None: + ret["pos_ref"] = np.array(samples[4], dtype=np.float32) + ret["neg_ref"] = np.array(samples[5], dtype=np.float32) + elif self.pos_ref_dataset is not None: + # Don't have labels... + ret["pos_ref"] = np.array(samples[2], dtype=np.float32) + ret["neg_ref"] = np.array(samples[3], dtype=np.float32) + return ret + except IndexError: + new_idx = idx % len(self) + print( + f"WARNING: Got index out of bounds error with index {idx} - taking modulo of index instead ({new_idx})" + ) + return self[new_idx] + + +def _build_index_mappings( + name, + pos_data_prefix, + documents, + pos_sizes, + neg_sizes, + pos_label_dataset, + neg_label_dataset, + num_samples, + seq_length, + seed, + packing_impl, + use_shared_fs=True, + allow_chopped=True, +): + """Build doc-idx, sample-idx, and shuffle-idx. + doc-idx: is an array (ordered) of documents to be used in training. + sample-idx: is the start document index and document offset for each + training sample. + shuffle-idx: maps the sample index into a random index into sample-idx. + """ + # Number of tokens in each epoch and number of required epochs. + tokens_per_epoch = _num_tokens(documents, pos_sizes) + num_epochs = _num_epochs(tokens_per_epoch, seq_length, num_samples) + # rng state + np_rng = np.random.RandomState(seed=seed) + + # Filename of the index mappings. + _filename = pos_data_prefix + _filename += "_{}_indexmap".format(name) + _filename += "_{}ns".format(num_samples) + _filename += "_{}sl".format(seq_length) + _filename += "_{}s".format(seed) + _filename += "_{}pi".format(packing_impl) + doc_idx_filename = _filename + "_doc_idx.npy" + sample_idx_filename = _filename + "_sample_idx.npy" + shuffle_idx_filename = _filename + "_shuffle_idx.npy" + + if not use_shared_fs: + should_process_dataset = int(os.environ["LOCAL_RANK"]) == 0 + else: + should_process_dataset = torch.distributed.get_rank() == 0 + + # Build the indexed mapping if not exist. + if should_process_dataset: + if ( + (not os.path.isfile(doc_idx_filename)) + or (not os.path.isfile(sample_idx_filename)) + or (not os.path.isfile(shuffle_idx_filename)) + ): + print_rank_0( + " > WARNING: could not find index map files, building " + "the indices on rank 0 ..." + ) + # doc-idx. + start_time = time.time() + if packing_impl == "pack_until_overflow": + # Naively pack data until it overflows, then roll it over to a new one instead. + shuffle_idx = np.arange(num_samples) # Shuffle index around epochs + np_rng.shuffle(shuffle_idx) + sample_idx = [] + doc_idx = [] + # Iterate over files until we have enough samples. + temp_shuffle_idx = np.arange(len(documents)) + np_rng.shuffle(temp_shuffle_idx) + running_length = 0 + curr_shuffle_idx = 0 + while len(sample_idx) < num_samples: + # If not allow_chopped, skip this item if it's chopped. + if not allow_chopped: + if ( + pos_sizes[temp_shuffle_idx[curr_shuffle_idx]] + < seq_length + 1 + ): + curr_shuffle_idx += 1 + continue + if ( + neg_sizes[temp_shuffle_idx[curr_shuffle_idx]] + < seq_length + 1 + ): + curr_shuffle_idx += 1 + continue + # Then, check if we need to skip this item... + if pos_label_dataset is not None: + if np.all( + pos_label_dataset.get(temp_shuffle_idx[curr_shuffle_idx])[ + : seq_length + 1 + ] + == -100 + ): + curr_shuffle_idx += 1 + continue + if np.all( + neg_label_dataset.get(temp_shuffle_idx[curr_shuffle_idx])[ + : seq_length + 1 + ] + == -100 + ): + curr_shuffle_idx += 1 + continue + doc_length = max( + pos_sizes[temp_shuffle_idx[curr_shuffle_idx]], + neg_sizes[temp_shuffle_idx[curr_shuffle_idx]], + ) + if running_length == 0: + sample_idx.append(np.array([len(doc_idx), 0])) + doc_idx.append(temp_shuffle_idx[curr_shuffle_idx]) + running_length += doc_length + else: + if running_length + doc_length > (seq_length + 1): + running_length = doc_length + sample_idx.append(np.array([len(doc_idx), 0])) + else: + running_length += doc_length + doc_idx.append(temp_shuffle_idx[curr_shuffle_idx]) + curr_shuffle_idx += 1 + if curr_shuffle_idx == len(documents): + curr_shuffle_idx = 0 + np_rng.shuffle(temp_shuffle_idx) + sample_idx.append(np.array([len(doc_idx), 0])) + np.save(doc_idx_filename, doc_idx, allow_pickle=True) + np.save(sample_idx_filename, sample_idx, allow_pickle=True) + np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True) + elif packing_impl == "unpacked": + # Unpacked data, one sample per document. + shuffle_idx = np.array([i % len(documents) for i in range(num_samples)]) + np_rng.shuffle(shuffle_idx) + sample_idx = np.zeros((num_samples + 1, 2), dtype=np.int64) + sample_idx[:, 0] = np.array([i for i in range(num_samples + 1)]) + sample_idx[:, 1] = 0 + doc_idx = list() + doc_i = 0 + while len(doc_idx) <= num_samples: + # Check if we need to skip this item... + if not allow_chopped: + # +1 since we shift left/right by 1 + if pos_sizes[doc_i] > seq_length + 1: + doc_i = (doc_i + 1) % len(documents) + continue + if neg_sizes[doc_i] > seq_length + 1: + doc_i = (doc_i + 1) % len(documents) + continue + # In theory if we don't allow chopped we should be able to skip it, but the warm fuzzies I get + # from this are worth the extra bool check + if np.all(pos_label_dataset.get(doc_i)[:seq_length] == -100): + doc_i = (doc_i + 1) % len(documents) + continue + if np.all(neg_label_dataset.get(doc_i)[:seq_length] == -100): + doc_i = (doc_i + 1) % len(documents) + continue + doc_idx.append(doc_i) + doc_i = (doc_i + 1) % len(documents) + np.save(doc_idx_filename, doc_idx, allow_pickle=True) + np.save(sample_idx_filename, sample_idx, allow_pickle=True) + np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True) + + # This should be a barrier but nccl barrier assumes + # device_index=rank which is not the case for model + # parallel case + counts = torch.cuda.LongTensor([1]) + torch.distributed.all_reduce(counts, group=mpu.get_io_parallel_group()) + assert counts[0].item() == torch.distributed.get_world_size( + group=mpu.get_io_parallel_group() + ) + + # Load mappings. + start_time = time.time() + print_rank_0(" > loading doc-idx mapping from {}".format(doc_idx_filename)) + doc_idx = np.load(doc_idx_filename, allow_pickle=True, mmap_mode="r") + print_rank_0(" > loading sample-idx mapping from {}".format(sample_idx_filename)) + sample_idx = np.load(sample_idx_filename, allow_pickle=True, mmap_mode="r") + print_rank_0(" > loading shuffle-idx mapping from {}".format(shuffle_idx_filename)) + shuffle_idx = np.load(shuffle_idx_filename, allow_pickle=True, mmap_mode="r") + print_rank_0( + " loaded indexed file in {:3.3f} seconds".format(time.time() - start_time) + ) + print_rank_0(" total number of samples: {}".format(sample_idx.shape[0])) + print_rank_0(" total number of epochs: {}".format(num_epochs)) + + return doc_idx, sample_idx, shuffle_idx + + +def _num_tokens(documents, sizes): + """Total number of tokens in the dataset.""" + return np.sum(sizes[documents]) + + +def _num_epochs(tokens_per_epoch, seq_length, num_samples): + """Based on number of samples and sequence length, calculate how many + epochs will be needed.""" + num_epochs = 0 + total_tokens = 0 + while True: + num_epochs += 1 + total_tokens += tokens_per_epoch + # -1 is because we need to retrieve seq_length + 1 token each time + # but the last token will overlap with the first token of the next + # sample except for the last sample. + if ((total_tokens - 1) // seq_length) >= num_samples: + return num_epochs + + +def _build_doc_idx(documents, num_epochs, np_rng): + """Build an array with length = number-of-epochs * number-of-documents. + Each index is mapped to a corresponding document.""" + doc_idx = np.mgrid[0:num_epochs, 0 : len(documents)][1] + doc_idx[:] = documents + doc_idx = doc_idx.reshape(-1) + doc_idx = doc_idx.astype(np.int32) + np_rng.shuffle(doc_idx) + return doc_idx + + +def _build_sample_idx(sizes, doc_idx, seq_length, num_epochs, tokens_per_epoch): + """Sample index mapping is a 2D array with sizes + [number-of-samples + 1, 2] where [..., 0] contains + the index into `doc_idx` and [..., 1] is the + starting offset in that document.""" + + # Total number of samples. For -1 see comments in `_num_epochs`. + num_samples = (num_epochs * tokens_per_epoch - 1) // seq_length + sample_idx = np.zeros([num_samples + 1, 2], dtype=np.int64) + + # Index into sample_idx. + sample_index = 0 + # Index into doc_idx. + doc_idx_index = 0 + # Beginning offset for each document. + doc_offset = 0 + # Start with first document and no offset. + sample_idx[sample_index][0] = doc_idx_index + sample_idx[sample_index][1] = doc_offset + sample_index += 1 + while sample_index <= num_samples: + # Start with a fresh sequence. + remaining_seq_length = seq_length + 1 + while remaining_seq_length != 0: + # Get the document length. + doc_id = doc_idx[doc_idx_index] + doc_length = sizes[doc_id] - doc_offset + # And add it to the current sequence. + remaining_seq_length -= doc_length + # If we have more than a full sequence, adjust offset and set + # remaining length to zero so we return from the while loop. + # Note that -1 here is for the same reason we have -1 in + # `_num_epochs` calculations. + if remaining_seq_length <= 0: + doc_offset += remaining_seq_length + doc_length - 1 + remaining_seq_length = 0 + else: + # Otherwise, start from the beginning of the next document. + doc_idx_index += 1 + doc_offset = 0 + # Record the sequence. + sample_idx[sample_index][0] = doc_idx_index + sample_idx[sample_index][1] = doc_offset + sample_index += 1 + + return sample_idx + + +def _build_shuffle_idx(size, np_rng): + """Build the range [0, size) and shuffle.""" + dtype_ = np.uint32 + if size >= (np.iinfo(np.uint32).max - 1): + dtype_ = np.int64 + shuffle_idx = np.arange(start=0, stop=size, step=1, dtype=dtype_) + np_rng.shuffle(shuffle_idx) + return shuffle_idx diff --git a/megatron/neox_arguments/arguments.py b/megatron/neox_arguments/arguments.py index 8fbe045bb..1677bf072 100644 --- a/megatron/neox_arguments/arguments.py +++ b/megatron/neox_arguments/arguments.py @@ -1116,10 +1116,16 @@ def calculate_derived(self): # Adding equal dataset weights if none are provided if self.train_data_paths and (self.train_data_weights is None): self.train_data_weights = [1.0] * len(self.train_data_paths) + elif self.pos_train_data_paths and (self.train_data_weights is None): + self.train_data_weights = [1.0] * len(self.pos_train_data_paths) if self.valid_data_paths and (self.valid_data_weights is None): self.valid_data_weights = [1.0] * len(self.valid_data_paths) + elif self.pos_valid_data_paths and (self.valid_data_weights is None): + self.valid_data_weights = [1.0] * len(self.pos_valid_data_paths) if self.test_data_paths and (self.test_data_weights is None): self.test_data_weights = [1.0] * len(self.test_data_paths) + elif self.pos_test_data_paths and (self.test_data_weights is None): + self.test_data_weights = [1.0] * len(self.pos_test_data_paths) if self.train_label_data_paths: err_str = "Must use `train_label_data_paths` with `train_data_paths`, not `data_path`" diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index 818c86d31..814622a5b 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -893,6 +893,42 @@ class NeoXArgsTraining(NeoXArgsTemplate): List of paths to validation label datasets (not shifted by 1 yet!). """ + pos_train_data_paths: list = None + neg_train_data_paths: list = None + """ + List of paths to positive and negative training datasets. + """ + + pos_train_label_data_paths: list = None + neg_train_label_data_paths: list = None + """ + List of paths to positive and negative training label datasets (not shifted by 1 yet!). + """ + + pos_valid_data_paths: list = None + neg_valid_data_paths: list = None + """ + List of paths to positive and negative validation datasets. + """ + + pos_valid_label_data_paths: list = None + neg_valid_label_data_paths: list = None + """ + List of paths to positive and negative validation label datasets (not shifted by 1 yet!). + """ + + pos_test_data_paths: list = None + neg_test_data_paths: list = None + """ + List of paths to positive and negative test datasets. + """ + + pos_test_label_data_paths: list = None + neg_test_label_data_paths: list = None + """ + List of paths to positive and negative test label datasets (not shifted by 1 yet!). + """ + train_data_weights: list = None """ List of 'weights' that decide how often to sample from each training dataset when blending datasets. If None, defaults to equal weighting. @@ -949,6 +985,26 @@ class NeoXArgsTraining(NeoXArgsTemplate): warning: pack_until_overflow is very naive and will likely have issues with pretraining scale datasets """ + dataset_impl: Literal["gpt2", "pairwise"] = "gpt2" + """ + Dataset implementation, can be one of "gpt2" or "pairwise" + """ + + train_impl: Literal["normal", "dpo"] = "normal" + """ + Training implementation, can be one of "normal" or "dpo" + """ + + dpo_fp32: bool = True + """ + Whether to cast logits to fp32 for DPO loss calculation. + """ + + dpo_beta: float = 0.1 + """ + Beta value for DPO + """ + allow_chopped: bool = True """ WARNING: if your packing impl is packed, this is ignored. @@ -1245,7 +1301,12 @@ class NeoXArgsTextgen(NeoXArgsTemplate): text_gen_type: str = None """ How to generate text/sample the model. - Options: `unconditional`, `input-file`, `interactive` + Options: `unconditional`, `input-file`, `interactive`, `precompute` + """ + + precompute_model_name: str = None + """ + Model name to use for saving precomputed logprobs """ temperature: float = 0.0 diff --git a/megatron/text_generation_utils.py b/megatron/text_generation_utils.py index 7b7a390ab..02926c2c3 100644 --- a/megatron/text_generation_utils.py +++ b/megatron/text_generation_utils.py @@ -23,12 +23,15 @@ import time from typing import List, Union +import numpy as np import torch import torch.nn.functional as F from megatron import print_rank_0 from megatron import mpu from megatron.utils import get_ltor_masks_and_position_ids, is_mp_rank_0 +from megatron.data.indexed_dataset import make_builder, make_dataset +from megatron.mpu.mappings import gather_from_model_parallel_region def get_batch(neox_args, context_tokens: torch.Tensor): @@ -52,7 +55,9 @@ def get_batch(neox_args, context_tokens: torch.Tensor): return tokens, attention_mask, position_ids -def pad_batch(context_tokens: List[List[int]], pad_id: int, pad_len: int): +def pad_batch( + context_tokens: List[List[int]], pad_id: int, pad_len: int, truncate: bool = False +): """ pads context lengths in context_tokens with pad_id to equal neox_args.seq_length, and returns the padded batch and the new lengths. @@ -60,17 +65,21 @@ def pad_batch(context_tokens: List[List[int]], pad_id: int, pad_len: int): context_tokens: list of lists of tokens pad_id: int, integer to use as padding token pad_len: int, context length to be padded; all batch items will be padded to the same length + truncate: bool, if True, truncate context tokens to pad_len if they are longer than pad_len returns: tuple of padded context tokens and a list of unpadded token count """ context_lengths = [] - for tokens in context_tokens: + for i, tokens in enumerate(context_tokens): context_length = len(tokens) if context_length < pad_len: tokens.extend([pad_id] * (pad_len - context_length)) elif context_length > pad_len: - raise ValueError("context_length is bigger than to be padded length") + if not truncate: + raise ValueError("context_length is bigger than to be padded length") + context_tokens[i] = tokens[:pad_len] + context_length = pad_len context_lengths.append(context_length) return context_tokens, context_lengths @@ -807,3 +816,180 @@ def generate_samples_interactive( print_rank_0("Generated Text: " + generated_text) if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0: _ = input("\n") + + +def get_logp(logits, labels, force_fp32=False): + if force_fp32: + logits = logits.float() + logp = logits.log_softmax(dim=-1) + return torch.gather(logp, dim=2, index=labels.unsqueeze(2)).squeeze(2) + + +def precompute_logits(neox_args, model): + """ + Precomputes logprobs from training/testing/validation datasets + + Saves it to the same directory as the dataset with the model name appended to it + + neox_args: NeoXArgs. + model: a Megatron model + + """ + if neox_args.precompute_model_name is None: + mdl_name = str(hash(neox_args.load)) + else: + mdl_name = neox_args.precompute_model_name + print_rank_0("Precomputing logprobs...") + model.eval() + data_paths = list() + if neox_args.train_data_paths is not None: + for path in neox_args.train_data_paths: + data_paths.append(path) + for path in neox_args.test_data_paths: + data_paths.append(path) + for path in neox_args.valid_data_paths: + data_paths.append(path) + elif neox_args.pos_train_data_paths is not None: + # Pairwise data... + for path in neox_args.pos_train_data_paths: + data_paths.append(path) + for path in neox_args.neg_train_data_paths: + data_paths.append(path) + for path in neox_args.pos_valid_data_paths: + data_paths.append(path) + for path in neox_args.neg_valid_data_paths: + data_paths.append(path) + for path in neox_args.pos_test_data_paths: + data_paths.append(path) + for path in neox_args.neg_test_data_paths: + data_paths.append(path) + for path in data_paths: + print_rank_0(f"Precomputing logits for {path}") + # Add hash to path... + out_path = path + f"_{mdl_name}" + if os.path.exists(out_path + ".idx"): + continue + dataset = make_dataset(path, neox_args.data_impl, not neox_args.mmap_warmup) + if is_mp_rank_0(): + out_dataset = make_builder(out_path + ".bin", neox_args.data_impl) + out_dataset._dtype = np.float32 + i = 0 + while i < len(dataset): + start = time.time() + model.module.clear_cache() # clear kv cache between batches + if is_mp_rank_0(): + offset = ( + mpu.get_data_parallel_rank() + * neox_args.train_micro_batch_size_per_gpu + ) + context_tokens = [ + [int(x) for x in dataset.get(j % len(dataset)).tolist()] + for j in range( + i + offset, + i + (neox_args.train_micro_batch_size_per_gpu + offset), + ) + ] + # grab microbatch + # pad batch in order to allow conversion to tensor + context_tokens, context_lengths = pad_batch( + copy.deepcopy(context_tokens), + pad_id=0, + pad_len=neox_args.seq_length + 1, + truncate=True, + ) + # print(context_tokens) + label_tokens = [tokens[1:] for tokens in context_tokens] + context_tokens = [tokens[:-1] for tokens in context_tokens] + else: + context_tokens = [ + [0 for _ in range(neox_args.seq_length)] + for _ in range(neox_args.batch_size) + ] + label_tokens = [ + [0 for _ in range(neox_args.seq_length)] + for _ in range(neox_args.batch_size) + ] + context_lengths = [0 for _ in range(neox_args.batch_size)] + i += ( + neox_args.train_micro_batch_size_per_gpu + * mpu.get_data_parallel_world_size() + ) + # print(context_tokens) + # convert to tensor and broadcast + context_tokens = torch.cuda.LongTensor(context_tokens) + label_tokens = torch.cuda.LongTensor(label_tokens) + # Make sure context tokens + start tokens are the same across all ranks + token_generation_start_index = torch.cuda.LongTensor(context_lengths) + torch.distributed.broadcast( + context_tokens, + mpu.get_model_parallel_src_rank(), + group=mpu.get_model_parallel_group(), + ) + torch.distributed.broadcast( + token_generation_start_index, + mpu.get_model_parallel_src_rank(), + group=mpu.get_model_parallel_group(), + ) + torch.distributed.broadcast( + label_tokens, + mpu.get_model_parallel_src_rank(), + group=mpu.get_model_parallel_group(), + ) + # context_tokens = context_tokens[:, :chop_len].contiguous() + # label_tokens = label_tokens[:, :chop_len].contiguous() + with torch.no_grad(): + # get attention mask / position ids + context_tokens, attention_mask, position_ids = get_batch( + neox_args, context_tokens + ) + model_inputs = ( + context_tokens, + position_ids, + attention_mask, + ) + maybe_tuple = forward_model( + model, model_inputs, neox_args.is_pipe_parallel + ) + if isinstance(maybe_tuple, tuple): + logits, _ = maybe_tuple + else: + logits = maybe_tuple + if logits is not None: # if pipe parallel, not all ranks return logits + logits = gather_from_model_parallel_region(logits) + logp = get_logp(logits, label_tokens, True).squeeze() + if neox_args.is_pipe_parallel: + # broadcast generated tokens to pipe parallel group + src_rank = model.grid.stage_to_global(model.num_stages - 1) + logp = ( + logp + if logits is not None + else torch.zeros( + neox_args.batch_size, dtype=torch.float32 + ).cuda() + ) + torch.distributed.broadcast( + tensor=logp, + src=src_rank, + group=mpu.get_pipe_parallel_group(), + ) + logp = logp.squeeze() + logp_list = [ + torch.zeros_like(logp) + for _ in range(mpu.get_data_parallel_world_size()) + ] + torch.distributed.all_gather( + logp_list, logp, group=mpu.get_data_parallel_group() + ) + logp = torch.cat(logp_list, dim=0).cpu().numpy() + if (mpu.get_model_parallel_rank() == 0) and ( + mpu.get_data_parallel_rank() == 0 + ): + for j in range(logp.shape[0]): + out_dataset.add_item(logp[j]) + out_dataset.end_document() + print_rank_0(f"Processed {i} / {len(dataset)} in {time.time() - start}") + if is_mp_rank_0(): + out_dataset.finalize( + out_path + ".idx", + ) + torch.distributed.barrier() diff --git a/megatron/training.py b/megatron/training.py index fc3d9e129..d9932483a 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -21,12 +21,14 @@ """Pretrain utilities.""" from datetime import datetime from functools import partial +from collections import defaultdict import math import sys from contextlib import nullcontext import torch +import torch.nn.functional as F import deepspeed from deepspeed.runtime.data_pipeline.curriculum_scheduler import CurriculumScheduler import numpy as np @@ -45,6 +47,7 @@ get_params_for_weight_decay_optimization, mark_norms_for_sequence_parallel_grad_sync, ) +from megatron.mpu.mappings import gather_from_model_parallel_region from megatron.checkpointing import load_checkpoint, save_checkpoint from megatron.data.data_utils import build_train_valid_test_data_iterators from megatron.initialize import initialize_megatron @@ -137,7 +140,7 @@ def gen(): old_hidden_size = neox_args.hidden_size neox_args.hidden_size = hidden_size - model, optimizer, _ = setup_model_and_optimizer( + model, optimizer, _, _ = setup_model_and_optimizer( neox_args=neox_args, use_cache=False ) @@ -193,7 +196,7 @@ def pretrain(neox_args): # Model, optimizer, and learning rate. timers("model and optimizer").start() - model, optimizer, lr_scheduler = setup_model_and_optimizer( + model, optimizer, lr_scheduler, reference_model = setup_model_and_optimizer( neox_args=neox_args, use_cache=False, iteration=neox_args.iteration ) timers("model and optimizer").stop() @@ -231,6 +234,7 @@ def pretrain(neox_args): neox_args=neox_args, timers=timers, model=model, + reference_model=reference_model, optimizer=optimizer, lr_scheduler=lr_scheduler, train_data_iterator=train_data_iterator, @@ -282,12 +286,12 @@ def _get_batch(neox_args, tokenizer, keys, data, datatype): label_key = keys[1] if len(keys) > 1 else None # Unpack. tokens_ = data_b[token_key].long() - if "label" in data_b: + if label_key in data_b: label_mask = (data_b[label_key].long() >= 0)[:, 1:].contiguous() labels = torch.where( data_b[label_key].long() >= 0, data_b[label_key].long(), - torch.zeros_like(data_b["label"].long()), + torch.zeros_like(data_b[label_key].long()), )[:, 1:].contiguous() else: label_mask = (tokens_.long() >= 0)[:, 1:].contiguous() @@ -311,7 +315,14 @@ def get_batch(neox_args, data_iterator): """Generate a batch""" # Items and their type. - keys = ["text", "label"] if neox_args.train_label_data_paths else ["text"] + if neox_args.train_impl == "normal": + keys = ["text", "label"] if neox_args.train_label_data_paths else ["text"] + elif neox_args.train_impl == "dpo": + keys = ( + [["pos", "pos_label"], ["neg", "neg_label"]] + if neox_args.pos_train_label_data_paths + else [["pos"], ["neg"]] + ) datatype = torch.int64 # Broadcast data. @@ -319,13 +330,43 @@ def get_batch(neox_args, data_iterator): data = next(data_iterator) else: data = None - return _get_batch( - neox_args=neox_args, - tokenizer=neox_args.tokenizer, - keys=keys, - data=data, - datatype=datatype, - ) + if neox_args.train_impl == "normal": + return _get_batch( + neox_args=neox_args, + tokenizer=neox_args.tokenizer, + keys=keys, + data=data, + datatype=datatype, + ) + elif neox_args.train_impl == "dpo": + pos_tup = _get_batch( + neox_args=neox_args, + tokenizer=neox_args.tokenizer, + keys=keys[0], + data=data, + datatype=datatype, + ) + neg_tup = _get_batch( + neox_args=neox_args, + tokenizer=neox_args.tokenizer, + keys=keys[1], + data=data, + datatype=datatype, + ) + if neox_args.precompute_model_name: + ref_data = mpu.broadcast_data(["pos_ref", "neg_ref"], data, torch.float) + else: + ref_data = {"pos_ref": None} + return [ + torch.cat((pos_item, neg_item), dim=0) + for pos_item, neg_item in zip(pos_tup, neg_tup) + ] + [ + torch.cat((ref_data["pos_ref"], ref_data["neg_ref"]), dim=0)[ + :, :-1 + ].contiguous() + if ref_data["pos_ref"] is not None + else None + ] def get_batch_pipe(data, neox_args, curr_scheduler=None): @@ -419,8 +460,23 @@ def mb_moe_loss_func(args, loss_mask, output_tensor=None): return averaged_lbl, loss_dict +def get_pos_neg_logp(logits, labels, force_fp32=False): + if force_fp32: + logits = logits.float() + logp = logits.log_softmax(dim=-1) + per_token_logp = torch.gather(logp, dim=2, index=labels.unsqueeze(2)).squeeze(2) + # Split to pos/neg... + return torch.chunk(per_token_logp, 2, 0) + + def forward_step( - data_iterator, model, neox_args, timers, return_logits=False, is_train=False + data_iterator, + model, + neox_args, + timers, + return_logits=False, + is_train=False, + reference_model=None, ): """Forward step.""" if neox_args.is_pipe_parallel: @@ -431,9 +487,14 @@ def forward_step( torch.cuda.nvtx.range_push(f"Get batch") if timers is not None: timers("batch generator").start() - tokens, labels, loss_mask, attention_mask, position_ids = get_batch( - neox_args=neox_args, data_iterator=data_iterator - ) + if neox_args.train_impl == "normal": + tokens, labels, loss_mask, attention_mask, position_ids = get_batch( + neox_args=neox_args, data_iterator=data_iterator + ) + if neox_args.train_impl == "dpo": + tokens, labels, loss_mask, attention_mask, position_ids, ref_logp = get_batch( + neox_args=neox_args, data_iterator=data_iterator + ) if timers is not None: timers("batch generator").stop() @@ -442,38 +503,100 @@ def forward_step( if neox_args.memory_profiling: torch.cuda.nvtx.range_push(f"Forward pass") - # Sequential returns moe_losses, but this is not yet supported by pipe parallel - maybe_tuple = model((tokens, position_ids, attention_mask), neox_args=neox_args) - if type(maybe_tuple) is tuple: - outputs, moe_losses = maybe_tuple - else: - outputs = maybe_tuple - moe_losses = [] - if ( - is_train - and neox_args.curriculum_learning - and neox_args.curriculum_seqlen < neox_args.seq_length - ): - loss_mask = loss_mask[:, : neox_args.curriculum_seqlen].contiguous() - labels = labels[:, : neox_args.curriculum_seqlen].contiguous() - main_loss = cross_entropy( - outputs, (labels, loss_mask), _fp16=neox_args.fp16_lm_cross_entropy - ) - if neox_args.moe_num_experts > 1: - if neox_args.moe_type == "deepspeed": - moe_loss = neox_args.moe_loss_coeff * sum(m.item() for m in moe_losses) - elif neox_args.moe_type == "megablocks": - moe_loss = mb_moe_loss_func(neox_args, loss_mask, outputs)[0] + metrics = {} + if neox_args.train_impl == "normal": + # Sequential returns moe_losses, but this is not yet supported by pipe parallel + maybe_tuple = model((tokens, position_ids, attention_mask), neox_args=neox_args) + if type(maybe_tuple) is tuple: + outputs, moe_losses = maybe_tuple else: - raise ValueError(f"Unsupported moe_type: {neox_args.moe_type}") - else: - moe_loss = 0.0 - loss = main_loss + moe_loss + outputs = maybe_tuple + moe_losses = [] + if ( + is_train + and neox_args.curriculum_learning + and neox_args.curriculum_seqlen < neox_args.seq_length + ): + loss_mask = loss_mask[:, : neox_args.curriculum_seqlen].contiguous() + labels = labels[:, : neox_args.curriculum_seqlen].contiguous() + main_loss = cross_entropy( + outputs, (labels, loss_mask), _fp16=neox_args.fp16_lm_cross_entropy + ) + if neox_args.moe_num_experts > 1: + if neox_args.moe_type == "deepspeed": + moe_loss = neox_args.moe_loss_coeff * sum(m.item() for m in moe_losses) + elif neox_args.moe_type == "megablocks": + moe_loss = mb_moe_loss_func(neox_args, loss_mask, outputs)[0] + else: + raise ValueError(f"Unsupported moe_type: {neox_args.moe_type}") + else: + moe_loss = 0.0 + loss = main_loss + moe_loss + elif neox_args.train_impl == "dpo": + # Based on https://github.com/eric-mitchell/direct-preference-optimization/blob/main/trainers.py#L90 + with torch.no_grad(): + # So we can gather token logps... + token_logp_labels = labels.clone() + token_logp_labels[token_logp_labels == -100] = 0 + pos_loss_mask, neg_loss_mask = torch.chunk(loss_mask, 2, 0) + if ref_logp is None: + ref_maybe_tuple = reference_model( + (tokens, position_ids, attention_mask), neox_args=neox_args + ) + if type(ref_maybe_tuple) is tuple: + # We should ignore MoE losses yeah? + ref_outputs, _ = ref_maybe_tuple + else: + ref_outputs = ref_maybe_tuple + # gather across tensor parallel group + ref_outputs = gather_from_model_parallel_region(ref_outputs) + ref_pos, ref_neg = get_pos_neg_logp( + ref_outputs, token_logp_labels, neox_args.dpo_fp32 + ) + else: + ref_pos, ref_neg = torch.chunk(ref_logp, 2, 0) + ref_pos = (ref_pos * pos_loss_mask).sum(-1) + ref_neg = (ref_neg * neg_loss_mask).sum(-1) + chosen_maybe_tuple = model( + (tokens, position_ids, attention_mask), neox_args=neox_args + ) + if type(chosen_maybe_tuple) is tuple: + # We should ignore MoE losses yeah? + chosen_outputs, _ = chosen_maybe_tuple + else: + chosen_outputs = chosen_maybe_tuple + chosen_outputs = gather_from_model_parallel_region(chosen_outputs) + chosen_pos, chosen_neg = get_pos_neg_logp( + chosen_outputs, token_logp_labels, neox_args.dpo_fp32 + ) + chosen_pos = (chosen_pos * pos_loss_mask).sum(-1) + chosen_neg = (chosen_neg * neg_loss_mask).sum(-1) + with torch.no_grad(): + # Collect metrics... + metrics["ref_neg"] = ref_neg.clone().detach().mean() + metrics["ref_pos"] = ref_pos.clone().detach().mean() + metrics["chosen_neg"] = chosen_neg.clone().detach().mean() + metrics["chosen_pos"] = chosen_pos.clone().detach().mean() + chosen_rewards = neox_args.dpo_beta * ( + chosen_pos.clone().detach() - ref_pos.clone().detach() + ) + rejected_rewards = neox_args.dpo_beta * ( + chosen_neg.clone().detach() - ref_neg.clone().detach() + ) + reward_acc = (chosen_rewards > rejected_rewards).float() + metrics["reward_acc"] = reward_acc.mean() + metrics["chosen_rewards"] = chosen_rewards.mean() + metrics["rejected_rewards"] = rejected_rewards.mean() + metrics["margins"] = (chosen_rewards - rejected_rewards).mean() + pi_logrations = chosen_pos - chosen_neg + ref_logrations = ref_pos - ref_neg + logits = pi_logrations - ref_logrations + loss = -F.logsigmoid(neox_args.dpo_beta * logits).mean() if neox_args.memory_profiling: torch.cuda.nvtx.range_pop() if return_logits: - return loss, outputs - return loss + return loss, outputs, metrics + return loss, metrics def get_model(neox_args, use_cache=False): @@ -548,9 +671,14 @@ def get_model(neox_args, use_cache=False): raise ValueError("Must be using deepspeed to run neox") -def get_optimizer(model, neox_args): +def get_optimizer(model, neox_args, dummy=False): """Set up the optimizer.""" - if neox_args.no_load_optim: + if neox_args.no_load_optim and neox_args.deepspeed: + # Required to have something so... + dummy = True + neox_args.optimizer = {"params": {"lr": 0.0}} + neox_args.optimizer_type = "adam" + elif neox_args.no_load_optim: return None, None if neox_args.optimizer is None: @@ -584,8 +712,13 @@ def get_optimizer(model, neox_args): _param_groups = [] for param_group in param_groups: trainable_params = [p for p in param_group["params"] if p.requires_grad] + if dummy: + trainable_params = [trainable_params[0]] # just take the first one param_group["params"] = trainable_params _param_groups.append(param_group) + if dummy: + # Only need one. + break param_groups = _param_groups # If we're using mup, then the optimizer must be adam or sgd @@ -699,7 +832,7 @@ def get_optimizer(model, neox_args): def get_learning_rate_scheduler(optimizer, neox_args): """Build the learning rate scheduler.""" - if neox_args.no_load_optim: + if (neox_args.no_load_optim) and not neox_args.deepspeed: # TODO: this should be configured as a separate arg return None if neox_args.deepspeed and neox_args.optimizer_type.lower() == "onebitadam": @@ -744,19 +877,30 @@ def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None): ) """Setup model and optimizer.""" + needs_reference_model = (neox_args.train_impl == "dpo") and ( + neox_args.precompute_model_name is None + ) model = get_model(neox_args=neox_args, use_cache=use_cache) + if needs_reference_model: + reference_model = get_model(neox_args=neox_args, use_cache=use_cache) + else: + reference_model = None optimizer, param_groups = get_optimizer(model=model, neox_args=neox_args) lr_scheduler = get_learning_rate_scheduler(optimizer=optimizer, neox_args=neox_args) - + if neox_args.deepspeed and needs_reference_model: + # Need an optimizer & lr_scheduler so make a very small one to keep deepspeed happy... + ref_optimizer, ref_param_groups = get_optimizer( + model=reference_model, neox_args=neox_args, dummy=True + ) + ref_lr_scheduler = get_learning_rate_scheduler( + optimizer=ref_optimizer, neox_args=neox_args + ) + else: + ref_optimizer, ref_param_groups, ref_lr_scheduler = None, None, None if neox_args.deepspeed: print_rank_0("DeepSpeed is enabled.") - if neox_args.no_load_optim: - assert optimizer is None - _model_params = None - _lr_scheduler = None - else: - _model_params = param_groups if optimizer is None else None - _lr_scheduler = lr_scheduler + _model_params = param_groups if optimizer is None else None + _lr_scheduler = lr_scheduler model, optimizer, _, lr_scheduler = deepspeed.initialize( model=model, @@ -769,6 +913,16 @@ def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None): # config_params=neox_args.deepspeed_config, mpu=mpu if not neox_args.is_pipe_parallel else None, ) + if needs_reference_model: + reference_model, _, _, _ = deepspeed.initialize( + model=reference_model, + optimizer=ref_optimizer, + args=neox_args, + lr_scheduler=ref_lr_scheduler, + dist_init_required=False, + model_parameters=ref_param_groups, + mpu=mpu if not neox_args.is_pipe_parallel else None, + ) mark_norms_for_sequence_parallel_grad_sync(model, neox_args) if neox_args.moe_num_experts > 1 and neox_args.moe_type == "megablocks": # We need to additionally set this flag to ensure DS parallelism properly handles this foreign MoE. @@ -805,6 +959,14 @@ def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None): lr_scheduler=lr_scheduler, iteration=iteration, ) + if needs_reference_model: + _ = load_checkpoint( + neox_args=neox_args, + model=reference_model, + optimizer=ref_optimizer, + lr_scheduler=ref_lr_scheduler, + iteration=iteration, + ) print_rank_0( f"Loading checkpoint and starting from iteration {neox_args.iteration}" ) @@ -816,7 +978,7 @@ def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None): if lr_scheduler is not None: lr_scheduler.optimizer = model.optimizer - return model, optimizer, lr_scheduler + return model, optimizer, lr_scheduler, reference_model def backward_step(neox_args, timers, optimizer, model, loss): @@ -838,7 +1000,15 @@ def backward_step(neox_args, timers, optimizer, model, loss): raise ValueError("Must be using deepspeed to run neox") -def train_step(neox_args, timers, data_iterator, model, optimizer, lr_scheduler): +def train_step( + neox_args, + timers, + data_iterator, + model, + optimizer, + lr_scheduler, + reference_model=None, +): """Single training step.""" # Pipeline parallelism schedules forward/backward/step @@ -846,6 +1016,7 @@ def train_step(neox_args, timers, data_iterator, model, optimizer, lr_scheduler) reduced_loss = train_step_pipe( neox_args=neox_args, timers=timers, model=model, data_iterator=data_iterator ) + reduce_metrics = reduced_loss if ( neox_args.memory_profiling and neox_args.iteration >= neox_args.profile_step_start @@ -855,18 +1026,22 @@ def train_step(neox_args, timers, data_iterator, model, optimizer, lr_scheduler) save_snapshot(neox_args) else: losses = [] + metric_dicts = defaultdict(list) for _ in range(neox_args.gradient_accumulation_steps): # Forward model for one step. timers("forward").start() - loss = forward_step( + loss, metric_dict = forward_step( neox_args=neox_args, timers=timers, data_iterator=data_iterator, model=model, is_train=True, + reference_model=reference_model, ) timers("forward").stop() losses.append(loss) + for key in metric_dict.keys(): + metric_dicts[key].append(metric_dict[key]) # Calculate gradients, reduce across processes, and clip. if ( neox_args.profile @@ -916,17 +1091,19 @@ def train_step(neox_args, timers, data_iterator, model, optimizer, lr_scheduler) and torch.distributed.get_rank() == 0 ): save_snapshot(neox_args) - reduced_loss = { - "lm_loss": reduce_losses(losses).mean() - } # reduces losses across machines for logging + # reduces metrics across machines for logging + reduce_metrics = { + key: reduce_losses(metric_dicts[key]).mean() for key in metric_dicts.keys() + } + reduce_metrics["lm_loss"] = reduce_losses(losses).mean() if neox_args.precision == "fp16" and model.optimizer.overflow: skipped_iter = 1 else: skipped_iter = 0 - collect_loss_for_unit_test(reduced_loss["lm_loss"]) - return reduced_loss, skipped_iter + collect_loss_for_unit_test(reduce_metrics["lm_loss"]) + return reduce_metrics, skipped_iter def train_step_pipe(neox_args, timers, model, data_iterator): @@ -952,6 +1129,7 @@ def train( neox_args, timers, model, + reference_model, optimizer, lr_scheduler, train_data_iterator, @@ -1007,6 +1185,7 @@ def train( model=model, optimizer=optimizer, lr_scheduler=lr_scheduler, + reference_model=reference_model, ) if neox_args.profile and iteration == neox_args.profile_step_stop: torch.cuda.cudart().cudaProfilerStop() @@ -1097,6 +1276,7 @@ def evaluate( # Turn on evaluation mode which disables dropout. model.eval() losses = [] + metric_dicts = defaultdict(list) if neox_args.char_level_ppl: data_iterator = CharCounter(data_iterator, neox_args.tokenizer) @@ -1118,14 +1298,15 @@ def evaluate( else neox_args.gradient_accumulation_steps ): # Forward evaluation - loss = forward_step_fn( + loss, metric_dict = forward_step_fn( model=model, data_iterator=data_iterator, neox_args=neox_args, timers=timers, ) losses.append(loss) - + for key in metric_dict.keys(): + metric_dicts[key].append(metric_dict[key]) # When contiguous memory optimizations are enabled, the buffers # allocated by the optimizations are deallocated during backward pass # in the absence of backward pass the buffers should be reset after each @@ -1135,6 +1316,8 @@ def evaluate( # reduces losses across processes for logging & run eval harness tasks eval_results = {"lm_loss": reduce_losses(losses).mean().item()} + for key in metric_dicts.keys(): + eval_results[key] = reduce_losses(metric_dicts[key]).mean().item() eval_results["lm_loss_ppl"] = math.exp(eval_results["lm_loss"]) if neox_args.char_level_ppl: diff --git a/megatron/utils.py b/megatron/utils.py index 26b4439bd..a64a8ba6c 100644 --- a/megatron/utils.py +++ b/megatron/utils.py @@ -449,7 +449,7 @@ def setup_for_inference_or_eval(use_cache=True, overwrite_values=None, input_arg initialize_megatron(neox_args) # set up model and load checkpoint. - model, _, _ = setup_model_and_optimizer( + model, _, _, _ = setup_model_and_optimizer( neox_args=neox_args, use_cache=use_cache, iteration=neox_args.iteration, diff --git a/tools/datasets/preprocess_data_with_chat_template.py b/tools/datasets/preprocess_data_with_chat_template.py index 55623b303..4e101ea5a 100644 --- a/tools/datasets/preprocess_data_with_chat_template.py +++ b/tools/datasets/preprocess_data_with_chat_template.py @@ -105,6 +105,7 @@ def build_chat( chat_tokens = tokenizer.apply_chat_template( chat[: i + 1], add_generation_prompt=add_gen )[len(tokens) :] + # remove previous stuff... tokens.extend(chat_tokens) if only_last_turn and (i != len(chat) - 1):