Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add experiments for different minibatch normalizations #264

Open
wants to merge 62 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
2c948a3
Add experiments for different minibatch normalizations
anwai98 May 16, 2024
52a0601
Update patch shapes
anwai98 May 16, 2024
0c4f17d
Add script to submit all batch jobs
anwai98 May 16, 2024
303b8a4
Add inference scripts
anwai98 May 17, 2024
76ca13a
Add livecell results
anwai98 May 17, 2024
c3a61d7
Add napari visualization
anwai98 May 17, 2024
e90a7f2
Merge branch 'verify-norm' of https://github.com/anwai98/torch-em int…
anwai98 May 17, 2024
92501f2
Minor fix to boundary evaluation
anwai98 May 17, 2024
91b38d3
Add boundaries result
anwai98 May 17, 2024
5ec72f2
Add mouse embryo training
anwai98 May 17, 2024
587b9ed
Add inference for all other datasets
anwai98 May 21, 2024
7da423f
Merge branch 'main' into verify-norm
anwai98 May 21, 2024
9cd5135
Add visualization scripts
anwai98 May 21, 2024
97fe427
Split up inference-eval schemes
anwai98 May 21, 2024
45c8734
Update prediction saving mode
anwai98 May 21, 2024
0c1bfa2
Update paths
anwai98 May 22, 2024
9b669e3
Compress predictions
anwai98 May 22, 2024
6955d3e
Merge branch 'main' into verify-norm
anwai98 Jun 11, 2024
dce0f7a
Add all remaining results
anwai98 Jun 12, 2024
089fd62
Add new analysis scripts
anwai98 Jun 12, 2024
479b352
Comment out visualization scripts
anwai98 Jun 12, 2024
3f85870
Update normalization for inference
anwai98 Jun 14, 2024
0bdbcbb
Add new results
anwai98 Jun 14, 2024
61f9627
Merge branch 'main' into verify-norm
anwai98 Jun 17, 2024
be32576
Add labels to visible layers
anwai98 Jun 19, 2024
1e09cb0
Merge branch 'main' into verify-norm
anwai98 Jul 16, 2024
8847ab4
Merge branch 'main' into verify-norm
anwai98 Sep 16, 2024
45dbd85
Replace mouse embryo with gonuclear
anwai98 Sep 16, 2024
369af03
Update GoNuclear dataset for returning different target types
anwai98 Sep 16, 2024
b7a4136
Merge branch 'main' into verify-norm
anwai98 Sep 16, 2024
8e934d4
Refactor loader kwargs
anwai98 Sep 16, 2024
93afcda
Update patch shapes
anwai98 Sep 17, 2024
01b5659
Change rescaling at inference
anwai98 Sep 17, 2024
510c148
Increase compression opts
anwai98 Sep 17, 2024
1a75bed
Remove old evaluation scripts
anwai98 Sep 17, 2024
68bf472
Minor refactor to visual analysis scripts
anwai98 Sep 17, 2024
63ca4b1
Explore noise augmentation parameters for livecell
anwai98 Sep 18, 2024
be17391
Make per dataset script submission flexible
anwai98 Sep 18, 2024
56c383f
Add more results
anwai98 Sep 18, 2024
d20ebaa
Merge branch 'verify-norm' of https://github.com/anwai98/torch-em int…
anwai98 Sep 18, 2024
bfcdbba
Uncomment storing files to segmentation
anwai98 Sep 18, 2024
15a17af
Replace volume-level watershed with elf seeded watershed:
anwai98 Sep 18, 2024
614bdd0
Show instance labels per experiment
anwai98 Sep 19, 2024
31e5c45
Add scripts to extract patch-wise max intensity and skip empty patches
anwai98 Sep 19, 2024
7841071
Update function to skip empty blocks
anwai98 Sep 19, 2024
aed7b05
Change intensity threshold for filtering empty slices
anwai98 Sep 19, 2024
674b65a
Update storing data to h5 files
anwai98 Sep 19, 2024
0d2fb10
Merge branch 'main' into verify-norm
anwai98 Sep 19, 2024
f1cd70f
Add multicut for plantseg boundary segmentations
anwai98 Sep 19, 2024
5b2b195
Add plantseg instance segmentation results
anwai98 Sep 20, 2024
742b228
Filter out older results
anwai98 Sep 20, 2024
9866d6a
Refactor visualization scripts for plantseg
anwai98 Sep 20, 2024
c317d4d
Merge branch 'main' into verify-norm
anwai98 Sep 23, 2024
9749ab0
Refactor augmentations
anwai98 Sep 23, 2024
2bce50a
Move multicut segmentation to torch_em module
anwai98 Sep 23, 2024
eed8765
Merge remote-tracking branch 'anwai/verify-norm' into verify-norm
anwai98 Sep 23, 2024
52e30ac
Refactor analysis script to adapt with plantseg
anwai98 Sep 23, 2024
c6d6274
Add find boundaries to plantseg labels
anwai98 Sep 23, 2024
573aae9
Add plantseg quantitative results
anwai98 Sep 23, 2024
af0bf25
Merge branch 'main' into verify-norm
anwai98 Oct 1, 2024
82d9fa8
Make eval-train mode in inference modular for experiments
anwai98 Oct 1, 2024
7ab7e28
Merge branch 'main' into verify-norm
anwai98 Oct 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions experiments/misc/normalization/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
*.out
gpu_jobs
*.png
64 changes: 64 additions & 0 deletions experiments/misc/normalization/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Experiments for Mini-Batch Normalzation Schemes

> NOTE 1: The discussions below are for semantic segmentation task for predicting binary foreground and instance segmentation task for predicting foreground and boundary.

> NOTE 2: The chosen evaluation metric is `Dice Score (DSC)` for semantic segmentation,
and `Segmentation Accuracy` for instance segmentation.

## Quantitative Results:

### LIVECell:

| Targets | Mini-Batch Norm. | Input Norm. | Dice | mSA | SA50 |
|----------|------------------|-----------------|--------------------|---------------------|---------------------|
| Binary | OldDefault | Per Tile | 0.928867041453946 | - | - |
| Binary | InstanceNorm | Per Tile | 0.9139200411471561 | - | - |
| Boundary | OldDefault | Per Tile | - | 0.3596574029856674 | 0.5706313962830515 |
| Boundary | InstanceNorm | Per Tile | - | 0.21136391380220415 | 0.32975049087380925 |
| Binary | OldDefault | Whole Volume | 0.9288671102955196 | - | - |
| Binary | InstanceNorm | Whole Volume | 0.9139609056661331 | - | - |
| Boundary | OldDefault | Whole Volume | - | 0.3596633257789534 | 0.5706479532532667 |
| Boundary | InstanceNorm | Whole Volume | - | 0.21294970907224356 | 0.33214134838916115 |



### Mouse Embryo (Nuclei):

| Targets | Mini-Batch Norm. | Input Norm. | Dice | mSA | SA50 |
|----------|------------------|-----------------|--------------------|---------------------|---------------------|
| Binary | OldDefault | Per Tile | 0.773449732035369 | - | - |
| Binary | InstanceNorm | Per Tile | 0.7413715897174049 | - | - |
| Boundary | OldDefault | Per Tile | - | 0.2089494345163684 | 0.396692895729518 |
| Boundary | InstanceNorm | Per Tile | - | 0.15872795687689387 | 0.3200512111903994 |
| Binary | OldDefault | Whole Volume | 0.7734503020445563 | - | - |
| Binary | InstanceNorm | Whole Volume | 0.767180298383585 | - | - |
| Boundary | OldDefault | Whole Volume | - | 0.20882369840046283 | 0.39638644753570157 |
| Boundary | InstanceNorm | Whole Volume | - | 0.15714229749213487 | 0.3231149929195066 |


## PlantSeg (Root):

| Targets | Mini-Batch Norm. | Input Norm. | Dice | mSA | SA50 |
|----------|------------------|-----------------|--------------------|-----------------------|-----------------------|
| Binary | OldDefault | Per Tile | 0.9928809913563437 | - | - |
| Binary | InstanceNorm | Per Tile | 0.981061366019915 | - | - |
| Boundary | OldDefault | Per Tile | - | 0.014372028994123383 | 0.018482085449422186 |
| Boundary | InstanceNorm | Per Tile | - | 0.0014261363636363636 | 0.004753787878787879 |
| Binary | OldDefault | Whole Volume | 0.9928820421405989 | - | - |
| Binary | InstanceNorm | Whole Volume | 0.9810579911503293 | - | - |
| Boundary | OldDefault | Whole Volume | - | 0.014372028994123383 | 0.018482085449422186 |
| Boundary | InstanceNorm | Whole Volume | - | 0.0014396498771498771 | 0.0047988329238329245 |


## MitoEM (Human and Rat):

| Targets | Mini-Batch Norm. | Input Norm. | Dice | mSA | SA50 |
|----------|------------------|-----------------|--------------------|---------------------|---------------------|
| Binary | OldDefault | Per Tile | 0.9206916296864652 | - | - |
| Binary | InstanceNorm | Per Tile | 0.9247380570327044 | - | - |
| Boundary | OldDefault | Per Tile | - | 0.4452180959378787 | 0.5943295393094057 |
| Boundary | InstanceNorm | Per Tile | - | 0.23307025211315505 | 0.32876735818930347 |
| Binary | OldDefault | Whole Volume | 0.9206915268071336 | - | - |
| Binary | InstanceNorm | Whole Volume | 0.9241492985508706 | - | - |
| Boundary | OldDefault | Whole Volume | - | 0.44530557897744855 | 0.594423580706029 |
| Boundary | InstanceNorm | Whole Volume | - | 0.24493949819392608 | 0.3480087025599067 |
anwai98 marked this conversation as resolved.
Show resolved Hide resolved
217 changes: 217 additions & 0 deletions experiments/misc/normalization/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
import os
from glob import glob

import h5py
import z5py
import numpy as np
import imageio.v3 as imageio

from torchvision import transforms

from torch_em.model import UNet2d, UNet3d
from torch_em.data import MinTwoInstanceSampler, datasets
from torch_em.transform import raw as fetch_transforms

from micro_sam.evaluation.livecell import _get_livecell_paths


if os.path.exists("/scratch/share/cidas"):
ROOT = "/scratch/share/cidas/cca/data"
SAVE_DIR = "/scratch/share/cidas/cca/test/verify_normalization"
else:
ROOT = "/media/anwai/ANWAI/data/"
SAVE_DIR = "/media/anwai/ANWAI/test/verify_normalization"


def get_model(dataset, task, norm):
out_chans = 1 if task == "binary" or dataset == "plantseg" else 2
_model_class = UNet2d if dataset == "livecell" else UNet3d

model = _model_class(
in_channels=1,
out_channels=out_chans,
initial_features=64,
depth=4,
final_activation="Sigmoid",
norm=norm,
)
return model


def get_experiment_name(dataset, task, norm, model_choice):
cfg = "2d" if dataset == "livecell" else "3d"
name = f"{dataset}_{model_choice}{cfg}_{norm}_{task}"
return name


class MultipleRawTransforms:
def __init__(
self, p=0.3, norm=None, blur_kwargs={}, gaussian_kwargs={},
poisson_kwargs={}, additive_poisson_kwargs={}, contrast_kwargs={}
):
self.norm = fetch_transforms.normalize_percentile if norm is None else norm
augs = [self.norm]

if gaussian_kwargs is not None:
augs.append(transforms.RandomApply([fetch_transforms.GaussianBlur(**blur_kwargs)], p=p))
if poisson_kwargs is not None:
augs.append(transforms.RandomApply([fetch_transforms.PoissonNoise(**poisson_kwargs)], p=p/2))
if additive_poisson_kwargs is not None:
augs.append(
transforms.RandomApply([fetch_transforms.AdditivePoissonNoise(**additive_poisson_kwargs)], p=p/2)
)
if gaussian_kwargs is not None:
augs.append(transforms.RandomApply([fetch_transforms.AdditiveGaussianNoise(**gaussian_kwargs)], p=p/2))
if contrast_kwargs is not None:
aug2 = transforms.RandomApply([fetch_transforms.RandomContrast(**contrast_kwargs)], p)

self.raw_transform = fetch_transforms.get_raw_transform(
normalizer=self.norm,
augmentation1=transforms.Compose(augs),
augmentation2=aug2
)

def __call__(self, raw):
raw = raw[None] # NOTE: reason for doing this is to add an empty dimension to work with torch transforms.
raw = self.raw_transform(raw)
return raw


def get_dataloaders(dataset, task):
assert task in ["binary", "boundaries"]
sampler = MinTwoInstanceSampler()

loader_kwargs = {
"num_workers": 16, "download": True, "sampler": sampler,
"raw_transform": MultipleRawTransforms(
p=0.3,
poisson_kwargs=None if dataset == "livecell" else {},
additive_poisson_kwargs={"lam": (0.0, 0.2)} if dataset == "livecell" else {},
),
}

if dataset == "livecell":
train_loader = datasets.get_livecell_loader(
path=os.path.join(ROOT, "livecell"),
split="train",
patch_shape=(512, 512),
batch_size=2,
binary=True if task == "binary" else False,
boundaries=True if task == "boundaries" else False,
**loader_kwargs
)
val_loader = datasets.get_livecell_loader(
path=os.path.join(ROOT, "livecell"),
split="val",
patch_shape=(512, 512),
batch_size=1,
binary=True if task == "binary" else False,
boundaries=True if task == "boundaries" else False,
**loader_kwargs
)
elif dataset == "plantseg":
train_loader = datasets.get_plantseg_loader(
path=os.path.join(ROOT, "plantseg"),
name="root",
split="train",
patch_shape=(32, 512, 512),
batch_size=2,
boundaries=True if task == "boundaries" else False,
**loader_kwargs
)

val_loader = datasets.get_plantseg_loader(
path=os.path.join(ROOT, "plantseg"),
name="root",
split="val",
patch_shape=(32, 512, 512),
batch_size=1,
boundaries=True if task == "boundaries" else False,
**loader_kwargs
)
elif dataset == "mitoem":
train_loader = datasets.get_mitoem_loader(
path=os.path.join(ROOT, "mitoem"),
splits="train",
patch_shape=(32, 512, 512),
batch_size=2,
binary=True if task == "binary" else False,
boundaries=True if task == "boundaries" else False,
**loader_kwargs
)

val_loader = datasets.get_mitoem_loader(
path=os.path.join(ROOT, "mitoem"),
splits="val",
patch_shape=(32, 512, 512),
batch_size=1,
binary=True if task == "binary" else False,
boundaries=True if task == "boundaries" else False,
**loader_kwargs
)
elif dataset == "gonuclear":
train_loader = datasets.get_gonuclear_loader(
path=os.path.join(ROOT, "gonuclear"),
patch_shape=(32, 512, 512),
batch_size=2,
segmentation_task="nuclei",
binary=True if task == "binary" else False,
boundaries=True if task == "boundaries" else False,
sample_ids=[1135, 1136, 1137],
**loader_kwargs
)
val_loader = datasets.get_gonuclear_loader(
path=os.path.join(ROOT, "gonuclear"),
patch_shape=(32, 512, 512),
batch_size=1,
segmentation_task="nuclei",
binary=True if task == "binary" else False,
boundaries=True if task == "boundaries" else False,
sample_ids=[1139],
**loader_kwargs
)
else:
raise ValueError(f"{dataset} is not a valid dataset choice for this experiment.")

return train_loader, val_loader


def dice_score(gt, seg, eps=1e-7):
nom = 2 * np.sum(gt * seg)
denom = np.sum(gt) + np.sum(seg)
score = float(nom) / float(denom + eps)
return score


def get_test_images(dataset):
if dataset == "livecell":
image_paths, gt_paths = _get_livecell_paths(input_folder=os.path.join(ROOT, "livecell"), split="test")
return image_paths, gt_paths
else:
if dataset == "gonuclear":
datasets.gonuclear.get_gonuclear_data(path=os.path.join(ROOT, "gonuclear"), download=True)
volume_paths = [os.path.join(ROOT, "gonuclear", "gonuclear_datasets", "1170.h5")]
elif dataset == "plantseg":
datasets.plantseg.get_plantseg_data(
path=os.path.join(ROOT, "plantseg"), download=True, name="root", split="test"
)
volume_paths = sorted(glob(os.path.join(ROOT, "plantseg", "root_test", "*.h5")))
elif dataset == "mitoem":
volume_paths = sorted(glob(os.path.join(ROOT, "mitoem", "*_val.n5")))

return volume_paths, volume_paths


def _load_image(input_path, key=None):
if key is None:
image = imageio.imread(input_path)

else:
if input_path.endswith(".h5"):
with h5py.File(input_path, "r") as f:
image = f[key][:]
else:
with z5py.File(input_path, "r") as f:
image = f[key][:]

return image
101 changes: 101 additions & 0 deletions experiments/misc/normalization/run_all_scripts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import os
import shutil
import itertools
import subprocess
from datetime import datetime


def write_batch_script(out_path, _name, dataset, phase, task, norm, dry):
"Writing scripts for different norm experiments."
batch_script = f"""#!/bin/bash
#SBATCH -t 2-00:00:00
#SBATCH --mem 128G
#SBATCH --nodes=1
#SBATCH --ntasks=1
#SBATCH -p grete-h100:shared
#SBATCH -G H100:1
#SBATCH -A gzz0001
#SBATCH -c 16
#SBATCH --job-name=unet-{dataset}

source ~/.bashrc
micromamba activate sam \n"""

# python script
python_script = "python run_unet.py "

# dataset choice (livecell / plantseg / mitoem)
python_script += f"-d {dataset} "

# phase of code execution (train / predict)
python_script += f"-p {phase} "

# nature of task (binary / boundaries)
python_script += f"-t {task} "

# normalization scheme (InstanceNorm/OldDefault)
python_script += f"-n {norm} "

# let's add the python script to the bash script
batch_script += python_script

_op = out_path[:-3] + f"_{os.path.split(_name)[-1]}.sh"
with open(_op, "w") as f:
f.write(batch_script)

if not dry:
subprocess.run(["sbatch", _op])


def get_batch_script_names(tmp_folder):
tmp_folder = os.path.expanduser(tmp_folder)
os.makedirs(tmp_folder, exist_ok=True)
script_name = "unet-norm"
dt = datetime.now().strftime("%Y_%m_%d_%H_%M_%S_%f")
tmp_name = script_name + dt
batch_script = os.path.join(tmp_folder, f"{tmp_name}.sh")
return batch_script


def submit_slurm(args):
"Submit python script that needs gpus with given inputs on a slurm node."
if args.dataset is None:
datasets = ["livecell", "plantseg", "mitoem", "gonuclear"]
else:
datasets = [args.dataset]

tasks = ["binary", "boundaries"]
norms = ["InstanceNormTrackStats", "InstanceNorm"]

for (dataset, task, norm) in itertools.product(datasets, tasks, norms):
if dataset == "plantseg" and task == "binary": # for plantseg: binary is just all pixels as foreground
continue

write_batch_script(
out_path=get_batch_script_names("./gpu_jobs"),
_name="unet-norm",
dataset=dataset,
phase=args.phase,
task=task,
norm=norm,
dry=args.dry,
)


def main(args):
try:
shutil.rmtree("./gpu_jobs")
except FileNotFoundError:
pass

submit_slurm(args)


if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("-d", "--dataset", type=str, default=None)
parser.add_argument("-p", "--phase", required=True, type=str)
parser.add_argument("--dry", action="store_true")
args = parser.parse_args()
main(args)
Loading