Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add selective peft methods #708

Merged
merged 8 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 71 additions & 7 deletions micro_sam/models/peft_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,7 @@ class LoRASurgery(nn.Module):
rank: The rank of the decomposition matrices for updating weights in each attention layer.
block: The chosen attention blocks for implementing lora.
"""
def __init__(
self,
rank: int,
block: nn.Module,
):
def __init__(self, rank: int, block: nn.Module):
super().__init__()
self.qkv_proj = block.attn.qkv
self.dim = self.qkv_proj.in_features
Expand Down Expand Up @@ -64,8 +60,8 @@ class FacTSurgery(nn.Module):
Args:
rank: The rank of the decomposition matrices for updating weights in each attention layer.
block: The chosen attention blocks for implementing fact.
dropout: The dropout rate for dropout layers.
"""

def __init__(
self,
rank: int,
Expand Down Expand Up @@ -110,6 +106,69 @@ def forward(self, x):
return qkv


class SelectiveSurgery(nn.Module):
"""Base class for selectively allowing gradient updates for certain parameters.
"""
def __init__(self, block: nn.Module):
super().__init__()
self.block = block

def allow_gradient_update_for_parameters(
self,
prefix: Optional[List[str]] = None,
suffix: Optional[List[str]] = None,
infix: Optional[List[str]] = None,
):
"""This function decides the parameter attributes to match for allowing gradient updates.

Args:
prefix: Matches the part of parameter name in front.
suffix: Matches the part of parameter name at the end.
infix: Matches parts of parameter name occuring in between.
"""
for k, v in self.block.named_parameters():
if prefix is not None and k.startswith(tuple(prefix)):
v.requires_grad = True

if suffix is not None and k.endswith(tuple(suffix)):
v.requires_grad = True

if infix is not None:
for per_infix in infix:
if k.find(per_infix) != -1:
v.requires_grad = True

def forward(self, x):
return x


class AttentionSurgery(SelectiveSurgery):
"""Child class for allowing gradient updates for parameters in attention layers.
"""
def __init__(self, block: nn.Module):
super().__init__(block=block)
# Allow gradient updates for the attention layers in the image encoder.
self.allow_gradient_update_for_parameters(prefix=["attn"])


class BiasSurgery(SelectiveSurgery):
"""Child class for allowing gradient updates for bias parameters.
"""
def __init__(self, block: nn.Module):
super().__init__(block=block)
# Allow gradient updates for the bias parameters in the image encoder.
self.allow_gradient_update_for_parameters(suffix=["bias"])


class LayerNormSurgery(SelectiveSurgery):
"""Child class for allowing gradient updates in normalization layers.
"""
def __init__(self, block: nn.Module):
super().__init__(block=block)
# Allow gradient updates for the LayerNorm parameters in the image encoder.
self.allow_gradient_update_for_parameters(infix=["norm1", "norm2"])


class PEFT_Sam(nn.Module):
"""Wraps the Segment Anything model's image encoder to different parameter efficient finetuning methods.

Expand All @@ -130,6 +189,7 @@ def __init__(
super().__init__()

assert rank > 0
assert issubclass(peft_module, Union[LoRASurgery, FacTSurgery, SelectiveSurgery]), "Invalid PEFT module."

if attention_layers_to_update:
self.peft_layers = attention_layers_to_update
Expand All @@ -148,7 +208,11 @@ def __init__(
if t_layer_i not in self.peft_layers:
continue

peft_block = self.peft_module(rank=rank, block=blk)
if issubclass(self.peft_module, SelectiveSurgery):
peft_block = self.peft_module(block=blk)
else:
peft_block = self.peft_module(rank=rank, block=blk)

self.peft_blocks.append(peft_block)

self.peft_blocks = nn.ModuleList(self.peft_blocks)
Expand Down
17 changes: 5 additions & 12 deletions micro_sam/training/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,7 @@ def get_trainable_sam_model(
freeze: Specify parts of the model that should be frozen, namely: image_encoder, prompt_encoder and mask_decoder
By default nothing is frozen and the full model is updated.
return_state: Whether to return the full checkpoint state.
lora_rank: The rank of the decomposition matrices for updating weights in each attention layer with lora.
If None then LoRA is not used.
lora_kwargs: Keyword arguments for the PEFT wrapper class.
peft_kwargs: Keyword arguments for the PEFT wrapper class.
flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints.
model_kwargs: Additional keyword arguments for the `util.get_sam_model`.

Expand All @@ -82,16 +80,11 @@ def get_trainable_sam_model(

# NOTE: This is done exclusive to "get_sam_model" here to use PEFT's layer-specific initialization on top.
# Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything.
# Overwrites the SAM model by freezing the backbone and allow low rank adaption to attention layers.
# Overwrites the SAM model by freezing the backbone and allow PEFT methods.
if peft_kwargs and isinstance(peft_kwargs, dict):
if model_type[:5] == "vit_t":
raise ValueError("'micro-sam' does not support parameter efficient finetuning for 'mobile-sam'.")

peft_module = peft_kwargs.get("peft_module")
if peft_module is not None:
from micro_sam.models.peft_sam import LoRASurgery, FacTSurgery
assert peft_module in [LoRASurgery, FacTSurgery], "Invalid PEFT module."

sam = custom_models.peft_sam.PEFT_Sam(sam, **peft_kwargs).sam

# freeze components of the model if freeze was passed
Expand All @@ -106,9 +99,9 @@ def get_trainable_sam_model(

# we would want to "freeze" all the components in the model if passed a list of parts
for l_item in freeze:
# in case LoRA is switched on, we cannot freeze the image encoder
if (peft_kwargs['rank'] is not None) and (l_item == "image_encoder"):
raise ValueError("You cannot use LoRA & freeze the image encoder at the same time.")
# in case PEFT is switched on, we cannot freeze the image encoder
if (peft_kwargs and peft_kwargs.get('rank') is not None) and (l_item == "image_encoder"):
raise ValueError("You cannot use PEFT & freeze the image encoder at the same time.")

if name.startswith(f"{l_item}"):
param.requires_grad = False
Expand Down
11 changes: 2 additions & 9 deletions micro_sam/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,9 +308,7 @@ def get_sam_model(
then `model_type` must be given as "vit_b".
return_sam: Return the sam model object as well as the predictor.
return_state: Return the unpickled checkpoint state.
lora_rank: The rank of the decomposition matrices for updating weights in each attention layer with lora.
If None then LoRA is not used.
lora_kwargs: Keyword arguments for th PEFT wrapper class.
peft_kwargs: Keyword arguments for th PEFT wrapper class.
flexible_load_checkpoint: Whether to adjust mismatching params while loading pretrained checkpoints.

Returns:
Expand Down Expand Up @@ -369,16 +367,11 @@ def get_sam_model(
sam = sam_model_registry[abbreviated_model_type]()

# Whether to use Parameter Efficient Finetuning methods to wrap around Segment Anything.
# Overwrites the SAM model by freezing the backbone and allow low rank adaption to attention layers.
# Overwrites the SAM model by freezing the backbone and allow PEFT.
if peft_kwargs and isinstance(peft_kwargs, dict):
if abbreviated_model_type == "vit_t":
raise ValueError("'micro-sam' does not support parameter efficient finetuning for 'mobile-sam'.")

peft_module = peft_kwargs.get("peft_module")
if peft_module is not None:
from .models.peft_sam import LoRASurgery, FacTSurgery
assert peft_module in [LoRASurgery, FacTSurgery], "Invalid PEFT module."

sam = custom_models.peft_sam.PEFT_Sam(sam, **peft_kwargs).sam

# In case the model checkpoints have some issues when it is initialized with different parameters than default.
Expand Down
42 changes: 42 additions & 0 deletions test/test_models/test_peft_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,48 @@ def test_fact_sam(self):
masks = output[0]["masks"]
self.assertEqual(masks.shape, expected_shape)

def test_attention_layer_peft_sam(self):
from micro_sam.models.peft_sam import PEFT_Sam, AttentionSurgery

_, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu")
peft_sam = PEFT_Sam(sam, rank=2, peft_module=AttentionSurgery)

shape = (3, 1024, 1024)
expected_shape = (1, 3, 1024, 1024)
with torch.no_grad():
batched_input = [{"image": torch.rand(*shape), "original_size": shape[1:]}]
output = peft_sam(batched_input, multimask_output=True)
masks = output[0]["masks"]
self.assertEqual(masks.shape, expected_shape)

def test_norm_layer_peft_sam(self):
from micro_sam.models.peft_sam import PEFT_Sam, LayerNormSurgery

_, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu")
peft_sam = PEFT_Sam(sam, rank=2, peft_module=LayerNormSurgery)

shape = (3, 1024, 1024)
expected_shape = (1, 3, 1024, 1024)
with torch.no_grad():
batched_input = [{"image": torch.rand(*shape), "original_size": shape[1:]}]
output = peft_sam(batched_input, multimask_output=True)
masks = output[0]["masks"]
self.assertEqual(masks.shape, expected_shape)

def test_bias_layer_peft_sam(self):
from micro_sam.models.peft_sam import PEFT_Sam, BiasSurgery

_, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu")
peft_sam = PEFT_Sam(sam, rank=2, peft_module=BiasSurgery)

shape = (3, 1024, 1024)
expected_shape = (1, 3, 1024, 1024)
with torch.no_grad():
batched_input = [{"image": torch.rand(*shape), "original_size": shape[1:]}]
output = peft_sam(batched_input, multimask_output=True)
masks = output[0]["masks"]
self.assertEqual(masks.shape, expected_shape)


if __name__ == "__main__":
unittest.main()
Loading