From 85bf0e429b06ade6bb4385ee3d28c4d12c7c96b1 Mon Sep 17 00:00:00 2001 From: Saaketh Date: Thu, 12 Oct 2023 17:44:01 -0700 Subject: [PATCH] added pytest warning --- streaming/base/shuffle/py1e.py | 13 +++++++------ tests/test_streaming.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 6 deletions(-) diff --git a/streaming/base/shuffle/py1e.py b/streaming/base/shuffle/py1e.py index b9ecc32dc..3583caa22 100644 --- a/streaming/base/shuffle/py1e.py +++ b/streaming/base/shuffle/py1e.py @@ -116,12 +116,6 @@ def get_shuffle_py1e(shard_sizes: NDArray[np.int64], # Update sample offset for the next shard. cn_sample_offset += span_size - # If warn_user is true, this means the block size for shifts was smaller than a span size. - # This will result in no shuffling being done on that span aka shard part, so warn user. - if warn_user: - warnings.warn('Shuffle block size was smaller than shard size for some shards. This \ - will result in these shards not being shuffled with other shards. Set \ - shuffle_block_size to a larger value for a higher quality shuffle.') # Get incides that would sort the sample_positions array. sort_indices = np.argsort(sample_positions) @@ -133,4 +127,11 @@ def get_shuffle_py1e(shard_sizes: NDArray[np.int64], offset += num_cn_samples + # If warn_user is true, this means the block size for shifts was smaller than a span size. + # This will result in no shuffling being done on that span aka shard part, so warn user. + if warn_user: + warnings.warn('Shuffle block size was smaller than shard size for some shards. This \ + will result in these shards not being shuffled with other shards. Set \ + shuffle_block_size to a larger value for a higher quality shuffle.') + return ids diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 6bef12839..0252ccc98 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -510,6 +510,37 @@ def run_one_iter(local: str, remote: str, seed: int) -> None: assert result == 0 +@pytest.mark.parametrize('batch_size', [4]) +@pytest.mark.parametrize('size_limit', [256, 512]) +@pytest.mark.parametrize('seed', [1111]) +@pytest.mark.parametrize('shuffle', [True]) +@pytest.mark.parametrize('shuffle_block_size', [50, 100]) +@pytest.mark.usefixtures('local_remote_dir') +def test_py1e_shuffle_block_warning(local_remote_dir: Any, batch_size: int, size_limit: int, + seed: int, shuffle: bool, shuffle_block_size: int): + remote_dir, local_dir = local_remote_dir + # Here, size_limit is in bytes. Each SequenceDataset sample is around 10 bytes, but the header + # will also take up some space. + convert_to_mds(out_root=remote_dir, + dataset_name='sequencedataset', + num_samples=1000, + size_limit=(size_limit * 10) + 1000) + + dataset = StreamingDataset(local=local_dir, + remote=remote_dir, + shuffle=shuffle, + batch_size=batch_size, + num_canonical_nodes=1, + shuffle_seed=seed, + shuffle_algo='py1e', + shuffle_block_size=shuffle_block_size) + dataloader = DataLoader(dataset=dataset, batch_size=batch_size) + + with pytest.warns(UserWarning, match=f'Shuffle block size was smaller than shard size*'): + for _ in dataloader: + print('yo') + + @pytest.mark.parametrize('batch_size', [128]) @pytest.mark.parametrize('drop_last', [False, True]) @pytest.mark.parametrize('shuffle', [False, True])