Skip to content

Commit

Permalink
add decorator skipping nodes without dimensions
Browse files Browse the repository at this point in the history
  • Loading branch information
melonora committed Nov 8, 2024
1 parent f3f6c03 commit b43074a
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
2 changes: 2 additions & 0 deletions multiscale_spatial_image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
"Methods",
"to_multiscale",
"itk_image_to_multiscale",
"skip_non_dimension_nodes",
"__version__",
]

from .__about__ import __version__
from .multiscale_spatial_image import MultiscaleSpatialImage
from .to_multiscale import Methods, to_multiscale, itk_image_to_multiscale
from .utils import skip_non_dimension_nodes
24 changes: 24 additions & 0 deletions multiscale_spatial_image/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from typing import Callable, Any
from xarray import Dataset
import functools


def skip_non_dimension_nodes(
func: Callable[[Dataset], Dataset],
) -> Callable[[Dataset], Dataset]:
"""Skip nodes in Datatree that do not contain dimensions.
This function implements the workaround of https://github.com/pydata/xarray/issues/9693. In particular,
we need this because of our DataTree representing multiscale image having a root node that does not have
dimensions. Several functions need to be mapped over the datasets in the datatree that depend on having
dimensions, e.g. a transpose.
"""

@functools.wraps(func)
def _func(ds: Dataset, *args: Any, **kwargs: Any) -> Dataset:
# check if dimensions are present otherwise return verbatim
if len(ds.dims) == 0:
return ds
return func(ds, *args, **kwargs)

return _func

0 comments on commit b43074a

Please sign in to comment.