From 8e93813313297ba6a7ba74fbed7b34b7fdce6789 Mon Sep 17 00:00:00 2001 From: coufon Date: Sun, 28 Jan 2024 19:57:15 +0000 Subject: [PATCH] Read added data in parallel read streams, and transform in parallel --- python/src/space/core/ops/change_data.py | 25 +++++--- python/src/space/core/options.py | 5 +- python/src/space/core/runners.py | 15 ++++- python/src/space/core/views.py | 4 ++ python/src/space/ray/ops/change_data.py | 51 +++++++++++---- python/src/space/ray/runners.py | 75 +++++++++++++++++------ python/tests/core/ops/test_change_data.py | 22 ++----- python/tests/ray/test_runners.py | 30 +++------ 8 files changed, 145 insertions(+), 82 deletions(-) diff --git a/python/src/space/core/ops/change_data.py b/python/src/space/core/ops/change_data.py index 9da626d..2f08099 100644 --- a/python/src/space/core/ops/change_data.py +++ b/python/src/space/core/ops/change_data.py @@ -14,9 +14,11 @@ # """Change data feed that computes delta between two snapshots.""" +import copy from dataclasses import dataclass from enum import Enum -from typing import Any, Iterable, Iterator, List +from typing import TYPE_CHECKING +from typing import Iterable, Iterator, List, Union import pyarrow as pa @@ -29,6 +31,9 @@ from space.core.utils import errors from space.core.utils.paths import StoragePathsMixin +if TYPE_CHECKING: + from space.core.utils.lazy_imports_utils import ray + class ChangeType(Enum): """Type of data changes.""" @@ -50,9 +55,8 @@ class ChangeData: # The change type. type_: ChangeType - # The change data (pa.Table or ray.data.Dataset). - # NOTE: type annotation not used, because of Ray lazy import. - data: Any + # The change data. + data: Union[pa.Table, List["ray.data.Dataset"]] def ordered_snapshot_ids(storage: Storage, start_snapshot_id: int, @@ -108,6 +112,9 @@ def __init__(self, storage: Storage, snapshot_id: int, self._snapshot_id = snapshot_id self._read_options = read_options + self._pk_only_read_option = copy.deepcopy(read_options) + self._pk_only_read_option.fields = self._storage.primary_keys + if snapshot_id not in self._metadata.snapshots: raise errors.VersionNotFoundError( f"Change data read can't find snapshot ID {snapshot_id}") @@ -121,18 +128,20 @@ def __iter__(self) -> Iterator[ChangeData]: # deletions and additions, it may delete newly added data. # TODO: to enforce this check upstream, or merge deletion+addition as a # update. - for data in self._read_op(self._change_log.deleted_rows): + for data in self._read_op(self._change_log.deleted_rows, + self._pk_only_read_option): yield ChangeData(self._snapshot_id, ChangeType.DELETE, data) - for data in self._read_op(self._change_log.added_rows): + for data in self._read_op(self._change_log.added_rows, self._read_options): yield ChangeData(self._snapshot_id, ChangeType.ADD, data) - def _read_op(self, bitmaps: Iterable[meta.RowBitmap]) -> Iterator[pa.Table]: + def _read_op(self, bitmaps: Iterable[meta.RowBitmap], + read_options: ReadOptions) -> Iterator[pa.Table]: return iter( FileSetReadOp(self._storage.location, self._metadata, self._bitmaps_to_file_set(bitmaps), - options=self._read_options)) + options=read_options)) @classmethod def _bitmaps_to_file_set(cls, diff --git a/python/src/space/core/options.py b/python/src/space/core/options.py index 1952656..b0f79d2 100644 --- a/python/src/space/core/options.py +++ b/python/src/space/core/options.py @@ -19,6 +19,9 @@ import pyarrow.compute as pc +# Default number of rows per batch in read result. +DEFAULT_READ_BATCH_SIZE = 16 + @dataclass class ReadOptions: @@ -50,7 +53,7 @@ class ReadOptions: batch_size: Optional[int] = None def __post_init__(self): - self.batch_size = self.batch_size or 16 + self.batch_size = self.batch_size or DEFAULT_READ_BATCH_SIZE @dataclass diff --git a/python/src/space/core/runners.py b/python/src/space/core/runners.py index 80e101b..855e459 100644 --- a/python/src/space/core/runners.py +++ b/python/src/space/core/runners.py @@ -81,8 +81,19 @@ def read_all( def diff(self, start_version: Union[int], end_version: Union[int]) -> Iterator[ChangeData]: """Read the change data between two versions. - - start_version is excluded; end_version is included. + + NOTE: it has limitations: + - For DELETE change type, only primary keys are returned + - DELETE changes are not processed by the UDF in transforms. For `filter` + transform, it may return additional rows that are deleted in source but + should be filtered in target. It does not affect correctness of sync. + + Args: + start_version: start version, not included in result + end_version: end version, included in result + + Return: + An iterator of change data """ diff --git a/python/src/space/core/views.py b/python/src/space/core/views.py index e17fbd6..fe4a980 100644 --- a/python/src/space/core/views.py +++ b/python/src/space/core/views.py @@ -155,6 +155,10 @@ def filter(self, fn: Callable, input_fields: Optional[List[str]] = None) -> View: """Filter rows by the provided user defined function. + + TODO: this filter is not applied to the deleted rows returned by diff(), it + thus returns more rows than expected. It does not affect correctness when + syncing the deletion to target MV, because the additional rows don't exist. Args: fn: a user defined function on batches. diff --git a/python/src/space/ray/ops/change_data.py b/python/src/space/ray/ops/change_data.py index 5e2a841..3cfb02a 100644 --- a/python/src/space/ray/ops/change_data.py +++ b/python/src/space/ray/ops/change_data.py @@ -14,6 +14,7 @@ # """Change data feed that computes delta between two snapshots by Ray.""" +import math from typing import Iterable, Iterator import ray @@ -55,16 +56,42 @@ def __iter__(self) -> Iterator[ChangeData]: # deletions and additions, it may delete newly added data. # TODO: to enforce this check upstream, or merge deletion+addition as a # update. - yield ChangeData(self._snapshot_id, ChangeType.DELETE, - self._ray_dataset(self._change_log.deleted_rows)) - yield ChangeData(self._snapshot_id, ChangeType.ADD, - self._ray_dataset(self._change_log.added_rows)) + if self._change_log.deleted_rows: + # Only read primary keys for deletions. The data to read is relatively + # small. In addition, currently deletion has to aggregate primary keys + # to delete (can't parallelize two sets of keys to delete). So we don't + # spit it to parallel read streams. + ds = self._ray_dataset(self._change_log.deleted_rows, + self._pk_only_read_option, + self._ray_options.max_parallelism) + yield ChangeData(self._snapshot_id, ChangeType.DELETE, [ds]) - def _ray_dataset(self, bitmaps: Iterable[meta.RowBitmap]) -> ray.data.Dataset: - return ray.data.read_datasource( - ray_data_sources.SpaceDataSource(), - storage=self._storage, - ray_options=self._ray_options, - read_options=self._read_options, - file_set=self._bitmaps_to_file_set(bitmaps), - parallelism=self._ray_options.max_parallelism) + if self._change_log.added_rows: + # Split added data into parallel read streams. + num_files = len(self._change_log.added_rows) + num_streams = self._ray_options.max_parallelism + shard_size = math.ceil(num_files / num_streams) + + shards = [] + for i in range(num_streams): + start = i * shard_size + end = min((i + 1) * shard_size, num_files) + shards.append(self._change_log.added_rows[start:end]) + + # Parallelism 1 means one reader for each read stream. + # There are `ray_options.max_parallelism` read streams. + # TODO: to measure performance and adjust. + yield ChangeData(self._snapshot_id, ChangeType.ADD, [ + self._ray_dataset(s, self._read_options, parallelism=1) + for s in shards + ]) + + def _ray_dataset(self, bitmaps: Iterable[meta.RowBitmap], + read_options: ReadOptions, + parallelism: int) -> ray.data.Dataset: + return ray.data.read_datasource(ray_data_sources.SpaceDataSource(), + storage=self._storage, + ray_options=self._ray_options, + read_options=read_options, + file_set=self._bitmaps_to_file_set(bitmaps), + parallelism=parallelism) diff --git a/python/src/space/ray/runners.py b/python/src/space/ray/runners.py index b5aecc6..f1021e5 100644 --- a/python/src/space/ray/runners.py +++ b/python/src/space/ray/runners.py @@ -16,11 +16,13 @@ from __future__ import annotations import copy +from functools import partial from typing import TYPE_CHECKING from typing import Iterator, List, Optional, Union import pyarrow as pa import pyarrow.compute as pc +import ray from space.core.jobs import JobResult from space.core.loaders.array_record import ArrayRecordIndexFn @@ -28,7 +30,6 @@ from space.core.runners import StorageMixin from space.core.ops import utils from space.core.ops.utils import FileOptions -from space.core.ops.append import LocalAppendOp from space.core.ops.base import InputData, InputIteratorFn from space.core.ops.change_data import ChangeData, ChangeType from space.core.ops.delete import FileSetDeleteOp @@ -89,6 +90,18 @@ def diff(self, start_version: Union[Version], end_version: Union[Version], batch_size: Optional[int] = None) -> Iterator[ChangeData]: + for change in self.diff_ray(start_version, end_version, batch_size): + assert isinstance(change.data, list) + for ds in change.data: + assert isinstance(ds, ray.data.Dataset) + for data in iter_batches(ds): + yield ChangeData(change.snapshot_id, change.type_, data) + + def diff_ray(self, + start_version: Union[Version], + end_version: Union[Version], + batch_size: Optional[int] = None) -> Iterator[ChangeData]: + """Return diff data in form of a list of Ray datasets.""" self._source_storage.reload() source_changes = read_change_data( self._source_storage, @@ -97,11 +110,22 @@ def diff(self, self._ray_options, ReadOptions(batch_size=batch_size)) for change in source_changes: - # TODO: skip processing the data for deletions; the caller is usually - # only interested at deleted primary keys. - # TODO: to split change data into chunks for parallel processing. - for data in iter_batches(self._view.process_source(change.data)): - yield ChangeData(change.snapshot_id, change.type_, data) + if change.type_ == ChangeType.DELETE: + yield change + + elif change.type_ == ChangeType.ADD: + # Change data is a list of Ray datasets, because of parallel read + # streams. It allows us to do parallel transforms here. + assert isinstance(change.data, list) + processed_data: List[ray.data.Dataset] = [] + for ds in change.data: + assert isinstance(ds, ray.data.Dataset) + processed_data.append(self._view.process_source(ds)) + + yield ChangeData(change.snapshot_id, change.type_, processed_data) + + else: + raise NotImplementedError(f"Change type {change.type_} not supported") @property def _source_storage(self) -> Storage: @@ -170,7 +194,9 @@ def refresh(self, previous_snapshot_id: Optional[int] = None txn = self._start_txn() - for change in self.diff(start_snapshot_id, end_snapshot_id, batch_size): + for change in self.diff_ray(start_snapshot_id, end_snapshot_id, batch_size): + assert isinstance(change.data, list) + # Commit when changes from the same snapshot end. if (previous_snapshot_id is not None and change.snapshot_id != previous_snapshot_id): @@ -206,8 +232,13 @@ def refresh(self, return job_results - def _process_delete(self, data: pa.Table) -> Optional[rt.Patch]: - filter_ = utils.primary_key_filter(self._storage.primary_keys, data) + def _process_delete(self, data: List[ray.data.Dataset]) -> Optional[rt.Patch]: + # Deletion does not use parallel read streams. + assert len(data) == 1 + arrow_data = pa.concat_tables(iter_batches( + data[0])) # type: ignore[arg-type] + + filter_ = utils.primary_key_filter(self._storage.primary_keys, arrow_data) if filter_ is None: return None @@ -216,12 +247,10 @@ def _process_delete(self, data: pa.Table) -> Optional[rt.Patch]: self._file_options) return op.delete() - def _process_append(self, data: pa.Table) -> Optional[rt.Patch]: - # TODO: to use RayAppendOp. - op = LocalAppendOp(self._storage.location, self._storage.metadata, - self._file_options) - op.write(data) - return op.finish() + def _process_append(self, data: List[ray.data.Dataset]) -> Optional[rt.Patch]: + return _append_from(self._storage, + [partial(iter_batches, ds) for ds in data], + self._ray_options, self._file_options) def _start_txn(self) -> Transaction: with self._storage.transaction() as txn: @@ -257,11 +286,8 @@ def append_from( ray_options.max_parallelism = min(len(source_fns), ray_options.max_parallelism) - op = RayAppendOp(self._storage.location, self._storage.metadata, - ray_options, self._file_options) - op.write_from(source_fns) - - return op.finish() + return _append_from(self._storage, source_fns, ray_options, + self._file_options) @StorageMixin.transactional def append_array_record(self, pattern: str, @@ -284,3 +310,12 @@ def _insert(self, data: InputData, def delete(self, filter_: pc.Expression) -> Optional[rt.Patch]: op = RayDeleteOp(self._storage, filter_, self._file_options) return op.delete() + + +def _append_from(storage: Storage, source_fns: Union[List[InputIteratorFn]], + ray_options: RayOptions, + file_options: FileOptions) -> Optional[rt.Patch]: + op = RayAppendOp(storage.location, storage.metadata, ray_options, + file_options) + op.write_from(source_fns) + return op.finish() diff --git a/python/tests/core/ops/test_change_data.py b/python/tests/core/ops/test_change_data.py index 766069f..4e79e67 100644 --- a/python/tests/core/ops/test_change_data.py +++ b/python/tests/core/ops/test_change_data.py @@ -45,14 +45,9 @@ def test_read_change_data(tmp_path, all_types_schema, all_types_input_data): runner.delete((pc.field("string") == "a") | (pc.field("string") == "A")) changes = list(runner.diff(1, 2)) assert len(changes) == 1 - expected_change1 = ChangeData( - ds.storage.metadata.current_snapshot_id, ChangeType.DELETE, - pa.Table.from_pydict({ - "int64": [1, 0], - "float64": [0.1, -0.1], - "bool": [True, False], - "string": ["a", "A"] - })) + expected_change1 = ChangeData(ds.storage.metadata.current_snapshot_id, + ChangeType.DELETE, + pa.Table.from_pydict({"int64": [1, 0]})) assert changes[0] == expected_change1 # Validate Upsert operation's changes. @@ -65,14 +60,9 @@ def test_read_change_data(tmp_path, all_types_schema, all_types_input_data): runner.upsert(upsert_data) changes = list(runner.diff(2, 3)) assert len(changes) == 2 - expected_change2 = ChangeData( - ds.storage.metadata.current_snapshot_id, ChangeType.DELETE, - pa.Table.from_pydict({ - "int64": [2, 3], - "float64": [0.2, 0.3], - "bool": [False, False], - "string": ["b", "c"] - })) + expected_change2 = ChangeData(ds.storage.metadata.current_snapshot_id, + ChangeType.DELETE, + pa.Table.from_pydict({"int64": [2, 3]})) expected_change3 = ChangeData(ds.storage.metadata.current_snapshot_id, ChangeType.ADD, pa.Table.from_pydict(upsert_data)) diff --git a/python/tests/ray/test_runners.py b/python/tests/ray/test_runners.py index 0dd4dbc..5fa3253 100644 --- a/python/tests/ray/test_runners.py +++ b/python/tests/ray/test_runners.py @@ -208,12 +208,9 @@ def test_diff_map_batches(self, tmp_path, sample_dataset, refresh_batch_size): # Test deletion. ds_runner.delete(pc.field("int64") == 2) - expected_change1 = ChangeData( - ds.storage.metadata.current_snapshot_id, ChangeType.DELETE, - pa.Table.from_pydict({ - "int64": [2], - "float64": [1.2] - })) + expected_change1 = ChangeData(ds.storage.metadata.current_snapshot_id, + ChangeType.DELETE, + pa.Table.from_pydict({"int64": [2]})) assert list(view_runner.diff(1, 2)) == [expected_change1] # Test that diff supports tags. @@ -257,13 +254,7 @@ def test_diff_map_batches(self, tmp_path, sample_dataset, refresh_batch_size): "binary": [b"b3", b"b4"] }) assert list(ds_runner.diff(2, 3)) == [ - ChangeData( - 3, ChangeType.DELETE, - pa.Table.from_pydict({ - "int64": [3], - "float64": [0.3], - "binary": [b"b3"] - })), + ChangeData(3, ChangeType.DELETE, pa.Table.from_pydict({"int64": [3]})), ChangeData( 3, ChangeType.ADD, pa.Table.from_pydict({ @@ -317,12 +308,8 @@ def test_diff_batch_size(self, tmp_path, sample_dataset): ray_runner.refresh(batch_size=2) assert list(ray_runner.read()) == [ pa.Table.from_pydict({ - "int64": [1, 2], - "float64": [1.1, 1.2], - }), - pa.Table.from_pydict({ - "int64": [3], - "float64": [1.3], + "int64": [1, 2, 3], + "float64": [1.1, 1.2, 1.3], }) ] @@ -356,10 +343,7 @@ def _sample_filter_udf(row: Dict[str, Any]) -> Dict[str, Any]: ds_runner.delete(pc.field("int64") == 2) expected_change1 = ChangeData( sample_dataset.storage.metadata.current_snapshot_id, ChangeType.DELETE, - pa.Table.from_pydict({ - "int64": [2], - "float64": [0.2] - })) + pa.Table.from_pydict({"int64": [2]})) assert list(view_runner.diff(1, 2)) == [expected_change1] # Test several changes.