From 78a02b2d54a95cfb2b20c1a022b165a78e6ff2e2 Mon Sep 17 00:00:00 2001 From: Anwai Archit <52396323+anwai98@users.noreply.github.com> Date: Tue, 1 Oct 2024 15:58:14 +0200 Subject: [PATCH] Add tests for selective peft methods (#715) --- test/test_models/test_peft_sam.py | 42 +++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/test/test_models/test_peft_sam.py b/test/test_models/test_peft_sam.py index 509a6765..4461aa9b 100644 --- a/test/test_models/test_peft_sam.py +++ b/test/test_models/test_peft_sam.py @@ -36,6 +36,48 @@ def test_fact_sam(self): masks = output[0]["masks"] self.assertEqual(masks.shape, expected_shape) + def test_attention_layer_peft_sam(self): + from micro_sam.models.peft_sam import PEFT_Sam, AttentionSurgery + + _, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu") + peft_sam = PEFT_Sam(sam, rank=2, peft_module=AttentionSurgery) + + shape = (3, 1024, 1024) + expected_shape = (1, 3, 1024, 1024) + with torch.no_grad(): + batched_input = [{"image": torch.rand(*shape), "original_size": shape[1:]}] + output = peft_sam(batched_input, multimask_output=True) + masks = output[0]["masks"] + self.assertEqual(masks.shape, expected_shape) + + def test_norm_layer_peft_sam(self): + from micro_sam.models.peft_sam import PEFT_Sam, LayerNormSurgery + + _, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu") + peft_sam = PEFT_Sam(sam, rank=2, peft_module=LayerNormSurgery) + + shape = (3, 1024, 1024) + expected_shape = (1, 3, 1024, 1024) + with torch.no_grad(): + batched_input = [{"image": torch.rand(*shape), "original_size": shape[1:]}] + output = peft_sam(batched_input, multimask_output=True) + masks = output[0]["masks"] + self.assertEqual(masks.shape, expected_shape) + + def test_bias_layer_peft_sam(self): + from micro_sam.models.peft_sam import PEFT_Sam, BiasSurgery + + _, sam = util.get_sam_model(model_type=self.model_type, return_sam=True, device="cpu") + peft_sam = PEFT_Sam(sam, rank=2, peft_module=BiasSurgery) + + shape = (3, 1024, 1024) + expected_shape = (1, 3, 1024, 1024) + with torch.no_grad(): + batched_input = [{"image": torch.rand(*shape), "original_size": shape[1:]}] + output = peft_sam(batched_input, multimask_output=True) + masks = output[0]["masks"] + self.assertEqual(masks.shape, expected_shape) + if __name__ == "__main__": unittest.main()