diff --git a/multiscale_spatial_image/__init__.py b/multiscale_spatial_image/__init__.py index e002746..e334b74 100644 --- a/multiscale_spatial_image/__init__.py +++ b/multiscale_spatial_image/__init__.py @@ -7,9 +7,11 @@ "Methods", "to_multiscale", "itk_image_to_multiscale", + "skip_non_dimension_nodes", "__version__", ] from .__about__ import __version__ from .multiscale_spatial_image import MultiscaleSpatialImage from .to_multiscale import Methods, to_multiscale, itk_image_to_multiscale +from .utils import skip_non_dimension_nodes diff --git a/multiscale_spatial_image/utils.py b/multiscale_spatial_image/utils.py new file mode 100644 index 0000000..d9a8d9e --- /dev/null +++ b/multiscale_spatial_image/utils.py @@ -0,0 +1,24 @@ +from typing import Callable, Any +from xarray import Dataset +import functools + + +def skip_non_dimension_nodes( + func: Callable[[Dataset], Dataset], +) -> Callable[[Dataset], Dataset]: + """Skip nodes in Datatree that do not contain dimensions. + + This function implements the workaround of https://github.com/pydata/xarray/issues/9693. In particular, + we need this because of our DataTree representing multiscale image having a root node that does not have + dimensions. Several functions need to be mapped over the datasets in the datatree that depend on having + dimensions, e.g. a transpose. + """ + + @functools.wraps(func) + def _func(ds: Dataset, *args: Any, **kwargs: Any) -> Dataset: + # check if dimensions are present otherwise return verbatim + if len(ds.dims) == 0: + return ds + return func(ds, *args, **kwargs) + + return _func