diff --git a/streaming/base/shuffle/py1e.py b/streaming/base/shuffle/py1e.py index 6127341e1..3583caa22 100644 --- a/streaming/base/shuffle/py1e.py +++ b/streaming/base/shuffle/py1e.py @@ -8,6 +8,8 @@ is determined by ``shuffle_block_size``. """ +import warnings + import numpy as np from numpy.typing import NDArray @@ -63,6 +65,7 @@ def get_shuffle_py1e(shard_sizes: NDArray[np.int64], # Populate the global sample ID mapping, shuffling within each span. ids = np.empty(num_samples, np.int64) offset = 0 + warn_user = False # Iterate through each canonical node's spans. # We don't want samples crossing canonical node boundaries. for cn_begin, cn_end in super_spans: @@ -94,6 +97,12 @@ def get_shuffle_py1e(shard_sizes: NDArray[np.int64], # This ensures that the span samples are only found in a max range of rand_block_size. cutoff = (rand_block_size - span_size) / 2 + # if cutoff is negative, this means span size is less than rand_block_size, so we set + # cutoff to 0 (no shuffling for this span) and warn the user later. + if cutoff < 0: + cutoff = 0 + warn_user = True + # Make sure the lower bound of the range doesn't cross the start of the canonical node. lower_bound = max(-cutoff, -cn_sample_offset) # Make sure the upper bound of the range doesn't cross the end of the canonical node. @@ -118,4 +127,11 @@ def get_shuffle_py1e(shard_sizes: NDArray[np.int64], offset += num_cn_samples + # If warn_user is true, this means the block size for shifts was smaller than a span size. + # This will result in no shuffling being done on that span aka shard part, so warn user. + if warn_user: + warnings.warn('Shuffle block size was smaller than shard size for some shards. This \ + will result in these shards not being shuffled with other shards. Set \ + shuffle_block_size to a larger value for a higher quality shuffle.') + return ids diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 6bef12839..086f894a7 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -510,6 +510,37 @@ def run_one_iter(local: str, remote: str, seed: int) -> None: assert result == 0 +@pytest.mark.parametrize('batch_size', [4]) +@pytest.mark.parametrize('size_limit', [256, 512]) +@pytest.mark.parametrize('seed', [1111]) +@pytest.mark.parametrize('shuffle', [True]) +@pytest.mark.parametrize('shuffle_block_size', [50, 100]) +@pytest.mark.usefixtures('local_remote_dir') +def test_py1e_shuffle_block_warning(local_remote_dir: Any, batch_size: int, size_limit: int, + seed: int, shuffle: bool, shuffle_block_size: int): + remote_dir, local_dir = local_remote_dir + # Here, size_limit is in bytes. Each SequenceDataset sample is around 10 bytes, but the header + # will also take up some space. + convert_to_mds(out_root=remote_dir, + dataset_name='sequencedataset', + num_samples=1000, + size_limit=(size_limit * 10) + 1000) + + dataset = StreamingDataset(local=local_dir, + remote=remote_dir, + shuffle=shuffle, + batch_size=batch_size, + num_canonical_nodes=1, + shuffle_seed=seed, + shuffle_algo='py1e', + shuffle_block_size=shuffle_block_size) + dataloader = DataLoader(dataset=dataset, batch_size=batch_size) + + with pytest.warns(UserWarning, match=f'Shuffle block size was smaller than shard size*'): + for _ in dataloader: + pass + + @pytest.mark.parametrize('batch_size', [128]) @pytest.mark.parametrize('drop_last', [False, True]) @pytest.mark.parametrize('shuffle', [False, True])