Skip to content

Commit

Permalink
feat: add support for filtering to the ShardedBatchSampler (#1950)
Browse files Browse the repository at this point in the history
The algorithm is described in more detail in the comments. When there is
a filter and randomness is requested then we will do our best to attempt
randomness but the resulting sequence will not be perfectly randomized.
  • Loading branch information
westonpace authored Feb 14, 2024
1 parent 5100222 commit ca53b7c
Show file tree
Hide file tree
Showing 2 changed files with 173 additions and 16 deletions.
144 changes: 128 additions & 16 deletions python/python/lance/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,21 @@ class ShardedBatchSampler(Sampler):
"""Sharded batch sampler.
Each rank / process will process a subset of the batches.
The input is subdivided into batches (of size `batch_size`). Each rank / process
takes every Nth batch (where N is the world size). The order in which batches
are loaded is randomized.
When there is no filter then each process only needs to load the rows assigned to
it but this process is still slightly less efficient than ShardedFragmentSampler
since it requires loading rows by range instead of loading all rows for a
given fragment.
If there is a filter then we cannot divide the row ids ahead of time. Instead,
each process will load the entire filtered dataset and discard the rows that are
not assigned to it. The resulting stream is then randomized via a reservoir
sampler. This does not perfectly randomize the stream but it should generate
a stream that is random enough for many use cases.
"""

def __init__(
Expand All @@ -343,27 +358,102 @@ def from_torch(randomize: bool = False, seed: int = 0) -> ShardedBatchSampler:
world_size = torch.distributed.get_world_size()
return ShardedBatchSampler(rank, world_size, randomize=randomize, seed=seed)

def __call__(
# Performs a filtered scan of the dataset and then throws away all but the Nth
# rows (where N is the world size)
def _shard_scan(
self,
dataset: lance.LanceDataset,
*args,
batch_size: int = 128,
columns: Optional[List[str]] = None,
filter: Optional[str] = None,
batch_readahead: int = 16,
with_row_id: Optional[bool] = None,
**kwargs,
batch_size: int,
columns: Optional[List[str]],
batch_readahead: int,
filter: str,
) -> Generator[lance.RecordBatch, None, None]:
accumulated_batches = []
rows_accumulated = 0
rows_to_skip = self._rank
for batch in dataset.scanner(
columns=columns,
batch_readahead=batch_readahead,
filter=filter,
scan_in_order=True,
).to_batches():
batch = batch.slice(rows_to_skip, batch.num_rows - rows_to_skip)
# Take every Nth row
indices = list(range(0, batch.num_rows, self._world_size))
rows_to_skip = (
self._world_size - (batch.num_rows % self._world_size)
) % self._world_size
batch = batch.take(indices)

# Add to our collection
rows_accumulated += batch.num_rows
accumulated_batches.append(batch)

# If we have enough to generate 1 or more batches then do so
if rows_accumulated > batch_size:
big_batch = (
pa.Table.from_batches(accumulated_batches)
.combine_chunks()
.to_batches()[0]
)
accumulated_batches = []
while big_batch.num_rows > batch_size:
next_batch = big_batch.slice(0, batch_size)
big_batch = big_batch.slice(batch_size)
yield next_batch
rows_accumulated = big_batch.num_rows
if big_batch.num_rows > 0:
accumulated_batches.append(big_batch)

def _sample_filtered(
self,
dataset: lance.LanceDataset,
batch_size: int,
columns: Optional[List[str]],
batch_readahead: int,
filter: str,
) -> Generator[lance.RecordBatch, None, None]:
shard_scan = self._shard_scan(
dataset, batch_size, columns, batch_readahead, filter
)
if not self._randomize:
yield from shard_scan

random.seed(self._seed)
heap = []
# We want to randomize the incoming sequence. The normal approach
# is to pull the whole thing in memory and run fisher-yates. We
# want to avoid buffering the entire input. So, as an approximation,
# we are using a heap + random number in a style similar to reservoir
# sampling.
#
# We will keep up to k batches in the reservoir. The higher
# k the more randomness we will get from the reservoir shuffle
# but the more memory we need.
#
# Picking 256 as a heuristic which should be 32Ki rows with
# the default batch size
k = 256
for batch in shard_scan:
priority = random.randint(0, k * 2 - 1)
entry = PrioritizedItem(priority, batch)
if len(heap) < k:
heappush(heap, entry)
else:
next_batch = heappushpop(heap, entry)
yield next_batch.item
for batch in heap:
yield batch.item

def _sample_all(
self,
dataset: lance.LanceDataset,
batch_size: int,
columns: Optional[List[str]],
batch_readahead: int,
) -> Generator[lance.RecordBatch, None, None]:
total = dataset.count_rows()

if with_row_id is not None:
warnings.warn(
"with_row_id is not supported for ShardedBatchSampler",
)

if filter is not None:
raise ValueError("`filter` is not supported with ShardedBatchSampler")

def _gen_ranges():
for start in range(
self._rank * batch_size,
Expand All @@ -382,3 +472,25 @@ def _gen_ranges():
columns=columns,
batch_readahead=batch_readahead,
)

def __call__(
self,
dataset: lance.LanceDataset,
*args,
batch_size: int = 128,
columns: Optional[List[str]] = None,
filter: Optional[str] = None,
batch_readahead: int = 16,
with_row_id: Optional[bool] = None,
**kwargs,
) -> Generator[lance.RecordBatch, None, None]:
if filter is None:
if with_row_id is not None:
warnings.warn(
"with_row_id is not supported for ShardedBatchSampler",
)
return self._sample_all(dataset, batch_size, columns, batch_readahead)
else:
return self._sample_filtered(
dataset, batch_size, columns, batch_readahead, filter
)
45 changes: 45 additions & 0 deletions python/python/tests/torch_tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,3 +196,48 @@ def test_sample_batches(tmp_path: Path):

all_ids = list(chain.from_iterable([batch.cpu().numpy() for batch in ds]))
assert all_ids == [i for i in range(2000) if i // 25 % 2 == 1]


def test_sample_batches_with_filter(tmp_path: Path):
NUM_ROWS = 10000
tbl = pa.Table.from_pydict({
"id": range(NUM_ROWS),
"filterme": [i % 2 for i in range(NUM_ROWS)],
})

lance.write_dataset(tbl, tmp_path, max_rows_per_file=2000)

ds = LanceDataset(
tmp_path,
batch_size=25,
columns=["id"],
with_row_id=True,
filter="filterme == 0",
sampler=ShardedBatchSampler(rank=3, world_size=5),
)

# The filtered sequence is 0, 2, 4, ...
#
# With rank 3 and world size 5 we should get
#
# - - - 6 -
# - - - 16 -
# - - - 26 -
# ...
all_ids = list(chain.from_iterable([batch.cpu().numpy() for batch in ds]))
assert all_ids == [6 + (10 * i) for i in range(len(all_ids))]

# Now test with random order
ds = LanceDataset(
tmp_path,
batch_size=25,
columns=["id"],
with_row_id=True,
filter="filterme == 0",
sampler=ShardedBatchSampler(rank=3, world_size=5, randomize=True),
)

randomized_ids = list(chain.from_iterable([batch.cpu().numpy() for batch in ds]))
assert randomized_ids != all_ids
randomized_ids.sort()
assert randomized_ids == all_ids

0 comments on commit ca53b7c

Please sign in to comment.