From f2cc1090f41d81e1c38f650b57a5c87f5a23171d Mon Sep 17 00:00:00 2001 From: Thomas Zilio Date: Wed, 20 Nov 2024 18:33:39 +0100 Subject: [PATCH] fix: Include dtype information in the Partitioning configuration. --- zcollection/merging/tests/test_merging.py | 13 ++++--- zcollection/partitioning/abc.py | 4 +- zcollection/partitioning/date.py | 7 ++++ zcollection/partitioning/tests/test_date.py | 38 ++++++++++++++----- .../partitioning/tests/test_sequence.py | 26 +++++++++---- 5 files changed, 64 insertions(+), 24 deletions(-) diff --git a/zcollection/merging/tests/test_merging.py b/zcollection/merging/tests/test_merging.py index a8dc080..1e535c9 100644 --- a/zcollection/merging/tests/test_merging.py +++ b/zcollection/merging/tests/test_merging.py @@ -45,12 +45,13 @@ def test_update_fs( """Test the _update_fs function.""" generator = data.create_test_dataset(delayed=False) zds = next(generator) + zds_sc = dask_client.scatter(zds) partition_folder = local_fs.root.joinpath('variable=1') zattrs = str(partition_folder.joinpath('.zattrs')) - future = dask_client.submit(_update_fs, str(partition_folder), - dask_client.scatter(zds), local_fs.fs) + future = dask_client.submit(_update_fs, str(partition_folder), zds_sc, + local_fs.fs) dask_client.gather(future) assert local_fs.exists(zattrs) @@ -60,7 +61,7 @@ def test_update_fs( try: future = dask_client.submit(_update_fs, str(partition_folder), - dask_client.scatter(zds), + zds_sc, local_fs.fs, synchronizer=ThrowError()) dask_client.gather(future) @@ -83,13 +84,13 @@ def test_perform( zds = next(generator) path = str(local_fs.root.joinpath('variable=1')) + zds_sc = dask_client.scatter(zds) - future = dask_client.submit(_update_fs, path, dask_client.scatter(zds), - local_fs.fs) + future = dask_client.submit(_update_fs, path, zds_sc, local_fs.fs) dask_client.gather(future) future = dask_client.submit(perform, - dask_client.scatter(zds), + zds_sc, path, 'time', local_fs.fs, diff --git a/zcollection/partitioning/abc.py b/zcollection/partitioning/abc.py index 64be840..8c1dc2c 100644 --- a/zcollection/partitioning/abc.py +++ b/zcollection/partitioning/abc.py @@ -106,7 +106,6 @@ def unique_and_check_monotony(arr: ArrayLike) -> tuple[NDArray, NDArray]: Args: arr: Array of elements. - is_delayed: If True, the array is delayed. Returns: Tuple of unique elements and their indices. """ @@ -331,12 +330,13 @@ def get_config(self) -> dict[str, Any]: Returns: The configuration of the partitioning scheme. """ - config: dict[str, str | None] = {'id': self.ID} + config: dict[str, str | tuple[str, ...] | None] = {'id': self.ID} slots: Generator[tuple[str, ...]] = (getattr( _class, '__slots__', ()) for _class in reversed(self.__class__.__mro__)) config.update((attr, getattr(self, attr)) for _class in slots for attr in _class if not attr.startswith('_')) + config['dtype'] = self._dtype return config @classmethod diff --git a/zcollection/partitioning/date.py b/zcollection/partitioning/date.py index dfde623..1e72d8e 100644 --- a/zcollection/partitioning/date.py +++ b/zcollection/partitioning/date.py @@ -217,3 +217,10 @@ def decode( py_datetime: datetime.datetime = datetime64.astype('M8[s]').item() return tuple((UNITS[ix], getattr(py_datetime, self._attrs[ix])) for ix in self._index) + + def get_config(self) -> dict[str, Any]: + config = super().get_config() + + # dtype are automatically computed by this partitioning + config.pop('dtype') + return config diff --git a/zcollection/partitioning/tests/test_date.py b/zcollection/partitioning/tests/test_date.py index a35b93a..4a1f54f 100644 --- a/zcollection/partitioning/tests/test_date.py +++ b/zcollection/partitioning/tests/test_date.py @@ -164,23 +164,43 @@ def test_construction() -> None: Date(('dates', ), 'W') -def test_config(): +RESOLUTION_DTYPE_TEST_SET = [ + ('Y', (('year', 'uint16'), )), + ('M', (('year', 'uint16'), ('month', 'uint8'))), + ('D', (('year', 'uint16'), ('month', 'uint8'), ('day', 'uint8'))), + ('h', (('year', 'uint16'), ('month', 'uint8'), ('day', 'uint8'), + ('hour', 'uint8'))), + ('m', (('year', 'uint16'), ('month', 'uint8'), ('day', 'uint8'), + ('hour', 'uint8'), ('minute', 'uint8'))), + ('s', (('year', 'uint16'), ('month', 'uint8'), ('day', 'uint8'), + ('hour', 'uint8'), ('minute', 'uint8'), ('second', 'uint8'))) +] + + +@pytest.mark.parametrize('resolution, dtype', RESOLUTION_DTYPE_TEST_SET) +def test_config(resolution, dtype): """Test the configuration of the Date class.""" - partitioning = Date(('dates', ), 'D') - assert partitioning.dtype() == (('year', 'uint16'), ('month', 'uint8'), - ('day', 'uint8')) + partitioning = Date(variables=('dates', ), resolution=resolution) + assert partitioning.dtype() == dtype + config = partitioning.get_config() - partitioning = get_codecs(config) - assert isinstance(partitioning, Date) + other = get_codecs(config) + + assert isinstance(other, Date) + assert other.variables == ('dates', ) + assert other.dtype() == dtype -def test_pickle(): +@pytest.mark.parametrize('resolution, dtype', RESOLUTION_DTYPE_TEST_SET) +def test_pickle(resolution, dtype): """Test the pickling of the Date class.""" - partitioning = Date(('dates', ), 'D') + partitioning = Date(('dates', ), resolution=resolution) other = pickle.loads(pickle.dumps(partitioning)) + assert isinstance(other, Date) - assert other.resolution == 'D' + assert other.resolution == resolution assert other.variables == ('dates', ) + assert other.dtype() == dtype @pytest.mark.parametrize('delayed', [False, True]) diff --git a/zcollection/partitioning/tests/test_sequence.py b/zcollection/partitioning/tests/test_sequence.py index 3a56e9d..7905898 100644 --- a/zcollection/partitioning/tests/test_sequence.py +++ b/zcollection/partitioning/tests/test_sequence.py @@ -113,20 +113,32 @@ def test_split_dataset( list(partitioning.split_dataset(zds, 'num_lines')) -def test_config() -> None: +VARIABLES_DTYPE_TEST_SET = [(('a', ), None), (('a', ), ('uint8', )), + (('a', 'b'), None), + (('a', 'b'), ('int8', 'int16'))] + + +@pytest.mark.parametrize('variables, dtype', VARIABLES_DTYPE_TEST_SET) +def test_config(variables, dtype) -> None: """Test the configuration of the Sequence class.""" - partitioning = Sequence(('cycle_number', 'pass_number')) + partitioning = Sequence(variables=variables, dtype=dtype) + config = partitioning.get_config() - partitioning = get_codecs(config) # type: ignore[assignment] - assert isinstance(partitioning, Sequence) + other = get_codecs(config) # type: ignore[assignment] + assert isinstance(other, Sequence) + assert other.dtype() == partitioning.dtype() -def test_pickle() -> None: + +@pytest.mark.parametrize('variables, dtype', VARIABLES_DTYPE_TEST_SET) +def test_pickle(variables, dtype) -> None: """Test the pickling of the Date class.""" - partitioning = Sequence(('cycle_number', 'pass_number')) + partitioning = Sequence(variables=variables, dtype=dtype) + other = pickle.loads(pickle.dumps(partitioning)) + assert isinstance(other, Sequence) - assert other.variables == ('cycle_number', 'pass_number') + assert other.dtype() == partitioning.dtype() # pylint: disable=protected-access