From 193b6fc0a18defc7b1d03debd35d49a624717eff Mon Sep 17 00:00:00 2001 From: Saaketh Narayan Date: Mon, 11 Nov 2024 10:46:31 -0500 Subject: [PATCH] Consistent errors for unused streams in batching methods (#826) --- streaming/base/batching/device_per_stream.py | 5 +++++ streaming/base/batching/per_stream.py | 6 +++--- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/streaming/base/batching/device_per_stream.py b/streaming/base/batching/device_per_stream.py index 282c255b4..63e1e8d9a 100644 --- a/streaming/base/batching/device_per_stream.py +++ b/streaming/base/batching/device_per_stream.py @@ -80,6 +80,11 @@ def generate_work_device_per_stream_batching(dataset: StreamingDataset, world: W raise TypeError(f'Dataset `shuffle_block_size` must be an integer. ' + f'Got {type(dataset.shuffle_block_size)} instead.') shuffle_block_portion = int(dataset.shuffle_block_size * stream.proportion) + if shuffle_block_portion == 0: + raise ValueError(f'Samples from stream {stream_id} are not being used. Please ' + + f'either increase the `shuffle_block_size` from ' + + f'{dataset.shuffle_block_size}, or increase the stream ' + + f'proportion from {stream.proportion}.') stream_shuffle = get_shuffle(dataset.shuffle_algo, shuffle_units, dataset.num_canonical_nodes, dataset.shuffle_seed, epoch, shuffle_block_portion) diff --git a/streaming/base/batching/per_stream.py b/streaming/base/batching/per_stream.py index 70f12a0dd..70c44c047 100644 --- a/streaming/base/batching/per_stream.py +++ b/streaming/base/batching/per_stream.py @@ -98,9 +98,9 @@ def generate_work_per_stream_batching(dataset: StreamingDataset, world: World, e if num_full_batches > 0: batches_from_partitions.append(global_batches_inorder[:num_full_batches]) else: - logger.warning(f'Stream with index {stream_idx} does not have an adequate number of ' + - f'samples to construct a complete global batch. Training will occur ' + - f'without any samples from this stream!') + raise ValueError(f'Stream with index {stream_idx} does not have an adequate number ' + + f'of samples to construct a complete global batch. Training will ' + + f'occur without any samples from this stream.') # Combine all global batches from all streams into one array. all_partition_batches = np.concatenate(batches_from_partitions)