Skip to content

Commit

Permalink
Merge pull request #101 from melonora/decorator
Browse files Browse the repository at this point in the history
Add decorator for skipping nodes without dimension
  • Loading branch information
melonora authored Nov 8, 2024
2 parents f3f6c03 + 5faadba commit c47c68c
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 0 deletions.
52 changes: 52 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,58 @@ DataTree('multiscales', parent=None)
image (y, x) uint8 dask.array<chunksize=(16, 16), meta=np.ndarray>
```

Map a function over datasets while skipping nodes that do not contain dimensions

```python
import numpy as np
from spatial_image import to_spatial_image
from multiscale_spatial_image import skip_non_dimension_nodes, to_multiscale

data = np.zeros((2, 200, 200))
dims = ("c", "y", "x")
scale_factors = [2, 2]
image = to_spatial_image(array_like=data, dims=dims)
multiscale = to_multiscale(image, scale_factors=scale_factors)

@skip_non_dimension_nodes
def transpose(ds, *args, **kwargs):
return ds.transpose(*args, **kwargs)

multiscale = multiscale.map_over_datasets(transpose, "y", "x", "c")
print(multiscale)
```

A transposed MultiscaleSpatialImage.

```
<xarray.DataTree>
Group: /
├── Group: /scale0
│ Dimensions: (c: 2, y: 200, x: 200)
│ Coordinates:
│ * c (c) int32 8B 0 1
│ * y (y) float64 2kB 0.0 1.0 2.0 3.0 4.0 ... 196.0 197.0 198.0 199.0
│ * x (x) float64 2kB 0.0 1.0 2.0 3.0 4.0 ... 196.0 197.0 198.0 199.0
│ Data variables:
│ image (y, x, c) float64 640kB dask.array<chunksize=(200, 200, 2), meta=np.ndarray>
├── Group: /scale1
│ Dimensions: (c: 2, y: 100, x: 100)
│ Coordinates:
│ * c (c) int32 8B 0 1
│ * y (y) float64 800B 0.5 2.5 4.5 6.5 8.5 ... 192.5 194.5 196.5 198.5
│ * x (x) float64 800B 0.5 2.5 4.5 6.5 8.5 ... 192.5 194.5 196.5 198.5
│ Data variables:
│ image (y, x, c) float64 160kB dask.array<chunksize=(100, 100, 2), meta=np.ndarray>
└── Group: /scale2
Dimensions: (c: 2, y: 50, x: 50)
Coordinates:
* c (c) int32 8B 0 1
* y (y) float64 400B 1.5 5.5 9.5 13.5 17.5 ... 185.5 189.5 193.5 197.5
* x (x) float64 400B 1.5 5.5 9.5 13.5 17.5 ... 185.5 189.5 193.5 197.5
Data variables:
image (y, x, c) float64 40kB dask.array<chunksize=(50, 50, 2), meta=np.ndarray>
```

Store as an Open Microscopy Environment-Next Generation File Format ([OME-NGFF])
/ [netCDF] [Zarr] store.

Expand Down
2 changes: 2 additions & 0 deletions multiscale_spatial_image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
"Methods",
"to_multiscale",
"itk_image_to_multiscale",
"skip_non_dimension_nodes",
"__version__",
]

from .__about__ import __version__
from .multiscale_spatial_image import MultiscaleSpatialImage
from .to_multiscale import Methods, to_multiscale, itk_image_to_multiscale
from .utils import skip_non_dimension_nodes
24 changes: 24 additions & 0 deletions multiscale_spatial_image/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from typing import Callable, Any
from xarray import Dataset
import functools


def skip_non_dimension_nodes(
func: Callable[[Dataset], Dataset],
) -> Callable[[Dataset], Dataset]:
"""Skip nodes in Datatree that do not contain dimensions.
This function implements the workaround of https://github.com/pydata/xarray/issues/9693. In particular,
we need this because of our DataTree representing multiscale image having a root node that does not have
dimensions. Several functions need to be mapped over the datasets in the datatree that depend on having
dimensions, e.g. a transpose.
"""

@functools.wraps(func)
def _func(ds: Dataset, *args: Any, **kwargs: Any) -> Dataset:
# check if dimensions are present otherwise return verbatim
if len(ds.dims) == 0:
return ds
return func(ds, *args, **kwargs)

return _func
23 changes: 23 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import numpy as np
from spatial_image import to_spatial_image
from multiscale_spatial_image import skip_non_dimension_nodes, to_multiscale


def test_skip_nodes():
data = np.zeros((2, 200, 200))
dims = ("c", "y", "x")
scale_factors = [2, 2]
image = to_spatial_image(array_like=data, dims=dims)
multiscale_img = to_multiscale(image, scale_factors=scale_factors)

@skip_non_dimension_nodes
def transpose(ds, *args, **kwargs):
return ds.transpose(*args, **kwargs)

for scale in list(multiscale_img.keys()):
assert multiscale_img[scale]["image"].dims == ("c", "y", "x")

# applying this function without skipping the root node would fail as the root node does not have dimensions.
result = multiscale_img.map_over_datasets(transpose, "y", "x", "c")
for scale in list(result.keys()):
assert result[scale]["image"].dims == ("y", "x", "c")

0 comments on commit c47c68c

Please sign in to comment.