Skip to content

Commit

Permalink
added pytest warning
Browse files Browse the repository at this point in the history
  • Loading branch information
snarayan21 committed Oct 13, 2023
1 parent 5b25c1d commit 85bf0e4
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 6 deletions.
13 changes: 7 additions & 6 deletions streaming/base/shuffle/py1e.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,6 @@ def get_shuffle_py1e(shard_sizes: NDArray[np.int64],
# Update sample offset for the next shard.
cn_sample_offset += span_size

# If warn_user is true, this means the block size for shifts was smaller than a span size.
# This will result in no shuffling being done on that span aka shard part, so warn user.
if warn_user:
warnings.warn('Shuffle block size was smaller than shard size for some shards. This \
will result in these shards not being shuffled with other shards. Set \
shuffle_block_size to a larger value for a higher quality shuffle.')
# Get incides that would sort the sample_positions array.
sort_indices = np.argsort(sample_positions)

Expand All @@ -133,4 +127,11 @@ def get_shuffle_py1e(shard_sizes: NDArray[np.int64],

offset += num_cn_samples

# If warn_user is true, this means the block size for shifts was smaller than a span size.
# This will result in no shuffling being done on that span aka shard part, so warn user.
if warn_user:
warnings.warn('Shuffle block size was smaller than shard size for some shards. This \
will result in these shards not being shuffled with other shards. Set \
shuffle_block_size to a larger value for a higher quality shuffle.')

return ids
31 changes: 31 additions & 0 deletions tests/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,37 @@ def run_one_iter(local: str, remote: str, seed: int) -> None:
assert result == 0


@pytest.mark.parametrize('batch_size', [4])
@pytest.mark.parametrize('size_limit', [256, 512])
@pytest.mark.parametrize('seed', [1111])
@pytest.mark.parametrize('shuffle', [True])
@pytest.mark.parametrize('shuffle_block_size', [50, 100])
@pytest.mark.usefixtures('local_remote_dir')
def test_py1e_shuffle_block_warning(local_remote_dir: Any, batch_size: int, size_limit: int,
seed: int, shuffle: bool, shuffle_block_size: int):
remote_dir, local_dir = local_remote_dir
# Here, size_limit is in bytes. Each SequenceDataset sample is around 10 bytes, but the header
# will also take up some space.
convert_to_mds(out_root=remote_dir,
dataset_name='sequencedataset',
num_samples=1000,
size_limit=(size_limit * 10) + 1000)

dataset = StreamingDataset(local=local_dir,
remote=remote_dir,
shuffle=shuffle,
batch_size=batch_size,
num_canonical_nodes=1,
shuffle_seed=seed,
shuffle_algo='py1e',
shuffle_block_size=shuffle_block_size)
dataloader = DataLoader(dataset=dataset, batch_size=batch_size)

with pytest.warns(UserWarning, match=f'Shuffle block size was smaller than shard size*'):
for _ in dataloader:
print('yo')


@pytest.mark.parametrize('batch_size', [128])
@pytest.mark.parametrize('drop_last', [False, True])
@pytest.mark.parametrize('shuffle', [False, True])
Expand Down

0 comments on commit 85bf0e4

Please sign in to comment.