Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Post-processing #11

Merged
merged 6 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/featureforest/_sam_prompt_segmentation_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
from .utils import (
colormaps, config
)
from .utils.postprocess import (
process_similarity_matrix, postprocess_label,
from .postprocess.postprocess import (
process_similarity_matrix, postprocess_label_mask,
generate_mask_prompts,
)

Expand Down Expand Up @@ -519,7 +519,7 @@ def run_prediction(self, slice_indices, prompts_mask):
high_sim_mask[
sim_mat >= float(self.similarity_threshold_textbox.text())
] = 255
post_high_sim_mask = postprocess_label(high_sim_mask, 0.15)
post_high_sim_mask = postprocess_label_mask(high_sim_mask, 0.15)
positive_prompts = generate_mask_prompts(post_high_sim_mask)

if len(positive_prompts) == 0:
Expand Down
206 changes: 157 additions & 49 deletions src/featureforest/_segmentation_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@
from .utils import (
colormaps, config
)
from .utils.postprocess import (
postprocess_segmentation,
from .postprocess import (
postprocess,
postprocess_with_sam
)
from .utils.postprocess_with_sam import postprocess_segmentations_with_sam


class SegmentationWidget(QWidget):
Expand All @@ -45,6 +45,8 @@ def __init__(self, napari_viewer: napari.Viewer):
self.image_layer = None
self.gt_layer = None
self.segmentation_layer = None
self.segmentation_result = None
self.postprocess_result = None
self.storage = None
self.rf_model = None
self.model_adapter = None
Expand All @@ -69,6 +71,7 @@ def prepare_widget(self):
self.create_label_stats_ui()
self.create_train_ui()
self.create_prediction_ui()
self.create_postprocessing_ui()

scroll_content = QWidget()
scroll_content.setLayout(self.base_layout)
Expand Down Expand Up @@ -221,31 +224,12 @@ def create_prediction_ui(self):
self.prediction_layer_combo = QComboBox()
self.prediction_layer_combo.setEnabled(False)
self.seg_add_radiobutton = QRadioButton("Add Segmentations")
self.seg_add_radiobutton.setChecked(True)
self.seg_add_radiobutton.setChecked(False)
self.seg_add_radiobutton.setEnabled(False)
self.seg_replace_radiobutton = QRadioButton("Replace Segmentations")
self.seg_replace_radiobutton.setChecked(True)
self.seg_replace_radiobutton.setEnabled(False)

# post-process ui
area_label = QLabel("Area Threshold(%):")
self.area_threshold_textbox = QLineEdit()
self.area_threshold_textbox.setText("15")
self.area_threshold_textbox.setValidator(
QDoubleValidator(
1.000, 100.000, 3, notation=QDoubleValidator.StandardNotation
)
)
self.area_threshold_textbox.setToolTip(
"Keeps regions with area above the threshold percentage."
)
self.area_threshold_textbox.setEnabled(False)

self.sam_post_checkbox = QCheckBox("Use SAM Predictor")
self.sam_post_checkbox.setEnabled(False)

self.postprocess_checkbox = QCheckBox("Postprocess Segmentations")
self.postprocess_checkbox.stateChanged.connect(self.postprocess_checkbox_changed)

predict_button = QPushButton("Predict Slice")
predict_button.setMinimumWidth(150)
predict_button.clicked.connect(lambda: self.predict(whole_stack=False))
Expand Down Expand Up @@ -274,11 +258,6 @@ def create_prediction_ui(self):
hbox.addWidget(self.seg_add_radiobutton)
hbox.addWidget(self.seg_replace_radiobutton)
vbox.addLayout(hbox)
vbox.addWidget(self.postprocess_checkbox)
vbox.addWidget(area_label)
vbox.addWidget(self.area_threshold_textbox)
vbox.addWidget(self.sam_post_checkbox)
vbox.addSpacing(20)
vbox.addWidget(predict_button, alignment=Qt.AlignLeft)
hbox = QHBoxLayout()
hbox.setContentsMargins(0, 0, 0, 0)
Expand All @@ -298,17 +277,84 @@ def create_prediction_ui(self):
gbox.setLayout(layout)
self.base_layout.addWidget(gbox)

def create_postprocessing_ui(self):
smooth_label = QLabel("Smoothing Iterations:")
self.smoothing_iteration_textbox = QLineEdit()
self.smoothing_iteration_textbox.setText("25")
self.smoothing_iteration_textbox.setValidator(
QIntValidator(0, 1000000)
)

area_label = QLabel("Area Threshold:")
self.area_threshold_textbox = QLineEdit()
self.area_threshold_textbox.setText("15")
self.area_threshold_textbox.setValidator(
QIntValidator(0, 2147483647)
)
self.area_threshold_textbox.setToolTip(
"Keeps only regions with area above the threshold."
)
self.area_percent_radiobutton = QRadioButton("percentage")
self.area_percent_radiobutton.setChecked(True)
self.area_abs_radiobutton = QRadioButton("absolute")

self.sam_post_checkbox = QCheckBox("Use SAM Predictor")
sam_label = QLabel(
"This will generate prompts for SAM Predictor using bounding boxes"
" around segmented regions."
)
sam_label.setWordWrap(True)

postprocess_button = QPushButton("Apply")
postprocess_button.setMinimumWidth(150)
self.toggle_postprocess_button = QPushButton("Toggle Off")
self.toggle_postprocess_button.setMinimumWidth(150)
self.toggle_postprocess_button.setCheckable(True)
self.toggle_postprocess_button.setChecked(True)
self.toggle_postprocess_button.clicked.connect(self.toggle_postprocess)
postprocess_button.clicked.connect(self.postprocess_segmentation)
postprocess_all_button = QPushButton("Apply to Stack")
postprocess_all_button.setMinimumWidth(150)
postprocess_all_button.clicked.connect(
lambda: self.postprocess_segmentation(whole_stack=True)
)

layout = QVBoxLayout()
vbox = QVBoxLayout()
vbox.setContentsMargins(0, 0, 0, 0)
vbox.addWidget(smooth_label)
vbox.addWidget(self.smoothing_iteration_textbox)
vbox.addWidget(area_label)
hbox = QHBoxLayout()
hbox.addWidget(self.area_percent_radiobutton)
hbox.addWidget(self.area_abs_radiobutton)
vbox.addLayout(hbox)
vbox.addWidget(self.area_threshold_textbox)
vbox.addSpacing(15)
vbox.addWidget(self.sam_post_checkbox)
vbox.addWidget(sam_label)
vbox.addSpacing(5)
hbox = QHBoxLayout()
hbox.setContentsMargins(0, 0, 0, 0)
hbox.addWidget(postprocess_button)
hbox.addWidget(self.toggle_postprocess_button)
vbox.addLayout(hbox)
vbox.addWidget(postprocess_all_button, alignment=Qt.AlignLeft)
# vbox.addSpacing(20)
layout.addLayout(vbox)

gbox = QGroupBox()
gbox.setTitle("Post-processing")
gbox.setMinimumWidth(100)
gbox.setLayout(layout)
self.base_layout.addWidget(gbox)

def new_layer_checkbox_changed(self):
state = self.new_layer_checkbox.checkState()
self.prediction_layer_combo.setEnabled(state == Qt.Unchecked)
self.seg_add_radiobutton.setEnabled(state == Qt.Unchecked)
self.seg_replace_radiobutton.setEnabled(state == Qt.Unchecked)

def postprocess_checkbox_changed(self):
state = self.postprocess_checkbox.checkState()
self.area_threshold_textbox.setEnabled(state == Qt.Checked)
self.sam_post_checkbox.setEnabled(state == Qt.Checked)

def check_input_layers(self, event: Event):
curr_text = self.image_combo.currentText()
self.image_combo.clear()
Expand Down Expand Up @@ -597,6 +643,8 @@ def run_prediction(self, slice_indices, img_height, img_width):
cm, _ = colormaps.create_colormap(len(np.unique(segmentation_image)))
self.segmentation_layer.colormap = cm
self.segmentation_layer.refresh()
# keep the segmentations before applying postprocessing
self.segmentation_result = self.segmentation_layer.data.copy()

def predict_slice(self, rf_model, slice_index, img_height, img_width):
"""Predict a slice patch by patch"""
Expand Down Expand Up @@ -627,21 +675,6 @@ def predict_slice(self, rf_model, slice_index, img_height, img_width):
# skip paddings
segmentation_image = segmentation_image[:img_height, :img_width]

# check for postprocessing
if self.postprocess_checkbox.checkState() == Qt.Checked:
# apply postprocessing
area_threshold = None
if len(self.area_threshold_textbox.text()) > 0:
area_threshold = float(self.area_threshold_textbox.text()) / 100
if self.sam_post_checkbox.checkState() == Qt.Checked:
segmentation_image = postprocess_segmentations_with_sam(
segmentation_image, area_threshold
)
else:
segmentation_image = postprocess_segmentation(
segmentation_image, area_threshold
)

return segmentation_image

def stop_predict(self):
Expand All @@ -663,6 +696,81 @@ def prediction_is_done(self):
print("Prediction is done!")
notif.show_info("Prediction is done!")

def postprocess_segmentation(self, whole_stack=False):
self.segmentation_layer = get_layer(
self.viewer,
self.prediction_layer_combo.currentText(), config.NAPARI_LABELS_LAYER
)
if self.segmentation_layer is None:
notif.show_error("No segmentation layer is selected!")
return

smoothing_iterations = 25
if len(self.smoothing_iteration_textbox.text()) > 0:
smoothing_iterations = int(self.smoothing_iteration_textbox.text())

area_threshold = 0
if len(self.area_threshold_textbox.text()) > 0:
area_threshold = int(self.area_threshold_textbox.text())
area_is_absolute = False
if self.area_abs_radiobutton.isChecked():
area_is_absolute = True

num_slices, img_height, img_width = get_stack_dims(self.image_layer.data)
slice_indices = []
if not whole_stack:
# only predict the current slice
slice_indices.append(self.viewer.dims.current_step[0])
else:
slice_indices = range(num_slices)

self.postprocess_result = np.zeros(
(num_slices, img_height, img_width), dtype=np.uint8
)
for slice_index in np_progress(slice_indices):
if self.sam_post_checkbox.checkState() == Qt.Checked:
self.postprocess_result[slice_index] = postprocess_with_sam(
self.segmentation_result[slice_index],
smoothing_iterations, area_threshold, area_is_absolute
)
else:
self.postprocess_result[slice_index] = postprocess(
self.segmentation_result[slice_index],
smoothing_iterations, area_threshold, area_is_absolute
)

if self.toggle_postprocess_button.isChecked():
if whole_stack:
self.segmentation_layer.data = self.postprocess_result
else:
curr_slice = self.viewer.dims.current_step[0]
self.segmentation_layer.data[curr_slice] = self.postprocess_result[
curr_slice
]
self.segmentation_layer.refresh()

def toggle_postprocess(self):
self.segmentation_layer = get_layer(
self.viewer,
self.prediction_layer_combo.currentText(), config.NAPARI_LABELS_LAYER
)
if self.segmentation_layer is None:
notif.show_error("No segmentation layer is selected!")
return
if self.postprocess_result is None or self.segmentation_result is None:
notif.show_warning("No postprocessing is done yet!")
return

if self.toggle_postprocess_button.isChecked():
# show the post-processing result
self.toggle_postprocess_button.setText("Toggle Off")
self.segmentation_layer.data = self.postprocess_result
else:
# turn off the post-processing result
self.toggle_postprocess_button.setText("Toggle On")
self.segmentation_layer.data = self.segmentation_result
self.segmentation_layer.refresh()

def export_segmentation(self, out_format="nrrd"):
if self.segmentation_layer is None:
notif.show_error("No segmentation layer is selected!")
Expand Down
8 changes: 8 additions & 0 deletions src/featureforest/postprocess/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .postprocess import postprocess
from .postprocess_with_sam import postprocess_with_sam


__all__ = [
"postprocess",
"postprocess_with_sam"
]
Loading
Loading