Skip to content

Commit

Permalink
Merge pull request #13 from juglab/ms/update/export
Browse files Browse the repository at this point in the history
Updated export functionality
  • Loading branch information
mese79 authored Sep 9, 2024
2 parents 4255d84 + 95d13e9 commit 2a43320
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 29 deletions.
20 changes: 10 additions & 10 deletions LICENSE
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
BSD 3-Clause License

Copyright (c) 2024, Mehdi Seifi
All rights reserved.
Copyright (c) 2024, JugLab

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:

* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.

* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.

* Neither the name of copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
Expand Down
76 changes: 58 additions & 18 deletions src/featureforest/_segmentation_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from qtpy.QtGui import QIntValidator, QDoubleValidator

import h5py
import nrrd
import numpy as np
from sklearn.ensemble import RandomForestClassifier

Expand All @@ -39,6 +38,7 @@
postprocess_with_sam_auto,
get_sam_auto_masks
)
from .exports import EXPORTERS


class SegmentationWidget(QWidget):
Expand Down Expand Up @@ -77,6 +77,7 @@ def prepare_widget(self):
self.create_train_ui()
self.create_prediction_ui()
self.create_postprocessing_ui()
self.create_export_ui()

scroll_content = QWidget()
scroll_content.setLayout(self.base_layout)
Expand Down Expand Up @@ -252,10 +253,6 @@ def create_prediction_ui(self):
self.predict_stop_button.setEnabled(False)
self.prediction_progress = QProgressBar()

export_nrrd_button = QPushButton("Export To nrrd")
export_nrrd_button.setMinimumWidth(150)
export_nrrd_button.clicked.connect(lambda: self.export_segmentation("nrrd"))

# layout
layout = QVBoxLayout()
vbox = QVBoxLayout()
Expand All @@ -275,10 +272,6 @@ def create_prediction_ui(self):
hbox.addWidget(self.predict_stop_button, alignment=Qt.AlignLeft)
vbox.addLayout(hbox)
vbox.addWidget(self.prediction_progress)
hbox = QHBoxLayout()
hbox.setContentsMargins(0, 15, 0, 0)
hbox.addWidget(export_nrrd_button, alignment=Qt.AlignLeft)
vbox.addLayout(hbox)
layout.addLayout(vbox)

gbox = QGroupBox()
Expand Down Expand Up @@ -365,8 +358,8 @@ def create_postprocessing_ui(self):
vbox.addSpacing(7)
hbox = QHBoxLayout()
hbox.setContentsMargins(0, 0, 0, 0)
hbox.addWidget(postprocess_button)
hbox.addWidget(postprocess_all_button)
hbox.addWidget(postprocess_button, alignment=Qt.AlignLeft)
hbox.addWidget(postprocess_all_button, alignment=Qt.AlignLeft)
vbox.addLayout(hbox)
# vbox.addSpacing(20)
layout.addLayout(vbox)
Expand All @@ -377,6 +370,41 @@ def create_postprocessing_ui(self):
gbox.setLayout(layout)
self.base_layout.addWidget(gbox)

def create_export_ui(self):
export_label = QLabel("Export Format:")
self.export_format_combo = QComboBox()
for exporter in EXPORTERS:
self.export_format_combo.addItem(exporter)

self.export_postprocess_checkbox = QCheckBox("Export with Post-processing")
self.export_postprocess_checkbox.setChecked(True)
self.export_postprocess_checkbox.setToolTip(
"Export segmentation result with applied post-processing, if available."
)

export_button = QPushButton("Export")
export_button.setMinimumWidth(150)
export_button.clicked.connect(self.export_segmentation)

# layout
layout = QVBoxLayout()
vbox = QVBoxLayout()
vbox.setContentsMargins(0, 0, 0, 0)
hbox = QHBoxLayout()
hbox.setContentsMargins(0, 0, 0, 0)
hbox.addWidget(export_label)
hbox.addWidget(self.export_format_combo)
vbox.addLayout(hbox)
vbox.addWidget(self.export_postprocess_checkbox)
vbox.addWidget(export_button, alignment=Qt.AlignLeft)
layout.addLayout(vbox)

gbox = QGroupBox()
gbox.setTitle("Export")
gbox.setMinimumWidth(100)
gbox.setLayout(layout)
self.base_layout.addWidget(gbox)

def new_layer_checkbox_changed(self):
state = self.new_layer_checkbox.checkState()
self.prediction_layer_combo.setEnabled(state == Qt.Unchecked)
Expand Down Expand Up @@ -809,16 +837,28 @@ def postprocess_segmentation(self, whole_stack=False):

self.postprocess_layer.refresh()

def export_segmentation(self, out_format="nrrd"):
def export_segmentation(self):
if self.segmentation_layer is None:
notif.show_error("No segmentation layer is selected!")
return

exporter = EXPORTERS[self.export_format_combo.currentText()]
# export_format = self.export_format_combo.currentText()
selected_file, _filter = QFileDialog.getSaveFileName(
self, "Jug Lab", ".", "Segmentation(*.nrrd)"
self, "Jug Lab", ".", f"Segmentation(*.{exporter.extension})"
)
if selected_file is not None and len(selected_file) > 0:
if not selected_file.endswith(".nrrd"):
selected_file += ".nrrd"
nrrd.write(selected_file, np.transpose(self.segmentation_layer.data))
notif.show_info("Selected segmentation was exported successfully.")
if selected_file is None or len(selected_file) == 0:
return # user canceled export

if not selected_file.endswith(f".{exporter.extension}"):
selected_file += f".{exporter.extension}"
layer_to_export = self.segmentation_layer
if (
self.export_postprocess_checkbox.isChecked() and
self.postprocess_layer is not None
):
layer_to_export = self.postprocess_layer

exporter.export(layer_to_export, selected_file)

notif.show_info("Selected layer was exported successfully.")
10 changes: 10 additions & 0 deletions src/featureforest/exports/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from .base import (
TiffExporter, NRRDExporter, NumpyExporter
)


EXPORTERS = {
"tiff": TiffExporter(),
"nrrd": NRRDExporter(),
"numpy": NumpyExporter(),
}
56 changes: 56 additions & 0 deletions src/featureforest/exports/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# from pathlib import Path

import nrrd
import numpy as np
from tifffile import imwrite
from napari.layers import Layer


class BaseExporter:
"""Base Exporter Class: all exporters should be a subclass of this class."""
def __init__(self, name: str = "Base Exporter", extension: str = "bin") -> None:
self.name = name
self.extension = extension

def export(self, layer: Layer, export_file: str) -> None:
"""Export the given layer data
Args:
layer (Layer): layer to export the data from
export_file (str): file path to export
"""
# implement actual export method here
return


class TiffExporter(BaseExporter):
"""Export the layer's data into TIFF format."""
def __init__(self, name: str = "TIFF", extension: str = "tiff") -> None:
super().__init__(name, extension)

def export(self, layer: Layer, export_file: str) -> None:
tiff_data = layer.data.astype(np.uint8)
mask_values = np.unique(tiff_data)
if len(mask_values) == 2:
# this is a binary mask
tiff_data[tiff_data == min(mask_values)] = 0
tiff_data[tiff_data == max(mask_values)] = 255
imwrite(export_file, tiff_data)


class NRRDExporter(BaseExporter):
"""Export the layer's data into NRRD format."""
def __init__(self, name: str = "NRRD", extension: str = "nrrd") -> None:
super().__init__(name, extension)

def export(self, layer: Layer, export_file: str) -> None:
nrrd.write(export_file, np.transpose(layer.data))


class NumpyExporter(BaseExporter):
"""Export the layer's data into a numpy array file."""
def __init__(self, name: str = "Numpy", extension: str = "npy") -> None:
super().__init__(name, extension)

def export(self, layer: Layer, export_file: str) -> None:
return np.save(export_file, layer.data)
2 changes: 1 addition & 1 deletion src/featureforest/utils/colormaps.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

def is_new_napari():
version = napari.__version__.split(".")
return int(version[2]) > 18
return int(version[1]) > 4 or int(version[2]) > 18


def bit_get(val, idx):
Expand Down

0 comments on commit 2a43320

Please sign in to comment.