-
Notifications
You must be signed in to change notification settings - Fork 47
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add training scripts for unet and unetr
- Loading branch information
Showing
8 changed files
with
577 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,241 @@ | ||
import os | ||
import argparse | ||
from glob import glob | ||
from tqdm import tqdm | ||
|
||
import h5py | ||
import numpy as np | ||
import pandas as pd | ||
import imageio.v3 as imageio | ||
|
||
import torch | ||
|
||
import torch_em | ||
from torch_em.transform.raw import normalize | ||
from torch_em.transform.raw import standardize | ||
from torch_em.loss import DiceBasedDistanceLoss | ||
from torch_em.util import segmentation, prediction | ||
from torch_em.transform.label import PerObjectDistanceTransform | ||
from torch_em.data.datasets.light_microscopy import get_livecell_loader, get_covid_if_loader | ||
|
||
import micro_sam.training as sam_training | ||
from micro_sam.training.util import ConvertToSemanticSamInputs | ||
|
||
from elf.evaluation import mean_segmentation_accuracy | ||
|
||
|
||
# | ||
# DATALOADERS | ||
# | ||
|
||
|
||
def covid_if_raw_trafo(raw): | ||
raw = normalize(raw) | ||
raw = raw * 255 | ||
return raw | ||
|
||
|
||
def get_loaders(path, patch_shape, dataset, for_sam=False): | ||
kwargs = { | ||
"label_transform": PerObjectDistanceTransform( | ||
distances=True, | ||
boundary_distances=True, | ||
directed_distances=False, | ||
foreground=True, | ||
min_size=25, | ||
), | ||
"label_dtype": torch.float32, | ||
"num_workers": 16, | ||
"patch_shape": patch_shape, | ||
"shuffle": True, | ||
} | ||
|
||
if for_sam: | ||
kwargs["raw_transform"] = sam_training.identity if dataset == "livecell" else covid_if_raw_trafo | ||
|
||
if dataset == "livecell": | ||
train_loader = get_livecell_loader(path=os.path.join(path, "livecell"), split="train", batch_size=2, **kwargs) | ||
val_loader = get_livecell_loader(path=os.path.join(path, "livecell"), split="val", batch_size=1, **kwargs) | ||
|
||
elif dataset.startswith("covid_if"): | ||
data_path = os.path.join(path, "covid_if") | ||
|
||
# Let's get the number of images to train on | ||
n_images = int(dataset.split("-")[-1]) | ||
assert n_images in [1, 2, 5, 10], f"Please choose number of images from 1, 2, 5, or 10; instead of {n_images}." | ||
|
||
train_volumes = (None, n_images) | ||
val_volumes = (10, 13) | ||
|
||
# Let's get the number of samples extracted, to set the "n_samples" value | ||
# This is done to avoid the time taken to save checkpoints over fewer training images. | ||
_loader = get_covid_if_loader( | ||
path=data_path, patch_shape=patch_shape, batch_size=1, sample_range=train_volumes | ||
) | ||
|
||
print( | ||
f"Found {len(_loader)} samples for training.", | ||
"Hence, we will use {0} samples for training.".format(50 if len(_loader) < 50 else len(_loader)) | ||
) | ||
|
||
# Finally, let's get the dataloaders | ||
train_loader = get_covid_if_loader( | ||
path=data_path, | ||
batch_size=1, | ||
sample_range=train_volumes, | ||
n_samples=50 if len(_loader) < 50 else None, | ||
**kwargs | ||
) | ||
val_loader = get_covid_if_loader( | ||
path=data_path, | ||
batch_size=1, | ||
sample_range=val_volumes, | ||
**kwargs | ||
) | ||
|
||
else: | ||
raise ValueError(f"'{dataset}' is not a valid dataset name.") | ||
|
||
return train_loader, val_loader | ||
|
||
|
||
# | ||
# TRAINING SCRIPTS | ||
# | ||
|
||
|
||
def run_training(name, path, save_root, iterations, model, device, dataset, for_sam=False): | ||
# all the necessary stuff for training | ||
patch_shape = (512, 512) | ||
train_loader, val_loader = get_loaders(path=path, patch_shape=patch_shape, dataset=dataset, for_sam=for_sam) | ||
loss = DiceBasedDistanceLoss(mask_distances_in_bg=True) | ||
|
||
trainer = torch_em.default_segmentation_trainer( | ||
name=name, | ||
model=model, | ||
train_loader=train_loader, | ||
val_loader=val_loader, | ||
device=device, | ||
learning_rate=1e-5, | ||
loss=loss, | ||
metric=loss, | ||
log_image_interval=50, | ||
save_root=save_root, | ||
compile_model=False, | ||
mixed_precision=True, | ||
scheduler_kwargs={"mode": "min", "factor": 0.9, "patience": 5, "verbose": True}, | ||
) | ||
|
||
trainer.fit(int(iterations)) | ||
|
||
|
||
# | ||
# INFERENCE SCRIPTS | ||
# | ||
|
||
def run_inference( | ||
path, checkpoint_path, model, device, result_path, dataset, for_sam=False, with_semantic_sam=False, | ||
): | ||
model.load_state_dict(torch.load(checkpoint_path, map_location="cpu")["model_state"]) | ||
model.to(device) | ||
model.eval() | ||
|
||
if dataset == "livecell": | ||
# the splits are provided with the livecell dataset to reproduce the results: | ||
# run the inference on the entire dataset as it is. | ||
test_image_dir = os.path.join(path, "livecell", "images", "livecell_test_images") | ||
all_test_labels = glob(os.path.join(path, "livecell", "annotations", "livecell_test_images", "*", "*")) | ||
|
||
elif dataset.startswith("covid_if"): | ||
# we create our own splits for this dataset. | ||
# - the first 10 images are dedicated for training. | ||
# - the next 3 images are dedicated for validation. | ||
# - the remaining images are used for testing | ||
all_test_labels = glob(os.path.join(path, "covid_if", "*.h5"))[13:] | ||
|
||
else: | ||
raise ValueError(f"'{dataset}' is not a valid dataset name.") | ||
|
||
def prediction_fn(net, inp): | ||
convert_inputs = ConvertToSemanticSamInputs() | ||
batched_inputs = convert_inputs(inp, torch.zeros_like(inp)) | ||
image_embeddings, batched_inputs = net.image_embeddings_oft(batched_inputs) | ||
batched_outputs = net(batched_inputs, image_embeddings, multimask_output=True) | ||
masks = torch.stack([output["masks"] for output in batched_outputs]).squeeze() | ||
masks = masks[None] | ||
return masks | ||
|
||
msa_list, sa50_list, sa75_list = [], [], [] | ||
for label_path in tqdm(all_test_labels): | ||
image_id = os.path.split(label_path)[-1] | ||
|
||
if dataset == "livecell": | ||
image = imageio.imread(os.path.join(test_image_dir, image_id)) | ||
labels = imageio.imread(label_path) | ||
else: | ||
with h5py.File(label_path) as f: | ||
image = f["raw/serum_IgG/s0"][:] | ||
labels = f["labels/cells/s0"][:] | ||
|
||
if for_sam: | ||
image = image.astype("float32") # functional interpolate cannot work with uint. | ||
per_tile_pp = covid_if_raw_trafo if dataset.startswith("covid_if") else None | ||
else: | ||
per_tile_pp = standardize | ||
|
||
predictions = prediction.predict_with_halo( | ||
input_=image, | ||
model=model, | ||
gpu_ids=[device], | ||
block_shape=(384, 384), | ||
halo=(64, 64), | ||
preprocess=per_tile_pp, | ||
disable_tqdm=True, | ||
output=np.zeros((3, *image.shape)) if with_semantic_sam else None, | ||
prediction_function=prediction_fn if with_semantic_sam else None, | ||
) | ||
predictions = predictions.squeeze() | ||
|
||
fg, cdist, bdist = predictions | ||
instances = segmentation.watershed_from_center_and_boundary_distances( | ||
cdist, bdist, fg, min_size=50, | ||
center_distance_threshold=0.5, | ||
boundary_distance_threshold=0.6, | ||
distance_smoothing=1.0 | ||
) | ||
|
||
msa, sa_acc = mean_segmentation_accuracy(instances, labels, return_accuracies=True) | ||
msa_list.append(msa) | ||
sa50_list.append(sa_acc[0]) | ||
sa75_list.append(sa_acc[5]) | ||
|
||
res = { | ||
"LIVECell" if dataset == "livecell" else "Covid IF": "Metrics", | ||
"mSA": np.mean(msa_list), | ||
"SA50": np.mean(sa50_list), | ||
"SA75": np.mean(sa75_list) | ||
} | ||
|
||
os.makedirs(result_path, exist_ok=True) | ||
res_path = os.path.join(result_path, "results.csv") | ||
df = pd.DataFrame.from_dict([res]) | ||
df.to_csv(res_path) | ||
print(df) | ||
print(f"The result is saved at {res_path}") | ||
|
||
|
||
# | ||
# MISCELLANOUS | ||
# | ||
|
||
|
||
def get_default_arguments(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("-d", "--dataset", type=str, required=True) | ||
parser.add_argument("-i", "--input_path", type=str, default="/scratch/projects/nim00007/sam/data") | ||
parser.add_argument("-s", "--save_root", type=str, default=None) | ||
parser.add_argument("-p", "--phase", type=str, default=None, choices=["train", "predict"]) | ||
parser.add_argument("--iterations", type=str, default=1e5) | ||
parser.add_argument("--sam", action="store_true") | ||
args = parser.parse_args() | ||
return args |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import os | ||
import shutil | ||
import itertools | ||
import subprocess | ||
from datetime import datetime | ||
|
||
|
||
def _write_batch_script(script_path, dataset_name, exp_script, save_root, phase, with_sam): | ||
job_name = exp_script.split("_")[-1] + ("-sam-" if with_sam else "-") + dataset_name | ||
|
||
batch_script = f"""#!/bin/bash | ||
#SBATCH -t 2-00:00:00 | ||
#SBATCH --mem 64G | ||
#SBATCH -c 16 | ||
#SBATCH --nodes=1 | ||
#SBATCH --ntasks=1 | ||
#SBATCH -p grete-h100:shared | ||
#SBATCH -G H100:1 | ||
#SBATCH -A gzz0001 | ||
#SBATCH --job-name={job_name} | ||
source activate sam \n""" | ||
|
||
# python script | ||
script = f"python {exp_script}.py " | ||
|
||
# all other parameters | ||
script += f"-d {dataset_name} -s {save_root} -p {phase} " | ||
|
||
# whether the model is trained using SAM pretrained weights | ||
if with_sam: | ||
script += "--sam " | ||
|
||
# let's combine both the scripts | ||
batch_script += script | ||
|
||
output_path = script_path[:-3] + f"_{job_name}.sh" | ||
with open(output_path, "w") as f: | ||
f.write(batch_script) | ||
|
||
cmd = ["sbatch", output_path] | ||
subprocess.run(cmd) | ||
|
||
|
||
def _get_batch_script(tmp_folder): | ||
tmp_folder = os.path.expanduser(tmp_folder) | ||
os.makedirs(tmp_folder, exist_ok=True) | ||
|
||
script_name = "ais_benchmarking" | ||
|
||
dt = datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f") | ||
tmp_name = script_name + dt | ||
batch_script = os.path.join(tmp_folder, f"{tmp_name}.sh") | ||
|
||
return batch_script | ||
|
||
|
||
def _submit_to_slurm(tmp_folder): | ||
save_root = "/scratch/share/cidas/cca/models/micro-sam/ais_benchmarking/" | ||
phase = "predict" # this can be updated to "train" / "predict" to run the respective scripts. | ||
|
||
scripts = ["train_unet", "train_unetr", "train_semanticsam"] | ||
datasets = ["livecell", "covid_if-1", "covid_if-2", "covid_if-5", "covid_if-10"] | ||
sam_combinations = [True, False] | ||
|
||
for (exp_script, dataset_name, with_sam) in itertools.product(scripts, datasets, sam_combinations): | ||
if exp_script.endswith("_unet") and with_sam: | ||
continue | ||
|
||
_write_batch_script( | ||
script_path=_get_batch_script(tmp_folder), | ||
dataset_name=dataset_name, | ||
exp_script=exp_script, | ||
save_root=save_root, | ||
phase=phase, | ||
with_sam=with_sam, | ||
) | ||
|
||
|
||
def main(): | ||
tmp_folder = "./gpu_jobs" | ||
|
||
try: | ||
shutil.rmtree(tmp_folder) | ||
except FileNotFoundError: | ||
pass | ||
|
||
_submit_to_slurm(tmp_folder) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Oops, something went wrong.