Skip to content

Commit

Permalink
working but non-unique across ranks
Browse files Browse the repository at this point in the history
  • Loading branch information
haileyschoelkopf committed Nov 29, 2023
1 parent ccd004b commit afd6483
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 30 deletions.
4 changes: 2 additions & 2 deletions configs/49M.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
# parallelism settings
"pipe_parallel_size": 1,
"pipe_parallel_size": 0,
"model_parallel_size": 1,

# model settings
Expand Down Expand Up @@ -49,7 +49,7 @@
"train_micro_batch_size_per_gpu": 4,
"gas": 1,
"data_impl": "mmap",
"num_workers": 1,
"num_workers": 8,

# activation checkpointing
"checkpoint_activations": true,
Expand Down
78 changes: 52 additions & 26 deletions megatron/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,31 @@ def build_train_valid_test_data_iterators(neox_args):
return train_data_iterator, valid_data_iterator, test_data_iterator


def make_streaming_data_loader(dataset, neox_args):
"""Build dataloader given an input StreamingDataset. (IterableDataset)"""
if dataset is None:
return None
# Data parallel arguments.
world_size = mpu.get_data_parallel_world_size()
rank = mpu.get_data_parallel_rank()
global_batch_size = neox_args.batch_size * world_size
num_workers = neox_args.num_workers

# Use a simple sampler with distributed batch sampler.
sampler = torch.utils.data.SequentialSampler(dataset)
# batch_sampler = DistributedBatchSampler(
# sampler=sampler,
# batch_size=global_batch_size,
# drop_last=True,
# rank=rank,
# world_size=world_size,
# )
# Torch dataloader.
return torch.utils.data.DataLoader(
dataset, num_workers=num_workers, pin_memory=True, #
)


def build_train_valid_test_data_iterators_streaming(neox_args):
"""as above, but builds Mosaic StreamingDatasets instead"""

Expand Down Expand Up @@ -558,9 +583,9 @@ def build_train_valid_test_data_iterators_streaming(neox_args):

# Build dataloders.
# TODO: confirm this gives right non-duplicated contents at each batch item
train_dataloader = make_data_loader(train_ds, neox_args=neox_args)
valid_dataloader = make_data_loader(valid_ds, neox_args=neox_args)
test_dataloader = make_data_loader(test_ds, neox_args=neox_args)
train_dataloader = make_streaming_data_loader(train_ds, neox_args=neox_args)
valid_dataloader = make_streaming_data_loader(valid_ds, neox_args=neox_args)
test_dataloader = make_streaming_data_loader(test_ds, neox_args=neox_args)

# Flags to know if we need to do training/validation/testing.
do_train = train_dataloader is not None and neox_args.train_iters > 0
Expand All @@ -586,29 +611,30 @@ def build_train_valid_test_data_iterators_streaming(neox_args):
neox_args.do_valid = flags[1].item()
neox_args.do_test = flags[2].item()

# Shift the start iterations. TODO: how to do this with streamingdatasets? might be same if we still use our megatron sampler
if train_dataloader is not None:
train_dataloader.batch_sampler.start_iter = (
neox_args.iteration * neox_args.gradient_accumulation_steps
) % len(train_dataloader)
print_rank_0(
"setting training data start iteration to {}".format(
train_dataloader.batch_sampler.start_iter
)
)
if valid_dataloader is not None:
start_iter_val = (
(neox_args.iteration * neox_args.gradient_accumulation_steps)
// neox_args.eval_interval
) * neox_args.eval_iters
valid_dataloader.batch_sampler.start_iter = start_iter_val % len(
valid_dataloader
)
print_rank_0(
"setting validation data start iteration to {}".format(
valid_dataloader.batch_sampler.start_iter
)
)
# Shift the start iterations.
# TODO: how to do this with streamingdatasets?
# if train_dataloader is not None:
# train_dataloader.batch_sampler.start_iter = (
# neox_args.iteration * neox_args.gradient_accumulation_steps
# ) % len(train_dataloader)
# print_rank_0(
# "setting training data start iteration to {}".format(
# train_dataloader.batch_sampler.start_iter
# )
# )
# if valid_dataloader is not None:
# start_iter_val = (
# (neox_args.iteration * neox_args.gradient_accumulation_steps)
# // neox_args.eval_interval
# ) * neox_args.eval_iters
# valid_dataloader.batch_sampler.start_iter = start_iter_val % len(
# valid_dataloader
# )
# print_rank_0(
# "setting validation data start iteration to {}".format(
# valid_dataloader.batch_sampler.start_iter
# )
# )

# Build iterators.
if train_dataloader is not None:
Expand Down
9 changes: 8 additions & 1 deletion megatron/data/streaming_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,12 @@ def __getitem__(self,
)
return token_sample

def __iter__(self):
idx = 0
while True:
yield self.__getitem__(idx)
idx += 1


def build_streaming_dataset(split, neox_args=None):
"""build a StreamingTextDataset"""
Expand Down Expand Up @@ -210,6 +216,7 @@ def build_streaming_dataset(split, neox_args=None):
max_seq_len=neox_args.seq_length + 1,
streams=streams,
split=None,
epoch_size=train_val_test_num_samples[split]
epoch_size=train_val_test_num_samples[split],

)

4 changes: 3 additions & 1 deletion megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,9 +315,11 @@ def get_batch(neox_args, data_iterator):

# Broadcast data.
if data_iterator is not None:
data = next(data_iterator)
data = {}
data["text"] = next(data_iterator)
else:
data = None
print(data)
return _get_batch(
neox_args=neox_args,
tokenizer=neox_args.tokenizer,
Expand Down
1 change: 1 addition & 0 deletions requirements/requirements-streaming.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
mosaicml-streaming

0 comments on commit afd6483

Please sign in to comment.