Skip to content

Commit

Permalink
Remove hack comments
Browse files Browse the repository at this point in the history
  • Loading branch information
anwai98 committed Dec 21, 2024
1 parent 598ee6f commit 63314bd
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 4 deletions.
2 changes: 1 addition & 1 deletion finetuning/livecell_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def finetune_livecell(args):
save_root=args.save_root,
scheduler_kwargs=scheduler_kwargs,
save_every_kth_epoch=args.save_every_kth_epoch,
peft_kwargs={"rank": args.lora_rank} if args.lora_rank is not None else None,
peft_kwargs={"rank": args.lora_rank, "quantize": True} if args.lora_rank is not None else None,
mixed_precision=False,
)

Expand Down
4 changes: 1 addition & 3 deletions micro_sam/models/peft_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class LoRASurgery(nn.Module):
Args:
rank: The rank of the decomposition matrices for updating weights in each attention layer.
block: The chosen attention blocks for implementing lora.
block: The chosen attention blocks for implementing LoRA.
"""
def __init__(self, rank: int, block: nn.Module):
super().__init__()
Expand Down Expand Up @@ -50,8 +50,6 @@ def forward(self, x):
qkv = self.qkv_proj(x) # B, N, N, 3 * org_C
new_q = self.alpha * self.w_b_linear_q(self.w_a_linear_q(x))
new_v = self.alpha * self.w_b_linear_v(self.w_a_linear_v(x))

# HACK
qkv = torch.cat(
[
qkv[:, :, :, :self.dim] + new_q, # replacing new q values
Expand Down

0 comments on commit 63314bd

Please sign in to comment.