Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Histopathology Generalist Training #248

Merged
merged 6 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 158 additions & 0 deletions finetuning/generalists/training/histopathology/obtain_hp_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import os
import numpy as np
from math import ceil, floor
from typing import Optional, List

from skimage import measure

import torch
import torch.utils.data as data_util

import torch_em
from torch_em.transform.raw import standardize
from torch_em.data import datasets, MinInstanceSampler, ConcatDataset


"""NOTE: test sets for in-domain histopathology evaluation
- monuseg test split
- monusac test split
- bcss test samples (split intrinsically - in the new PR)

length of individual loaders: @all (3 channel input images)
- lizard: train - 718; val - 179
- bcss: train - 108; val - 28
- monuseg: train - 30; val - 7
- monusac: train - 168; val - 41
- pannuke: train - 1294; val - 680
"""


def _get_train_val_split(ds, val_fraction: float = 0.2):
generator = torch.Generator().manual_seed(42)
train_ds, val_ds = data_util.random_split(ds, [1 - val_fraction, val_fraction], generator=generator)
return train_ds, val_ds


class BCSSLabelTrafo:
def __init__(self, label_choices: Optional[List[int]] = None, do_connected_components: bool = False):
self.label_choices = label_choices
self.do_connected_components = do_connected_components

def __call__(self, labels: np.ndarray) -> np.ndarray:
"""Returns the transformed bcss data labels (use-case for SAM)"""
anwai98 marked this conversation as resolved.
Show resolved Hide resolved
if self.label_choices is not None:
labels[~np.isin(labels, self.label_choices)] = 0

if self.do_connected_components:
segmentation = measure.label(labels)
else:
segmentation = label_consecutive_trafo(labels)

return segmentation


def raw_padding_trafo(raw, desired_shape=(3, 512, 512)):
assert raw.shape[0] == 3, "The input shape isn't channels first, expected: (3, H, W)"
raw = standardize(raw)
tmp_ddim = (desired_shape[1] - raw.shape[1], desired_shape[2] - raw.shape[2])
ddim = (tmp_ddim[0] / 2, tmp_ddim[1] / 2)
raw = np.pad(
raw,
pad_width=((0, 0), (ceil(ddim[0]), floor(ddim[0])), (ceil(ddim[1]), floor(ddim[1]))),
mode="reflect"
)
assert raw.shape == desired_shape
return raw


def label_padding_trafo(labels, desired_shape=(512, 512)):
tmp_ddim = (desired_shape[0] - labels.shape[0], desired_shape[1] - labels.shape[1])
ddim = (tmp_ddim[0] / 2, tmp_ddim[1] / 2)
labels = np.pad(
labels,
pad_width=((ceil(ddim[0]), floor(ddim[0])), (ceil(ddim[1]), floor(ddim[1]))),
mode="reflect"
)
assert labels.shape == desired_shape
labels = label_consecutive_trafo(labels)
return labels


def label_consecutive_trafo(labels):
labels = labels.astype(int)
labels = torch_em.transform.label.label_consecutive(labels) # to ensure consecutive IDs
return labels


def get_concat_hp_datasets(path, patch_shape):
label_dtype = torch.int64
sampler = MinInstanceSampler(min_num_instances=5)

# make lizard dataset splits into fractions
lizard_ds = datasets.get_lizard_dataset(
path=os.path.join(path, "lizard"), patch_shape=patch_shape, sampler=sampler, label_dtype=label_dtype,
raw_transform=raw_padding_trafo, label_transform=label_padding_trafo
anwai98 marked this conversation as resolved.
Show resolved Hide resolved
)
lizard_train_ds, lizard_val_ds = _get_train_val_split(ds=lizard_ds)
lizard_train_ds.ndim = 2
lizard_val_ds.ndim = 2

# get bcss internal splits
bcss_train_ds = datasets.get_bcss_dataset(
path=os.path.join(path, "bcss"), patch_shape=patch_shape, split="train", sampler=MinInstanceSampler(),
label_transform=BCSSLabelTrafo(do_connected_components=True), label_dtype=label_dtype
)
bcss_val_ds = datasets.get_bcss_dataset(
path=os.path.join(path, "bcss"), patch_shape=patch_shape, split="val", sampler=MinInstanceSampler(),
label_transform=BCSSLabelTrafo(do_connected_components=True), label_dtype=label_dtype
)

# make monuseg train dataset splits into fractions
monuseg_ds = datasets.get_monuseg_dataset(
path=os.path.join(path, "monuseg"), patch_shape=patch_shape, split="train", sampler=sampler,
label_transform=label_consecutive_trafo, ndim=2, label_dtype=label_dtype
)
monuseg_train_ds, monuseg_val_ds = _get_train_val_split(ds=monuseg_ds)

# make monusac train dataset splits into fractions
monusac_ds = datasets.get_monusac_dataset(
path=os.path.join(path, "monusac"), patch_shape=patch_shape, split="train", sampler=MinInstanceSampler(),
label_transform=label_consecutive_trafo, ndim=2, label_dtype=label_dtype
)
monusac_train_ds, monusac_val_ds = _get_train_val_split(ds=monusac_ds)

# out of three folds (sets of data) of provided data, we use two for training and 1 for validation
pannuke_train_ds = datasets.get_pannuke_dataset(
path=os.path.join(path, "pannuke"), patch_shape=(1, *patch_shape), sampler=sampler, folds=["fold_1", "fold_2"],
label_transform=label_padding_trafo, raw_transform=raw_padding_trafo, ndim=2, label_dtype=label_dtype
anwai98 marked this conversation as resolved.
Show resolved Hide resolved
)
pannuke_val_ds = datasets.get_pannuke_dataset(
path=os.path.join(path, "pannuke"), patch_shape=(1, *patch_shape), sampler=sampler, folds=["fold_3"],
label_transform=label_padding_trafo, raw_transform=raw_padding_trafo, ndim=2, label_dtype=label_dtype
)

generalist_hp_train_dataset = ConcatDataset(
lizard_train_ds, bcss_train_ds, monuseg_train_ds, monusac_train_ds, pannuke_train_ds
)

generalist_hp_val_dataset = ConcatDataset(
lizard_val_ds, bcss_val_ds, monuseg_val_ds, monusac_val_ds, pannuke_val_ds
)

return generalist_hp_train_dataset, generalist_hp_val_dataset


def get_generalist_hp_loaders(patch_shape, data_path):
"""This returns the concatenated histopathology datasets implemented in `torch_em`:
https://github.com/constantinpape/torch-em/tree/main/torch_em/data/datasets
It will automatically download all the datasets

NOTE: to remove / replace the datasets with another dataset, you need to add the datasets (for train and val splits)
in `get_concat_lm_dataset`. The labels have to be in a label mask instance segmentation format.
i.e. the tensors (inputs & masks) should be of same spatial shape, with each object in the mask having it's own ID.
IMPORTANT: the ID 0 is reserved for background, and the IDs must be consecutive.
"""
generalist_train_dataset, generalist_val_dataset = get_concat_hp_datasets(path=data_path, patch_shape=patch_shape)
train_loader = torch_em.get_data_loader(generalist_train_dataset, batch_size=2, shuffle=True, num_workers=16)
val_loader = torch_em.get_data_loader(generalist_val_dataset, batch_size=1, shuffle=True, num_workers=16)
return train_loader, val_loader
Original file line number Diff line number Diff line change
@@ -1,62 +1,50 @@
import os
import argparse

import micro_sam.training as sam_training
import torch
import torch_em
from torch_em.loss import DiceLoss

import torch.utils.data as data_util
from torch_em.data.datasets import get_lizard_dataset
from torch_em.data.sampler import MinInstanceSampler
import micro_sam.training as sam_training
from micro_sam.util import export_custom_sam_model

from obtain_hp_datasets import get_generalist_hp_loaders

# TODO use other datasets than lizard
def get_dataloaders(patch_shape, data_path):
label_transform = torch_em.transform.label.label_consecutive # to ensure consecutive IDs
sampler = MinInstanceSampler(min_num_instances=5)
dataset = get_lizard_dataset(
path=data_path, download=True, patch_shape=patch_shape, label_transform=label_transform,
sampler=sampler,
)
train_ds, val_ds = data_util.random_split(dataset, [0.9, 0.1])
train_loader = torch_em.get_data_loader(train_ds, batch_size=1)
val_loader = torch_em.get_data_loader(val_ds, batch_size=1)
return train_loader, val_loader


def finetune_histopatho(input_path, export_path, model_type="vit_h", iterations=int(2e4), save_root=None):
"""Example code for finetuning SAM on LiveCELL"""
def finetune_hp_generalist(args):
"""Example code for finetuning SAM on multiple histopathology datasets"""
# override this (below) if you have some more complex set-up and need to specify the exact gpu
device = "cuda" if torch.cuda.is_available() else "cpu"

# training settings:
model_type = args.model_type
checkpoint_path = None # override this to start training from a custom checkpoint
device = "cuda" # override this if you have some more complex set-up and need to specify the exact gpu
patch_shape = (512, 512) # the patch shape for training
n_objects_per_batch = 50 # this is the number of objects per batch that will be sampled

train_loader, val_loader = get_dataloaders(patch_shape=patch_shape, data_path=input_path)
n_objects_per_batch = 25 # this is the number of objects per batch that will be sampled
freeze_parts = None # override this to freeze one or more of these backbones

# get the trainable segment anything model
model = sam_training.get_trainable_sam_model(model_type, checkpoint_path, device=device)
model = sam_training.get_trainable_sam_model(model_type, checkpoint_path, freeze_parts)

# all the stuff we need for training
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.9, patience=10, verbose=True)
train_loader, val_loader = get_generalist_hp_loaders(patch_shape=patch_shape, data_path=args.input_path)

# this class creates all the training data for a batch (inputs, prompts and labels)
convert_inputs = sam_training.ConvertToSamInputs()

checkpoint_name = "sam-histopatho-v1"
checkpoint_name = f"generalist-hp-sam-{args.model_type}"
# the trainer which performs training and validation (implemented using "torch_em")
trainer = sam_training.SamTrainer(
name=checkpoint_name,
save_root=save_root,
save_root=args.save_root,
train_loader=train_loader,
val_loader=val_loader,
model=model,
optimizer=optimizer,
# currently we compute loss batch-wise, else we pass channelwise True
loss=torch_em.loss.DiceLoss(channelwise=False),
metric=torch_em.loss.DiceLoss(),
loss=DiceLoss(channelwise=False),
metric=DiceLoss(),
device=device,
lr_scheduler=scheduler,
logger=sam_training.SamLogger,
Expand All @@ -67,22 +55,42 @@ def finetune_histopatho(input_path, export_path, model_type="vit_h", iterations=
n_sub_iteration=8,
compile_model=False
)
trainer.fit(iterations)
if export_path is not None:
trainer.fit(iterations=args.iterations)
if args.export_path is not None:
checkpoint_path = os.path.join(
"" if save_root is None else save_root, "checkpoints", checkpoint_name, "best.pt"
"" if args.save_root is None else args.save_root, "checkpoints", checkpoint_name, "best.pt"
)
export_custom_sam_model(
checkpoint_path=checkpoint_path,
model_type=model_type,
save_path=export_path,
save_path=args.export_path,
)


def main():
input_path = "/scratch-grete/projects/nim00007/data/lizard"
export_path = "./sam-vith-histopatho-v1.pth"
finetune_histopatho(input_path, export_path)
parser = argparse.ArgumentParser(description="Finetune Segment Anything for the Histopathology datasets.")
parser.add_argument(
"--input_path", "-i", default="/scratch/usr/nimanwai/data/",
help="The filepath to all the respective hp datasets. If the data does not exist yet it will be downloaded"
)
parser.add_argument(
"--model_type", "-m", default="vit_b",
help="The model type to use for fine-tuning. Either vit_h, vit_b or vit_l."
)
parser.add_argument(
"--save_root", "-s",
help="Where to save the checkpoint and logs. By default they will be saved where this script is run from."
)
parser.add_argument(
"--iterations", type=int, default=int(1e5),
help="For how many iterations should the model be trained? By default 100k."
)
parser.add_argument(
"--export_path", "-e",
help="Where to export the finetuned model to. The exported model can be used in the annotation tools."
)
args = parser.parse_args()
finetune_hp_generalist(args)


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
#! /bin/bash
#SBATCH -c 16
#SBATCH --mem 128G
#SBATCH -t 2800
#SBATCH -t 7-00:00:00
#SBATCH -p grete:shared
#SBATCH -G A100:1
#SBATCH -A nim00007
#SBATCH --constraint=80gb
#SBATCH --qos=7d
#SBATCH --job-name=sam_histopathology

source activate sam
source ~/.bashrc
mamba activate sam
python train_histopathology_generalist.py $@