diff --git a/_version.py b/_version.py deleted file mode 100644 index ec48c6c..0000000 --- a/_version.py +++ /dev/null @@ -1,2 +0,0 @@ -# Versioning bound to TARDIS-em versioning -version = "0.0.1" diff --git a/napari.yaml b/napari.yaml deleted file mode 100644 index 959f72e..0000000 --- a/napari.yaml +++ /dev/null @@ -1,97 +0,0 @@ -####################################################################### -# TARDIS - Transformer And Rapid Dimensionless Instance Segmentation # -# # -# New York Structural Biology Center # -# Simons Machine Learning Center # -# # -# Robert Kiewisz, Tristan Bepler # -# MIT License 2024 # -####################################################################### -name: napari-tardis-em -display_name: TARDIS-em napari plugin -contributions: - commands: - # Read/Write files Images -# - id: napari-tardis_em.import_data -# python_name: napari-tardis_em.utils.import_data:load_images -# title: Open micrograph|tomogram -# -# - id: napari-tardis_em.export_image -# python_name: napari-tardis_em.utils.export_data:export_images -# title: Save micrograph|tomogram -# -# - id: napari-tardis_em.export_coord -# python_name: napari-tardis_em.utils.export_coord:export_coord -# title: Save instance segmentation - - # Widgets - # General - - id: napari-tardis-em.viewer_train - python_name: napari_tardis_em.viewers.viewer_train:TardisWidget - title: Train TARDIS CNN - - - id: napari-tardis-em.viewer_predict - python_name: napari_tardis_em.viewers.viewer_predict:TardisWidget - title: Predict TARDIS CNN - - # Microtubules - - id: napari-tardis-em.viewer_mt_3d - python_name: napari_tardis_em.viewers.viewer_mt_3d:TardisWidget - title: Predict Microtubules 3D - -# - id: napari-tardis-em.viewer_mt_2d -# python_name: napari_tardis_em.viewers.viewer_mt_3d:TardisWidget -# title: Predict Microtubules 2D - - # Membrane - - id: napari-tardis-em.viewer_mem_3d - python_name: napari_tardis_em.viewers.viewer_mem_3d:TardisWidget - title: Predict Membrane 3D - - - id: napari-tardis-em.viewer_mem_2d - python_name: napari_tardis_em.viewers.viewer_mt_2d:TardisWidget - title: Predict Microtubules 2D - -# # Actin - - id: napari-tardis-em.viewer_actin_3d - python_name: napari_tardis_em.viewers.viewer_actin_3d:TardisWidget - title: Predict Actin 3D - -# readers: -# - command: napari_tardis_em.import_data -# accepts_directories: False -# filename_patterns: ['*.mrc', '*.rec', '*.tiff', '*.tif', '*.am'] - -# writers: -# - command: napari_tardis_em.export_image -# layer_types: ['image*'] -# filename_extensions: ['*.mrc', '*.rec', '*.tiff', '*.tif', '*.am'] -# - command: napari_tardis_em.export_coord -# layer_types: ['points*'] -# filename_extensions: ['*.csv', '*.npy', '*.am'] - - widgets: - # General - - command: napari-tardis-em.viewer_train - display_name: Train TARDIS CNN - - - command: napari-tardis-em.viewer_predict - display_name: Predict TARDIS CNN - - # Microtubules - - command: napari-tardis-em.viewer_mt_3d - display_name: Predict Microtubules 3D - -# - command: napari-tardis-em.viewer_mt_2d -# display_name: Predict Microtubules 2D - - # Membrane - - command: napari-tardis-em.viewer_mem_3d - display_name: Predict Membrane 3D - - - command: napari-tardis-em.viewer_mem_2d - display_name: Predict Membrane 2D - -# # Actin - - command: napari-tardis-em.viewer_actin_3d - display_name: Predict Actin 3D \ No newline at end of file diff --git a/src/napari_tardis_em/napari.yaml b/src/napari_tardis_em/napari.yaml index 959f72e..e219aa0 100644 --- a/src/napari_tardis_em/napari.yaml +++ b/src/napari_tardis_em/napari.yaml @@ -40,7 +40,7 @@ contributions: title: Predict Microtubules 3D # - id: napari-tardis-em.viewer_mt_2d -# python_name: napari_tardis_em.viewers.viewer_mt_3d:TardisWidget +# python_name: napari_tardis_em.viewers.viewer_mt_2d:TardisWidget # title: Predict Microtubules 2D # Membrane @@ -49,7 +49,7 @@ contributions: title: Predict Membrane 3D - id: napari-tardis-em.viewer_mem_2d - python_name: napari_tardis_em.viewers.viewer_mt_2d:TardisWidget + python_name: napari_tardis_em.viewers.viewer_mem_2d:TardisWidget title: Predict Microtubules 2D # # Actin diff --git a/src/napari_tardis_em/viewers/utils.py b/src/napari_tardis_em/viewers/utils.py index 9567d82..214b64c 100644 --- a/src/napari_tardis_em/viewers/utils.py +++ b/src/napari_tardis_em/viewers/utils.py @@ -73,9 +73,12 @@ def update_viewer_prediction(viewer, image: np.ndarray, position: dict): position["x"][1] = position["x"][0] + diff[1] img.data[ - position["y"][0]: position["y"][1], - position["x"][0]: position["x"][1], - ] = image[: diff[0], : diff[1],] + position["y"][0] : position["y"][1], + position["x"][0] : position["x"][1], + ] = image[ + : diff[0], + : diff[1], + ] viewer.layers["Prediction"].visible = False viewer.layers["Prediction"].visible = True @@ -101,7 +104,7 @@ def create_point_layer( pass point_features = { - "confidence": tuple(points[:, 0].flatten()*np.random.randint(100)), + "confidence": tuple(points[:, 0].flatten() * np.random.randint(100)), } points = np.array(points[:, 1:]) @@ -111,11 +114,7 @@ def create_point_layer( points = np.hstack((points, z)) # Convert xyz to zyx - points = np.vstack(( - points[:, 2], - points[:, 1], - points[:, 0] - )).T + points = np.vstack((points[:, 2], points[:, 1], points[:, 0])).T viewer.layers.select_all() viewer.layers.toggle_selected_visibility() @@ -198,7 +197,9 @@ def create_image_layer( viewer.layers[name].visible = False -def setup_environment_and_dataset(dir_, mask_size, pixel_size, patch_size, correct_pixel_size=None): +def setup_environment_and_dataset( + dir_, mask_size, pixel_size, patch_size, correct_pixel_size=None +): """Set environment""" TRAIN_IMAGE_DIR = join(dir_, "train", "imgs") TRAIN_MASK_DIR = join(dir_, "train", "masks") @@ -213,7 +214,13 @@ def setup_environment_and_dataset(dir_, mask_size, pixel_size, patch_size, corre img_format=IMG_FORMAT, test_img=TEST_IMAGE_DIR, test_mask=TEST_MASK_DIR, - mask_format=("_mask.am", ".CorrelationLines.am", "_mask.mrc", "_mask.tif", "_mask.csv"), + mask_format=( + "_mask.am", + ".CorrelationLines.am", + "_mask.mrc", + "_mask.tif", + "_mask.csv", + ), ) """Optionally: Set-up environment if not existing""" diff --git a/src/napari_tardis_em/viewers/viewer_actin_3d.py b/src/napari_tardis_em/viewers/viewer_actin_3d.py index 0b48c34..57c3255 100644 --- a/src/napari_tardis_em/viewers/viewer_actin_3d.py +++ b/src/napari_tardis_em/viewers/viewer_actin_3d.py @@ -39,6 +39,7 @@ from tardis_em.utils.normalization import adaptive_threshold from tardis_em.utils.predictor import GeneralPredictor from tardis_em.utils.setup_envir import clean_up +from tardis_em.utils.spline_metric import sort_by_length from napari_tardis_em.viewers.styles import border_style from napari_tardis_em.utils.utils import get_list_of_device @@ -133,8 +134,12 @@ def __init__(self, viewer_actin_3d: Viewer): self.cnn_type.setToolTip("Select type of CNN you would like to train.") self.cnn_type.currentIndexChanged.connect(self.update_versions) - self.checkpoint = QLineEdit("None") - self.checkpoint.setToolTip("Optional, directory to CNN checkpoint.") + self.checkpoint = QPushButton("None") + self.checkpoint.setToolTip( + "Optional, directory to CNN checkpoint to restart training." + ) + self.checkpoint.clicked.connect(self.update_checkpoint_dir) + self.checkpoint_dir = None self.patch_size = QComboBox() self.patch_size.addItems( @@ -184,6 +189,7 @@ def __init__(self, viewer_actin_3d: Viewer): "false/positives. Higher value will result in cleaner output but may \n" "reduce recall." ) + self.dist_threshold.valueChanged.connect(self.update_dist_graph) self.device = QComboBox() self.device.addItems(get_list_of_device()) @@ -240,6 +246,7 @@ def __init__(self, viewer_actin_3d: Viewer): "length in angstrom. All filaments shorter then this length \n" "will be deleted." ) + self.filter_by_length.textChanged.connect(self.update_dist_graph) self.connect_splines = QLineEdit("2500") self.connect_splines.setValidator(QIntValidator(0, 10000)) @@ -252,6 +259,7 @@ def __init__(self, viewer_actin_3d: Viewer): "determines how far apart two actin can be, while still being considered \n" "as a single unit if they are oriented in the same direction." ) + self.connect_splines.textChanged.connect(self.update_dist_graph) self.connect_cylinder = QLineEdit("250") self.connect_cylinder.setValidator(QIntValidator(0, 10000)) @@ -263,6 +271,7 @@ def __init__(self, viewer_actin_3d: Viewer): "The ends of these filaments must be located within this cylinder \n" "to be considered connected." ) + self.connect_cylinder.textChanged.connect(self.update_dist_graph) """""" """""" """ UI Setup @@ -302,6 +311,117 @@ def __init__(self, viewer_actin_3d: Viewer): self.setLayout(layout) + def update_checkpoint_dir(self): + filename, _ = QFileDialog.getOpenFileName( + caption="Open File", + directory=getcwd(), + ) + self.checkpoint.setText(filename[-30:]) + self.checkpoint_dir = filename + + def update_versions(self): + for i in range(self.model_version.count()): + self.model_version.removeItem(0) + + versions = get_all_version_aws(self.cnn_type.currentText(), "32", "actin_3d") + + if len(versions) == 0: + self.model_version.addItems(["None"]) + else: + self.model_version.addItems(["None"] + [i.split("_")[-1] for i in versions]) + + def update_cnn_threshold(self): + if self.img is not None: + self.viewer.layers[self.dir.split("/")[-1]].visible = True + + if float(self.cnn_threshold.text()) == 1.0: + self.img_threshold = adaptive_threshold(self.img).astype(np.uint8) + elif float(self.cnn_threshold.text()) == 0.0: + self.img_threshold = np.copy(self.img) + else: + self.img_threshold = np.where( + self.img >= float(self.cnn_threshold.text()), 1, 0 + ).astype(np.uint8) + + create_image_layer( + self.viewer, + image=self.img_threshold, + name="Prediction", + transparency=True, + range_=(0, 1), + ) + + self.predictor.image = self.img_threshold + self.predictor.save_semantic_mask(self.dir.split("/")[-1]) + + def update_dist_layer(self): + if self.predictor.segments is not None: + create_point_layer( + viewer=self.viewer, + points=self.predictor.segments, + name="Predicted_Instances", + visibility=True, + ) + else: + return + + def update_dist_graph(self): + if self.predictor is not None: + if self.predictor.graphs is not None: + if bool(self.filament.checkState()): + sort = True + prune = 5 + else: + sort = False + prune = 15 + + try: + self.predictor.segments = ( + self.predictor.GraphToSegment.patch_to_segment( + graph=self.predictor.graphs, + coord=self.predictor.pc_ld, + idx=self.predictor.output_idx, + sort=sort, + prune=prune, + ) + ) + self.predictor.segments = sort_by_length(self.predictor.segments) + except: + self.predictor.segments = None + + if self.predictor.segments is None: + show_info("TARDIS-em could not find any instances :(") + return + else: + show_info( + f"TARDIS-em found {int(np.max(self.predictor.segments[:, 0]))} instances :)" + ) + self.predictor.save_instance_PC(self.dir.split("/")[-1]) + + def calculate_position(self, name): + patch_size = int(self.patch_size.currentText()) + name = name.split("_") + name = { + "z": int(name[1]), + "y": int(name[2]), + "x": int(name[3]), + "stride": int(name[4]), + } + + x_start = (name["x"] * patch_size) - (name["x"] * name["stride"]) + x_end = x_start + patch_size + name["x"] = [x_start, x_end] + + y_start = (name["y"] * patch_size) - (name["y"] * name["stride"]) + y_end = y_start + patch_size + name["y"] = [y_start, y_end] + + z_start = (name["z"] * patch_size) - (name["z"] * name["stride"]) + z_end = z_start + patch_size + name["z"] = [z_start, z_end] + + return name + def load_directory(self): filename, _ = QFileDialog.getOpenFileName( caption="Open File", @@ -387,7 +507,7 @@ def predict_semantic(self): correct_px=correct_px, convolution_nn=self.cnn_type.currentText(), checkpoint=[ - None if self.checkpoint.text() == "None" else self.checkpoint.text(), + None if self.checkpoint.text() == "None" else self.checkpoint_dir, None, ], model_version=model_version, @@ -486,17 +606,6 @@ def predict_dataset(img_dataset_, predictor): else: return - def update_dist_layer(self): - if self.predictor.segments is not None: - create_point_layer( - viewer=self.viewer, - points=self.predictor.segments, - name="Predicted_Instances", - visibility=True, - ) - else: - return - def predict_instance(self): if self.predictor is None: show_error(f"Please initialize with 'Predict Semantic' button") @@ -579,7 +688,7 @@ def show_command(self): ch = ( "" if self.checkpoint.text() == "None" - else f"-ch {self.checkpoint.text()}_None " + else f"-ch {self.checkpoint_dir}_None " ) mv = ( @@ -643,62 +752,3 @@ def show_command(self): f"-pv {int(self.points_in_patch.text())} " f"-dv {self.device.currentText()}" ) - - def update_cnn_threshold(self): - if self.img is not None: - self.viewer.layers[self.dir.split("/")[-1]].visible = True - - if float(self.cnn_threshold.text()) == 1.0: - self.img_threshold = adaptive_threshold(self.img).astype(np.uint8) - elif float(self.cnn_threshold.text()) == 0.0: - self.img_threshold = np.copy(self.img) - else: - self.img_threshold = np.where( - self.img >= float(self.cnn_threshold.text()), 1, 0 - ).astype(np.uint8) - - create_image_layer( - self.viewer, - image=self.img_threshold, - name="Prediction", - transparency=True, - range_=(0, 1), - ) - - self.predictor.image = self.img_threshold - self.predictor.save_semantic_mask(self.dir.split("/")[-1]) - - def update_versions(self): - for i in range(self.model_version.count()): - self.model_version.removeItem(0) - - versions = get_all_version_aws(self.cnn_type.currentText(), "32", "actin_3d") - - if len(versions) == 0: - self.model_version.addItems(["None"]) - else: - self.model_version.addItems(["None"] + [i.split("_")[-1] for i in versions]) - - def calculate_position(self, name): - patch_size = int(self.patch_size.currentText()) - name = name.split("_") - name = { - "z": int(name[1]), - "y": int(name[2]), - "x": int(name[3]), - "stride": int(name[4]), - } - - x_start = (name["x"] * patch_size) - (name["x"] * name["stride"]) - x_end = x_start + patch_size - name["x"] = [x_start, x_end] - - y_start = (name["y"] * patch_size) - (name["y"] * name["stride"]) - y_end = y_start + patch_size - name["y"] = [y_start, y_end] - - z_start = (name["z"] * patch_size) - (name["z"] * name["stride"]) - z_end = z_start + patch_size - name["z"] = [z_start, z_end] - - return name diff --git a/src/napari_tardis_em/viewers/viewer_mem_2d.py b/src/napari_tardis_em/viewers/viewer_mem_2d.py index 5a4733c..516ba3c 100644 --- a/src/napari_tardis_em/viewers/viewer_mem_2d.py +++ b/src/napari_tardis_em/viewers/viewer_mem_2d.py @@ -39,6 +39,7 @@ from tardis_em.utils.normalization import adaptive_threshold from tardis_em.utils.predictor import GeneralPredictor from tardis_em.utils.setup_envir import clean_up +from tardis_em.utils.spline_metric import sort_by_length from napari_tardis_em.viewers.styles import border_style from napari_tardis_em.utils.utils import get_list_of_device @@ -133,8 +134,12 @@ def __init__(self, viewer_mem_2d: Viewer): self.cnn_type.setToolTip("Select type of CNN you would like to train.") self.cnn_type.currentIndexChanged.connect(self.update_versions) - self.checkpoint = QLineEdit("None") - self.checkpoint.setToolTip("Optional, directory to CNN checkpoint.") + self.checkpoint = QPushButton("None") + self.checkpoint.setToolTip( + "Optional, directory to CNN checkpoint to restart training." + ) + self.checkpoint.clicked.connect(self.update_checkpoint_dir) + self.checkpoint_dir = None self.patch_size = QComboBox() self.patch_size.addItems( @@ -184,6 +189,7 @@ def __init__(self, viewer_mem_2d: Viewer): "false/positives. Higher value will result in cleaner output but may \n" "reduce recall." ) + self.dist_threshold.valueChanged.connect(self.update_dist_graph) self.device = QComboBox() self.device.addItems(get_list_of_device()) @@ -244,6 +250,7 @@ def __init__(self, viewer_mem_2d: Viewer): "determines how far apart two membranes can be, while still being considered \n" "as a single unit if they are oriented in the same direction." ) + self.connect_membranes.textChanged.connect(self.update_dist_graph) self.connect_cylinder = QLineEdit("250") self.connect_cylinder.setValidator(QIntValidator(0, 10000)) @@ -255,6 +262,7 @@ def __init__(self, viewer_mem_2d: Viewer): "The ends of these filaments must be located within this cylinder \n" "to be considered connected." ) + self.connect_cylinder.textChanged.connect(self.update_dist_graph) """""" """""" """ UI Setup @@ -292,6 +300,119 @@ def __init__(self, viewer_mem_2d: Viewer): self.setLayout(layout) + def update_checkpoint_dir(self): + filename, _ = QFileDialog.getOpenFileName( + caption="Open File", + directory=getcwd(), + ) + self.checkpoint.setText(filename[-30:]) + self.checkpoint_dir = filename + + def update_versions(self): + for i in range(self.model_version.count()): + self.model_version.removeItem(0) + + versions = get_all_version_aws(self.cnn_type.currentText(), "32", "membrane_2d") + + if len(versions) == 0: + self.model_version.addItems(["None"]) + else: + self.model_version.addItems(["None"] + [i.split("_")[-1] for i in versions]) + + def update_cnn_threshold(self): + if self.img is not None: + self.viewer.layers[self.dir.split("/")[-1]].visible = True + + if float(self.cnn_threshold.text()) == 1.0: + self.img_threshold = adaptive_threshold(self.img).astype(np.uint8) + elif float(self.cnn_threshold.text()) == 0.0: + self.img_threshold = np.copy(self.img) + else: + self.img_threshold = np.where( + self.img >= float(self.cnn_threshold.text()), 1, 0 + ).astype(np.uint8) + + create_image_layer( + self.viewer, + image=self.img_threshold, + name="Prediction", + transparency=True, + range_=(0, 1), + ) + + self.predictor.image = self.img_threshold + self.predictor.save_semantic_mask(self.dir.split("/")[-1]) + + def update_dist_layer(self): + self.predictor.image = self.img_threshold + + if self.predictor.segments is not None: + create_point_layer( + viewer=self.viewer, + points=self.predictor.segments, + name="Predicted_Instances", + visibility=True, + ) + else: + return + + def update_dist_graph(self): + if self.predictor is not None: + if self.predictor.graphs is not None: + if bool(self.filament.checkState()): + sort = True + prune = 5 + else: + sort = False + prune = 15 + + try: + self.predictor.segments = ( + self.predictor.GraphToSegment.patch_to_segment( + graph=self.predictor.graphs, + coord=self.predictor.pc_ld, + idx=self.predictor.output_idx, + sort=sort, + prune=prune, + ) + ) + self.predictor.segments = sort_by_length(self.predictor.segments) + except: + self.predictor.segments = None + + if self.predictor.segments is None: + show_info("TARDIS-em could not find any instances :(") + return + else: + show_info( + f"TARDIS-em found {int(np.max(self.predictor.segments[:, 0]))} instances :)" + ) + self.predictor.save_instance_PC(self.dir.split("/")[-1]) + + def calculate_position(self, name): + patch_size = int(self.patch_size.currentText()) + name = name.split("_") + name = { + "z": int(name[1]), + "y": int(name[2]), + "x": int(name[3]), + "stride": int(name[4]), + } + + x_start = (name["x"] * patch_size) - (name["x"] * name["stride"]) + x_end = x_start + patch_size + name["x"] = [x_start, x_end] + + y_start = (name["y"] * patch_size) - (name["y"] * name["stride"]) + y_end = y_start + patch_size + name["y"] = [y_start, y_end] + + z_start = (name["z"] * patch_size) - (name["z"] * name["stride"]) + z_end = z_start + patch_size + name["z"] = [z_start, z_end] + + return name + def load_directory(self): filename, _ = QFileDialog.getOpenFileName( caption="Open File", @@ -377,7 +498,7 @@ def predict_semantic(self): correct_px=correct_px, convolution_nn=self.cnn_type.currentText(), checkpoint=[ - None if self.checkpoint.text() == "None" else self.checkpoint.text(), + None if self.checkpoint.text() == "None" else self.checkpoint_dir, None, ], model_version=model_version, @@ -473,19 +594,6 @@ def predict_dataset(img_dataset_, predictor): else: return - def update_dist_layer(self): - self.predictor.image = self.img_threshold - - if self.predictor.segments is not None: - create_point_layer( - viewer=self.viewer, - points=self.predictor.segments, - name="Predicted_Instances", - visibility=True, - ) - else: - return - def predict_instance(self): if self.predictor is None: show_error(f"Please initialize with 'Predict Semantic' button") @@ -568,7 +676,7 @@ def show_command(self): ch = ( "" if self.checkpoint.text() == "None" - else f"-ch {self.checkpoint.text()}_None " + else f"-ch {self.checkpoint_dir}_None " ) mv = ( @@ -626,62 +734,3 @@ def show_command(self): f"-pv {int(self.points_in_patch.text())} " f"-dv {self.device.currentText()}" ) - - def update_cnn_threshold(self): - if self.img is not None: - self.viewer.layers[self.dir.split("/")[-1]].visible = True - - if float(self.cnn_threshold.text()) == 1.0: - self.img_threshold = adaptive_threshold(self.img).astype(np.uint8) - elif float(self.cnn_threshold.text()) == 0.0: - self.img_threshold = np.copy(self.img) - else: - self.img_threshold = np.where( - self.img >= float(self.cnn_threshold.text()), 1, 0 - ).astype(np.uint8) - - create_image_layer( - self.viewer, - image=self.img_threshold, - name="Prediction", - transparency=True, - range_=(0, 1), - ) - - self.predictor.image = self.img_threshold - self.predictor.save_semantic_mask(self.dir.split("/")[-1]) - - def update_versions(self): - for i in range(self.model_version.count()): - self.model_version.removeItem(0) - - versions = get_all_version_aws(self.cnn_type.currentText(), "32", "membrane_2d") - - if len(versions) == 0: - self.model_version.addItems(["None"]) - else: - self.model_version.addItems(["None"] + [i.split("_")[-1] for i in versions]) - - def calculate_position(self, name): - patch_size = int(self.patch_size.currentText()) - name = name.split("_") - name = { - "z": int(name[1]), - "y": int(name[2]), - "x": int(name[3]), - "stride": int(name[4]), - } - - x_start = (name["x"] * patch_size) - (name["x"] * name["stride"]) - x_end = x_start + patch_size - name["x"] = [x_start, x_end] - - y_start = (name["y"] * patch_size) - (name["y"] * name["stride"]) - y_end = y_start + patch_size - name["y"] = [y_start, y_end] - - z_start = (name["z"] * patch_size) - (name["z"] * name["stride"]) - z_end = z_start + patch_size - name["z"] = [z_start, z_end] - - return name diff --git a/src/napari_tardis_em/viewers/viewer_mem_3d.py b/src/napari_tardis_em/viewers/viewer_mem_3d.py index aba286c..52e9521 100644 --- a/src/napari_tardis_em/viewers/viewer_mem_3d.py +++ b/src/napari_tardis_em/viewers/viewer_mem_3d.py @@ -39,6 +39,7 @@ from tardis_em.utils.normalization import adaptive_threshold from tardis_em.utils.predictor import GeneralPredictor from tardis_em.utils.setup_envir import clean_up +from tardis_em.utils.spline_metric import sort_by_length from napari_tardis_em.viewers.styles import border_style from napari_tardis_em.utils.utils import get_list_of_device @@ -133,8 +134,12 @@ def __init__(self, viewer_mem_3d: Viewer): self.cnn_type.setToolTip("Select type of CNN you would like to train.") self.cnn_type.currentIndexChanged.connect(self.update_versions) - self.checkpoint = QLineEdit("None") - self.checkpoint.setToolTip("Optional, directory to CNN checkpoint.") + self.checkpoint = QPushButton("None") + self.checkpoint.setToolTip( + "Optional, directory to CNN checkpoint to restart training." + ) + self.checkpoint.clicked.connect(self.update_checkpoint_dir) + self.checkpoint_dir = None self.patch_size = QComboBox() self.patch_size.addItems( @@ -184,6 +189,7 @@ def __init__(self, viewer_mem_3d: Viewer): "false/positives. Higher value will result in cleaner output but may \n" "reduce recall." ) + self.dist_threshold.valueChanged.connect(self.update_dist_graph) self.device = QComboBox() self.device.addItems(get_list_of_device()) @@ -259,6 +265,117 @@ def __init__(self, viewer_mem_3d: Viewer): self.setLayout(layout) + def update_checkpoint_dir(self): + filename, _ = QFileDialog.getOpenFileName( + caption="Open File", + directory=getcwd(), + ) + self.checkpoint.setText(filename[-30:]) + self.checkpoint_dir = filename + + def update_versions(self): + for i in range(self.model_version.count()): + self.model_version.removeItem(0) + + versions = get_all_version_aws(self.cnn_type.currentText(), "32", "membrane_3d") + + if len(versions) == 0: + self.model_version.addItems(["None"]) + else: + self.model_version.addItems(["None"] + [i.split("_")[-1] for i in versions]) + + def update_cnn_threshold(self): + if self.img is not None: + self.viewer.layers[self.dir.split("/")[-1]].visible = True + + if float(self.cnn_threshold.text()) == 1.0: + self.img_threshold = adaptive_threshold(self.img).astype(np.uint8) + elif float(self.cnn_threshold.text()) == 0.0: + self.img_threshold = np.copy(self.img) + else: + self.img_threshold = np.where( + self.img >= float(self.cnn_threshold.text()), 1, 0 + ).astype(np.uint8) + + create_image_layer( + self.viewer, + image=self.img_threshold, + name="Prediction", + transparency=True, + range_=(0, 1), + ) + + self.predictor.image = self.img_threshold + self.predictor.save_semantic_mask(self.dir.split("/")[-1]) + + def update_dist_layer(self): + if self.predictor.segments is not None: + create_point_layer( + viewer=self.viewer, + points=self.predictor.segments, + name="Predicted_Instances", + visibility=True, + ) + else: + return + + def update_dist_graph(self): + if self.predictor is not None: + if self.predictor.graphs is not None: + if bool(self.filament.checkState()): + sort = True + prune = 5 + else: + sort = False + prune = 15 + + try: + self.predictor.segments = ( + self.predictor.GraphToSegment.patch_to_segment( + graph=self.predictor.graphs, + coord=self.predictor.pc_ld, + idx=self.predictor.output_idx, + sort=sort, + prune=prune, + ) + ) + self.predictor.segments = sort_by_length(self.predictor.segments) + except: + self.predictor.segments = None + + if self.predictor.segments is None: + show_info("TARDIS-em could not find any instances :(") + return + else: + show_info( + f"TARDIS-em found {int(np.max(self.predictor.segments[:, 0]))} instances :)" + ) + self.predictor.save_instance_PC(self.dir.split("/")[-1]) + + def calculate_position(self, name): + patch_size = int(self.patch_size.currentText()) + name = name.split("_") + name = { + "z": int(name[1]), + "y": int(name[2]), + "x": int(name[3]), + "stride": int(name[4]), + } + + x_start = (name["x"] * patch_size) - (name["x"] * name["stride"]) + x_end = x_start + patch_size + name["x"] = [x_start, x_end] + + y_start = (name["y"] * patch_size) - (name["y"] * name["stride"]) + y_end = y_start + patch_size + name["y"] = [y_start, y_end] + + z_start = (name["z"] * patch_size) - (name["z"] * name["stride"]) + z_end = z_start + patch_size + name["z"] = [z_start, z_end] + + return name + def load_directory(self): filename, _ = QFileDialog.getOpenFileName( caption="Open File", @@ -344,7 +461,7 @@ def predict_semantic(self): correct_px=correct_px, convolution_nn=self.cnn_type.currentText(), checkpoint=[ - None if self.checkpoint.text() == "None" else self.checkpoint.text(), + None if self.checkpoint.text() == "None" else self.checkpoint_dir, None, ], model_version=model_version, @@ -437,17 +554,6 @@ def predict_dataset(img_dataset_, predictor): else: return - def update_dist_layer(self): - if self.predictor.segments is not None: - create_point_layer( - viewer=self.viewer, - points=self.predictor.segments, - name="Predicted_Instances", - visibility=True, - ) - else: - return - def predict_instance(self): if self.predictor is None: show_error(f"Please initialize with 'Predict Semantic' button") @@ -530,7 +636,7 @@ def show_command(self): ch = ( "" if self.checkpoint.text() == "None" - else f"-ch {self.checkpoint.text()}_None " + else f"-ch {self.checkpoint_dir}_None " ) mv = ( @@ -575,62 +681,3 @@ def show_command(self): f"-pv {int(self.points_in_patch.text())} " f"-dv {self.device.currentText()}" ) - - def update_cnn_threshold(self): - if self.img is not None: - self.viewer.layers[self.dir.split("/")[-1]].visible = True - - if float(self.cnn_threshold.text()) == 1.0: - self.img_threshold = adaptive_threshold(self.img).astype(np.uint8) - elif float(self.cnn_threshold.text()) == 0.0: - self.img_threshold = np.copy(self.img) - else: - self.img_threshold = np.where( - self.img >= float(self.cnn_threshold.text()), 1, 0 - ).astype(np.uint8) - - create_image_layer( - self.viewer, - image=self.img_threshold, - name="Prediction", - transparency=True, - range_=(0, 1), - ) - - self.predictor.image = self.img_threshold - self.predictor.save_semantic_mask(self.dir.split("/")[-1]) - - def update_versions(self): - for i in range(self.model_version.count()): - self.model_version.removeItem(0) - - versions = get_all_version_aws(self.cnn_type.currentText(), "32", "membrane_3d") - - if len(versions) == 0: - self.model_version.addItems(["None"]) - else: - self.model_version.addItems(["None"] + [i.split("_")[-1] for i in versions]) - - def calculate_position(self, name): - patch_size = int(self.patch_size.currentText()) - name = name.split("_") - name = { - "z": int(name[1]), - "y": int(name[2]), - "x": int(name[3]), - "stride": int(name[4]), - } - - x_start = (name["x"] * patch_size) - (name["x"] * name["stride"]) - x_end = x_start + patch_size - name["x"] = [x_start, x_end] - - y_start = (name["y"] * patch_size) - (name["y"] * name["stride"]) - y_end = y_start + patch_size - name["y"] = [y_start, y_end] - - z_start = (name["z"] * patch_size) - (name["z"] * name["stride"]) - z_end = z_start + patch_size - name["z"] = [z_start, z_end] - - return name diff --git a/src/napari_tardis_em/viewers/viewer_mt_3d.py b/src/napari_tardis_em/viewers/viewer_mt_3d.py index 20f1136..7517823 100644 --- a/src/napari_tardis_em/viewers/viewer_mt_3d.py +++ b/src/napari_tardis_em/viewers/viewer_mt_3d.py @@ -39,6 +39,7 @@ from tardis_em.utils.normalization import adaptive_threshold from tardis_em.utils.predictor import GeneralPredictor from tardis_em.utils.setup_envir import clean_up +from tardis_em.utils.spline_metric import sort_by_length from napari_tardis_em.viewers.styles import border_style from napari_tardis_em.utils.utils import get_list_of_device @@ -133,8 +134,12 @@ def __init__(self, viewer_mt_3d: Viewer): self.cnn_type.setToolTip("Select type of CNN you would like to train.") self.cnn_type.currentIndexChanged.connect(self.update_versions) - self.checkpoint = QLineEdit("None") - self.checkpoint.setToolTip("Optional, directory to CNN checkpoint.") + self.checkpoint = QPushButton("None") + self.checkpoint.setToolTip( + "Optional, directory to CNN checkpoint to restart training." + ) + self.checkpoint.clicked.connect(self.update_checkpoint_dir) + self.checkpoint_dir = None self.patch_size = QComboBox() self.patch_size.addItems( @@ -184,6 +189,7 @@ def __init__(self, viewer_mt_3d: Viewer): "false/positives. Higher value will result in cleaner output but may \n" "reduce recall." ) + self.cnn_threshold.valueChanged.connect(self.update_dist_graph) self.device = QComboBox() self.device.addItems(get_list_of_device()) @@ -268,6 +274,8 @@ def __init__(self, viewer_mt_3d: Viewer): "length in angstrom. All filaments shorter then this length \n" "will be deleted." ) + self.filter_by_length.textChanged.connect(self.update_dist_graph) + self.connect_splines = QLineEdit("2500") self.connect_splines.setValidator(QIntValidator(0, 10000)) self.connect_splines.setToolTip( @@ -279,6 +287,7 @@ def __init__(self, viewer_mt_3d: Viewer): "determines how far apart two microtubules can be, while still being considered \n" "as a single unit if they are oriented in the same direction." ) + self.connect_splines.textChanged.connect(self.update_dist_graph) self.connect_cylinder = QLineEdit("250") self.connect_cylinder.setValidator(QIntValidator(0, 10000)) @@ -290,6 +299,7 @@ def __init__(self, viewer_mt_3d: Viewer): "The ends of these filaments must be located within this cylinder \n" "to be considered connected." ) + self.connect_cylinder.textChanged.connect(self.update_dist_graph) """""" """""" """ UI Setup @@ -332,6 +342,129 @@ def __init__(self, viewer_mt_3d: Viewer): self.setLayout(layout) + def update_checkpoint_dir(self): + filename, _ = QFileDialog.getOpenFileName( + caption="Open File", + directory=getcwd(), + ) + self.checkpoint.setText(filename[-30:]) + self.checkpoint_dir = filename + + def update_cnn_threshold(self): + if self.img is not None: + self.viewer.layers[self.dir.split("/")[-1]].visible = True + + if float(self.cnn_threshold.text()) == 1.0: + self.img_threshold = adaptive_threshold(self.img).astype(np.uint8) + elif float(self.cnn_threshold.text()) == 0.0: + self.img_threshold = np.copy(self.img) + else: + self.img_threshold = np.where( + self.img >= float(self.cnn_threshold.text()), 1, 0 + ).astype(np.uint8) + + create_image_layer( + self.viewer, + image=self.img_threshold, + name="Prediction", + transparency=True, + range_=(0, 1), + ) + + self.predictor.image = self.img_threshold + self.predictor.save_semantic_mask(self.dir.split("/")[-1]) + + def update_dist_layer(self): + if self.predictor.segments is not None: + create_point_layer( + viewer=self.viewer, + points=self.predictor.segments, + name="Predicted_Instances", + visibility=True, + ) + else: + return + + if self.predictor.segments_filter is not None: + create_point_layer( + viewer=self.viewer, + points=self.predictor.segments_filter, + name="Predicted_Instances_filter", + visibility=True, + ) + else: + return + + def update_dist_graph(self): + if self.predictor is not None: + if self.predictor.graphs is not None: + if bool(self.filament.checkState()): + sort = True + prune = 5 + else: + sort = False + prune = 15 + + try: + self.predictor.segments = ( + self.predictor.GraphToSegment.patch_to_segment( + graph=self.predictor.graphs, + coord=self.predictor.pc_ld, + idx=self.predictor.output_idx, + sort=sort, + prune=prune, + ) + ) + self.predictor.segments = sort_by_length(self.predictor.segments) + except: + self.predictor.segments = None + + if self.predictor.segments is None: + show_info("TARDIS-em could not find any instances :(") + return + else: + show_info( + f"TARDIS-em found {int(np.max(self.predictor.segments[:, 0]))} instances :)" + ) + self.predictor.save_instance_PC(self.dir.split("/")[-1]) + + def update_versions(self): + for i in range(self.model_version.count()): + self.model_version.removeItem(0) + + versions = get_all_version_aws( + self.cnn_type.currentText(), "32", "microtubules_3d" + ) + + if len(versions) == 0: + self.model_version.addItems(["None"]) + else: + self.model_version.addItems(["None"] + [i.split("_")[-1] for i in versions]) + + def calculate_position(self, name): + patch_size = int(self.patch_size.currentText()) + name = name.split("_") + name = { + "z": int(name[1]), + "y": int(name[2]), + "x": int(name[3]), + "stride": int(name[4]), + } + + x_start = (name["x"] * patch_size) - (name["x"] * name["stride"]) + x_end = x_start + patch_size + name["x"] = [x_start, x_end] + + y_start = (name["y"] * patch_size) - (name["y"] * name["stride"]) + y_end = y_start + patch_size + name["y"] = [y_start, y_end] + + z_start = (name["z"] * patch_size) - (name["z"] * name["stride"]) + z_end = z_start + patch_size + name["z"] = [z_start, z_end] + + return name + def load_directory(self): filename, _ = QFileDialog.getOpenFileName( caption="Open File", @@ -417,7 +550,7 @@ def predict_semantic(self): correct_px=correct_px, convolution_nn=self.cnn_type.currentText(), checkpoint=[ - None if self.checkpoint.text() == "None" else self.checkpoint.text(), + None if self.checkpoint.text() == "None" else self.checkpoint_dir, None, ], model_version=model_version, @@ -516,27 +649,6 @@ def predict_dataset(img_dataset_, predictor): else: return - def update_dist_layer(self): - if self.predictor.segments is not None: - create_point_layer( - viewer=self.viewer, - points=self.predictor.segments, - name="Predicted_Instances", - visibility=True, - ) - else: - return - - if self.predictor.segments_filter is not None: - create_point_layer( - viewer=self.viewer, - points=self.predictor.segments_filter, - name="Predicted_Instances_filter", - visibility=True, - ) - else: - return - def predict_instance(self): if self.predictor is None: show_error(f"Please initialize with 'Predict Semantic' button") @@ -619,7 +731,7 @@ def show_command(self): ch = ( "" if self.checkpoint.text() == "None" - else f"-ch {self.checkpoint.text()}_None " + else f"-ch {self.checkpoint_dir}_None " ) mv = ( @@ -702,64 +814,3 @@ def show_command(self): f"-pv {int(self.points_in_patch.text())} " f"-dv {self.device.currentText()}" ) - - def update_cnn_threshold(self): - if self.img is not None: - self.viewer.layers[self.dir.split("/")[-1]].visible = True - - if float(self.cnn_threshold.text()) == 1.0: - self.img_threshold = adaptive_threshold(self.img).astype(np.uint8) - elif float(self.cnn_threshold.text()) == 0.0: - self.img_threshold = np.copy(self.img) - else: - self.img_threshold = np.where( - self.img >= float(self.cnn_threshold.text()), 1, 0 - ).astype(np.uint8) - - create_image_layer( - self.viewer, - image=self.img_threshold, - name="Prediction", - transparency=True, - range_=(0, 1), - ) - - self.predictor.image = self.img_threshold - self.predictor.save_semantic_mask(self.dir.split("/")[-1]) - - def update_versions(self): - for i in range(self.model_version.count()): - self.model_version.removeItem(0) - - versions = get_all_version_aws( - self.cnn_type.currentText(), "32", "microtubules_3d" - ) - - if len(versions) == 0: - self.model_version.addItems(["None"]) - else: - self.model_version.addItems(["None"] + [i.split("_")[-1] for i in versions]) - - def calculate_position(self, name): - patch_size = int(self.patch_size.currentText()) - name = name.split("_") - name = { - "z": int(name[1]), - "y": int(name[2]), - "x": int(name[3]), - "stride": int(name[4]), - } - - x_start = (name["x"] * patch_size) - (name["x"] * name["stride"]) - x_end = x_start + patch_size - name["x"] = [x_start, x_end] - - y_start = (name["y"] * patch_size) - (name["y"] * name["stride"]) - y_end = y_start + patch_size - name["y"] = [y_start, y_end] - - z_start = (name["z"] * patch_size) - (name["z"] * name["stride"]) - z_end = z_start + patch_size - name["z"] = [z_start, z_end] - - return name diff --git a/src/napari_tardis_em/viewers/viewer_predict.py b/src/napari_tardis_em/viewers/viewer_predict.py index 67fbd9f..b02ced5 100644 --- a/src/napari_tardis_em/viewers/viewer_predict.py +++ b/src/napari_tardis_em/viewers/viewer_predict.py @@ -39,7 +39,7 @@ from tardis_em.utils.normalization import adaptive_threshold from tardis_em.utils.predictor import GeneralPredictor from tardis_em.utils.setup_envir import clean_up - +from tardis_em.utils.spline_metric import sort_by_length from napari_tardis_em.viewers.styles import border_style from napari_tardis_em.utils.utils import get_list_of_device from napari_tardis_em.viewers.utils import ( @@ -201,6 +201,7 @@ def __init__(self, viewer_predict: Viewer): "false/positives. Higher value will result in cleaner output but may \n" "reduce recall." ) + self.dist_threshold.valueChanged.connect(self.update_dist_graph) self.device = QComboBox() self.device.addItems(get_list_of_device()) @@ -259,6 +260,8 @@ def __init__(self, viewer_predict: Viewer): "length in angstrom. All filaments shorter then this length \n" "will be deleted." ) + self.filter_by_length.textChanged.connect(self.update_dist_graph) + self.connect_splines = QLineEdit("None") self.connect_splines.setValidator(QIntValidator(0, 10000)) self.connect_splines.setToolTip( @@ -270,6 +273,7 @@ def __init__(self, viewer_predict: Viewer): "determines how far apart two filaments can be, while still being considered \n" "as a single unit if they are oriented in the same direction." ) + self.connect_splines.textChanged.connect(self.update_dist_graph) self.connect_cylinder = QLineEdit("None") self.connect_cylinder.setValidator(QIntValidator(0, 10000)) @@ -281,6 +285,7 @@ def __init__(self, viewer_predict: Viewer): "The ends of these filaments must be located within this cylinder \n" "to be considered connected." ) + self.connect_cylinder.textChanged.connect(self.update_dist_graph) """""" """""" """ UI Setup @@ -327,6 +332,118 @@ def update_filament_setting(self): self.connect_splines.setText("2500") self.connect_cylinder.setText("250") + def update_checkpoint_dir(self): + filename, _ = QFileDialog.getOpenFileName( + caption="Open File", + directory=getcwd(), + ) + self.checkpoint.setText(filename[-30:]) + self.checkpoint_dir = filename + + def update_cnn_threshold(self): + if self.img is not None: + self.viewer.layers[self.dir.split("/")[-1]].visible = True + + if float(self.cnn_threshold.text()) == 1.0: + self.img_threshold = adaptive_threshold(self.img).astype(np.uint8) + elif float(self.cnn_threshold.text()) == 0.0: + self.img_threshold = np.copy(self.img) + else: + self.img_threshold = np.where( + self.img >= float(self.cnn_threshold.text()), 1, 0 + ).astype(np.uint8) + + create_image_layer( + self.viewer, + image=self.img_threshold, + name="Prediction", + transparency=True, + range_=(0, 1), + ) + + self.predictor.image = self.img_threshold + self.predictor.save_semantic_mask(self.dir.split("/")[-1]) + + def update_dist_layer(self): + self.predictor.image = self.img_threshold + + if self.predictor.segments is not None: + create_point_layer( + viewer=self.viewer, + points=self.predictor.segments, + name="Predicted_Instances", + visibility=True, + ) + else: + return + + if self.predictor.segments_filter is not None: + create_point_layer( + viewer=self.viewer, + points=self.predictor.segments_filter, + name="Predicted_Instances_filter", + visibility=True, + ) + else: + return + + def update_dist_graph(self): + if self.predictor is not None: + if self.predictor.graphs is not None: + if bool(self.filament.checkState()): + sort = True + prune = 5 + else: + sort = False + prune = 15 + + try: + self.predictor.segments = ( + self.predictor.GraphToSegment.patch_to_segment( + graph=self.predictor.graphs, + coord=self.predictor.pc_ld, + idx=self.predictor.output_idx, + sort=sort, + prune=prune, + ) + ) + self.predictor.segments = sort_by_length(self.predictor.segments) + except: + self.predictor.segments = None + + if self.predictor.segments is None: + show_info("TARDIS-em could not find any instances :(") + return + else: + show_info( + f"TARDIS-em found {int(np.max(self.predictor.segments[:, 0]))} instances :)" + ) + self.predictor.save_instance_PC(self.dir.split("/")[-1]) + + def calculate_position(self, name): + patch_size = int(self.patch_size.currentText()) + name = name.split("_") + name = { + "z": int(name[1]), + "y": int(name[2]), + "x": int(name[3]), + "stride": int(name[4]), + } + + x_start = (name["x"] * patch_size) - (name["x"] * name["stride"]) + x_end = x_start + patch_size + name["x"] = [x_start, x_end] + + y_start = (name["y"] * patch_size) - (name["y"] * name["stride"]) + y_end = y_start + patch_size + name["y"] = [y_start, y_end] + + z_start = (name["z"] * patch_size) - (name["z"] * name["stride"]) + z_end = z_start + patch_size + name["z"] = [z_start, z_end] + + return name + def load_directory(self): filename, _ = QFileDialog.getOpenFileName( caption="Open File", @@ -525,10 +642,10 @@ def predict_dataset(img_dataset_, predictor): self.img = self.predictor.image_stitcher( image_dir=self.predictor.output, mask=False, dtype=np.float32 )[ - : self.predictor.scale_shape[0], - : self.predictor.scale_shape[1], - : self.predictor.scale_shape[2], - ] + : self.predictor.scale_shape[0], + : self.predictor.scale_shape[1], + : self.predictor.scale_shape[2], + ] self.img, _ = scale_image( image=self.img, scale=self.predictor.org_shape ) @@ -542,29 +659,6 @@ def predict_dataset(img_dataset_, predictor): else: return - def update_dist_layer(self): - self.predictor.image = self.img_threshold - - if self.predictor.segments is not None: - create_point_layer( - viewer=self.viewer, - points=self.predictor.segments, - name="Predicted_Instances", - visibility=True, - ) - else: - return - - if self.predictor.segments_filter is not None: - create_point_layer( - viewer=self.viewer, - points=self.predictor.segments_filter, - name="Predicted_Instances_filter", - visibility=True, - ) - else: - return - def predict_instance(self): if self.predictor is None: show_error(f"Please initialize with 'Predict Semantic' button") @@ -650,17 +744,9 @@ def show_command(self): else f"-ch {self.checkpoint_dir}_None " ) - fi = ( - "-fi True " - if bool(self.filament.checkState()) - else "" - ) + fi = "-fi True " if bool(self.filament.checkState()) else "" - it = ( - "-it 2d " - if self.image_type.currentText() == "2D" - else f"-it 3d " - ) + it = "-it 2d " if self.image_type.currentText() == "2D" else f"-it 3d " cnn = ( "" @@ -706,59 +792,3 @@ def show_command(self): f"-pv {int(self.points_in_patch.text())} " f"-dv {self.device.currentText()}" ) - - def update_cnn_threshold(self): - if self.img is not None: - self.viewer.layers[self.dir.split("/")[-1]].visible = True - - if float(self.cnn_threshold.text()) == 1.0: - self.img_threshold = adaptive_threshold(self.img).astype(np.uint8) - elif float(self.cnn_threshold.text()) == 0.0: - self.img_threshold = np.copy(self.img) - else: - self.img_threshold = np.where( - self.img >= float(self.cnn_threshold.text()), 1, 0 - ).astype(np.uint8) - - create_image_layer( - self.viewer, - image=self.img_threshold, - name="Prediction", - transparency=True, - range_=(0, 1), - ) - - self.predictor.image = self.img_threshold - self.predictor.save_semantic_mask(self.dir.split("/")[-1]) - - def calculate_position(self, name): - patch_size = int(self.patch_size.currentText()) - name = name.split("_") - name = { - "z": int(name[1]), - "y": int(name[2]), - "x": int(name[3]), - "stride": int(name[4]), - } - - x_start = (name["x"] * patch_size) - (name["x"] * name["stride"]) - x_end = x_start + patch_size - name["x"] = [x_start, x_end] - - y_start = (name["y"] * patch_size) - (name["y"] * name["stride"]) - y_end = y_start + patch_size - name["y"] = [y_start, y_end] - - z_start = (name["z"] * patch_size) - (name["z"] * name["stride"]) - z_end = z_start + patch_size - name["z"] = [z_start, z_end] - - return name - - def update_checkpoint_dir(self): - filename, _ = QFileDialog.getOpenFileName( - caption="Open File", - directory=getcwd(), - ) - self.checkpoint.setText(filename[-30:]) - self.checkpoint_dir = filename diff --git a/src/napari_tardis_em/viewers/viewer_train.py b/src/napari_tardis_em/viewers/viewer_train.py index 15d07c0..d1b788c 100644 --- a/src/napari_tardis_em/viewers/viewer_train.py +++ b/src/napari_tardis_em/viewers/viewer_train.py @@ -426,7 +426,7 @@ def trainer(self): "prediction": False, } else: - model_dict['img_size'] = patch_size + model_dict["img_size"] = patch_size self.structure = model_dict """Build CNN model""" diff --git a/utils/__init__.py b/utils/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/utils/export_coord.py b/utils/export_coord.py deleted file mode 100644 index 40f6323..0000000 --- a/utils/export_coord.py +++ /dev/null @@ -1,9 +0,0 @@ -####################################################################### -# TARDIS - Transformer And Rapid Dimensionless Instance Segmentation # -# # -# New York Structural Biology Center # -# Simons Machine Learning Center # -# # -# Robert Kiewisz, Tristan Bepler # -# MIT License 2024 # -####################################################################### diff --git a/utils/export_image.py b/utils/export_image.py deleted file mode 100644 index 40f6323..0000000 --- a/utils/export_image.py +++ /dev/null @@ -1,9 +0,0 @@ -####################################################################### -# TARDIS - Transformer And Rapid Dimensionless Instance Segmentation # -# # -# New York Structural Biology Center # -# Simons Machine Learning Center # -# # -# Robert Kiewisz, Tristan Bepler # -# MIT License 2024 # -####################################################################### diff --git a/utils/import_data.py b/utils/import_data.py deleted file mode 100644 index 40f6323..0000000 --- a/utils/import_data.py +++ /dev/null @@ -1,9 +0,0 @@ -####################################################################### -# TARDIS - Transformer And Rapid Dimensionless Instance Segmentation # -# # -# New York Structural Biology Center # -# Simons Machine Learning Center # -# # -# Robert Kiewisz, Tristan Bepler # -# MIT License 2024 # -####################################################################### diff --git a/utils/styles.py b/utils/styles.py deleted file mode 100644 index e7361ad..0000000 --- a/utils/styles.py +++ /dev/null @@ -1,20 +0,0 @@ -####################################################################### -# TARDIS - Transformer And Rapid Dimensionless Instance Segmentation # -# # -# New York Structural Biology Center # -# Simons Machine Learning Center # -# # -# Robert Kiewisz, Tristan Bepler # -# MIT License 2024 # -####################################################################### - - -def border_style(color="blue", direction="bottom", border=3, padding=4): - assert direction in ["top", "bottom", "right", "left"] - style = ( - f"border-{direction}: {border}px solid {color};" - f"padding-{direction}: {padding}px;" - "background-color:none;" - ) - - return style diff --git a/utils/utils.py b/utils/utils.py deleted file mode 100644 index 7b2248b..0000000 --- a/utils/utils.py +++ /dev/null @@ -1,25 +0,0 @@ -####################################################################### -# TARDIS - Transformer And Rapid Dimensionless Instance Segmentation # -# # -# New York Structural Biology Center # -# Simons Machine Learning Center # -# # -# Robert Kiewisz, Tristan Bepler # -# MIT License 2024 # -####################################################################### -import torch - - -def get_list_of_device(): - devices = ["cpu"] - - # Check if CUDA (NVIDIA GPU) is available and list all available CUDA devices - if torch.cuda.is_available(): - for i in range(torch.cuda.device_count()): - devices.append(f"cuda:{i}") - - # Check for MPS (Apple's Metal Performance Shaders) availability - if torch.backends.mps.is_available(): - devices.append("mps") - - return devices diff --git a/viewers/__init__.py b/viewers/__init__.py deleted file mode 100644 index ba33225..0000000 --- a/viewers/__init__.py +++ /dev/null @@ -1 +0,0 @@ -colormap_for_display = "Spectral" diff --git a/viewers/utils.py b/viewers/utils.py deleted file mode 100644 index 1e6b28e..0000000 --- a/viewers/utils.py +++ /dev/null @@ -1,96 +0,0 @@ -####################################################################### -# TARDIS - Transformer And Rapid Dimensionless Instance Segmentation # -# # -# New York Structural Biology Center # -# Simons Machine Learning Center # -# # -# Robert Kiewisz, Tristan Bepler # -# MIT License 2024 # -####################################################################### -import numpy as np - -from napari_tardis_em.viewers import colormap_for_display - - -def update_viewer_prediction(viewer, image: np.ndarray, position: dict): - img = viewer.layers["Prediction"] - - try: - img.data[ - position["z"][0] : position["z"][1], - position["y"][0] : position["y"][1], - position["x"][0] : position["x"][1], - ] = image - except ValueError: - shape_ = img.data.shape - diff = [ - image.shape[0] - (position["z"][1] - shape_[0]), - image.shape[0] - (position["y"][1] - shape_[1]), - image.shape[0] - (position["x"][1] - shape_[2]), - ] - diff = [ - diff[0] if 0 < diff[0] < image.shape[0] else image.shape[0], - diff[1] if 0 < diff[1] < image.shape[0] else image.shape[0], - diff[2] if 0 < diff[2] < image.shape[0] else image.shape[0], - ] - position["z"][1] = position["z"][0] + diff[0] - position["y"][1] = position["y"][0] + diff[1] - position["x"][1] = position["x"][0] + diff[2] - - img.data[ - position["z"][0] : position["z"][1], - position["y"][0] : position["y"][1], - position["x"][0] : position["x"][1], - ] = image[: diff[0], : diff[1], : diff[2]] - - -def create_image_layer( - viewer, - image: np.ndarray, - name: str, - transparency=False, - visibility=True, - range_=None, -): - """ - Create an image layer in napari. - - Args: - image (np.ndarray): Image array to display - name (str): Layer name - transparency (bool): If True, show image as transparent layer - visibility (bool): - range_(tuple): - """ - try: - viewer.layers.remove(name) - except Exception as e: - pass - - if transparency: - viewer.add_image(image, name=name, colormap=colormap_for_display, opacity=0.5) - else: - viewer.add_image(image, name=name, colormap="gray", opacity=1.0) - - if range_ is not None: - try: - viewer.layers[name].contrast_limits = ( - range_[0], - range_[1], - ) - except Exception as e: - pass - else: - try: - viewer.layers[name].contrast_limits = ( - image.min(), - image.max(), - ) - except Exception as e: - pass - - if visibility: - # set layer as not visible - viewer.layers[name].visible = True - else: - viewer.layers[name].visible = False diff --git a/viewers/viewer_actin_3d.py b/viewers/viewer_actin_3d.py deleted file mode 100644 index 8c76115..0000000 --- a/viewers/viewer_actin_3d.py +++ /dev/null @@ -1,59 +0,0 @@ -####################################################################### -# TARDIS - Transformer And Rapid Dimensionless Instance Segmentation # -# # -# New York Structural Biology Center # -# Simons Machine Learning Center # -# # -# Robert Kiewisz, Tristan Bepler # -# MIT License 2024 # -####################################################################### -from os import getcwd -from os.path import join - -import numpy as np -import torch - -from PyQt5.QtGui import QIntValidator, QDoubleValidator -from PyQt5.QtWidgets import ( - QPushButton, - QFormLayout, - QLineEdit, - QComboBox, - QLabel, - QCheckBox, - QDoubleSpinBox, - QFileDialog, -) -from napari import Viewer -from napari.utils.progress import progress -from napari.qt.threading import thread_worker - -from qtpy.QtWidgets import QWidget -from PyQt5.QtCore import Qt - -from tardis_em.utils.predictor import GeneralPredictor -from tardis_em.utils.load_data import load_image -from tardis_em.cnn.data_processing.trim import trim_with_stride -from tardis_em.utils.aws import get_all_version_aws -from tardis_em.cnn.datasets.dataloader import PredictionDataset -from tardis_em.cnn.data_processing.scaling import scale_image -from tardis_em.utils.normalization import adaptive_threshold - -from napari.utils.notifications import show_info, show_error -from napari_tardis_em.utils.styles import border_style -from napari_tardis_em.utils.utils import get_list_of_device -from napari_tardis_em.viewers.utils import create_image_layer, update_viewer_prediction - - -class TardisWidget(QWidget): - """ - Easy to use plugin for general Actin prediction. - - Plugin integrate TARDIS-em and allow to easily set up training. To make it more - user-friendly, this plugin guid user what to do, and during training display - results from validation loop. - """ - - def __init__(self, viewer_actin_3d: Viewer): - super().__init__() - self.viewer = viewer_actin_3d diff --git a/viewers/viewer_mem_2d.py b/viewers/viewer_mem_2d.py deleted file mode 100644 index 37513c8..0000000 --- a/viewers/viewer_mem_2d.py +++ /dev/null @@ -1,59 +0,0 @@ -####################################################################### -# TARDIS - Transformer And Rapid Dimensionless Instance Segmentation # -# # -# New York Structural Biology Center # -# Simons Machine Learning Center # -# # -# Robert Kiewisz, Tristan Bepler # -# MIT License 2024 # -####################################################################### -from os import getcwd -from os.path import join - -import numpy as np -import torch - -from PyQt5.QtGui import QIntValidator, QDoubleValidator -from PyQt5.QtWidgets import ( - QPushButton, - QFormLayout, - QLineEdit, - QComboBox, - QLabel, - QCheckBox, - QDoubleSpinBox, - QFileDialog, -) -from napari import Viewer -from napari.utils.progress import progress -from napari.qt.threading import thread_worker - -from qtpy.QtWidgets import QWidget -from PyQt5.QtCore import Qt - -from tardis_em.utils.predictor import GeneralPredictor -from tardis_em.utils.load_data import load_image -from tardis_em.cnn.data_processing.trim import trim_with_stride -from tardis_em.utils.aws import get_all_version_aws -from tardis_em.cnn.datasets.dataloader import PredictionDataset -from tardis_em.cnn.data_processing.scaling import scale_image -from tardis_em.utils.normalization import adaptive_threshold - -from napari.utils.notifications import show_info, show_error -from napari_tardis_em.utils.styles import border_style -from napari_tardis_em.utils.utils import get_list_of_device -from napari_tardis_em.viewers.utils import create_image_layer, update_viewer_prediction - - -class TardisWidget(QWidget): - """ - Easy to use plugin for general Membrane prediction from micrographs. - - Plugin integrate TARDIS-em and allow to easily set up training. To make it more - user-friendly, this plugin guid user what to do, and during training display - results from validation loop. - """ - - def __init__(self, viewer_mem_2d: Viewer): - super().__init__() - self.viewer = viewer_mem_2d diff --git a/viewers/viewer_mem_3d.py b/viewers/viewer_mem_3d.py deleted file mode 100644 index faba891..0000000 --- a/viewers/viewer_mem_3d.py +++ /dev/null @@ -1,59 +0,0 @@ -####################################################################### -# TARDIS - Transformer And Rapid Dimensionless Instance Segmentation # -# # -# New York Structural Biology Center # -# Simons Machine Learning Center # -# # -# Robert Kiewisz, Tristan Bepler # -# MIT License 2024 # -####################################################################### -from os import getcwd -from os.path import join - -import numpy as np -import torch - -from PyQt5.QtGui import QIntValidator, QDoubleValidator -from PyQt5.QtWidgets import ( - QPushButton, - QFormLayout, - QLineEdit, - QComboBox, - QLabel, - QCheckBox, - QDoubleSpinBox, - QFileDialog, -) -from napari import Viewer -from napari.utils.progress import progress -from napari.qt.threading import thread_worker - -from qtpy.QtWidgets import QWidget -from PyQt5.QtCore import Qt - -from tardis_em.utils.predictor import GeneralPredictor -from tardis_em.utils.load_data import load_image -from tardis_em.cnn.data_processing.trim import trim_with_stride -from tardis_em.utils.aws import get_all_version_aws -from tardis_em.cnn.datasets.dataloader import PredictionDataset -from tardis_em.cnn.data_processing.scaling import scale_image -from tardis_em.utils.normalization import adaptive_threshold - -from napari.utils.notifications import show_info, show_error -from napari_tardis_em.utils.styles import border_style -from napari_tardis_em.utils.utils import get_list_of_device -from napari_tardis_em.viewers.utils import create_image_layer, update_viewer_prediction - - -class TardisWidget(QWidget): - """ - Easy to use plugin for general Membrane prediction from tomograms. - - Plugin integrate TARDIS-em and allow to easily set up training. To make it more - user-friendly, this plugin guid user what to do, and during training display - results from validation loop. - """ - - def __init__(self, viewer_mem_3d: Viewer): - super().__init__() - self.viewer = viewer_mem_3d diff --git a/viewers/viewer_mt_2d.py b/viewers/viewer_mt_2d.py deleted file mode 100644 index 40f6323..0000000 --- a/viewers/viewer_mt_2d.py +++ /dev/null @@ -1,9 +0,0 @@ -####################################################################### -# TARDIS - Transformer And Rapid Dimensionless Instance Segmentation # -# # -# New York Structural Biology Center # -# Simons Machine Learning Center # -# # -# Robert Kiewisz, Tristan Bepler # -# MIT License 2024 # -####################################################################### diff --git a/viewers/viewer_mt_3d.py b/viewers/viewer_mt_3d.py deleted file mode 100644 index 9c3a784..0000000 --- a/viewers/viewer_mt_3d.py +++ /dev/null @@ -1,696 +0,0 @@ -####################################################################### -# TARDIS - Transformer And Rapid Dimensionless Instance Segmentation # -# # -# New York Structural Biology Center # -# Simons Machine Learning Center # -# # -# Robert Kiewisz, Tristan Bepler # -# MIT License 2024 # -####################################################################### -from os import getcwd -from os.path import join - -import numpy as np -import torch - -from PyQt5.QtGui import QIntValidator, QDoubleValidator -from PyQt5.QtWidgets import ( - QPushButton, - QFormLayout, - QLineEdit, - QComboBox, - QLabel, - QCheckBox, - QDoubleSpinBox, - QFileDialog, -) -from napari import Viewer -from napari.utils.progress import progress -from napari.qt.threading import thread_worker - -from qtpy.QtWidgets import QWidget -from PyQt5.QtCore import Qt - -from tardis_em.utils.predictor import GeneralPredictor -from tardis_em.utils.load_data import load_image -from tardis_em.cnn.data_processing.trim import trim_with_stride -from tardis_em.utils.aws import get_all_version_aws -from tardis_em.cnn.datasets.dataloader import PredictionDataset -from tardis_em.cnn.data_processing.scaling import scale_image -from tardis_em.utils.normalization import adaptive_threshold - -from napari.utils.notifications import show_info, show_error -from napari_tardis_em.utils.styles import border_style -from napari_tardis_em.utils.utils import get_list_of_device -from napari_tardis_em.viewers.utils import create_image_layer, update_viewer_prediction - - -class TardisWidget(QWidget): - """ - Easy to use plugin for general Microtubule prediction. - - Plugin integrate TARDIS-em and allow to easily set up training. To make it more - user-friendly, this plugin guid user what to do, and during training display - results from validation loop. - """ - - def __init__(self, viewer_mt_3d: Viewer): - super().__init__() - - self.viewer = viewer_mt_3d - - self.img, self.px = None, None - self.out_ = getcwd() - - """""" """""" """ - UI Elements - """ """""" """""" - self.directory = QPushButton(f"...{getcwd()[-30:]}") - self.directory.setToolTip( - "Select directory with image or single file you would like to predict. \n " - "\n" - "Supported formats:\n" - "images: *.mrc, *.rec, *.map, *.am, *.tif" - ) - self.directory.clicked.connect(self.load_directory) - self.dir = getcwd() - - self.output = QPushButton(f"...{getcwd()[-17:]}/Predictions/") - self.output.setToolTip( - "Select directory in which plugin will save train model, checkpoints and training logs." - ) - self.output_folder = f"...{getcwd()[-17:]}/Predictions/" - - ############################## - # Setting user should change # - ############################## - label_2 = QLabel("Setting user should change") - label_2.setStyleSheet(border_style("green")) - - self.output_semantic = QComboBox() - self.output_semantic.addItems(["mrc", "tif", "npy", "am"]) - self.output_semantic.setToolTip("Select semantic output format file.") - - self.output_instance = QComboBox() - self.output_instance.addItems(["None", "csv", "npy", "amSG"]) - self.output_instance.setToolTip("Select instance output format file.") - - self.output_formats = ( - f"{self.output_semantic.currentText()}_{self.output_instance.currentText()}" - ) - - ########################### - # Setting user may change # - ########################### - label_3 = QLabel("Setting user may change") - label_3.setStyleSheet(border_style("yellow")) - - self.mask = QCheckBox() - self.mask.setCheckState(Qt.CheckState.Unchecked) - self.mask.setToolTip( - "Define if you input tomograms images or binary mask \n" - "with pre segmented microtubules." - ) - - self.correct_px = QLineEdit("None") - self.correct_px.setValidator(QDoubleValidator(0.00, 100.00, 3)) - self.correct_px.setToolTip( - "Set correct pixel size value, if image header \n" - "do not contain or stores incorrect information." - ) - - self.cnn_type = QComboBox() - self.cnn_type.addItems(["unet", "resnet", "unet3plus", "fnet", "fnet_attn"]) - self.cnn_type.setCurrentIndex(4) - self.cnn_type.setToolTip("Select type of CNN you would like to train.") - self.cnn_type.currentIndexChanged.connect(self.update_versions) - - self.checkpoint = QLineEdit("None") - self.checkpoint.setToolTip("Optional, directory to CNN checkpoint.") - - self.patch_size = QComboBox() - self.patch_size.addItems( - ["32", "64", "96", "128", "160", "192", "256", "512", "1024"] - ) - self.patch_size.setCurrentIndex(2) - self.patch_size.setToolTip( - "Select patch size value that will be used to split \n" - "all images into smaller patches." - ) - - self.rotate = QCheckBox() - self.rotate.setCheckState(Qt.CheckState.Checked) - self.rotate.setToolTip( - "Select if you want to switch on/of rotation during the prediction. \n" - "If selected, during CNN prediction image is rotate 4x by 90 degrees.\n" - "This will increase prediction time 4x. \n" - "However may lead to more cleaner output." - ) - - self.cnn_threshold = QDoubleSpinBox() - self.cnn_threshold.setDecimals(2) - self.cnn_threshold.setMinimum(0) - self.cnn_threshold.setMaximum(1) - self.cnn_threshold.setSingleStep(0.05) - self.cnn_threshold.setValue(0.25) - self.cnn_threshold.setToolTip( - "Threshold value for binary prediction. Lower value will increase \n" - "recall [retrieve more of predicted object] but also may increase \n" - "false/positives. Higher value will result in cleaner output but may \n" - "reduce recall.\n" - "\n" - "If selected 0.0 - Output probability mask \n" - "If selected 1.0 - Use adaptive threshold." - ) - self.cnn_threshold.valueChanged.connect(self.update_cnn_threshold) - - self.dist_threshold = QDoubleSpinBox() - self.dist_threshold.setDecimals(2) - self.dist_threshold.setMinimum(0) - self.dist_threshold.setMaximum(1) - self.dist_threshold.setSingleStep(0.05) - self.dist_threshold.setValue(0.50) - self.dist_threshold.setToolTip( - "Threshold value for instance prediction. Lower value will increase \n" - "recall [retrieve more of predicted object] but also may increase \n" - "false/positives. Higher value will result in cleaner output but may \n" - "reduce recall." - ) - - self.device = QComboBox() - self.device.addItems(get_list_of_device()) - self.device.setCurrentIndex(0) - self.device.setToolTip( - "Select available device on which you want to train your model." - ) - - ######################################## - # Setting user is not advice to change # - ######################################## - label_4 = QLabel("Setting user is not advice to change") - label_4.setStyleSheet(border_style("red")) - - self.points_in_patch = QLineEdit("600") - self.points_in_patch.setValidator(QDoubleValidator(100, 10000, 1)) - self.points_in_patch.setToolTip( - "Number of point in patch. Higher number will increase how may points \n" - "DIST model will process at the time. This is usually only the memory GPU constrain." - ) - - self.model_version = QComboBox() - self.model_version.addItems(["None"]) - self.update_versions() - self.model_version.setToolTip("Optional version of the model from 1 to inf.") - - self.predict_1_button = QPushButton("Predict Semantic...") - self.predict_1_button.setMinimumWidth(225) - self.predict_1_button.clicked.connect(self.predict_semantic) - - self.predict_2_button = QPushButton("Predict Instances...") - self.predict_2_button.setMinimumWidth(225) - self.predict_2_button.clicked.connect(self.predict_instance) - - self.export_command = QPushButton("Export command for high-throughput") - self.export_command.setMinimumWidth(225) - self.export_command.clicked.connect(self.show_command) - - ################################# - # Optional Microtubules Filters # - ################################# - label_5 = QLabel("Optional Microtubules Filters") - label_5.setStyleSheet(border_style("orange")) - - self.amira_prefix = QLineEdit(".CorrelationLines") - self.amira_prefix.setToolTip( - "If dir/amira foldr exist, TARDIS will search for files with \n" - "given prefix (e.g. file_name.CorrelationLines.am). If the correct \n" - "file is found, TARDIS will use its instance segmentation with \n" - "ZiB Amira prediction, and output additional file called \n" - "file_name_AmiraCompare.am." - ) - - self.amira_compare_distance = QLineEdit("175") - self.amira_compare_distance.setValidator(QIntValidator(0, 10000)) - self.amira_compare_distance.setToolTip( - "The comparison with Amira prediction is done by evaluating \n" - "filaments distance between Amira and TARDIS. This parameter defines the maximum \n" - "distance to the similarity between two splines. Value given in Angstrom [A]." - ) - self.amira_inter_probability = QDoubleSpinBox() - self.amira_inter_probability.setDecimals(2) - self.amira_inter_probability.setMinimum(0) - self.amira_inter_probability.setMaximum(1) - self.amira_inter_probability.setSingleStep(0.05) - self.amira_inter_probability.setValue(0.25) - self.amira_inter_probability.setToolTip( - "This parameter define normalize between 0 and 1 overlap \n" - "between filament from TARDIS na Amira sufficient to identifies microtubule as \n" - "a match between both software." - ) - - self.filter_by_length = QLineEdit("1000") - self.filter_by_length.setValidator(QIntValidator(0, 10000)) - self.filter_by_length.setToolTip( - "Filtering parameters for microtubules, defining maximum microtubule \n" - "length in angstrom. All filaments shorter then this length \n" - "will be deleted." - ) - self.connect_splines = QLineEdit("2500") - self.connect_splines.setValidator(QIntValidator(0, 10000)) - self.connect_splines.setToolTip( - "To address the issue where microtubules are mistakenly \n" - "identified as two different filaments, we use a filtering technique. \n" - "This involves identifying the direction each filament end points towards and then \n" - "linking any filaments that are facing the same direction and are within \n" - "a certain distance from each other, measured in angstroms. This distance threshold \n" - "determines how far apart two microtubules can be, while still being considered \n" - "as a single unit if they are oriented in the same direction." - ) - - self.connect_cylinder = QLineEdit("250") - self.connect_cylinder.setValidator(QIntValidator(0, 10000)) - self.connect_cylinder.setToolTip( - "To minimize false positives when linking microtubules, we limit \n" - "the search area to a cylindrical radius specified in angstroms. \n" - "For each spline, we find the direction the filament end is pointing in \n" - "and look for another filament that is oriented in the same direction. \n" - "The ends of these filaments must be located within this cylinder \n" - "to be considered connected." - ) - - """""" """""" """ - UI Setup - """ """""" """""" - layout = QFormLayout() - layout.addRow("Select Directory", self.directory) - layout.addRow("Output Directory", self.output) - - layout.addRow("---- CNN Options ----", label_2) - layout.addRow("Semantic output", self.output_semantic) - layout.addRow("Instance output", self.output_instance) - - layout.addRow("----- Extra --------", label_3) - layout.addRow("Input as a mask", self.mask) - layout.addRow("Correct pixel size", self.correct_px) - layout.addRow("CNN type", self.cnn_type) - layout.addRow("Checkpoint", self.checkpoint) - layout.addRow("Patch size", self.patch_size) - layout.addRow("Rotation", self.rotate) - layout.addRow("CNN threshold", self.cnn_threshold) - layout.addRow("DIST threshold", self.dist_threshold) - layout.addRow("Device", self.device) - - layout.addRow("---- MT Filters -----", label_5) - layout.addRow("Amira file prefix", self.amira_prefix) - layout.addRow("Compare distance with Amira [A]", self.amira_compare_distance) - layout.addRow("Compare similarity probability", self.amira_inter_probability) - layout.addRow("Filter MT length [A]", self.filter_by_length) - layout.addRow("Connect splines within distance [A]", self.connect_splines) - layout.addRow("Connect splines within diameter [A]", self.connect_cylinder) - - layout.addRow("---- Advance -------", label_4) - layout.addRow("No. of points [DIST]", self.points_in_patch) - layout.addRow("Model Version", self.model_version) - - layout.addRow("", self.predict_1_button) - layout.addRow("", self.predict_2_button) - layout.addRow("", self.export_command) - - self.setLayout(layout) - - def load_directory(self): - filename, _ = QFileDialog.getOpenFileName( - caption="Open File", - directory=getcwd(), - filter="Image Files (*.mrc *.rec *.map, *.tif, *.tiff, *.am)", - ) - - out_ = [ - i - for i in filename.split("/") - if not i.endswith((".mrc", ".rec", ".map", ".tif", ".tiff", ".am")) - ] - self.out_ = "/".join(out_) - - self.output.setText(f"...{self.out_[-17:]}/Predictions/") - self.output_folder = f"...{self.out_}/Predictions/" - - self.directory.setText(filename[-30:]) - self.dir = filename - - self.img, self.px = load_image(self.dir) - - if self.correct_px.text() == "None" and self.px >= 0.0 or self.px != 1.0: - self.correct_px.setText(f"{self.px}") - - create_image_layer( - self.viewer, - image=self.img, - name=self.dir.split("/")[-1], - range_=(np.min(self.img), np.max(self.img)), - ) - - def predict_semantic(self): - """Pre-settings""" - - if self.correct_px.text() == "None": - correct_px = None - else: - correct_px = float(self.correct_px.text()) - - msg = ( - f"Predicted file is without pixel size metadate {correct_px}." - "Please correct correct_px argument with a correct pixel size value." - ) - if correct_px is None: - show_error(msg) - return - - self.output_formats = ( - f"{self.output_semantic.currentText()}_{self.output_instance.currentText()}" - ) - - if self.output_instance.currentText() == "None": - instances = False - else: - instances = True - - cnn_threshold = ( - "auto" - if float(self.cnn_threshold.text()) == 1.0 - else self.cnn_threshold.text() - ) - - if self.model_version.currentText() == "None": - model_version = None - else: - model_version = int(self.model_version.currentText()) - - self.predictor = GeneralPredictor( - predict="Microtubule", - dir_=self.dir, - binary_mask=bool(self.mask.checkState()), - correct_px=correct_px, - convolution_nn=self.cnn_type.currentText(), - checkpoint=( - None if self.checkpoint.text() == "None" else self.checkpoint.text(), - None, - ), - model_version=model_version, - output_format=self.output_formats, - patch_size=int(self.patch_size.currentText()), - cnn_threshold=cnn_threshold, - dist_threshold=float(self.dist_threshold.text()), - points_in_patch=int(self.points_in_patch.text()), - predict_with_rotation=bool(self.rotate.checkState()), - amira_prefix=self.amira_prefix.text(), - filter_by_length=int(self.filter_by_length.text()), - connect_splines=int(self.connect_splines.text()), - connect_cylinder=int(self.connect_cylinder.text()), - amira_compare_distance=int(self.amira_compare_distance.text()), - amira_inter_probability=float(self.amira_inter_probability.text()), - instances=instances, - device_=self.device.currentText(), - debug=False, - tardis_logo=False, - ) - - self.predictor.get_file_list() - self.predictor.create_headers() - self.predictor.load_data(id_name=self.predictor.predict_list[0]) - - if not bool(self.mask.checkState()): - trim_with_stride( - image=self.predictor.image, - scale=self.predictor.scale_shape, - trim_size_xy=self.predictor.patch_size, - trim_size_z=self.predictor.patch_size, - output=join(self.predictor.dir, "temp", "Patches"), - image_counter=0, - clean_empty=False, - stride=round(self.predictor.patch_size * 0.125), - ) - - create_image_layer( - self.viewer, - image=self.predictor.image, - name=self.dir.split("/")[-1], - range_=(np.min(self.predictor.image), np.max(self.predictor.image)), - visibility=False, - ) - - create_image_layer( - self.viewer, - image=np.zeros(self.predictor.scale_shape, dtype=np.float32), - name="Prediction", - transparency=True, - ) - - self.predictor.image = None - self.scale_shape = self.predictor.scale_shape - - img_dataset = PredictionDataset( - join(self.predictor.dir, "temp", "Patches", "imgs") - ) - worker = self.predict_dataset(img_dataset, self.predictor) - worker.start() - - def cnn_postprocess(self): - self.img = self.predictor.image_stitcher( - image_dir=self.predictor.output, mask=False, dtype=np.float32 - )[ - : self.predictor.scale_shape[0], - : self.predictor.scale_shape[1], - : self.predictor.scale_shape[2], - ] - self.img, _ = scale_image(image=self.img, scale=self.predictor.org_shape) - self.img = torch.sigmoid(torch.from_numpy(self.img)).cpu().detach().numpy() - - self.img_threshold = None - if float(self.cnn_threshold.text()) == 1.0: - self.img_threshold = adaptive_threshold(self.img).astype(np.uint8) - elif float(self.cnn_threshold.text()) == 0.0: - self.img_threshold = np.copy(self.img) - else: - self.img_threshold = np.where( - self.img >= float(self.cnn_threshold.text()), 1, 0 - ).astype(np.uint8) - - create_image_layer( - self.viewer, - image=self.img_threshold, - name="Prediction", - transparency=True, - range_=(0, 1), - ) - self.predictor.image = self.img_threshold - self.predictor.save_semantic_mask(self.dir.split("/")[-1]) - - def predict_instance(self): - self.output_formats = ( - f"{self.output_semantic.currentText()}_{self.output_instance.currentText()}" - ) - - if not self.output_formats.endswith("None"): - if self.predictor.dist is None: - self.predictor.output_format = self.output_formats - self.predictor.build_NN("Microtubule") - - self.segments = np.zeros((0, 4)) - - if not self.img_threshold.min() == 0 and not self.img_threshold.max() == 1: - show_error("You need to first select CNN threshold greater then 0.0") - return - - self.predictor.preprocess_DIST(self.dir.split("/")[-1]) - - if len(self.predictor.pc_ld) > 0: - # Build patches dataset - ( - self.predictor.coords_df, - _, - self.predictor.output_idx, - _, - ) = self.predictor.patch_pc.patched_dataset(coord=self.predictor.pc_ld) - - self.predictor.graphs = self.predictor.predict_DIST( - id_=0, id_name=self.dir.split("/")[-1] - ) - self.predictor.postprocess_DIST(id_=0, id_name=self.dir.split("/")[-1]) - - if self.predictor.segments is None: - show_info("TARDIS-em could not find any instances :(") - return - - self.predictor.save_instance_PC(self.dir.split("/")[-1]) - self.predictor.clean_up(dir_=self.dir) - - @thread_worker - def predict_dataset(self, img_dataset, predictor): - for j in range(len(img_dataset)): - input_, name = img_dataset.__getitem__(j) - - input_ = predictor.predict_cnn_napari(input_, name) - update_viewer_prediction(self.viewer, input_, self.calculate_position(name)) - - self.cnn_postprocess() - - def show_command(self): - mask = "" if not bool(self.mask.checkState()) else "-ms True" - - correct_px = ( - "" - if self.correct_px.text() == "None" - else f"-px {float(self.correct_px.text())} " - ) - if self.px is not None: - correct_px = ( - "" - if self.px == float(self.correct_px.text()) - else f"-px {float(self.correct_px.text())} " - ) - - px = "" if not bool(self.mask.checkState()) else "-ms True " - - ch = ( - "" - if self.checkpoint.text() == "None" - else f"-ch {self.checkpoint.text()}_None " - ) - - mv = ( - "" - if self.model_version.currentText() == "None" - else f"-mv {int(self.model_version.currentText())} " - ) - - cnn = ( - "" - if self.cnn_type.currentText() == "fnet_attn" - else f"-cnn {self.cnn_type.currentText()} " - ) - - rt = "" if bool(self.rotate.checkState()) else "-rt False " - - ct = ( - "-ct auto " - if float(self.cnn_threshold.text()) == 1.0 - else f"-ct {self.cnn_threshold.text()} " - ) - - dt = ( - f"-dt {float(self.dist_threshold.text())} " - if not self.output_formats.endswith("None") - else "" - ) - - ap = ( - "" - if self.amira_prefix.text() == ".CorrelationLines" - else f"-ap {self.amira_prefix.text()} " - ) - acd = ( - "" - if self.amira_compare_distance.text() == "175" - else f"-acd {self.amira_compare_distance.text()} " - ) - aip = ( - "" - if self.amira_inter_probability.text() == "0.25" - else f"-aip {self.amira_inter_probability.text()} " - ) - - fl = ( - "" - if self.filter_by_length.text() == "1000" - else f"-fl {int(self.filter_by_length.text())} " - ) - cs = ( - "" - if self.connect_splines.text() == "2500" - else f"-fl {int(self.connect_splines.text())} " - ) - cc = ( - "" - if self.connect_cylinder.text() == "250" - else f"-fl {int(self.connect_cylinder.text())} " - ) - - show_info( - f"tardis_mt " - f"-dir {self.out_} " - f"{mask}" - f"{px}" - f"{ch}" - f"{mv}" - f"{cnn}" - f"-out {self.output_formats} " - f"-ps {int(self.patch_size.currentText())} " - f"{rt}" - f"{ct}" - f"{dt}" - f"{ap}" - f"{acd}" - f"{aip}" - f"{fl}" - f"{cs}" - f"{cc}" - f"-pv {int(self.points_in_patch.text())} " - f"-dv {self.device.currentText()}" - ) - - def update_cnn_threshold(self): - if self.img is not None: - if float(self.cnn_threshold.text()) == 1.0: - self.img_threshold = adaptive_threshold(self.img).astype(np.uint8) - elif float(self.cnn_threshold.text()) == 0.0: - self.img_threshold = np.copy(self.img) - else: - self.img_threshold = np.where( - self.img >= float(self.cnn_threshold.text()), 1, 0 - ).astype(np.uint8) - - create_image_layer( - self.viewer, - image=self.img_threshold, - name="Prediction", - transparency=True, - range_=(0, 1), - ) - - def update_versions(self): - for i in range(self.model_version.count()): - self.model_version.removeItem(0) - - versions = get_all_version_aws( - self.cnn_type.currentText(), "32", "microtubules_3d" - ) - - if len(versions) == 0: - self.model_version.addItems(["None"]) - else: - self.model_version.addItems(["None"] + [i.split("_")[-1] for i in versions]) - - def calculate_position(self, name): - patch_size = int(self.patch_size.currentText()) - name = name.split("_") - name = { - "z": int(name[1]), - "y": int(name[2]), - "x": int(name[3]), - "stride": int(name[4]), - } - - x_start = (name["x"] * patch_size) - (name["x"] * name["stride"]) - x_end = x_start + patch_size - name["x"] = [x_start, x_end] - - y_start = (name["y"] * patch_size) - (name["y"] * name["stride"]) - y_end = y_start + patch_size - name["y"] = [y_start, y_end] - - z_start = (name["z"] * patch_size) - (name["z"] * name["stride"]) - z_end = z_start + patch_size - name["z"] = [z_start, z_end] - - return name diff --git a/viewers/viewer_predict.py b/viewers/viewer_predict.py deleted file mode 100644 index 7874457..0000000 --- a/viewers/viewer_predict.py +++ /dev/null @@ -1,178 +0,0 @@ -####################################################################### -# TARDIS - Transformer And Rapid Dimensionless Instance Segmentation # -# # -# New York Structural Biology Center # -# Simons Machine Learning Center # -# # -# Robert Kiewisz, Tristan Bepler # -# MIT License 2024 # -####################################################################### -import os - -from PyQt5.QtCore import Qt -from PyQt5.QtGui import QIntValidator, QDoubleValidator -from PyQt5.QtWidgets import ( - QPushButton, - QFormLayout, - QLineEdit, - QComboBox, - QLabel, - QCheckBox, -) -from napari import Viewer - -from qtpy.QtWidgets import QWidget - -from napari_tardis_em.utils.styles import border_style -from napari_tardis_em.utils.utils import get_list_of_device - - -class TardisWidget(QWidget): - """ - Easy to use plugin for CNN prediction. - - Plugin integrates TARDIS-em and allows to easily set up prediction on a pre-trained - model - """ - - def __init__(self, viewer_predict: Viewer): - super().__init__() - - self.viewer = viewer_predict - - """""" """""" """ - UI Elements - """ """""" """""" - directory = QPushButton(f"...{os.getcwd()[-30:]}") - directory.setToolTip( - "Select directory with image or single file you would like to predict. \n " - "\n" - "Supported formats:\n" - "images: *.mrc, *.rec, *.map, *.am" - ) - output = QPushButton(f"...{os.getcwd()[-17:]}/Predictions/") - output.setToolTip( - "Select directory in which plugin will save train model, checkpoints and training logs." - ) - - ############################## - # Setting user should change # - ############################## - label_2 = QLabel("Setting user should change") - label_2.setStyleSheet(border_style("green")) - - output_semantic = QComboBox() - output_semantic.addItems(["mrc", "tif", "am"]) - output_semantic.setToolTip("Select semantic output format file.") - - output_instance = QComboBox() - output_instance.addItems(["None", "csv", "npy", "amSG"]) - output_instance.setToolTip("Select instance output format file.") - - ########################### - # Setting user may change # - ########################### - label_3 = QLabel("Setting user may change") - label_3.setStyleSheet(border_style("yellow")) - - correct_px = QLineEdit("None") - correct_px.setToolTip( - "Set correct pixel size value, if image header \n" - "do not contain or stores incorrect information." - ) - - cnn_type = QComboBox() - cnn_type.addItems(["unet", "resnet", "unet3plus", "fnet", "fnet_attn"]) - cnn_type.setCurrentIndex(0) - cnn_type.setToolTip("Select type of CNN you would like to train.") - - checkpoint = QLineEdit("None") - checkpoint.setToolTip("Optional, directory to CNN checkpoint.") - - patch_size = QComboBox() - patch_size.addItems( - ["32", "64", "96", "128", "160", "192", "256", "512", "1024"] - ) - patch_size.setCurrentIndex(1) - patch_size.setToolTip( - "Select patch size value that will be used to split \n" - "all images into smaller patches." - ) - - rotate = QCheckBox() - rotate.setCheckState(Qt.CheckState.Checked) - rotate.setToolTip( - "Select if you want to switch on/of rotation during the prediction. \n" - "If selected, during CNN prediction image is rotate 4x by 90 degrees.\n" - "This will increase prediction time 4x. \n" - "However may lead to more cleaner output." - ) - - cnn_threshold = QLineEdit("0.25") - cnn_threshold.setValidator(QDoubleValidator(0.0, 1.0, 3)) - cnn_threshold.setToolTip( - "Threshold value for binary prediction. Lower value will increase \n" - "recall [retrieve more of predicted object] but also may increase \n" - "false/positives. Higher value will result in cleaner output but may \n" - "reduce recall." - ) - - dist_threshold = QLineEdit("0.5") - dist_threshold.setValidator(QDoubleValidator(0.0, 1.0, 3)) - dist_threshold.setToolTip( - "Threshold value for instance prediction. Lower value will increase \n" - "recall [retrieve more of predicted object] but also may increase \n" - "false/positives. Higher value will result in cleaner output but may \n" - "reduce recall." - ) - - device = QComboBox() - device.addItems(get_list_of_device()) - device.setCurrentIndex(0) - device.setToolTip( - "Select available device on which you want to train your model." - ) - - ######################################## - # Setting user is not advice to change # - ######################################## - label_4 = QLabel("Setting user is not advice to change") - label_4.setStyleSheet(border_style("red")) - - points_in_patch = QLineEdit("900") - points_in_patch.setValidator(QDoubleValidator(100, 10000, 1)) - points_in_patch.setToolTip( - "Number of point in patch. Higher number will increase how may points \n" - "DIST model will process at the time. This is usually only the memory GPU constrain." - ) - - predict_button = QPushButton("Predict ...") - predict_button.setMinimumWidth(225) - - """""" """""" """ - UI Setup - """ """""" """""" - layout = QFormLayout() - layout.addRow("Select Directory", directory) - layout.addRow("Output Directory", output) - - layout.addRow("---- CNN Options ----", label_2) - layout.addRow("Semantic output", output_semantic) - layout.addRow("Instance output", output_instance) - - layout.addRow("----- Extra --------", label_3) - layout.addRow("Correct pixel size", correct_px) - layout.addRow("CNN type", cnn_type) - layout.addRow("Checkpoint", checkpoint) - layout.addRow("Patch size", patch_size) - layout.addRow("Rotation", rotate) - layout.addRow("CNN threshold", cnn_threshold) - layout.addRow("DIST threshold", dist_threshold) - layout.addRow("Device", device) - - layout.addRow("---- Advance -------", label_4) - layout.addRow("No. of points [DIST]", points_in_patch) - - layout.addRow("", predict_button) - - self.setLayout(layout) diff --git a/viewers/viewer_train.py b/viewers/viewer_train.py deleted file mode 100644 index dfef9fe..0000000 --- a/viewers/viewer_train.py +++ /dev/null @@ -1,268 +0,0 @@ -####################################################################### -# TARDIS - Transformer And Rapid Dimensionless Instance Segmentation # -# # -# New York Structural Biology Center # -# Simons Machine Learning Center # -# # -# Robert Kiewisz, Tristan Bepler # -# MIT License 2024 # -####################################################################### -import os - -from PyQt5.QtGui import QIntValidator, QDoubleValidator -from PyQt5.QtWidgets import QPushButton, QFormLayout, QLineEdit, QComboBox, QLabel -from napari import Viewer - -from qtpy.QtWidgets import QWidget - -from napari_tardis_em.utils.styles import border_style -from napari_tardis_em.utils.utils import get_list_of_device - - -class TardisWidget(QWidget): - """ - Easy to use plugin for CNN training. - - Plugin integrate TARDIS-em and allow to easily set up training. To make it more - user-friendly, this plugin guid user what to do, and during training display - results from validation loop. - """ - - def __init__(self, viewer_train: Viewer): - super().__init__() - - self.viewer = viewer_train - - """""" """""" """ - UI Elements - """ """""" """""" - directory = QPushButton(f"...{os.getcwd()[-30:]}") - directory.setToolTip( - "Select directory with image and mask files. \n " - "Image and file files should have the following naming: \n" - " - image: name.*\n" - " - mask: name_maks.*\n" - "\n" - "Supported formats:\n" - "images: *.mrc, *.rec, *.map, *.am\n" - "masks: *.CorrelationLines.am, *_mask.am, *_mask.mrc, *_mask.rec, *_mask.csv, *_mask.tif" - ) - output = QPushButton(f"...{os.getcwd()[-17:]}/tardis-em_training/") - output.setToolTip( - "Select directory in which plugin will save train model, checkpoints and training logs." - ) - - ############################## - # Setting user should change # - ############################## - label_2 = QLabel("Setting user should change") - label_2.setStyleSheet(border_style("green")) - - patch_size = QComboBox() - patch_size.addItems( - ["32", "64", "96", "128", "160", "192", "256", "512", "1024"] - ) - patch_size.setCurrentIndex(1) - patch_size.setToolTip( - "Select patch size value that will be used to split \n" - "all images into smaller patches." - ) - - cnn_type = QComboBox() - cnn_type.addItems(["unet", "resnet", "unet3plus", "fnet", "fnet_attn"]) - cnn_type.setCurrentIndex(0) - cnn_type.setToolTip("Select type of CNN you would like to train.") - - image_type = QComboBox() - image_type.addItems(["2D", "3D"]) - image_type.setCurrentIndex(1) - image_type.setToolTip( - "Select type of images you would like to train CNN model on." - ) - - cnn_in_channel = QLineEdit("1") - cnn_in_channel.setValidator(QIntValidator(1, 100)) - cnn_in_channel.setToolTip( - "Select how many input channels the CNN network should expect." - ) - - device = QComboBox() - device.addItems(get_list_of_device()) - device.setCurrentIndex(0) - device.setToolTip( - "Select available device on which you want to train your model." - ) - - checkpoint = QLineEdit("None") - checkpoint.setToolTip( - "Optional, directory to CNN checkpoint to restart training." - ) - - ########################### - # Setting user may change # - ########################### - label_3 = QLabel("Setting user may change") - label_3.setStyleSheet(border_style("yellow")) - - pixel_size = QLineEdit("None") - pixel_size.setValidator(QDoubleValidator(0.1, 50.0, 2)) - pixel_size.setToolTip( - "Optionally, select pixel size value that will be \n" - "used to normalize all images fixed resolution." - ) - - mask_size = QLineEdit("150") - mask_size.setValidator(QIntValidator(5, 250)) - mask_size.setToolTip( - "Select mask size in Angstrom. The mask size is used \n" - "to draw mask/labels based on coordinates if name_maks.* \n" - "files is a *.csv file with coordinates." - ) - - batch_size = QLineEdit("24") - batch_size.setValidator(QIntValidator(5, 50)) - batch_size.setToolTip( - "Select number of batches. The batch refers to a set of multiple data \n" - "samples processed together. This setting will heavy imply how much GPU memory \n" - "CNN training will require. Reduce this number if needed." - ) - - cnn_layers = QLineEdit("5") - cnn_layers.setValidator(QIntValidator(2, 6)) - cnn_layers.setToolTip("Select number of convolution layer for CNN.") - - cnn_scaler = QComboBox() - cnn_scaler.addItems(["16", "32", "64"]) - cnn_scaler.setCurrentIndex(1) - cnn_scaler.setToolTip( - "Convolution multiplayer for CNN layers. This mean what is the CNN layer \n" - "multiplayer at each layer.\n" - "For example:\n" - "If we have 5 layers and 32 multiplayer at the last layer model will have 512 channels." - ) - - loss_function = QComboBox() - loss_function.addItems( - [ - "AdaptiveDiceLoss", - "BCELoss", - "WBCELoss", - "BCEDiceLoss", - "CELoss", - "DiceLoss", - "ClDiceLoss", - "ClBCELoss", - "SigmoidFocalLoss", - "LaplacianEigenmapsLoss", - "BCEMSELoss", - ] - ) - loss_function.setCurrentIndex(1) - loss_function.setToolTip("Select one of the pre-build loss functions.") - - learning_rate = QLineEdit("0.0005") - learning_rate.setValidator(QDoubleValidator(0.0000001, 0.1, 7)) - learning_rate.setToolTip( - "Select learning rate.\n" - "The learning rate is a hyperparameter that controls how much to adjust \n" - "the model’s weights with respect to the loss gradient during training" - ) - - epoch = QLineEdit("1000") - epoch.setValidator(QIntValidator(10, 100000)) - epoch.setToolTip( - "Select maximum number of epoches for which CNN model should train." - ) - - early_stop = QLineEdit("100") - early_stop.setValidator(QIntValidator(5, 10000)) - early_stop.setToolTip( - "Early stopping in CNN training is a regularization technique that halts \n" - "training when the model’s performance on a validation set stops improving, \n" - "preventing overfitting. This ensures the model retains optimal generalization \n" - "capabilities by terminating training at the point of best validation performance. \n" - "It's recommended to use 10% value of epoch size." - ) - - dropout_rate = QLineEdit("0.5") - dropout_rate.setValidator(QDoubleValidator(0.00, 1.00, 3)) - dropout_rate.setToolTip( - "In machine learning, dropout is a regularization technique that randomly \n" - "omits a fraction of neurons during training to prevent overfitting, \n" - "while in education, dropout refers to a student who leaves school \n" - "before completing their program." - ) - - ######################################## - # Setting user is not advice to change # - ######################################## - label_4 = QLabel("Setting user is not advice to change") - label_4.setStyleSheet(border_style("red")) - - cnn_out_channel = QLineEdit("1") - cnn_out_channel.setValidator(QIntValidator(1, 100)) - cnn_out_channel.setToolTip( - "Select how many output channels the CNN network should return." - ) - - cnn_structure = QLineEdit("gcl") - cnn_structure.setToolTip( - "Define structure order of the convolution block." - "c - convolution" - "g - group normalization" - "b - batch normalization" - "r - ReLU" - "l - LeakyReLU" - "e - GeLu" - "p - PReLu", - ) - - cnn_kernel = QLineEdit("3") - cnn_kernel.setToolTip("Select convolution kernel size.") - - cnn_padding = QLineEdit("1") - cnn_padding.setToolTip("Select convolution padding size.") - - cnn_max_pool = QLineEdit("2") - cnn_max_pool.setToolTip("Select convolution max pool size.") - - train_button = QPushButton("Train ...") - train_button.setMinimumWidth(225) - - """""" """""" """ - UI Setup - """ """""" """""" - layout = QFormLayout() - layout.addRow("Select Directory", directory) - layout.addRow("Output Directory", output) - - layout.addRow("---- CNN Options ----", label_2) - layout.addRow("Patch Size", patch_size) - layout.addRow("CNN type", cnn_type) - layout.addRow("Image type", image_type) - layout.addRow("No. of input channel", cnn_in_channel) - layout.addRow("Checkpoint", checkpoint) - layout.addRow("Device", device) - - layout.addRow("----- Extra --------", label_3) - layout.addRow("Pixel Size", pixel_size) - layout.addRow("Mask Size", mask_size) - layout.addRow("Batch Size", batch_size) - layout.addRow("No. of CNN layers", cnn_layers) - layout.addRow("Channel scaler size", cnn_scaler) - layout.addRow("Loss function", loss_function) - layout.addRow("Learning rate", learning_rate) - layout.addRow("No. of Epoches", epoch) - layout.addRow("Early stop", early_stop) - layout.addRow("Dropout rate", dropout_rate) - - layout.addRow("---- Advance -------", label_4) - layout.addRow("No. of output channel", cnn_out_channel) - layout.addRow("Define CNN structure", cnn_structure) - layout.addRow("CNN kernel size", cnn_kernel) - layout.addRow("CNN padding size", cnn_padding) - layout.addRow("CNN max_pool size", cnn_max_pool) - - layout.addRow("", train_button) - - self.setLayout(layout)