From ca53b7c41f4668b80318879cf5eb77d64223f5a1 Mon Sep 17 00:00:00 2001 From: Weston Pace Date: Wed, 14 Feb 2024 13:58:43 -0800 Subject: [PATCH] feat: add support for filtering to the ShardedBatchSampler (#1950) The algorithm is described in more detail in the comments. When there is a filter and randomness is requested then we will do our best to attempt randomness but the resulting sequence will not be perfectly randomized. --- python/python/lance/sampler.py | 144 ++++++++++++++++--- python/python/tests/torch_tests/test_data.py | 45 ++++++ 2 files changed, 173 insertions(+), 16 deletions(-) diff --git a/python/python/lance/sampler.py b/python/python/lance/sampler.py index 49d07a6f80..949287b058 100644 --- a/python/python/lance/sampler.py +++ b/python/python/lance/sampler.py @@ -321,6 +321,21 @@ class ShardedBatchSampler(Sampler): """Sharded batch sampler. Each rank / process will process a subset of the batches. + + The input is subdivided into batches (of size `batch_size`). Each rank / process + takes every Nth batch (where N is the world size). The order in which batches + are loaded is randomized. + + When there is no filter then each process only needs to load the rows assigned to + it but this process is still slightly less efficient than ShardedFragmentSampler + since it requires loading rows by range instead of loading all rows for a + given fragment. + + If there is a filter then we cannot divide the row ids ahead of time. Instead, + each process will load the entire filtered dataset and discard the rows that are + not assigned to it. The resulting stream is then randomized via a reservoir + sampler. This does not perfectly randomize the stream but it should generate + a stream that is random enough for many use cases. """ def __init__( @@ -343,27 +358,102 @@ def from_torch(randomize: bool = False, seed: int = 0) -> ShardedBatchSampler: world_size = torch.distributed.get_world_size() return ShardedBatchSampler(rank, world_size, randomize=randomize, seed=seed) - def __call__( + # Performs a filtered scan of the dataset and then throws away all but the Nth + # rows (where N is the world size) + def _shard_scan( self, dataset: lance.LanceDataset, - *args, - batch_size: int = 128, - columns: Optional[List[str]] = None, - filter: Optional[str] = None, - batch_readahead: int = 16, - with_row_id: Optional[bool] = None, - **kwargs, + batch_size: int, + columns: Optional[List[str]], + batch_readahead: int, + filter: str, + ) -> Generator[lance.RecordBatch, None, None]: + accumulated_batches = [] + rows_accumulated = 0 + rows_to_skip = self._rank + for batch in dataset.scanner( + columns=columns, + batch_readahead=batch_readahead, + filter=filter, + scan_in_order=True, + ).to_batches(): + batch = batch.slice(rows_to_skip, batch.num_rows - rows_to_skip) + # Take every Nth row + indices = list(range(0, batch.num_rows, self._world_size)) + rows_to_skip = ( + self._world_size - (batch.num_rows % self._world_size) + ) % self._world_size + batch = batch.take(indices) + + # Add to our collection + rows_accumulated += batch.num_rows + accumulated_batches.append(batch) + + # If we have enough to generate 1 or more batches then do so + if rows_accumulated > batch_size: + big_batch = ( + pa.Table.from_batches(accumulated_batches) + .combine_chunks() + .to_batches()[0] + ) + accumulated_batches = [] + while big_batch.num_rows > batch_size: + next_batch = big_batch.slice(0, batch_size) + big_batch = big_batch.slice(batch_size) + yield next_batch + rows_accumulated = big_batch.num_rows + if big_batch.num_rows > 0: + accumulated_batches.append(big_batch) + + def _sample_filtered( + self, + dataset: lance.LanceDataset, + batch_size: int, + columns: Optional[List[str]], + batch_readahead: int, + filter: str, + ) -> Generator[lance.RecordBatch, None, None]: + shard_scan = self._shard_scan( + dataset, batch_size, columns, batch_readahead, filter + ) + if not self._randomize: + yield from shard_scan + + random.seed(self._seed) + heap = [] + # We want to randomize the incoming sequence. The normal approach + # is to pull the whole thing in memory and run fisher-yates. We + # want to avoid buffering the entire input. So, as an approximation, + # we are using a heap + random number in a style similar to reservoir + # sampling. + # + # We will keep up to k batches in the reservoir. The higher + # k the more randomness we will get from the reservoir shuffle + # but the more memory we need. + # + # Picking 256 as a heuristic which should be 32Ki rows with + # the default batch size + k = 256 + for batch in shard_scan: + priority = random.randint(0, k * 2 - 1) + entry = PrioritizedItem(priority, batch) + if len(heap) < k: + heappush(heap, entry) + else: + next_batch = heappushpop(heap, entry) + yield next_batch.item + for batch in heap: + yield batch.item + + def _sample_all( + self, + dataset: lance.LanceDataset, + batch_size: int, + columns: Optional[List[str]], + batch_readahead: int, ) -> Generator[lance.RecordBatch, None, None]: total = dataset.count_rows() - if with_row_id is not None: - warnings.warn( - "with_row_id is not supported for ShardedBatchSampler", - ) - - if filter is not None: - raise ValueError("`filter` is not supported with ShardedBatchSampler") - def _gen_ranges(): for start in range( self._rank * batch_size, @@ -382,3 +472,25 @@ def _gen_ranges(): columns=columns, batch_readahead=batch_readahead, ) + + def __call__( + self, + dataset: lance.LanceDataset, + *args, + batch_size: int = 128, + columns: Optional[List[str]] = None, + filter: Optional[str] = None, + batch_readahead: int = 16, + with_row_id: Optional[bool] = None, + **kwargs, + ) -> Generator[lance.RecordBatch, None, None]: + if filter is None: + if with_row_id is not None: + warnings.warn( + "with_row_id is not supported for ShardedBatchSampler", + ) + return self._sample_all(dataset, batch_size, columns, batch_readahead) + else: + return self._sample_filtered( + dataset, batch_size, columns, batch_readahead, filter + ) diff --git a/python/python/tests/torch_tests/test_data.py b/python/python/tests/torch_tests/test_data.py index 42589dfe76..1cd9d94c50 100644 --- a/python/python/tests/torch_tests/test_data.py +++ b/python/python/tests/torch_tests/test_data.py @@ -196,3 +196,48 @@ def test_sample_batches(tmp_path: Path): all_ids = list(chain.from_iterable([batch.cpu().numpy() for batch in ds])) assert all_ids == [i for i in range(2000) if i // 25 % 2 == 1] + + +def test_sample_batches_with_filter(tmp_path: Path): + NUM_ROWS = 10000 + tbl = pa.Table.from_pydict({ + "id": range(NUM_ROWS), + "filterme": [i % 2 for i in range(NUM_ROWS)], + }) + + lance.write_dataset(tbl, tmp_path, max_rows_per_file=2000) + + ds = LanceDataset( + tmp_path, + batch_size=25, + columns=["id"], + with_row_id=True, + filter="filterme == 0", + sampler=ShardedBatchSampler(rank=3, world_size=5), + ) + + # The filtered sequence is 0, 2, 4, ... + # + # With rank 3 and world size 5 we should get + # + # - - - 6 - + # - - - 16 - + # - - - 26 - + # ... + all_ids = list(chain.from_iterable([batch.cpu().numpy() for batch in ds])) + assert all_ids == [6 + (10 * i) for i in range(len(all_ids))] + + # Now test with random order + ds = LanceDataset( + tmp_path, + batch_size=25, + columns=["id"], + with_row_id=True, + filter="filterme == 0", + sampler=ShardedBatchSampler(rank=3, world_size=5, randomize=True), + ) + + randomized_ids = list(chain.from_iterable([batch.cpu().numpy() for batch in ds])) + assert randomized_ids != all_ids + randomized_ids.sort() + assert randomized_ids == all_ids