From c42c14482f7cad52d9ee21ccae79d69be1c1ba9e Mon Sep 17 00:00:00 2001 From: anwai98 Date: Sun, 5 Jan 2025 20:30:23 +0100 Subject: [PATCH 1/2] Minor update to prompt-based segmentation --- micro_sam/prompt_based_segmentation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/micro_sam/prompt_based_segmentation.py b/micro_sam/prompt_based_segmentation.py index 3569d741..9062dd82 100644 --- a/micro_sam/prompt_based_segmentation.py +++ b/micro_sam/prompt_based_segmentation.py @@ -383,7 +383,7 @@ def _to_tile(prompts, shape, tile_shape, halo): raise ValueError("If points are passed you also need to pass labels.") point_coords, point_labels = points, labels - elif use_points: + elif use_points and len(np.unique(mask)) > 1: point_coords, point_labels = _compute_points_from_mask( mask, original_size=original_size, box_extension=box_extension, use_single_point=use_single_point, @@ -395,7 +395,7 @@ def _to_tile(prompts, shape, tile_shape, halo): if box is None: box = _compute_box_from_mask( mask, original_size=original_size, box_extension=box_extension - ) if use_box else None + ) if use_box and len(np.unique(mask)) > 1 else None else: box = _process_box(box, mask.shape, original_size=original_size, box_extension=box_extension) From 112f7e3831dde6cfa20f0f408b1a8c2a4fffb8f8 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Sun, 5 Jan 2025 22:41:03 +0100 Subject: [PATCH 2/2] Make foreground check efficient --- micro_sam/prompt_based_segmentation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/micro_sam/prompt_based_segmentation.py b/micro_sam/prompt_based_segmentation.py index 9062dd82..a5396703 100644 --- a/micro_sam/prompt_based_segmentation.py +++ b/micro_sam/prompt_based_segmentation.py @@ -383,7 +383,7 @@ def _to_tile(prompts, shape, tile_shape, halo): raise ValueError("If points are passed you also need to pass labels.") point_coords, point_labels = points, labels - elif use_points and len(np.unique(mask)) > 1: + elif use_points and mask.sum() != 0: point_coords, point_labels = _compute_points_from_mask( mask, original_size=original_size, box_extension=box_extension, use_single_point=use_single_point, @@ -395,7 +395,7 @@ def _to_tile(prompts, shape, tile_shape, halo): if box is None: box = _compute_box_from_mask( mask, original_size=original_size, box_extension=box_extension - ) if use_box and len(np.unique(mask)) > 1 else None + ) if use_box and mask.sum() != 0 else None else: box = _process_box(box, mask.shape, original_size=original_size, box_extension=box_extension)