From 8685f753a63691d50fefa226b62424a1f3ee8000 Mon Sep 17 00:00:00 2001 From: Carolin Date: Tue, 8 Oct 2024 14:25:00 +0200 Subject: [PATCH] added implementation of ssf peft method --- micro_sam/models/peft_sam.py | 54 +++++++++++++++++++++++++++---- test/test_models/test_peft_sam.py | 14 ++++++++ 2 files changed, 62 insertions(+), 6 deletions(-) diff --git a/micro_sam/models/peft_sam.py b/micro_sam/models/peft_sam.py index febbccf6b..3d03fe162 100644 --- a/micro_sam/models/peft_sam.py +++ b/micro_sam/models/peft_sam.py @@ -1,6 +1,7 @@ import math from typing import List, Union, Optional +import torch import torch.nn as nn from segment_anything.modeling import Sam @@ -123,7 +124,7 @@ def allow_gradient_update_for_parameters( Args: prefix: Matches the part of parameter name in front. suffix: Matches the part of parameter name at the end. - infix: Matches parts of parameter name occuring in between. + infix: Matches parts of parameter name occuring in between. """ for k, v in self.block.named_parameters(): if prefix is not None and k.startswith(tuple(prefix)): @@ -168,6 +169,28 @@ def __init__(self, block: nn.Module): self.allow_gradient_update_for_parameters(infix=["norm1", "norm2"]) +class SSFSurgery(nn.Module): + + def __init__(self, layer, dim): + + super().__init__() + self.layer = layer + self.scale = nn.Parameter(torch.normal(mean=1.0, std=0.2, size=(dim,))) + self.shift = nn.Parameter(torch.normal(mean=0.0, std=0.2, size=(dim,))) + layer = self + + def forward(self, x): + x = self.layer(x) + + assert self.scale.shape == self.shift.shape + if x.shape[-1] == self.ssf_scale.shape[0]: + return x * self.scale + self.shift + elif x.shape[1] == self.scale.shape[0]: + return x * self.scale.view(1, -1, 1, 1) + self.shift.view(1, -1, 1, 1) + else: + raise ValueError('the input tensor shape does not match the shape of the scale factor.') + + class PEFT_Sam(nn.Module): """Wraps the Segment Anything model's image encoder to different parameter efficient finetuning methods. @@ -189,7 +212,8 @@ def __init__( super().__init__() assert rank > 0 - assert issubclass(peft_module, Union[LoRASurgery, FacTSurgery, SelectiveSurgery]), "Invalid PEFT module." + assert issubclass(peft_module, Union[LoRASurgery, FacTSurgery, SelectiveSurgery, SSFSurgery]), ( + "Invalid PEFT module") if attention_layers_to_update: self.peft_layers = attention_layers_to_update @@ -203,21 +227,39 @@ def __init__( for param in model.image_encoder.parameters(): param.requires_grad = False + # if peft method is SSF, add SSF to the embedding layers + if issubclass(self.peft_module, SSFSurgery): + self.peft_blocks.append(self.peft_module(model.image_encoder.patch_embed.proj, + model.image_encoder.patch_embed.proj.out_channels)) + for t_layer_i, blk in enumerate(model.image_encoder.blocks): # If we only want specific layers with PEFT instead of all if t_layer_i not in self.peft_layers: continue if issubclass(self.peft_module, SelectiveSurgery): - peft_block = self.peft_module(block=blk) + self.peft_blocks.append(self.peft_module(block=blk)) + elif issubclass(self.peft_module, SSFSurgery): + self.peft_blocks.extend(self.add_scale_shift(blk)) else: - peft_block = self.peft_module(rank=rank, block=blk, **module_kwargs) - - self.peft_blocks.append(peft_block) + self.peft_blocks.append(self.peft_module(rank=rank, block=blk, **module_kwargs)) self.peft_blocks = nn.ModuleList(self.peft_blocks) self.sam = model + def add_scale_shift(self, blk): + """Add the scale an shift surgery after every operation (qkv, projection, mlp, norm)""" + peft_blocks = [] + + peft_blocks.append(SSFSurgery(blk.attn.qkv, blk.attn.qkv.in_features)) + peft_blocks.append(SSFSurgery(blk.attn.proj, blk.attn.proj.in_features)) + peft_blocks.append(SSFSurgery(blk.mlp.lin1, blk.mlp.lin1.in_features)) + peft_blocks.append(SSFSurgery(blk.mlp.lin2, blk.mlp.lin2.in_features)) + peft_blocks.append(SSFSurgery(blk.norm1, blk.norm1.normalized_shape[0])) + peft_blocks.append(SSFSurgery(blk.norm2, blk.norm2.normalized_shape[0])) + + return peft_blocks + def forward(self, batched_input, multimask_output): return self.sam(batched_input, multimask_output) diff --git a/test/test_models/test_peft_sam.py b/test/test_models/test_peft_sam.py index 4461aa9b1..4e62fb733 100644 --- a/test/test_models/test_peft_sam.py +++ b/test/test_models/test_peft_sam.py @@ -78,6 +78,20 @@ def test_bias_layer_peft_sam(self): masks = output[0]["masks"] self.assertEqual(masks.shape, expected_shape) + def test_ssf_peft_sam(self): + from micro_sam.models.peft_sam import PEFT_Sam, SSFSurgery + + _, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu") + peft_sam = PEFT_Sam(sam, rank=2, peft_module=SSFSurgery) + + shape = (3, 1024, 1024) + expected_shape = (1, 3, 1024, 1024) + with torch.no_grad(): + batched_input = [{"image": torch.rand(*shape), "original_size": shape[1:]}] + output = peft_sam(batched_input, multimask_output=True) + masks = output[0]["masks"] + self.assertEqual(masks.shape, expected_shape) + if __name__ == "__main__": unittest.main()