From ed4d5c9449e4ba468ea993563a60cfa31bd9871c Mon Sep 17 00:00:00 2001 From: Carolin Teuber Date: Mon, 15 Jul 2024 17:35:23 +0200 Subject: [PATCH 01/15] include lora in the evaluation and allow to load models without giving a checkpoint --- finetuning/evaluation/evaluate_amg.py | 11 ++++------- .../evaluate_instance_segmentation.py | 7 ++++--- finetuning/evaluation/iterative_prompting.py | 3 ++- finetuning/evaluation/precompute_embeddings.py | 3 ++- finetuning/evaluation/util.py | 5 ++++- .../resource-efficient/covid_if_finetuning.py | 18 ++++++++++++++---- micro_sam/evaluation/inference.py | 7 +++++-- micro_sam/instance_segmentation.py | 3 ++- micro_sam/training/training.py | 5 +++-- 9 files changed, 40 insertions(+), 22 deletions(-) diff --git a/finetuning/evaluation/evaluate_amg.py b/finetuning/evaluation/evaluate_amg.py index 69ec63ef..df158a67 100644 --- a/finetuning/evaluation/evaluate_amg.py +++ b/finetuning/evaluation/evaluate_amg.py @@ -7,7 +7,7 @@ from util import get_pred_paths, get_default_arguments, VANILLA_MODELS -def run_amg_inference(dataset_name, model_type, checkpoint, experiment_folder): +def run_amg_inference(dataset_name, model_type, checkpoint, experiment_folder, lora_rank=None): val_image_paths, val_gt_paths = get_paths(dataset_name, split="val") test_image_paths, _ = get_paths(dataset_name, split="test") prediction_folder = run_amg( @@ -16,7 +16,8 @@ def run_amg_inference(dataset_name, model_type, checkpoint, experiment_folder): experiment_folder, val_image_paths, val_gt_paths, - test_image_paths + test_image_paths, + lora_rank=lora_rank, ) return prediction_folder @@ -32,12 +33,8 @@ def eval_amg(dataset_name, prediction_folder, experiment_folder): def main(): args = get_default_arguments() - if args.checkpoint is None: - ckpt = VANILLA_MODELS[args.model] - else: - ckpt = args.checkpoint - prediction_folder = run_amg_inference(args.dataset, args.model, ckpt, args.experiment_folder) + prediction_folder = run_amg_inference(args.dataset, args.model, args.checkpoint, args.experiment_folder, args.lora_rank) eval_amg(args.dataset, prediction_folder, args.experiment_folder) diff --git a/finetuning/evaluation/evaluate_instance_segmentation.py b/finetuning/evaluation/evaluate_instance_segmentation.py index 70da7635..0fce1799 100644 --- a/finetuning/evaluation/evaluate_instance_segmentation.py +++ b/finetuning/evaluation/evaluate_instance_segmentation.py @@ -7,7 +7,7 @@ from util import get_pred_paths, get_default_arguments -def run_instance_segmentation_with_decoder_inference(dataset_name, model_type, checkpoint, experiment_folder): +def run_instance_segmentation_with_decoder_inference(dataset_name, model_type, checkpoint, experiment_folder, lora_rank): val_image_paths, val_gt_paths = get_paths(dataset_name, split="val") test_image_paths, _ = get_paths(dataset_name, split="test") prediction_folder = run_instance_segmentation_with_decoder( @@ -16,7 +16,8 @@ def run_instance_segmentation_with_decoder_inference(dataset_name, model_type, c experiment_folder, val_image_paths, val_gt_paths, - test_image_paths + test_image_paths, + lora_rank=lora_rank, ) return prediction_folder @@ -34,7 +35,7 @@ def main(): args = get_default_arguments() prediction_folder = run_instance_segmentation_with_decoder_inference( - args.dataset, args.model, args.checkpoint, args.experiment_folder + args.dataset, args.model, args.checkpoint, args.experiment_folder, args.lora_rank ) eval_instance_segmentation_with_decoder(args.dataset, prediction_folder, args.experiment_folder) diff --git a/finetuning/evaluation/iterative_prompting.py b/finetuning/evaluation/iterative_prompting.py index 08c0cf3b..0187e245 100644 --- a/finetuning/evaluation/iterative_prompting.py +++ b/finetuning/evaluation/iterative_prompting.py @@ -5,6 +5,7 @@ from util import get_paths # comment this and create a custom function with the same name to run int. seg. on your data from util import get_model, get_default_arguments +from micro_sam.util import get_sam_model def _run_iterative_prompting(dataset_name, exp_folder, predictor, start_with_box_prompt, use_masks): @@ -42,7 +43,7 @@ def main(): start_with_box_prompt = args.box # overwrite to start first iters' prompt with box instead of single point # get the predictor to perform inference - predictor = get_model(model_type=args.model, ckpt=args.checkpoint) + predictor = get_sam_model(model_type=args.model, checkpoint_path=args.checkpoint, lora_rank=args.lora_rank) prediction_root = _run_iterative_prompting( args.dataset, args.experiment_folder, predictor, start_with_box_prompt, args.use_masks diff --git a/finetuning/evaluation/precompute_embeddings.py b/finetuning/evaluation/precompute_embeddings.py index 438cba59..0356c015 100644 --- a/finetuning/evaluation/precompute_embeddings.py +++ b/finetuning/evaluation/precompute_embeddings.py @@ -4,12 +4,13 @@ from util import get_paths # comment this and create a custom function with the same name to execute on your data from util import get_model, get_default_arguments +from micro_sam.util import get_sam_model def main(): args = get_default_arguments() - predictor = get_model(model_type=args.model, ckpt=args.checkpoint) + predictor = get_sam_model(model_type=args.model, checkpoint_path=args.checkpoint, lora_rank=args.lora_rank) embedding_dir = os.path.join(args.experiment_folder, "embeddings") os.makedirs(embedding_dir, exist_ok=True) diff --git a/finetuning/evaluation/util.py b/finetuning/evaluation/util.py index 8b1716e8..a3c93291 100644 --- a/finetuning/evaluation/util.py +++ b/finetuning/evaluation/util.py @@ -219,13 +219,16 @@ def get_default_arguments(): parser.add_argument( "-m", "--model", type=str, required=True, help="Provide the model type to initialize the predictor" ) - parser.add_argument("-c", "--checkpoint", type=none_or_str, required=True, default=None) + parser.add_argument("-c", "--checkpoint", type=none_or_str, default=None) parser.add_argument("-e", "--experiment_folder", type=str, required=True) parser.add_argument("-d", "--dataset", type=str, default=None) parser.add_argument("--box", action="store_true", help="If passed, starts with first prompt as box") parser.add_argument( "--use_masks", action="store_true", help="To use logits masks for iterative prompting." ) + parser.add_argument( + "--lora_rank", type=int, default=None, help="The rank of the LoRA if LoRA model is used for inference." + ) args = parser.parse_args() return args diff --git a/finetuning/specialists/resource-efficient/covid_if_finetuning.py b/finetuning/specialists/resource-efficient/covid_if_finetuning.py index 26108721..016970bb 100644 --- a/finetuning/specialists/resource-efficient/covid_if_finetuning.py +++ b/finetuning/specialists/resource-efficient/covid_if_finetuning.py @@ -80,13 +80,13 @@ def finetune_covid_if(args): patch_shape = (512, 512) # the patch shape for training n_objects_per_batch = args.n_objects # the number of objects per batch that will be sampled freeze_parts = args.freeze # override this to freeze different parts of the model - checkpoint_name = f"{args.model_type}/covid_if_sam" - + checkpoint_name = f"{model_type}/{args.checkpoint_name}" # all stuff we need for training train_loader, val_loader = get_dataloaders( patch_shape=patch_shape, data_path=args.input_path, n_images=args.n_images ) scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 3, "verbose": True} + optimizer_class = torch.optim.AdamW # Run training sam_training.train_sam( @@ -99,12 +99,13 @@ def finetune_covid_if(args): checkpoint_path=checkpoint_path, freeze=freeze_parts, device=device, - lr=1e-5, + lr=args.lr, n_epochs=args.epochs, save_root=args.save_root, scheduler_kwargs=scheduler_kwargs, save_every_kth_epoch=args.save_every_kth_epoch, - + optimizer_class=optimizer_class, + lora_rank=args.lora_rank ) @@ -148,6 +149,15 @@ def main(): parser.add_argument( "--n_images", type=int, default=None, help="The number of images used for finetuning." ) + parser.add_argument( + "--lora_rank", type=int, default=None, help="The rank of the LoRA model." + ) + parser.add_argument( + "--lr", type=float, default=5e-5, help="The learning rate for the optimizer. Default is 5e-5." + ) + parser.add_argument( + "--checkpoint_name", type=str, default="covid_if_sam", + ) args = parser.parse_args() finetune_covid_if(args) diff --git a/micro_sam/evaluation/inference.py b/micro_sam/evaluation/inference.py index 1905fc77..afe47e05 100644 --- a/micro_sam/evaluation/inference.py +++ b/micro_sam/evaluation/inference.py @@ -547,11 +547,12 @@ def run_amg( test_image_paths: List[Union[str, os.PathLike]], iou_thresh_values: Optional[List[float]] = None, stability_score_values: Optional[List[float]] = None, + lora_rank: Optional[int] = None, ) -> str: embedding_folder = os.path.join(experiment_folder, "embeddings") # where the precomputed embeddings are saved os.makedirs(embedding_folder, exist_ok=True) - predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint) + predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint, lora_rank=lora_rank) amg = AutomaticMaskGenerator(predictor) amg_prefix = "amg" @@ -588,11 +589,13 @@ def run_instance_segmentation_with_decoder( val_image_paths: List[Union[str, os.PathLike]], val_gt_paths: List[Union[str, os.PathLike]], test_image_paths: List[Union[str, os.PathLike]], + lora_rank: Optional[int] = None, ) -> str: + embedding_folder = os.path.join(experiment_folder, "embeddings") # where the precomputed embeddings are saved os.makedirs(embedding_folder, exist_ok=True) - predictor, decoder = get_predictor_and_decoder(model_type=model_type, checkpoint_path=checkpoint) + predictor, decoder = get_predictor_and_decoder(model_type=model_type, checkpoint_path=checkpoint, lora_rank=lora_rank) segmenter = InstanceSegmentationWithDecoder(predictor, decoder) seg_prefix = "instance_segmentation_with_decoder" diff --git a/micro_sam/instance_segmentation.py b/micro_sam/instance_segmentation.py index 23d666b9..6947303f 100644 --- a/micro_sam/instance_segmentation.py +++ b/micro_sam/instance_segmentation.py @@ -798,6 +798,7 @@ def get_predictor_and_decoder( model_type: str, checkpoint_path: Union[str, os.PathLike], device: Optional[Union[str, torch.device]] = None, + lora_rank: Optional[int] = None, ) -> Tuple[SamPredictor, DecoderAdapter]: """Load the SAM model (predictor) and instance segmentation decoder. @@ -816,7 +817,7 @@ def get_predictor_and_decoder( device = util.get_device(device) predictor, state = util.get_sam_model( model_type=model_type, checkpoint_path=checkpoint_path, - device=device, return_state=True + device=device, return_state=True, lora_rank=lora_rank ) if "decoder_state" not in state: raise ValueError(f"The checkpoint at {checkpoint_path} does not contain a decoder state") diff --git a/micro_sam/training/training.py b/micro_sam/training/training.py index bdb40168..abcaf7f0 100644 --- a/micro_sam/training/training.py +++ b/micro_sam/training/training.py @@ -148,6 +148,7 @@ def train_sam( save_every_kth_epoch: Optional[int] = None, pbar_signals: Optional[QObject] = None, optimizer_class: Optional[Optimizer] = torch.optim.AdamW, + lora_rank: Optional[int] = None, **model_kwargs, ) -> None: """Run training for a SAM model. @@ -183,12 +184,12 @@ def train_sam( If passed None, the chosen default parameters are used in ReduceLROnPlateau. save_every_kth_epoch: Save checkpoints after every kth epoch separately. pbar_signals: Controls for napari progress bar. + lora_rank: The rank of the LoRA Training """ _check_loader(train_loader, with_segmentation_decoder) _check_loader(val_loader, with_segmentation_decoder) device = get_device(device) - # Get the trainable segment anything model. model, state = get_trainable_sam_model( model_type=model_type, @@ -196,9 +197,9 @@ def train_sam( freeze=freeze, checkpoint_path=checkpoint_path, return_state=True, + lora_rank=lora_rank, **model_kwargs ) - # This class creates all the training data for a batch (inputs, prompts and labels). convert_inputs = ConvertToSamInputs(transform=model.transform, box_distortion_factor=0.025) From a523bc7c20d9ad1fc9324601d494c11f0f381a05 Mon Sep 17 00:00:00 2001 From: Carolin Teuber Date: Tue, 16 Jul 2024 14:52:35 +0200 Subject: [PATCH 02/15] minor changes in livecell training script and adaptation of evaluation script to include lora functionality --- finetuning/evaluation/submit_all_evaluation.py | 12 +++++++++--- finetuning/livecell/lora/train_livecell.py | 12 ++++++++---- 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/finetuning/evaluation/submit_all_evaluation.py b/finetuning/evaluation/submit_all_evaluation.py index b64549de..a481d8f4 100644 --- a/finetuning/evaluation/submit_all_evaluation.py +++ b/finetuning/evaluation/submit_all_evaluation.py @@ -14,7 +14,7 @@ def write_batch_script( env_name, out_path, inference_setup, checkpoint, model_type, - experiment_folder, dataset_name, delay=None, use_masks=False + experiment_folder, dataset_name, delay=None, use_masks=False, lora_rank=NotImplementedError ): "Writing scripts with different fold-trainings for micro-sam evaluation" batch_script = f"""#!/bin/bash @@ -23,7 +23,7 @@ def write_batch_script( #SBATCH -t 4-00:00:00 #SBATCH -p grete:shared #SBATCH -G A100:1 -#SBATCH -A gzz0001 +#SBATCH -A nim00007 #SBATCH --constraint=80gb #SBATCH --qos=96h #SBATCH --job-name={inference_setup} @@ -55,9 +55,13 @@ def write_batch_script( # use logits for iterative prompting if inference_setup == "iterative_prompting" and use_masks: python_script += "--use_masks " + + if lora_rank is not None: + python_script += f"--lora_rank {lora_rank} " # let's add the python script to the bash script batch_script += python_script + print(batch_script) with open(_op, "w") as f: f.write(batch_script) @@ -175,7 +179,8 @@ def submit_slurm(args): experiment_folder=experiment_folder, dataset_name=dataset_name, delay=None if current_setup == "precompute_embeddings" else make_delay, - use_masks=args.use_masks + use_masks=args.use_masks, + lora_rank=args.lora_rank ) # the logic below automates the process of first running the precomputation of embeddings, and only then inference. @@ -219,6 +224,7 @@ def main(args): # ask for a specific experiment parser.add_argument("-s", "--specific_experiment", type=str, default=None) + parser.add_argument("--lora_rank", type=int, default=None) args = parser.parse_args() main(args) diff --git a/finetuning/livecell/lora/train_livecell.py b/finetuning/livecell/lora/train_livecell.py index 6b12ac61..0bce2287 100644 --- a/finetuning/livecell/lora/train_livecell.py +++ b/finetuning/livecell/lora/train_livecell.py @@ -57,12 +57,13 @@ def finetune_livecell(args): patch_shape = (520, 704) # the patch shape for training n_objects_per_batch = args.n_objects # this is the number of objects per batch that will be sampled freeze_parts = args.freeze # override this to freeze different parts of the model - lora_rank = 4 # the rank for low rank adaptation - checkpoint_name = f"{args.model_type}/livecell_sam" + lora_rank = args.lora_rank # the rank for low rank adaptation + checkpoint_ending = f"{lora_rank}" if lora_rank is not None else "full_ft" + checkpoint_name = f"{args.model_type}/livecell_sam_{checkpoint_ending}" # all the stuff we need for training train_loader, val_loader = get_dataloaders(patch_shape=patch_shape, data_path=args.input_path) - scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 10, "verbose": True} + scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 3, "verbose": True} optimizer_class = torch.optim.AdamW # Run training. @@ -71,7 +72,7 @@ def finetune_livecell(args): model_type=model_type, train_loader=train_loader, val_loader=val_loader, - early_stopping=None, + early_stopping=10, n_objects_per_batch=n_objects_per_batch, checkpoint_path=checkpoint_path, freeze=freeze_parts, @@ -122,6 +123,9 @@ def main(): parser.add_argument( "--n_objects", type=int, default=25, help="The number of instances (objects) per batch used for finetuning." ) + parser.add_argument( + "--lora_rank", type=int, default=None, help="The rank for low rank adaptation." + ) args = parser.parse_args() finetune_livecell(args) From 8dd7556998c6fc89713f0c0fb85b0fe5a867dbf5 Mon Sep 17 00:00:00 2001 From: Carolin Teuber Date: Thu, 25 Jul 2024 08:33:13 +0200 Subject: [PATCH 03/15] changed default training to 100k epochs with early stopping for lora --- finetuning/livecell/lora/train_livecell.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/finetuning/livecell/lora/train_livecell.py b/finetuning/livecell/lora/train_livecell.py index 0bce2287..72fff5cc 100644 --- a/finetuning/livecell/lora/train_livecell.py +++ b/finetuning/livecell/lora/train_livecell.py @@ -109,7 +109,7 @@ def main(): help="Where to save the checkpoint and logs. By default they will be saved where this script is run." ) parser.add_argument( - "--iterations", type=int, default=int(1e4), + "--iterations", type=int, default=int(1e5), help="For how many iterations should the model be trained? By default 100k." ) parser.add_argument( From 4af78b155136516afa4d520cff42fe4c5f006f98 Mon Sep 17 00:00:00 2001 From: GOESTERN-0886323 Date: Fri, 2 Aug 2024 15:37:47 +0200 Subject: [PATCH 04/15] removed necessity of checkpoint --- finetuning/evaluation/evaluate_amg.py | 2 +- finetuning/evaluation/util.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/finetuning/evaluation/evaluate_amg.py b/finetuning/evaluation/evaluate_amg.py index 36ff8fb3..931a5b73 100644 --- a/finetuning/evaluation/evaluate_amg.py +++ b/finetuning/evaluation/evaluate_amg.py @@ -34,7 +34,7 @@ def eval_amg(dataset_name, prediction_folder, experiment_folder): def main(): args = get_default_arguments() - prediction_folder = run_amg_inference(args.dataset, args.model, ckpt, args.experiment_folder, args.lora_rank) + prediction_folder = run_amg_inference(args.dataset, args.model, args.checkpoint, args.experiment_folder, args.lora_rank) eval_amg(args.dataset, prediction_folder, args.experiment_folder) diff --git a/finetuning/evaluation/util.py b/finetuning/evaluation/util.py index 8a313298..5afde58d 100644 --- a/finetuning/evaluation/util.py +++ b/finetuning/evaluation/util.py @@ -81,8 +81,6 @@ def get_dataset_paths(dataset_name, split_choice): def get_model(model_type, ckpt, lora_rank): - if ckpt is None: - ckpt = VANILLA_MODELS[model_type] predictor = get_sam_model( model_type=model_type, checkpoint_path=ckpt, lora_rank=lora_rank, From c80377d324e75f0a8f888bd8eb60dc0b23018f56 Mon Sep 17 00:00:00 2001 From: Carolin Date: Tue, 13 Aug 2024 14:54:48 +0200 Subject: [PATCH 05/15] first draft of FacTSurgery --- micro_sam/models/peft_sam.py | 52 +++++++++++++++++++++++++++++++++++- 1 file changed, 51 insertions(+), 1 deletion(-) diff --git a/micro_sam/models/peft_sam.py b/micro_sam/models/peft_sam.py index 2bdeed70..e035e4de 100644 --- a/micro_sam/models/peft_sam.py +++ b/micro_sam/models/peft_sam.py @@ -56,6 +56,50 @@ def forward(self, x): return qkv +class FacT_Surgery(nn.Module): + """Operates on the attention layers for performing factorized attention. + + (Inspired from: https://github.com/cchen-cc/MA-SAM/blob/main/MA-SAM/sam_fact_tt_image_encoder.py) + + """ + + def __init__( + self, + rank: int, + block: nn.Module, + FacTu: nn.Module, + FacTv: nn.Module, + ): + super().__init__() + self.qkv_proj = block.attn.qkv + self.dim = self.qkv_proj.in_features + + self.q_FacTs = nn.Linear(rank, rank, bias=False) + self.v_FacTs = nn.Linear(rank, rank, bias=False) + + self.dp_q = nn.Dropout(0.1) + self.dp_v = nn.Dropout(0.1) + + self.FacTu = FacTu + self.FacTv = FacTv + + block.attn.qkv = self + + + def forward(self, x): + + qkv = self.qkv_proj(x) # B,N,N,3*org_C + new_q = self.FacTv(self.dp_q(self.q_FacTs(self.FacTu(x)))) + new_v = self.FacTv(self.dp_v(self.v_FacTs(self.FacTu(x)))) + # NOTE : Scaling Factor was set to 1 as it can be tuned via the learning rate + # Does it make sense to include it, in order to have similar learning rate as the original model? + qkv[:, :, :, : self.dim] += new_q + qkv[:, :, :, -self.dim:] += new_v + return qkv + + + + class PEFT_Sam(nn.Module): """Wraps the Segment Anything model's image encoder to different parameter efficient finetuning methods. @@ -79,6 +123,9 @@ def __init__( assert rank > 0 + self.FacTu = nn.Linear(model.dim, rank, bias=False) + self.FacTv = nn.Linear(rank, model.dim, bias=False) + if attention_layers_to_update: self.peft_layers = attention_layers_to_update else: # Applies PEFT to the image encoder by default @@ -95,8 +142,11 @@ def __init__( # If we only want specific layers with PEFT instead of all if t_layer_i not in self.peft_layers: continue + if peft_module == LoRASurgery: + peft_block = self.peft_module(rank=rank, block=blk) + else: + peft_block = self.peft_module(rank=rank, block=blk, FacTu=self.FacTu, FacTv=self.FacTv) - peft_block = self.peft_module(rank=rank, block=blk) self.peft_blocks.append(peft_block) self.peft_blocks = nn.ModuleList(self.peft_blocks) From 12d76b0afbcf29e8e100f0ed8fb2a902b8aa1b3a Mon Sep 17 00:00:00 2001 From: Carolin Date: Tue, 13 Aug 2024 15:28:33 +0200 Subject: [PATCH 06/15] added test for FacT and made minor change in PEFT_Sam --- micro_sam/models/peft_sam.py | 7 ++++--- test/test_models/test_peft_sam.py | 19 +++++++++++++++++-- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/micro_sam/models/peft_sam.py b/micro_sam/models/peft_sam.py index e035e4de..ef16ac4d 100644 --- a/micro_sam/models/peft_sam.py +++ b/micro_sam/models/peft_sam.py @@ -56,7 +56,7 @@ def forward(self, x): return qkv -class FacT_Surgery(nn.Module): +class FacTSurgery(nn.Module): """Operates on the attention layers for performing factorized attention. (Inspired from: https://github.com/cchen-cc/MA-SAM/blob/main/MA-SAM/sam_fact_tt_image_encoder.py) @@ -123,8 +123,9 @@ def __init__( assert rank > 0 - self.FacTu = nn.Linear(model.dim, rank, bias=False) - self.FacTv = nn.Linear(rank, model.dim, bias=False) + dim = model.image_encoder.blocks[0].attn.qkv.in_features + self.FacTu = nn.Linear(dim, rank, bias=False) + self.FacTv = nn.Linear(rank, dim, bias=False) if attention_layers_to_update: self.peft_layers = attention_layers_to_update diff --git a/test/test_models/test_peft_sam.py b/test/test_models/test_peft_sam.py index 1af3ef2c..4c66702a 100644 --- a/test/test_models/test_peft_sam.py +++ b/test/test_models/test_peft_sam.py @@ -1,5 +1,6 @@ import unittest +from micro_sam.models.peft_sam import FacTSurgery, LoRASurgery import torch import micro_sam.util as util @@ -7,11 +8,25 @@ class TestPEFTSam(unittest.TestCase): model_type = "vit_b" - def test_peft_sam(self): + def test_lora_sam(self): from micro_sam.models.peft_sam import PEFT_Sam _, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu") - peft_sam = PEFT_Sam(sam, rank=2) + peft_sam = PEFT_Sam(sam, rank=2, peft_module=LoRASurgery) + + shape = (3, 1024, 1024) + expected_shape = (1, 3, 1024, 1024) + with torch.no_grad(): + batched_input = [{"image": torch.rand(*shape), "original_size": shape[1:]}] + output = peft_sam(batched_input, multimask_output=True) + masks = output[0]["masks"] + self.assertEqual(masks.shape, expected_shape) + + def test_fact_sam(self): + from micro_sam.models.peft_sam import PEFT_Sam + + _, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu") + peft_sam = PEFT_Sam(sam, rank=2, peft_module=FacTSurgery) shape = (3, 1024, 1024) expected_shape = (1, 3, 1024, 1024) From e6031a44e712a03eea56e9a413e53d18c8dbf432 Mon Sep 17 00:00:00 2001 From: Carolin Date: Wed, 14 Aug 2024 14:37:51 +0200 Subject: [PATCH 07/15] changed implementation in peft arguments --- micro_sam/models/peft_sam.py | 3 +-- micro_sam/training/training.py | 4 ++-- micro_sam/training/util.py | 8 ++++---- micro_sam/util.py | 8 ++++---- 4 files changed, 11 insertions(+), 12 deletions(-) diff --git a/micro_sam/models/peft_sam.py b/micro_sam/models/peft_sam.py index ef16ac4d..4378501f 100644 --- a/micro_sam/models/peft_sam.py +++ b/micro_sam/models/peft_sam.py @@ -77,6 +77,7 @@ def __init__( self.q_FacTs = nn.Linear(rank, rank, bias=False) self.v_FacTs = nn.Linear(rank, rank, bias=False) + # NOTE : Dropout is not included in the original implementation self.dp_q = nn.Dropout(0.1) self.dp_v = nn.Dropout(0.1) @@ -98,8 +99,6 @@ def forward(self, x): return qkv - - class PEFT_Sam(nn.Module): """Wraps the Segment Anything model's image encoder to different parameter efficient finetuning methods. diff --git a/micro_sam/training/training.py b/micro_sam/training/training.py index 065d53ff..4e0b72a9 100644 --- a/micro_sam/training/training.py +++ b/micro_sam/training/training.py @@ -154,7 +154,7 @@ def train_sam( save_every_kth_epoch: Optional[int] = None, pbar_signals: Optional[QObject] = None, optimizer_class: Optional[Optimizer] = torch.optim.AdamW, - lora_rank: Optional[int] = None, + peft_kwargs: Optional[Dict] = None, **model_kwargs, ) -> None: """Run training for a SAM model. @@ -204,7 +204,7 @@ def train_sam( freeze=freeze, checkpoint_path=checkpoint_path, return_state=True, - lora_rank=lora_rank, + peft_kwargs=peft_kwargs, **model_kwargs ) # This class creates all the training data for a batch (inputs, prompts and labels). diff --git a/micro_sam/training/util.py b/micro_sam/training/util.py index 759c905e..7650c2c3 100644 --- a/micro_sam/training/util.py +++ b/micro_sam/training/util.py @@ -44,8 +44,7 @@ def get_trainable_sam_model( checkpoint_path: Optional[Union[str, os.PathLike]] = None, freeze: Optional[List[str]] = None, return_state: bool = False, - lora_rank: Optional[int] = None, - lora_kwargs: Optional[Dict] = None, + peft_kwargs: Optional[Dict] = None, flexible_load_checkpoint: bool = False, **model_kwargs ) -> TrainableSAM: @@ -84,8 +83,9 @@ def get_trainable_sam_model( # NOTE: This is done exclusive to "get_sam_model" here to use PEFT's layer-specific initialization on top. # Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything. # Overwrites the SAM model by freezing the backbone and allow low rank adaption to attention layers. - if lora_rank is not None: - sam = custom_models.peft_sam.PEFT_Sam(sam, rank=lora_rank, **({} if lora_kwargs is None else lora_kwargs)).sam + if peft_kwargs is not None: + assert peft_kwargs['module'] in ['LoRASurgery', 'FacTSurgery'], "Invalid PEFT module." + sam = custom_models.peft_sam.PEFT_Sam(sam, rank=peft_kwargs['rank'], peft_module=peft_kwargs['module']).sam # freeze components of the model if freeze was passed # ideally we would want to add components in such a way that: diff --git a/micro_sam/util.py b/micro_sam/util.py index 45550a49..b5bee771 100644 --- a/micro_sam/util.py +++ b/micro_sam/util.py @@ -273,8 +273,7 @@ def get_sam_model( checkpoint_path: Optional[Union[str, os.PathLike]] = None, return_sam: bool = False, return_state: bool = False, - lora_rank: Optional[int] = None, - lora_kwargs: Optional[Dict] = None, + peft_kwargs: Optional[Dict] = None, flexible_load_checkpoint: bool = False, **model_kwargs, ) -> SamPredictor: @@ -371,10 +370,11 @@ def get_sam_model( # Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything. # Overwrites the SAM model by freezing the backbone and allow low rank adaption to attention layers. - if lora_rank is not None: + if peft_kwargs is not None: + assert peft_kwargs["module"] in ["LoRASurgery", "FacTSurgery"], "Invalid PEFT module." if abbreviated_model_type == "vit_t": raise ValueError("Parameter efficient finetuning is not supported for 'mobile-sam'.") - sam = custom_models.peft_sam.PEFT_Sam(sam, rank=lora_rank, **({} if lora_kwargs is None else lora_kwargs)).sam + sam = custom_models.peft_sam.PEFT_Sam(sam, rank=peft_kwargs['rank'], peft_module=peft_kwargs['module']).sam # In case the model checkpoints have some issues when it is initialized with different parameters than default. if flexible_load_checkpoint: From db64f572728532d1c1c4003692d02bb89d0ef225 Mon Sep 17 00:00:00 2001 From: Carolin Date: Sat, 31 Aug 2024 09:47:55 +0200 Subject: [PATCH 08/15] branch cleanup --- finetuning/evaluation/iterative_prompting.py | 1 - finetuning/evaluation/precompute_embeddings.py | 1 - finetuning/evaluation/submit_all_evaluation.py | 14 ++++---------- finetuning/livecell/lora/train_livecell.py | 16 ++++++---------- .../resource-efficient/covid_if_finetuning.py | 8 ++++---- micro_sam/evaluation/inference.py | 1 - 6 files changed, 14 insertions(+), 27 deletions(-) diff --git a/finetuning/evaluation/iterative_prompting.py b/finetuning/evaluation/iterative_prompting.py index 3909babf..eae3f845 100644 --- a/finetuning/evaluation/iterative_prompting.py +++ b/finetuning/evaluation/iterative_prompting.py @@ -5,7 +5,6 @@ from util import get_paths # comment this and create a custom function with the same name to run int. seg. on your data from util import get_model, get_default_arguments -from micro_sam.util import get_sam_model def _run_iterative_prompting(dataset_name, exp_folder, predictor, start_with_box_prompt, use_masks): diff --git a/finetuning/evaluation/precompute_embeddings.py b/finetuning/evaluation/precompute_embeddings.py index 7a85dda1..605627fe 100644 --- a/finetuning/evaluation/precompute_embeddings.py +++ b/finetuning/evaluation/precompute_embeddings.py @@ -4,7 +4,6 @@ from util import get_paths # comment this and create a custom function with the same name to execute on your data from util import get_model, get_default_arguments -from micro_sam.util import get_sam_model def main(): diff --git a/finetuning/evaluation/submit_all_evaluation.py b/finetuning/evaluation/submit_all_evaluation.py index a481d8f4..465d96df 100644 --- a/finetuning/evaluation/submit_all_evaluation.py +++ b/finetuning/evaluation/submit_all_evaluation.py @@ -14,7 +14,7 @@ def write_batch_script( env_name, out_path, inference_setup, checkpoint, model_type, - experiment_folder, dataset_name, delay=None, use_masks=False, lora_rank=NotImplementedError + experiment_folder, dataset_name, delay=None, use_masks=False ): "Writing scripts with different fold-trainings for micro-sam evaluation" batch_script = f"""#!/bin/bash @@ -23,7 +23,7 @@ def write_batch_script( #SBATCH -t 4-00:00:00 #SBATCH -p grete:shared #SBATCH -G A100:1 -#SBATCH -A nim00007 +#SBATCH -A gzz0001 #SBATCH --constraint=80gb #SBATCH --qos=96h #SBATCH --job-name={inference_setup} @@ -55,13 +55,9 @@ def write_batch_script( # use logits for iterative prompting if inference_setup == "iterative_prompting" and use_masks: python_script += "--use_masks " - - if lora_rank is not None: - python_script += f"--lora_rank {lora_rank} " # let's add the python script to the bash script batch_script += python_script - print(batch_script) with open(_op, "w") as f: f.write(batch_script) @@ -179,8 +175,7 @@ def submit_slurm(args): experiment_folder=experiment_folder, dataset_name=dataset_name, delay=None if current_setup == "precompute_embeddings" else make_delay, - use_masks=args.use_masks, - lora_rank=args.lora_rank + use_masks=args.use_masks ) # the logic below automates the process of first running the precomputation of embeddings, and only then inference. @@ -224,7 +219,6 @@ def main(args): # ask for a specific experiment parser.add_argument("-s", "--specific_experiment", type=str, default=None) - parser.add_argument("--lora_rank", type=int, default=None) args = parser.parse_args() - main(args) + main(args) \ No newline at end of file diff --git a/finetuning/livecell/lora/train_livecell.py b/finetuning/livecell/lora/train_livecell.py index 369c4f13..450a640c 100644 --- a/finetuning/livecell/lora/train_livecell.py +++ b/finetuning/livecell/lora/train_livecell.py @@ -51,13 +51,12 @@ def finetune_livecell(args): patch_shape = (520, 704) # the patch shape for training n_objects_per_batch = args.n_objects # this is the number of objects per batch that will be sampled freeze_parts = args.freeze # override this to freeze different parts of the model - lora_rank = args.lora_rank # the rank for low rank adaptation - checkpoint_ending = f"{lora_rank}" if lora_rank is not None else "full_ft" - checkpoint_name = f"{args.model_type}/livecell_sam_{checkpoint_ending}" + lora_rank = 4 # the rank for low rank adaptation + checkpoint_name = f"{args.model_type}/livecell_sam" # all the stuff we need for training train_loader, val_loader = get_dataloaders(patch_shape=patch_shape, data_path=args.input_path) - scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 3, "verbose": True} + scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 10, "verbose": True} optimizer_class = torch.optim.AdamW # Run training. @@ -66,7 +65,7 @@ def finetune_livecell(args): model_type=model_type, train_loader=train_loader, val_loader=val_loader, - early_stopping=10, + early_stopping=None, n_objects_per_batch=n_objects_per_batch, checkpoint_path=checkpoint_path, freeze=freeze_parts, @@ -103,7 +102,7 @@ def main(): help="Where to save the checkpoint and logs. By default they will be saved where this script is run." ) parser.add_argument( - "--iterations", type=int, default=int(1e5), + "--iterations", type=int, default=int(1e4), help="For how many iterations should the model be trained? By default 100k." ) parser.add_argument( @@ -117,12 +116,9 @@ def main(): parser.add_argument( "--n_objects", type=int, default=25, help="The number of instances (objects) per batch used for finetuning." ) - parser.add_argument( - "--lora_rank", type=int, default=None, help="The rank for low rank adaptation." - ) args = parser.parse_args() finetune_livecell(args) if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/finetuning/specialists/resource-efficient/covid_if_finetuning.py b/finetuning/specialists/resource-efficient/covid_if_finetuning.py index d33046b5..632f8ba2 100644 --- a/finetuning/specialists/resource-efficient/covid_if_finetuning.py +++ b/finetuning/specialists/resource-efficient/covid_if_finetuning.py @@ -80,13 +80,13 @@ def finetune_covid_if(args): patch_shape = (512, 512) # the patch shape for training n_objects_per_batch = args.n_objects # the number of objects per batch that will be sampled freeze_parts = args.freeze # override this to freeze different parts of the model - checkpoint_name = f"{model_type}/{args.checkpoint_name}" + checkpoint_name = f"{args.model_type}/covid_if_sam" + # all stuff we need for training train_loader, val_loader = get_dataloaders( patch_shape=patch_shape, data_path=args.input_path, n_images=args.n_images ) scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 3, "verbose": True} - optimizer_class = torch.optim.AdamW # Run training sam_training.train_sam( @@ -99,7 +99,7 @@ def finetune_covid_if(args): checkpoint_path=checkpoint_path, freeze=freeze_parts, device=device, - lr=args.lr, + lr=1e-5, n_epochs=args.epochs, save_root=args.save_root, scheduler_kwargs=scheduler_kwargs, @@ -158,4 +158,4 @@ def main(): if __name__ == "__main__": import warnings warnings.filterwarnings("ignore") - main() + main() \ No newline at end of file diff --git a/micro_sam/evaluation/inference.py b/micro_sam/evaluation/inference.py index f5c8a13a..8340b408 100644 --- a/micro_sam/evaluation/inference.py +++ b/micro_sam/evaluation/inference.py @@ -591,7 +591,6 @@ def run_instance_segmentation_with_decoder( test_image_paths: List[Union[str, os.PathLike]], lora_rank: Optional[int] = None, ) -> str: - embedding_folder = os.path.join(experiment_folder, "embeddings") # where the precomputed embeddings are saved os.makedirs(embedding_folder, exist_ok=True) From fee07c3845bf244cab7fddfe43c49640163f25ee Mon Sep 17 00:00:00 2001 From: Carolin Date: Sun, 1 Sep 2024 10:19:27 +0200 Subject: [PATCH 09/15] added flexible implementation of peft methods in the evaluation scripts --- finetuning/evaluation/evaluate_amg.py | 8 ++++---- finetuning/evaluation/evaluate_instance_segmentation.py | 4 ++-- finetuning/evaluation/iterative_prompting.py | 3 ++- finetuning/evaluation/util.py | 7 ++++--- micro_sam/evaluation/inference.py | 8 ++++---- micro_sam/instance_segmentation.py | 4 ++-- 6 files changed, 18 insertions(+), 16 deletions(-) diff --git a/finetuning/evaluation/evaluate_amg.py b/finetuning/evaluation/evaluate_amg.py index 931a5b73..e1e66b7b 100644 --- a/finetuning/evaluation/evaluate_amg.py +++ b/finetuning/evaluation/evaluate_amg.py @@ -7,7 +7,7 @@ from util import get_pred_paths, get_default_arguments, VANILLA_MODELS -def run_amg_inference(dataset_name, model_type, checkpoint, experiment_folder, lora_rank): +def run_amg_inference(dataset_name, model_type, checkpoint, experiment_folder, peft_kwargs): val_image_paths, val_gt_paths = get_paths(dataset_name, split="val") test_image_paths, _ = get_paths(dataset_name, split="test") prediction_folder = run_amg( @@ -17,7 +17,7 @@ def run_amg_inference(dataset_name, model_type, checkpoint, experiment_folder, l val_image_paths=val_image_paths, val_gt_paths=val_gt_paths, test_image_paths=test_image_paths, - lora_rank=lora_rank, + peft_kwargs=peft_kwargs, ) return prediction_folder @@ -33,8 +33,8 @@ def eval_amg(dataset_name, prediction_folder, experiment_folder): def main(): args = get_default_arguments() - - prediction_folder = run_amg_inference(args.dataset, args.model, args.checkpoint, args.experiment_folder, args.lora_rank) + peft_kwargs = {"rank": args.peft_rank, "module": args.peft_module} + prediction_folder = run_amg_inference(args.dataset, args.model, args.checkpoint, args.experiment_folder, peft_kwargs) eval_amg(args.dataset, prediction_folder, args.experiment_folder) diff --git a/finetuning/evaluation/evaluate_instance_segmentation.py b/finetuning/evaluation/evaluate_instance_segmentation.py index c41e9fb4..49f4b717 100644 --- a/finetuning/evaluation/evaluate_instance_segmentation.py +++ b/finetuning/evaluation/evaluate_instance_segmentation.py @@ -35,9 +35,9 @@ def eval_instance_segmentation_with_decoder(dataset_name, prediction_folder, exp def main(): args = get_default_arguments() - + peft_kwargs = {"rank": args.peft_rank, "module": args.peft_module} prediction_folder = run_instance_segmentation_with_decoder_inference( - args.dataset, args.model, args.checkpoint, args.experiment_folder, args.lora_rank, + args.dataset, args.model, args.checkpoint, args.experiment_folder, peft_kwargs, ) eval_instance_segmentation_with_decoder(args.dataset, prediction_folder, args.experiment_folder) diff --git a/finetuning/evaluation/iterative_prompting.py b/finetuning/evaluation/iterative_prompting.py index eae3f845..2dd09930 100644 --- a/finetuning/evaluation/iterative_prompting.py +++ b/finetuning/evaluation/iterative_prompting.py @@ -42,7 +42,8 @@ def main(): start_with_box_prompt = args.box # overwrite to start first iters' prompt with box instead of single point # get the predictor to perform inference - predictor = get_model(model_type=args.model, ckpt=args.checkpoint, lora_rank=args.lora_rank) + peft_kwargs = {"rank": args.peft_rank, "module": args.peft_module} + predictor = get_model(model_type=args.model, ckpt=args.checkpoint, peft_kwargs=peft_kwargs) prediction_root = _run_iterative_prompting( args.dataset, args.experiment_folder, predictor, start_with_box_prompt, args.use_masks diff --git a/finetuning/evaluation/util.py b/finetuning/evaluation/util.py index 5afde58d..3dedaa02 100644 --- a/finetuning/evaluation/util.py +++ b/finetuning/evaluation/util.py @@ -80,10 +80,10 @@ def get_dataset_paths(dataset_name, split_choice): return raw_dir, labels_dir -def get_model(model_type, ckpt, lora_rank): +def get_model(model_type, ckpt, peft_kwargs): predictor = get_sam_model( - model_type=model_type, checkpoint_path=ckpt, lora_rank=lora_rank, + model_type=model_type, checkpoint_path=ckpt, peft_kwargs=peft_kwargs, ) return predictor @@ -227,7 +227,8 @@ def get_default_arguments(): parser.add_argument( "--use_masks", action="store_true", help="To use logits masks for iterative prompting." ) - parser.add_argument("--lora_rank", default=None, type=int, help="The rank for low rank adaptation method.") + parser.add_argument("--peft_rank", default=None, type=int, help="The rank for peft method.") + parser.add_argument("--peft_module", default=None, type=int, help="The module for peft method. (e.g. LoRA or FacT)") args = parser.parse_args() return args diff --git a/micro_sam/evaluation/inference.py b/micro_sam/evaluation/inference.py index 8340b408..e1736fa5 100644 --- a/micro_sam/evaluation/inference.py +++ b/micro_sam/evaluation/inference.py @@ -547,12 +547,12 @@ def run_amg( test_image_paths: List[Union[str, os.PathLike]], iou_thresh_values: Optional[List[float]] = None, stability_score_values: Optional[List[float]] = None, - lora_rank: Optional[int] = None, + peft_kwargs: Optional[Dict] = None, ) -> str: embedding_folder = os.path.join(experiment_folder, "embeddings") # where the precomputed embeddings are saved os.makedirs(embedding_folder, exist_ok=True) - predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint, lora_rank=lora_rank) + predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint, peft_kwargs=peft_kwargs) amg = AutomaticMaskGenerator(predictor) amg_prefix = "amg" @@ -589,13 +589,13 @@ def run_instance_segmentation_with_decoder( val_image_paths: List[Union[str, os.PathLike]], val_gt_paths: List[Union[str, os.PathLike]], test_image_paths: List[Union[str, os.PathLike]], - lora_rank: Optional[int] = None, + peft_kwargs: Optional[Dict] = None, ) -> str: embedding_folder = os.path.join(experiment_folder, "embeddings") # where the precomputed embeddings are saved os.makedirs(embedding_folder, exist_ok=True) predictor, decoder = get_predictor_and_decoder( - model_type=model_type, checkpoint_path=checkpoint, lora_rank=lora_rank, + model_type=model_type, checkpoint_path=checkpoint, peft_kwargs=peft_kwargs, ) segmenter = InstanceSegmentationWithDecoder(predictor, decoder) seg_prefix = "instance_segmentation_with_decoder" diff --git a/micro_sam/instance_segmentation.py b/micro_sam/instance_segmentation.py index c8780428..80654e67 100644 --- a/micro_sam/instance_segmentation.py +++ b/micro_sam/instance_segmentation.py @@ -800,7 +800,7 @@ def get_predictor_and_decoder( model_type: str, checkpoint_path: Union[str, os.PathLike], device: Optional[Union[str, torch.device]] = None, - lora_rank: Optional[int] = None, + peft_kwargs: Optional[Dict] = None, ) -> Tuple[SamPredictor, DecoderAdapter]: """Load the SAM model (predictor) and instance segmentation decoder. @@ -823,7 +823,7 @@ def get_predictor_and_decoder( checkpoint_path=checkpoint_path, device=device, return_state=True, - lora_rank=lora_rank, + peft_kwargs=peft_kwargs, ) if "decoder_state" not in state: raise ValueError(f"The checkpoint at {checkpoint_path} does not contain a decoder state") From cfb7b4353d70d2661676406aae4495fd38d517a3 Mon Sep 17 00:00:00 2001 From: Carolin Date: Sun, 1 Sep 2024 10:21:30 +0200 Subject: [PATCH 10/15] changed lora_rank to peft_rank --- micro_sam/training/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/micro_sam/training/util.py b/micro_sam/training/util.py index 7650c2c3..a06e7502 100644 --- a/micro_sam/training/util.py +++ b/micro_sam/training/util.py @@ -100,7 +100,7 @@ def get_trainable_sam_model( # we would want to "freeze" all the components in the model if passed a list of parts for l_item in freeze: # in case LoRA is switched on, we cannot freeze the image encoder - if (lora_rank is not None) and (l_item == "image_encoder"): + if (peft_kwargs['rank'] is not None) and (l_item == "image_encoder"): raise ValueError("You cannot use LoRA & freeze the image encoder at the same time.") if name.startswith(f"{l_item}"): From b770399e154f5090e539f90582bf8a4e8b1c4b43 Mon Sep 17 00:00:00 2001 From: Anwai Archit Date: Thu, 5 Sep 2024 22:12:12 +0530 Subject: [PATCH 11/15] Refactor peft methods and fix tests --- finetuning/evaluation/evaluate_amg.py | 10 +- .../evaluate_instance_segmentation.py | 10 +- finetuning/evaluation/iterative_prompting.py | 11 +- .../evaluation/precompute_embeddings.py | 12 +- finetuning/evaluation/util.py | 11 +- finetuning/livecell/lora/train_livecell.py | 124 ------------------ micro_sam/models/peft_sam.py | 45 ++++--- micro_sam/models/sam_3d_wrapper.py | 2 +- micro_sam/models/simple_sam_3d_wrapper.py | 2 +- micro_sam/util.py | 13 +- 10 files changed, 64 insertions(+), 176 deletions(-) delete mode 100644 finetuning/livecell/lora/train_livecell.py diff --git a/finetuning/evaluation/evaluate_amg.py b/finetuning/evaluation/evaluate_amg.py index e1e66b7b..f171b9af 100644 --- a/finetuning/evaluation/evaluate_amg.py +++ b/finetuning/evaluation/evaluate_amg.py @@ -3,8 +3,10 @@ from micro_sam.evaluation.evaluation import run_evaluation from micro_sam.evaluation.inference import run_amg -from util import get_paths # comment this and create a custom function with the same name to run amg on your data -from util import get_pred_paths, get_default_arguments, VANILLA_MODELS +from util import ( + get_paths, # comment this line out and create a custom function with the same name to run amg on your data + get_pred_paths, get_default_arguments +) def run_amg_inference(dataset_name, model_type, checkpoint, experiment_folder, peft_kwargs): @@ -34,7 +36,9 @@ def eval_amg(dataset_name, prediction_folder, experiment_folder): def main(): args = get_default_arguments() peft_kwargs = {"rank": args.peft_rank, "module": args.peft_module} - prediction_folder = run_amg_inference(args.dataset, args.model, args.checkpoint, args.experiment_folder, peft_kwargs) + prediction_folder = run_amg_inference( + args.dataset, args.model, args.checkpoint, args.experiment_folder, peft_kwargs + ) eval_amg(args.dataset, prediction_folder, args.experiment_folder) diff --git a/finetuning/evaluation/evaluate_instance_segmentation.py b/finetuning/evaluation/evaluate_instance_segmentation.py index 49f4b717..bd311d57 100644 --- a/finetuning/evaluation/evaluate_instance_segmentation.py +++ b/finetuning/evaluation/evaluate_instance_segmentation.py @@ -3,12 +3,14 @@ from micro_sam.evaluation.evaluation import run_evaluation from micro_sam.evaluation.inference import run_instance_segmentation_with_decoder -from util import get_paths # comment this and create a custom function with the same name to run ais on your data -from util import get_pred_paths, get_default_arguments +from util import ( + get_paths, # comment this line out and create a custom function with the same name to run ais on your data + get_pred_paths, get_default_arguments +) def run_instance_segmentation_with_decoder_inference( - dataset_name, model_type, checkpoint, experiment_folder, lora_rank + dataset_name, model_type, checkpoint, experiment_folder, peft_kwargs, ): val_image_paths, val_gt_paths = get_paths(dataset_name, split="val") test_image_paths, _ = get_paths(dataset_name, split="test") @@ -19,7 +21,7 @@ def run_instance_segmentation_with_decoder_inference( val_image_paths=val_image_paths, val_gt_paths=val_gt_paths, test_image_paths=test_image_paths, - lora_rank=lora_rank, + peft_kwargs=peft_kwargs, ) return prediction_folder diff --git a/finetuning/evaluation/iterative_prompting.py b/finetuning/evaluation/iterative_prompting.py index 2dd09930..b261f4d9 100644 --- a/finetuning/evaluation/iterative_prompting.py +++ b/finetuning/evaluation/iterative_prompting.py @@ -1,10 +1,13 @@ import os +from micro_sam.util import get_sam_model from micro_sam.evaluation import inference from micro_sam.evaluation.evaluation import run_evaluation_for_iterative_prompting -from util import get_paths # comment this and create a custom function with the same name to run int. seg. on your data -from util import get_model, get_default_arguments +from util import ( + get_paths, # comment this line out and create a custom function with the same name to run int. seg. on your data + get_default_arguments +) def _run_iterative_prompting(dataset_name, exp_folder, predictor, start_with_box_prompt, use_masks): @@ -43,7 +46,9 @@ def main(): # get the predictor to perform inference peft_kwargs = {"rank": args.peft_rank, "module": args.peft_module} - predictor = get_model(model_type=args.model, ckpt=args.checkpoint, peft_kwargs=peft_kwargs) + predictor = get_sam_model( + model_type=args.model, checkpoint_path=args.checkpoint, peft_kwargs=peft_kwargs, + ) prediction_root = _run_iterative_prompting( args.dataset, args.experiment_folder, predictor, start_with_box_prompt, args.use_masks diff --git a/finetuning/evaluation/precompute_embeddings.py b/finetuning/evaluation/precompute_embeddings.py index 605627fe..40438906 100644 --- a/finetuning/evaluation/precompute_embeddings.py +++ b/finetuning/evaluation/precompute_embeddings.py @@ -1,15 +1,21 @@ import os +from micro_sam.util import get_sam_model from micro_sam.evaluation import precompute_all_embeddings -from util import get_paths # comment this and create a custom function with the same name to execute on your data -from util import get_model, get_default_arguments +from util import ( + get_paths, # comment this and create a custom function with the same name to execute on your data + get_default_arguments +) def main(): args = get_default_arguments() - predictor = get_model(model_type=args.model, ckpt=args.checkpoint, lora_rank=args.lora_rank) + peft_kwargs = {"rank": args.peft_rank, "module": args.peft_module} + predictor = get_sam_model( + model_type=args.model, checkpoint_path=args.checkpoint, peft_kwargs=peft_kwargs, + ) embedding_dir = os.path.join(args.experiment_folder, "embeddings") os.makedirs(embedding_dir, exist_ok=True) diff --git a/finetuning/evaluation/util.py b/finetuning/evaluation/util.py index 3dedaa02..d55009ee 100644 --- a/finetuning/evaluation/util.py +++ b/finetuning/evaluation/util.py @@ -5,7 +5,6 @@ from torch_em.data import datasets -from micro_sam.util import get_sam_model from micro_sam.evaluation.livecell import _get_livecell_paths @@ -80,14 +79,6 @@ def get_dataset_paths(dataset_name, split_choice): return raw_dir, labels_dir -def get_model(model_type, ckpt, peft_kwargs): - - predictor = get_sam_model( - model_type=model_type, checkpoint_path=ckpt, peft_kwargs=peft_kwargs, - ) - return predictor - - def get_paths(dataset_name, split): assert dataset_name in DATASETS, dataset_name @@ -228,7 +219,7 @@ def get_default_arguments(): "--use_masks", action="store_true", help="To use logits masks for iterative prompting." ) parser.add_argument("--peft_rank", default=None, type=int, help="The rank for peft method.") - parser.add_argument("--peft_module", default=None, type=int, help="The module for peft method. (e.g. LoRA or FacT)") + parser.add_argument("--peft_module", default=None, type=str, help="The module for peft method. (e.g. LoRA or FacT)") args = parser.parse_args() return args diff --git a/finetuning/livecell/lora/train_livecell.py b/finetuning/livecell/lora/train_livecell.py deleted file mode 100644 index 31977b21..00000000 --- a/finetuning/livecell/lora/train_livecell.py +++ /dev/null @@ -1,124 +0,0 @@ -import os -import argparse - -import torch - -from torch_em.data.datasets import get_livecell_loader -from torch_em.transform.label import PerObjectDistanceTransform - -import micro_sam.training as sam_training -from micro_sam.util import export_custom_sam_model - - -def get_dataloaders(patch_shape, data_path, cell_type=None): - """This returns the livecell data loaders implemented in torch_em: - https://github.com/constantinpape/torch-em/blob/main/torch_em/data/datasets/livecell.py - It will automatically download the livecell data. - - Note: to replace this with another data loader you need to return a torch data loader - that retuns `x, y` tensors, where `x` is the image data and `y` are the labels. - The labels have to be in a label mask instance segmentation format. - I.e. a tensor of the same spatial shape as `x`, with each object mask having its own ID. - Important: the ID 0 is reseved for background, and the IDs must be consecutive - """ - label_transform = PerObjectDistanceTransform( - distances=True, boundary_distances=True, directed_distances=False, foreground=True, instances=True, min_size=25 - ) - raw_transform = sam_training.identity # the current workflow avoids rescaling the inputs to [-1, 1] - train_loader = get_livecell_loader( - path=data_path, patch_shape=patch_shape, split="train", batch_size=2, num_workers=16, - cell_types=cell_type, download=True, shuffle=True, label_transform=label_transform, - raw_transform=raw_transform, label_dtype=torch.float32, - ) - val_loader = get_livecell_loader( - path=data_path, patch_shape=patch_shape, split="val", batch_size=4, num_workers=16, - cell_types=cell_type, download=True, shuffle=True, label_transform=label_transform, - raw_transform=raw_transform, label_dtype=torch.float32, - ) - - return train_loader, val_loader - - -def finetune_livecell(args): - """Code for finetuning SAM (using LoRA) on LIVECell - """ - # override this (below) if you have some more complex set-up and need to specify the exact gpu - device = "cuda" if torch.cuda.is_available() else "cpu" - - # training settings: - model_type = args.model_type - checkpoint_path = None # override this to start training from a custom checkpoint - patch_shape = (520, 704) # the patch shape for training - n_objects_per_batch = args.n_objects # this is the number of objects per batch that will be sampled - freeze_parts = args.freeze # override this to freeze different parts of the model - lora_rank = 4 # the rank for low rank adaptation - checkpoint_name = f"{args.model_type}/livecell_sam" - - # all the stuff we need for training - train_loader, val_loader = get_dataloaders(patch_shape=patch_shape, data_path=args.input_path) - scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 10, "verbose": True} - optimizer_class = torch.optim.AdamW - - # Run training. - sam_training.train_sam( - name=checkpoint_name, - model_type=model_type, - train_loader=train_loader, - val_loader=val_loader, - early_stopping=None, - n_objects_per_batch=n_objects_per_batch, - checkpoint_path=checkpoint_path, - freeze=freeze_parts, - device=device, - lr=1e-5, - n_iterations=args.iterations, - save_root=args.save_root, - scheduler_kwargs=scheduler_kwargs, - optimizer_class=optimizer_class, - lora_rank=lora_rank, - ) - - if args.export_path is not None: - checkpoint_path = os.path.join( - "" if args.save_root is None else args.save_root, "checkpoints", checkpoint_name, "best.pt" - ) - export_custom_sam_model( - checkpoint_path=checkpoint_path, model_type=model_type, save_path=args.export_path, - ) - - -def main(): - parser = argparse.ArgumentParser(description="Finetune Segment Anything for the LIVECell dataset.") - parser.add_argument( - "--input_path", "-i", default="/scratch/projects/nim00007/sam/data/livecell/", - help="The filepath to the LIVECell data. If the data does not exist yet it will be downloaded." - ) - parser.add_argument( - "--model_type", "-m", default="vit_b", - help="The model type to use for fine-tuning. Either vit_h, vit_b or vit_l." - ) - parser.add_argument( - "--save_root", "-s", default=None, - help="Where to save the checkpoint and logs. By default they will be saved where this script is run." - ) - parser.add_argument( - "--iterations", type=int, default=int(1e4), - help="For how many iterations should the model be trained? By default 100k." - ) - parser.add_argument( - "--export_path", "-e", - help="Where to export the finetuned model to. The exported model can be used in the annotation tools." - ) - parser.add_argument( - "--freeze", type=str, nargs="+", default=None, - help="Which parts of the model to freeze for finetuning." - ) - parser.add_argument( - "--n_objects", type=int, default=25, help="The number of instances (objects) per batch used for finetuning." - ) - args = parser.parse_args() - finetune_livecell(args) - - -if __name__ == "__main__": - main() diff --git a/micro_sam/models/peft_sam.py b/micro_sam/models/peft_sam.py index 4378501f..1639049c 100644 --- a/micro_sam/models/peft_sam.py +++ b/micro_sam/models/peft_sam.py @@ -61,18 +61,19 @@ class FacTSurgery(nn.Module): (Inspired from: https://github.com/cchen-cc/MA-SAM/blob/main/MA-SAM/sam_fact_tt_image_encoder.py) + Args: + rank: The rank of the decomposition matrices for updating weights in each attention layer. + block: The chosen attention blocks for implementing fact. """ def __init__( - self, - rank: int, - block: nn.Module, - FacTu: nn.Module, - FacTv: nn.Module, + self, + rank: int, + block: nn.Module, ): super().__init__() self.qkv_proj = block.attn.qkv - self.dim = self.qkv_proj.in_features + self.dim = self.qkv_proj.in_features self.q_FacTs = nn.Linear(rank, rank, bias=False) self.v_FacTs = nn.Linear(rank, rank, bias=False) @@ -81,29 +82,34 @@ def __init__( self.dp_q = nn.Dropout(0.1) self.dp_v = nn.Dropout(0.1) - self.FacTu = FacTu - self.FacTv = FacTv + self.FacTu = nn.Linear(self.dim, rank, bias=False) + self.FacTv = nn.Linear(rank, self.dim, bias=False) block.attn.qkv = self - def forward(self, x): + qkv = self.qkv_proj(x) # B, N, N, 3 * org_C + + new_q = self.q_FacTs(self.FacTu(x)) + new_v = self.v_FacTs(self.FacTu(x)) + + new_q = self.dp_q(new_q) + new_v = self.dp_v(new_v) + + new_q = self.FacTv(new_q) + new_v = self.FacTv(new_v) - qkv = self.qkv_proj(x) # B,N,N,3*org_C - new_q = self.FacTv(self.dp_q(self.q_FacTs(self.FacTu(x)))) - new_v = self.FacTv(self.dp_v(self.v_FacTs(self.FacTu(x)))) # NOTE : Scaling Factor was set to 1 as it can be tuned via the learning rate # Does it make sense to include it, in order to have similar learning rate as the original model? qkv[:, :, :, : self.dim] += new_q qkv[:, :, :, -self.dim:] += new_v + return qkv - + class PEFT_Sam(nn.Module): """Wraps the Segment Anything model's image encoder to different parameter efficient finetuning methods. - Inspired by https://github.com/JamesQFreeman/Sam_LoRA/ - Args: model: The Segment Anything model. rank: The rank for low-rank adaptation. @@ -122,10 +128,6 @@ def __init__( assert rank > 0 - dim = model.image_encoder.blocks[0].attn.qkv.in_features - self.FacTu = nn.Linear(dim, rank, bias=False) - self.FacTv = nn.Linear(rank, dim, bias=False) - if attention_layers_to_update: self.peft_layers = attention_layers_to_update else: # Applies PEFT to the image encoder by default @@ -142,11 +144,8 @@ def __init__( # If we only want specific layers with PEFT instead of all if t_layer_i not in self.peft_layers: continue - if peft_module == LoRASurgery: - peft_block = self.peft_module(rank=rank, block=blk) - else: - peft_block = self.peft_module(rank=rank, block=blk, FacTu=self.FacTu, FacTv=self.FacTv) + peft_block = self.peft_module(rank=rank, block=blk) self.peft_blocks.append(peft_block) self.peft_blocks = nn.ModuleList(self.peft_blocks) diff --git a/micro_sam/models/sam_3d_wrapper.py b/micro_sam/models/sam_3d_wrapper.py index 3e0b7573..ac5bb7bb 100644 --- a/micro_sam/models/sam_3d_wrapper.py +++ b/micro_sam/models/sam_3d_wrapper.py @@ -28,7 +28,7 @@ def get_sam_3d_model( flexible_load_checkpoint=True, num_multimask_outputs=n_classes, image_size=image_size, - lora_rank=lora_rank, + peft_kwargs={"rank": lora_rank}, ) sam_3d = Sam3DWrapper(sam, freeze_encoder=freeze_encoder_) diff --git a/micro_sam/models/simple_sam_3d_wrapper.py b/micro_sam/models/simple_sam_3d_wrapper.py index 6f67caa4..6b7032a9 100644 --- a/micro_sam/models/simple_sam_3d_wrapper.py +++ b/micro_sam/models/simple_sam_3d_wrapper.py @@ -23,7 +23,7 @@ def get_simple_sam_3d_model( return_sam=True, image_size=image_size, flexible_load_checkpoint=True, - lora_rank=lora_rank, + peft_kwargs={"rank": lora_rank}, ) # Make sure not to freeze the encoder when using LoRA. diff --git a/micro_sam/util.py b/micro_sam/util.py index b5bee771..91f283da 100644 --- a/micro_sam/util.py +++ b/micro_sam/util.py @@ -370,11 +370,16 @@ def get_sam_model( # Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything. # Overwrites the SAM model by freezing the backbone and allow low rank adaption to attention layers. - if peft_kwargs is not None: - assert peft_kwargs["module"] in ["LoRASurgery", "FacTSurgery"], "Invalid PEFT module." + if peft_kwargs is not None and isinstance(peft_kwargs, dict): if abbreviated_model_type == "vit_t": - raise ValueError("Parameter efficient finetuning is not supported for 'mobile-sam'.") - sam = custom_models.peft_sam.PEFT_Sam(sam, rank=peft_kwargs['rank'], peft_module=peft_kwargs['module']).sam + raise ValueError("'micro-sam' does not support parameter efficient finetuning for 'mobile-sam'.") + + peft_module = peft_kwargs.get("peft_module") + if peft_module is not None: + from .models.peft_sam import LoRASurgery, FacTSurgery + assert peft_module in [LoRASurgery, FacTSurgery], "Invalid PEFT module." + + sam = custom_models.peft_sam.PEFT_Sam(sam, **peft_kwargs).sam # In case the model checkpoints have some issues when it is initialized with different parameters than default. if flexible_load_checkpoint: From 9d0fb5d14c85e4e46d4a0215cf9136014b447918 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Sat, 7 Sep 2024 08:52:10 +0200 Subject: [PATCH 12/15] Fix tests --- micro_sam/models/sam_3d_wrapper.py | 10 +++++++--- micro_sam/models/simple_sam_3d_wrapper.py | 6 +++++- micro_sam/util.py | 2 +- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/micro_sam/models/sam_3d_wrapper.py b/micro_sam/models/sam_3d_wrapper.py index ac5bb7bb..c6a76d96 100644 --- a/micro_sam/models/sam_3d_wrapper.py +++ b/micro_sam/models/sam_3d_wrapper.py @@ -18,8 +18,10 @@ def get_sam_3d_model( model_type="vit_b", checkpoint_path=None, ): - # Make sure not to freeze the encoder when using LoRA. - freeze_encoder_ = freeze_encoder if lora_rank is None else False + peft_kwargs = {} + if lora_rank is not None: + peft_kwargs["rank"] = lora_rank + _, sam = get_sam_model( model_type=model_type, device=device, @@ -28,9 +30,11 @@ def get_sam_3d_model( flexible_load_checkpoint=True, num_multimask_outputs=n_classes, image_size=image_size, - peft_kwargs={"rank": lora_rank}, + peft_kwargs=peft_kwargs, ) + # Make sure not to freeze the encoder when using LoRA. + freeze_encoder_ = freeze_encoder if lora_rank is None else False sam_3d = Sam3DWrapper(sam, freeze_encoder=freeze_encoder_) sam_3d.to(device) return sam_3d diff --git a/micro_sam/models/simple_sam_3d_wrapper.py b/micro_sam/models/simple_sam_3d_wrapper.py index 6b7032a9..47d2d60b 100644 --- a/micro_sam/models/simple_sam_3d_wrapper.py +++ b/micro_sam/models/simple_sam_3d_wrapper.py @@ -16,6 +16,10 @@ def get_simple_sam_3d_model( model_type="vit_b", checkpoint_path=None, ): + peft_kwargs = {} + if lora_rank is not None: + peft_kwargs["rank"] = lora_rank + _, sam = get_sam_model( model_type=model_type, device=device, @@ -23,7 +27,7 @@ def get_simple_sam_3d_model( return_sam=True, image_size=image_size, flexible_load_checkpoint=True, - peft_kwargs={"rank": lora_rank}, + peft_kwargs=peft_kwargs, ) # Make sure not to freeze the encoder when using LoRA. diff --git a/micro_sam/util.py b/micro_sam/util.py index 91f283da..af1d1dc6 100644 --- a/micro_sam/util.py +++ b/micro_sam/util.py @@ -370,7 +370,7 @@ def get_sam_model( # Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything. # Overwrites the SAM model by freezing the backbone and allow low rank adaption to attention layers. - if peft_kwargs is not None and isinstance(peft_kwargs, dict): + if peft_kwargs and isinstance(peft_kwargs, dict): if abbreviated_model_type == "vit_t": raise ValueError("'micro-sam' does not support parameter efficient finetuning for 'mobile-sam'.") From 5ef1835e83405ddd58f5d14e9bd94b6ccaa569de Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Mon, 9 Sep 2024 14:14:03 +0530 Subject: [PATCH 13/15] Update micro_sam/training/util.py Co-authored-by: Carolin Teuber <115626873+caroteu@users.noreply.github.com> --- micro_sam/training/util.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/micro_sam/training/util.py b/micro_sam/training/util.py index a06e7502..fb4834c0 100644 --- a/micro_sam/training/util.py +++ b/micro_sam/training/util.py @@ -83,9 +83,16 @@ def get_trainable_sam_model( # NOTE: This is done exclusive to "get_sam_model" here to use PEFT's layer-specific initialization on top. # Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything. # Overwrites the SAM model by freezing the backbone and allow low rank adaption to attention layers. - if peft_kwargs is not None: - assert peft_kwargs['module'] in ['LoRASurgery', 'FacTSurgery'], "Invalid PEFT module." - sam = custom_models.peft_sam.PEFT_Sam(sam, rank=peft_kwargs['rank'], peft_module=peft_kwargs['module']).sam + if peft_kwargs and isinstance(peft_kwargs, dict): + if model_type[:5] == "vit_t": + raise ValueError("'micro-sam' does not support parameter efficient finetuning for 'mobile-sam'.") + + peft_module = peft_kwargs.get("peft_module") + if peft_module is not None: + from micro_sam.models.peft_sam import LoRASurgery, FacTSurgery + assert peft_module in [LoRASurgery, FacTSurgery], "Invalid PEFT module." + + sam = custom_models.peft_sam.PEFT_Sam(sam, **peft_kwargs).sam # freeze components of the model if freeze was passed # ideally we would want to add components in such a way that: From 9aaf2ba091277156492f96f530ca56e4adcf01cf Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Tue, 17 Sep 2024 13:23:49 +0200 Subject: [PATCH 14/15] Make dropout optional --- micro_sam/models/peft_sam.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/micro_sam/models/peft_sam.py b/micro_sam/models/peft_sam.py index 1639049c..59167a1d 100644 --- a/micro_sam/models/peft_sam.py +++ b/micro_sam/models/peft_sam.py @@ -1,5 +1,5 @@ import math -from typing import List, Union +from typing import List, Union, Optional import torch.nn as nn @@ -70,6 +70,7 @@ def __init__( self, rank: int, block: nn.Module, + dropout: Optional[float] = None, ): super().__init__() self.qkv_proj = block.attn.qkv @@ -78,9 +79,11 @@ def __init__( self.q_FacTs = nn.Linear(rank, rank, bias=False) self.v_FacTs = nn.Linear(rank, rank, bias=False) - # NOTE : Dropout is not included in the original implementation - self.dp_q = nn.Dropout(0.1) - self.dp_v = nn.Dropout(0.1) + self.dropout = dropout + if self.dropout is not None: + # NOTE : Dropout is not included in the original implementation + self.dp_q = nn.Dropout(self.dropout) + self.dp_v = nn.Dropout(self.dropout) self.FacTu = nn.Linear(self.dim, rank, bias=False) self.FacTv = nn.Linear(rank, self.dim, bias=False) @@ -93,8 +96,9 @@ def forward(self, x): new_q = self.q_FacTs(self.FacTu(x)) new_v = self.v_FacTs(self.FacTu(x)) - new_q = self.dp_q(new_q) - new_v = self.dp_v(new_v) + if self.dropout is not None: + new_q = self.dp_q(new_q) + new_v = self.dp_v(new_v) new_q = self.FacTv(new_q) new_v = self.FacTv(new_v) From 1de0cbdf7b1ecea276332c42b2ecc566d12e3882 Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Tue, 17 Sep 2024 14:23:05 +0200 Subject: [PATCH 15/15] Update test_peft_sam.py --- test/test_models/test_peft_sam.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_models/test_peft_sam.py b/test/test_models/test_peft_sam.py index 4c66702a..509a6765 100644 --- a/test/test_models/test_peft_sam.py +++ b/test/test_models/test_peft_sam.py @@ -1,7 +1,7 @@ import unittest -from micro_sam.models.peft_sam import FacTSurgery, LoRASurgery import torch + import micro_sam.util as util @@ -9,7 +9,7 @@ class TestPEFTSam(unittest.TestCase): model_type = "vit_b" def test_lora_sam(self): - from micro_sam.models.peft_sam import PEFT_Sam + from micro_sam.models.peft_sam import PEFT_Sam, LoRASurgery _, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu") peft_sam = PEFT_Sam(sam, rank=2, peft_module=LoRASurgery) @@ -23,7 +23,7 @@ def test_lora_sam(self): self.assertEqual(masks.shape, expected_shape) def test_fact_sam(self): - from micro_sam.models.peft_sam import PEFT_Sam + from micro_sam.models.peft_sam import PEFT_Sam, FacTSurgery _, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu") peft_sam = PEFT_Sam(sam, rank=2, peft_module=FacTSurgery)