Skip to content

Commit

Permalink
Add selective peft methods (#708)
Browse files Browse the repository at this point in the history
Add selective peft methods (eg. attention ft, bias ft, layernorm ft)
  • Loading branch information
anwai98 authored Oct 1, 2024
1 parent e1bf659 commit a8af9c4
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 28 deletions.
78 changes: 71 additions & 7 deletions micro_sam/models/peft_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,7 @@ class LoRASurgery(nn.Module):
rank: The rank of the decomposition matrices for updating weights in each attention layer.
block: The chosen attention blocks for implementing lora.
"""
def __init__(
self,
rank: int,
block: nn.Module,
):
def __init__(self, rank: int, block: nn.Module):
super().__init__()
self.qkv_proj = block.attn.qkv
self.dim = self.qkv_proj.in_features
Expand Down Expand Up @@ -64,8 +60,8 @@ class FacTSurgery(nn.Module):
Args:
rank: The rank of the decomposition matrices for updating weights in each attention layer.
block: The chosen attention blocks for implementing fact.
dropout: The dropout rate for dropout layers.
"""

def __init__(
self,
rank: int,
Expand Down Expand Up @@ -110,6 +106,69 @@ def forward(self, x):
return qkv


class SelectiveSurgery(nn.Module):
"""Base class for selectively allowing gradient updates for certain parameters.
"""
def __init__(self, block: nn.Module):
super().__init__()
self.block = block

def allow_gradient_update_for_parameters(
self,
prefix: Optional[List[str]] = None,
suffix: Optional[List[str]] = None,
infix: Optional[List[str]] = None,
):
"""This function decides the parameter attributes to match for allowing gradient updates.
Args:
prefix: Matches the part of parameter name in front.
suffix: Matches the part of parameter name at the end.
infix: Matches parts of parameter name occuring in between.
"""
for k, v in self.block.named_parameters():
if prefix is not None and k.startswith(tuple(prefix)):
v.requires_grad = True

if suffix is not None and k.endswith(tuple(suffix)):
v.requires_grad = True

if infix is not None:
for per_infix in infix:
if k.find(per_infix) != -1:
v.requires_grad = True

def forward(self, x):
return x


class AttentionSurgery(SelectiveSurgery):
"""Child class for allowing gradient updates for parameters in attention layers.
"""
def __init__(self, block: nn.Module):
super().__init__(block=block)
# Allow gradient updates for the attention layers in the image encoder.
self.allow_gradient_update_for_parameters(prefix=["attn"])


class BiasSurgery(SelectiveSurgery):
"""Child class for allowing gradient updates for bias parameters.
"""
def __init__(self, block: nn.Module):
super().__init__(block=block)
# Allow gradient updates for the bias parameters in the image encoder.
self.allow_gradient_update_for_parameters(suffix=["bias"])


class LayerNormSurgery(SelectiveSurgery):
"""Child class for allowing gradient updates in normalization layers.
"""
def __init__(self, block: nn.Module):
super().__init__(block=block)
# Allow gradient updates for the LayerNorm parameters in the image encoder.
self.allow_gradient_update_for_parameters(infix=["norm1", "norm2"])


class PEFT_Sam(nn.Module):
"""Wraps the Segment Anything model's image encoder to different parameter efficient finetuning methods.
Expand All @@ -130,6 +189,7 @@ def __init__(
super().__init__()

assert rank > 0
assert issubclass(peft_module, Union[LoRASurgery, FacTSurgery, SelectiveSurgery]), "Invalid PEFT module."

if attention_layers_to_update:
self.peft_layers = attention_layers_to_update
Expand All @@ -148,7 +208,11 @@ def __init__(
if t_layer_i not in self.peft_layers:
continue

peft_block = self.peft_module(rank=rank, block=blk)
if issubclass(self.peft_module, SelectiveSurgery):
peft_block = self.peft_module(block=blk)
else:
peft_block = self.peft_module(rank=rank, block=blk)

self.peft_blocks.append(peft_block)

self.peft_blocks = nn.ModuleList(self.peft_blocks)
Expand Down
17 changes: 5 additions & 12 deletions micro_sam/training/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,7 @@ def get_trainable_sam_model(
freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder
By default nothing is frozen and the full model is updated.
return_state: Whether to return the full checkpoint state.
lora_rank: The rank of the decomposition matrices for updating weights in each attention layer with lora.
If None then LoRA is not used.
lora_kwargs: Keyword arguments for the PEFT wrapper class.
peft_kwargs: Keyword arguments for the PEFT wrapper class.
flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints.
model_kwargs: Additional keyword arguments for the `util.get_sam_model`.
Expand All @@ -82,16 +80,11 @@ def get_trainable_sam_model(

# NOTE: This is done exclusive to "get_sam_model" here to use PEFT's layer-specific initialization on top.
# Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything.
# Overwrites the SAM model by freezing the backbone and allow low rank adaption to attention layers.
# Overwrites the SAM model by freezing the backbone and allow PEFT methods.
if peft_kwargs and isinstance(peft_kwargs, dict):
if model_type[:5] == "vit_t":
raise ValueError("'micro-sam' does not support parameter efficient finetuning for 'mobile-sam'.")

peft_module = peft_kwargs.get("peft_module")
if peft_module is not None:
from micro_sam.models.peft_sam import LoRASurgery, FacTSurgery
assert peft_module in [LoRASurgery, FacTSurgery], "Invalid PEFT module."

sam = custom_models.peft_sam.PEFT_Sam(sam, **peft_kwargs).sam

# freeze components of the model if freeze was passed
Expand All @@ -106,9 +99,9 @@ def get_trainable_sam_model(

# we would want to "freeze" all the components in the model if passed a list of parts
for l_item in freeze:
# in case LoRA is switched on, we cannot freeze the image encoder
if (peft_kwargs['rank'] is not None) and (l_item == "image_encoder"):
raise ValueError("You cannot use LoRA & freeze the image encoder at the same time.")
# in case PEFT is switched on, we cannot freeze the image encoder
if (peft_kwargs and peft_kwargs.get('rank') is not None) and (l_item == "image_encoder"):
raise ValueError("You cannot use PEFT & freeze the image encoder at the same time.")

if name.startswith(f"{l_item}"):
param.requires_grad = False
Expand Down
11 changes: 2 additions & 9 deletions micro_sam/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,9 +308,7 @@ def get_sam_model(
then `model_type` must be given as "vit_b".
return_sam: Return the sam model object as well as the predictor.
return_state: Return the unpickled checkpoint state.
lora_rank: The rank of the decomposition matrices for updating weights in each attention layer with lora.
If None then LoRA is not used.
lora_kwargs: Keyword arguments for th PEFT wrapper class.
peft_kwargs: Keyword arguments for th PEFT wrapper class.
flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints.
Returns:
Expand Down Expand Up @@ -369,16 +367,11 @@ def get_sam_model(
sam = sam_model_registry[abbreviated_model_type]()

# Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything.
# Overwrites the SAM model by freezing the backbone and allow low rank adaption to attention layers.
# Overwrites the SAM model by freezing the backbone and allow PEFT.
if peft_kwargs and isinstance(peft_kwargs, dict):
if abbreviated_model_type == "vit_t":
raise ValueError("'micro-sam' does not support parameter efficient finetuning for 'mobile-sam'.")

peft_module = peft_kwargs.get("peft_module")
if peft_module is not None:
from .models.peft_sam import LoRASurgery, FacTSurgery
assert peft_module in [LoRASurgery, FacTSurgery], "Invalid PEFT module."

sam = custom_models.peft_sam.PEFT_Sam(sam, **peft_kwargs).sam

# In case the model checkpoints have some issues when it is initialized with different parameters than default.
Expand Down
42 changes: 42 additions & 0 deletions test/test_models/test_peft_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,48 @@ def test_fact_sam(self):
masks = output[0]["masks"]
self.assertEqual(masks.shape, expected_shape)

def test_attention_layer_peft_sam(self):
from micro_sam.models.peft_sam import PEFT_Sam, AttentionSurgery

_, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu")
peft_sam = PEFT_Sam(sam, rank=2, peft_module=AttentionSurgery)

shape = (3, 1024, 1024)
expected_shape = (1, 3, 1024, 1024)
with torch.no_grad():
batched_input = [{"image": torch.rand(*shape), "original_size": shape[1:]}]
output = peft_sam(batched_input, multimask_output=True)
masks = output[0]["masks"]
self.assertEqual(masks.shape, expected_shape)

def test_norm_layer_peft_sam(self):
from micro_sam.models.peft_sam import PEFT_Sam, LayerNormSurgery

_, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu")
peft_sam = PEFT_Sam(sam, rank=2, peft_module=LayerNormSurgery)

shape = (3, 1024, 1024)
expected_shape = (1, 3, 1024, 1024)
with torch.no_grad():
batched_input = [{"image": torch.rand(*shape), "original_size": shape[1:]}]
output = peft_sam(batched_input, multimask_output=True)
masks = output[0]["masks"]
self.assertEqual(masks.shape, expected_shape)

def test_bias_layer_peft_sam(self):
from micro_sam.models.peft_sam import PEFT_Sam, BiasSurgery

_, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu")
peft_sam = PEFT_Sam(sam, rank=2, peft_module=BiasSurgery)

shape = (3, 1024, 1024)
expected_shape = (1, 3, 1024, 1024)
with torch.no_grad():
batched_input = [{"image": torch.rand(*shape), "original_size": shape[1:]}]
output = peft_sam(batched_input, multimask_output=True)
masks = output[0]["masks"]
self.assertEqual(masks.shape, expected_shape)


if __name__ == "__main__":
unittest.main()

0 comments on commit a8af9c4

Please sign in to comment.