Skip to content

Commit

Permalink
add more skeleton code
Browse files Browse the repository at this point in the history
  • Loading branch information
haileyschoelkopf committed Nov 29, 2023
1 parent b52a2aa commit 4422b0d
Showing 1 changed file with 17 additions and 8 deletions.
25 changes: 17 additions & 8 deletions megatron/data/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,12 @@ def build_train_valid_test_data_iterators(neox_args):

def build_train_valid_test_data_iterators_streaming(neox_args):
"""as above, but builds Mosaic StreamingDatasets instead"""

try:
from streaming import StreamingDataset
except ModuleNotFoundError:
raise Exception("Must install `streaming` package to use StreamingDatasets!")


(train_dataloader, valid_dataloader, test_dataloader) = (None, None, None)

Expand Down Expand Up @@ -531,19 +537,22 @@ def build_train_valid_test_data_iterators_streaming(neox_args):
# neox_args.test_data_weights, train_val_test_num_samples[2]
# )


for split, data_path in zip(
["train", "valid", "test"],
[neox_args.train_data_paths, neox_args.valid_data_paths, neox_args.test_data_paths]
): # TODO: assumes only one data source per split
# Remote directory (S3 or local filesystem) where dataset is stored
remote_dir = 's3://{data_path[0]}'
# Local directory where dataset is cached during operation
local_dir = '/tmp/cache-{data_path[0]}/{split}'
dataset = StreamingDataset(local=local_dir, remote=remote_dir, split=None, shuffle=True) # TODO: sampler from megatron handles shuffle, right? check this

#Load mosaic streaming datasets from train_data_paths, valid_data_paths, test_data_paths


# next, make a blended dataset out of the ones we built (estimate how many docs we need from each?)
# next, make a blended dataset out of the ones we built (estimate how many docs we need from each?)
# TODO: pull up how you do the sampling-proportional-to-weights from Mosaic dataset

if train_datasets:
train_ds = BlendableDataset(train_datasets, train_weights)
if valid_datasets:
valid_ds = BlendableDataset(valid_datasets, valid_weights)
if test_datasets:
test_ds = BlendableDataset(test_datasets, test_weights)
else:
raise ValueError("tried to use StreamingDataset, but data_path was set in config. Please pass data via train_data_paths, valid_data_paths, and test_data_paths")

Expand Down

0 comments on commit 4422b0d

Please sign in to comment.