Skip to content

Commit

Permalink
adjust to accept different dataset type
Browse files Browse the repository at this point in the history
  • Loading branch information
lintangsutawika committed Jun 24, 2024
1 parent 7324c4b commit 9eebcb8
Showing 1 changed file with 30 additions and 2 deletions.
32 changes: 30 additions & 2 deletions megatron/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset
from megatron.data.blendable_dataset import BlendableDataset
from megatron.data.gpt2_dataset import GPT2Dataset
from megatron.data.ptok_dataset import PTokDataset
from megatron.data.samplers import DistributedBatchSampler


Expand Down Expand Up @@ -51,7 +52,18 @@ def make_data_loader(dataset, neox_args):
)


def get_dataset_obj(dataset_type):
print_rank_0("> building {} dataset ...".format(dataset_type))
# Select and instantiate the dataset.
if dataset_type.lower() == "default".lower():
dataset_obj = GPT2Dataset
elif dataset_type.lower() == "pause".lower():
dataset_obj = PTokDataset
return dataset_obj


def build_the_dataset(
dataset_type,
data_prefix,
name,
data_impl,
Expand All @@ -61,6 +73,7 @@ def build_the_dataset(
skip_warmup,
build_index_mappings=True,
label_prefix=None,
dataset_configs=None,
):
"""Build train/valid/test datasets."""

Expand All @@ -75,7 +88,9 @@ def build_the_dataset(
print_rank_0(" no. of documents:{}".format(total_num_of_documents))
dataset = None
documents = np.arange(start=0, stop=total_num_of_documents, step=1, dtype=np.int32)
dataset = GPT2Dataset(

dataset_obj = get_dataset_obj(dataset_type)
dataset = dataset_obj(
name,
data_prefix,
documents,
Expand All @@ -85,11 +100,13 @@ def build_the_dataset(
seed,
build_index_mappings=build_index_mappings,
label_dataset=label_dataset,
**dataset_configs if dataset_configs is not None else {},
)
return dataset


def build_train_valid_test_datasets(
dataset_type,
data_prefix,
use_shared_fs,
data_impl,
Expand All @@ -98,6 +115,7 @@ def build_train_valid_test_datasets(
seq_length,
seed,
skip_warmup,
dataset_configs=None,
):
"""Build train, valid, and test datasets."""

Expand Down Expand Up @@ -130,7 +148,8 @@ def build_dataset(index, name):
start=splits[index], stop=splits[index + 1], step=1, dtype=np.int32
)

dataset = GPT2Dataset(
dataset_obj = get_dataset_obj(dataset_type)
dataset = dataset_obj(
name,
data_prefix,
documents,
Expand All @@ -139,6 +158,7 @@ def build_dataset(index, name):
seq_length,
seed,
use_shared_fs=use_shared_fs,
**dataset_configs if dataset_configs is not None else {},
)
return dataset

Expand Down Expand Up @@ -215,6 +235,7 @@ def build_weighted_datasets(
if train_path:
train_datasets.append(
build_the_dataset(
dataset_type=neox_args.dataset_type,
data_prefix=train_path,
name=f"train_{i}",
data_impl=neox_args.data_impl,
Expand All @@ -224,12 +245,14 @@ def build_weighted_datasets(
skip_warmup=(not neox_args.mmap_warmup),
build_index_mappings=build_index_mappings,
label_prefix=label_path,
dataset_configs=neox_args.dataset_cfg,
)
)

if valid_path:
valid_datasets.append(
build_the_dataset(
dataset_type=neox_args.dataset_type,
data_prefix=valid_path,
name=f"valid_{i}",
data_impl=neox_args.data_impl,
Expand All @@ -238,12 +261,14 @@ def build_weighted_datasets(
seed=neox_args.seed,
skip_warmup=(not neox_args.mmap_warmup),
build_index_mappings=build_index_mappings,
dataset_configs=neox_args.dataset_cfg,
)
)

if test_path:
test_datasets.append(
build_the_dataset(
dataset_type=neox_args.dataset_type,
data_prefix=test_path,
name=f"test_{i}",
data_impl=neox_args.data_impl,
Expand All @@ -252,6 +277,7 @@ def build_weighted_datasets(
seed=neox_args.seed,
skip_warmup=(not neox_args.mmap_warmup),
build_index_mappings=build_index_mappings,
dataset_configs=neox_args.dataset_cfg,
)
)
return train_datasets, valid_datasets, test_datasets
Expand Down Expand Up @@ -406,6 +432,7 @@ def build_train_valid_test_data_iterators(neox_args):
# when just data_path is provided
# split dataset into train, valid and test from data_path
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
dataset_type=neox_args.dataset_type,
data_prefix=neox_args.data_path,
use_shared_fs=neox_args.use_shared_fs,
data_impl=neox_args.data_impl,
Expand All @@ -414,6 +441,7 @@ def build_train_valid_test_data_iterators(neox_args):
seq_length=neox_args.seq_length,
seed=neox_args.seed,
skip_warmup=(not neox_args.mmap_warmup),
dataset_configs=neox_args.dataset_cfg,
)

# Build dataloders.
Expand Down

0 comments on commit 9eebcb8

Please sign in to comment.