Skip to content

Commit

Permalink
Add AIS benchmarking scripts (#657)
Browse files Browse the repository at this point in the history
Add training scripts for unet and unetr
  • Loading branch information
anwai98 authored Jul 23, 2024
1 parent ed11cc2 commit 83c8313
Show file tree
Hide file tree
Showing 8 changed files with 577 additions and 5 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,6 @@ cython_debug/
# Torch-em stuff
checkpoints/
logs/

# "gpu_jobs" folder where slurm batch submission scripts are saved
gpu_jobs/
2 changes: 1 addition & 1 deletion micro_sam/training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@
from .util import ConvertToSamInputs, get_trainable_sam_model, identity
from .joint_sam_trainer import JointSamTrainer, JointSamLogger
from .simple_sam_trainer import SimpleSamTrainer, MedSAMTrainer
from .semantic_sam_trainer import SemanticSamTrainer
from .semantic_sam_trainer import SemanticSamTrainer, SemanticMapsSamTrainer
from .training import train_sam, train_sam_for_configuration, default_sam_loader, default_sam_dataset, CONFIGURATIONS
24 changes: 20 additions & 4 deletions micro_sam/training/semantic_sam_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,13 @@ def __init__(
):
assert num_classes > 1

loss = CustomDiceLoss(num_classes=num_classes)
metric = CustomDiceLoss(num_classes=num_classes)
super().__init__(loss=loss, metric=metric, **kwargs)
if "loss" not in kwargs:
kwargs["loss"] = CustomDiceLoss(num_classes=num_classes)

if "metric" not in kwargs:
kwargs["metric"] = CustomDiceLoss(num_classes=num_classes)

super().__init__(**kwargs)

self.convert_inputs = convert_inputs
self.num_classes = num_classes
Expand Down Expand Up @@ -141,6 +145,18 @@ def _validate_impl(self, forward_context):
print(f"The Average Validation Metric Score for the Current Epoch is {dice_metric}")

if self.logger is not None:
self.logger.log_validation(self._iteration, metric_val, loss_val, x, y, torch.softmax(masks, dim=1))
self.logger.log_validation(
self._iteration, metric_val, loss_val, x, y, torch.softmax(masks, dim=1)
)

return metric_val


class SemanticMapsSamTrainer(SemanticSamTrainer):
def _compute_loss(self, y, masks):
target = y.to(self.device, non_blocking=True)

# Compute loss for the predictions
net_loss = self.loss(target, masks)

return net_loss
241 changes: 241 additions & 0 deletions scripts/for_benchmarking_ais/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
import os
import argparse
from glob import glob
from tqdm import tqdm

import h5py
import numpy as np
import pandas as pd
import imageio.v3 as imageio

import torch

import torch_em
from torch_em.transform.raw import normalize
from torch_em.transform.raw import standardize
from torch_em.loss import DiceBasedDistanceLoss
from torch_em.util import segmentation, prediction
from torch_em.transform.label import PerObjectDistanceTransform
from torch_em.data.datasets.light_microscopy import get_livecell_loader, get_covid_if_loader

import micro_sam.training as sam_training
from micro_sam.training.util import ConvertToSemanticSamInputs

from elf.evaluation import mean_segmentation_accuracy


#
# DATALOADERS
#


def covid_if_raw_trafo(raw):
raw = normalize(raw)
raw = raw * 255
return raw


def get_loaders(path, patch_shape, dataset, for_sam=False):
kwargs = {
"label_transform": PerObjectDistanceTransform(
distances=True,
boundary_distances=True,
directed_distances=False,
foreground=True,
min_size=25,
),
"label_dtype": torch.float32,
"num_workers": 16,
"patch_shape": patch_shape,
"shuffle": True,
}

if for_sam:
kwargs["raw_transform"] = sam_training.identity if dataset == "livecell" else covid_if_raw_trafo

if dataset == "livecell":
train_loader = get_livecell_loader(path=os.path.join(path, "livecell"), split="train", batch_size=2, **kwargs)
val_loader = get_livecell_loader(path=os.path.join(path, "livecell"), split="val", batch_size=1, **kwargs)

elif dataset.startswith("covid_if"):
data_path = os.path.join(path, "covid_if")

# Let's get the number of images to train on
n_images = int(dataset.split("-")[-1])
assert n_images in [1, 2, 5, 10], f"Please choose number of images from 1, 2, 5, or 10; instead of {n_images}."

train_volumes = (None, n_images)
val_volumes = (10, 13)

# Let's get the number of samples extracted, to set the "n_samples" value
# This is done to avoid the time taken to save checkpoints over fewer training images.
_loader = get_covid_if_loader(
path=data_path, patch_shape=patch_shape, batch_size=1, sample_range=train_volumes
)

print(
f"Found {len(_loader)} samples for training.",
"Hence, we will use {0} samples for training.".format(50 if len(_loader) < 50 else len(_loader))
)

# Finally, let's get the dataloaders
train_loader = get_covid_if_loader(
path=data_path,
batch_size=1,
sample_range=train_volumes,
n_samples=50 if len(_loader) < 50 else None,
**kwargs
)
val_loader = get_covid_if_loader(
path=data_path,
batch_size=1,
sample_range=val_volumes,
**kwargs
)

else:
raise ValueError(f"'{dataset}' is not a valid dataset name.")

return train_loader, val_loader


#
# TRAINING SCRIPTS
#


def run_training(name, path, save_root, iterations, model, device, dataset, for_sam=False):
# all the necessary stuff for training
patch_shape = (512, 512)
train_loader, val_loader = get_loaders(path=path, patch_shape=patch_shape, dataset=dataset, for_sam=for_sam)
loss = DiceBasedDistanceLoss(mask_distances_in_bg=True)

trainer = torch_em.default_segmentation_trainer(
name=name,
model=model,
train_loader=train_loader,
val_loader=val_loader,
device=device,
learning_rate=1e-5,
loss=loss,
metric=loss,
log_image_interval=50,
save_root=save_root,
compile_model=False,
mixed_precision=True,
scheduler_kwargs={"mode": "min", "factor": 0.9, "patience": 5, "verbose": True},
)

trainer.fit(int(iterations))


#
# INFERENCE SCRIPTS
#

def run_inference(
path, checkpoint_path, model, device, result_path, dataset, for_sam=False, with_semantic_sam=False,
):
model.load_state_dict(torch.load(checkpoint_path, map_location="cpu")["model_state"])
model.to(device)
model.eval()

if dataset == "livecell":
# the splits are provided with the livecell dataset to reproduce the results:
# run the inference on the entire dataset as it is.
test_image_dir = os.path.join(path, "livecell", "images", "livecell_test_images")
all_test_labels = glob(os.path.join(path, "livecell", "annotations", "livecell_test_images", "*", "*"))

elif dataset.startswith("covid_if"):
# we create our own splits for this dataset.
# - the first 10 images are dedicated for training.
# - the next 3 images are dedicated for validation.
# - the remaining images are used for testing
all_test_labels = glob(os.path.join(path, "covid_if", "*.h5"))[13:]

else:
raise ValueError(f"'{dataset}' is not a valid dataset name.")

def prediction_fn(net, inp):
convert_inputs = ConvertToSemanticSamInputs()
batched_inputs = convert_inputs(inp, torch.zeros_like(inp))
image_embeddings, batched_inputs = net.image_embeddings_oft(batched_inputs)
batched_outputs = net(batched_inputs, image_embeddings, multimask_output=True)
masks = torch.stack([output["masks"] for output in batched_outputs]).squeeze()
masks = masks[None]
return masks

msa_list, sa50_list, sa75_list = [], [], []
for label_path in tqdm(all_test_labels):
image_id = os.path.split(label_path)[-1]

if dataset == "livecell":
image = imageio.imread(os.path.join(test_image_dir, image_id))
labels = imageio.imread(label_path)
else:
with h5py.File(label_path) as f:
image = f["raw/serum_IgG/s0"][:]
labels = f["labels/cells/s0"][:]

if for_sam:
image = image.astype("float32") # functional interpolate cannot work with uint.
per_tile_pp = covid_if_raw_trafo if dataset.startswith("covid_if") else None
else:
per_tile_pp = standardize

predictions = prediction.predict_with_halo(
input_=image,
model=model,
gpu_ids=[device],
block_shape=(384, 384),
halo=(64, 64),
preprocess=per_tile_pp,
disable_tqdm=True,
output=np.zeros((3, *image.shape)) if with_semantic_sam else None,
prediction_function=prediction_fn if with_semantic_sam else None,
)
predictions = predictions.squeeze()

fg, cdist, bdist = predictions
instances = segmentation.watershed_from_center_and_boundary_distances(
cdist, bdist, fg, min_size=50,
center_distance_threshold=0.5,
boundary_distance_threshold=0.6,
distance_smoothing=1.0
)

msa, sa_acc = mean_segmentation_accuracy(instances, labels, return_accuracies=True)
msa_list.append(msa)
sa50_list.append(sa_acc[0])
sa75_list.append(sa_acc[5])

res = {
"LIVECell" if dataset == "livecell" else "Covid IF": "Metrics",
"mSA": np.mean(msa_list),
"SA50": np.mean(sa50_list),
"SA75": np.mean(sa75_list)
}

os.makedirs(result_path, exist_ok=True)
res_path = os.path.join(result_path, "results.csv")
df = pd.DataFrame.from_dict([res])
df.to_csv(res_path)
print(df)
print(f"The result is saved at {res_path}")


#
# MISCELLANOUS
#


def get_default_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("-d", "--dataset", type=str, required=True)
parser.add_argument("-i", "--input_path", type=str, default="/scratch/projects/nim00007/sam/data")
parser.add_argument("-s", "--save_root", type=str, default=None)
parser.add_argument("-p", "--phase", type=str, default=None, choices=["train", "predict"])
parser.add_argument("--iterations", type=str, default=1e5)
parser.add_argument("--sam", action="store_true")
args = parser.parse_args()
return args
92 changes: 92 additions & 0 deletions scripts/for_benchmarking_ais/submit_scripts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import os
import shutil
import itertools
import subprocess
from datetime import datetime


def _write_batch_script(script_path, dataset_name, exp_script, save_root, phase, with_sam):
job_name = exp_script.split("_")[-1] + ("-sam-" if with_sam else "-") + dataset_name

batch_script = f"""#!/bin/bash
#SBATCH -t 2-00:00:00
#SBATCH --mem 64G
#SBATCH -c 16
#SBATCH --nodes=1
#SBATCH --ntasks=1
#SBATCH -p grete-h100:shared
#SBATCH -G H100:1
#SBATCH -A gzz0001
#SBATCH --job-name={job_name}
source activate sam \n"""

# python script
script = f"python {exp_script}.py "

# all other parameters
script += f"-d {dataset_name} -s {save_root} -p {phase} "

# whether the model is trained using SAM pretrained weights
if with_sam:
script += "--sam "

# let's combine both the scripts
batch_script += script

output_path = script_path[:-3] + f"_{job_name}.sh"
with open(output_path, "w") as f:
f.write(batch_script)

cmd = ["sbatch", output_path]
subprocess.run(cmd)


def _get_batch_script(tmp_folder):
tmp_folder = os.path.expanduser(tmp_folder)
os.makedirs(tmp_folder, exist_ok=True)

script_name = "ais_benchmarking"

dt = datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f")
tmp_name = script_name + dt
batch_script = os.path.join(tmp_folder, f"{tmp_name}.sh")

return batch_script


def _submit_to_slurm(tmp_folder):
save_root = "/scratch/share/cidas/cca/models/micro-sam/ais_benchmarking/"
phase = "predict" # this can be updated to "train" / "predict" to run the respective scripts.

scripts = ["train_unet", "train_unetr", "train_semanticsam"]
datasets = ["livecell", "covid_if-1", "covid_if-2", "covid_if-5", "covid_if-10"]
sam_combinations = [True, False]

for (exp_script, dataset_name, with_sam) in itertools.product(scripts, datasets, sam_combinations):
if exp_script.endswith("_unet") and with_sam:
continue

_write_batch_script(
script_path=_get_batch_script(tmp_folder),
dataset_name=dataset_name,
exp_script=exp_script,
save_root=save_root,
phase=phase,
with_sam=with_sam,
)


def main():
tmp_folder = "./gpu_jobs"

try:
shutil.rmtree(tmp_folder)
except FileNotFoundError:
pass

_submit_to_slurm(tmp_folder)


if __name__ == "__main__":
main()
Loading

0 comments on commit 83c8313

Please sign in to comment.