From e5de29fcec6b0ad6d06407c29d8de9ee6c9f70f6 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Thu, 25 Jul 2024 14:32:43 +0200 Subject: [PATCH] Remove disconnected objects after AMG --- micro_sam/instance_segmentation.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/micro_sam/instance_segmentation.py b/micro_sam/instance_segmentation.py index 23d666b9..bf3589b2 100644 --- a/micro_sam/instance_segmentation.py +++ b/micro_sam/instance_segmentation.py @@ -19,6 +19,8 @@ from segment_anything.predictor import SamPredictor from skimage.measure import regionprops +from skimage.measure import label +from skimage.segmentation import relabel_sequential from torchvision.ops.boxes import batched_nms, box_area from torch_em.model import UNETR @@ -46,6 +48,7 @@ def mask_data_to_segmentation( with_background: bool, min_object_size: int = 0, max_object_size: Optional[int] = None, + label_masks: bool = True, ) -> np.ndarray: """Convert the output of the automatic mask generation to an instance segmentation. @@ -56,6 +59,7 @@ def mask_data_to_segmentation( object in the output will be mapped to zero (the background value). min_object_size: The minimal size of an object in pixels. max_object_size: The maximal size of an object in pixels. + label_masks: Whether to apply connected components to the result before remving small objects. Returns: The instance segmentation. """ @@ -79,6 +83,8 @@ def require_numpy(mask): segmentation[require_numpy(mask["segmentation"])] = this_seg_id seg_id = this_seg_id + 1 + if label_masks: + segmentation = label(segmentation) seg_ids, sizes = np.unique(segmentation, return_counts=True) # In some cases objects may be smaller than peviously calculated, @@ -94,7 +100,7 @@ def require_numpy(mask): filter_ids = np.concatenate([filter_ids, [bg_id]]) segmentation[np.isin(segmentation, filter_ids)] = 0 - vigra.analysis.relabelConsecutive(segmentation, out=segmentation) + segmentation = relabel_sequential(segmentation)[0] return segmentation