Skip to content

Commit

Permalink
Add training scripts for unet and unetr
Browse files Browse the repository at this point in the history
  • Loading branch information
anwai98 committed Jul 11, 2024
1 parent 7c6e1a4 commit 312fc4b
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 0 deletions.
42 changes: 42 additions & 0 deletions scripts/for_benchmarking_ais/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import argparse

import torch

from torch_em.transform.label import PerObjectDistanceTransform
from torch_em.data.datasets.light_microscopy import get_livecell_loader

import micro_sam.training as sam_training


def get_loaders(path, patch_shape, for_sam=False):
kwargs = {
"label_transform": PerObjectDistanceTransform(
distances=True,
boundary_distances=True,
directed_distances=False,
foreground=True,
min_size=25,
),
"label_dtype": torch.float32,
"num_workers": 16,
"patch_shape": patch_shape
}

if for_sam:
kwargs["raw_transform"] = sam_training.identity

train_loader = get_livecell_loader(path=path, split="train", batch_size=2, **kwargs)
val_loader = get_livecell_loader(path=path, split="val", batch_size=1, **kwargs)

return train_loader, val_loader


def get_default_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--input_path", type=str, default="/scratch/projects/nim00007/sam/data/livecell")
parser.add_argument("-s", "--save_root", type=str, default=None)
parser.add_argument("-p", "--phase", type=str, default=None, choices=["train", "predict", "evaluate"])
parser.add_argument("--iterations", type=str, default=1e5)
parser.add_argument("--sam", action="store_true")
args = parser.parse_args()
return args
56 changes: 56 additions & 0 deletions scripts/for_benchmarking_ais/train_unet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from common import get_loaders, get_default_arguments

import torch

from torch_em.model import UNet2d
from torch_em.loss import DiceBasedDistanceLoss
from torch_em import default_segmentation_trainer
from torch_em.model.unetr import SingleDeconv2DBlock


def run_training_for_livecell(path, save_root, iterations):
# all the necessary stuff for training
device = "cuda" if torch.cuda.is_available() else "cpu"
patch_shape = (512, 512)
train_loader, val_loader = get_loaders(path=path, patch_shape=patch_shape)
loss = DiceBasedDistanceLoss(mask_distances_in_bg=True)

model = UNet2d(
in_channels=1,
out_channels=3,
initial_features=64,
final_activation="Sigmoid",
sampler_impl=SingleDeconv2DBlock,
)
model.to(device)

trainer = default_segmentation_trainer(
name="livecell-unet",
model=model,
train_loader=train_loader,
val_loader=val_loader,
device=device,
learning_rate=1e-4,
loss=loss,
metric=loss,
log_image_interval=50,
save_root=save_root,
compile_model=False,
mixed_precision=True,
scheduler_kwargs={"mode": "min", "factor": 0.9, "patience": 5}
)

trainer.fit(int(iterations))


def main(args):
if args.phase == "train":
run_training_for_livecell(path=args.input_path, save_root=args.save_root, iterations=args.iterations)

else:
raise NotImplementedError


if __name__ == "__main__":
args = get_default_arguments()
main(args)
62 changes: 62 additions & 0 deletions scripts/for_benchmarking_ais/train_unetr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from common import get_loaders, get_default_arguments

import torch

from torch_em.model import UNETR
from torch_em.loss import DiceBasedDistanceLoss
from torch_em import default_segmentation_trainer


def run_training_for_livecell(path, save_root, iterations, for_sam):
# all the necessary stuff for training
device = "cuda" if torch.cuda.is_available() else "cpu"
patch_shape = (512, 512)
train_loader, val_loader = get_loaders(path=path, patch_shape=patch_shape, for_sam=for_sam)
loss = DiceBasedDistanceLoss(mask_distances_in_bg=True)
checkpoint_path = "/scratch-grete/share/cidas/cca/models/sam/sam_vit_l_0b3195.pth" if for_sam else None

model = UNETR(
encoder="vit_l",
out_channels=3,
final_activation="Sigmoid",
use_skip_connection=False,
use_sam_stats=for_sam,
encoder_checkpoint=checkpoint_path,
)
model.to(device)

trainer = default_segmentation_trainer(
name="livecell-unetr-sam" if for_sam else "livecell-unetr",
model=model,
train_loader=train_loader,
val_loader=val_loader,
device=device,
learning_rate=1e-4,
loss=loss,
metric=loss,
log_image_interval=50,
save_root=save_root,
compile_model=False,
mixed_precision=True,
scheduler_kwargs={"mode": "min", "factor": 0.9, "patience": 5}
)

trainer.fit(int(iterations))


def main(args):
if args.phase == "train":
run_training_for_livecell(
path=args.input_path,
save_root=args.save_root,
iterations=args.iterations,
for_sam=args.sam,
)

else:
raise NotImplementedError


if __name__ == "__main__":
args = get_default_arguments()
main(args)

0 comments on commit 312fc4b

Please sign in to comment.