From 6e6ec0031f09657e8ef703e0d4bac4ac8d500739 Mon Sep 17 00:00:00 2001 From: William Silversmith Date: Tue, 23 Jul 2024 14:19:09 -0400 Subject: [PATCH] feat: stacked CCL for huge arrays (#127) * refactor: move cc3d C++ logic and pure python logic away from each other * feat: 6-connected stacked ccl * test: finish simple test of stacked ccl * fix: errors in joining logic * feat+test: support 26 connectivity * redesign: support only 6 or 26 connectivity * perf: refit ccl array as it comes out * fix: off-by-one error * fix: dtype error for 0 sized arrays * fixtest: use fastremap.renumber * fix: another off-by-one error * docs: show how to use connected_components_stack * docs: tell users how to fix import errors * install: fix name of "extras_requires" * docs: how to install --- .gitignore | 3 +- MANIFEST.in | 6 +- README.md | 16 + automated_test.py | 31 +- cc3d/__init__.py | 280 ++++++++++++++++++ cc3d.hpp => cc3d/cc3d.hpp | 0 .../cc3d_continuous.hpp | 0 cc3d_graphs.hpp => cc3d/cc3d_graphs.hpp | 0 cc3d.pyx => cc3d/fastcc3d.pyx | 116 +------- setup.py | 9 +- 10 files changed, 341 insertions(+), 120 deletions(-) create mode 100644 cc3d/__init__.py rename cc3d.hpp => cc3d/cc3d.hpp (100%) rename cc3d_continuous.hpp => cc3d/cc3d_continuous.hpp (100%) rename cc3d_graphs.hpp => cc3d/cc3d_graphs.hpp (100%) rename cc3d.pyx => cc3d/fastcc3d.pyx (94%) diff --git a/.gitignore b/.gitignore index 4f4963a..24398e6 100644 --- a/.gitignore +++ b/.gitignore @@ -144,4 +144,5 @@ callgrind* subvol.bin -cc3d.cpp \ No newline at end of file +cc3d.cpp +cc3d/fastcc3d.cpp \ No newline at end of file diff --git a/MANIFEST.in b/MANIFEST.in index 75819df..2afc026 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,5 +1,5 @@ -include cc3d.hpp -include cc3d_graphs.hpp -include cc3d_continuous.hpp +include cc3d/cc3d.hpp +include cc3d/cc3d_graphs.hpp +include cc3d/cc3d_continuous.hpp include LICENSE include COPYING.LESSER \ No newline at end of file diff --git a/README.md b/README.md index 8688342..33d0ebe 100644 --- a/README.md +++ b/README.md @@ -99,6 +99,22 @@ labels_out = cc3d.connected_components(labels_in, delta=10) labels_in = np.memmap("labels.bin", order="F", dtype=np.uint32, shape=(5000, 5000, 2000)) labels_out = cc3d.connected_components(labels_in, out_file="out.bin") +# Here's another strategy that you can use for huge files that won't even +# take up any disk space. Provide any iterator to this function that produces +# thick z sections of the input array that are in sequential order. +# The output is a highly compressed CrackleArray that is still random access. +# See: https://github.com/seung-lab/crackle +# You need to pip install connected-components-3d[stack] to get the extra modules. +def sections(labels_in): + """ + A generator that produces thick Z slices + of an image + """ + for z in range(0, labels_in.shape[2], 100): + yield labels_in[:,:,z:z+100] + +compressed_labels_out = cc3d.connected_components_stack(sections(labels)) + # You can extract the number of labels (which is also the maximum # label value) like so: labels_out, N = cc3d.connected_components(labels_in, return_N=True) # free diff --git a/automated_test.py b/automated_test.py index 10f9cd2..33c7215 100644 --- a/automated_test.py +++ b/automated_test.py @@ -1,6 +1,7 @@ import pytest import sys +import fastremap import cc3d import numpy as np @@ -1490,15 +1491,43 @@ def test_pytorch_integration_ccl_doesnt_crash(): assert isinstance(out, torch.Tensor) assert torch.all(out == labels) +@pytest.mark.parametrize("connectivity", [6, 26]) +def test_connected_components_stack(connectivity): + stack = [ + np.ones([100,100,100], dtype=np.uint32) + for i in range(4) + ] + stack += [ np.zeros([100,100,1], dtype=np.uint32) ] + stack += [ + np.ones([100,100,100], dtype=np.uint32) + for i in range(2) + ] + arr = cc3d.connected_components_stack(stack, connectivity=connectivity) + ans = np.ones([100,100,601], dtype=np.uint32) + ans[:,:,400] = 0 + ans[:,:,401:] = 2 + assert np.all(np.unique(arr[:]) == [ 0, 1, 2 ]) + assert np.all(arr[:] == ans) + image = np.random.randint(0,100, size=[100,100,11], dtype=np.uint8) + ans = cc3d.connected_components(image, connectivity=connectivity) + ans, _ = fastremap.renumber(ans[:]) + stack = [ + image[:,:,:2], + image[:,:,2:6], + image[:,:,6:9], + image[:,:,9:11], + ] + res = cc3d.connected_components_stack(stack, connectivity=connectivity) + res, _ = fastremap.renumber(res[:]) - + assert np.all(res[:] == ans[:,:,:11]) diff --git a/cc3d/__init__.py b/cc3d/__init__.py new file mode 100644 index 0000000..1a2dfbd --- /dev/null +++ b/cc3d/__init__.py @@ -0,0 +1,280 @@ +from typing import ( + Dict, Union, Tuple, Iterator, + Sequence, Optional, Any, BinaryIO +) + +import fastcc3d +from fastcc3d import ( + connected_components, + statistics, + each, + contacts, + region_graph, + voxel_connectivity_graph, + color_connectivity_graph, + estimate_provisional_labels, +) + +import numpy as np + +def dust( + img:np.ndarray, + threshold:Union[int,float], + connectivity:int = 26, + in_place:bool = False, +) -> np.ndarray: + """ + Remove from the input image connected components + smaller than threshold ("dust"). The name of the function + can be read as a verb "to dust" the image. + + img: 2D or 3D image + threshold: discard components smaller than this in voxels + connectivity: cc3d connectivity to use + in_place: whether to modify the input image or perform + dust + + Returns: dusted image + """ + orig_dtype = img.dtype + img = _view_as_unsigned(img) + + if not in_place: + img = np.copy(img) + + cc_labels, N = connected_components( + img, connectivity=connectivity, return_N=True + ) + stats = statistics(cc_labels) + mask_sizes = stats["voxel_counts"] + del stats + + to_mask = [ + i for i in range(1, N+1) if mask_sizes[i] < threshold + ] + + if len(to_mask) == 0: + return img + + mask = np.isin(cc_labels, to_mask) + del cc_labels + np.logical_not(mask, out=mask) + np.multiply(img, mask, out=img) + return img.view(orig_dtype) + +def largest_k( + img:np.ndarray, + k:int, + connectivity:int = 26, + delta:Union[int,float] = 0, + return_N:bool = False, +) -> np.ndarray: + """ + Returns the k largest connected components + in the image. + """ + assert k >= 0 + + order = "C" if img.flags.c_contiguous else "F" + + if k == 0: + return np.zeros(img.shape, dtype=np.uint16, order=order) + + cc_labels, N = connected_components( + img, connectivity=connectivity, + return_N=True, delta=delta, + ) + if N <= k: + if return_N: + return cc_labels, N + return cc_labels + + cts = statistics(cc_labels)["voxel_counts"] + preserve = [ (i,ct) for i,ct in enumerate(cts) if i > 0 ] + preserve.sort(key=lambda x: x[1]) + preserve = [ x[0] for x in preserve[-k:] ] + + shape, dtype = cc_labels.shape, cc_labels.dtype + rns = fastcc3d.runs(cc_labels) + + order = "C" if cc_labels.flags.c_contiguous else "F" + del cc_labels + + cc_out = np.zeros(shape, dtype=dtype, order=order) + for i, label in enumerate(preserve): + fastcc3d.draw(i+1, rns[label], cc_out) + + if return_N: + return cc_out, len(preserve) + return cc_out + +def _view_as_unsigned(img:np.ndarray) -> np.ndarray: + if np.issubdtype(img.dtype, np.unsignedinteger) or img.dtype == bool: + return img + elif img.dtype == np.int8: + return img.view(np.uint8) + elif img.dtype == np.int16: + return img.view(np.uint16) + elif img.dtype == np.int32: + return img.view(np.uint32) + elif img.dtype == np.int64: + return img.view(np.uint64) + + return img + + +class DisjointSet: + def __init__(self): + self.data = {} + def makeset(self, x): + self.data[x] = x + return x + def find(self, x): + if not x in self.data: + return None + i = self.data[x] + while i != self.data[i]: + self.data[i] = self.data[self.data[i]] + i = self.data[i] + return i + def union(self, x, y): + i = self.find(x) + j = self.find(y) + if i is None: + i = self.makeset(x) + if j is None: + j = self.makeset(y) + + if i < j: + self.data[j] = i + else: + self.data[i] = j + +def connected_components_stack( + stacked_images:Sequence[np.ndarray], + connectivity:int = 26, + return_N:bool = False, + out_dtype:Optional[Any] = None, +): + """ + This is for performing connected component labeling + on an array larger than RAM. + + stacked_images is a sequence of 3D images that are of equal + width and height (x,y) and arbitrary depth (z). For example, + you might define a generator that produces a tenth of your + data at a time. The data must be sequenced in z order from + z = 0 to z = depth - 1. + + Each 3D image will have CCL run on it and then compressed + into crackle format (https://github.com/seung-lab/crackle) + which is highly compressed but still usable and randomly + accessible by z-slice. + + The bottom previous slice and top current + slice will be analyzed to produce a merged image. + + The final output will be a CrackleArray. You + can access parts of the image using standard array + operations, write the array data to disk using arr.binary + or fully decompressing the array using arr.decompress() + to obtain a numpy array (but presumably this will blow + out your RAM since the image is so big). + """ + try: + import crackle + import fastremap + except ImportError: + print("You need to pip install connected-components-3d[stack]") + raise + + full_binary = None + bottom_cc_img = None + bottom_cc_labels = None + + if connectivity not in (6,26): + raise ValueError(f"Connectivity must be 6 or 26. Got: {connectivity}") + + offset = 0 + + for image in stacked_images: + cc_labels, N = connected_components( + image, connectivity=connectivity, + return_N=True, out_dtype=np.uint64, + ) + cc_labels[cc_labels != 0] += offset + offset += N + binary = crackle.compress(cc_labels) + + if full_binary is None: + full_binary = binary + bottom_cc_img = image[:,:,-1] + bottom_cc_labels = cc_labels[:,:,-1] + continue + + top_cc_labels = cc_labels[:,:,0] + + equivalences = DisjointSet() + + buniq = fastremap.unique(bottom_cc_labels) + tuniq = fastremap.unique(top_cc_labels) + + for u in buniq: + equivalences.makeset(u) + for u in tuniq: + equivalences.makeset(u) + + if connectivity == 6: + for y in range(image.shape[1]): + for x in range(image.shape[0]): + if bottom_cc_labels[x,y] == 0 or top_cc_labels[x,y] == 0: + continue + if bottom_cc_img[x,y] == image[x,y,0]: + equivalences.union(bottom_cc_labels[x,y], top_cc_labels[x,y]) + else: + for y in range(image.shape[1]): + for x in range(image.shape[0]): + if bottom_cc_labels[x,y] == 0: + continue + + for y0 in range(max(y - 1, 0), min(y + 1, image.shape[1] - 1) + 1): + for x0 in range(max(x - 1, 0), min(x + 1, image.shape[0] - 1) + 1): + if top_cc_labels[x0,y0] == 0: + continue + + if bottom_cc_img[x,y] == image[x0,y0,0]: + equivalences.union( + bottom_cc_labels[x,y], top_cc_labels[x0,y0] + ) + + relabel = {} + for u in buniq: + relabel[int(u)] = int(equivalences.find(u)) + for u in tuniq: + relabel[int(u)] = int(equivalences.find(u)) + + full_binary = crackle.zstack([ + full_binary, + binary, + ]) + full_binary = crackle.remap(full_binary, relabel, preserve_missing_labels=True) + + bottom_cc_img = image[:,:,-1] + bottom_cc_labels = cc_labels[:,:,-1] + bottom_cc_labels = fastremap.remap(bottom_cc_labels, relabel, preserve_missing_labels=True) + + if crackle.contains(full_binary, 0): + start = 0 + else: + start = 1 + + full_binary, mapping = crackle.renumber(full_binary, start=start) + arr = crackle.CrackleArray(full_binary).refit() + + if return_N: + return arr, arr.num_labels() + else: + return arr + + + diff --git a/cc3d.hpp b/cc3d/cc3d.hpp similarity index 100% rename from cc3d.hpp rename to cc3d/cc3d.hpp diff --git a/cc3d_continuous.hpp b/cc3d/cc3d_continuous.hpp similarity index 100% rename from cc3d_continuous.hpp rename to cc3d/cc3d_continuous.hpp diff --git a/cc3d_graphs.hpp b/cc3d/cc3d_graphs.hpp similarity index 100% rename from cc3d_graphs.hpp rename to cc3d/cc3d_graphs.hpp diff --git a/cc3d.pyx b/cc3d/fastcc3d.pyx similarity index 94% rename from cc3d.pyx rename to cc3d/fastcc3d.pyx index f53f531..a46a2b0 100644 --- a/cc3d.pyx +++ b/cc3d/fastcc3d.pyx @@ -311,7 +311,10 @@ def connected_components( raise ValueError(f"periodic_boundary is not yet implemented continuous data.") if data.size == 0: - out_labels = np.zeros(shape=(0,), dtype=data.dtype) + dtype = data.dtype + if out_dtype is not None: + dtype = out_dtype + out_labels = np.zeros(shape=(0,), dtype=dtype) if return_N: return (out_labels, 0) return out_labels @@ -1307,114 +1310,3 @@ def each( if in_place: return InPlaceImageIterator() return ImageIterator() - -## The functions below are conveniences for doing -## common tasks efficiently. - -def _view_as_unsigned(img:np.ndarray): - if np.issubdtype(img.dtype, np.unsignedinteger) or img.dtype == bool: - return img - elif img.dtype == np.int8: - return img.view(np.uint8) - elif img.dtype == np.int16: - return img.view(np.uint16) - elif img.dtype == np.int32: - return img.view(np.uint32) - elif img.dtype == np.int64: - return img.view(np.uint64) - - return img - -@cython.binding(True) -def dust( - img:np.ndarray, - threshold:Union[int,float], - connectivity:int = 26, - in_place:bool = False, -) -> np.ndarray: - """ - Remove from the input image connected components - smaller than threshold ("dust"). The name of the function - can be read as a verb "to dust" the image. - - img: 2D or 3D image - threshold: discard components smaller than this in voxels - connectivity: cc3d connectivity to use - in_place: whether to modify the input image or perform - dust - - Returns: dusted image - """ - orig_dtype = img.dtype - img = _view_as_unsigned(img) - - if not in_place: - img = np.copy(img) - - cc_labels, N = connected_components( - img, connectivity=connectivity, return_N=True - ) - stats = statistics(cc_labels) - mask_sizes = stats["voxel_counts"] - del stats - - to_mask = [ - i for i in range(1, N+1) if mask_sizes[i] < threshold - ] - - if len(to_mask) == 0: - return img - - mask = np.isin(cc_labels, to_mask) - del cc_labels - np.logical_not(mask, out=mask) - np.multiply(img, mask, out=img) - return img.view(orig_dtype) - -@cython.binding(True) -def largest_k( - img:np.ndarray, - k:int, - connectivity:int = 26, - delta:Union[int,float] = 0, - return_N:bool = False, -) -> np.ndarray: - """ - Returns the k largest connected components - in the image. - """ - assert k >= 0 - - order = "C" if img.flags.c_contiguous else "F" - - if k == 0: - return np.zeros(img.shape, dtype=np.uint16, order=order) - - cc_labels, N = connected_components( - img, connectivity=connectivity, - return_N=True, delta=delta, - ) - if N <= k: - if return_N: - return cc_labels, N - return cc_labels - - cts = statistics(cc_labels)["voxel_counts"] - preserve = [ (i,ct) for i,ct in enumerate(cts) if i > 0 ] - preserve.sort(key=lambda x: x[1]) - preserve = [ x[0] for x in preserve[-k:] ] - - shape, dtype = cc_labels.shape, cc_labels.dtype - rns = runs(cc_labels) - - order = "C" if cc_labels.flags.c_contiguous else "F" - del cc_labels - - cc_out = np.zeros(shape, dtype=dtype, order=order) - for i, label in enumerate(preserve): - draw(i+1, rns[label], cc_out) - - if return_N: - return cc_out, len(preserve) - return cc_out - diff --git a/setup.py b/setup.py index ae453b8..5377e5f 100644 --- a/setup.py +++ b/setup.py @@ -42,12 +42,15 @@ def requirements(): setup_requires=['pbr', 'numpy', 'cython'], install_requires=['numpy'], python_requires=">=3.8,<4.0", + extras_requires={ + "stack": [ "crackle-codec", "fastremap" ], + }, ext_modules=[ setuptools.Extension( - 'cc3d', - sources=[ 'cc3d.pyx' ], + 'fastcc3d', + sources=[ 'cc3d/fastcc3d.pyx' ], language='c++', - include_dirs=[ str(NumpyImport()) ], + include_dirs=[ 'cc3d', str(NumpyImport()) ], extra_compile_args=extra_compile_args, ) ],