Skip to content

Commit

Permalink
Merge pull request #81 from giovp/giovp/datatree_accessor
Browse files Browse the repository at this point in the history
datatree accessor to multiscale spatial image
  • Loading branch information
thewtex authored Oct 17, 2023
2 parents bc66a50 + 7388a4c commit 0f3d295
Show file tree
Hide file tree
Showing 6 changed files with 225 additions and 170 deletions.
215 changes: 112 additions & 103 deletions examples/ConvertPyImageJDataset.ipynb

Large diffs are not rendered by default.

118 changes: 78 additions & 40 deletions examples/ConvertTiffFile.ipynb

Large diffs are not rendered by default.

43 changes: 24 additions & 19 deletions multiscale_spatial_image/multiscale_spatial_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@
from collections.abc import MutableMapping
from pathlib import Path
from zarr.storage import BaseStore
import xarray as xr
from datatree import register_datatree_accessor


class MultiscaleSpatialImage(DataTree):
@register_datatree_accessor("msi")
class MultiscaleSpatialImage:
"""A multi-scale representation of a spatial image.
This is an xarray DataTree, with content compatible with the Open Microscopy Environment-
Expand All @@ -24,17 +27,16 @@ class MultiscaleSpatialImage(DataTree):
scale2
"""

def __init__(
def __init__(self, xarray_obj: DataTree):
self._dt = xarray_obj

def to_zarr(
self,
name: str = "multiscales",
data: Union[xr.Dataset, xr.DataArray] = None,
parent: TreeNode = None,
children: List[TreeNode] = None,
store: Union[MutableMapping, str, Path, BaseStore],
mode: str = "w",
encoding=None,
**kwargs,
):
"""DataTree with a root name of *multiscales*."""
super().__init__(data=data, name=name, parent=parent, children=children)

def to_zarr(self, store: Union[MutableMapping, str, Path, BaseStore], mode: str = "w", encoding=None, **kwargs):
"""
Write multi-scale spatial image contents to a Zarr store.
Expand All @@ -57,29 +59,32 @@ def to_zarr(self, store: Union[MutableMapping, str, Path, BaseStore], mode: str
"""

multiscales = []
scale0 = self[self.groups[1]]
scale0 = self._dt[self._dt.groups[1]]
for name in scale0.ds.data_vars.keys():

ngff_datasets = []
for child in self.children:
image = self[child].ds
for child in self._dt.children:
image = self._dt[child].ds
scale_transform = []
translate_transform = []
for dim in image.dims:
if len(image.coords[dim]) > 1 and np.issubdtype(image.coords[dim].dtype, np.number):
if len(image.coords[dim]) > 1 and np.issubdtype(
image.coords[dim].dtype, np.number
):
scale_transform.append(
float(image.coords[dim][1] - image.coords[dim][0])
)
else:
scale_transform.append(1.0)
if len(image.coords[dim]) > 0 and np.issubdtype(image.coords[dim].dtype, np.number):
if len(image.coords[dim]) > 0 and np.issubdtype(
image.coords[dim].dtype, np.number
):
translate_transform.append(float(image.coords[dim][0]))
else:
translate_transform.append(0.0)

ngff_datasets.append(
{
"path": f"{self[child].name}/{name}",
"path": f"{self._dt[child].name}/{name}",
"coordinateTransformations": [
{
"type": "scale",
Expand Down Expand Up @@ -117,6 +122,6 @@ def to_zarr(self, store: Union[MutableMapping, str, Path, BaseStore], mode: str

# NGFF v0.4 metadata
ngff_metadata = {"multiscales": multiscales, "multiscaleSpatialImageVersion": 1}
self.ds = self.ds.assign_attrs(**ngff_metadata)
self._dt.ds = self._dt.ds.assign_attrs(**ngff_metadata)

super().to_zarr(store, **kwargs)
self._dt.to_zarr(store, **kwargs)
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from spatial_image import to_spatial_image

from .to_multiscale import to_multiscale, Methods
from ..multiscale_spatial_image import MultiscaleSpatialImage
from datatree import DataTree

def itk_image_to_multiscale(
image,
Expand All @@ -20,7 +20,7 @@ def itk_image_to_multiscale(
Tuple[Tuple[int, ...], ...],
Mapping[Any, Union[None, int, Tuple[int, ...]]],
]
] = None) -> MultiscaleSpatialImage:
] = None) -> DataTree:

import itk
import numpy as np
Expand Down
8 changes: 4 additions & 4 deletions multiscale_spatial_image/to_multiscale/to_multiscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dask.array import map_blocks, map_overlap
import numpy as np

from ..multiscale_spatial_image import MultiscaleSpatialImage
from datatree import DataTree

from ._xarray import _downsample_xarray_coarsen
from ._itk import _downsample_itk_bin_shrink, _downsample_itk_gaussian, _downsample_itk_label
Expand Down Expand Up @@ -35,7 +35,7 @@ def to_multiscale(
Mapping[Any, Union[None, int, Tuple[int, ...]]],
]
] = None,
) -> MultiscaleSpatialImage:
) -> DataTree:
"""\
Generate a multiscale representation of a spatial image.
Expand Down Expand Up @@ -67,7 +67,7 @@ def to_multiscale(
Returns
-------
result : MultiscaleSpatialImage
result : DataTree
Multiscale representation. An xarray DataTree where each node is a SpatialImage Dataset
named by the integer scale. Increasing scales are downscaled versions of the input image.
"""
Expand Down Expand Up @@ -120,7 +120,7 @@ def to_multiscale(
elif method is Methods.DASK_IMAGE_MODE:
data_objects = _downsample_dask_image(current_input, default_chunks, out_chunks, scale_factors, data_objects, image, label='mode')

multiscale = MultiscaleSpatialImage.from_dict(
multiscale = DataTree.from_dict(
d=data_objects
)

Expand Down
7 changes: 5 additions & 2 deletions test/test_ngff_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from referencing.jsonschema import DRAFT202012
from jsonschema import Draft202012Validator

from datatree import DataTree

from multiscale_spatial_image import to_multiscale, MultiscaleSpatialImage
from spatial_image import to_spatial_image
import numpy as np
Expand All @@ -23,9 +25,10 @@ def load_schema(version: str = "0.4", strict: bool = False) -> Dict:
schema = json.loads(response.data.decode())
return schema

def check_valid_ngff(multiscale: MultiscaleSpatialImage):
def check_valid_ngff(multiscale: DataTree):
store = zarr.storage.MemoryStore(dimension_separator="/")
multiscale.to_zarr(store, compute=True)
assert isinstance(multiscale.msi, MultiscaleSpatialImage)
multiscale.msi.to_zarr(store, compute=True)
zarr.convenience.consolidate_metadata(store)
metadata = json.loads(store.get(".zmetadata"))["metadata"]
ngff = metadata[".zattrs"]
Expand Down

0 comments on commit 0f3d295

Please sign in to comment.