From 63314bd74ba60e2af8c2ab248f1b2468927da6ca Mon Sep 17 00:00:00 2001 From: anwai98 Date: Sat, 21 Dec 2024 23:44:31 +0100 Subject: [PATCH] Remove hack comments --- finetuning/livecell_finetuning.py | 2 +- micro_sam/models/peft_sam.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/finetuning/livecell_finetuning.py b/finetuning/livecell_finetuning.py index a6594f78..26942112 100644 --- a/finetuning/livecell_finetuning.py +++ b/finetuning/livecell_finetuning.py @@ -77,7 +77,7 @@ def finetune_livecell(args): save_root=args.save_root, scheduler_kwargs=scheduler_kwargs, save_every_kth_epoch=args.save_every_kth_epoch, - peft_kwargs={"rank": args.lora_rank} if args.lora_rank is not None else None, + peft_kwargs={"rank": args.lora_rank, "quantize": True} if args.lora_rank is not None else None, mixed_precision=False, ) diff --git a/micro_sam/models/peft_sam.py b/micro_sam/models/peft_sam.py index 32e9ee79..7b73ff43 100644 --- a/micro_sam/models/peft_sam.py +++ b/micro_sam/models/peft_sam.py @@ -22,7 +22,7 @@ class LoRASurgery(nn.Module): Args: rank: The rank of the decomposition matrices for updating weights in each attention layer. - block: The chosen attention blocks for implementing lora. + block: The chosen attention blocks for implementing LoRA. """ def __init__(self, rank: int, block: nn.Module): super().__init__() @@ -50,8 +50,6 @@ def forward(self, x): qkv = self.qkv_proj(x) # B, N, N, 3 * org_C new_q = self.alpha * self.w_b_linear_q(self.w_a_linear_q(x)) new_v = self.alpha * self.w_b_linear_v(self.w_a_linear_v(x)) - - # HACK qkv = torch.cat( [ qkv[:, :, :, :self.dim] + new_q, # replacing new q values