From 7f2ec8ef4b3f559b77abc05cc3742f3ef3163a17 Mon Sep 17 00:00:00 2001 From: Anwai Archit Date: Tue, 16 Jul 2024 16:10:15 +0200 Subject: [PATCH] Revert "Merge branch 'main' into verify-norm" This reverts commit 1e09cb086757247a5ce74fc499935ea7edac73a0, reversing changes made to be3257607521af609ea6568b62b9e9786d2921a1. --- .../unet-segmentation/dsb/train_boundaries.py | 9 +- .../datasets/{medical => }/check_papila.py | 4 +- .../datasets/{medical => }/check_siim_acr.py | 6 +- scripts/datasets/medical/check_acdc.py | 23 --- scripts/datasets/medical/check_amos.py | 24 --- scripts/datasets/medical/check_cbis_ddsm.py | 24 --- scripts/datasets/medical/check_cholecseg8k.py | 21 --- scripts/datasets/medical/check_covid19_seg.py | 23 --- scripts/datasets/medical/check_dca1.py | 21 --- scripts/datasets/medical/check_duke_liver.py | 23 --- scripts/datasets/medical/check_han_seg.py | 20 -- scripts/datasets/medical/check_isic.py | 22 --- scripts/datasets/medical/check_m2caiseg.py | 21 --- scripts/datasets/medical/check_oimhs.py | 5 +- .../datasets/medical/check_osic_pulmofib.py | 19 +- scripts/datasets/medical/check_piccolo.py | 20 -- scripts/datasets/medical/check_sega.py | 4 +- scripts/datasets/medical/check_spider.py | 20 -- scripts/datasets/medical/check_toothfairy.py | 21 --- torch_em/__version__.py | 2 +- .../datasets/light_microscopy/__init__.py | 3 - .../datasets/light_microscopy/cellseg_3d.py | 106 ----------- .../light_microscopy/embedseg_data.py | 122 ------------- .../datasets/light_microscopy/organoidnet.py | 109 ----------- torch_em/data/datasets/medical/__init__.py | 13 -- torch_em/data/datasets/medical/acdc.py | 106 ----------- torch_em/data/datasets/medical/amos.py | 132 -------------- torch_em/data/datasets/medical/cbis_ddsm.py | 120 ------------ torch_em/data/datasets/medical/cholecseg8k.py | 159 ---------------- torch_em/data/datasets/medical/covid19_seg.py | 118 ------------ torch_em/data/datasets/medical/dca1.py | 105 ----------- torch_em/data/datasets/medical/drive.py | 40 ++-- torch_em/data/datasets/medical/duke_liver.py | 162 ----------------- torch_em/data/datasets/medical/han_seg.py | 128 ------------- torch_em/data/datasets/medical/isic.py | 147 --------------- torch_em/data/datasets/medical/m2caiseg.py | 172 ------------------ torch_em/data/datasets/medical/oimhs.py | 48 +---- .../data/datasets/medical/osic_pulmofib.py | 6 - torch_em/data/datasets/medical/papila.py | 50 +++-- torch_em/data/datasets/medical/piccolo.py | 105 ----------- torch_em/data/datasets/medical/sega.py | 38 +--- torch_em/data/datasets/medical/siim_acr.py | 31 ++-- torch_em/data/datasets/medical/spider.py | 79 -------- torch_em/data/datasets/medical/toothfairy.py | 88 --------- torch_em/data/datasets/util.py | 9 - torch_em/data/sampler.py | 11 +- torch_em/model/unet.py | 4 +- torch_em/model/unetr.py | 7 +- torch_em/transform/label.py | 16 -- torch_em/util/debug.py | 1 - torch_em/util/prediction.py | 2 +- torch_em/util/util.py | 9 +- 52 files changed, 120 insertions(+), 2458 deletions(-) rename scripts/datasets/{medical => }/check_papila.py (78%) rename scripts/datasets/{medical => }/check_siim_acr.py (74%) delete mode 100644 scripts/datasets/medical/check_acdc.py delete mode 100644 scripts/datasets/medical/check_amos.py delete mode 100644 scripts/datasets/medical/check_cbis_ddsm.py delete mode 100644 scripts/datasets/medical/check_cholecseg8k.py delete mode 100644 scripts/datasets/medical/check_covid19_seg.py delete mode 100644 scripts/datasets/medical/check_dca1.py delete mode 100644 scripts/datasets/medical/check_duke_liver.py delete mode 100644 scripts/datasets/medical/check_han_seg.py delete mode 100644 scripts/datasets/medical/check_isic.py delete mode 100644 scripts/datasets/medical/check_m2caiseg.py delete mode 100644 scripts/datasets/medical/check_piccolo.py delete mode 100644 scripts/datasets/medical/check_spider.py delete mode 100644 scripts/datasets/medical/check_toothfairy.py delete mode 100644 torch_em/data/datasets/light_microscopy/cellseg_3d.py delete mode 100644 torch_em/data/datasets/light_microscopy/embedseg_data.py delete mode 100644 torch_em/data/datasets/light_microscopy/organoidnet.py delete mode 100644 torch_em/data/datasets/medical/acdc.py delete mode 100644 torch_em/data/datasets/medical/amos.py delete mode 100644 torch_em/data/datasets/medical/cbis_ddsm.py delete mode 100644 torch_em/data/datasets/medical/cholecseg8k.py delete mode 100644 torch_em/data/datasets/medical/covid19_seg.py delete mode 100644 torch_em/data/datasets/medical/dca1.py delete mode 100644 torch_em/data/datasets/medical/duke_liver.py delete mode 100644 torch_em/data/datasets/medical/han_seg.py delete mode 100644 torch_em/data/datasets/medical/isic.py delete mode 100644 torch_em/data/datasets/medical/m2caiseg.py delete mode 100644 torch_em/data/datasets/medical/piccolo.py delete mode 100644 torch_em/data/datasets/medical/spider.py delete mode 100644 torch_em/data/datasets/medical/toothfairy.py diff --git a/experiments/unet-segmentation/dsb/train_boundaries.py b/experiments/unet-segmentation/dsb/train_boundaries.py index addd5c79..645e336b 100644 --- a/experiments/unet-segmentation/dsb/train_boundaries.py +++ b/experiments/unet-segmentation/dsb/train_boundaries.py @@ -9,16 +9,11 @@ def train_boundaries(args): patch_shape = (1, 256, 256) train_loader = get_dsb_loader( - args.input, patch_shape=patch_shape, split="train", + args.input, patch_shape, split="train", download=True, boundaries=True, batch_size=args.batch_size ) - - # Uncomment this for checking the loader. - # from torch_em.util.debug import check_loader - # check_loader(train_loader, 4) - val_loader = get_dsb_loader( - args.input, patch_shape=patch_shape, split="test", + args.input, patch_shape, split="test", boundaries=True, batch_size=args.batch_size ) loss = torch_em.loss.DiceLoss() diff --git a/scripts/datasets/medical/check_papila.py b/scripts/datasets/check_papila.py similarity index 78% rename from scripts/datasets/medical/check_papila.py rename to scripts/datasets/check_papila.py index 637c6747..dd170959 100644 --- a/scripts/datasets/medical/check_papila.py +++ b/scripts/datasets/check_papila.py @@ -2,7 +2,7 @@ from torch_em.data.datasets.medical import get_papila_loader -ROOT = "/scratch/share/cidas/cca/data/papila" +ROOT = "/media/anwai/ANWAI/data/papila" def check_papila(): @@ -16,7 +16,7 @@ def check_papila(): download=True, ) - check_loader(loader, 8, plt=True, save_path="./papila.png") + check_loader(loader, 8) if __name__ == "__main__": diff --git a/scripts/datasets/medical/check_siim_acr.py b/scripts/datasets/check_siim_acr.py similarity index 74% rename from scripts/datasets/medical/check_siim_acr.py rename to scripts/datasets/check_siim_acr.py index 1f3df848..ebceb51a 100644 --- a/scripts/datasets/medical/check_siim_acr.py +++ b/scripts/datasets/check_siim_acr.py @@ -3,7 +3,7 @@ from torch_em.data.datasets.medical import get_siim_acr_loader -ROOT = "/scratch/share/cidas/cca/data/siim_acr" +ROOT = "/media/anwai/ANWAI/data/siim_acr" def check_siim_acr(): @@ -13,10 +13,10 @@ def check_siim_acr(): patch_shape=(512, 512), batch_size=2, download=True, - resize_inputs=True, + resize_inputs=False, sampler=MinInstanceSampler() ) - check_loader(loader, 8, plt=True, save_path="./siim_acr.png") + check_loader(loader, 8) if __name__ == "__main__": diff --git a/scripts/datasets/medical/check_acdc.py b/scripts/datasets/medical/check_acdc.py deleted file mode 100644 index 81a073ed..00000000 --- a/scripts/datasets/medical/check_acdc.py +++ /dev/null @@ -1,23 +0,0 @@ -from torch_em.util.debug import check_loader -from torch_em.data.datasets.medical import get_acdc_loader -from torch_em.data import MinInstanceSampler - - -ROOT = "/media/anwai/ANWAI/data/acdc" - - -def check_acdc(): - loader = get_acdc_loader( - path=ROOT, - patch_shape=(4, 256, 256), - batch_size=2, - split="train", - download=True, - sampler=MinInstanceSampler(min_num_instances=4), - ) - - check_loader(loader, 8) - - -if __name__ == "__main__": - check_acdc() diff --git a/scripts/datasets/medical/check_amos.py b/scripts/datasets/medical/check_amos.py deleted file mode 100644 index 98357854..00000000 --- a/scripts/datasets/medical/check_amos.py +++ /dev/null @@ -1,24 +0,0 @@ -from torch_em.util.debug import check_loader -from torch_em.data import MinInstanceSampler -from torch_em.data.datasets.medical import get_amos_loader - -ROOT = "/media/anwai/ANWAI/data/amos" - - -def check_amos(): - loader = get_amos_loader( - path=ROOT, - split="train", - patch_shape=(1, 512, 512), - modality="mri", - ndim=2, - batch_size=2, - download=True, - sampler=MinInstanceSampler(min_num_instances=3), - resize_inputs=False, - ) - check_loader(loader, 8) - - -if __name__ == "__main__": - check_amos() diff --git a/scripts/datasets/medical/check_cbis_ddsm.py b/scripts/datasets/medical/check_cbis_ddsm.py deleted file mode 100644 index ca305183..00000000 --- a/scripts/datasets/medical/check_cbis_ddsm.py +++ /dev/null @@ -1,24 +0,0 @@ -from torch_em.data import MinInstanceSampler -from torch_em.util.debug import check_loader -from torch_em.data.datasets.medical import get_cbis_ddsm_loader - - -ROOT = "/media/anwai/ANWAI/data/cbis_ddsm" - - -def check_cbis_ddsm(): - loader = get_cbis_ddsm_loader( - path=ROOT, - patch_shape=(512, 512), - batch_size=2, - split="Train", - task=None, - tumour_type=None, - resize_inputs=True, - sampler=MinInstanceSampler() - ) - check_loader(loader, 8) - - -if __name__ == "__main__": - check_cbis_ddsm() diff --git a/scripts/datasets/medical/check_cholecseg8k.py b/scripts/datasets/medical/check_cholecseg8k.py deleted file mode 100644 index 41571080..00000000 --- a/scripts/datasets/medical/check_cholecseg8k.py +++ /dev/null @@ -1,21 +0,0 @@ -from torch_em.util.debug import check_loader -from torch_em.data.datasets.medical import get_cholecseg8k_loader - - -ROOT = "/media/anwai/ANWAI/data/cholecseg8k" - - -def get_cholecseg8k(): - loader = get_cholecseg8k_loader( - path=ROOT, - patch_shape=(512, 512), - batch_size=2, - split="train", - resize_inputs=True, - download=False, - ) - check_loader(loader, 8) - - -if __name__ == "__main__": - get_cholecseg8k() diff --git a/scripts/datasets/medical/check_covid19_seg.py b/scripts/datasets/medical/check_covid19_seg.py deleted file mode 100644 index 92111d71..00000000 --- a/scripts/datasets/medical/check_covid19_seg.py +++ /dev/null @@ -1,23 +0,0 @@ -from torch_em.util.debug import check_loader -from torch_em.data import MinInstanceSampler -from torch_em.data.datasets.medical import get_covid19_seg_loader - - -ROOT = "/media/anwai/ANWAI/data/covid19_seg" - - -def check_covid19_seg(): - loader = get_covid19_seg_loader( - path=ROOT, - patch_shape=(32, 512, 512), - batch_size=2, - task="lung", - download=True, - sampler=MinInstanceSampler(), - ) - - check_loader(loader, 8) - - -if __name__ == "__main__": - check_covid19_seg() diff --git a/scripts/datasets/medical/check_dca1.py b/scripts/datasets/medical/check_dca1.py deleted file mode 100644 index 428329cc..00000000 --- a/scripts/datasets/medical/check_dca1.py +++ /dev/null @@ -1,21 +0,0 @@ -from torch_em.util.debug import check_loader -from torch_em.data.datasets.medical import get_dca1_loader - - -ROOT = "/media/anwai/ANWAI/data/dca1" - - -def check_dca1(): - loader = get_dca1_loader( - path=ROOT, - patch_shape=(512, 512), - batch_size=2, - split="test", - resize_inputs=True, - download=False, - ) - check_loader(loader, 8) - - -if __name__ == "__main__": - check_dca1() diff --git a/scripts/datasets/medical/check_duke_liver.py b/scripts/datasets/medical/check_duke_liver.py deleted file mode 100644 index 3668aed8..00000000 --- a/scripts/datasets/medical/check_duke_liver.py +++ /dev/null @@ -1,23 +0,0 @@ -from torch_em.util.debug import check_loader -from torch_em.data.datasets.medical import get_duke_liver_loader - - -ROOT = "/media/anwai/ANWAI/data/duke_liver" - - -def check_duke_liver(): - from micro_sam.training import identity - loader = get_duke_liver_loader( - path=ROOT, - patch_shape=(32, 512, 512), - batch_size=2, - split="train", - download=False, - raw_transform=identity, - - ) - check_loader(loader, 8) - - -if __name__ == "__main__": - check_duke_liver() diff --git a/scripts/datasets/medical/check_han_seg.py b/scripts/datasets/medical/check_han_seg.py deleted file mode 100644 index dd12fcad..00000000 --- a/scripts/datasets/medical/check_han_seg.py +++ /dev/null @@ -1,20 +0,0 @@ -from torch_em.util.debug import check_loader -from torch_em.data.datasets.medical import get_han_seg_loader - - -ROOT = "/media/anwai/ANWAI/data/han-seg/" - - -def check_han_seg(): - loader = get_han_seg_loader( - path=ROOT, - patch_shape=(32, 512, 512), - batch_size=2, - download=False, - ) - - check_loader(loader, 8) - - -if __name__ == "__main__": - check_han_seg() diff --git a/scripts/datasets/medical/check_isic.py b/scripts/datasets/medical/check_isic.py deleted file mode 100644 index c7a77aee..00000000 --- a/scripts/datasets/medical/check_isic.py +++ /dev/null @@ -1,22 +0,0 @@ -from torch_em.util.debug import check_loader -from torch_em.data.datasets.medical import get_isic_loader - - -ROOT = "/scratch/share/cidas/cca/data/isic" - - -def check_isic(): - loader = get_isic_loader( - path=ROOT, - patch_shape=(700, 700), - batch_size=2, - split="test", - download=True, - resize_inputs=True, - ) - - check_loader(loader, 8, plt=True, save_path="./isic.png") - - -if __name__ == "__main__": - check_isic() diff --git a/scripts/datasets/medical/check_m2caiseg.py b/scripts/datasets/medical/check_m2caiseg.py deleted file mode 100644 index 9853f773..00000000 --- a/scripts/datasets/medical/check_m2caiseg.py +++ /dev/null @@ -1,21 +0,0 @@ -from torch_em.util.debug import check_loader -from torch_em.data.datasets.medical import get_m2caiseg_loader - - -ROOT = "/media/anwai/ANWAI/data/m2caiseg" - - -def check_m2caiseg(): - loader = get_m2caiseg_loader( - path=ROOT, - split="train", - patch_shape=(512, 512), - batch_size=2, - resize_inputs=True, - download=True, - ) - check_loader(loader, 8) - - -if __name__ == "__main__": - check_m2caiseg() diff --git a/scripts/datasets/medical/check_oimhs.py b/scripts/datasets/medical/check_oimhs.py index feea6d6f..63543625 100644 --- a/scripts/datasets/medical/check_oimhs.py +++ b/scripts/datasets/medical/check_oimhs.py @@ -2,7 +2,7 @@ from torch_em.data.datasets.medical import get_oimhs_loader -ROOT = "/scratch/share/cidas/cca/data/oimhs" +ROOT = "/media/anwai/ANWAI/data/oimhs" def check_oimhs(): @@ -10,12 +10,11 @@ def check_oimhs(): path=ROOT, patch_shape=(512, 512), batch_size=2, - split="test", download=False, resize_inputs=True, ) - check_loader(loader, 8, plt=True, save_path="./oimhs.png") + check_loader(loader, 8) if __name__ == "__main__": diff --git a/scripts/datasets/medical/check_osic_pulmofib.py b/scripts/datasets/medical/check_osic_pulmofib.py index 9ef934c6..beef0bcf 100644 --- a/scripts/datasets/medical/check_osic_pulmofib.py +++ b/scripts/datasets/medical/check_osic_pulmofib.py @@ -11,11 +11,28 @@ def check_osic_pulmofib(): patch_shape=(1, 512, 512), batch_size=2, resize_inputs=False, - download=True, + download=False, ) check_loader(loader, 8) +def visualize_data(): + import os + from glob import glob + + import nrrd + import napari + + all_volume_paths = sorted(glob(os.path.join(ROOT, "nrrd_heart", "*", "*"))) + for vol_path in all_volume_paths: + vol, header = nrrd.read(vol_path) + + v = napari.Viewer() + v.add_image(vol.transpose(2, 0, 1)) + napari.run() + + if __name__ == "__main__": + # visualize_data() check_osic_pulmofib() diff --git a/scripts/datasets/medical/check_piccolo.py b/scripts/datasets/medical/check_piccolo.py deleted file mode 100644 index 7d43313f..00000000 --- a/scripts/datasets/medical/check_piccolo.py +++ /dev/null @@ -1,20 +0,0 @@ -from torch_em.util.debug import check_loader -from torch_em.data.datasets.medical import get_piccolo_loader - - -ROOT = "/media/anwai/ANWAI/data/piccolo" - - -def check_piccolo(): - loader = get_piccolo_loader( - path=ROOT, - patch_shape=(512, 512), - batch_size=2, - split="train", - resize_inputs=True, - ) - check_loader(loader, 8) - - -if __name__ == "__main__": - check_piccolo() diff --git a/scripts/datasets/medical/check_sega.py b/scripts/datasets/medical/check_sega.py index 2420f2d2..83528e3c 100644 --- a/scripts/datasets/medical/check_sega.py +++ b/scripts/datasets/medical/check_sega.py @@ -9,9 +9,11 @@ def check_sega(): loader = get_sega_loader( path=ROOT, - patch_shape=(32, 512, 512), + patch_shape=(1, 512, 512), batch_size=2, + ndim=2, data_choice="KiTS", + resize_inputs=True, download=True, sampler=MinInstanceSampler(), ) diff --git a/scripts/datasets/medical/check_spider.py b/scripts/datasets/medical/check_spider.py deleted file mode 100644 index 3a4421fc..00000000 --- a/scripts/datasets/medical/check_spider.py +++ /dev/null @@ -1,20 +0,0 @@ -from torch_em.util.debug import check_loader -from torch_em.data import MinInstanceSampler -from torch_em.data.datasets.medical import get_spider_loader - - -ROOT = "/media/anwai/ANWAI/data/spider" - - -def check_spider(): - loader = get_spider_loader( - path=ROOT, - patch_shape=(1, 512, 512), - batch_size=2, - sampler=MinInstanceSampler() - ) - - check_loader(loader, 8) - - -check_spider() diff --git a/scripts/datasets/medical/check_toothfairy.py b/scripts/datasets/medical/check_toothfairy.py deleted file mode 100644 index dc0fbcf3..00000000 --- a/scripts/datasets/medical/check_toothfairy.py +++ /dev/null @@ -1,21 +0,0 @@ -from torch_em.util.debug import check_loader -from torch_em.data import MinInstanceSampler -from torch_em.data.datasets.medical import get_toothfairy_loader - - -ROOT = "/scratch/share/cidas/cca/data/toothfairy/" - - -def check_toothfairy(): - loader = get_toothfairy_loader( - path=ROOT, - patch_shape=(1, 512, 512), - ndim=2, - batch_size=2, - sampler=MinInstanceSampler() - ) - - check_loader(loader, 8, plt=True, save_path="./toothfairy.png") - - -check_toothfairy() diff --git a/torch_em/__version__.py b/torch_em/__version__.py index 4910b9ec..49e0fc1e 100644 --- a/torch_em/__version__.py +++ b/torch_em/__version__.py @@ -1 +1 @@ -__version__ = "0.7.3" +__version__ = "0.7.0" diff --git a/torch_em/data/datasets/light_microscopy/__init__.py b/torch_em/data/datasets/light_microscopy/__init__.py index 53d51066..89f9bd25 100644 --- a/torch_em/data/datasets/light_microscopy/__init__.py +++ b/torch_em/data/datasets/light_microscopy/__init__.py @@ -1,10 +1,8 @@ -from .cellseg_3d import get_cellseg_3d_loader, get_cellseg_3d_dataset from .covid_if import get_covid_if_loader, get_covid_if_dataset from .ctc import get_ctc_segmentation_loader, get_ctc_segmentation_dataset from .deepbacs import get_deepbacs_loader, get_deepbacs_dataset from .dsb import get_dsb_loader, get_dsb_dataset from .dynamicnuclearnet import get_dynamicnuclearnet_loader, get_dynamicnuclearnet_dataset -from .embedseg_data import get_embedseg_loader, get_embedseg_dataset from .hpa import get_hpa_segmentation_loader, get_hpa_segmentation_dataset from .livecell import get_livecell_loader, get_livecell_dataset from .mouse_embryo import get_mouse_embryo_loader, get_mouse_embryo_dataset @@ -13,6 +11,5 @@ get_neurips_cellseg_unsupervised_loader, get_neurips_cellseg_unsupervised_dataset ) from .orgasegment import get_orgasegment_dataset, get_orgasegment_loader -from .organoidnet import get_organoidnet_dataset, get_organoidnet_loader from .plantseg import get_plantseg_loader, get_plantseg_dataset from .tissuenet import get_tissuenet_loader, get_tissuenet_dataset diff --git a/torch_em/data/datasets/light_microscopy/cellseg_3d.py b/torch_em/data/datasets/light_microscopy/cellseg_3d.py deleted file mode 100644 index 60e5b8ea..00000000 --- a/torch_em/data/datasets/light_microscopy/cellseg_3d.py +++ /dev/null @@ -1,106 +0,0 @@ -"""This dataset contains annotation for nucleus segmentation in 3d fluorescence microscopy from mesoSPIM microscopy. - -This dataset is from the publication https://doi.org/10.1101/2024.05.17.594691 . -Please cite it if you use this dataset in your research. -""" - -import os -from glob import glob -from typing import Optional, Tuple, Union - -import torch_em -from torch.utils.data import Dataset, DataLoader -from .. import util - -URL = "https://zenodo.org/records/11095111/files/DATASET_WITH_GT.zip?download=1" -CHECKSUM = "6d8e8d778e479000161fdfea70201a6ded95b3958a703f69def63e69bbddf9d6" - - -def get_cellseg_3d_data(path: Union[os.PathLike, str], download: bool) -> str: - """Download the CellSeg3d training data. - - Args: - path: Filepath to a folder where the downloaded data will be saved. - download: Whether to download the data if it is not present. - - Returns: - The filepath to the training data. - """ - url = URL - checksum = CHECKSUM - - data_path = os.path.join(path, "DATASET_WITH_GT") - if os.path.exists(data_path): - return data_path - - os.makedirs(path, exist_ok=True) - zip_path = os.path.join(path, "cellseg3d.zip") - util.download_source(zip_path, url, download, checksum) - util.unzip(zip_path, path, True) - - return data_path - - -def get_cellseg_3d_dataset( - path: Union[os.PathLike, str], - patch_shape: Tuple[int, int], - sample_ids: Optional[Tuple[int, ...]] = None, - download: bool = False, - **kwargs -) -> Dataset: - """Get the CellSeg3d dataset for segmenting nuclei in 3d fluorescence microscopy. - - Args: - path: Filepath to a folder where the downloaded data will be saved. - patch_shape: The patch shape to use for training. - sample_ids: The volume ids to load. - download: Whether to download the data if it is not present. - kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. - - Returns: - The segmentation dataset. - """ - data_root = get_cellseg_3d_data(path, download) - - raw_paths = sorted(glob(os.path.join(data_root, "*.tif"))) - label_paths = sorted(glob(os.path.join(data_root, "labels", "*.tif"))) - assert len(raw_paths) == len(label_paths) - if sample_ids is not None: - assert all(sid < len(raw_paths) for sid in sample_ids) - raw_paths = [raw_paths[i] for i in sample_ids] - label_paths = [label_paths[i] for i in sample_ids] - - raw_key, label_key = None, None - - return torch_em.default_segmentation_dataset( - raw_paths, raw_key, label_paths, label_key, patch_shape, **kwargs - ) - - -def get_cellseg_3d_loader( - path: Union[os.PathLike, str], - patch_shape: Tuple[int, int], - batch_size: int, - sample_ids: Optional[Tuple[int, ...]] = None, - download: bool = False, - **kwargs -) -> DataLoader: - """Get the CellSeg3d dataloder for segmenting nuclei in 3d fluorescence microscopy. - - Args: - path: Filepath to a folder where the downloaded data will be saved. - patch_shape: The patch shape to use for training. - batch_size: The batch size for training. - sample_ids: The volume ids to load. - download: Whether to download the data if it is not present. - kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. - - Returns: - The DataLoader. - """ - ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) - dataset = get_cellseg_3d_dataset( - path, patch_shape, sample_ids=sample_ids, download=download, **ds_kwargs, - ) - loader = torch_em.get_data_loader(dataset, batch_size=batch_size, **loader_kwargs) - return loader diff --git a/torch_em/data/datasets/light_microscopy/embedseg_data.py b/torch_em/data/datasets/light_microscopy/embedseg_data.py deleted file mode 100644 index 38e32ce1..00000000 --- a/torch_em/data/datasets/light_microscopy/embedseg_data.py +++ /dev/null @@ -1,122 +0,0 @@ -"""This dataset contains annotation for 3d fluorescence microscopy segmentation -that were introduced by the EmbedSeg publication. - -This dataset is from the publication https://proceedings.mlr.press/v143/lalit21a.html. -Please cite it if you use this dataset in your research. -""" - -import os -from glob import glob -from typing import Tuple, Union - -import torch_em -from torch.utils.data import Dataset, DataLoader -from .. import util - -URLS = { - "Mouse-Organoid-Cells-CBG": "https://github.com/juglab/EmbedSeg/releases/download/v0.1.0/Mouse-Organoid-Cells-CBG.zip", # noqa - "Mouse-Skull-Nuclei-CBG": "https://github.com/juglab/EmbedSeg/releases/download/v0.1.0/Mouse-Skull-Nuclei-CBG.zip", - "Platynereis-ISH-Nuclei-CBG": "https://github.com/juglab/EmbedSeg/releases/download/v0.1.0/Platynereis-ISH-Nuclei-CBG.zip", # noqa - "Platynereis-Nuclei-CBG": "https://github.com/juglab/EmbedSeg/releases/download/v0.1.0/Platynereis-Nuclei-CBG.zip", -} -CHECKSUMS = { - "Mouse-Organoid-Cells-CBG": "3695ac340473900ace8c37fd7f3ae0d37217de9f2b86c2341f36b1727825e48b", - "Mouse-Skull-Nuclei-CBG": "3600ec261a48bf953820e0536cacd0bb8a5141be6e7435a4cb0fffeb0caf594e", - "Platynereis-ISH-Nuclei-CBG": "bc9284df6f6d691a8e81b47310d95617252cc98ebf7daeab55801b330ba921e0", - "Platynereis-Nuclei-CBG": "448cb7b46f2fe7d472795e05c8d7dfb40f259d94595ad2cfd256bc2aa4ab3be7", -} - - -def get_embedseg_data(path: Union[os.PathLike, str], name: str, download: bool) -> str: - """Download the EmbedSeg training data. - - Args: - path: Filepath to a folder where the downloaded data will be saved. - name: Name of the dataset to download. - download: Whether to download the data if it is not present. - - Returns: - The filepath to the training data. - """ - if name not in URLS: - raise ValueError(f"The dataset name must be in {list(URLS.keys())}. You provided {name}.") - - url = URLS[name] - checksum = CHECKSUMS[name] - - data_path = os.path.join(path, name) - if os.path.exists(data_path): - return data_path - - os.makedirs(path, exist_ok=True) - zip_path = os.path.join(path, f"{name}.zip") - util.download_source(zip_path, url, download, checksum) - util.unzip(zip_path, path, True) - - return data_path - - -def get_embedseg_dataset( - path: Union[os.PathLike, str], - patch_shape: Tuple[int, int], - name: str, - split: str = "train", - download: bool = False, - **kwargs -) -> Dataset: - """Get an EmbedSeg dataset for 3d fluorescence microscopy segmentation. - - Args: - path: Filepath to a folder where the downloaded data will be saved. - patch_shape: The patch shape to use for training. - name: Name of the dataset to download. - split: The split to use for the dataset. - download: Whether to download the data if it is not present. - kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. - - Returns: - The segmentation dataset. - """ - data_root = get_embedseg_data(path, name, download) - - raw_paths = sorted(glob(os.path.join(data_root, split, "images", "*.tif"))) - label_paths = sorted(glob(os.path.join(data_root, split, "masks", "*.tif"))) - assert len(raw_paths) > 0 - assert len(raw_paths) == len(label_paths) - - raw_key, label_key = None, None - - return torch_em.default_segmentation_dataset( - raw_paths, raw_key, label_paths, label_key, patch_shape, **kwargs - ) - - -def get_embedseg_loader( - path: Union[os.PathLike, str], - patch_shape: Tuple[int, int], - batch_size: int, - name: str, - split: str = "train", - download: bool = False, - **kwargs -) -> DataLoader: - """Get an EmbedSeg dataloader for 3d fluorescence microscopy segmentation. - - Args: - path: Filepath to a folder where the downloaded data will be saved. - patch_shape: The patch shape to use for training. - batch_size: The batch size for training. - name: Name of the dataset to download. - split: The split to use for the dataset. - download: Whether to download the data if it is not present. - kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. - - Returns: - The DataLoader. - """ - ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) - dataset = get_embedseg_dataset( - path, name=name, split=split, patch_shape=patch_shape, download=download, **ds_kwargs, - ) - loader = torch_em.get_data_loader(dataset, batch_size=batch_size, **loader_kwargs) - return loader diff --git a/torch_em/data/datasets/light_microscopy/organoidnet.py b/torch_em/data/datasets/light_microscopy/organoidnet.py deleted file mode 100644 index 765444da..00000000 --- a/torch_em/data/datasets/light_microscopy/organoidnet.py +++ /dev/null @@ -1,109 +0,0 @@ -import os -import shutil -import zipfile - -from glob import glob -from typing import Tuple, Union - -import torch_em - -from .. import util - - -URL = "https://zenodo.org/records/10643410/files/OrganoIDNetData.zip?download=1" -CHECKSUM = "3cd9239bf74bda096ecb5b7bdb95f800c7fa30b9937f9aba6ddf98d754cbfa3d" - - -def get_organoidnet_data(path, split, download): - splits = ["Training", "Validation", "Test"] - assert split in splits - - os.makedirs(path, exist_ok=True) - - data_dir = os.path.join(path, split) - if os.path.exists(data_dir): - return data_dir - - # Download and extraction. - zip_path = os.path.join(path, "OrganoIDNetData.zip") - util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM) - - # Only "Training", "Test", "Validation" from the zip are relevant and need to be extracted. - # They are in "/OrganoIDNetData/Dataset/" - prefix = "OrganoIDNetData/Dataset/" - for dl_split in splits: - - dl_prefix = prefix + dl_split - - with zipfile.ZipFile(zip_path) as archive: - for ff in archive.namelist(): - if ff.startswith(dl_prefix): - archive.extract(ff, path) - - for dl_split in splits: - shutil.move( - os.path.join(path, "OrganoIDNetData/Dataset", dl_split), - os.path.join(path, dl_split) - ) - - assert os.path.exists(data_dir) - - os.remove(zip_path) - return data_dir - - -def _get_data_paths(path, split, download): - data_dir = get_organoidnet_data(path=path, split=split, download=download) - - image_paths = sorted(glob(os.path.join(data_dir, "Images", "*.tif"))) - label_paths = sorted(glob(os.path.join(data_dir, "Masks", "*.tif"))) - - return image_paths, label_paths - - -def get_organoidnet_dataset( - path: Union[os.PathLike, str], - split: str, - patch_shape: Tuple[int, int], - download: bool = False, - **kwargs -): - """Dataset for the segmentation of panceratic organoids. - - This dataset is from the publication https://doi.org/10.1007/s13402-024-00958-2. - Please cite it if you use this dataset for a publication. - """ - image_paths, label_paths = _get_data_paths(path=path, split=split, download=download) - - return torch_em.default_segmentation_dataset( - raw_paths=image_paths, - raw_key=None, - label_paths=label_paths, - label_key=None, - patch_shape=patch_shape, - is_seg_dataset=False, - **kwargs - ) - - -def get_organoidnet_loader( - path: Union[os.PathLike, str], - split: str, - patch_shape: Tuple[int, int], - batch_size: int, - download: bool = False, - **kwargs -): - """Dataloader for the segmentation of pancreatic organoids in brightfield images. - See `get_organoidnet_dataset` for details. - """ - ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) - dataset = get_organoidnet_dataset( - path=path, - split=split, - patch_shape=patch_shape, - download=download, - **ds_kwargs - ) - loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) - return loader diff --git a/torch_em/data/datasets/medical/__init__.py b/torch_em/data/datasets/medical/__init__.py index 8895b82f..84c77d76 100644 --- a/torch_em/data/datasets/medical/__init__.py +++ b/torch_em/data/datasets/medical/__init__.py @@ -1,32 +1,19 @@ -from .acdc import get_acdc_dataset, get_acdc_loader from .autopet import get_autopet_loader -from .amos import get_amos_dataset, get_amos_loader from .btcv import get_btcv_dataset, get_btcv_loader from .busi import get_busi_dataset, get_busi_loader from .camus import get_camus_dataset, get_camus_loader -from .cbis_ddsm import get_cbis_ddsm_dataset, get_cbis_ddsm_loader -from .cholecseg8k import get_cholecseg8k_dataset, get_cholecseg8k_loader -from .covid19_seg import get_covid19_seg_dataset, get_covid19_seg_loader -from .dca1 import get_dca1_dataset, get_dca1_loader from .drive import get_drive_dataset, get_drive_loader -from .duke_liver import get_duke_liver_dataset, get_duke_liver_loader from .feta24 import get_feta24_dataset, get_feta24_loader -from .han_seg import get_han_seg_dataset, get_han_seg_loader from .idrid import get_idrid_dataset, get_idrid_loader -from .isic import get_isic_dataset, get_isic_loader from .jnuifm import get_jnuifm_dataset, get_jnuifm_loader -from .m2caiseg import get_m2caiseg_dataset, get_m2caiseg_loader from .micro_usp import get_micro_usp_dataset, get_micro_usp_loader from .montgomery import get_montgomery_dataset, get_montgomery_loader from .msd import get_msd_dataset, get_msd_loader from .oimhs import get_oimhs_dataset, get_oimhs_loader from .osic_pulmofib import get_osic_pulmofib_dataset, get_osic_pulmofib_loader from .papila import get_papila_dataset, get_papila_loader -from .piccolo import get_piccolo_dataset, get_piccolo_loader from .plethora import get_plethora_dataset, get_plethora_loader from .sa_med2d import get_sa_med2d_dataset, get_sa_med2d_loader from .sega import get_sega_dataset, get_sega_loader from .siim_acr import get_siim_acr_dataset, get_siim_acr_loader -from .spider import get_spider_dataset, get_spider_loader -from .toothfairy import get_toothfairy_dataset, get_toothfairy_loader from .uwaterloo_skin import get_uwaterloo_skin_dataset, get_uwaterloo_skin_loader diff --git a/torch_em/data/datasets/medical/acdc.py b/torch_em/data/datasets/medical/acdc.py deleted file mode 100644 index 448708bb..00000000 --- a/torch_em/data/datasets/medical/acdc.py +++ /dev/null @@ -1,106 +0,0 @@ -import os -from glob import glob -from natsort import natsorted -from typing import Union, Tuple - -import torch_em - -from .. import util -from ... import ConcatDataset - - -URL = "https://humanheart-project.creatis.insa-lyon.fr/database/api/v1/collection/637218c173e9f0047faa00fb/download" -CHECKSUM = "2787e08b0d3525cbac710fc3bdf69ee7c5fd7446472e49db8bc78548802f6b5e" - - -def get_acdc_data(path, download): - os.makedirs(path, exist_ok=True) - - zip_path = os.path.join(path, "ACDC.zip") - trg_dir = os.path.join(path, "ACDC") - if os.path.exists(trg_dir): - return trg_dir - - util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM) - util.unzip(zip_path=zip_path, dst=path, remove=False) - return trg_dir - - -def _get_acdc_paths(path, split, download): - root_dir = get_acdc_data(path=path, download=download) - - if split == "train": - input_dir = os.path.join(root_dir, "database", "training") - else: - input_dir = os.path.join(root_dir, "database", "testing") - - all_patient_dirs = natsorted(glob(os.path.join(input_dir, "patient*"))) - - image_paths, gt_paths = [], [] - for per_patient_dir in all_patient_dirs: - # the volumes with frames are for particular time frames (end diastole (ED) and end systole (ES)) - # the "frames" denote - ED and ES phase instances, which have manual segmentations. - all_volumes = glob(os.path.join(per_patient_dir, "*frame*.nii.gz")) - for vol_path in all_volumes: - sres = vol_path.find("gt") - if sres == -1: # this means the search was invalid, hence it's the mri volume - image_paths.append(vol_path) - else: # this means that the search went through, hence it's the ground truth volume - gt_paths.append(vol_path) - - return natsorted(image_paths), natsorted(gt_paths) - - -def get_acdc_dataset( - path: Union[os.PathLike, str], - split: str, - patch_shape: Tuple[int, int], - download: bool = False, - **kwargs -): - """Dataset fir multi-structure segmentation in cardiac MRI. - - The labels have the following mapping: - - 0 (background), 1 (right ventricle cavity),2 (myocardium), 3 (left ventricle cavity) - - The database is located at - https://humanheart-project.creatis.insa-lyon.fr/database/#collection/637218c173e9f0047faa00fb - - The dataset is from the publication https://doi.org/10.1109/TMI.2018.2837502 - - Please cite it if you use this dataset for a publication. - """ - assert split in ["train", "test"], f"{split} is not a valid split." - - image_paths, gt_paths = _get_acdc_paths(path=path, split=split, download=download) - - all_datasets = [] - for image_path, gt_path in zip(image_paths, gt_paths): - per_vol_ds = torch_em.default_segmentation_dataset( - raw_paths=image_path, - raw_key="data", - label_paths=gt_path, - label_key="data", - patch_shape=patch_shape, - is_seg_dataset=True, - **kwargs - ) - all_datasets.append(per_vol_ds) - - return ConcatDataset(*all_datasets) - - -def get_acdc_loader( - path: Union[os.PathLike, str], - split: str, - patch_shape: Tuple[int, int], - batch_size: int, - download: bool = False, - **kwargs -): - """Dataloader for multi-structure segmentation in cardiac MRI, See `get_acdc_dataset` for details. - """ - ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) - dataset = get_acdc_dataset(path=path, split=split, patch_shape=patch_shape, download=download, **ds_kwargs) - loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) - return loader diff --git a/torch_em/data/datasets/medical/amos.py b/torch_em/data/datasets/medical/amos.py deleted file mode 100644 index 299617f9..00000000 --- a/torch_em/data/datasets/medical/amos.py +++ /dev/null @@ -1,132 +0,0 @@ -import os -from glob import glob -from pathlib import Path -from typing import Union, Tuple, Optional - -import torch_em - -from .. import util - - -URL = "https://zenodo.org/records/7155725/files/amos22.zip" -CHECKSUM = "d2fbf2c31abba9824d183f05741ce187b17905b8cca64d1078eabf1ba96775c2" - - -def get_amos_data(path, download, remove_zip=False): - os.makedirs(path, exist_ok=True) - - data_dir = os.path.join(path, "amos22") - if os.path.exists(data_dir): - return data_dir - - zip_path = os.path.join(path, "amos22.zip") - - util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM) - util.unzip(zip_path=zip_path, dst=path, remove=remove_zip) - - return data_dir - - -def _get_amos_paths(path, split, modality, download): - data_dir = get_amos_data(path=path, download=download) - - if split == "train": - im_dir, gt_dir = "imagesTr", "labelsTr" - elif split == "val": - im_dir, gt_dir = "imagesVa", "labelsVa" - elif split == "test": - im_dir, gt_dir = "imagesTs", "labelsTs" - else: - raise ValueError(f"'{split}' is not a valid split.") - - image_paths = sorted(glob(os.path.join(data_dir, im_dir, "*.nii.gz"))) - gt_paths = sorted(glob(os.path.join(data_dir, gt_dir, "*.nii.gz"))) - - if modality is None: - chosen_image_paths, chosen_gt_paths = image_paths, gt_paths - else: - ct_image_paths, ct_gt_paths = [], [] - mri_image_paths, mri_gt_paths = [], [] - for image_path, gt_path in zip(image_paths, gt_paths): - patient_id = Path(image_path.split(".")[0]).stem - id_value = int(patient_id.split("_")[-1]) - - is_ct = id_value < 500 - - if is_ct: - ct_image_paths.append(image_path) - ct_gt_paths.append(gt_path) - else: - mri_image_paths.append(image_path) - mri_gt_paths.append(gt_path) - - if modality.upper() == "CT": - chosen_image_paths, chosen_gt_paths = ct_image_paths, ct_gt_paths - elif modality.upper() == "MRI": - chosen_image_paths, chosen_gt_paths = mri_image_paths, mri_gt_paths - else: - raise ValueError(f"'{modality}' is not a valid modality.") - - return chosen_image_paths, chosen_gt_paths - - -def get_amos_dataset( - path: Union[os.PathLike, str], - split: str, - patch_shape: Tuple[int, ...], - modality: Optional[str] = None, - resize_inputs: bool = False, - download: bool = False, - **kwargs -): - """Dataset for abdominal multi-organ segmentation in CT and MRI scans. - - The database is located at https://doi.org/10.5281/zenodo.7155725 - The dataset is from AMOS 2022 Challenge - https://doi.org/10.48550/arXiv.2206.08023 - Please cite it if you use this dataset for publication. - """ - image_paths, gt_paths = _get_amos_paths(path=path, split=split, modality=modality, download=download) - - if resize_inputs: - resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False} - kwargs, patch_shape = util.update_kwargs_for_resize_trafo( - kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs - ) - - dataset = torch_em.default_segmentation_dataset( - raw_paths=image_paths, - raw_key="data", - label_paths=gt_paths, - label_key="data", - patch_shape=patch_shape, - is_seg_dataset=True, - **kwargs - ) - - return dataset - - -def get_amos_loader( - path: Union[os.PathLike, str], - split: str, - patch_shape: Tuple[int, ...], - batch_size: int, - modality: Optional[str] = None, - resize_inputs: bool = False, - download: bool = False, - **kwargs -): - """Dataloader for abdominal multi-organ segmentation in CT and MRI scans. See `get_amos_dataset` for details. - """ - ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) - dataset = get_amos_dataset( - path=path, - split=split, - patch_shape=patch_shape, - modality=modality, - resize_inputs=resize_inputs, - download=download, - **ds_kwargs - ) - loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) - return loader diff --git a/torch_em/data/datasets/medical/cbis_ddsm.py b/torch_em/data/datasets/medical/cbis_ddsm.py deleted file mode 100644 index 81a5830b..00000000 --- a/torch_em/data/datasets/medical/cbis_ddsm.py +++ /dev/null @@ -1,120 +0,0 @@ -import os -from glob import glob -from natsort import natsorted -from typing import Union, Tuple, Literal, Optional - -import torch_em - -from .. import util - - -def get_cbis_ddsm_data(path, split, task, tumour_type, download): - os.makedirs(path, exist_ok=True) - - assert split in ["Train", "Test"] - - if task is None: - task = "*" - else: - assert task in ["Calc", "Mass"] - - if tumour_type is None: - tumour_type = "*" - else: - assert tumour_type in ["MALIGNANT", "BENIGN"] - - data_dir = os.path.join(path, "DATA") - if os.path.exists(data_dir): - return os.path.join(path, "DATA", task, split, tumour_type) - - util.download_source_kaggle( - path=path, dataset_name="mohamedbenticha/cbis-ddsm/", download=download, - ) - zip_path = os.path.join(path, "cbis-ddsm.zip") - util.unzip(zip_path=zip_path, dst=path) - return os.path.join(path, "DATA", task, split, tumour_type) - - -def _get_cbis_ddsm_paths(path, split, task, tumour_type, download): - data_dir = get_cbis_ddsm_data( - path=path, - split=split, - task=task, - tumour_type=tumour_type, - download=download - ) - - image_paths = natsorted(glob(os.path.join(data_dir, "*_FULL_*.png"))) - gt_paths = natsorted(glob(os.path.join(data_dir, "*_MASK_*.png"))) - - assert len(image_paths) == len(gt_paths) - - return image_paths, gt_paths - - -def get_cbis_ddsm_dataset( - path: Union[os.PathLike, str], - patch_shape: Tuple[int, int], - split: Literal["Train", "Test"], - task: Optional[Literal["Calc", "Mass"]] = None, - tumour_type: Optional[Literal["MALIGNANT", "BENIGN"]] = None, - resize_inputs: bool = False, - download: bool = False, - **kwargs -): - """Dataset for segmentation of calcification and mass in mammography. - - This dataset is a preprocessed version of https://www.cancerimagingarchive.net/collection/cbis-ddsm/ available - at https://www.kaggle.com/datasets/mohamedbenticha/cbis-ddsm/data. The related publication is: - - https://doi.org/10.1038/sdata.2017.177 - - Please cite it if you use this dataset in a publication. - """ - image_paths, gt_paths = _get_cbis_ddsm_paths( - path=path, split=split, task=task, tumour_type=tumour_type, download=download - ) - - if resize_inputs: - resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False} - kwargs, patch_shape = util.update_kwargs_for_resize_trafo( - kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs - ) - - dataset = torch_em.default_segmentation_dataset( - raw_paths=image_paths, - raw_key=None, - label_paths=gt_paths, - label_key=None, - patch_shape=patch_shape, - is_seg_dataset=False, - **kwargs - ) - return dataset - - -def get_cbis_ddsm_loader( - path: Union[os.PathLike, str], - patch_shape: Tuple[int, int], - batch_size: int, - split: Literal["Train", "Test"], - task: Optional[Literal["Calc", "Mass"]] = None, - tumour_type: Optional[Literal["MALIGNANT", "BENIGN"]] = None, - resize_inputs: bool = False, - download: bool = False, - **kwargs -): - """Dataloader for segmentation of calcification and mass in mammography. See `get_cbis_ddsm_dataset` for details. - """ - ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) - dataset = get_cbis_ddsm_dataset( - path=path, - patch_shape=patch_shape, - split=split, - task=task, - tumour_type=tumour_type, - resize_inputs=resize_inputs, - download=download, - **ds_kwargs - ) - loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) - return loader diff --git a/torch_em/data/datasets/medical/cholecseg8k.py b/torch_em/data/datasets/medical/cholecseg8k.py deleted file mode 100644 index f0dfaa2e..00000000 --- a/torch_em/data/datasets/medical/cholecseg8k.py +++ /dev/null @@ -1,159 +0,0 @@ -import os -import shutil -from glob import glob -from tqdm import tqdm -from pathlib import Path -from natsort import natsorted -from typing import Tuple, Union, Literal - -import numpy as np -import imageio.v3 as imageio - -import torch_em - -from .. import util - - -LABEL_MAPS = { - (255, 255, 255): 0, # small white frame around the image - (50, 50, 50): 0, # background - (11, 11, 11): 1, # abdominal wall - (21, 21, 21): 2, # liver - (13, 13, 13): 3, # gastrointestinal tract - (12, 12, 12): 4, # fat - (31, 31, 31): 5, # grasper - (23, 23, 23): 6, # connective tissue - (24, 24, 24): 7, # blood - (25, 25, 25): 8, # cystic dust - (32, 32, 32): 9, # l-hook electrocautery - (22, 22, 22): 10, # gallbladder - (33, 33, 33): 11, # hepatic vein - (5, 5, 5): 12 # liver ligament -} - - -def get_cholecseg8k_data(path, download): - os.makedirs(path, exist_ok=True) - - data_dir = os.path.join(path, "data") - if os.path.exists(data_dir): - return data_dir - - zip_path = os.path.join(path, "cholecseg8k.zip") - util.download_source_kaggle(path=zip_path, dataset_name="newslab/cholecseg8k", download=download) - util.unzip(zip_path=zip_path, dst=data_dir) - return data_dir - - -def _get_cholecseg8k_paths(path, split, download): - data_dir = get_cholecseg8k_data(path=path, download=download) - - video_dirs = natsorted(glob(os.path.join(data_dir, "video*"))) - if split == "train": - video_dirs = video_dirs[2:-2] - elif split == "val": - video_dirs = [video_dirs[1], video_dirs[-2]] - elif split == "test": - video_dirs = [video_dirs[0], video_dirs[-1]] - else: - raise ValueError(f"'{split}' is not a valid split.") - - ppdir = os.path.join(data_dir, "preprocessed", split) - if os.path.exists(ppdir): - _image_paths = natsorted(glob(os.path.join(ppdir, "images", "*"))) - _gt_paths = natsorted(glob(os.path.join(ppdir, "masks", "*"))) - return _image_paths, _gt_paths - - os.makedirs(os.path.join(ppdir, "images"), exist_ok=True) - os.makedirs(os.path.join(ppdir, "masks"), exist_ok=True) - - image_paths, gt_paths = [], [] - for video_dir in tqdm(video_dirs): - org_image_paths = natsorted(glob(os.path.join(video_dir, "video*", "*_endo.png"))) - org_gt_paths = natsorted(glob(os.path.join(video_dir, "video*", "*_endo_watershed_mask.png"))) - - for org_image_path, org_gt_path in zip(org_image_paths, org_gt_paths): - image_id = os.path.split(org_image_path)[-1] - - image_path = os.path.join(ppdir, "images", image_id) - gt_path = os.path.join(ppdir, "masks", Path(image_id).with_suffix(".tif")) - - image_paths.append(image_path) - gt_paths.append(gt_path) - - if os.path.exists(image_path) and os.path.exists(gt_path): - continue - - gt = imageio.imread(org_gt_path) - assert gt.ndim == 3 - if gt.shape[-1] != 3: # some labels have a 4th channel which has all values as 255 - print("Found a label with inconsistent format.") - # let's verify the case - assert np.unique(gt[..., -1]) == 255 - gt = gt[..., :3] - - instances = np.zeros(gt.shape[:2]) - for lmap in LABEL_MAPS: - binary_map = (gt == lmap).all(axis=2) - instances[binary_map > 0] = LABEL_MAPS[lmap] - - shutil.copy(src=org_image_path, dst=image_path) - imageio.imwrite(gt_path, instances, compression="zlib") - - return image_paths, gt_paths - - -def get_cholecseg8k_dataset( - path: Union[str, os.PathLike], - patch_shape: Tuple[int, int], - split: Literal["train", "val", "test"], - resize_inputs: bool = False, - download: bool = False, - **kwargs -): - """Dataset for segmentation of organs and instruments in endoscopy. - - This dataset is from Twinanda et al. - https://doi.org/10.48550/arXiv.1602.03012 - - This dataset is located at https://www.kaggle.com/datasets/newslab/cholecseg8k/data - - Please cite it if you use this data in a publication. - """ - image_paths, gt_paths = _get_cholecseg8k_paths(path=path, split=split, download=download) - - if resize_inputs: - resize_kwargs = {"patch_shape": patch_shape, "is_rgb": True} - kwargs, patch_shape = util.update_kwargs_for_resize_trafo( - kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs - ) - - dataset = torch_em.default_segmentation_dataset( - raw_paths=image_paths, - raw_key=None, - label_paths=gt_paths, - label_key=None, - is_seg_dataset=False, - patch_shape=patch_shape, - **kwargs - ) - - return dataset - - -def get_cholecseg8k_loader( - path: Union[str, os.PathLike], - patch_shape: Tuple[int, int], - batch_size: int, - split: Literal["train", "val", "test"], - resize_inputs: bool = False, - download: bool = False, - **kwargs -): - """Dataloader for segmentation of organs and instruments in endoscopy. See `get_cholecseg_dataset` for details. - """ - ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) - dataset = get_cholecseg8k_dataset( - path=path, patch_shape=patch_shape, split=split, resize_inputs=resize_inputs, download=download, **ds_kwargs - ) - loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) - return loader diff --git a/torch_em/data/datasets/medical/covid19_seg.py b/torch_em/data/datasets/medical/covid19_seg.py deleted file mode 100644 index 65ecc354..00000000 --- a/torch_em/data/datasets/medical/covid19_seg.py +++ /dev/null @@ -1,118 +0,0 @@ -import os -from glob import glob -from pathlib import Path -from typing import Union, Tuple, Optional - -import torch_em - -from .. import util - - -URL = { - "images": "https://zenodo.org/records/3757476/files/COVID-19-CT-Seg_20cases.zip", - "lung_and_infection": "https://zenodo.org/records/3757476/files/Lung_and_Infection_Mask.zip", - "lung": "https://zenodo.org/records/3757476/files/Lung_Mask.zip", - "infection": "https://zenodo.org/records/3757476/files/Infection_Mask.zip" -} - -CHECKSUM = { - "images": "a5060480eff9315b069b086312dac4872777901fb80d268a5a83edd9f4e7b440", - "lung_and_infection": "34f5a573cb8fb53cb15abe81868395d9addf436854826a6fd6e70c2b294f19c3", - "lung": "f060b0d0299939a6d95ddefdbfa281de1a779c4d230a5adbd32414711d6d8187", - "infection": "87901c73fdd2230260e61d2dbc57bf56026efc28264006b8ea2bf411453c1694" -} - -ZIP_FNAMES = { - "images": "COVID-19-CT-Seg_20cases.zip", - "lung_and_infection": "Lung_and_Infection_Mask.zip", - "lung": "Lung_Mask.zip", - "infection": "Infection_Mask.zip" -} - - -def get_covid19_seg_data(path, task, download): - os.makedirs(path, exist_ok=True) - - im_dir = os.path.join(path, "images", Path(ZIP_FNAMES["images"]).stem) - gt_dir = os.path.join(path, "gt", Path(ZIP_FNAMES[task]).stem) - - if os.path.exists(im_dir) and os.path.exists(gt_dir): - return im_dir, gt_dir - - im_zip_path = os.path.join(path, ZIP_FNAMES["images"]) - gt_zip_path = os.path.join(path, ZIP_FNAMES[task]) - - # download the images - util.download_source( - path=im_zip_path, url=URL["images"], download=download, checksum=CHECKSUM["images"] - ) - util.unzip(zip_path=im_zip_path, dst=im_dir, remove=False) - - # download the gt - util.download_source( - path=gt_zip_path, url=URL[task], download=download, checksum=CHECKSUM[task] - ) - util.unzip(zip_path=gt_zip_path, dst=gt_dir) - - return im_dir, gt_dir - - -def _get_covid19_seg_paths(path, task, download): - image_dir, gt_dir = get_covid19_seg_data(path=path, task=task, download=download) - - image_paths = sorted(glob(os.path.join(image_dir, "*.nii.gz"))) - gt_paths = sorted(glob(os.path.join(gt_dir, "*.nii.gz"))) - - return image_paths, gt_paths - - -def get_covid19_seg_dataset( - path: Union[os.PathLike, str], - patch_shape: Tuple[int, int], - task: Optional[str] = None, - download: bool = False, - **kwargs -): - """Dataset for lung and covid infection segmentation in CT scans. - - The database is located at https://doi.org/10.5281/zenodo.3757476. - - This dataset is from Ma et al. - https://doi.org/10.1002/mp.14676. - Please cite it if you use this dataset for a publication. - """ - if task is None: - task = "lung_and_infection" - else: - assert task in ["lung", "infection", "lung_and_infection"], f"{task} is not a valid task." - - image_paths, gt_paths = _get_covid19_seg_paths(path, task, download) - - dataset = torch_em.default_segmentation_dataset( - raw_paths=image_paths, - raw_key="data", - label_paths=gt_paths, - label_key="data", - patch_shape=patch_shape, - is_seg_dataset=True, - **kwargs - ) - - return dataset - - -def get_covid19_seg_loader( - path: Union[os.PathLike, str], - batch_size: int, - patch_shape: Tuple[int, int], - task: Optional[str] = None, - download: bool = False, - **kwargs -): - """Dataloader for lung and covid infection segmentation in CT scans. See `get_covid19_seg_dataset` for details. - """ - ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) - dataset = get_covid19_seg_dataset( - path=path, patch_shape=patch_shape, task=task, download=download, **ds_kwargs - ) - loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) - return loader diff --git a/torch_em/data/datasets/medical/dca1.py b/torch_em/data/datasets/medical/dca1.py deleted file mode 100644 index df06b88c..00000000 --- a/torch_em/data/datasets/medical/dca1.py +++ /dev/null @@ -1,105 +0,0 @@ -import os -from glob import glob -from natsort import natsorted -from typing import Union, Tuple, Literal - -import torch_em - -from .. import util - - -URL = "http://personal.cimat.mx:8181/~ivan.cruz/DB_Angiograms_files/DB_Angiograms_134.zip" -CHECKSUM = "7161638a6e92c6a6e47a747db039292c8a1a6bad809aac0d1fd16a10a6f22a11" - - -def get_dca1_data(path, download): - os.makedirs(path, exist_ok=True) - - data_dir = os.path.join(path, "Database_134_Angiograms") - if os.path.exists(data_dir): - return data_dir - - zip_path = os.path.join(path, "DB_Angiograms_134.zip") - util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM) - util.unzip(zip_path=zip_path, dst=path) - return data_dir - - -def _get_dca1_paths(path, split, download): - data_dir = get_dca1_data(path=path, download=download) - - image_paths, gt_paths = [], [] - for image_path in natsorted(glob(os.path.join(data_dir, "*.pgm"))): - if image_path.endswith("_gt.pgm"): - gt_paths.append(image_path) - else: - image_paths.append(image_path) - - image_paths, gt_paths = natsorted(image_paths), natsorted(gt_paths) - - if split == "train": # first 85 images - image_paths, gt_paths = image_paths[:-49], gt_paths[:-49] - elif split == "val": # 15 images - image_paths, gt_paths = image_paths[-49:-34], gt_paths[-49:-34] - elif split == "test": # last 34 images - image_paths, gt_paths = image_paths[-34:], gt_paths[-34:] - else: - raise ValueError(f"'{split}' is not a valid split.") - - return image_paths, gt_paths - - -def get_dca1_dataset( - path: Union[os.PathLike, str], - patch_shape: Tuple[int, int], - split: Literal["train", "val", "test"], - resize_inputs: bool = False, - download: bool = False, - **kwargs -): - """Dataset for segmentation of coronary ateries in x-ray angiography. - - This dataset is from Cervantes-Sanchez et al. - https://doi.org/10.3390/app9245507. - - The database is located at http://personal.cimat.mx:8181/~ivan.cruz/DB_Angiograms.html. - - Please cite it if you use this dataset in a publication. - """ - image_paths, gt_paths = _get_dca1_paths(path=path, split=split, download=download) - - if resize_inputs: - resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False} - kwargs, patch_shape = util.update_kwargs_for_resize_trafo( - kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs - ) - - dataset = torch_em.default_segmentation_dataset( - raw_paths=image_paths, - raw_key=None, - label_paths=gt_paths, - label_key=None, - patch_shape=patch_shape, - is_seg_dataset=False, - **kwargs - ) - - return dataset - - -def get_dca1_loader( - path: Union[os.PathLike, str], - patch_shape: Tuple[int, int], - batch_size: int, - split: Literal["train", "val", "test"], - resize_inputs: bool = False, - download: bool = False, - **kwargs -): - """Dataloader for segmentation of coronary ateries in x-ray angiography. See `get_dca1_dataset` for details. - """ - ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) - dataset = get_dca1_dataset( - path=path, patch_shape=patch_shape, split=split, resize_inputs=resize_inputs, download=download, **ds_kwargs - ) - loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) - return loader diff --git a/torch_em/data/datasets/medical/drive.py b/torch_em/data/datasets/medical/drive.py index da06a2a1..4c097260 100644 --- a/torch_em/data/datasets/medical/drive.py +++ b/torch_em/data/datasets/medical/drive.py @@ -6,8 +6,10 @@ import imageio.v3 as imageio import torch_em +from torch_em.transform.generic import ResizeInputs from .. import util +from ... import ImageCollectionDataset URL = { @@ -58,28 +60,18 @@ def _get_drive_ground_truth(data_dir): return neu_gt_paths -def _get_drive_paths(path, split, download): +def _get_drive_paths(path, download): data_dir = get_drive_data(path=path, download=download) image_paths = sorted(glob(os.path.join(data_dir, "images", "*.tif"))) gt_paths = _get_drive_ground_truth(data_dir) - if split == "train": - image_paths, gt_paths = image_paths[:10], gt_paths[:10] - elif split == "val": - image_paths, gt_paths = image_paths[10:14], gt_paths[10:14] - elif split == "test": - image_paths, gt_paths = image_paths[14:], gt_paths[14:] - else: - raise ValueError(f"'{split}' is not a valid split.") - return image_paths, gt_paths def get_drive_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, int], - split: str, resize_inputs: bool = False, download: bool = False, **kwargs @@ -92,21 +84,22 @@ def get_drive_dataset( Please cite it if you use this dataset for a publication. """ - image_paths, gt_paths = _get_drive_paths(path=path, split=split, download=download) + image_paths, gt_paths = _get_drive_paths(path=path, download=download) if resize_inputs: - resize_kwargs = {"patch_shape": patch_shape, "is_rgb": True} - kwargs, patch_shape = util.update_kwargs_for_resize_trafo( - kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs - ) + raw_trafo = ResizeInputs(target_shape=patch_shape, is_rgb=True) + label_trafo = ResizeInputs(target_shape=patch_shape, is_label=True) + patch_shape = None + else: + patch_shape = patch_shape + raw_trafo, label_trafo = None, None - dataset = torch_em.default_segmentation_dataset( - raw_paths=image_paths, - raw_key=None, - label_paths=gt_paths, - label_key=None, + dataset = ImageCollectionDataset( + raw_image_paths=image_paths, + label_image_paths=gt_paths, patch_shape=patch_shape, - is_seg_dataset=False, + raw_transform=raw_trafo, + label_transform=label_trafo, **kwargs ) @@ -116,7 +109,6 @@ def get_drive_dataset( def get_drive_loader( path: Union[os.PathLike, str], patch_shape: Tuple[int, int], - split: str, batch_size: int, resize_inputs: bool = False, download: bool = False, @@ -126,7 +118,7 @@ def get_drive_loader( """ ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) dataset = get_drive_dataset( - path=path, patch_shape=patch_shape, split=split, resize_inputs=resize_inputs, download=download, **ds_kwargs + path=path, patch_shape=patch_shape, resize_inputs=resize_inputs, download=download, **ds_kwargs ) loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) return loader diff --git a/torch_em/data/datasets/medical/duke_liver.py b/torch_em/data/datasets/medical/duke_liver.py deleted file mode 100644 index 3cc897c9..00000000 --- a/torch_em/data/datasets/medical/duke_liver.py +++ /dev/null @@ -1,162 +0,0 @@ -import os -from glob import glob -from tqdm import tqdm -from natsort import natsorted - -import numpy as np - -import torch_em - -from .. import util - - -def get_duke_liver_data(path, download): - """The dataset is located at https://doi.org/10.5281/zenodo.7774566. - - Follow the instructions below to get access to the dataset. - - Visit the zenodo site attached above. - - Send a request message alongwith some details to get access to the dataset. - - The authors would accept the request, then you can access the dataset. - - Next, download the `Segmentation.zip` file and provide the path where the zip file is stored. - """ - if download: - raise NotImplementedError( - "Automatic download for Duke Liver dataset is not possible. See `get_duke_liver_data` for details." - ) - - data_dir = os.path.join(path, "data", "Segmentation") - if os.path.exists(data_dir): - return data_dir - - zip_path = os.path.join(path, "Segmentation.zip") - util.unzip(zip_path=zip_path, dst=os.path.join(path, "data"), remove=False) - return data_dir - - -def _preprocess_data(path, data_dir): - preprocess_dir = os.path.join(path, "data", "preprocessed") - - if os.path.exists(preprocess_dir): - _image_paths = natsorted(glob(os.path.join(preprocess_dir, "images", "*.nii.gz"))) - _gt_paths = natsorted(glob(os.path.join(preprocess_dir, "masks", "*.nii.gz"))) - return _image_paths, _gt_paths - - os.makedirs(os.path.join(preprocess_dir, "images"), exist_ok=True) - os.makedirs(os.path.join(preprocess_dir, "masks"), exist_ok=True) - - import pydicom as dicom - import nibabel as nib - - image_paths, gt_paths = [], [] - for patient_dir in tqdm(glob(os.path.join(data_dir, "00*"))): - patient_id = os.path.split(patient_dir)[-1] - - for sub_id_dir in glob(os.path.join(patient_dir, "*")): - sub_id = os.path.split(sub_id_dir)[-1] - - image_path = os.path.join(preprocess_dir, "images", f"{patient_id}_{sub_id}.nii.gz") - gt_path = os.path.join(preprocess_dir, "masks", f"{patient_id}_{sub_id}.nii.gz") - - image_paths.append(image_path) - gt_paths.append(gt_path) - - if os.path.exists(image_path) and os.path.exists(gt_path): - continue - - image_slice_paths = natsorted(glob(os.path.join(sub_id_dir, "images", "*.dicom"))) - gt_slice_paths = natsorted(glob(os.path.join(sub_id_dir, "masks", "*.dicom"))) - - images, gts = [], [] - for image_slice_path, gt_slice_path in zip(image_slice_paths, gt_slice_paths): - image_slice = dicom.dcmread(image_slice_path).pixel_array - gt_slice = dicom.dcmread(gt_slice_path).pixel_array - - images.append(image_slice) - gts.append(gt_slice) - - image = np.stack(images).transpose(1, 2, 0) - gt = np.stack(gts).transpose(1, 2, 0) - - assert image.shape == gt.shape - - image = nib.Nifti2Image(image, np.eye(4)) - gt = nib.Nifti2Image(gt, np.eye(4)) - - nib.save(image, image_path) - nib.save(gt, gt_path) - - return natsorted(image_paths), natsorted(gt_paths) - - -def _get_duke_liver_paths(path, split, download): - data_dir = get_duke_liver_data(path=path, download=download) - - image_paths, gt_paths = _preprocess_data(path=path, data_dir=data_dir) - - if split == "train": - image_paths, gt_paths = image_paths[:250], gt_paths[:250] - elif split == "val": - image_paths, gt_paths = image_paths[250:260], gt_paths[250:260] - elif split == "test": - image_paths, gt_paths = image_paths[260:], gt_paths[260:] - else: - raise ValueError(f"'{split}' is not a valid split.") - - return image_paths, gt_paths - - -def get_duke_liver_dataset( - path, - patch_shape, - split, - resize_inputs=False, - download=False, - **kwargs -): - """Dataset for segmentation of liver in MRI. - - This dataset is from Macdonald et al. - https://doi.org/10.1148/ryai.220275. - - The dataset is located at https://doi.org/10.5281/zenodo.7774566 (see `get_duke_liver_dataset` for details). - - Please cite it if you use it in a publication. - """ - - image_paths, gt_paths = _get_duke_liver_paths(path=path, split=split, download=download) - - if resize_inputs: - resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False} - kwargs, patch_shape = util.update_kwargs_for_resize_trafo( - kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs - ) - - dataset = torch_em.default_segmentation_dataset( - raw_paths=image_paths, - raw_key="data", - label_paths=gt_paths, - label_key="data", - is_seg_dataset=True, - patch_shape=patch_shape, - **kwargs - ) - - return dataset - - -def get_duke_liver_loader( - path, - patch_shape, - batch_size, - split, - resize_inputs=False, - download=False, - **kwargs -): - """Dataloader for segmentation of liver in MRI. See `get_duke_liver_dataset` for details. - """ - ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) - dataset = get_duke_liver_dataset( - path=path, patch_shape=patch_shape, split=split, resize_inputs=resize_inputs, download=download, **ds_kwargs - ) - loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) - return loader diff --git a/torch_em/data/datasets/medical/han_seg.py b/torch_em/data/datasets/medical/han_seg.py deleted file mode 100644 index 0f558bb8..00000000 --- a/torch_em/data/datasets/medical/han_seg.py +++ /dev/null @@ -1,128 +0,0 @@ -import os -from glob import glob -from tqdm import tqdm -from pathlib import Path -from natsort import natsorted -from typing import Union, Tuple - -import nrrd -import numpy as np -import nibabel as nib - -import torch_em - -from .. import util - - -URL = "https://zenodo.org/records/7442914/files/HaN-Seg.zip" -CHECKSUM = "20226dd717f334dc1b1afe961b3375f946fa56b64a80bf5349128f90c0bbfa5f" - - -def get_han_seg_data(path, download): - os.makedirs(path, exist_ok=True) - - data_dir = os.path.join(path, "HaN-Seg") - if os.path.exists(data_dir): - return data_dir - - zip_path = os.path.join(path, "HaN-Seg.zip") - util.download_source( - path=zip_path, url=URL, download=download, checksum=CHECKSUM - ) - util.unzip(zip_path=zip_path, dst=path, remove=False) - - return data_dir - - -def _get_han_seg_paths(path, download): - data_dir = get_han_seg_data(path=path, download=download) - - image_dir = os.path.join(data_dir, "set_1", "preprocessed", "images") - gt_dir = os.path.join(data_dir, "set_1", "preprocessed", "ground_truth") - os.makedirs(image_dir, exist_ok=True) - os.makedirs(gt_dir, exist_ok=True) - - image_paths, gt_paths = [], [] - all_case_dirs = natsorted(glob(os.path.join(data_dir, "set_1", "case_*"))) - for case_dir in tqdm(all_case_dirs): - image_path = os.path.join(image_dir, f"{os.path.split(case_dir)[-1]}_ct.nii.gz") - gt_path = os.path.join(gt_dir, f"{os.path.split(case_dir)[-1]}.nii.gz") - image_paths.append(image_path) - gt_paths.append(gt_path) - if os.path.exists(image_path) and os.path.exists(gt_path): - continue - - all_nrrd_paths = natsorted(glob(os.path.join(case_dir, "*.nrrd"))) - all_volumes, all_volume_ids = [], [] - for nrrd_path in all_nrrd_paths: - image_id = Path(nrrd_path).stem - - # we skip the MRI volumes - if image_id.endswith("_MR_T1"): - continue - - data, header = nrrd.read(nrrd_path) - all_volumes.append(data) - all_volume_ids.append(image_id) - - raw = all_volumes[0] - raw = nib.Nifti2Image(raw, np.eye(4)) - nib.save(raw, image_path) - - gt = np.zeros(raw.shape) - for idx, per_organ in enumerate(all_volumes[1:], 1): - gt[per_organ > 0] = idx - gt = nib.Nifti2Image(gt, np.eye(4)) - nib.save(gt, gt_path) - - return image_paths, gt_paths - - -def get_han_seg_dataset( - path: Union[os.PathLike, str], - patch_shape: Tuple[int, ...], - resize_inputs: bool = False, - download: bool = False, - **kwargs -): - """Dataset for head and neck organ-at-rish segmentation in CT scans. - - This dataset is from Podobnik et al. - https://doi.org/10.1002/mp.16197 - Please cite it if you use it in a publication. - """ - image_paths, gt_paths = _get_han_seg_paths(path=path, download=download) - - if resize_inputs: - resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False} - kwargs, patch_shape = util.update_kwargs_for_resize_trafo( - kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs, - ) - - dataset = torch_em.default_segmentation_dataset( - raw_paths=image_paths, - raw_key="data", - label_paths=gt_paths, - label_key="data", - patch_shape=patch_shape, - **kwargs - ) - - return dataset - - -def get_han_seg_loader( - path: Union[os.PathLike, str], - patch_shape: Tuple[int, ...], - batch_size: int, - resize_inputs: bool = False, - download: bool = False, - **kwargs -): - """Dataloader for for head and neck organ-at-rish segmentation in CT scans. See `get_han_seg_dataset` for details. - """ - ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) - dataset = get_han_seg_dataset( - path=path, patch_shape=patch_shape, resize_inputs=resize_inputs, download=download, **ds_kwargs - ) - loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) - return loader diff --git a/torch_em/data/datasets/medical/isic.py b/torch_em/data/datasets/medical/isic.py deleted file mode 100644 index efabba2c..00000000 --- a/torch_em/data/datasets/medical/isic.py +++ /dev/null @@ -1,147 +0,0 @@ -import os -from glob import glob -from pathlib import Path -from typing import Union, Tuple - -import torch_em - -from .. import util -from ..light_microscopy.neurips_cell_seg import to_rgb - - -URL = { - "images": { - "train": "https://isic-challenge-data.s3.amazonaws.com/2018/ISIC2018_Task1-2_Training_Input.zip", - "val": "https://isic-challenge-data.s3.amazonaws.com/2018/ISIC2018_Task1-2_Validation_Input.zip", - "test": "https://isic-challenge-data.s3.amazonaws.com/2018/ISIC2018_Task1-2_Test_Input.zip", - }, - "gt": { - "train": "https://isic-challenge-data.s3.amazonaws.com/2018/ISIC2018_Task1_Training_GroundTruth.zip", - "val": "https://isic-challenge-data.s3.amazonaws.com/2018/ISIC2018_Task1_Validation_GroundTruth.zip", - "test": "https://isic-challenge-data.s3.amazonaws.com/2018/ISIC2018_Task1_Test_GroundTruth.zip", - }, -} - -CHECKSUM = { - "images": { - "train": "80f98572347a2d7a376227fa9eb2e4f7459d317cb619865b8b9910c81446675f", - "val": "0ea920fcfe512d12a6e620b50b50233c059f67b10146e1479c82be58ff15a797", - "test": "e59ae1f69f4ed16f09db2cb1d76c2a828487b63d28f6ab85997f5616869b127d", - }, - "gt": { - "train": "99f8b2bb3c4d6af483362010715f7e7d5d122d9f6c02cac0e0d15bef77c7604c", - "val": "f6911e9c0a64e6d687dd3ca466ca927dd5e82145cb2163b7a1e5b37d7a716285", - "test": "2e8f6edce454a5bdee52485e39f92bd6eddf357e81f39018d05512175238ef82", - } -} - - -def get_isic_data(path, split, download): - os.makedirs(path, exist_ok=True) - - im_url = URL["images"][split] - im_checksum = CHECKSUM["images"][split] - - gt_url = URL["gt"][split] - gt_checksum = CHECKSUM["gt"][split] - - im_zipfile = os.path.split(im_url)[-1] - gt_zipfile = os.path.split(gt_url)[-1] - - imdir = os.path.join(path, Path(im_zipfile).stem) - gtdir = os.path.join(path, Path(gt_zipfile).stem) - - im_zip_path = os.path.join(path, im_zipfile) - gt_zip_path = os.path.join(path, gt_zipfile) - - if os.path.exists(imdir) and os.path.exists(gtdir): - return imdir, gtdir - - # download the images - util.download_source(path=im_zip_path, url=im_url, download=download, checksum=im_checksum) - util.unzip(zip_path=im_zip_path, dst=path, remove=False) - # download the ground-truth - util.download_source(path=gt_zip_path, url=gt_url, download=download, checksum=gt_checksum) - util.unzip(zip_path=gt_zip_path, dst=path, remove=False) - - return imdir, gtdir - - -def _get_isic_paths(path, split, download): - image_dir, gt_dir = get_isic_data(path=path, split=split, download=download) - - image_paths = sorted(glob(os.path.join(image_dir, "*.jpg"))) - gt_paths = sorted(glob(os.path.join(gt_dir, "*.png"))) - - return image_paths, gt_paths - - -def get_isic_dataset( - path: Union[os.PathLike, str], - patch_shape: Tuple[int, int], - split: str, - download: bool = False, - make_rgb: bool = True, - resize_inputs: bool = False, - **kwargs -): - """Dataset for the segmentation of skin lesion in dermoscopy images. - - This dataset is related to the following publication(s): - - https://doi.org/10.1038/sdata.2018.161 - - https://doi.org/10.48550/arXiv.1710.05006 - - https://doi.org/10.48550/arXiv.1902.03368 - - The database is located at https://challenge.isic-archive.com/data/#2018 - - Please cite it if you use this dataset for a publication. - """ - assert split in list(URL["images"].keys()), f"{split} is not a valid split." - - image_paths, gt_paths = _get_isic_paths(path=path, split=split, download=download) - - if make_rgb: - kwargs["raw_transform"] = to_rgb - - if resize_inputs: - resize_kwargs = {"patch_shape": patch_shape, "is_rgb": True} - kwargs, patch_shape = util.update_kwargs_for_resize_trafo( - kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs - ) - - dataset = torch_em.default_segmentation_dataset( - raw_paths=image_paths, - raw_key=None, - label_paths=gt_paths, - label_key=None, - patch_shape=patch_shape, - is_seg_dataset=False, - **kwargs - ) - return dataset - - -def get_isic_loader( - path: Union[os.PathLike, str], - patch_shape: Tuple[int, int], - batch_size: int, - split: str, - download: bool = False, - make_rgb: bool = True, - resize_inputs: bool = False, - **kwargs -): - """Dataloader for the segmentation of skin lesion in dermoscopy images. See `get_isic_dataset` for details. - """ - ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) - dataset = get_isic_dataset( - path=path, - patch_shape=patch_shape, - split=split, - download=download, - make_rgb=make_rgb, - resize_inputs=resize_inputs, - **ds_kwargs - ) - loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) - return loader diff --git a/torch_em/data/datasets/medical/m2caiseg.py b/torch_em/data/datasets/medical/m2caiseg.py deleted file mode 100644 index 51f92710..00000000 --- a/torch_em/data/datasets/medical/m2caiseg.py +++ /dev/null @@ -1,172 +0,0 @@ -import os -from glob import glob -from tqdm import tqdm -from pathlib import Path -from natsort import natsorted -from typing import Union, Tuple - -import numpy as np -import imageio.v3 as imageio - -import torch_em - -from .. import util - - -LABEL_MAPS = { - (0, 0, 0): 0, # out of frame - (0, 85, 170): 1, # grasper - (0, 85, 255): 2, # bipolar - (0, 170, 255): 3, # hook - (0, 255, 85): 4, # scissors - (0, 255, 170): 5, # clipper - (85, 0, 170): 6, # irrigator - (85, 0, 255): 7, # specimen bag - (170, 85, 85): 8, # trocars - (170, 170, 170): 9, # clip - (85, 170, 0): 10, # liver - (85, 170, 255): 11, # gall bladder - (85, 255, 0): 12, # fat - (85, 255, 170): 13, # upper wall - (170, 0, 255): 14, # artery - (255, 0, 255): 15, # intestine - (255, 255, 0): 16, # bile - (255, 0, 0): 17, # blood - (170, 0, 85): 18, # unknown -} - - -def get_m2caiseg_data(path, download): - os.makedirs(path, exist_ok=True) - - data_dir = os.path.join(path, r"m2caiSeg dataset") - if os.path.exists(data_dir): - return data_dir - - util.download_source_kaggle(path=path, dataset_name="salmanmaq/m2caiseg", download=download) - zip_path = os.path.join(path, "m2caiseg.zip") - util.unzip(zip_path=zip_path, dst=path) - - return data_dir - - -def _get_m2caiseg_paths(path, split, download): - data_dir = get_m2caiseg_data(path=path, download=download) - - if split == "val": - impaths = natsorted(glob(os.path.join(data_dir, "train", "images", "*.jpg"))) - gpaths = natsorted(glob(os.path.join(data_dir, "train", "groundtruth", "*.png"))) - - imids = [os.path.split(_p)[-1] for _p in impaths] - gids = [os.path.split(_p)[-1] for _p in gpaths] - - image_paths = [ - _p for _p in natsorted( - glob(os.path.join(data_dir, "trainval", "images", "*.jpg")) - ) if os.path.split(_p)[-1] not in imids - ] - gt_paths = [ - _p for _p in natsorted( - glob(os.path.join(data_dir, "trainval", "groundtruth", "*.png")) - ) if os.path.split(_p)[-1] not in gids - ] - - else: - image_paths = natsorted(glob(os.path.join(data_dir, split, "images", "*.jpg"))) - gt_paths = natsorted(glob(os.path.join(data_dir, split, "groundtruth", "*.png"))) - - images_dir = os.path.join(data_dir, "preprocessed", split, "images") - mask_dir = os.path.join(data_dir, "preprocessed", split, "masks") - if os.path.exists(images_dir) and os.path.exists(mask_dir): - return natsorted(glob(os.path.join(images_dir, "*"))), natsorted(glob(os.path.join(mask_dir, "*"))) - - os.makedirs(images_dir, exist_ok=True) - os.makedirs(mask_dir, exist_ok=True) - - fimage_paths, fgt_paths = [], [] - for image_path, gt_path in tqdm(zip(image_paths, gt_paths), total=len(image_paths)): - image = imageio.imread(image_path) - gt = imageio.imread(gt_path) - - image_id = Path(image_path).stem - gt_id = Path(gt_path).stem - - if image.shape != gt.shape: - print("This pair of image and labels mismatch.") - continue - - dst_image_path = os.path.join(images_dir, f"{image_id}.tif") - dst_gt_path = os.path.join(mask_dir, f"{gt_id}.tif") - - fimage_paths.append(image_path) - fgt_paths.append(dst_gt_path) - if os.path.exists(dst_gt_path) and os.path.exists(dst_image_path): - continue - - instances = np.zeros(gt.shape[:2]) - for lmap in LABEL_MAPS: - binary_map = (gt == lmap).all(axis=2) - instances[binary_map > 0] = LABEL_MAPS[lmap] - - imageio.imwrite(dst_image_path, image, compression="zlib") - imageio.imwrite(dst_gt_path, instances, compression="zlib") - - return fimage_paths, fgt_paths - - -def get_m2caiseg_dataset( - path: Union[os.PathLike, str], - split: str, - patch_shape: Tuple[int, int], - resize_inputs: bool = False, - download: bool = False, - **kwargs -): - """Dataset for segmentation of organs and instruments in endoscopy. - - This data is from Maqbool et al. - https://doi.org/10.48550/arXiv.2008.10134 - - This dataset is located at https://www.kaggle.com/datasets/salmanmaq/m2caiseg - - Please cite it if you use this data in a publication. - """ - assert split in ["train", "val", "test"] - - image_paths, gt_paths = _get_m2caiseg_paths(path=path, split=split, download=download) - - if resize_inputs: - resize_kwargs = {"patch_shape": patch_shape, "is_rgb": True} - kwargs, patch_shape = util.update_kwargs_for_resize_trafo( - kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs - ) - - dataset = torch_em.default_segmentation_dataset( - raw_paths=image_paths, - raw_key=None, - label_paths=gt_paths, - label_key=None, - patch_shape=patch_shape, - is_seg_dataset=False, - **kwargs - ) - - return dataset - - -def get_m2caiseg_loader( - path: Union[os.PathLike, str], - split: str, - patch_shape: Tuple[int, int], - batch_size: int, - resize_inputs: bool = False, - download: bool = False, - **kwargs -): - """Dataloader for segmentation of organs and instruments in endoscopy. See `get_m2caiseg_dataset` for details. - """ - ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) - dataset = get_m2caiseg_dataset( - path=path, split=split, patch_shape=patch_shape, resize_inputs=resize_inputs, download=download, **ds_kwargs - ) - loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) - return loader diff --git a/torch_em/data/datasets/medical/oimhs.py b/torch_em/data/datasets/medical/oimhs.py index 24faada2..bcc5082c 100644 --- a/torch_em/data/datasets/medical/oimhs.py +++ b/torch_em/data/datasets/medical/oimhs.py @@ -3,12 +3,10 @@ from tqdm import tqdm from pathlib import Path from natsort import natsorted -from typing import Union, Tuple, Literal +from typing import Union, Tuple -import json import numpy as np import imageio.v3 as imageio -from sklearn.model_selection import train_test_split import torch_em @@ -40,27 +38,7 @@ def get_oimhs_data(path, download): return data_dir -def _create_splits(data_dir, split_file, test_fraction=0.2): - eye_dirs = natsorted(glob(os.path.join(data_dir, "Images", "*"))) - - # let's split the data - main_split, test_split = train_test_split(eye_dirs, test_size=test_fraction) - train_split, val_split = train_test_split(main_split, test_size=0.1) - - decided_splits = {"train": train_split, "val": val_split, "test": test_split} - - with open(split_file, "w") as f: - json.dump(decided_splits, f) - - -def _get_per_split_dirs(split_file, split): - with open(split_file, "r") as f: - data = json.load(f) - - return data[split] - - -def _get_oimhs_paths(path, split, download): +def _get_oimhs_paths(path, download): data_dir = get_oimhs_data(path=path, download=download) image_dir = os.path.join(data_dir, "preprocessed", "images") @@ -68,13 +46,8 @@ def _get_oimhs_paths(path, split, download): os.makedirs(image_dir, exist_ok=True) os.makedirs(gt_dir, exist_ok=True) - split_file = os.path.join(path, "split_file.json") - if not os.path.exists(split_file): - _create_splits(data_dir, split_file) - - eye_dirs = _get_per_split_dirs(split_file=split_file, split=split) - image_paths, gt_paths = [], [] + eye_dirs = natsorted(glob(os.path.join(data_dir, "Images", "*"))) for eye_dir in tqdm(eye_dirs): eye_id = os.path.split(eye_dir)[-1] all_oct_scan_paths = natsorted(glob(os.path.join(eye_dir, "*.png"))) @@ -83,11 +56,9 @@ def _get_oimhs_paths(path, split, download): image_path = os.path.join(image_dir, f"{eye_id}_{scan_id}.tif") gt_path = os.path.join(gt_dir, f"{eye_id}_{scan_id}.tif") - - image_paths.append(image_path) - gt_paths.append(gt_path) - if os.path.exists(image_path) and os.path.exists(gt_path): + image_paths.append(image_path) + gt_paths.append(gt_path) continue scan = imageio.imread(per_scan_path) @@ -101,13 +72,15 @@ def _get_oimhs_paths(path, split, download): imageio.imwrite(image_path, image, compression="zlib") imageio.imwrite(gt_path, instances, compression="zlib") + image_paths.append(image_path) + gt_paths.append(gt_path) + return image_paths, gt_paths def get_oimhs_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, int], - split: Literal["train", "val", "test"], resize_inputs: bool = False, download: bool = False, **kwargs @@ -118,7 +91,7 @@ def get_oimhs_dataset( Please cite it if you use this dataset for your publication. """ - image_paths, gt_paths = _get_oimhs_paths(path=path, split=split, download=download) + image_paths, gt_paths = _get_oimhs_paths(path=path, download=download) if resize_inputs: resize_kwargs = {"patch_shape": patch_shape, "is_rgb": True} @@ -143,7 +116,6 @@ def get_oimhs_loader( path: Union[os.PathLike, str], patch_shape: Tuple[int, int], batch_size: int, - split: Literal["train", "val", "test"], resize_inputs: bool = False, download: bool = False, **kwargs @@ -153,7 +125,7 @@ def get_oimhs_loader( """ ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) dataset = get_oimhs_dataset( - path=path, patch_shape=patch_shape, split=split, resize_inputs=resize_inputs, download=download, **ds_kwargs + path=path, patch_shape=patch_shape, resize_inputs=resize_inputs, download=download, **ds_kwargs ) loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) return loader diff --git a/torch_em/data/datasets/medical/osic_pulmofib.py b/torch_em/data/datasets/medical/osic_pulmofib.py index 581cd911..1a26615b 100644 --- a/torch_em/data/datasets/medical/osic_pulmofib.py +++ b/torch_em/data/datasets/medical/osic_pulmofib.py @@ -133,12 +133,6 @@ def get_osic_pulmofib_dataset( """ image_paths, gt_paths = _get_osic_pulmofib_paths(path=path, download=download) - if resize_inputs: - resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False} - kwargs, patch_shape = util.update_kwargs_for_resize_trafo( - kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs - ) - dataset = torch_em.default_segmentation_dataset( raw_paths=image_paths, raw_key="data", diff --git a/torch_em/data/datasets/medical/papila.py b/torch_em/data/datasets/medical/papila.py index 8a030b8f..5ecf1f81 100644 --- a/torch_em/data/datasets/medical/papila.py +++ b/torch_em/data/datasets/medical/papila.py @@ -2,15 +2,17 @@ from glob import glob from tqdm import tqdm from pathlib import Path -from typing import Union, Tuple, Literal +from typing import Union, Tuple import numpy as np from skimage import draw import imageio.v3 as imageio import torch_em +from torch_em.transform.generic import ResizeInputs from .. import util +from ... import ImageCollectionDataset URL = "https://figshare.com/ndownloader/files/35013982" @@ -48,8 +50,15 @@ def _get_papila_paths(path, task, expert_choice, download): image_paths = sorted(glob(os.path.join(data_dir, "FundusImages", "*.jpg"))) gt_dir = os.path.join(data_dir, "ground_truth") + if os.path.exists(gt_dir): + gt_paths = sorted(glob(os.path.join(gt_dir, f"*_{task}.tif"))) + return image_paths, gt_paths + os.makedirs(gt_dir, exist_ok=True) + if task is None: # we get the binary segmentations for both disc and cup + task = "*" + patient_ids = [Path(image_path).stem for image_path in image_paths] input_shape = (1934, 2576, 3) # shape of the input images @@ -59,9 +68,11 @@ def _get_papila_paths(path, task, expert_choice, download): glob(os.path.join(data_dir, "ExpertsSegmentations", "Contours", f"{patient_id}_{task}_{expert_choice}.txt")) ) + assert len(gt_contours) == (4 if task is None else 2) + for gt_contour in gt_contours: tmp_task = Path(gt_contour).stem.split("_")[1] - gt_path = os.path.join(gt_dir, f"{patient_id}_{tmp_task}_{expert_choice}.tif") + gt_path = os.path.join(gt_dir, f"{patient_id}_{tmp_task}.tif") gt_paths.append(gt_path) if os.path.exists(gt_path): continue @@ -75,8 +86,8 @@ def _get_papila_paths(path, task, expert_choice, download): def get_papila_dataset( path: Union[os.PathLike, str], patch_shape: Tuple[int, int], - task: Literal["cup", "disc"] = "disc", - expert_choice: Literal["exp1", "exp2"] = "exp1", + task: str = "disc", + expert_choice: str = "exp1", resize_inputs: bool = False, download: bool = False, **kwargs @@ -89,23 +100,26 @@ def get_papila_dataset( Please cite it if you use this dataset for a publication. """ assert expert_choice in ["exp1", "exp2"], f"'{expert_choice}' is not a valid expert choice." - assert task in ["cup", "disc"], f"'{task}' is not a valid task." + + if task is not None: + assert task in ["cup", "disc"], f"'{task}' is not a valid task." image_paths, gt_paths = _get_papila_paths(path=path, task=task, expert_choice=expert_choice, download=download) if resize_inputs: - resize_kwargs = {"patch_shape": patch_shape, "is_rgb": True} - kwargs, patch_shape = util.update_kwargs_for_resize_trafo( - kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs - ) - - dataset = torch_em.default_segmentation_dataset( - raw_paths=image_paths, - raw_key=None, - label_paths=gt_paths, - label_key=None, + raw_trafo = ResizeInputs(target_shape=patch_shape, is_rgb=True) + label_trafo = ResizeInputs(target_shape=patch_shape, is_label=True) + patch_shape = None + else: + patch_shape = patch_shape + raw_trafo, label_trafo = None, None + + dataset = ImageCollectionDataset( + raw_image_paths=image_paths, + label_image_paths=gt_paths, patch_shape=patch_shape, - is_seg_dataset=False, + raw_transform=raw_trafo, + label_transform=label_trafo, **kwargs ) @@ -116,8 +130,8 @@ def get_papila_loader( path: Union[os.PathLike, str], patch_shape: Tuple[int, int], batch_size: int, - task: Literal["cup", "disc"] = "disc", - expert_choice: Literal["exp1", "exp2"] = "exp1", + task: str = "disc", + expert_choice: str = "exp1", resize_inputs: bool = False, download: bool = False, **kwargs diff --git a/torch_em/data/datasets/medical/piccolo.py b/torch_em/data/datasets/medical/piccolo.py deleted file mode 100644 index 0b8738f5..00000000 --- a/torch_em/data/datasets/medical/piccolo.py +++ /dev/null @@ -1,105 +0,0 @@ -import os -from glob import glob -from natsort import natsorted -from typing import Union, Tuple, Literal - -import torch_em - -from .. import util - - -def get_piccolo_data(path, download): - """The database is located at: - - https://www.biobancovasco.bioef.eus/en/Sample-and-data-e-catalog/Databases/PD178-PICCOLO-EN1.html - - Follow the instructions below to get access to the dataset. - - Visit the attached website above - - Fill up the access request form: https://labur.eus/EzJUN - - Send an email to Basque Biobank at solicitudes.biobancovasco@bioef.eus, requesting access to the dataset. - - The team will request you to follow-up with some formalities. - - Then, you will gain access to the ".rar" file. - - Finally, provide the path where the rar file is stored, and you should be good to go. - """ - if download: - raise NotImplementedError( - "Automatic download is not possible for this dataset. See 'get_piccolo_data' for details." - ) - - data_dir = os.path.join(path, r"piccolo dataset-release0.1") - if os.path.exists(data_dir): - return data_dir - - rar_file = os.path.join(path, r"piccolo dataset_widefield-release0.1.rar") - if not os.path.exists(rar_file): - raise FileNotFoundError( - "You must download the PICCOLO dataset from the Basque Biobank, see 'get_piccolo_data' for details." - ) - - util.unzip_rarfile(rar_path=rar_file, dst=path, remove=False) - return data_dir - - -def _get_piccolo_paths(path, split, download): - data_dir = get_piccolo_data(path=path, download=download) - - split_dir = os.path.join(data_dir, split) - - image_paths = natsorted(glob(os.path.join(split_dir, "polyps", "*"))) - gt_paths = natsorted(glob(os.path.join(split_dir, "masks", "*"))) - - return image_paths, gt_paths - - -def get_piccolo_dataset( - path: Union[os.PathLike, str], - patch_shape: Tuple[int, int], - split: Literal["train", "validation", "test"], - resize_inputs: bool = False, - download: bool = False, - **kwargs -): - """Dataset for polyp segmentation in narrow band imaging colonoscopy images. - - This dataset is from Sánchez-Peralta et al. - https://doi.org/10.3390/app10238501. - To access the dataset, see `get_piccolo_data` for details. - - Please cite it if you use this data in a publication. - """ - image_paths, gt_paths = _get_piccolo_paths(path=path, split=split, download=download) - - if resize_inputs: - resize_kwargs = {"patch_shape": patch_shape, "is_rgb": True} - kwargs, patch_shape = util.update_kwargs_for_resize_trafo( - kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs - ) - - dataset = torch_em.default_segmentation_dataset( - raw_paths=image_paths, - raw_key=None, - label_paths=gt_paths, - label_key=None, - patch_shape=patch_shape, - is_seg_dataset=False, - **kwargs - ) - return dataset - - -def get_piccolo_loader( - path: Union[os.PathLike, str], - patch_shape: Tuple[int, int], - batch_size: int, - split: Literal["train", "validation", "test"], - resize_inputs: bool = False, - download: bool = False, - **kwargs -): - """Dataloader for polyp segmentation in narrow band imaging colonoscopy images. - See `get_piccolo_dataset` for details. - """ - ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) - dataset = get_piccolo_dataset( - path=path, patch_shape=patch_shape, split=split, resize_inputs=resize_inputs, download=download, **ds_kwargs - ) - loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) - return loader diff --git a/torch_em/data/datasets/medical/sega.py b/torch_em/data/datasets/medical/sega.py index 3c82763b..c101db2c 100644 --- a/torch_em/data/datasets/medical/sega.py +++ b/torch_em/data/datasets/medical/sega.py @@ -66,38 +66,7 @@ def _get_sega_paths(path, data_choice, download): else: image_paths.append(volume_path) - # now let's wrap the volumes to nifti format - fimage_dir = os.path.join(path, "data", "images") - fgt_dir = os.path.join(path, "data", "labels") - - os.makedirs(fimage_dir, exist_ok=True) - os.makedirs(fgt_dir, exist_ok=True) - - fimage_paths, fgt_paths = [], [] - for image_path, gt_path in zip(natsorted(image_paths), natsorted(gt_paths)): - fimage_path = os.path.join(fimage_dir, f"{Path(image_path).stem}.nii.gz") - fgt_path = os.path.join(fgt_dir, f"{Path(image_path).stem}.nii.gz") - - fimage_paths.append(fimage_path) - fgt_paths.append(fgt_path) - - if os.path.exists(fimage_path) and os.path.exists(fgt_path): - continue - - import nrrd - import numpy as np - import nibabel as nib - - image = nrrd.read(image_path)[0] - gt = nrrd.read(gt_path)[0] - - image_nifti = nib.Nifti2Image(image, np.eye(4)) - gt_nifti = nib.Nifti2Image(gt, np.eye(4)) - - nib.save(image_nifti, fimage_path) - nib.save(gt_nifti, fgt_path) - - return natsorted(fimage_paths), natsorted(fgt_paths) + return natsorted(image_paths), natsorted(gt_paths) def get_sega_dataset( @@ -123,11 +92,10 @@ def get_sega_dataset( dataset = torch_em.default_segmentation_dataset( raw_paths=image_paths, - raw_key="data", + raw_key=None, label_paths=gt_paths, - label_key="data", + label_key=None, patch_shape=patch_shape, - is_seg_dataset=True, **kwargs ) diff --git a/torch_em/data/datasets/medical/siim_acr.py b/torch_em/data/datasets/medical/siim_acr.py index 7326c7e1..f0033dcb 100644 --- a/torch_em/data/datasets/medical/siim_acr.py +++ b/torch_em/data/datasets/medical/siim_acr.py @@ -1,10 +1,12 @@ import os from glob import glob -from typing import Union, Tuple, Literal +from typing import Union, Tuple import torch_em +from torch_em.transform.generic import ResizeInputs from .. import util +from ... import ImageCollectionDataset KAGGLE_DATASET_NAME = "vbookshelf/pneumothorax-chest-xray-images-and-masks" @@ -40,7 +42,7 @@ def _get_siim_acr_paths(path, split, download): def get_siim_acr_dataset( path: Union[os.PathLike, str], - split: Literal["train", "test"], + split: str, patch_shape: Tuple[int, int], download: bool = False, resize_inputs: bool = False, @@ -58,18 +60,19 @@ def get_siim_acr_dataset( image_paths, gt_paths = _get_siim_acr_paths(path=path, split=split, download=download) if resize_inputs: - resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False} - kwargs, patch_shape = util.update_kwargs_for_resize_trafo( - kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs - ) - - dataset = torch_em.default_segmentation_dataset( - raw_paths=image_paths, - raw_key=None, - label_paths=gt_paths, - label_key=None, + raw_trafo = ResizeInputs(target_shape=patch_shape, is_label=False) + label_trafo = ResizeInputs(target_shape=patch_shape, is_label=True) + patch_shape = None + else: + patch_shape = patch_shape + raw_trafo, label_trafo = None, None + + dataset = ImageCollectionDataset( + raw_image_paths=image_paths, + label_image_paths=gt_paths, patch_shape=patch_shape, - is_seg_dataset=False, + raw_transform=raw_trafo, + label_transform=label_trafo, **kwargs ) dataset.max_sampling_attempts = 5000 @@ -79,7 +82,7 @@ def get_siim_acr_dataset( def get_siim_acr_loader( path: Union[os.PathLike, str], - split: Literal["train", "test"], + split: str, patch_shape: Tuple[int, int], batch_size: int, download: bool = False, diff --git a/torch_em/data/datasets/medical/spider.py b/torch_em/data/datasets/medical/spider.py deleted file mode 100644 index 6735b97f..00000000 --- a/torch_em/data/datasets/medical/spider.py +++ /dev/null @@ -1,79 +0,0 @@ -import os -from glob import glob -from natsort import natsorted - -import torch_em - -from .. import util - - -URL = { - "images": "https://zenodo.org/records/10159290/files/images.zip?download=1", - "masks": "https://zenodo.org/records/10159290/files/masks.zip?download=1" -} - -CHECKSUMS = { - "images": "a54cba2905284ff6cc9999f1dd0e4d871c8487187db7cd4b068484eac2f50f17", - "masks": "13a6e25a8c0d74f507e16ebb2edafc277ceeaf2598474f1fed24fdf59cb7f18f" -} - - -def get_spider_data(path, download): - os.makedirs(path, exist_ok=True) - - data_dir = os.path.join(path, "data") - if os.path.exists(data_dir): - return data_dir - - zip_path = os.path.join(path, "images.zip") - util.download_source(path=zip_path, url=URL["images"], download=download, checksum=CHECKSUMS["images"]) - util.unzip(zip_path=zip_path, dst=data_dir) - - zip_path = os.path.join(path, "masks.zip") - util.download_source(path=zip_path, url=URL["images"], download=download, checksum=CHECKSUMS["images"]) - util.unzip(zip_path=zip_path, dst=data_dir) - - return data_dir - - -def _get_spider_paths(path, download): - data_dir = get_spider_data(path, download) - - image_paths = natsorted(glob(os.path.join(data_dir, "images", "*.mha"))) - gt_paths = natsorted(glob(os.path.join(data_dir, "masks", "*.mha"))) - - return image_paths, gt_paths - - -def get_spider_dataset(path, patch_shape, download=False, **kwargs): - """Dataset for segmentation of vertebrae, intervertebral discs and spinal canal in T1 and T2 MRI series. - - https://zenodo.org/records/10159290 - https://www.nature.com/articles/s41597-024-03090-w - - Please cite it if you use this data in a publication. - """ - # TODO: expose the choice to choose specific MRI modality, for now this works for our interests. - image_paths, gt_paths = _get_spider_paths(path, download) - - dataset = torch_em.default_segmentation_dataset( - raw_paths=image_paths, - raw_key=None, - label_paths=gt_paths, - label_key=None, - is_seg_dataset=True, - patch_shape=patch_shape, - **kwargs - ) - - return dataset - - -def get_spider_loader(path, patch_shape, batch_size, download=False, **kwargs): - """Dataloader for segmentation of vertebrae, intervertebral discs and spinal canal in T1 and T2 MRI series. - See `get_spider_dataset` for details. - """ - ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) - dataset = get_spider_dataset(path=path, patch_shape=patch_shape, download=download, **ds_kwargs) - loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) - return loader diff --git a/torch_em/data/datasets/medical/toothfairy.py b/torch_em/data/datasets/medical/toothfairy.py deleted file mode 100644 index 5713f6b5..00000000 --- a/torch_em/data/datasets/medical/toothfairy.py +++ /dev/null @@ -1,88 +0,0 @@ -import os -from glob import glob -from tqdm import tqdm -from natsort import natsorted - -import numpy as np -import nibabel as nib - -import torch_em - -from .. import util - - -def get_toothfairy_data(path, download): - """Automatic download is not possible. - """ - if download: - raise NotImplementedError - - data_dir = os.path.join(path, "ToothFairy_Dataset", "Dataset") - return data_dir - - -def _get_toothfairy_paths(path, download): - data_dir = get_toothfairy_data(path, download) - - images_dir = os.path.join(path, "data", "images") - gt_dir = os.path.join(path, "data", "dense_labels") - if os.path.exists(images_dir) and os.path.exists(gt_dir): - return natsorted(glob(os.path.join(images_dir, "*.nii.gz"))), natsorted(glob(os.path.join(gt_dir, "*.nii.gz"))) - - os.makedirs(images_dir, exist_ok=True) - os.makedirs(gt_dir, exist_ok=True) - - image_paths, gt_paths = [], [] - for patient_dir in tqdm(glob(os.path.join(data_dir, "P*"))): - patient_id = os.path.split(patient_dir)[-1] - - dense_anns_path = os.path.join(patient_dir, "gt_alpha.npy") - if not os.path.exists(dense_anns_path): - continue - - image_path = os.path.join(patient_dir, "data.npy") - - image = np.load(image_path) - gt = np.load(dense_anns_path) - - image_nifti = nib.Nifti2Image(image, np.eye(4)) - gt_nifti = nib.Nifti2Image(gt, np.eye(4)) - - trg_image_path = os.path.join(images_dir, f"{patient_id}.nii.gz") - trg_gt_path = os.path.join(gt_dir, f"{patient_id}.nii.gz") - - nib.save(image_nifti, trg_image_path) - nib.save(gt_nifti, trg_gt_path) - - image_paths.append(trg_image_path) - gt_paths.append(trg_gt_path) - - return image_paths, gt_paths - - -def get_toothfairy_dataset(path, patch_shape, download=False, **kwargs): - """Canal segmentation in CBCT - https://toothfairy.grand-challenge.org/ - """ - image_paths, gt_paths = _get_toothfairy_paths(path, download) - - dataset = torch_em.default_segmentation_dataset( - raw_paths=image_paths, - raw_key="data", - label_paths=gt_paths, - label_key="data", - is_seg_dataset=True, - patch_shape=patch_shape, - **kwargs - ) - - return dataset - - -def get_toothfairy_loader(path, patch_shape, batch_size, download=False, **kwargs): - """ - """ - ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) - dataset = get_toothfairy_dataset(path, patch_shape, download, **ds_kwargs) - loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) - return loader diff --git a/torch_em/data/datasets/util.py b/torch_em/data/datasets/util.py index 8a0d2c35..2164b660 100644 --- a/torch_em/data/datasets/util.py +++ b/torch_em/data/datasets/util.py @@ -221,15 +221,6 @@ def unzip_tarfile(tar_path, dst, remove=True): os.remove(tar_path) -def unzip_rarfile(rar_path, dst, remove=True): - import rarfile - with rarfile.RarFile(rar_path) as f: - f.extractall(path=dst) - - if remove: - os.remove(rar_path) - - def unzip(zip_path, dst, remove=True): with zipfile.ZipFile(zip_path, "r") as f: f.extractall(dst) diff --git a/torch_em/data/sampler.py b/torch_em/data/sampler.py index 930aad63..7a47fb56 100644 --- a/torch_em/data/sampler.py +++ b/torch_em/data/sampler.py @@ -1,5 +1,5 @@ import numpy as np -from typing import List, Optional +from typing import List class MinForegroundSampler: @@ -85,18 +85,13 @@ class MinInstanceSampler: def __init__( self, min_num_instances: int = 2, - p_reject: float = 1.0, - min_size: Optional[int] = None + p_reject: float = 1.0 ): self.min_num_instances = min_num_instances self.p_reject = p_reject - self.min_size = min_size def __call__(self, x, y): - uniques, sizes = np.unique(y, return_counts=True) - if self.min_size is not None: - filter_ids = uniques[sizes >= self.min_size] - uniques = filter_ids + uniques = np.unique(y) if len(uniques) >= self.min_num_instances: return True else: diff --git a/torch_em/model/unet.py b/torch_em/model/unet.py index fc6c255d..fdc45f3b 100644 --- a/torch_em/model/unet.py +++ b/torch_em/model/unet.py @@ -353,10 +353,10 @@ def get_norm_layer(norm, dim, channels, n_groups=32): if norm is None: return None if norm == "InstanceNorm": - return nn.InstanceNorm2d(channels) if dim == 2 else nn.InstanceNorm3d(channels) - elif norm == "InstanceNormTrackStats": kwargs = {"affine": True, "track_running_stats": True, "momentum": 0.01} return nn.InstanceNorm2d(channels, **kwargs) if dim == 2 else nn.InstanceNorm3d(channels, **kwargs) + elif norm == "OldDefault": + return nn.InstanceNorm2d(channels) if dim == 2 else nn.InstanceNorm3d(channels) elif norm == "GroupNorm": return nn.GroupNorm(min(n_groups, channels), channels) elif norm == "BatchNorm": diff --git a/torch_em/model/unetr.py b/torch_em/model/unetr.py index 06c225c2..d86f5357 100644 --- a/torch_em/model/unetr.py +++ b/torch_em/model/unetr.py @@ -127,6 +127,7 @@ def __init__( scale_factors=scale_factors[::-1], conv_block_impl=ConvBlock2d, sampler_impl=_upsampler, + norm="OldDefault", ) else: self.decoder = decoder @@ -142,14 +143,14 @@ def __init__( Deconv2DBlock(features_decoder[0], features_decoder[1]), Deconv2DBlock(features_decoder[1], features_decoder[2]) ) - self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1]) + self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1], norm="OldDefault") else: self.deconv1 = Deconv2DBlock(embed_dim, features_decoder[0]) self.deconv2 = Deconv2DBlock(features_decoder[0], features_decoder[1]) self.deconv3 = Deconv2DBlock(features_decoder[1], features_decoder[2]) self.deconv4 = Deconv2DBlock(features_decoder[2], features_decoder[3]) - self.base = ConvBlock2d(embed_dim, features_decoder[0]) + self.base = ConvBlock2d(embed_dim, features_decoder[0], norm="OldDefault") self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1) @@ -157,7 +158,7 @@ def __init__( scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1] ) - self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1]) + self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1], norm="OldDefault") self.final_activation = self._get_activation(final_activation) diff --git a/torch_em/transform/label.py b/torch_em/transform/label.py index cd73b8ca..d13233e2 100644 --- a/torch_em/transform/label.py +++ b/torch_em/transform/label.py @@ -37,22 +37,6 @@ def label_consecutive(labels, with_background=True): return seg -class MinSizeLabelTransform: - def __init__(self, min_size=None, ndim=None, ensure_zero=False): - self.min_size = min_size - self.ndim = ndim - self.ensure_zero = ensure_zero - - def __call__(self, labels): - components = connected_components(labels, ndim=self.ndim, ensure_zero=self.ensure_zero) - if self.min_size is not None: - ids, sizes = np.unique(components, return_counts=True) - filter_ids = ids[sizes < self.min_size] - components[np.isin(components, filter_ids)] = 0 - components, _, _ = skimage.segmentation.relabel_sequential(components) - return components - - # TODO smoothing class BoundaryTransform: def __init__(self, mode="thick", add_binary_target=False, ndim=None): diff --git a/torch_em/util/debug.py b/torch_em/util/debug.py index 0a2c3a43..ca7bfa90 100644 --- a/torch_em/util/debug.py +++ b/torch_em/util/debug.py @@ -106,7 +106,6 @@ def _check_napari(loader, n_samples, instance_labels, model=None, device=None, r v.add_image(y) if pred is not None: v.add_image(pred) - napari.run() diff --git a/torch_em/util/prediction.py b/torch_em/util/prediction.py index eb6ba53d..e4ae86ed 100644 --- a/torch_em/util/prediction.py +++ b/torch_em/util/prediction.py @@ -182,7 +182,7 @@ def predict_block(block_id): if mask is not None: mask_block, _ = _load_block(mask, offset, block_shape, halo, with_channels=False) - mask_block = mask_block[inner_bb].astype("bool") + mask_block = mask_block[inner_bb] if mask_block.sum() == 0: return diff --git a/torch_em/util/util.py b/torch_em/util/util.py index d2fb9147..2b5b51b1 100644 --- a/torch_em/util/util.py +++ b/torch_em/util/util.py @@ -68,14 +68,7 @@ def ensure_tensor(tensor, dtype=None): if isinstance(tensor, np.ndarray): if np.dtype(tensor.dtype) in DTYPE_MAP: tensor = tensor.astype(DTYPE_MAP[tensor.dtype]) - # Try to convert the tensor, even if it has wrong byte-order - try: - tensor = torch.from_numpy(tensor) - except ValueError: - tensor = tensor.view(tensor.dtype.newbyteorder()) - if np.dtype(tensor.dtype) in DTYPE_MAP: - tensor = tensor.astype(DTYPE_MAP[tensor.dtype]) - tensor = torch.from_numpy(tensor) + tensor = torch.from_numpy(tensor) assert torch.is_tensor(tensor), f"Cannot convert {type(tensor)} to torch" if dtype is not None: