Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor resource efficient finetuning scripts #653

Merged
merged 1 commit into from
Jul 9, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 27 additions & 97 deletions finetuning/specialists/resource-efficient/covid_if_finetuning.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
import os
import argparse

import torch

from torch_em.model import UNETR
from torch_em.data import MinInstanceSampler
from torch_em.loss import DiceBasedDistanceLoss
from torch_em.transform.raw import normalize
from torch_em.data.datasets import get_covid_if_loader
from torch_em.transform.label import PerObjectDistanceTransform

import micro_sam.training as sam_training
from micro_sam.util import export_custom_sam_model


def covid_if_raw_trafo(raw):
raw = normalize(raw)
raw = raw * 255
return raw


def get_dataloaders(patch_shape, data_path, n_images):
"""This returns the immunofluoroscence data loaders implemented in torch_em:
https://github.com/constantinpape/torch-em/blob/main/torch_em/data/datasets/covid_if.py
It will automatically download the IF data.
It will automatically download the immunofluoroscence data.

Note: to replace this with another data loader you need to return a torch data loader
that retuns `x, y` tensors, where `x` is the image data and `y` are the labels.
Expand All @@ -29,7 +32,7 @@ def get_dataloaders(patch_shape, data_path, n_images):
label_transform = PerObjectDistanceTransform(
distances=True, boundary_distances=True, directed_distances=False, foreground=True, instances=True, min_size=25
)
raw_transform = sam_training.identity # the current workflow avoids rescaling the inputs to [-1, 1]
raw_transform = covid_if_raw_trafo
sampler = MinInstanceSampler()

choice_of_images = [1, 2, 5, 10]
Expand Down Expand Up @@ -67,7 +70,7 @@ def get_dataloaders(patch_shape, data_path, n_images):


def finetune_covid_if(args):
"""Example code for finetuning SAM on Covid-IF"""
"""Code for finetuning SAM on Covid-IF"""
# override this (below) if you have some more complex set-up and need to specify the exact gpu
device = "cuda" if torch.cuda.is_available() else "cpu"

Expand All @@ -77,105 +80,32 @@ def finetune_covid_if(args):
patch_shape = (512, 512) # the patch shape for training
n_objects_per_batch = args.n_objects # the number of objects per batch that will be sampled
freeze_parts = args.freeze # override this to freeze different parts of the model
checkpoint_name = f"{args.model_type}/covid_if_sam"

# HACK: let's convert the model checkpoints to the desired format
if checkpoint_path is not None:
from pathlib import Path
target_checkpoint_path = os.path.join(Path(checkpoint_path).parent, "checkpoint.pt")
if not os.path.exists(target_checkpoint_path):
export_custom_sam_model(
checkpoint_path=checkpoint_path, model_type=model_type, save_path=target_checkpoint_path
)
else:
target_checkpoint_path = checkpoint_path

# get the trainable segment anything model
model = sam_training.get_trainable_sam_model(
model_type=model_type, device=device, checkpoint_path=target_checkpoint_path, freeze=freeze_parts
)
model.to(device)

# let's get the UNETR model for automatic instance segmentation pipeline
unetr = UNETR(
backbone="sam",
encoder=model.sam.image_encoder,
out_channels=3,
use_sam_stats=True,
final_activation="Sigmoid",
use_skip_connection=False,
resize_input=True,
use_conv_transpose=True,
)

# let's initialize the decoder block from the previous fine-tuning, if provided
if checkpoint_path is not None:
import pickle
from micro_sam.util import _CustomUnpickler
custom_unpickle = pickle
custom_unpickle.Unpickler = _CustomUnpickler

decoder_state = torch.load(
checkpoint_path, map_location="cpu", pickle_module=custom_unpickle
)["decoder_state"]
unetr_state_dict = unetr.state_dict()
for k, v in unetr_state_dict.items():
if not k.startswith("encoder"):
unetr_state_dict[k] = decoder_state[k]
unetr.load_state_dict(unetr_state_dict)

unetr.to(device)

# let's get the parameters for SAM and the decoder from UNETR
joint_model_params = [params for params in model.parameters()] # sam parameters
for name, params in unetr.named_parameters(): # unetr's decoder parameters
if not name.startswith("encoder"):
joint_model_params.append(params)

# all the stuff we need for training
optimizer = torch.optim.Adam(joint_model_params, lr=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.9, patience=3, verbose=True)
# all stuff we need for training
train_loader, val_loader = get_dataloaders(
patch_shape=patch_shape, data_path=args.input_path, n_images=args.n_images
)
scheduler_kwargs = {"mode": "min", "factor": 0.9, "patience": 3, "verbose": True}

# this class creates all the training data for a batch (inputs, prompts and labels)
convert_inputs = sam_training.ConvertToSamInputs(transform=model.transform, box_distortion_factor=0.025)

checkpoint_name = f"{args.model_type}/covid_if_sam"

# the trainer which performs the joint training and validation (implemented using "torch_em")
trainer = sam_training.JointSamTrainer(
# Run training
sam_training.train_sam(
name=checkpoint_name,
save_root=args.save_root,
model_type=model_type,
train_loader=train_loader,
val_loader=val_loader,
model=model,
optimizer=optimizer,
device=device,
lr_scheduler=scheduler,
logger=sam_training.JointSamLogger,
log_image_interval=100,
mixed_precision=True,
convert_inputs=convert_inputs,
early_stopping=10,
n_objects_per_batch=n_objects_per_batch,
n_sub_iteration=8,
compile_model=False,
mask_prob=0.5, # (optional) overwrite to provide the probability of using mask inputs while training
unetr=unetr,
instance_loss=DiceBasedDistanceLoss(mask_distances_in_bg=True),
instance_metric=DiceBasedDistanceLoss(mask_distances_in_bg=True),
early_stopping=10
)
trainer.fit(epochs=args.epochs, save_every_kth_epoch=args.save_every_kth_epoch)
if args.export_path is not None:
checkpoint_path = os.path.join(
"" if args.save_root is None else args.save_root, "checkpoints", checkpoint_name, "best.pt"
)
export_custom_sam_model(
checkpoint_path=checkpoint_path,
model_type=model_type,
save_path=args.export_path,
)
checkpoint_path=checkpoint_path,
freeze=freeze_parts,
device=device,
lr=1e-5,
n_epochs=args.epochs,
save_root=args.save_root,
scheduler_kwargs=scheduler_kwargs,
save_every_kth_epoch=args.save_every_kth_epoch,

)


def main():
Expand Down
Loading