diff --git a/scripts/for_benchmarking_ais/common.py b/scripts/for_benchmarking_ais/common.py index e5168480..79d88b4c 100644 --- a/scripts/for_benchmarking_ais/common.py +++ b/scripts/for_benchmarking_ais/common.py @@ -11,6 +11,7 @@ import torch import torch_em +from torch_em.transform.raw import normalize from torch_em.transform.raw import standardize from torch_em.loss import DiceBasedDistanceLoss from torch_em.util import segmentation, prediction @@ -28,6 +29,12 @@ # +def covid_if_raw_trafo(raw): + raw = normalize(raw) + raw = raw * 255 + return raw + + def get_loaders(path, patch_shape, dataset, for_sam=False): kwargs = { "label_transform": PerObjectDistanceTransform( @@ -44,7 +51,7 @@ def get_loaders(path, patch_shape, dataset, for_sam=False): } if for_sam: - kwargs["raw_transform"] = sam_training.identity + kwargs["raw_transform"] = sam_training.identity if dataset == "livecell" else covid_if_raw_trafo if dataset == "livecell": train_loader = get_livecell_loader(path=os.path.join(path, "livecell"), split="train", batch_size=2, **kwargs)