Skip to content

Commit

Permalink
Add tests for selective peft methods
Browse files Browse the repository at this point in the history
  • Loading branch information
anwai98 committed Oct 1, 2024
1 parent 6bc6b0d commit 73f77a5
Showing 1 changed file with 42 additions and 0 deletions.
42 changes: 42 additions & 0 deletions test/test_models/test_peft_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,48 @@ def test_fact_sam(self):
masks = output[0]["masks"]
self.assertEqual(masks.shape, expected_shape)

def test_attention_layer_peft_sam(self):
from micro_sam.models.peft_sam import PEFT_Sam, AttentionSurgery

_, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu")
peft_sam = PEFT_Sam(sam, rank=2, peft_module=AttentionSurgery)

shape = (3, 1024, 1024)
expected_shape = (1, 3, 1024, 1024)
with torch.no_grad():
batched_input = [{"image": torch.rand(*shape), "original_size": shape[1:]}]
output = peft_sam(batched_input, multimask_output=True)
masks = output[0]["masks"]
self.assertEqual(masks.shape, expected_shape)

def test_norm_layer_peft_sam(self):
from micro_sam.models.peft_sam import PEFT_Sam, LayerNormSurgery

_, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu")
peft_sam = PEFT_Sam(sam, rank=2, peft_module=LayerNormSurgery)

shape = (3, 1024, 1024)
expected_shape = (1, 3, 1024, 1024)
with torch.no_grad():
batched_input = [{"image": torch.rand(*shape), "original_size": shape[1:]}]
output = peft_sam(batched_input, multimask_output=True)
masks = output[0]["masks"]
self.assertEqual(masks.shape, expected_shape)

def test_bias_layer_peft_sam(self):
from micro_sam.models.peft_sam import PEFT_Sam, BiasSurgery

_, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu")
peft_sam = PEFT_Sam(sam, rank=2, peft_module=BiasSurgery)

shape = (3, 1024, 1024)
expected_shape = (1, 3, 1024, 1024)
with torch.no_grad():
batched_input = [{"image": torch.rand(*shape), "original_size": shape[1:]}]
output = peft_sam(batched_input, multimask_output=True)
masks = output[0]["masks"]
self.assertEqual(masks.shape, expected_shape)


if __name__ == "__main__":
unittest.main()

0 comments on commit 73f77a5

Please sign in to comment.