Skip to content

Commit

Permalink
add pseudocode for building streamingdatasets dataloaders
Browse files Browse the repository at this point in the history
  • Loading branch information
haileyschoelkopf committed Nov 29, 2023
1 parent 10bf788 commit b52a2aa
Showing 1 changed file with 134 additions and 0 deletions.
134 changes: 134 additions & 0 deletions megatron/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,140 @@ def build_train_valid_test_data_iterators(neox_args):
return train_data_iterator, valid_data_iterator, test_data_iterator


def build_train_valid_test_data_iterators_streaming(neox_args):
"""as above, but builds Mosaic StreamingDatasets instead"""

(train_dataloader, valid_dataloader, test_dataloader) = (None, None, None)

print_rank_0("> building train, validation, and test datasets ...")

# Ensure only the first/last pipeline stages have data loaders
if neox_args.is_pipe_parallel:
is_first_stage = mpu.get_pipe_parallel_rank() == 0
is_last_stage = (
mpu.get_pipe_parallel_rank() == mpu.get_pipe_parallel_world_size() - 1
)
pipe_load = is_first_stage or is_last_stage
else:
pipe_load = True

# Data loader only on rank 0 of each model parallel group.
if mpu.get_model_parallel_rank() == 0 and pipe_load:
# Number of train/valid/test samples.
train_iters = neox_args.train_iters
eval_iters = (train_iters // neox_args.eval_interval + 1) * neox_args.eval_iters
test_iters = neox_args.eval_iters
train_val_test_num_samples = [
train_iters * neox_args.train_batch_size,
eval_iters * neox_args.train_batch_size,
test_iters * neox_args.train_batch_size,
]

if neox_args.train_data_paths:
# if neox_args.train_data_paths:
# # when individual train / valid / test data paths are provided
# # normalize weight values and get num samples for each dataset
# train_weights, train_num_samples = get_normalized_weights_and_num_samples(
# neox_args.train_data_weights, train_val_test_num_samples[0]
# )
# valid_weights, valid_num_samples = get_normalized_weights_and_num_samples(
# neox_args.valid_data_weights, train_val_test_num_samples[1]
# )
# test_weights, test_num_samples = get_normalized_weights_and_num_samples(
# neox_args.test_data_weights, train_val_test_num_samples[2]
# )


#Load mosaic streaming datasets from train_data_paths, valid_data_paths, test_data_paths


# next, make a blended dataset out of the ones we built (estimate how many docs we need from each?)
# TODO: pull up how you do the sampling-proportional-to-weights from Mosaic dataset

if train_datasets:
train_ds = BlendableDataset(train_datasets, train_weights)
if valid_datasets:
valid_ds = BlendableDataset(valid_datasets, valid_weights)
if test_datasets:
test_ds = BlendableDataset(test_datasets, test_weights)
else:
raise ValueError("tried to use StreamingDataset, but data_path was set in config. Please pass data via train_data_paths, valid_data_paths, and test_data_paths")


# Build dataloders.
# TODO: confirm this gives right non-duplicated contents at each batch item
train_dataloader = make_data_loader(train_ds, neox_args=neox_args)
valid_dataloader = make_data_loader(valid_ds, neox_args=neox_args)
test_dataloader = make_data_loader(test_ds, neox_args=neox_args)

# Flags to know if we need to do training/validation/testing.
do_train = train_dataloader is not None and neox_args.train_iters > 0
do_valid = valid_dataloader is not None and neox_args.eval_iters > 0
do_test = test_dataloader is not None and neox_args.eval_iters > 0

flags = torch.cuda.LongTensor([int(do_train), int(do_valid), int(do_test)])
else:
flags = torch.cuda.LongTensor([0, 0, 0])

# Broadcast num tokens.
if neox_args.is_pipe_parallel:
# Only first/last pipeline stages have data loaders, so pipeline parallelism should
# broadcast globally instead of just the model parallel group.
torch.distributed.broadcast(flags, src=0)
else:
torch.distributed.broadcast(
flags,
mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group(),
)
neox_args.do_train = flags[0].item()
neox_args.do_valid = flags[1].item()
neox_args.do_test = flags[2].item()

# Shift the start iterations. TODO: how to do this with streamingdatasets?
if train_dataloader is not None:
train_dataloader.batch_sampler.start_iter = (
neox_args.iteration * neox_args.gradient_accumulation_steps
) % len(train_dataloader)
print_rank_0(
"setting training data start iteration to {}".format(
train_dataloader.batch_sampler.start_iter
)
)
if valid_dataloader is not None:
start_iter_val = (
(neox_args.iteration * neox_args.gradient_accumulation_steps)
// neox_args.eval_interval
) * neox_args.eval_iters
valid_dataloader.batch_sampler.start_iter = start_iter_val % len(
valid_dataloader
)
print_rank_0(
"setting validation data start iteration to {}".format(
valid_dataloader.batch_sampler.start_iter
)
)

# Build iterators.
if train_dataloader is not None:
train_data_iterator = iter(train_dataloader)
else:
train_data_iterator = None

if valid_dataloader is not None:
valid_data_iterator = iter(valid_dataloader)
else:
valid_data_iterator = None

if test_dataloader is not None:
test_data_iterator = iter(test_dataloader)
else:
test_data_iterator = None

return train_data_iterator, valid_data_iterator, test_data_iterator



def compile_helper():
"""Compile helper function at runtime. Make sure this
is invoked on a single process."""
Expand Down

0 comments on commit b52a2aa

Please sign in to comment.