Skip to content

Commit

Permalink
added implementation of ssf peft method
Browse files Browse the repository at this point in the history
  • Loading branch information
caroteu committed Oct 8, 2024
1 parent cda4f66 commit 8685f75
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 6 deletions.
54 changes: 48 additions & 6 deletions micro_sam/models/peft_sam.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import math
from typing import List, Union, Optional

import torch
import torch.nn as nn

from segment_anything.modeling import Sam
Expand Down Expand Up @@ -123,7 +124,7 @@ def allow_gradient_update_for_parameters(
Args:
prefix: Matches the part of parameter name in front.
suffix: Matches the part of parameter name at the end.
infix: Matches parts of parameter name occuring in between.
infix: Matches parts of parameter name occuring in between.
"""
for k, v in self.block.named_parameters():
if prefix is not None and k.startswith(tuple(prefix)):
Expand Down Expand Up @@ -168,6 +169,28 @@ def __init__(self, block: nn.Module):
self.allow_gradient_update_for_parameters(infix=["norm1", "norm2"])


class SSFSurgery(nn.Module):

def __init__(self, layer, dim):

super().__init__()
self.layer = layer
self.scale = nn.Parameter(torch.normal(mean=1.0, std=0.2, size=(dim,)))
self.shift = nn.Parameter(torch.normal(mean=0.0, std=0.2, size=(dim,)))
layer = self

def forward(self, x):
x = self.layer(x)

assert self.scale.shape == self.shift.shape
if x.shape[-1] == self.ssf_scale.shape[0]:
return x * self.scale + self.shift
elif x.shape[1] == self.scale.shape[0]:
return x * self.scale.view(1, -1, 1, 1) + self.shift.view(1, -1, 1, 1)
else:
raise ValueError('the input tensor shape does not match the shape of the scale factor.')


class PEFT_Sam(nn.Module):
"""Wraps the Segment Anything model's image encoder to different parameter efficient finetuning methods.
Expand All @@ -189,7 +212,8 @@ def __init__(
super().__init__()

assert rank > 0
assert issubclass(peft_module, Union[LoRASurgery, FacTSurgery, SelectiveSurgery]), "Invalid PEFT module."
assert issubclass(peft_module, Union[LoRASurgery, FacTSurgery, SelectiveSurgery, SSFSurgery]), (
"Invalid PEFT module")

if attention_layers_to_update:
self.peft_layers = attention_layers_to_update
Expand All @@ -203,21 +227,39 @@ def __init__(
for param in model.image_encoder.parameters():
param.requires_grad = False

# if peft method is SSF, add SSF to the embedding layers
if issubclass(self.peft_module, SSFSurgery):
self.peft_blocks.append(self.peft_module(model.image_encoder.patch_embed.proj,
model.image_encoder.patch_embed.proj.out_channels))

for t_layer_i, blk in enumerate(model.image_encoder.blocks):
# If we only want specific layers with PEFT instead of all
if t_layer_i not in self.peft_layers:
continue

if issubclass(self.peft_module, SelectiveSurgery):
peft_block = self.peft_module(block=blk)
self.peft_blocks.append(self.peft_module(block=blk))
elif issubclass(self.peft_module, SSFSurgery):
self.peft_blocks.extend(self.add_scale_shift(blk))
else:
peft_block = self.peft_module(rank=rank, block=blk, **module_kwargs)

self.peft_blocks.append(peft_block)
self.peft_blocks.append(self.peft_module(rank=rank, block=blk, **module_kwargs))

self.peft_blocks = nn.ModuleList(self.peft_blocks)

self.sam = model

def add_scale_shift(self, blk):
"""Add the scale an shift surgery after every operation (qkv, projection, mlp, norm)"""
peft_blocks = []

peft_blocks.append(SSFSurgery(blk.attn.qkv, blk.attn.qkv.in_features))
peft_blocks.append(SSFSurgery(blk.attn.proj, blk.attn.proj.in_features))
peft_blocks.append(SSFSurgery(blk.mlp.lin1, blk.mlp.lin1.in_features))
peft_blocks.append(SSFSurgery(blk.mlp.lin2, blk.mlp.lin2.in_features))
peft_blocks.append(SSFSurgery(blk.norm1, blk.norm1.normalized_shape[0]))
peft_blocks.append(SSFSurgery(blk.norm2, blk.norm2.normalized_shape[0]))

return peft_blocks

def forward(self, batched_input, multimask_output):
return self.sam(batched_input, multimask_output)
14 changes: 14 additions & 0 deletions test/test_models/test_peft_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,20 @@ def test_bias_layer_peft_sam(self):
masks = output[0]["masks"]
self.assertEqual(masks.shape, expected_shape)

def test_ssf_peft_sam(self):
from micro_sam.models.peft_sam import PEFT_Sam, SSFSurgery

_, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu")
peft_sam = PEFT_Sam(sam, rank=2, peft_module=SSFSurgery)

shape = (3, 1024, 1024)
expected_shape = (1, 3, 1024, 1024)
with torch.no_grad():
batched_input = [{"image": torch.rand(*shape), "original_size": shape[1:]}]
output = peft_sam(batched_input, multimask_output=True)
masks = output[0]["masks"]
self.assertEqual(masks.shape, expected_shape)


if __name__ == "__main__":
unittest.main()

0 comments on commit 8685f75

Please sign in to comment.