Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add decorator for skipping nodes without dimension #101

Merged
merged 3 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,58 @@ DataTree('multiscales', parent=None)
image (y, x) uint8 dask.array<chunksize=(16, 16), meta=np.ndarray>
```

Map a function over datasets while skipping nodes that do not contain dimensions

```python
import numpy as np
from spatial_image import to_spatial_image
from multiscale_spatial_image import skip_non_dimension_nodes, to_multiscale

data = np.zeros((2, 200, 200))
dims = ("c", "y", "x")
scale_factors = [2, 2]
image = to_spatial_image(array_like=data, dims=dims)
multiscale = to_multiscale(image, scale_factors=scale_factors)

@skip_non_dimension_nodes
def transpose(ds, *args, **kwargs):
return ds.transpose(*args, **kwargs)

multiscale = multiscale.map_over_datasets(transpose, "y", "x", "c")
print(multiscale)
```

A transposed MultiscaleSpatialImage.

```
<xarray.DataTree>
Group: /
├── Group: /scale0
│ Dimensions: (c: 2, y: 200, x: 200)
│ Coordinates:
│ * c (c) int32 8B 0 1
│ * y (y) float64 2kB 0.0 1.0 2.0 3.0 4.0 ... 196.0 197.0 198.0 199.0
│ * x (x) float64 2kB 0.0 1.0 2.0 3.0 4.0 ... 196.0 197.0 198.0 199.0
│ Data variables:
│ image (y, x, c) float64 640kB dask.array<chunksize=(200, 200, 2), meta=np.ndarray>
├── Group: /scale1
│ Dimensions: (c: 2, y: 100, x: 100)
│ Coordinates:
│ * c (c) int32 8B 0 1
│ * y (y) float64 800B 0.5 2.5 4.5 6.5 8.5 ... 192.5 194.5 196.5 198.5
│ * x (x) float64 800B 0.5 2.5 4.5 6.5 8.5 ... 192.5 194.5 196.5 198.5
│ Data variables:
│ image (y, x, c) float64 160kB dask.array<chunksize=(100, 100, 2), meta=np.ndarray>
└── Group: /scale2
Dimensions: (c: 2, y: 50, x: 50)
Coordinates:
* c (c) int32 8B 0 1
* y (y) float64 400B 1.5 5.5 9.5 13.5 17.5 ... 185.5 189.5 193.5 197.5
* x (x) float64 400B 1.5 5.5 9.5 13.5 17.5 ... 185.5 189.5 193.5 197.5
Data variables:
image (y, x, c) float64 40kB dask.array<chunksize=(50, 50, 2), meta=np.ndarray>
```

Store as an Open Microscopy Environment-Next Generation File Format ([OME-NGFF])
/ [netCDF] [Zarr] store.

Expand Down
2 changes: 2 additions & 0 deletions multiscale_spatial_image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
"Methods",
"to_multiscale",
"itk_image_to_multiscale",
"skip_non_dimension_nodes",
"__version__",
]

from .__about__ import __version__
from .multiscale_spatial_image import MultiscaleSpatialImage
from .to_multiscale import Methods, to_multiscale, itk_image_to_multiscale
from .utils import skip_non_dimension_nodes
24 changes: 24 additions & 0 deletions multiscale_spatial_image/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from typing import Callable, Any
from xarray import Dataset
import functools


def skip_non_dimension_nodes(
func: Callable[[Dataset], Dataset],
) -> Callable[[Dataset], Dataset]:
"""Skip nodes in Datatree that do not contain dimensions.

This function implements the workaround of https://github.com/pydata/xarray/issues/9693. In particular,
we need this because of our DataTree representing multiscale image having a root node that does not have
dimensions. Several functions need to be mapped over the datasets in the datatree that depend on having
dimensions, e.g. a transpose.
"""

@functools.wraps(func)
def _func(ds: Dataset, *args: Any, **kwargs: Any) -> Dataset:
# check if dimensions are present otherwise return verbatim
if len(ds.dims) == 0:
return ds
return func(ds, *args, **kwargs)

return _func
23 changes: 23 additions & 0 deletions test/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import numpy as np
from spatial_image import to_spatial_image
from multiscale_spatial_image import skip_non_dimension_nodes, to_multiscale


def test_skip_nodes():
data = np.zeros((2, 200, 200))
dims = ("c", "y", "x")
scale_factors = [2, 2]
image = to_spatial_image(array_like=data, dims=dims)
multiscale_img = to_multiscale(image, scale_factors=scale_factors)

@skip_non_dimension_nodes
def transpose(ds, *args, **kwargs):
return ds.transpose(*args, **kwargs)

for scale in list(multiscale_img.keys()):
assert multiscale_img[scale]["image"].dims == ("c", "y", "x")

# applying this function without skipping the root node would fail as the root node does not have dimensions.
result = multiscale_img.map_over_datasets(transpose, "y", "x", "c")
for scale in list(result.keys()):
assert result[scale]["image"].dims == ("y", "x", "c")
Loading