From 9eebcb88e86fb7eb0988bc9870f5c01fb4aab019 Mon Sep 17 00:00:00 2001 From: lintangsutawika Date: Mon, 24 Jun 2024 13:54:50 +0000 Subject: [PATCH] adjust to accept different dataset type --- megatron/data/data_utils.py | 32 ++++++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/megatron/data/data_utils.py b/megatron/data/data_utils.py index bc5754cdb..681357605 100644 --- a/megatron/data/data_utils.py +++ b/megatron/data/data_utils.py @@ -23,6 +23,7 @@ from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset from megatron.data.blendable_dataset import BlendableDataset from megatron.data.gpt2_dataset import GPT2Dataset +from megatron.data.ptok_dataset import PTokDataset from megatron.data.samplers import DistributedBatchSampler @@ -51,7 +52,18 @@ def make_data_loader(dataset, neox_args): ) +def get_dataset_obj(dataset_type): + print_rank_0("> building {} dataset ...".format(dataset_type)) + # Select and instantiate the dataset. + if dataset_type.lower() == "default".lower(): + dataset_obj = GPT2Dataset + elif dataset_type.lower() == "pause".lower(): + dataset_obj = PTokDataset + return dataset_obj + + def build_the_dataset( + dataset_type, data_prefix, name, data_impl, @@ -61,6 +73,7 @@ def build_the_dataset( skip_warmup, build_index_mappings=True, label_prefix=None, + dataset_configs=None, ): """Build train/valid/test datasets.""" @@ -75,7 +88,9 @@ def build_the_dataset( print_rank_0(" no. of documents:{}".format(total_num_of_documents)) dataset = None documents = np.arange(start=0, stop=total_num_of_documents, step=1, dtype=np.int32) - dataset = GPT2Dataset( + + dataset_obj = get_dataset_obj(dataset_type) + dataset = dataset_obj( name, data_prefix, documents, @@ -85,11 +100,13 @@ def build_the_dataset( seed, build_index_mappings=build_index_mappings, label_dataset=label_dataset, + **dataset_configs if dataset_configs is not None else {}, ) return dataset def build_train_valid_test_datasets( + dataset_type, data_prefix, use_shared_fs, data_impl, @@ -98,6 +115,7 @@ def build_train_valid_test_datasets( seq_length, seed, skip_warmup, + dataset_configs=None, ): """Build train, valid, and test datasets.""" @@ -130,7 +148,8 @@ def build_dataset(index, name): start=splits[index], stop=splits[index + 1], step=1, dtype=np.int32 ) - dataset = GPT2Dataset( + dataset_obj = get_dataset_obj(dataset_type) + dataset = dataset_obj( name, data_prefix, documents, @@ -139,6 +158,7 @@ def build_dataset(index, name): seq_length, seed, use_shared_fs=use_shared_fs, + **dataset_configs if dataset_configs is not None else {}, ) return dataset @@ -215,6 +235,7 @@ def build_weighted_datasets( if train_path: train_datasets.append( build_the_dataset( + dataset_type=neox_args.dataset_type, data_prefix=train_path, name=f"train_{i}", data_impl=neox_args.data_impl, @@ -224,12 +245,14 @@ def build_weighted_datasets( skip_warmup=(not neox_args.mmap_warmup), build_index_mappings=build_index_mappings, label_prefix=label_path, + dataset_configs=neox_args.dataset_cfg, ) ) if valid_path: valid_datasets.append( build_the_dataset( + dataset_type=neox_args.dataset_type, data_prefix=valid_path, name=f"valid_{i}", data_impl=neox_args.data_impl, @@ -238,12 +261,14 @@ def build_weighted_datasets( seed=neox_args.seed, skip_warmup=(not neox_args.mmap_warmup), build_index_mappings=build_index_mappings, + dataset_configs=neox_args.dataset_cfg, ) ) if test_path: test_datasets.append( build_the_dataset( + dataset_type=neox_args.dataset_type, data_prefix=test_path, name=f"test_{i}", data_impl=neox_args.data_impl, @@ -252,6 +277,7 @@ def build_weighted_datasets( seed=neox_args.seed, skip_warmup=(not neox_args.mmap_warmup), build_index_mappings=build_index_mappings, + dataset_configs=neox_args.dataset_cfg, ) ) return train_datasets, valid_datasets, test_datasets @@ -406,6 +432,7 @@ def build_train_valid_test_data_iterators(neox_args): # when just data_path is provided # split dataset into train, valid and test from data_path train_ds, valid_ds, test_ds = build_train_valid_test_datasets( + dataset_type=neox_args.dataset_type, data_prefix=neox_args.data_path, use_shared_fs=neox_args.use_shared_fs, data_impl=neox_args.data_impl, @@ -414,6 +441,7 @@ def build_train_valid_test_data_iterators(neox_args): seq_length=neox_args.seq_length, seed=neox_args.seed, skip_warmup=(not neox_args.mmap_warmup), + dataset_configs=neox_args.dataset_cfg, ) # Build dataloders.