Skip to content

Commit

Permalink
Merge pull request #279 from computational-cell-analytics/master
Browse files Browse the repository at this point in the history
Keep master and dev in sync
  • Loading branch information
constantinpape authored Nov 18, 2023
2 parents c236df3 + 2a55ec5 commit d5dc804
Show file tree
Hide file tree
Showing 6 changed files with 213 additions and 44 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ Please check out [the documentation](https://computational-cell-analytics.github

We welcome new contributions!

If you are interested in contributing to micro-sam, please see the [contributing guide](docs/contributing.md) and [developer documentation](docs/development.md). The first step is to [discuss your idea in anew issue](https://github.com/computational-cell-analytics/micro-sam/issues/new) with the current developers.
If you are interested in contributing to micro-sam, please see the [contributing guide](doc/contributing.md) and [developer documentation](doc/development.md). The first step is to [discuss your idea in anew issue](https://github.com/computational-cell-analytics/micro-sam/issues/new) with the current developers.

## Citation

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from skimage.segmentation import watershed

from torch_em import get_data_loader
from torch_em.transform.raw import standardize
from torch_em.transform.label import label_consecutive
from torch_em.data import ConcatDataset, MinInstanceSampler, datasets

Expand All @@ -20,8 +21,8 @@ def axondeepseg_label_trafo(labels):
return seg


def raw_trafo_for_padding(raw):
desired_shape = (512, 512)
def raw_trafo_for_padding(raw, desired_shape=(512, 512)):
raw = standardize(raw)
tmp_ddim = (desired_shape[0] - raw.shape[0], desired_shape[1] - raw.shape[1])
ddim = (tmp_ddim[0] / 2, tmp_ddim[1] / 2)
raw = np.pad(raw,
Expand All @@ -31,8 +32,7 @@ def raw_trafo_for_padding(raw):
return raw


def label_trafo_for_padding(labels):
desired_shape = (512, 512)
def label_trafo_for_padding(labels, desired_shape=(512, 512)):
labels = label(labels)
labels = label_consecutive(labels)
tmp_ddim = (desired_shape[0] - labels.shape[0], desired_shape[1] - labels.shape[1])
Expand Down
158 changes: 158 additions & 0 deletions finetuning/generalists/training/histopathology/obtain_hp_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import os
import numpy as np
from math import ceil, floor
from typing import Optional, List

from skimage import measure

import torch
import torch.utils.data as data_util

import torch_em
from torch_em.transform.raw import standardize
from torch_em.data import datasets, MinInstanceSampler, ConcatDataset


"""NOTE: test sets for in-domain histopathology evaluation
- monuseg test split
- monusac test split
- bcss test samples (split intrinsically - in the new PR)
length of individual loaders: @all (3 channel input images)
- lizard: train - 718; val - 179
- bcss: train - 108; val - 28
- monuseg: train - 30; val - 7
- monusac: train - 168; val - 41
- pannuke: train - 1294; val - 680
"""


def _get_train_val_split(ds, val_fraction: float = 0.2):
generator = torch.Generator().manual_seed(42)
train_ds, val_ds = data_util.random_split(ds, [1 - val_fraction, val_fraction], generator=generator)
return train_ds, val_ds


class BCSSLabelTrafo:
def __init__(self, label_choices: Optional[List[int]] = None, do_connected_components: bool = False):
self.label_choices = label_choices
self.do_connected_components = do_connected_components

def __call__(self, labels: np.ndarray) -> np.ndarray:
"""Returns the transformed bcss data labels (use-case for SAM)"""
if self.label_choices is not None:
labels[~np.isin(labels, self.label_choices)] = 0

if self.do_connected_components:
segmentation = measure.label(labels)
else:
segmentation = label_consecutive_trafo(labels)

return segmentation


def raw_padding_trafo(raw, desired_shape=(3, 512, 512)):
assert raw.shape[0] == 3, "The input shape isn't channels first, expected: (3, H, W)"
raw = standardize(raw)
tmp_ddim = (desired_shape[1] - raw.shape[1], desired_shape[2] - raw.shape[2])
ddim = (tmp_ddim[0] / 2, tmp_ddim[1] / 2)
raw = np.pad(
raw,
pad_width=((0, 0), (ceil(ddim[0]), floor(ddim[0])), (ceil(ddim[1]), floor(ddim[1]))),
mode="reflect"
)
assert raw.shape == desired_shape
return raw


def label_padding_trafo(labels, desired_shape=(512, 512)):
tmp_ddim = (desired_shape[0] - labels.shape[0], desired_shape[1] - labels.shape[1])
ddim = (tmp_ddim[0] / 2, tmp_ddim[1] / 2)
labels = np.pad(
labels,
pad_width=((ceil(ddim[0]), floor(ddim[0])), (ceil(ddim[1]), floor(ddim[1]))),
mode="reflect"
)
assert labels.shape == desired_shape
labels = label_consecutive_trafo(labels)
return labels


def label_consecutive_trafo(labels):
labels = labels.astype(int)
labels = torch_em.transform.label.label_consecutive(labels) # to ensure consecutive IDs
return labels


def get_concat_hp_datasets(path, patch_shape):
label_dtype = torch.int64
sampler = MinInstanceSampler(min_num_instances=5)

# make lizard dataset splits into fractions
lizard_ds = datasets.get_lizard_dataset(
path=os.path.join(path, "lizard"), patch_shape=patch_shape, sampler=sampler, label_dtype=label_dtype,
raw_transform=raw_padding_trafo, label_transform=label_padding_trafo
)
lizard_train_ds, lizard_val_ds = _get_train_val_split(ds=lizard_ds)
lizard_train_ds.ndim = 2
lizard_val_ds.ndim = 2

# get bcss internal splits
bcss_train_ds = datasets.get_bcss_dataset(
path=os.path.join(path, "bcss"), patch_shape=patch_shape, split="train", sampler=MinInstanceSampler(),
label_transform=BCSSLabelTrafo(do_connected_components=True), label_dtype=label_dtype
)
bcss_val_ds = datasets.get_bcss_dataset(
path=os.path.join(path, "bcss"), patch_shape=patch_shape, split="val", sampler=MinInstanceSampler(),
label_transform=BCSSLabelTrafo(do_connected_components=True), label_dtype=label_dtype
)

# make monuseg train dataset splits into fractions
monuseg_ds = datasets.get_monuseg_dataset(
path=os.path.join(path, "monuseg"), patch_shape=patch_shape, split="train", sampler=sampler,
label_transform=label_consecutive_trafo, ndim=2, label_dtype=label_dtype
)
monuseg_train_ds, monuseg_val_ds = _get_train_val_split(ds=monuseg_ds)

# make monusac train dataset splits into fractions
monusac_ds = datasets.get_monusac_dataset(
path=os.path.join(path, "monusac"), patch_shape=patch_shape, split="train", sampler=MinInstanceSampler(),
label_transform=label_consecutive_trafo, ndim=2, label_dtype=label_dtype
)
monusac_train_ds, monusac_val_ds = _get_train_val_split(ds=monusac_ds)

# out of three folds (sets of data) of provided data, we use two for training and 1 for validation
pannuke_train_ds = datasets.get_pannuke_dataset(
path=os.path.join(path, "pannuke"), patch_shape=(1, *patch_shape), sampler=sampler, folds=["fold_1", "fold_2"],
label_transform=label_padding_trafo, raw_transform=raw_padding_trafo, ndim=2, label_dtype=label_dtype
)
pannuke_val_ds = datasets.get_pannuke_dataset(
path=os.path.join(path, "pannuke"), patch_shape=(1, *patch_shape), sampler=sampler, folds=["fold_3"],
label_transform=label_padding_trafo, raw_transform=raw_padding_trafo, ndim=2, label_dtype=label_dtype
)

generalist_hp_train_dataset = ConcatDataset(
lizard_train_ds, bcss_train_ds, monuseg_train_ds, monusac_train_ds, pannuke_train_ds
)

generalist_hp_val_dataset = ConcatDataset(
lizard_val_ds, bcss_val_ds, monuseg_val_ds, monusac_val_ds, pannuke_val_ds
)

return generalist_hp_train_dataset, generalist_hp_val_dataset


def get_generalist_hp_loaders(patch_shape, data_path):
"""This returns the concatenated histopathology datasets implemented in `torch_em`:
https://github.com/constantinpape/torch-em/tree/main/torch_em/data/datasets
It will automatically download all the datasets
NOTE: to remove / replace the datasets with another dataset, you need to add the datasets (for train and val splits)
in `get_concat_lm_dataset`. The labels have to be in a label mask instance segmentation format.
i.e. the tensors (inputs & masks) should be of same spatial shape, with each object in the mask having it's own ID.
IMPORTANT: the ID 0 is reserved for background, and the IDs must be consecutive.
"""
generalist_train_dataset, generalist_val_dataset = get_concat_hp_datasets(path=data_path, patch_shape=patch_shape)
train_loader = torch_em.get_data_loader(generalist_train_dataset, batch_size=2, shuffle=True, num_workers=16)
val_loader = torch_em.get_data_loader(generalist_val_dataset, batch_size=1, shuffle=True, num_workers=16)
return train_loader, val_loader
Original file line number Diff line number Diff line change
@@ -1,62 +1,50 @@
import os
import argparse

import micro_sam.training as sam_training
import torch
import torch_em
from torch_em.loss import DiceLoss

import torch.utils.data as data_util
from torch_em.data.datasets import get_lizard_dataset
from torch_em.data.sampler import MinInstanceSampler
import micro_sam.training as sam_training
from micro_sam.util import export_custom_sam_model

from obtain_hp_datasets import get_generalist_hp_loaders

# TODO use other datasets than lizard
def get_dataloaders(patch_shape, data_path):
label_transform = torch_em.transform.label.label_consecutive # to ensure consecutive IDs
sampler = MinInstanceSampler(min_num_instances=5)
dataset = get_lizard_dataset(
path=data_path, download=True, patch_shape=patch_shape, label_transform=label_transform,
sampler=sampler,
)
train_ds, val_ds = data_util.random_split(dataset, [0.9, 0.1])
train_loader = torch_em.get_data_loader(train_ds, batch_size=1)
val_loader = torch_em.get_data_loader(val_ds, batch_size=1)
return train_loader, val_loader


def finetune_histopatho(input_path, export_path, model_type="vit_h", iterations=int(2e4), save_root=None):
"""Example code for finetuning SAM on LiveCELL"""
def finetune_hp_generalist(args):
"""Example code for finetuning SAM on multiple histopathology datasets"""
# override this (below) if you have some more complex set-up and need to specify the exact gpu
device = "cuda" if torch.cuda.is_available() else "cpu"

# training settings:
model_type = args.model_type
checkpoint_path = None # override this to start training from a custom checkpoint
device = "cuda" # override this if you have some more complex set-up and need to specify the exact gpu
patch_shape = (512, 512) # the patch shape for training
n_objects_per_batch = 50 # this is the number of objects per batch that will be sampled

train_loader, val_loader = get_dataloaders(patch_shape=patch_shape, data_path=input_path)
n_objects_per_batch = 25 # this is the number of objects per batch that will be sampled
freeze_parts = None # override this to freeze one or more of these backbones

# get the trainable segment anything model
model = sam_training.get_trainable_sam_model(model_type, checkpoint_path, device=device)
model = sam_training.get_trainable_sam_model(model_type, checkpoint_path, freeze_parts)

# all the stuff we need for training
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.9, patience=10, verbose=True)
train_loader, val_loader = get_generalist_hp_loaders(patch_shape=patch_shape, data_path=args.input_path)

# this class creates all the training data for a batch (inputs, prompts and labels)
convert_inputs = sam_training.ConvertToSamInputs()

checkpoint_name = "sam-histopatho-v1"
checkpoint_name = f"generalist-hp-sam-{args.model_type}"
# the trainer which performs training and validation (implemented using "torch_em")
trainer = sam_training.SamTrainer(
name=checkpoint_name,
save_root=save_root,
save_root=args.save_root,
train_loader=train_loader,
val_loader=val_loader,
model=model,
optimizer=optimizer,
# currently we compute loss batch-wise, else we pass channelwise True
loss=torch_em.loss.DiceLoss(channelwise=False),
metric=torch_em.loss.DiceLoss(),
loss=DiceLoss(channelwise=False),
metric=DiceLoss(),
device=device,
lr_scheduler=scheduler,
logger=sam_training.SamLogger,
Expand All @@ -67,22 +55,42 @@ def finetune_histopatho(input_path, export_path, model_type="vit_h", iterations=
n_sub_iteration=8,
compile_model=False
)
trainer.fit(iterations)
if export_path is not None:
trainer.fit(iterations=args.iterations)
if args.export_path is not None:
checkpoint_path = os.path.join(
"" if save_root is None else save_root, "checkpoints", checkpoint_name, "best.pt"
"" if args.save_root is None else args.save_root, "checkpoints", checkpoint_name, "best.pt"
)
export_custom_sam_model(
checkpoint_path=checkpoint_path,
model_type=model_type,
save_path=export_path,
save_path=args.export_path,
)


def main():
input_path = "/scratch-grete/projects/nim00007/data/lizard"
export_path = "./sam-vith-histopatho-v1.pth"
finetune_histopatho(input_path, export_path)
parser = argparse.ArgumentParser(description="Finetune Segment Anything for the Histopathology datasets.")
parser.add_argument(
"--input_path", "-i", default="/scratch/usr/nimanwai/data/",
help="The filepath to all the respective hp datasets. If the data does not exist yet it will be downloaded"
)
parser.add_argument(
"--model_type", "-m", default="vit_b",
help="The model type to use for fine-tuning. Either vit_h, vit_b or vit_l."
)
parser.add_argument(
"--save_root", "-s",
help="Where to save the checkpoint and logs. By default they will be saved where this script is run from."
)
parser.add_argument(
"--iterations", type=int, default=int(1e5),
help="For how many iterations should the model be trained? By default 100k."
)
parser.add_argument(
"--export_path", "-e",
help="Where to export the finetuned model to. The exported model can be used in the annotation tools."
)
args = parser.parse_args()
finetune_hp_generalist(args)


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
#! /bin/bash
#SBATCH -c 16
#SBATCH --mem 128G
#SBATCH -t 2800
#SBATCH -t 7-00:00:00
#SBATCH -p grete:shared
#SBATCH -G A100:1
#SBATCH -A nim00007
#SBATCH --constraint=80gb
#SBATCH --qos=7d
#SBATCH --job-name=sam_histopathology

source activate sam
source ~/.bashrc
mamba activate sam
python train_histopathology_generalist.py $@
2 changes: 1 addition & 1 deletion micro_sam/evaluation/livecell.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def run_livecell_inference() -> None:
help="Pass the checkpoint-specific model name being used for inference.")

# the experiment type:
# - default settings (p1-n0, p2-n4, box)
# - default settings (p1-n0, p2-n4, p4-n8, box)
# - full experiment (ranges: p:1-16, n:0-16)
# - automatic mask generation (auto)
# if none of the two are active then the prompt setting arguments will be parsed
Expand Down

0 comments on commit d5dc804

Please sign in to comment.