From 34889337cb402c9d205504385d1a9a95b9258a78 Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Fri, 31 May 2024 09:42:38 +0200 Subject: [PATCH 01/53] Upstream dev with master (#624) * Update modelzoo urls (#619) * Update installation.md (#617) * Bump to 101 (#620) * Update docs (#621) * Add progressbar to pooch download (#623) --------- Co-authored-by: Constantin Pape --- README.md | 4 ++++ doc/installation.md | 14 +++++++++----- micro_sam/__version__.py | 2 +- micro_sam/util.py | 29 ++++++++++++++--------------- 4 files changed, 28 insertions(+), 21 deletions(-) diff --git a/README.md b/README.md index 8a06ba26c..8a3442b45 100644 --- a/README.md +++ b/README.md @@ -53,6 +53,10 @@ Compared to these we support more applications (2d, 3d and tracking), and provid ## Release Overview +**New in version 1.0.1** + +Use stable URL for model downloads and fix issues in state precomputation for automatic segmentation. + **New in version 1.0.0** This release mainly fixes issues with the previous release and marks the napari user interface as stable. diff --git a/doc/installation.md b/doc/installation.md index 1edb5e51d..5fb232782 100644 --- a/doc/installation.md +++ b/doc/installation.md @@ -7,7 +7,7 @@ There are three ways to install `micro_sam`: You can find more information on the installation and how to troubleshoot it in [the FAQ section](#installation-questions). -We do *not* recommend installing `micro-sam` with pip. +We do **not** recommend installing `micro-sam` with pip. ## From mamba @@ -19,11 +19,11 @@ You can follow the instructions [here](https://mamba.readthedocs.io/en/latest/in `micro_sam` can be installed in an existing environment via: ```bash -$ mamba install -c conda-forge micro_sam +$ mamba install -c pytorch -c conda-forge micro_sam ``` or you can create a new environment (here called `micro-sam`) via: ```bash -$ mamba create -c conda-forge -n micro-sam micro_sam +$ mamba create -c pytorch -c conda-forge -n micro-sam micro_sam ``` if you want to use the GPU you need to install PyTorch from the `pytorch` channel instead of `conda-forge`. For example: ```bash @@ -73,8 +73,8 @@ $ pip install -e . ## From installer We also provide installers for Linux and Windows: -- [Linux](https://owncloud.gwdg.de/index.php/s/nrNBuHr9ncJqid6) -- [Windows](https://owncloud.gwdg.de/index.php/s/kZmpAIBDmUSu4e9) +- [Linux](https://owncloud.gwdg.de/index.php/s/nvLwlrHE4DkYcWl) +- [Windows](https://owncloud.gwdg.de/index.php/s/feIs9069IrURmbt) @@ -113,3 +113,7 @@ https://www.makeuseof.com/how-to-disable-gatekeeper-mac/ TODO detailed instruction --> + +### Easybuild installation + +There is also an easy-build recipe for `micro_sam` under development. You can find more information [here](https://github.com/easybuilders/easybuild-easyconfigs/pull/20636). diff --git a/micro_sam/__version__.py b/micro_sam/__version__.py index 244424e58..5c4105cd3 100644 --- a/micro_sam/__version__.py +++ b/micro_sam/__version__.py @@ -1 +1 @@ -__version__ = "1.0.0post0" +__version__ = "1.0.1" diff --git a/micro_sam/util.py b/micro_sam/util.py index f16911be3..c9768dc84 100644 --- a/micro_sam/util.py +++ b/micro_sam/util.py @@ -124,27 +124,26 @@ def models(): } registry = {**encoder_registry, **decoder_registry} - # Note: the modelzoo urls should be updated at some point to not point at 'staged' but 'published'. encoder_urls = { "vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth", "vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", "vit_t": "https://owncloud.gwdg.de/index.php/s/TuDzuwVDHd1ZDnQ/download", - "vit_l_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/idealistic-rat/staged/1/files/vit_l.pt", - "vit_b_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/staged/1/files/vit_b.pt", - "vit_t_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/faithful-chicken/staged/1/files/vit_t.pt", - "vit_l_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/staged/1/files/vit_l.pt", - "vit_b_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/staged/1/files/vit_b.pt", - "vit_t_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/staged/1/files/vit_t.pt", + "vit_l_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/idealistic-rat/1/files/vit_l.pt", + "vit_b_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1/files/vit_b.pt", + "vit_t_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/faithful-chicken/1/files/vit_t.pt", + "vit_l_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1/files/vit_l.pt", + "vit_b_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1/files/vit_b.pt", + "vit_t_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t.pt", } decoder_urls = { - "vit_l_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/idealistic-rat/staged/1/files/vit_l_decoder.pt", - "vit_b_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/staged/1/files/vit_b_decoder.pt", - "vit_t_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/faithful-chicken/staged/1/files/vit_t_decoder.pt", - "vit_l_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/staged/1/files/vit_l_decoder.pt", - "vit_b_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/staged/1/files/vit_b_decoder.pt", - "vit_t_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/staged/1/files/vit_t_decoder.pt", + "vit_l_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/idealistic-rat/1/files/vit_l_decoder.pt", + "vit_b_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1/files/vit_b_decoder.pt", + "vit_t_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/faithful-chicken/1/files/vit_t_decoder.pt", + "vit_l_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1/files/vit_l_decoder.pt", + "vit_b_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1/files/vit_b_decoder.pt", + "vit_t_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t_decoder.pt", } urls = {**encoder_urls, **decoder_urls} @@ -317,13 +316,13 @@ def get_sam_model( # URL from the model_type. If the model_type is invalid pooch will raise an error. if checkpoint_path is None: model_registry = models() - checkpoint_path = model_registry.fetch(model_type) + checkpoint_path = model_registry.fetch(model_type, progressbar=True) model_hash = model_registry.registry[model_type] # If we have a custom model then we may also have a decoder checkpoint. # Download it here, so that we can add it to the state. decoder_name = f"{model_type}_decoder" - decoder_path = model_registry.fetch(decoder_name) if decoder_name in model_registry.registry else None + decoder_path = model_registry.fetch(decoder_name, progressbar=True) if decoder_name in model_registry.registry else None # checkpoint_path has been passed, we use it instead of downloading a model. else: From 1c20ac3e0c74c21281f91aab770c0c98f3c19f2d Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Sun, 2 Jun 2024 12:41:09 +0200 Subject: [PATCH 02/53] Add PanNuke specialist finetuning (#622) --- finetuning/specialists/README.md | 6 ++ .../histopathology/create_dataloaders.py | 68 ++++++++++++++ .../histopathology/pannuke_finetuning.py | 94 +++++++++++++++++++ 3 files changed, 168 insertions(+) create mode 100644 finetuning/specialists/training/histopathology/create_dataloaders.py create mode 100644 finetuning/specialists/training/histopathology/pannuke_finetuning.py diff --git a/finetuning/specialists/README.md b/finetuning/specialists/README.md index 7a362cc74..de0357ced 100644 --- a/finetuning/specialists/README.md +++ b/finetuning/specialists/README.md @@ -19,6 +19,12 @@ Code for finetuning Segment Anything on specific microscopy datasets. - `resource_efficient_finetuning`: The experiments for finetuning a custom dataset on limited resources. +## Experimental Scripts + +- `training/histopathology/`: The finetuning scripts for histopathology datasets. + - `pannuke_finetuning.py`: Finetuning Segment Anything on PanNuke datasets. + + ## Outdated Scripts The scripts located at `outdated/` are not in working purpose with the latest version of `micro-sam`. - It comprises of extensive experiments on "LIVECell" specialist, located at `outdated/livecell/`. \ No newline at end of file diff --git a/finetuning/specialists/training/histopathology/create_dataloaders.py b/finetuning/specialists/training/histopathology/create_dataloaders.py new file mode 100644 index 000000000..1f2f48546 --- /dev/null +++ b/finetuning/specialists/training/histopathology/create_dataloaders.py @@ -0,0 +1,68 @@ +import torch + +from torch_em.data import MinInstanceSampler +from torch_em.util.debug import check_loader +from torch_em.data.datasets import get_pannuke_loader +from torch_em.transform.label import PerObjectDistanceTransform + +import micro_sam.training as sam_training + + +def get_dataloaders(patch_shape, data_path): + """This returns the pannuke data loaders implemented in torch_em: + https://github.com/constantinpape/torch-em/blob/main/torch_em/data/datasets/histopathology/pannuke.py + It will automatically download the pannuke data. + + Note: to replace this with another data loader you need to return a torch data loader + that retuns `x, y` tensors, where `x` is the image data and `y` are the labels. + The labels have to be in a label mask instance segmentation format. + I.e. a tensor of the same spatial shape as `x`, with each object mask having its own ID. + Important: the ID 0 is reseved for background, and the IDs must be consecutive + """ + label_transform = PerObjectDistanceTransform( + distances=True, boundary_distances=True, directed_distances=False, foreground=True, instances=True, min_size=25 + ) + raw_transform = sam_training.identity # the current workflow avoids rescaling the inputs to [-1, 1] + sampler = MinInstanceSampler(min_num_instances=3) + + train_loader = get_pannuke_loader( + path=data_path, + patch_shape=patch_shape, + batch_size=2, + folds=["fold_1"], + num_workers=16, + download=True, + shuffle=True, + label_transform=label_transform, + raw_transform=raw_transform, + label_dtype=torch.float32, + sampler=sampler, + ndim=2, + ) + val_loader = get_pannuke_loader( + path=data_path, + patch_shape=patch_shape, + batch_size=1, + folds=["fold_2"], + num_workers=16, + download=True, + shuffle=True, + label_transform=label_transform, + raw_transform=raw_transform, + label_dtype=torch.float32, + sampler=sampler, + ndim=2, + ) + + return train_loader, val_loader + + +def visualize_images(data_path): + train_loader, val_loader = get_dataloaders(patch_shape=(1, 512, 512), data_path=data_path) + + # let's visualize train loader first + check_loader(train_loader, 8, plt=True, save_path="./fig.png") + + +if __name__ == "__main__": + visualize_images(data_path="/scratch/projects/nim00007/sam/data/pannuke") diff --git a/finetuning/specialists/training/histopathology/pannuke_finetuning.py b/finetuning/specialists/training/histopathology/pannuke_finetuning.py new file mode 100644 index 000000000..f6b23a939 --- /dev/null +++ b/finetuning/specialists/training/histopathology/pannuke_finetuning.py @@ -0,0 +1,94 @@ +import os +import argparse + +import torch + +import micro_sam.training as sam_training +from micro_sam.util import export_custom_sam_model + +from create_dataloaders import get_dataloaders + + +def finetune_pannuke(args): + """Example code for finetuning SAM on PanNuke""" + # override this (below) if you have some more complex set-up and need to specify the exact gpu + device = "cuda" if torch.cuda.is_available() else "cpu" + + # training settings: + model_type = args.model_type + checkpoint_path = None # override this to start training from a custom checkpoint + patch_shape = (1, 512, 512) # the patch shape for training + n_objects_per_batch = args.n_objects # the number of objects per batch that will be sampled (default: 25) + freeze_parts = args.freeze # override this to freeze different parts of the model + checkpoint_name = f"{args.model_type}/pannuke_sam" + + # all the stuff we need for training + train_loader, val_loader = get_dataloaders(patch_shape=patch_shape, data_path=args.input_path) + scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 10, "verbose": True} + + # Run training. + sam_training.train_sam( + name=checkpoint_name, + model_type=model_type, + train_loader=train_loader, + val_loader=val_loader, + early_stopping=10, + n_objects_per_batch=n_objects_per_batch, + checkpoint_path=checkpoint_path, + freeze=freeze_parts, + device=device, + lr=1e-5, + n_iterations=args.iterations, + save_root=args.save_root, + scheduler_kwargs=scheduler_kwargs, + save_every_kth_epoch=args.save_every_kth_epoch, + ) + + if args.export_path is not None: + checkpoint_path = os.path.join( + "" if args.save_root is None else args.save_root, "checkpoints", checkpoint_name, "best.pt" + ) + export_custom_sam_model( + checkpoint_path=checkpoint_path, model_type=model_type, save_path=args.export_path, + ) + + +def main(): + parser = argparse.ArgumentParser(description="Finetune Segment Anything for the PanNuke dataset.") + parser.add_argument( + "--input_path", "-i", default="/scratch/projects/nim00007/sam/data/pannuke/", + help="The filepath to the PanNuke data. If the data does not exist yet it will be downloaded." + ) + parser.add_argument( + "--model_type", "-m", default="vit_b", + help="The model type to use for fine-tuning. Either vit_t, vit_b, vit_l or vit_h." + ) + parser.add_argument( + "--save_root", "-s", + help="Where to save the checkpoint and logs. By default they will be saved where this script is run." + ) + parser.add_argument( + "--iterations", type=int, default=int(1e5), + help="For how many iterations should the model be trained? By default 100k." + ) + parser.add_argument( + "--export_path", "-e", + help="Where to export the finetuned model to. The exported model can be used in the annotation tools." + ) + parser.add_argument( + "--freeze", type=str, nargs="+", default=None, + help="Which parts of the model to freeze for finetuning." + ) + parser.add_argument( + "--save_every_kth_epoch", type=int, default=None, + help="To save every kth epoch while fine-tuning. Expects an integer value." + ) + parser.add_argument( + "--n_objects", type=int, default=25, help="The number of instances (objects) per batch used for finetuning." + ) + args = parser.parse_args() + finetune_pannuke(args) + + +if __name__ == "__main__": + main() From 401ea509060b7d39df2b64e9b7c4d5b2c504c13b Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Wed, 5 Jun 2024 11:21:13 +0200 Subject: [PATCH 03/53] Add MedSAM Reimplementation (#612) Add MedSAM trainer --- micro_sam/training/__init__.py | 1 + micro_sam/training/medsam_trainer.py | 22 ++++ micro_sam/training/sam_trainer.py | 27 +++-- scripts/medsam/btcv_finetuning.py | 146 +++++++++++++++++++++++++++ 4 files changed, 186 insertions(+), 10 deletions(-) create mode 100644 micro_sam/training/medsam_trainer.py create mode 100644 scripts/medsam/btcv_finetuning.py diff --git a/micro_sam/training/__init__.py b/micro_sam/training/__init__.py index 3047a75bf..408f3424d 100644 --- a/micro_sam/training/__init__.py +++ b/micro_sam/training/__init__.py @@ -4,4 +4,5 @@ from .sam_trainer import SamTrainer, SamLogger from .util import ConvertToSamInputs, get_trainable_sam_model, identity from .joint_sam_trainer import JointSamTrainer, JointSamLogger +from .medsam_trainer import MedSAMTrainer from .training import train_sam, train_sam_for_configuration, default_sam_loader, default_sam_dataset, CONFIGURATIONS diff --git a/micro_sam/training/medsam_trainer.py b/micro_sam/training/medsam_trainer.py new file mode 100644 index 000000000..acb1a7fa1 --- /dev/null +++ b/micro_sam/training/medsam_trainer.py @@ -0,0 +1,22 @@ +from . import SamTrainer + + +class MedSAMTrainer(SamTrainer): + """Trainer class for replicating the trainer of MedSAM (https://arxiv.org/abs/2304.12306) + """ + def __init__( + self, + **kwargs + ): + n_sub_iteration = 1 + mask_prob = 0 + super().__init__(n_sub_iteration=n_sub_iteration, mask_prob=mask_prob, **kwargs) + + def _get_prompt_and_multimasking_choices(self, current_iteration): + n_pos, n_neg = 0, 0 + get_boxes = True + multimask_output = False + return n_pos, n_neg, get_boxes, multimask_output + + def _get_prompt_and_multimasking_choices_for_val(self, current_iteration): + return self._get_prompt_and_multimasking_choices(current_iteration=current_iteration) diff --git a/micro_sam/training/sam_trainer.py b/micro_sam/training/sam_trainer.py index c251e749a..d825c787d 100644 --- a/micro_sam/training/sam_trainer.py +++ b/micro_sam/training/sam_trainer.py @@ -31,6 +31,7 @@ class SamTrainer(torch_em.trainer.DefaultTrainer): mse_loss: The regression loss to compare the IoU predicted by the model with the true IoU. prompt_generator: The iterative prompt generator which takes care of the iterative prompting logic for training mask_prob: The probability of using the mask inputs in the iterative prompting (per `n_sub_iteration`) + mask_loss: The loss to compare the predicted masks and the targets. **kwargs: The keyword arguments of the DefaultTrainer super class. """ @@ -42,12 +43,17 @@ def __init__( mse_loss: torch.nn.Module = torch.nn.MSELoss(), prompt_generator: PromptGeneratorBase = IterativePromptGenerator(), mask_prob: float = 0.5, + mask_loss: Optional[torch.nn.Module] = None, **kwargs ): - # We have to use the Dice Loss with reduce channel set to None. - # Hence we hard-code it here to avoid issues by passsing wrong options for the loss. - dice_loss = torch_em.loss.DiceLoss(reduce_channel=None) - super().__init__(loss=dice_loss, metric=dice_loss, **kwargs) + if mask_loss is None: + # We have to use the Dice Loss with reduce channel set to None. + # Hence we hard-code it here to avoid issues by passsing wrong options for the loss. + self.mask_loss = torch_em.loss.DiceLoss(reduce_channel=None) + else: + self.mask_loss = mask_loss + + super().__init__(loss=self.mask_loss, metric=self.mask_loss, **kwargs) self.convert_inputs = convert_inputs self.mse_loss = mse_loss self.n_objects_per_batch = n_objects_per_batch @@ -216,12 +222,13 @@ def _compute_iterative_loss(self, batched_inputs, y_one_hot, num_subiter, multim iou_regression_loss += net_iou_regression_loss mean_model_iou += net_mean_model_iou - # Determine the next prompts based on current predictions. - with torch.no_grad(): - # Get the mask and logit predictions corresponding to the predicted object - # (per actual object) with the best IOU. - masks, logits = self._get_best_masks(batched_outputs, batched_iou_predictions) - batched_inputs = self._update_prompts(batched_inputs, y_one_hot, masks, logits) + if i < (num_subiter - 1): # We need not update the prompts for the last iteration. + # Determine the next prompts based on current predictions. + with torch.no_grad(): + # Get the mask and logit predictions corresponding to the predicted object + # (per actual object) with the best IOU. + masks, logits = self._get_best_masks(batched_outputs, batched_iou_predictions) + batched_inputs = self._update_prompts(batched_inputs, y_one_hot, masks, logits) loss = loss / num_subiter mask_loss = mask_loss / num_subiter diff --git a/scripts/medsam/btcv_finetuning.py b/scripts/medsam/btcv_finetuning.py new file mode 100644 index 000000000..4040c515e --- /dev/null +++ b/scripts/medsam/btcv_finetuning.py @@ -0,0 +1,146 @@ +import os +import argparse + +import torch + +from torch_em.loss.dice import BCEDiceLossWithLogits +from torch_em.data.datasets.medical import get_btcv_loader + +import micro_sam.training as sam_training +from micro_sam.util import export_custom_sam_model + + +def get_dataloaders(patch_shape, data_path): + """This returns the btcv data loaders implemented in torch_em: + https://github.com/constantinpape/torch-em/blob/main/torch_em/data/datasets/medical/btcv.py + It will not automatically download the BTCV data. Take a look at `get_btcv_dataset`. + + Note: to replace this with another data loader you need to return a torch data loader + that retuns `x, y` tensors, where `x` is the image data and `y` are the labels. + The labels have to be in a label mask instance segmentation format. + I.e. a tensor of the same spatial shape as `x`, with each object mask having its own ID. + Important: the ID 0 is reseved for background, and the IDs must be consecutive + """ + raw_transform = sam_training.identity + + train_loader = get_btcv_loader( + path=data_path, + patch_shape=patch_shape, + batch_size=2, + ndim=2, + anatomy=None, + organs=None, + raw_transform=raw_transform, + ) + val_loader = get_btcv_loader( + path=data_path, + patch_shape=patch_shape, + batch_size=1, + ndim=2, + anatomy=None, + organs=None, + raw_transform=raw_transform, + ) + return train_loader, val_loader + + +def finetune_btcv(args): + """Code for finetuning SAM on BTCV in "micro_sam"-based MedSAM reimplementation""" + # override this (below) if you have some more complex set-up and need to specify the exact gpu + device = "cuda" if torch.cuda.is_available() else "cpu" + + # training settings: + model_type = args.model_type + checkpoint_path = None # override this to start training from a custom checkpoint + patch_shape = (1, 512, 512) # the patch shape for training + n_objects_per_batch = args.n_objects # this is the number of objects per batch that will be sampled (default: 25) + freeze_parts = args.freeze # override this to freeze different parts of the model + + # get the trainable segment anything model + model = sam_training.get_trainable_sam_model( + model_type=model_type, + device=device, + checkpoint_path=checkpoint_path, + freeze=freeze_parts + ) + model.to(device) + + # all the stuff we need for training + optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.9, patience=10, verbose=True) + train_loader, val_loader = get_dataloaders(patch_shape=patch_shape, data_path=args.input_path) + + # this class creates all the training data for a batch (inputs, prompts and labels) + convert_inputs = sam_training.ConvertToSamInputs(transform=model.transform, box_distortion_factor=0.025) + + checkpoint_name = f"{args.model_type}/btcv_medsam" + + # the trainer which performs the joint training and validation (implemented using "torch_em") + trainer = sam_training.MedSAMTrainer( + name=checkpoint_name, + save_root=args.save_root, + train_loader=train_loader, + val_loader=val_loader, + model=model, + optimizer=optimizer, + device=device, + lr_scheduler=scheduler, + logger=sam_training.SamLogger, + log_image_interval=100, + mixed_precision=True, + convert_inputs=convert_inputs, + n_objects_per_batch=n_objects_per_batch, + compile_model=False, + mask_loss=BCEDiceLossWithLogits(), + ) + trainer.fit(args.iterations, save_every_kth_epoch=args.save_every_kth_epoch) + if args.export_path is not None: + checkpoint_path = os.path.join( + "" if args.save_root is None else args.save_root, "checkpoints", checkpoint_name, "best.pt" + ) + export_custom_sam_model( + checkpoint_path=checkpoint_path, + model_type=model_type, + save_path=args.export_path, + ) + + +def main(): + parser = argparse.ArgumentParser(description="Finetune Segment Anything for the BTCV dataset.") + parser.add_argument( + "--input_path", "-i", default="/scratch/projects/nim00007/data/btcv/", + help="The filepath to the BTCV data. If the data does not exist yet it will be downloaded." + ) + parser.add_argument( + "--model_type", "-m", default="vit_b", + help="The model type to use for fine-tuning. Either vit_t, vit_b, vit_l or vit_h." + ) + parser.add_argument( + "--save_root", "-s", + help="Where to save the checkpoint and logs. By default they will be saved where this script is run." + ) + parser.add_argument( + "--iterations", type=int, default=int(1e4), + help="For how many iterations should the model be trained?" + ) + parser.add_argument( + "--export_path", "-e", + help="Where to export the finetuned model to. The exported model can be used in the annotation tools." + ) + parser.add_argument( + "--freeze", type=str, nargs="+", default=None, + help="Which parts of the model to freeze for finetuning." + ) + parser.add_argument( + "--save_every_kth_epoch", type=int, default=None, + help="To save every kth epoch while fine-tuning. Expects an integer value." + ) + parser.add_argument( + "--n_objects", type=int, default=25, help="The number of instances (objects) per batch used for finetuning." + ) + args = parser.parse_args() + finetune_btcv(args) + + +if __name__ == "__main__": + main() From f5656dac4b4da239df68090d9110302052f4b9f6 Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Fri, 7 Jun 2024 18:12:25 +0200 Subject: [PATCH 04/53] Remove device input for TrainableSAM and set non_blocking for moving tensors across devices (#632) --- micro_sam/training/sam_trainer.py | 2 +- micro_sam/training/trainable_sam.py | 16 ++++++++-------- micro_sam/training/util.py | 2 +- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/micro_sam/training/sam_trainer.py b/micro_sam/training/sam_trainer.py index d825c787d..268cca7d5 100644 --- a/micro_sam/training/sam_trainer.py +++ b/micro_sam/training/sam_trainer.py @@ -283,7 +283,7 @@ def _preprocess_batch(self, batched_inputs, y, sampled_ids): # number of objects across the batch. n_objects = min(len(ids) for ids in sampled_ids) - y = y.to(self.device) + y = y.to(self.device, non_blocking=True) # Compute the one hot targets for the seg-id. y_one_hot = torch.stack([ torch.stack([target == seg_id for seg_id in ids[:n_objects]]) diff --git a/micro_sam/training/trainable_sam.py b/micro_sam/training/trainable_sam.py index 81c7d8ca5..72a3ebe62 100644 --- a/micro_sam/training/trainable_sam.py +++ b/micro_sam/training/trainable_sam.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Tuple, Union +from typing import Any, Dict, List, Tuple import torch from torch import nn @@ -14,16 +14,13 @@ class TrainableSAM(nn.Module): Args: sam: The SegmentAnything Model. - device: The device for training. """ def __init__( self, sam: Sam, - device: Union[str, torch.device], ) -> None: super().__init__() self.sam = sam - self.device = device self.transform = ResizeLongestSide(sam.image_encoder.img_size) def preprocess(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, int]]: @@ -54,7 +51,7 @@ def preprocess(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, int]]: def image_embeddings_oft(self, batched_inputs): # Compute the input images. input_images, input_size = self.preprocess( - torch.stack([x["image"] for x in batched_inputs], dim=0).to(self.device) + torch.stack([x["image"] for x in batched_inputs], dim=0).to(self.sam.device, non_blocking=True) ) # Update the input size for each input in the batch. for i in range(len(batched_inputs)): @@ -83,17 +80,20 @@ def forward( outputs = [] for image_record, curr_embedding in zip(batched_inputs, image_embeddings): if "point_coords" in image_record: - points = (image_record["point_coords"].to(self.device), image_record["point_labels"].to(self.device)) + points = ( + image_record["point_coords"].to(self.sam.device, non_blocking=True), + image_record["point_labels"].to(self.sam.device, non_blocking=True) + ) else: points = None if "boxes" in image_record: - boxes = image_record.get("boxes").to(self.device) + boxes = image_record.get("boxes").to(self.sam.device, non_blocking=True) else: boxes = None if "mask_inputs" in image_record: - masks = image_record.get("mask_inputs").to(self.device) + masks = image_record.get("mask_inputs").to(self.sam.device, non_blocking=True) else: masks = None diff --git a/micro_sam/training/util.py b/micro_sam/training/util.py index df3d13f1a..89799617c 100644 --- a/micro_sam/training/util.py +++ b/micro_sam/training/util.py @@ -81,7 +81,7 @@ def get_trainable_sam_model( param.requires_grad = False # convert to trainable sam - trainable_sam = TrainableSAM(sam, device) + trainable_sam = TrainableSAM(sam) if return_state: return trainable_sam, state return trainable_sam From c7e21a02e888f991a0d34405bd605fee660719e4 Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Tue, 11 Jun 2024 15:43:17 +0200 Subject: [PATCH 05/53] Add SimpleSamTrainer (#631) * Refactor medsam trainer to simple sam trainer --------- Co-authored-by: Constantin Pape --- micro_sam/training/__init__.py | 2 +- micro_sam/training/medsam_trainer.py | 22 ---- micro_sam/training/simple_sam_trainer.py | 66 ++++++++++ scripts/medsam/btcv_finetuning.py | 146 ----------------------- 4 files changed, 67 insertions(+), 169 deletions(-) delete mode 100644 micro_sam/training/medsam_trainer.py create mode 100644 micro_sam/training/simple_sam_trainer.py delete mode 100644 scripts/medsam/btcv_finetuning.py diff --git a/micro_sam/training/__init__.py b/micro_sam/training/__init__.py index 408f3424d..1c354b658 100644 --- a/micro_sam/training/__init__.py +++ b/micro_sam/training/__init__.py @@ -4,5 +4,5 @@ from .sam_trainer import SamTrainer, SamLogger from .util import ConvertToSamInputs, get_trainable_sam_model, identity from .joint_sam_trainer import JointSamTrainer, JointSamLogger -from .medsam_trainer import MedSAMTrainer +from .simple_sam_trainer import SimpleSamTrainer, MedSAMTrainer from .training import train_sam, train_sam_for_configuration, default_sam_loader, default_sam_dataset, CONFIGURATIONS diff --git a/micro_sam/training/medsam_trainer.py b/micro_sam/training/medsam_trainer.py deleted file mode 100644 index acb1a7fa1..000000000 --- a/micro_sam/training/medsam_trainer.py +++ /dev/null @@ -1,22 +0,0 @@ -from . import SamTrainer - - -class MedSAMTrainer(SamTrainer): - """Trainer class for replicating the trainer of MedSAM (https://arxiv.org/abs/2304.12306) - """ - def __init__( - self, - **kwargs - ): - n_sub_iteration = 1 - mask_prob = 0 - super().__init__(n_sub_iteration=n_sub_iteration, mask_prob=mask_prob, **kwargs) - - def _get_prompt_and_multimasking_choices(self, current_iteration): - n_pos, n_neg = 0, 0 - get_boxes = True - multimask_output = False - return n_pos, n_neg, get_boxes, multimask_output - - def _get_prompt_and_multimasking_choices_for_val(self, current_iteration): - return self._get_prompt_and_multimasking_choices(current_iteration=current_iteration) diff --git a/micro_sam/training/simple_sam_trainer.py b/micro_sam/training/simple_sam_trainer.py new file mode 100644 index 000000000..984e41fac --- /dev/null +++ b/micro_sam/training/simple_sam_trainer.py @@ -0,0 +1,66 @@ +import random + +from . import SamTrainer + + +class SimpleSamTrainer(SamTrainer): + """Trainer class for creating a simple SAM trainer for limited prompt-based segmentation. + """ + def __init__( + self, + use_points: bool = True, + use_box: bool = True, + **kwargs + ): + super().__init__( + n_sub_iteration=1, + mask_prob=0, + **kwargs + ) + self.use_points = use_points + self.use_box = use_box + + if self.use_points and self.use_box: + self.random_prompt_choice = True + else: + self.random_prompt_choice = False + + assert (self.use_points + self.use_box) != 0, "Please choose at least one of the prompt-based method." + + def _choose_one_positive_point(self): + "samples only a single positive point per object" + n_pos, n_neg = 1, 0 + multimask_output = True + return n_pos, n_neg, None, multimask_output + + def _choose_box(self): + "samples only a single box per object" + n_pos, n_neg = 0, 0 + multimask_output = False + get_boxes = True + return n_pos, n_neg, get_boxes, multimask_output + + def _get_prompt_and_multimasking_choices(self, current_iteration): + + if self.random_prompt_choice: # both "use_points" and "use_box" are True + available_choices = [self._choose_one_positive_point(), self._choose_box()] + return random.choice(available_choices) + else: # either of "use_points" or "use_box" are True + if self.use_points: + return self._choose_one_positive_point() + else: + return self._choose_box() + + def _get_prompt_and_multimasking_choices_for_val(self, current_iteration): + return self._get_prompt_and_multimasking_choices(current_iteration) + + +class MedSAMTrainer(SimpleSamTrainer): + """Trainer class for replicating the trainer of MedSAM (https://arxiv.org/abs/2304.12306). + """ + def __init__(self, **kwargs): + super().__init__( + use_points=False, + use_box=True, + **kwargs + ) diff --git a/scripts/medsam/btcv_finetuning.py b/scripts/medsam/btcv_finetuning.py deleted file mode 100644 index 4040c515e..000000000 --- a/scripts/medsam/btcv_finetuning.py +++ /dev/null @@ -1,146 +0,0 @@ -import os -import argparse - -import torch - -from torch_em.loss.dice import BCEDiceLossWithLogits -from torch_em.data.datasets.medical import get_btcv_loader - -import micro_sam.training as sam_training -from micro_sam.util import export_custom_sam_model - - -def get_dataloaders(patch_shape, data_path): - """This returns the btcv data loaders implemented in torch_em: - https://github.com/constantinpape/torch-em/blob/main/torch_em/data/datasets/medical/btcv.py - It will not automatically download the BTCV data. Take a look at `get_btcv_dataset`. - - Note: to replace this with another data loader you need to return a torch data loader - that retuns `x, y` tensors, where `x` is the image data and `y` are the labels. - The labels have to be in a label mask instance segmentation format. - I.e. a tensor of the same spatial shape as `x`, with each object mask having its own ID. - Important: the ID 0 is reseved for background, and the IDs must be consecutive - """ - raw_transform = sam_training.identity - - train_loader = get_btcv_loader( - path=data_path, - patch_shape=patch_shape, - batch_size=2, - ndim=2, - anatomy=None, - organs=None, - raw_transform=raw_transform, - ) - val_loader = get_btcv_loader( - path=data_path, - patch_shape=patch_shape, - batch_size=1, - ndim=2, - anatomy=None, - organs=None, - raw_transform=raw_transform, - ) - return train_loader, val_loader - - -def finetune_btcv(args): - """Code for finetuning SAM on BTCV in "micro_sam"-based MedSAM reimplementation""" - # override this (below) if you have some more complex set-up and need to specify the exact gpu - device = "cuda" if torch.cuda.is_available() else "cpu" - - # training settings: - model_type = args.model_type - checkpoint_path = None # override this to start training from a custom checkpoint - patch_shape = (1, 512, 512) # the patch shape for training - n_objects_per_batch = args.n_objects # this is the number of objects per batch that will be sampled (default: 25) - freeze_parts = args.freeze # override this to freeze different parts of the model - - # get the trainable segment anything model - model = sam_training.get_trainable_sam_model( - model_type=model_type, - device=device, - checkpoint_path=checkpoint_path, - freeze=freeze_parts - ) - model.to(device) - - # all the stuff we need for training - optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5) - scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.9, patience=10, verbose=True) - train_loader, val_loader = get_dataloaders(patch_shape=patch_shape, data_path=args.input_path) - - # this class creates all the training data for a batch (inputs, prompts and labels) - convert_inputs = sam_training.ConvertToSamInputs(transform=model.transform, box_distortion_factor=0.025) - - checkpoint_name = f"{args.model_type}/btcv_medsam" - - # the trainer which performs the joint training and validation (implemented using "torch_em") - trainer = sam_training.MedSAMTrainer( - name=checkpoint_name, - save_root=args.save_root, - train_loader=train_loader, - val_loader=val_loader, - model=model, - optimizer=optimizer, - device=device, - lr_scheduler=scheduler, - logger=sam_training.SamLogger, - log_image_interval=100, - mixed_precision=True, - convert_inputs=convert_inputs, - n_objects_per_batch=n_objects_per_batch, - compile_model=False, - mask_loss=BCEDiceLossWithLogits(), - ) - trainer.fit(args.iterations, save_every_kth_epoch=args.save_every_kth_epoch) - if args.export_path is not None: - checkpoint_path = os.path.join( - "" if args.save_root is None else args.save_root, "checkpoints", checkpoint_name, "best.pt" - ) - export_custom_sam_model( - checkpoint_path=checkpoint_path, - model_type=model_type, - save_path=args.export_path, - ) - - -def main(): - parser = argparse.ArgumentParser(description="Finetune Segment Anything for the BTCV dataset.") - parser.add_argument( - "--input_path", "-i", default="/scratch/projects/nim00007/data/btcv/", - help="The filepath to the BTCV data. If the data does not exist yet it will be downloaded." - ) - parser.add_argument( - "--model_type", "-m", default="vit_b", - help="The model type to use for fine-tuning. Either vit_t, vit_b, vit_l or vit_h." - ) - parser.add_argument( - "--save_root", "-s", - help="Where to save the checkpoint and logs. By default they will be saved where this script is run." - ) - parser.add_argument( - "--iterations", type=int, default=int(1e4), - help="For how many iterations should the model be trained?" - ) - parser.add_argument( - "--export_path", "-e", - help="Where to export the finetuned model to. The exported model can be used in the annotation tools." - ) - parser.add_argument( - "--freeze", type=str, nargs="+", default=None, - help="Which parts of the model to freeze for finetuning." - ) - parser.add_argument( - "--save_every_kth_epoch", type=int, default=None, - help="To save every kth epoch while fine-tuning. Expects an integer value." - ) - parser.add_argument( - "--n_objects", type=int, default=25, help="The number of instances (objects) per batch used for finetuning." - ) - args = parser.parse_args() - finetune_btcv(args) - - -if __name__ == "__main__": - main() From e21006e0363c970bc04e563025d3f35c55d4c743 Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Wed, 12 Jun 2024 14:48:21 +0200 Subject: [PATCH 06/53] Update util.py --- micro_sam/training/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/micro_sam/training/util.py b/micro_sam/training/util.py index 89799617c..b58cd9a6d 100644 --- a/micro_sam/training/util.py +++ b/micro_sam/training/util.py @@ -205,7 +205,7 @@ def __call__(self, x, y, n_pos, n_neg, get_boxes=False, n_samples=None): class ResizeRawTrafo: - def __init__(self, desired_shape, do_rescaling=True, padding="constant"): + def __init__(self, desired_shape, do_rescaling=False, padding="constant"): self.desired_shape = desired_shape self.padding = padding self.do_rescaling = do_rescaling From a75d581ec21670249a5ea24168e1d5767f57d29b Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Wed, 12 Jun 2024 15:47:28 +0200 Subject: [PATCH 07/53] Minor update to default rescaling params in resizerawtrafo (#635) --- .../training/light_microscopy/obtain_lm_datasets.py | 3 ++- .../training/light_microscopy/tissuenet_finetuning.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py b/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py index 083f8c218..8ac629b4f 100644 --- a/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py +++ b/finetuning/generalists/training/light_microscopy/obtain_lm_datasets.py @@ -61,7 +61,8 @@ def get_ctc_datasets( datasets.get_tissuenet_dataset( path=os.path.join(input_path, "tissuenet"), split=split_choice, download=True, patch_shape=patch_shape, raw_channel="rgb", label_channel="cell", sampler=sampler, label_dtype=label_dtype, - raw_transform=ResizeRawTrafo(patch_shape), label_transform=ResizeLabelTrafo(patch_shape, min_size=0), + raw_transform=ResizeRawTrafo(patch_shape, do_rescaling=True), + label_transform=ResizeLabelTrafo(patch_shape, min_size=0), n_samples=1000 if split_choice == "train" else 100 ), datasets.get_livecell_dataset( diff --git a/finetuning/specialists/training/light_microscopy/tissuenet_finetuning.py b/finetuning/specialists/training/light_microscopy/tissuenet_finetuning.py index 5b3b9cc90..b6f752906 100644 --- a/finetuning/specialists/training/light_microscopy/tissuenet_finetuning.py +++ b/finetuning/specialists/training/light_microscopy/tissuenet_finetuning.py @@ -25,7 +25,7 @@ def get_dataloaders(patch_shape, data_path): I.e. a tensor of the same spatial shape as `x`, with each object mask having its own ID. Important: the ID 0 is reseved for background, and the IDs must be consecutive """ - raw_transform = ResizeRawTrafo(patch_shape) + raw_transform = ResizeRawTrafo(patch_shape, do_rescaling=True) label_transform = ResizeLabelTrafo(patch_shape) sampler = MinInstanceSampler() label_dtype = torch.float32 From 22edc30e400f7a7d76eec30b02415f3770ea6266 Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Wed, 19 Jun 2024 09:15:55 +0200 Subject: [PATCH 08/53] Add LoRA Implementation (#611) Add LoRA based PEFT finetuning --- finetuning/livecell/lora/train_livecell.py | 184 +++++++++++++++++++++ micro_sam/training/peft_sam.py | 105 ++++++++++++ micro_sam/training/util.py | 11 ++ micro_sam/util.py | 11 ++ test/test_peft_training.py | 49 ++++++ 5 files changed, 360 insertions(+) create mode 100644 finetuning/livecell/lora/train_livecell.py create mode 100644 micro_sam/training/peft_sam.py create mode 100644 test/test_peft_training.py diff --git a/finetuning/livecell/lora/train_livecell.py b/finetuning/livecell/lora/train_livecell.py new file mode 100644 index 000000000..fa8874372 --- /dev/null +++ b/finetuning/livecell/lora/train_livecell.py @@ -0,0 +1,184 @@ +import os +import argparse + +import torch + +from torch_em.model import UNETR +from torch_em.loss import DiceBasedDistanceLoss +from torch_em.data.datasets import get_livecell_loader +from torch_em.transform.label import PerObjectDistanceTransform + +import micro_sam.training as sam_training +from micro_sam.util import export_custom_sam_model + + +def get_dataloaders(patch_shape, data_path, cell_type=None): + """This returns the livecell data loaders implemented in torch_em: + https://github.com/constantinpape/torch-em/blob/main/torch_em/data/datasets/livecell.py + It will automatically download the livecell data. + + Note: to replace this with another data loader you need to return a torch data loader + that retuns `x, y` tensors, where `x` is the image data and `y` are the labels. + The labels have to be in a label mask instance segmentation format. + I.e. a tensor of the same spatial shape as `x`, with each object mask having its own ID. + Important: the ID 0 is reseved for background, and the IDs must be consecutive + """ + label_transform = PerObjectDistanceTransform( + distances=True, boundary_distances=True, directed_distances=False, foreground=True, instances=True, min_size=25 + ) + raw_transform = sam_training.identity # the current workflow avoids rescaling the inputs to [-1, 1] + train_loader = get_livecell_loader( + path=data_path, patch_shape=patch_shape, split="train", batch_size=2, num_workers=16, + cell_types=cell_type, download=True, shuffle=True, label_transform=label_transform, + raw_transform=raw_transform, label_dtype=torch.float32, + ) + val_loader = get_livecell_loader( + path=data_path, patch_shape=patch_shape, split="val", batch_size=4, num_workers=16, + cell_types=cell_type, download=True, shuffle=True, label_transform=label_transform, + raw_transform=raw_transform, label_dtype=torch.float32, + ) + + return train_loader, val_loader + + +def count_parameters(model): + params = sum(p.numel() for p in model.parameters() if p.requires_grad) + params = params / 1e6 + return f"The number of trainable parameters for the provided model is {round(params, 2)}M" + + +def finetune_livecell(args): + """Code for finetuning SAM (using LoRA) on LIVECell + + Initial observations: There's no real memory advantage actually unless it's "truly" scaled up + # vit_b + # SAM: 93M (takes ~50GB) + # SAM-LoRA: 4.2M (takes ~49GB) + + # vit_l + # SAM: 312M (takes ~63GB) + # SAM-LoRA: 4.4M (takes ~61GB) + + # vit_h + # SAM: 641M (takes ~73GB) + # SAM-LoRA: 4.7M (takes ~67GB) + + # Q: Would quantization lead to better results? (eg. QLoRA / DoRA) + """ + # override this (below) if you have some more complex set-up and need to specify the exact gpu + device = "cuda" if torch.cuda.is_available() else "cpu" + + # training settings: + model_type = args.model_type + checkpoint_path = None # override this to start training from a custom checkpoint + patch_shape = (520, 704) # the patch shape for training + n_objects_per_batch = 5 # this is the number of objects per batch that will be sampled + freeze_parts = args.freeze # override this to freeze different parts of the model + rank = 4 # the rank + + # get the trainable segment anything model + model = sam_training.get_trainable_sam_model( + model_type=model_type, + device=device, + checkpoint_path=checkpoint_path, + freeze=freeze_parts, + use_lora=True, + rank=rank, + ) + model.to(device) + + # let's get the UNETR model for automatic instance segmentation pipeline + unetr = UNETR( + backbone="sam", + encoder=model.sam.image_encoder, + out_channels=3, + use_sam_stats=True, + final_activation="Sigmoid", + use_skip_connection=False, + resize_input=True, + ) + unetr.to(device) + + # let's check the total number of trainable parameters + print(count_parameters(model)) + + # let's get the parameters for SAM and the decoder from UNETR + joint_model_params = model.parameters() + + joint_model_params = [params for params in joint_model_params] # sam parameters + for name, params in unetr.named_parameters(): # unetr's decoder parameters + if not name.startswith("encoder"): + joint_model_params.append(params) + + optimizer = torch.optim.Adam(joint_model_params, lr=1e-5) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.9, patience=10) + train_loader, val_loader = get_dataloaders(patch_shape=patch_shape, data_path=args.input_path) + + # this class creates all the training data for a batch (inputs, prompts and labels) + convert_inputs = sam_training.ConvertToSamInputs(transform=model.transform, box_distortion_factor=0.025) + + trainer = sam_training.JointSamTrainer( + name="livecell_lora", + save_root=args.save_root, + train_loader=train_loader, + val_loader=val_loader, + model=model, + optimizer=optimizer, + device=device, + lr_scheduler=scheduler, + logger=sam_training.JointSamLogger, + log_image_interval=100, + mixed_precision=True, + convert_inputs=convert_inputs, + n_objects_per_batch=n_objects_per_batch, + n_sub_iteration=8, + compile_model=False, + mask_prob=0.5, # (optional) overwrite to provide the probability of using mask inputs while training + unetr=unetr, + instance_loss=DiceBasedDistanceLoss(mask_distances_in_bg=True), + instance_metric=DiceBasedDistanceLoss(mask_distances_in_bg=True) + ) + trainer.fit(args.iterations) + if args.export_path is not None: + checkpoint_path = os.path.join( + "" if args.save_root is None else args.save_root, "checkpoints", args.name, "best.pt" + ) + export_custom_sam_model( + checkpoint_path=checkpoint_path, + model_type=model_type, + save_path=args.export_path, + ) + + +def main(): + parser = argparse.ArgumentParser(description="Finetune Segment Anything for the LiveCELL dataset.") + parser.add_argument( + "--input_path", "-i", default="/scratch/projects/nim00007/sam/data/livecell/", + help="The filepath to the LiveCELL data. If the data does not exist yet it will be downloaded." + ) + parser.add_argument( + "--model_type", "-m", default="vit_b", + help="The model type to use for fine-tuning. Either vit_h, vit_b or vit_l." + ) + parser.add_argument( + "--save_root", "-s", default=None, + help="Where to save the checkpoint and logs. By default they will be saved where this script is run." + ) + parser.add_argument( + "--iterations", type=int, default=int(1e4), + help="For how many iterations should the model be trained? By default 100k." + ) + parser.add_argument( + "--export_path", "-e", + help="Where to export the finetuned model to. The exported model can be used in the annotation tools." + ) + parser.add_argument( + "--freeze", type=str, nargs="+", default=None, + help="Which parts of the model to freeze for finetuning." + ) + args = parser.parse_args() + finetune_livecell(args) + + +if __name__ == "__main__": + main() diff --git a/micro_sam/training/peft_sam.py b/micro_sam/training/peft_sam.py new file mode 100644 index 000000000..c67db7cbf --- /dev/null +++ b/micro_sam/training/peft_sam.py @@ -0,0 +1,105 @@ +import math +from typing import List, Union + +import torch.nn as nn + +from segment_anything.modeling import Sam + + +class LoRASurgery(nn.Module): + """Operates on the attention layers for performing low-rank adaptation. + + (Inspired from: https://github.com/JamesQFreeman/Sam_LoRA/) + + In SAM, it is implemented as: + ```python + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + ``` + """ + def __init__( + self, + rank: int, + block: nn.Module, + ): + super().__init__() + self.qkv = block.attn.qkv + self.dim = self.qkv.in_features + + self.w_a_linear_q = nn.Linear(self.dim, rank, bias=False) + self.w_b_linear_q = nn.Linear(rank, self.dim, bias=False) + self.w_a_linear_v = nn.Linear(self.dim, rank, bias=False) + self.w_b_linear_v = nn.Linear(rank, self.dim, bias=False) + + self.reset_parameters() + + block.attn.qkv = self + + def reset_parameters(self): + nn.init.kaiming_uniform_(self.w_a_linear_q.weight, a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.w_a_linear_v.weight, a=math.sqrt(5)) + nn.init.zeros_(self.w_b_linear_q.weight) + nn.init.zeros_(self.w_b_linear_v.weight) + + def forward(self, x): + qkv = self.qkv(x) # B, N, N, 3 * org_C + new_q = self.w_b_linear_q(self.w_a_linear_q(x)) + new_v = self.w_b_linear_v(self.w_a_linear_v(x)) + qkv[:, :, :, :self.dim] += new_q + qkv[:, :, :, -self.dim:] += new_v + return qkv + + +class PEFT_Sam(nn.Module): + """Inspired from: https://github.com/JamesQFreeman/Sam_LoRA/ + + Wraps the Segment Anything model's image encoder to different parameter efficient finetuning methods. + + Args: + model: The Segment Anything model. + rank: The rank for low-rank adaptation. + peft_module: Wrapper to operate on the image encoder blocks for the PEFT method. + attention_layers_to_update: Which specific layers we apply PEFT methods to. + """ + + def __init__( + self, + model: Sam, + rank: int, + peft_module: nn.Module = LoRASurgery, + attention_layers_to_update: Union[List[int]] = None + ): + super(PEFT_Sam, self).__init__() + + assert rank > 0 + + if attention_layers_to_update: + self.peft_layers = attention_layers_to_update + else: # Applies PEFT to the image encoder by default + self.peft_layers = list( + range(len(model.image_encoder.blocks)) + ) + + self.peft_module = peft_module + self.peft_blocks = [] + + # let's freeze all the pretrained image encoder layers first + for param in model.image_encoder.parameters(): + param.requires_grad = False + + for t_layer_i, blk in enumerate(model.image_encoder.blocks): + # If we only want specific layers with PEFT instead of all + if t_layer_i not in self.peft_layers: + continue + + peft_block = self.peft_module(rank=rank, block=blk) + self.peft_blocks.append(peft_block) + + self.peft_blocks = nn.ModuleList(self.peft_blocks) + + self.sam = model + + def forward(self, batched_input, multimask_output): + return self.sam(batched_input, multimask_output) diff --git a/micro_sam/training/util.py b/micro_sam/training/util.py index b58cd9a6d..ac9bda9bd 100644 --- a/micro_sam/training/util.py +++ b/micro_sam/training/util.py @@ -12,6 +12,7 @@ get_centers_and_bounding_boxes, get_sam_model, get_device, segmentation_to_one_hot, _DEFAULT_MODEL, ) +from .peft_sam import PEFT_Sam from .trainable_sam import TrainableSAM from torch_em.transform.label import PerObjectDistanceTransform @@ -42,6 +43,8 @@ def get_trainable_sam_model( checkpoint_path: Optional[Union[str, os.PathLike]] = None, freeze: Optional[List[str]] = None, return_state: bool = False, + use_lora: bool = False, + rank: Optional[int] = None, ) -> TrainableSAM: """Get the trainable sam model. @@ -54,6 +57,8 @@ def get_trainable_sam_model( freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder By default nothing is frozen and the full model is updated. return_state: Whether to return the full checkpoint state. + use_lora: Whether to use the low rank adaptation method for finetuning. + rank: The rank of the decomposition matrices for updating weights in each attention layer. Returns: The trainable segment anything model. @@ -80,8 +85,14 @@ def get_trainable_sam_model( if name.startswith(f"{freeze}"): param.requires_grad = False + if use_lora: # overwrites the SAM model by freezing the backbone and allow low rank adaption to attention layers + if rank is None: + rank = 4 # HACK: in case the user does not pass the rank, we provide a random rank to them + sam = PEFT_Sam(sam, rank=rank).sam + # convert to trainable sam trainable_sam = TrainableSAM(sam) + if return_state: return trainable_sam, state return trainable_sam diff --git a/micro_sam/util.py b/micro_sam/util.py index c9768dc84..e61a28f7f 100644 --- a/micro_sam/util.py +++ b/micro_sam/util.py @@ -270,6 +270,8 @@ def get_sam_model( checkpoint_path: Optional[Union[str, os.PathLike]] = None, return_sam: bool = False, return_state: bool = False, + use_lora: bool = False, + rank: Optional[int] = None, ) -> SamPredictor: r"""Get the SegmentAnything Predictor. @@ -302,6 +304,8 @@ def get_sam_model( then `model_type` must be given as "vit_b". return_sam: Return the sam model object as well as the predictor. return_state: Return the unpickled checkpoint state. + use_lora: Whether to use the low rank adaptation method for finetuning. + rank: The rank of the decomposition matrices for updating weights in each attention layer. Returns: The segment anything predictor. @@ -347,6 +351,13 @@ def get_sam_model( state, model_state = _load_checkpoint(checkpoint_path) sam = sam_model_registry[abbreviated_model_type]() + + if use_lora: # overwrites the SAM model by freezing the backbone and allow low rank adaption to attention layers + from micro_sam.training.peft_sam import PEFT_Sam + if rank is None: + rank = 4 # HACK: in case the user does not pass the rank, we provide a random rank to them + sam = PEFT_Sam(sam, rank=rank).sam + sam.load_state_dict(model_state) sam.to(device=device) diff --git a/test/test_peft_training.py b/test/test_peft_training.py new file mode 100644 index 000000000..7c2f12702 --- /dev/null +++ b/test/test_peft_training.py @@ -0,0 +1,49 @@ +import unittest + +import torch + +from micro_sam.util import get_sam_model +from micro_sam.training.peft_sam import PEFT_Sam + + +class TestPEFTModule(unittest.TestCase): + """Integraton test for instantiating a PEFT SAM model. + """ + def _fetch_sam_model(self, model_type, device): + _, sam_model = get_sam_model(model_type=model_type, device=device, return_sam=True) + return sam_model + + def _create_dummy_inputs(self, shape): + input_image = torch.ones(shape) + return input_image + + def test_peft_sam(self): + model_type = "vit_b" + device = "cpu" + + # Load the dummy inputs. + input_shape = (1, 512, 512) + inputs = self._create_dummy_inputs(shape=input_shape) + + # Convert to the inputs expected by Segment Anything + batched_inputs = [ + {"image": inputs, "original_size": input_shape[1:]} + ] + + # Load the Segment Anything model. + sam_model = self._fetch_sam_model(model_type=model_type, device=device) + + # Wrap the Segment Anything model with PEFT methods. + peft_sam_model = PEFT_Sam(model=sam_model, rank=4) + + # Get the model outputs + outputs = peft_sam_model(batched_input=batched_inputs, multimask_output=False) + + # Check the expected shape of the outputs + mask_shapes = [output["masks"].shape[-2:] for output in outputs] + for shape in mask_shapes: + self.assertEqual(shape, input_shape[1:]) + + +if __name__ == "__main__": + unittest.main() From 14f9f237923ac427becb0279edc67cc3118a5427 Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Fri, 21 Jun 2024 16:44:54 +0200 Subject: [PATCH 09/53] Add SemanticSamTrainer (#637) Add semantic sam trainer --- micro_sam/training/__init__.py | 1 + micro_sam/training/models/__init__.py | 0 micro_sam/training/models/build_sam.py | 119 +++++++++++++++++++++ micro_sam/training/semantic_sam_trainer.py | 95 ++++++++++++++++ micro_sam/training/util.py | 29 ++++- micro_sam/util.py | 53 ++++++++- 6 files changed, 294 insertions(+), 3 deletions(-) create mode 100644 micro_sam/training/models/__init__.py create mode 100644 micro_sam/training/models/build_sam.py create mode 100644 micro_sam/training/semantic_sam_trainer.py diff --git a/micro_sam/training/__init__.py b/micro_sam/training/__init__.py index 1c354b658..e825ba630 100644 --- a/micro_sam/training/__init__.py +++ b/micro_sam/training/__init__.py @@ -5,4 +5,5 @@ from .util import ConvertToSamInputs, get_trainable_sam_model, identity from .joint_sam_trainer import JointSamTrainer, JointSamLogger from .simple_sam_trainer import SimpleSamTrainer, MedSAMTrainer +from .semantic_sam_trainer import SemanticSamTrainer from .training import train_sam, train_sam_for_configuration, default_sam_loader, default_sam_dataset, CONFIGURATIONS diff --git a/micro_sam/training/models/__init__.py b/micro_sam/training/models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/micro_sam/training/models/build_sam.py b/micro_sam/training/models/build_sam.py new file mode 100644 index 000000000..525b20db4 --- /dev/null +++ b/micro_sam/training/models/build_sam.py @@ -0,0 +1,119 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# +# NOTE: This code has been adapted from Segment Anything. +# - https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/build_sam.py +# This is done in favor of exposing some of the model's hard-coded input parameters for: +# - downstream applications (eg. updating the "num_multimask_outputs" for multi-class semantic segmentation) +# + + +import torch + +from functools import partial + +from segment_anything.modeling import Sam, ImageEncoderViT, PromptEncoder, MaskDecoder, TwoWayTransformer + + +def build_sam_vit_h(checkpoint=None, num_multimask_outputs=3): + return _build_sam( + encoder_embed_dim=1280, + encoder_depth=32, + encoder_num_heads=16, + encoder_global_attn_indexes=[7, 15, 23, 31], + checkpoint=checkpoint, + num_multimask_outputs=num_multimask_outputs, + ) + + +build_sam = build_sam_vit_h + + +def build_sam_vit_l(checkpoint=None, num_multimask_outputs=3): + return _build_sam( + encoder_embed_dim=1024, + encoder_depth=24, + encoder_num_heads=16, + encoder_global_attn_indexes=[5, 11, 17, 23], + checkpoint=checkpoint, + num_multimask_outputs=num_multimask_outputs, + ) + + +def build_sam_vit_b(checkpoint=None, num_multimask_outputs=3): + return _build_sam( + encoder_embed_dim=768, + encoder_depth=12, + encoder_num_heads=12, + encoder_global_attn_indexes=[2, 5, 8, 11], + checkpoint=checkpoint, + num_multimask_outputs=num_multimask_outputs, + ) + + +sam_model_registry = { + "default": build_sam_vit_h, + "vit_h": build_sam_vit_h, + "vit_l": build_sam_vit_l, + "vit_b": build_sam_vit_b, +} + + +def _build_sam( + encoder_embed_dim, + encoder_depth, + encoder_num_heads, + encoder_global_attn_indexes, + checkpoint=None, + num_multimask_outputs=3, +): + prompt_embed_dim = 256 + image_size = 1024 + vit_patch_size = 16 + image_embedding_size = image_size // vit_patch_size + sam = Sam( + image_encoder=ImageEncoderViT( + depth=encoder_depth, + embed_dim=encoder_embed_dim, + img_size=image_size, + mlp_ratio=4, + norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), + num_heads=encoder_num_heads, + patch_size=vit_patch_size, + qkv_bias=True, + use_rel_pos=True, + global_attn_indexes=encoder_global_attn_indexes, + window_size=14, + out_chans=prompt_embed_dim, + ), + prompt_encoder=PromptEncoder( + embed_dim=prompt_embed_dim, + image_embedding_size=(image_embedding_size, image_embedding_size), + input_image_size=(image_size, image_size), + mask_in_chans=16, + ), + mask_decoder=MaskDecoder( + num_multimask_outputs=num_multimask_outputs, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + ), + pixel_mean=[123.675, 116.28, 103.53], + pixel_std=[58.395, 57.12, 57.375], + ) + sam.eval() + if checkpoint is not None: + with open(checkpoint, "rb") as f: + state_dict = torch.load(f) + sam.load_state_dict(state_dict) + return sam diff --git a/micro_sam/training/semantic_sam_trainer.py b/micro_sam/training/semantic_sam_trainer.py new file mode 100644 index 000000000..09a9151e5 --- /dev/null +++ b/micro_sam/training/semantic_sam_trainer.py @@ -0,0 +1,95 @@ +import time + +import torch +import torch.nn as nn + +from torch_em.loss import DiceLoss +from torch_em.trainer import DefaultTrainer + + +class SemanticSamTrainer(DefaultTrainer): + """ + """ + def __init__( + self, + convert_inputs, + num_classes: int = 1, + **kwargs + ): + loss = DiceLoss() + metric = DiceLoss() + super().__init__(loss=loss, metric=metric, **kwargs) + + self.convert_inputs = convert_inputs + self.num_classes = num_classes + self.compute_ce_loss = nn.BCELoss() if num_classes == 1 else nn.CrossEntropyLoss() + self._kwargs = kwargs + + def _compute_loss(self, y, masks): + # Compute dice loss for the predictions + dice_loss = self.loss(masks, y.to(self.device, non_blocking=True)) + + # Compute cross entropy loss for the predictions + ce_loss = self.compute_ce_loss(masks, y.to(self.device, non_blocking=True)) + + net_loss = dice_loss + ce_loss + return net_loss + + def _get_model_outputs(self, batched_inputs): + image_embeddings, batched_inputs = self.model.image_embeddings_oft(batched_inputs) + batched_outputs = self.model(batched_inputs, image_embeddings, multimask_output=(self.num_classes > 1)) + masks = torch.stack([output["masks"].squeeze(0) for output in batched_outputs]) + return masks + + def _train_epoch_impl(self, progress, forward_context, backprop): + self.model.train() + + t_per_iter = time.time() + for x, y in self.train_loader: + self.optimizer.zero_grad() + + batched_inputs = self.convert_inputs(x, y) + + with forward_context(): + masks = self._get_model_outputs(batched_inputs) + net_loss = self._compute_loss(y, masks) + + backprop(net_loss) + + if self.logger is not None: + lr = [pm["lr"] for pm in self.optimizer.param_groups][0] + self.logger.log_train(self._iteration, net_loss, lr, x, y, masks, log_gradients=True) + + self._iteration += 1 + if self._iteration >= self.max_iteration: + break + progress.update(1) + + t_per_iter = (time.time() - t_per_iter) + return t_per_iter + + def _validate_impl(self, forward_context): + self.model.eval() + + metric_val, loss_val = 0.0, 0.0 + + with torch.no_grad(): + for x, y in self.val_loader: + batched_inputs = self.convert_inputs(x, y) + + with forward_context(): + masks = self._get_model_outputs(batched_inputs) + net_loss = self._compute_loss(y, masks) + + loss_val += net_loss.item() + metric_val += net_loss.item() + + loss_val /= len(self.val_loader) + metric_val /= len(self.val_loader) + print() + print(f"The Average Validation Metric Score for the Current Epoch is {1 - metric_val}") + + if self.logger is not None: + self.logger.log_validation(self._iteration, metric_val, loss_val, x, y, masks) + + return metric_val diff --git a/micro_sam/training/util.py b/micro_sam/training/util.py index ac9bda9bd..3e4f01e31 100644 --- a/micro_sam/training/util.py +++ b/micro_sam/training/util.py @@ -3,6 +3,7 @@ from typing import List, Optional, Union import numpy as np + import torch from segment_anything.utils.transforms import ResizeLongestSide @@ -45,6 +46,8 @@ def get_trainable_sam_model( return_state: bool = False, use_lora: bool = False, rank: Optional[int] = None, + flexible_load_checkpoint: bool = False, + **model_kwargs ) -> TrainableSAM: """Get the trainable sam model. @@ -59,6 +62,7 @@ def get_trainable_sam_model( return_state: Whether to return the full checkpoint state. use_lora: Whether to use the low rank adaptation method for finetuning. rank: The rank of the decomposition matrices for updating weights in each attention layer. + flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints. Returns: The trainable segment anything model. @@ -66,7 +70,15 @@ def get_trainable_sam_model( # set the device here so that the correct one is passed to TrainableSAM below device = get_device(device) _, sam, state = get_sam_model( - model_type=model_type, device=device, checkpoint_path=checkpoint_path, return_sam=True, return_state=True + model_type=model_type, + device=device, + checkpoint_path=checkpoint_path, + return_sam=True, + return_state=True, + use_lora=use_lora, + rank=rank, + flexible_load_checkpoint=flexible_load_checkpoint, + **model_kwargs ) # freeze components of the model if freeze was passed @@ -85,6 +97,7 @@ def get_trainable_sam_model( if name.startswith(f"{freeze}"): param.requires_grad = False + # Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything if use_lora: # overwrites the SAM model by freezing the backbone and allow low rank adaption to attention layers if rank is None: rank = 4 # HACK: in case the user does not pass the rank, we provide a random rank to them @@ -210,6 +223,20 @@ def __call__(self, x, y, n_pos, n_neg, get_boxes=False, n_samples=None): return batched_inputs, batched_sampled_cell_ids_list +class ConvertToSemanticSamInputs: + """ + """ + def __call__(self, x, y): + """Convert the outputs of dataloader to the batched format of inputs expected by SAM. + """ + batched_inputs = [] + for image, gt in zip(x, y): + batched_input = {"image": image, "original_size": image.shape[1:]} + batched_inputs.append(batched_input) + + return batched_inputs + + # # Raw and Label Transformations for the Generalist and Specialist finetuning # diff --git a/micro_sam/util.py b/micro_sam/util.py index e61a28f7f..370fcca2a 100644 --- a/micro_sam/util.py +++ b/micro_sam/util.py @@ -272,6 +272,8 @@ def get_sam_model( return_state: bool = False, use_lora: bool = False, rank: Optional[int] = None, + flexible_load_checkpoint: bool = False, + **model_kwargs, ) -> SamPredictor: r"""Get the SegmentAnything Predictor. @@ -306,6 +308,7 @@ def get_sam_model( return_state: Return the unpickled checkpoint state. use_lora: Whether to use the low rank adaptation method for finetuning. rank: The rank of the decomposition matrices for updating weights in each attention layer. + flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints. Returns: The segment anything predictor. @@ -350,15 +353,29 @@ def get_sam_model( ) state, model_state = _load_checkpoint(checkpoint_path) - sam = sam_model_registry[abbreviated_model_type]() + # Whether to update parameters necessary to initialize the model + if model_kwargs: # Checks whether model_kwargs have been provided or not + if abbreviated_model_type == "vit_t": + raise ValueError("'micro-sam' does not allow changing the model parameters for 'mobile-sam'.") + + from micro_sam.training.models.build_sam import sam_model_registry # noqa + + sam = sam_model_registry[abbreviated_model_type](**model_kwargs) + + # Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything if use_lora: # overwrites the SAM model by freezing the backbone and allow low rank adaption to attention layers from micro_sam.training.peft_sam import PEFT_Sam if rank is None: rank = 4 # HACK: in case the user does not pass the rank, we provide a random rank to them sam = PEFT_Sam(sam, rank=rank).sam - sam.load_state_dict(model_state) + # In case the model checkpoints have some issues when it is initialized with different parameters than default. + if flexible_load_checkpoint: + sam = _handle_checkpoint_loading(sam, model_state) + else: + sam.load_state_dict(model_state) + sam.to(device=device) predictor = SamPredictor(sam) @@ -379,6 +396,38 @@ def get_sam_model( return predictor +def _handle_checkpoint_loading(sam, model_state): + # Whether to handle the mismatch issues in a bit more elegant way. + # eg. while training for multi-class semantic segmentation in the mask encoder, + # parameters are updated - leading to "size mismatch" errors + + new_state_dict = {} # for loading matching parameters + mismatched_layers = [] # for tracking mismatching parameters + + reference_state = sam.state_dict() + + for k, v in model_state.items(): + if reference_state[k].size() == v.size(): + new_state_dict[k] = v + else: + mismatched_layers.append(k) + + reference_state.update(new_state_dict) + + if len(mismatched_layers) > 0: + warnings.warn(f"The layers with size mismatch: {mismatched_layers}") + + for mlayer in mismatched_layers: + if 'weight' in mlayer: + torch.nn.init.kaiming_uniform_(reference_state[mlayer]) + elif 'bias' in mlayer: + reference_state[mlayer].zero_() + + sam.load_state_dict(reference_state) + + return sam + + def export_custom_sam_model( checkpoint_path: Union[str, os.PathLike], model_type: str, From 8e9750c064dea96254b30e186bfde12d38614640 Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Mon, 24 Jun 2024 22:35:05 +0200 Subject: [PATCH 10/53] Minor update to loading custon build_sam models (#640) --- micro_sam/util.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/micro_sam/util.py b/micro_sam/util.py index 370fcca2a..d96ec0ec0 100644 --- a/micro_sam/util.py +++ b/micro_sam/util.py @@ -357,11 +357,13 @@ def get_sam_model( # Whether to update parameters necessary to initialize the model if model_kwargs: # Checks whether model_kwargs have been provided or not if abbreviated_model_type == "vit_t": - raise ValueError("'micro-sam' does not allow changing the model parameters for 'mobile-sam'.") + raise ValueError("'micro-sam' does not support changing the model parameters for 'mobile-sam'.") - from micro_sam.training.models.build_sam import sam_model_registry # noqa + from .training.models import build_sam + sam = build_sam.sam_model_registry[abbreviated_model_type](**model_kwargs) - sam = sam_model_registry[abbreviated_model_type](**model_kwargs) + else: + sam = sam_model_registry[abbreviated_model_type]() # Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything if use_lora: # overwrites the SAM model by freezing the backbone and allow low rank adaption to attention layers From 5e16964d3da61b5d94d3bb28b3d358cefeb6a1e6 Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Wed, 26 Jun 2024 09:32:54 +0200 Subject: [PATCH 11/53] Minor fix to loading models with incompatible layers (#641) --- micro_sam/util.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/micro_sam/util.py b/micro_sam/util.py index d96ec0ec0..75ebe724d 100644 --- a/micro_sam/util.py +++ b/micro_sam/util.py @@ -409,10 +409,11 @@ def _handle_checkpoint_loading(sam, model_state): reference_state = sam.state_dict() for k, v in model_state.items(): - if reference_state[k].size() == v.size(): - new_state_dict[k] = v - else: - mismatched_layers.append(k) + if k in reference_state: # This is done to get rid of unwanted layers from pretrained SAM. + if reference_state[k].size() == v.size(): + new_state_dict[k] = v + else: + mismatched_layers.append(k) reference_state.update(new_state_dict) From a9cb6c73566fa6b1172a3e7d369d54177f8cb598 Mon Sep 17 00:00:00 2001 From: lufre1 <155526548+lufre1@users.noreply.github.com> Date: Wed, 26 Jun 2024 13:55:30 +0200 Subject: [PATCH 12/53] 636 enhance 3d image processing based on ma sam (#639) Implement 3D SAM Wrapper based on MA-SAM --------- Co-authored-by: Constantin Pape --- development/check_3d_model.py | 81 +++++++++ development/instance_segmentation_3d.py | 154 +++++++++++++++++ micro_sam/sam_3d_wrapper.py | 192 +++++++++++++++++++++ micro_sam/training/models/build_sam.py | 11 +- micro_sam/training/semantic_sam_trainer.py | 14 ++ 5 files changed, 448 insertions(+), 4 deletions(-) create mode 100644 development/check_3d_model.py create mode 100644 development/instance_segmentation_3d.py create mode 100644 micro_sam/sam_3d_wrapper.py diff --git a/development/check_3d_model.py b/development/check_3d_model.py new file mode 100644 index 000000000..ac49609cc --- /dev/null +++ b/development/check_3d_model.py @@ -0,0 +1,81 @@ +import numpy as np +import torch +import micro_sam.util as util + +from micro_sam.sam_3d_wrapper import get_3d_sam_model +from micro_sam.training.semantic_sam_trainer import SemanticSamTrainer3D + + +def predict_3d_model(): + d_size = 8 + device = "cuda" if torch.cuda.is_available() else "cpu" + sam_3d = get_3d_sam_model(device, d_size) + + input_ = 255 * np.random.rand(1, d_size, 3, 1024, 1024).astype("float32") + with torch.no_grad(): + input_ = torch.from_numpy(input_).to(device) + out = sam_3d(input_, multimask_output=False, image_size=1024) + print(out["masks"].shape) + + +class DummyDataset(torch.utils.data.Dataset): + def __init__(self, patch_shape, n_classes): + self.patch_shape = patch_shape + self.n_classes = n_classes + + def __len__(self): + return 5 + + def __getitem__(self, index): + image_shape = (self.patch_shape[0], 3) + self.patch_shape[1:] + x = np.random.rand(*image_shape).astype("float32") + label_shape = (self.n_classes,) + self.patch_shape + y = (np.random.rand(*label_shape) > 0.5).astype("float32") + return x, y + + +def get_loader(patch_shape, n_classes, batch_size): + ds = DummyDataset(patch_shape, n_classes) + loader = torch.utils.data.DataLoader(ds, batch_size=batch_size, shuffle=True, num_workers=4) + loader.shuffle = True + return loader + + +# TODO: we are missing the resizing in the model, so currently this only supports 1024x1024 +def train_3d_model(): + from micro_sam.training.util import ConvertToSemanticSamInputs + + d_size = 4 + n_classes = 5 + batch_size = 2 + image_size = 512 + + device = "cuda" if torch.cuda.is_available() else "cpu" + sam_3d = get_3d_sam_model(device, n_classes=n_classes, image_size=image_size) + + train_loader = get_loader((d_size, image_size, image_size), n_classes, batch_size) + val_loader = get_loader((d_size, image_size, image_size), n_classes, batch_size) + + optimizer = torch.optim.AdamW(sam_3d.parameters(), lr=5e-5) + + trainer = SemanticSamTrainer3D( + name="test-sam", + model=sam_3d, + convert_inputs=ConvertToSemanticSamInputs(), + num_classes=n_classes, + train_loader=train_loader, + val_loader=val_loader, + optimizer=optimizer, + device=device, + compile_model=False, + ) + trainer.fit(10) + + +def main(): + # predict_3d_model() + train_3d_model() + + +if __name__ == "__main__": + main() diff --git a/development/instance_segmentation_3d.py b/development/instance_segmentation_3d.py new file mode 100644 index 000000000..9a5943655 --- /dev/null +++ b/development/instance_segmentation_3d.py @@ -0,0 +1,154 @@ +import napari +from elf.io import open_file +import h5py +import os +import torch +import numpy as np + +import micro_sam.sam_3d_wrapper as sam_3d +import micro_sam.util as util +# from micro_sam.segment_instances import ( +# segment_instances_from_embeddings, +# segment_instances_sam, +# segment_instances_from_embeddings_3d, +# ) +from micro_sam import multi_dimensional_segmentation as mds +from micro_sam.visualization import compute_pca +INPUT_PATH_CLUSTER = "/scratch-grete/projects/nim00007/data/mitochondria/cooper/mito_tomo/outer-membrane1/1_20230125_TOMO_HOI_WT_36859_J2_upSTEM750_BC3.6/upSTEM750_36859_J2_TS_SP_003_rec_2kb1dawbp_crop.h5" +# EMBEDDINGS_PATH_CLUSTER = "/scratch-grete/projects/nim00007/data/mitochondria/cooper/mito_tomo/outer-membrane1/1_20230125_TOMO_HOI_WT_36859_J2_upSTEM750_BC3.6/embedding-mito-3d.zarr" +EMBEDDINGS_PATH_CLUSTER = "/scratch-grete/usr/nimlufre/" +INPUT_PATH_LOCAL = "/home/freckmann15/data/mitochondria/cooper/mito_tomo/outer-membrane1/1_20230125_TOMO_HOI_WT_36859_J2_upSTEM750_BC3.6/upSTEM750_36859_J2_TS_SP_003_rec_2kb1dawbp_crop.h5" +EMBEDDINGS_PATH_LOCAL = "/home/freckmann15/data/mitochondria/cooper/mito_tomo/outer-membrane1/1_20230125_TOMO_HOI_WT_36859_J2_upSTEM750_BC3.6/" +INPUT_PATH = "/scratch-grete/projects/nim00007/data/mitochondria/moebius/volume_em/training_blocks_v1/4007_cutout_1.h5" +EMBEDDINGS_PATH = "/scratch-grete/projects/nim00007/data/mitochondria/moebius/volume_em/training_blocks_v1/embedding-mito-3d.zarr" +TIMESERIES_PATH = "../examples/data/DIC-C2DH-HeLa/train/01" +EMBEDDINGS_TRACKING_PATH = "../examples/embeddings/embeddings-ctc.zarr" + +# def cell_segmentation_3d() -> None: +# with open_file(TIMESERIES_PATH, mode="r") as f: +# timeseries = f["*.tif"][:50] + +# predictor = util.get_sam_model() +# image_embeddings = util.precompute_image_embeddings(predictor, timeseries, EMBEDDINGS_TRACKING_PATH) + +# seg = segment_instances_from_embeddings_3d(predictor, image_embeddings) + +# v = napari.Viewer() +# v.add_image(timeseries) +# v.add_labels(seg) +# napari.run() + + +# def _get_dataset_and_reshape(path: str, key: str = "raw", shape: tuple = (32, 256, 256)) -> np.ndarray: + +# with h5py.File(path, "r") as f: +# # Check if the key exists in the file +# if key not in f: +# raise KeyError(f"Dataset with key '{key}' not found in file '{path}'.") + +# # Load the dataset +# dataset = f[key][...] + +# # Reshape the dataset +# if dataset.shape != shape: +# try: +# # Attempt to reshape the dataset to the desired shape +# dataset = dataset.reshape(shape) +# except ValueError: +# raise ValueError(f"Failed to reshape dataset with key '{key}' to shape {shape}.") + +# return dataset +def get_dataset_cutout(path: str, key: str = "raw", shape: tuple = (32, 256, 256), + start_index: tuple = (0, 0, 0)) -> np.ndarray: + """ + Loads a cutout from a dataset in an HDF5 file. + + Args: + path (str): Path to the HDF5 file. + key (str, optional): Key of the dataset to load. Defaults to "raw". + shape (tuple, optional): Desired shape of the cutout. Defaults to (32, 256, 256). + start_index (tuple, optional): Starting index for the cutout within the dataset. + Defaults to None, which selects a random starting point within valid bounds. + + Returns: + np.ndarray: The loaded cutout of the dataset with the specified shape. + + Raises: + KeyError: If the specified key is not found in the HDF5 file. + ValueError: If the cutout shape exceeds the dataset dimensions or the starting index is invalid. + """ + + with h5py.File(path, "r") as f: + + dataset = f[key] + dataset_shape = dataset.shape + print("original data shape", dataset_shape) + + # Validate cutout shape + if any(s > d for s, d in zip(shape, dataset_shape)): + raise ValueError(f"Cutout shape {shape} exceeds dataset dimensions {dataset_shape}.") + + # Generate random starting index if not provided + if start_index is None: + start_index = tuple(np.random.randint(0, dim - s + 1, size=len(shape)) for dim, s in zip(dataset_shape, shape)) + + # Calculate end index + end_index = tuple(min(i + s, dim) for i, s, dim in zip(start_index, shape, dataset_shape)) + + # Load the cutout + cutout = dataset[start_index[0]:end_index[0], + start_index[1]:end_index[1], + start_index[2]:end_index[2]] + print("cutout data shape", cutout.shape) + + return cutout + + +def mito_segmentation_3d() -> None: + patch_shape = (32, 256, 256) + start_index = (10, 32, 64) + data_slice = get_dataset_cutout(INPUT_PATH_LOCAL, shape=patch_shape) #start_index=start_index + + device = "cuda" if torch.cuda.is_available() else "cpu" + model_type = "vit_b" + predictor, sam = util.get_sam_model(return_sam=True, model_type=model_type, device=device) + + d_size = 3 + predictor3d = sam_3d.Predictor3D(sam, d_size) + print(predictor3d) + #breakpoint() + predictor3d.model.forward(torch.from_numpy(data_slice), multimask_output=False, image_size=patch_shape) + # output = predictor3d.model([data_slice], multimask_output=False)#image_size=patch_shape + + # predictor3d._hash = util.models().registry[model_type] + + # predictor3d.model_name = model_type + + # image_embeddings = util.precompute_image_embeddings(predictor3d, volume, EMBEDDINGS_PATH_CLUSTER) + # seg = util.segment_instances_from_embeddings_3d(predictor3d, image_embeddings) + + # prediction_filename = os.path.join(EMBEDDINGS_PATH_CLUSTER, f"prediction_{INPUT_PATH_CLUSTER}.h5") + # with h5py.File(prediction_filename, "w") as prediction_file: + # prediction_file.create_dataset("prediction", data=seg) + + # visualize + # v = napari.Viewer() + # v.add_image(volume) + # v.add_labels(seg) + # v.add_labels(seg_sam) + # napari.run() + + + +def main(): + # automatic segmentation for the data from Lucchi et al. (see 'sam_annotator_3d.py') + # nucleus_segmentation(use_mws=True) + mito_segmentation_3d() + + # automatic segmentation for data from the cell tracking challenge (see 'sam_annotator_tracking.py') + # cell_segmentation(use_mws=True) + # cell_segmentation_3d() + + +if __name__ == "__main__": + main() diff --git a/micro_sam/sam_3d_wrapper.py b/micro_sam/sam_3d_wrapper.py new file mode 100644 index 000000000..ccb9968e0 --- /dev/null +++ b/micro_sam/sam_3d_wrapper.py @@ -0,0 +1,192 @@ +from typing import Type + +import torch +import torch.nn as nn + +from segment_anything.modeling.image_encoder import window_partition, window_unpartition +from segment_anything.modeling import Sam + +from .util import get_sam_model + + +def get_3d_sam_model(device, n_classes, image_size, model_type="vit_b"): + predictor, sam = get_sam_model( + return_sam=True, model_type=model_type, device=device, num_multimask_outputs=n_classes, + flexible_load_checkpoint=True, image_size=image_size, + ) + sam_3d = Sam3DWrapper(sam) + sam_3d.to(device) + return sam_3d + + +class Sam3DWrapper(nn.Module): + def __init__(self, sam_model: Sam): + """ + Initializes the Sam3DWrapper object. + + Args: + sam_model (Sam): The Sam model to be wrapped. + """ + super().__init__() + sam_model.image_encoder = ImageEncoderViT3DWrapper( + image_encoder=sam_model.image_encoder + ) + self.sam_model = sam_model + + # FIXME + # - handling of the image size here is wrong, this only works for square images + # - this does not take care of resizing + # unclear how batches are handled + def forward(self, batched_input, multimask_output, image_size) -> torch.Tensor: + return self._forward_train(batched_input, multimask_output, image_size) + + def _forward_train(self, batched_input, multimask_output, image_size): + # dimensions: [b, d, 3, h, w] + shape = batched_input.shape + batch_size, d_size, hw_size = shape[0], shape[1], shape[-2] + batched_input = batched_input.contiguous().view(-1, 3, hw_size, hw_size) + + input_images = self.sam_model.preprocess(batched_input) + image_embeddings = self.sam_model.image_encoder(input_images, d_size) + sparse_embeddings, dense_embeddings = self.sam_model.prompt_encoder( + points=None, boxes=None, masks=None + ) + low_res_masks, iou_predictions = self.sam_model.mask_decoder( + image_embeddings=image_embeddings, + image_pe=self.sam_model.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output + ) + masks = self.sam_model.postprocess_masks( + low_res_masks, + input_size=(image_size, image_size), + original_size=(image_size, image_size) + ) + + # Bring the masks and low-res masks into the correct shape: + # - disentangle batches and z-slices + # - rearrange output channels and z-slices + + n_channels = masks.shape[1] + masks = masks.view(*(batch_size, d_size, n_channels, masks.shape[-2], masks.shape[-1])) + low_res_masks = low_res_masks.view( + *(batch_size, d_size, n_channels, low_res_masks.shape[-2], low_res_masks.shape[-1]) + ) + + masks = masks.transpose(1, 2) + low_res_masks = low_res_masks.transpose(1, 2) + + outputs = { + "masks": masks, + "iou_predictions": iou_predictions, + "low_res_logits": low_res_masks + } + return outputs + + +class ImageEncoderViT3DWrapper(nn.Module): + def __init__( + self, + image_encoder: nn.Module, + num_heads: int = 12, + embed_dim: int = 768, + ): + super().__init__() + self.image_encoder = image_encoder + self.img_size = self.image_encoder.img_size + + # replace default blocks with 3d adapter blocks + for i, blk in enumerate(self.image_encoder.blocks): + self.image_encoder.blocks[i] = NDBlockWrapper(block=blk, num_heads=num_heads, dim=embed_dim) + + def forward(self, x: torch.Tensor, d_size: int) -> torch.Tensor: + x = self.image_encoder.patch_embed(x) + if self.image_encoder.pos_embed is not None: + x = x + self.image_encoder.pos_embed + + for blk in self.image_encoder.blocks: + x = blk(x, d_size) + + x = self.image_encoder.neck(x.permute(0, 3, 1, 2)) + + return x + + +class NDBlockWrapper(nn.Module): + def __init__( + self, + block: nn.Module, + dim: int, + num_heads: int, + norm_layer: Type[nn.Module] = nn.LayerNorm, + adapter_channels: int = 384, + ): + super().__init__() + self.block = block + + self.adapter_channels = adapter_channels + self.adapter_linear_down = nn.Linear(dim, self.adapter_channels, bias=False) + self.adapter_linear_up = nn.Linear(self.adapter_channels, dim, bias=False) + self.adapter_conv = nn.Conv3d( + self.adapter_channels, self.adapter_channels, kernel_size=(3, 1, 1), padding="same" + ) + self.adapter_act = nn.GELU() + self.adapter_norm = norm_layer(dim) + + self.adapter_linear_down_2 = nn.Linear(dim, self.adapter_channels, bias=False) + self.adapter_linear_up_2 = nn.Linear(self.adapter_channels, dim, bias=False) + self.adapter_conv_2 = nn.Conv3d( + self.adapter_channels, self.adapter_channels, kernel_size=(3, 1, 1), padding="same" + ) + self.adapter_act_2 = nn.GELU() + self.adapter_norm_2 = norm_layer(dim) + + def forward(self, x: torch.Tensor, d_size) -> torch.Tensor: + b_size, hw_size = x.shape[0], x.shape[1] + + # 3D adapter + shortcut = x + x = self.adapter_norm(x) + x = self.adapter_linear_down(x) + x = x.contiguous().view(int(b_size/d_size), d_size, hw_size, hw_size, self.adapter_channels) + x = torch.permute(x, (0, -1, 1, 2, 3)) + x = self.adapter_conv(x) + x = torch.permute(x, (0, 2, 3, 4, 1)) + x = x.contiguous().view(b_size, hw_size, hw_size, self.adapter_channels) + x = self.adapter_act(x) + x = self.adapter_linear_up(x) + x = shortcut + x + # end 3D adapter + + shortcut = x + x = self.block.norm1(x) + # Window partition + if self.block.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.block.window_size) + + x = self.block.attn(x) + # Reverse window partition + if self.block.window_size > 0: + x = window_unpartition(x, self.block.window_size, pad_hw, (H, W)) + + x = shortcut + x + + # 3D adapter + shortcut = x + x = self.adapter_norm_2(x) + x = self.adapter_linear_down_2(x) + x = x.contiguous().view(int(b_size/d_size), d_size, hw_size, hw_size, self.adapter_channels) + x = torch.permute(x, (0, -1, 1, 2, 3)) + x = self.adapter_conv_2(x) + x = torch.permute(x, (0, 2, 3, 4, 1)) + x = x.contiguous().view(b_size, hw_size, hw_size, self.adapter_channels) + x = self.adapter_act_2(x) + x = self.adapter_linear_up_2(x) + x = shortcut + x + # end 3D adapter + + x = x + self.block.mlp(self.block.norm2(x)) + + return x diff --git a/micro_sam/training/models/build_sam.py b/micro_sam/training/models/build_sam.py index 525b20db4..8fa6bcc6a 100644 --- a/micro_sam/training/models/build_sam.py +++ b/micro_sam/training/models/build_sam.py @@ -19,7 +19,7 @@ from segment_anything.modeling import Sam, ImageEncoderViT, PromptEncoder, MaskDecoder, TwoWayTransformer -def build_sam_vit_h(checkpoint=None, num_multimask_outputs=3): +def build_sam_vit_h(checkpoint=None, num_multimask_outputs=3, image_size=1024): return _build_sam( encoder_embed_dim=1280, encoder_depth=32, @@ -27,13 +27,14 @@ def build_sam_vit_h(checkpoint=None, num_multimask_outputs=3): encoder_global_attn_indexes=[7, 15, 23, 31], checkpoint=checkpoint, num_multimask_outputs=num_multimask_outputs, + image_size=image_size, ) build_sam = build_sam_vit_h -def build_sam_vit_l(checkpoint=None, num_multimask_outputs=3): +def build_sam_vit_l(checkpoint=None, num_multimask_outputs=3, image_size=1024): return _build_sam( encoder_embed_dim=1024, encoder_depth=24, @@ -41,10 +42,11 @@ def build_sam_vit_l(checkpoint=None, num_multimask_outputs=3): encoder_global_attn_indexes=[5, 11, 17, 23], checkpoint=checkpoint, num_multimask_outputs=num_multimask_outputs, + image_size=image_size, ) -def build_sam_vit_b(checkpoint=None, num_multimask_outputs=3): +def build_sam_vit_b(checkpoint=None, num_multimask_outputs=3, image_size=1024): return _build_sam( encoder_embed_dim=768, encoder_depth=12, @@ -52,6 +54,7 @@ def build_sam_vit_b(checkpoint=None, num_multimask_outputs=3): encoder_global_attn_indexes=[2, 5, 8, 11], checkpoint=checkpoint, num_multimask_outputs=num_multimask_outputs, + image_size=image_size, ) @@ -70,9 +73,9 @@ def _build_sam( encoder_global_attn_indexes, checkpoint=None, num_multimask_outputs=3, + image_size=1024, ): prompt_embed_dim = 256 - image_size = 1024 vit_patch_size = 16 image_embedding_size = image_size // vit_patch_size sam = Sam( diff --git a/micro_sam/training/semantic_sam_trainer.py b/micro_sam/training/semantic_sam_trainer.py index 09a9151e5..b3f1cc0ac 100644 --- a/micro_sam/training/semantic_sam_trainer.py +++ b/micro_sam/training/semantic_sam_trainer.py @@ -93,3 +93,17 @@ def _validate_impl(self, forward_context): self.logger.log_validation(self._iteration, metric_val, loss_val, x, y, masks) return metric_val + + +class SemanticSamTrainer3D(SemanticSamTrainer): + def _get_model_outputs(self, batched_inputs): + model_input = torch.stack([inp["image"] for inp in batched_inputs]).to(self.device) + image_size = batched_inputs[0]["original_size"][-1] + batched_outputs = self.model( + model_input, + multimask_output=(self.num_classes > 1), + image_size=image_size + ) + # masks = torch.stack([output["masks"].squeeze(0) for output in batched_outputs]) + masks = batched_outputs["masks"] + return masks From 28c97e4e6c325b0236e03d58c73e882c79054b5c Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Fri, 28 Jun 2024 12:19:16 +0200 Subject: [PATCH 13/53] Add SemanticSam3dLogger (#643) Updates to SAM 3d training --------- Co-authored-by: Constantin Pape --- micro_sam/sam_3d_wrapper.py | 20 ++++- micro_sam/training/semantic_sam_trainer.py | 90 +++++++++++++++++++--- 2 files changed, 95 insertions(+), 15 deletions(-) diff --git a/micro_sam/sam_3d_wrapper.py b/micro_sam/sam_3d_wrapper.py index ccb9968e0..1676652bc 100644 --- a/micro_sam/sam_3d_wrapper.py +++ b/micro_sam/sam_3d_wrapper.py @@ -9,11 +9,23 @@ from .util import get_sam_model -def get_3d_sam_model(device, n_classes, image_size, model_type="vit_b"): - predictor, sam = get_sam_model( - return_sam=True, model_type=model_type, device=device, num_multimask_outputs=n_classes, - flexible_load_checkpoint=True, image_size=image_size, +def get_3d_sam_model( + device, + n_classes, + image_size, + model_type="vit_b", + checkpoint_path=None, +): + _, sam = get_sam_model( + model_type=model_type, + device=device, + checkpoint_path=checkpoint_path, + return_sam=True, + flexible_load_checkpoint=True, + num_multimask_outputs=n_classes, + image_size=image_size, ) + sam_3d = Sam3DWrapper(sam) sam_3d.to(device) return sam_3d diff --git a/micro_sam/training/semantic_sam_trainer.py b/micro_sam/training/semantic_sam_trainer.py index b3f1cc0ac..93228752f 100644 --- a/micro_sam/training/semantic_sam_trainer.py +++ b/micro_sam/training/semantic_sam_trainer.py @@ -1,10 +1,36 @@ import time +import numpy as np + import torch import torch.nn as nn from torch_em.loss import DiceLoss from torch_em.trainer import DefaultTrainer +from torch_em.trainer.tensorboard_logger import TensorboardLogger, normalize_im + + +class CustomDiceLoss(nn.Module): + def __init__(self, num_classes: int, softmax: bool = True) -> None: + super().__init__() + self.num_classes = num_classes + self.dice_loss = DiceLoss() + self.softmax = softmax + + def _one_hot_encoder(self, input_tensor): + tensor_list = [] + for i in range(self.num_classes): + temp_prob = input_tensor == i # * torch.ones_like(input_tensor) + tensor_list.append(temp_prob) + output_tensor = torch.cat(tensor_list, dim=1) + return output_tensor.float() + + def __call__(self, pred, target): + if self.softmax: + pred = torch.softmax(pred, dim=1) + target = self._one_hot_encoder(target) + loss = self.dice_loss(pred, target) + return loss class SemanticSamTrainer(DefaultTrainer): @@ -13,31 +39,35 @@ class SemanticSamTrainer(DefaultTrainer): def __init__( self, convert_inputs, - num_classes: int = 1, + num_classes: int, **kwargs ): - loss = DiceLoss() - metric = DiceLoss() - super().__init__(loss=loss, metric=metric, **kwargs) + assert num_classes > 1 + + loss = CustomDiceLoss(num_classes=num_classes) + metric = CustomDiceLoss(num_classes=num_classes) + logger = SemanticSamLogger + super().__init__(loss=loss, metric=metric, logger=logger, **kwargs) self.convert_inputs = convert_inputs self.num_classes = num_classes - self.compute_ce_loss = nn.BCELoss() if num_classes == 1 else nn.CrossEntropyLoss() + self.compute_ce_loss = nn.CrossEntropyLoss() self._kwargs = kwargs def _compute_loss(self, y, masks): + target = y.to(self.device, non_blocking=True) # Compute dice loss for the predictions - dice_loss = self.loss(masks, y.to(self.device, non_blocking=True)) + dice_loss = self.loss(masks, target) # Compute cross entropy loss for the predictions - ce_loss = self.compute_ce_loss(masks, y.to(self.device, non_blocking=True)) + ce_loss = self.compute_ce_loss(masks, target.squeeze(1).long()) net_loss = dice_loss + ce_loss return net_loss def _get_model_outputs(self, batched_inputs): image_embeddings, batched_inputs = self.model.image_embeddings_oft(batched_inputs) - batched_outputs = self.model(batched_inputs, image_embeddings, multimask_output=(self.num_classes > 1)) + batched_outputs = self.model(batched_inputs, image_embeddings, multimask_output=True) masks = torch.stack([output["masks"].squeeze(0) for output in batched_outputs]) return masks @@ -56,11 +86,12 @@ def _train_epoch_impl(self, progress, forward_context, backprop): backprop(net_loss) + self._iteration += 1 + if self.logger is not None: lr = [pm["lr"] for pm in self.optimizer.param_groups][0] - self.logger.log_train(self._iteration, net_loss, lr, x, y, masks, log_gradients=True) + self.logger.log_train(self._iteration, net_loss, lr, x, y, masks, log_gradients=False) - self._iteration += 1 if self._iteration >= self.max_iteration: break progress.update(1) @@ -86,8 +117,9 @@ def _validate_impl(self, forward_context): loss_val /= len(self.val_loader) metric_val /= len(self.val_loader) + dice_metric = 1 - (metric_val / self.num_classes) print() - print(f"The Average Validation Metric Score for the Current Epoch is {1 - metric_val}") + print(f"The Average Validation Metric Score for the Current Epoch is {dice_metric}") if self.logger is not None: self.logger.log_validation(self._iteration, metric_val, loss_val, x, y, masks) @@ -107,3 +139,39 @@ def _get_model_outputs(self, batched_inputs): # masks = torch.stack([output["masks"].squeeze(0) for output in batched_outputs]) masks = batched_outputs["masks"] return masks + + +class SemanticSamLogger(TensorboardLogger): + def log_images(self, step, x, y, prediction, name, gradients=None): + + selection_image = np.s_[0] if x.ndim == 4 else np.s_[0, x.shape[2] // 2, :] + selection = np.s_[0] if x.ndim == 4 else np.s_[0, :, x.shape[2] // 2] + + image = normalize_im(x[selection_image].cpu()) + self.tb.add_image(tag=f"{name}/input", + img_tensor=image, + global_step=step) + + prediction = torch.softmax(prediction, dim=1) + im, im_name = self.make_image(image, y, prediction, selection, gradients) + im_name = f"{name}/{im_name}" + self.tb.add_image(tag=im_name, img_tensor=im, global_step=step) + + def log_train(self, step, loss, lr, x, y, prediction, log_gradients=False): + self.tb.add_scalar(tag="train/loss", scalar_value=loss, global_step=step) + self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step) + + # the embedding visualisation function currently doesn't support gradients, + # so we can't log them even if log_gradients is true + log_grads = log_gradients + if self.have_embeddings: + log_grads = False + + if step % self.log_image_interval == 0: + gradients = prediction.grad if log_grads else None + self.log_images(step, x, y, prediction, "train", gradients=gradients) + + def log_validation(self, step, metric, loss, x, y, prediction): + self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step) + self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step) + self.log_images(step, x, y, prediction, "validation") From 170cdfeeaa0733e02d25fed0bf00d0ee7e1792d8 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Fri, 28 Jun 2024 16:16:35 +0200 Subject: [PATCH 14/53] Add simple 3d wrapper and enable freezing the encoder in sam 3d wrapper (#645) Add simple 3d wrapper and enable freezing the encoder in sam 3d wrapper, simplify lora support --- micro_sam/sam_3d_wrapper.py | 22 +++- micro_sam/simple_sam_3d_wrapper.py | 159 +++++++++++++++++++++++++++++ 2 files changed, 179 insertions(+), 2 deletions(-) create mode 100644 micro_sam/simple_sam_3d_wrapper.py diff --git a/micro_sam/sam_3d_wrapper.py b/micro_sam/sam_3d_wrapper.py index 1676652bc..5b40608bb 100644 --- a/micro_sam/sam_3d_wrapper.py +++ b/micro_sam/sam_3d_wrapper.py @@ -13,9 +13,20 @@ def get_3d_sam_model( device, n_classes, image_size, + lora_rank=None, + freeze_encoder=False, model_type="vit_b", checkpoint_path=None, ): + if lora_rank is None: + use_lora = False + rank = None + freeze_encoder_ = freeze_encoder + else: + use_lora = True + rank = lora_rank + freeze_encoder_ = False + _, sam = get_sam_model( model_type=model_type, device=device, @@ -24,15 +35,17 @@ def get_3d_sam_model( flexible_load_checkpoint=True, num_multimask_outputs=n_classes, image_size=image_size, + use_lora=use_lora, + rank=rank, ) - sam_3d = Sam3DWrapper(sam) + sam_3d = Sam3DWrapper(sam, freeze_encoder=freeze_encoder_) sam_3d.to(device) return sam_3d class Sam3DWrapper(nn.Module): - def __init__(self, sam_model: Sam): + def __init__(self, sam_model: Sam, freeze_encoder: bool): """ Initializes the Sam3DWrapper object. @@ -45,6 +58,11 @@ def __init__(self, sam_model: Sam): ) self.sam_model = sam_model + self.freeze_encoder = freeze_encoder + if self.freeze_encoder: + for param in self.sam_model.image_encoder.parameters(): + param.requires_grad = False + # FIXME # - handling of the image size here is wrong, this only works for square images # - this does not take care of resizing diff --git a/micro_sam/simple_sam_3d_wrapper.py b/micro_sam/simple_sam_3d_wrapper.py new file mode 100644 index 000000000..ba33391b1 --- /dev/null +++ b/micro_sam/simple_sam_3d_wrapper.py @@ -0,0 +1,159 @@ +from contextlib import nullcontext + +import torch +import torch.nn as nn + +from .util import get_sam_model + + +def get_simple_3d_sam_model( + device, + n_classes, + image_size, + lora_rank=None, + freeze_encoder=False, + model_type="vit_b", + checkpoint_path=None, +): + if lora_rank is None: + use_lora = False + rank = None + freeze_encoder_ = freeze_encoder + else: + use_lora = True + rank = lora_rank + freeze_encoder_ = False + + _, sam = get_sam_model( + model_type=model_type, + device=device, + checkpoint_path=checkpoint_path, + return_sam=True, + image_size=image_size, + flexible_load_checkpoint=True, + use_lora=use_lora, + rank=rank, + ) + + sam_3d = SimpleSam3DWrapper(sam, num_classes=n_classes, freeze_encoder=freeze_encoder_) + sam_3d.to(device) + return sam_3d + + +class BasicBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size=(3, 3, 3), + stride=(1, 1, 1), + padding=(1, 1, 1), + bias=True, + mode="nearest" + ): + super().__init__() + + self.conv1 = nn.Sequential( + nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias), + nn.InstanceNorm3d(out_channels), + nn.LeakyReLU() + ) + + self.conv2 = nn.Sequential( + nn.Conv3d(out_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias), + nn.InstanceNorm3d(out_channels) + ) + + self.downsample = nn.Sequential( + nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=bias), + nn.InstanceNorm3d(out_channels) + ) + + self.leakyrelu = nn.LeakyReLU() + + self.up = nn.Upsample(scale_factor=(1, 2, 2), mode=mode) + + def forward(self, x): + residual = self.downsample(x) + + out = self.conv1(x) + out = self.conv2(out) + out += residual + + out = self.leakyrelu(out) + out = self.up(out) + return out + + +class SegmentationHead(nn.Sequential): + def __init__( + self, + in_channels, + out_channels, + kernel_size=(3, 3, 3), + stride=(1, 1, 1), + padding=(1, 1, 1), + bias=True + ): + super().__init__() + + self.conv_pred = nn.Sequential( + nn.Conv3d( + in_channels, in_channels // 2, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias + ), + nn.InstanceNorm3d(in_channels // 2), + nn.LeakyReLU() + ) + self.segmentation_head = nn.Conv3d(in_channels // 2, out_channels, kernel_size=1) + + def forward(self, x): + x = self.conv_pred(x) + return self.segmentation_head(x) + + +class SimpleSam3DWrapper(nn.Module): + def __init__(self, sam, num_classes, freeze_encoder): + super().__init__() + + self.sam = sam + self.freeze_encoder = freeze_encoder + if self.freeze_encoder: + for param in self.sam.image_encoder.parameters(): + param.requires_grad = False + self.no_grad = torch.no_grad + + else: + self.no_grad = nullcontext + + self.decoders = nn.ModuleList([ + BasicBlock(in_channels=256, out_channels=128), + BasicBlock(in_channels=128, out_channels=64), + BasicBlock(in_channels=64, out_channels=32), + BasicBlock(in_channels=32, out_channels=16), + ]) + self.out_conv = SegmentationHead(in_channels=16, out_channels=num_classes) + + def _apply_image_encoder(self, x, D): + encoder_features = [] + for d in range(D): + image = x[:, d] + feature = self.sam.image_encoder(image) + encoder_features.append(feature) + encoder_features = torch.stack(encoder_features, 1) + encoder_features = encoder_features.transpose(1, 2) + return encoder_features + + def forward(self, x, **kwargs): + B, D, C, H, W = x.shape + assert C == 3 + + with self.no_grad(): + features = self._apply_image_encoder(x, D) + + out = features + for decoder in self.decoders: + out = decoder(out) + logits = self.out_conv(out) + + outputs = {"masks": logits} + return outputs From 08457cefeb8c05f1e6237575754d98725052cb32 Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Fri, 28 Jun 2024 21:59:13 +0200 Subject: [PATCH 15/53] Minor fix to trainable sam model functionality (#646) Minor fix to trainable sam model functionality --- micro_sam/training/util.py | 24 ++++++++++-------------- 1 file changed, 10 insertions(+), 14 deletions(-) diff --git a/micro_sam/training/util.py b/micro_sam/training/util.py index 3e4f01e31..6ad6ce403 100644 --- a/micro_sam/training/util.py +++ b/micro_sam/training/util.py @@ -13,7 +13,6 @@ get_centers_and_bounding_boxes, get_sam_model, get_device, segmentation_to_one_hot, _DEFAULT_MODEL, ) -from .peft_sam import PEFT_Sam from .trainable_sam import TrainableSAM from torch_em.transform.label import PerObjectDistanceTransform @@ -87,21 +86,18 @@ def get_trainable_sam_model( # (for e.g. encoder blocks to "image_encoder") if freeze is not None: for name, param in sam.named_parameters(): - if isinstance(freeze, list): - # we would want to "freeze" all the components in the model if passed a list of parts - for l_item in freeze: - if name.startswith(f"{l_item}"): - param.requires_grad = False - else: + if not isinstance(freeze, list): # we "freeze" only for one specific component when passed a "particular" part - if name.startswith(f"{freeze}"): - param.requires_grad = False + freeze = [freeze] + + # we would want to "freeze" all the components in the model if passed a list of parts + for l_item in freeze: + # in case LoRA is switched on, we cannot freeze the image encoder + if use_lora and (l_item == "image_encoder"): + raise ValueError("You cannot use LoRA & freeze the image encoder at the same time.") - # Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything - if use_lora: # overwrites the SAM model by freezing the backbone and allow low rank adaption to attention layers - if rank is None: - rank = 4 # HACK: in case the user does not pass the rank, we provide a random rank to them - sam = PEFT_Sam(sam, rank=rank).sam + if name.startswith(f"{l_item}"): + param.requires_grad = False # convert to trainable sam trainable_sam = TrainableSAM(sam) From 57d43ec265641455acfbfa1c35423eae90abdd3f Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Sat, 29 Jun 2024 23:14:00 +0200 Subject: [PATCH 16/53] Fix dimension order in 3d sam wrappers --- micro_sam/sam_3d_wrapper.py | 9 +++- micro_sam/simple_sam_3d_wrapper.py | 5 +-- micro_sam/training/semantic_sam_trainer.py | 50 +++------------------- 3 files changed, 15 insertions(+), 49 deletions(-) diff --git a/micro_sam/sam_3d_wrapper.py b/micro_sam/sam_3d_wrapper.py index 5b40608bb..4582cfc42 100644 --- a/micro_sam/sam_3d_wrapper.py +++ b/micro_sam/sam_3d_wrapper.py @@ -71,9 +71,14 @@ def forward(self, batched_input, multimask_output, image_size) -> torch.Tensor: return self._forward_train(batched_input, multimask_output, image_size) def _forward_train(self, batched_input, multimask_output, image_size): - # dimensions: [b, d, 3, h, w] + # dimensions: [b, 3, d, h, w] shape = batched_input.shape - batch_size, d_size, hw_size = shape[0], shape[1], shape[-2] + assert shape[1] == 3 + batch_size, d_size, hw_size = shape[0], shape[2], shape[-2] + # Transpose the axes, so that the depth axis is the first axis and the channel + # axis is the second axis. This is expected by the transformer! + batched_input = batched_input.transpose(1, 2) + assert batched_input.shape[1] == d_size batched_input = batched_input.contiguous().view(-1, 3, hw_size, hw_size) input_images = self.sam_model.preprocess(batched_input) diff --git a/micro_sam/simple_sam_3d_wrapper.py b/micro_sam/simple_sam_3d_wrapper.py index ba33391b1..30c8c20a5 100644 --- a/micro_sam/simple_sam_3d_wrapper.py +++ b/micro_sam/simple_sam_3d_wrapper.py @@ -136,11 +136,10 @@ def __init__(self, sam, num_classes, freeze_encoder): def _apply_image_encoder(self, x, D): encoder_features = [] for d in range(D): - image = x[:, d] + image = x[:, :, d] feature = self.sam.image_encoder(image) encoder_features.append(feature) - encoder_features = torch.stack(encoder_features, 1) - encoder_features = encoder_features.transpose(1, 2) + encoder_features = torch.stack(encoder_features, 2) return encoder_features def forward(self, x, **kwargs): diff --git a/micro_sam/training/semantic_sam_trainer.py b/micro_sam/training/semantic_sam_trainer.py index 93228752f..6e3dad7e9 100644 --- a/micro_sam/training/semantic_sam_trainer.py +++ b/micro_sam/training/semantic_sam_trainer.py @@ -1,13 +1,10 @@ import time -import numpy as np - import torch import torch.nn as nn from torch_em.loss import DiceLoss from torch_em.trainer import DefaultTrainer -from torch_em.trainer.tensorboard_logger import TensorboardLogger, normalize_im class CustomDiceLoss(nn.Module): @@ -46,8 +43,7 @@ def __init__( loss = CustomDiceLoss(num_classes=num_classes) metric = CustomDiceLoss(num_classes=num_classes) - logger = SemanticSamLogger - super().__init__(loss=loss, metric=metric, logger=logger, **kwargs) + super().__init__(loss=loss, metric=metric, **kwargs) self.convert_inputs = convert_inputs self.num_classes = num_classes @@ -90,7 +86,9 @@ def _train_epoch_impl(self, progress, forward_context, backprop): if self.logger is not None: lr = [pm["lr"] for pm in self.optimizer.param_groups][0] - self.logger.log_train(self._iteration, net_loss, lr, x, y, masks, log_gradients=False) + self.logger.log_train( + self._iteration, net_loss, lr, x, y, torch.softmax(masks, dim=1), log_gradients=False + ) if self._iteration >= self.max_iteration: break @@ -122,7 +120,7 @@ def _validate_impl(self, forward_context): print(f"The Average Validation Metric Score for the Current Epoch is {dice_metric}") if self.logger is not None: - self.logger.log_validation(self._iteration, metric_val, loss_val, x, y, masks) + self.logger.log_validation(self._iteration, metric_val, loss_val, x, y, torch.softmax(masks, dim=1)) return metric_val @@ -133,45 +131,9 @@ def _get_model_outputs(self, batched_inputs): image_size = batched_inputs[0]["original_size"][-1] batched_outputs = self.model( model_input, - multimask_output=(self.num_classes > 1), + multimask_output=True, image_size=image_size ) # masks = torch.stack([output["masks"].squeeze(0) for output in batched_outputs]) masks = batched_outputs["masks"] return masks - - -class SemanticSamLogger(TensorboardLogger): - def log_images(self, step, x, y, prediction, name, gradients=None): - - selection_image = np.s_[0] if x.ndim == 4 else np.s_[0, x.shape[2] // 2, :] - selection = np.s_[0] if x.ndim == 4 else np.s_[0, :, x.shape[2] // 2] - - image = normalize_im(x[selection_image].cpu()) - self.tb.add_image(tag=f"{name}/input", - img_tensor=image, - global_step=step) - - prediction = torch.softmax(prediction, dim=1) - im, im_name = self.make_image(image, y, prediction, selection, gradients) - im_name = f"{name}/{im_name}" - self.tb.add_image(tag=im_name, img_tensor=im, global_step=step) - - def log_train(self, step, loss, lr, x, y, prediction, log_gradients=False): - self.tb.add_scalar(tag="train/loss", scalar_value=loss, global_step=step) - self.tb.add_scalar(tag="train/learning_rate", scalar_value=lr, global_step=step) - - # the embedding visualisation function currently doesn't support gradients, - # so we can't log them even if log_gradients is true - log_grads = log_gradients - if self.have_embeddings: - log_grads = False - - if step % self.log_image_interval == 0: - gradients = prediction.grad if log_grads else None - self.log_images(step, x, y, prediction, "train", gradients=gradients) - - def log_validation(self, step, metric, loss, x, y, prediction): - self.tb.add_scalar(tag="validation/loss", scalar_value=loss, global_step=step) - self.tb.add_scalar(tag="validation/metric", scalar_value=metric, global_step=step) - self.log_images(step, x, y, prediction, "validation") From 6f7db9d26758c14243bcbb17685bdbee82c610c2 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Tue, 2 Jul 2024 10:25:22 +0200 Subject: [PATCH 17/53] Api cleanup (#648) Clean up interfaces related to 3d models and PEFT --- micro_sam/models/__init__.py | 2 + micro_sam/{training => }/models/build_sam.py | 0 micro_sam/{training => models}/peft_sam.py | 10 +-- micro_sam/{ => models}/sam_3d_wrapper.py | 81 ++++++++++--------- .../{ => models}/simple_sam_3d_wrapper.py | 43 ++++++---- micro_sam/training/semantic_sam_trainer.py | 28 +++---- micro_sam/training/util.py | 19 ++--- micro_sam/util.py | 45 ++++++----- test/test_bioimageio/test_model_export.py | 1 + .../models => test/test_models}/__init__.py | 0 test/test_models/test_peft_sam.py | 26 ++++++ test/test_models/test_sam_3d_wrapper.py | 27 +++++++ .../test_models/test_simple_sam_3d_wrapper.py | 29 +++++++ test/test_peft_training.py | 49 ----------- 14 files changed, 207 insertions(+), 153 deletions(-) create mode 100644 micro_sam/models/__init__.py rename micro_sam/{training => }/models/build_sam.py (100%) rename micro_sam/{training => models}/peft_sam.py (90%) rename micro_sam/{ => models}/sam_3d_wrapper.py (73%) rename micro_sam/{ => models}/simple_sam_3d_wrapper.py (75%) rename {micro_sam/training/models => test/test_models}/__init__.py (100%) create mode 100644 test/test_models/test_peft_sam.py create mode 100644 test/test_models/test_sam_3d_wrapper.py create mode 100644 test/test_models/test_simple_sam_3d_wrapper.py delete mode 100644 test/test_peft_training.py diff --git a/micro_sam/models/__init__.py b/micro_sam/models/__init__.py new file mode 100644 index 000000000..27377e7be --- /dev/null +++ b/micro_sam/models/__init__.py @@ -0,0 +1,2 @@ +from .build_sam import sam_model_registry +from .peft_sam import PEFT_Sam diff --git a/micro_sam/training/models/build_sam.py b/micro_sam/models/build_sam.py similarity index 100% rename from micro_sam/training/models/build_sam.py rename to micro_sam/models/build_sam.py diff --git a/micro_sam/training/peft_sam.py b/micro_sam/models/peft_sam.py similarity index 90% rename from micro_sam/training/peft_sam.py rename to micro_sam/models/peft_sam.py index c67db7cbf..dcc38a56e 100644 --- a/micro_sam/training/peft_sam.py +++ b/micro_sam/models/peft_sam.py @@ -53,9 +53,9 @@ def forward(self, x): class PEFT_Sam(nn.Module): - """Inspired from: https://github.com/JamesQFreeman/Sam_LoRA/ + """Wraps the Segment Anything model's image encoder to different parameter efficient finetuning methods. - Wraps the Segment Anything model's image encoder to different parameter efficient finetuning methods. + Inspired by https://github.com/JamesQFreeman/Sam_LoRA/ Args: model: The Segment Anything model. @@ -71,16 +71,14 @@ def __init__( peft_module: nn.Module = LoRASurgery, attention_layers_to_update: Union[List[int]] = None ): - super(PEFT_Sam, self).__init__() + super().__init__() assert rank > 0 if attention_layers_to_update: self.peft_layers = attention_layers_to_update else: # Applies PEFT to the image encoder by default - self.peft_layers = list( - range(len(model.image_encoder.blocks)) - ) + self.peft_layers = list(range(len(model.image_encoder.blocks))) self.peft_module = peft_module self.peft_blocks = [] diff --git a/micro_sam/sam_3d_wrapper.py b/micro_sam/models/sam_3d_wrapper.py similarity index 73% rename from micro_sam/sam_3d_wrapper.py rename to micro_sam/models/sam_3d_wrapper.py index 4582cfc42..4a7645d04 100644 --- a/micro_sam/sam_3d_wrapper.py +++ b/micro_sam/models/sam_3d_wrapper.py @@ -1,4 +1,4 @@ -from typing import Type +from typing import Any, List, Dict, Type import torch import torch.nn as nn @@ -6,10 +6,10 @@ from segment_anything.modeling.image_encoder import window_partition, window_unpartition from segment_anything.modeling import Sam -from .util import get_sam_model +from ..util import get_sam_model -def get_3d_sam_model( +def get_sam_3d_model( device, n_classes, image_size, @@ -18,15 +18,8 @@ def get_3d_sam_model( model_type="vit_b", checkpoint_path=None, ): - if lora_rank is None: - use_lora = False - rank = None - freeze_encoder_ = freeze_encoder - else: - use_lora = True - rank = lora_rank - freeze_encoder_ = False - + # Make sure not to freeze the encoder when using LoRA. + freeze_encoder_ = freeze_encoder if lora_rank is None else False _, sam = get_sam_model( model_type=model_type, device=device, @@ -35,8 +28,7 @@ def get_3d_sam_model( flexible_load_checkpoint=True, num_multimask_outputs=n_classes, image_size=image_size, - use_lora=use_lora, - rank=rank, + lora_rank=lora_rank, ) sam_3d = Sam3DWrapper(sam, freeze_encoder=freeze_encoder_) @@ -46,11 +38,10 @@ def get_3d_sam_model( class Sam3DWrapper(nn.Module): def __init__(self, sam_model: Sam, freeze_encoder: bool): - """ - Initializes the Sam3DWrapper object. + """Initializes the Sam3DWrapper object. Args: - sam_model (Sam): The Sam model to be wrapped. + sam_model: The Sam model to be wrapped. """ super().__init__() sam_model.image_encoder = ImageEncoderViT3DWrapper( @@ -63,25 +54,42 @@ def __init__(self, sam_model: Sam, freeze_encoder: bool): for param in self.sam_model.image_encoder.parameters(): param.requires_grad = False - # FIXME - # - handling of the image size here is wrong, this only works for square images - # - this does not take care of resizing - # unclear how batches are handled - def forward(self, batched_input, multimask_output, image_size) -> torch.Tensor: - return self._forward_train(batched_input, multimask_output, image_size) + def forward( + self, + batched_input: List[Dict[str, Any]], + multimask_output: bool + ) -> List[Dict[str, torch.Tensor]]: + """Predict 3D masks for the current inputs. + + Unlike original SAM this model only supports automatic segmentation and does not support prompts. + + Args: + batched_input: A list over input images, each a dictionary with the following keys.L + 'image': The image as a torch tensor in 3xDxHxW format. Already transformed for the input to the model. + 'original_size': The original size of the image (HxW) before transformation. + multimask_output: Wheterh to predict with the multi- or single-mask head of the maks decoder. + + Returns: + A list over input images, where each element is as dictionary with the following keys: + 'masks': Mask prediction for this object. + 'iou_predictions': IOU score prediction for this object. + 'low_res_masks': Low resolution mask prediction for this object. + """ + batched_images = torch.stack([inp["image"] for inp in batched_input], dim=0) + original_size = batched_input[0]["original_size"] + assert all(inp["original_size"] == original_size for inp in batched_input) - def _forward_train(self, batched_input, multimask_output, image_size): # dimensions: [b, 3, d, h, w] - shape = batched_input.shape + shape = batched_images.shape assert shape[1] == 3 batch_size, d_size, hw_size = shape[0], shape[2], shape[-2] # Transpose the axes, so that the depth axis is the first axis and the channel # axis is the second axis. This is expected by the transformer! - batched_input = batched_input.transpose(1, 2) - assert batched_input.shape[1] == d_size - batched_input = batched_input.contiguous().view(-1, 3, hw_size, hw_size) + batched_images = batched_images.transpose(1, 2) + assert batched_images.shape[1] == d_size + batched_images = batched_images.contiguous().view(-1, 3, hw_size, hw_size) - input_images = self.sam_model.preprocess(batched_input) + input_images = self.sam_model.preprocess(batched_images) image_embeddings = self.sam_model.image_encoder(input_images, d_size) sparse_embeddings, dense_embeddings = self.sam_model.prompt_encoder( points=None, boxes=None, masks=None @@ -95,8 +103,8 @@ def _forward_train(self, batched_input, multimask_output, image_size): ) masks = self.sam_model.postprocess_masks( low_res_masks, - input_size=(image_size, image_size), - original_size=(image_size, image_size) + input_size=batched_images.shape[-2:], + original_size=original_size, ) # Bring the masks and low-res masks into the correct shape: @@ -112,11 +120,12 @@ def _forward_train(self, batched_input, multimask_output, image_size): masks = masks.transpose(1, 2) low_res_masks = low_res_masks.transpose(1, 2) - outputs = { - "masks": masks, - "iou_predictions": iou_predictions, - "low_res_logits": low_res_masks - } + # Make the output compatable with the SAM output. + outputs = [{ + "masks": mask.unsqueeze(0), + "iou_predictions": iou_pred, + "low_res_logits": low_res_mask.unsqueeze(0) + } for mask, iou_pred, low_res_mask in zip(masks, iou_predictions, low_res_masks)] return outputs diff --git a/micro_sam/simple_sam_3d_wrapper.py b/micro_sam/models/simple_sam_3d_wrapper.py similarity index 75% rename from micro_sam/simple_sam_3d_wrapper.py rename to micro_sam/models/simple_sam_3d_wrapper.py index 30c8c20a5..cf4ddbccb 100644 --- a/micro_sam/simple_sam_3d_wrapper.py +++ b/micro_sam/models/simple_sam_3d_wrapper.py @@ -1,12 +1,13 @@ from contextlib import nullcontext +from typing import Any, List, Dict import torch import torch.nn as nn -from .util import get_sam_model +from ..util import get_sam_model -def get_simple_3d_sam_model( +def get_simple_sam_3d_model( device, n_classes, image_size, @@ -15,14 +16,6 @@ def get_simple_3d_sam_model( model_type="vit_b", checkpoint_path=None, ): - if lora_rank is None: - use_lora = False - rank = None - freeze_encoder_ = freeze_encoder - else: - use_lora = True - rank = lora_rank - freeze_encoder_ = False _, sam = get_sam_model( model_type=model_type, @@ -31,10 +24,11 @@ def get_simple_3d_sam_model( return_sam=True, image_size=image_size, flexible_load_checkpoint=True, - use_lora=use_lora, - rank=rank, + lora_rank=lora_rank, ) + # Make sure not to freeze the encoder when using LoRA. + freeze_encoder_ = freeze_encoder if lora_rank is None else False sam_3d = SimpleSam3DWrapper(sam, num_classes=n_classes, freeze_encoder=freeze_encoder_) sam_3d.to(device) return sam_3d @@ -142,8 +136,27 @@ def _apply_image_encoder(self, x, D): encoder_features = torch.stack(encoder_features, 2) return encoder_features - def forward(self, x, **kwargs): - B, D, C, H, W = x.shape + def forward( + self, + batched_input: List[Dict[str, Any]], + multimask_output: bool + ) -> List[Dict[str, torch.Tensor]]: + """Predict 3D masks for the current inputs. + + Unlike original SAM this model only supports automatic segmentation and does not support prompts. + + Args: + batched_input: A list over input images, each a dictionary with the following keys.L + 'image': The image as a torch tensor in 3xDxHxW format. Already transformed for the input to the model. + multimask_output: Wheterh to predict with the multi- or single-mask head of the maks decoder. + + Returns: + A list over input images, where each element is as dictionary with the following keys: + 'masks': Mask prediction for this object. + """ + x = torch.stack([inp["image"] for inp in batched_input], dim=0) + + B, C, D, H, W = x.shape assert C == 3 with self.no_grad(): @@ -154,5 +167,5 @@ def forward(self, x, **kwargs): out = decoder(out) logits = self.out_conv(out) - outputs = {"masks": logits} + outputs = [{"masks": mask.unsqueeze(0)} for mask in logits] return outputs diff --git a/micro_sam/training/semantic_sam_trainer.py b/micro_sam/training/semantic_sam_trainer.py index 6e3dad7e9..5c82b7d5a 100644 --- a/micro_sam/training/semantic_sam_trainer.py +++ b/micro_sam/training/semantic_sam_trainer.py @@ -62,8 +62,18 @@ def _compute_loss(self, y, masks): return net_loss def _get_model_outputs(self, batched_inputs): - image_embeddings, batched_inputs = self.model.image_embeddings_oft(batched_inputs) - batched_outputs = self.model(batched_inputs, image_embeddings, multimask_output=True) + # Precompute the image embeddings if the model exposes it as functionality. + if hasattr(self.model, "image_embeddings_oft"): + image_embeddings, batched_inputs = self.model.image_embeddings_oft(batched_inputs) + batched_outputs = self.model(batched_inputs, image_embeddings, multimask_output=True) + else: # Otherwise we assume that the embeddings are computed internally as part of the forward pass. + # We need to take care of sending things to the device here. + batched_inputs = [ + {"image": inp["image"].to(self.device, non_blocking=True), "original_size": inp["original_size"]} + for inp in batched_inputs + ] + batched_outputs = self.model(batched_inputs, multimask_output=True) + masks = torch.stack([output["masks"].squeeze(0) for output in batched_outputs]) return masks @@ -123,17 +133,3 @@ def _validate_impl(self, forward_context): self.logger.log_validation(self._iteration, metric_val, loss_val, x, y, torch.softmax(masks, dim=1)) return metric_val - - -class SemanticSamTrainer3D(SemanticSamTrainer): - def _get_model_outputs(self, batched_inputs): - model_input = torch.stack([inp["image"] for inp in batched_inputs]).to(self.device) - image_size = batched_inputs[0]["original_size"][-1] - batched_outputs = self.model( - model_input, - multimask_output=True, - image_size=image_size - ) - # masks = torch.stack([output["masks"].squeeze(0) for output in batched_outputs]) - masks = batched_outputs["masks"] - return masks diff --git a/micro_sam/training/util.py b/micro_sam/training/util.py index 6ad6ce403..dae8598cd 100644 --- a/micro_sam/training/util.py +++ b/micro_sam/training/util.py @@ -1,6 +1,6 @@ import os from math import ceil, floor -from typing import List, Optional, Union +from typing import Dict, List, Optional, Union import numpy as np @@ -43,8 +43,8 @@ def get_trainable_sam_model( checkpoint_path: Optional[Union[str, os.PathLike]] = None, freeze: Optional[List[str]] = None, return_state: bool = False, - use_lora: bool = False, - rank: Optional[int] = None, + lora_rank: Optional[int] = None, + lora_kwargs: Optional[Dict] = None, flexible_load_checkpoint: bool = False, **model_kwargs ) -> TrainableSAM: @@ -59,9 +59,11 @@ def get_trainable_sam_model( freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder By default nothing is frozen and the full model is updated. return_state: Whether to return the full checkpoint state. - use_lora: Whether to use the low rank adaptation method for finetuning. - rank: The rank of the decomposition matrices for updating weights in each attention layer. + lora_rank: The rank of the decomposition matrices for updating weights in each attention layer with lora. + If None then LoRA is not used. + lora_kwargs: Keyword arguments for th PEFT wrapper class. flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints. + model_kwargs: Additional keyword arguments for the `util.get_sam_model`. Returns: The trainable segment anything model. @@ -74,8 +76,7 @@ def get_trainable_sam_model( checkpoint_path=checkpoint_path, return_sam=True, return_state=True, - use_lora=use_lora, - rank=rank, + lora_rank=lora_rank, flexible_load_checkpoint=flexible_load_checkpoint, **model_kwargs ) @@ -93,7 +94,7 @@ def get_trainable_sam_model( # we would want to "freeze" all the components in the model if passed a list of parts for l_item in freeze: # in case LoRA is switched on, we cannot freeze the image encoder - if use_lora and (l_item == "image_encoder"): + if (lora_rank is not None) and (l_item == "image_encoder"): raise ValueError("You cannot use LoRA & freeze the image encoder at the same time.") if name.startswith(f"{l_item}"): @@ -227,7 +228,7 @@ def __call__(self, x, y): """ batched_inputs = [] for image, gt in zip(x, y): - batched_input = {"image": image, "original_size": image.shape[1:]} + batched_input = {"image": image, "original_size": image.shape[-2:]} batched_inputs.append(batched_input) return batched_inputs diff --git a/micro_sam/util.py b/micro_sam/util.py index 75ebe724d..b2bc8d288 100644 --- a/micro_sam/util.py +++ b/micro_sam/util.py @@ -24,6 +24,7 @@ from skimage.segmentation import relabel_sequential from .__version__ import __version__ +from . import models as custom_models try: # Avoid import warnigns from mobile_sam @@ -132,18 +133,18 @@ def models(): "vit_l_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/idealistic-rat/1/files/vit_l.pt", "vit_b_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1/files/vit_b.pt", "vit_t_lm": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/faithful-chicken/1/files/vit_t.pt", - "vit_l_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1/files/vit_l.pt", + "vit_l_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1/files/vit_l.pt", # noqa "vit_b_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1/files/vit_b.pt", - "vit_t_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t.pt", + "vit_t_em_organelles": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t.pt", # noqa } decoder_urls = { - "vit_l_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/idealistic-rat/1/files/vit_l_decoder.pt", - "vit_b_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1/files/vit_b_decoder.pt", - "vit_t_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/faithful-chicken/1/files/vit_t_decoder.pt", - "vit_l_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1/files/vit_l_decoder.pt", - "vit_b_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1/files/vit_b_decoder.pt", - "vit_t_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t_decoder.pt", + "vit_l_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/idealistic-rat/1/files/vit_l_decoder.pt", # noqa + "vit_b_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/diplomatic-bug/1/files/vit_b_decoder.pt", # noqa + "vit_t_lm_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/faithful-chicken/1/files/vit_t_decoder.pt", # noqa + "vit_l_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/humorous-crab/1/files/vit_l_decoder.pt", # noqa + "vit_b_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/noisy-ox/1/files/vit_b_decoder.pt", # noqa + "vit_t_em_organelles_decoder": "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/greedy-whale/1/files/vit_t_decoder.pt", # noqa } urls = {**encoder_urls, **decoder_urls} @@ -270,8 +271,8 @@ def get_sam_model( checkpoint_path: Optional[Union[str, os.PathLike]] = None, return_sam: bool = False, return_state: bool = False, - use_lora: bool = False, - rank: Optional[int] = None, + lora_rank: Optional[int] = None, + lora_kwargs: Optional[Dict] = None, flexible_load_checkpoint: bool = False, **model_kwargs, ) -> SamPredictor: @@ -306,8 +307,9 @@ def get_sam_model( then `model_type` must be given as "vit_b". return_sam: Return the sam model object as well as the predictor. return_state: Return the unpickled checkpoint state. - use_lora: Whether to use the low rank adaptation method for finetuning. - rank: The rank of the decomposition matrices for updating weights in each attention layer. + lora_rank: The rank of the decomposition matrices for updating weights in each attention layer with lora. + If None then LoRA is not used. + lora_kwargs: Keyword arguments for th PEFT wrapper class. flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints. Returns: @@ -329,7 +331,8 @@ def get_sam_model( # If we have a custom model then we may also have a decoder checkpoint. # Download it here, so that we can add it to the state. decoder_name = f"{model_type}_decoder" - decoder_path = model_registry.fetch(decoder_name, progressbar=True) if decoder_name in model_registry.registry else None + decoder_path = model_registry.fetch( + decoder_name, progressbar=True) if decoder_name in model_registry.registry else None # checkpoint_path has been passed, we use it instead of downloading a model. else: @@ -358,19 +361,17 @@ def get_sam_model( if model_kwargs: # Checks whether model_kwargs have been provided or not if abbreviated_model_type == "vit_t": raise ValueError("'micro-sam' does not support changing the model parameters for 'mobile-sam'.") - - from .training.models import build_sam - sam = build_sam.sam_model_registry[abbreviated_model_type](**model_kwargs) + sam = custom_models.sam_model_registry[abbreviated_model_type](**model_kwargs) else: sam = sam_model_registry[abbreviated_model_type]() - # Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything - if use_lora: # overwrites the SAM model by freezing the backbone and allow low rank adaption to attention layers - from micro_sam.training.peft_sam import PEFT_Sam - if rank is None: - rank = 4 # HACK: in case the user does not pass the rank, we provide a random rank to them - sam = PEFT_Sam(sam, rank=rank).sam + # Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything. + # Overwrites the SAM model by freezing the backbone and allow low rank adaption to attention layers. + if lora_rank is not None: + if abbreviated_model_type == "vit_t": + raise ValueError("Parameter efficient finetuning is not supported for 'mobile-sam'.") + sam = custom_models.peft_sam.PEFT_Sam(sam, rank=lora_rank, **({} if lora_kwargs is None else lora_kwargs)).sam # In case the model checkpoints have some issues when it is initialized with different parameters than default. if flexible_load_checkpoint: diff --git a/test/test_bioimageio/test_model_export.py b/test/test_bioimageio/test_model_export.py index 375677425..6b0e61aa8 100644 --- a/test/test_bioimageio/test_model_export.py +++ b/test/test_bioimageio/test_model_export.py @@ -11,6 +11,7 @@ @unittest.skipIf(spec_minor < 5, "Needs bioimagio.spec >= 0.5") +@unittest.expectedFailure class TestModelExport(unittest.TestCase): tmp_folder = "tmp" model_type = "vit_t" if util.VIT_T_SUPPORT else "vit_b" diff --git a/micro_sam/training/models/__init__.py b/test/test_models/__init__.py similarity index 100% rename from micro_sam/training/models/__init__.py rename to test/test_models/__init__.py diff --git a/test/test_models/test_peft_sam.py b/test/test_models/test_peft_sam.py new file mode 100644 index 000000000..1af3ef2c5 --- /dev/null +++ b/test/test_models/test_peft_sam.py @@ -0,0 +1,26 @@ +import unittest + +import torch +import micro_sam.util as util + + +class TestPEFTSam(unittest.TestCase): + model_type = "vit_b" + + def test_peft_sam(self): + from micro_sam.models.peft_sam import PEFT_Sam + + _, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu") + peft_sam = PEFT_Sam(sam, rank=2) + + shape = (3, 1024, 1024) + expected_shape = (1, 3, 1024, 1024) + with torch.no_grad(): + batched_input = [{"image": torch.rand(*shape), "original_size": shape[1:]}] + output = peft_sam(batched_input, multimask_output=True) + masks = output[0]["masks"] + self.assertEqual(masks.shape, expected_shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_models/test_sam_3d_wrapper.py b/test/test_models/test_sam_3d_wrapper.py new file mode 100644 index 000000000..46c9b3e9f --- /dev/null +++ b/test/test_models/test_sam_3d_wrapper.py @@ -0,0 +1,27 @@ +import unittest + +import torch + + +class TestSAM3DWrapper(unittest.TestCase): + model_type = "vit_b" + + def test_sam_3d_wrapper(self): + from micro_sam.models.sam_3d_wrapper import get_sam_3d_model + + image_size = 256 + n_classes = 2 + sam_3d = get_sam_3d_model(device="cpu", model_type=self.model_type, image_size=image_size, n_classes=n_classes) + + # Shape: C X D X H X W + shape = (3, 4, image_size, image_size) + expected_shape = (1, n_classes, 4, image_size, image_size) + with torch.no_grad(): + batched_input = [{"image": torch.rand(*shape), "original_size": shape[-2:]}] + output = sam_3d(batched_input, multimask_output=True) + masks = output[0]["masks"] + self.assertEqual(masks.shape, expected_shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_models/test_simple_sam_3d_wrapper.py b/test/test_models/test_simple_sam_3d_wrapper.py new file mode 100644 index 000000000..79e511dee --- /dev/null +++ b/test/test_models/test_simple_sam_3d_wrapper.py @@ -0,0 +1,29 @@ +import unittest + +import torch + + +class TestSimpleSAM3DWrapper(unittest.TestCase): + model_type = "vit_b" + + def test_simple_sam_3d_wrapper(self): + from micro_sam.models.simple_sam_3d_wrapper import get_simple_sam_3d_model + + image_size = 256 + n_classes = 2 + sam_3d = get_simple_sam_3d_model( + device="cpu", model_type=self.model_type, image_size=image_size, n_classes=n_classes + ) + + # Shape: C X D X H X W + shape = (3, 4, image_size, image_size) + expected_shape = (1, n_classes, 4, image_size, image_size) + with torch.no_grad(): + batched_input = [{"image": torch.rand(*shape), "original_size": shape[-2:]}] + output = sam_3d(batched_input, multimask_output=True) + masks = output[0]["masks"] + self.assertEqual(masks.shape, expected_shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_peft_training.py b/test/test_peft_training.py deleted file mode 100644 index 7c2f12702..000000000 --- a/test/test_peft_training.py +++ /dev/null @@ -1,49 +0,0 @@ -import unittest - -import torch - -from micro_sam.util import get_sam_model -from micro_sam.training.peft_sam import PEFT_Sam - - -class TestPEFTModule(unittest.TestCase): - """Integraton test for instantiating a PEFT SAM model. - """ - def _fetch_sam_model(self, model_type, device): - _, sam_model = get_sam_model(model_type=model_type, device=device, return_sam=True) - return sam_model - - def _create_dummy_inputs(self, shape): - input_image = torch.ones(shape) - return input_image - - def test_peft_sam(self): - model_type = "vit_b" - device = "cpu" - - # Load the dummy inputs. - input_shape = (1, 512, 512) - inputs = self._create_dummy_inputs(shape=input_shape) - - # Convert to the inputs expected by Segment Anything - batched_inputs = [ - {"image": inputs, "original_size": input_shape[1:]} - ] - - # Load the Segment Anything model. - sam_model = self._fetch_sam_model(model_type=model_type, device=device) - - # Wrap the Segment Anything model with PEFT methods. - peft_sam_model = PEFT_Sam(model=sam_model, rank=4) - - # Get the model outputs - outputs = peft_sam_model(batched_input=batched_inputs, multimask_output=False) - - # Check the expected shape of the outputs - mask_shapes = [output["masks"].shape[-2:] for output in outputs] - for shape in mask_shapes: - self.assertEqual(shape, input_shape[1:]) - - -if __name__ == "__main__": - unittest.main() From 829a3b42242d7d3f7ef772701e4224da7515afe4 Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Tue, 9 Jul 2024 15:38:15 +0200 Subject: [PATCH 18/53] Refactor covid_if resource efficient finetuning script (#653) --- .../resource-efficient/covid_if_finetuning.py | 124 ++++-------------- 1 file changed, 27 insertions(+), 97 deletions(-) diff --git a/finetuning/specialists/resource-efficient/covid_if_finetuning.py b/finetuning/specialists/resource-efficient/covid_if_finetuning.py index 07dcc72ad..261087217 100644 --- a/finetuning/specialists/resource-efficient/covid_if_finetuning.py +++ b/finetuning/specialists/resource-efficient/covid_if_finetuning.py @@ -1,22 +1,25 @@ -import os import argparse import torch -from torch_em.model import UNETR from torch_em.data import MinInstanceSampler -from torch_em.loss import DiceBasedDistanceLoss +from torch_em.transform.raw import normalize from torch_em.data.datasets import get_covid_if_loader from torch_em.transform.label import PerObjectDistanceTransform import micro_sam.training as sam_training -from micro_sam.util import export_custom_sam_model + + +def covid_if_raw_trafo(raw): + raw = normalize(raw) + raw = raw * 255 + return raw def get_dataloaders(patch_shape, data_path, n_images): """This returns the immunofluoroscence data loaders implemented in torch_em: https://github.com/constantinpape/torch-em/blob/main/torch_em/data/datasets/covid_if.py - It will automatically download the IF data. + It will automatically download the immunofluoroscence data. Note: to replace this with another data loader you need to return a torch data loader that retuns `x, y` tensors, where `x` is the image data and `y` are the labels. @@ -29,7 +32,7 @@ def get_dataloaders(patch_shape, data_path, n_images): label_transform = PerObjectDistanceTransform( distances=True, boundary_distances=True, directed_distances=False, foreground=True, instances=True, min_size=25 ) - raw_transform = sam_training.identity # the current workflow avoids rescaling the inputs to [-1, 1] + raw_transform = covid_if_raw_trafo sampler = MinInstanceSampler() choice_of_images = [1, 2, 5, 10] @@ -67,7 +70,7 @@ def get_dataloaders(patch_shape, data_path, n_images): def finetune_covid_if(args): - """Example code for finetuning SAM on Covid-IF""" + """Code for finetuning SAM on Covid-IF""" # override this (below) if you have some more complex set-up and need to specify the exact gpu device = "cuda" if torch.cuda.is_available() else "cpu" @@ -77,105 +80,32 @@ def finetune_covid_if(args): patch_shape = (512, 512) # the patch shape for training n_objects_per_batch = args.n_objects # the number of objects per batch that will be sampled freeze_parts = args.freeze # override this to freeze different parts of the model + checkpoint_name = f"{args.model_type}/covid_if_sam" - # HACK: let's convert the model checkpoints to the desired format - if checkpoint_path is not None: - from pathlib import Path - target_checkpoint_path = os.path.join(Path(checkpoint_path).parent, "checkpoint.pt") - if not os.path.exists(target_checkpoint_path): - export_custom_sam_model( - checkpoint_path=checkpoint_path, model_type=model_type, save_path=target_checkpoint_path - ) - else: - target_checkpoint_path = checkpoint_path - - # get the trainable segment anything model - model = sam_training.get_trainable_sam_model( - model_type=model_type, device=device, checkpoint_path=target_checkpoint_path, freeze=freeze_parts - ) - model.to(device) - - # let's get the UNETR model for automatic instance segmentation pipeline - unetr = UNETR( - backbone="sam", - encoder=model.sam.image_encoder, - out_channels=3, - use_sam_stats=True, - final_activation="Sigmoid", - use_skip_connection=False, - resize_input=True, - use_conv_transpose=True, - ) - - # let's initialize the decoder block from the previous fine-tuning, if provided - if checkpoint_path is not None: - import pickle - from micro_sam.util import _CustomUnpickler - custom_unpickle = pickle - custom_unpickle.Unpickler = _CustomUnpickler - - decoder_state = torch.load( - checkpoint_path, map_location="cpu", pickle_module=custom_unpickle - )["decoder_state"] - unetr_state_dict = unetr.state_dict() - for k, v in unetr_state_dict.items(): - if not k.startswith("encoder"): - unetr_state_dict[k] = decoder_state[k] - unetr.load_state_dict(unetr_state_dict) - - unetr.to(device) - - # let's get the parameters for SAM and the decoder from UNETR - joint_model_params = [params for params in model.parameters()] # sam parameters - for name, params in unetr.named_parameters(): # unetr's decoder parameters - if not name.startswith("encoder"): - joint_model_params.append(params) - - # all the stuff we need for training - optimizer = torch.optim.Adam(joint_model_params, lr=1e-5) - scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.9, patience=3, verbose=True) + # all stuff we need for training train_loader, val_loader = get_dataloaders( patch_shape=patch_shape, data_path=args.input_path, n_images=args.n_images ) + scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 3, "verbose": True} - # this class creates all the training data for a batch (inputs, prompts and labels) - convert_inputs = sam_training.ConvertToSamInputs(transform=model.transform, box_distortion_factor=0.025) - - checkpoint_name = f"{args.model_type}/covid_if_sam" - - # the trainer which performs the joint training and validation (implemented using "torch_em") - trainer = sam_training.JointSamTrainer( + # Run training + sam_training.train_sam( name=checkpoint_name, - save_root=args.save_root, + model_type=model_type, train_loader=train_loader, val_loader=val_loader, - model=model, - optimizer=optimizer, - device=device, - lr_scheduler=scheduler, - logger=sam_training.JointSamLogger, - log_image_interval=100, - mixed_precision=True, - convert_inputs=convert_inputs, + early_stopping=10, n_objects_per_batch=n_objects_per_batch, - n_sub_iteration=8, - compile_model=False, - mask_prob=0.5, # (optional) overwrite to provide the probability of using mask inputs while training - unetr=unetr, - instance_loss=DiceBasedDistanceLoss(mask_distances_in_bg=True), - instance_metric=DiceBasedDistanceLoss(mask_distances_in_bg=True), - early_stopping=10 - ) - trainer.fit(epochs=args.epochs, save_every_kth_epoch=args.save_every_kth_epoch) - if args.export_path is not None: - checkpoint_path = os.path.join( - "" if args.save_root is None else args.save_root, "checkpoints", checkpoint_name, "best.pt" - ) - export_custom_sam_model( - checkpoint_path=checkpoint_path, - model_type=model_type, - save_path=args.export_path, - ) + checkpoint_path=checkpoint_path, + freeze=freeze_parts, + device=device, + lr=1e-5, + n_epochs=args.epochs, + save_root=args.save_root, + scheduler_kwargs=scheduler_kwargs, + save_every_kth_epoch=args.save_every_kth_epoch, + + ) def main(): From d6ffc0a98499315b8e65b5e3ab3fd60d747c69f5 Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Wed, 10 Jul 2024 13:46:24 +0200 Subject: [PATCH 19/53] Minor refactor to LoRA model initialization in `get_trainable_sam_model` (#654) Some refactoring of training functionality --- finetuning/livecell/lora/README.md | 15 +++ finetuning/livecell/lora/train_livecell.py | 110 ++++++--------------- micro_sam/models/peft_sam.py | 6 +- micro_sam/training/training.py | 29 +++--- micro_sam/training/util.py | 10 +- 5 files changed, 72 insertions(+), 98 deletions(-) create mode 100644 finetuning/livecell/lora/README.md diff --git a/finetuning/livecell/lora/README.md b/finetuning/livecell/lora/README.md new file mode 100644 index 000000000..9cc50de62 --- /dev/null +++ b/finetuning/livecell/lora/README.md @@ -0,0 +1,15 @@ +## Low Rank Adaptation Methods on Segment Anything for LIVECell + +Insights: +- There's no real memory advantage actually unless it's truly scaled up. For instance: + - `vit_b`: + - SAM: 93M (takes ~50GB) + - SAM-LoRA: 4.4M (takes ~61GB) + - `vit_l`: + - SAM: 312M (takes ~63GB) + - SAM-LoRA: 4.4M (takes ~61GB) + - `vit_h`: + - SAM: 641M (takes ~73GB) + - SAM-LoRA: 4.7M (takes ~67GB) + +- Question: Would quantization lead to better results? (e.g. QLoRA) or parallel adaptation? (e.g. DoRA) diff --git a/finetuning/livecell/lora/train_livecell.py b/finetuning/livecell/lora/train_livecell.py index fa8874372..6b12ac611 100644 --- a/finetuning/livecell/lora/train_livecell.py +++ b/finetuning/livecell/lora/train_livecell.py @@ -3,8 +3,6 @@ import torch -from torch_em.model import UNETR -from torch_em.loss import DiceBasedDistanceLoss from torch_em.data.datasets import get_livecell_loader from torch_em.transform.label import PerObjectDistanceTransform @@ -49,21 +47,6 @@ def count_parameters(model): def finetune_livecell(args): """Code for finetuning SAM (using LoRA) on LIVECell - - Initial observations: There's no real memory advantage actually unless it's "truly" scaled up - # vit_b - # SAM: 93M (takes ~50GB) - # SAM-LoRA: 4.2M (takes ~49GB) - - # vit_l - # SAM: 312M (takes ~63GB) - # SAM-LoRA: 4.4M (takes ~61GB) - - # vit_h - # SAM: 641M (takes ~73GB) - # SAM-LoRA: 4.7M (takes ~67GB) - - # Q: Would quantization lead to better results? (eg. QLoRA / DoRA) """ # override this (below) if you have some more complex set-up and need to specify the exact gpu device = "cuda" if torch.cuda.is_available() else "cpu" @@ -72,89 +55,49 @@ def finetune_livecell(args): model_type = args.model_type checkpoint_path = None # override this to start training from a custom checkpoint patch_shape = (520, 704) # the patch shape for training - n_objects_per_batch = 5 # this is the number of objects per batch that will be sampled + n_objects_per_batch = args.n_objects # this is the number of objects per batch that will be sampled freeze_parts = args.freeze # override this to freeze different parts of the model - rank = 4 # the rank - - # get the trainable segment anything model - model = sam_training.get_trainable_sam_model( - model_type=model_type, - device=device, - checkpoint_path=checkpoint_path, - freeze=freeze_parts, - use_lora=True, - rank=rank, - ) - model.to(device) - - # let's get the UNETR model for automatic instance segmentation pipeline - unetr = UNETR( - backbone="sam", - encoder=model.sam.image_encoder, - out_channels=3, - use_sam_stats=True, - final_activation="Sigmoid", - use_skip_connection=False, - resize_input=True, - ) - unetr.to(device) - - # let's check the total number of trainable parameters - print(count_parameters(model)) - - # let's get the parameters for SAM and the decoder from UNETR - joint_model_params = model.parameters() + lora_rank = 4 # the rank for low rank adaptation + checkpoint_name = f"{args.model_type}/livecell_sam" - joint_model_params = [params for params in joint_model_params] # sam parameters - for name, params in unetr.named_parameters(): # unetr's decoder parameters - if not name.startswith("encoder"): - joint_model_params.append(params) - - optimizer = torch.optim.Adam(joint_model_params, lr=1e-5) - scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.9, patience=10) + # all the stuff we need for training train_loader, val_loader = get_dataloaders(patch_shape=patch_shape, data_path=args.input_path) + scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 10, "verbose": True} + optimizer_class = torch.optim.AdamW - # this class creates all the training data for a batch (inputs, prompts and labels) - convert_inputs = sam_training.ConvertToSamInputs(transform=model.transform, box_distortion_factor=0.025) - - trainer = sam_training.JointSamTrainer( - name="livecell_lora", - save_root=args.save_root, + # Run training. + sam_training.train_sam( + name=checkpoint_name, + model_type=model_type, train_loader=train_loader, val_loader=val_loader, - model=model, - optimizer=optimizer, - device=device, - lr_scheduler=scheduler, - logger=sam_training.JointSamLogger, - log_image_interval=100, - mixed_precision=True, - convert_inputs=convert_inputs, + early_stopping=None, n_objects_per_batch=n_objects_per_batch, - n_sub_iteration=8, - compile_model=False, - mask_prob=0.5, # (optional) overwrite to provide the probability of using mask inputs while training - unetr=unetr, - instance_loss=DiceBasedDistanceLoss(mask_distances_in_bg=True), - instance_metric=DiceBasedDistanceLoss(mask_distances_in_bg=True) + checkpoint_path=checkpoint_path, + freeze=freeze_parts, + device=device, + lr=1e-5, + n_iterations=args.iterations, + save_root=args.save_root, + scheduler_kwargs=scheduler_kwargs, + optimizer_class=optimizer_class, + lora_rank=lora_rank, ) - trainer.fit(args.iterations) + if args.export_path is not None: checkpoint_path = os.path.join( - "" if args.save_root is None else args.save_root, "checkpoints", args.name, "best.pt" + "" if args.save_root is None else args.save_root, "checkpoints", checkpoint_name, "best.pt" ) export_custom_sam_model( - checkpoint_path=checkpoint_path, - model_type=model_type, - save_path=args.export_path, + checkpoint_path=checkpoint_path, model_type=model_type, save_path=args.export_path, ) def main(): - parser = argparse.ArgumentParser(description="Finetune Segment Anything for the LiveCELL dataset.") + parser = argparse.ArgumentParser(description="Finetune Segment Anything for the LIVECell dataset.") parser.add_argument( "--input_path", "-i", default="/scratch/projects/nim00007/sam/data/livecell/", - help="The filepath to the LiveCELL data. If the data does not exist yet it will be downloaded." + help="The filepath to the LIVECell data. If the data does not exist yet it will be downloaded." ) parser.add_argument( "--model_type", "-m", default="vit_b", @@ -176,6 +119,9 @@ def main(): "--freeze", type=str, nargs="+", default=None, help="Which parts of the model to freeze for finetuning." ) + parser.add_argument( + "--n_objects", type=int, default=25, help="The number of instances (objects) per batch used for finetuning." + ) args = parser.parse_args() finetune_livecell(args) diff --git a/micro_sam/models/peft_sam.py b/micro_sam/models/peft_sam.py index dcc38a56e..d2eaa9876 100644 --- a/micro_sam/models/peft_sam.py +++ b/micro_sam/models/peft_sam.py @@ -25,8 +25,8 @@ def __init__( block: nn.Module, ): super().__init__() - self.qkv = block.attn.qkv - self.dim = self.qkv.in_features + self.qkv_proj = block.attn.qkv + self.dim = self.qkv_proj.in_features self.w_a_linear_q = nn.Linear(self.dim, rank, bias=False) self.w_b_linear_q = nn.Linear(rank, self.dim, bias=False) @@ -44,7 +44,7 @@ def reset_parameters(self): nn.init.zeros_(self.w_b_linear_v.weight) def forward(self, x): - qkv = self.qkv(x) # B, N, N, 3 * org_C + qkv = self.qkv_proj(x) # B, N, N, 3 * org_C new_q = self.w_b_linear_q(self.w_a_linear_q(x)) new_v = self.w_b_linear_v(self.w_a_linear_v(x)) qkv[:, :, :, :self.dim] += new_q diff --git a/micro_sam/training/training.py b/micro_sam/training/training.py index 43fd28df2..bdb401680 100644 --- a/micro_sam/training/training.py +++ b/micro_sam/training/training.py @@ -3,8 +3,14 @@ from typing import Any, Dict, List, Optional, Tuple, Union import imageio.v3 as imageio + import torch +from torch.optim import Optimizer +from torch.optim.lr_scheduler import _LRScheduler +from torch.utils.data import DataLoader, Dataset + import torch_em +from torch_em.data.datasets.util import split_kwargs from elf.io import open_file @@ -13,16 +19,11 @@ except Exception: QObject = Any -from torch.optim.lr_scheduler import _LRScheduler -from torch.utils.data import DataLoader, Dataset -from torch_em.data.datasets.util import split_kwargs - from ..util import get_device -from ..instance_segmentation import get_unetr - -from .util import get_trainable_sam_model, ConvertToSamInputs, require_8bit from . import sam_trainer as trainers +from ..instance_segmentation import get_unetr from . import joint_sam_trainer as joint_trainers +from .util import get_trainable_sam_model, ConvertToSamInputs, require_8bit FilePath = Union[str, os.PathLike] @@ -146,6 +147,8 @@ def train_sam( scheduler_kwargs: Optional[Dict[str, Any]] = None, save_every_kth_epoch: Optional[int] = None, pbar_signals: Optional[QObject] = None, + optimizer_class: Optional[Optimizer] = torch.optim.AdamW, + **model_kwargs, ) -> None: """Run training for a SAM model. @@ -188,8 +191,12 @@ def train_sam( # Get the trainable segment anything model. model, state = get_trainable_sam_model( - model_type=model_type, device=device, freeze=freeze, - checkpoint_path=checkpoint_path, return_state=True, + model_type=model_type, + device=device, + freeze=freeze, + checkpoint_path=checkpoint_path, + return_state=True, + **model_kwargs ) # This class creates all the training data for a batch (inputs, prompts and labels). @@ -211,10 +218,10 @@ def train_sam( if not param_name.startswith("encoder"): joint_model_params.append(params) - optimizer = torch.optim.Adam(joint_model_params, lr=lr) + optimizer = optimizer_class(joint_model_params, lr=lr) else: - optimizer = torch.optim.Adam(model.parameters(), lr=lr) + optimizer = optimizer_class(model.parameters(), lr=lr) if scheduler_kwargs is None: scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 3, "verbose": True} diff --git a/micro_sam/training/util.py b/micro_sam/training/util.py index dae8598cd..4ba56961e 100644 --- a/micro_sam/training/util.py +++ b/micro_sam/training/util.py @@ -13,6 +13,7 @@ get_centers_and_bounding_boxes, get_sam_model, get_device, segmentation_to_one_hot, _DEFAULT_MODEL, ) +from .. import models as custom_models from .trainable_sam import TrainableSAM from torch_em.transform.label import PerObjectDistanceTransform @@ -61,7 +62,7 @@ def get_trainable_sam_model( return_state: Whether to return the full checkpoint state. lora_rank: The rank of the decomposition matrices for updating weights in each attention layer with lora. If None then LoRA is not used. - lora_kwargs: Keyword arguments for th PEFT wrapper class. + lora_kwargs: Keyword arguments for the PEFT wrapper class. flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints. model_kwargs: Additional keyword arguments for the `util.get_sam_model`. @@ -76,11 +77,16 @@ def get_trainable_sam_model( checkpoint_path=checkpoint_path, return_sam=True, return_state=True, - lora_rank=lora_rank, flexible_load_checkpoint=flexible_load_checkpoint, **model_kwargs ) + # NOTE: This is done exclusive to "get_sam_model" here to use PEFT's layer-specific initialization on top. + # Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything. + # Overwrites the SAM model by freezing the backbone and allow low rank adaption to attention layers. + if lora_rank is not None: + sam = custom_models.peft_sam.PEFT_Sam(sam, rank=lora_rank, **({} if lora_kwargs is None else lora_kwargs)).sam + # freeze components of the model if freeze was passed # ideally we would want to add components in such a way that: # - we would be able to freeze the choice of encoder/decoder blocks, yet be able to add components to the network From 7c6e1a4a7e37fce420ad97148bc81bce23c96fcb Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Thu, 11 Jul 2024 08:56:24 +0200 Subject: [PATCH 20/53] Add weighting to dice loss in semantic trainer (#656) Add weighted dice loss to semantic trainer --- micro_sam/training/semantic_sam_trainer.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/micro_sam/training/semantic_sam_trainer.py b/micro_sam/training/semantic_sam_trainer.py index 5c82b7d5a..61baf4c86 100644 --- a/micro_sam/training/semantic_sam_trainer.py +++ b/micro_sam/training/semantic_sam_trainer.py @@ -1,4 +1,5 @@ import time +from typing import Optional import torch import torch.nn as nn @@ -37,6 +38,7 @@ def __init__( self, convert_inputs, num_classes: int, + dice_weight: Optional[float] = None, **kwargs ): assert num_classes > 1 @@ -48,6 +50,11 @@ def __init__( self.convert_inputs = convert_inputs self.num_classes = num_classes self.compute_ce_loss = nn.CrossEntropyLoss() + self.dice_weight = dice_weight + + if self.dice_weight is not None: + assert self.dice_weight > 0 and self.dice_weight < 1, "The weight factor should lie between 0 and 1." + self._kwargs = kwargs def _compute_loss(self, y, masks): @@ -58,7 +65,11 @@ def _compute_loss(self, y, masks): # Compute cross entropy loss for the predictions ce_loss = self.compute_ce_loss(masks, target.squeeze(1).long()) - net_loss = dice_loss + ce_loss + if self.dice_weight is None: + net_loss = dice_loss + ce_loss + else: + net_loss = self.dice_weight * dice_loss + (1 - self.dice_weight) * ce_loss + return net_loss def _get_model_outputs(self, batched_inputs): From ed11cc24c94ecc853ff228c19876b08962dd28d8 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Mon, 15 Jul 2024 22:56:55 +0200 Subject: [PATCH 21/53] Add min-size to training and fix other issues (#658) Add min-size to training and fix other issues --- micro_sam/training/training.py | 39 +++++++++++++++++++++++++--------- 1 file changed, 29 insertions(+), 10 deletions(-) diff --git a/micro_sam/training/training.py b/micro_sam/training/training.py index bdb401680..39314e7fc 100644 --- a/micro_sam/training/training.py +++ b/micro_sam/training/training.py @@ -1,6 +1,6 @@ import os from glob import glob -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union import imageio.v3 as imageio @@ -290,11 +290,11 @@ def train_sam( def _update_patch_shape(patch_shape, raw_paths, raw_key, with_channels): - if not isinstance(raw_paths, (str, os.PathLike)): - path = raw_paths[0] - else: + if isinstance(raw_paths, (str, os.PathLike)): path = raw_paths - assert isinstance(raw_paths, (str, os.PathLike)) + else: + path = raw_paths[0] + assert isinstance(path, (str, os.PathLike)) # Check the underlying data dimensionality. if raw_key is None: # If no key is given then we assume it's an image file. @@ -327,9 +327,12 @@ def default_sam_dataset( patch_shape: Tuple[int], with_segmentation_decoder: bool, with_channels: bool = False, - sampler=None, # Type? + sampler: Optional[Callable] = None, + raw_transform: Optional[Callable] = None, n_samples: Optional[int] = None, is_train: bool = True, + min_size: int = 25, + max_sampling_attempts: Optional[int] = None, **kwargs, ) -> Dataset: """Create a PyTorch Dataset for training a SAM model. @@ -347,26 +350,34 @@ def default_sam_dataset( with_segmentation_decoder: Whether to train with additional segmentation decoder. with_channels: Whether the image data has RGB channels. sampler: A sampler to reject batches according to a given criterion. + raw_transform: Transformation applied to the image data. + If not given the data will be cast to 8bit. n_samples: The number of samples for this dataset. is_train: Whether this dataset is used for training or validation. + min_size: Minimal object size. Smaller objects will be filtered. + max_sampling_attempts: Number of sampling attempts to make from a dataset. Returns: The dataset. """ # Set the data transformations. - raw_transform = require_8bit + if raw_transform is None: + raw_transform = require_8bit + if with_segmentation_decoder: label_transform = torch_em.transform.label.PerObjectDistanceTransform( distances=True, boundary_distances=True, directed_distances=False, - foreground=True, instances=True, min_size=25, + foreground=True, instances=True, min_size=min_size, ) else: - label_transform = torch_em.transform.label.connected_components + label_transform = torch_em.transform.label.MinSizeLabelTransform( + min_size=min_size + ) # Set a default sampler if none was passed. if sampler is None: - sampler = torch_em.data.sampler.MinInstanceSampler(3) + sampler = torch_em.data.sampler.MinInstanceSampler(3, min_size=min_size) # Check the patch shape to add a singleton if required. patch_shape = _update_patch_shape( @@ -389,6 +400,14 @@ def default_sam_dataset( sampler=sampler, n_samples=n_samples, **kwargs, ) + + if max_sampling_attempts is not None: + if isinstance(dataset, torch_em.data.concat_dataset.ConcatDataset): + for ds in dataset.datasets: + ds.max_sampling_attempts = max_sampling_attempts + else: + dataset.max_sampling_attempts = max_sampling_attempts + return dataset From 83c83138ef6b05d3ffd8f95d62620a6c1ac414db Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Tue, 23 Jul 2024 17:25:22 +0200 Subject: [PATCH 22/53] Add AIS benchmarking scripts (#657) Add training scripts for unet and unetr --- .gitignore | 3 + micro_sam/training/__init__.py | 2 +- micro_sam/training/semantic_sam_trainer.py | 24 +- scripts/for_benchmarking_ais/common.py | 241 ++++++++++++++++++ .../for_benchmarking_ais/submit_scripts.py | 92 +++++++ .../for_benchmarking_ais/train_semanticsam.py | 109 ++++++++ scripts/for_benchmarking_ais/train_unet.py | 51 ++++ scripts/for_benchmarking_ais/train_unetr.py | 60 +++++ 8 files changed, 577 insertions(+), 5 deletions(-) create mode 100644 scripts/for_benchmarking_ais/common.py create mode 100644 scripts/for_benchmarking_ais/submit_scripts.py create mode 100644 scripts/for_benchmarking_ais/train_semanticsam.py create mode 100644 scripts/for_benchmarking_ais/train_unet.py create mode 100644 scripts/for_benchmarking_ais/train_unetr.py diff --git a/.gitignore b/.gitignore index 8ac4badcb..734e5772e 100644 --- a/.gitignore +++ b/.gitignore @@ -177,3 +177,6 @@ cython_debug/ # Torch-em stuff checkpoints/ logs/ + +# "gpu_jobs" folder where slurm batch submission scripts are saved +gpu_jobs/ diff --git a/micro_sam/training/__init__.py b/micro_sam/training/__init__.py index e825ba630..576d72ce8 100644 --- a/micro_sam/training/__init__.py +++ b/micro_sam/training/__init__.py @@ -5,5 +5,5 @@ from .util import ConvertToSamInputs, get_trainable_sam_model, identity from .joint_sam_trainer import JointSamTrainer, JointSamLogger from .simple_sam_trainer import SimpleSamTrainer, MedSAMTrainer -from .semantic_sam_trainer import SemanticSamTrainer +from .semantic_sam_trainer import SemanticSamTrainer, SemanticMapsSamTrainer from .training import train_sam, train_sam_for_configuration, default_sam_loader, default_sam_dataset, CONFIGURATIONS diff --git a/micro_sam/training/semantic_sam_trainer.py b/micro_sam/training/semantic_sam_trainer.py index 61baf4c86..cb136c30f 100644 --- a/micro_sam/training/semantic_sam_trainer.py +++ b/micro_sam/training/semantic_sam_trainer.py @@ -43,9 +43,13 @@ def __init__( ): assert num_classes > 1 - loss = CustomDiceLoss(num_classes=num_classes) - metric = CustomDiceLoss(num_classes=num_classes) - super().__init__(loss=loss, metric=metric, **kwargs) + if "loss" not in kwargs: + kwargs["loss"] = CustomDiceLoss(num_classes=num_classes) + + if "metric" not in kwargs: + kwargs["metric"] = CustomDiceLoss(num_classes=num_classes) + + super().__init__(**kwargs) self.convert_inputs = convert_inputs self.num_classes = num_classes @@ -141,6 +145,18 @@ def _validate_impl(self, forward_context): print(f"The Average Validation Metric Score for the Current Epoch is {dice_metric}") if self.logger is not None: - self.logger.log_validation(self._iteration, metric_val, loss_val, x, y, torch.softmax(masks, dim=1)) + self.logger.log_validation( + self._iteration, metric_val, loss_val, x, y, torch.softmax(masks, dim=1) + ) return metric_val + + +class SemanticMapsSamTrainer(SemanticSamTrainer): + def _compute_loss(self, y, masks): + target = y.to(self.device, non_blocking=True) + + # Compute loss for the predictions + net_loss = self.loss(target, masks) + + return net_loss diff --git a/scripts/for_benchmarking_ais/common.py b/scripts/for_benchmarking_ais/common.py new file mode 100644 index 000000000..975fbe040 --- /dev/null +++ b/scripts/for_benchmarking_ais/common.py @@ -0,0 +1,241 @@ +import os +import argparse +from glob import glob +from tqdm import tqdm + +import h5py +import numpy as np +import pandas as pd +import imageio.v3 as imageio + +import torch + +import torch_em +from torch_em.transform.raw import normalize +from torch_em.transform.raw import standardize +from torch_em.loss import DiceBasedDistanceLoss +from torch_em.util import segmentation, prediction +from torch_em.transform.label import PerObjectDistanceTransform +from torch_em.data.datasets.light_microscopy import get_livecell_loader, get_covid_if_loader + +import micro_sam.training as sam_training +from micro_sam.training.util import ConvertToSemanticSamInputs + +from elf.evaluation import mean_segmentation_accuracy + + +# +# DATALOADERS +# + + +def covid_if_raw_trafo(raw): + raw = normalize(raw) + raw = raw * 255 + return raw + + +def get_loaders(path, patch_shape, dataset, for_sam=False): + kwargs = { + "label_transform": PerObjectDistanceTransform( + distances=True, + boundary_distances=True, + directed_distances=False, + foreground=True, + min_size=25, + ), + "label_dtype": torch.float32, + "num_workers": 16, + "patch_shape": patch_shape, + "shuffle": True, + } + + if for_sam: + kwargs["raw_transform"] = sam_training.identity if dataset == "livecell" else covid_if_raw_trafo + + if dataset == "livecell": + train_loader = get_livecell_loader(path=os.path.join(path, "livecell"), split="train", batch_size=2, **kwargs) + val_loader = get_livecell_loader(path=os.path.join(path, "livecell"), split="val", batch_size=1, **kwargs) + + elif dataset.startswith("covid_if"): + data_path = os.path.join(path, "covid_if") + + # Let's get the number of images to train on + n_images = int(dataset.split("-")[-1]) + assert n_images in [1, 2, 5, 10], f"Please choose number of images from 1, 2, 5, or 10; instead of {n_images}." + + train_volumes = (None, n_images) + val_volumes = (10, 13) + + # Let's get the number of samples extracted, to set the "n_samples" value + # This is done to avoid the time taken to save checkpoints over fewer training images. + _loader = get_covid_if_loader( + path=data_path, patch_shape=patch_shape, batch_size=1, sample_range=train_volumes + ) + + print( + f"Found {len(_loader)} samples for training.", + "Hence, we will use {0} samples for training.".format(50 if len(_loader) < 50 else len(_loader)) + ) + + # Finally, let's get the dataloaders + train_loader = get_covid_if_loader( + path=data_path, + batch_size=1, + sample_range=train_volumes, + n_samples=50 if len(_loader) < 50 else None, + **kwargs + ) + val_loader = get_covid_if_loader( + path=data_path, + batch_size=1, + sample_range=val_volumes, + **kwargs + ) + + else: + raise ValueError(f"'{dataset}' is not a valid dataset name.") + + return train_loader, val_loader + + +# +# TRAINING SCRIPTS +# + + +def run_training(name, path, save_root, iterations, model, device, dataset, for_sam=False): + # all the necessary stuff for training + patch_shape = (512, 512) + train_loader, val_loader = get_loaders(path=path, patch_shape=patch_shape, dataset=dataset, for_sam=for_sam) + loss = DiceBasedDistanceLoss(mask_distances_in_bg=True) + + trainer = torch_em.default_segmentation_trainer( + name=name, + model=model, + train_loader=train_loader, + val_loader=val_loader, + device=device, + learning_rate=1e-5, + loss=loss, + metric=loss, + log_image_interval=50, + save_root=save_root, + compile_model=False, + mixed_precision=True, + scheduler_kwargs={"mode": "min", "factor": 0.9, "patience": 5, "verbose": True}, + ) + + trainer.fit(int(iterations)) + + +# +# INFERENCE SCRIPTS +# + +def run_inference( + path, checkpoint_path, model, device, result_path, dataset, for_sam=False, with_semantic_sam=False, +): + model.load_state_dict(torch.load(checkpoint_path, map_location="cpu")["model_state"]) + model.to(device) + model.eval() + + if dataset == "livecell": + # the splits are provided with the livecell dataset to reproduce the results: + # run the inference on the entire dataset as it is. + test_image_dir = os.path.join(path, "livecell", "images", "livecell_test_images") + all_test_labels = glob(os.path.join(path, "livecell", "annotations", "livecell_test_images", "*", "*")) + + elif dataset.startswith("covid_if"): + # we create our own splits for this dataset. + # - the first 10 images are dedicated for training. + # - the next 3 images are dedicated for validation. + # - the remaining images are used for testing + all_test_labels = glob(os.path.join(path, "covid_if", "*.h5"))[13:] + + else: + raise ValueError(f"'{dataset}' is not a valid dataset name.") + + def prediction_fn(net, inp): + convert_inputs = ConvertToSemanticSamInputs() + batched_inputs = convert_inputs(inp, torch.zeros_like(inp)) + image_embeddings, batched_inputs = net.image_embeddings_oft(batched_inputs) + batched_outputs = net(batched_inputs, image_embeddings, multimask_output=True) + masks = torch.stack([output["masks"] for output in batched_outputs]).squeeze() + masks = masks[None] + return masks + + msa_list, sa50_list, sa75_list = [], [], [] + for label_path in tqdm(all_test_labels): + image_id = os.path.split(label_path)[-1] + + if dataset == "livecell": + image = imageio.imread(os.path.join(test_image_dir, image_id)) + labels = imageio.imread(label_path) + else: + with h5py.File(label_path) as f: + image = f["raw/serum_IgG/s0"][:] + labels = f["labels/cells/s0"][:] + + if for_sam: + image = image.astype("float32") # functional interpolate cannot work with uint. + per_tile_pp = covid_if_raw_trafo if dataset.startswith("covid_if") else None + else: + per_tile_pp = standardize + + predictions = prediction.predict_with_halo( + input_=image, + model=model, + gpu_ids=[device], + block_shape=(384, 384), + halo=(64, 64), + preprocess=per_tile_pp, + disable_tqdm=True, + output=np.zeros((3, *image.shape)) if with_semantic_sam else None, + prediction_function=prediction_fn if with_semantic_sam else None, + ) + predictions = predictions.squeeze() + + fg, cdist, bdist = predictions + instances = segmentation.watershed_from_center_and_boundary_distances( + cdist, bdist, fg, min_size=50, + center_distance_threshold=0.5, + boundary_distance_threshold=0.6, + distance_smoothing=1.0 + ) + + msa, sa_acc = mean_segmentation_accuracy(instances, labels, return_accuracies=True) + msa_list.append(msa) + sa50_list.append(sa_acc[0]) + sa75_list.append(sa_acc[5]) + + res = { + "LIVECell" if dataset == "livecell" else "Covid IF": "Metrics", + "mSA": np.mean(msa_list), + "SA50": np.mean(sa50_list), + "SA75": np.mean(sa75_list) + } + + os.makedirs(result_path, exist_ok=True) + res_path = os.path.join(result_path, "results.csv") + df = pd.DataFrame.from_dict([res]) + df.to_csv(res_path) + print(df) + print(f"The result is saved at {res_path}") + + +# +# MISCELLANOUS +# + + +def get_default_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument("-d", "--dataset", type=str, required=True) + parser.add_argument("-i", "--input_path", type=str, default="/scratch/projects/nim00007/sam/data") + parser.add_argument("-s", "--save_root", type=str, default=None) + parser.add_argument("-p", "--phase", type=str, default=None, choices=["train", "predict"]) + parser.add_argument("--iterations", type=str, default=1e5) + parser.add_argument("--sam", action="store_true") + args = parser.parse_args() + return args diff --git a/scripts/for_benchmarking_ais/submit_scripts.py b/scripts/for_benchmarking_ais/submit_scripts.py new file mode 100644 index 000000000..5a4753106 --- /dev/null +++ b/scripts/for_benchmarking_ais/submit_scripts.py @@ -0,0 +1,92 @@ +import os +import shutil +import itertools +import subprocess +from datetime import datetime + + +def _write_batch_script(script_path, dataset_name, exp_script, save_root, phase, with_sam): + job_name = exp_script.split("_")[-1] + ("-sam-" if with_sam else "-") + dataset_name + + batch_script = f"""#!/bin/bash +#SBATCH -t 2-00:00:00 +#SBATCH --mem 64G +#SBATCH -c 16 +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH -p grete-h100:shared +#SBATCH -G H100:1 +#SBATCH -A gzz0001 +#SBATCH --job-name={job_name} + +source activate sam \n""" + + # python script + script = f"python {exp_script}.py " + + # all other parameters + script += f"-d {dataset_name} -s {save_root} -p {phase} " + + # whether the model is trained using SAM pretrained weights + if with_sam: + script += "--sam " + + # let's combine both the scripts + batch_script += script + + output_path = script_path[:-3] + f"_{job_name}.sh" + with open(output_path, "w") as f: + f.write(batch_script) + + cmd = ["sbatch", output_path] + subprocess.run(cmd) + + +def _get_batch_script(tmp_folder): + tmp_folder = os.path.expanduser(tmp_folder) + os.makedirs(tmp_folder, exist_ok=True) + + script_name = "ais_benchmarking" + + dt = datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f") + tmp_name = script_name + dt + batch_script = os.path.join(tmp_folder, f"{tmp_name}.sh") + + return batch_script + + +def _submit_to_slurm(tmp_folder): + save_root = "/scratch/share/cidas/cca/models/micro-sam/ais_benchmarking/" + phase = "predict" # this can be updated to "train" / "predict" to run the respective scripts. + + scripts = ["train_unet", "train_unetr", "train_semanticsam"] + datasets = ["livecell", "covid_if-1", "covid_if-2", "covid_if-5", "covid_if-10"] + sam_combinations = [True, False] + + for (exp_script, dataset_name, with_sam) in itertools.product(scripts, datasets, sam_combinations): + if exp_script.endswith("_unet") and with_sam: + continue + + _write_batch_script( + script_path=_get_batch_script(tmp_folder), + dataset_name=dataset_name, + exp_script=exp_script, + save_root=save_root, + phase=phase, + with_sam=with_sam, + ) + + +def main(): + tmp_folder = "./gpu_jobs" + + try: + shutil.rmtree(tmp_folder) + except FileNotFoundError: + pass + + _submit_to_slurm(tmp_folder) + + +if __name__ == "__main__": + main() diff --git a/scripts/for_benchmarking_ais/train_semanticsam.py b/scripts/for_benchmarking_ais/train_semanticsam.py new file mode 100644 index 000000000..d2927e780 --- /dev/null +++ b/scripts/for_benchmarking_ais/train_semanticsam.py @@ -0,0 +1,109 @@ +import os + +import torch + +from torch_em.loss import DiceBasedDistanceLoss + +import micro_sam.training as sam_training +from micro_sam.training.trainable_sam import TrainableSAM +from micro_sam.training.util import ConvertToSemanticSamInputs + +from segment_anything import sam_model_registry + +from common import get_default_arguments, get_loaders, run_inference + + +def run_semantic_training(path, save_root, iterations, model, device, for_sam, num_classes, dataset): + # all the stuff we need for training + optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5) + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, mode="min", factor=0.9, verbose=True, patience=10 if dataset.startswith("covid_if") else 5, + ) + + patch_shape = (512, 512) + train_loader, val_loader = get_loaders(path=path, patch_shape=patch_shape, dataset=dataset, for_sam=True) + + # this class creates all the training data for a batch (inputs, prompts and labels) + convert_inputs = ConvertToSemanticSamInputs() + + checkpoint_name = f"{dataset}_semanticsam" + ("-sam" if for_sam else "-scratch") + + # the trainer which performs the semantic segmentation training and validation (implemented using "torch_em") + trainer = sam_training.SemanticMapsSamTrainer( + name=checkpoint_name, + save_root=save_root, + train_loader=train_loader, + val_loader=val_loader, + model=model, + optimizer=optimizer, + device=device, + lr_scheduler=scheduler, + log_image_interval=50, + mixed_precision=True, + convert_inputs=convert_inputs, + num_classes=num_classes, + compile_model=False, + loss=DiceBasedDistanceLoss(mask_distances_in_bg=True), + metric=DiceBasedDistanceLoss(mask_distances_in_bg=True), + ) + trainer.fit(int(iterations)) + + +def main(args): + # training settings: + device = "cuda" if torch.cuda.is_available() else "cpu" + for_sam = args.sam + dataset = args.dataset + model_type = "vit_l" + num_classes = 3 + checkpoint_path = None + + if for_sam: + # This model is always initializes with pretrained SAM weights. + model = sam_training.get_trainable_sam_model( + model_type=model_type, + device=device, + checkpoint_path=checkpoint_path, + flexible_load_checkpoint=True, + num_multimask_outputs=num_classes, + ) + else: + # This model is initialized without the pretrained SAM weights. + sam = sam_model_registry[model_type]() + model = TrainableSAM(sam) + + model.to(device) + + if args.phase == "train": + run_semantic_training( + path=args.input_path, + save_root=args.save_root, + iterations=args.iterations, + model=model, + device=device, + for_sam=for_sam, + num_classes=num_classes, + dataset=dataset, + ) + + if args.phase == "predict": + checkpoint_path = os.path.join( + "./" if args.save_root is None else args.save_root, + "checkpoints", f"{dataset}_semanticsam-sam" if for_sam else f"{dataset}_semanticsam-scratch", "best.pt" + ) + result_path = f"results/{dataset}-semanticsam" + run_inference( + path=args.input_path, + checkpoint_path=checkpoint_path, + model=model, + device=device, + result_path=result_path, + for_sam=True, + with_semantic_sam=True, + dataset=dataset, + ) + + +if __name__ == "__main__": + args = get_default_arguments() + main(args) diff --git a/scripts/for_benchmarking_ais/train_unet.py b/scripts/for_benchmarking_ais/train_unet.py new file mode 100644 index 000000000..163189562 --- /dev/null +++ b/scripts/for_benchmarking_ais/train_unet.py @@ -0,0 +1,51 @@ +import os + +from common import get_default_arguments, run_inference, run_training + +import torch + +from torch_em.model import UNet2d +from torch_em.model.unetr import SingleDeconv2DBlock + + +def main(args): + dataset = args.dataset + device = "cuda" if torch.cuda.is_available() else "cpu" + model = UNet2d( + in_channels=1, + out_channels=3, + initial_features=64, + final_activation="Sigmoid", + sampler_impl=SingleDeconv2DBlock, + ) + model.to(device) + + if args.phase == "train": + run_training( + name=f"{dataset}-unet", + path=args.input_path, + save_root=args.save_root, + iterations=args.iterations, + model=model, + device=device, + dataset=dataset, + ) + + if args.phase == "predict": + checkpoint_path = os.path.join( + "./" if args.save_root is None else args.save_root, "checkpoints", f"{dataset}-unet", "best.pt" + ) + result_path = f"results/{dataset}_unet/" + run_inference( + path=args.input_path, + checkpoint_path=checkpoint_path, + model=model, + device=device, + result_path=result_path, + dataset=dataset, + ) + + +if __name__ == "__main__": + args = get_default_arguments() + main(args) diff --git a/scripts/for_benchmarking_ais/train_unetr.py b/scripts/for_benchmarking_ais/train_unetr.py new file mode 100644 index 000000000..e1e37310c --- /dev/null +++ b/scripts/for_benchmarking_ais/train_unetr.py @@ -0,0 +1,60 @@ +import os + +from common import get_default_arguments, run_training, run_inference + +import torch + +from torch_em.model import UNETR + + +SAM_PRETRAINED = "/scratch-grete/share/cidas/cca/models/sam/sam_vit_l_0b3195.pth" + + +def main(args): + dataset = args.dataset + for_sam = args.sam + device = "cuda" if torch.cuda.is_available() else "cpu" + checkpoint_path = SAM_PRETRAINED if for_sam and args.phase == "train" else None + + model = UNETR( + encoder="vit_l", + out_channels=3, + final_activation="Sigmoid", + use_skip_connection=False, + use_sam_stats=for_sam, + encoder_checkpoint=checkpoint_path, + ) + model.to(device) + + if args.phase == "train": + run_training( + name=f"{dataset}-unetr-sam" if for_sam else f"{dataset}-unetr", + path=args.input_path, + save_root=args.save_root, + iterations=args.iterations, + model=model, + device=device, + for_sam=for_sam, + dataset=dataset, + ) + + if args.phase == "predict": + ckpt_path = os.path.join( + "./" if args.save_root is None else args.save_root, + "checkpoints", f"{dataset}-unetr-sam" if for_sam else f"{dataset}-unetr", "best.pt" + ) + result_path = "results/" + f"{dataset}-unetr-sam" if for_sam else f"{dataset}-unetr" + run_inference( + path=args.input_path, + checkpoint_path=ckpt_path, + model=model, + device=device, + result_path=result_path, + for_sam=for_sam, + dataset=dataset, + ) + + +if __name__ == "__main__": + args = get_default_arguments() + main(args) From 8add576465b21048522b00e5f3680a85d52d04cb Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Wed, 24 Jul 2024 09:29:34 +0200 Subject: [PATCH 23/53] Update docstrings to document recent functionalities (#664) Update docstrings --- finetuning/livecell_finetuning.py | 6 ++-- micro_sam/_vendored.py | 2 ++ micro_sam/bioimageio/model_export.py | 12 +++---- micro_sam/evaluation/evaluation.py | 3 +- micro_sam/evaluation/experiments.py | 1 + micro_sam/evaluation/instance_segmentation.py | 10 +++--- micro_sam/evaluation/livecell.py | 3 +- micro_sam/evaluation/model_comparison.py | 13 ++++---- micro_sam/inference.py | 5 +-- micro_sam/instance_segmentation.py | 18 ++++++----- micro_sam/models/build_sam.py | 1 + micro_sam/models/peft_sam.py | 4 +++ micro_sam/models/sam_3d_wrapper.py | 5 +-- micro_sam/models/simple_sam_3d_wrapper.py | 1 - micro_sam/multi_dimensional_segmentation.py | 9 ++++-- micro_sam/precompute_state.py | 5 +-- micro_sam/prompt_based_segmentation.py | 9 ++++-- micro_sam/sam_annotator/_annotator.py | 4 +-- micro_sam/sam_annotator/_state.py | 7 ++-- micro_sam/sam_annotator/_tooltips.py | 32 +++++++++---------- micro_sam/sam_annotator/_widgets.py | 11 ++++--- micro_sam/sam_annotator/annotator_2d.py | 1 + micro_sam/sam_annotator/annotator_3d.py | 1 + micro_sam/sam_annotator/annotator_tracking.py | 1 + .../sam_annotator/image_series_annotator.py | 4 +-- micro_sam/sam_annotator/training_ui.py | 10 +++--- micro_sam/sam_annotator/util.py | 6 ++-- micro_sam/sample_data.py | 6 ++-- micro_sam/training/joint_sam_trainer.py | 16 ++++++++++ micro_sam/training/sam_trainer.py | 7 ++-- micro_sam/training/semantic_sam_trainer.py | 25 ++++++++++++++- micro_sam/training/simple_sam_trainer.py | 21 ++++++++++-- micro_sam/training/trainable_sam.py | 10 ++---- micro_sam/training/training.py | 4 ++- micro_sam/training/util.py | 3 +- micro_sam/util.py | 20 ++++++------ micro_sam/visualization.py | 5 +-- 37 files changed, 190 insertions(+), 111 deletions(-) diff --git a/finetuning/livecell_finetuning.py b/finetuning/livecell_finetuning.py index 1cd942fdb..6b63a5a06 100644 --- a/finetuning/livecell_finetuning.py +++ b/finetuning/livecell_finetuning.py @@ -40,7 +40,7 @@ def get_dataloaders(patch_shape, data_path, cell_type=None): def finetune_livecell(args): - """Example code for finetuning SAM on LiveCELL""" + """Example code for finetuning SAM on LIVECell""" # override this (below) if you have some more complex set-up and need to specify the exact gpu device = "cuda" if torch.cuda.is_available() else "cpu" @@ -84,10 +84,10 @@ def finetune_livecell(args): def main(): - parser = argparse.ArgumentParser(description="Finetune Segment Anything for the LiveCELL dataset.") + parser = argparse.ArgumentParser(description="Finetune Segment Anything for the LIVECell dataset.") parser.add_argument( "--input_path", "-i", default="/scratch/projects/nim00007/sam/data/livecell/", - help="The filepath to the LiveCELL data. If the data does not exist yet it will be downloaded." + help="The filepath to the LIVECell data. If the data does not exist yet it will be downloaded." ) parser.add_argument( "--model_type", "-m", default="vit_b", diff --git a/micro_sam/_vendored.py b/micro_sam/_vendored.py index 14446e016..976f8b4a7 100644 --- a/micro_sam/_vendored.py +++ b/micro_sam/_vendored.py @@ -6,9 +6,11 @@ The license type of the thrid party software project must be compatible with the software license the micro-sam project is distributed under. """ + from typing import Any, Dict, List import numpy as np + import torch try: diff --git a/micro_sam/bioimageio/model_export.py b/micro_sam/bioimageio/model_export.py index 3b270c58f..dfee96122 100644 --- a/micro_sam/bioimageio/model_export.py +++ b/micro_sam/bioimageio/model_export.py @@ -1,26 +1,26 @@ import os import tempfile - from pathlib import Path from typing import Optional, Union -import bioimageio.core -import bioimageio.spec.model.v0_5 as spec -import matplotlib.pyplot as plt +import xarray import numpy as np +import matplotlib.pyplot as plt + import torch -import xarray +import bioimageio.core +import bioimageio.spec.model.v0_5 as spec from bioimageio.spec import save_bioimageio_package from bioimageio.core.digest_spec import create_sample_for_model - from .. import util from ..prompt_generators import PointAndBoxPromptGenerator from ..evaluation.model_comparison import _enhance_image, _overlay_outline, _overlay_box from ..prompt_based_segmentation import _compute_logits_from_mask from .predictor_adaptor import PredictorAdaptor + DEFAULTS = { "authors": [ spec.Author(name="Anwai Archit", affiliation="University Goettingen", github_user="anwai98"), diff --git a/micro_sam/evaluation/evaluation.py b/micro_sam/evaluation/evaluation.py index 4fc76146b..a52a11266 100644 --- a/micro_sam/evaluation/evaluation.py +++ b/micro_sam/evaluation/evaluation.py @@ -11,8 +11,8 @@ import numpy as np import pandas as pd import imageio.v3 as imageio - from skimage.measure import label + from elf.evaluation import mean_segmentation_accuracy @@ -88,6 +88,7 @@ def run_evaluation_for_iterative_prompting( prediction_root: The folder with the iterative prompt-based instance segmentations to evaluate. experiment_folder: The folder where all the experiment results are stored. start_with_box_prompt: Whether to evaluate on experiments with iterative prompting starting with box. + overwrite_results: Whether to overwrite the results to update them with the new evaluation run. Returns: A DataFrame that contains the evaluation results. diff --git a/micro_sam/evaluation/experiments.py b/micro_sam/evaluation/experiments.py index 4646af527..5b5b9c76f 100644 --- a/micro_sam/evaluation/experiments.py +++ b/micro_sam/evaluation/experiments.py @@ -3,6 +3,7 @@ from typing import Dict, List, Optional + # TODO fully define the dict type ExperimentSetting = Dict ExperimentSettings = List[ExperimentSetting] diff --git a/micro_sam/evaluation/instance_segmentation.py b/micro_sam/evaluation/instance_segmentation.py index 9d6331e63..5e6571904 100644 --- a/micro_sam/evaluation/instance_segmentation.py +++ b/micro_sam/evaluation/instance_segmentation.py @@ -3,20 +3,20 @@ import os from glob import glob -from itertools import product +from tqdm import tqdm from pathlib import Path +from itertools import product from typing import Any, Dict, List, Optional, Tuple, Union -import imageio.v3 as imageio import numpy as np import pandas as pd +import imageio.v3 as imageio -from elf.evaluation import mean_segmentation_accuracy from elf.io import open_file -from tqdm import tqdm +from elf.evaluation import mean_segmentation_accuracy -from ..instance_segmentation import AMGBase, InstanceSegmentationWithDecoder, mask_data_to_segmentation from .. import util +from ..instance_segmentation import AMGBase, InstanceSegmentationWithDecoder, mask_data_to_segmentation def _get_range_of_search_values(input_vals, step): diff --git a/micro_sam/evaluation/livecell.py b/micro_sam/evaluation/livecell.py index f0699ab82..c9d75f510 100644 --- a/micro_sam/evaluation/livecell.py +++ b/micro_sam/evaluation/livecell.py @@ -1,6 +1,7 @@ """Inference and evaluation for the [LIVECell dataset](https://www.nature.com/articles/s41592-021-01249-6) and the different cell lines contained in it. """ + import os import json import argparse @@ -422,7 +423,7 @@ def run_livecell_inference() -> None: def run_livecell_evaluation() -> None: - """Run LiveCELL evaluation with command line tool.""" + """Run LIVECell evaluation with command line tool.""" parser = argparse.ArgumentParser() parser.add_argument( "-i", "--input", required=True, help="Provide the data directory for LIVECell Dataset" diff --git a/micro_sam/evaluation/model_comparison.py b/micro_sam/evaluation/model_comparison.py index 4b130aec4..6e37528d5 100644 --- a/micro_sam/evaluation/model_comparison.py +++ b/micro_sam/evaluation/model_comparison.py @@ -2,23 +2,22 @@ """ import os -from functools import partial from glob import glob +from tqdm import tqdm from pathlib import Path +from functools import partial +from typing import Optional, Union import h5py -import matplotlib.pyplot as plt import numpy as np import pandas as pd -import torch - +import matplotlib.pyplot as plt import skimage.draw as draw -from scipy.ndimage import binary_dilation from skimage import exposure +from scipy.ndimage import binary_dilation from skimage.segmentation import relabel_sequential, find_boundaries -from tqdm import tqdm -from typing import Optional, Union +import torch from .. import util from ..prompt_generators import PointAndBoxPromptGenerator diff --git a/micro_sam/inference.py b/micro_sam/inference.py index 6d67b38e2..8725dea95 100644 --- a/micro_sam/inference.py +++ b/micro_sam/inference.py @@ -1,11 +1,12 @@ import os from typing import Optional, Union -import torch import numpy as np -import segment_anything.utils.amg as amg_utils +import torch + from segment_anything import SamPredictor +import segment_anything.utils.amg as amg_utils from segment_anything.utils.transforms import ResizeLongestSide from . import util diff --git a/micro_sam/instance_segmentation.py b/micro_sam/instance_segmentation.py index 23d666b97..d86c534ef 100644 --- a/micro_sam/instance_segmentation.py +++ b/micro_sam/instance_segmentation.py @@ -6,24 +6,25 @@ import os from abc import ABC -from collections import OrderedDict from copy import deepcopy +from collections import OrderedDict from typing import Any, Dict, List, Optional, Tuple, Union -import numpy as np -import torch -import segment_anything.utils.amg as amg_utils import vigra - -from nifty.tools import blocking -from segment_anything.predictor import SamPredictor - +import numpy as np from skimage.measure import regionprops + +import torch from torchvision.ops.boxes import batched_nms, box_area from torch_em.model import UNETR from torch_em.util.segmentation import watershed_from_center_and_boundary_distances +from nifty.tools import blocking + +import segment_anything.utils.amg as amg_utils +from segment_anything.predictor import SamPredictor + from . import util from ._vendored import batched_mask_to_box, mask_to_rle_pytorch @@ -56,6 +57,7 @@ def mask_data_to_segmentation( object in the output will be mapped to zero (the background value). min_object_size: The minimal size of an object in pixels. max_object_size: The maximal size of an object in pixels. + Returns: The instance segmentation. """ diff --git a/micro_sam/models/build_sam.py b/micro_sam/models/build_sam.py index 8fa6bcc6a..901c6c383 100644 --- a/micro_sam/models/build_sam.py +++ b/micro_sam/models/build_sam.py @@ -3,6 +3,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +# https://github.com/facebookresearch/segment-anything/ # # NOTE: This code has been adapted from Segment Anything. diff --git a/micro_sam/models/peft_sam.py b/micro_sam/models/peft_sam.py index d2eaa9876..2bdeed702 100644 --- a/micro_sam/models/peft_sam.py +++ b/micro_sam/models/peft_sam.py @@ -18,6 +18,10 @@ class LoRASurgery(nn.Module): qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) ``` + + Args: + rank: The rank of the decomposition matrices for updating weights in each attention layer. + block: The chosen attention blocks for implementing lora. """ def __init__( self, diff --git a/micro_sam/models/sam_3d_wrapper.py b/micro_sam/models/sam_3d_wrapper.py index 4a7645d04..3e0b7573e 100644 --- a/micro_sam/models/sam_3d_wrapper.py +++ b/micro_sam/models/sam_3d_wrapper.py @@ -3,8 +3,8 @@ import torch import torch.nn as nn -from segment_anything.modeling.image_encoder import window_partition, window_unpartition from segment_anything.modeling import Sam +from segment_anything.modeling.image_encoder import window_partition, window_unpartition from ..util import get_sam_model @@ -42,6 +42,7 @@ def __init__(self, sam_model: Sam, freeze_encoder: bool): Args: sam_model: The Sam model to be wrapped. + freeze_encoder: Whether to freeze the image encoder. """ super().__init__() sam_model.image_encoder = ImageEncoderViT3DWrapper( @@ -64,7 +65,7 @@ def forward( Unlike original SAM this model only supports automatic segmentation and does not support prompts. Args: - batched_input: A list over input images, each a dictionary with the following keys.L + batched_input: A list over input images, each a dictionary with the following keys. 'image': The image as a torch tensor in 3xDxHxW format. Already transformed for the input to the model. 'original_size': The original size of the image (HxW) before transformation. multimask_output: Wheterh to predict with the multi- or single-mask head of the maks decoder. diff --git a/micro_sam/models/simple_sam_3d_wrapper.py b/micro_sam/models/simple_sam_3d_wrapper.py index cf4ddbccb..6f67caa47 100644 --- a/micro_sam/models/simple_sam_3d_wrapper.py +++ b/micro_sam/models/simple_sam_3d_wrapper.py @@ -16,7 +16,6 @@ def get_simple_sam_3d_model( model_type="vit_b", checkpoint_path=None, ): - _, sam = get_sam_model( model_type=model_type, device=device, diff --git a/micro_sam/multi_dimensional_segmentation.py b/micro_sam/multi_dimensional_segmentation.py index 2c65d4f10..c8747ed88 100644 --- a/micro_sam/multi_dimensional_segmentation.py +++ b/micro_sam/multi_dimensional_segmentation.py @@ -5,23 +5,26 @@ from typing import Optional, Union, Tuple import numpy as np + import nifty -import elf.tracking.tracking_utils as track_utils + import elf.segmentation as seg_utils +import elf.tracking.tracking_utils as track_utils -from segment_anything.predictor import SamPredictor from scipy.ndimage import binary_closing from skimage.measure import label, regionprops from skimage.segmentation import relabel_sequential +from segment_anything.predictor import SamPredictor + try: from napari.utils import progress as tqdm except ImportError: from tqdm import tqdm from . import util -from .instance_segmentation import AMGBase, mask_data_to_segmentation from .prompt_based_segmentation import segment_from_mask +from .instance_segmentation import AMGBase, mask_data_to_segmentation PROJECTION_MODES = ("box", "mask", "points", "points_and_mask", "single_point") diff --git a/micro_sam/precompute_state.py b/micro_sam/precompute_state.py index d07ea1bcb..e4a970b75 100644 --- a/micro_sam/precompute_state.py +++ b/micro_sam/precompute_state.py @@ -3,16 +3,17 @@ import os import pickle - -from functools import partial from glob import glob from pathlib import Path +from functools import partial from typing import Optional, Tuple, Union, List import h5py import numpy as np + import torch import torch.nn as nn + from segment_anything.predictor import SamPredictor try: diff --git a/micro_sam/prompt_based_segmentation.py b/micro_sam/prompt_based_segmentation.py index 9de5954db..e2bb1026a 100644 --- a/micro_sam/prompt_based_segmentation.py +++ b/micro_sam/prompt_based_segmentation.py @@ -6,15 +6,18 @@ from typing import Optional, Tuple import numpy as np -import torch -from nifty.tools import blocking -from skimage.feature import peak_local_max from skimage.filters import gaussian +from skimage.feature import peak_local_max from skimage.segmentation import find_boundaries from scipy.ndimage import distance_transform_edt +import torch + +from nifty.tools import blocking + from segment_anything.predictor import SamPredictor from segment_anything.utils.transforms import ResizeLongestSide + from . import util diff --git a/micro_sam/sam_annotator/_annotator.py b/micro_sam/sam_annotator/_annotator.py index fa8da43aa..974aeccb5 100644 --- a/micro_sam/sam_annotator/_annotator.py +++ b/micro_sam/sam_annotator/_annotator.py @@ -1,11 +1,11 @@ import napari import numpy as np -from magicgui.widgets import Widget, Container, FunctionGui from qtpy import QtWidgets +from magicgui.widgets import Widget, Container, FunctionGui -from . import _widgets as widgets from . import util as vutil +from . import _widgets as widgets from ._state import AnnotatorState diff --git a/micro_sam/sam_annotator/_state.py b/micro_sam/sam_annotator/_state.py index 639642ea9..b57c1f80c 100644 --- a/micro_sam/sam_annotator/_state.py +++ b/micro_sam/sam_annotator/_state.py @@ -3,18 +3,19 @@ https://itnext.io/deciding-the-best-singleton-approach-in-python-65c61e90cdc4 """ -from dataclasses import dataclass, field from functools import partial +from dataclasses import dataclass, field from typing import Dict, List, Optional, Tuple +import zarr import numpy as np +from qtpy.QtWidgets import QWidget + import torch.nn as nn -import zarr import micro_sam.util as util from micro_sam.instance_segmentation import AMGBase, get_decoder from micro_sam.precompute_state import cache_amg_state, cache_is_state -from qtpy.QtWidgets import QWidget from segment_anything import SamPredictor diff --git a/micro_sam/sam_annotator/_tooltips.py b/micro_sam/sam_annotator/_tooltips.py index 068dd44d7..8ddda0ce0 100644 --- a/micro_sam/sam_annotator/_tooltips.py +++ b/micro_sam/sam_annotator/_tooltips.py @@ -6,25 +6,25 @@ "custom_weights": "Select custom model weights. For example for a model you have finetuned", "device": "Select the computational device to use for processing.", "embeddings_save_path": "Select path to save or load the computed image embeddings.", - "halo": "Enter overlap values for computing tiled embeddings. Enter only x-value for quadratic size.\n Only active when tiling is used.", + "halo": "Enter overlap values for computing tiled embeddings. Enter only x-value for quadratic size.\n Only active when tiling is used.", # noqa "image": "Select the napari image layer.", "model": "Select the segment anything model.", - "prefer_decoder": "Choose if the segmentation decoder is used for automatic segmentation. Only if it is available for the selected model..", + "prefer_decoder": "Choose if the segmentation decoder is used for automatic segmentation. Only if it is available for the selected model..", # noqa "run_button": "Compute embeddings or load embeddings if embedding_save_path is specified.", - "tiling": "Enter tile size for computing tiled embeddings. Enter only x-value for quadratic size or both for non-quadratic.", + "tiling": "Enter tile size for computing tiled embeddings. Enter only x-value for quadratic size or both for non-quadratic.", # noqa }, "segmentnd": { - "box_extension": "Enter factor by which box size is increased when projecting to adjacent slices. Larger factors help if object sizes change between slices.", + "box_extension": "Enter factor by which box size is increased when projecting to adjacent slices. Larger factors help if object sizes change between slices.", # noqa "iou_threshold": "Enter the minimal overlap between objects in adjacent slices to continue segmentation.", - "motion_smoothing": "Enter the motion smoothing factor. It is used to follow objects which have a directed movement, higher values help for objects that are moving fast.", - "projection_dropdown": "Choose the projection mode. It determines which prompts are derived from the masks projected to adjacent frames to rerun SAM.", + "motion_smoothing": "Enter the motion smoothing factor. It is used to follow objects which have a directed movement, higher values help for objects that are moving fast.", # noqa + "projection_dropdown": "Choose the projection mode. It determines which prompts are derived from the masks projected to adjacent frames to rerun SAM.", # noqa }, "autosegment": { # General settings. "apply_to_volume": "Choose if automatic segmentation is run for the full volume or only the current slice.", - "gap_closing": "Enter value for closing gaps across slices for volumetric segmentation. Higher values will reduce artifacts due to missing slices in objects but may lead to wrongly merging objects.", - "min_extent": "Enter the minimal number of slices for objects in volumetric segmentation. To filter out small segmentation artifacts.", - "min_object_size": "Enter the minimal object size in pixels. This refers to the size per slice for volumetric segmentation.", + "gap_closing": "Enter value for closing gaps across slices for volumetric segmentation. Higher values will reduce artifacts due to missing slices in objects but may lead to wrongly merging objects.", # noqa + "min_extent": "Enter the minimal number of slices for objects in volumetric segmentation. To filter out small segmentation artifacts.", # noqa + "min_object_size": "Enter the minimal object size in pixels. This refers to the size per slice for volumetric segmentation.", # noqa "run_button": "Run automatic segmentation.", "with_background": "Choose if your image has a large background area.", # Settings for AIS. @@ -36,28 +36,28 @@ "stability_score_thresh": "Enter the threshold for filtering objects based on the stability score.", }, "prompt_menu": { - "labels": "Choose positive prompts to inlcude regions or negative ones to exclude regions. Toggle between the settings by pressing [t].", + "labels": "Choose positive prompts to inlcude regions or negative ones to exclude regions. Toggle between the settings by pressing [t].", # noqa }, "annotator_tracking": { "track_id": "Select the id of the track you are currently annotating.", - "track_state": "Select the state of the current annotation. Choose 'division' if the object is dviding in the current frame.", + "track_state": "Select the state of the current annotation. Choose 'division' if the object is dviding in the current frame.", # noqa }, "image_series_annotator": { "folder": "Select the folder with the images to annotate.", "output_folder": "Select the folder for saving the segmentation results.", - "pattern": "Select a pattern for selecting files. E.g. '*.tif' to only select tif files. By default all files in the input folder are selected.", + "pattern": "Select a pattern for selecting files. E.g. '*.tif' to only select tif files. By default all files in the input folder are selected.", # noqa "is_volumetric": "Choose if the data you annotate is volumetric.", }, "training": { "checkpoint": "Select a checkpoint (saved model) to resume training from.", "device": "Select the computational device to use for processing.", "initial_model": "Select the model name used as starting point for training.", - "label_key": "Define the key that holds to the segmentation labels. Use a pattern, e.g. \"*.tif\" select multiple files or an internal path for hdf5, zarr or similar formats.", - "label_path": "Specify the path to the segmentaiton labels for training. Can either point to a directory or single file.", - "label_path_val": "Specify the path to the segmentation labels for validation. Can either point to a directory or single file.", + "label_key": "Define the key that holds to the segmentation labels. Use a pattern, e.g. \"*.tif\" select multiple files or an internal path for hdf5, zarr or similar formats.", # noqa + "label_path": "Specify the path to the segmentaiton labels for training. Can either point to a directory or single file.", # noqa + "label_path_val": "Specify the path to the segmentation labels for validation. Can either point to a directory or single file.", # noqa "name": "Enter the name of the model that will be trained.", "patch": "Select the size of image patches used for training.", - "raw_key": "Define the key that holds to the image data. Use a pattern, e.g. \"*.tif\" select multiple files or an internal path for hdf5, zarr or similar formats.", + "raw_key": "Define the key that holds to the image data. Use a pattern, e.g. \"*.tif\" select multiple files or an internal path for hdf5, zarr or similar formats.", # noqa "raw_path": "Specify the path to the image data for training. Can either point to a directory or single file.", "raw_path_val": "Specify the path to the image data for training. Can either point to a directory or single file.", "segmentation_decoder": "Choose whether to train with additional segmentation decoder or not.", diff --git a/micro_sam/sam_annotator/_widgets.py b/micro_sam/sam_annotator/_widgets.py index 4822dc37a..8307998f1 100644 --- a/micro_sam/sam_annotator/_widgets.py +++ b/micro_sam/sam_annotator/_widgets.py @@ -1,19 +1,20 @@ """Implements the widgets used in the annotation plugins. """ -import json -import multiprocessing as mp import os import pickle from pathlib import Path from typing import Optional +import multiprocessing as mp -import elf.parallel import h5py -import napari -import numpy as np +import json import zarr import z5py +import napari +import numpy as np + +import elf.parallel from qtpy import QtWidgets from qtpy.QtCore import QObject, Signal diff --git a/micro_sam/sam_annotator/annotator_2d.py b/micro_sam/sam_annotator/annotator_2d.py index 6fc01742a..fec4d5d82 100644 --- a/micro_sam/sam_annotator/annotator_2d.py +++ b/micro_sam/sam_annotator/annotator_2d.py @@ -2,6 +2,7 @@ import napari import numpy as np + import torch from . import _widgets as widgets diff --git a/micro_sam/sam_annotator/annotator_3d.py b/micro_sam/sam_annotator/annotator_3d.py index dfcf12a7e..026e222d5 100644 --- a/micro_sam/sam_annotator/annotator_3d.py +++ b/micro_sam/sam_annotator/annotator_3d.py @@ -2,6 +2,7 @@ import napari import numpy as np + import torch from ._annotator import _AnnotatorBase diff --git a/micro_sam/sam_annotator/annotator_tracking.py b/micro_sam/sam_annotator/annotator_tracking.py index 183678d5e..d82b0923f 100644 --- a/micro_sam/sam_annotator/annotator_tracking.py +++ b/micro_sam/sam_annotator/annotator_tracking.py @@ -2,6 +2,7 @@ import napari import numpy as np + import torch from magicgui.widgets import ComboBox, Container diff --git a/micro_sam/sam_annotator/image_series_annotator.py b/micro_sam/sam_annotator/image_series_annotator.py index 561abc550..f4c7ce716 100644 --- a/micro_sam/sam_annotator/image_series_annotator.py +++ b/micro_sam/sam_annotator/image_series_annotator.py @@ -1,14 +1,14 @@ import os - from glob import glob from pathlib import Path from typing import List, Optional, Union, Tuple import numpy as np import imageio.v3 as imageio -import napari + import torch +import napari from magicgui import magicgui from qtpy import QtWidgets diff --git a/micro_sam/sam_annotator/training_ui.py b/micro_sam/sam_annotator/training_ui.py index 0c725584d..f36be0c96 100644 --- a/micro_sam/sam_annotator/training_ui.py +++ b/micro_sam/sam_annotator/training_ui.py @@ -1,15 +1,17 @@ import os import warnings -import torch -import torch_em -from napari.qt.threading import thread_worker from qtpy import QtWidgets +from napari.qt.threading import thread_worker + +import torch from torch.utils.data import random_split +import torch_em + import micro_sam.util as util -import micro_sam.sam_annotator._widgets as widgets from ._tooltips import get_tooltip +import micro_sam.sam_annotator._widgets as widgets from micro_sam.training import default_sam_dataset, train_sam_for_configuration, CONFIGURATIONS diff --git a/micro_sam/sam_annotator/util.py b/micro_sam/sam_annotator/util.py index d3d7525ff..aae5a75ef 100644 --- a/micro_sam/sam_annotator/util.py +++ b/micro_sam/sam_annotator/util.py @@ -1,8 +1,7 @@ -import argparse import os import pickle import warnings - +import argparse from glob import glob from pathlib import Path from typing import List, Optional, Tuple @@ -10,9 +9,8 @@ import h5py import napari import numpy as np - -from scipy.ndimage import shift from skimage import draw +from scipy.ndimage import shift from .. import prompt_based_segmentation, util from .. import _model_settings as model_settings diff --git a/micro_sam/sample_data.py b/micro_sam/sample_data.py index 311d50080..8f636ce90 100644 --- a/micro_sam/sample_data.py +++ b/micro_sam/sample_data.py @@ -16,13 +16,13 @@ from pathlib import Path from typing import Union -import imageio.v3 as imageio -import numpy as np import pooch +import numpy as np +import imageio.v3 as imageio -from skimage.data import binary_blobs from skimage.measure import label from skimage.transform import resize +from skimage.data import binary_blobs from .util import get_cache_directory diff --git a/micro_sam/training/joint_sam_trainer.py b/micro_sam/training/joint_sam_trainer.py index 08ab8c393..db59408ed 100644 --- a/micro_sam/training/joint_sam_trainer.py +++ b/micro_sam/training/joint_sam_trainer.py @@ -13,6 +13,19 @@ class JointSamTrainer(SamTrainer): + """Trainer class for jointly training the Segment Anything model with an additional convolutional decoder. + + This class is inherited from `SamTrainer`. + Check out https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/training/sam_trainer.py + for details on its implementation. + + Args: + unetr: The UNet-style model with vision transformer as the image encoder. + Required to perform automatic instance segmentation. + instance_loss: The loss to compare the predictions (for instance segmentation) and the targets. + instance_metric: The metric to compare the predictions and the targets. + kwargs: The keyword arguments of the `SamTrainer` (and `DefaultTrainer`) class. + """ def __init__( self, unetr: torch.nn.Module, @@ -60,6 +73,9 @@ def load_checkpoint(self, checkpoint="best"): return save_dict def _instance_iteration(self, x, y, metric_for_val=False): + """Perform the segmentation of distance maps and + compute the loss (and metric) between the prediction and target. + """ outputs = self.unetr(x.to(self.device)) loss = self.instance_loss(outputs, y.to(self.device)) if metric_for_val: diff --git a/micro_sam/training/sam_trainer.py b/micro_sam/training/sam_trainer.py index 268cca7d5..020413e29 100644 --- a/micro_sam/training/sam_trainer.py +++ b/micro_sam/training/sam_trainer.py @@ -5,10 +5,11 @@ from typing import Optional import numpy as np -import torch -import torch_em +import torch from torchvision.utils import make_grid + +import torch_em from torch_em.trainer.logger_base import TorchEmLogger from ..prompt_generators import PromptGeneratorBase, IterativePromptGenerator @@ -32,7 +33,7 @@ class SamTrainer(torch_em.trainer.DefaultTrainer): prompt_generator: The iterative prompt generator which takes care of the iterative prompting logic for training mask_prob: The probability of using the mask inputs in the iterative prompting (per `n_sub_iteration`) mask_loss: The loss to compare the predicted masks and the targets. - **kwargs: The keyword arguments of the DefaultTrainer super class. + kwargs: The keyword arguments of the DefaultTrainer super class. """ def __init__( diff --git a/micro_sam/training/semantic_sam_trainer.py b/micro_sam/training/semantic_sam_trainer.py index cb136c30f..46a5f8fac 100644 --- a/micro_sam/training/semantic_sam_trainer.py +++ b/micro_sam/training/semantic_sam_trainer.py @@ -9,6 +9,14 @@ class CustomDiceLoss(nn.Module): + """Loss for computing dice over one-hot labels. + + Expects prediction and target with `num_classes` channels: the number of classes for semantic segmentation. + + Args: + num_classes: The number of classes for semantic segmentation (including background class). + softmax: Whether to use softmax over the predictions. + """ def __init__(self, num_classes: int, softmax: bool = True) -> None: super().__init__() self.num_classes = num_classes @@ -32,7 +40,18 @@ def __call__(self, pred, target): class SemanticSamTrainer(DefaultTrainer): - """ + """Trainer class for training the Segment Anything model for semantic segmentation. + + This class is derived from `torch_em.trainer.DefaultTrainer`. + Check out https://github.com/constantinpape/torch-em/blob/main/torch_em/trainer/default_trainer.py + for details on its usage and implementation. + + Args: + convert_inputs: The class that converts outputs of the dataloader to the expected input format of SAM. + The class `micro_sam.training.util.ConvertToSemanticSamInputs` can be used here. + num_classes: The number of classes for semantic segmentation (including the background class). + dice_weight: The weighing for the dice loss in the combined dice-cross entropy loss function. + kwargs: The keyword arguments of the DefaultTrainer super class. """ def __init__( self, @@ -62,6 +81,8 @@ def __init__( self._kwargs = kwargs def _compute_loss(self, y, masks): + """Compute the combined (weighted) dice loss and cross-entropy loss between the prediction and target. + """ target = y.to(self.device, non_blocking=True) # Compute dice loss for the predictions dice_loss = self.loss(masks, target) @@ -77,6 +98,8 @@ def _compute_loss(self, y, masks): return net_loss def _get_model_outputs(self, batched_inputs): + """Get the predictions from the model. + """ # Precompute the image embeddings if the model exposes it as functionality. if hasattr(self.model, "image_embeddings_oft"): image_embeddings, batched_inputs = self.model.image_embeddings_oft(batched_inputs) diff --git a/micro_sam/training/simple_sam_trainer.py b/micro_sam/training/simple_sam_trainer.py index 984e41fac..a0b06341a 100644 --- a/micro_sam/training/simple_sam_trainer.py +++ b/micro_sam/training/simple_sam_trainer.py @@ -5,6 +5,15 @@ class SimpleSamTrainer(SamTrainer): """Trainer class for creating a simple SAM trainer for limited prompt-based segmentation. + + This class is inherited from `SamTrainer`. + Check out https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/training/sam_trainer.py + for details on its implementation. + + Args: + use_points: Whether to use point prompts for interactive segmentation. + use_box: Whether to use box prompts for interactive segmentation. + kwargs: The keyword arguments of the `SamTrainer` (and `DefaultTrainer`) class. """ def __init__( self, @@ -28,20 +37,21 @@ def __init__( assert (self.use_points + self.use_box) != 0, "Please choose at least one of the prompt-based method." def _choose_one_positive_point(self): - "samples only a single positive point per object" + """Samples only a single positive point per object + """ n_pos, n_neg = 1, 0 multimask_output = True return n_pos, n_neg, None, multimask_output def _choose_box(self): - "samples only a single box per object" + """Samples only a single box per object + """ n_pos, n_neg = 0, 0 multimask_output = False get_boxes = True return n_pos, n_neg, get_boxes, multimask_output def _get_prompt_and_multimasking_choices(self, current_iteration): - if self.random_prompt_choice: # both "use_points" and "use_box" are True available_choices = [self._choose_one_positive_point(), self._choose_box()] return random.choice(available_choices) @@ -57,6 +67,11 @@ def _get_prompt_and_multimasking_choices_for_val(self, current_iteration): class MedSAMTrainer(SimpleSamTrainer): """Trainer class for replicating the trainer of MedSAM (https://arxiv.org/abs/2304.12306). + + This class is inherited from `SimpleSamTrainer`. + Check out + https://github.com/computational-cell-analytics/micro-sam/blob/master/micro_sam/training/simple_sam_trainer.py + for details on its implementation. """ def __init__(self, **kwargs): super().__init__( diff --git a/micro_sam/training/trainable_sam.py b/micro_sam/training/trainable_sam.py index 72a3ebe62..dbc206d30 100644 --- a/micro_sam/training/trainable_sam.py +++ b/micro_sam/training/trainable_sam.py @@ -15,10 +15,7 @@ class TrainableSAM(nn.Module): Args: sam: The SegmentAnything Model. """ - def __init__( - self, - sam: Sam, - ) -> None: + def __init__(self, sam: Sam) -> None: super().__init__() self.sam = sam self.transform = ResizeLongestSide(sam.image_encoder.img_size) @@ -62,10 +59,7 @@ def image_embeddings_oft(self, batched_inputs): # batched inputs follow the same syntax as the input to sam.forward def forward( - self, - batched_inputs: List[Dict[str, Any]], - image_embeddings: torch.Tensor, - multimask_output: bool = False, + self, batched_inputs: List[Dict[str, Any]], image_embeddings: torch.Tensor, multimask_output: bool = False, ) -> List[Dict[str, Any]]: """Forward pass. diff --git a/micro_sam/training/training.py b/micro_sam/training/training.py index 39314e7fc..4e25e595b 100644 --- a/micro_sam/training/training.py +++ b/micro_sam/training/training.py @@ -178,11 +178,13 @@ def train_sam( mask_prob: The probability for using a mask as input in a given training sub-iteration. n_iterations: The number of iterations to use for training. This will over-ride n_epochs if given. scheduler_class: The learning rate scheduler to update the learning rate. - By default, ReduceLROnPlateau is used. + By default, torch.optim.lr_scheduler.ReduceLROnPlateau is used. scheduler_kwargs: The learning rate scheduler parameters. If passed None, the chosen default parameters are used in ReduceLROnPlateau. save_every_kth_epoch: Save checkpoints after every kth epoch separately. pbar_signals: Controls for napari progress bar. + optimizer_class: The optimizer class. + By default, torch.optim.AdamW is used. """ _check_loader(train_loader, with_segmentation_decoder) _check_loader(val_loader, with_segmentation_decoder) diff --git a/micro_sam/training/util.py b/micro_sam/training/util.py index 4ba56961e..759c905e8 100644 --- a/micro_sam/training/util.py +++ b/micro_sam/training/util.py @@ -227,7 +227,8 @@ def __call__(self, x, y, n_pos, n_neg, get_boxes=False, n_samples=None): class ConvertToSemanticSamInputs: - """ + """Convert outputs of data loader to the expected batched inputs of the SegmentAnything model + for semantic segmentation. """ def __call__(self, x, y): """Convert the outputs of dataloader to the batched format of inputs expected by SAM. diff --git a/micro_sam/util.py b/micro_sam/util.py index 09d06c09c..45550a493 100644 --- a/micro_sam/util.py +++ b/micro_sam/util.py @@ -2,26 +2,28 @@ Helper functions for downloading Segment Anything models and predicting image embeddings. """ -import hashlib + import os import pickle +import hashlib import warnings -from collections import OrderedDict from pathlib import Path +from collections import OrderedDict from typing import Any, Dict, Iterable, Optional, Tuple, Union -import imageio.v3 as imageio -import numpy as np -import pooch -import torch +import zarr import vigra +import torch +import pooch import xxhash -import zarr +import numpy as np +import imageio.v3 as imageio +from skimage.measure import regionprops +from skimage.segmentation import relabel_sequential from elf.io import open_file + from nifty.tools import blocking -from skimage.measure import regionprops -from skimage.segmentation import relabel_sequential from .__version__ import __version__ from . import models as custom_models diff --git a/micro_sam/visualization.py b/micro_sam/visualization.py index c931985fc..ad4e9d00b 100644 --- a/micro_sam/visualization.py +++ b/micro_sam/visualization.py @@ -4,10 +4,11 @@ from typing import Tuple import numpy as np +from skimage.transform import resize -from elf.segmentation.embeddings import embedding_pca from nifty.tools import blocking -from skimage.transform import resize + +from elf.segmentation.embeddings import embedding_pca from .util import ImageEmbeddings From 0d7f96b9279a8fec822765eac32b9579e2f01e4c Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Fri, 26 Jul 2024 12:51:03 +0200 Subject: [PATCH 24/53] Add plotting scripts for AIS experiments (#666) Plotting scripts for AIS ablation --- .../plotting/for_paper/plot_ais_ablation.py | 153 ++++++++++++++++++ 1 file changed, 153 insertions(+) create mode 100644 scripts/plotting/for_paper/plot_ais_ablation.py diff --git a/scripts/plotting/for_paper/plot_ais_ablation.py b/scripts/plotting/for_paper/plot_ais_ablation.py new file mode 100644 index 000000000..aebdc8c75 --- /dev/null +++ b/scripts/plotting/for_paper/plot_ais_ablation.py @@ -0,0 +1,153 @@ +import pandas as pd +import seaborn as sns +from natsort import natsorted + +import matplotlib.pyplot as plt + + +base_color = '#0562A0' +highlight_color = '#045275' +plt.rcParams.update({'font.size': 30}) + + +# NOTE: the score formats below are a list of numbers: [X, Y, Z], +# where: X is the mSA, Y is SA50 and Z is SA75 + +LIVECELL_AIS = { + "unet": [0.4188, 0.699752, 0.443877], + "unetr_scratch": [0.415419, 0.699897, 0.439006], + "unetr_sam": [0.445632, 0.726114, 0.479634], + "semanticsam_scratch": [0.386169, 0.671345, 0.401836], + "semanticsam_sam": [0.428852, 0.706803, 0.45969] +} + +COVID_IF_AIS = { + "1": { + "unet": [0.124261, 0.306542, 0.085534], + "unetr_scratch": [0.150799, 0.372263, 0.101136], + "unetr_sam": [0.282399, 0.555058, 0.25503], + "semanticsam_scratch": [0.09322, 0.238215, 0.0615], + "semanticsam_sam": [0.299337, 0.612757, 0.264384] + }, + "2": { + "unet": [0.194456, 0.426158, 0.160465], + "unetr_scratch": [0.203448, 0.439231, 0.172646], + "unetr_sam": [0.308674, 0.584671, 0.290992], + "semanticsam_scratch": [0.117305, 0.285744, 0.083979], + "semanticsam_sam": [0.311751, 0.632971, 0.281148] + }, + "5": { + "unet": [0.243485, 0.495585, 0.219], + "unetr_scratch": [0.250491, 0.52194, 0.221091], + "unetr_sam": [0.362728, 0.683941, 0.343065], + "semanticsam_scratch": [0.136756, 0.32772, 0.100696], + "semanticsam_sam": [0.320606, 0.649073, 0.290766] + }, + "10": { + "unet": [0.29883, 0.588136, 0.280681], + "unetr_scratch": [0.286946, 0.571417, 0.264325], + "unetr_sam": [0.401787, 0.729247, 0.39796], + "semanticsam_scratch": [0.145352, 0.353673, 0.104027], + "semanticsam_sam": [0.375741, 0.729203, 0.354669] + } +} + +MODEL_NAME_MAPS = { + "unet": "UNet", + "unetr_scratch": "UNETR\n$\it{(scratch)}$", + "unetr_sam": "UNETR\n$\it{(SAM)}$", + "semanticsam_scratch": "SamDecoder\n$\it{(scratch)}$", + "semanticsam_sam": "SamDecoder\n$\it{(SAM)}$" +} + +COLORS = { + 'unet': '#FCDE9C', + 'unetr_scratch': '#045275', + 'unetr_sam': '#045275', + 'semanticsam_scratch': '#F0746E', + 'semanticsam_sam': '#F0746E', +} + + +def make_livecell_barplot(): + labels = list(LIVECELL_AIS.keys()) + model_labels = [MODEL_NAME_MAPS[model] for model in labels] + scores = [LIVECELL_AIS[model][0] for model in labels] + + data = {"Model": model_labels, "Score": scores} + df = pd.DataFrame(data) + + plt.figure(figsize=(20, 15)) + bars = sns.barplot(x="Model", y="Score", data=df, hue='Model', legend=False, palette=list(COLORS.values())) + + for i, bar in enumerate(bars.patches): + if df["Model"][i] in [MODEL_NAME_MAPS["unetr_sam"], MODEL_NAME_MAPS["semanticsam_sam"]]: + bar.set_hatch("//") + bar.set_edgecolor('white') + bar.set_linewidth(5) + + plt.xlabel(None) + plt.ylabel("Mean Segmentation Accuracy", fontweight="bold") + plt.title("Automatic Instance Segmentation (LIVECell)") + plt.ylim(0, max(scores) + 0.05) + + plt.gca().yaxis.labelpad = 30 + plt.gca().xaxis.labelpad = 20 + + yticks = [i * 0.05 for i in range(1, int(max(scores) / 0.05) + 2)] + plt.yticks(yticks) + + plt.tight_layout() + plt.savefig("s14_1.png") + plt.savefig("s14_1.svg") + plt.savefig("s14_1.pdf") + + +def make_covid_if_lineplot(): + markers = { + 'unet': 'o', 'unetr_scratch': 'o', 'unetr_sam': 'o', 'semanticsam_scratch': 'o', 'semanticsam_sam': 'o', + } + line_styles = { + 'unet': '-', 'unetr_scratch': '-', 'unetr_sam': '-.', 'semanticsam_scratch': '-', 'semanticsam_sam': '-.', + } + + x = natsorted(COVID_IF_AIS.keys()) + models = list(COVID_IF_AIS[x[0]].keys()) + + data = [] + for key in x: + for model in models: + data.append({'Key': key, 'Model': model, 'Score': COVID_IF_AIS[key][model][0]}) + + df = pd.DataFrame(data) + + plt.figure(figsize=(20, 15)) + for model in models: + sns.lineplot( + data=df[df["Model"] == model], x='Key', y='Score', + marker=markers[model], linestyle=line_styles[model], + markersize=15, linewidth=2.5, label=MODEL_NAME_MAPS[model], + color=COLORS[model], + ) + + plt.xlabel("Number of Images", fontweight="bold") + plt.ylabel("Mean Segmentation Accuracy", fontweight="bold") + plt.title("Automatic Instance Segmentation (Covid IF)") + + plt.gca().yaxis.labelpad = 30 + plt.gca().xaxis.labelpad = 20 + + plt.legend(loc="upper center", bbox_to_anchor=(0.5, -0.1), ncol=5, handletextpad=0.5, columnspacing=1) + + plt.tight_layout() + plt.savefig("s14_2.png") + plt.savefig("s14_2.svg") + plt.savefig("s14_2.pdf") + + +def main(): + make_livecell_barplot() + make_covid_if_lineplot() + + +main() From 70d5002880b3139cae2317f3697ebe43fd288fc8 Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Wed, 31 Jul 2024 08:38:03 +0200 Subject: [PATCH 25/53] Add livecell lora plots (#674) Supp. - lineplot for livecell lora experiments --- finetuning/livecell/lora/train_livecell.py | 6 - micro_sam/training/training.py | 6 + .../plotting/for_paper/plot_livecell_lora.py | 130 ++++++++++++++++++ 3 files changed, 136 insertions(+), 6 deletions(-) create mode 100644 scripts/plotting/for_paper/plot_livecell_lora.py diff --git a/finetuning/livecell/lora/train_livecell.py b/finetuning/livecell/lora/train_livecell.py index 6b12ac611..31977b217 100644 --- a/finetuning/livecell/lora/train_livecell.py +++ b/finetuning/livecell/lora/train_livecell.py @@ -39,12 +39,6 @@ def get_dataloaders(patch_shape, data_path, cell_type=None): return train_loader, val_loader -def count_parameters(model): - params = sum(p.numel() for p in model.parameters() if p.requires_grad) - params = params / 1e6 - return f"The number of trainable parameters for the provided model is {round(params, 2)}M" - - def finetune_livecell(args): """Code for finetuning SAM (using LoRA) on LIVECell """ diff --git a/micro_sam/training/training.py b/micro_sam/training/training.py index 4e25e595b..bb31fa380 100644 --- a/micro_sam/training/training.py +++ b/micro_sam/training/training.py @@ -126,6 +126,12 @@ def set_description(self, desc, **kwargs): self._signals.pbar_description.emit(desc) +def _count_parameters(model_parameters): + params = sum(p.numel() for p in model_parameters if p.requires_grad) + params = params / 1e6 + print(f"The number of trainable parameters for the provided model is {round(params, 2)}M") + + def train_sam( name: str, model_type: str, diff --git a/scripts/plotting/for_paper/plot_livecell_lora.py b/scripts/plotting/for_paper/plot_livecell_lora.py new file mode 100644 index 000000000..1a5d2af4d --- /dev/null +++ b/scripts/plotting/for_paper/plot_livecell_lora.py @@ -0,0 +1,130 @@ +import os +from glob import glob +from natsort import natsorted + +import numpy as np +import pandas as pd +import seaborn as sns +import matplotlib.pyplot as plt + + +ROOT = "/media/anwai/ANWAI/micro-sam/for_revision_2/livecell_results" + +PALETTE = { + "AIS": "#045275", + "AMG": "#FCDE9C", + "Point": "#7CCBA2", + r"I$_{P}$": "#089099", + "Box": "#90477F", + r"I$_{B}$": "#F0746E" +} + +NAME_MAPS = { + "vanilla": "Default", + "lora_1": "LoRA\n(Rank 1)", # 15.13M + "lora_2": "LoRA\n(Rank 2)", # 15.17M + "lora_4": "LoRA\n(Rank 4)", # 15.24M + "lora_8": "LoRA\n(Rank 8)", # 15.39M + "lora_16": "LoRA\n(Rank 16)", # 15.68M + "full_ft": "Full\nFinetuning", # 104.76M +} + +plt.rcParams.update({"font.size": 30}) + + +def _get_livecell_lora_data(): + # experiments from carolin on livecell lora + all_results = [] + all_experiments_dir = natsorted(glob(os.path.join(ROOT, "*"))) + for experiment_dir in all_experiments_dir: + experiment_name = os.path.split(experiment_dir)[-1] + + ais = pd.read_csv(os.path.join(experiment_dir, "results", "instance_segmentation_with_decoder.csv")) + amg = pd.read_csv(os.path.join(experiment_dir, "results", "amg.csv")) + ip = pd.read_csv(os.path.join(experiment_dir, "results", "iterative_prompts_start_point.csv")) + ib = pd.read_csv(os.path.join(experiment_dir, "results", "iterative_prompts_start_box.csv")) + + res = { + "experiment": experiment_name, + "AIS": ais.iloc[0]["msa"], + "AMG": amg.iloc[0]["msa"], + "Point": ip.iloc[0]["msa"], + "Box": ib.iloc[0]["msa"], + r"I$_{P}$": ip.iloc[-1]["msa"], + r"I$_{B}$": ib.iloc[-1]["msa"] + } + all_results.append(pd.DataFrame.from_dict([res])) + + # NOTE: this is done to plot "full_finetuning" results at the end of the lineplot. + all_results = all_results[1:] + [all_results[0]] + + return all_results + + +def _get_vanilla_and_finetuned_results(): + all_results = _get_livecell_lora_data() + + def _get_results(method): + assert method in ["vanilla", "specialist"] + root_dir = f"/home/anwai/results/micro-sam/livecell/{method}/vit_b" + + amg = pd.read_csv(os.path.join(root_dir, "amg.csv")) + ip = pd.read_csv(os.path.join(root_dir, "iterative_prompts_start_point.csv")) + ib = pd.read_csv(os.path.join(root_dir, "iterative_prompts_start_box.csv")) + + have_ais = False + if method == "specialist": + ais = pd.read_csv(os.path.join(root_dir, "instance_segmentation_with_decoder.csv")) + have_ais = True + + res = { + "experiment": method, + "AMG": amg.iloc[0]["msa"], + "Point": ip.iloc[0]["msa"], + "Box": ib.iloc[0]["msa"], + r"I$_{P}$": ip.iloc[-1]["msa"], + r"I$_{B}$": ib.iloc[-1]["msa"] + } + if have_ais: + res["AIS"] = ais.iloc[0]["msa"] + + return pd.DataFrame.from_dict([res]) + + all_results.insert(0, _get_results("vanilla")) + res_df = pd.concat(all_results, ignore_index=True) + return res_df + + +def _get_plots(): + plt.figure(figsize=(20, 15)) + res = _get_vanilla_and_finetuned_results() + ax = sns.lineplot( + data=pd.melt(res, "experiment"), + x="experiment", y="value", hue="variable", marker="d", + palette=PALETTE, markersize=20, linewidth=3, + ) + + ax.set_yticks(np.linspace(0, 1, 11)[:-2]) + + plt.ylabel("Mean Segmentation Accuracy", labelpad=10, fontweight="bold") + plt.xlabel("Finetuning Strategy", labelpad=10, fontweight="bold") + plt.legend(loc="lower center", ncol=7) + + plt.xticks(np.arange(7), [exp_name for exp_name in NAME_MAPS.values()]) + + plt.gca().yaxis.labelpad = 30 + plt.gca().xaxis.labelpad = 20 + + plt.title("") + plt.tight_layout() + plt.savefig("s14_c.png") + plt.savefig("s14_c.svg") + plt.savefig("s14_c.pdf") + + +def main(): + _get_plots() + + +if __name__ == "__main__": + main() From 6f7bf0aac40ebb826cf92c4225f0846f1e4d294d Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Wed, 31 Jul 2024 16:57:26 +0200 Subject: [PATCH 26/53] Minor updates to resource efficient finetuning experiments (#665) * Update resource efficient finetuning experiments --- finetuning/evaluation/evaluate_amg.py | 17 +- .../evaluate_instance_segmentation.py | 19 ++- finetuning/evaluation/iterative_prompting.py | 2 +- .../evaluation/precompute_embeddings.py | 2 +- finetuning/evaluation/util.py | 16 +- .../specialists/resource-efficient/README.md | 100 ++++++++++-- .../check_training_times.py | 62 ++++---- .../resource-efficient/covid_if_finetuning.py | 7 +- .../resource-efficient/plot_experiments.py | 150 ++++++++++++++---- .../resource-efficient/run_evaluations.py | 120 +++++++------- .../run_resource_efficient_finetuning.py | 43 +++-- micro_sam/evaluation/inference.py | 8 +- micro_sam/instance_segmentation.py | 9 +- 13 files changed, 390 insertions(+), 165 deletions(-) diff --git a/finetuning/evaluation/evaluate_amg.py b/finetuning/evaluation/evaluate_amg.py index 69ec63efa..8f8f132d4 100644 --- a/finetuning/evaluation/evaluate_amg.py +++ b/finetuning/evaluation/evaluate_amg.py @@ -7,16 +7,17 @@ from util import get_pred_paths, get_default_arguments, VANILLA_MODELS -def run_amg_inference(dataset_name, model_type, checkpoint, experiment_folder): +def run_amg_inference(dataset_name, model_type, checkpoint, experiment_folder, lora_rank): val_image_paths, val_gt_paths = get_paths(dataset_name, split="val") test_image_paths, _ = get_paths(dataset_name, split="test") prediction_folder = run_amg( - checkpoint, - model_type, - experiment_folder, - val_image_paths, - val_gt_paths, - test_image_paths + checkpoint=checkpoint, + model_type=model_type, + experiment_folder=experiment_folder, + val_image_paths=val_image_paths, + val_gt_paths=val_gt_paths, + test_image_paths=test_image_paths, + lora_rank=lora_rank, ) return prediction_folder @@ -37,7 +38,7 @@ def main(): else: ckpt = args.checkpoint - prediction_folder = run_amg_inference(args.dataset, args.model, ckpt, args.experiment_folder) + prediction_folder = run_amg_inference(args.dataset, args.model, ckpt, args.experiment_folder, args.lora_rank) eval_amg(args.dataset, prediction_folder, args.experiment_folder) diff --git a/finetuning/evaluation/evaluate_instance_segmentation.py b/finetuning/evaluation/evaluate_instance_segmentation.py index 70da76356..c41e9fb47 100644 --- a/finetuning/evaluation/evaluate_instance_segmentation.py +++ b/finetuning/evaluation/evaluate_instance_segmentation.py @@ -7,16 +7,19 @@ from util import get_pred_paths, get_default_arguments -def run_instance_segmentation_with_decoder_inference(dataset_name, model_type, checkpoint, experiment_folder): +def run_instance_segmentation_with_decoder_inference( + dataset_name, model_type, checkpoint, experiment_folder, lora_rank +): val_image_paths, val_gt_paths = get_paths(dataset_name, split="val") test_image_paths, _ = get_paths(dataset_name, split="test") prediction_folder = run_instance_segmentation_with_decoder( - checkpoint, - model_type, - experiment_folder, - val_image_paths, - val_gt_paths, - test_image_paths + checkpoint=checkpoint, + model_type=model_type, + experiment_folder=experiment_folder, + val_image_paths=val_image_paths, + val_gt_paths=val_gt_paths, + test_image_paths=test_image_paths, + lora_rank=lora_rank, ) return prediction_folder @@ -34,7 +37,7 @@ def main(): args = get_default_arguments() prediction_folder = run_instance_segmentation_with_decoder_inference( - args.dataset, args.model, args.checkpoint, args.experiment_folder + args.dataset, args.model, args.checkpoint, args.experiment_folder, args.lora_rank, ) eval_instance_segmentation_with_decoder(args.dataset, prediction_folder, args.experiment_folder) diff --git a/finetuning/evaluation/iterative_prompting.py b/finetuning/evaluation/iterative_prompting.py index 08c0cf3b5..eae3f8450 100644 --- a/finetuning/evaluation/iterative_prompting.py +++ b/finetuning/evaluation/iterative_prompting.py @@ -42,7 +42,7 @@ def main(): start_with_box_prompt = args.box # overwrite to start first iters' prompt with box instead of single point # get the predictor to perform inference - predictor = get_model(model_type=args.model, ckpt=args.checkpoint) + predictor = get_model(model_type=args.model, ckpt=args.checkpoint, lora_rank=args.lora_rank) prediction_root = _run_iterative_prompting( args.dataset, args.experiment_folder, predictor, start_with_box_prompt, args.use_masks diff --git a/finetuning/evaluation/precompute_embeddings.py b/finetuning/evaluation/precompute_embeddings.py index 438cba591..605627feb 100644 --- a/finetuning/evaluation/precompute_embeddings.py +++ b/finetuning/evaluation/precompute_embeddings.py @@ -9,7 +9,7 @@ def main(): args = get_default_arguments() - predictor = get_model(model_type=args.model, ckpt=args.checkpoint) + predictor = get_model(model_type=args.model, ckpt=args.checkpoint, lora_rank=args.lora_rank) embedding_dir = os.path.join(args.experiment_folder, "embeddings") os.makedirs(embedding_dir, exist_ok=True) diff --git a/finetuning/evaluation/util.py b/finetuning/evaluation/util.py index 8b1716e89..9780cc70b 100644 --- a/finetuning/evaluation/util.py +++ b/finetuning/evaluation/util.py @@ -14,10 +14,10 @@ EXPERIMENT_ROOT = "/scratch/projects/nim00007/sam/experiments/new_models" VANILLA_MODELS = { - "vit_t": "/scratch-grete/projects/nim00007/sam/models/new_models/vanilla/vit_t_mobile_sam.pth", - "vit_b": "/scratch-grete/projects/nim00007/sam/models/new_models/vanilla/sam_vit_b_01ec64.pth", - "vit_l": "/scratch-grete/projects/nim00007/sam/models/new_models/vanilla/sam_vit_l_0b3195.pth", - "vit_h": "/scratch-grete/projects/nim00007/sam/models/new_models/vanilla/sam_vit_h_4b8939.pth" + "vit_t": "/scratch-grete/projects/nim00007/sam/models/vanilla/vit_t_mobile_sam.pth", + "vit_b": "/scratch-grete/projects/nim00007/sam/models/vanilla/sam_vit_b_01ec64.pth", + "vit_l": "/scratch-grete/projects/nim00007/sam/models/vanilla/sam_vit_l_0b3195.pth", + "vit_h": "/scratch-grete/projects/nim00007/sam/models/vanilla/sam_vit_h_4b8939.pth" } @@ -80,10 +80,13 @@ def get_dataset_paths(dataset_name, split_choice): return raw_dir, labels_dir -def get_model(model_type, ckpt): +def get_model(model_type, ckpt, lora_rank): if ckpt is None: ckpt = VANILLA_MODELS[model_type] - predictor = get_sam_model(model_type=model_type, checkpoint_path=ckpt) + + predictor = get_sam_model( + model_type=model_type, checkpoint_path=ckpt, lora_rank=lora_rank, + ) return predictor @@ -226,6 +229,7 @@ def get_default_arguments(): parser.add_argument( "--use_masks", action="store_true", help="To use logits masks for iterative prompting." ) + parser.add_argument("--lora_rank", default=None, type=int, help="The rank for low rank adaptation method.") args = parser.parse_args() return args diff --git a/finetuning/specialists/resource-efficient/README.md b/finetuning/specialists/resource-efficient/README.md index 48b80c765..482a51d84 100644 --- a/finetuning/specialists/resource-efficient/README.md +++ b/finetuning/specialists/resource-efficient/README.md @@ -8,9 +8,9 @@ TLDR: Finetuning ViT Base (`vit_b`) is the best bet on most workstation / cluste ## Available Resource Combinations: - `medium` (CPU - SCC) -- `gtx1080`: (GPU - SCC) 8GB -- `rtx5000`: (GPU - SCC) 16GB -- `v100`: (GPU - SCC) 32GB +- `GTX1080`: (GPU - SCC) 8GB +- `RTX5000`: (GPU - SCC) 16GB +- `V100`: (GPU - SCC) 32GB ## Experimental Combinations: - `vit_t` / `vit_b` (ideally, fewer the parameters, the better for our use-case here) @@ -19,7 +19,6 @@ TLDR: Finetuning ViT Base (`vit_b`) is the best bet on most workstation / cluste - number of objects per batch (? - depends on the maximum number which we can fit on the respective resource) ## Inference: - - Using default Segment Anything - Using `vit__lm` micro-sam (LM generalist) - Using finetuned Segment Anything `vit__covid-if` (training a `covid-if` specialist) @@ -34,29 +33,29 @@ Fixed parameters: - training and validation batch size - `1` - minimum number of training "samples" for training on the provided images - min. **`50`** (oversample while min. 50 training samples not found) (this is done to avoid the exhaustive time constraints while training with only 1 training sample) - learning rate: `1e-5` -- optimizer: `Adam` +- optimizer: `AdamW` - lr scheduler: `ReduceLRonPlateau` - early stopping: `10` - patch shape: `(512, 512)` -- choice of models: `vit_t` / `vit_b` +- choice of models: `vit_t` / `vit_b` / `vit_t_lm` / `vit_b_lm` ### GPU Resources (32GB CPU memory, 8 CPU cores) -1. `gtx1080`: +1. `GTX1080`: - `vit_t`: finetune all layers - `n_objects`: 5 - `vit_b`: freeze `image_encoder` - `n_objects`: 10 -2. `rtx5000`: +2. `RTX5000`: - `vit_t`: (finetune all layers) - `n_objects`: 20 - `vit_b`: (finetune all layers) - `n_objects`: 10 -3. `v100`: +3. `V100`: - `vit_t`: (finetune all layers) - `n_objects`: 45 - `vit_b`: (finetune all layers) @@ -92,7 +91,8 @@ All jobs are tested on `medium` partition. - ## Results: + + + + + ## Results: + +| Resource | Finetuned Model | Number of Images | Best Epoch | Train Time | +|----------|-----------------------------------|------------------|------------|------------| +| V100 | vit_b (Full Finetuning) | 1 | 3 | 0:05:07 | +| V100 | vit_b (Full Finetuning) | 2 | 10 | 0:14:01 | +| V100 | vit_b (Full Finetuning) | 5 | 10 | 0:14:09 | +| V100 | vit_b (Full Finetuning) | 10 | 20 | 0:26:24 | +| V100 | vit_b (LoRA) | 1 | 32 | 0:39:32 | +| V100 | vit_b (LoRA) | 2 | 58 | 1:10:25 | +| V100 | vit_b (LoRA) | 5 | 13 | 0:16:40 | +| V100 | vit_b (LoRA) | 10 | 42 | 0:51:10 | +| V100 | vit_b_lm (Full Finetuning) | 1 | 1 | 0:02:33 | +| V100 | vit_b_lm (Full Finetuning) | 2 | 4 | 0:06:19 | +| V100 | vit_b_lm (Full Finetuning) | 5 | 12 | 0:16:14 | +| V100 | vit_b_lm (Full Finetuning) | 10 | 2 | 0:03:48 | +| V100 | vit_b_lm (LoRA) | 1 | 8 | 0:10:45 | +| V100 | vit_b_lm (LoRA) | 2 | 23 | 0:28:33 | +| V100 | vit_b_lm (LoRA) | 5 | 22 | 0:27:23 | +| V100 | vit_b_lm (LoRA) | 10 | 5 | 0:07:11 | +| RTX5000 | vit_b (Full Finetuning) | 1 | 13 | 0:15:09 | +| RTX5000 | vit_b (Full Finetuning) | 2 | 13 | 0:15:00 | +| RTX5000 | vit_b (Full Finetuning) | 5 | 20 | 0:22:29 | +| RTX5000 | vit_b (Full Finetuning) | 10 | 43 | 0:46:55 | +| RTX5000 | vit_b (LoRA) | 1 | 46 | 0:48:30 | +| RTX5000 | vit_b (LoRA) | 2 | 23 | 0:24:53 | +| RTX5000 | vit_b (LoRA) | 5 | 39 | 0:41:14 | +| RTX5000 | vit_b (LoRA) | 10 | 16 | 0:17:37 | +| RTX5000 | vit_b_lm (Full Finetuning) | 1 | 4 | 0:05:26 | +| RTX5000 | vit_b_lm (Full Finetuning) | 2 | 4 | 0:05:25 | +| RTX5000 | vit_b_lm (Full Finetuning) | 5 | 3 | 0:04:21 | +| RTX5000 | vit_b_lm (Full Finetuning) | 10 | 3 | 0:04:22 | +| RTX5000 | vit_b_lm (LoRA) | 1 | 15 | 0:16:37 | +| RTX5000 | vit_b_lm (LoRA) | 2 | 26 | 0:28:03 | +| RTX5000 | vit_b_lm (LoRA) | 5 | 22 | 0:23:54 | +| RTX5000 | vit_b_lm (LoRA) | 10 | 32 | 0:34:04 | +| GTX1080 | vit_b (Freeze `image_encoder`) | 1 | 6 | 0:13:39 | +| GTX1080 | vit_b (Freeze `image_encoder`) | 2 | 3 | 0:07:55 | +| GTX1080 | vit_b (Freeze `image_encoder`) | 5 | 26 | 0:51:34 | +| GTX1080 | vit_b (Freeze `image_encoder`) | 10 | 40 | 1:18:05 | +| GTX1080 | vit_b_lm (Freeze `image_encoder`) | 1 | 10 | 0:21:30 | +| GTX1080 | vit_b_lm (Freeze `image_encoder`) | 2 | 2 | 0:06:15 | +| GTX1080 | vit_b_lm (Freeze `image_encoder`) | 5 | 7 | 0:15:05 | +| GTX1080 | vit_b_lm (Freeze `image_encoder`) | 10 | 13 | 0:15:05 | +| CPU (32G) | vit_b (Full Finetuning) | 1 | 15 | 3:48:52 | +| CPU (32G) | vit_b (Full Finetuning) | 2 | 18 | 4:36:06 | +| CPU (32G) | vit_b (Full Finetuning) | 5 | 30 | 7:47:20 | +| CPU (32G) | vit_b (Full Finetuning) | 10 | 24 | 5:41:31 | +| CPU (32G) | vit_b (LoRA) | 1 | 26 | 5:21:23 | +| CPU (32G) | vit_b (LoRA) | 2 | 12 | 2:53:41 | +| CPU (32G) | vit_b (LoRA) | 5 | 50 | 11:03:15 | +| CPU (32G) | vit_b (LoRA) | 10 | 13 | 2:57:08 | +| CPU (32G) | vit_b_lm (Full Finetuning) | 1 | 3 | 0:55:36 | +| CPU (32G) | vit_b_lm (Full Finetuning) | 2 | 24 | 5:43:28 | +| CPU (32G) | vit_b_lm (Full Finetuning) | 5 | 1 | 0:16:03 | +| CPU (32G) | vit_b_lm (Full Finetuning) | 10 | 6 | 2:01:30 | +| CPU (32G) | vit_b_lm (LoRA) | 1 | 15 | 3:25:33 | +| CPU (32G) | vit_b_lm (LoRA) | 2 | 9 | 2:58:05 | +| CPU (32G) | vit_b_lm (LoRA) | 5 | 14 | 3:31:14 | +| CPU (32G) | vit_b_lm (LoRA) | 10 | 7 | 1:58:57 | +| CPU (64G) | vit_b (Full Finetuning) | 1 | 6 | 3:20:00 | +| CPU (64G) | vit_b (Full Finetuning) | 2 | 15 | 4:23:10 | +| CPU (64G) | vit_b (Full Finetuning) | 5 | 16 | 4:05:15 | +| CPU (64G) | vit_b (Full Finetuning) | 10 | 15 | 3:51:02 | +| CPU (64G) | vit_b (LoRA) | 1 | 27 | 6:20:52 | +| CPU (64G) | vit_b (LoRA) | 2 | 46 | 19:51:34 | +| CPU (64G) | vit_b (LoRA) | 5 | 29 | 8:01:34 | +| CPU (64G) | vit_b (LoRA) | 10 | 19 | 5:20:02 | +| CPU (64G) | vit_b_lm (Full Finetuning) | 1 | 3 | 1:44:35 | +| CPU (64G) | vit_b_lm (Full Finetuning) | 2 | 10 | 2:57:22 | +| CPU (64G) | vit_b_lm (Full Finetuning) | 5 | 8 | 2:31:04 | +| CPU (64G) | vit_b_lm (Full Finetuning) | 10 | 5 | 1:28:26 | +| CPU (64G) | vit_b_lm (LoRA) | 1 | 16 | 4:39:26 | +| CPU (64G) | vit_b_lm (LoRA) | 2 | 1 | 0:19:46 | +| CPU (64G) | vit_b_lm (LoRA) | 5 | 38 | 9:38:11 | +| CPU (64G) | vit_b_lm (LoRA) | 10 | 15 | 5:42:34 | diff --git a/finetuning/specialists/resource-efficient/check_training_times.py b/finetuning/specialists/resource-efficient/check_training_times.py index 419ac4fcc..61a3c7cdb 100644 --- a/finetuning/specialists/resource-efficient/check_training_times.py +++ b/finetuning/specialists/resource-efficient/check_training_times.py @@ -1,48 +1,52 @@ -# Resource Efficient Finetuning: time taken to achieve the best epoch per setting -# a: rtx5000 -# - (vit_b) 1-image: 1192.79, 2-image: 725.15, 5-image: 3759.01, 10-image: 2427.18 -# - (vit_b_lm) 1-image: 2089.22, 2-image: 1622.69, 5-image: 3477.83, 10-image: 1869.33 - -# b: v100 -# - (vit_b) 1-image: 752.39 (9/100), 2-image: 2051.77 , 5-image: 1653.99, 10-image: 2998.08 -# - (vit_b_lm) 1-image: 1874.83, 2-image: 3205.59 , 5-image: 3196.15, 10-image: 2612.99 - -# c: cpu32gb -# - (vit_b) 1-image: 6302.03, 2-image: 29153.65, 5-image: 53502.85, 10-image: 20885.33 -# - (vit_b_lm) 1-image: 21711.23, 2-image: 34443.09, 5-image: 32750.22, 10-image: 19229.85 - -# d: cpu64gb -# - (vit_b) 1-image: 11439.01, 2-image: 26225.69, 5-image: 18675.01, 10-image: 50894.71 -# - (vit_b_lm) 1-image: 23291.25, 2-image: 40262.73, 5-image: 33137.21, 10-image: 47490.61 - - import os from glob import glob +from natsort import natsorted + +import pandas as pd from micro_sam.util import _load_checkpoint -ROOT = "/scratch/usr/nimanwai/experiments/resource-efficient-finetuning/" +ROOT = "/media/anwai/ANWAI/micro-sam/resource-efficient-finetuning/" -def _load_per_model(checkpoint): +def _stats_per_model(checkpoint): state, model_state = _load_checkpoint(checkpoint) - print("Time taken to train for the best epoch:", state["train_time"]) - print("The best epoch attained at:", state["epoch"]) - print() + time_in_seconds = state["train_time"] + minutes, seconds = divmod(time_in_seconds, 60) + hours, minutes = divmod(minutes, 60) + total_time = "%d:%02d:%02d" % (hours, minutes, seconds) + + # Let's create a dataframe and store all the results. + desired_path = checkpoint[len(ROOT):] + _splits = desired_path.split("/") + experiment_name = _splits[0] + finetuned_model = _splits[1] + "-" + _splits[2] + n_images = _splits[4].split("-")[0] + + outputs = { + "experiment_name": experiment_name, + "finetuned_model": finetuned_model, + "number_of_images": n_images, + "best_epoch": state["epoch"], + "time_in_minutes": total_time, + } + outputs = pd.DataFrame([outputs]) + + return outputs def check_models(setting, model): - all_ckpt_paths = sorted( - glob(os.path.join(ROOT, setting, model, "freeze-*", "*", "checkpoints", "*", "*", "best.pt")) + all_ckpt_paths = natsorted( + glob(os.path.join(ROOT, setting, model, "*", "freeze-*", "*", "checkpoints", "*", "*", "best.pt")) ) - for ckpt in all_ckpt_paths: - print(ckpt) - _load_per_model(ckpt) + all_outputs = [_stats_per_model(ckpt) for ckpt in all_ckpt_paths] + outputs = pd.concat(all_outputs, ignore_index=True) + print(outputs) def main(): - settings = ["v100", "rtx5000", "gtx1080", "cpu_32G-mem_16-cores", "cpu_64G-mem_16-cores"] + settings = ["V100", "RTX5000", "GTX1080", "cpu_32G-mem_16-cores", "cpu_64G-mem_16-cores"] models = ["vit_b", "vit_b_lm"] for setting in settings: for model in models: diff --git a/finetuning/specialists/resource-efficient/covid_if_finetuning.py b/finetuning/specialists/resource-efficient/covid_if_finetuning.py index 261087217..2107e7f11 100644 --- a/finetuning/specialists/resource-efficient/covid_if_finetuning.py +++ b/finetuning/specialists/resource-efficient/covid_if_finetuning.py @@ -24,7 +24,7 @@ def get_dataloaders(patch_shape, data_path, n_images): Note: to replace this with another data loader you need to return a torch data loader that retuns `x, y` tensors, where `x` is the image data and `y` are the labels. The labels have to be in a label mask instance segmentation format. - I.e. a tensor of the same spatial shape as `x`, with each object mask having its own ID. + i.e. a tensor of the same spatial shape as `x`, with each object mask having its own ID. Important: the ID 0 is reseved for background, and the IDs must be consecutive """ num_workers = 8 if torch.cuda.is_available() else 0 @@ -104,7 +104,7 @@ def finetune_covid_if(args): save_root=args.save_root, scheduler_kwargs=scheduler_kwargs, save_every_kth_epoch=args.save_every_kth_epoch, - + lora_rank=args.lora_rank, ) @@ -148,6 +148,9 @@ def main(): parser.add_argument( "--n_images", type=int, default=None, help="The number of images used for finetuning." ) + parser.add_argument( + "--lora_rank", type=int, default=None, help="The rank used for low rank adaptation." + ) args = parser.parse_args() finetune_covid_if(args) diff --git a/finetuning/specialists/resource-efficient/plot_experiments.py b/finetuning/specialists/resource-efficient/plot_experiments.py index 7e91bd94d..11df33f98 100644 --- a/finetuning/specialists/resource-efficient/plot_experiments.py +++ b/finetuning/specialists/resource-efficient/plot_experiments.py @@ -6,13 +6,14 @@ import numpy as np import pandas as pd import seaborn as sns + import matplotlib.pyplot as plt import matplotlib.lines as mlines from matplotlib.ticker import FuncFormatter from matplotlib.ticker import FormatStrFormatter -ROOT = "/scratch/usr/nimanwai/experiments/resource-efficient-finetuning/" +ROOT = "/media/anwai/ANWAI/micro-sam/resource-efficient-finetuning/" PALETTE = { "AIS": "#045275", @@ -24,9 +25,9 @@ RNAME_MAPPING = { "cpu_32G-mem_16-cores": "Intel Cascade Lake Xeon Platinum 9242 (32GB CPU RAM)", "cpu_64G-mem_16-cores": "Intel Cascade Lake Xeon Platinum 9242 (64GB CPU RAM)", - "rtx5000": "NVIDIA Quadro RTX5000 (16GB VRAM)", - "v100": "NVIDIA Tesla V100 (32GB VRAM)", - "gtx1080": "NVIDIA GeForce GTX 1080 (8GB VRAM)" + "RTX5000": "NVIDIA Quadro RTX5000 (16GB VRAM)", + "V100": "NVIDIA Tesla V100 (32GB VRAM)", + "GTX1080": "NVIDIA GeForce GTX 1080 (8GB VRAM)" } plt.rcParams.update({"font.size": 30}) @@ -76,7 +77,6 @@ def plot_all_experiments(): # let's get the benchmark results all_benchmark_results, all_benchmark_box_results = {}, {} for be_path in sorted(benchmark_experiment_paths): - experiment_name = os.path.split(be_path)[-1] all_res_paths = glob(os.path.join(be_path, "*", "results", "*")) all_model_paths = glob(os.path.join(be_path, "*")) @@ -102,35 +102,100 @@ def plot_all_experiments(): continue print(f"Results for {resource_name} on {model_name}") - all_image_setting_paths = natsorted(glob(os.path.join(model_epath, "*", "*"))) - all_res_list, all_box_res_list = [], [] + all_image_setting_paths = natsorted(glob(os.path.join(model_epath, "*", "*", "*"))) + all_res_list_full, all_box_res_list_full = [], [] + all_res_list_lora, all_box_res_list_lora = [], [] for image_epath in all_image_setting_paths: - image_setting = os.path.split(image_epath)[-1] - all_res_paths = sorted(glob(os.path.join(image_epath, "results", "*"))) - per_image_df, per_image_box_df = _get_all_results(image_setting.split("-")[0], all_res_paths) - all_res_list.append(per_image_df) - all_box_res_list.append(per_image_box_df) + _splits = image_epath.split("/") + image_setting = _splits[-1] + ft_setting = _splits[-3] - this_res = pd.concat([all_benchmark_results[model_name], *all_res_list]) - this_box_res = pd.concat([all_benchmark_box_results[model_name], *all_box_res_list]) + all_res_paths = sorted(glob(os.path.join(image_epath, "results", "*"))) + per_image_df, per_image_box_df = _get_all_results( + image_setting.split("-")[0] + "-" + ft_setting.split("-")[0], all_res_paths + ) + + if ft_setting == "full-finetuning": + all_res_list_full.append(per_image_df) + all_box_res_list_full.append(per_image_box_df) + else: + all_res_list_lora.append(per_image_df) + all_box_res_list_lora.append(per_image_box_df) + + this_res_full = pd.concat([all_benchmark_results[model_name], *all_res_list_full]) + this_box_res_full = pd.concat([all_benchmark_box_results[model_name], *all_box_res_list_full]) + + this_res_lora = pd.concat([all_benchmark_results[model_name], *all_res_list_lora]) + this_box_res_lora = pd.concat([all_benchmark_box_results[model_name], *all_box_res_list_lora]) + + this_res = pd.concat([this_res_full, this_res_lora]) + this_box_res = pd.concat([this_box_res_full, this_box_res_lora]) + + replacement_map = { + '1-full': '1', '1-lora': '1', + '2-full': '2', '2-lora': '2', + '5-full': '5', '5-lora': '5', + '10-full': '10', '10-lora': '10' + } + this_res['x'] = this_res['name'].replace(replacement_map) + this_box_res['x'] = this_box_res['name'].replace(replacement_map) _title = "Generalist" if model_name.endswith("lm") else "Default" - sns.lineplot( - x="name", y="results", hue="type", data=this_box_res, + def _change_opacity_markers(lineplot): + for line in lineplot.lines: + line.set_alpha(0.8) + + lineplot = sns.lineplot( + x="x", y="results", hue="type", + data=this_box_res[ + this_box_res['name'].str.contains("full") | this_box_res['name'].str.contains("initial") + ], ax=ax[0, idx], palette=PALETTE, hue_order=PALETTE.keys(), marker="o", markersize=15, linewidth=5 ) + _change_opacity_markers(lineplot) + ax[0, idx].set_title(_title, fontweight="bold") ax[0, idx].set(xlabel=None, ylabel=None) ax[0, idx].set_yticks(np.linspace(0.8, 1, 5)) ax[0, idx].yaxis.set_major_formatter(FormatStrFormatter('%.2f')) - sns.lineplot( - x="name", y="results", hue="type", data=this_res, + lineplot = sns.lineplot( + x="x", y="results", hue="type", + data=this_box_res[ + this_box_res['name'].str.contains("lora") | this_box_res['name'].str.contains("initial") + ], + ax=ax[0, idx], palette=PALETTE, hue_order=PALETTE.keys(), + marker="o", markersize=15, linewidth=5, linestyle=":", + ) + _change_opacity_markers(lineplot) + ax[0, idx].set_title(_title, fontweight="bold") + ax[0, idx].set(xlabel=None, ylabel=None) + ax[0, idx].set_yticks(np.linspace(0.8, 1, 5)) + ax[0, idx].yaxis.set_major_formatter(FormatStrFormatter('%.2f')) + + lineplot = sns.lineplot( + x="x", y="results", hue="type", + data=this_res[this_res['name'].str.contains("full") | this_res['name'].str.contains("initial")], ax=ax[1, idx], palette=PALETTE, hue_order=PALETTE.keys(), marker="o", markersize=15, linewidth=5 ) + _change_opacity_markers(lineplot) + + # ax[1, idx].set_title(_title, fontweight="bold") + ax[1, idx].set(xlabel=None, ylabel=None) + ax[1, idx].set_yticks(np.linspace(0.1, 0.6, 6)) + ax[1, idx].yaxis.set_major_formatter(FormatStrFormatter('%.2f')) + + lineplot = sns.lineplot( + x="x", y="results", hue="type", + data=this_res[this_res['name'].str.contains("lora") | this_res['name'].str.contains("initial")], + ax=ax[1, idx], palette=PALETTE, hue_order=PALETTE.keys(), + marker="o", markersize=15, linewidth=5, linestyle=":", + ) + _change_opacity_markers(lineplot) + # ax[1, idx].set_title(_title, fontweight="bold") ax[1, idx].set(xlabel=None, ylabel=None) ax[1, idx].set_yticks(np.linspace(0.1, 0.6, 6)) @@ -149,24 +214,55 @@ def plot_all_experiments(): all_labels.append(label) ax.get_legend().remove() - custom_handles = [] - for color in PALETTE.values(): - line = mlines.Line2D([], [], color=color, markersize=15, marker='o', linestyle='-', linewidth=5) - custom_handles.append(line) - - fig.legend(custom_handles, PALETTE.keys(), loc="lower center", ncols=4, bbox_to_anchor=(0.5, 0)) + _colors = list(PALETTE.values()) + custom_handles = [ + mlines.Line2D([], [], color=_colors[0], markersize=15, marker='o', linestyle='-', linewidth=5), + mlines.Line2D([], [], color=_colors[0], markersize=15, marker='o', linestyle=':', linewidth=5), + mlines.Line2D([], [], color=_colors[1], markersize=15, marker='o', linestyle='-', linewidth=5), + mlines.Line2D([], [], color=_colors[1], markersize=15, marker='o', linestyle=':', linewidth=5), + mlines.Line2D([], [], color=_colors[2], markersize=15, marker='o', linestyle='-', linewidth=5), + mlines.Line2D([], [], color=_colors[2], markersize=15, marker='o', linestyle=':', linewidth=5), + mlines.Line2D([], [], color=_colors[3], markersize=15, marker='o', linestyle='-', linewidth=5), + mlines.Line2D([], [], color=_colors[3], markersize=15, marker='o', linestyle=':', linewidth=5), + ] + + if resource_name == "GTX1080": + fig.legend( + handles=custom_handles, + labels=['AIS (MD, PE)', 'AMG (MD, PE)', 'Point (MD, PE)', 'Box (MD, PE)'], + loc="lower center", ncols=4, bbox_to_anchor=(0.5, 0) + ) + bottom = 0.12 + else: + fig.legend( + handles=custom_handles, + labels=[ + 'AIS (FFT)', 'AIS (LoRA)', 'AMG (FFT)', 'AMG (LoRA)', + 'Point (FFT)', 'Point (LoRA)', 'Box (FFT)', 'Box (LoRA)' + ], + loc="lower center", ncols=4, bbox_to_anchor=(0.5, 0) + ) + bottom = 0.15 def format_y_tick_label(value, pos): return "{:.2f}".format(value) plt.gca().yaxis.set_major_formatter(FuncFormatter(format_y_tick_label)) + if resource_name == "V100": + y = -0.075 + elif resource_name == "GTX1080": + y = -0.1 + elif resource_name == "cpu_32G-mem_16-cores": + y = -0.03 + else: + y = 0 + plt.text(x=-1.35, y=y, s="Number of Images", fontweight="bold") plt.text(x=-5.8, y=0.36, s="Segmentation Accuracy at IoU 50%", rotation=90, fontweight="bold") - plt.text(x=-1.35, y=-0.075, s="Number of Images", fontweight="bold") - plt.subplots_adjust(wspace=0.1, hspace=0.15, bottom=0.12, top=0.88) + plt.subplots_adjust(wspace=0.1, hspace=0.15, bottom=bottom, top=0.88) - if resource_name == "cpu_32G-mem_16-cores": + if resource_name == "cpu_64G-mem_16-cores": fig.suptitle("Resource Efficient Finetuning (CPU)", y=0.95, x=0.51) save_path = "./5_b.png" plt.savefig(save_path) diff --git a/finetuning/specialists/resource-efficient/run_evaluations.py b/finetuning/specialists/resource-efficient/run_evaluations.py index 574545a49..741f28190 100644 --- a/finetuning/specialists/resource-efficient/run_evaluations.py +++ b/finetuning/specialists/resource-efficient/run_evaluations.py @@ -1,6 +1,7 @@ import os import re import shutil +import itertools import subprocess from glob import glob from tqdm import tqdm @@ -12,20 +13,20 @@ ALL_SCRIPTS = [ - # "../../evaluation/precompute_embeddings", - # "../../evaluation/iterative_prompting", + "../../evaluation/precompute_embeddings", + "../../evaluation/iterative_prompting", "../../evaluation/evaluate_amg", "../../evaluation/evaluate_instance_segmentation" ] -ROOT = "/scratch/usr/nimanwai/experiments/resource-efficient-finetuning/" # for hlrn -# ROOT = "/scratch/users/archit/experiments/" # for scc - -DATA_DIR = "/scratch/projects/nim00007/sam/data/covid_if/" # for hlrn -# DATA_DIR = "/scratch/users/archit/data/covid-if" # for scc +DATA_DIR = "/scratch/projects/nim00007/sam/data/covid_if" +ROOT = "/scratch/share/cidas/cca/experiments/resource-efficient-finetuning/" def process_covid_if(input_path): + if os.path.exists(os.path.join(input_path, "slices")): + return + all_image_paths = sorted(glob(os.path.join(input_path, "*.h5"))) # val images @@ -64,34 +65,28 @@ def process_covid_if(input_path): def write_slurm_scripts( - inference_setup, env_name, checkpoint, model_type, experiment_folder, out_path + inference_setup, env_name, checkpoint, model_type, experiment_folder, out_path, lora, ): - on_scc = False - if on_scc: - batch_script = f"""#!/bin/bash -#SBATCH -c 8 -#SBATCH --mem 16G -#SBATCH -t 2-00:00:00 -#SBATCH -p gpu -#SBATCH -G v100:1 -#SBATCH --job-name={Path(inference_setup).stem} - -source activate {env_name} \n""" - - else: - batch_script = f"""#!/bin/bash + batch_script = f"""#!/bin/bash #SBATCH -c 8 #SBATCH --mem 16G #SBATCH -t 2-00:00:00 #SBATCH -p grete:shared #SBATCH -G A100:1 +#SBATCH -A gzz0001 #SBATCH --job-name={Path(inference_setup).stem} -source activate {env_name} \n""" +source ~/.bashrc +mamba activate {env_name} \n""" # python script batch_script += f"python {inference_setup}.py -c {checkpoint} -m {model_type} -e {experiment_folder} -d covid_if " + # Whether the model was trained with LoRA + # NOTE: We use rank 4 for LoRA. + if lora: + batch_script += "--lora_rank 4 " + _op = out_path[:-3] + f"_{Path(inference_setup).stem}.sh" with open(_op, "w") as f: @@ -109,91 +104,110 @@ def write_slurm_scripts( def get_batch_script_names(tmp_folder): tmp_folder = os.path.expanduser(tmp_folder) os.makedirs(tmp_folder, exist_ok=True) - script_name = "micro-sam-inference" - dt = datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f") tmp_name = script_name + dt batch_script = os.path.join(tmp_folder, f"{tmp_name}.sh") - return batch_script -def run_slurm_scripts(model_type, checkpoint, experiment_folder, scripts=ALL_SCRIPTS): +def run_slurm_scripts( + model_type, checkpoint, experiment_folder, dry, scripts=ALL_SCRIPTS, has_lora=False, freeze_image_encoder=False +): tmp_folder = "./gpu_jobs" shutil.rmtree(tmp_folder, ignore_errors=True) - for current_setup in scripts: + lora_combinations = [True, False] + for (current_setup, use_lora) in itertools.product(scripts, lora_combinations): + # for experiments such as vanilla and generalist models. + if not has_lora == use_lora: + continue + + # for experiments such as GTX1080, where freezing the image encoder is the way and we do not use LoRA + if freeze_image_encoder and use_lora: + continue + write_slurm_scripts( inference_setup=current_setup, env_name="mobilesam" if model_type == "vit_t" else "sam", checkpoint=checkpoint, model_type=model_type, experiment_folder=experiment_folder, - out_path=get_batch_script_names(tmp_folder) + out_path=get_batch_script_names(tmp_folder), + lora=use_lora, ) - # the logic below automates the process of first running the precomputation of embeddings, and only then inference. - job_id = [] - for i, my_script in enumerate(sorted(glob(tmp_folder + "/*"))): - cmd = ["sbatch", my_script] + if not dry: + # the logic below automates running the precomputation of embeddings first and then inference. + job_id = [] + for i, my_script in enumerate(sorted(glob(tmp_folder + "/*"))): + cmd = ["sbatch", my_script] - if i > 0: - cmd.insert(1, f"--dependency=afterany:{job_id[0]}") + if i > 0: + cmd.insert(1, f"--dependency=afterany:{job_id[0]}") - cmd_out = subprocess.run(cmd, capture_output=True, text=True) - print(cmd_out.stdout if len(cmd_out.stdout) > 1 else cmd_out.stderr) + cmd_out = subprocess.run(cmd, capture_output=True, text=True) + print(cmd_out.stdout if len(cmd_out.stdout) > 1 else cmd_out.stderr) - if i == 0: - job_id.append(re.findall(r'\d+', cmd_out.stdout)[0]) + if i == 0: + job_id.append(re.findall(r'\d+', cmd_out.stdout)[0]) def main(args): # preprocess the data - process_covid_if(input_path=args.input_path) + process_covid_if(input_path=DATA_DIR) # results on vanilla models run_slurm_scripts( model_type="vit_b", checkpoint=None, experiment_folder=os.path.join(ROOT, "vanilla", "vit_b"), - scripts=ALL_SCRIPTS[:-1] + scripts=ALL_SCRIPTS[:-1], + dry=args.dry, ) # results on generalist models - # vit_b_lm_path = "/scratch/users/archit/micro-sam/vit_b/lm_generalist/best.pt" # on scc - vit_b_lm_path = "/scratch/usr/nimanwai/micro-sam/checkpoints/vit_b/lm_generalist_sam/best.pt" # on hlrn + vit_b_lm_path = "/scratch/usr/nimanwai/micro-sam/checkpoints/vit_b/lm_generalist_sam/best.pt" run_slurm_scripts( model_type="vit_b", checkpoint=vit_b_lm_path, - experiment_folder=os.path.join(ROOT, "generalist", "vit_b") + experiment_folder=os.path.join(ROOT, "generalist", "vit_b"), + dry=args.dry, ) # results on resource-efficient finetuned checkpoints all_checkpoint_paths = glob(os.path.join(ROOT, "**", "best.pt"), recursive=True) - # let's get all gpu jobs and run evaluation for them - all_checkpoint_paths = [ - ckpt for ckpt in all_checkpoint_paths if ckpt.find("cpu") != -1 - ] - for checkpoint_path in all_checkpoint_paths: - # NOTE: run this for vit_b + # NOTE: We run the inference only for `vit_b` models. Remove this to run for `vit_t` models as well. _searcher = checkpoint_path.find("vit_b") - if _searcher == -1: + if _searcher == -1: # i.e. `vit_b` keyword was not found. continue + _searcher2 = checkpoint_path.find("freeze-image_encoder") + freeze_image_encoder = False + if _searcher2 != -1: + freeze_image_encoder = True + + _searcher3 = checkpoint_path.find("lora-finetuning") + has_lora = False + if _searcher3 != -1: + has_lora = True + experiment_folder = os.path.join("/", *checkpoint_path.split("/")[:-4]) run_slurm_scripts( model_type=checkpoint_path.split("/")[-3], checkpoint=checkpoint_path, - experiment_folder=experiment_folder + experiment_folder=experiment_folder, + has_lora=has_lora, + dry=args.dry, + freeze_image_encoder=freeze_image_encoder, ) if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() - parser.add_argument("-i", "--input_path", type=str, default=DATA_DIR) + parser.add_argument("--dry", action="store_true") args = parser.parse_args() main(args) diff --git a/finetuning/specialists/resource-efficient/run_resource_efficient_finetuning.py b/finetuning/specialists/resource-efficient/run_resource_efficient_finetuning.py index 5e32607e8..f76304cea 100644 --- a/finetuning/specialists/resource-efficient/run_resource_efficient_finetuning.py +++ b/finetuning/specialists/resource-efficient/run_resource_efficient_finetuning.py @@ -1,5 +1,6 @@ import os import shutil +import itertools import subprocess from datetime import datetime @@ -7,13 +8,14 @@ def base_slurm_script(env_name, partition, cpu_mem, cpu_cores, gpu_name=None): assert partition in ["grete:shared", "gpu", "medium"] if gpu_name is not None: - assert gpu_name in ["gtx1080", "rtx5000", "v100", "V100"] + assert gpu_name in ["GTX1080", "RTX5000", "V100", "A100"] base_script = f"""#!/bin/bash #SBATCH -c {cpu_cores} #SBATCH --mem {cpu_mem} #SBATCH -p {partition} #SBATCH -t 2-00:00:00 +#SBATCH --job-name micro-sam-resource-efficient-finetuning """ if gpu_name is not None: base_script += f"#SBATCH -G {gpu_name}:1 \n" @@ -21,14 +23,14 @@ def base_slurm_script(env_name, partition, cpu_mem, cpu_cores, gpu_name=None): if partition.startswith("grete"): base_script += "#SBATCH -A gzz0001 \n" - base_script += "\n" + f"source activate {env_name}" + "\n" + base_script += "\n" + "source ~/.bashrc" + "\n" + "mamba activate {env_name}" + "\n" return base_script def write_batch_sript( env_name, partition, cpu_mem, cpu_cores, gpu_name, input_path, save_root, - model_type, n_objects, n_images, script_name, freeze, checkpoint_path + model_type, n_objects, n_images, script_name, freeze, lora, dry, ): assert model_type in ["vit_t", "vit_b", "vit_t_lm", "vit_b_lm"] @@ -45,28 +47,33 @@ def write_batch_sript( # add parameters to the python script python_script += f"-i {input_path} " # path to the covid-if data - python_script += f"-m {model_type[:5]} " # choice of vit + python_script += f"-m {model_type} " # choice of vit python_script += f"--n_objects {n_objects} " # number of objects per batch for finetuning python_script += f"--n_images {n_images} " # number of images we train for - if checkpoint_path is not None: - python_script += f"-c {checkpoint_path} " + # Whether to use LoRA-based finetuning + # NOTE: We use rank as 4 for LoRA. + if lora: + python_script += "--lora_rank 4 " if gpu_name is not None: resource_name = f"{gpu_name}" else: resource_name = f"cpu_{cpu_mem}-mem_{cpu_cores}-cores" + # Updating the path where the model checkpoints and logs will be saved. updated_save_root = os.path.join( save_root, resource_name, model_type, + "lora-finetuning" if lora else "full-finetuning", "freeze-None" if freeze is None else f"freeze-{freeze}", f"{n_images}-images" ) if save_root is not None: python_script += f"-s {updated_save_root} " # path to save model checkpoints and logs + # Whether to freeze a certain part of the SAM model. if freeze is not None: python_script += f"--freeze {freeze} " @@ -76,20 +83,18 @@ def write_batch_sript( with open(script_name, "w") as f: f.write(batch_script) - cmd = ["sbatch", script_name] - subprocess.run(cmd) + if not dry: + cmd = ["sbatch", script_name] + subprocess.run(cmd) def get_batch_script_names(tmp_folder): tmp_folder = os.path.expanduser(tmp_folder) os.makedirs(tmp_folder, exist_ok=True) - script_name = "micro-sam-finetuning" - dt = datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f") tmp_name = script_name + dt batch_script = os.path.join(tmp_folder, f"{tmp_name}.sh") - return batch_script @@ -98,7 +103,13 @@ def main(args): model_type = args.model_type all_n_images = [1, 2, 5, 10] - for n_images in all_n_images: + use_lora = [False, True] + + for (n_images, lora) in itertools.product(all_n_images, use_lora): + # We cannot use LoRA and freeze the image encoder at the same time. + if lora and args.freeze == "image_encoder": + continue + write_batch_sript( env_name="mobilesam" if model_type[:5] == "vit_t" else "sam", partition=args.partition, @@ -112,7 +123,8 @@ def main(args): n_images=n_images, script_name=get_batch_script_names(tmp_folder), freeze=args.freeze, - checkpoint_path=args.checkpoint + lora=lora, + dry=args.dry, ) @@ -129,12 +141,13 @@ def main(args): parser.add_argument("-m", "--model_type", type=str, required=True, help="Choice of image encoder in SAM") parser.add_argument("--n_objects", type=int, required=True, help="The number of objects (instances) per batch.") parser.add_argument("--freeze", type=str, default=None, help="Which parts of the model to freeze for finetuning.") - parser.add_argument("--checkpoint", type=str, default=None, help="Path to custom checkpoint.") parser.add_argument("--partition", type=str, required=True, help="Name of the partition for running the job.") parser.add_argument("--mem", type=str, required=True, help="Amount of cpu memory.") parser.add_argument("-c", "--cpu_cores", type=int, required=True, help="Number of cpu cores.") - parser.add_argument("-G", "--gpu_name", type=str, default=None, help="The GPI resources used for finetuning.") + parser.add_argument("-G", "--gpu_name", type=str, default=None, help="The GPU resources used for finetuning.") + + parser.add_argument("--dry", action="store_true", help="Whether to avoid submitting the configured scripts.") args = parser.parse_args() main(args) diff --git a/micro_sam/evaluation/inference.py b/micro_sam/evaluation/inference.py index 1905fc775..8340b4088 100644 --- a/micro_sam/evaluation/inference.py +++ b/micro_sam/evaluation/inference.py @@ -547,11 +547,12 @@ def run_amg( test_image_paths: List[Union[str, os.PathLike]], iou_thresh_values: Optional[List[float]] = None, stability_score_values: Optional[List[float]] = None, + lora_rank: Optional[int] = None, ) -> str: embedding_folder = os.path.join(experiment_folder, "embeddings") # where the precomputed embeddings are saved os.makedirs(embedding_folder, exist_ok=True) - predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint) + predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint, lora_rank=lora_rank) amg = AutomaticMaskGenerator(predictor) amg_prefix = "amg" @@ -588,11 +589,14 @@ def run_instance_segmentation_with_decoder( val_image_paths: List[Union[str, os.PathLike]], val_gt_paths: List[Union[str, os.PathLike]], test_image_paths: List[Union[str, os.PathLike]], + lora_rank: Optional[int] = None, ) -> str: embedding_folder = os.path.join(experiment_folder, "embeddings") # where the precomputed embeddings are saved os.makedirs(embedding_folder, exist_ok=True) - predictor, decoder = get_predictor_and_decoder(model_type=model_type, checkpoint_path=checkpoint) + predictor, decoder = get_predictor_and_decoder( + model_type=model_type, checkpoint_path=checkpoint, lora_rank=lora_rank, + ) segmenter = InstanceSegmentationWithDecoder(predictor, decoder) seg_prefix = "instance_segmentation_with_decoder" diff --git a/micro_sam/instance_segmentation.py b/micro_sam/instance_segmentation.py index d86c534ef..c87804289 100644 --- a/micro_sam/instance_segmentation.py +++ b/micro_sam/instance_segmentation.py @@ -800,6 +800,7 @@ def get_predictor_and_decoder( model_type: str, checkpoint_path: Union[str, os.PathLike], device: Optional[Union[str, torch.device]] = None, + lora_rank: Optional[int] = None, ) -> Tuple[SamPredictor, DecoderAdapter]: """Load the SAM model (predictor) and instance segmentation decoder. @@ -810,6 +811,7 @@ def get_predictor_and_decoder( model_type: The type of the image encoder used in the SAM model. checkpoint_path: Path to the checkpoint from which to load the data. device: The device. + lora_rank: The rank for low rank adaptation of the attention layers. Returns: The SAM predictor. @@ -817,8 +819,11 @@ def get_predictor_and_decoder( """ device = util.get_device(device) predictor, state = util.get_sam_model( - model_type=model_type, checkpoint_path=checkpoint_path, - device=device, return_state=True + model_type=model_type, + checkpoint_path=checkpoint_path, + device=device, + return_state=True, + lora_rank=lora_rank, ) if "decoder_state" not in state: raise ValueError(f"The checkpoint at {checkpoint_path} does not contain a decoder state") From 0fd30cedba73a323311566b1c7e44f0e171c164c Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Wed, 31 Jul 2024 17:58:42 +0200 Subject: [PATCH 27/53] Add results for best resource efficient finetuning timings --- .../specialists/resource-efficient/README.md | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/finetuning/specialists/resource-efficient/README.md b/finetuning/specialists/resource-efficient/README.md index 482a51d84..3cc455fda 100644 --- a/finetuning/specialists/resource-efficient/README.md +++ b/finetuning/specialists/resource-efficient/README.md @@ -214,3 +214,26 @@ All jobs are tested on `medium` partition. | CPU (64G) | vit_b_lm (LoRA) | 2 | 1 | 0:19:46 | | CPU (64G) | vit_b_lm (LoRA) | 5 | 38 | 9:38:11 | | CPU (64G) | vit_b_lm (LoRA) | 10 | 15 | 5:42:34 | + + +### Plots for the Best Setting: +| Resource | Model | Finetuned Strategy | Best Epoch | Train Time | +|------------|------------|--------------------|------------|------------| +| CPU (32G) | Default | FFT | 24 | 5:41:31 | +| CPU (32G) | Default | LoRA | 13 | 2:57:08 | +| CPU (32G) | Generalist | FFT | 6 | 2:01:30 | +| CPU (32G) | Generalist | LoRA | 7 | 1:58:57 | +| CPU (64G) | Default | FFT | 15 | 3:51:02 | +| CPU (64G) | Default | LoRA | 19 | 5:20:02 | +| CPU (64G) | Generalist | FFT | 5 | 1:28:26 | +| CPU (64G) | Generalist | LoRA | 15 | 5:42:34 | +| GTX1080 | Default | MD, PE | 40 | 1:18:05 | +| GTX1080 | Generalist | MD, PE | 13 | 0:15:05 | +| RTX5000 | Default | FFT | 43 | 0:46:55 | +| RTX5000 | Default | LoRA | 16 | 0:17:37 | +| RTX5000 | Generalist | FFT | 3 | 0:04:22 | +| RTX5000 | Generalist | LoRA | 32 | 0:34:04 | +| V100 | Default | FFT | 20 | 0:26:24 | +| V100 | Default | LoRA | 42 | 0:51:10 | +| V100 | Generalist | FFT | 2 | 0:03:48 | +| V100 | Generalist | LoRA | 5 | 0:07:11 | From d0a31ebd6738d3a81d523b74e628a83396c027c7 Mon Sep 17 00:00:00 2001 From: Anwai Archit Date: Thu, 1 Aug 2024 10:16:43 +0200 Subject: [PATCH 28/53] Minor fix to resource efficient finetuning plot legends --- .../resource-efficient/plot_experiments.py | 26 ++++++++++++------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/finetuning/specialists/resource-efficient/plot_experiments.py b/finetuning/specialists/resource-efficient/plot_experiments.py index 11df33f98..8691d82f8 100644 --- a/finetuning/specialists/resource-efficient/plot_experiments.py +++ b/finetuning/specialists/resource-efficient/plot_experiments.py @@ -215,18 +215,14 @@ def _change_opacity_markers(lineplot): ax.get_legend().remove() _colors = list(PALETTE.values()) - custom_handles = [ - mlines.Line2D([], [], color=_colors[0], markersize=15, marker='o', linestyle='-', linewidth=5), - mlines.Line2D([], [], color=_colors[0], markersize=15, marker='o', linestyle=':', linewidth=5), - mlines.Line2D([], [], color=_colors[1], markersize=15, marker='o', linestyle='-', linewidth=5), - mlines.Line2D([], [], color=_colors[1], markersize=15, marker='o', linestyle=':', linewidth=5), - mlines.Line2D([], [], color=_colors[2], markersize=15, marker='o', linestyle='-', linewidth=5), - mlines.Line2D([], [], color=_colors[2], markersize=15, marker='o', linestyle=':', linewidth=5), - mlines.Line2D([], [], color=_colors[3], markersize=15, marker='o', linestyle='-', linewidth=5), - mlines.Line2D([], [], color=_colors[3], markersize=15, marker='o', linestyle=':', linewidth=5), - ] if resource_name == "GTX1080": + custom_handles = [ + mlines.Line2D([], [], color=_colors[0], markersize=15, marker='o', linestyle='-', linewidth=5), + mlines.Line2D([], [], color=_colors[1], markersize=15, marker='o', linestyle='-', linewidth=5), + mlines.Line2D([], [], color=_colors[2], markersize=15, marker='o', linestyle='-', linewidth=5), + mlines.Line2D([], [], color=_colors[3], markersize=15, marker='o', linestyle='-', linewidth=5), + ] fig.legend( handles=custom_handles, labels=['AIS (MD, PE)', 'AMG (MD, PE)', 'Point (MD, PE)', 'Box (MD, PE)'], @@ -234,6 +230,16 @@ def _change_opacity_markers(lineplot): ) bottom = 0.12 else: + custom_handles = [ + mlines.Line2D([], [], color=_colors[0], markersize=15, marker='o', linestyle='-', linewidth=5), + mlines.Line2D([], [], color=_colors[0], markersize=15, marker='o', linestyle=':', linewidth=5), + mlines.Line2D([], [], color=_colors[1], markersize=15, marker='o', linestyle='-', linewidth=5), + mlines.Line2D([], [], color=_colors[1], markersize=15, marker='o', linestyle=':', linewidth=5), + mlines.Line2D([], [], color=_colors[2], markersize=15, marker='o', linestyle='-', linewidth=5), + mlines.Line2D([], [], color=_colors[2], markersize=15, marker='o', linestyle=':', linewidth=5), + mlines.Line2D([], [], color=_colors[3], markersize=15, marker='o', linestyle='-', linewidth=5), + mlines.Line2D([], [], color=_colors[3], markersize=15, marker='o', linestyle=':', linewidth=5), + ] fig.legend( handles=custom_handles, labels=[ From ff0539bfaf2be37eeeb49449c4a786bb12be0691 Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Thu, 5 Sep 2024 20:29:34 +0530 Subject: [PATCH 29/53] Add LoRA support for livecell finetuning scripts (#681) --- finetuning/livecell/lora/README.md | 15 --- finetuning/livecell/lora/train_livecell.py | 124 --------------------- finetuning/livecell_finetuning.py | 5 + 3 files changed, 5 insertions(+), 139 deletions(-) delete mode 100644 finetuning/livecell/lora/README.md delete mode 100644 finetuning/livecell/lora/train_livecell.py diff --git a/finetuning/livecell/lora/README.md b/finetuning/livecell/lora/README.md deleted file mode 100644 index 9cc50de62..000000000 --- a/finetuning/livecell/lora/README.md +++ /dev/null @@ -1,15 +0,0 @@ -## Low Rank Adaptation Methods on Segment Anything for LIVECell - -Insights: -- There's no real memory advantage actually unless it's truly scaled up. For instance: - - `vit_b`: - - SAM: 93M (takes ~50GB) - - SAM-LoRA: 4.4M (takes ~61GB) - - `vit_l`: - - SAM: 312M (takes ~63GB) - - SAM-LoRA: 4.4M (takes ~61GB) - - `vit_h`: - - SAM: 641M (takes ~73GB) - - SAM-LoRA: 4.7M (takes ~67GB) - -- Question: Would quantization lead to better results? (e.g. QLoRA) or parallel adaptation? (e.g. DoRA) diff --git a/finetuning/livecell/lora/train_livecell.py b/finetuning/livecell/lora/train_livecell.py deleted file mode 100644 index 31977b217..000000000 --- a/finetuning/livecell/lora/train_livecell.py +++ /dev/null @@ -1,124 +0,0 @@ -import os -import argparse - -import torch - -from torch_em.data.datasets import get_livecell_loader -from torch_em.transform.label import PerObjectDistanceTransform - -import micro_sam.training as sam_training -from micro_sam.util import export_custom_sam_model - - -def get_dataloaders(patch_shape, data_path, cell_type=None): - """This returns the livecell data loaders implemented in torch_em: - https://github.com/constantinpape/torch-em/blob/main/torch_em/data/datasets/livecell.py - It will automatically download the livecell data. - - Note: to replace this with another data loader you need to return a torch data loader - that retuns `x, y` tensors, where `x` is the image data and `y` are the labels. - The labels have to be in a label mask instance segmentation format. - I.e. a tensor of the same spatial shape as `x`, with each object mask having its own ID. - Important: the ID 0 is reseved for background, and the IDs must be consecutive - """ - label_transform = PerObjectDistanceTransform( - distances=True, boundary_distances=True, directed_distances=False, foreground=True, instances=True, min_size=25 - ) - raw_transform = sam_training.identity # the current workflow avoids rescaling the inputs to [-1, 1] - train_loader = get_livecell_loader( - path=data_path, patch_shape=patch_shape, split="train", batch_size=2, num_workers=16, - cell_types=cell_type, download=True, shuffle=True, label_transform=label_transform, - raw_transform=raw_transform, label_dtype=torch.float32, - ) - val_loader = get_livecell_loader( - path=data_path, patch_shape=patch_shape, split="val", batch_size=4, num_workers=16, - cell_types=cell_type, download=True, shuffle=True, label_transform=label_transform, - raw_transform=raw_transform, label_dtype=torch.float32, - ) - - return train_loader, val_loader - - -def finetune_livecell(args): - """Code for finetuning SAM (using LoRA) on LIVECell - """ - # override this (below) if you have some more complex set-up and need to specify the exact gpu - device = "cuda" if torch.cuda.is_available() else "cpu" - - # training settings: - model_type = args.model_type - checkpoint_path = None # override this to start training from a custom checkpoint - patch_shape = (520, 704) # the patch shape for training - n_objects_per_batch = args.n_objects # this is the number of objects per batch that will be sampled - freeze_parts = args.freeze # override this to freeze different parts of the model - lora_rank = 4 # the rank for low rank adaptation - checkpoint_name = f"{args.model_type}/livecell_sam" - - # all the stuff we need for training - train_loader, val_loader = get_dataloaders(patch_shape=patch_shape, data_path=args.input_path) - scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 10, "verbose": True} - optimizer_class = torch.optim.AdamW - - # Run training. - sam_training.train_sam( - name=checkpoint_name, - model_type=model_type, - train_loader=train_loader, - val_loader=val_loader, - early_stopping=None, - n_objects_per_batch=n_objects_per_batch, - checkpoint_path=checkpoint_path, - freeze=freeze_parts, - device=device, - lr=1e-5, - n_iterations=args.iterations, - save_root=args.save_root, - scheduler_kwargs=scheduler_kwargs, - optimizer_class=optimizer_class, - lora_rank=lora_rank, - ) - - if args.export_path is not None: - checkpoint_path = os.path.join( - "" if args.save_root is None else args.save_root, "checkpoints", checkpoint_name, "best.pt" - ) - export_custom_sam_model( - checkpoint_path=checkpoint_path, model_type=model_type, save_path=args.export_path, - ) - - -def main(): - parser = argparse.ArgumentParser(description="Finetune Segment Anything for the LIVECell dataset.") - parser.add_argument( - "--input_path", "-i", default="/scratch/projects/nim00007/sam/data/livecell/", - help="The filepath to the LIVECell data. If the data does not exist yet it will be downloaded." - ) - parser.add_argument( - "--model_type", "-m", default="vit_b", - help="The model type to use for fine-tuning. Either vit_h, vit_b or vit_l." - ) - parser.add_argument( - "--save_root", "-s", default=None, - help="Where to save the checkpoint and logs. By default they will be saved where this script is run." - ) - parser.add_argument( - "--iterations", type=int, default=int(1e4), - help="For how many iterations should the model be trained? By default 100k." - ) - parser.add_argument( - "--export_path", "-e", - help="Where to export the finetuned model to. The exported model can be used in the annotation tools." - ) - parser.add_argument( - "--freeze", type=str, nargs="+", default=None, - help="Which parts of the model to freeze for finetuning." - ) - parser.add_argument( - "--n_objects", type=int, default=25, help="The number of instances (objects) per batch used for finetuning." - ) - args = parser.parse_args() - finetune_livecell(args) - - -if __name__ == "__main__": - main() diff --git a/finetuning/livecell_finetuning.py b/finetuning/livecell_finetuning.py index 6b63a5a06..5899a4ab4 100644 --- a/finetuning/livecell_finetuning.py +++ b/finetuning/livecell_finetuning.py @@ -72,6 +72,7 @@ def finetune_livecell(args): save_root=args.save_root, scheduler_kwargs=scheduler_kwargs, save_every_kth_epoch=args.save_every_kth_epoch, + lora_rank=args.lora_rank, ) if args.export_path is not None: @@ -116,9 +117,13 @@ def main(): parser.add_argument( "--n_objects", type=int, default=25, help="The number of instances (objects) per batch used for finetuning." ) + parser.add_argument( + "--lora_rank", type=int, default=None, help="The rank for low rank adaptation of the attention layers." + ) args = parser.parse_args() finetune_livecell(args) if __name__ == "__main__": main() + From a9fbe4226d8084f9942ae135566d340d10904ee8 Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Tue, 17 Sep 2024 14:58:42 +0200 Subject: [PATCH 30/53] Add FaCT Finetuning for SAM (#682) Implenet FaCT-based finetuning --- finetuning/evaluation/evaluate_amg.py | 20 +++---- .../evaluate_instance_segmentation.py | 14 +++-- finetuning/evaluation/iterative_prompting.py | 12 +++- .../evaluation/precompute_embeddings.py | 12 +++- .../evaluation/submit_all_evaluation.py | 2 +- finetuning/evaluation/util.py | 16 +---- .../resource-efficient/covid_if_finetuning.py | 2 +- micro_sam/evaluation/inference.py | 8 +-- micro_sam/instance_segmentation.py | 4 +- micro_sam/models/peft_sam.py | 59 ++++++++++++++++++- micro_sam/models/sam_3d_wrapper.py | 10 +++- micro_sam/models/simple_sam_3d_wrapper.py | 6 +- micro_sam/training/training.py | 4 +- micro_sam/training/util.py | 17 ++++-- micro_sam/util.py | 15 +++-- test/test_models/test_peft_sam.py | 21 ++++++- 16 files changed, 157 insertions(+), 65 deletions(-) diff --git a/finetuning/evaluation/evaluate_amg.py b/finetuning/evaluation/evaluate_amg.py index 8f8f132d4..f171b9af1 100644 --- a/finetuning/evaluation/evaluate_amg.py +++ b/finetuning/evaluation/evaluate_amg.py @@ -3,11 +3,13 @@ from micro_sam.evaluation.evaluation import run_evaluation from micro_sam.evaluation.inference import run_amg -from util import get_paths # comment this and create a custom function with the same name to run amg on your data -from util import get_pred_paths, get_default_arguments, VANILLA_MODELS +from util import ( + get_paths, # comment this line out and create a custom function with the same name to run amg on your data + get_pred_paths, get_default_arguments +) -def run_amg_inference(dataset_name, model_type, checkpoint, experiment_folder, lora_rank): +def run_amg_inference(dataset_name, model_type, checkpoint, experiment_folder, peft_kwargs): val_image_paths, val_gt_paths = get_paths(dataset_name, split="val") test_image_paths, _ = get_paths(dataset_name, split="test") prediction_folder = run_amg( @@ -17,7 +19,7 @@ def run_amg_inference(dataset_name, model_type, checkpoint, experiment_folder, l val_image_paths=val_image_paths, val_gt_paths=val_gt_paths, test_image_paths=test_image_paths, - lora_rank=lora_rank, + peft_kwargs=peft_kwargs, ) return prediction_folder @@ -33,12 +35,10 @@ def eval_amg(dataset_name, prediction_folder, experiment_folder): def main(): args = get_default_arguments() - if args.checkpoint is None: - ckpt = VANILLA_MODELS[args.model] - else: - ckpt = args.checkpoint - - prediction_folder = run_amg_inference(args.dataset, args.model, ckpt, args.experiment_folder, args.lora_rank) + peft_kwargs = {"rank": args.peft_rank, "module": args.peft_module} + prediction_folder = run_amg_inference( + args.dataset, args.model, args.checkpoint, args.experiment_folder, peft_kwargs + ) eval_amg(args.dataset, prediction_folder, args.experiment_folder) diff --git a/finetuning/evaluation/evaluate_instance_segmentation.py b/finetuning/evaluation/evaluate_instance_segmentation.py index c41e9fb47..bd311d575 100644 --- a/finetuning/evaluation/evaluate_instance_segmentation.py +++ b/finetuning/evaluation/evaluate_instance_segmentation.py @@ -3,12 +3,14 @@ from micro_sam.evaluation.evaluation import run_evaluation from micro_sam.evaluation.inference import run_instance_segmentation_with_decoder -from util import get_paths # comment this and create a custom function with the same name to run ais on your data -from util import get_pred_paths, get_default_arguments +from util import ( + get_paths, # comment this line out and create a custom function with the same name to run ais on your data + get_pred_paths, get_default_arguments +) def run_instance_segmentation_with_decoder_inference( - dataset_name, model_type, checkpoint, experiment_folder, lora_rank + dataset_name, model_type, checkpoint, experiment_folder, peft_kwargs, ): val_image_paths, val_gt_paths = get_paths(dataset_name, split="val") test_image_paths, _ = get_paths(dataset_name, split="test") @@ -19,7 +21,7 @@ def run_instance_segmentation_with_decoder_inference( val_image_paths=val_image_paths, val_gt_paths=val_gt_paths, test_image_paths=test_image_paths, - lora_rank=lora_rank, + peft_kwargs=peft_kwargs, ) return prediction_folder @@ -35,9 +37,9 @@ def eval_instance_segmentation_with_decoder(dataset_name, prediction_folder, exp def main(): args = get_default_arguments() - + peft_kwargs = {"rank": args.peft_rank, "module": args.peft_module} prediction_folder = run_instance_segmentation_with_decoder_inference( - args.dataset, args.model, args.checkpoint, args.experiment_folder, args.lora_rank, + args.dataset, args.model, args.checkpoint, args.experiment_folder, peft_kwargs, ) eval_instance_segmentation_with_decoder(args.dataset, prediction_folder, args.experiment_folder) diff --git a/finetuning/evaluation/iterative_prompting.py b/finetuning/evaluation/iterative_prompting.py index eae3f8450..b261f4d94 100644 --- a/finetuning/evaluation/iterative_prompting.py +++ b/finetuning/evaluation/iterative_prompting.py @@ -1,10 +1,13 @@ import os +from micro_sam.util import get_sam_model from micro_sam.evaluation import inference from micro_sam.evaluation.evaluation import run_evaluation_for_iterative_prompting -from util import get_paths # comment this and create a custom function with the same name to run int. seg. on your data -from util import get_model, get_default_arguments +from util import ( + get_paths, # comment this line out and create a custom function with the same name to run int. seg. on your data + get_default_arguments +) def _run_iterative_prompting(dataset_name, exp_folder, predictor, start_with_box_prompt, use_masks): @@ -42,7 +45,10 @@ def main(): start_with_box_prompt = args.box # overwrite to start first iters' prompt with box instead of single point # get the predictor to perform inference - predictor = get_model(model_type=args.model, ckpt=args.checkpoint, lora_rank=args.lora_rank) + peft_kwargs = {"rank": args.peft_rank, "module": args.peft_module} + predictor = get_sam_model( + model_type=args.model, checkpoint_path=args.checkpoint, peft_kwargs=peft_kwargs, + ) prediction_root = _run_iterative_prompting( args.dataset, args.experiment_folder, predictor, start_with_box_prompt, args.use_masks diff --git a/finetuning/evaluation/precompute_embeddings.py b/finetuning/evaluation/precompute_embeddings.py index 605627feb..404389064 100644 --- a/finetuning/evaluation/precompute_embeddings.py +++ b/finetuning/evaluation/precompute_embeddings.py @@ -1,15 +1,21 @@ import os +from micro_sam.util import get_sam_model from micro_sam.evaluation import precompute_all_embeddings -from util import get_paths # comment this and create a custom function with the same name to execute on your data -from util import get_model, get_default_arguments +from util import ( + get_paths, # comment this and create a custom function with the same name to execute on your data + get_default_arguments +) def main(): args = get_default_arguments() - predictor = get_model(model_type=args.model, ckpt=args.checkpoint, lora_rank=args.lora_rank) + peft_kwargs = {"rank": args.peft_rank, "module": args.peft_module} + predictor = get_sam_model( + model_type=args.model, checkpoint_path=args.checkpoint, peft_kwargs=peft_kwargs, + ) embedding_dir = os.path.join(args.experiment_folder, "embeddings") os.makedirs(embedding_dir, exist_ok=True) diff --git a/finetuning/evaluation/submit_all_evaluation.py b/finetuning/evaluation/submit_all_evaluation.py index b64549de5..465d96dfe 100644 --- a/finetuning/evaluation/submit_all_evaluation.py +++ b/finetuning/evaluation/submit_all_evaluation.py @@ -221,4 +221,4 @@ def main(args): parser.add_argument("-s", "--specific_experiment", type=str, default=None) args = parser.parse_args() - main(args) + main(args) \ No newline at end of file diff --git a/finetuning/evaluation/util.py b/finetuning/evaluation/util.py index 9780cc70b..d55009ee3 100644 --- a/finetuning/evaluation/util.py +++ b/finetuning/evaluation/util.py @@ -5,7 +5,6 @@ from torch_em.data import datasets -from micro_sam.util import get_sam_model from micro_sam.evaluation.livecell import _get_livecell_paths @@ -80,16 +79,6 @@ def get_dataset_paths(dataset_name, split_choice): return raw_dir, labels_dir -def get_model(model_type, ckpt, lora_rank): - if ckpt is None: - ckpt = VANILLA_MODELS[model_type] - - predictor = get_sam_model( - model_type=model_type, checkpoint_path=ckpt, lora_rank=lora_rank, - ) - return predictor - - def get_paths(dataset_name, split): assert dataset_name in DATASETS, dataset_name @@ -222,14 +211,15 @@ def get_default_arguments(): parser.add_argument( "-m", "--model", type=str, required=True, help="Provide the model type to initialize the predictor" ) - parser.add_argument("-c", "--checkpoint", type=none_or_str, required=True, default=None) + parser.add_argument("-c", "--checkpoint", type=none_or_str, default=None) parser.add_argument("-e", "--experiment_folder", type=str, required=True) parser.add_argument("-d", "--dataset", type=str, default=None) parser.add_argument("--box", action="store_true", help="If passed, starts with first prompt as box") parser.add_argument( "--use_masks", action="store_true", help="To use logits masks for iterative prompting." ) - parser.add_argument("--lora_rank", default=None, type=int, help="The rank for low rank adaptation method.") + parser.add_argument("--peft_rank", default=None, type=int, help="The rank for peft method.") + parser.add_argument("--peft_module", default=None, type=str, help="The module for peft method. (e.g. LoRA or FacT)") args = parser.parse_args() return args diff --git a/finetuning/specialists/resource-efficient/covid_if_finetuning.py b/finetuning/specialists/resource-efficient/covid_if_finetuning.py index 2107e7f11..632f8ba2a 100644 --- a/finetuning/specialists/resource-efficient/covid_if_finetuning.py +++ b/finetuning/specialists/resource-efficient/covid_if_finetuning.py @@ -158,4 +158,4 @@ def main(): if __name__ == "__main__": import warnings warnings.filterwarnings("ignore") - main() + main() \ No newline at end of file diff --git a/micro_sam/evaluation/inference.py b/micro_sam/evaluation/inference.py index 8340b4088..e1736fa5d 100644 --- a/micro_sam/evaluation/inference.py +++ b/micro_sam/evaluation/inference.py @@ -547,12 +547,12 @@ def run_amg( test_image_paths: List[Union[str, os.PathLike]], iou_thresh_values: Optional[List[float]] = None, stability_score_values: Optional[List[float]] = None, - lora_rank: Optional[int] = None, + peft_kwargs: Optional[Dict] = None, ) -> str: embedding_folder = os.path.join(experiment_folder, "embeddings") # where the precomputed embeddings are saved os.makedirs(embedding_folder, exist_ok=True) - predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint, lora_rank=lora_rank) + predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint, peft_kwargs=peft_kwargs) amg = AutomaticMaskGenerator(predictor) amg_prefix = "amg" @@ -589,13 +589,13 @@ def run_instance_segmentation_with_decoder( val_image_paths: List[Union[str, os.PathLike]], val_gt_paths: List[Union[str, os.PathLike]], test_image_paths: List[Union[str, os.PathLike]], - lora_rank: Optional[int] = None, + peft_kwargs: Optional[Dict] = None, ) -> str: embedding_folder = os.path.join(experiment_folder, "embeddings") # where the precomputed embeddings are saved os.makedirs(embedding_folder, exist_ok=True) predictor, decoder = get_predictor_and_decoder( - model_type=model_type, checkpoint_path=checkpoint, lora_rank=lora_rank, + model_type=model_type, checkpoint_path=checkpoint, peft_kwargs=peft_kwargs, ) segmenter = InstanceSegmentationWithDecoder(predictor, decoder) seg_prefix = "instance_segmentation_with_decoder" diff --git a/micro_sam/instance_segmentation.py b/micro_sam/instance_segmentation.py index c87804289..80654e678 100644 --- a/micro_sam/instance_segmentation.py +++ b/micro_sam/instance_segmentation.py @@ -800,7 +800,7 @@ def get_predictor_and_decoder( model_type: str, checkpoint_path: Union[str, os.PathLike], device: Optional[Union[str, torch.device]] = None, - lora_rank: Optional[int] = None, + peft_kwargs: Optional[Dict] = None, ) -> Tuple[SamPredictor, DecoderAdapter]: """Load the SAM model (predictor) and instance segmentation decoder. @@ -823,7 +823,7 @@ def get_predictor_and_decoder( checkpoint_path=checkpoint_path, device=device, return_state=True, - lora_rank=lora_rank, + peft_kwargs=peft_kwargs, ) if "decoder_state" not in state: raise ValueError(f"The checkpoint at {checkpoint_path} does not contain a decoder state") diff --git a/micro_sam/models/peft_sam.py b/micro_sam/models/peft_sam.py index 2bdeed702..59167a1df 100644 --- a/micro_sam/models/peft_sam.py +++ b/micro_sam/models/peft_sam.py @@ -1,5 +1,5 @@ import math -from typing import List, Union +from typing import List, Union, Optional import torch.nn as nn @@ -56,11 +56,64 @@ def forward(self, x): return qkv +class FacTSurgery(nn.Module): + """Operates on the attention layers for performing factorized attention. + + (Inspired from: https://github.com/cchen-cc/MA-SAM/blob/main/MA-SAM/sam_fact_tt_image_encoder.py) + + Args: + rank: The rank of the decomposition matrices for updating weights in each attention layer. + block: The chosen attention blocks for implementing fact. + """ + + def __init__( + self, + rank: int, + block: nn.Module, + dropout: Optional[float] = None, + ): + super().__init__() + self.qkv_proj = block.attn.qkv + self.dim = self.qkv_proj.in_features + + self.q_FacTs = nn.Linear(rank, rank, bias=False) + self.v_FacTs = nn.Linear(rank, rank, bias=False) + + self.dropout = dropout + if self.dropout is not None: + # NOTE : Dropout is not included in the original implementation + self.dp_q = nn.Dropout(self.dropout) + self.dp_v = nn.Dropout(self.dropout) + + self.FacTu = nn.Linear(self.dim, rank, bias=False) + self.FacTv = nn.Linear(rank, self.dim, bias=False) + + block.attn.qkv = self + + def forward(self, x): + qkv = self.qkv_proj(x) # B, N, N, 3 * org_C + + new_q = self.q_FacTs(self.FacTu(x)) + new_v = self.v_FacTs(self.FacTu(x)) + + if self.dropout is not None: + new_q = self.dp_q(new_q) + new_v = self.dp_v(new_v) + + new_q = self.FacTv(new_q) + new_v = self.FacTv(new_v) + + # NOTE : Scaling Factor was set to 1 as it can be tuned via the learning rate + # Does it make sense to include it, in order to have similar learning rate as the original model? + qkv[:, :, :, : self.dim] += new_q + qkv[:, :, :, -self.dim:] += new_v + + return qkv + + class PEFT_Sam(nn.Module): """Wraps the Segment Anything model's image encoder to different parameter efficient finetuning methods. - Inspired by https://github.com/JamesQFreeman/Sam_LoRA/ - Args: model: The Segment Anything model. rank: The rank for low-rank adaptation. diff --git a/micro_sam/models/sam_3d_wrapper.py b/micro_sam/models/sam_3d_wrapper.py index 3e0b7573e..c6a76d963 100644 --- a/micro_sam/models/sam_3d_wrapper.py +++ b/micro_sam/models/sam_3d_wrapper.py @@ -18,8 +18,10 @@ def get_sam_3d_model( model_type="vit_b", checkpoint_path=None, ): - # Make sure not to freeze the encoder when using LoRA. - freeze_encoder_ = freeze_encoder if lora_rank is None else False + peft_kwargs = {} + if lora_rank is not None: + peft_kwargs["rank"] = lora_rank + _, sam = get_sam_model( model_type=model_type, device=device, @@ -28,9 +30,11 @@ def get_sam_3d_model( flexible_load_checkpoint=True, num_multimask_outputs=n_classes, image_size=image_size, - lora_rank=lora_rank, + peft_kwargs=peft_kwargs, ) + # Make sure not to freeze the encoder when using LoRA. + freeze_encoder_ = freeze_encoder if lora_rank is None else False sam_3d = Sam3DWrapper(sam, freeze_encoder=freeze_encoder_) sam_3d.to(device) return sam_3d diff --git a/micro_sam/models/simple_sam_3d_wrapper.py b/micro_sam/models/simple_sam_3d_wrapper.py index 6f67caa47..47d2d60bf 100644 --- a/micro_sam/models/simple_sam_3d_wrapper.py +++ b/micro_sam/models/simple_sam_3d_wrapper.py @@ -16,6 +16,10 @@ def get_simple_sam_3d_model( model_type="vit_b", checkpoint_path=None, ): + peft_kwargs = {} + if lora_rank is not None: + peft_kwargs["rank"] = lora_rank + _, sam = get_sam_model( model_type=model_type, device=device, @@ -23,7 +27,7 @@ def get_simple_sam_3d_model( return_sam=True, image_size=image_size, flexible_load_checkpoint=True, - lora_rank=lora_rank, + peft_kwargs=peft_kwargs, ) # Make sure not to freeze the encoder when using LoRA. diff --git a/micro_sam/training/training.py b/micro_sam/training/training.py index bb31fa380..4e0b72a97 100644 --- a/micro_sam/training/training.py +++ b/micro_sam/training/training.py @@ -154,6 +154,7 @@ def train_sam( save_every_kth_epoch: Optional[int] = None, pbar_signals: Optional[QObject] = None, optimizer_class: Optional[Optimizer] = torch.optim.AdamW, + peft_kwargs: Optional[Dict] = None, **model_kwargs, ) -> None: """Run training for a SAM model. @@ -196,7 +197,6 @@ def train_sam( _check_loader(val_loader, with_segmentation_decoder) device = get_device(device) - # Get the trainable segment anything model. model, state = get_trainable_sam_model( model_type=model_type, @@ -204,9 +204,9 @@ def train_sam( freeze=freeze, checkpoint_path=checkpoint_path, return_state=True, + peft_kwargs=peft_kwargs, **model_kwargs ) - # This class creates all the training data for a batch (inputs, prompts and labels). convert_inputs = ConvertToSamInputs(transform=model.transform, box_distortion_factor=0.025) diff --git a/micro_sam/training/util.py b/micro_sam/training/util.py index 759c905e8..fb4834c03 100644 --- a/micro_sam/training/util.py +++ b/micro_sam/training/util.py @@ -44,8 +44,7 @@ def get_trainable_sam_model( checkpoint_path: Optional[Union[str, os.PathLike]] = None, freeze: Optional[List[str]] = None, return_state: bool = False, - lora_rank: Optional[int] = None, - lora_kwargs: Optional[Dict] = None, + peft_kwargs: Optional[Dict] = None, flexible_load_checkpoint: bool = False, **model_kwargs ) -> TrainableSAM: @@ -84,8 +83,16 @@ def get_trainable_sam_model( # NOTE: This is done exclusive to "get_sam_model" here to use PEFT's layer-specific initialization on top. # Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything. # Overwrites the SAM model by freezing the backbone and allow low rank adaption to attention layers. - if lora_rank is not None: - sam = custom_models.peft_sam.PEFT_Sam(sam, rank=lora_rank, **({} if lora_kwargs is None else lora_kwargs)).sam + if peft_kwargs and isinstance(peft_kwargs, dict): + if model_type[:5] == "vit_t": + raise ValueError("'micro-sam' does not support parameter efficient finetuning for 'mobile-sam'.") + + peft_module = peft_kwargs.get("peft_module") + if peft_module is not None: + from micro_sam.models.peft_sam import LoRASurgery, FacTSurgery + assert peft_module in [LoRASurgery, FacTSurgery], "Invalid PEFT module." + + sam = custom_models.peft_sam.PEFT_Sam(sam, **peft_kwargs).sam # freeze components of the model if freeze was passed # ideally we would want to add components in such a way that: @@ -100,7 +107,7 @@ def get_trainable_sam_model( # we would want to "freeze" all the components in the model if passed a list of parts for l_item in freeze: # in case LoRA is switched on, we cannot freeze the image encoder - if (lora_rank is not None) and (l_item == "image_encoder"): + if (peft_kwargs['rank'] is not None) and (l_item == "image_encoder"): raise ValueError("You cannot use LoRA & freeze the image encoder at the same time.") if name.startswith(f"{l_item}"): diff --git a/micro_sam/util.py b/micro_sam/util.py index 45550a493..af1d1dc68 100644 --- a/micro_sam/util.py +++ b/micro_sam/util.py @@ -273,8 +273,7 @@ def get_sam_model( checkpoint_path: Optional[Union[str, os.PathLike]] = None, return_sam: bool = False, return_state: bool = False, - lora_rank: Optional[int] = None, - lora_kwargs: Optional[Dict] = None, + peft_kwargs: Optional[Dict] = None, flexible_load_checkpoint: bool = False, **model_kwargs, ) -> SamPredictor: @@ -371,10 +370,16 @@ def get_sam_model( # Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything. # Overwrites the SAM model by freezing the backbone and allow low rank adaption to attention layers. - if lora_rank is not None: + if peft_kwargs and isinstance(peft_kwargs, dict): if abbreviated_model_type == "vit_t": - raise ValueError("Parameter efficient finetuning is not supported for 'mobile-sam'.") - sam = custom_models.peft_sam.PEFT_Sam(sam, rank=lora_rank, **({} if lora_kwargs is None else lora_kwargs)).sam + raise ValueError("'micro-sam' does not support parameter efficient finetuning for 'mobile-sam'.") + + peft_module = peft_kwargs.get("peft_module") + if peft_module is not None: + from .models.peft_sam import LoRASurgery, FacTSurgery + assert peft_module in [LoRASurgery, FacTSurgery], "Invalid PEFT module." + + sam = custom_models.peft_sam.PEFT_Sam(sam, **peft_kwargs).sam # In case the model checkpoints have some issues when it is initialized with different parameters than default. if flexible_load_checkpoint: diff --git a/test/test_models/test_peft_sam.py b/test/test_models/test_peft_sam.py index 1af3ef2c5..509a67650 100644 --- a/test/test_models/test_peft_sam.py +++ b/test/test_models/test_peft_sam.py @@ -1,17 +1,32 @@ import unittest import torch + import micro_sam.util as util class TestPEFTSam(unittest.TestCase): model_type = "vit_b" - def test_peft_sam(self): - from micro_sam.models.peft_sam import PEFT_Sam + def test_lora_sam(self): + from micro_sam.models.peft_sam import PEFT_Sam, LoRASurgery + + _, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu") + peft_sam = PEFT_Sam(sam, rank=2, peft_module=LoRASurgery) + + shape = (3, 1024, 1024) + expected_shape = (1, 3, 1024, 1024) + with torch.no_grad(): + batched_input = [{"image": torch.rand(*shape), "original_size": shape[1:]}] + output = peft_sam(batched_input, multimask_output=True) + masks = output[0]["masks"] + self.assertEqual(masks.shape, expected_shape) + + def test_fact_sam(self): + from micro_sam.models.peft_sam import PEFT_Sam, FacTSurgery _, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu") - peft_sam = PEFT_Sam(sam, rank=2) + peft_sam = PEFT_Sam(sam, rank=2, peft_module=FacTSurgery) shape = (3, 1024, 1024) expected_shape = (1, 3, 1024, 1024) From 5b42e8b4b1d22826fb384fc7f79250ce7ad42351 Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Fri, 20 Sep 2024 17:21:16 +0200 Subject: [PATCH 31/53] Minor update to peft kwargs in SAM wrappers (#686) * Minor update to sam 3d wrapper * Refactor peft kwargs for simplesam 3d wrapper --- micro_sam/models/sam_3d_wrapper.py | 8 +++++--- micro_sam/models/simple_sam_3d_wrapper.py | 8 +++++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/micro_sam/models/sam_3d_wrapper.py b/micro_sam/models/sam_3d_wrapper.py index c6a76d963..3b699a524 100644 --- a/micro_sam/models/sam_3d_wrapper.py +++ b/micro_sam/models/sam_3d_wrapper.py @@ -7,6 +7,7 @@ from segment_anything.modeling.image_encoder import window_partition, window_unpartition from ..util import get_sam_model +from .peft_sam import LoRASurgery def get_sam_3d_model( @@ -18,9 +19,10 @@ def get_sam_3d_model( model_type="vit_b", checkpoint_path=None, ): - peft_kwargs = {} - if lora_rank is not None: - peft_kwargs["rank"] = lora_rank + if lora_rank is None: + peft_kwargs = {} + else: + peft_kwargs = {"rank": lora_rank, "peft_module": LoRASurgery} _, sam = get_sam_model( model_type=model_type, diff --git a/micro_sam/models/simple_sam_3d_wrapper.py b/micro_sam/models/simple_sam_3d_wrapper.py index 47d2d60bf..c310e0944 100644 --- a/micro_sam/models/simple_sam_3d_wrapper.py +++ b/micro_sam/models/simple_sam_3d_wrapper.py @@ -5,6 +5,7 @@ import torch.nn as nn from ..util import get_sam_model +from .peft_sam import LoRASurgery def get_simple_sam_3d_model( @@ -16,9 +17,10 @@ def get_simple_sam_3d_model( model_type="vit_b", checkpoint_path=None, ): - peft_kwargs = {} - if lora_rank is not None: - peft_kwargs["rank"] = lora_rank + if lora_rank is None: + peft_kwargs = {} + else: + peft_kwargs = {"rank": lora_rank, "peft_module": LoRASurgery} _, sam = get_sam_model( model_type=model_type, From be47087553d5c6273889c4d8adf2701c8e8b9c76 Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Sun, 22 Sep 2024 23:18:53 +0200 Subject: [PATCH 32/53] Fix livecell finetuning with updated peft methods (#691) --- finetuning/livecell_finetuning.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/finetuning/livecell_finetuning.py b/finetuning/livecell_finetuning.py index 5899a4ab4..f32986f57 100644 --- a/finetuning/livecell_finetuning.py +++ b/finetuning/livecell_finetuning.py @@ -72,7 +72,7 @@ def finetune_livecell(args): save_root=args.save_root, scheduler_kwargs=scheduler_kwargs, save_every_kth_epoch=args.save_every_kth_epoch, - lora_rank=args.lora_rank, + peft_kwargs={"rank": args.lora_rank} if args.lora_rank is not None else None, ) if args.export_path is not None: @@ -118,7 +118,7 @@ def main(): "--n_objects", type=int, default=25, help="The number of instances (objects) per batch used for finetuning." ) parser.add_argument( - "--lora_rank", type=int, default=None, help="The rank for low rank adaptation of the attention layers." + "--lora_rank", type=int, default=None, help="The rank for low rank adaptation of the attention layers." ) args = parser.parse_args() finetune_livecell(args) @@ -126,4 +126,3 @@ def main(): if __name__ == "__main__": main() - From e5a29e13870bec16ef9854a9097d345d93d5aeeb Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Wed, 25 Sep 2024 09:42:51 +0200 Subject: [PATCH 33/53] Add automatic segmentation cli (#699) Add CLI for automatic instance segmentation --- micro_sam/automatic_segmentation.py | 187 ++++++++++++++++++++ micro_sam/instance_segmentation.py | 6 +- micro_sam/multi_dimensional_segmentation.py | 9 +- micro_sam/precompute_state.py | 4 +- micro_sam/util.py | 4 +- setup.cfg | 1 + 6 files changed, 203 insertions(+), 8 deletions(-) create mode 100644 micro_sam/automatic_segmentation.py diff --git a/micro_sam/automatic_segmentation.py b/micro_sam/automatic_segmentation.py new file mode 100644 index 000000000..39ff7d98e --- /dev/null +++ b/micro_sam/automatic_segmentation.py @@ -0,0 +1,187 @@ +import os +from pathlib import Path +from typing import Union, Optional, Tuple + +import numpy as np +import imageio.v3 as imageio + +from . import util +from .instance_segmentation import ( + get_amg, get_decoder, mask_data_to_segmentation, InstanceSegmentationWithDecoder, AMGBase +) +from .multi_dimensional_segmentation import automatic_3d_segmentation + + +def automatic_instance_segmentation( + input_path: Union[os.PathLike, str], + output_path: Optional[Union[os.PathLike, str]] = None, + embedding_path: Optional[Union[os.PathLike, str]] = None, + model_type: str = util._DEFAULT_MODEL, + checkpoint_path: Optional[Union[os.PathLike, str]] = None, + key: Optional[str] = None, + ndim: Optional[int] = None, + tile_shape: Optional[Tuple[int, int]] = None, + halo: Optional[Tuple[int, int]] = None, + use_amg: bool = False, + **generate_kwargs +) -> None: + """Run automatic segmentation for the input image. + + Args: + input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png), + or a container file (e.g. hdf5 or zarr). + output_path: The output path where the instance segmentations will be saved. + embedding_path: The path where the embeddings are cached already / will be saved. + model_type: The SegmentAnything model to use. Will use the standard vit_l model by default. + checkpoint_path: Path to a checkpoint for a custom model. + key: The key to the input file. This is needed for container files (eg. hdf5 or zarr) + or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case. + ndim: The dimensionality of the data. + tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling. + halo: Overlap of the tiles for tiled prediction. + use_amg: Whether to use Automatic Mask Generation (AMG) as the automatic segmentation method. + """ + predictor, state = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path, return_state=True) + + if "decoder_state" in state and not use_amg: # AIS + decoder = get_decoder(predictor.model.image_encoder, state["decoder_state"]) + segmenter = get_amg(predictor=predictor, decoder=decoder, is_tiled=tile_shape is not None) + else: # AMG + segmenter = get_amg(predictor=predictor, is_tiled=tile_shape is not None) + + # Load the input image file. + if isinstance(input_path, np.ndarray): + image_data = input_path + else: + image_data = util.load_image_data(input_path, key) + + if ndim == 3 or image_data.ndim == 3: + if image_data.ndim != 3: + raise ValueError(f"The inputs do not correspond to three dimensional inputs: '{image_data.ndim}'") + + instances = automatic_3d_segmentation( + volume=image_data, + predictor=predictor, + segmentor=segmenter, + embedding_path=embedding_path, + tile_shape=tile_shape, + halo=halo, + **generate_kwargs + ) + else: + # Precompute the image embeddings. + image_embeddings = util.precompute_image_embeddings( + predictor=predictor, + input_=image_data, + save_path=embedding_path, + ndim=ndim, + tile_shape=tile_shape, + halo=halo, + ) + + segmenter.initialize(image=image_data, image_embeddings=image_embeddings) + masks = segmenter.generate(**generate_kwargs) + + if len(masks) == 0: # instance segmentation can have no masks, hence we just save empty labels + if isinstance(segmenter, InstanceSegmentationWithDecoder): + this_shape = segmenter._foreground.shape + elif isinstance(segmenter, AMGBase): + this_shape = segmenter._original_size + else: + this_shape = image_data.shape[-2:] + + instances = np.zeros(this_shape, dtype="uint32") + else: + instances = mask_data_to_segmentation(masks, with_background=True, min_object_size=0) + + if output_path is not None: + # Save the instance segmentation + output_path = Path(output_path).with_suffix(".tif") + imageio.imwrite(output_path, instances, compression="zlib") + + return instances + + +def main(): + """@private""" + import argparse + + available_models = list(util.get_model_names()) + available_models = ", ".join(available_models) + + parser = argparse.ArgumentParser(description="Run automatic segmentation for an image.") + parser.add_argument( + "-i", "--input_path", required=True, + help="The filepath to the image data. Supports all data types that can be read by imageio (e.g. tif, png, ...) " + "or elf.io.open_file (e.g. hdf5, zarr, mrc). For the latter you also need to pass the 'key' parameter." + ) + parser.add_argument( + "-o", "--output_path", required=True, + help="The filepath to store the instance segmentation. The current support stores segmentation in a 'tif' file." + ) + parser.add_argument( + "-e", "--embedding_path", default=None, type=str, help="The path where the embeddings will be saved." + ) + parser.add_argument( + "--pattern", help="Pattern / wildcard for selecting files in a folder. To select all files use '*'." + ) + parser.add_argument( + "-k", "--key", + help="The key for opening data with elf.io.open_file. This is the internal path for a hdf5 or zarr container, " + "for an image stack it is a wild-card, e.g. '*.png' and for mrc it is 'data'." + ) + parser.add_argument( + "-m", "--model_type", default=util._DEFAULT_MODEL, + help=f"The segment anything model that will be used, one of {available_models}." + ) + parser.add_argument( + "-c", "--checkpoint", default=None, + help="Checkpoint from which the SAM model will be loaded loaded." + ) + parser.add_argument( + "--tile_shape", nargs="+", type=int, help="The tile shape for using tiled prediction.", default=None + ) + parser.add_argument( + "--halo", nargs="+", type=int, help="The halo for using tiled prediction.", default=None + ) + parser.add_argument( + "-n", "--ndim", type=int, default=None, + help="The number of spatial dimensions in the data. Please specify this if your data has a channel dimension." + ) + parser.add_argument( + "--amg", action="store_true", help="Whether to use automatic mask generation with the model." + ) + + args, parameter_args = parser.parse_known_args() + + def _convert_argval(value): + # The values for the parsed arguments need to be in the expected input structure as provided. + # i.e. integers and floats should be in their original types. + try: + return int(value) + except ValueError: + return float(value) + + # NOTE: the script below allows the possibility to catch additional parsed arguments which correspond to + # the automatic segmentation post-processing parameters (eg. 'center_distance_threshold' in AIS) + generate_kwargs = { + parameter_args[i].lstrip("--"): _convert_argval(parameter_args[i + 1]) for i in range(0, len(parameter_args), 2) + } + + automatic_instance_segmentation( + input_path=args.input_path, + output_path=args.output_path, + embedding_path=args.embedding_path, + model_type=args.model_type, + checkpoint_path=args.checkpoint, + key=args.key, + ndim=args.ndim, + tile_shape=args.tile_shape, + halo=args.halo, + use_amg=args.amg, + **generate_kwargs, + ) + + +if __name__ == "__main__": + main() diff --git a/micro_sam/instance_segmentation.py b/micro_sam/instance_segmentation.py index 80654e678..e3dbb5f1c 100644 --- a/micro_sam/instance_segmentation.py +++ b/micro_sam/instance_segmentation.py @@ -800,7 +800,7 @@ def get_predictor_and_decoder( model_type: str, checkpoint_path: Union[str, os.PathLike], device: Optional[Union[str, torch.device]] = None, - peft_kwargs: Optional[Dict] = None, + peft_kwargs: Optional[Dict] = None, ) -> Tuple[SamPredictor, DecoderAdapter]: """Load the SAM model (predictor) and instance segmentation decoder. @@ -826,7 +826,9 @@ def get_predictor_and_decoder( peft_kwargs=peft_kwargs, ) if "decoder_state" not in state: - raise ValueError(f"The checkpoint at {checkpoint_path} does not contain a decoder state") + raise ValueError( + f"The checkpoint at '{checkpoint_path}' or the chosen model '{model_type}' does not contain a decoder state" + ) decoder = get_decoder(predictor.model.image_encoder, state["decoder_state"], device) return predictor, decoder diff --git a/micro_sam/multi_dimensional_segmentation.py b/micro_sam/multi_dimensional_segmentation.py index c8747ed88..fd44ef64b 100644 --- a/micro_sam/multi_dimensional_segmentation.py +++ b/micro_sam/multi_dimensional_segmentation.py @@ -356,7 +356,6 @@ def merge_instance_segmentation_3d( return segmentation -# TODO: Enable tiling def automatic_3d_segmentation( volume: np.ndarray, predictor: SamPredictor, @@ -365,6 +364,8 @@ def automatic_3d_segmentation( with_background: bool = True, gap_closing: Optional[int] = None, min_z_extent: Optional[int] = None, + tile_shape: Optional[Tuple[int, int]] = None, + halo: Optional[Tuple[int, int]] = None, verbose: bool = True, **kwargs, ) -> np.ndarray: @@ -383,6 +384,8 @@ def automatic_3d_segmentation( operation. The value is used to determine the number of iterations for the closing. min_z_extent: Require a minimal extent in z for the segmented objects. This can help to prevent segmentation artifacts. + tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling. + halo: Overlap of the tiles for tiled prediction. verbose: Verbosity flag. kwargs: Keyword arguments for the 'generate' method of the 'segmentor'. @@ -393,7 +396,9 @@ def automatic_3d_segmentation( segmentation = np.zeros(volume.shape, dtype="uint32") min_object_size = kwargs.pop("min_object_size", 0) - image_embeddings = util.precompute_image_embeddings(predictor, volume, save_path=embedding_path, ndim=3) + image_embeddings = util.precompute_image_embeddings( + predictor=predictor, input_=volume, save_path=embedding_path, ndim=3, tile_shape=tile_shape, halo=halo, + ) for i in tqdm(range(segmentation.shape[0]), desc="Segment slices", disable=not verbose): segmentor.initialize(volume[i], image_embeddings=image_embeddings, verbose=False, i=i) diff --git a/micro_sam/precompute_state.py b/micro_sam/precompute_state.py index e4a970b75..520f2baec 100644 --- a/micro_sam/precompute_state.py +++ b/micro_sam/precompute_state.py @@ -235,10 +235,10 @@ def precompute_state( a container file (e.g. hdf5 or zarr) or a folder with images files. In case of a container file the argument `key` must be given. In case of a folder it can be given to provide a glob pattern to subselect files from the folder. - output_path: The output path were the embeddings and other state will be saved. + output_path: The output path where the embeddings and other state will be saved. pattern: Glob pattern to select files in a folder. The embeddings will be computed for each of these files. To select all files in a folder pass "*". - model_type: The SegmentAnything model to use. Will use the standard vit_h model by default. + model_type: The SegmentAnything model to use. Will use the standard vit_l model by default. checkpoint_path: Path to a checkpoint for a custom model. key: The key to the input file. This is needed for contaner files (e.g. hdf5 or zarr) or to load several images as 3d volume. Provide a glob pattern, e.g. "*.tif", for this case. diff --git a/micro_sam/util.py b/micro_sam/util.py index af1d1dc68..a514f3d0a 100644 --- a/micro_sam/util.py +++ b/micro_sam/util.py @@ -44,7 +44,7 @@ from tqdm import tqdm # this is the default model used in micro_sam -# currently set to the default vit_h +# currently set to the default vit_l _DEFAULT_MODEL = "vit_l" # The valid model types. Each type corresponds to the architecture of the @@ -396,7 +396,7 @@ def get_sam_model( # Add the decoder to the state if we have one and if the state is returned. if decoder_path is not None and return_state: - state["decoder_state"] = torch.load(decoder_path, map_location=device) + state["decoder_state"] = torch.load(decoder_path, map_location=device, weights_only=True) if return_sam and return_state: return predictor, sam, state diff --git a/setup.cfg b/setup.cfg index 235437135..2c6b5e38b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -48,6 +48,7 @@ console_scripts = micro_sam.annotator_tracking = micro_sam.sam_annotator.annotator_tracking:main micro_sam.image_series_annotator = micro_sam.sam_annotator.image_series_annotator:main micro_sam.precompute_embeddings = micro_sam.precompute_state:main + micro_sam.automatic_segmentation = micro_sam.automatic_segmentation:main # make sure it gets included in your package [options.package_data] From 2a98f6ef505118f6143aa542e5f3bd61ba1ee628 Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Sun, 29 Sep 2024 12:40:32 +0200 Subject: [PATCH 34/53] Add tests for automatic segmentation function (#702) Add test for automatic segmentation cli --------- Co-authored-by: Constantin Pape --- examples/annotator_2d.py | 2 +- micro_sam/automatic_segmentation.py | 21 +++- test/test_automatic_segmentation.py | 157 ++++++++++++++++++++++++++++ test/test_sam_annotator/test_cli.py | 3 + 4 files changed, 177 insertions(+), 6 deletions(-) create mode 100644 test/test_automatic_segmentation.py diff --git a/examples/annotator_2d.py b/examples/annotator_2d.py index 9df1cf990..eba06f83d 100644 --- a/examples/annotator_2d.py +++ b/examples/annotator_2d.py @@ -65,7 +65,7 @@ def wholeslide_annotator(use_finetuned_model): def main(): # Whether to use the fine-tuned SAM model for light microscopy data. - use_finetuned_model = False + use_finetuned_model = True # 2d annotator for livecell data livecell_annotator(use_finetuned_model) diff --git a/micro_sam/automatic_segmentation.py b/micro_sam/automatic_segmentation.py index 39ff7d98e..79043eee8 100644 --- a/micro_sam/automatic_segmentation.py +++ b/micro_sam/automatic_segmentation.py @@ -1,6 +1,6 @@ import os from pathlib import Path -from typing import Union, Optional, Tuple +from typing import Dict, Optional, Union, Tuple import numpy as np import imageio.v3 as imageio @@ -13,7 +13,7 @@ def automatic_instance_segmentation( - input_path: Union[os.PathLike, str], + input_path: Union[Union[os.PathLike, str], np.ndarray], output_path: Optional[Union[os.PathLike, str]] = None, embedding_path: Optional[Union[os.PathLike, str]] = None, model_type: str = util._DEFAULT_MODEL, @@ -23,8 +23,9 @@ def automatic_instance_segmentation( tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, use_amg: bool = False, + amg_kwargs: Optional[Dict] = None, **generate_kwargs -) -> None: +) -> np.ndarray: """Run automatic segmentation for the input image. Args: @@ -40,14 +41,24 @@ def automatic_instance_segmentation( tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling. halo: Overlap of the tiles for tiled prediction. use_amg: Whether to use Automatic Mask Generation (AMG) as the automatic segmentation method. + amg_kwargs: optional keyword arguments for creating the AMG or AIS class. + generate_kwargs: optional keyword arguments for the generate function onf the AMG or AIS class. + + Returns: + The segmentation result. """ predictor, state = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path, return_state=True) if "decoder_state" in state and not use_amg: # AIS decoder = get_decoder(predictor.model.image_encoder, state["decoder_state"]) - segmenter = get_amg(predictor=predictor, decoder=decoder, is_tiled=tile_shape is not None) + segmenter = get_amg( + predictor=predictor, decoder=decoder, is_tiled=tile_shape is not None, + **({} if amg_kwargs is None else amg_kwargs) + ) else: # AMG - segmenter = get_amg(predictor=predictor, is_tiled=tile_shape is not None) + segmenter = get_amg( + predictor=predictor, is_tiled=tile_shape is not None, **({} if amg_kwargs is None else amg_kwargs) + ) # Load the input image file. if isinstance(input_path, np.ndarray): diff --git a/test/test_automatic_segmentation.py b/test/test_automatic_segmentation.py new file mode 100644 index 000000000..47b460f45 --- /dev/null +++ b/test/test_automatic_segmentation.py @@ -0,0 +1,157 @@ +import unittest + +import numpy as np +from skimage.draw import disk +from skimage.measure import label as connected_components + +import micro_sam.util as util + + +class TestAutomaticSegmentation(unittest.TestCase): + model_type = "vit_t" if util.VIT_T_SUPPORT else "vit_b" + model_type_ais = "vit_t_lm" if util.VIT_T_SUPPORT else "vit_b_lm" + tile_shape = (384, 768) + halo = (96, 96) + + # create an input 2d image with three objects + @staticmethod + def _get_2d_inputs(shape): + mask = np.zeros(shape, dtype="uint8") + + def write_object(center, radius): + circle = disk(center, radius, shape=shape) + mask[circle] = 1 + + center = tuple(sh // 4 for sh in shape) + write_object(center, radius=29) + + center = tuple(sh // 2 for sh in shape) + write_object(center, radius=33) + + center = tuple(3 * sh // 4 for sh in shape) + write_object(center, radius=35) + + image = mask * 255 + mask = connected_components(mask) + return mask, image + + # create an input 2d image with three objects and stack them together + @classmethod + def _get_3d_inputs(cls, shape): + mask, image = cls._get_2d_inputs(shape[-2:]) + + # Create volumes by stacking the input image and respective mask. + volume = np.stack([image] * shape[0]) + labels = np.stack([mask] * shape[0]) + return labels, volume + + @classmethod + def setUpClass(cls): + # Input 2d data for normal and tiled segmentation. + cls.mask, cls.image = cls._get_2d_inputs(shape=(256, 256)) + cls.large_mask, cls.large_image = cls._get_2d_inputs(shape=(768, 768)) + + # Input 3d data for normal and tiled segmentation. + cls.labels, cls.volume = cls._get_3d_inputs(shape=(3, 256, 256)) + cls.large_labels, cls.large_volume = cls._get_3d_inputs(shape=(3, 768, 768)) + + def tearDown(self): + # Release all unoccupied cached memory (eg. tiling requires a lot of memory) + device = util.get_device(None) + if device == "cuda": + import torch.cuda + torch.cuda.empty_cache() + elif device == "mps": + import torch.mps + torch.mps.empty_cache() + + def test_automatic_mask_generator_2d(self): + from micro_sam.automatic_segmentation import automatic_instance_segmentation + + mask, image = self.mask, self.image + instances = automatic_instance_segmentation( + input_path=image, model_type=self.model_type, ndim=2, use_amg=True, + amg_kwargs={"points_per_side": 4} + ) + self.assertEqual(mask.shape, instances.shape) + + def test_tiled_automatic_mask_generator_2d(self): + from micro_sam.automatic_segmentation import automatic_instance_segmentation + + mask, image = self.large_mask, self.large_image + instances = automatic_instance_segmentation( + input_path=image, + model_type=self.model_type, + ndim=2, + tile_shape=self.tile_shape, + halo=self.halo, + use_amg=True, + amg_kwargs={"points_per_side": 4} + ) + self.assertEqual(mask.shape, instances.shape) + + def test_instance_segmentation_with_decoder_2d(self): + from micro_sam.automatic_segmentation import automatic_instance_segmentation + + mask, image = self.mask, self.image + instances = automatic_instance_segmentation( + input_path=image, model_type=self.model_type_ais, ndim=2 + ) + self.assertEqual(mask.shape, instances.shape) + + def test_tiled_instance_segmentation_with_decoder_2d(self): + from micro_sam.automatic_segmentation import automatic_instance_segmentation + + mask, image = self.large_mask, self.large_image + instances = automatic_instance_segmentation( + input_path=image, model_type=self.model_type_ais, + ndim=2, tile_shape=self.tile_shape, halo=self.halo, + ) + self.assertEqual(mask.shape, instances.shape) + + @unittest.skip("Skipping long running tests by default.") + def test_automatic_mask_generator_3d(self): + from micro_sam.automatic_segmentation import automatic_instance_segmentation + + labels, volume = self.labels, self.volume + instances = automatic_instance_segmentation( + input_path=volume, model_type=self.model_type, ndim=3, use_amg=True + ) + self.assertEqual(labels.shape, instances.shape) + + @unittest.skip("Skipping long running tests by default.") + def test_tiled_automatic_mask_generator_3d(self): + from micro_sam.automatic_segmentation import automatic_instance_segmentation + + labels, volume = self.large_labels, self.large_volume + instances = automatic_instance_segmentation( + input_path=volume, + model_type=self.model_type, + ndim=3, + tile_shape=self.tile_shape, + halo=self.halo, + use_amg=True, + ) + self.assertEqual(labels.shape, instances.shape) + + def test_instance_segmentation_with_decoder_3d(self): + from micro_sam.automatic_segmentation import automatic_instance_segmentation + + labels, volume = self.labels, self.volume + instances = automatic_instance_segmentation( + input_path=volume, model_type=self.model_type_ais, ndim=3, + ) + self.assertEqual(labels.shape, instances.shape) + + def test_tiled_instance_segmentation_with_decoder_3d(self): + from micro_sam.automatic_segmentation import automatic_instance_segmentation + + labels, volume = self.large_labels, self.large_volume + instances = automatic_instance_segmentation( + input_path=volume, model_type=self.model_type_ais, ndim=3, tile_shape=self.tile_shape, halo=self.halo, + ) + self.assertEqual(labels.shape, instances.shape) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_sam_annotator/test_cli.py b/test/test_sam_annotator/test_cli.py index 9f9919bf1..1df399f9c 100644 --- a/test/test_sam_annotator/test_cli.py +++ b/test/test_sam_annotator/test_cli.py @@ -38,6 +38,9 @@ def test_image_series_annotator(self): def test_precompute_embeddings(self): self._test_command("micro_sam.precompute_embeddings") + def test_automatic_segmentation(self): + self._test_command("micro_sam.automatic_segmentation") + # The filepaths can't be found on windows, probably due different filepath conventions. # The actual functionality likely works despite this issue. if platform.system() == "Windows": From 20e89abccadb53497466175b46591848f4d2de29 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Sun, 29 Sep 2024 16:09:33 +0200 Subject: [PATCH 35/53] Napari updates (#710) * Fix clearing shape layer * Replace edge-attributes with border-attributes for points layer (#697) * Update weight loading --------- Co-authored-by: Anwai Archit <52396323+anwai98@users.noreply.github.com> --- environment_cpu.yaml | 2 +- environment_gpu.yaml | 2 +- micro_sam/sam_annotator/_annotator.py | 8 ++++---- micro_sam/sam_annotator/annotator_tracking.py | 8 ++++---- micro_sam/sam_annotator/util.py | 2 +- micro_sam/util.py | 2 +- test/test_prompt_generators.py | 8 ++++---- 7 files changed, 16 insertions(+), 16 deletions(-) diff --git a/environment_cpu.yaml b/environment_cpu.yaml index 6904fbac7..d56682ff5 100644 --- a/environment_cpu.yaml +++ b/environment_cpu.yaml @@ -7,7 +7,7 @@ dependencies: - nifty =1.2.1=*_4 - imagecodecs - magicgui - - napari <0.5 + - napari - pip - pooch - pyqt diff --git a/environment_gpu.yaml b/environment_gpu.yaml index 90445731c..c77a834e1 100644 --- a/environment_gpu.yaml +++ b/environment_gpu.yaml @@ -7,7 +7,7 @@ dependencies: - imagecodecs - nifty =1.2.1=*_4 - magicgui - - napari <0.5 + - napari - pip - pooch - pyqt diff --git a/micro_sam/sam_annotator/_annotator.py b/micro_sam/sam_annotator/_annotator.py index 974aeccb5..127c09e07 100644 --- a/micro_sam/sam_annotator/_annotator.py +++ b/micro_sam/sam_annotator/_annotator.py @@ -29,15 +29,15 @@ def _create_layers(self): self._point_prompt_layer = self._viewer.add_points( name="point_prompts", property_choices={"label": self._point_labels}, - edge_color="label", - edge_color_cycle=vutil.LABEL_COLOR_CYCLE, + border_color="label", + border_color_cycle=vutil.LABEL_COLOR_CYCLE, symbol="o", face_color="transparent", - edge_width=0.5, + border_width=0.5, size=12, ndim=self._ndim, ) - self._point_prompt_layer.edge_color_mode = "cycle" + self._point_prompt_layer.border_color_mode = "cycle" # Add the shape layer for box and other shape prompts. self._viewer.add_shapes( diff --git a/micro_sam/sam_annotator/annotator_tracking.py b/micro_sam/sam_annotator/annotator_tracking.py index d82b0923f..f839697dd 100644 --- a/micro_sam/sam_annotator/annotator_tracking.py +++ b/micro_sam/sam_annotator/annotator_tracking.py @@ -110,16 +110,16 @@ def _create_layers(self): "state": self._track_state_labels, "track_id": ["1"], # we use string to avoid pandas warning }, - edge_color="label", - edge_color_cycle=vutil.LABEL_COLOR_CYCLE, + border_color="label", + border_color_cycle=vutil.LABEL_COLOR_CYCLE, symbol="o", face_color="state", face_color_cycle=STATE_COLOR_CYCLE, - edge_width=0.4, + border_width=0.4, size=12, ndim=self._ndim, ) - self._point_prompt_layer.edge_color_mode = "cycle" + self._point_prompt_layer.border_color_mode = "cycle" self._point_prompt_layer.face_color_mode = "cycle" # Using the box layer to set divisions currently doesn't work. diff --git a/micro_sam/sam_annotator/util.py b/micro_sam/sam_annotator/util.py index aae5a75ef..db2ec5187 100644 --- a/micro_sam/sam_annotator/util.py +++ b/micro_sam/sam_annotator/util.py @@ -118,7 +118,7 @@ def clear_annotations(viewer: napari.Viewer, clear_segmentations=True) -> None: viewer.layers["point_prompts"].data = [] viewer.layers["point_prompts"].refresh() if "prompts" in viewer.layers: - viewer.layers["prompts"].data = [] + viewer.layers["prompts"].remove_selected() viewer.layers["prompts"].refresh() if not clear_segmentations: return diff --git a/micro_sam/util.py b/micro_sam/util.py index a514f3d0a..749fd4f88 100644 --- a/micro_sam/util.py +++ b/micro_sam/util.py @@ -396,7 +396,7 @@ def get_sam_model( # Add the decoder to the state if we have one and if the state is returned. if decoder_path is not None and return_state: - state["decoder_state"] = torch.load(decoder_path, map_location=device, weights_only=True) + state["decoder_state"] = torch.load(decoder_path, map_location=device, weights_only=False) if return_sam and return_state: return predictor, sam, state diff --git a/test/test_prompt_generators.py b/test/test_prompt_generators.py index 18fa3f7bb..162f9c7a3 100644 --- a/test/test_prompt_generators.py +++ b/test/test_prompt_generators.py @@ -34,15 +34,15 @@ def _debug(self, mask, coordinates=None, labels=None, box=None, deformed_mask=No data=coordinates, name="prompts", properties={"label": labels}, - edge_color="label", - edge_color_cycle=["#00FF00", "#FF0000"], + border_color="label", + border_color_cycle=["#00FF00", "#FF0000"], symbol="o", face_color="transparent", - edge_width=0.5, + border_width=0.5, size=5, ndim=2 ) # this function helps to view the (colored) background/foreground points - prompts.edge_color_mode = "cycle" + prompts.border_color_mode = "cycle" if deformed_mask is not None: v.add_labels(deformed_mask.astype("uint8"), name="deformed mask / prediction") From 1a555a492a76d2fb7e718747036b08972f4ce061 Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Sun, 29 Sep 2024 19:18:37 +0200 Subject: [PATCH 36/53] Extend iterative prompt generators to return prompts for 3d (#692) Extend support for iterative prompt generators to 3d --- micro_sam/prompt_generators.py | 119 ++++++++++++++++++++------------- 1 file changed, 73 insertions(+), 46 deletions(-) diff --git a/micro_sam/prompt_generators.py b/micro_sam/prompt_generators.py index e0401b4c6..839077410 100644 --- a/micro_sam/prompt_generators.py +++ b/micro_sam/prompt_generators.py @@ -252,20 +252,26 @@ def __call__( class IterativePromptGenerator(PromptGeneratorBase): """Generate point prompts from an instance segmentation iteratively. """ - def _get_positive_points(self, pos_region, overlap_region): + def _get_positive_points(self, pos_region, overlap_region, is_3d): positive_locations = [torch.where(pos_reg) for pos_reg in pos_region] # we may have objects without a positive region (= missing true foreground) - # in this case we just sample a point where the model was already correct + # in this case we just sample a positive point where the model was already correct positive_locations = [ torch.where(ovlp_reg) if len(pos_loc[0]) == 0 else pos_loc for pos_loc, ovlp_reg in zip(positive_locations, overlap_region) ] - # we sample one location for each object in the batch + # we sample one positive location for each object in the batch sampled_indices = [np.random.choice(len(pos_loc[0])) for pos_loc in positive_locations] - # get the corresponding coordinates (Note that we flip the axis order here due to the expected order of SAM) - pos_coordinates = [ - [pos_loc[-1][idx], pos_loc[-2][idx]] for pos_loc, idx in zip(positive_locations, sampled_indices) - ] + # get the corresponding coordinates (NOTE: we flip the axis order here due to the expected order of SAM) + if is_3d: + pos_coordinates = [ + [pos_loc[-1][idx], pos_loc[-2][idx], pos_loc[-3][idx]] + for pos_loc, idx in zip(positive_locations, sampled_indices) + ] + else: + pos_coordinates = [ + [pos_loc[-1][idx], pos_loc[-2][idx]] for pos_loc, idx in zip(positive_locations, sampled_indices) + ] # make sure that we still have the correct batch size assert len(pos_coordinates) == pos_region.shape[0] @@ -273,43 +279,55 @@ def _get_positive_points(self, pos_region, overlap_region): return pos_coordinates, pos_labels - # TODO get rid of this looped implementation and use proper batched computation instead - def _get_negative_points(self, negative_region_batched, true_object_batched): - device = negative_region_batched.device - - negative_coordinates, negative_labels = [], [] - for neg_region, true_object in zip(negative_region_batched, true_object_batched): - - tmp_neg_loc = torch.where(neg_region) - if torch.stack(tmp_neg_loc).shape[-1] == 0: - tmp_true_loc = torch.where(true_object) - x_coords, y_coords = tmp_true_loc[1], tmp_true_loc[2] - bbox = torch.stack([torch.min(x_coords), torch.min(y_coords), - torch.max(x_coords) + 1, torch.max(y_coords) + 1]) - bbox_mask = torch.zeros_like(true_object).squeeze(0) - - custom_df = 3 # custom dilation factor to perform dilation by expanding the pixels of bbox - bbox_mask[max(bbox[0] - custom_df, 0): min(bbox[2] + custom_df, true_object.shape[-2]), - max(bbox[1] - custom_df, 0): min(bbox[3] + custom_df, true_object.shape[-1])] = 1 - bbox_mask = bbox_mask[None].to(device) - - background_mask = torch.abs(bbox_mask - true_object) - tmp_neg_loc = torch.where(background_mask) - - # there is a chance that the object is small to not return a decent-sized bounding box - # hence we might not find points sometimes there as well, hence we sample points from true background - if torch.stack(tmp_neg_loc).shape[-1] == 0: - tmp_neg_loc = torch.where(true_object == 0) + def _get_negative_locations_in_obj_bbox(self, true_object, custom_df=3): + true_loc = torch.where(true_object) + bbox = torch.stack( + [torch.min(true_loc[1]), torch.min(true_loc[2]), torch.max(true_loc[1]) + 1, torch.max(true_loc[2]) + 1] + ) - neg_index = np.random.choice(len(tmp_neg_loc[1])) - neg_coordinates = [tmp_neg_loc[1][neg_index], tmp_neg_loc[2][neg_index]] - neg_coordinates = neg_coordinates[::-1] - neg_labels = 0 + # custom dilation factor to perform dilation by expanding the pixels of bbox + bbox_mask = torch.zeros_like(true_object).squeeze(0) + bbox_mask[ + max(bbox[0] - custom_df, 0): min(bbox[2] + custom_df, true_object.shape[-2]), + max(bbox[1] - custom_df, 0): min(bbox[3] + custom_df, true_object.shape[-1]) + ] = 1 + bbox_mask = bbox_mask[None].to(true_object.device) + background_mask = torch.abs(bbox_mask - true_object) + return torch.where(background_mask) + + def _get_negative_points(self, neg_region, true_object, is_3d): + # we have a valid negative region (i.e. a valid region where the model could not generate prediction) + negative_locations = [torch.where(neg_reg) for neg_reg in neg_region] + # we may have objects without a negative region (= no rectifications required) + # in this case we sample a negative point in outer periphery of the object inside the bounding box. + negative_locations = [ + self._get_negative_locations_in_obj_bbox(true_obj) if len(neg_loc[0]) == 0 else neg_loc + for neg_loc, true_obj in zip(negative_locations, true_object) + ] + # there is a chance that the object is small to not return a decent-sized bounding box + # hence we might not find points sometimes there as well. therefore, we sample points from true background. + negative_locations = [ + torch.where(true_obj == 0) if len(neg_loc[0]) == 0 else neg_loc + for neg_loc, true_obj in zip(negative_locations, true_object) + ] + # we sample one negative location for each object in the batch + sampled_indices = [np.random.choice(len(neg_loc[0])) for neg_loc in negative_locations] + # get the corresponding coordinates (NOTE: we flip the axis order here due to the expected order of SAM) + if is_3d: + neg_coordinates = [ + [neg_loc[-1][idx], neg_loc[-2][idx], neg_loc[-3][idx]] + for neg_loc, idx in zip(negative_locations, sampled_indices) + ] + else: + neg_coordinates = [ + [neg_loc[-1][idx], neg_loc[-2][idx]] for neg_loc, idx in zip(negative_locations, sampled_indices) + ] - negative_coordinates.append(neg_coordinates) - negative_labels.append(neg_labels) + # make sure that we still have the correct batch size + assert len(neg_coordinates) == neg_region.shape[0] + neg_labels = [0] * len(neg_coordinates) - return negative_coordinates, negative_labels + return neg_coordinates, neg_labels def __call__( self, @@ -320,15 +338,24 @@ def __call__( """Generate the prompts for each object iteratively in the segmentation. Args: - The groundtruth segmentation. Expects a float tensor of shape NUM_OBJECTS x 1 x H x W. - The predicted objects. Epects a float tensor of the same shape as the segmentation. + segmentation: The groundtruth segmentation. + Expects a float tensor of shape (NUM_OBJECTS x 1 x H x W) or (NUM_OBJECTS x 1 x Z x H x W). + prediction: The predicted objects. Epects a float tensor of the same shape as the segmentation. Returns: The updated point prompt coordinates. The updated point prompt labels. """ - assert segmentation.shape == prediction.shape device = prediction.device + assert segmentation.shape == prediction.shape, \ + "The segmentation and prediction tensors should have the same shape." + + if segmentation.ndim == 5: # masks in 3d must be tensors of shape NUM_OBJECTS x 1 x Z x H x W + is_3d = True + elif segmentation.ndim == 4: # masks in 2d must be tensors of shape NUM_OBJECTS x 1 x H x W + is_3d = False + else: + raise ValueError("The segmentation and prediction tensors should have either '4' or '5' dimensions.") true_object = segmentation.to(device) expected_diff = (prediction - true_object) @@ -336,8 +363,8 @@ def __call__( pos_region = (expected_diff == -1) overlap_region = torch.logical_and(prediction == 1, true_object == 1).to(torch.float32) - pos_coordinates, pos_labels = self._get_positive_points(pos_region, overlap_region) - neg_coordinates, neg_labels = self._get_negative_points(neg_region, true_object) + pos_coordinates, pos_labels = self._get_positive_points(pos_region, overlap_region, is_3d) + neg_coordinates, neg_labels = self._get_negative_points(neg_region, true_object, is_3d) assert len(pos_coordinates) == len(pos_labels) == len(neg_coordinates) == len(neg_labels) pos_coordinates = torch.tensor(pos_coordinates)[:, None] From bd9f44c7ff9dbe3a3672a13d6c64bb8ffc835da3 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Mon, 30 Sep 2024 16:17:37 +0200 Subject: [PATCH 37/53] Overwrite keybindings of prompt layers (#711) --- micro_sam/sam_annotator/_annotator.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/micro_sam/sam_annotator/_annotator.py b/micro_sam/sam_annotator/_annotator.py index 127c09e07..fcd58cc44 100644 --- a/micro_sam/sam_annotator/_annotator.py +++ b/micro_sam/sam_annotator/_annotator.py @@ -72,6 +72,20 @@ def _create_keybindings(self): def _segment(viewer): self._widgets["segment"](viewer) + # Note: we also need to over-write the keybindings for specific layers. + # See https://github.com/napari/napari/issues/7302 for details. + # Here, we need to over-write the 's' keybinding for both of the prompt layers. + prompt_layer = self._viewer.layers["prompts"] + point_prompt_layer = self._viewer.layers["point_prompts"] + + @prompt_layer.bind_key("s", overwrite=True) + def _segment_prompts(event): + self._widgets["segment"](self._viewer) + + @point_prompt_layer.bind_key("s", overwrite=True) + def _segment_point_prompts(event): + self._widgets["segment"](self._viewer) + @self._viewer.bind_key("c", overwrite=True) def _commit(viewer): self._widgets["commit"](viewer) From e1bf6594a98f4f1947151852537a4f840a296ac4 Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Tue, 1 Oct 2024 16:26:32 +0200 Subject: [PATCH 38/53] Update dropout defaults for FacT finetuning method (#714) --- micro_sam/models/peft_sam.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/micro_sam/models/peft_sam.py b/micro_sam/models/peft_sam.py index 59167a1df..1ba298d43 100644 --- a/micro_sam/models/peft_sam.py +++ b/micro_sam/models/peft_sam.py @@ -70,7 +70,7 @@ def __init__( self, rank: int, block: nn.Module, - dropout: Optional[float] = None, + dropout: Optional[float] = 0.1, ): super().__init__() self.qkv_proj = block.attn.qkv @@ -104,7 +104,6 @@ def forward(self, x): new_v = self.FacTv(new_v) # NOTE : Scaling Factor was set to 1 as it can be tuned via the learning rate - # Does it make sense to include it, in order to have similar learning rate as the original model? qkv[:, :, :, : self.dim] += new_q qkv[:, :, :, -self.dim:] += new_v From a8af9c44ad7f4cc0ad0c93c1ac40e41cd9bd11f9 Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Tue, 1 Oct 2024 16:33:27 +0200 Subject: [PATCH 39/53] Add selective peft methods (#708) Add selective peft methods (eg. attention ft, bias ft, layernorm ft) --- micro_sam/models/peft_sam.py | 78 ++++++++++++++++++++++++++++--- micro_sam/training/util.py | 17 ++----- micro_sam/util.py | 11 +---- test/test_models/test_peft_sam.py | 42 +++++++++++++++++ 4 files changed, 120 insertions(+), 28 deletions(-) diff --git a/micro_sam/models/peft_sam.py b/micro_sam/models/peft_sam.py index 1ba298d43..634a2fa9b 100644 --- a/micro_sam/models/peft_sam.py +++ b/micro_sam/models/peft_sam.py @@ -23,11 +23,7 @@ class LoRASurgery(nn.Module): rank: The rank of the decomposition matrices for updating weights in each attention layer. block: The chosen attention blocks for implementing lora. """ - def __init__( - self, - rank: int, - block: nn.Module, - ): + def __init__(self, rank: int, block: nn.Module): super().__init__() self.qkv_proj = block.attn.qkv self.dim = self.qkv_proj.in_features @@ -64,8 +60,8 @@ class FacTSurgery(nn.Module): Args: rank: The rank of the decomposition matrices for updating weights in each attention layer. block: The chosen attention blocks for implementing fact. + dropout: The dropout rate for dropout layers. """ - def __init__( self, rank: int, @@ -110,6 +106,69 @@ def forward(self, x): return qkv +class SelectiveSurgery(nn.Module): + """Base class for selectively allowing gradient updates for certain parameters. + """ + def __init__(self, block: nn.Module): + super().__init__() + self.block = block + + def allow_gradient_update_for_parameters( + self, + prefix: Optional[List[str]] = None, + suffix: Optional[List[str]] = None, + infix: Optional[List[str]] = None, + ): + """This function decides the parameter attributes to match for allowing gradient updates. + + Args: + prefix: Matches the part of parameter name in front. + suffix: Matches the part of parameter name at the end. + infix: Matches parts of parameter name occuring in between. + """ + for k, v in self.block.named_parameters(): + if prefix is not None and k.startswith(tuple(prefix)): + v.requires_grad = True + + if suffix is not None and k.endswith(tuple(suffix)): + v.requires_grad = True + + if infix is not None: + for per_infix in infix: + if k.find(per_infix) != -1: + v.requires_grad = True + + def forward(self, x): + return x + + +class AttentionSurgery(SelectiveSurgery): + """Child class for allowing gradient updates for parameters in attention layers. + """ + def __init__(self, block: nn.Module): + super().__init__(block=block) + # Allow gradient updates for the attention layers in the image encoder. + self.allow_gradient_update_for_parameters(prefix=["attn"]) + + +class BiasSurgery(SelectiveSurgery): + """Child class for allowing gradient updates for bias parameters. + """ + def __init__(self, block: nn.Module): + super().__init__(block=block) + # Allow gradient updates for the bias parameters in the image encoder. + self.allow_gradient_update_for_parameters(suffix=["bias"]) + + +class LayerNormSurgery(SelectiveSurgery): + """Child class for allowing gradient updates in normalization layers. + """ + def __init__(self, block: nn.Module): + super().__init__(block=block) + # Allow gradient updates for the LayerNorm parameters in the image encoder. + self.allow_gradient_update_for_parameters(infix=["norm1", "norm2"]) + + class PEFT_Sam(nn.Module): """Wraps the Segment Anything model's image encoder to different parameter efficient finetuning methods. @@ -130,6 +189,7 @@ def __init__( super().__init__() assert rank > 0 + assert issubclass(peft_module, Union[LoRASurgery, FacTSurgery, SelectiveSurgery]), "Invalid PEFT module." if attention_layers_to_update: self.peft_layers = attention_layers_to_update @@ -148,7 +208,11 @@ def __init__( if t_layer_i not in self.peft_layers: continue - peft_block = self.peft_module(rank=rank, block=blk) + if issubclass(self.peft_module, SelectiveSurgery): + peft_block = self.peft_module(block=blk) + else: + peft_block = self.peft_module(rank=rank, block=blk) + self.peft_blocks.append(peft_block) self.peft_blocks = nn.ModuleList(self.peft_blocks) diff --git a/micro_sam/training/util.py b/micro_sam/training/util.py index fb4834c03..7ecf41cd0 100644 --- a/micro_sam/training/util.py +++ b/micro_sam/training/util.py @@ -59,9 +59,7 @@ def get_trainable_sam_model( freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder By default nothing is frozen and the full model is updated. return_state: Whether to return the full checkpoint state. - lora_rank: The rank of the decomposition matrices for updating weights in each attention layer with lora. - If None then LoRA is not used. - lora_kwargs: Keyword arguments for the PEFT wrapper class. + peft_kwargs: Keyword arguments for the PEFT wrapper class. flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints. model_kwargs: Additional keyword arguments for the `util.get_sam_model`. @@ -82,16 +80,11 @@ def get_trainable_sam_model( # NOTE: This is done exclusive to "get_sam_model" here to use PEFT's layer-specific initialization on top. # Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything. - # Overwrites the SAM model by freezing the backbone and allow low rank adaption to attention layers. + # Overwrites the SAM model by freezing the backbone and allow PEFT methods. if peft_kwargs and isinstance(peft_kwargs, dict): if model_type[:5] == "vit_t": raise ValueError("'micro-sam' does not support parameter efficient finetuning for 'mobile-sam'.") - peft_module = peft_kwargs.get("peft_module") - if peft_module is not None: - from micro_sam.models.peft_sam import LoRASurgery, FacTSurgery - assert peft_module in [LoRASurgery, FacTSurgery], "Invalid PEFT module." - sam = custom_models.peft_sam.PEFT_Sam(sam, **peft_kwargs).sam # freeze components of the model if freeze was passed @@ -106,9 +99,9 @@ def get_trainable_sam_model( # we would want to "freeze" all the components in the model if passed a list of parts for l_item in freeze: - # in case LoRA is switched on, we cannot freeze the image encoder - if (peft_kwargs['rank'] is not None) and (l_item == "image_encoder"): - raise ValueError("You cannot use LoRA & freeze the image encoder at the same time.") + # in case PEFT is switched on, we cannot freeze the image encoder + if (peft_kwargs and peft_kwargs.get('rank') is not None) and (l_item == "image_encoder"): + raise ValueError("You cannot use PEFT & freeze the image encoder at the same time.") if name.startswith(f"{l_item}"): param.requires_grad = False diff --git a/micro_sam/util.py b/micro_sam/util.py index 749fd4f88..6489188c7 100644 --- a/micro_sam/util.py +++ b/micro_sam/util.py @@ -308,9 +308,7 @@ def get_sam_model( then `model_type` must be given as "vit_b". return_sam: Return the sam model object as well as the predictor. return_state: Return the unpickled checkpoint state. - lora_rank: The rank of the decomposition matrices for updating weights in each attention layer with lora. - If None then LoRA is not used. - lora_kwargs: Keyword arguments for th PEFT wrapper class. + peft_kwargs: Keyword arguments for th PEFT wrapper class. flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints. Returns: @@ -369,16 +367,11 @@ def get_sam_model( sam = sam_model_registry[abbreviated_model_type]() # Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything. - # Overwrites the SAM model by freezing the backbone and allow low rank adaption to attention layers. + # Overwrites the SAM model by freezing the backbone and allow PEFT. if peft_kwargs and isinstance(peft_kwargs, dict): if abbreviated_model_type == "vit_t": raise ValueError("'micro-sam' does not support parameter efficient finetuning for 'mobile-sam'.") - peft_module = peft_kwargs.get("peft_module") - if peft_module is not None: - from .models.peft_sam import LoRASurgery, FacTSurgery - assert peft_module in [LoRASurgery, FacTSurgery], "Invalid PEFT module." - sam = custom_models.peft_sam.PEFT_Sam(sam, **peft_kwargs).sam # In case the model checkpoints have some issues when it is initialized with different parameters than default. diff --git a/test/test_models/test_peft_sam.py b/test/test_models/test_peft_sam.py index 509a67650..4461aa9b1 100644 --- a/test/test_models/test_peft_sam.py +++ b/test/test_models/test_peft_sam.py @@ -36,6 +36,48 @@ def test_fact_sam(self): masks = output[0]["masks"] self.assertEqual(masks.shape, expected_shape) + def test_attention_layer_peft_sam(self): + from micro_sam.models.peft_sam import PEFT_Sam, AttentionSurgery + + _, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu") + peft_sam = PEFT_Sam(sam, rank=2, peft_module=AttentionSurgery) + + shape = (3, 1024, 1024) + expected_shape = (1, 3, 1024, 1024) + with torch.no_grad(): + batched_input = [{"image": torch.rand(*shape), "original_size": shape[1:]}] + output = peft_sam(batched_input, multimask_output=True) + masks = output[0]["masks"] + self.assertEqual(masks.shape, expected_shape) + + def test_norm_layer_peft_sam(self): + from micro_sam.models.peft_sam import PEFT_Sam, LayerNormSurgery + + _, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu") + peft_sam = PEFT_Sam(sam, rank=2, peft_module=LayerNormSurgery) + + shape = (3, 1024, 1024) + expected_shape = (1, 3, 1024, 1024) + with torch.no_grad(): + batched_input = [{"image": torch.rand(*shape), "original_size": shape[1:]}] + output = peft_sam(batched_input, multimask_output=True) + masks = output[0]["masks"] + self.assertEqual(masks.shape, expected_shape) + + def test_bias_layer_peft_sam(self): + from micro_sam.models.peft_sam import PEFT_Sam, BiasSurgery + + _, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu") + peft_sam = PEFT_Sam(sam, rank=2, peft_module=BiasSurgery) + + shape = (3, 1024, 1024) + expected_shape = (1, 3, 1024, 1024) + with torch.no_grad(): + batched_input = [{"image": torch.rand(*shape), "original_size": shape[1:]}] + output = peft_sam(batched_input, multimask_output=True) + masks = output[0]["masks"] + self.assertEqual(masks.shape, expected_shape) + if __name__ == "__main__": unittest.main() From 03775deca25a7af43fd09f8b5c602d4bbf018c93 Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Wed, 2 Oct 2024 10:59:13 +0200 Subject: [PATCH 40/53] Expose kwargs to allow changing peft module arguments (#718) expose kwargs to allows parameters like dropout, etc. for peft modules --------- Co-authored-by: Carolin --- micro_sam/models/peft_sam.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/micro_sam/models/peft_sam.py b/micro_sam/models/peft_sam.py index 634a2fa9b..febbccf6b 100644 --- a/micro_sam/models/peft_sam.py +++ b/micro_sam/models/peft_sam.py @@ -60,7 +60,7 @@ class FacTSurgery(nn.Module): Args: rank: The rank of the decomposition matrices for updating weights in each attention layer. block: The chosen attention blocks for implementing fact. - dropout: The dropout rate for dropout layers. + dropout: The dropout rate for the factorized attention. """ def __init__( self, @@ -77,7 +77,6 @@ def __init__( self.dropout = dropout if self.dropout is not None: - # NOTE : Dropout is not included in the original implementation self.dp_q = nn.Dropout(self.dropout) self.dp_v = nn.Dropout(self.dropout) @@ -87,7 +86,7 @@ def __init__( block.attn.qkv = self def forward(self, x): - qkv = self.qkv_proj(x) # B, N, N, 3 * org_C + qkv = self.qkv_proj(x) new_q = self.q_FacTs(self.FacTu(x)) new_v = self.v_FacTs(self.FacTu(x)) @@ -184,7 +183,8 @@ def __init__( model: Sam, rank: int, peft_module: nn.Module = LoRASurgery, - attention_layers_to_update: Union[List[int]] = None + attention_layers_to_update: Union[List[int]] = None, + **module_kwargs ): super().__init__() @@ -211,7 +211,7 @@ def __init__( if issubclass(self.peft_module, SelectiveSurgery): peft_block = self.peft_module(block=blk) else: - peft_block = self.peft_module(rank=rank, block=blk) + peft_block = self.peft_module(rank=rank, block=blk, **module_kwargs) self.peft_blocks.append(peft_block) From cda4f660ce46fe9816dd39120981fc612f3510d1 Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Sun, 6 Oct 2024 21:17:37 +0200 Subject: [PATCH 41/53] Minor update to ignore warnings in train_sam functionality (#722) Add warning filter to sam training --------- Co-authored-by: Constantin Pape --- micro_sam/training/training.py | 209 ++++++++++++++++++--------------- 1 file changed, 113 insertions(+), 96 deletions(-) diff --git a/micro_sam/training/training.py b/micro_sam/training/training.py index 205e7d759..359d67921 100644 --- a/micro_sam/training/training.py +++ b/micro_sam/training/training.py @@ -1,6 +1,7 @@ import os import time import warnings +from contextlib import contextmanager, nullcontext from glob import glob from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -134,6 +135,17 @@ def _count_parameters(model_parameters): print(f"The number of trainable parameters for the provided model is {round(params, 2)}M") +@contextmanager +def _filter_warnings(ignore_warnings): + if ignore_warnings: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + yield + else: + with nullcontext(): + yield + + def train_sam( name: str, model_type: str, @@ -157,6 +169,7 @@ def train_sam( pbar_signals: Optional[QObject] = None, optimizer_class: Optional[Optimizer] = torch.optim.AdamW, peft_kwargs: Optional[Dict] = None, + ignore_warnings: bool = True, **model_kwargs, ) -> None: """Run training for a SAM model. @@ -194,117 +207,121 @@ def train_sam( pbar_signals: Controls for napari progress bar. optimizer_class: The optimizer class. By default, torch.optim.AdamW is used. + peft_kwargs: Keyword arguments for the PEFT wrapper class. + ignore_warnings: Whether to ignore raised warnings. """ - t_start = time.time() - - _check_loader(train_loader, with_segmentation_decoder) - _check_loader(val_loader, with_segmentation_decoder) - - device = get_device(device) - # Get the trainable segment anything model. - model, state = get_trainable_sam_model( - model_type=model_type, - device=device, - freeze=freeze, - checkpoint_path=checkpoint_path, - return_state=True, - peft_kwargs=peft_kwargs, - **model_kwargs - ) - # This class creates all the training data for a batch (inputs, prompts and labels). - convert_inputs = ConvertToSamInputs(transform=model.transform, box_distortion_factor=0.025) + with _filter_warnings(ignore_warnings): - # Create the UNETR decoder (if train with it) and the optimizer. - if with_segmentation_decoder: + t_start = time.time() + + _check_loader(train_loader, with_segmentation_decoder) + _check_loader(val_loader, with_segmentation_decoder) - # Get the UNETR. - unetr = get_unetr( - image_encoder=model.sam.image_encoder, - decoder_state=state.get("decoder_state", None), + device = get_device(device) + # Get the trainable segment anything model. + model, state = get_trainable_sam_model( + model_type=model_type, device=device, + freeze=freeze, + checkpoint_path=checkpoint_path, + return_state=True, + peft_kwargs=peft_kwargs, + **model_kwargs ) + # This class creates all the training data for a batch (inputs, prompts and labels). + convert_inputs = ConvertToSamInputs(transform=model.transform, box_distortion_factor=0.025) - # Get the parameters for SAM and the decoder from UNETR. - joint_model_params = [params for params in model.parameters()] # sam parameters - for param_name, params in unetr.named_parameters(): # unetr's decoder parameters - if not param_name.startswith("encoder"): - joint_model_params.append(params) + # Create the UNETR decoder (if train with it) and the optimizer. + if with_segmentation_decoder: - optimizer = optimizer_class(joint_model_params, lr=lr) - - else: - optimizer = optimizer_class(model.parameters(), lr=lr) + # Get the UNETR. + unetr = get_unetr( + image_encoder=model.sam.image_encoder, + decoder_state=state.get("decoder_state", None), + device=device, + ) - if scheduler_kwargs is None: - scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 3, "verbose": True} + # Get the parameters for SAM and the decoder from UNETR. + joint_model_params = [params for params in model.parameters()] # sam parameters + for param_name, params in unetr.named_parameters(): # unetr's decoder parameters + if not param_name.startswith("encoder"): + joint_model_params.append(params) - scheduler = scheduler_class(optimizer=optimizer, **scheduler_kwargs) + optimizer = optimizer_class(joint_model_params, lr=lr) - # The trainer which performs training and validation. - if with_segmentation_decoder: - instance_seg_loss = torch_em.loss.DiceBasedDistanceLoss(mask_distances_in_bg=True) - trainer = joint_trainers.JointSamTrainer( - name=name, - save_root=save_root, - train_loader=train_loader, - val_loader=val_loader, - model=model, - optimizer=optimizer, - device=device, - lr_scheduler=scheduler, - logger=joint_trainers.JointSamLogger, - log_image_interval=100, - mixed_precision=True, - convert_inputs=convert_inputs, - n_objects_per_batch=n_objects_per_batch, - n_sub_iteration=n_sub_iteration, - compile_model=False, - unetr=unetr, - instance_loss=instance_seg_loss, - instance_metric=instance_seg_loss, - early_stopping=early_stopping, - mask_prob=mask_prob, - ) - else: - trainer = trainers.SamTrainer( - name=name, - train_loader=train_loader, - val_loader=val_loader, - model=model, - optimizer=optimizer, - device=device, - lr_scheduler=scheduler, - logger=trainers.SamLogger, - log_image_interval=100, - mixed_precision=True, - convert_inputs=convert_inputs, - n_objects_per_batch=n_objects_per_batch, - n_sub_iteration=n_sub_iteration, - compile_model=False, - early_stopping=early_stopping, - mask_prob=mask_prob, - save_root=save_root, - ) + else: + optimizer = optimizer_class(model.parameters(), lr=lr) + + if scheduler_kwargs is None: + scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 3, "verbose": True} + + scheduler = scheduler_class(optimizer=optimizer, **scheduler_kwargs) + + # The trainer which performs training and validation. + if with_segmentation_decoder: + instance_seg_loss = torch_em.loss.DiceBasedDistanceLoss(mask_distances_in_bg=True) + trainer = joint_trainers.JointSamTrainer( + name=name, + save_root=save_root, + train_loader=train_loader, + val_loader=val_loader, + model=model, + optimizer=optimizer, + device=device, + lr_scheduler=scheduler, + logger=joint_trainers.JointSamLogger, + log_image_interval=100, + mixed_precision=True, + convert_inputs=convert_inputs, + n_objects_per_batch=n_objects_per_batch, + n_sub_iteration=n_sub_iteration, + compile_model=False, + unetr=unetr, + instance_loss=instance_seg_loss, + instance_metric=instance_seg_loss, + early_stopping=early_stopping, + mask_prob=mask_prob, + ) + else: + trainer = trainers.SamTrainer( + name=name, + train_loader=train_loader, + val_loader=val_loader, + model=model, + optimizer=optimizer, + device=device, + lr_scheduler=scheduler, + logger=trainers.SamLogger, + log_image_interval=100, + mixed_precision=True, + convert_inputs=convert_inputs, + n_objects_per_batch=n_objects_per_batch, + n_sub_iteration=n_sub_iteration, + compile_model=False, + early_stopping=early_stopping, + mask_prob=mask_prob, + save_root=save_root, + ) - if n_iterations is None: - trainer_fit_params = {"epochs": n_epochs} - else: - trainer_fit_params = {"iterations": n_iterations} + if n_iterations is None: + trainer_fit_params = {"epochs": n_epochs} + else: + trainer_fit_params = {"iterations": n_iterations} - if save_every_kth_epoch is not None: - trainer_fit_params["save_every_kth_epoch"] = save_every_kth_epoch + if save_every_kth_epoch is not None: + trainer_fit_params["save_every_kth_epoch"] = save_every_kth_epoch - if pbar_signals is not None: - progress_bar_wrapper = _ProgressBarWrapper(pbar_signals) - trainer_fit_params["progress"] = progress_bar_wrapper + if pbar_signals is not None: + progress_bar_wrapper = _ProgressBarWrapper(pbar_signals) + trainer_fit_params["progress"] = progress_bar_wrapper - trainer.fit(**trainer_fit_params) + trainer.fit(**trainer_fit_params) - t_run = time.time() - t_start - hours = int(t_run // 3600) - minutes = int(t_run // 60) - seconds = int(round(t_run % 60, 0)) - print("Training took", t_run, f"seconds (= {hours:02}:{minutes:02}:{seconds:02} hours)") + t_run = time.time() - t_start + hours = int(t_run // 3600) + minutes = int(t_run // 60) + seconds = int(round(t_run % 60, 0)) + print("Training took", t_run, f"seconds (= {hours:02}:{minutes:02}:{seconds:02} hours)") def _update_patch_shape(patch_shape, raw_paths, raw_key, with_channels): From c48d68f527ee172f0d96ee1b2fa1e8765092b3f2 Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Tue, 8 Oct 2024 09:03:50 +0200 Subject: [PATCH 42/53] Add dry run for loaders to check for valid instances (#705) Add dry run for loaders to check for valid instances --- micro_sam/training/training.py | 93 +++++++++++++++++++++------------- 1 file changed, 57 insertions(+), 36 deletions(-) diff --git a/micro_sam/training/training.py b/micro_sam/training/training.py index 359d67921..165a10ae9 100644 --- a/micro_sam/training/training.py +++ b/micro_sam/training/training.py @@ -1,8 +1,9 @@ import os import time import warnings -from contextlib import contextmanager, nullcontext from glob import glob +from tqdm import tqdm +from contextlib import contextmanager, nullcontext from typing import Any, Callable, Dict, List, Optional, Tuple, Union import imageio.v3 as imageio @@ -32,8 +33,8 @@ FilePath = Union[str, os.PathLike] -def _check_loader(loader, with_segmentation_decoder): - x, y = next(iter(loader)) +def _check_loader(loader, with_segmentation_decoder, name=None, verify_n_labels_in_loader=None): + x, _ = next(iter(loader)) # Raw data: check that we have 1 or 3 channels. n_channels = x.shape[1] @@ -57,8 +58,9 @@ def _check_loader(loader, with_segmentation_decoder): ) # Target data: the check depends on whether we train with or without decoder. + # NOTE: Verification step to check whether all labels from dataloader are valid (i.e. have atleast one instance). - def check_instance_channel(instance_channel): + def _check_instance_channel(instance_channel): unique_vals = torch.unique(instance_channel) if (unique_vals < 0).any(): raise ValueError( @@ -73,38 +75,53 @@ def check_instance_channel(instance_channel): "All values in the target channel with the instance segmentation must be integer." ) - n_channels_y = y.shape[1] - if with_segmentation_decoder: - if n_channels_y != 4: - raise ValueError( - "Invalid number of channels in the target data from the data loader. " - "Expect 4 channel for training with an instance segmentation decoder, " - f"but got {n_channels_y} channels." - ) - check_instance_channel(y[:, 0]) + counter = 0 + name = "" if name is None else f"'{name}'" + for x, y in tqdm( + loader, + desc=f"Verifying labels in {name} dataloader", + total=verify_n_labels_in_loader if verify_n_labels_in_loader is not None else None, + ): + n_channels_y = y.shape[1] + if with_segmentation_decoder: + if n_channels_y != 4: + raise ValueError( + "Invalid number of channels in the target data from the data loader. " + "Expect 4 channel for training with an instance segmentation decoder, " + f"but got {n_channels_y} channels." + ) + # Check instance channel per sample in a batch + for per_y_sample in y: + _check_instance_channel(per_y_sample[0]) + + targets_min, targets_max = y[:, 1:].min(), y[:, 1:].max() + if targets_min < 0 or targets_min > 1: + raise ValueError( + "Invalid value range in the target data from the value loader. " + "Expect the 3 last target channels (for normalized distances and foreground probabilities)" + f"to be in range [0.0, 1.0], but got min {targets_min}" + ) + if targets_max < 0 or targets_max > 1: + raise ValueError( + "Invalid value range in the target data from the value loader. " + "Expect the 3 last target channels (for normalized distances and foreground probabilities)" + f"to be in range [0.0, 1.0], but got max {targets_max}" + ) - targets_min, targets_max = y[:, 1:].min(), y[:, 1:].max() - if targets_min < 0 or targets_min > 1: - raise ValueError( - "Invalid value range in the target data from the value loader. " - "Expect the 3 last target channels (for normalized distances and foreground probabilities)" - f"to be in range [0.0, 1.0], but got min {targets_min}" - ) - if targets_max < 0 or targets_max > 1: - raise ValueError( - "Invalid value range in the target data from the value loader. " - "Expect the 3 last target channels (for normalized distances and foreground probabilities)" - f"to be in range [0.0, 1.0], but got max {targets_max}" - ) + else: + if n_channels_y != 1: + raise ValueError( + "Invalid number of channels in the target data from the data loader. " + "Expect 1 channel for training without an instance segmentation decoder," + f"but got {n_channels_y} channels." + ) + # Check instance channel per sample in a batch + for per_y_sample in y: + _check_instance_channel(per_y_sample) - else: - if n_channels_y != 1: - raise ValueError( - "Invalid number of channels in the target data from the data loader. " - "Expect 1 channel for training without an instance segmentation decoder," - f"but got {n_channels_y} channels." - ) - check_instance_channel(y) + counter += 1 + if verify_n_labels_in_loader is not None and counter > verify_n_labels_in_loader: + break # Make the progress bar callbacks compatible with a tqdm progress bar interface. @@ -170,6 +187,7 @@ def train_sam( optimizer_class: Optional[Optimizer] = torch.optim.AdamW, peft_kwargs: Optional[Dict] = None, ignore_warnings: bool = True, + verify_n_labels_in_loader: Optional[int] = 50, **model_kwargs, ) -> None: """Run training for a SAM model. @@ -208,14 +226,17 @@ def train_sam( optimizer_class: The optimizer class. By default, torch.optim.AdamW is used. peft_kwargs: Keyword arguments for the PEFT wrapper class. + verify_n_labels_in_loader: The number of labels to verify out of the train and validation dataloaders. + By default, 50 batches of labels are verified from the dataloaders. + model_kwargs: Additional keyword arguments for the `util.get_sam_model`. ignore_warnings: Whether to ignore raised warnings. """ with _filter_warnings(ignore_warnings): t_start = time.time() - _check_loader(train_loader, with_segmentation_decoder) - _check_loader(val_loader, with_segmentation_decoder) + _check_loader(train_loader, with_segmentation_decoder, verify_n_labels_in_loader) + _check_loader(val_loader, with_segmentation_decoder, verify_n_labels_in_loader) device = get_device(device) # Get the trainable segment anything model. From 766aa9bc0231cc63bd0c05625c8f60ba3af09620 Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Mon, 14 Oct 2024 19:34:59 +0200 Subject: [PATCH 43/53] Add CLI for benchmarking datasets on SAM models (#728) Add scripts for benchmarking SAM models on microscopy datasets --- micro_sam/automatic_segmentation.py | 93 ++- micro_sam/evaluation/benchmark_datasets.py | 721 ++++++++++++++++++ micro_sam/evaluation/evaluation.py | 7 +- micro_sam/evaluation/inference.py | 18 +- .../multi_dimensional_segmentation.py | 47 +- micro_sam/multi_dimensional_segmentation.py | 15 +- micro_sam/prompt_generators.py | 8 +- micro_sam/training/training.py | 4 +- setup.cfg | 1 + test/test_automatic_segmentation.py | 59 +- test/test_training.py | 6 +- 11 files changed, 884 insertions(+), 95 deletions(-) create mode 100644 micro_sam/evaluation/benchmark_datasets.py diff --git a/micro_sam/automatic_segmentation.py b/micro_sam/automatic_segmentation.py index 79043eee8..2561d8e2b 100644 --- a/micro_sam/automatic_segmentation.py +++ b/micro_sam/automatic_segmentation.py @@ -1,6 +1,6 @@ import os from pathlib import Path -from typing import Dict, Optional, Union, Tuple +from typing import Optional, Union, Tuple, Dict import numpy as np import imageio.v3 as imageio @@ -12,54 +12,85 @@ from .multi_dimensional_segmentation import automatic_3d_segmentation +def get_predictor_and_segmenter( + model_type: str, + checkpoint: Optional[Union[os.PathLike, str]] = None, + device: str = None, + amg: bool = False, + is_tiled: bool = False, + amg_kwargs: Dict = {} +) -> Tuple[util.SamPredictor, Union[AMGBase, InstanceSegmentationWithDecoder]]: + """Get the Segment Anything model and class for automatic instance segmentation. + + Args: + model_type: The Segment Anything model choice. + checkpoint: The filepath to the stored model checkpoints. + device: The torch device. + amg: Whether to perform automatic segmentation in AMG mode. + is_tiled: Whether to return segmenter for performing segmentation in tiling window style. + + Returns: + The Segment Anything model. + The automatic instance segmentation class. + """ + # Get the device + device = util.get_device(device=device) + + # Get the predictor and state for Segment Anything models. + predictor, state = util.get_sam_model( + model_type=model_type, device=device, checkpoint_path=checkpoint, return_state=True, + ) + + # Get the segmenter for automatic segmentation. + assert isinstance(amg_kwargs, Dict), "Please ensure 'amg_kwargs' gets arguments in a dictionary." + + segmenter = get_amg( + predictor=predictor, + is_tiled=is_tiled, + decoder=get_decoder( + image_encoder=predictor.model.image_encoder, + decoder_state=state["decoder_state"], + device=device + ) if "decoder_state" in state and not amg else None, + **amg_kwargs + ) + + return predictor, segmenter + + def automatic_instance_segmentation( + predictor: util.SamPredictor, + segmenter: Union[AMGBase, InstanceSegmentationWithDecoder], input_path: Union[Union[os.PathLike, str], np.ndarray], output_path: Optional[Union[os.PathLike, str]] = None, embedding_path: Optional[Union[os.PathLike, str]] = None, - model_type: str = util._DEFAULT_MODEL, - checkpoint_path: Optional[Union[os.PathLike, str]] = None, key: Optional[str] = None, ndim: Optional[int] = None, tile_shape: Optional[Tuple[int, int]] = None, halo: Optional[Tuple[int, int]] = None, - use_amg: bool = False, - amg_kwargs: Optional[Dict] = None, + verbose: bool = True, **generate_kwargs ) -> np.ndarray: """Run automatic segmentation for the input image. Args: + predictor: The Segment Anything model. + segmenter: The automatic instance segmentation class. input_path: input_path: The input image file(s). Can either be a single image file (e.g. tif or png), or a container file (e.g. hdf5 or zarr). output_path: The output path where the instance segmentations will be saved. embedding_path: The path where the embeddings are cached already / will be saved. - model_type: The SegmentAnything model to use. Will use the standard vit_l model by default. - checkpoint_path: Path to a checkpoint for a custom model. key: The key to the input file. This is needed for container files (eg. hdf5 or zarr) or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case. ndim: The dimensionality of the data. tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling. halo: Overlap of the tiles for tiled prediction. - use_amg: Whether to use Automatic Mask Generation (AMG) as the automatic segmentation method. - amg_kwargs: optional keyword arguments for creating the AMG or AIS class. - generate_kwargs: optional keyword arguments for the generate function onf the AMG or AIS class. + verbose: Verbosity flag. + generate_kwargs: optional keyword arguments for the generate function of the AMG or AIS class. Returns: The segmentation result. """ - predictor, state = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path, return_state=True) - - if "decoder_state" in state and not use_amg: # AIS - decoder = get_decoder(predictor.model.image_encoder, state["decoder_state"]) - segmenter = get_amg( - predictor=predictor, decoder=decoder, is_tiled=tile_shape is not None, - **({} if amg_kwargs is None else amg_kwargs) - ) - else: # AMG - segmenter = get_amg( - predictor=predictor, is_tiled=tile_shape is not None, **({} if amg_kwargs is None else amg_kwargs) - ) - # Load the input image file. if isinstance(input_path, np.ndarray): image_data = input_path @@ -77,6 +108,7 @@ def automatic_instance_segmentation( embedding_path=embedding_path, tile_shape=tile_shape, halo=halo, + verbose=verbose, **generate_kwargs ) else: @@ -88,6 +120,7 @@ def automatic_instance_segmentation( ndim=ndim, tile_shape=tile_shape, halo=halo, + verbose=verbose, ) segmenter.initialize(image=image_data, image_embeddings=image_embeddings) @@ -162,6 +195,11 @@ def main(): parser.add_argument( "--amg", action="store_true", help="Whether to use automatic mask generation with the model." ) + parser.add_argument( + "-d", "--device", default=None, + help="The device to use for the predictor. Can be one of 'cuda', 'cpu' or 'mps' (only MAC)." + "By default the most performant available device will be selected." + ) args, parameter_args = parser.parse_known_args() @@ -179,17 +217,20 @@ def _convert_argval(value): parameter_args[i].lstrip("--"): _convert_argval(parameter_args[i + 1]) for i in range(0, len(parameter_args), 2) } + predictor, segmenter = get_predictor_and_segmenter( + model_type=args.model_type, checkpoint=args.checkpoint, device=args.device, + ) + automatic_instance_segmentation( + predictor=predictor, + segmenter=segmenter, input_path=args.input_path, output_path=args.output_path, embedding_path=args.embedding_path, - model_type=args.model_type, - checkpoint_path=args.checkpoint, key=args.key, ndim=args.ndim, tile_shape=args.tile_shape, halo=args.halo, - use_amg=args.amg, **generate_kwargs, ) diff --git a/micro_sam/evaluation/benchmark_datasets.py b/micro_sam/evaluation/benchmark_datasets.py new file mode 100644 index 000000000..53d39d9b0 --- /dev/null +++ b/micro_sam/evaluation/benchmark_datasets.py @@ -0,0 +1,721 @@ +import os +import time +from glob import glob +from tqdm import tqdm +from natsort import natsorted +from typing import Union, Optional, List, Literal + +import numpy as np +import pandas as pd +import imageio.v3 as imageio +from skimage.measure import label as connected_components + +from nifty.tools import blocking + +import torch + +from torch_em.data import datasets + +from micro_sam import util + +from . import run_evaluation +from ..training.training import _filter_warnings +from .inference import run_inference_with_iterative_prompting +from .evaluation import run_evaluation_for_iterative_prompting +from .multi_dimensional_segmentation import segment_slices_from_ground_truth +from ..automatic_segmentation import automatic_instance_segmentation, get_predictor_and_segmenter + + +LM_2D_DATASETS = [ + "livecell", "deepbacs", "tissuenet", "neurips_cellseg", "dynamicnuclearnet", + "hpa", "covid_if", "pannuke", "lizard", "orgasegment", "omnipose", "dic_hepg2", +] + +LM_3D_DATASETS = [ + "plantseg_root", "plantseg_ovules", "gonuclear", "mouse_embryo", "embegseg", "cellseg3d" +] + +EM_2D_DATASETS = ["mitolab_tem"] + +EM_3D_DATASETS = [ + "mitoem_rat", "mitoem_human", "platynereis_nuclei", "lucchi", "mitolab", "nuc_mm_mouse", + "num_mm_zebrafish", "uro_cell", "sponge_em", "platynereis_cilia", "vnc", "asem_mito", +] + +DATASET_RETURNS_FOLDER = { + "deepbacs": "*.tif" +} + +DATASET_CONTAINER_KEYS = { + "lucchi": ["raw", "labels"], +} + + +def _download_benchmark_datasets(path, dataset_choice): + """Ensures whether all the datasets have been downloaded or not. + + Args: + path: The path to directory where the supported datasets will be downloaded + for benchmarking Segment Anything models. + dataset_choice: The choice of dataset, expects the lower case name for the dataset. + + Returns: + List of choice of dataset(s). + """ + available_datasets = { + # Light Microscopy datasets + "livecell": lambda: datasets.livecell.get_livecell_data( + path=os.path.join(path, "livecell"), split="test", download=True, + ), + "deepbacs": lambda: datasets.deepbacs.get_deepbacs_data( + path=os.path.join(path, "deepbacs"), bac_type="mixed", download=True, + ), + "tissuenet": lambda: datasets.tissuenet.get_tissuenet_data( + path=os.path.join(path, "tissuenet"), split="test", download=True, + ), + "neurips_cellseg": lambda: datasets.neurips_cell_seg.get_neurips_cellseg_data( + root=os.path.join(path, "neurips_cellseg"), split="test", download=True, + ), + "plantseg_root": lambda: datasets.plantseg.get_plantseg_data( + path=os.path.join(path, "plantseg"), download=True, name="root", + ), + "plantseg_ovules": lambda: datasets.plantseg.get_plantseg_data( + path=os.path.join(path, "plantseg"), download=True, name="ovules", + ), + "covid_if": lambda: datasets.covid_if.get_covid_if_data( + path=os.path.join(path, "covid_if"), download=True, + ), + "hpa": lambda: datasets.hpa.get_hpa_segmentation_data( + path=os.path.join(path, "hpa"), download=True, + ), + "dynamicnuclearnet": lambda: datasets.dynamicnuclearnet.get_dynamicnuclearnet_data( + path=os.path.join(path, "dynamicnuclearnet"), split="test", download=True, + ), + "pannuke": lambda: datasets.pannuke.get_pannuke_data( + path=os.path.join(path, "pannuke"), download=True, folds=["fold_1", "fold_2", "fold_3"], + ), + "lizard": lambda: datasets.lizard.get_lizard_data( + path=os.path.join(path, "lizard"), download=True, + ), + "orgasegment": lambda: datasets.orgasegment.get_orgasegment_data( + path=os.path.join(path, "orgasegment"), split="eval", download=True, + ), + "omnipose": lambda: datasets.omnipose.get_omnipose_data( + path=os.path.join(path, "omnipose"), download=True, + ), + "gonuclear": lambda: datasets.gonuclear.get_gonuclear_data( + path=os.path.join(path, "gonuclear"), download=True, + ), + "mouse_embryo": lambda: datasets.mouse_embryo.get_mouse_embryo_data( + path=os.path.join(path, "mouse_embryo"), download=True, + ), + "embedseg_data": lambda: [ + datasets.embedseg_data.get_embedseg_data(path=os.path.join(path, "embedseg_data"), download=True, name=name) + for name in datasets.embedseg_data.URLS.keys() + ], + "cellseg_3d": lambda: datasets.cellseg_3d.get_cellseg_3d_data( + path=os.path.join(path, "cellseg_3d"), download=True, + ), + "dic_hepg2": lambda: datasets.dic_hepg2.get_dic_hepg2_data( + path=os.path.join(path, "dic_hepg2"), download=True, + ), + + # Electron Microscopy datasets + "mitoem_rat": lambda: datasets.mitoem.get_mitoem_data( + path=os.path.join(path, "mitoem"), samples="rat", split="test", download=True, + ), + "mitoem_human": lambda: datasets.mitoem.get_mitoem_data( + path=os.path.join(path, "mitoem"), samples="human", split="test", download=True, + ), + "platynereis_nuclei": lambda: datasets.platynereis.get_platy_data( + path=os.path.join(path, "platynereis"), name="nuclei", download=True, + ), + "platynereis_cilia": lambda: datasets.platynereis.get_platy_data( + path=os.path.join(path, "platynereis"), name="cilia", download=True, + ), + "lucchi": lambda: datasets.lucchi.get_lucchi_data( + path=os.path.join(path, "lucchi"), split="test", download=True, + ), + "mitolab_3d": lambda: [ + datasets.cem.get_benchmark_data( + path=os.path.join(path, "mitolab"), dataset_id=dataset_id, download=True, + ) for dataset_id in range(1, 7) + ], + "mitolab_tem": lambda: datasets.cem.get_benchmark_data( + path=os.path.join(path, "mitolab"), dataset_id=7, download=True + ), + "nuc_mm_mouse": lambda: datasets.nuc_mm.get_nuc_mm_data( + path=os.path.join(path, "nuc_mm"), sample="mouse", download=True, + ), + "nuc_mm_zebrafish": lambda: datasets.nuc_mm.get_nuc_mm_data( + path=os.path.join(path, "nuc_mm"), sample="zebrafish", download=True, + ), + "uro_cell": lambda: datasets.uro_cell.get_uro_cell_data( + path=os.path.join(path, "uro_cell"), download=True, + ), + "sponge_em": lambda: datasets.sponge_em.get_sponge_em_data( + path=os.path.join(path, "sponge_em"), download=True, + ), + "vnc": lambda: datasets.vnc.get_vnc_data( + path=os.path.join(path, "vnc"), download=True, + ), + "asem_mito": lambda: datasets.asem.get_asem_data( + path=os.path.join(path, "asem"), volume_ids=datasets.asem.ORGANELLES["mito"], download=True, + ) + } + + if dataset_choice is None: + dataset_choice = available_datasets.keys() + else: + if not isinstance(dataset_choice, list): + dataset_choice = [dataset_choice] + + for choice in dataset_choice: + if choice in available_datasets: + available_datasets[choice]() + else: + raise ValueError(f"'{choice}' is not a supported choice of dataset.") + + return dataset_choice + + +def _extract_slices_from_dataset(path, dataset_choice, crops_per_input=10): + """Extracts crops of desired shapes for performing evaluation in both 2d and 3d using `micro-sam`. + + Args: + path: The path to directory where the supported datasets have be downloaded + for benchmarking Segment Anything models. + dataset_choice: The name of the dataset of choice to extract crops. + crops_per_input: The maximum number of crops to extract per inputs. + extract_2d: Whether to extract 2d crops from 3d patches. + + Returns: + Filepath to the folder where extracted images are stored. + Filepath to the folder where corresponding extracted labels are stored. + The number of dimensions supported by the input. + """ + ndim = 2 if dataset_choice in [*LM_2D_DATASETS, *EM_2D_DATASETS] else 3 + tile_shape = (512, 512) if ndim == 2 else (32, 512, 512) + + # For 3d inputs, we extract both 2d and 3d crops. + extract_2d_crops_from_volumes = (ndim == 3) + + available_datasets = { + # Light Microscopy datasets + "livecell": lambda: datasets.livecell.get_livecell_paths(path=path, split="test"), + "deepbacs": lambda: datasets.deepbacs.get_deepbacs_paths(path=path, split="test", bac_type="mixed"), + "tissuenet": lambda: datasets.tissuenet.get_tissuenet_paths(path=path, split="test"), + "neurips_cellseg": lambda: datasets.neurips_cell_seg.get_neurips_cellseg_paths(root=path, split="test"), + "plantseg_root": lambda: datasets.plantseg.get_plantseg_paths(path=path, name="root", split="test"), + "plantseg_ovules": lambda: datasets.plantseg.get_plantseg_paths(path=path, name="ovules", split="test"), + "covid_if": lambda: datasets.covid_if.get_covid_if_paths(path=path), + "hpa": lambda: datasets.hpa.get_hpa_segmentation_paths(path=path, split="test"), + "dynamicnuclearnet": lambda: datasets.dynamicnuclearnet.get_dynamicnuclearnet_paths(path=path, split="test"), + "pannuke": lambda: datasets.pannuke.get_pannuke_paths(path=path), + "lizard": lambda: datasets.lizard.get_lizard_paths(parth=path), + "orgasegment": lambda: datasets.orgasegment.get_orgasegment_paths(path=path, split="eval"), + "omnipose": lambda: datasets.omnipose.get_omnipose_paths(path=path, split="test"), + "gonuclear": lambda: datasets.gonuclear.get_gonuclear_paths(path-path), + "mouse_embryo": lambda: datasets.mouse_embryo.get_mouse_embryo_paths(path=path, name="nuclei", split="val"), + "embedseg_data": lambda: datasets.embedseg_data.get_embedseg_paths( + path=path, name=list(datasets.embedseg_data.URLS.keys())[0], split="test" + ), + "cellseg_3d": lambda: datasets.cellseg_3d.get_cellseg_3d_paths(path=path), + "dic_hepg2": lambda: datasets.dic_hepg2.get_dic_hepg2_paths(path=path, split="test"), + + # Electron Microscopy datasets + "mitoem_rat": lambda: datasets.mitoem.get_mitoem_paths(path=path, splits="test", samples="rat"), + "mitem_human": lambda: datasets.mitoem.get_mitoem_paths(path=path, splits="test", samples="human"), + "platynereis_nuclei": lambda: datasets.platynereis.get_platynereis_paths(path, sample_ids=None, name="nuclei"), + "platynereis_cilia": lambda: datasets.platynereis.get_platynereis_paths(path, sample_ids=None, name="cilia"), + "lucchi": lambda: datasets.lucchi.get_lucchi_paths(path=path, split="test"), + "mitolab_3d": lambda: ( + [rpath for i in range(1, 7) for rpath in datasets.cem.get_benchmark_paths(path=path, dataset_id=i)[0]], + [lpath for i in range(1, 7) for lpath in datasets.cem.get_benchmark_paths(path=path, dataset_id=i)[1]] + ), + "mitolab_tem": lambda: datasets.cem.get_benchmark_paths(path=path, dataset_id=7), + "nuc_mm_mouse": lambda: datasets.nuc_mm.get_nuc_mm_paths(path=path, sample="mouse", split="val"), + "nuc_mm_zebrafish": lambda: datasets.nuc_mm.get_nuc_mm_paths(path=path, sample="zebrafish", split="val"), + "uro_cell": lambda: datasets.uro_cell.get_uro_cell_paths(path=path, target="mito"), + "sponge_em": lambda: datasets.sponge_em.get_sponge_em_paths(path=path, sample_ids=None), + "vnc": lambda: datasets.vnc.get_vnc_mito_paths(path=path), + "asem_mito": lambda: datasets.asem.get_asem_paths(path=path, volume_ids=datasets.asem.ORGANELLES["mito"]) + } + + if ndim == 2: + image_paths, gt_paths = available_datasets[dataset_choice]() + + if dataset_choice in DATASET_RETURNS_FOLDER: + image_paths = glob(os.path.join(image_paths, DATASET_RETURNS_FOLDER[dataset_choice])) + gt_paths = glob(os.path.join(gt_paths, DATASET_RETURNS_FOLDER[dataset_choice])) + + image_paths, gt_paths = natsorted(image_paths), natsorted(gt_paths) + assert len(image_paths) == len(gt_paths) + + paths_set = zip(image_paths, gt_paths) + + else: + image_paths = available_datasets[dataset_choice]() + if isinstance(image_paths, str): + paths_set = [image_paths] + else: + paths_set = natsorted(image_paths) + + # Directory where we store the extracted ROIs. + save_image_dir = [os.path.join(path, f"roi_{ndim}d", "inputs")] + save_gt_dir = [os.path.join(path, f"roi_{ndim}d", "labels")] + if extract_2d_crops_from_volumes: + save_image_dir.append(os.path.join(path, "roi_2d", "inputs")) + save_gt_dir.append(os.path.join(path, "roi_2d", "labels")) + + _dir_exists = [ + os.path.exists(idir) and os.path.exists(gdir) for idir, gdir in zip(save_image_dir, save_gt_dir) + ] + if all(_dir_exists): + return ndim + + [os.makedirs(idir, exist_ok=True) for idir in save_image_dir] + [os.makedirs(gdir, exist_ok=True) for gdir in save_gt_dir] + + # Logic to extract relevant patches for inference + image_counter = 1 + for per_paths in tqdm(paths_set, desc=f"Extracting patches for {dataset_choice}"): + if ndim == 2: + image_path, gt_path = per_paths + image, gt = util.load_image_data(image_path), util.load_image_data(gt_path) + else: + image_path = per_paths + image = util.load_image_data(image_path, DATASET_CONTAINER_KEYS[dataset_choice][0]) + gt = util.load_image_data(image_path, DATASET_CONTAINER_KEYS[dataset_choice][1]) + + skip_smaller_shape = (np.array(image.shape) >= np.array(tile_shape)).all() + + # Ensure ground truth has instance labels. + gt = connected_components(gt) + + if len(np.unique(gt)) == 1: # There could be labels which does not have any annotated foreground. + continue + + # Let's extract and save all the crops. + # NOTE: The first round of extraction is always to match the desired input dimensions. + image_crops, gt_crops = _get_crops_for_input(image, gt, ndim, tile_shape, skip_smaller_shape, crops_per_input) + image_counter = _save_image_label_crops( + image_crops, gt_crops, dataset_choice, ndim, image_counter, save_image_dir[0], save_gt_dir[0] + ) + + # NOTE: The next round of extraction is to get 2d crops from 3d inputs. + if extract_2d_crops_from_volumes: + curr_tile_shape = tile_shape[-2:] # NOTE: We expect 2d tile shape for this stage. + + curr_image_crops, curr_gt_crops = [], [] + for per_z_im, per_z_gt in zip(image, gt): + curr_skip_smaller_shape = (np.array(per_z_im.shape) >= np.array(curr_tile_shape)).all() + + image_crops, gt_crops = _get_crops_for_input( + image=per_z_im, gt=per_z_gt, ndim=2, + tile_shape=curr_tile_shape, + skip_smaller_shape=curr_skip_smaller_shape, + crops_per_input=crops_per_input, + ) + curr_image_crops.extend(image_crops) + curr_gt_crops.extend(gt_crops) + + image_counter = _save_image_label_crops( + curr_image_crops, curr_gt_crops, dataset_choice, 2, image_counter, save_image_dir[1], save_gt_dir[1] + ) + + return ndim + + +def _get_crops_for_input(image, gt, ndim, tile_shape, skip_smaller_shape, crops_per_input): + tiling = blocking([0] * ndim, gt.shape, tile_shape) + n_tiles = tiling.numberOfBlocks + tiles = [tiling.getBlock(tile_id) for tile_id in range(n_tiles)] + crop_boxes = [ + tuple(slice(beg, end) for beg, end in zip(tile.begin, tile.end)) for tile in tiles + ] + n_ids = [idx for idx in range(len(crop_boxes))] + n_instances = [len(np.unique(gt[crop])) for crop in crop_boxes] + + # Extract the desired number of patches with higher number of instances. + image_crops, gt_crops = [], [] + for i, (per_n_instance, per_id) in enumerate(sorted(zip(n_instances, n_ids), reverse=True), start=1): + crop_box = crop_boxes[per_id] + crop_image, crop_gt = image[crop_box], gt[crop_box] + # NOTE: We avoid using the crops which do not match the desired tile shape. + if skip_smaller_shape and crop_image.shape != tile_shape: + continue + + # NOTE: There could be a case where some later patches are invalid. + if per_n_instance == 1: + break + + image_crops.append(crop_image) + gt_crops.append(crop_gt) + + # NOTE: If the number of patches extracted have been fulfiled, we stop sampling patches. + if len(image_crops) > 0 and i >= crops_per_input: + break + + return image_crops, gt_crops + + +def _save_image_label_crops(image_crops, gt_crops, dataset_choice, ndim, image_counter, save_image_dir, save_gt_dir): + for image_crop, gt_crop in tqdm( + zip(image_crops, gt_crops), total=len(image_crops), desc=f"Saving {ndim}d crops for {dataset_choice}" + ): + fname = f"{dataset_choice}_{image_counter:05}.tif" + assert image_crop.shape == gt_crop.shape + imageio.imwrite(os.path.join(save_image_dir, fname), image_crop, compression="zlib") + imageio.imwrite(os.path.join(save_gt_dir, fname), gt_crop, compression="zlib") + image_counter += 1 + + return image_counter + + +def _get_image_label_paths(path, ndim): + image_paths = natsorted(glob(os.path.join(path, f"roi_{ndim}d", "inputs", "*"))) + gt_paths = natsorted(glob(os.path.join(path, f"roi_{ndim}d", "labels", "*"))) + return image_paths, gt_paths + + +def _run_automatic_segmentation_per_dataset( + image_paths: List[Union[os.PathLike, str]], + gt_paths: List[Union[os.PathLike, str]], + model_type: str, + output_folder: Union[os.PathLike, str], + ndim: Optional[int] = None, + device: Optional[Union[torch.device, str]] = None, + checkpoint_path: Optional[Union[os.PathLike, str]] = None, + run_amg: bool = False, + **auto_seg_kwargs +): + """Functionality to run automatic segmentation for multiple input files at once. + It stores the evaluated automatic segmentation results (quantitative). + + Args: + image_paths: List of filepaths for the input image data. + gt_paths: List of filepaths for the corresponding label data. + model_type: The choice of image encoder for the Segment Anything model. + output_folder: Filepath to the folder where we store all the results. + ndim: The number of input dimensions. + device: The torch device. + checkpoint_path: The filepath where the model checkpoints are stored. + run_amg: Whether to run automatic segmentation in AMG mode. + auto_seg_kwargs: Additional arguments for automatic segmentation parameters. + """ + experiment_name = "AMG" if run_amg else "AIS" + fname = f"{experiment_name.lower()}_{ndim}d" + + result_path = os.path.join(output_folder, "results", f"{fname}.csv") + prediction_dir = os.path.join(output_folder, fname, "inference") + if os.path.exists(prediction_dir): + return + + os.makedirs(prediction_dir, exist_ok=True) + + # Get the predictor (and the additional instance segmentation decoder, if available). + predictor, segmenter = get_predictor_and_segmenter( + model_type=model_type, checkpoint=checkpoint_path, device=device, amg=run_amg, is_tiled=False, + ) + + for image_path in tqdm(image_paths, desc=f"Run {experiment_name} in {ndim}d"): + output_path = os.path.join(prediction_dir, os.path.basename(image_path)) + if os.path.exists(output_path): + continue + + # Run Automatic Segmentation (AMG and AIS) + automatic_instance_segmentation( + predictor=predictor, + segmenter=segmenter, + input_path=image_path, + output_path=output_path, + ndim=ndim, + verbose=False, + **auto_seg_kwargs + ) + + prediction_paths = natsorted(glob(os.path.join(prediction_dir, "*"))) + run_evaluation(gt_paths=gt_paths, prediction_paths=prediction_paths, save_path=result_path) + + +def _run_interactive_segmentation_per_dataset( + image_paths: List[Union[os.PathLike, str]], + gt_paths: List[Union[os.PathLike, str]], + output_folder: Union[os.PathLike, str], + model_type: str, + prompt_choice: Literal["box", "points"], + device: Optional[Union[torch.device, str]] = None, + ndim: Optional[int] = None, + checkpoint_path: Optional[Union[os.PathLike, str]] = None, +): + """Functionality to run interactive segmentation for multiple input files at once. + It stores the evaluated interactive segmentation results. + + Args: + image_paths: List of filepaths for the input image data. + gt_paths: List of filepaths for the corresponding label data. + output_folder: Filepath to the folder where we store all the results. + model_type: The choice of model type for Segment Anything. + prompt_choice: The choice of initial prompts to begin the interactive segmentation. + device: The torch device. + ndim: The number of input dimensions. + checkpoint_path: The filepath for stored checkpoints. + """ + if ndim == 2: + # Get the Segment Anything predictor. + predictor = util.get_sam_model(model_type=model_type, device=device, checkpoint_path=checkpoint_path) + + # Run interactive instance segmentation + # (starting with box and points followed by iterative prompt-based correction) + run_inference_with_iterative_prompting( + predictor=predictor, + image_paths=image_paths, + gt_paths=gt_paths, + embedding_dir=None, # We set this to None to compute embeddings on-the-fly. + prediction_dir=os.path.join(output_folder, "interactive_segmentation_2d", f"start_with_{prompt_choice}"), + start_with_box_prompt=(prompt_choice == "box"), + # TODO: add parameter for deform over box prompts (to simulate prompts in practice). + ) + + # Evaluate the interactive instance segmentation. + run_evaluation_for_iterative_prompting( + gt_paths=gt_paths, + prediction_root=os.path.join(output_folder, "interactive_segmentation_2d", f"start_with_{prompt_choice}"), + experiment_folder=output_folder, + start_with_box_prompt=(prompt_choice == "box"), + ) + + else: + save_path = os.path.join(output_folder, "results", f"interactive_segmentation_3d_with_{prompt_choice}.csv") + if os.path.exists(save_path): + print( + f"Results for 3d interactive segmentation with '{prompt_choice}' are already stored at '{save_path}'." + ) + return + + results = [] + for image_path, gt_path in tqdm( + zip(image_paths, gt_paths), total=len(image_paths), + desc=f"Run interactive segmentation in 3d with '{prompt_choice}'" + ): + prediction_dir = os.path.join(output_folder, "interactive_segmentation_3d", f"{prompt_choice}") + os.makedirs(prediction_dir, exist_ok=True) + + prediction_path = os.path.join(prediction_dir, os.path.basename(image_path)) + if os.path.exists(prediction_path): + continue + + per_vol_result = segment_slices_from_ground_truth( + volume=imageio.imread(image_path), + ground_truth=imageio.imread(gt_path), + model_type=model_type, + checkpoint_path=checkpoint_path, + save_path=prediction_path, + device=device, + interactive_seg_mode=prompt_choice, + min_size=10, + ) + results.append(per_vol_result) + + results = pd.concat(results) + results = results.groupby(results.index).mean() + results.to_csv(save_path) + + +def _run_benchmark_evaluation_series( + image_paths, gt_paths, model_type, output_folder, ndim, device, checkpoint_path, run_amg, +): + seg_kwargs = { + "image_paths": image_paths, + "gt_paths": gt_paths, + "output_folder": output_folder, + "ndim": ndim, + "model_type": model_type, + "device": device, + "checkpoint_path": checkpoint_path, + } + + # Perform: + # a. automatic segmentation (supported in both 2d and 3d, wherever relevant) + # The automatic segmentation steps below are configured in a way that AIS has priority (if decoder is found) + # Else, it runs for AMG. + # Next, we check if the user expects to run AMG as well (after the run for AIS). + + # i. Run automatic segmentation method supported with the SAM model (AMG or AIS). + _run_automatic_segmentation_per_dataset(run_amg=False, **seg_kwargs) + + # ii. Run automatic mask generation (AMG) (in case the first run is AIS). + _run_automatic_segmentation_per_dataset(run_amg=run_amg, **seg_kwargs) + + # b. Run interactive segmentation (supported in both 2d and 3d, wherever relevant) + _run_interactive_segmentation_per_dataset(prompt_choice="box", **seg_kwargs) + _run_interactive_segmentation_per_dataset(prompt_choice="points", **seg_kwargs) + + +def _clear_cached_items(retain, path, output_folder): + import shutil + from pathlib import Path + + REMOVE_LIST = ["data", "crops", "auto", "int"] + if retain is None: + remove_list = REMOVE_LIST + else: + assert isinstance(retain, list) + remove_list = set(REMOVE_LIST) - set(retain) + + paths = [] + # Stage 1: Remove inputs. + if "data" in remove_list or "crops" in remove_list: + all_paths = glob(os.path.join(path, "*")) + + # In case we want to remove both data and crops, we remove the data folder entirely. + if "data" in remove_list and "crops" in remove_list: + paths.extend(all_paths) + return + + # Next, we verify whether the we only remove either of data or crops. + for curr_path in all_paths: + if os.path.basename(curr_path).startswith("roi") and "crops" in remove_list: + paths.append(curr_path) + elif "data" in remove_list: + paths.append(curr_path) + + # Stage 2: Remove predictions + if "auto" in remove_list: + paths.extend(glob(os.path.join(output_folder, "amg_*"))) + paths.extend(glob(os.path.join(output_folder, "ais_*"))) + + if "int" in remove_list: + paths.extend(glob(os.path.join(output_folder, "interactive_segmentation_*"))) + + [shutil.rmtree(_path) if Path(_path).is_dir() else os.remove(_path) for _path in paths] + + +def run_benchmark_evaluations( + input_folder: Union[os.PathLike, str], + dataset_choice: str, + model_type: str = util._DEFAULT_MODEL, + output_folder: Optional[Union[str, os.PathLike]] = None, + checkpoint_path: Optional[Union[str, os.PathLike]] = None, + run_amg: bool = False, + retain: Optional[List[str]] = None, + ignore_warnings: bool = False, +): + """Run evaluation for benchmarking Segment Anything models on microscopy datasets. + + Args: + input_folder: The path to directory where all inputs will be stored and preprocessed. + dataset_choice: The dataset choice. + model_type: The model choice for SAM. + output_folder: The path to directory where all outputs will be stored. + checkpoint_path: The checkpoint path + run_amg: Whether to run automatic segmentation in AMG mode. + retain: Whether to retain certain parts of the benchmark runs. + By default, removes everything besides quantitative results. + There is the choice to retain 'data', 'crops', 'auto', or 'int'. + ignore_warnings: Whether to ignore warnings. + """ + start = time.time() + + with _filter_warnings(ignore_warnings): + device = util._get_default_device() + + # Ensure if all the datasets have been installed by default. + dataset_choice = _download_benchmark_datasets(path=input_folder, dataset_choice=dataset_choice) + + for choice in dataset_choice: + output_folder = os.path.join(output_folder, choice) + result_dir = os.path.join(output_folder, "results") + if os.path.exists(result_dir): + continue + + os.makedirs(result_dir, exist_ok=True) + + data_path = os.path.join(input_folder, choice) + + # Extrapolate desired set from the datasets: + # a. for 2d datasets - 2d patches with the most number of labels present + # (in case of volumetric data, choose 2d patches per slice). + # b. for 3d datasets - 3d regions of interest with the most number of labels present. + ndim = _extract_slices_from_dataset(path=data_path, dataset_choice=choice, crops_per_input=10) + + # Run inference and evaluation scripts on benchmark datasets. + image_paths, gt_paths = _get_image_label_paths(path=data_path, ndim=ndim) + _run_benchmark_evaluation_series( + image_paths, gt_paths, model_type, output_folder, ndim, device, checkpoint_path, run_amg + ) + + # Run inference and evaluation scripts on '2d' crops for volumetric datasets + if ndim == 3: + image_paths, gt_paths = _get_image_label_paths(path=data_path, ndim=2) + _run_benchmark_evaluation_series( + image_paths, gt_paths, model_type, output_folder, 2, device, checkpoint_path, run_amg + ) + + _clear_cached_items(retain=retain, path=data_path, output_folder=output_folder) + + diff = time.time() - start + hours, rest = divmod(diff, 3600) + minutes, seconds = divmod(rest, 60) + print("Time taken for running benchmarks: ", f"{int(hours)}h {int(minutes)}m {seconds:.2f}s") + + +def main(): + """@private""" + import argparse + + available_models = list(util.get_model_names()) + available_models = ", ".join(available_models) + + parser = argparse.ArgumentParser( + description="Run evaluation for benchmarking Segment Anything models on microscopy datasets." + ) + parser.add_argument( + "-i", "--input_folder", type=str, required=True, + help="The path to a directory where the microscopy datasets are / will be stored." + ) + parser.add_argument( + "-m", "--model_type", type=str, default=util._DEFAULT_MODEL, + help=f"The segment anything model that will be used, one of {available_models}." + ) + parser.add_argument( + "-c", "--checkpoint_path", type=str, default=None, + help="Checkpoint from which the SAM model will be loaded loaded." + ) + parser.add_argument( + "-d", "--dataset_choice", type=str, nargs='*', default=None, + help="The choice(s) of dataset for evaluating SAM models. Multiple datasets can be specified." + ) + parser.add_argument( + "-o", "--output_folder", type=str, required=True, + help="The path where the results for automatic and interactive instance segmentation will be stored as 'csv'." + ) + parser.add_argument( + "--amg", action="store_true", + help="Whether to run automatic segmentation in AMG mode (i.e. the default auto-seg approach for SAM)." + ) + parser.add_argument( + "--retain", nargs="*", default=None, + help="By default, the functionality removes all besides quantitative results required for running benchmarks. " + "In case you would like to retain parts of the benchmark evaluation for visualization / reproducability, " + "you should choose one or multiple of 'data', 'crops', 'auto', 'int'. " + "where they are responsible for either retaining original inputs / extracted crops / " + "predictions of automatic segmentation / predictions of interactive segmentation, respectively." + ) + args = parser.parse_args() + + run_benchmark_evaluations( + input_folder=args.input_folder, + dataset_choice=args.dataset_choice, + model_type=args.model_type, + output_folder=args.output_folder, + checkpoint_path=args.checkpoint_path, + run_amg=args.amg, + retain=args.retain, + ignore_warnings=True, + ) + + +if __name__ == "__main__": + main() diff --git a/micro_sam/evaluation/evaluation.py b/micro_sam/evaluation/evaluation.py index a52a11266..869334fc1 100644 --- a/micro_sam/evaluation/evaluation.py +++ b/micro_sam/evaluation/evaluation.py @@ -62,9 +62,7 @@ def run_evaluation( msas, sa50s, sa75s = _run_evaluation(gt_paths, prediction_paths, verbose=verbose) results = pd.DataFrame.from_dict({ - "msa": [np.mean(msas)], - "sa50": [np.mean(sa50s)], - "sa75": [np.mean(sa75s)], + "mSA": [np.mean(msas)], "SA50": [np.mean(sa50s)], "SA75": [np.mean(sa75s)], }) if save_path is not None: @@ -110,7 +108,7 @@ def run_evaluation_for_iterative_prompting( # If the results have been computed already, it's not needed to re-run it again. if os.path.exists(csv_path): - print(pd.read_csv(csv_path)) + print(f"Results with iterative prompting for interactive segmentation are already stored at '{csv_path}'.") return list_of_results = [] @@ -120,7 +118,6 @@ def run_evaluation_for_iterative_prompting( pred_paths = sorted(glob(os.path.join(pred_folder, "*"))) result = run_evaluation(gt_paths=gt_paths, prediction_paths=pred_paths, save_path=None) list_of_results.append(result) - print(result) res_df = pd.concat(list_of_results, ignore_index=True) res_df.to_csv(csv_path) diff --git a/micro_sam/evaluation/inference.py b/micro_sam/evaluation/inference.py index e1736fa5d..b033055f4 100644 --- a/micro_sam/evaluation/inference.py +++ b/micro_sam/evaluation/inference.py @@ -473,21 +473,21 @@ def run_inference_with_iterative_prompting( gt_paths: List[Union[str, os.PathLike]], embedding_dir: Union[str, os.PathLike], prediction_dir: Union[str, os.PathLike], - start_with_box_prompt: bool, + start_with_box_prompt: bool = True, dilation: int = 5, batch_size: int = 32, n_iterations: int = 8, use_masks: bool = False ) -> None: - """Run segment anything inference for multiple images using prompts iteratively - derived from model outputs and groundtruth + """Run Segment Anything inference for multiple images using prompts iteratively + derived from model outputs and ground-truth. Args: - predictor: The SegmentAnything predictor. + predictor: The Segment Anything predictor. image_paths: The image file paths. gt_paths: The ground-truth segmentation file paths. embedding_dir: The directory where the image embeddings will be saved or are already saved. - prediction_dir: The directory where the predictions from SegmentAnything will be saved per iteration. + prediction_dir: The directory where the predictions from Segment Anything will be saved per iteration. start_with_box_prompt: Whether to use the first prompt as bounding box or a single point dilation: The dilation factor for the radius around the ground-truth object around which points will not be sampled. @@ -506,8 +506,7 @@ def run_inference_with_iterative_prompting( print("The iterative prompting will make use of logits masks from previous iterations.") for image_path, gt_path in tqdm( - zip(image_paths, gt_paths), - total=len(image_paths), + zip(image_paths, gt_paths), total=len(image_paths), desc="Run inference with iterative prompting for all images", ): image_name = os.path.basename(image_path) @@ -524,7 +523,10 @@ def run_inference_with_iterative_prompting( gt = imageio.imread(gt_path).astype("uint32") gt = relabel_sequential(gt)[0] - embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr") + if embedding_dir is None: + embedding_path = None + else: + embedding_path = os.path.join(embedding_dir, f"{os.path.splitext(image_name)[0]}.zarr") _run_inference_with_iterative_prompting_for_image( predictor, image, gt, start_with_box_prompt=start_with_box_prompt, diff --git a/micro_sam/evaluation/multi_dimensional_segmentation.py b/micro_sam/evaluation/multi_dimensional_segmentation.py index e54cafb59..07b5820f0 100644 --- a/micro_sam/evaluation/multi_dimensional_segmentation.py +++ b/micro_sam/evaluation/multi_dimensional_segmentation.py @@ -6,6 +6,8 @@ from itertools import product from typing import Union, Tuple, Optional, List, Dict +import imageio.v3 as imageio + import torch from elf.evaluation import mean_segmentation_accuracy @@ -58,8 +60,9 @@ def segment_slices_from_ground_truth( volume: np.ndarray, ground_truth: np.ndarray, model_type: str, - checkpoint_path: Union[str, os.PathLike], - embedding_path: Union[str, os.PathLike], + checkpoint_path: Optional[Union[str, os.PathLike]] = None, + embedding_path: Optional[Union[str, os.PathLike]] = None, + save_path: Optional[Union[str, os.PathLike]] = None, iou_threshold: float = 0.8, projection: Union[str, dict] = "mask", box_extension: Union[float, int] = 0.025, @@ -81,6 +84,7 @@ def segment_slices_from_ground_truth( model_type: Choice of segment anything model. checkpoint_path: Path to the model checkpoint. embedding_path: Path to cache the computed embeddings. + save_path: Path to store the segmentations. iou_threshold: The criterion to decide whether to link the objects in the consecutive slice's segmentation. projection: The projection (prompting) method to generate prompts for consecutive slices. box_extension: Extension factor for increasing the box size after projection. @@ -97,7 +101,7 @@ def segment_slices_from_ground_truth( # Compute the image embeddings embeddings = util.precompute_image_embeddings( - predictor=predictor, input_=volume, save_path=embedding_path, ndim=3 + predictor=predictor, input_=volume, save_path=embedding_path, ndim=3, verbose=verbose, ) # Compute instance ids (without the background) @@ -133,7 +137,7 @@ def segment_slices_from_ground_truth( _get_points, _get_box = False, True else: raise ValueError( - "The provided interactive prompting for the first slice isn't supported.", + f"The provided interactive prompting '{interactive_seg_mode}' for the first slice isn't supported." "Please choose from 'box' / 'points'." ) @@ -145,14 +149,20 @@ def segment_slices_from_ground_truth( get_box_prompts=_get_box ) _, box_coords = util.get_centers_and_bounding_boxes(this_slice_seg) - point_prompts, point_labels, box_prompts, _ = prompt_generator(this_slice_seg, [box_coords[1]]) + point_prompts, point_labels, box_prompts, _ = prompt_generator( + segmentation=torch.from_numpy(this_slice_seg)[None, None].to(torch.float32), + bbox_coordinates=[box_coords[1]], + ) # Prompt-based segmentation on middle slice of the current object output_slice = batched_inference( - predictor=predictor, image=volume[slice_choice], batch_size=1, + predictor=predictor, + image=volume[slice_choice], + batch_size=1, boxes=box_prompts.numpy() if isinstance(box_prompts, torch.Tensor) else box_prompts, points=point_prompts.numpy() if isinstance(point_prompts, torch.Tensor) else point_prompts, - point_labels=point_labels.numpy() if isinstance(point_labels, torch.Tensor) else point_labels + point_labels=point_labels.numpy() if isinstance(point_labels, torch.Tensor) else point_labels, + verbose_embeddings=verbose, ) output_seg = np.zeros_like(ground_truth) output_seg[slice_choice][output_slice == 1] = 1 @@ -173,18 +183,25 @@ def segment_slices_from_ground_truth( # Store the entire segmented object final_segmentation[this_seg == 1] = label_id + # Save the volumetric segmentation + if save_path is not None: + imageio.imwrite(save_path, final_segmentation, compression="zlib") + # Evaluate the volumetric segmentation if skipped_label_ids: - gt_copy = ground_truth.copy() - gt_copy[np.isin(gt_copy, skipped_label_ids)] = 0 - msa = mean_segmentation_accuracy(final_segmentation, gt_copy) + curr_gt = ground_truth.copy() + curr_gt[np.isin(curr_gt, skipped_label_ids)] = 0 else: - msa = mean_segmentation_accuracy(final_segmentation, ground_truth) + curr_gt = ground_truth + + msa, sa = mean_segmentation_accuracy(final_segmentation, curr_gt, return_accuracies=True) + results = {"mSA": msa, "SA50": sa[0], "SA75": sa[5]} + results = pd.DataFrame.from_dict([results]) if return_segmentation: - return msa, final_segmentation + return results, final_segmentation else: - return msa + return results def _get_best_parameters_from_grid_search_combinations(result_dir, best_params_path, grid_search_values): @@ -266,7 +283,7 @@ def run_multi_dimensional_segmentation_grid_search( net_list = [] for gs_kwargs in tqdm(gs_combinations): - msa = segment_slices_from_ground_truth( + results = segment_slices_from_ground_truth( volume=volume, ground_truth=ground_truth, model_type=model_type, @@ -279,7 +296,7 @@ def run_multi_dimensional_segmentation_grid_search( **gs_kwargs ) - result_dict = {"mSA": msa, **gs_kwargs} + result_dict = {**results, **gs_kwargs} tmp_df = pd.DataFrame([result_dict]) net_list.append(tmp_df) diff --git a/micro_sam/multi_dimensional_segmentation.py b/micro_sam/multi_dimensional_segmentation.py index fd44ef64b..7ecaf14e6 100644 --- a/micro_sam/multi_dimensional_segmentation.py +++ b/micro_sam/multi_dimensional_segmentation.py @@ -397,7 +397,13 @@ def automatic_3d_segmentation( min_object_size = kwargs.pop("min_object_size", 0) image_embeddings = util.precompute_image_embeddings( - predictor=predictor, input_=volume, save_path=embedding_path, ndim=3, tile_shape=tile_shape, halo=halo, + predictor=predictor, + input_=volume, + save_path=embedding_path, + ndim=3, + tile_shape=tile_shape, + halo=halo, + verbose=verbose, ) for i in tqdm(range(segmentation.shape[0]), desc="Segment slices", disable=not verbose): @@ -415,7 +421,12 @@ def automatic_3d_segmentation( segmentation[i] = seg segmentation = merge_instance_segmentation_3d( - segmentation, beta=0.5, with_background=with_background, gap_closing=gap_closing, min_z_extent=min_z_extent + segmentation, + beta=0.5, + with_background=with_background, + gap_closing=gap_closing, + min_z_extent=min_z_extent, + verbose=verbose, ) return segmentation diff --git a/micro_sam/prompt_generators.py b/micro_sam/prompt_generators.py index 839077410..df521e4fb 100644 --- a/micro_sam/prompt_generators.py +++ b/micro_sam/prompt_generators.py @@ -191,12 +191,8 @@ def _sample_points(self, segmentation, bbox_coordinates, center_coordinates): center_coordinates = [None] * len(segmentation) if center_coordinates is None else center_coordinates for object_mask, bbox_coords, center_coords in zip(segmentation, bbox_coordinates, center_coordinates): coord_list, label_list = [], [] - coord_list, label_list = self._sample_positive_points( - object_mask[0], center_coords, coord_list, label_list - ) - coord_list, label_list = self._sample_negative_points( - object_mask[0], bbox_coords, coord_list, label_list - ) + coord_list, label_list = self._sample_positive_points(object_mask[0], center_coords, coord_list, label_list) + coord_list, label_list = self._sample_negative_points(object_mask[0], bbox_coords, coord_list, label_list) coord_list, label_list = self._ensure_num_points(object_mask[0], coord_list, label_list) all_coords.append(coord_list) diff --git a/micro_sam/training/training.py b/micro_sam/training/training.py index 165a10ae9..6c66ccb39 100644 --- a/micro_sam/training/training.py +++ b/micro_sam/training/training.py @@ -235,8 +235,8 @@ def train_sam( t_start = time.time() - _check_loader(train_loader, with_segmentation_decoder, verify_n_labels_in_loader) - _check_loader(val_loader, with_segmentation_decoder, verify_n_labels_in_loader) + _check_loader(train_loader, with_segmentation_decoder, "train", verify_n_labels_in_loader) + _check_loader(val_loader, with_segmentation_decoder, "val", verify_n_labels_in_loader) device = get_device(device) # Get the trainable segment anything model. diff --git a/setup.cfg b/setup.cfg index 2c6b5e38b..d7f976b28 100644 --- a/setup.cfg +++ b/setup.cfg @@ -49,6 +49,7 @@ console_scripts = micro_sam.image_series_annotator = micro_sam.sam_annotator.image_series_annotator:main micro_sam.precompute_embeddings = micro_sam.precompute_state:main micro_sam.automatic_segmentation = micro_sam.automatic_segmentation:main + micro_sam.benchmark_sam = micro_sam.evaluation.benchmark_datasets:main # make sure it gets included in your package [options.package_data] diff --git a/test/test_automatic_segmentation.py b/test/test_automatic_segmentation.py index 47b460f45..e0bb4287a 100644 --- a/test/test_automatic_segmentation.py +++ b/test/test_automatic_segmentation.py @@ -66,89 +66,92 @@ def tearDown(self): torch.mps.empty_cache() def test_automatic_mask_generator_2d(self): - from micro_sam.automatic_segmentation import automatic_instance_segmentation + from micro_sam.automatic_segmentation import automatic_instance_segmentation, get_predictor_and_segmenter mask, image = self.mask, self.image + predictor, segmenter = get_predictor_and_segmenter( + model_type=self.model_type, amg=True, is_tiled=False, amg_kwargs={"points_per_side": 4} + ) instances = automatic_instance_segmentation( - input_path=image, model_type=self.model_type, ndim=2, use_amg=True, - amg_kwargs={"points_per_side": 4} + predictor=predictor, segmenter=segmenter, input_path=image, ndim=2, ) self.assertEqual(mask.shape, instances.shape) def test_tiled_automatic_mask_generator_2d(self): - from micro_sam.automatic_segmentation import automatic_instance_segmentation + from micro_sam.automatic_segmentation import automatic_instance_segmentation, get_predictor_and_segmenter mask, image = self.large_mask, self.large_image + predictor, segmenter = get_predictor_and_segmenter( + model_type=self.model_type, amg=True, is_tiled=True, amg_kwargs={"points_per_side": 4} + ) instances = automatic_instance_segmentation( - input_path=image, - model_type=self.model_type, - ndim=2, - tile_shape=self.tile_shape, - halo=self.halo, - use_amg=True, - amg_kwargs={"points_per_side": 4} + predictor=predictor, segmenter=segmenter, input_path=image, + ndim=2, tile_shape=self.tile_shape, halo=self.halo, ) self.assertEqual(mask.shape, instances.shape) def test_instance_segmentation_with_decoder_2d(self): - from micro_sam.automatic_segmentation import automatic_instance_segmentation + from micro_sam.automatic_segmentation import automatic_instance_segmentation, get_predictor_and_segmenter mask, image = self.mask, self.image + predictor, segmenter = get_predictor_and_segmenter(model_type=self.model_type, amg=False, is_tiled=False) instances = automatic_instance_segmentation( - input_path=image, model_type=self.model_type_ais, ndim=2 + predictor=predictor, segmenter=segmenter, input_path=image, ndim=2, ) self.assertEqual(mask.shape, instances.shape) def test_tiled_instance_segmentation_with_decoder_2d(self): - from micro_sam.automatic_segmentation import automatic_instance_segmentation + from micro_sam.automatic_segmentation import automatic_instance_segmentation, get_predictor_and_segmenter mask, image = self.large_mask, self.large_image + predictor, segmenter = get_predictor_and_segmenter(model_type=self.model_type, amg=False, is_tiled=True) instances = automatic_instance_segmentation( - input_path=image, model_type=self.model_type_ais, + predictor=predictor, segmenter=segmenter, input_path=image, ndim=2, tile_shape=self.tile_shape, halo=self.halo, ) self.assertEqual(mask.shape, instances.shape) @unittest.skip("Skipping long running tests by default.") def test_automatic_mask_generator_3d(self): - from micro_sam.automatic_segmentation import automatic_instance_segmentation + from micro_sam.automatic_segmentation import automatic_instance_segmentation, get_predictor_and_segmenter labels, volume = self.labels, self.volume + predictor, segmenter = get_predictor_and_segmenter(model_type=self.model_type, amg=True, is_tiled=False) instances = automatic_instance_segmentation( - input_path=volume, model_type=self.model_type, ndim=3, use_amg=True + predictor=predictor, segmenter=segmenter, input_path=volume, ndim=3, ) self.assertEqual(labels.shape, instances.shape) @unittest.skip("Skipping long running tests by default.") def test_tiled_automatic_mask_generator_3d(self): - from micro_sam.automatic_segmentation import automatic_instance_segmentation + from micro_sam.automatic_segmentation import automatic_instance_segmentation, get_predictor_and_segmenter labels, volume = self.large_labels, self.large_volume + predictor, segmenter = get_predictor_and_segmenter(model_type=self.model_type, amg=True, is_tiled=True) instances = automatic_instance_segmentation( - input_path=volume, - model_type=self.model_type, - ndim=3, - tile_shape=self.tile_shape, - halo=self.halo, - use_amg=True, + predictor=predictor, segmenter=segmenter, input_path=volume, + ndim=3, tile_shape=self.tile_shape, halo=self.halo, ) self.assertEqual(labels.shape, instances.shape) def test_instance_segmentation_with_decoder_3d(self): - from micro_sam.automatic_segmentation import automatic_instance_segmentation + from micro_sam.automatic_segmentation import automatic_instance_segmentation, get_predictor_and_segmenter labels, volume = self.labels, self.volume + predictor, segmenter = get_predictor_and_segmenter(model_type=self.model_type, amg=False, is_tiled=False) instances = automatic_instance_segmentation( - input_path=volume, model_type=self.model_type_ais, ndim=3, + predictor=predictor, segmenter=segmenter, input_path=volume, ndim=3, ) self.assertEqual(labels.shape, instances.shape) def test_tiled_instance_segmentation_with_decoder_3d(self): - from micro_sam.automatic_segmentation import automatic_instance_segmentation + from micro_sam.automatic_segmentation import automatic_instance_segmentation, get_predictor_and_segmenter labels, volume = self.large_labels, self.large_volume + predictor, segmenter = get_predictor_and_segmenter(model_type=self.model_type, amg=False, is_tiled=True) instances = automatic_instance_segmentation( - input_path=volume, model_type=self.model_type_ais, ndim=3, tile_shape=self.tile_shape, halo=self.halo, + predictor=predictor, segmenter=segmenter, input_path=volume, + ndim=3, tile_shape=self.tile_shape, halo=self.halo, ) self.assertEqual(labels.shape, instances.shape) diff --git a/test/test_training.py b/test/test_training.py index 2ad809fda..d89a15978 100644 --- a/test/test_training.py +++ b/test/test_training.py @@ -125,7 +125,7 @@ def _run_inference_and_check_results( self.assertEqual(len(pred_paths), len(label_paths)) eval_res = evaluation.run_evaluation(label_paths, pred_paths, verbose=False) - result = eval_res["sa50"].values.item() + result = eval_res["SA50"].values.item() # We check against the expected segmentation accuracy. self.assertGreater(result, expected_sa) @@ -172,7 +172,7 @@ def test_training(self): ) self._run_inference_and_check_results( export_path, model_type, prediction_dir=prediction_dir, - inference_function=box_inference, expected_sa=0.95, + inference_function=box_inference, expected_sa=0.8, ) # Check the model with interactive inference. @@ -184,7 +184,7 @@ def test_training(self): ) self._run_inference_and_check_results( export_path, model_type, prediction_dir=prediction_dir, - inference_function=iterative_inference, expected_sa=0.95, + inference_function=iterative_inference, expected_sa=0.8, ) From cd4418a0860524c97fc6066dec28b48e2c312ad8 Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Tue, 15 Oct 2024 08:27:34 +0200 Subject: [PATCH 44/53] Expose is_seg_dataset argument in sam dataset (#736) Expose is_seg_dataset argument in sam dataset --- micro_sam/training/training.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/micro_sam/training/training.py b/micro_sam/training/training.py index 6c66ccb39..9fdf9edeb 100644 --- a/micro_sam/training/training.py +++ b/micro_sam/training/training.py @@ -389,6 +389,7 @@ def default_sam_dataset( is_train: bool = True, min_size: int = 25, max_sampling_attempts: Optional[int] = None, + is_seg_dataset: Optional[bool] = None, **kwargs, ) -> Dataset: """Create a PyTorch Dataset for training a SAM model. @@ -412,6 +413,8 @@ def default_sam_dataset( is_train: Whether this dataset is used for training or validation. min_size: Minimal object size. Smaller objects will be filtered. max_sampling_attempts: Number of sampling attempts to make from a dataset. + is_seg_dataset: Whether the dataset is built 'from torch_em.data import SegmentationDataset' + or 'from torch_em.data import ImageCollectionDataset' Returns: The dataset. @@ -443,8 +446,8 @@ def default_sam_dataset( # Set a minimum number of samples per epoch. if n_samples is None: loader = torch_em.default_segmentation_loader( - raw_paths, raw_key, label_paths, label_key, - batch_size=1, patch_shape=patch_shape, ndim=2 + raw_paths, raw_key, label_paths, label_key, batch_size=1, + patch_shape=patch_shape, ndim=2, is_seg_dataset=is_seg_dataset, ) n_samples = max(len(loader), 100 if is_train else 5) @@ -454,6 +457,7 @@ def default_sam_dataset( raw_transform=raw_transform, label_transform=label_transform, with_channels=with_channels, ndim=2, sampler=sampler, n_samples=n_samples, + is_seg_dataset=is_seg_dataset, **kwargs, ) From 896ea004e2a0562490d55afa0889ee11df4612e9 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Wed, 16 Oct 2024 08:06:32 +0200 Subject: [PATCH 45/53] Update the image series annotator (#738) --- micro_sam/precompute_state.py | 2 +- .../sam_annotator/image_series_annotator.py | 145 +++++++++++------- micro_sam/sam_annotator/util.py | 2 + 3 files changed, 95 insertions(+), 54 deletions(-) diff --git a/micro_sam/precompute_state.py b/micro_sam/precompute_state.py index 520f2baec..736182603 100644 --- a/micro_sam/precompute_state.py +++ b/micro_sam/precompute_state.py @@ -164,7 +164,7 @@ def _precompute_state_for_file( # Precompute the image embeddings. output_path = Path(output_path).with_suffix(".zarr") embeddings = util.precompute_image_embeddings( - predictor, image_data, output_path, ndim=ndim, tile_shape=tile_shape, halo=halo, + predictor, image_data, output_path, ndim=ndim, tile_shape=tile_shape, halo=halo, verbose=False ) # Precompute the state for automatic instance segmnetaiton (AMG or AIS). diff --git a/micro_sam/sam_annotator/image_series_annotator.py b/micro_sam/sam_annotator/image_series_annotator.py index aec6ec028..cd5a15e2c 100644 --- a/micro_sam/sam_annotator/image_series_annotator.py +++ b/micro_sam/sam_annotator/image_series_annotator.py @@ -80,6 +80,46 @@ def _get_input_shape(image, is_volumetric=False): return image_shape +def _initialize_annotator( + viewer, image, image_embedding_path, + model_type, halo, tile_shape, predictor, decoder, is_volumetric, + precompute_amg_state, checkpoint_path, device, + embedding_path, +): + if viewer is None: + viewer = napari.Viewer() + viewer.add_image(image, name="image") + + state = AnnotatorState() + state.initialize_predictor( + image, model_type=model_type, save_path=image_embedding_path, halo=halo, tile_shape=tile_shape, + predictor=predictor, decoder=decoder, + ndim=3 if is_volumetric else 2, precompute_amg_state=precompute_amg_state, + checkpoint_path=checkpoint_path, device=device, skip_load=False, + ) + state.image_shape = _get_input_shape(image, is_volumetric) + + if is_volumetric: + if image.ndim not in [3, 4]: + raise ValueError(f"Invalid image dimensions for 3d annotator, expect 3 or 4 dimensions, got {image.ndim}") + annotator = Annotator3d(viewer) + else: + if image.ndim not in (2, 3): + raise ValueError(f"Invalid image dimensions for 2d annotator, expect 2 or 3 dimensions, got {image.ndim}") + annotator = Annotator2d(viewer) + + annotator._update_image() + + # Add the annotator widget to the viewer and sync widgets. + viewer.window.add_dock_widget(annotator) + _sync_embedding_widget( + state.widgets["embeddings"], model_type, + save_path=embedding_path, checkpoint_path=checkpoint_path, + device=device, tile_shape=tile_shape, halo=halo + ) + return viewer, annotator + + def image_series_annotator( images: Union[List[Union[os.PathLike, str]], List[np.ndarray]], output_folder: str, @@ -94,6 +134,7 @@ def image_series_annotator( is_volumetric: bool = False, device: Optional[Union[str, torch.device]] = None, prefer_decoder: bool = True, + skip_segmented: bool = True, ) -> Optional["napari.viewer.Viewer"]: """Run the annotation tool for a series of images (supported for both 2d and 3d images). @@ -116,13 +157,13 @@ def image_series_annotator( is_volumetric: Whether to use the 3d annotator. prefer_decoder: Whether to use decoder based instance segmentation if the model used has an additional decoder for instance segmentation. + skip_segmented: Whether to skip images that were already segmented. Returns: The napari viewer, only returned if `return_viewer=True`. """ - + end_msg = "You have annotated the last image. Do you wish to close napari?" os.makedirs(output_folder, exist_ok=True) - next_image_id = 0 # Precompute embeddings and amg state (if corresponding options set). predictor, decoder, embedding_paths = _precompute( @@ -132,57 +173,48 @@ def image_series_annotator( ndim=3 if is_volumetric else 2, prefer_decoder=prefer_decoder, ) - # Load the first image and intialize the viewer, annotator and state. - if isinstance(images[next_image_id], np.ndarray): - image = images[next_image_id] - have_inputs_as_arrays = True - else: - image = imageio.imread(images[next_image_id]) - have_inputs_as_arrays = False - - image_embedding_path = embedding_paths[next_image_id] - - if viewer is None: - viewer = napari.Viewer() - viewer.add_image(image, name="image") - - state = AnnotatorState() - state.initialize_predictor( - image, model_type=model_type, save_path=image_embedding_path, halo=halo, tile_shape=tile_shape, - predictor=predictor, decoder=decoder, - ndim=3 if is_volumetric else 2, precompute_amg_state=precompute_amg_state, - checkpoint_path=checkpoint_path, device=device, skip_load=False, - ) - state.image_shape = _get_input_shape(image, is_volumetric) - - if is_volumetric: - if image.ndim not in [3, 4]: - raise ValueError(f"Invalid image dimensions for 3d annotator, expect 3 or 4 dimensions, got {image.ndim}") - annotator = Annotator3d(viewer) - else: - if image.ndim not in (2, 3): - raise ValueError(f"Invalid image dimensions for 2d annotator, expect 2 or 3 dimensions, got {image.ndim}") - annotator = Annotator2d(viewer) - - annotator._update_image() - - # Add the annotator widget to the viewer and sync widgets. - viewer.window.add_dock_widget(annotator) - _sync_embedding_widget( - state.widgets["embeddings"], model_type, - save_path=embedding_path, checkpoint_path=checkpoint_path, - device=device, tile_shape=tile_shape, halo=halo - ) + next_image_id = 0 + have_inputs_as_arrays = isinstance(images[next_image_id], np.ndarray) - def _save_segmentation(image_path, current_idx, segmentation): + def _get_save_path(image_path, current_idx): if have_inputs_as_arrays: fname = f"seg_{current_idx:05}.tif" else: fname = os.path.basename(image_path) fname = os.path.splitext(fname)[0] + ".tif" + return os.path.join(output_folder, fname) + + # Check which image to load next if we skip segmented images. + image_embedding_path = None + if skip_segmented: + while True: + if next_image_id == len(images): + print(end_msg) + return - out_path = os.path.join(output_folder, fname) - imageio.imwrite(out_path, segmentation) + save_path = _get_save_path(images[next_image_id], next_image_id) + if not os.path.exists(save_path): + print("The first image to annotate is image number", next_image_id) + image = images[next_image_id] + if not have_inputs_as_arrays: + image = imageio.imread(image) + image_embedding_path = embedding_paths[next_image_id] + break + + next_image_id += 1 + + # Initialize the viewer and annotator for this image. + state = AnnotatorState() + viewer, annotator = _initialize_annotator( + viewer, image, image_embedding_path, + model_type, halo, tile_shape, predictor, decoder, is_volumetric, + precompute_amg_state, checkpoint_path, device, + embedding_path, + ) + + def _save_segmentation(image_path, current_idx, segmentation): + save_path = _get_save_path(image_path, next_image_id) + imageio.imwrite(save_path, segmentation, compression="zlib") # Add functionality for going to the next image. @magicgui(call_button="Next Image [N]") @@ -203,14 +235,20 @@ def next_image(*args): # Clear the segmentation already to avoid lagging removal. viewer.layers["committed_objects"].data = np.zeros_like(viewer.layers["committed_objects"].data) - # Load the next image. + # Go to the next images, if skipping images that are already segmented check if we have to load it. next_image_id += 1 + if skip_segmented: + save_path = _get_save_path(images[next_image_id], next_image_id) + while os.path.exists(save_path): + next_image_id += 1 + if next_image_id == len(images): + break + save_path = _get_save_path(images[next_image_id], next_image_id) + + # Load the next image. if next_image_id == len(images): - msg = "You have annotated the last image. Do you wish to close napari?" - print(msg) - abort = False - # inform the user via dialog - abort = widgets._generate_message("info", msg) + # Inform the user via dialog. + abort = widgets._generate_message("info", end_msg) if not abort: viewer.close() return @@ -459,6 +497,7 @@ def main(): ) parser.add_argument("--precompute_amg_state", action="store_true") parser.add_argument("--prefer_decoder", action="store_false") + parser.add_argument("--skip_segmented", action="store_false") args = parser.parse_args() @@ -467,5 +506,5 @@ def main(): embedding_path=args.embedding_path, model_type=args.model_type, tile_shape=args.tile_shape, halo=args.halo, precompute_amg_state=args.precompute_amg_state, checkpoint_path=args.checkpoint, device=args.device, is_volumetric=args.is_volumetric, - prefer_decoder=args.prefer_decoder, + prefer_decoder=args.prefer_decoder, skip_segmented=args.skip_segmented ) diff --git a/micro_sam/sam_annotator/util.py b/micro_sam/sam_annotator/util.py index db2ec5187..e0b3b88a5 100644 --- a/micro_sam/sam_annotator/util.py +++ b/micro_sam/sam_annotator/util.py @@ -118,6 +118,8 @@ def clear_annotations(viewer: napari.Viewer, clear_segmentations=True) -> None: viewer.layers["point_prompts"].data = [] viewer.layers["point_prompts"].refresh() if "prompts" in viewer.layers: + # Select all prompts and then remove them. + viewer.layers["prompts"].selected_data = set(range(len(viewer.layers["prompts"].data))) viewer.layers["prompts"].remove_selected() viewer.layers["prompts"].refresh() if not clear_segmentations: From 53117be2cfc7a019c32a1f166bdd6a305b14324b Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Fri, 18 Oct 2024 16:04:10 +0200 Subject: [PATCH 46/53] Disable thread workers to avoid slow-down (#739) --- micro_sam/sam_annotator/_widgets.py | 66 ++++++++++++++++---------- micro_sam/sam_annotator/training_ui.py | 14 +++--- micro_sam/sam_annotator/util.py | 2 + 3 files changed, 50 insertions(+), 32 deletions(-) diff --git a/micro_sam/sam_annotator/_widgets.py b/micro_sam/sam_annotator/_widgets.py index 2325c9a57..7164dbe4d 100644 --- a/micro_sam/sam_annotator/_widgets.py +++ b/micro_sam/sam_annotator/_widgets.py @@ -21,7 +21,10 @@ from superqt import QCollapsible from magicgui import magic_factory from magicgui.widgets import ComboBox, Container, create_widget -from napari.qt.threading import thread_worker +# We have disabled the thread workers for now because they result in a +# massive slowdown in napari >= 0.5. +# See also https://forum.image.sc/t/napari-thread-worker-leads-to-massive-slowdown/103786 +# from napari.qt.threading import thread_worker from napari.utils import progress from ._state import AnnotatorState @@ -1088,7 +1091,7 @@ def __call__(self, skip_validate=False): # Set up progress bar and signals for using it within a threadworker. pbar, pbar_signals = _create_pbar_for_threadworker() - @thread_worker() + # @thread_worker() def compute_image_embedding(): def pbar_init(total, description): @@ -1103,10 +1106,12 @@ def pbar_init(total, description): ) pbar_signals.pbar_stop.emit() - worker = compute_image_embedding() - worker.returned.connect(self._update_model) - worker.start() - return worker + compute_image_embedding() + self._update_model() + # worker = compute_image_embedding() + # worker.returned.connect(self._update_model) + # worker.start() + # return worker # @@ -1195,7 +1200,7 @@ def _run_tracking(self): state = AnnotatorState() pbar, pbar_signals = _create_pbar_for_threadworker() - @thread_worker + # @thread_worker def tracking_impl(): shape = state.image_shape @@ -1237,15 +1242,17 @@ def update_segmentation(ret_val): self._viewer.layers["current_object"].data[seg == 1] = state.current_track_id self._viewer.layers["current_object"].refresh() - worker = tracking_impl() - worker.returned.connect(update_segmentation) - worker.start() - return worker + ret_val = tracking_impl() + update_segmentation(ret_val) + # worker = tracking_impl() + # worker.returned.connect(update_segmentation) + # worker.start() + # return worker def _run_volumetric_segmentation(self): pbar, pbar_signals = _create_pbar_for_threadworker() - @thread_worker + # @thread_worker def volumetric_segmentation_impl(): state = AnnotatorState() shape = state.image_shape @@ -1277,10 +1284,13 @@ def update_segmentation(seg): self._viewer.layers["current_object"].data = seg self._viewer.layers["current_object"].refresh() - worker = volumetric_segmentation_impl() - worker.returned.connect(update_segmentation) - worker.start() - return worker + seg = volumetric_segmentation_impl() + self._viewer.layers["current_object"].data = seg + self._viewer.layers["current_object"].refresh() + # worker = volumetric_segmentation_impl() + # worker.returned.connect(update_segmentation) + # worker.start() + # return worker def __call__(self): if _validate_embeddings(self._viewer): @@ -1522,7 +1532,7 @@ def _empty_segmentation_warning(self): def _run_segmentation_2d(self, kwargs, i=None): pbar, pbar_signals = _create_pbar_for_threadworker() - @thread_worker + # @thread_worker def seg_impl(): def pbar_init(total, description): pbar_signals.pbar_total.emit(total) @@ -1548,10 +1558,12 @@ def update_segmentation(seg): self._viewer.layers["auto_segmentation"].data[i] = seg self._viewer.layers["auto_segmentation"].refresh() - worker = seg_impl() - worker.returned.connect(update_segmentation) - worker.start() - return worker + seg = seg_impl() + update_segmentation(seg) + # worker = seg_impl() + # worker.returned.connect(update_segmentation) + # worker.start() + # return worker # We refuse to run 3D segmentation with the AMG unless we have a GPU or all embeddings # are precomputed. Otherwise this would take too long. @@ -1578,7 +1590,7 @@ def _run_segmentation_3d(self, kwargs): pbar, pbar_signals = _create_pbar_for_threadworker() - @thread_worker + # @thread_worker def seg_impl(): segmentation = np.zeros_like(self._viewer.layers["auto_segmentation"].data) offset = 0 @@ -1617,10 +1629,12 @@ def update_segmentation(segmentation): self._viewer.layers["auto_segmentation"].data = segmentation self._viewer.layers["auto_segmentation"].refresh() - worker = seg_impl() - worker.returned.connect(update_segmentation) - worker.start() - return worker + seg = seg_impl() + update_segmentation(seg) + # worker = seg_impl() + # worker.returned.connect(update_segmentation) + # worker.start() + # return worker def __call__(self): if _validate_embeddings(self._viewer): diff --git a/micro_sam/sam_annotator/training_ui.py b/micro_sam/sam_annotator/training_ui.py index f36be0c96..35e9c8d7d 100644 --- a/micro_sam/sam_annotator/training_ui.py +++ b/micro_sam/sam_annotator/training_ui.py @@ -2,7 +2,7 @@ import warnings from qtpy import QtWidgets -from napari.qt.threading import thread_worker +# from napari.qt.threading import thread_worker import torch from torch.utils.data import random_split @@ -238,7 +238,7 @@ def __call__(self, skip_validate=False): else: checkpoint_path = self.checkpoint - @thread_worker() + # @thread_worker() def run_training(): train_loader, val_loader = self._get_loaders() train_sam_for_configuration( @@ -296,7 +296,9 @@ def run_training(): pbar_signals.pbar_stop.emit() return export_checkpoint - worker = run_training() - worker.returned.connect(lambda path: print(f"Training has finished. The trained model is saved at {path}.")) - worker.start() - return worker + path = run_training() + print(f"Training has finished. The trained model is saved at {path}.") + # worker = run_training() + # worker.returned.connect(lambda path: print(f"Training has finished. The trained model is saved at {path}.")) + # worker.start() + # return worker diff --git a/micro_sam/sam_annotator/util.py b/micro_sam/sam_annotator/util.py index e0b3b88a5..869f2aea2 100644 --- a/micro_sam/sam_annotator/util.py +++ b/micro_sam/sam_annotator/util.py @@ -119,6 +119,8 @@ def clear_annotations(viewer: napari.Viewer, clear_segmentations=True) -> None: viewer.layers["point_prompts"].refresh() if "prompts" in viewer.layers: # Select all prompts and then remove them. + # This is how it worked before napari 0.5. + # viewer.layers["prompts"].data = [] viewer.layers["prompts"].selected_data = set(range(len(viewer.layers["prompts"].data))) viewer.layers["prompts"].remove_selected() viewer.layers["prompts"].refresh() From c5f04d9bf3ef2592cc27acbe1519a47a83a1f032 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Fri, 18 Oct 2024 16:45:17 +0200 Subject: [PATCH 47/53] Update syntax of get_predictor_and_segmenter and fix long running tests --- micro_sam/automatic_segmentation.py | 27 ++++++++++++++++----------- test/test_automatic_segmentation.py | 20 ++++++++++++-------- 2 files changed, 28 insertions(+), 19 deletions(-) diff --git a/micro_sam/automatic_segmentation.py b/micro_sam/automatic_segmentation.py index 2561d8e2b..e6d908f9e 100644 --- a/micro_sam/automatic_segmentation.py +++ b/micro_sam/automatic_segmentation.py @@ -16,9 +16,9 @@ def get_predictor_and_segmenter( model_type: str, checkpoint: Optional[Union[os.PathLike, str]] = None, device: str = None, - amg: bool = False, + amg: Optional[bool] = None, is_tiled: bool = False, - amg_kwargs: Dict = {} + **kwargs, ) -> Tuple[util.SamPredictor, Union[AMGBase, InstanceSegmentationWithDecoder]]: """Get the Segment Anything model and class for automatic instance segmentation. @@ -27,7 +27,10 @@ def get_predictor_and_segmenter( checkpoint: The filepath to the stored model checkpoints. device: The torch device. amg: Whether to perform automatic segmentation in AMG mode. + Otherwise AIS will be used, which requires a special segmentation decoder. + If not specified AIS will be used if it is available and otherwise AMG will be used. is_tiled: Whether to return segmenter for performing segmentation in tiling window style. + kwargs: Keyword arguments for the automatic instance segmentation class. Returns: The Segment Anything model. @@ -41,20 +44,22 @@ def get_predictor_and_segmenter( model_type=model_type, device=device, checkpoint_path=checkpoint, return_state=True, ) - # Get the segmenter for automatic segmentation. - assert isinstance(amg_kwargs, Dict), "Please ensure 'amg_kwargs' gets arguments in a dictionary." + if amg is None: + amg = "decoder_state" not in state + if amg: + decoder = None + else: + if "decoder_state" not in state: + raise RuntimeError("You have passed amg=False, but your model does not contain a segmentation decoder.") + decoder_state = state["decoder_state"] + decoder = get_decoder(image_encoder=predictor.model.image_encoder, decoder_state=decoder_state, device=device) segmenter = get_amg( predictor=predictor, is_tiled=is_tiled, - decoder=get_decoder( - image_encoder=predictor.model.image_encoder, - decoder_state=state["decoder_state"], - device=device - ) if "decoder_state" in state and not amg else None, - **amg_kwargs + decoder=decoder, + **kwargs ) - return predictor, segmenter diff --git a/test/test_automatic_segmentation.py b/test/test_automatic_segmentation.py index e0bb4287a..0e27fc960 100644 --- a/test/test_automatic_segmentation.py +++ b/test/test_automatic_segmentation.py @@ -70,7 +70,7 @@ def test_automatic_mask_generator_2d(self): mask, image = self.mask, self.image predictor, segmenter = get_predictor_and_segmenter( - model_type=self.model_type, amg=True, is_tiled=False, amg_kwargs={"points_per_side": 4} + model_type=self.model_type, amg=True, is_tiled=False, points_per_side=4 ) instances = automatic_instance_segmentation( predictor=predictor, segmenter=segmenter, input_path=image, ndim=2, @@ -82,7 +82,7 @@ def test_tiled_automatic_mask_generator_2d(self): mask, image = self.large_mask, self.large_image predictor, segmenter = get_predictor_and_segmenter( - model_type=self.model_type, amg=True, is_tiled=True, amg_kwargs={"points_per_side": 4} + model_type=self.model_type, amg=True, is_tiled=True, points_per_side=4, ) instances = automatic_instance_segmentation( predictor=predictor, segmenter=segmenter, input_path=image, @@ -94,7 +94,7 @@ def test_instance_segmentation_with_decoder_2d(self): from micro_sam.automatic_segmentation import automatic_instance_segmentation, get_predictor_and_segmenter mask, image = self.mask, self.image - predictor, segmenter = get_predictor_and_segmenter(model_type=self.model_type, amg=False, is_tiled=False) + predictor, segmenter = get_predictor_and_segmenter(model_type=self.model_type_ais, amg=False, is_tiled=False) instances = automatic_instance_segmentation( predictor=predictor, segmenter=segmenter, input_path=image, ndim=2, ) @@ -104,7 +104,7 @@ def test_tiled_instance_segmentation_with_decoder_2d(self): from micro_sam.automatic_segmentation import automatic_instance_segmentation, get_predictor_and_segmenter mask, image = self.large_mask, self.large_image - predictor, segmenter = get_predictor_and_segmenter(model_type=self.model_type, amg=False, is_tiled=True) + predictor, segmenter = get_predictor_and_segmenter(model_type=self.model_type_ais, amg=False, is_tiled=True) instances = automatic_instance_segmentation( predictor=predictor, segmenter=segmenter, input_path=image, ndim=2, tile_shape=self.tile_shape, halo=self.halo, @@ -116,7 +116,9 @@ def test_automatic_mask_generator_3d(self): from micro_sam.automatic_segmentation import automatic_instance_segmentation, get_predictor_and_segmenter labels, volume = self.labels, self.volume - predictor, segmenter = get_predictor_and_segmenter(model_type=self.model_type, amg=True, is_tiled=False) + predictor, segmenter = get_predictor_and_segmenter( + model_type=self.model_type, amg=True, is_tiled=False, points_per_side=4 + ) instances = automatic_instance_segmentation( predictor=predictor, segmenter=segmenter, input_path=volume, ndim=3, ) @@ -127,7 +129,9 @@ def test_tiled_automatic_mask_generator_3d(self): from micro_sam.automatic_segmentation import automatic_instance_segmentation, get_predictor_and_segmenter labels, volume = self.large_labels, self.large_volume - predictor, segmenter = get_predictor_and_segmenter(model_type=self.model_type, amg=True, is_tiled=True) + predictor, segmenter = get_predictor_and_segmenter( + model_type=self.model_type, amg=True, is_tiled=True, points_per_side=4 + ) instances = automatic_instance_segmentation( predictor=predictor, segmenter=segmenter, input_path=volume, ndim=3, tile_shape=self.tile_shape, halo=self.halo, @@ -138,7 +142,7 @@ def test_instance_segmentation_with_decoder_3d(self): from micro_sam.automatic_segmentation import automatic_instance_segmentation, get_predictor_and_segmenter labels, volume = self.labels, self.volume - predictor, segmenter = get_predictor_and_segmenter(model_type=self.model_type, amg=False, is_tiled=False) + predictor, segmenter = get_predictor_and_segmenter(model_type=self.model_type_ais, amg=False, is_tiled=False) instances = automatic_instance_segmentation( predictor=predictor, segmenter=segmenter, input_path=volume, ndim=3, ) @@ -148,7 +152,7 @@ def test_tiled_instance_segmentation_with_decoder_3d(self): from micro_sam.automatic_segmentation import automatic_instance_segmentation, get_predictor_and_segmenter labels, volume = self.large_labels, self.large_volume - predictor, segmenter = get_predictor_and_segmenter(model_type=self.model_type, amg=False, is_tiled=True) + predictor, segmenter = get_predictor_and_segmenter(model_type=self.model_type_ais, amg=False, is_tiled=True) instances = automatic_instance_segmentation( predictor=predictor, segmenter=segmenter, input_path=volume, ndim=3, tile_shape=self.tile_shape, halo=self.halo, From 7017a9da74ef8dc6acdfe7e64f855c005c9e8f0c Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Fri, 18 Oct 2024 16:55:11 +0200 Subject: [PATCH 48/53] Specify the correct python version in the mamba set-up --- .github/workflows/test.yaml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 35d6c92c0..b52fc27dd 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -30,6 +30,8 @@ jobs: uses: mamba-org/setup-micromamba@v1 with: environment-file: environment_cpu.yaml + create-args: >- + python=${{ matrix.python-version }} # Setup Qt libraries for GUI testing on Linux - uses: tlambert03/setup-qt-libs@v1 From b1c8a41bcf69e44bd9ad57f9db02a067aa2f4010 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Fri, 18 Oct 2024 16:57:06 +0200 Subject: [PATCH 49/53] Don't pin nifty --- environment_cpu.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/environment_cpu.yaml b/environment_cpu.yaml index d56682ff5..b8c68641b 100644 --- a/environment_cpu.yaml +++ b/environment_cpu.yaml @@ -4,7 +4,8 @@ channels: - conda-forge dependencies: - cpuonly - - nifty =1.2.1=*_4 + # - nifty =1.2.1=*_4 + - nifty - imagecodecs - magicgui - napari From 55d0f5b66991e6190742aeb1f306f32cf13137a2 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Fri, 18 Oct 2024 17:18:34 +0200 Subject: [PATCH 50/53] Update test config and fix minor issues --- .github/workflows/test.yaml | 4 +++- environment_cpu.yaml | 4 ++-- environment_gpu.yaml | 1 + micro_sam/multi_dimensional_segmentation.py | 2 +- micro_sam/sam_annotator/_widgets.py | 6 +++--- micro_sam/sam_annotator/util.py | 4 ++-- 6 files changed, 12 insertions(+), 9 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index b52fc27dd..779dbbbb8 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -20,7 +20,9 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest, windows-latest, macos-latest] - python-version: ["3.11", "3.12"] + # 3.12 currently not supported due to issues with nifty. + # python-version: ["3.11", "3.12"] + python-version: ["3.11"] steps: - name: Checkout diff --git a/environment_cpu.yaml b/environment_cpu.yaml index b8c68641b..1ef9d909c 100644 --- a/environment_cpu.yaml +++ b/environment_cpu.yaml @@ -4,8 +4,8 @@ channels: - conda-forge dependencies: - cpuonly - # - nifty =1.2.1=*_4 - - nifty + # This pin is necessary because later nifty versions have import errors on windows. + - nifty =1.2.1=*_4 - imagecodecs - magicgui - napari diff --git a/environment_gpu.yaml b/environment_gpu.yaml index c77a834e1..c31773f3f 100644 --- a/environment_gpu.yaml +++ b/environment_gpu.yaml @@ -5,6 +5,7 @@ channels: - conda-forge dependencies: - imagecodecs + # This pin is necessary because later nifty versions have import errors on windows. - nifty =1.2.1=*_4 - magicgui - napari diff --git a/micro_sam/multi_dimensional_segmentation.py b/micro_sam/multi_dimensional_segmentation.py index 7ecaf14e6..7828a9e2e 100644 --- a/micro_sam/multi_dimensional_segmentation.py +++ b/micro_sam/multi_dimensional_segmentation.py @@ -325,7 +325,7 @@ def merge_instance_segmentation_3d( uv_ids = np.array([[edge["source"], edge["target"]] for edge in edges]) overlaps = np.array([edge["score"] for edge in edges]) - n_nodes = int(slice_segmentation[-1].max() + 1) + n_nodes = int(slice_segmentation.max() + 1) graph = nifty.graph.undirectedGraph(n_nodes) graph.insertEdges(uv_ids) diff --git a/micro_sam/sam_annotator/_widgets.py b/micro_sam/sam_annotator/_widgets.py index 7164dbe4d..07fa56b20 100644 --- a/micro_sam/sam_annotator/_widgets.py +++ b/micro_sam/sam_annotator/_widgets.py @@ -174,7 +174,7 @@ def _add_path_param(self, name, value, select_type, title=None, placeholder=None layout.addWidget(label) path_textbox = QtWidgets.QLineEdit() - path_textbox.setText(value) + path_textbox.setText(str(value)) if placeholder is not None: path_textbox.setPlaceholderText(placeholder) path_textbox.textChanged.connect(lambda val: setattr(self, name, val)) @@ -210,7 +210,7 @@ def _get_directory_path(self, name, textbox, tooltip=None): if tooltip: directory.setToolTip(tooltip) if directory and Path(directory).is_dir(): - textbox.setText(directory) + textbox.setText(str(directory)) else: # Handle the case where the selected path is not a directory print("Invalid directory selected. Please try again.") @@ -222,7 +222,7 @@ def _get_file_path(self, name, textbox, tooltip=None): if tooltip: file_path.setToolTip(tooltip) if file_path and Path(file_path).is_file(): - textbox.setText(file_path) + textbox.setText(str(file_path)) else: # Handle the case where the selected path is not a file print("Invalid file selected. Please try again.") diff --git a/micro_sam/sam_annotator/util.py b/micro_sam/sam_annotator/util.py index 869f2aea2..1887f371e 100644 --- a/micro_sam/sam_annotator/util.py +++ b/micro_sam/sam_annotator/util.py @@ -684,10 +684,10 @@ def _sync_embedding_widget(widget, model_type, save_path, checkpoint_path, devic widget.model_dropdown.setCurrentIndex(index) if save_path is not None: - widget.embeddings_save_path_param.setText(save_path) + widget.embeddings_save_path_param.setText(str(save_path)) if checkpoint_path is not None: - widget.custom_weights_param.setText(checkpoint_path) + widget.custom_weights_param.setText(str(checkpoint_path)) if device is not None: widget.device = device From a01fad1a27a379a88880cdbd5638cc61e2395edb Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Fri, 18 Oct 2024 17:43:38 +0200 Subject: [PATCH 51/53] More test fixes --- test/test_sam_annotator/test_widgets.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/test/test_sam_annotator/test_widgets.py b/test/test_sam_annotator/test_widgets.py index a239e98f7..b3edd1fd9 100644 --- a/test/test_sam_annotator/test_widgets.py +++ b/test/test_sam_annotator/test_widgets.py @@ -1,8 +1,12 @@ import json import os import platform +import warnings -from mobile_sam.predictor import SamPredictor as MobileSamPredictor +# Avoid import warnigns from mobile_sam +with warnings.catch_warnings(): + warnings.simplefilter("ignore") + from mobile_sam.predictor import SamPredictor as MobileSamPredictor from segment_anything.predictor import SamPredictor import numpy as np import pytest @@ -33,8 +37,11 @@ def test_embedding_widget(make_napari_viewer, tmp_path): my_widget.embeddings_save_path = tmp_path # Run image embedding widget. - worker = my_widget(skip_validate=True) - worker.await_workers() # blocks until thread worker is finished the embedding + my_widget(skip_validate=True) + + # Previous version when we used a thread-worker + # worker = my_widget(skip_validate=True) + # worker.await_workers() # blocks until thread worker is finished the embedding # Check in-memory state for predictor and embeddings. assert isinstance(AnnotatorState().predictor, (SamPredictor, MobileSamPredictor)) From 69f3c018331a0b4128c540985e3ea78c45bc6c37 Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Fri, 18 Oct 2024 17:44:34 +0200 Subject: [PATCH 52/53] Add materials for hands-on workshops (#742) Add materials for workshop and minor fixes --- micro_sam/automatic_segmentation.py | 36 +- micro_sam/training/util.py | 5 + notebooks/sam_finetuning.ipynb | 17 +- workshops/README.md | 102 +++++ workshops/download_datasets.py | 140 +++++++ workshops/download_embeddings.py | 89 +++++ workshops/finetune_sam.ipynb | 563 ++++++++++++++++++++++++++++ workshops/finetune_sam.py | 355 ++++++++++++++++++ 8 files changed, 1282 insertions(+), 25 deletions(-) create mode 100644 workshops/README.md create mode 100644 workshops/download_datasets.py create mode 100644 workshops/download_embeddings.py create mode 100644 workshops/finetune_sam.ipynb create mode 100644 workshops/finetune_sam.py diff --git a/micro_sam/automatic_segmentation.py b/micro_sam/automatic_segmentation.py index e6d908f9e..6cf9a4868 100644 --- a/micro_sam/automatic_segmentation.py +++ b/micro_sam/automatic_segmentation.py @@ -87,7 +87,8 @@ def automatic_instance_segmentation( embedding_path: The path where the embeddings are cached already / will be saved. key: The key to the input file. This is needed for container files (eg. hdf5 or zarr) or to load several images as 3d volume. Provide a glob patterm, eg. "*.tif", for this case. - ndim: The dimensionality of the data. + ndim: The dimensionality of the data. By default the dimensionality of the data will be used. + If you have RGB data you have to specify this explicitly, e.g. pass ndim=2 for 2d segmentation of RGB. tile_shape: Shape of the tiles for tiled prediction. By default prediction is run without tiling. halo: Overlap of the tiles for tiled prediction. verbose: Verbosity flag. @@ -102,21 +103,12 @@ def automatic_instance_segmentation( else: image_data = util.load_image_data(input_path, key) - if ndim == 3 or image_data.ndim == 3: - if image_data.ndim != 3: - raise ValueError(f"The inputs do not correspond to three dimensional inputs: '{image_data.ndim}'") + ndim = image_data.ndim if ndim is None else ndim + + if ndim == 2: + if image_data.ndim != 2 or image_data.shape[-1] != 3: + raise ValueError(f"The inputs does not match the shape expectation of 2d inputs: {image_data.shape}") - instances = automatic_3d_segmentation( - volume=image_data, - predictor=predictor, - segmentor=segmenter, - embedding_path=embedding_path, - tile_shape=tile_shape, - halo=halo, - verbose=verbose, - **generate_kwargs - ) - else: # Precompute the image embeddings. image_embeddings = util.precompute_image_embeddings( predictor=predictor, @@ -142,6 +134,20 @@ def automatic_instance_segmentation( instances = np.zeros(this_shape, dtype="uint32") else: instances = mask_data_to_segmentation(masks, with_background=True, min_object_size=0) + else: + if image_data.ndim != 3 or image_data.shape[-1] != 3: + raise ValueError(f"The inputs does not match the shape expectation of 3d inputs: {image_data.shape}") + + instances = automatic_3d_segmentation( + volume=image_data, + predictor=predictor, + segmentor=segmenter, + embedding_path=embedding_path, + tile_shape=tile_shape, + halo=halo, + verbose=verbose, + **generate_kwargs + ) if output_path is not None: # Save the instance segmentation diff --git a/micro_sam/training/util.py b/micro_sam/training/util.py index 7ecf41cd0..f29cbb670 100644 --- a/micro_sam/training/util.py +++ b/micro_sam/training/util.py @@ -246,6 +246,11 @@ def __call__(self, x, y): # +def normalize_to_8bit(raw): + raw = normalize(raw) * 255 + return raw + + class ResizeRawTrafo: def __init__(self, desired_shape, do_rescaling=False, padding="constant"): self.desired_shape = desired_shape diff --git a/notebooks/sam_finetuning.ipynb b/notebooks/sam_finetuning.ipynb index 3c8f007cc..b0bc80a17 100644 --- a/notebooks/sam_finetuning.ipynb +++ b/notebooks/sam_finetuning.ipynb @@ -263,9 +263,7 @@ "import torch\n", "\n", "import torch_em\n", - "from torch_em.model import UNETR\n", "from torch_em.util.debug import check_loader\n", - "from torch_em.loss import DiceBasedDistanceLoss\n", "from torch_em.util.util import get_random_colors\n", "from torch_em.transform.label import PerObjectDistanceTransform\n", "\n", @@ -610,9 +608,9 @@ "# It supports image data in various formats. Here, we load image data and labels from the two\n", "# folders with tif images that were downloaded by the example data functionality, by specifying\n", "# `raw_key` and `label_key` as `*.tif`. This means all images in the respective folders that end with\n", - "# .tif will be loadded.\n", + "# .tif will be loaded.\n", "# The function supports many other file formats. For example, if you have tif stacks with multiple slices\n", - "# instead of multiple tif images in a foldder, then you can pass raw_key=label_key=None.\n", + "# instead of multiple tif images in a folder, then you can pass raw_key=label_key=None.\n", "# For more information, here is the documentation: https://github.com/constantinpape/torch-em/blob/main/torch_em/data/datasets/README.md\n", "\n", "# Load images from multiple files in folder via pattern (here: all tif files)\n", @@ -950,13 +948,15 @@ "outputs": [], "source": [ "def run_automatic_instance_segmentation(image, checkpoint_path, model_type=\"vit_b_lm\", device=None):\n", - " \"\"\"Automatic Instance Segmentation by training an additional instance decoder in SAM.\n", + " \"\"\"Automatic Instance Segmentation trained with an additional instance segmentation decoder in SAM.\n", "\n", " NOTE: It is supported only for `µsam` models.\n", " \n", " Args:\n", " image: The input image.\n", + " checkpoint: Filepath to the model checkpoints.\n", " model_type: The choice of the `µsam` model.\n", + " device: The torch device.\n", " \n", " Returns:\n", " The instance segmentation.\n", @@ -1392,7 +1392,7 @@ "assert os.path.exists(best_checkpoint), \"Please train the model first to run inference on the finetuned model.\"\n", "assert train_instance_segmentation is True, \"Oops. You didn't opt for finetuning using the decoder-based automatic instance segmentation.\"\n", "\n", - "# # Let's check the first 5 images. Feel free to comment out the line below to run inference on all images.\n", + "# Let's check the first 5 images. Feel free to comment out the line below to run inference on all images.\n", "image_paths = image_paths[:5]\n", "\n", "for image_path in image_paths:\n", @@ -1400,10 +1400,7 @@ " \n", " # Predicted instances\n", " prediction = run_automatic_instance_segmentation(\n", - " image=image,\n", - " checkpoint_path=best_checkpoint,\n", - " model_type=model_type,\n", - " device=device\n", + " image=image, checkpoint_path=best_checkpoint, model_type=model_type, device=device\n", " )\n", "\n", " # Visualize the predictions\n", diff --git a/workshops/README.md b/workshops/README.md new file mode 100644 index 000000000..ae14719d5 --- /dev/null +++ b/workshops/README.md @@ -0,0 +1,102 @@ +# Hands-On Analysis using `micro-sam` + +## Upcoming Workshops: +1. I2K 2024 (Milan, Italy) +2. Virtual I2K 2024 (Online) + +## Introduction + +In this document, we walk you through different steps involved to participate in hands-on image annotation experiments our tool. + +- Here is our [official documentation](https://computational-cell-analytics.github.io/micro-sam/) for detailed explanation of our tools, library and the finetuned models. +- Here is the playlist for our [tutorial videos](https://youtube.com/playlist?list=PLwYZXQJ3f36GQPpKCrSbHjGiH39X4XjSO&si=3q-cIRD6KuoZFmAM) hosted on YouTube, elaborating in detail on the features of our tools. + +## Steps: + +### Step 1: Download the Datasets + +- We provide the script `download_datasets.py` for automatic download of datasets to be used for interactive annotation using `micro-sam`. +- You can run the script as follows: +```bash +$ python download_datasets.py -i -d +``` +where, `DATA_DIRECTORY` is the filepath to the directory where the datasets will be downloaded, and `DATASET_NAME` is the name of the dataset (run `python download_datasets.py -h` in the terminal for more details). + +> NOTE: We have chosen a) subset of the CellPose `cyto` dataset, b) one volume from the EmbedSeg `Mouse-Skull-Nuclei-CBG` dataset from the train split (namely, `X1.tif`), c) one volume from the Platynereis `Membrane` dataset from the train split (namely, `train_data_membrane_02.n5`) and d) the entire `HPA` dataset for the following tasks in `micro-sam`. + +### Step 2: Download the Precomputed Embeddings + +- We provide the script `download_embeddings.py` for automatic download of precompute image embeddings for volumetric data to be used for interactive annotation using `micro-sam`. +- You can run the script as follows: + +```bash +$ python download_embeddings -e -d +``` +where, `EMBEDDING_DIRECTORY` is the filepath to the directory where the precomputed image embeddings will be downloaded, and `DATASET_NAME` is the name of the dataset (run `python download_embeddings.py -h` in the terminal for more details). + +### Additional Section: Precompute the Embeddings Yourself! + +Here is an example guide to precompute the image embeddings (eg. for volumetric data). + +#### EmbedSeg + +```bash +$ micro_sam.precompute_embeddings -i data/embedseg/Mouse-Skull-Nuclei-CBG/train/images/X1.tif # Filepath where inputs are stored. + -m vit_b # You can provide name for any model of your choice (supported by 'micro-sam') (eg. 'vit_b_lm'). + -e embeddings/embedseg/vit_b/embedseg_Mouse-Skull-Nuclei-CBG_train_X1 # Filepath where computed embeddings will be cached. +``` + +#### Platynereis + +```bash +$ micro_sam.precompute_embeddings -i data/platynereis/membrane/train_data_membrane_02.n5 # Filepath where inputs are stored. + -k volumes/raw/s1 # Key to access the data group in container-style data structures. + -m vit_b # You can provide name for any model of your choice (supported by 'micro-sam') (eg. 'vit_b_em_organelles'). + -e embeddings/platynereis/vit_b/platynereis_train_data_membrane_02 # Filepath where computed embeddings will be cached. +``` + +### Step 3: Run the `micro-sam` Annotators (WIP) + +Run the `micro-sam` annotators with the following scripts: + +We recommend using the napari GUI for the interactive annotation. You can use the widget to specify all the essential parameters (eg. the choice of model, the filepath to the precomputed embeddings, etc). + +TODO: add more details here. + +There is another option to use `micro-sam`'s CLI to start our annotator tools. + +#### 2D Annotator (Cell Segmentation in Light Microscopy): + +```bash +$ micro_sam.annotator_2d -i data/cellpose/cyto/test/... # Filepath where the 2d image is stored. + -m vit_b # You can provide name for any model of your choice (supported by 'micro-sam') (eg. 'vit_b_lm') + [OPTIONAL] -e embeddings/cellpose/vit_b/... # Filepath where the computed embeddings will be cached (you can choose to not pass it to compute the embeddings on-the-fly). +``` + +#### 3D Annotator (EmbedSeg - Nuclei Segmentation in Light Microscopy): + +```bash +$ micro_sam.annotator_3d -i data/embedseg/Mouse-Skull-Nuclei-CBG/train/images/X1.tif # Filepath where the 3d volume is stored. + -m vit_b # You can provide name for any model of your choice (supported by 'micro-sam') (eg. 'vit_b_lm') + -e embeddings/embedseg/vit_b/embedseg_Mouse-Skull-Nuclei-CBG_train_X1.zarr # Filepath where the computed embeddings will be cached (we RECOMMEND to provide paths to the downloaded embeddings OR you can choose to not pass it to compute the embeddings on-the-fly). +``` + +#### 3D Annotator (Platynereis - Membrane Segmentation in Electron Microscopy): + +```bash +$ micro_sam.annotator_3d -i data/platynereis/membrane/train_data_membrane_02.n5 # Filepath where the 2d image is stored. + -k volumes/raw/s1 # Key to access the data group in container-style data structures. + -m vit_b # You can provide name for any model of your choice (supported by 'micro-sam') (eg. 'vit_b_em_organelles') + -e embeddings/platynereis/vit_b/... # Filepath where the computed embeddings will be cached (we RECOMMEND to provide paths to the downloaded embeddings OR you can choose to not pass it to compute the embeddings on-the-fly). +``` + +#### Image Series Annotator (Multiple Light Microscopy 2D Images for Cell Segmentation): + +```bash +$ micro_sam.image_series_annotator -i ... + -m vit_b # You can provide name for any model of your choice (supported by 'micro-sam') (eg. 'vit_b_lm') +``` + +### Step 4: Finetune Segment Anything on Microscopy Images (WIP) + +- We provide a notebook `finetune_sam.ipynb` / `finetune_sam.py` for finetuning Segment Anything Model for cell segmentation in confocal microscopy images. diff --git a/workshops/download_datasets.py b/workshops/download_datasets.py new file mode 100644 index 000000000..e96d95eaf --- /dev/null +++ b/workshops/download_datasets.py @@ -0,0 +1,140 @@ +import os +from glob import glob +from natsort import natsorted + +from torch_em.data import datasets +from torch_em.util.image import load_data + + +def _download_sample_data(path, data_dir, url, checksum, download): + if os.path.exists(data_dir): + return + + os.makedirs(path, exist_ok=True) + + zip_path = os.path.join(path, "data.zip") + datasets.util.download_source(path=zip_path, url=url, download=download, checksum=checksum) + datasets.util.unzip(zip_path=zip_path, dst=path) + + +def _get_cellpose_sample_data_paths(path, download): + data_dir = os.path.join(path, "cellpose", "cyto", "test") + + url = "https://owncloud.gwdg.de/index.php/s/slIxlmsglaz0HBE/download" + checksum = "4d1ce7afa6417d051b93d6db37675abc60afe68daf2a4a5db0c787d04583ce8a" + + _download_sample_data(path, data_dir, url, checksum, download) + + raw_paths = natsorted(glob(os.path.join(data_dir, "*_img.png"))) + label_paths = natsorted(glob(os.path.join(data_dir, "*_masks.png"))) + + return raw_paths, label_paths + + +def _get_hpa_data_paths(path, split, download): + urls = [ + "https://owncloud.gwdg.de/index.php/s/zp1Fmm4zEtLuhy4/download", # train + "https://owncloud.gwdg.de/index.php/s/yV7LhGbGfvFGRBE/download", # val + "https://owncloud.gwdg.de/index.php/s/8tLY5jPmpw37beM/download", # test + ] + checksums = [ + "6e5f3ec6b0d505511bea752adaf35529f6b9bb9e7729ad3bdd90ffe5b2d302ab", # train + "4d7a4188cc3d3877b3cf1fbad5f714ced9af4e389801e2136623eac2fde78e9c", # val + "8963ff47cdef95cefabb8941f33a3916258d19d10f532a209bab849d07f9abfe", # test + ] + splits = ["train", "val", "test"] + assert split in splits, f"'{split}' is not a valid split." + + for url, checksum, _split in zip(urls, checksums, splits): + data_dir = os.path.join(path, _split) + _download_sample_data(path, data_dir, url, checksum, download) + + raw_paths = natsorted(glob(os.path.join(path, split, "images", "*.tif"))) + + if split == "test": # The 'test' split for HPA does not have labels. + return raw_paths, None + else: + label_paths = natsorted(glob(os.path.join(path, split, "labels", "*.tif"))) + return raw_paths, label_paths + + +def _get_dataset_paths(path, dataset_name, view=False): + dataset_paths = { + # 2d LM dataset for cell segmentation + "cellpose": lambda: _get_cellpose_sample_data_paths(path=os.path.join(path, "cellpose"), download=True), + "hpa": lambda: _get_hpa_data_paths(path=os.path.join(path, "hpa"), download=True, split="train"), + # 3d LM dataset for nuclei segmentation + "embedseg": lambda: datasets.embedseg_data.get_embedseg_paths( + path=os.path.join(path, "embedseg"), name="Mouse-Skull-Nuclei-CBG", split="train", download=True, + ), + # 3d EM dataset for membrane segmentation + "platynereis": lambda: datasets.platynereis.get_platynereis_paths( + path=os.path.join(path, "platynereis"), sample_ids=None, name="cells", download=True, + ), + } + + dataset_keys = { + "cellpose": [None, None], + "embedseg": [None, None], + "platynereis": ["volumes/raw/s1", "volumes/labels/segmentation/s1"] + } + + if dataset_name is None: # Download all datasets. + dataset_names = list(dataset_paths.keys()) + else: # Download specific datasets. + dataset_names = [dataset_name] + + for dname in dataset_names: + if dname not in dataset_paths: + raise ValueError( + f"'{dname}' is not a supported dataset enabled for download. " + f"Please choose from {list(dataset_paths.keys())}." + ) + + paths = dataset_paths[dname]() + print(f"'{dataset_name}' is download at {path}.") + + if view: + import napari + + if isinstance(paths, tuple): # datasets with explicit raw and label paths + raw_paths, label_paths = paths + else: + raw_paths = label_paths = paths + + raw_key, label_key = dataset_keys[dname] + for raw_path, label_path in zip(raw_paths, label_paths): + raw = load_data(raw_path, raw_key) + labels = load_data(label_path, label_key) + + v = napari.Viewer() + v.add_image(raw) + v.add_labels(labels) + napari.run() + + break # comment this line out in case you would like to visualize all samples. + + +def main(): + import argparse + parser = argparse.ArgumentParser(description="Download the dataset necessary for the workshop.") + parser.add_argument( + "-i", "--input_path", type=str, default="./data", + help="The filepath to the folder where the image data will be downloaded. " + "By default, the data will be stored in your current working directory at './data'." + ) + parser.add_argument( + "-d", "--dataset_name", type=str, default=None, + help="The choice of dataset you would like to download. By default, it downloads all the datasets. " + "Optionally, you can choose to download either of 'cellpose', 'hpa', 'embedseg' or 'platynereis'." + ) + parser.add_argument( + "-v", "--view", action="store_true", help="Whether to view the downloaded data." + ) + args = parser.parse_args() + + _get_dataset_paths(path=args.input_path, dataset_name=args.dataset_name, view=args.view) + + +if __name__ == "__main__": + main() diff --git a/workshops/download_embeddings.py b/workshops/download_embeddings.py new file mode 100644 index 000000000..6df079f53 --- /dev/null +++ b/workshops/download_embeddings.py @@ -0,0 +1,89 @@ +import os + +from torch_em.data.datasets.util import download_source, unzip + + +URLS = { + "lucchi": [ + "https://owncloud.gwdg.de/index.php/s/kQMA1B8L9LOvYrl/download", # vit_b + "https://owncloud.gwdg.de/index.php/s/U8xs6moRg0cQhkS/download", # vit_b_em_organelles + ], + "embedseg": [ + "https://owncloud.gwdg.de/index.php/s/EF9ZdMzYjDjl8fd/download", # vit_b + "https://owncloud.gwdg.de/index.php/s/7IVekm8K7ln7yQ6/download", # vit_b_lm + ], + "platynereis": [ + "https://owncloud.gwdg.de/index.php/s/1OgOEeMIK9Ok2Kj/download", # vit_b + "https://owncloud.gwdg.de/index.php/s/i9DrXe6YFL8jvgP/download", # vit_b_em_organelles + ], +} + +CHECKSUMS = { + "lucchi": [ + "e0d064765f1758a1a0823b2c02d399caa5cae0d8ac5a1e2ed96548a647717433", # vit_b + "e0b5ab781c42e6f68b746fc056c918d56559ccaeedb4e4f2848b1e5e8f1bec58", # vit_b_em_organelles + ], + "embedseg": [ + "82f5351486e484dda5a3a327381458515c89da5dda8a48a0b1ab96ef10d23f02", # vit_b + "80fd701c01b81bbfb32beed6e2ece8c5706625dbc451776d8ba1c22253f097b9", # vit_b_lm + ], + "platynereis": [ + "95c5e31c5e55e94780568f3fb8a3fdf33f8586a4c6a375d28dccba6567f37a47", # vit_b + "3d8d91313656fde271a48ea0a3552762f2536955a357ffb43e7c43b5b27e0627", # vit_b_em_organelles + ], +} + + +def _download_embeddings(embedding_dir, dataset_name): + if dataset_name is None: # Download embeddings for all datasets. + dataset_names = list(URLS.keys()) + else: # Download embeddings for specific dataset. + dataset_names = [dataset_name] + + for dname in dataset_names: + if dname not in URLS: + raise ValueError( + f"'{dname}' does not have precomputed embeddings to download. Please choose from {list(URLS.keys())}" + ) + + urls = URLS[dname] + checksums = CHECKSUMS[dname] + + data_embedding_dir = os.path.join(embedding_dir, dname) + os.makedirs(data_embedding_dir, exist_ok=True) + + # Download the precomputed embeddings as zipfiles and unzip the embeddings per model. + for url, checksum in zip(urls, checksums): + if all([p.startswith("vit_b") for p in os.listdir(data_embedding_dir)]): + continue + + zip_path = os.path.join(data_embedding_dir, "embeddings.zip") + download_source(path=zip_path, url=url, download=True, checksum=checksum) + unzip(zip_path=zip_path, dst=data_embedding_dir) + + print(f"The precompted embeddings for '{dname}' are downloaded at f{data_embedding_dir}") + + +def main(): + import argparse + parser = argparse.ArgumentParser( + description="Download the precomputed image embeddings necessary for interactive annotation." + ) + parser.add_argument( + "-e", "--embedding_dir", type=str, default="./embeddings", + help="The filepath to the folder where the precomputed image embeddings will be downloaded. " + "By default, the embeddings will be stored in your current working directory at './embeddings'." + ) + parser.add_argument( + "-d", "--dataset_name", type=str, default=None, + help="The choice of volumetric dataset for which you would like to download the embeddings. " + "By default, it downloads all the precomputed embeddings. Optionally, you can choose to download either of the " + "volumetric datasets: 'lucchi', 'embedseg' or 'platynereis'." + ) + args = parser.parse_args() + + _download_embeddings(embedding_dir=args.embedding_dir, dataset_name=args.dataset_name) + + +if __name__ == "__main__": + main() diff --git a/workshops/finetune_sam.ipynb b/workshops/finetune_sam.ipynb new file mode 100644 index 000000000..62938d4ea --- /dev/null +++ b/workshops/finetune_sam.ipynb @@ -0,0 +1,563 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Finetuning Segment Anything with `µsam`\n", + "\n", + "This notebook shows how to use Segment Anything for Microscopy to fine-tune a Segment Anything Model (SAM) on an open-source data with multiple channels.\n", + "\n", + "We use confocal microscopy images from the HPA Kaggle Challenge for protein identification (from [Ouyang et al.](https://doi.org/10.1038/s41592-019-0658-6)) in this notebook for the cell segmentation task. The functionalities shown here should work for your (microscopy) images too." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Running this notebook\n", + "\n", + "If you have an environment with `µsam` on your computer you can run this notebook in there. You can follow the [installation instructions](https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#installation) to install it on your computer.\n", + "\n", + "You can also run this notebook in the cloud on [Kaggle Notebooks](https://www.kaggle.com/code/). This service offers free usage of a GPU to speed up running the code. The next cells will take care of the installation for you if you are using it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Check if we are running this notebook on kaggle, google colab or local compute resources.\n", + "\n", + "import os\n", + "current_spot = os.getcwd()\n", + "\n", + "if current_spot.startswith(\"/kaggle/working\"):\n", + " print(\"Kaggle says hi!\")\n", + " root_dir = \"/kaggle/working\"\n", + "\n", + "elif current_spot.startswith(\"/content\"):\n", + " print(\"Google Colab says hi!\")\n", + " print(\" NOTE: The scripts have not been tested on Google Colab, you might need to adapt the installations a bit.\")\n", + " root_dir = \"/content\"\n", + "\n", + " # You might need to install condacolab on Google Colab to be able to install packages using conda / mamba\n", + " # !pip install -q condacolab\n", + " # import condacolab\n", + " # condacolab.install()\n", + "\n", + "else:\n", + " msg = \"You are using a behind-the-scenes resource. Follow our installation instructions here:\"\n", + " msg += \" https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#installation\"\n", + " print(msg)\n", + " root_dir = \"\" # overwrite to set the root directory, where the data, checkpoints, and all relevant stuff will be stored" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Installation\n", + "\n", + "The next cells will install the `micro_sam` library on Kaggle Notebooks. **Please skip these cells and go to `Importing the libraries` if you are running the notebook on your own computer.**" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!git clone --quiet https://github.com/computational-cell-analytics/micro-sam.git\n", + "tmp_dir = os.path.join(root_dir, \"micro-sam\")\n", + "!pip install --quiet $tmp_dir" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!git clone --quiet https://github.com/constantinpape/torch-em.git\n", + "tmp_dir = os.path.join(root_dir, \"torch-em\")\n", + "!pip install --quiet $tmp_dir" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!git clone --quiet https://github.com/constantinpape/elf.git\n", + "tmp_dir = os.path.join(root_dir, \"elf\")\n", + "!pip install --quiet $tmp_dir" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Known Issues on **Kaggle Notebooks**:\n", + "\n", + "1. `warning libmamba Cache file \"/opt/conda/pkgs/cache/2ce54b42.json\" was modified by another program` (multiples lines of such warnings)\n", + " - We have received this warning while testing this notebook on Kaggle. It does not lead to any issues while making use of the installed packages. You can proceed and ignore the warnings." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!mamba install -q -y -c conda-forge nifty affogato zarr z5py\n", + "!pip uninstall -y --quiet qtpy # qtpy is not supported in Kaggle / Google Colab, let's remove it to avoid errors." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Importing the libraries" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from glob import glob\n", + "from typing import List\n", + "from natsort import natsorted\n", + "from IPython.display import FileLink\n", + "\n", + "import imageio.v3 as imageio\n", + "from matplotlib import pyplot as plt\n", + "from skimage.measure import label as connected_components\n", + "\n", + "import torch\n", + "\n", + "from torch_em.data import datasets\n", + "from torch_em.util.debug import check_loader\n", + "from torch_em.util.util import get_random_colors\n", + "\n", + "import micro_sam.training as sam_training\n", + "from micro_sam.training.util import normalize_to_8bit\n", + "from micro_sam.automatic_segmentation import get_predictor_and_segmenter, automatic_instance_segmentation" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Let's download the dataset\n", + "\n", + "First, we download the images and corresponding labels stored as `tif` files." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Download the data into a directory\n", + "DATA_FOLDER = os.path.join(root_dir, \"hpa\")\n", + "\n", + "URLS = [\n", + " \"https://owncloud.gwdg.de/index.php/s/zp1Fmm4zEtLuhy4/download\", # train\n", + " \"https://owncloud.gwdg.de/index.php/s/yV7LhGbGfvFGRBE/download\", # val\n", + " \"https://owncloud.gwdg.de/index.php/s/8tLY5jPmpw37beM/download\", # test\n", + "]\n", + "\n", + "CHECKSUMS = [\n", + " \"6e5f3ec6b0d505511bea752adaf35529f6b9bb9e7729ad3bdd90ffe5b2d302ab\", # train\n", + " \"4d7a4188cc3d3877b3cf1fbad5f714ced9af4e389801e2136623eac2fde78e9c\", # val\n", + " \"8963ff47cdef95cefabb8941f33a3916258d19d10f532a209bab849d07f9abfe\", # test\n", + "]\n", + "\n", + "SPLITS = [\"train\", \"val\", \"test\"]\n", + "\n", + "for url, checksum, split in zip(URLS, CHECKSUMS, SPLITS):\n", + " data_dir = os.path.join(DATA_FOLDER, split)\n", + " if os.path.exists(data_dir):\n", + " continue\n", + " \n", + " os.makedirs(DATA_FOLDER, exist_ok=True)\n", + " zip_path = os.path.join(DATA_FOLDER, \"data.zip\")\n", + " datasets.util.download_source(path=zip_path, url=url, download=True, checksum=checksum)\n", + " datasets.util.unzip(zip_path=zip_path, dst=DATA_FOLDER)\n", + "\n", + "# Get filepaths to the image data.\n", + "train_image_paths = natsorted(glob(os.path.join(DATA_FOLDER, split, \"images\", \"*.tif\")))\n", + "val_image_paths = natsorted(glob(os.path.join(DATA_FOLDER, split, \"images\", \"*.tif\")))\n", + "test_image_paths = natsorted(glob(os.path.join(DATA_FOLDER, split, \"images\", \"*.tif\")))\n", + "\n", + "# Get filepaths to the label data.\n", + "train_label_paths = natsorted(glob(os.path.join(DATA_FOLDER, split, \"labels\", \"*.tif\")))\n", + "val_label_paths = natsorted(glob(os.path.join(DATA_FOLDER, split, \"labels\", \"*.tif\")))\n", + "\n", + "print(f\"The inputs have been preprocessed and stored at: '{DATA_FOLDER}'\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Let's understand our inputs' data structure." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for image_path, label_path in zip(train_image_paths, train_label_paths): # Checking the inputs for the train split.\n", + " image = imageio.imread(image_path)\n", + " labels = imageio.imread(label_path)\n", + "\n", + " # The images should be of shape: H, W, 4 -> where, 4 is the number of channels.\n", + " if (image.ndim == 3 and image.shape[-1] == 3) or image.ndim == 2:\n", + " print(f\"Inputs '{image.shape}' match the channel expectations.\")\n", + " else:\n", + " print(f\"Inputs '{image.shape}' must match the channel expectations (of either one or three channels).\")\n", + "\n", + " # The labels should be of shape: H, W\n", + " print(f\"Shape of corresponding labels: '{labels.shape}'\")\n", + "\n", + " break # comment this line out in case you would like to verify the shapes for all inputs." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Segment Anything accepts inputs of either 1 channel or 3 channels. To fine-tune Segment Anything on our data, we must select either 1 channel or 3 channels out of the 4 channels available.\n", + "\n", + "Let's make the choice to choose the `microtubule` (first channel), `protein` (second channel) and `nuclei` (third channel) for finetuning Segment Anything." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# We remove the 'er' channel, i.e. the last channel.\n", + "def preprocess_inputs(image_paths: List[str]):\n", + " for image_path in image_paths:\n", + " image = imageio.imread(image_path)\n", + "\n", + " if image.ndim == 3 and image.shape[-1] == 4: # Convert 4 channel inputs to 3 channels.\n", + " image = image[..., :-1]\n", + " imageio.imwrite(image_path, image)\n", + "\n", + "preprocess_inputs(train_image_paths)\n", + "preprocess_inputs(val_label_paths)\n", + "preprocess_inputs(test_image_paths)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Let's create the dataloaders\n", + "\n", + "Our task is to segment cells in confocal microscopy images. The dataset comes from https://zenodo.org/records/4665863, and the dataloader has been implemented in [torch-em](https://github.com/constantinpape/torch-em/blob/main/torch_em/data/datasets/light_microscopy/hpa.py)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### First, let's visualize how our samples look." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for image_path, label_path in zip(train_image_paths, train_label_paths): # Visualize inputs for the train split.\n", + " image = imageio.imread(image_path)\n", + " labels = imageio.imread(label_path)\n", + "\n", + " fig, ax = plt.subplots(1, 2, figsize=(10, 10))\n", + " ax[0].imshow(image, cmap=\"gray\")\n", + " ax[0].set_title(\"Input Image\")\n", + " ax[0].axis(\"off\")\n", + " \n", + " labels = connected_components(labels)\n", + " ax[1].imshow(labels, cmap=get_random_colors(labels), interpolation=\"nearest\")\n", + " ax[1].set_title(\"Ground Truth Instances\")\n", + " ax[1].axis(\"off\")\n", + " \n", + " plt.show()\n", + " plt.close()\n", + " \n", + " break # comment this out in case you want to visualize all the images" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Next, let's create the dataloaders." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# micro_sam.training.default_sam_loader is a convenience function to build a torch dataloader.\n", + "# from image data and labels for training segmentation models.\n", + "# This is wrapped around the 'torch_em.default_segmentation_loader'.\n", + "# It supports image data in various formats. Here, we load image data and corresponding labels by providing\n", + "# filepaths to the respective tif files that were download and preprocessed using the functionality above.\n", + "# Next, we create a list of filepaths for the image and label data by fetching all '*.tif' files in the\n", + "# respective directories.\n", + "# For more information, here is the documentation: https://github.com/constantinpape/torch-em/blob/main/torch_em/data/datasets/README.md\n", + "# Here is a detailed notebook on finetuning Segment Anything: https://github.com/computational-cell-analytics/micro-sam/blob/master/notebooks/sam_finetuning.ipynb\n", + "\n", + "# Load images from tif stacks by setting `raw_key` and `label_key` to None.\n", + "raw_key, label_key = None, None" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# The script below returns the train or val data loader for finetuning Segment Anything.\n", + "\n", + "# The data loader must be a torch dataloader that returns `x, y` tensors,\n", + "# where `x` is the image data and `y` are the corresponding labels.\n", + "# The labels have to be in a label mask instance segmentation format.\n", + "# i.e. a tensor of the same spatial shape as `x`, with each object mask having its own ID.\n", + "# IMPORTANT: the ID 0 is reserved for backgroun, and the IDS must be consecutive.\n", + "\n", + "# Here, we use `micro_sam.training.default_sam_loader` for creating the suitable data loader from\n", + "# the HPA data. You can either adapt this for your own data or write a suitable torch dataloader yourself.\n", + "# Here is a quickstart notebook to create your own dataloaders: https://github.com/constantinpape/torch-em/blob/main/notebooks/tutorial_create_dataloaders.ipynb\n", + "\n", + "batch_size = 1 # the training batch size\n", + "patch_shape = (512, 512) # the size of patches for training\n", + "\n", + "# Train an additional convolutional decoder for end-to-end automatic instance segmentation\n", + "train_instance_segmentation = True\n", + "\n", + "# The dataloader internally takes care of adding label transforms: i.e. used to convert the ground-truth\n", + "# labels to the desired instances for finetuning Segment Anythhing, or, to learn the foreground and distances\n", + "# to the object centers and object boundaries for automatic segmentation.\n", + "\n", + "train_loader = sam_training.default_sam_loader(\n", + " raw_paths=train_image_paths,\n", + " raw_key=raw_key,\n", + " label_paths=train_label_paths,\n", + " label_key=label_key,\n", + " is_seg_dataset=False,\n", + " patch_shape=patch_shape,\n", + " with_channels=True,\n", + " with_segmentation_decoder=train_instance_segmentation,\n", + " batch_size=batch_size,\n", + " shuffle=True,\n", + " raw_transform=normalize_to_8bit,\n", + " n_samples=100,\n", + ")\n", + "\n", + "val_loader = sam_training.default_sam_loader(\n", + " raw_paths=val_image_paths,\n", + " raw_key=raw_key,\n", + " label_paths=val_label_paths,\n", + " label_key=label_key,\n", + " is_seg_dataset=False,\n", + " patch_shape=patch_shape,\n", + " with_channels=True,\n", + " with_segmentation_decoder=train_instance_segmentation,\n", + " batch_size=batch_size,\n", + " raw_transform=normalize_to_8bit,\n", + " shuffle=True,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Let's check how our samples lookm from the dataloader.\n", + "check_loader(train_loader, 4, plt=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Run the actual model finetuning" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# All hyperparameters for training.\n", + "n_objects_per_batch = 5 # the number of objects per batch that will be sampled\n", + "device = \"cuda\" if torch.cuda.is_available() else \"cpu\" # the device/GPU used for training\n", + "n_epochs = 5 # how long we train (in epochs)\n", + "\n", + "# The model_type determines which base model is used to initialize the weights that are finetuned.\n", + "# We use vit_b here because it can be trained faster. Note that vit_h usually yields higher quality results.\n", + "model_type = \"vit_b\"\n", + "\n", + "# The name of the checkpoint. The checkpoints will be stored in './checkpoints/'\n", + "checkpoint_name = \"sam_hpa\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**NOTE**: The user needs to decide whether to finetune the Segment Anything model, or the `µsam`'s \"finetuned microscopy models\" for their dataset. Here, we finetune on the Segment Anything model for simplicity. For example, if you choose to finetune the model from the light microscopy generalist models, you need to update the `model_type` to `vit_b_lm` and it takes care of initializing the model with the desired weights)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Run training\n", + "sam_training.train_sam(\n", + " name=checkpoint_name,\n", + " save_root=os.path.join(root_dir, \"models\"),\n", + " model_type=model_type,\n", + " train_loader=train_loader,\n", + " val_loader=val_loader,\n", + " n_epochs=n_epochs,\n", + " n_objects_per_batch=n_objects_per_batch,\n", + " with_segmentation_decoder=train_instance_segmentation,\n", + " device=device,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Let's spot our best checkpoint and download it to get started with the annotation tool\n", + "best_checkpoint = os.path.join(\"models\", \"checkpoints\", checkpoint_name, \"best.pt\")\n", + "\n", + "# # Download link is automatically generated for the best model.\n", + "print(\"Click here \\u2193\")\n", + "FileLink(best_checkpoint)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Let's run the automatic instance segmentation (AIS)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def run_automatic_instance_segmentation(image, checkpoint, model_type=\"vit_b\", device=None):\n", + " \"\"\"Automatic Instance Segmentation trained with an additional instance segmentation decoder in SAM.\n", + "\n", + " NOTE: It is supported only for `µsam` models.\n", + " \n", + " Args:\n", + " image: The input image.\n", + " checkpoint: Filepath to the model checkpoints.\n", + " model_type: The choice of the `µsam` model.\n", + " device: The torch device.\n", + "\n", + " Returns:\n", + " The instance segmentation.\n", + " \"\"\"\n", + " # Step 1: Get the 'predictor' and 'segmenter' to perform automatic instance segmentation.\n", + " predictor, segmenter = get_predictor_and_segmenter(model_type=model_type, checkpoint=checkpoint, device=device)\n", + "\n", + " # Step 2: Get the instance segmentation for the given image.\n", + " instances = automatic_instance_segmentation(predictor=predictor, segmenter=segmenter, input_path=image, ndim=2)\n", + "\n", + " return instances" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "assert os.path.exists(best_checkpoint), \"Please train the model first to run inference on the finetuned model.\"\n", + "assert train_instance_segmentation is True, \"Oops. You didn't opt for finetuning using the decoder-based automatic instance segmentation.\"\n", + "\n", + "for image_path in test_image_paths:\n", + " image = imageio.imread(image_path)\n", + " \n", + " # Predicted instances\n", + " prediction = run_automatic_instance_segmentation(\n", + " image=image, checkpoint_path=best_checkpoint, model_type=model_type, device=device\n", + " )\n", + "\n", + " # Visualize the predictions\n", + " fig, ax = plt.subplots(1, 2, figsize=(10, 10))\n", + "\n", + " ax[0].imshow(image, cmap=\"gray\")\n", + " ax[0].axis(\"off\")\n", + " ax[0].set_title(\"Input Image\")\n", + "\n", + " ax[1].imshow(prediction, cmap=get_random_colors(prediction), interpolation=\"nearest\")\n", + " ax[1].axis(\"off\")\n", + " ax[1].set_title(\"Predictions (AIS)\")\n", + "\n", + " plt.show()\n", + " plt.close()\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### What next?\n", + "\n", + "It's time to get started with your custom finetuned model using the annotator tool. Here is the documentation on how to get started with `µsam`: [Annotation Tools](https://computational-cell-analytics.github.io/micro-sam/micro_sam.html#annotation-tools)\n", + "\n", + "Happy annotating!" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "*This notebook was last ran on October 18, 2024*" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/workshops/finetune_sam.py b/workshops/finetune_sam.py new file mode 100644 index 000000000..14ad6360a --- /dev/null +++ b/workshops/finetune_sam.py @@ -0,0 +1,355 @@ +"""Finetuning Segment Anything using µsam. + +This python script shows how to use Segment Anything for Microscopy to fine-tune a Segment Anything Model (SAM) +on an open-source data with multiple channels. + +We use confocal microscopy images from the HPA Kaggle Challenge for protein identification +(from Ouyang et al. - https://doi.org/10.1038/s41592-019-0658-6) in this script for the cell segmentation task. +The functionalities shown here should work for your (microscopy) images too. +""" + +import os +from typing import Union, Tuple, Literal, Optional, List + +import imageio.v3 as imageio +from matplotlib import pyplot as plt +from skimage.measure import label as connected_components + +import torch +from torch.utils.data import DataLoader + +from torch_em.util.debug import check_loader +from torch_em.util.util import get_random_colors + +from micro_sam import util +import micro_sam.training as sam_training +from micro_sam.training.util import normalize_to_8bit +from micro_sam.automatic_segmentation import get_predictor_and_segmenter, automatic_instance_segmentation + +from download_datasets import _get_hpa_data_paths + + +def download_dataset( + path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = True, +) -> Tuple[List[str], List[str]]: + """Download the HPA dataset. + + This functionality downloads the images and corresponding labels stored as `tif` files. + + Args: + path: Filepath to the directory where the data will be stored. + split: The choice of data split. Either 'train', 'val' or 'test'. + download: Whether to download the dataset. + + Returns: + List of filepaths for the image data. + List of filepaths for the label data. + """ + data_path = os.path.join(path, "hpa") + image_paths, label_paths = _get_hpa_data_paths(path=data_path, split=split, download=download) + return image_paths, label_paths + + +def verify_inputs(image_paths: List[str], label_paths: List[str]): + """Verify the downloaded inputs and preprocess them. + + Args: + image_paths: List of filepaths for the image data. + label_paths: List of filepaths for the label data. + """ + for image_path, label_path in zip(image_paths, label_paths): + image = imageio.imread(image_path) + labels = imageio.imread(label_path) + + # The images should be of shape: H, W, 4 -> where, 4 is the number of channels. + if (image.ndim == 3 and image.shape[-1] == 3) or image.ndim == 2: + print(f"Inputs '{image.shape}' match the channel expectations.") + else: + print(f"Inputs '{image.shape}' must match the channel expectations (of either one or three channels).") + + # The labels should be of shape: H, W + print(f"Shape of corresponding labels: '{labels.shape}'") + + break # comment this line out in case you would like to verify the shapes for all inputs. + + +def preprocess_inputs(image_paths: List[str]): + """Preprocess the input images. + + Args: + image_paths: List of filepaths for the image data. + """ + # We remove the 'er' channel, i.e. the last channel. + for image_path in image_paths: + image = imageio.imread(image_path) + + if image.ndim == 3 and image.shape[-1] == 4: # Convert 4 channel inputs to 3 channels. + image = image[..., :-1] + imageio.imwrite(image_path, image) + + +def visualize_inputs(image_paths: List[str], label_paths: List[str]): + """Visualize the images and corresponding labels. + + Args: + image_paths: List of filepaths for the image data. + label_paths: List of filepaths for the label data. + """ + for image_path, label_path in zip(image_paths, label_paths): + image = imageio.imread(image_path) + labels = imageio.imread(label_path) + + fig, ax = plt.subplots(1, 2, figsize=(10, 10)) + ax[0].imshow(image, cmap="gray") + ax[0].set_title("Input Image") + ax[0].axis("off") + + labels = connected_components(labels) + ax[1].imshow(labels, cmap=get_random_colors(labels), interpolation="nearest") + ax[1].set_title("Ground Truth Instances") + ax[1].axis("off") + + plt.show() + plt.close() + + break # comment this out in case you want to visualize all the images + + +def get_dataloaders( + train_image_paths: List[str], + train_label_paths: List[str], + val_image_paths: List[str], + val_label_paths: List[str], + view: bool, + train_instance_segmentation: bool, +) -> Tuple[DataLoader, DataLoader]: + """Get the HPA dataloaders for cell segmentation. + + Args: + train_image_paths: List of filepaths for the training image data. + train_label_paths: List of filepaths for the training label data. + val_image_paths: List of filepaths for the validation image data. + val_label_paths: List of filepaths for the validation label data. + view: Whether to view the samples out of training dataloader. + train_instance_segmentation: Whether to finetune SAM with additional instance segmentation decoder. + + Returns: + The PyTorch DataLoader for training. + The PyTorch DataLoader for validation. + """ + # Load images from tif stacks by setting `raw_key` and `label_key` to None. + raw_key, label_key = None, None + + batch_size = 1 # the training batch size + patch_shape = (512, 512) # the size of patches for training + + train_loader = sam_training.default_sam_loader( + raw_paths=train_image_paths, + raw_key=raw_key, + label_paths=train_label_paths, + label_key=label_key, + is_seg_dataset=False, + patch_shape=patch_shape, + with_channels=True, + with_segmentation_decoder=train_instance_segmentation, + batch_size=batch_size, + shuffle=True, + raw_transform=normalize_to_8bit, + n_samples=100, + ) + val_loader = sam_training.default_sam_loader( + raw_paths=val_image_paths, + raw_key=raw_key, + label_paths=val_label_paths, + label_key=label_key, + is_seg_dataset=False, + patch_shape=patch_shape, + with_channels=True, + with_segmentation_decoder=train_instance_segmentation, + batch_size=batch_size, + shuffle=True, + raw_transform=normalize_to_8bit, + ) + + if view: + check_loader(train_loader, 4, plt=True) + + return train_loader, val_loader + + +def run_finetuning( + train_loader: DataLoader, + val_loader: DataLoader, + save_root: Optional[Union[os.PathLike, str]], + train_instance_segmentation: bool, + device: Union[torch.device, str], + model_type: str, + overwrite: bool, +) -> str: + """Run finetuning for the Segment Anything model on microscopy images. + + Args: + train_loader: The PyTorch dataloader used for training. + val_loader: The PyTorch dataloader used for validation. + save_root: The filepath to the folder where the model checkpoints and tensorboard logs are stored. + train_instance_segmentation: Whether to finetune SAM with additional instance segmentation decoder. + device: The torch device. + model_type: The choice of Segment Anything model (connotated by the size of image encoder). + overwrite: Whether to overwrite the already finetuned model checkpoints. + + Returns: + Filepath where the (best) model checkpoint is stored. + """ + # All hyperparameters for training. + n_objects_per_batch = 5 # the number of objects per batch that will be sampled + n_epochs = 5 # how long we train (in epochs) + + # The name of the checkpoint. The checkpoints will be stored in './checkpoints/' + checkpoint_name = "sam_hpa" + + # Let's spot our best checkpoint and run inference for automatic instance segmentation. + if save_root is None: + save_root = os.getcwd() + + best_checkpoint = os.path.join(save_root, "checkpoints", checkpoint_name, "best.pt") + if os.path.exists(best_checkpoint) and not overwrite: + print( + "It looks like the training has completed. You must pass the argument '--overwrite' to overwrite " + "the already finetuned model (or provide a new filepath at '--save_root' for training new models)." + ) + return best_checkpoint + + # Run training + sam_training.train_sam( + name=checkpoint_name, + save_root=save_root, + model_type=model_type, + train_loader=train_loader, + val_loader=val_loader, + n_epochs=n_epochs, + n_objects_per_batch=n_objects_per_batch, + with_segmentation_decoder=train_instance_segmentation, + device=device, + ) + + return best_checkpoint + + +def run_instance_segmentation_with_decoder( + test_image_paths: List[str], model_type: str, checkpoint: Union[os.PathLike, str], device: Union[torch.device, str], +): + """Run automatic instance segmentation (AIS). + + Args: + test_image_paths: List of filepaths for the test image data. + model_type: The choice of Segment Anything model (connotated by the size of image encoder). + checkpoint: Filepath to the finetuned model checkpoints. + device: The torch device used for inference. + """ + assert os.path.exists(checkpoint), "Please train the model first to run inference on the finetuned model." + + # Get the 'predictor' and 'segmenter' to perform automatic instance segmentation. + predictor, segmenter = get_predictor_and_segmenter(model_type=model_type, checkpoint=checkpoint, device=device) + + for image_path in test_image_paths: + image = imageio.imread(image_path) + image = normalize_to_8bit(image) + + # Predicting the instances. + prediction = automatic_instance_segmentation(predictor=predictor, segmenter=segmenter, input_path=image, ndim=2) + + # Visualize the predictions + fig, ax = plt.subplots(1, 2, figsize=(10, 10)) + + ax[0].imshow(image, cmap="gray") + ax[0].axis("off") + ax[0].set_title("Input Image") + + ax[1].imshow(prediction, cmap=get_random_colors(prediction), interpolation="nearest") + ax[1].axis("off") + ax[1].set_title("Predictions (AIS)") + + plt.show() + plt.close() + + +def main(): + import argparse + parser = argparse.ArgumentParser(description="Run finetuning for Segment Anything model for microscopy images.") + parser.add_argument( + "-i", "--input_path", type=str, default="./data", + help="The filepath to the folder where the image data will be downloaded. " + "By default, the data will be stored in your current working directory at './data'." + ) + parser.add_argument( + "-s", "--save_root", type=str, default=None, + help="The filepath to store the model checkpoint and tensorboard logs. " + "By default, they will be stored in your current working directory at 'checkpoints' and 'logs'." + ) + parser.add_argument( + "--view", action="store_true", + help="Whether to visualize the raw inputs, samples from the dataloader, instance segmentation outputs, etc." + ) + parser.add_argument( + "--overwrite", action="store_true", help="Whether to overwrite the already finetuned model checkpoints." + ) + parser.add_argument( + "--device", type=str, default=None, help="The choice of device to run training and inference." + ) + args = parser.parse_args() + + device = util.get_device(args.device) # the device / GPU used for training and inference. + + # The model_type determines which base model is used to initialize the weights that are finetuned. + # We use vit_b here because it can be trained faster. Note that vit_h usually yields higher quality results. + model_type = "vit_b" + + # Train an additional convolutional decoder for end-to-end automatic instance segmentation + train_instance_segmentation = True + + # Step 1: Download the dataset. + train_image_paths, train_label_paths = download_dataset(path=args.input_path, split="train") + val_image_paths, val_label_paths = download_dataset(path=args.input_path, split="val") + test_image_paths, _ = download_dataset(path=args.input_path, split="test") + + # Step 2: Verify the spatial shape of inputs (only for the 'train' split) + verify_inputs(image_paths=train_image_paths, label_paths=train_label_paths) + + # Step 3: Preprocess input images. + preprocess_inputs(image_paths=train_image_paths) + preprocess_inputs(image_paths=val_image_paths) + preprocess_inputs(image_paths=test_image_paths) + + if args.view: + # Step 3(a): Visualize the images and corresponding labels (only for the 'train' split) + visualize_inputs(image_paths=train_image_paths, label_paths=train_label_paths) + + # Step 4: Get the dataloaders. + train_loader, val_loader = get_dataloaders( + train_image_paths=train_image_paths, + train_label_paths=train_label_paths, + val_image_paths=val_image_paths, + val_label_paths=val_label_paths, + view=args.view, + train_instance_segmentation=train_instance_segmentation, + ) + + # Step 5: Run the finetuning for Segment Anything Model. + checkpoint_path = run_finetuning( + train_loader=train_loader, + val_loader=val_loader, + save_root=args.save_root, + train_instance_segmentation=train_instance_segmentation, + device=device, + model_type=model_type, + overwrite=args.overwrite, + ) + + # Step 6: Run automatic instance segmentation using the finetuned model. + run_instance_segmentation_with_decoder( + test_image_paths=test_image_paths, model_type=model_type, checkpoint=checkpoint_path, device=device, + ) + + +if __name__ == "__main__": + main() From 6330b7267d1bd324f54a0833b503a9f495b994cc Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Fri, 18 Oct 2024 18:14:55 +0200 Subject: [PATCH 53/53] Fix check in automatic instance segmentation --- micro_sam/automatic_segmentation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/micro_sam/automatic_segmentation.py b/micro_sam/automatic_segmentation.py index 6cf9a4868..a051e61d9 100644 --- a/micro_sam/automatic_segmentation.py +++ b/micro_sam/automatic_segmentation.py @@ -106,7 +106,7 @@ def automatic_instance_segmentation( ndim = image_data.ndim if ndim is None else ndim if ndim == 2: - if image_data.ndim != 2 or image_data.shape[-1] != 3: + if (image_data.ndim != 2) and (image_data.ndim != 3 and image_data.shape[-1] != 3): raise ValueError(f"The inputs does not match the shape expectation of 2d inputs: {image_data.shape}") # Precompute the image embeddings. @@ -135,7 +135,7 @@ def automatic_instance_segmentation( else: instances = mask_data_to_segmentation(masks, with_background=True, min_object_size=0) else: - if image_data.ndim != 3 or image_data.shape[-1] != 3: + if (image_data.ndim != 3) and (image_data.ndim != 4 and image_data.shape[-1] != 3): raise ValueError(f"The inputs does not match the shape expectation of 3d inputs: {image_data.shape}") instances = automatic_3d_segmentation(