diff --git a/sam2/utils/misc.py b/sam2/utils/misc.py index 525e8cb3..b1f63ce4 100644 --- a/sam2/utils/misc.py +++ b/sam2/utils/misc.py @@ -58,9 +58,16 @@ def get_connected_components(mask): - counts: A tensor of shape (N, 1, H, W) containing the area of the connected components for foreground pixels and 0 for background pixels. """ - from sam2 import _C + try: + from sam2 import _C + + return _C.get_connected_componnets(mask.to(torch.uint8).contiguous()) + except Exception: + import skimage.measure - return _C.get_connected_componnets(mask.to(torch.uint8).contiguous()) + return skimage.measure.label( + mask.to(torch.uint8).contiguous().cpu().numpy(), return_num=True + ) def mask_to_box(masks: torch.Tensor):