diff --git a/finetuning/livecell_finetuning.py b/finetuning/livecell_finetuning.py index 1cd942fd..6b63a5a0 100644 --- a/finetuning/livecell_finetuning.py +++ b/finetuning/livecell_finetuning.py @@ -40,7 +40,7 @@ def get_dataloaders(patch_shape, data_path, cell_type=None): def finetune_livecell(args): - """Example code for finetuning SAM on LiveCELL""" + """Example code for finetuning SAM on LIVECell""" # override this (below) if you have some more complex set-up and need to specify the exact gpu device = "cuda" if torch.cuda.is_available() else "cpu" @@ -84,10 +84,10 @@ def finetune_livecell(args): def main(): - parser = argparse.ArgumentParser(description="Finetune Segment Anything for the LiveCELL dataset.") + parser = argparse.ArgumentParser(description="Finetune Segment Anything for the LIVECell dataset.") parser.add_argument( "--input_path", "-i", default="/scratch/projects/nim00007/sam/data/livecell/", - help="The filepath to the LiveCELL data. If the data does not exist yet it will be downloaded." + help="The filepath to the LIVECell data. If the data does not exist yet it will be downloaded." ) parser.add_argument( "--model_type", "-m", default="vit_b", diff --git a/micro_sam/_vendored.py b/micro_sam/_vendored.py index 14446e01..976f8b4a 100644 --- a/micro_sam/_vendored.py +++ b/micro_sam/_vendored.py @@ -6,9 +6,11 @@ The license type of the thrid party software project must be compatible with the software license the micro-sam project is distributed under. """ + from typing import Any, Dict, List import numpy as np + import torch try: diff --git a/micro_sam/bioimageio/model_export.py b/micro_sam/bioimageio/model_export.py index 3b270c58..dfee9612 100644 --- a/micro_sam/bioimageio/model_export.py +++ b/micro_sam/bioimageio/model_export.py @@ -1,26 +1,26 @@ import os import tempfile - from pathlib import Path from typing import Optional, Union -import bioimageio.core -import bioimageio.spec.model.v0_5 as spec -import matplotlib.pyplot as plt +import xarray import numpy as np +import matplotlib.pyplot as plt + import torch -import xarray +import bioimageio.core +import bioimageio.spec.model.v0_5 as spec from bioimageio.spec import save_bioimageio_package from bioimageio.core.digest_spec import create_sample_for_model - from .. import util from ..prompt_generators import PointAndBoxPromptGenerator from ..evaluation.model_comparison import _enhance_image, _overlay_outline, _overlay_box from ..prompt_based_segmentation import _compute_logits_from_mask from .predictor_adaptor import PredictorAdaptor + DEFAULTS = { "authors": [ spec.Author(name="Anwai Archit", affiliation="University Goettingen", github_user="anwai98"), diff --git a/micro_sam/evaluation/evaluation.py b/micro_sam/evaluation/evaluation.py index 4fc76146..a52a1126 100644 --- a/micro_sam/evaluation/evaluation.py +++ b/micro_sam/evaluation/evaluation.py @@ -11,8 +11,8 @@ import numpy as np import pandas as pd import imageio.v3 as imageio - from skimage.measure import label + from elf.evaluation import mean_segmentation_accuracy @@ -88,6 +88,7 @@ def run_evaluation_for_iterative_prompting( prediction_root: The folder with the iterative prompt-based instance segmentations to evaluate. experiment_folder: The folder where all the experiment results are stored. start_with_box_prompt: Whether to evaluate on experiments with iterative prompting starting with box. + overwrite_results: Whether to overwrite the results to update them with the new evaluation run. Returns: A DataFrame that contains the evaluation results. diff --git a/micro_sam/evaluation/experiments.py b/micro_sam/evaluation/experiments.py index 4646af52..5b5b9c76 100644 --- a/micro_sam/evaluation/experiments.py +++ b/micro_sam/evaluation/experiments.py @@ -3,6 +3,7 @@ from typing import Dict, List, Optional + # TODO fully define the dict type ExperimentSetting = Dict ExperimentSettings = List[ExperimentSetting] diff --git a/micro_sam/evaluation/instance_segmentation.py b/micro_sam/evaluation/instance_segmentation.py index 9d6331e6..5e657190 100644 --- a/micro_sam/evaluation/instance_segmentation.py +++ b/micro_sam/evaluation/instance_segmentation.py @@ -3,20 +3,20 @@ import os from glob import glob -from itertools import product +from tqdm import tqdm from pathlib import Path +from itertools import product from typing import Any, Dict, List, Optional, Tuple, Union -import imageio.v3 as imageio import numpy as np import pandas as pd +import imageio.v3 as imageio -from elf.evaluation import mean_segmentation_accuracy from elf.io import open_file -from tqdm import tqdm +from elf.evaluation import mean_segmentation_accuracy -from ..instance_segmentation import AMGBase, InstanceSegmentationWithDecoder, mask_data_to_segmentation from .. import util +from ..instance_segmentation import AMGBase, InstanceSegmentationWithDecoder, mask_data_to_segmentation def _get_range_of_search_values(input_vals, step): diff --git a/micro_sam/evaluation/livecell.py b/micro_sam/evaluation/livecell.py index f0699ab8..c9d75f51 100644 --- a/micro_sam/evaluation/livecell.py +++ b/micro_sam/evaluation/livecell.py @@ -1,6 +1,7 @@ """Inference and evaluation for the [LIVECell dataset](https://www.nature.com/articles/s41592-021-01249-6) and the different cell lines contained in it. """ + import os import json import argparse @@ -422,7 +423,7 @@ def run_livecell_inference() -> None: def run_livecell_evaluation() -> None: - """Run LiveCELL evaluation with command line tool.""" + """Run LIVECell evaluation with command line tool.""" parser = argparse.ArgumentParser() parser.add_argument( "-i", "--input", required=True, help="Provide the data directory for LIVECell Dataset" diff --git a/micro_sam/evaluation/model_comparison.py b/micro_sam/evaluation/model_comparison.py index 4b130aec..6e37528d 100644 --- a/micro_sam/evaluation/model_comparison.py +++ b/micro_sam/evaluation/model_comparison.py @@ -2,23 +2,22 @@ """ import os -from functools import partial from glob import glob +from tqdm import tqdm from pathlib import Path +from functools import partial +from typing import Optional, Union import h5py -import matplotlib.pyplot as plt import numpy as np import pandas as pd -import torch - +import matplotlib.pyplot as plt import skimage.draw as draw -from scipy.ndimage import binary_dilation from skimage import exposure +from scipy.ndimage import binary_dilation from skimage.segmentation import relabel_sequential, find_boundaries -from tqdm import tqdm -from typing import Optional, Union +import torch from .. import util from ..prompt_generators import PointAndBoxPromptGenerator diff --git a/micro_sam/inference.py b/micro_sam/inference.py index 6d67b38e..8725dea9 100644 --- a/micro_sam/inference.py +++ b/micro_sam/inference.py @@ -1,11 +1,12 @@ import os from typing import Optional, Union -import torch import numpy as np -import segment_anything.utils.amg as amg_utils +import torch + from segment_anything import SamPredictor +import segment_anything.utils.amg as amg_utils from segment_anything.utils.transforms import ResizeLongestSide from . import util diff --git a/micro_sam/instance_segmentation.py b/micro_sam/instance_segmentation.py index 23d666b9..d86c534e 100644 --- a/micro_sam/instance_segmentation.py +++ b/micro_sam/instance_segmentation.py @@ -6,24 +6,25 @@ import os from abc import ABC -from collections import OrderedDict from copy import deepcopy +from collections import OrderedDict from typing import Any, Dict, List, Optional, Tuple, Union -import numpy as np -import torch -import segment_anything.utils.amg as amg_utils import vigra - -from nifty.tools import blocking -from segment_anything.predictor import SamPredictor - +import numpy as np from skimage.measure import regionprops + +import torch from torchvision.ops.boxes import batched_nms, box_area from torch_em.model import UNETR from torch_em.util.segmentation import watershed_from_center_and_boundary_distances +from nifty.tools import blocking + +import segment_anything.utils.amg as amg_utils +from segment_anything.predictor import SamPredictor + from . import util from ._vendored import batched_mask_to_box, mask_to_rle_pytorch @@ -56,6 +57,7 @@ def mask_data_to_segmentation( object in the output will be mapped to zero (the background value). min_object_size: The minimal size of an object in pixels. max_object_size: The maximal size of an object in pixels. + Returns: The instance segmentation. """ diff --git a/micro_sam/models/build_sam.py b/micro_sam/models/build_sam.py index 8fa6bcc6..901c6c38 100644 --- a/micro_sam/models/build_sam.py +++ b/micro_sam/models/build_sam.py @@ -3,6 +3,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +# https://github.com/facebookresearch/segment-anything/ # # NOTE: This code has been adapted from Segment Anything. diff --git a/micro_sam/models/peft_sam.py b/micro_sam/models/peft_sam.py index d2eaa987..2bdeed70 100644 --- a/micro_sam/models/peft_sam.py +++ b/micro_sam/models/peft_sam.py @@ -18,6 +18,10 @@ class LoRASurgery(nn.Module): qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) ``` + + Args: + rank: The rank of the decomposition matrices for updating weights in each attention layer. + block: The chosen attention blocks for implementing lora. """ def __init__( self, diff --git a/micro_sam/models/sam_3d_wrapper.py b/micro_sam/models/sam_3d_wrapper.py index 4a7645d0..3e0b7573 100644 --- a/micro_sam/models/sam_3d_wrapper.py +++ b/micro_sam/models/sam_3d_wrapper.py @@ -3,8 +3,8 @@ import torch import torch.nn as nn -from segment_anything.modeling.image_encoder import window_partition, window_unpartition from segment_anything.modeling import Sam +from segment_anything.modeling.image_encoder import window_partition, window_unpartition from ..util import get_sam_model @@ -42,6 +42,7 @@ def __init__(self, sam_model: Sam, freeze_encoder: bool): Args: sam_model: The Sam model to be wrapped. + freeze_encoder: Whether to freeze the image encoder. """ super().__init__() sam_model.image_encoder = ImageEncoderViT3DWrapper( @@ -64,7 +65,7 @@ def forward( Unlike original SAM this model only supports automatic segmentation and does not support prompts. Args: - batched_input: A list over input images, each a dictionary with the following keys.L + batched_input: A list over input images, each a dictionary with the following keys. 'image': The image as a torch tensor in 3xDxHxW format. Already transformed for the input to the model. 'original_size': The original size of the image (HxW) before transformation. multimask_output: Wheterh to predict with the multi- or single-mask head of the maks decoder. diff --git a/micro_sam/models/simple_sam_3d_wrapper.py b/micro_sam/models/simple_sam_3d_wrapper.py index cf4ddbcc..6f67caa4 100644 --- a/micro_sam/models/simple_sam_3d_wrapper.py +++ b/micro_sam/models/simple_sam_3d_wrapper.py @@ -16,7 +16,6 @@ def get_simple_sam_3d_model( model_type="vit_b", checkpoint_path=None, ): - _, sam = get_sam_model( model_type=model_type, device=device, diff --git a/micro_sam/multi_dimensional_segmentation.py b/micro_sam/multi_dimensional_segmentation.py index 2c65d4f1..c8747ed8 100644 --- a/micro_sam/multi_dimensional_segmentation.py +++ b/micro_sam/multi_dimensional_segmentation.py @@ -5,23 +5,26 @@ from typing import Optional, Union, Tuple import numpy as np + import nifty -import elf.tracking.tracking_utils as track_utils + import elf.segmentation as seg_utils +import elf.tracking.tracking_utils as track_utils -from segment_anything.predictor import SamPredictor from scipy.ndimage import binary_closing from skimage.measure import label, regionprops from skimage.segmentation import relabel_sequential +from segment_anything.predictor import SamPredictor + try: from napari.utils import progress as tqdm except ImportError: from tqdm import tqdm from . import util -from .instance_segmentation import AMGBase, mask_data_to_segmentation from .prompt_based_segmentation import segment_from_mask +from .instance_segmentation import AMGBase, mask_data_to_segmentation PROJECTION_MODES = ("box", "mask", "points", "points_and_mask", "single_point") diff --git a/micro_sam/precompute_state.py b/micro_sam/precompute_state.py index d07ea1bc..e4a970b7 100644 --- a/micro_sam/precompute_state.py +++ b/micro_sam/precompute_state.py @@ -3,16 +3,17 @@ import os import pickle - -from functools import partial from glob import glob from pathlib import Path +from functools import partial from typing import Optional, Tuple, Union, List import h5py import numpy as np + import torch import torch.nn as nn + from segment_anything.predictor import SamPredictor try: diff --git a/micro_sam/prompt_based_segmentation.py b/micro_sam/prompt_based_segmentation.py index 9de5954d..e2bb1026 100644 --- a/micro_sam/prompt_based_segmentation.py +++ b/micro_sam/prompt_based_segmentation.py @@ -6,15 +6,18 @@ from typing import Optional, Tuple import numpy as np -import torch -from nifty.tools import blocking -from skimage.feature import peak_local_max from skimage.filters import gaussian +from skimage.feature import peak_local_max from skimage.segmentation import find_boundaries from scipy.ndimage import distance_transform_edt +import torch + +from nifty.tools import blocking + from segment_anything.predictor import SamPredictor from segment_anything.utils.transforms import ResizeLongestSide + from . import util diff --git a/micro_sam/sam_annotator/_annotator.py b/micro_sam/sam_annotator/_annotator.py index fa8da43a..974aeccb 100644 --- a/micro_sam/sam_annotator/_annotator.py +++ b/micro_sam/sam_annotator/_annotator.py @@ -1,11 +1,11 @@ import napari import numpy as np -from magicgui.widgets import Widget, Container, FunctionGui from qtpy import QtWidgets +from magicgui.widgets import Widget, Container, FunctionGui -from . import _widgets as widgets from . import util as vutil +from . import _widgets as widgets from ._state import AnnotatorState diff --git a/micro_sam/sam_annotator/_state.py b/micro_sam/sam_annotator/_state.py index 639642ea..b57c1f80 100644 --- a/micro_sam/sam_annotator/_state.py +++ b/micro_sam/sam_annotator/_state.py @@ -3,18 +3,19 @@ https://itnext.io/deciding-the-best-singleton-approach-in-python-65c61e90cdc4 """ -from dataclasses import dataclass, field from functools import partial +from dataclasses import dataclass, field from typing import Dict, List, Optional, Tuple +import zarr import numpy as np +from qtpy.QtWidgets import QWidget + import torch.nn as nn -import zarr import micro_sam.util as util from micro_sam.instance_segmentation import AMGBase, get_decoder from micro_sam.precompute_state import cache_amg_state, cache_is_state -from qtpy.QtWidgets import QWidget from segment_anything import SamPredictor diff --git a/micro_sam/sam_annotator/_tooltips.py b/micro_sam/sam_annotator/_tooltips.py index 068dd44d..8ddda0ce 100644 --- a/micro_sam/sam_annotator/_tooltips.py +++ b/micro_sam/sam_annotator/_tooltips.py @@ -6,25 +6,25 @@ "custom_weights": "Select custom model weights. For example for a model you have finetuned", "device": "Select the computational device to use for processing.", "embeddings_save_path": "Select path to save or load the computed image embeddings.", - "halo": "Enter overlap values for computing tiled embeddings. Enter only x-value for quadratic size.\n Only active when tiling is used.", + "halo": "Enter overlap values for computing tiled embeddings. Enter only x-value for quadratic size.\n Only active when tiling is used.", # noqa "image": "Select the napari image layer.", "model": "Select the segment anything model.", - "prefer_decoder": "Choose if the segmentation decoder is used for automatic segmentation. Only if it is available for the selected model..", + "prefer_decoder": "Choose if the segmentation decoder is used for automatic segmentation. Only if it is available for the selected model..", # noqa "run_button": "Compute embeddings or load embeddings if embedding_save_path is specified.", - "tiling": "Enter tile size for computing tiled embeddings. Enter only x-value for quadratic size or both for non-quadratic.", + "tiling": "Enter tile size for computing tiled embeddings. Enter only x-value for quadratic size or both for non-quadratic.", # noqa }, "segmentnd": { - "box_extension": "Enter factor by which box size is increased when projecting to adjacent slices. Larger factors help if object sizes change between slices.", + "box_extension": "Enter factor by which box size is increased when projecting to adjacent slices. Larger factors help if object sizes change between slices.", # noqa "iou_threshold": "Enter the minimal overlap between objects in adjacent slices to continue segmentation.", - "motion_smoothing": "Enter the motion smoothing factor. It is used to follow objects which have a directed movement, higher values help for objects that are moving fast.", - "projection_dropdown": "Choose the projection mode. It determines which prompts are derived from the masks projected to adjacent frames to rerun SAM.", + "motion_smoothing": "Enter the motion smoothing factor. It is used to follow objects which have a directed movement, higher values help for objects that are moving fast.", # noqa + "projection_dropdown": "Choose the projection mode. It determines which prompts are derived from the masks projected to adjacent frames to rerun SAM.", # noqa }, "autosegment": { # General settings. "apply_to_volume": "Choose if automatic segmentation is run for the full volume or only the current slice.", - "gap_closing": "Enter value for closing gaps across slices for volumetric segmentation. Higher values will reduce artifacts due to missing slices in objects but may lead to wrongly merging objects.", - "min_extent": "Enter the minimal number of slices for objects in volumetric segmentation. To filter out small segmentation artifacts.", - "min_object_size": "Enter the minimal object size in pixels. This refers to the size per slice for volumetric segmentation.", + "gap_closing": "Enter value for closing gaps across slices for volumetric segmentation. Higher values will reduce artifacts due to missing slices in objects but may lead to wrongly merging objects.", # noqa + "min_extent": "Enter the minimal number of slices for objects in volumetric segmentation. To filter out small segmentation artifacts.", # noqa + "min_object_size": "Enter the minimal object size in pixels. This refers to the size per slice for volumetric segmentation.", # noqa "run_button": "Run automatic segmentation.", "with_background": "Choose if your image has a large background area.", # Settings for AIS. @@ -36,28 +36,28 @@ "stability_score_thresh": "Enter the threshold for filtering objects based on the stability score.", }, "prompt_menu": { - "labels": "Choose positive prompts to inlcude regions or negative ones to exclude regions. Toggle between the settings by pressing [t].", + "labels": "Choose positive prompts to inlcude regions or negative ones to exclude regions. Toggle between the settings by pressing [t].", # noqa }, "annotator_tracking": { "track_id": "Select the id of the track you are currently annotating.", - "track_state": "Select the state of the current annotation. Choose 'division' if the object is dviding in the current frame.", + "track_state": "Select the state of the current annotation. Choose 'division' if the object is dviding in the current frame.", # noqa }, "image_series_annotator": { "folder": "Select the folder with the images to annotate.", "output_folder": "Select the folder for saving the segmentation results.", - "pattern": "Select a pattern for selecting files. E.g. '*.tif' to only select tif files. By default all files in the input folder are selected.", + "pattern": "Select a pattern for selecting files. E.g. '*.tif' to only select tif files. By default all files in the input folder are selected.", # noqa "is_volumetric": "Choose if the data you annotate is volumetric.", }, "training": { "checkpoint": "Select a checkpoint (saved model) to resume training from.", "device": "Select the computational device to use for processing.", "initial_model": "Select the model name used as starting point for training.", - "label_key": "Define the key that holds to the segmentation labels. Use a pattern, e.g. \"*.tif\" select multiple files or an internal path for hdf5, zarr or similar formats.", - "label_path": "Specify the path to the segmentaiton labels for training. Can either point to a directory or single file.", - "label_path_val": "Specify the path to the segmentation labels for validation. Can either point to a directory or single file.", + "label_key": "Define the key that holds to the segmentation labels. Use a pattern, e.g. \"*.tif\" select multiple files or an internal path for hdf5, zarr or similar formats.", # noqa + "label_path": "Specify the path to the segmentaiton labels for training. Can either point to a directory or single file.", # noqa + "label_path_val": "Specify the path to the segmentation labels for validation. Can either point to a directory or single file.", # noqa "name": "Enter the name of the model that will be trained.", "patch": "Select the size of image patches used for training.", - "raw_key": "Define the key that holds to the image data. Use a pattern, e.g. \"*.tif\" select multiple files or an internal path for hdf5, zarr or similar formats.", + "raw_key": "Define the key that holds to the image data. Use a pattern, e.g. \"*.tif\" select multiple files or an internal path for hdf5, zarr or similar formats.", # noqa "raw_path": "Specify the path to the image data for training. Can either point to a directory or single file.", "raw_path_val": "Specify the path to the image data for training. Can either point to a directory or single file.", "segmentation_decoder": "Choose whether to train with additional segmentation decoder or not.", diff --git a/micro_sam/sam_annotator/_widgets.py b/micro_sam/sam_annotator/_widgets.py index 4822dc37..8307998f 100644 --- a/micro_sam/sam_annotator/_widgets.py +++ b/micro_sam/sam_annotator/_widgets.py @@ -1,19 +1,20 @@ """Implements the widgets used in the annotation plugins. """ -import json -import multiprocessing as mp import os import pickle from pathlib import Path from typing import Optional +import multiprocessing as mp -import elf.parallel import h5py -import napari -import numpy as np +import json import zarr import z5py +import napari +import numpy as np + +import elf.parallel from qtpy import QtWidgets from qtpy.QtCore import QObject, Signal diff --git a/micro_sam/sam_annotator/annotator_2d.py b/micro_sam/sam_annotator/annotator_2d.py index 6fc01742..fec4d5d8 100644 --- a/micro_sam/sam_annotator/annotator_2d.py +++ b/micro_sam/sam_annotator/annotator_2d.py @@ -2,6 +2,7 @@ import napari import numpy as np + import torch from . import _widgets as widgets diff --git a/micro_sam/sam_annotator/annotator_3d.py b/micro_sam/sam_annotator/annotator_3d.py index dfcf12a7..026e222d 100644 --- a/micro_sam/sam_annotator/annotator_3d.py +++ b/micro_sam/sam_annotator/annotator_3d.py @@ -2,6 +2,7 @@ import napari import numpy as np + import torch from ._annotator import _AnnotatorBase diff --git a/micro_sam/sam_annotator/annotator_tracking.py b/micro_sam/sam_annotator/annotator_tracking.py index 183678d5..d82b0923 100644 --- a/micro_sam/sam_annotator/annotator_tracking.py +++ b/micro_sam/sam_annotator/annotator_tracking.py @@ -2,6 +2,7 @@ import napari import numpy as np + import torch from magicgui.widgets import ComboBox, Container diff --git a/micro_sam/sam_annotator/image_series_annotator.py b/micro_sam/sam_annotator/image_series_annotator.py index 561abc55..f4c7ce71 100644 --- a/micro_sam/sam_annotator/image_series_annotator.py +++ b/micro_sam/sam_annotator/image_series_annotator.py @@ -1,14 +1,14 @@ import os - from glob import glob from pathlib import Path from typing import List, Optional, Union, Tuple import numpy as np import imageio.v3 as imageio -import napari + import torch +import napari from magicgui import magicgui from qtpy import QtWidgets diff --git a/micro_sam/sam_annotator/training_ui.py b/micro_sam/sam_annotator/training_ui.py index 0c725584..f36be0c9 100644 --- a/micro_sam/sam_annotator/training_ui.py +++ b/micro_sam/sam_annotator/training_ui.py @@ -1,15 +1,17 @@ import os import warnings -import torch -import torch_em -from napari.qt.threading import thread_worker from qtpy import QtWidgets +from napari.qt.threading import thread_worker + +import torch from torch.utils.data import random_split +import torch_em + import micro_sam.util as util -import micro_sam.sam_annotator._widgets as widgets from ._tooltips import get_tooltip +import micro_sam.sam_annotator._widgets as widgets from micro_sam.training import default_sam_dataset, train_sam_for_configuration, CONFIGURATIONS diff --git a/micro_sam/sam_annotator/util.py b/micro_sam/sam_annotator/util.py index d3d7525f..aae5a75e 100644 --- a/micro_sam/sam_annotator/util.py +++ b/micro_sam/sam_annotator/util.py @@ -1,8 +1,7 @@ -import argparse import os import pickle import warnings - +import argparse from glob import glob from pathlib import Path from typing import List, Optional, Tuple @@ -10,9 +9,8 @@ import h5py import napari import numpy as np - -from scipy.ndimage import shift from skimage import draw +from scipy.ndimage import shift from .. import prompt_based_segmentation, util from .. import _model_settings as model_settings diff --git a/micro_sam/sample_data.py b/micro_sam/sample_data.py index 311d5008..8f636ce9 100644 --- a/micro_sam/sample_data.py +++ b/micro_sam/sample_data.py @@ -16,13 +16,13 @@ from pathlib import Path from typing import Union -import imageio.v3 as imageio -import numpy as np import pooch +import numpy as np +import imageio.v3 as imageio -from skimage.data import binary_blobs from skimage.measure import label from skimage.transform import resize +from skimage.data import binary_blobs from .util import get_cache_directory diff --git a/micro_sam/training/joint_sam_trainer.py b/micro_sam/training/joint_sam_trainer.py index 08ab8c39..db59408e 100644 --- a/micro_sam/training/joint_sam_trainer.py +++ b/micro_sam/training/joint_sam_trainer.py @@ -13,6 +13,19 @@ class JointSamTrainer(SamTrainer): + """Trainer class for jointly training the Segment Anything model with an additional convolutional decoder. + + This class is inherited from `SamTrainer`. + Check out https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/training/sam_trainer.py + for details on its implementation. + + Args: + unetr: The UNet-style model with vision transformer as the image encoder. + Required to perform automatic instance segmentation. + instance_loss: The loss to compare the predictions (for instance segmentation) and the targets. + instance_metric: The metric to compare the predictions and the targets. + kwargs: The keyword arguments of the `SamTrainer` (and `DefaultTrainer`) class. + """ def __init__( self, unetr: torch.nn.Module, @@ -60,6 +73,9 @@ def load_checkpoint(self, checkpoint="best"): return save_dict def _instance_iteration(self, x, y, metric_for_val=False): + """Perform the segmentation of distance maps and + compute the loss (and metric) between the prediction and target. + """ outputs = self.unetr(x.to(self.device)) loss = self.instance_loss(outputs, y.to(self.device)) if metric_for_val: diff --git a/micro_sam/training/sam_trainer.py b/micro_sam/training/sam_trainer.py index 268cca7d..020413e2 100644 --- a/micro_sam/training/sam_trainer.py +++ b/micro_sam/training/sam_trainer.py @@ -5,10 +5,11 @@ from typing import Optional import numpy as np -import torch -import torch_em +import torch from torchvision.utils import make_grid + +import torch_em from torch_em.trainer.logger_base import TorchEmLogger from ..prompt_generators import PromptGeneratorBase, IterativePromptGenerator @@ -32,7 +33,7 @@ class SamTrainer(torch_em.trainer.DefaultTrainer): prompt_generator: The iterative prompt generator which takes care of the iterative prompting logic for training mask_prob: The probability of using the mask inputs in the iterative prompting (per `n_sub_iteration`) mask_loss: The loss to compare the predicted masks and the targets. - **kwargs: The keyword arguments of the DefaultTrainer super class. + kwargs: The keyword arguments of the DefaultTrainer super class. """ def __init__( diff --git a/micro_sam/training/semantic_sam_trainer.py b/micro_sam/training/semantic_sam_trainer.py index cb136c30..46a5f8fa 100644 --- a/micro_sam/training/semantic_sam_trainer.py +++ b/micro_sam/training/semantic_sam_trainer.py @@ -9,6 +9,14 @@ class CustomDiceLoss(nn.Module): + """Loss for computing dice over one-hot labels. + + Expects prediction and target with `num_classes` channels: the number of classes for semantic segmentation. + + Args: + num_classes: The number of classes for semantic segmentation (including background class). + softmax: Whether to use softmax over the predictions. + """ def __init__(self, num_classes: int, softmax: bool = True) -> None: super().__init__() self.num_classes = num_classes @@ -32,7 +40,18 @@ def __call__(self, pred, target): class SemanticSamTrainer(DefaultTrainer): - """ + """Trainer class for training the Segment Anything model for semantic segmentation. + + This class is derived from `torch_em.trainer.DefaultTrainer`. + Check out https://github.com/constantinpape/torch-em/blob/main/torch_em/trainer/default_trainer.py + for details on its usage and implementation. + + Args: + convert_inputs: The class that converts outputs of the dataloader to the expected input format of SAM. + The class `micro_sam.training.util.ConvertToSemanticSamInputs` can be used here. + num_classes: The number of classes for semantic segmentation (including the background class). + dice_weight: The weighing for the dice loss in the combined dice-cross entropy loss function. + kwargs: The keyword arguments of the DefaultTrainer super class. """ def __init__( self, @@ -62,6 +81,8 @@ def __init__( self._kwargs = kwargs def _compute_loss(self, y, masks): + """Compute the combined (weighted) dice loss and cross-entropy loss between the prediction and target. + """ target = y.to(self.device, non_blocking=True) # Compute dice loss for the predictions dice_loss = self.loss(masks, target) @@ -77,6 +98,8 @@ def _compute_loss(self, y, masks): return net_loss def _get_model_outputs(self, batched_inputs): + """Get the predictions from the model. + """ # Precompute the image embeddings if the model exposes it as functionality. if hasattr(self.model, "image_embeddings_oft"): image_embeddings, batched_inputs = self.model.image_embeddings_oft(batched_inputs) diff --git a/micro_sam/training/simple_sam_trainer.py b/micro_sam/training/simple_sam_trainer.py index 984e41fa..a0b06341 100644 --- a/micro_sam/training/simple_sam_trainer.py +++ b/micro_sam/training/simple_sam_trainer.py @@ -5,6 +5,15 @@ class SimpleSamTrainer(SamTrainer): """Trainer class for creating a simple SAM trainer for limited prompt-based segmentation. + + This class is inherited from `SamTrainer`. + Check out https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/training/sam_trainer.py + for details on its implementation. + + Args: + use_points: Whether to use point prompts for interactive segmentation. + use_box: Whether to use box prompts for interactive segmentation. + kwargs: The keyword arguments of the `SamTrainer` (and `DefaultTrainer`) class. """ def __init__( self, @@ -28,20 +37,21 @@ def __init__( assert (self.use_points + self.use_box) != 0, "Please choose at least one of the prompt-based method." def _choose_one_positive_point(self): - "samples only a single positive point per object" + """Samples only a single positive point per object + """ n_pos, n_neg = 1, 0 multimask_output = True return n_pos, n_neg, None, multimask_output def _choose_box(self): - "samples only a single box per object" + """Samples only a single box per object + """ n_pos, n_neg = 0, 0 multimask_output = False get_boxes = True return n_pos, n_neg, get_boxes, multimask_output def _get_prompt_and_multimasking_choices(self, current_iteration): - if self.random_prompt_choice: # both "use_points" and "use_box" are True available_choices = [self._choose_one_positive_point(), self._choose_box()] return random.choice(available_choices) @@ -57,6 +67,11 @@ def _get_prompt_and_multimasking_choices_for_val(self, current_iteration): class MedSAMTrainer(SimpleSamTrainer): """Trainer class for replicating the trainer of MedSAM (https://arxiv.org/abs/2304.12306). + + This class is inherited from `SimpleSamTrainer`. + Check out + https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/training/simple_sam_trainer.py + for details on its implementation. """ def __init__(self, **kwargs): super().__init__( diff --git a/micro_sam/training/trainable_sam.py b/micro_sam/training/trainable_sam.py index 72a3ebe6..dbc206d3 100644 --- a/micro_sam/training/trainable_sam.py +++ b/micro_sam/training/trainable_sam.py @@ -15,10 +15,7 @@ class TrainableSAM(nn.Module): Args: sam: The SegmentAnything Model. """ - def __init__( - self, - sam: Sam, - ) -> None: + def __init__(self, sam: Sam) -> None: super().__init__() self.sam = sam self.transform = ResizeLongestSide(sam.image_encoder.img_size) @@ -62,10 +59,7 @@ def image_embeddings_oft(self, batched_inputs): # batched inputs follow the same syntax as the input to sam.forward def forward( - self, - batched_inputs: List[Dict[str, Any]], - image_embeddings: torch.Tensor, - multimask_output: bool = False, + self, batched_inputs: List[Dict[str, Any]], image_embeddings: torch.Tensor, multimask_output: bool = False, ) -> List[Dict[str, Any]]: """Forward pass. diff --git a/micro_sam/training/training.py b/micro_sam/training/training.py index 39314e7f..4e25e595 100644 --- a/micro_sam/training/training.py +++ b/micro_sam/training/training.py @@ -178,11 +178,13 @@ def train_sam( mask_prob: The probability for using a mask as input in a given training sub-iteration. n_iterations: The number of iterations to use for training. This will over-ride n_epochs if given. scheduler_class: The learning rate scheduler to update the learning rate. - By default, ReduceLROnPlateau is used. + By default, torch.optim.lr_scheduler.ReduceLROnPlateau is used. scheduler_kwargs: The learning rate scheduler parameters. If passed None, the chosen default parameters are used in ReduceLROnPlateau. save_every_kth_epoch: Save checkpoints after every kth epoch separately. pbar_signals: Controls for napari progress bar. + optimizer_class: The optimizer class. + By default, torch.optim.AdamW is used. """ _check_loader(train_loader, with_segmentation_decoder) _check_loader(val_loader, with_segmentation_decoder) diff --git a/micro_sam/training/util.py b/micro_sam/training/util.py index 4ba56961..759c905e 100644 --- a/micro_sam/training/util.py +++ b/micro_sam/training/util.py @@ -227,7 +227,8 @@ def __call__(self, x, y, n_pos, n_neg, get_boxes=False, n_samples=None): class ConvertToSemanticSamInputs: - """ + """Convert outputs of data loader to the expected batched inputs of the SegmentAnything model + for semantic segmentation. """ def __call__(self, x, y): """Convert the outputs of dataloader to the batched format of inputs expected by SAM. diff --git a/micro_sam/util.py b/micro_sam/util.py index 09d06c09..45550a49 100644 --- a/micro_sam/util.py +++ b/micro_sam/util.py @@ -2,26 +2,28 @@ Helper functions for downloading Segment Anything models and predicting image embeddings. """ -import hashlib + import os import pickle +import hashlib import warnings -from collections import OrderedDict from pathlib import Path +from collections import OrderedDict from typing import Any, Dict, Iterable, Optional, Tuple, Union -import imageio.v3 as imageio -import numpy as np -import pooch -import torch +import zarr import vigra +import torch +import pooch import xxhash -import zarr +import numpy as np +import imageio.v3 as imageio +from skimage.measure import regionprops +from skimage.segmentation import relabel_sequential from elf.io import open_file + from nifty.tools import blocking -from skimage.measure import regionprops -from skimage.segmentation import relabel_sequential from .__version__ import __version__ from . import models as custom_models diff --git a/micro_sam/visualization.py b/micro_sam/visualization.py index c931985f..ad4e9d00 100644 --- a/micro_sam/visualization.py +++ b/micro_sam/visualization.py @@ -4,10 +4,11 @@ from typing import Tuple import numpy as np +from skimage.transform import resize -from elf.segmentation.embeddings import embedding_pca from nifty.tools import blocking -from skimage.transform import resize + +from elf.segmentation.embeddings import embedding_pca from .util import ImageEmbeddings