Skip to content

Commit

Permalink
keep the sam auto-seg masks until input image changes
Browse files Browse the repository at this point in the history
  • Loading branch information
mese79 committed Aug 7, 2024
1 parent 7cf9cb7 commit 7105aca
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 31 deletions.
13 changes: 11 additions & 2 deletions src/featureforest/_segmentation_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@
from .postprocess import (
postprocess,
postprocess_with_sam,
postprocess_with_sam_auto
postprocess_with_sam_auto,
get_sam_auto_masks
)


Expand All @@ -50,6 +51,7 @@ def __init__(self, napari_viewer: napari.Viewer):
self.storage = None
self.rf_model = None
self.model_adapter = None
self.sam_auto_masks = None
self.patch_size = 512 # default values
self.overlap = 384
self.stride = self.patch_size - self.overlap
Expand Down Expand Up @@ -98,6 +100,7 @@ def create_input_ui(self):
# input layer
input_label = QLabel("Input Layer:")
self.image_combo = QComboBox()
self.image_combo.currentIndexChanged.connect(self.clear_sam_auto_masks)
# sam storage
storage_label = QLabel("SAM Embeddings Storage:")
self.storage_textbox = QLineEdit()
Expand Down Expand Up @@ -416,6 +419,9 @@ def check_label_layers(self, event: Event):
if index > -1:
self.prediction_layer_combo.setCurrentIndex(index)

def clear_sam_auto_masks(self):
self.sam_auto_masks = None

def postprocess_layer_removed(self, event: Event):
"""Fires when current postprocess layer is removed."""
if (
Expand Down Expand Up @@ -782,9 +788,12 @@ def postprocess_segmentation(self, whole_stack=False):
if num_slices > 1:
input_image = self.image_layer.data[slice_index]
iou_threshold = float(self.sam_auto_threshold_textbox.text())
# get sam auto-segmentation masks
if self.sam_auto_masks is None:
self.sam_auto_masks = get_sam_auto_masks(input_image)
# postprocess
self.postprocess_layer.data[slice_index] = postprocess_with_sam_auto(
input_image,
self.sam_auto_masks,
prediction,
smoothing_iterations, iou_threshold,
area_threshold, area_is_absolute
Expand Down
6 changes: 4 additions & 2 deletions src/featureforest/postprocess/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from .postprocess import postprocess
from .postprocess_with_sam import (
postprocess_with_sam,
postprocess_with_sam_auto
postprocess_with_sam_auto,
get_sam_auto_masks
)


__all__ = [
"postprocess",
"postprocess_with_sam",
"postprocess_with_sam_auto"
"postprocess_with_sam_auto",
"get_sam_auto_masks"
]
69 changes: 42 additions & 27 deletions src/featureforest/postprocess/postprocess_with_sam.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Tuple

from napari.utils import progress as np_progress
# import napari.utils.notifications as notif
Expand Down Expand Up @@ -207,6 +207,43 @@ def postprocess_with_sam(
return final_mask


def get_sam_auto_masks(input_image: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
"""Returns masks generated by SamAutomaticMaskGenerator
Args:
input_image (np.ndarray): input image
Returns:
Tuple[np.ndarray, np.ndarray]: a tuple of (masks, areas)
"""
if not is_image_rgb(input_image):
input_image = np.repeat(
input_image[..., np.newaxis],
3, axis=-1
)
assert is_image_rgb(input_image)
# init a sam auto-segmentation mask generator
mask_generator = SamAutomaticMaskGenerator(
model=get_light_hq_sam(),
points_per_side=64,
pred_iou_thresh=0.8,
stability_score_thresh=0.85,
stability_score_offset=0.9,
crop_n_layers=1,
crop_n_points_downscale_factor=2,
# crop_nms_thresh=0.7,
min_mask_region_area=20
)
# generate SAM masks
print("generating masks using SamAutomaticMaskGenerator...")
with np_progress(range(1), desc="Generating masks using SamAutomaticMaskGenerator"):
sam_generated_masks = mask_generator.generate(input_image)
sam_masks = np.array([mask["segmentation"] for mask in sam_generated_masks])
sam_areas = np.array([mask["area"] for mask in sam_generated_masks])

return sam_masks, sam_areas


def get_ious(mask: np.ndarray, sam_masks: np.ndarray) -> np.ndarray:
"""Calculate IOU between prediction mask and all SAM generated masks.
Expand All @@ -226,7 +263,7 @@ def get_ious(mask: np.ndarray, sam_masks: np.ndarray) -> np.ndarray:


def postprocess_with_sam_auto(
input_image: np.ndarray,
sam_auto_masks: Tuple[np.ndarray, np.ndarray],
segmentation_image: np.ndarray,
smoothing_iterations: int = 20,
iou_threshold: float = 0.45,
Expand All @@ -250,33 +287,11 @@ def postprocess_with_sam_auto(
Returns:
np.ndarray: post-processed segmentation image
"""
if not is_image_rgb(input_image):
input_image = np.repeat(
input_image[..., np.newaxis],
3, axis=-1
)
assert is_image_rgb(input_image)
sam_masks, sam_areas = sam_auto_masks
print(f"generated masks: {len(sam_masks)}")

if iou_threshold > 1.0:
iou_threshold = 1.0
# init a sam auto-segmentation mask generator
mask_generator = SamAutomaticMaskGenerator(
model=get_light_hq_sam(),
points_per_side=64,
pred_iou_thresh=0.8,
stability_score_thresh=0.85,
stability_score_offset=0.9,
crop_n_layers=1,
crop_n_points_downscale_factor=2,
# crop_nms_thresh=0.7,
min_mask_region_area=20
)
# generate SAM masks
print("generating masks using SamAutomaticMaskGenerator...")
with np_progress(range(1), desc="Generating masks using SamAutomaticMaskGenerator"):
sam_generated_masks = mask_generator.generate(input_image)
sam_masks = np.array([mask["segmentation"] for mask in sam_generated_masks])
sam_areas = np.array([mask["area"] for mask in sam_generated_masks])
print(f"generated masks: {len(sam_generated_masks)}")

final_mask = np.zeros_like(segmentation_image, dtype=np.uint8)
# postprocessing gets done for each class segmentation.
Expand Down

0 comments on commit 7105aca

Please sign in to comment.