Skip to content

Commit

Permalink
Make qlora with mixed precision work
Browse files Browse the repository at this point in the history
  • Loading branch information
anwai98 committed Dec 22, 2024
1 parent 63314bd commit ce1f82a
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 7 deletions.
5 changes: 2 additions & 3 deletions finetuning/livecell_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ def get_dataloaders(patch_shape, data_path, cell_type=None):
train_loader = get_livecell_loader(
path=data_path, patch_shape=patch_shape, split="train", batch_size=2, num_workers=16,
cell_types=cell_type, download=True, shuffle=True, label_transform=label_transform,
raw_transform=raw_transform, label_dtype=torch.float32,
raw_transform=raw_transform, label_dtype=torch.float32
)
val_loader = get_livecell_loader(
path=data_path, patch_shape=patch_shape, split="val", batch_size=1, num_workers=16,
cell_types=cell_type, download=True, shuffle=True, label_transform=label_transform,
raw_transform=raw_transform, label_dtype=torch.float32,
raw_transform=raw_transform, label_dtype=torch.float32
)

return train_loader, val_loader
Expand Down Expand Up @@ -78,7 +78,6 @@ def finetune_livecell(args):
scheduler_kwargs=scheduler_kwargs,
save_every_kth_epoch=args.save_every_kth_epoch,
peft_kwargs={"rank": args.lora_rank, "quantize": True} if args.lora_rank is not None else None,
mixed_precision=False,
)

if args.export_path is not None:
Expand Down
4 changes: 0 additions & 4 deletions micro_sam/training/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,10 +271,6 @@ def train_sam(
else:
model_params = model.parameters()

if peft_kwargs and "quantize" in peft_kwargs:
import bitsandbytes as bnb
optimizer_class = bnb.optim.AdamW8bit

optimizer = optimizer_class(model_params, lr=lr)

if scheduler_kwargs is None:
Expand Down

0 comments on commit ce1f82a

Please sign in to comment.