Skip to content

Commit

Permalink
merged train routine updated
Browse files Browse the repository at this point in the history
  • Loading branch information
lufre1 committed Jul 12, 2024
2 parents a550893 + ad76f2e commit 908e1c1
Show file tree
Hide file tree
Showing 7 changed files with 127 additions and 203 deletions.
15 changes: 15 additions & 0 deletions finetuning/livecell/lora/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
## Low Rank Adaptation Methods on Segment Anything for LIVECell

Insights:
- There's no real memory advantage actually unless it's truly scaled up. For instance:
- `vit_b`:
- SAM: 93M (takes ~50GB)
- SAM-LoRA: 4.4M (takes ~61GB)
- `vit_l`:
- SAM: 312M (takes ~63GB)
- SAM-LoRA: 4.4M (takes ~61GB)
- `vit_h`:
- SAM: 641M (takes ~73GB)
- SAM-LoRA: 4.7M (takes ~67GB)

- Question: Would quantization lead to better results? (e.g. QLoRA) or parallel adaptation? (e.g. DoRA)
110 changes: 28 additions & 82 deletions finetuning/livecell/lora/train_livecell.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@

import torch

from torch_em.model import UNETR
from torch_em.loss import DiceBasedDistanceLoss
from torch_em.data.datasets import get_livecell_loader
from torch_em.transform.label import PerObjectDistanceTransform

Expand Down Expand Up @@ -49,21 +47,6 @@ def count_parameters(model):

def finetune_livecell(args):
"""Code for finetuning SAM (using LoRA) on LIVECell
Initial observations: There's no real memory advantage actually unless it's "truly" scaled up
# vit_b
# SAM: 93M (takes ~50GB)
# SAM-LoRA: 4.2M (takes ~49GB)
# vit_l
# SAM: 312M (takes ~63GB)
# SAM-LoRA: 4.4M (takes ~61GB)
# vit_h
# SAM: 641M (takes ~73GB)
# SAM-LoRA: 4.7M (takes ~67GB)
# Q: Would quantization lead to better results? (eg. QLoRA / DoRA)
"""
# override this (below) if you have some more complex set-up and need to specify the exact gpu
device = "cuda" if torch.cuda.is_available() else "cpu"
Expand All @@ -72,89 +55,49 @@ def finetune_livecell(args):
model_type = args.model_type
checkpoint_path = None # override this to start training from a custom checkpoint
patch_shape = (520, 704) # the patch shape for training
n_objects_per_batch = 5 # this is the number of objects per batch that will be sampled
n_objects_per_batch = args.n_objects # this is the number of objects per batch that will be sampled
freeze_parts = args.freeze # override this to freeze different parts of the model
rank = 4 # the rank

# get the trainable segment anything model
model = sam_training.get_trainable_sam_model(
model_type=model_type,
device=device,
checkpoint_path=checkpoint_path,
freeze=freeze_parts,
use_lora=True,
rank=rank,
)
model.to(device)

# let's get the UNETR model for automatic instance segmentation pipeline
unetr = UNETR(
backbone="sam",
encoder=model.sam.image_encoder,
out_channels=3,
use_sam_stats=True,
final_activation="Sigmoid",
use_skip_connection=False,
resize_input=True,
)
unetr.to(device)

# let's check the total number of trainable parameters
print(count_parameters(model))

# let's get the parameters for SAM and the decoder from UNETR
joint_model_params = model.parameters()
lora_rank = 4 # the rank for low rank adaptation
checkpoint_name = f"{args.model_type}/livecell_sam"

joint_model_params = [params for params in joint_model_params] # sam parameters
for name, params in unetr.named_parameters(): # unetr's decoder parameters
if not name.startswith("encoder"):
joint_model_params.append(params)

optimizer = torch.optim.Adam(joint_model_params, lr=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.9, patience=10)
# all the stuff we need for training
train_loader, val_loader = get_dataloaders(patch_shape=patch_shape, data_path=args.input_path)
scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 10, "verbose": True}
optimizer_class = torch.optim.AdamW

# this class creates all the training data for a batch (inputs, prompts and labels)
convert_inputs = sam_training.ConvertToSamInputs(transform=model.transform, box_distortion_factor=0.025)

trainer = sam_training.JointSamTrainer(
name="livecell_lora",
save_root=args.save_root,
# Run training.
sam_training.train_sam(
name=checkpoint_name,
model_type=model_type,
train_loader=train_loader,
val_loader=val_loader,
model=model,
optimizer=optimizer,
device=device,
lr_scheduler=scheduler,
logger=sam_training.JointSamLogger,
log_image_interval=100,
mixed_precision=True,
convert_inputs=convert_inputs,
early_stopping=None,
n_objects_per_batch=n_objects_per_batch,
n_sub_iteration=8,
compile_model=False,
mask_prob=0.5, # (optional) overwrite to provide the probability of using mask inputs while training
unetr=unetr,
instance_loss=DiceBasedDistanceLoss(mask_distances_in_bg=True),
instance_metric=DiceBasedDistanceLoss(mask_distances_in_bg=True)
checkpoint_path=checkpoint_path,
freeze=freeze_parts,
device=device,
lr=1e-5,
n_iterations=args.iterations,
save_root=args.save_root,
scheduler_kwargs=scheduler_kwargs,
optimizer_class=optimizer_class,
lora_rank=lora_rank,
)
trainer.fit(args.iterations)

if args.export_path is not None:
checkpoint_path = os.path.join(
"" if args.save_root is None else args.save_root, "checkpoints", args.name, "best.pt"
"" if args.save_root is None else args.save_root, "checkpoints", checkpoint_name, "best.pt"
)
export_custom_sam_model(
checkpoint_path=checkpoint_path,
model_type=model_type,
save_path=args.export_path,
checkpoint_path=checkpoint_path, model_type=model_type, save_path=args.export_path,
)


def main():
parser = argparse.ArgumentParser(description="Finetune Segment Anything for the LiveCELL dataset.")
parser = argparse.ArgumentParser(description="Finetune Segment Anything for the LIVECell dataset.")
parser.add_argument(
"--input_path", "-i", default="/scratch/projects/nim00007/sam/data/livecell/",
help="The filepath to the LiveCELL data. If the data does not exist yet it will be downloaded."
help="The filepath to the LIVECell data. If the data does not exist yet it will be downloaded."
)
parser.add_argument(
"--model_type", "-m", default="vit_b",
Expand All @@ -176,6 +119,9 @@ def main():
"--freeze", type=str, nargs="+", default=None,
help="Which parts of the model to freeze for finetuning."
)
parser.add_argument(
"--n_objects", type=int, default=25, help="The number of instances (objects) per batch used for finetuning."
)
args = parser.parse_args()
finetune_livecell(args)

Expand Down
124 changes: 27 additions & 97 deletions finetuning/specialists/resource-efficient/covid_if_finetuning.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
import os
import argparse

import torch

from torch_em.model import UNETR
from torch_em.data import MinInstanceSampler
from torch_em.loss import DiceBasedDistanceLoss
from torch_em.transform.raw import normalize
from torch_em.data.datasets import get_covid_if_loader
from torch_em.transform.label import PerObjectDistanceTransform

import micro_sam.training as sam_training
from micro_sam.util import export_custom_sam_model


def covid_if_raw_trafo(raw):
raw = normalize(raw)
raw = raw * 255
return raw


def get_dataloaders(patch_shape, data_path, n_images):
"""This returns the immunofluoroscence data loaders implemented in torch_em:
https://github.com/constantinpape/torch-em/blob/main/torch_em/data/datasets/covid_if.py
It will automatically download the IF data.
It will automatically download the immunofluoroscence data.
Note: to replace this with another data loader you need to return a torch data loader
that retuns `x, y` tensors, where `x` is the image data and `y` are the labels.
Expand All @@ -29,7 +32,7 @@ def get_dataloaders(patch_shape, data_path, n_images):
label_transform = PerObjectDistanceTransform(
distances=True, boundary_distances=True, directed_distances=False, foreground=True, instances=True, min_size=25
)
raw_transform = sam_training.identity # the current workflow avoids rescaling the inputs to [-1, 1]
raw_transform = covid_if_raw_trafo
sampler = MinInstanceSampler()

choice_of_images = [1, 2, 5, 10]
Expand Down Expand Up @@ -67,7 +70,7 @@ def get_dataloaders(patch_shape, data_path, n_images):


def finetune_covid_if(args):
"""Example code for finetuning SAM on Covid-IF"""
"""Code for finetuning SAM on Covid-IF"""
# override this (below) if you have some more complex set-up and need to specify the exact gpu
device = "cuda" if torch.cuda.is_available() else "cpu"

Expand All @@ -77,105 +80,32 @@ def finetune_covid_if(args):
patch_shape = (512, 512) # the patch shape for training
n_objects_per_batch = args.n_objects # the number of objects per batch that will be sampled
freeze_parts = args.freeze # override this to freeze different parts of the model
checkpoint_name = f"{args.model_type}/covid_if_sam"

# HACK: let's convert the model checkpoints to the desired format
if checkpoint_path is not None:
from pathlib import Path
target_checkpoint_path = os.path.join(Path(checkpoint_path).parent, "checkpoint.pt")
if not os.path.exists(target_checkpoint_path):
export_custom_sam_model(
checkpoint_path=checkpoint_path, model_type=model_type, save_path=target_checkpoint_path
)
else:
target_checkpoint_path = checkpoint_path

# get the trainable segment anything model
model = sam_training.get_trainable_sam_model(
model_type=model_type, device=device, checkpoint_path=target_checkpoint_path, freeze=freeze_parts
)
model.to(device)

# let's get the UNETR model for automatic instance segmentation pipeline
unetr = UNETR(
backbone="sam",
encoder=model.sam.image_encoder,
out_channels=3,
use_sam_stats=True,
final_activation="Sigmoid",
use_skip_connection=False,
resize_input=True,
use_conv_transpose=True,
)

# let's initialize the decoder block from the previous fine-tuning, if provided
if checkpoint_path is not None:
import pickle
from micro_sam.util import _CustomUnpickler
custom_unpickle = pickle
custom_unpickle.Unpickler = _CustomUnpickler

decoder_state = torch.load(
checkpoint_path, map_location="cpu", pickle_module=custom_unpickle
)["decoder_state"]
unetr_state_dict = unetr.state_dict()
for k, v in unetr_state_dict.items():
if not k.startswith("encoder"):
unetr_state_dict[k] = decoder_state[k]
unetr.load_state_dict(unetr_state_dict)

unetr.to(device)

# let's get the parameters for SAM and the decoder from UNETR
joint_model_params = [params for params in model.parameters()] # sam parameters
for name, params in unetr.named_parameters(): # unetr's decoder parameters
if not name.startswith("encoder"):
joint_model_params.append(params)

# all the stuff we need for training
optimizer = torch.optim.Adam(joint_model_params, lr=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.9, patience=3, verbose=True)
# all stuff we need for training
train_loader, val_loader = get_dataloaders(
patch_shape=patch_shape, data_path=args.input_path, n_images=args.n_images
)
scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 3, "verbose": True}

# this class creates all the training data for a batch (inputs, prompts and labels)
convert_inputs = sam_training.ConvertToSamInputs(transform=model.transform, box_distortion_factor=0.025)

checkpoint_name = f"{args.model_type}/covid_if_sam"

# the trainer which performs the joint training and validation (implemented using "torch_em")
trainer = sam_training.JointSamTrainer(
# Run training
sam_training.train_sam(
name=checkpoint_name,
save_root=args.save_root,
model_type=model_type,
train_loader=train_loader,
val_loader=val_loader,
model=model,
optimizer=optimizer,
device=device,
lr_scheduler=scheduler,
logger=sam_training.JointSamLogger,
log_image_interval=100,
mixed_precision=True,
convert_inputs=convert_inputs,
early_stopping=10,
n_objects_per_batch=n_objects_per_batch,
n_sub_iteration=8,
compile_model=False,
mask_prob=0.5, # (optional) overwrite to provide the probability of using mask inputs while training
unetr=unetr,
instance_loss=DiceBasedDistanceLoss(mask_distances_in_bg=True),
instance_metric=DiceBasedDistanceLoss(mask_distances_in_bg=True),
early_stopping=10
)
trainer.fit(epochs=args.epochs, save_every_kth_epoch=args.save_every_kth_epoch)
if args.export_path is not None:
checkpoint_path = os.path.join(
"" if args.save_root is None else args.save_root, "checkpoints", checkpoint_name, "best.pt"
)
export_custom_sam_model(
checkpoint_path=checkpoint_path,
model_type=model_type,
save_path=args.export_path,
)
checkpoint_path=checkpoint_path,
freeze=freeze_parts,
device=device,
lr=1e-5,
n_epochs=args.epochs,
save_root=args.save_root,
scheduler_kwargs=scheduler_kwargs,
save_every_kth_epoch=args.save_every_kth_epoch,

)


def main():
Expand Down
6 changes: 3 additions & 3 deletions micro_sam/models/peft_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ def __init__(
block: nn.Module,
):
super().__init__()
self.qkv = block.attn.qkv
self.dim = self.qkv.in_features
self.qkv_proj = block.attn.qkv
self.dim = self.qkv_proj.in_features

self.w_a_linear_q = nn.Linear(self.dim, rank, bias=False)
self.w_b_linear_q = nn.Linear(rank, self.dim, bias=False)
Expand All @@ -44,7 +44,7 @@ def reset_parameters(self):
nn.init.zeros_(self.w_b_linear_v.weight)

def forward(self, x):
qkv = self.qkv(x) # B, N, N, 3 * org_C
qkv = self.qkv_proj(x) # B, N, N, 3 * org_C
new_q = self.w_b_linear_q(self.w_a_linear_q(x))
new_v = self.w_b_linear_v(self.w_a_linear_v(x))
qkv[:, :, :, :self.dim] += new_q
Expand Down
Loading

0 comments on commit 908e1c1

Please sign in to comment.