From d85aeb11bfcdda0dfee6f92c5e63995447f83070 Mon Sep 17 00:00:00 2001 From: Mehdi Seifi Date: Sat, 20 Jul 2024 23:46:16 +0330 Subject: [PATCH 1/6] make post-process ui separate from prediction ui --- src/featureforest/_segmentation_widget.py | 109 +++++++++++++--------- 1 file changed, 63 insertions(+), 46 deletions(-) diff --git a/src/featureforest/_segmentation_widget.py b/src/featureforest/_segmentation_widget.py index 31de5a0..4ad6c5a 100644 --- a/src/featureforest/_segmentation_widget.py +++ b/src/featureforest/_segmentation_widget.py @@ -69,6 +69,7 @@ def prepare_widget(self): self.create_label_stats_ui() self.create_train_ui() self.create_prediction_ui() + self.create_postprocessing_ui() scroll_content = QWidget() scroll_content.setLayout(self.base_layout) @@ -221,31 +222,12 @@ def create_prediction_ui(self): self.prediction_layer_combo = QComboBox() self.prediction_layer_combo.setEnabled(False) self.seg_add_radiobutton = QRadioButton("Add Segmentations") - self.seg_add_radiobutton.setChecked(True) + self.seg_add_radiobutton.setChecked(False) self.seg_add_radiobutton.setEnabled(False) self.seg_replace_radiobutton = QRadioButton("Replace Segmentations") + self.seg_replace_radiobutton.setChecked(True) self.seg_replace_radiobutton.setEnabled(False) - # post-process ui - area_label = QLabel("Area Threshold(%):") - self.area_threshold_textbox = QLineEdit() - self.area_threshold_textbox.setText("15") - self.area_threshold_textbox.setValidator( - QDoubleValidator( - 1.000, 100.000, 3, notation=QDoubleValidator.StandardNotation - ) - ) - self.area_threshold_textbox.setToolTip( - "Keeps regions with area above the threshold percentage." - ) - self.area_threshold_textbox.setEnabled(False) - - self.sam_post_checkbox = QCheckBox("Use SAM Predictor") - self.sam_post_checkbox.setEnabled(False) - - self.postprocess_checkbox = QCheckBox("Postprocess Segmentations") - self.postprocess_checkbox.stateChanged.connect(self.postprocess_checkbox_changed) - predict_button = QPushButton("Predict Slice") predict_button.setMinimumWidth(150) predict_button.clicked.connect(lambda: self.predict(whole_stack=False)) @@ -274,11 +256,6 @@ def create_prediction_ui(self): hbox.addWidget(self.seg_add_radiobutton) hbox.addWidget(self.seg_replace_radiobutton) vbox.addLayout(hbox) - vbox.addWidget(self.postprocess_checkbox) - vbox.addWidget(area_label) - vbox.addWidget(self.area_threshold_textbox) - vbox.addWidget(self.sam_post_checkbox) - vbox.addSpacing(20) vbox.addWidget(predict_button, alignment=Qt.AlignLeft) hbox = QHBoxLayout() hbox.setContentsMargins(0, 0, 0, 0) @@ -298,17 +275,58 @@ def create_prediction_ui(self): gbox.setLayout(layout) self.base_layout.addWidget(gbox) + def create_postprocessing_ui(self): + area_label = QLabel("Area Threshold(%):") + self.area_threshold_textbox = QLineEdit() + self.area_threshold_textbox.setText("15") + self.area_threshold_textbox.setValidator( + QDoubleValidator( + 1.000, 100.000, 3, notation=QDoubleValidator.StandardNotation + ) + ) + self.area_threshold_textbox.setToolTip( + "Keeps only regions with area above the threshold." + ) + + self.sam_post_checkbox = QCheckBox("Use SAM Predictor") + sam_label = QLabel( + "This will generate prompts for SAM Predictor using bounding boxes" + " around segmented regions." + ) + sam_label.setWordWrap(True) + + postprocess_button = QPushButton("Apply") + postprocess_button.setMinimumWidth(150) + # postprocess_button.clicked.connect(lambda: self.predict(whole_stack=False)) + postprocess_all_button = QPushButton("Apply to Stack") + postprocess_all_button.setMinimumWidth(150) + + layout = QVBoxLayout() + vbox = QVBoxLayout() + vbox.setContentsMargins(0, 0, 0, 0) + vbox.addWidget(area_label) + vbox.addWidget(self.area_threshold_textbox) + vbox.addSpacing(10) + vbox.addWidget(self.sam_post_checkbox) + vbox.addWidget(sam_label) + vbox.addSpacing(5) + vbox.addWidget(postprocess_button, alignment=Qt.AlignLeft) + vbox.addWidget(postprocess_all_button, alignment=Qt.AlignLeft) + # vbox.addSpacing(20) + layout.addLayout(vbox) + + gbox = QGroupBox() + gbox.setTitle("Post-processing") + gbox.setMinimumWidth(100) + gbox.setLayout(layout) + self.base_layout.addWidget(gbox) + def new_layer_checkbox_changed(self): state = self.new_layer_checkbox.checkState() self.prediction_layer_combo.setEnabled(state == Qt.Unchecked) self.seg_add_radiobutton.setEnabled(state == Qt.Unchecked) self.seg_replace_radiobutton.setEnabled(state == Qt.Unchecked) - def postprocess_checkbox_changed(self): - state = self.postprocess_checkbox.checkState() - self.area_threshold_textbox.setEnabled(state == Qt.Checked) - self.sam_post_checkbox.setEnabled(state == Qt.Checked) - def check_input_layers(self, event: Event): curr_text = self.image_combo.currentText() self.image_combo.clear() @@ -627,21 +645,6 @@ def predict_slice(self, rf_model, slice_index, img_height, img_width): # skip paddings segmentation_image = segmentation_image[:img_height, :img_width] - # check for postprocessing - if self.postprocess_checkbox.checkState() == Qt.Checked: - # apply postprocessing - area_threshold = None - if len(self.area_threshold_textbox.text()) > 0: - area_threshold = float(self.area_threshold_textbox.text()) / 100 - if self.sam_post_checkbox.checkState() == Qt.Checked: - segmentation_image = postprocess_segmentations_with_sam( - segmentation_image, area_threshold - ) - else: - segmentation_image = postprocess_segmentation( - segmentation_image, area_threshold - ) - return segmentation_image def stop_predict(self): @@ -663,6 +666,20 @@ def prediction_is_done(self): print("Prediction is done!") notif.show_info("Prediction is done!") + def postprocess_slice(self, segmentation_image): + # apply postprocessing + area_threshold = None + if len(self.area_threshold_textbox.text()) > 0: + area_threshold = float(self.area_threshold_textbox.text()) / 100 + if self.sam_post_checkbox.checkState() == Qt.Checked: + segmentation_image = postprocess_segmentations_with_sam( + segmentation_image, area_threshold + ) + else: + segmentation_image = postprocess_segmentation( + segmentation_image, area_threshold + ) + def export_segmentation(self, out_format="nrrd"): if self.segmentation_layer is None: notif.show_error("No segmentation layer is selected!") From 3a5093798f1409b560ecb95d19a79f4b9b6aadfc Mon Sep 17 00:00:00 2001 From: Mehdi Seifi Date: Sun, 21 Jul 2024 13:57:32 +0330 Subject: [PATCH 2/6] postprocess works for a single slice --- src/featureforest/_segmentation_widget.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/featureforest/_segmentation_widget.py b/src/featureforest/_segmentation_widget.py index 4ad6c5a..092b4fa 100644 --- a/src/featureforest/_segmentation_widget.py +++ b/src/featureforest/_segmentation_widget.py @@ -297,7 +297,7 @@ def create_postprocessing_ui(self): postprocess_button = QPushButton("Apply") postprocess_button.setMinimumWidth(150) - # postprocess_button.clicked.connect(lambda: self.predict(whole_stack=False)) + postprocess_button.clicked.connect(self.postprocess_slice) postprocess_all_button = QPushButton("Apply to Stack") postprocess_all_button.setMinimumWidth(150) @@ -666,11 +666,14 @@ def prediction_is_done(self): print("Prediction is done!") notif.show_info("Prediction is done!") - def postprocess_slice(self, segmentation_image): - # apply postprocessing + def postprocess_slice(self, whole_stack=False): + curr_slice = self.viewer.dims.current_step[0] + segmentation_image = self.segmentation_layer.data[curr_slice] + area_threshold = None if len(self.area_threshold_textbox.text()) > 0: area_threshold = float(self.area_threshold_textbox.text()) / 100 + if self.sam_post_checkbox.checkState() == Qt.Checked: segmentation_image = postprocess_segmentations_with_sam( segmentation_image, area_threshold @@ -680,6 +683,8 @@ def postprocess_slice(self, segmentation_image): segmentation_image, area_threshold ) + self.segmentation_layer.data[curr_slice] = segmentation_image + def export_segmentation(self, out_format="nrrd"): if self.segmentation_layer is None: notif.show_error("No segmentation layer is selected!") From 98d6cabb84b72c30a2ae525a643d26d059b7e872 Mon Sep 17 00:00:00 2001 From: Mehdi Seifi Date: Thu, 25 Jul 2024 17:48:44 +0200 Subject: [PATCH 3/6] updated to work for whole stack too; toggle button is added --- .../_sam_prompt_segmentation_widget.py | 6 +- src/featureforest/_segmentation_widget.py | 132 +++++++++++++++--- src/featureforest/postprocess/__init__.py | 8 ++ .../postprocess/mean_curvature.py | 86 ++++++++++++ .../{utils => postprocess}/postprocess.py | 82 ++++++----- .../postprocess_with_sam.py | 2 +- 6 files changed, 257 insertions(+), 59 deletions(-) create mode 100644 src/featureforest/postprocess/__init__.py create mode 100644 src/featureforest/postprocess/mean_curvature.py rename src/featureforest/{utils => postprocess}/postprocess.py (59%) rename src/featureforest/{utils => postprocess}/postprocess_with_sam.py (99%) diff --git a/src/featureforest/_sam_prompt_segmentation_widget.py b/src/featureforest/_sam_prompt_segmentation_widget.py index 7ca21e9..6cb7871 100644 --- a/src/featureforest/_sam_prompt_segmentation_widget.py +++ b/src/featureforest/_sam_prompt_segmentation_widget.py @@ -31,8 +31,8 @@ from .utils import ( colormaps, config ) -from .utils.postprocess import ( - process_similarity_matrix, postprocess_label, +from .postprocess.postprocess import ( + process_similarity_matrix, postprocess_label_mask, generate_mask_prompts, ) @@ -519,7 +519,7 @@ def run_prediction(self, slice_indices, prompts_mask): high_sim_mask[ sim_mat >= float(self.similarity_threshold_textbox.text()) ] = 255 - post_high_sim_mask = postprocess_label(high_sim_mask, 0.15) + post_high_sim_mask = postprocess_label_mask(high_sim_mask, 0.15) positive_prompts = generate_mask_prompts(post_high_sim_mask) if len(positive_prompts) == 0: diff --git a/src/featureforest/_segmentation_widget.py b/src/featureforest/_segmentation_widget.py index 092b4fa..a84b3ba 100644 --- a/src/featureforest/_segmentation_widget.py +++ b/src/featureforest/_segmentation_widget.py @@ -32,10 +32,10 @@ from .utils import ( colormaps, config ) -from .utils.postprocess import ( - postprocess_segmentation, +from .postprocess import ( + postprocess, + postprocess_with_sam ) -from .utils.postprocess_with_sam import postprocess_segmentations_with_sam class SegmentationWidget(QWidget): @@ -45,6 +45,8 @@ def __init__(self, napari_viewer: napari.Viewer): self.image_layer = None self.gt_layer = None self.segmentation_layer = None + self.segmentation_result = None + self.postprocess_result = None self.storage = None self.rf_model = None self.model_adapter = None @@ -276,17 +278,25 @@ def create_prediction_ui(self): self.base_layout.addWidget(gbox) def create_postprocessing_ui(self): - area_label = QLabel("Area Threshold(%):") + smooth_label = QLabel("Smoothing Iterations:") + self.smoothing_iteration_textbox = QLineEdit() + self.smoothing_iteration_textbox.setText("25") + self.smoothing_iteration_textbox.setValidator( + QIntValidator(0, 1000000) + ) + + area_label = QLabel("Area Threshold:") self.area_threshold_textbox = QLineEdit() self.area_threshold_textbox.setText("15") self.area_threshold_textbox.setValidator( - QDoubleValidator( - 1.000, 100.000, 3, notation=QDoubleValidator.StandardNotation - ) + QIntValidator(0, 2147483647) ) self.area_threshold_textbox.setToolTip( "Keeps only regions with area above the threshold." ) + self.area_percent_radiobutton = QRadioButton("percentage") + self.area_percent_radiobutton.setChecked(True) + self.area_abs_radiobutton = QRadioButton("absolute") self.sam_post_checkbox = QCheckBox("Use SAM Predictor") sam_label = QLabel( @@ -297,20 +307,38 @@ def create_postprocessing_ui(self): postprocess_button = QPushButton("Apply") postprocess_button.setMinimumWidth(150) - postprocess_button.clicked.connect(self.postprocess_slice) + self.toggle_postprocess_button = QPushButton("Toggle Off") + self.toggle_postprocess_button.setMinimumWidth(150) + self.toggle_postprocess_button.setCheckable(True) + self.toggle_postprocess_button.setChecked(True) + self.toggle_postprocess_button.clicked.connect(self.toggle_postprocess) + postprocess_button.clicked.connect(self.postprocess_segmentation) postprocess_all_button = QPushButton("Apply to Stack") postprocess_all_button.setMinimumWidth(150) + postprocess_all_button.clicked.connect( + lambda: self.postprocess_segmentation(whole_stack=True) + ) layout = QVBoxLayout() vbox = QVBoxLayout() vbox.setContentsMargins(0, 0, 0, 0) + vbox.addWidget(smooth_label) + vbox.addWidget(self.smoothing_iteration_textbox) vbox.addWidget(area_label) + hbox = QHBoxLayout() + hbox.addWidget(self.area_percent_radiobutton) + hbox.addWidget(self.area_abs_radiobutton) + vbox.addLayout(hbox) vbox.addWidget(self.area_threshold_textbox) - vbox.addSpacing(10) + vbox.addSpacing(15) vbox.addWidget(self.sam_post_checkbox) vbox.addWidget(sam_label) vbox.addSpacing(5) - vbox.addWidget(postprocess_button, alignment=Qt.AlignLeft) + hbox = QHBoxLayout() + hbox.setContentsMargins(0, 0, 0, 0) + hbox.addWidget(postprocess_button) + hbox.addWidget(self.toggle_postprocess_button) + vbox.addLayout(hbox) vbox.addWidget(postprocess_all_button, alignment=Qt.AlignLeft) # vbox.addSpacing(20) layout.addLayout(vbox) @@ -615,6 +643,8 @@ def run_prediction(self, slice_indices, img_height, img_width): cm, _ = colormaps.create_colormap(len(np.unique(segmentation_image))) self.segmentation_layer.colormap = cm self.segmentation_layer.refresh() + # keep the segmentations before applying postprocessing + self.segmentation_result = self.segmentation_layer.data.copy() def predict_slice(self, rf_model, slice_index, img_height, img_width): """Predict a slice patch by patch""" @@ -666,24 +696,80 @@ def prediction_is_done(self): print("Prediction is done!") notif.show_info("Prediction is done!") - def postprocess_slice(self, whole_stack=False): - curr_slice = self.viewer.dims.current_step[0] - segmentation_image = self.segmentation_layer.data[curr_slice] + def postprocess_segmentation(self, whole_stack=False): + self.segmentation_layer = get_layer( + self.viewer, + self.prediction_layer_combo.currentText(), config.NAPARI_LABELS_LAYER + ) + if self.segmentation_layer is None: + notif.show_error("No segmentation layer is selected!") + return + + smoothing_iterations = 25 + if len(self.smoothing_iteration_textbox.text()) > 0: + smoothing_iterations = int(self.smoothing_iteration_textbox.text()) - area_threshold = None + area_threshold = 0 if len(self.area_threshold_textbox.text()) > 0: - area_threshold = float(self.area_threshold_textbox.text()) / 100 + area_threshold = int(self.area_threshold_textbox.text()) + area_is_absolute = False + if self.area_abs_radiobutton.isChecked(): + area_is_absolute = True - if self.sam_post_checkbox.checkState() == Qt.Checked: - segmentation_image = postprocess_segmentations_with_sam( - segmentation_image, area_threshold - ) + num_slices, img_height, img_width = get_stack_dims(self.image_layer.data) + slice_indices = [] + if not whole_stack: + # only predict the current slice + slice_indices.append(self.viewer.dims.current_step[0]) else: - segmentation_image = postprocess_segmentation( - segmentation_image, area_threshold - ) + slice_indices = range(num_slices) + + self.postprocess_result = np.zeros( + (num_slices, img_height, img_width), dtype=np.uint8 + ) + for slice_index in np_progress(slice_indices): + if self.sam_post_checkbox.checkState() == Qt.Checked: + # TODO + self.postprocess_result[slice_index] = postprocess_with_sam( + self.segmentation_result[slice_index], area_threshold + ) + else: + self.postprocess_result[slice_index] = postprocess( + self.segmentation_result[slice_index], + smoothing_iterations, area_threshold, area_is_absolute + ) - self.segmentation_layer.data[curr_slice] = segmentation_image + if self.toggle_postprocess_button.isChecked(): + if whole_stack: + self.segmentation_layer.data = self.postprocess_result + else: + curr_slice = self.viewer.dims.current_step[0] + self.segmentation_layer.data[curr_slice] = self.postprocess_result[ + curr_slice + ] + self.segmentation_layer.refresh() + + def toggle_postprocess(self): + self.segmentation_layer = get_layer( + self.viewer, + self.prediction_layer_combo.currentText(), config.NAPARI_LABELS_LAYER + ) + if self.segmentation_layer is None: + notif.show_error("No segmentation layer is selected!") + return + if self.postprocess_result is None or self.segmentation_result is None: + notif.show_warning("No postprocessing is done yet!") + return + + if self.toggle_postprocess_button.isChecked(): + # show the post-processing result + self.toggle_postprocess_button.setText("Toggle Off") + self.segmentation_layer.data = self.postprocess_result + else: + # turn off the post-processing result + self.toggle_postprocess_button.setText("Toggle On") + self.segmentation_layer.data = self.segmentation_result + self.segmentation_layer.refresh() def export_segmentation(self, out_format="nrrd"): if self.segmentation_layer is None: diff --git a/src/featureforest/postprocess/__init__.py b/src/featureforest/postprocess/__init__.py new file mode 100644 index 0000000..7e39dbf --- /dev/null +++ b/src/featureforest/postprocess/__init__.py @@ -0,0 +1,8 @@ +from .postprocess import postprocess +from .postprocess_with_sam import postprocess_with_sam + + +__all__ = [ + "postprocess", + "postprocess_with_sam" +] diff --git a/src/featureforest/postprocess/mean_curvature.py b/src/featureforest/postprocess/mean_curvature.py new file mode 100644 index 0000000..800f946 --- /dev/null +++ b/src/featureforest/postprocess/mean_curvature.py @@ -0,0 +1,86 @@ +import numpy as np + + +def apply_threshold(img, t=127): + img_copy = img.copy() + img_copy[img_copy > t] = 255 + img_copy[img_copy <= t] = 0 + + return img_copy + + +def get_mean_curvature(input_image: np.ndarray) -> np.ndarray: + """Returns the mean curvature of the input image. + + Args: + input_img (np.ndarray): 2D input image + + Returns: + np.ndarray: smoothed output image (float32) + """ + # make float copy of the input + output_image = input_image.copy().astype(np.float32) + # pad image so the algorithm can be applied to the edges too. + padded = np.pad(output_image, 1, mode="edge") + num_rows, num_cols = padded.shape + # for each column there are right and left columns + right_indices = np.arange(2, num_cols) + left_indices = np.arange(0, num_cols - 2) + # for each row there are top and bottom rows + top_indices = np.arange(0, num_rows - 2) + bottom_indices = np.arange(2, num_rows) + # calculate differences on original image(not padded), + # so we use (1:-1) to ignore pads. + # dx: RIGHT(pix) - LEFT(pix) + dx = padded[1:-1, right_indices] - padded[1:-1, left_indices] + dx2 = np.power(dx, 2) + # dy: BOTTOM(pix) - TOP(pix) + dy = padded[bottom_indices, 1:-1] - padded[top_indices, 1:-1] + dy2 = np.power(dy, 2) + # second order differences + # dxx: RIGHT(pix) + LEFT(pix) - 2 * pix + dxx = padded[1:-1, right_indices] + padded[1:-1, left_indices] - 2 * output_image + # dyy: BOTTOM(pix) + TOP(pix) - 2 * pix + dyy = padded[bottom_indices, 1:-1] + padded[top_indices, 1:-1] - 2 * output_image + + # diagonal neighbors + dxy = 0.25 * ( + # BOTTOM-RIGHT(pix) - TOP-RIGHT(pix) - BOTTOM-LEFT(pix) + TOP-LEFT(pix) + padded[2:, 2:] - padded[:-2, 2:] - padded[2:, :-2] + padded[:-2, :-2] + ) + # mean curvature + magnitudes = np.sqrt(dx2 + dy2) # as coefficient of mean curvature + numerator = dx2 * dyy + dy2 * dxx - 2 * dx * dy * dxy + denom = np.sqrt(np.power(dx2 + dy2, 3)) + denom[denom == 0] = 1 # to handle zero division + mean_curvatures = numerator / denom + output_image += 0.25 * magnitudes * mean_curvatures + + return output_image + + +def mean_curvature_smoothing( + input_image: np.ndarray, num_iterations: int = 1 +) -> np.ndarray: + """Smooth the input image by applying mean curvature algorithm in iterations. + + Args: + input_image (np.ndarray): @d input image + num_iterations (int, optional): number of smoothing iterations. Defaults to 1. + + Returns: + np.ndarray: smoothed output image (uint8) + """ + output_image = get_mean_curvature(input_image) + for _ in range(num_iterations - 1): + output_image = get_mean_curvature(output_image) + # scale image in [0, 255] + output_image = (output_image - output_image.min()) * ( + 255 / (output_image.max() - output_image.min()) + ) + output_image = output_image.astype(np.uint8) + # apply a threshold to get the final mask + # TODO: find a proper threshold in a more intelligent way! + output_image = apply_threshold(output_image, 30) + + return output_image diff --git a/src/featureforest/utils/postprocess.py b/src/featureforest/postprocess/postprocess.py similarity index 59% rename from src/featureforest/utils/postprocess.py rename to src/featureforest/postprocess/postprocess.py index e659163..345815c 100644 --- a/src/featureforest/utils/postprocess.py +++ b/src/featureforest/postprocess/postprocess.py @@ -1,56 +1,74 @@ import numpy as np import cv2 - -def process_similarity_matrix(sim_mat): - """Smooth out given similarity matrix.""" - sim_mat_uint8 = cv2.normalize(sim_mat, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U) - sim_mat_smoothed = cv2.medianBlur(sim_mat_uint8, 13) / 255. - - return sim_mat_smoothed - - -def postprocess_label(bin_image, area_threshold: float = None): - # image morphology - elipse1 = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7)) - elipse2 = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) - morphed_img = cv2.morphologyEx(bin_image, cv2.MORPH_DILATE, elipse1, iterations=1) - morphed_img = cv2.morphologyEx(morphed_img, cv2.MORPH_ERODE, elipse2, iterations=1) - # remove components with small areas +from .mean_curvature import mean_curvature_smoothing + + +def postprocess_label_mask( + bin_image: np.ndarray, + smoothing_iterations: int, + area_threshold: int, + area_is_abs: bool +): + # image morphology: trying to close small holes + elipse = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) + smoothed_mask = cv2.morphologyEx(bin_image, cv2.MORPH_DILATE, elipse, iterations=2) + smoothed_mask = cv2.morphologyEx(smoothed_mask, cv2.MORPH_ERODE, elipse, iterations=2) + # iterative mean curvature smoothing + smoothed_mask = mean_curvature_smoothing(smoothed_mask, smoothing_iterations) + + # remove regions with small areas # stats: left, top, height, width, area num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats( - morphed_img, connectivity=8, ltype=cv2.CV_32S + smoothed_mask, connectivity=8, ltype=cv2.CV_32S ) - # if there is only background or only one component besides it: + # if there is only background or only one region plus bg: # num_labels = 1 -> only bg - # num_labels = 2 -> only one component + # num_labels = 2 -> only one region if num_labels < 3: - return morphed_img + return smoothed_mask - # get not background(0) areas + # get not background areas (not 0) areas = stats[1:, -1] - if area_threshold is None: - area_threshold = np.quantile(areas, 0.5) - else: - area_threshold = np.quantile(areas, area_threshold) + # if given threshold is a percentage then get corresponding area value + if not area_is_abs: + if area_threshold > 100: + area_threshold = 100 + area_threshold = np.percentile(areas, area_threshold) + # eliminate small regions small_parts = np.argwhere(stats[:, -1] <= area_threshold) - morphed_img[np.isin(labels, small_parts)] = 0 + smoothed_mask[np.isin(labels, small_parts)] = 0 - return morphed_img + return smoothed_mask -def postprocess_segmentation(segmentation_image, area_threshold: float = None): - final_image = np.zeros_like(segmentation_image, dtype=np.uint8) +def postprocess( + segmentation_image, + smoothing_iterations: int = 25, + area_threshold: int = 15, + area_is_abs: bool = False +): + final_mask = np.zeros_like(segmentation_image, dtype=np.uint8) # postprocessing gets done for each label's segments. class_labels = [c for c in np.unique(segmentation_image) if c > 0] for label in class_labels: # make a binary image for the label bin_image = (segmentation_image == label).astype(np.uint8) * 255 - processed_image = postprocess_label(bin_image, area_threshold) + processed_mask = postprocess_label_mask( + bin_image, smoothing_iterations, area_threshold, area_is_abs + ) # put the processed image into final result image - final_image[processed_image == 255] = label + final_mask[processed_mask == 255] = label - return final_image + return final_mask + + +def process_similarity_matrix(sim_mat): + """Smooth out given similarity matrix.""" + sim_mat_uint8 = cv2.normalize(sim_mat, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U) + sim_mat_smoothed = cv2.medianBlur(sim_mat_uint8, 13) / 255. + + return sim_mat_smoothed def get_furthest_point_from_edge(mask): diff --git a/src/featureforest/utils/postprocess_with_sam.py b/src/featureforest/postprocess/postprocess_with_sam.py similarity index 99% rename from src/featureforest/utils/postprocess_with_sam.py rename to src/featureforest/postprocess/postprocess_with_sam.py index a095171..6617291 100644 --- a/src/featureforest/utils/postprocess_with_sam.py +++ b/src/featureforest/postprocess/postprocess_with_sam.py @@ -190,7 +190,7 @@ def get_sam_mask( return final_mask -def postprocess_segmentations_with_sam( +def postprocess_with_sam( segmentations_image: np.ndarray, area_threshold: float = None ) -> np.ndarray: From 4a201a361b808840294be52c9f7b75cc3696ba47 Mon Sep 17 00:00:00 2001 From: Mehdi Seifi Date: Fri, 26 Jul 2024 11:24:56 +0200 Subject: [PATCH 4/6] updates postprocess using SAM; more doc strings --- src/featureforest/_segmentation_widget.py | 4 +- src/featureforest/postprocess/postprocess.py | 31 +++++++++- .../postprocess/postprocess_with_sam.py | 62 ++++++------------- 3 files changed, 49 insertions(+), 48 deletions(-) diff --git a/src/featureforest/_segmentation_widget.py b/src/featureforest/_segmentation_widget.py index a84b3ba..e4cc9b6 100644 --- a/src/featureforest/_segmentation_widget.py +++ b/src/featureforest/_segmentation_widget.py @@ -729,9 +729,9 @@ def postprocess_segmentation(self, whole_stack=False): ) for slice_index in np_progress(slice_indices): if self.sam_post_checkbox.checkState() == Qt.Checked: - # TODO self.postprocess_result[slice_index] = postprocess_with_sam( - self.segmentation_result[slice_index], area_threshold + self.segmentation_result[slice_index], + smoothing_iterations, area_threshold, area_is_absolute ) else: self.postprocess_result[slice_index] = postprocess( diff --git a/src/featureforest/postprocess/postprocess.py b/src/featureforest/postprocess/postprocess.py index 345815c..2703482 100644 --- a/src/featureforest/postprocess/postprocess.py +++ b/src/featureforest/postprocess/postprocess.py @@ -9,7 +9,18 @@ def postprocess_label_mask( smoothing_iterations: int, area_threshold: int, area_is_abs: bool -): +) -> np.ndarray: + """Post-process a binary mask image (of a class label) + + Args: + bin_image (np.ndarray): input mask image + smoothing_iterations (int): number of smoothing iterations + area_threshold (int): threshold to remove small regions + area_is_abs (bool): False if the threshold is a percentage + + Returns: + np.ndarray: post-processed mask image + """ # image morphology: trying to close small holes elipse = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) smoothed_mask = cv2.morphologyEx(bin_image, cv2.MORPH_DILATE, elipse, iterations=2) @@ -43,11 +54,25 @@ def postprocess_label_mask( def postprocess( - segmentation_image, + segmentation_image: np.ndarray, smoothing_iterations: int = 25, area_threshold: int = 15, area_is_abs: bool = False -): +) -> np.ndarray: + """Post-process a segmentation image mask containing multiple classes. + + Args: + segmentation_image (np.ndarray): input segmentation image + smoothing_iterations (int, optional): number of smoothing iterations. + Defaults to 25. + area_threshold (int, optional): threshold to remove small regions. + Defaults to 15. + area_is_abs (bool, optional): False if the threshold is a percentage. + Defaults to False. + + Returns: + np.ndarray: post-processed segmentation image + """ final_mask = np.zeros_like(segmentation_image, dtype=np.uint8) # postprocessing gets done for each label's segments. class_labels = [c for c in np.unique(segmentation_image) if c > 0] diff --git a/src/featureforest/postprocess/postprocess_with_sam.py b/src/featureforest/postprocess/postprocess_with_sam.py index 6617291..795b137 100644 --- a/src/featureforest/postprocess/postprocess_with_sam.py +++ b/src/featureforest/postprocess/postprocess_with_sam.py @@ -12,6 +12,7 @@ from segment_anything_hq import SamPredictor from featureforest.utils.downloader import download_model +from .postprocess import postprocess_label_mask def get_light_hq_sam() -> Sam: @@ -109,39 +110,6 @@ def get_bounding_boxes(image: np.ndarray) -> List[Rect]: return bboxes -def postprocess_label(bin_image: np.ndarray, area_threshold: float = None) -> np.ndarray: - """Postprocess the label segmentation mask - - Args: - bin_image (np.ndarray): input binary image - area_threshold (float, optional): threshold to remove small parts in the mask. - Defaults to None. - - Returns: - np.ndarray: post-processed mask - """ - # remove noises - kernel = np.ones((3, 3), dtype=np.uint8) - morphed_img = cv2.morphologyEx(bin_image, cv2.MORPH_CLOSE, kernel, iterations=1) - morphed_img = cv2.morphologyEx(morphed_img, cv2.MORPH_OPEN, kernel, iterations=2) - # remove components with small areas - # stats: left, top, height, width, area - num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats( - morphed_img, connectivity=8, ltype=cv2.CV_32S - ) - # get not background(0) areas - areas = stats[1:, -1] - if area_threshold is None: - area_threshold = np.quantile(areas, 0.5) - else: - area_threshold = np.quantile(areas, area_threshold) - small_parts = np.argwhere(stats[:, -1] <= area_threshold) - morphed_img[np.isin(labels, small_parts)] = 0 - # print(small_parts.sum(0)) - - return morphed_img - - def get_sam_mask( predictor: SamPredictor, image: np.ndarray, bboxes: List[Rect] ) -> np.ndarray: @@ -191,32 +159,40 @@ def get_sam_mask( def postprocess_with_sam( - segmentations_image: np.ndarray, - area_threshold: float = None + segmentation_image: np.ndarray, + smoothing_iterations: int = 25, + area_threshold: int = 15, + area_is_abs: bool = False ) -> np.ndarray: """Post-processes segmentations using SAM predictor. Args: - segmentations_image (np.ndarray): input segmentation image - area_threshold (float, optional): threshold to remove small parts in the mask. - Defaults to None. + segmentation_image (np.ndarray): input segmentation image + smoothing_iterations (int, optional): number of smoothing iterations. + Defaults to 25. + area_threshold (int, optional): threshold to remove small regions. + Defaults to 15. + area_is_abs (bool, optional): False if the threshold is a percentage. + Defaults to False. Returns: - np.ndarray: _description_ + np.ndarray: post-processed segmentation image """ # init a sam predictor using light hq sam predictor = SamPredictor(get_light_hq_sam()) - final_image = np.zeros_like(segmentations_image, dtype=np.uint8) + final_image = np.zeros_like(segmentation_image, dtype=np.uint8) # postprocessing gets done for each class segmentation. bg_label = 0 - class_labels = [c for c in np.unique(segmentations_image) if c > bg_label] + class_labels = [c for c in np.unique(segmentation_image) if c > bg_label] for label in np_progress( class_labels, desc="Getting SAM masks for each class" ): # make a binary image for the label (class) - bin_image = (segmentations_image == label).astype(np.uint8) * 255 - processed_mask = postprocess_label(bin_image, area_threshold) + bin_image = (segmentation_image == label).astype(np.uint8) * 255 + processed_mask = postprocess_label_mask( + bin_image, smoothing_iterations, area_threshold, area_is_abs + ) # get component bounding boxes w_bboxes = get_watershed_bboxes(processed_mask) bboxes = get_bounding_boxes(processed_mask) From a04cafcf69d07d24b6f75e6b5b7a92509396b38f Mon Sep 17 00:00:00 2001 From: Mehdi Seifi Date: Fri, 26 Jul 2024 11:34:15 +0200 Subject: [PATCH 5/6] mean curvature: use mask mean as the threshold --- src/featureforest/postprocess/mean_curvature.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/featureforest/postprocess/mean_curvature.py b/src/featureforest/postprocess/mean_curvature.py index 800f946..ea10ca0 100644 --- a/src/featureforest/postprocess/mean_curvature.py +++ b/src/featureforest/postprocess/mean_curvature.py @@ -81,6 +81,6 @@ def mean_curvature_smoothing( output_image = output_image.astype(np.uint8) # apply a threshold to get the final mask # TODO: find a proper threshold in a more intelligent way! - output_image = apply_threshold(output_image, 30) + output_image = apply_threshold(output_image, t=output_image.mean()) return output_image From 8cf7f8e85f8af42bb40cd1ffe2deb7ff185f2544 Mon Sep 17 00:00:00 2001 From: Mehdi Seifi Date: Wed, 31 Jul 2024 11:25:06 +0200 Subject: [PATCH 6/6] updated thresholding after mean-curvature using Otsu --- src/featureforest/postprocess/mean_curvature.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/featureforest/postprocess/mean_curvature.py b/src/featureforest/postprocess/mean_curvature.py index ea10ca0..9bd7f3a 100644 --- a/src/featureforest/postprocess/mean_curvature.py +++ b/src/featureforest/postprocess/mean_curvature.py @@ -1,10 +1,16 @@ import numpy as np +import cv2 -def apply_threshold(img, t=127): +def apply_threshold(img, t=None): img_copy = img.copy() - img_copy[img_copy > t] = 255 - img_copy[img_copy <= t] = 0 + if t is None: + t, img_copy = cv2.threshold( + img_copy, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU + ) + else: + img_copy[img_copy > t] = 255 + img_copy[img_copy <= t] = 0 return img_copy @@ -80,7 +86,6 @@ def mean_curvature_smoothing( ) output_image = output_image.astype(np.uint8) # apply a threshold to get the final mask - # TODO: find a proper threshold in a more intelligent way! - output_image = apply_threshold(output_image, t=output_image.mean()) + output_image = apply_threshold(output_image) return output_image