Skip to content

Commit

Permalink
Add py1e warning when Shuffle block size is smaller than shard size (#…
Browse files Browse the repository at this point in the history
…463)

* py1e warning

* added pytest warning

* Remove print statement in a test file

---------

Co-authored-by: James Knighton <[email protected]>
Co-authored-by: Karan Jariwala <[email protected]>
Co-authored-by: Karan Jariwala <[email protected]>
  • Loading branch information
4 people authored Oct 13, 2023
1 parent 5a5fa6f commit baad3d9
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 0 deletions.
16 changes: 16 additions & 0 deletions streaming/base/shuffle/py1e.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
is determined by ``shuffle_block_size``.
"""

import warnings

import numpy as np
from numpy.typing import NDArray

Expand Down Expand Up @@ -63,6 +65,7 @@ def get_shuffle_py1e(shard_sizes: NDArray[np.int64],
# Populate the global sample ID mapping, shuffling within each span.
ids = np.empty(num_samples, np.int64)
offset = 0
warn_user = False
# Iterate through each canonical node's spans.
# We don't want samples crossing canonical node boundaries.
for cn_begin, cn_end in super_spans:
Expand Down Expand Up @@ -94,6 +97,12 @@ def get_shuffle_py1e(shard_sizes: NDArray[np.int64],
# This ensures that the span samples are only found in a max range of rand_block_size.
cutoff = (rand_block_size - span_size) / 2

# if cutoff is negative, this means span size is less than rand_block_size, so we set
# cutoff to 0 (no shuffling for this span) and warn the user later.
if cutoff < 0:
cutoff = 0
warn_user = True

# Make sure the lower bound of the range doesn't cross the start of the canonical node.
lower_bound = max(-cutoff, -cn_sample_offset)
# Make sure the upper bound of the range doesn't cross the end of the canonical node.
Expand All @@ -118,4 +127,11 @@ def get_shuffle_py1e(shard_sizes: NDArray[np.int64],

offset += num_cn_samples

# If warn_user is true, this means the block size for shifts was smaller than a span size.
# This will result in no shuffling being done on that span aka shard part, so warn user.
if warn_user:
warnings.warn('Shuffle block size was smaller than shard size for some shards. This \
will result in these shards not being shuffled with other shards. Set \
shuffle_block_size to a larger value for a higher quality shuffle.')

return ids
31 changes: 31 additions & 0 deletions tests/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,37 @@ def run_one_iter(local: str, remote: str, seed: int) -> None:
assert result == 0


@pytest.mark.parametrize('batch_size', [4])
@pytest.mark.parametrize('size_limit', [256, 512])
@pytest.mark.parametrize('seed', [1111])
@pytest.mark.parametrize('shuffle', [True])
@pytest.mark.parametrize('shuffle_block_size', [50, 100])
@pytest.mark.usefixtures('local_remote_dir')
def test_py1e_shuffle_block_warning(local_remote_dir: Any, batch_size: int, size_limit: int,
seed: int, shuffle: bool, shuffle_block_size: int):
remote_dir, local_dir = local_remote_dir
# Here, size_limit is in bytes. Each SequenceDataset sample is around 10 bytes, but the header
# will also take up some space.
convert_to_mds(out_root=remote_dir,
dataset_name='sequencedataset',
num_samples=1000,
size_limit=(size_limit * 10) + 1000)

dataset = StreamingDataset(local=local_dir,
remote=remote_dir,
shuffle=shuffle,
batch_size=batch_size,
num_canonical_nodes=1,
shuffle_seed=seed,
shuffle_algo='py1e',
shuffle_block_size=shuffle_block_size)
dataloader = DataLoader(dataset=dataset, batch_size=batch_size)

with pytest.warns(UserWarning, match=f'Shuffle block size was smaller than shard size*'):
for _ in dataloader:
pass


@pytest.mark.parametrize('batch_size', [128])
@pytest.mark.parametrize('drop_last', [False, True])
@pytest.mark.parametrize('shuffle', [False, True])
Expand Down

0 comments on commit baad3d9

Please sign in to comment.