From 312fc4ba102f54d0511f94bff54f7871984700f1 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Thu, 11 Jul 2024 16:04:40 +0200 Subject: [PATCH] Add training scripts for unet and unetr --- scripts/for_benchmarking_ais/common.py | 42 ++++++++++++++ scripts/for_benchmarking_ais/train_unet.py | 56 +++++++++++++++++++ scripts/for_benchmarking_ais/train_unetr.py | 62 +++++++++++++++++++++ 3 files changed, 160 insertions(+) create mode 100644 scripts/for_benchmarking_ais/common.py create mode 100644 scripts/for_benchmarking_ais/train_unet.py create mode 100644 scripts/for_benchmarking_ais/train_unetr.py diff --git a/scripts/for_benchmarking_ais/common.py b/scripts/for_benchmarking_ais/common.py new file mode 100644 index 00000000..9b25db2f --- /dev/null +++ b/scripts/for_benchmarking_ais/common.py @@ -0,0 +1,42 @@ +import argparse + +import torch + +from torch_em.transform.label import PerObjectDistanceTransform +from torch_em.data.datasets.light_microscopy import get_livecell_loader + +import micro_sam.training as sam_training + + +def get_loaders(path, patch_shape, for_sam=False): + kwargs = { + "label_transform": PerObjectDistanceTransform( + distances=True, + boundary_distances=True, + directed_distances=False, + foreground=True, + min_size=25, + ), + "label_dtype": torch.float32, + "num_workers": 16, + "patch_shape": patch_shape + } + + if for_sam: + kwargs["raw_transform"] = sam_training.identity + + train_loader = get_livecell_loader(path=path, split="train", batch_size=2, **kwargs) + val_loader = get_livecell_loader(path=path, split="val", batch_size=1, **kwargs) + + return train_loader, val_loader + + +def get_default_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument("-i", "--input_path", type=str, default="/scratch/projects/nim00007/sam/data/livecell") + parser.add_argument("-s", "--save_root", type=str, default=None) + parser.add_argument("-p", "--phase", type=str, default=None, choices=["train", "predict", "evaluate"]) + parser.add_argument("--iterations", type=str, default=1e5) + parser.add_argument("--sam", action="store_true") + args = parser.parse_args() + return args diff --git a/scripts/for_benchmarking_ais/train_unet.py b/scripts/for_benchmarking_ais/train_unet.py new file mode 100644 index 00000000..f897cbb2 --- /dev/null +++ b/scripts/for_benchmarking_ais/train_unet.py @@ -0,0 +1,56 @@ +from common import get_loaders, get_default_arguments + +import torch + +from torch_em.model import UNet2d +from torch_em.loss import DiceBasedDistanceLoss +from torch_em import default_segmentation_trainer +from torch_em.model.unetr import SingleDeconv2DBlock + + +def run_training_for_livecell(path, save_root, iterations): + # all the necessary stuff for training + device = "cuda" if torch.cuda.is_available() else "cpu" + patch_shape = (512, 512) + train_loader, val_loader = get_loaders(path=path, patch_shape=patch_shape) + loss = DiceBasedDistanceLoss(mask_distances_in_bg=True) + + model = UNet2d( + in_channels=1, + out_channels=3, + initial_features=64, + final_activation="Sigmoid", + sampler_impl=SingleDeconv2DBlock, + ) + model.to(device) + + trainer = default_segmentation_trainer( + name="livecell-unet", + model=model, + train_loader=train_loader, + val_loader=val_loader, + device=device, + learning_rate=1e-4, + loss=loss, + metric=loss, + log_image_interval=50, + save_root=save_root, + compile_model=False, + mixed_precision=True, + scheduler_kwargs={"mode": "min", "factor": 0.9, "patience": 5} + ) + + trainer.fit(int(iterations)) + + +def main(args): + if args.phase == "train": + run_training_for_livecell(path=args.input_path, save_root=args.save_root, iterations=args.iterations) + + else: + raise NotImplementedError + + +if __name__ == "__main__": + args = get_default_arguments() + main(args) diff --git a/scripts/for_benchmarking_ais/train_unetr.py b/scripts/for_benchmarking_ais/train_unetr.py new file mode 100644 index 00000000..79602a8d --- /dev/null +++ b/scripts/for_benchmarking_ais/train_unetr.py @@ -0,0 +1,62 @@ +from common import get_loaders, get_default_arguments + +import torch + +from torch_em.model import UNETR +from torch_em.loss import DiceBasedDistanceLoss +from torch_em import default_segmentation_trainer + + +def run_training_for_livecell(path, save_root, iterations, for_sam): + # all the necessary stuff for training + device = "cuda" if torch.cuda.is_available() else "cpu" + patch_shape = (512, 512) + train_loader, val_loader = get_loaders(path=path, patch_shape=patch_shape, for_sam=for_sam) + loss = DiceBasedDistanceLoss(mask_distances_in_bg=True) + checkpoint_path = "/scratch-grete/share/cidas/cca/models/sam/sam_vit_l_0b3195.pth" if for_sam else None + + model = UNETR( + encoder="vit_l", + out_channels=3, + final_activation="Sigmoid", + use_skip_connection=False, + use_sam_stats=for_sam, + encoder_checkpoint=checkpoint_path, + ) + model.to(device) + + trainer = default_segmentation_trainer( + name="livecell-unetr-sam" if for_sam else "livecell-unetr", + model=model, + train_loader=train_loader, + val_loader=val_loader, + device=device, + learning_rate=1e-4, + loss=loss, + metric=loss, + log_image_interval=50, + save_root=save_root, + compile_model=False, + mixed_precision=True, + scheduler_kwargs={"mode": "min", "factor": 0.9, "patience": 5} + ) + + trainer.fit(int(iterations)) + + +def main(args): + if args.phase == "train": + run_training_for_livecell( + path=args.input_path, + save_root=args.save_root, + iterations=args.iterations, + for_sam=args.sam, + ) + + else: + raise NotImplementedError + + +if __name__ == "__main__": + args = get_default_arguments() + main(args)