From e0a53b8903b5b964efe012071ecc69215a052b81 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Sat, 28 Sep 2024 22:58:45 +0200 Subject: [PATCH 1/6] Add selective peft methods --- micro_sam/models/peft_sam.py | 78 ++++++++++++++++++++++++++++++------ micro_sam/training/util.py | 17 +++----- micro_sam/util.py | 11 +---- 3 files changed, 73 insertions(+), 33 deletions(-) diff --git a/micro_sam/models/peft_sam.py b/micro_sam/models/peft_sam.py index 59167a1df..50a182f99 100644 --- a/micro_sam/models/peft_sam.py +++ b/micro_sam/models/peft_sam.py @@ -23,11 +23,7 @@ class LoRASurgery(nn.Module): rank: The rank of the decomposition matrices for updating weights in each attention layer. block: The chosen attention blocks for implementing lora. """ - def __init__( - self, - rank: int, - block: nn.Module, - ): + def __init__(self, rank: int, block: nn.Module): super().__init__() self.qkv_proj = block.attn.qkv self.dim = self.qkv_proj.in_features @@ -66,12 +62,7 @@ class FacTSurgery(nn.Module): block: The chosen attention blocks for implementing fact. """ - def __init__( - self, - rank: int, - block: nn.Module, - dropout: Optional[float] = None, - ): + def __init__(self, rank: int, block: nn.Module, dropout: Optional[float] = None): super().__init__() self.qkv_proj = block.attn.qkv self.dim = self.qkv_proj.in_features @@ -111,6 +102,64 @@ def forward(self, x): return qkv +class SelectiveSurgery(nn.Module): + """Base class for selectively allowing gradient updates for certain parameters. + """ + def __init__(self, block: nn.Module): + super().__init__() + self.block = block + + def allow_gradient_update_for_parameters( + self, + prefix: Optional[List[str]] = None, + suffix: Optional[List[str]] = None, + infix: Optional[List[str]] = None, + ): + """ + """ + for k, v in self.block.named_parameters(): + if prefix is not None and k.startswith(tuple(prefix)): + v.requires_grad = True + + if suffix is not None and k.endswith(tuple(suffix)): + v.requires_grad = True + + if infix is not None: + for per_infix in infix: + if k.find(per_infix) != -1: + v.requires_grad = True + + def forward(self, x): + return x + + +class AttentionSurgery(SelectiveSurgery): + """ + """ + def __init__(self, block: nn.Module): + super().__init__(block=block) + # Allow gradient updates for the attention layers in the image encoder. + self.allow_gradient_update_for_parameters(prefix=["attn"]) + + +class BiasSurgery(SelectiveSurgery): + """ + """ + def __init__(self, block: nn.Module): + super().__init__(block=block) + # Allow gradient updates for the bias parameters in the image encoder. + self.allow_gradient_update_for_parameters(suffix=["bias"]) + + +class LayerNormSurgery(SelectiveSurgery): + """ + """ + def __init__(self, block: nn.Module): + super().__init__(block=block) + # Allow gradient updates for the LayerNorm parameters in the image encoder. + self.allow_gradient_update_for_parameters(infix=["norm1", "norm2"]) + + class PEFT_Sam(nn.Module): """Wraps the Segment Anything model's image encoder to different parameter efficient finetuning methods. @@ -131,6 +180,7 @@ def __init__( super().__init__() assert rank > 0 + assert issubclass(peft_module, Union[LoRASurgery, FacTSurgery, SelectiveSurgery]), "Invalid PEFT module." if attention_layers_to_update: self.peft_layers = attention_layers_to_update @@ -149,7 +199,11 @@ def __init__( if t_layer_i not in self.peft_layers: continue - peft_block = self.peft_module(rank=rank, block=blk) + if issubclass(self.peft_module, SelectiveSurgery): + peft_block = self.peft_module(block=blk) + else: + peft_block = self.peft_module(rank=rank, block=blk) + self.peft_blocks.append(peft_block) self.peft_blocks = nn.ModuleList(self.peft_blocks) diff --git a/micro_sam/training/util.py b/micro_sam/training/util.py index fb4834c03..2fbe73a02 100644 --- a/micro_sam/training/util.py +++ b/micro_sam/training/util.py @@ -59,9 +59,7 @@ def get_trainable_sam_model( freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder By default nothing is frozen and the full model is updated. return_state: Whether to return the full checkpoint state. - lora_rank: The rank of the decomposition matrices for updating weights in each attention layer with lora. - If None then LoRA is not used. - lora_kwargs: Keyword arguments for the PEFT wrapper class. + left_kwargs: Keyword arguments for the PEFT wrapper class. flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints. model_kwargs: Additional keyword arguments for the `util.get_sam_model`. @@ -82,16 +80,11 @@ def get_trainable_sam_model( # NOTE: This is done exclusive to "get_sam_model" here to use PEFT's layer-specific initialization on top. # Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything. - # Overwrites the SAM model by freezing the backbone and allow low rank adaption to attention layers. + # Overwrites the SAM model by freezing the backbone and allow PEFT methods. if peft_kwargs and isinstance(peft_kwargs, dict): if model_type[:5] == "vit_t": raise ValueError("'micro-sam' does not support parameter efficient finetuning for 'mobile-sam'.") - peft_module = peft_kwargs.get("peft_module") - if peft_module is not None: - from micro_sam.models.peft_sam import LoRASurgery, FacTSurgery - assert peft_module in [LoRASurgery, FacTSurgery], "Invalid PEFT module." - sam = custom_models.peft_sam.PEFT_Sam(sam, **peft_kwargs).sam # freeze components of the model if freeze was passed @@ -106,9 +99,9 @@ def get_trainable_sam_model( # we would want to "freeze" all the components in the model if passed a list of parts for l_item in freeze: - # in case LoRA is switched on, we cannot freeze the image encoder - if (peft_kwargs['rank'] is not None) and (l_item == "image_encoder"): - raise ValueError("You cannot use LoRA & freeze the image encoder at the same time.") + # in case PEFT is switched on, we cannot freeze the image encoder + if (peft_kwargs and peft_kwargs.get('rank') is not None) and (l_item == "image_encoder"): + raise ValueError("You cannot use PEFT & freeze the image encoder at the same time.") if name.startswith(f"{l_item}"): param.requires_grad = False diff --git a/micro_sam/util.py b/micro_sam/util.py index a514f3d0a..07aed9a21 100644 --- a/micro_sam/util.py +++ b/micro_sam/util.py @@ -308,9 +308,7 @@ def get_sam_model( then `model_type` must be given as "vit_b". return_sam: Return the sam model object as well as the predictor. return_state: Return the unpickled checkpoint state. - lora_rank: The rank of the decomposition matrices for updating weights in each attention layer with lora. - If None then LoRA is not used. - lora_kwargs: Keyword arguments for th PEFT wrapper class. + peft_kwargs: Keyword arguments for th PEFT wrapper class. flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints. Returns: @@ -369,16 +367,11 @@ def get_sam_model( sam = sam_model_registry[abbreviated_model_type]() # Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything. - # Overwrites the SAM model by freezing the backbone and allow low rank adaption to attention layers. + # Overwrites the SAM model by freezing the backbone and allow PEFT. if peft_kwargs and isinstance(peft_kwargs, dict): if abbreviated_model_type == "vit_t": raise ValueError("'micro-sam' does not support parameter efficient finetuning for 'mobile-sam'.") - peft_module = peft_kwargs.get("peft_module") - if peft_module is not None: - from .models.peft_sam import LoRASurgery, FacTSurgery - assert peft_module in [LoRASurgery, FacTSurgery], "Invalid PEFT module." - sam = custom_models.peft_sam.PEFT_Sam(sam, **peft_kwargs).sam # In case the model checkpoints have some issues when it is initialized with different parameters than default. From 26203242e47d9c1bea07987104ddda2291ed5278 Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Tue, 1 Oct 2024 15:17:29 +0200 Subject: [PATCH 2/6] Update util.py --- micro_sam/training/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/micro_sam/training/util.py b/micro_sam/training/util.py index 2fbe73a02..7ecf41cd0 100644 --- a/micro_sam/training/util.py +++ b/micro_sam/training/util.py @@ -59,7 +59,7 @@ def get_trainable_sam_model( freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder By default nothing is frozen and the full model is updated. return_state: Whether to return the full checkpoint state. - left_kwargs: Keyword arguments for the PEFT wrapper class. + peft_kwargs: Keyword arguments for the PEFT wrapper class. flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints. model_kwargs: Additional keyword arguments for the `util.get_sam_model`. From 6fd6e5064d765619b7457350fe522b23062f5bee Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Tue, 1 Oct 2024 15:31:01 +0200 Subject: [PATCH 3/6] Add docstrings --- micro_sam/models/peft_sam.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/micro_sam/models/peft_sam.py b/micro_sam/models/peft_sam.py index 50a182f99..b955c413b 100644 --- a/micro_sam/models/peft_sam.py +++ b/micro_sam/models/peft_sam.py @@ -134,7 +134,7 @@ def forward(self, x): class AttentionSurgery(SelectiveSurgery): - """ + """Child class for allowing gradient updates for parameters in attention layers. """ def __init__(self, block: nn.Module): super().__init__(block=block) @@ -143,7 +143,7 @@ def __init__(self, block: nn.Module): class BiasSurgery(SelectiveSurgery): - """ + """Child class for allowing gradient updates for bias parameters. """ def __init__(self, block: nn.Module): super().__init__(block=block) @@ -152,7 +152,7 @@ def __init__(self, block: nn.Module): class LayerNormSurgery(SelectiveSurgery): - """ + """Child class for allowing gradient updates in normalization layers. """ def __init__(self, block: nn.Module): super().__init__(block=block) From 06deb4e4a2fcb7b1ab64b6c7ff5763cebee2bcde Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Tue, 1 Oct 2024 15:35:49 +0200 Subject: [PATCH 4/6] Add docstring for gradient update fn --- micro_sam/models/peft_sam.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/micro_sam/models/peft_sam.py b/micro_sam/models/peft_sam.py index b955c413b..4acfe9f4a 100644 --- a/micro_sam/models/peft_sam.py +++ b/micro_sam/models/peft_sam.py @@ -115,7 +115,12 @@ def allow_gradient_update_for_parameters( suffix: Optional[List[str]] = None, infix: Optional[List[str]] = None, ): - """ + """This function decides the parameter attributes to match for allowing gradient updates. + + Args: + prefix: Matches the part of parameter name in front. + suffix: Matches the part of parameter name at the end. + infix: Matches parts of parameter name occuring in between. """ for k, v in self.block.named_parameters(): if prefix is not None and k.startswith(tuple(prefix)): From 78a02b2d54a95cfb2b20c1a022b165a78e6ff2e2 Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Tue, 1 Oct 2024 15:58:14 +0200 Subject: [PATCH 5/6] Add tests for selective peft methods (#715) --- test/test_models/test_peft_sam.py | 42 +++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/test/test_models/test_peft_sam.py b/test/test_models/test_peft_sam.py index 509a67650..4461aa9b1 100644 --- a/test/test_models/test_peft_sam.py +++ b/test/test_models/test_peft_sam.py @@ -36,6 +36,48 @@ def test_fact_sam(self): masks = output[0]["masks"] self.assertEqual(masks.shape, expected_shape) + def test_attention_layer_peft_sam(self): + from micro_sam.models.peft_sam import PEFT_Sam, AttentionSurgery + + _, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu") + peft_sam = PEFT_Sam(sam, rank=2, peft_module=AttentionSurgery) + + shape = (3, 1024, 1024) + expected_shape = (1, 3, 1024, 1024) + with torch.no_grad(): + batched_input = [{"image": torch.rand(*shape), "original_size": shape[1:]}] + output = peft_sam(batched_input, multimask_output=True) + masks = output[0]["masks"] + self.assertEqual(masks.shape, expected_shape) + + def test_norm_layer_peft_sam(self): + from micro_sam.models.peft_sam import PEFT_Sam, LayerNormSurgery + + _, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu") + peft_sam = PEFT_Sam(sam, rank=2, peft_module=LayerNormSurgery) + + shape = (3, 1024, 1024) + expected_shape = (1, 3, 1024, 1024) + with torch.no_grad(): + batched_input = [{"image": torch.rand(*shape), "original_size": shape[1:]}] + output = peft_sam(batched_input, multimask_output=True) + masks = output[0]["masks"] + self.assertEqual(masks.shape, expected_shape) + + def test_bias_layer_peft_sam(self): + from micro_sam.models.peft_sam import PEFT_Sam, BiasSurgery + + _, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu") + peft_sam = PEFT_Sam(sam, rank=2, peft_module=BiasSurgery) + + shape = (3, 1024, 1024) + expected_shape = (1, 3, 1024, 1024) + with torch.no_grad(): + batched_input = [{"image": torch.rand(*shape), "original_size": shape[1:]}] + output = peft_sam(batched_input, multimask_output=True) + masks = output[0]["masks"] + self.assertEqual(masks.shape, expected_shape) + if __name__ == "__main__": unittest.main() From 2940b8716c411ab933d695fc79fcfd87ff82dbbf Mon Sep 17 00:00:00 2001 From: anwai98 Date: Tue, 1 Oct 2024 16:30:59 +0200 Subject: [PATCH 6/6] Add dropout arg to docstring --- micro_sam/models/peft_sam.py | 1 + 1 file changed, 1 insertion(+) diff --git a/micro_sam/models/peft_sam.py b/micro_sam/models/peft_sam.py index 838e650ae..634a2fa9b 100644 --- a/micro_sam/models/peft_sam.py +++ b/micro_sam/models/peft_sam.py @@ -60,6 +60,7 @@ class FacTSurgery(nn.Module): Args: rank: The rank of the decomposition matrices for updating weights in each attention layer. block: The chosen attention blocks for implementing fact. + dropout: The dropout rate for dropout layers. """ def __init__( self,