Skip to content

Commit

Permalink
feat: stacked CCL for huge arrays (#127)
Browse files Browse the repository at this point in the history
* refactor: move cc3d C++ logic and pure python logic away from each other

* feat: 6-connected stacked ccl

* test: finish simple test of stacked ccl

* fix: errors in joining logic

* feat+test: support 26 connectivity

* redesign: support only 6 or 26 connectivity

* perf: refit ccl array as it comes out

* fix: off-by-one error

* fix: dtype error for 0 sized arrays

* fixtest: use fastremap.renumber

* fix: another off-by-one error

* docs: show how to use connected_components_stack

* docs: tell users how to fix import errors

* install: fix name of "extras_requires"

* docs: how to install
  • Loading branch information
william-silversmith authored Jul 23, 2024
1 parent 89f7f46 commit 6e6ec00
Show file tree
Hide file tree
Showing 10 changed files with 341 additions and 120 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -144,4 +144,5 @@ callgrind*
subvol.bin


cc3d.cpp
cc3d.cpp
cc3d/fastcc3d.cpp
6 changes: 3 additions & 3 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
include cc3d.hpp
include cc3d_graphs.hpp
include cc3d_continuous.hpp
include cc3d/cc3d.hpp
include cc3d/cc3d_graphs.hpp
include cc3d/cc3d_continuous.hpp
include LICENSE
include COPYING.LESSER
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,22 @@ labels_out = cc3d.connected_components(labels_in, delta=10)
labels_in = np.memmap("labels.bin", order="F", dtype=np.uint32, shape=(5000, 5000, 2000))
labels_out = cc3d.connected_components(labels_in, out_file="out.bin")

# Here's another strategy that you can use for huge files that won't even
# take up any disk space. Provide any iterator to this function that produces
# thick z sections of the input array that are in sequential order.
# The output is a highly compressed CrackleArray that is still random access.
# See: https://github.com/seung-lab/crackle
# You need to pip install connected-components-3d[stack] to get the extra modules.
def sections(labels_in):
"""
A generator that produces thick Z slices
of an image
"""
for z in range(0, labels_in.shape[2], 100):
yield labels_in[:,:,z:z+100]

compressed_labels_out = cc3d.connected_components_stack(sections(labels))

# You can extract the number of labels (which is also the maximum
# label value) like so:
labels_out, N = cc3d.connected_components(labels_in, return_N=True) # free
Expand Down
31 changes: 30 additions & 1 deletion automated_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
import sys

import fastremap
import cc3d
import numpy as np

Expand Down Expand Up @@ -1490,15 +1491,43 @@ def test_pytorch_integration_ccl_doesnt_crash():
assert isinstance(out, torch.Tensor)
assert torch.all(out == labels)

@pytest.mark.parametrize("connectivity", [6, 26])
def test_connected_components_stack(connectivity):
stack = [
np.ones([100,100,100], dtype=np.uint32)
for i in range(4)
]

stack += [ np.zeros([100,100,1], dtype=np.uint32) ]

stack += [
np.ones([100,100,100], dtype=np.uint32)
for i in range(2)
]

arr = cc3d.connected_components_stack(stack, connectivity=connectivity)

ans = np.ones([100,100,601], dtype=np.uint32)
ans[:,:,400] = 0
ans[:,:,401:] = 2

assert np.all(np.unique(arr[:]) == [ 0, 1, 2 ])
assert np.all(arr[:] == ans)

image = np.random.randint(0,100, size=[100,100,11], dtype=np.uint8)

ans = cc3d.connected_components(image, connectivity=connectivity)
ans, _ = fastremap.renumber(ans[:])

stack = [
image[:,:,:2],
image[:,:,2:6],
image[:,:,6:9],
image[:,:,9:11],
]

res = cc3d.connected_components_stack(stack, connectivity=connectivity)
res, _ = fastremap.renumber(res[:])


assert np.all(res[:] == ans[:,:,:11])

280 changes: 280 additions & 0 deletions cc3d/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,280 @@
from typing import (
Dict, Union, Tuple, Iterator,
Sequence, Optional, Any, BinaryIO
)

import fastcc3d
from fastcc3d import (
connected_components,
statistics,
each,
contacts,
region_graph,
voxel_connectivity_graph,
color_connectivity_graph,
estimate_provisional_labels,
)

import numpy as np

def dust(
img:np.ndarray,
threshold:Union[int,float],
connectivity:int = 26,
in_place:bool = False,
) -> np.ndarray:
"""
Remove from the input image connected components
smaller than threshold ("dust"). The name of the function
can be read as a verb "to dust" the image.
img: 2D or 3D image
threshold: discard components smaller than this in voxels
connectivity: cc3d connectivity to use
in_place: whether to modify the input image or perform
dust
Returns: dusted image
"""
orig_dtype = img.dtype
img = _view_as_unsigned(img)

if not in_place:
img = np.copy(img)

cc_labels, N = connected_components(
img, connectivity=connectivity, return_N=True
)
stats = statistics(cc_labels)
mask_sizes = stats["voxel_counts"]
del stats

to_mask = [
i for i in range(1, N+1) if mask_sizes[i] < threshold
]

if len(to_mask) == 0:
return img

mask = np.isin(cc_labels, to_mask)
del cc_labels
np.logical_not(mask, out=mask)
np.multiply(img, mask, out=img)
return img.view(orig_dtype)

def largest_k(
img:np.ndarray,
k:int,
connectivity:int = 26,
delta:Union[int,float] = 0,
return_N:bool = False,
) -> np.ndarray:
"""
Returns the k largest connected components
in the image.
"""
assert k >= 0

order = "C" if img.flags.c_contiguous else "F"

if k == 0:
return np.zeros(img.shape, dtype=np.uint16, order=order)

cc_labels, N = connected_components(
img, connectivity=connectivity,
return_N=True, delta=delta,
)
if N <= k:
if return_N:
return cc_labels, N
return cc_labels

cts = statistics(cc_labels)["voxel_counts"]
preserve = [ (i,ct) for i,ct in enumerate(cts) if i > 0 ]
preserve.sort(key=lambda x: x[1])
preserve = [ x[0] for x in preserve[-k:] ]

shape, dtype = cc_labels.shape, cc_labels.dtype
rns = fastcc3d.runs(cc_labels)

order = "C" if cc_labels.flags.c_contiguous else "F"
del cc_labels

cc_out = np.zeros(shape, dtype=dtype, order=order)
for i, label in enumerate(preserve):
fastcc3d.draw(i+1, rns[label], cc_out)

if return_N:
return cc_out, len(preserve)
return cc_out

def _view_as_unsigned(img:np.ndarray) -> np.ndarray:
if np.issubdtype(img.dtype, np.unsignedinteger) or img.dtype == bool:
return img
elif img.dtype == np.int8:
return img.view(np.uint8)
elif img.dtype == np.int16:
return img.view(np.uint16)
elif img.dtype == np.int32:
return img.view(np.uint32)
elif img.dtype == np.int64:
return img.view(np.uint64)

return img


class DisjointSet:
def __init__(self):
self.data = {}
def makeset(self, x):
self.data[x] = x
return x
def find(self, x):
if not x in self.data:
return None
i = self.data[x]
while i != self.data[i]:
self.data[i] = self.data[self.data[i]]
i = self.data[i]
return i
def union(self, x, y):
i = self.find(x)
j = self.find(y)
if i is None:
i = self.makeset(x)
if j is None:
j = self.makeset(y)

if i < j:
self.data[j] = i
else:
self.data[i] = j

def connected_components_stack(
stacked_images:Sequence[np.ndarray],
connectivity:int = 26,
return_N:bool = False,
out_dtype:Optional[Any] = None,
):
"""
This is for performing connected component labeling
on an array larger than RAM.
stacked_images is a sequence of 3D images that are of equal
width and height (x,y) and arbitrary depth (z). For example,
you might define a generator that produces a tenth of your
data at a time. The data must be sequenced in z order from
z = 0 to z = depth - 1.
Each 3D image will have CCL run on it and then compressed
into crackle format (https://github.com/seung-lab/crackle)
which is highly compressed but still usable and randomly
accessible by z-slice.
The bottom previous slice and top current
slice will be analyzed to produce a merged image.
The final output will be a CrackleArray. You
can access parts of the image using standard array
operations, write the array data to disk using arr.binary
or fully decompressing the array using arr.decompress()
to obtain a numpy array (but presumably this will blow
out your RAM since the image is so big).
"""
try:
import crackle
import fastremap
except ImportError:
print("You need to pip install connected-components-3d[stack]")
raise

full_binary = None
bottom_cc_img = None
bottom_cc_labels = None

if connectivity not in (6,26):
raise ValueError(f"Connectivity must be 6 or 26. Got: {connectivity}")

offset = 0

for image in stacked_images:
cc_labels, N = connected_components(
image, connectivity=connectivity,
return_N=True, out_dtype=np.uint64,
)
cc_labels[cc_labels != 0] += offset
offset += N
binary = crackle.compress(cc_labels)

if full_binary is None:
full_binary = binary
bottom_cc_img = image[:,:,-1]
bottom_cc_labels = cc_labels[:,:,-1]
continue

top_cc_labels = cc_labels[:,:,0]

equivalences = DisjointSet()

buniq = fastremap.unique(bottom_cc_labels)
tuniq = fastremap.unique(top_cc_labels)

for u in buniq:
equivalences.makeset(u)
for u in tuniq:
equivalences.makeset(u)

if connectivity == 6:
for y in range(image.shape[1]):
for x in range(image.shape[0]):
if bottom_cc_labels[x,y] == 0 or top_cc_labels[x,y] == 0:
continue
if bottom_cc_img[x,y] == image[x,y,0]:
equivalences.union(bottom_cc_labels[x,y], top_cc_labels[x,y])
else:
for y in range(image.shape[1]):
for x in range(image.shape[0]):
if bottom_cc_labels[x,y] == 0:
continue

for y0 in range(max(y - 1, 0), min(y + 1, image.shape[1] - 1) + 1):
for x0 in range(max(x - 1, 0), min(x + 1, image.shape[0] - 1) + 1):
if top_cc_labels[x0,y0] == 0:
continue

if bottom_cc_img[x,y] == image[x0,y0,0]:
equivalences.union(
bottom_cc_labels[x,y], top_cc_labels[x0,y0]
)

relabel = {}
for u in buniq:
relabel[int(u)] = int(equivalences.find(u))
for u in tuniq:
relabel[int(u)] = int(equivalences.find(u))

full_binary = crackle.zstack([
full_binary,
binary,
])
full_binary = crackle.remap(full_binary, relabel, preserve_missing_labels=True)

bottom_cc_img = image[:,:,-1]
bottom_cc_labels = cc_labels[:,:,-1]
bottom_cc_labels = fastremap.remap(bottom_cc_labels, relabel, preserve_missing_labels=True)

if crackle.contains(full_binary, 0):
start = 0
else:
start = 1

full_binary, mapping = crackle.renumber(full_binary, start=start)
arr = crackle.CrackleArray(full_binary).refit()

if return_N:
return arr, arr.num_labels()
else:
return arr



File renamed without changes.
File renamed without changes.
File renamed without changes.
Loading

0 comments on commit 6e6ec00

Please sign in to comment.