Skip to content

Commit

Permalink
fix(widget-pred): combine model filters, #308 related
Browse files Browse the repository at this point in the history
Refactor dimensions, modalities, and output types ComboBoxes into a unified component to improve the interface and enhance user experience

Co-authored-by: Talley Lambert <[email protected]>
  • Loading branch information
qin-yu and tlambert03 committed Sep 11, 2024
1 parent b2395ea commit f3e1667
Showing 1 changed file with 50 additions and 49 deletions.
99 changes: 50 additions & 49 deletions plantseg/viewer_napari/widgets/predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch.cuda
from magicgui import magicgui
from magicgui.types import Separator
from magicgui.widgets import Container, create_widget
from napari.layers import Image
from napari.types import LayerDataTuple

Expand All @@ -19,7 +20,6 @@
from plantseg.viewer_napari.widgets.segmentation import widget_agglomeration, widget_dt_ws, widget_lifted_multicut
from plantseg.viewer_napari.widgets.utils import schedule_task

ALL = 'All'
ALL_CUDA_DEVICES = [f'cuda:{i}' for i in range(torch.cuda.device_count())]
MPS = ['mps'] if torch.backends.mps.is_available() else []
ALL_DEVICES = ALL_CUDA_DEVICES + MPS + ['cpu']
Expand All @@ -28,6 +28,11 @@
SINGLE_PATCH_MODE = [("Auto", False), ("One (lower VRAM usage)", True)]
ADVANCED_SETTINGS = [("Enable", True), ("Disable", False)]

# Using Enum causes more complexity, stay constant
ALL_DIM = 'All dimensions'
ALL_MOD = 'All modalities'
ALL_TYP = 'All types'


########################################################################################################################
# #
Expand All @@ -45,6 +50,36 @@ def to_choices(cls):
return [(mode.value, mode) for mode in cls]


model_filters = Container(
widgets=[
create_widget(
annotation=str,
name="dimensionality",
label='Dimensionality',
widget_type='ComboBox',
options={'choices': [ALL_DIM] + model_zoo.get_unique_dimensionalities()},
),
create_widget(
annotation=str,
name="modality",
label='Microscopy modality',
widget_type='ComboBox',
options={'choices': [ALL_MOD] + model_zoo.get_unique_modalities()},
),
create_widget(
annotation=str,
name="output_type",
label='Prediction type',
widget_type='ComboBox',
options={'choices': [ALL_TYP] + model_zoo.get_unique_output_types()},
),
],
label='Model filters',
layout="horizontal",
labels=False,
)


@magicgui(
call_button='Run Predictions',
mode={
Expand All @@ -55,25 +90,6 @@ def to_choices(cls):
'choices': UNetPredictionsMode.to_choices(),
},
image={'label': 'Image', 'tooltip': 'Raw image to be processed with a neural network.'},
dimensionality={
'label': 'Dimensionality',
'tooltip': 'Dimensionality of the model (2D or 3D). '
'Any 2D model can be used for 3D data. If unsure, select "All".',
'widget_type': 'ComboBox',
'choices': [ALL] + model_zoo.get_unique_dimensionalities(),
},
modality={
'label': 'Microscopy modality',
'tooltip': 'Modality of the model (e.g. confocal, light-sheet ...). If unsure, select "All".',
'widget_type': 'ComboBox',
'choices': [ALL] + model_zoo.get_unique_modalities(),
},
output_type={
'label': 'Prediction type',
'widget_type': 'ComboBox',
'tooltip': 'Type of prediction (e.g. cell boundaries predictions or nuclei...).' ' If unsure, select "All".',
'choices': [ALL] + model_zoo.get_unique_output_types(),
},
model_name={
'label': 'PlantSeg model',
'tooltip': f'Select a pretrained PlantSeg model. '
Expand Down Expand Up @@ -118,9 +134,6 @@ def widget_unet_predictions(
plantseg_filter: bool = True,
model_name: Optional[str] = None,
model_id: Optional[str] = None,
dimensionality: str = ALL,
modality: str = ALL,
output_type: str = ALL,
device: str = ALL_DEVICES[0],
advanced: bool = False,
patch_size: tuple[int, int, int] = (128, 128, 128),
Expand Down Expand Up @@ -157,6 +170,8 @@ def widget_unet_predictions(
)


widget_unet_predictions.insert(5, model_filters)

advanced_unet_predictions_widgets = [
widget_unet_predictions.patch_size,
widget_unet_predictions.patch_halo,
Expand Down Expand Up @@ -211,13 +226,11 @@ def _on_widget_unet_predictions_advanced_changed(advanced):

@widget_unet_predictions.mode.changed.connect
def _on_widget_unet_predictions_mode_change(mode: UNetPredictionsMode):
widgets_p = [
widgets_p = [ # PlantSeg
widget_unet_predictions.model_name,
widget_unet_predictions.dimensionality,
widget_unet_predictions.modality,
widget_unet_predictions.output_type,
model_filters,
]
widgets_b = [
widgets_b = [ # BioImage.IO
widget_unet_predictions.model_id,
widget_unet_predictions.plantseg_filter,
]
Expand Down Expand Up @@ -256,34 +269,22 @@ def _on_widget_unet_predictions_plantseg_filter_change(plantseg_filter: bool):
# _on_prediction_input_image_change(widget_unet_predictions, image)


def _on_any_metadata_changed(modality, output_type, dimensionality):
modality = [modality] if modality != ALL else None
output_type = [output_type] if output_type != ALL else None
dimensionality = [dimensionality] if dimensionality != ALL else None
@model_filters.changed.connect
def _on_any_metadata_changed(widget):
modality = widget.modality.value
output_type = widget.output_type.value
dimensionality = widget.dimensionality.value

modality = [modality] if modality != ALL_MOD else None
output_type = [output_type] if output_type != ALL_TYP else None
dimensionality = [dimensionality] if dimensionality != ALL_DIM else None
widget_unet_predictions.model_name.choices = model_zoo.list_models(
modality_filter=modality,
output_type_filter=output_type,
dimensionality_filter=dimensionality,
)


widget_unet_predictions.modality.changed.connect(
lambda value: _on_any_metadata_changed(
value, widget_unet_predictions.output_type.value, widget_unet_predictions.dimensionality.value
)
)
widget_unet_predictions.output_type.changed.connect(
lambda value: _on_any_metadata_changed(
widget_unet_predictions.modality.value, value, widget_unet_predictions.dimensionality.value
)
)
widget_unet_predictions.dimensionality.changed.connect(
lambda value: _on_any_metadata_changed(
widget_unet_predictions.modality.value, widget_unet_predictions.output_type.value, value
)
)


@widget_unet_predictions.model_name.changed.connect
def _on_model_name_changed(model_name: str):
description = model_zoo.get_model_description(model_name)
Expand Down

0 comments on commit f3e1667

Please sign in to comment.