Skip to content

Commit

Permalink
Minor updates to covid_if norm
Browse files Browse the repository at this point in the history
  • Loading branch information
anwai98 committed Jul 18, 2024
1 parent 683dfe8 commit 1357c42
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion scripts/for_benchmarking_ais/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch

import torch_em
from torch_em.transform.raw import normalize
from torch_em.transform.raw import standardize
from torch_em.loss import DiceBasedDistanceLoss
from torch_em.util import segmentation, prediction
Expand All @@ -28,6 +29,12 @@
#


def covid_if_raw_trafo(raw):
raw = normalize(raw)
raw = raw * 255
return raw


def get_loaders(path, patch_shape, dataset, for_sam=False):
kwargs = {
"label_transform": PerObjectDistanceTransform(
Expand All @@ -44,7 +51,7 @@ def get_loaders(path, patch_shape, dataset, for_sam=False):
}

if for_sam:
kwargs["raw_transform"] = sam_training.identity
kwargs["raw_transform"] = sam_training.identity if dataset == "livecell" else covid_if_raw_trafo

if dataset == "livecell":
train_loader = get_livecell_loader(path=os.path.join(path, "livecell"), split="train", batch_size=2, **kwargs)
Expand Down

0 comments on commit 1357c42

Please sign in to comment.