Skip to content

Commit

Permalink
perf(largest_k): more special handling for k=1, precomputed_ccl option
Browse files Browse the repository at this point in the history
  • Loading branch information
william-silversmith committed Nov 16, 2024
1 parent eda1498 commit 442f1bc
Showing 1 changed file with 30 additions and 11 deletions.
41 changes: 30 additions & 11 deletions cc3d/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,39 +71,58 @@ def largest_k(
delta:Union[int,float] = 0,
return_N:bool = False,
binary_image:bool = False,
precomputed_ccl:bool = False,
) -> np.ndarray:
"""
Returns the k largest connected components
in the image.
k: number of components to return (>= 0)
connectivity:
(2d) 4 [edges], 8 [edges+corners]
(3d) 6 [faces], 18 [faces+edges], or 26 [faces+edges+corners]
delta: if using a continuous image, the allowed difference
in adjacent voxel values
return_N: return value is (image, N)
binary_image: treat the input image as a binary image
precomputed_ccl: for performance, avoid computing a CCL
pass since the input is already a CCL output from this
library.
"""
assert k >= 0

order = "C" if img.flags.c_contiguous else "F"

if k == 0:
return np.zeros(img.shape, dtype=np.uint16, order=order)

if precomputed_ccl:
cc_labels = np.copy(img, order="F")
N = np.max(cc_labels)
else:
cc_labels, N = connected_components(
img, connectivity=connectivity,
return_N=True, delta=delta,
binary_image=bool(binary_image),
)

cc_labels, N = connected_components(
img, connectivity=connectivity,
return_N=True, delta=delta,
binary_image=bool(binary_image),
)
if N <= k:
if return_N:
return cc_labels, N
return cc_labels

cts = statistics(cc_labels, no_slice_conversion=True)["voxel_counts"]
preserve = [ (i,ct) for i,ct in enumerate(cts) if i > 0 ]
preserve.sort(key=lambda x: x[1])
preserve = [ x[0] for x in preserve[-k:] ]


if k == 1:
cc_out = (cc_labels == preserve[0])
cc_out = (cc_labels == np.argmax(cts))
if return_N:
return cc_out, 1
return cc_out


preserve = [ (i,ct) for i,ct in enumerate(cts) if i > 0 ]
preserve.sort(key=lambda x: x[1])
preserve = [ x[0] for x in preserve[-k:] ]

try:
import fastremap
cc_out = fastremap.mask_except(cc_labels, preserve, in_place=True)
Expand Down

0 comments on commit 442f1bc

Please sign in to comment.