diff --git a/finetuning/evaluation/evaluate_amg.py b/finetuning/evaluation/evaluate_amg.py index 8f8f132d..f171b9af 100644 --- a/finetuning/evaluation/evaluate_amg.py +++ b/finetuning/evaluation/evaluate_amg.py @@ -3,11 +3,13 @@ from micro_sam.evaluation.evaluation import run_evaluation from micro_sam.evaluation.inference import run_amg -from util import get_paths # comment this and create a custom function with the same name to run amg on your data -from util import get_pred_paths, get_default_arguments, VANILLA_MODELS +from util import ( + get_paths, # comment this line out and create a custom function with the same name to run amg on your data + get_pred_paths, get_default_arguments +) -def run_amg_inference(dataset_name, model_type, checkpoint, experiment_folder, lora_rank): +def run_amg_inference(dataset_name, model_type, checkpoint, experiment_folder, peft_kwargs): val_image_paths, val_gt_paths = get_paths(dataset_name, split="val") test_image_paths, _ = get_paths(dataset_name, split="test") prediction_folder = run_amg( @@ -17,7 +19,7 @@ def run_amg_inference(dataset_name, model_type, checkpoint, experiment_folder, l val_image_paths=val_image_paths, val_gt_paths=val_gt_paths, test_image_paths=test_image_paths, - lora_rank=lora_rank, + peft_kwargs=peft_kwargs, ) return prediction_folder @@ -33,12 +35,10 @@ def eval_amg(dataset_name, prediction_folder, experiment_folder): def main(): args = get_default_arguments() - if args.checkpoint is None: - ckpt = VANILLA_MODELS[args.model] - else: - ckpt = args.checkpoint - - prediction_folder = run_amg_inference(args.dataset, args.model, ckpt, args.experiment_folder, args.lora_rank) + peft_kwargs = {"rank": args.peft_rank, "module": args.peft_module} + prediction_folder = run_amg_inference( + args.dataset, args.model, args.checkpoint, args.experiment_folder, peft_kwargs + ) eval_amg(args.dataset, prediction_folder, args.experiment_folder) diff --git a/finetuning/evaluation/evaluate_instance_segmentation.py b/finetuning/evaluation/evaluate_instance_segmentation.py index c41e9fb4..bd311d57 100644 --- a/finetuning/evaluation/evaluate_instance_segmentation.py +++ b/finetuning/evaluation/evaluate_instance_segmentation.py @@ -3,12 +3,14 @@ from micro_sam.evaluation.evaluation import run_evaluation from micro_sam.evaluation.inference import run_instance_segmentation_with_decoder -from util import get_paths # comment this and create a custom function with the same name to run ais on your data -from util import get_pred_paths, get_default_arguments +from util import ( + get_paths, # comment this line out and create a custom function with the same name to run ais on your data + get_pred_paths, get_default_arguments +) def run_instance_segmentation_with_decoder_inference( - dataset_name, model_type, checkpoint, experiment_folder, lora_rank + dataset_name, model_type, checkpoint, experiment_folder, peft_kwargs, ): val_image_paths, val_gt_paths = get_paths(dataset_name, split="val") test_image_paths, _ = get_paths(dataset_name, split="test") @@ -19,7 +21,7 @@ def run_instance_segmentation_with_decoder_inference( val_image_paths=val_image_paths, val_gt_paths=val_gt_paths, test_image_paths=test_image_paths, - lora_rank=lora_rank, + peft_kwargs=peft_kwargs, ) return prediction_folder @@ -35,9 +37,9 @@ def eval_instance_segmentation_with_decoder(dataset_name, prediction_folder, exp def main(): args = get_default_arguments() - + peft_kwargs = {"rank": args.peft_rank, "module": args.peft_module} prediction_folder = run_instance_segmentation_with_decoder_inference( - args.dataset, args.model, args.checkpoint, args.experiment_folder, args.lora_rank, + args.dataset, args.model, args.checkpoint, args.experiment_folder, peft_kwargs, ) eval_instance_segmentation_with_decoder(args.dataset, prediction_folder, args.experiment_folder) diff --git a/finetuning/evaluation/iterative_prompting.py b/finetuning/evaluation/iterative_prompting.py index eae3f845..b261f4d9 100644 --- a/finetuning/evaluation/iterative_prompting.py +++ b/finetuning/evaluation/iterative_prompting.py @@ -1,10 +1,13 @@ import os +from micro_sam.util import get_sam_model from micro_sam.evaluation import inference from micro_sam.evaluation.evaluation import run_evaluation_for_iterative_prompting -from util import get_paths # comment this and create a custom function with the same name to run int. seg. on your data -from util import get_model, get_default_arguments +from util import ( + get_paths, # comment this line out and create a custom function with the same name to run int. seg. on your data + get_default_arguments +) def _run_iterative_prompting(dataset_name, exp_folder, predictor, start_with_box_prompt, use_masks): @@ -42,7 +45,10 @@ def main(): start_with_box_prompt = args.box # overwrite to start first iters' prompt with box instead of single point # get the predictor to perform inference - predictor = get_model(model_type=args.model, ckpt=args.checkpoint, lora_rank=args.lora_rank) + peft_kwargs = {"rank": args.peft_rank, "module": args.peft_module} + predictor = get_sam_model( + model_type=args.model, checkpoint_path=args.checkpoint, peft_kwargs=peft_kwargs, + ) prediction_root = _run_iterative_prompting( args.dataset, args.experiment_folder, predictor, start_with_box_prompt, args.use_masks diff --git a/finetuning/evaluation/precompute_embeddings.py b/finetuning/evaluation/precompute_embeddings.py index 605627fe..40438906 100644 --- a/finetuning/evaluation/precompute_embeddings.py +++ b/finetuning/evaluation/precompute_embeddings.py @@ -1,15 +1,21 @@ import os +from micro_sam.util import get_sam_model from micro_sam.evaluation import precompute_all_embeddings -from util import get_paths # comment this and create a custom function with the same name to execute on your data -from util import get_model, get_default_arguments +from util import ( + get_paths, # comment this and create a custom function with the same name to execute on your data + get_default_arguments +) def main(): args = get_default_arguments() - predictor = get_model(model_type=args.model, ckpt=args.checkpoint, lora_rank=args.lora_rank) + peft_kwargs = {"rank": args.peft_rank, "module": args.peft_module} + predictor = get_sam_model( + model_type=args.model, checkpoint_path=args.checkpoint, peft_kwargs=peft_kwargs, + ) embedding_dir = os.path.join(args.experiment_folder, "embeddings") os.makedirs(embedding_dir, exist_ok=True) diff --git a/finetuning/evaluation/submit_all_evaluation.py b/finetuning/evaluation/submit_all_evaluation.py index b64549de..465d96df 100644 --- a/finetuning/evaluation/submit_all_evaluation.py +++ b/finetuning/evaluation/submit_all_evaluation.py @@ -221,4 +221,4 @@ def main(args): parser.add_argument("-s", "--specific_experiment", type=str, default=None) args = parser.parse_args() - main(args) + main(args) \ No newline at end of file diff --git a/finetuning/evaluation/util.py b/finetuning/evaluation/util.py index 9780cc70..d55009ee 100644 --- a/finetuning/evaluation/util.py +++ b/finetuning/evaluation/util.py @@ -5,7 +5,6 @@ from torch_em.data import datasets -from micro_sam.util import get_sam_model from micro_sam.evaluation.livecell import _get_livecell_paths @@ -80,16 +79,6 @@ def get_dataset_paths(dataset_name, split_choice): return raw_dir, labels_dir -def get_model(model_type, ckpt, lora_rank): - if ckpt is None: - ckpt = VANILLA_MODELS[model_type] - - predictor = get_sam_model( - model_type=model_type, checkpoint_path=ckpt, lora_rank=lora_rank, - ) - return predictor - - def get_paths(dataset_name, split): assert dataset_name in DATASETS, dataset_name @@ -222,14 +211,15 @@ def get_default_arguments(): parser.add_argument( "-m", "--model", type=str, required=True, help="Provide the model type to initialize the predictor" ) - parser.add_argument("-c", "--checkpoint", type=none_or_str, required=True, default=None) + parser.add_argument("-c", "--checkpoint", type=none_or_str, default=None) parser.add_argument("-e", "--experiment_folder", type=str, required=True) parser.add_argument("-d", "--dataset", type=str, default=None) parser.add_argument("--box", action="store_true", help="If passed, starts with first prompt as box") parser.add_argument( "--use_masks", action="store_true", help="To use logits masks for iterative prompting." ) - parser.add_argument("--lora_rank", default=None, type=int, help="The rank for low rank adaptation method.") + parser.add_argument("--peft_rank", default=None, type=int, help="The rank for peft method.") + parser.add_argument("--peft_module", default=None, type=str, help="The module for peft method. (e.g. LoRA or FacT)") args = parser.parse_args() return args diff --git a/finetuning/specialists/resource-efficient/covid_if_finetuning.py b/finetuning/specialists/resource-efficient/covid_if_finetuning.py index 2107e7f1..632f8ba2 100644 --- a/finetuning/specialists/resource-efficient/covid_if_finetuning.py +++ b/finetuning/specialists/resource-efficient/covid_if_finetuning.py @@ -158,4 +158,4 @@ def main(): if __name__ == "__main__": import warnings warnings.filterwarnings("ignore") - main() + main() \ No newline at end of file diff --git a/micro_sam/evaluation/inference.py b/micro_sam/evaluation/inference.py index 8340b408..e1736fa5 100644 --- a/micro_sam/evaluation/inference.py +++ b/micro_sam/evaluation/inference.py @@ -547,12 +547,12 @@ def run_amg( test_image_paths: List[Union[str, os.PathLike]], iou_thresh_values: Optional[List[float]] = None, stability_score_values: Optional[List[float]] = None, - lora_rank: Optional[int] = None, + peft_kwargs: Optional[Dict] = None, ) -> str: embedding_folder = os.path.join(experiment_folder, "embeddings") # where the precomputed embeddings are saved os.makedirs(embedding_folder, exist_ok=True) - predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint, lora_rank=lora_rank) + predictor = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint, peft_kwargs=peft_kwargs) amg = AutomaticMaskGenerator(predictor) amg_prefix = "amg" @@ -589,13 +589,13 @@ def run_instance_segmentation_with_decoder( val_image_paths: List[Union[str, os.PathLike]], val_gt_paths: List[Union[str, os.PathLike]], test_image_paths: List[Union[str, os.PathLike]], - lora_rank: Optional[int] = None, + peft_kwargs: Optional[Dict] = None, ) -> str: embedding_folder = os.path.join(experiment_folder, "embeddings") # where the precomputed embeddings are saved os.makedirs(embedding_folder, exist_ok=True) predictor, decoder = get_predictor_and_decoder( - model_type=model_type, checkpoint_path=checkpoint, lora_rank=lora_rank, + model_type=model_type, checkpoint_path=checkpoint, peft_kwargs=peft_kwargs, ) segmenter = InstanceSegmentationWithDecoder(predictor, decoder) seg_prefix = "instance_segmentation_with_decoder" diff --git a/micro_sam/instance_segmentation.py b/micro_sam/instance_segmentation.py index c8780428..80654e67 100644 --- a/micro_sam/instance_segmentation.py +++ b/micro_sam/instance_segmentation.py @@ -800,7 +800,7 @@ def get_predictor_and_decoder( model_type: str, checkpoint_path: Union[str, os.PathLike], device: Optional[Union[str, torch.device]] = None, - lora_rank: Optional[int] = None, + peft_kwargs: Optional[Dict] = None, ) -> Tuple[SamPredictor, DecoderAdapter]: """Load the SAM model (predictor) and instance segmentation decoder. @@ -823,7 +823,7 @@ def get_predictor_and_decoder( checkpoint_path=checkpoint_path, device=device, return_state=True, - lora_rank=lora_rank, + peft_kwargs=peft_kwargs, ) if "decoder_state" not in state: raise ValueError(f"The checkpoint at {checkpoint_path} does not contain a decoder state") diff --git a/micro_sam/models/peft_sam.py b/micro_sam/models/peft_sam.py index 2bdeed70..59167a1d 100644 --- a/micro_sam/models/peft_sam.py +++ b/micro_sam/models/peft_sam.py @@ -1,5 +1,5 @@ import math -from typing import List, Union +from typing import List, Union, Optional import torch.nn as nn @@ -56,11 +56,64 @@ def forward(self, x): return qkv +class FacTSurgery(nn.Module): + """Operates on the attention layers for performing factorized attention. + + (Inspired from: https://github.com/cchen-cc/MA-SAM/blob/main/MA-SAM/sam_fact_tt_image_encoder.py) + + Args: + rank: The rank of the decomposition matrices for updating weights in each attention layer. + block: The chosen attention blocks for implementing fact. + """ + + def __init__( + self, + rank: int, + block: nn.Module, + dropout: Optional[float] = None, + ): + super().__init__() + self.qkv_proj = block.attn.qkv + self.dim = self.qkv_proj.in_features + + self.q_FacTs = nn.Linear(rank, rank, bias=False) + self.v_FacTs = nn.Linear(rank, rank, bias=False) + + self.dropout = dropout + if self.dropout is not None: + # NOTE : Dropout is not included in the original implementation + self.dp_q = nn.Dropout(self.dropout) + self.dp_v = nn.Dropout(self.dropout) + + self.FacTu = nn.Linear(self.dim, rank, bias=False) + self.FacTv = nn.Linear(rank, self.dim, bias=False) + + block.attn.qkv = self + + def forward(self, x): + qkv = self.qkv_proj(x) # B, N, N, 3 * org_C + + new_q = self.q_FacTs(self.FacTu(x)) + new_v = self.v_FacTs(self.FacTu(x)) + + if self.dropout is not None: + new_q = self.dp_q(new_q) + new_v = self.dp_v(new_v) + + new_q = self.FacTv(new_q) + new_v = self.FacTv(new_v) + + # NOTE : Scaling Factor was set to 1 as it can be tuned via the learning rate + # Does it make sense to include it, in order to have similar learning rate as the original model? + qkv[:, :, :, : self.dim] += new_q + qkv[:, :, :, -self.dim:] += new_v + + return qkv + + class PEFT_Sam(nn.Module): """Wraps the Segment Anything model's image encoder to different parameter efficient finetuning methods. - Inspired by https://github.com/JamesQFreeman/Sam_LoRA/ - Args: model: The Segment Anything model. rank: The rank for low-rank adaptation. diff --git a/micro_sam/models/sam_3d_wrapper.py b/micro_sam/models/sam_3d_wrapper.py index 3e0b7573..c6a76d96 100644 --- a/micro_sam/models/sam_3d_wrapper.py +++ b/micro_sam/models/sam_3d_wrapper.py @@ -18,8 +18,10 @@ def get_sam_3d_model( model_type="vit_b", checkpoint_path=None, ): - # Make sure not to freeze the encoder when using LoRA. - freeze_encoder_ = freeze_encoder if lora_rank is None else False + peft_kwargs = {} + if lora_rank is not None: + peft_kwargs["rank"] = lora_rank + _, sam = get_sam_model( model_type=model_type, device=device, @@ -28,9 +30,11 @@ def get_sam_3d_model( flexible_load_checkpoint=True, num_multimask_outputs=n_classes, image_size=image_size, - lora_rank=lora_rank, + peft_kwargs=peft_kwargs, ) + # Make sure not to freeze the encoder when using LoRA. + freeze_encoder_ = freeze_encoder if lora_rank is None else False sam_3d = Sam3DWrapper(sam, freeze_encoder=freeze_encoder_) sam_3d.to(device) return sam_3d diff --git a/micro_sam/models/simple_sam_3d_wrapper.py b/micro_sam/models/simple_sam_3d_wrapper.py index 6f67caa4..47d2d60b 100644 --- a/micro_sam/models/simple_sam_3d_wrapper.py +++ b/micro_sam/models/simple_sam_3d_wrapper.py @@ -16,6 +16,10 @@ def get_simple_sam_3d_model( model_type="vit_b", checkpoint_path=None, ): + peft_kwargs = {} + if lora_rank is not None: + peft_kwargs["rank"] = lora_rank + _, sam = get_sam_model( model_type=model_type, device=device, @@ -23,7 +27,7 @@ def get_simple_sam_3d_model( return_sam=True, image_size=image_size, flexible_load_checkpoint=True, - lora_rank=lora_rank, + peft_kwargs=peft_kwargs, ) # Make sure not to freeze the encoder when using LoRA. diff --git a/micro_sam/training/training.py b/micro_sam/training/training.py index bb31fa38..4e0b72a9 100644 --- a/micro_sam/training/training.py +++ b/micro_sam/training/training.py @@ -154,6 +154,7 @@ def train_sam( save_every_kth_epoch: Optional[int] = None, pbar_signals: Optional[QObject] = None, optimizer_class: Optional[Optimizer] = torch.optim.AdamW, + peft_kwargs: Optional[Dict] = None, **model_kwargs, ) -> None: """Run training for a SAM model. @@ -196,7 +197,6 @@ def train_sam( _check_loader(val_loader, with_segmentation_decoder) device = get_device(device) - # Get the trainable segment anything model. model, state = get_trainable_sam_model( model_type=model_type, @@ -204,9 +204,9 @@ def train_sam( freeze=freeze, checkpoint_path=checkpoint_path, return_state=True, + peft_kwargs=peft_kwargs, **model_kwargs ) - # This class creates all the training data for a batch (inputs, prompts and labels). convert_inputs = ConvertToSamInputs(transform=model.transform, box_distortion_factor=0.025) diff --git a/micro_sam/training/util.py b/micro_sam/training/util.py index 759c905e..fb4834c0 100644 --- a/micro_sam/training/util.py +++ b/micro_sam/training/util.py @@ -44,8 +44,7 @@ def get_trainable_sam_model( checkpoint_path: Optional[Union[str, os.PathLike]] = None, freeze: Optional[List[str]] = None, return_state: bool = False, - lora_rank: Optional[int] = None, - lora_kwargs: Optional[Dict] = None, + peft_kwargs: Optional[Dict] = None, flexible_load_checkpoint: bool = False, **model_kwargs ) -> TrainableSAM: @@ -84,8 +83,16 @@ def get_trainable_sam_model( # NOTE: This is done exclusive to "get_sam_model" here to use PEFT's layer-specific initialization on top. # Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything. # Overwrites the SAM model by freezing the backbone and allow low rank adaption to attention layers. - if lora_rank is not None: - sam = custom_models.peft_sam.PEFT_Sam(sam, rank=lora_rank, **({} if lora_kwargs is None else lora_kwargs)).sam + if peft_kwargs and isinstance(peft_kwargs, dict): + if model_type[:5] == "vit_t": + raise ValueError("'micro-sam' does not support parameter efficient finetuning for 'mobile-sam'.") + + peft_module = peft_kwargs.get("peft_module") + if peft_module is not None: + from micro_sam.models.peft_sam import LoRASurgery, FacTSurgery + assert peft_module in [LoRASurgery, FacTSurgery], "Invalid PEFT module." + + sam = custom_models.peft_sam.PEFT_Sam(sam, **peft_kwargs).sam # freeze components of the model if freeze was passed # ideally we would want to add components in such a way that: @@ -100,7 +107,7 @@ def get_trainable_sam_model( # we would want to "freeze" all the components in the model if passed a list of parts for l_item in freeze: # in case LoRA is switched on, we cannot freeze the image encoder - if (lora_rank is not None) and (l_item == "image_encoder"): + if (peft_kwargs['rank'] is not None) and (l_item == "image_encoder"): raise ValueError("You cannot use LoRA & freeze the image encoder at the same time.") if name.startswith(f"{l_item}"): diff --git a/micro_sam/util.py b/micro_sam/util.py index 45550a49..af1d1dc6 100644 --- a/micro_sam/util.py +++ b/micro_sam/util.py @@ -273,8 +273,7 @@ def get_sam_model( checkpoint_path: Optional[Union[str, os.PathLike]] = None, return_sam: bool = False, return_state: bool = False, - lora_rank: Optional[int] = None, - lora_kwargs: Optional[Dict] = None, + peft_kwargs: Optional[Dict] = None, flexible_load_checkpoint: bool = False, **model_kwargs, ) -> SamPredictor: @@ -371,10 +370,16 @@ def get_sam_model( # Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything. # Overwrites the SAM model by freezing the backbone and allow low rank adaption to attention layers. - if lora_rank is not None: + if peft_kwargs and isinstance(peft_kwargs, dict): if abbreviated_model_type == "vit_t": - raise ValueError("Parameter efficient finetuning is not supported for 'mobile-sam'.") - sam = custom_models.peft_sam.PEFT_Sam(sam, rank=lora_rank, **({} if lora_kwargs is None else lora_kwargs)).sam + raise ValueError("'micro-sam' does not support parameter efficient finetuning for 'mobile-sam'.") + + peft_module = peft_kwargs.get("peft_module") + if peft_module is not None: + from .models.peft_sam import LoRASurgery, FacTSurgery + assert peft_module in [LoRASurgery, FacTSurgery], "Invalid PEFT module." + + sam = custom_models.peft_sam.PEFT_Sam(sam, **peft_kwargs).sam # In case the model checkpoints have some issues when it is initialized with different parameters than default. if flexible_load_checkpoint: diff --git a/test/test_models/test_peft_sam.py b/test/test_models/test_peft_sam.py index 1af3ef2c..509a6765 100644 --- a/test/test_models/test_peft_sam.py +++ b/test/test_models/test_peft_sam.py @@ -1,17 +1,32 @@ import unittest import torch + import micro_sam.util as util class TestPEFTSam(unittest.TestCase): model_type = "vit_b" - def test_peft_sam(self): - from micro_sam.models.peft_sam import PEFT_Sam + def test_lora_sam(self): + from micro_sam.models.peft_sam import PEFT_Sam, LoRASurgery + + _, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu") + peft_sam = PEFT_Sam(sam, rank=2, peft_module=LoRASurgery) + + shape = (3, 1024, 1024) + expected_shape = (1, 3, 1024, 1024) + with torch.no_grad(): + batched_input = [{"image": torch.rand(*shape), "original_size": shape[1:]}] + output = peft_sam(batched_input, multimask_output=True) + masks = output[0]["masks"] + self.assertEqual(masks.shape, expected_shape) + + def test_fact_sam(self): + from micro_sam.models.peft_sam import PEFT_Sam, FacTSurgery _, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu") - peft_sam = PEFT_Sam(sam, rank=2) + peft_sam = PEFT_Sam(sam, rank=2, peft_module=FacTSurgery) shape = (3, 1024, 1024) expected_shape = (1, 3, 1024, 1024)