From b43074a3288a1242ffde7bba8f7c1141890fd3f7 Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Fri, 8 Nov 2024 17:56:50 +0100 Subject: [PATCH 1/3] add decorator skipping nodes without dimensions --- multiscale_spatial_image/__init__.py | 2 ++ multiscale_spatial_image/utils.py | 24 ++++++++++++++++++++++++ 2 files changed, 26 insertions(+) create mode 100644 multiscale_spatial_image/utils.py diff --git a/multiscale_spatial_image/__init__.py b/multiscale_spatial_image/__init__.py index e002746..e334b74 100644 --- a/multiscale_spatial_image/__init__.py +++ b/multiscale_spatial_image/__init__.py @@ -7,9 +7,11 @@ "Methods", "to_multiscale", "itk_image_to_multiscale", + "skip_non_dimension_nodes", "__version__", ] from .__about__ import __version__ from .multiscale_spatial_image import MultiscaleSpatialImage from .to_multiscale import Methods, to_multiscale, itk_image_to_multiscale +from .utils import skip_non_dimension_nodes diff --git a/multiscale_spatial_image/utils.py b/multiscale_spatial_image/utils.py new file mode 100644 index 0000000..d9a8d9e --- /dev/null +++ b/multiscale_spatial_image/utils.py @@ -0,0 +1,24 @@ +from typing import Callable, Any +from xarray import Dataset +import functools + + +def skip_non_dimension_nodes( + func: Callable[[Dataset], Dataset], +) -> Callable[[Dataset], Dataset]: + """Skip nodes in Datatree that do not contain dimensions. + + This function implements the workaround of https://github.com/pydata/xarray/issues/9693. In particular, + we need this because of our DataTree representing multiscale image having a root node that does not have + dimensions. Several functions need to be mapped over the datasets in the datatree that depend on having + dimensions, e.g. a transpose. + """ + + @functools.wraps(func) + def _func(ds: Dataset, *args: Any, **kwargs: Any) -> Dataset: + # check if dimensions are present otherwise return verbatim + if len(ds.dims) == 0: + return ds + return func(ds, *args, **kwargs) + + return _func From 066edd2295d251c4f3559e7dbf17c708106bf0ab Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Fri, 8 Nov 2024 23:21:37 +0100 Subject: [PATCH 2/3] add test for decorator --- test/test_utils.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 test/test_utils.py diff --git a/test/test_utils.py b/test/test_utils.py new file mode 100644 index 0000000..57828c5 --- /dev/null +++ b/test/test_utils.py @@ -0,0 +1,24 @@ +from multiscale_spatial_image import skip_non_dimension_nodes +import numpy as np +from spatial_image import to_spatial_image +from multiscale_spatial_image import to_multiscale + + +def test_skip_nodes(): + data = np.zeros((2, 200, 200)) + dims = ("c", "y", "x") + scale_factors = [2, 2] + image = to_spatial_image(array_like=data, dims=dims) + multiscale_img = to_multiscale(image, scale_factors=scale_factors) + + @skip_non_dimension_nodes + def transpose(ds, *args, **kwargs): + return ds.transpose(*args, **kwargs) + + for scale in list(multiscale_img.keys()): + assert multiscale_img[scale]["image"].dims == ("c", "y", "x") + + # applying this function without skipping the root node would fail as the root node does not have dimensions. + result = multiscale_img.map_over_datasets(transpose, "y", "x", "c") + for scale in list(result.keys()): + assert result[scale]["image"].dims == ("y", "x", "c") From 5faadba15fb3b5b9615e3e466327e50df111b6bb Mon Sep 17 00:00:00 2001 From: Wouter-Michiel Vierdag Date: Fri, 8 Nov 2024 23:37:11 +0100 Subject: [PATCH 3/3] add documentation --- README.md | 52 ++++++++++++++++++++++++++++++++++++++++++++++ test/test_utils.py | 3 +-- 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 7a27eae..a4df9d9 100644 --- a/README.md +++ b/README.md @@ -83,6 +83,58 @@ DataTree('multiscales', parent=None) image (y, x) uint8 dask.array ``` +Map a function over datasets while skipping nodes that do not contain dimensions + +```python +import numpy as np +from spatial_image import to_spatial_image +from multiscale_spatial_image import skip_non_dimension_nodes, to_multiscale + +data = np.zeros((2, 200, 200)) +dims = ("c", "y", "x") +scale_factors = [2, 2] +image = to_spatial_image(array_like=data, dims=dims) +multiscale = to_multiscale(image, scale_factors=scale_factors) + +@skip_non_dimension_nodes +def transpose(ds, *args, **kwargs): + return ds.transpose(*args, **kwargs) + +multiscale = multiscale.map_over_datasets(transpose, "y", "x", "c") +print(multiscale) +``` + +A transposed MultiscaleSpatialImage. + +``` + +Group: / +├── Group: /scale0 +│ Dimensions: (c: 2, y: 200, x: 200) +│ Coordinates: +│ * c (c) int32 8B 0 1 +│ * y (y) float64 2kB 0.0 1.0 2.0 3.0 4.0 ... 196.0 197.0 198.0 199.0 +│ * x (x) float64 2kB 0.0 1.0 2.0 3.0 4.0 ... 196.0 197.0 198.0 199.0 +│ Data variables: +│ image (y, x, c) float64 640kB dask.array +├── Group: /scale1 +│ Dimensions: (c: 2, y: 100, x: 100) +│ Coordinates: +│ * c (c) int32 8B 0 1 +│ * y (y) float64 800B 0.5 2.5 4.5 6.5 8.5 ... 192.5 194.5 196.5 198.5 +│ * x (x) float64 800B 0.5 2.5 4.5 6.5 8.5 ... 192.5 194.5 196.5 198.5 +│ Data variables: +│ image (y, x, c) float64 160kB dask.array +└── Group: /scale2 + Dimensions: (c: 2, y: 50, x: 50) + Coordinates: + * c (c) int32 8B 0 1 + * y (y) float64 400B 1.5 5.5 9.5 13.5 17.5 ... 185.5 189.5 193.5 197.5 + * x (x) float64 400B 1.5 5.5 9.5 13.5 17.5 ... 185.5 189.5 193.5 197.5 + Data variables: + image (y, x, c) float64 40kB dask.array +``` + Store as an Open Microscopy Environment-Next Generation File Format ([OME-NGFF]) / [netCDF] [Zarr] store. diff --git a/test/test_utils.py b/test/test_utils.py index 57828c5..173a76a 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,7 +1,6 @@ -from multiscale_spatial_image import skip_non_dimension_nodes import numpy as np from spatial_image import to_spatial_image -from multiscale_spatial_image import to_multiscale +from multiscale_spatial_image import skip_non_dimension_nodes, to_multiscale def test_skip_nodes():