Skip to content

Commit

Permalink
py1e warning
Browse files Browse the repository at this point in the history
  • Loading branch information
snarayan21 committed Oct 6, 2023
1 parent 6b63851 commit 5b25c1d
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions streaming/base/shuffle/py1e.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
is determined by ``shuffle_block_size``.
"""

import warnings

import numpy as np
from numpy.typing import NDArray

Expand Down Expand Up @@ -63,6 +65,7 @@ def get_shuffle_py1e(shard_sizes: NDArray[np.int64],
# Populate the global sample ID mapping, shuffling within each span.
ids = np.empty(num_samples, np.int64)
offset = 0
warn_user = False
# Iterate through each canonical node's spans.
# We don't want samples crossing canonical node boundaries.
for cn_begin, cn_end in super_spans:
Expand Down Expand Up @@ -94,6 +97,12 @@ def get_shuffle_py1e(shard_sizes: NDArray[np.int64],
# This ensures that the span samples are only found in a max range of rand_block_size.
cutoff = (rand_block_size - span_size) / 2

# if cutoff is negative, this means span size is less than rand_block_size, so we set
# cutoff to 0 (no shuffling for this span) and warn the user later.
if cutoff < 0:
cutoff = 0
warn_user = True

# Make sure the lower bound of the range doesn't cross the start of the canonical node.
lower_bound = max(-cutoff, -cn_sample_offset)
# Make sure the upper bound of the range doesn't cross the end of the canonical node.
Expand All @@ -107,6 +116,12 @@ def get_shuffle_py1e(shard_sizes: NDArray[np.int64],
# Update sample offset for the next shard.
cn_sample_offset += span_size

# If warn_user is true, this means the block size for shifts was smaller than a span size.
# This will result in no shuffling being done on that span aka shard part, so warn user.
if warn_user:
warnings.warn('Shuffle block size was smaller than shard size for some shards. This \
will result in these shards not being shuffled with other shards. Set \
shuffle_block_size to a larger value for a higher quality shuffle.')
# Get incides that would sort the sample_positions array.
sort_indices = np.argsort(sample_positions)

Expand Down

0 comments on commit 5b25c1d

Please sign in to comment.