From 4fee915ebba827e3421f5498d6d1ed996dd48d92 Mon Sep 17 00:00:00 2001 From: Joshua Gould Date: Fri, 12 Jul 2024 08:43:04 -0400 Subject: [PATCH 1/5] Support trailing channel axis for labeled_comprehension --- dask_image/ndmeasure/__init__.py | 10 ++++++---- dask_image/ndmeasure/_utils/__init__.py | 3 +-- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/dask_image/ndmeasure/__init__.py b/dask_image/ndmeasure/__init__.py index 5f4eaf48..061a2f1f 100644 --- a/dask_image/ndmeasure/__init__.py +++ b/dask_image/ndmeasure/__init__.py @@ -4,12 +4,12 @@ import functools import operator import warnings -from dask import compute, delayed import dask.array as da import dask.bag as db import dask.dataframe as dd import numpy as np +from dask import compute, delayed from . import _utils from ._utils import _label @@ -378,7 +378,6 @@ def label(image, structure=None, wrap_axes=None): relabeled = _label.relabel_blocks(block_labeled, new_labeling) n = da.max(relabeled) - return (relabeled, n) @@ -402,7 +401,9 @@ def labeled_comprehension(image, Parameters ---------- image : ndarray - N-D image data + Intensity image with same size as ``label_image``, plus optionally + an extra dimension for multichannel data. The extra channel dimension, + if present, must be the last axis. label_image : ndarray, optional Image features noted by integers. If None (default), all values. index : int or sequence of ints, optional @@ -448,7 +449,8 @@ def labeled_comprehension(image, result = np.empty(index.shape, dtype=object) for i in np.ndindex(index.shape): lbl_mtch_i = (label_image == index[i]) - args_lbl_mtch_i = tuple(e[lbl_mtch_i] for e in args) + args_lbl_mtch_i = tuple( + e[lbl_mtch_i] if e.ndim == 2 else e.reshape(-1, e.shape[2])[lbl_mtch_i.reshape(-1)] for e in args) result[i] = _utils._labeled_comprehension_func( func, out_dtype, default_1d, *args_lbl_mtch_i ) diff --git a/dask_image/ndmeasure/_utils/__init__.py b/dask_image/ndmeasure/_utils/__init__.py index bcb5b70f..33461cf3 100644 --- a/dask_image/ndmeasure/_utils/__init__.py +++ b/dask_image/ndmeasure/_utils/__init__.py @@ -30,8 +30,7 @@ def _norm_input_labels_index(image, label_image=None, index=None): "Having index with dimensionality greater than 1 is undefined.", FutureWarning ) - - if image.shape != label_image.shape: + if image.shape[:2] != label_image.shape: # allow trailing channel raise ValueError( "The image and label_image arrays must be the same shape." ) From 62b58701a43db219eebb02b7f3ae36ccff356964 Mon Sep 17 00:00:00 2001 From: Joshua Gould Date: Fri, 12 Jul 2024 08:56:47 -0400 Subject: [PATCH 2/5] fixed ndim check --- dask_image/ndmeasure/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dask_image/ndmeasure/__init__.py b/dask_image/ndmeasure/__init__.py index 061a2f1f..d9498936 100644 --- a/dask_image/ndmeasure/__init__.py +++ b/dask_image/ndmeasure/__init__.py @@ -450,7 +450,8 @@ def labeled_comprehension(image, for i in np.ndindex(index.shape): lbl_mtch_i = (label_image == index[i]) args_lbl_mtch_i = tuple( - e[lbl_mtch_i] if e.ndim == 2 else e.reshape(-1, e.shape[2])[lbl_mtch_i.reshape(-1)] for e in args) + e[lbl_mtch_i] if e.ndim == lbl_mtch_i.ndim else e.reshape(-1, e.shape[-1])[lbl_mtch_i.reshape(-1)] for e in + args) result[i] = _utils._labeled_comprehension_func( func, out_dtype, default_1d, *args_lbl_mtch_i ) From 09aabef692507f93905866cbd31933ea94bd0402 Mon Sep 17 00:00:00 2001 From: Joshua Gould Date: Fri, 12 Jul 2024 08:57:22 -0400 Subject: [PATCH 3/5] fixed shape check --- dask_image/ndmeasure/_utils/__init__.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/dask_image/ndmeasure/_utils/__init__.py b/dask_image/ndmeasure/_utils/__init__.py index 33461cf3..29a4e11d 100644 --- a/dask_image/ndmeasure/_utils/__init__.py +++ b/dask_image/ndmeasure/_utils/__init__.py @@ -30,9 +30,11 @@ def _norm_input_labels_index(image, label_image=None, index=None): "Having index with dimensionality greater than 1 is undefined.", FutureWarning ) - if image.shape[:2] != label_image.shape: # allow trailing channel + + image_shape = image.shape if image.ndim == label_image.ndim else image.shape[:-1] + if image_shape != label_image.shape: # allow trailing channel raise ValueError( - "The image and label_image arrays must be the same shape." + f"The image and label_image arrays must be the same shape. {image_shape} != {label_image.shape}" ) return (image, label_image, index) From ea976ab98d722131623a824368dfac25e7da6330 Mon Sep 17 00:00:00 2001 From: Joshua Gould Date: Fri, 12 Jul 2024 13:47:12 -0400 Subject: [PATCH 4/5] add skip_trailing_dim to _ravel_shape_indices --- dask_image/ndmeasure/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dask_image/ndmeasure/__init__.py b/dask_image/ndmeasure/__init__.py index d9498936..396d6ebf 100644 --- a/dask_image/ndmeasure/__init__.py +++ b/dask_image/ndmeasure/__init__.py @@ -442,7 +442,7 @@ def labeled_comprehension(image, args = (image,) if pass_positions: positions = _utils._ravel_shape_indices( - image.shape, chunks=image.chunks + image.shape, chunks=image.chunks, skip_trailing_dim=image.ndim != label_image.ndim ) args = (image, positions) From 6c8a31b27adec4556cee3de990179c31770d48e9 Mon Sep 17 00:00:00 2001 From: Joshua Gould Date: Fri, 12 Jul 2024 13:53:31 -0400 Subject: [PATCH 5/5] add skip_trailing_dim to _ravel_shape_indices --- dask_image/ndmeasure/_utils/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dask_image/ndmeasure/_utils/__init__.py b/dask_image/ndmeasure/_utils/__init__.py index 29a4e11d..28b0e1dd 100644 --- a/dask_image/ndmeasure/_utils/__init__.py +++ b/dask_image/ndmeasure/_utils/__init__.py @@ -48,7 +48,7 @@ def _ravel_shape_indices_kernel(*args): return sum(args2) -def _ravel_shape_indices(dimensions, dtype=int, chunks=None): +def _ravel_shape_indices(dimensions, dtype=int, chunks=None, skip_trailing_dim:bool=False): """ Gets the raveled indices shaped like input. """ @@ -61,7 +61,7 @@ def _ravel_shape_indices(dimensions, dtype=int, chunks=None): dtype=dtype, chunks=c ) - for i, c in enumerate(chunks) + for i, c in enumerate(chunks[:-1] if skip_trailing_dim else chunks) ] indices = da.blockwise(