From ce1f82a143be521da7a421127441bdb030d7044c Mon Sep 17 00:00:00 2001 From: anwai98 Date: Sun, 22 Dec 2024 11:45:36 +0100 Subject: [PATCH] Make qlora with mixed precision work --- finetuning/livecell_finetuning.py | 5 ++--- micro_sam/training/training.py | 4 ---- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/finetuning/livecell_finetuning.py b/finetuning/livecell_finetuning.py index 26942112..95274773 100644 --- a/finetuning/livecell_finetuning.py +++ b/finetuning/livecell_finetuning.py @@ -28,12 +28,12 @@ def get_dataloaders(patch_shape, data_path, cell_type=None): train_loader = get_livecell_loader( path=data_path, patch_shape=patch_shape, split="train", batch_size=2, num_workers=16, cell_types=cell_type, download=True, shuffle=True, label_transform=label_transform, - raw_transform=raw_transform, label_dtype=torch.float32, + raw_transform=raw_transform, label_dtype=torch.float32 ) val_loader = get_livecell_loader( path=data_path, patch_shape=patch_shape, split="val", batch_size=1, num_workers=16, cell_types=cell_type, download=True, shuffle=True, label_transform=label_transform, - raw_transform=raw_transform, label_dtype=torch.float32, + raw_transform=raw_transform, label_dtype=torch.float32 ) return train_loader, val_loader @@ -78,7 +78,6 @@ def finetune_livecell(args): scheduler_kwargs=scheduler_kwargs, save_every_kth_epoch=args.save_every_kth_epoch, peft_kwargs={"rank": args.lora_rank, "quantize": True} if args.lora_rank is not None else None, - mixed_precision=False, ) if args.export_path is not None: diff --git a/micro_sam/training/training.py b/micro_sam/training/training.py index 3411bf40..bfebf458 100644 --- a/micro_sam/training/training.py +++ b/micro_sam/training/training.py @@ -271,10 +271,6 @@ def train_sam( else: model_params = model.parameters() - if peft_kwargs and "quantize" in peft_kwargs: - import bitsandbytes as bnb - optimizer_class = bnb.optim.AdamW8bit - optimizer = optimizer_class(model_params, lr=lr) if scheduler_kwargs is None: