From 676b7ed067fd7118e4bd3a8fe3501a386f0b0e28 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Wed, 25 Oct 2023 13:33:09 +0200 Subject: [PATCH 1/6] Refactor histopatho --- .../train_histopathology_generalist.py | 80 ++++++++++--------- .../histopathology/train_model.sbatch | 5 +- 2 files changed, 47 insertions(+), 38 deletions(-) diff --git a/finetuning/generalists/training/histopathology/train_histopathology_generalist.py b/finetuning/generalists/training/histopathology/train_histopathology_generalist.py index 1fda098d..d0198982 100644 --- a/finetuning/generalists/training/histopathology/train_histopathology_generalist.py +++ b/finetuning/generalists/training/histopathology/train_histopathology_generalist.py @@ -1,62 +1,50 @@ import os +import argparse -import micro_sam.training as sam_training import torch -import torch_em +from torch_em.loss import DiceLoss -import torch.utils.data as data_util -from torch_em.data.datasets import get_lizard_dataset -from torch_em.data.sampler import MinInstanceSampler +import micro_sam.training as sam_training from micro_sam.util import export_custom_sam_model +from obtain_hp_datasets import get_generalist_hp_loaders -# TODO use other datasets than lizard -def get_dataloaders(patch_shape, data_path): - label_transform = torch_em.transform.label.label_consecutive # to ensure consecutive IDs - sampler = MinInstanceSampler(min_num_instances=5) - dataset = get_lizard_dataset( - path=data_path, download=True, patch_shape=patch_shape, label_transform=label_transform, - sampler=sampler, - ) - train_ds, val_ds = data_util.random_split(dataset, [0.9, 0.1]) - train_loader = torch_em.get_data_loader(train_ds, batch_size=1) - val_loader = torch_em.get_data_loader(val_ds, batch_size=1) - return train_loader, val_loader - -def finetune_histopatho(input_path, export_path, model_type="vit_h", iterations=int(2e4), save_root=None): - """Example code for finetuning SAM on LiveCELL""" +def finetune_hp_generalist(args): + """Example code for finetuning SAM on multiple histopathology datasets""" + # override this (below) if you have some more complex set-up and need to specify the exact gpu + device = "cuda" if torch.cuda.is_available() else "cpu" # training settings: + model_type = args.model_type checkpoint_path = None # override this to start training from a custom checkpoint - device = "cuda" # override this if you have some more complex set-up and need to specify the exact gpu patch_shape = (512, 512) # the patch shape for training - n_objects_per_batch = 50 # this is the number of objects per batch that will be sampled - - train_loader, val_loader = get_dataloaders(patch_shape=patch_shape, data_path=input_path) + n_objects_per_batch = 25 # this is the number of objects per batch that will be sampled + freeze_parts = None # override this to freeze one or more of these backbones # get the trainable segment anything model - model = sam_training.get_trainable_sam_model(model_type, checkpoint_path, device=device) + model = sam_training.get_trainable_sam_model(model_type, checkpoint_path, freeze_parts) # all the stuff we need for training optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.9, patience=10, verbose=True) + train_loader, val_loader = get_generalist_hp_loaders(patch_shape=patch_shape, data_path=args.input_path) # this class creates all the training data for a batch (inputs, prompts and labels) convert_inputs = sam_training.ConvertToSamInputs() - checkpoint_name = "sam-histopatho-v1" + checkpoint_name = "generalist-hp-sam" # the trainer which performs training and validation (implemented using "torch_em") trainer = sam_training.SamTrainer( name=checkpoint_name, - save_root=save_root, + save_root=args.save_root, train_loader=train_loader, val_loader=val_loader, model=model, optimizer=optimizer, # currently we compute loss batch-wise, else we pass channelwise True - loss=torch_em.loss.DiceLoss(channelwise=False), - metric=torch_em.loss.DiceLoss(), + loss=DiceLoss(channelwise=False), + metric=DiceLoss(), device=device, lr_scheduler=scheduler, logger=sam_training.SamLogger, @@ -67,22 +55,42 @@ def finetune_histopatho(input_path, export_path, model_type="vit_h", iterations= n_sub_iteration=8, compile_model=False ) - trainer.fit(iterations) - if export_path is not None: + trainer.fit(iterations=args.iterations) + if args.export_path is not None: checkpoint_path = os.path.join( - "" if save_root is None else save_root, "checkpoints", checkpoint_name, "best.pt" + "" if args.save_root is None else args.save_root, "checkpoints", checkpoint_name, "best.pt" ) export_custom_sam_model( checkpoint_path=checkpoint_path, model_type=model_type, - save_path=export_path, + save_path=args.export_path, ) def main(): - input_path = "/scratch-grete/projects/nim00007/data/lizard" - export_path = "./sam-vith-histopatho-v1.pth" - finetune_histopatho(input_path, export_path) + parser = argparse.ArgumentParser(description="Finetune Segment Anything for the Histopathology datasets.") + parser.add_argument( + "--input_path", "-i", default="/scratch/usr/nimanwai/data/", + help="The filepath to all the respective hp datasets. If the data does not exist yet it will be downloaded" + ) + parser.add_argument( + "--model_type", "-m", default="vit_b", + help="The model type to use for fine-tuning. Either vit_h, vit_b or vit_l." + ) + parser.add_argument( + "--save_root", "-s", + help="Where to save the checkpoint and logs. By default they will be saved where this script is run from." + ) + parser.add_argument( + "--iterations", type=int, default=int(1e5), + help="For how many iterations should the model be trained? By default 100k." + ) + parser.add_argument( + "--export_path", "-e", + help="Where to export the finetuned model to. The exported model can be used in the annotation tools." + ) + args = parser.parse_args() + finetune_hp_generalist(args) if __name__ == "__main__": diff --git a/finetuning/generalists/training/histopathology/train_model.sbatch b/finetuning/generalists/training/histopathology/train_model.sbatch index 3a588465..2abe4042 100755 --- a/finetuning/generalists/training/histopathology/train_model.sbatch +++ b/finetuning/generalists/training/histopathology/train_model.sbatch @@ -1,11 +1,12 @@ #! /bin/bash #SBATCH -c 16 #SBATCH --mem 128G -#SBATCH -t 2800 +#SBATCH -t 4-00:00:00 #SBATCH -p grete:shared #SBATCH -G A100:1 #SBATCH -A nim00007 #SBATCH --constraint=80gb +#SBATCH --qos=96h -source activate sam +mamba activate sam python train_histopathology_generalist.py $@ From 0fc1e3b7890e2c47ac7a54f64784f3a4acdbc348 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Wed, 25 Oct 2023 16:50:56 +0200 Subject: [PATCH 2/6] Adapt lizard training --- .../histopathology/obtain_hp_datasets.py | 21 +++++++++++++++++++ micro_sam/training/sam_trainer.py | 2 +- 2 files changed, 22 insertions(+), 1 deletion(-) create mode 100644 finetuning/generalists/training/histopathology/obtain_hp_datasets.py diff --git a/finetuning/generalists/training/histopathology/obtain_hp_datasets.py b/finetuning/generalists/training/histopathology/obtain_hp_datasets.py new file mode 100644 index 00000000..49b7b994 --- /dev/null +++ b/finetuning/generalists/training/histopathology/obtain_hp_datasets.py @@ -0,0 +1,21 @@ +import os + +import torch_em +from torch_em.data import datasets +import torch.utils.data as data_util +from torch_em.data.sampler import MinInstanceSampler + + +# TODO use other datasets than lizard +# need to add: pannuke, bcss, monuseg, monusac +def get_generalist_hp_loaders(patch_shape, data_path): + label_transform = torch_em.transform.label.label_consecutive # to ensure consecutive IDs + sampler = MinInstanceSampler(min_num_instances=5) + dataset = datasets.get_lizard_dataset( + path=os.path.join(data_path, "lizard"), download=False, patch_shape=patch_shape, + label_transform=label_transform, sampler=sampler + ) + train_ds, val_ds = data_util.random_split(dataset, [0.9, 0.1]) + train_loader = torch_em.get_data_loader(train_ds, batch_size=1) + val_loader = torch_em.get_data_loader(val_ds, batch_size=1) + return train_loader, val_loader diff --git a/micro_sam/training/sam_trainer.py b/micro_sam/training/sam_trainer.py index 177c6b02..9b5e6766 100644 --- a/micro_sam/training/sam_trainer.py +++ b/micro_sam/training/sam_trainer.py @@ -258,7 +258,7 @@ def _get_updated_points_per_mask_per_subiter(self, masks, sampled_binary_y, batc def _update_samples_for_gt_instances(self, y, n_samples): num_instances_gt = torch.amax(y, dim=(1, 2, 3)) - num_instances_gt = num_instances_gt.numpy() + num_instances_gt = num_instances_gt.numpy().astype(int) # FIXME: remove this as this is already taken care of by em generalist pr n_samples = min(num_instances_gt) if n_samples > min(num_instances_gt) else n_samples return n_samples From 51f5898590de97ce41b94a977f0bb86b8befef53 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Thu, 26 Oct 2023 17:02:06 +0200 Subject: [PATCH 3/6] Add histopathogy datasets --- .../histopathology/obtain_hp_datasets.py | 155 ++++++++++++++++-- 1 file changed, 142 insertions(+), 13 deletions(-) diff --git a/finetuning/generalists/training/histopathology/obtain_hp_datasets.py b/finetuning/generalists/training/histopathology/obtain_hp_datasets.py index 49b7b994..75bec977 100644 --- a/finetuning/generalists/training/histopathology/obtain_hp_datasets.py +++ b/finetuning/generalists/training/histopathology/obtain_hp_datasets.py @@ -1,21 +1,150 @@ import os +import numpy as np +from math import ceil, floor +from typing import Optional, List -import torch_em -from torch_em.data import datasets +import torch import torch.utils.data as data_util -from torch_em.data.sampler import MinInstanceSampler +import torch_em +from torch_em.data import datasets, MinInstanceSampler, ConcatDataset -# TODO use other datasets than lizard -# need to add: pannuke, bcss, monuseg, monusac -def get_generalist_hp_loaders(patch_shape, data_path): - label_transform = torch_em.transform.label.label_consecutive # to ensure consecutive IDs + +"""NOTE: test sets for in-domain histopathology evaluation + - monuseg test split + - monusac test split + - bcss test samples (split intrinsically - in the new PR) + +length of individual loaders: @all (3 channel input images) + - lizard: train - 359*2; val - 179 + - bcss: train - 54*2; val - 28 + - monuseg: train - 15*2; val - 7 + - monusac: train - 84*2; val - 41 + - pannuke: train - 647*2; val - 680 +""" + + +def _get_train_val_split(ds, val_fraction: float = 0.2): + generator = torch.Generator().manual_seed(42) + train_ds, val_ds = data_util.random_split(ds, [1 - val_fraction, val_fraction], generator=generator) + return train_ds, val_ds + + +class BCSSLabelTrafo: + def __init__(self, label_choices: Optional[List[int]] = None): + self.label_choices = label_choices + + def __call__(self, labels: np.ndarray) -> np.ndarray: + """Returns the transformed bcss data labels (use-case for SAM)""" + if self.label_choices is not None: + labels[~np.isin(labels, self.label_choices)] = 0 + segmentation = label_padding_trafo(labels) + else: + segmentation = label_padding_trafo(labels) + return segmentation + + +def raw_padding_trafo(raw, desired_shape=(3, 512, 512)): + assert raw.shape[0] == 3, "The input shape isn't channels first, expected: (3, H, W)" + tmp_ddim = (desired_shape[1] - raw.shape[1], desired_shape[2] - raw.shape[2]) + ddim = (tmp_ddim[0] / 2, tmp_ddim[1] / 2) + raw = np.pad( + raw, + pad_width=((0, 0), (ceil(ddim[0]), floor(ddim[0])), (ceil(ddim[1]), floor(ddim[1]))), + mode="reflect" + ) + assert raw.shape == desired_shape + return raw + + +def label_padding_trafo(labels, desired_shape=(512, 512)): + tmp_ddim = (desired_shape[0] - labels.shape[0], desired_shape[1] - labels.shape[1]) + ddim = (tmp_ddim[0] / 2, tmp_ddim[1] / 2) + labels = np.pad( + labels, + pad_width=((ceil(ddim[0]), floor(ddim[0])), (ceil(ddim[1]), floor(ddim[1]))), + mode="reflect" + ) + assert labels.shape == desired_shape + labels = label_consecutive_trafo(labels) + return labels + + +def label_consecutive_trafo(labels): + labels = labels.astype(int) + labels = torch_em.transform.label.label_consecutive(labels) # to ensure consecutive IDs + return labels + + +def get_concat_hp_datasets(path, patch_shape): + label_dtype = torch.int64 sampler = MinInstanceSampler(min_num_instances=5) - dataset = datasets.get_lizard_dataset( - path=os.path.join(data_path, "lizard"), download=False, patch_shape=patch_shape, - label_transform=label_transform, sampler=sampler + + # make lizard dataset splits into fractions + lizard_ds = datasets.get_lizard_dataset( + path=os.path.join(path, "lizard"), patch_shape=patch_shape, sampler=sampler, label_dtype=label_dtype, + raw_transform=raw_padding_trafo, label_transform=label_padding_trafo + ) + lizard_train_ds, lizard_val_ds = _get_train_val_split(ds=lizard_ds) + lizard_train_ds.ndim = 2 + lizard_val_ds.ndim = 2 + + # get bcss internal splits + bcss_train_ds = datasets.get_bcss_dataset( + path=os.path.join(path, "bcss"), patch_shape=patch_shape, split="train", sampler=MinInstanceSampler(), + label_transform=BCSSLabelTrafo(), label_dtype=label_dtype + ) + bcss_val_ds = datasets.get_bcss_dataset( + path=os.path.join(path, "bcss"), patch_shape=patch_shape, split="val", sampler=MinInstanceSampler(), + label_transform=BCSSLabelTrafo(), label_dtype=label_dtype + ) + + # make monuseg train dataset splits into fractions + monuseg_ds = datasets.get_monuseg_dataset( + path=os.path.join(path, "monuseg"), patch_shape=patch_shape, split="train", sampler=sampler, + label_transform=label_consecutive_trafo, ndim=2, label_dtype=label_dtype + ) + monuseg_train_ds, monuseg_val_ds = _get_train_val_split(ds=monuseg_ds) + + # make monusac train dataset splits into fractions + monusac_ds = datasets.get_monusac_dataset( + path=os.path.join(path, "monusac"), patch_shape=patch_shape, split="train", sampler=MinInstanceSampler(), + label_transform=label_consecutive_trafo, ndim=2, label_dtype=label_dtype + ) + monusac_train_ds, monusac_val_ds = _get_train_val_split(ds=monusac_ds) + + # out of three folds (sets of data) of provided data, we use two for training and 1 for validation + pannuke_train_ds = datasets.get_pannuke_dataset( + path=os.path.join(path, "pannuke"), patch_shape=(1, *patch_shape), sampler=sampler, folds=["fold_1", "fold_2"], + label_transform=label_padding_trafo, raw_transform=raw_padding_trafo, ndim=2, label_dtype=label_dtype ) - train_ds, val_ds = data_util.random_split(dataset, [0.9, 0.1]) - train_loader = torch_em.get_data_loader(train_ds, batch_size=1) - val_loader = torch_em.get_data_loader(val_ds, batch_size=1) + pannuke_val_ds = datasets.get_pannuke_dataset( + path=os.path.join(path, "pannuke"), patch_shape=(1, *patch_shape), sampler=sampler, folds=["fold_3"], + label_transform=label_padding_trafo, raw_transform=raw_padding_trafo, ndim=2, label_dtype=label_dtype + ) + + generalist_hp_train_dataset = ConcatDataset( + lizard_train_ds, bcss_train_ds, monuseg_train_ds, monusac_train_ds, pannuke_train_ds + ) + + generalist_hp_val_dataset = ConcatDataset( + lizard_val_ds, bcss_val_ds, monuseg_val_ds, monusac_val_ds, pannuke_val_ds + ) + + return generalist_hp_train_dataset, generalist_hp_val_dataset + + +def get_generalist_hp_loaders(patch_shape, data_path): + """This returns the concatenated histopathology datasets implemented in `torch_em`: + https://github.com/constantinpape/torch-em/tree/main/torch_em/data/datasets + It will automatically download all the datasets + + NOTE: to remove / replace the datasets with another dataset, you need to add the datasets (for train and val splits) + in `get_concat_lm_dataset`. The labels have to be in a label mask instance segmentation format. + i.e. the tensors (inputs & masks) should be of same spatial shape, with each object in the mask having it's own ID. + IMPORTANT: the ID 0 is reserved for background, and the IDs must be consecutive. + """ + generalist_train_dataset, generalist_val_dataset = get_concat_hp_datasets(path=data_path, patch_shape=patch_shape) + train_loader = torch_em.get_data_loader(generalist_train_dataset, batch_size=2, shuffle=True, num_workers=16) + val_loader = torch_em.get_data_loader(generalist_val_dataset, batch_size=1, shuffle=True, num_workers=16) return train_loader, val_loader From 3fc4308b4cdcd0b1fc867aa0121190c52eaaff5d Mon Sep 17 00:00:00 2001 From: anwai98 Date: Fri, 27 Oct 2023 15:00:25 +0200 Subject: [PATCH 4/6] Update training submission --- .../histopathology/train_histopathology_generalist.py | 2 +- .../generalists/training/histopathology/train_model.sbatch | 6 ++++-- micro_sam/training/sam_trainer.py | 2 +- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/finetuning/generalists/training/histopathology/train_histopathology_generalist.py b/finetuning/generalists/training/histopathology/train_histopathology_generalist.py index d0198982..893196c4 100644 --- a/finetuning/generalists/training/histopathology/train_histopathology_generalist.py +++ b/finetuning/generalists/training/histopathology/train_histopathology_generalist.py @@ -33,7 +33,7 @@ def finetune_hp_generalist(args): # this class creates all the training data for a batch (inputs, prompts and labels) convert_inputs = sam_training.ConvertToSamInputs() - checkpoint_name = "generalist-hp-sam" + checkpoint_name = f"generalist-hp-sam-{args.model_type}" # the trainer which performs training and validation (implemented using "torch_em") trainer = sam_training.SamTrainer( name=checkpoint_name, diff --git a/finetuning/generalists/training/histopathology/train_model.sbatch b/finetuning/generalists/training/histopathology/train_model.sbatch index 2abe4042..7834d1a2 100755 --- a/finetuning/generalists/training/histopathology/train_model.sbatch +++ b/finetuning/generalists/training/histopathology/train_model.sbatch @@ -1,12 +1,14 @@ #! /bin/bash #SBATCH -c 16 #SBATCH --mem 128G -#SBATCH -t 4-00:00:00 +#SBATCH -t 7-00:00:00 #SBATCH -p grete:shared #SBATCH -G A100:1 #SBATCH -A nim00007 #SBATCH --constraint=80gb -#SBATCH --qos=96h +#SBATCH --qos=7d +#SBATCH --job-name=sam_histopathology +source ~/.bashrc mamba activate sam python train_histopathology_generalist.py $@ diff --git a/micro_sam/training/sam_trainer.py b/micro_sam/training/sam_trainer.py index 9b5e6766..177c6b02 100644 --- a/micro_sam/training/sam_trainer.py +++ b/micro_sam/training/sam_trainer.py @@ -258,7 +258,7 @@ def _get_updated_points_per_mask_per_subiter(self, masks, sampled_binary_y, batc def _update_samples_for_gt_instances(self, y, n_samples): num_instances_gt = torch.amax(y, dim=(1, 2, 3)) - num_instances_gt = num_instances_gt.numpy().astype(int) # FIXME: remove this as this is already taken care of by em generalist pr + num_instances_gt = num_instances_gt.numpy() n_samples = min(num_instances_gt) if n_samples > min(num_instances_gt) else n_samples return n_samples From 3f132762202ccde0a5fa41ecc34a08fca45572ec Mon Sep 17 00:00:00 2001 From: anwai98 Date: Mon, 30 Oct 2023 13:20:08 +0100 Subject: [PATCH 5/6] Add connected components to bcss label trafo --- .../histopathology/obtain_hp_datasets.py | 24 ++++++++++++------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/finetuning/generalists/training/histopathology/obtain_hp_datasets.py b/finetuning/generalists/training/histopathology/obtain_hp_datasets.py index 75bec977..8869547d 100644 --- a/finetuning/generalists/training/histopathology/obtain_hp_datasets.py +++ b/finetuning/generalists/training/histopathology/obtain_hp_datasets.py @@ -3,6 +3,8 @@ from math import ceil, floor from typing import Optional, List +from skimage import measure + import torch import torch.utils.data as data_util @@ -16,11 +18,11 @@ - bcss test samples (split intrinsically - in the new PR) length of individual loaders: @all (3 channel input images) - - lizard: train - 359*2; val - 179 - - bcss: train - 54*2; val - 28 - - monuseg: train - 15*2; val - 7 - - monusac: train - 84*2; val - 41 - - pannuke: train - 647*2; val - 680 + - lizard: train - 718; val - 179 + - bcss: train - 108; val - 28 + - monuseg: train - 30; val - 7 + - monusac: train - 168; val - 41 + - pannuke: train - 1294; val - 680 """ @@ -31,16 +33,20 @@ def _get_train_val_split(ds, val_fraction: float = 0.2): class BCSSLabelTrafo: - def __init__(self, label_choices: Optional[List[int]] = None): + def __init__(self, label_choices: Optional[List[int]] = None, do_connected_components: bool = False): self.label_choices = label_choices + self.do_connected_components = do_connected_components def __call__(self, labels: np.ndarray) -> np.ndarray: """Returns the transformed bcss data labels (use-case for SAM)""" if self.label_choices is not None: labels[~np.isin(labels, self.label_choices)] = 0 - segmentation = label_padding_trafo(labels) + + if self.do_connected_components: + segmentation = measure.label(labels) else: segmentation = label_padding_trafo(labels) + return segmentation @@ -92,11 +98,11 @@ def get_concat_hp_datasets(path, patch_shape): # get bcss internal splits bcss_train_ds = datasets.get_bcss_dataset( path=os.path.join(path, "bcss"), patch_shape=patch_shape, split="train", sampler=MinInstanceSampler(), - label_transform=BCSSLabelTrafo(), label_dtype=label_dtype + label_transform=BCSSLabelTrafo(do_connected_components=True), label_dtype=label_dtype ) bcss_val_ds = datasets.get_bcss_dataset( path=os.path.join(path, "bcss"), patch_shape=patch_shape, split="val", sampler=MinInstanceSampler(), - label_transform=BCSSLabelTrafo(), label_dtype=label_dtype + label_transform=BCSSLabelTrafo(do_connected_components=True), label_dtype=label_dtype ) # make monuseg train dataset splits into fractions From 1071e0208ceda3356a9159d1d8cf2e62f658d402 Mon Sep 17 00:00:00 2001 From: anwai98 Date: Mon, 30 Oct 2023 15:03:21 +0100 Subject: [PATCH 6/6] Fix raw trafo for segmentation datasets --- .../generalists/training/histopathology/obtain_hp_datasets.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/finetuning/generalists/training/histopathology/obtain_hp_datasets.py b/finetuning/generalists/training/histopathology/obtain_hp_datasets.py index 8869547d..835417cb 100644 --- a/finetuning/generalists/training/histopathology/obtain_hp_datasets.py +++ b/finetuning/generalists/training/histopathology/obtain_hp_datasets.py @@ -9,6 +9,7 @@ import torch.utils.data as data_util import torch_em +from torch_em.transform.raw import standardize from torch_em.data import datasets, MinInstanceSampler, ConcatDataset @@ -45,13 +46,14 @@ def __call__(self, labels: np.ndarray) -> np.ndarray: if self.do_connected_components: segmentation = measure.label(labels) else: - segmentation = label_padding_trafo(labels) + segmentation = label_consecutive_trafo(labels) return segmentation def raw_padding_trafo(raw, desired_shape=(3, 512, 512)): assert raw.shape[0] == 3, "The input shape isn't channels first, expected: (3, H, W)" + raw = standardize(raw) tmp_ddim = (desired_shape[1] - raw.shape[1], desired_shape[2] - raw.shape[2]) ddim = (tmp_ddim[0] / 2, tmp_ddim[1] / 2) raw = np.pad(