diff --git a/.github/doc_env.yaml b/.github/doc_env.yaml new file mode 100644 index 00000000..58233dd8 --- /dev/null +++ b/.github/doc_env.yaml @@ -0,0 +1,13 @@ +channels: + - pytorch + - conda-forge +name: + sam +dependencies: + - cpuonly + - pdoc + - python-elf + - pytorch + - torchvision + - tqdm + - zarr diff --git a/.github/workflows/build_docs.yaml b/.github/workflows/build_docs.yaml new file mode 100644 index 00000000..91d80f93 --- /dev/null +++ b/.github/workflows/build_docs.yaml @@ -0,0 +1,65 @@ +name: build_documentation + +# build the documentation for a new release +on: + release: + types: created + +# # for debugging +# on: [push, pull_request] + +# security: restrict permissions for CI jobs. +permissions: + contents: read + +# NOTE: importing of napari fails with CI and I am not quite sure why +# I tried to adjust it based on the napari CI tests, but that didn't seem to help + +# https://github.com/napari/napari/blob/main/.github/workflows/test_comprehensive.yml +jobs: + # Build the documentation and upload the static HTML files as an artifact. + build: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Setup miniconda + uses: conda-incubator/setup-miniconda@v2 + with: + activate-environment: sam + mamba-version: "*" + auto-update-conda: true + environment-file: .github/doc_env.yaml + python-version: ${{ matrix.python-version }} + auto-activate-base: false + env: + ACTIONS_ALLOW_UNSECURE_COMMANDS: true + + - name: Install package + shell: bash -l {0} + run: pip install --no-deps -e . + + # We use a custom build script for pdoc itself, ideally you just run `pdoc -o docs/ ...` here. + - name: Run pdoc + shell: bash -l {0} + run: python build_doc.py --out + + - uses: actions/upload-pages-artifact@v1 + with: + path: tmp/ + + # Deploy the artifact to GitHub pages. + # This is a separate job so that only actions/deploy-pages has the necessary permissions. + deploy: + needs: build + runs-on: ubuntu-latest + permissions: + pages: write + id-token: write + environment: + name: github-pages + url: ${{ steps.deployment.outputs.page_url }} + steps: + - id: deployment + uses: actions/deploy-pages@v2 diff --git a/README.md b/README.md index a968e4db..25b41ab6 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,4 @@ +[![DOC](https://shields.mitmproxy.org/badge/docs-pdoc.dev-brightgreen.svg)](https://constantinpape.github.io/torch-em/torch_em.html) [![Build Status](https://github.com/constantinpape/torch-em/workflows/test/badge.svg)](https://github.com/constantinpape/torch-em/actions) [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.5108853.svg)](https://doi.org/10.5281/zenodo.5108853) [![Anaconda-Server Badge](https://anaconda.org/conda-forge/torch_em/badges/version.svg)](https://anaconda.org/conda-forge/torch_em) @@ -103,7 +104,7 @@ If you want to use the GPU version, make sure to set the correct CUDA version fo You can set up a conda environment using one of these files like this: ```bash -mamba create -f .yaml -n +mamba env create -f .yaml -n mamba activate pip install -e . ``` diff --git a/build_doc.py b/build_doc.py new file mode 100644 index 00000000..49bd533b --- /dev/null +++ b/build_doc.py @@ -0,0 +1,14 @@ +import argparse +from subprocess import run + +parser = argparse.ArgumentParser() +parser.add_argument("--out", "-o", action="store_true") +args = parser.parse_args() + +cmd = ["pdoc", "--docformat", "google"] + +if args.out: + cmd.extend(["--out", "tmp/"]) +cmd.append("torch_em") + +run(cmd) diff --git a/doc/datasets_and_dataloaders.md b/doc/datasets_and_dataloaders.md index fd629be4..abd9e663 100644 --- a/doc/datasets_and_dataloaders.md +++ b/doc/datasets_and_dataloaders.md @@ -1,54 +1,32 @@ # Datasets in `torch-em` -Available open-source datasets in `torch-em` located at `torch_em/data/datasets/` (see `scripts/datasets` for a quick guide on how to use the dataloaders out-of-the-box): +We provide PyTorch Datasets / DataLoaders for many different biomedical datasets, mostly for segmentation tasks. +They are implemented in `torch_em.data.datasets`. See `scripts/datasets` for examples on how to visualize images from these datasets. -### Microscopy -- ASEM (`asem.py`): Segmentation of organelles in FIB-SEM cells. -- AxonDeepSeg (`axondeepseg.py`): Segmentation of myelinated axons in electron microscopy. -- MitoLab* (`cem.py`): - - CEM MitoLab: Segmentation of mitochondria in electron microscopy. - - CEM Mito Benchmark: Segmentation of mitochondria in 7 benchmark electron microscopy datasets. -- Covid IF (`covidif.py`): Segmentation of cells and nuclei in immunofluoroscence. -- CREMI (`cremi.py`): Segmentation of neurons in electron microscopy. -- Cell Tracking Challenge (`ctc.py`): Segmentation data for cell tracking challenge (consists of 10 datasets). -- DeepBacs (`deepbacs.py`): Segmentation of bacteria in light microscopy. -- DSB (`dsb.py`): Segmentation of nuclei in light microscopy. -- DynamicNuclearNet* (`dynamicnuclearnet.py`): Segmentation of nuclei in fluorescence microscopy. -- HPA (`hpa.py`): Segmentation of cells in light microscopy. -- ISBI (`isbi2012.py`): Segmentation of neurons in electron microscopy. -- Kasthuri (`kasthuri.py`): Segmentation of mitochondria in electron microscopy. -- LIVECell (`livecell.py`): Segmentation of cells in phase-contrast microscopy. -- Lucchi (`lucchi.py`): Segmentation of mitochondria in electron microscopy. -- MitoEM (`mitoem.py`): Segmentation of mitochondria in electron microscopy. -- Mouse Embryo (`mouse_embryo.py`): Segmentation of nuclei in confocal microscopy. -- NeurIPS CellSeg (`neurips_cell_seg.py`): Segmentation of cells in multi-modality light microscopy datasets. -- NucMM (`nuc_mm.py`): Segmentation of nuclei in electron microscopy and micro-CT. -- PlantSeg (`plantseg.py`): Segmentation of cells in confocal and light-sheet microscopy. -- Platynereis (`platynereis.py`): Segmentation of nuclei in electron microscopy. -- PNAS* (`pnas_arabidopsis.py`): TODO -- SNEMI (`snemi.py`): Segmentation of neurons in electron microscopy. -- Sponge EM (`sponge_em.py`): Segmentation of sponge cells and organelles in electron microscopy. -- TissueNet* (`tissuenet.py`): Segmentation of cellls in tissue imaged with light microscopy. -- UroCell (`uro_cell.py`): Segmentation of mitochondria and other organelles in electron microscopy. -- VNC (`vnc.py`): Segmentation of mitochondria in electron microscopy +## Available Datasets -### Histopathology +All datasets in `torch_em.data.datasets` are implemented according to the following logic: +- The function `get_..._data` downloads the respective datasets. Note that some datasets cannot be downloaded automatically. In these cases the function will raise an error with a message that explains how to download the data. +- The function `get_..._dataset` returns the PyTorch Dataset for the corresponding dataset. +- The function `get_..._dataloader` returns the PyTorch DataLoader for the corresponding dataset. -- BCSS (`bcss.py`): Segmentation of breast cancer tissue in histopathology. -- Lizard* (`lizard.py`): Segmentation of nuclei in histopathology. -- MoNuSaC (`monusac.py`): Segmentation of multi-organ nuclei in histopathology. -- MoNuSeg (`monuseg.py`): Segmentation of multi-organ nuclei in histopathology. -- PanNuke (`pannuke.py`): Segmentation of nuclei in histopathology. +### Light Microscopy + +We provide several light microscopy datasets. See `torch_em.data.datasets.light_microscopy` for an overview. + +### Electron Microscopy +We provide several electron microscopy datasets. See `torch_em.data.datasets.electron_microscopy` for an overview. + +### Histopathology + +`torch_em.data.datasets.histopathology` ### Medical Imaging -- AutoPET* (`medical/autopet.py`): Segmentation of lesions in whole-body FDG-PET/CT. -- BTCV* (`medical/btcv.py`): Segmentation of multiple organs in CT. +`torch_em.data.datasets.medical` -### NOTE: -- \* - These datasets cannot be used out of the box (mostly because of missing automatic downloading). Please take a look at the scripts and the dataset object for details. ## How to create your own dataloader? @@ -157,3 +135,51 @@ dataset = RawImageCollectionDataset( # there are other optional parameters, see `torch_em.data.raw_image_collection_dataset.py` for details. ) ``` + + + diff --git a/doc/start_page.md b/doc/start_page.md new file mode 100644 index 00000000..1bca94a1 --- /dev/null +++ b/doc/start_page.md @@ -0,0 +1,3 @@ +[torch_em](https://github.com/constantinpape/torch-em) is a library for deep learning in microscopy images. It is built on top of [PyTorch](https://pytorch.org/). + +We are working on the documentation and will extend and improve it soon! diff --git a/experiments/unet-segmentation/dsb/train_boundaries.py b/experiments/unet-segmentation/dsb/train_boundaries.py index 645e336b..addd5c79 100644 --- a/experiments/unet-segmentation/dsb/train_boundaries.py +++ b/experiments/unet-segmentation/dsb/train_boundaries.py @@ -9,11 +9,16 @@ def train_boundaries(args): patch_shape = (1, 256, 256) train_loader = get_dsb_loader( - args.input, patch_shape, split="train", + args.input, patch_shape=patch_shape, split="train", download=True, boundaries=True, batch_size=args.batch_size ) + + # Uncomment this for checking the loader. + # from torch_em.util.debug import check_loader + # check_loader(train_loader, 4) + val_loader = get_dsb_loader( - args.input, patch_shape, split="test", + args.input, patch_shape=patch_shape, split="test", boundaries=True, batch_size=args.batch_size ) loss = torch_em.loss.DiceLoss() diff --git a/notebooks/README.md b/notebooks/README.md index b07bf094..3a57bd95 100644 --- a/notebooks/README.md +++ b/notebooks/README.md @@ -1,3 +1,3 @@ # Jupyter Notebooks for `torch-em`-based Implementations and Tutorials -1. `tutorial_create_dataloaders.ipynb`: This notebook gives you a head-start on how to create your custom dataloaders for segmentation for most data structures. It's recommended to checkout `torch_em.data.datasets.README.md` for the first outlook before getting started. \ No newline at end of file +1. `tutorial_create_dataloaders.ipynb`: This notebook gives you a head-start on how to create your custom dataloaders for segmentation for most data structures. It's recommended to checkout https://github.com/constantinpape/torch-em/blob/main/doc/datasets_and_dataloaders.md for the first outlook before getting started. diff --git a/scripts/datasets/check_axondeepseg.py b/scripts/datasets/check_axondeepseg.py deleted file mode 100644 index 70770da9..00000000 --- a/scripts/datasets/check_axondeepseg.py +++ /dev/null @@ -1,23 +0,0 @@ -from torch_em.data.datasets import get_axondeepseg_loader -from torch_em.util.debug import check_loader - - -ROOT = "/scratch/usr/nimanwai/data/axondeepseg" - - -def check_axondeepseg(): - loader = get_axondeepseg_loader( - ROOT, name="sem", patch_shape=(1024, 1024), batch_size=1, split="train", - one_hot_encoding=True, shuffle=True, download=True, val_fraction=0.1 - ) - check_loader(loader, 5, True, True, False, "sem_loader.png") - - loader = get_axondeepseg_loader( - ROOT, name="tem", patch_shape=(1024, 1024), batch_size=1, split="train", - one_hot_encoding=True, shuffle=True, download=True, val_fraction=0.1 - ) - check_loader(loader, 5, True, True, False, "tem_loader.png") - - -if __name__ == "__main__": - check_axondeepseg() diff --git a/scripts/datasets/check_busi.py b/scripts/datasets/check_busi.py new file mode 100644 index 00000000..a6430ce0 --- /dev/null +++ b/scripts/datasets/check_busi.py @@ -0,0 +1,22 @@ +from torch_em.util.debug import check_loader +from torch_em.data.datasets.medical import get_busi_loader + + +ROOT = "/media/anwai/ANWAI/data/busi" + + +def check_busi(): + loader = get_busi_loader( + path=ROOT, + patch_shape=(512, 512), + batch_size=2, + category=None, + resize_inputs=False, + download=True, + ) + + check_loader(loader, 8) + + +if __name__ == "__main__": + check_busi() diff --git a/scripts/datasets/check_drive.py b/scripts/datasets/check_drive.py new file mode 100644 index 00000000..f75ec5c9 --- /dev/null +++ b/scripts/datasets/check_drive.py @@ -0,0 +1,21 @@ +from torch_em.util.debug import check_loader +from torch_em.data.datasets.medical import get_drive_loader + + +ROOT = "/media/anwai/ANWAI/data/drive" + + +def check_drive(): + loader = get_drive_loader( + path=ROOT, + patch_shape=(256, 256), + batch_size=2, + resize_inputs=True, + download=True, + ) + + check_loader(loader, 8) + + +if __name__ == "__main__": + check_drive() diff --git a/scripts/datasets/check_hpa.py b/scripts/datasets/check_hpa.py deleted file mode 100644 index 770bd56f..00000000 --- a/scripts/datasets/check_hpa.py +++ /dev/null @@ -1,11 +0,0 @@ -from torch_em.data.datasets import get_hpa_segmentation_loader -from torch_em.util.debug import check_loader - - -def check_hpa(): - loader = get_hpa_segmentation_loader("./data/hpa", "train", (512, 512), 1, download=True) - check_loader(loader, 8, instance_labels=True) - - -if __name__ == "__main__": - check_hpa() diff --git a/scripts/datasets/check_kasthuri.py b/scripts/datasets/check_kasthuri.py deleted file mode 100644 index 4dfce7ff..00000000 --- a/scripts/datasets/check_kasthuri.py +++ /dev/null @@ -1,11 +0,0 @@ -from torch_em.data.datasets import get_kasthuri_loader -from torch_em.util.debug import check_loader - - -def check_kasthuri(): - loader = get_kasthuri_loader("./data/kasthuri", "train", (8, 512, 512), 1, download=True) - check_loader(loader, 8, instance_labels=True) - - -if __name__ == "__main__": - check_kasthuri() diff --git a/scripts/datasets/check_lizard.py b/scripts/datasets/check_lizard.py deleted file mode 100644 index 71bf508d..00000000 --- a/scripts/datasets/check_lizard.py +++ /dev/null @@ -1,11 +0,0 @@ -from torch_em.data.datasets import get_lizard_loader -from torch_em.util.debug import check_loader - - -def check_lizard(): - loader = get_lizard_loader("./data/lizard", (512, 512), 1, download=True) - check_loader(loader, 8, rgb=True, instance_labels=True) - - -if __name__ == "__main__": - check_lizard() diff --git a/scripts/datasets/check_mouse_embryo.py b/scripts/datasets/check_mouse_embryo.py deleted file mode 100644 index 7f7b9b8c..00000000 --- a/scripts/datasets/check_mouse_embryo.py +++ /dev/null @@ -1,11 +0,0 @@ -from torch_em.data.datasets import get_mouse_embryo_loader -from torch_em.util.debug import check_loader - - -def check_mouse_embryo(): - loader = get_mouse_embryo_loader("./data/mouse_embryo", "nuclei", "train", (8, 512, 512), 1, download=True) - check_loader(loader, 8, instance_labels=True) - - -if __name__ == "__main__": - check_mouse_embryo() diff --git a/scripts/datasets/check_nuc_mm.py b/scripts/datasets/check_nuc_mm.py deleted file mode 100644 index 9cfb0467..00000000 --- a/scripts/datasets/check_nuc_mm.py +++ /dev/null @@ -1,20 +0,0 @@ -from torch_em.data.datasets import get_nuc_mm_loader -from torch_em.util.debug import check_loader - -NUC_MM_ROOT = "/scratch/usr/nimanwai/data/nuc_mm" - - -def check_nuc_mm(): - loader = get_nuc_mm_loader( - NUC_MM_ROOT, "mouse", "train", patch_shape=(1, 192, 192), batch_size=1, download=True - ) - check_loader(loader, 5, instance_labels=True, plt=True, save_path="mouse_loader.png") - - loader = get_nuc_mm_loader( - NUC_MM_ROOT, "zebrafish", "train", patch_shape=(1, 64, 64), batch_size=1, download=True - ) - check_loader(loader, 5, instance_labels=True, plt=True, save_path="zebrafish_loader.png") - - -if __name__ == "__main__": - check_nuc_mm() diff --git a/scripts/datasets/check_asem.py b/scripts/datasets/electron_microscopy/check_asem.py similarity index 100% rename from scripts/datasets/check_asem.py rename to scripts/datasets/electron_microscopy/check_asem.py diff --git a/scripts/datasets/electron_microscopy/check_axondeepseg.py b/scripts/datasets/electron_microscopy/check_axondeepseg.py new file mode 100644 index 00000000..9f26991b --- /dev/null +++ b/scripts/datasets/electron_microscopy/check_axondeepseg.py @@ -0,0 +1,35 @@ +import os +import sys + +from torch_em.data.datasets import get_axondeepseg_loader +from torch_em.util.debug import check_loader + +sys.path.append("..") + + +def check_axondeepseg(): + from util import ROOT, USE_NAPARI + + data_root = os.path.join(ROOT, "axondeepseg") + if USE_NAPARI: + use_plt = False + save_path = None, None + else: + use_plt = True + save_path = "sem_data.png", "tem_data.png" + + loader = get_axondeepseg_loader( + data_root, name="sem", patch_shape=(1024, 1024), batch_size=1, split="train", + one_hot_encoding=True, shuffle=True, download=True, val_fraction=0.1 + ) + check_loader(loader, 5, plt=use_plt, save_path=save_path[0]) + + loader = get_axondeepseg_loader( + data_root, name="tem", patch_shape=(1024, 1024), batch_size=1, split="train", + one_hot_encoding=True, shuffle=True, download=True, val_fraction=0.1 + ) + check_loader(loader, 5, plt=use_plt, save_path=save_path[1]) + + +if __name__ == "__main__": + check_axondeepseg() diff --git a/scripts/datasets/check_cem.py b/scripts/datasets/electron_microscopy/check_cem.py similarity index 97% rename from scripts/datasets/check_cem.py rename to scripts/datasets/electron_microscopy/check_cem.py index 1b6ad1ea..3cef55ca 100644 --- a/scripts/datasets/check_cem.py +++ b/scripts/datasets/electron_microscopy/check_cem.py @@ -1,14 +1,15 @@ import os -import imageio.v3 as imageio +import sys from glob import glob +import imageio.v3 as imageio import numpy as np import torch_em from torch_em.data.datasets import cem from torch_em.util.debug import check_loader -# ROOT = "./data" -ROOT = "/scratch-grete/projects/nim00007/data/mitolab" +sys.path.append("..") +from util import ROOT def get_all_shapes(): diff --git a/scripts/datasets/check_cremi.py b/scripts/datasets/electron_microscopy/check_cremi.py similarity index 56% rename from scripts/datasets/check_cremi.py rename to scripts/datasets/electron_microscopy/check_cremi.py index 809cc32e..ca326090 100644 --- a/scripts/datasets/check_cremi.py +++ b/scripts/datasets/electron_microscopy/check_cremi.py @@ -1,9 +1,16 @@ +import os +import sys + from torch_em.data.datasets import get_cremi_loader from torch_em.util.debug import check_loader +sys.path.append("..") + def check_cremi(): - loader = get_cremi_loader("./data/cremi", (8, 512, 512), 1, download=True) + from util import ROOT + + loader = get_cremi_loader(os.path.join(ROOT, "cremi"), (8, 512, 512), 1, download=True) check_loader(loader, 8, instance_labels=True) diff --git a/scripts/datasets/electron_microscopy/check_deepict.py b/scripts/datasets/electron_microscopy/check_deepict.py new file mode 100644 index 00000000..32bf465d --- /dev/null +++ b/scripts/datasets/electron_microscopy/check_deepict.py @@ -0,0 +1,44 @@ +import os +import sys + +import h5py +import napari + +from torch_em.util.debug import check_loader +from torch_em.data.datasets.electron_microscopy import get_deepict_actin_loader +from torch_em.data import MinForegroundSampler + +sys.path.append("..") + + +def check_deepict_actin_volumes(): + from util import ROOT + + data_root = os.path.join(ROOT, "deepict") + + for dataset in ["00004", "00012"]: + path = os.path.join(data_root, "deepict_actin", f"{dataset}.h5") + with h5py.File(path, "r") as f: + raw = f["raw"][:] + actin_seg = f["/labels/actin"][:] + + v = napari.Viewer() + v.add_image(raw) + v.add_labels(actin_seg) + v.title = dataset + napari.run() + + +def check_deepict_actin_loader(): + from util import ROOT + + data_root = os.path.join(ROOT, "deepict") + loader = get_deepict_actin_loader( + data_root, (96, 398, 398), 1, download=True, sampler=MinForegroundSampler(min_fraction=0.025) + ) + check_loader(loader, 8, instance_labels=True) + + +if __name__ == "__main__": + check_deepict_actin_volumes() + # check_deepict_actin_loader() diff --git a/scripts/datasets/electron_microscopy/check_emneuron.py b/scripts/datasets/electron_microscopy/check_emneuron.py new file mode 100644 index 00000000..d714645e --- /dev/null +++ b/scripts/datasets/electron_microscopy/check_emneuron.py @@ -0,0 +1,23 @@ +import os +import sys + +from torch_em.util.debug import check_loader +from torch_em.data.datasets import get_emneuron_loader + +sys.path.append("..") + + +def check_emneuron(): + from util import ROOT + + loader = get_emneuron_loader( + path=os.path.join(ROOT, "emneuron"), + batch_size=1, + patch_shape=(8, 512, 512), + split="val", + ) + check_loader(loader, 8, instance_labels=True, plt=True, save_path="./test.png") + + +if __name__ == "__main__": + check_emneuron() diff --git a/scripts/datasets/check_isbi.py b/scripts/datasets/electron_microscopy/check_isbi.py similarity index 52% rename from scripts/datasets/check_isbi.py rename to scripts/datasets/electron_microscopy/check_isbi.py index 5fbc3715..4c47735c 100644 --- a/scripts/datasets/check_isbi.py +++ b/scripts/datasets/electron_microscopy/check_isbi.py @@ -1,9 +1,17 @@ +import os +import sys + from torch_em.data.datasets import get_isbi_loader from torch_em.util.debug import check_loader +sys.path.append("..") + def check_isbi(): - loader = get_isbi_loader("./data/isbi.h5", (8, 512, 512), 1, download=True) + from util import ROOT + + data_path = os.path.join(ROOT, "isbi.h5") + loader = get_isbi_loader(data_path, (8, 512, 512), 1, download=True) check_loader(loader, 8, instance_labels=True) diff --git a/scripts/datasets/electron_microscopy/check_kasthuri.py b/scripts/datasets/electron_microscopy/check_kasthuri.py new file mode 100644 index 00000000..69358822 --- /dev/null +++ b/scripts/datasets/electron_microscopy/check_kasthuri.py @@ -0,0 +1,18 @@ +import os +import sys + +from torch_em.data.datasets import get_kasthuri_loader +from torch_em.util.debug import check_loader + +sys.path.append("..") + + +def check_kasthuri(): + from util import ROOT, USE_NAPARI + + loader = get_kasthuri_loader(os.path.join(ROOT, "kasthuri"), "train", (8, 512, 512), 1, download=True) + check_loader(loader, 8, instance_labels=True, plt=not USE_NAPARI) + + +if __name__ == "__main__": + check_kasthuri() diff --git a/scripts/datasets/check_lucchi.py b/scripts/datasets/electron_microscopy/check_lucchi.py similarity index 55% rename from scripts/datasets/check_lucchi.py rename to scripts/datasets/electron_microscopy/check_lucchi.py index 099fc328..6b511a13 100644 --- a/scripts/datasets/check_lucchi.py +++ b/scripts/datasets/electron_microscopy/check_lucchi.py @@ -1,9 +1,16 @@ +import os +import sys + from torch_em.data.datasets import get_lucchi_loader from torch_em.util.debug import check_loader +sys.path.append("..") + def check_lucchi(): - loader = get_lucchi_loader("./data/lucchi", "train", (8, 512, 512), 1, download=True) + from util import ROOT + + loader = get_lucchi_loader(os.path.join(ROOT, "lucchi"), "train", (8, 512, 512), 1, download=True) check_loader(loader, 8, instance_labels=True) diff --git a/scripts/datasets/electron_microscopy/check_mitoem.py b/scripts/datasets/electron_microscopy/check_mitoem.py new file mode 100644 index 00000000..719ce0a2 --- /dev/null +++ b/scripts/datasets/electron_microscopy/check_mitoem.py @@ -0,0 +1,19 @@ +import os +import sys + +from torch_em.data.datasets import get_mitoem_loader +from torch_em.util.debug import check_loader + +sys.path.append("..") + + +def check_mitoem(): + from util import ROOT + + loader = get_mitoem_loader(os.path.join(ROOT, "mitoem"), splits=["train"], patch_shape=(8, 512, 512), + batch_size=1, download=True) + check_loader(loader, 8, instance_labels=True) + + +if __name__ == "__main__": + check_mitoem() diff --git a/scripts/datasets/electron_microscopy/check_nuc_mm.py b/scripts/datasets/electron_microscopy/check_nuc_mm.py new file mode 100644 index 00000000..9b830458 --- /dev/null +++ b/scripts/datasets/electron_microscopy/check_nuc_mm.py @@ -0,0 +1,33 @@ +import os +import sys + +from torch_em.data.datasets import get_nuc_mm_loader +from torch_em.util.debug import check_loader + +sys.path.append("..") + + +def check_nuc_mm(): + from util import ROOT, USE_NAPARI + + nuc_mm_root = os.path.join(ROOT, "nuc_mm") + if USE_NAPARI: + use_plt = False + save_path = None, None + else: + use_plt = True + save_path = "mouse_data.png", "zebrafish_data.png" + + loader = get_nuc_mm_loader( + nuc_mm_root, "mouse", "train", patch_shape=(1, 192, 192), batch_size=1, download=True + ) + check_loader(loader, 5, instance_labels=True, plt=use_plt, save_path=save_path[0]) + + loader = get_nuc_mm_loader( + nuc_mm_root, "zebrafish", "train", patch_shape=(1, 64, 64), batch_size=1, download=True + ) + check_loader(loader, 5, instance_labels=True, plt=use_plt, save_path=save_path[1]) + + +if __name__ == "__main__": + check_nuc_mm() diff --git a/scripts/datasets/check_platynereis.py b/scripts/datasets/electron_microscopy/check_platynereis.py similarity index 50% rename from scripts/datasets/check_platynereis.py rename to scripts/datasets/electron_microscopy/check_platynereis.py index a9bf05de..5de77d74 100644 --- a/scripts/datasets/check_platynereis.py +++ b/scripts/datasets/electron_microscopy/check_platynereis.py @@ -1,27 +1,35 @@ -import torch_em.data.datasets.platynereis as platy +import os +import sys + +import torch_em.data.datasets.electron_microscopy.platynereis as platy from torch_em.util.debug import check_loader +sys.path.append("..") + def check_platynereis(): + from util import ROOT + + data_root = os.path.join(ROOT, "platy") # check nucleus loader print("Check nucleus loader") - loader = platy.get_platynereis_nuclei_loader("./data/platy", (8, 512, 512), 1, download=True) + loader = platy.get_platynereis_nuclei_loader(data_root, (8, 512, 512), 1, download=True) check_loader(loader, 8, instance_labels=True) # check cell loader print("Check cell loader") - loader = platy.get_platynereis_cell_loader("./data/platy", (8, 512, 512), 1, download=True) + loader = platy.get_platynereis_cell_loader(data_root, (8, 512, 512), 1, download=True) check_loader(loader, 8, instance_labels=True) # check cilia loader print("Check cilia loader") - loader = platy.get_platynereis_cilia_loader("./data/platy", (8, 512, 512), 1, download=True) + loader = platy.get_platynereis_cilia_loader(data_root, (8, 512, 512), 1, download=True) check_loader(loader, 8, instance_labels=True) # check cuticle loader print("Check cuticle loader") - loader = platy.get_platynereis_cuticle_loader("./data/platy", (8, 512, 512), 1, download=True) + loader = platy.get_platynereis_cuticle_loader(data_root, (8, 512, 512), 1, download=True) check_loader(loader, 8, instance_labels=True) diff --git a/scripts/datasets/check_snemi.py b/scripts/datasets/electron_microscopy/check_snemi.py similarity index 56% rename from scripts/datasets/check_snemi.py rename to scripts/datasets/electron_microscopy/check_snemi.py index 58e91fc9..54e33bae 100644 --- a/scripts/datasets/check_snemi.py +++ b/scripts/datasets/electron_microscopy/check_snemi.py @@ -1,9 +1,16 @@ +import os +import sys + from torch_em.data.datasets import get_snemi_loader from torch_em.util.debug import check_loader +sys.path.append("..") + def check_snemi(): - loader = get_snemi_loader("./data/snemi", (8, 512, 512), 1, download=True) + from util import ROOT + + loader = get_snemi_loader(os.path.join(ROOT, "snemi"), (8, 512, 512), 1, download=True) check_loader(loader, 8, instance_labels=True) diff --git a/scripts/datasets/check_sponge_em.py b/scripts/datasets/electron_microscopy/check_sponge_em.py similarity index 55% rename from scripts/datasets/check_sponge_em.py rename to scripts/datasets/electron_microscopy/check_sponge_em.py index fb10f631..549b7fca 100644 --- a/scripts/datasets/check_sponge_em.py +++ b/scripts/datasets/electron_microscopy/check_sponge_em.py @@ -1,9 +1,16 @@ +import os +import sys + from torch_em.data.datasets import get_sponge_em_loader from torch_em.util.debug import check_loader +sys.path.append("..") + def check_sponge_em(): - loader = get_sponge_em_loader("./data/sponge_em", "instances", (8, 512, 512), 1, download=True) + from util import ROOT + + loader = get_sponge_em_loader(os.path.join(ROOT, "sponge_em"), "instances", (8, 512, 512), 1, download=True) check_loader(loader, 8, instance_labels=True) diff --git a/scripts/datasets/check_uro_cell.py b/scripts/datasets/electron_microscopy/check_uro_cell.py similarity index 55% rename from scripts/datasets/check_uro_cell.py rename to scripts/datasets/electron_microscopy/check_uro_cell.py index 05361c0a..4c9baf96 100644 --- a/scripts/datasets/check_uro_cell.py +++ b/scripts/datasets/electron_microscopy/check_uro_cell.py @@ -1,9 +1,16 @@ +import os +import sys + from torch_em.data.datasets import get_uro_cell_loader from torch_em.util.debug import check_loader +sys.path.append("..") + def check_uro_cell(): - loader = get_uro_cell_loader("./data/uro_cell", "mito", (8, 512, 512), 1, download=True) + from util import ROOT + + loader = get_uro_cell_loader(os.path.join(ROOT, "uro_cell"), "mito", (8, 512, 512), 1, download=True) check_loader(loader, 8, instance_labels=True) diff --git a/scripts/datasets/check_vnc.py b/scripts/datasets/electron_microscopy/check_vnc.py similarity index 56% rename from scripts/datasets/check_vnc.py rename to scripts/datasets/electron_microscopy/check_vnc.py index 57b20075..930619bf 100644 --- a/scripts/datasets/check_vnc.py +++ b/scripts/datasets/electron_microscopy/check_vnc.py @@ -1,9 +1,17 @@ +import os +import sys + + from torch_em.data.datasets import get_vnc_mito_loader from torch_em.util.debug import check_loader +sys.path.append("..") + def check_vnc(): - loader = get_vnc_mito_loader("./data/vnc", (8, 512, 512), 1, download=True) + from util import ROOT + + loader = get_vnc_mito_loader(os.path.join(ROOT, "vnc"), (8, 512, 512), 1, download=True) check_loader(loader, 8, instance_labels=True) diff --git a/scripts/datasets/check_bcss.py b/scripts/datasets/histopathology/check_bcss.py similarity index 100% rename from scripts/datasets/check_bcss.py rename to scripts/datasets/histopathology/check_bcss.py diff --git a/scripts/datasets/histopathology/check_cryonuseg.py b/scripts/datasets/histopathology/check_cryonuseg.py new file mode 100644 index 00000000..1aa1c22b --- /dev/null +++ b/scripts/datasets/histopathology/check_cryonuseg.py @@ -0,0 +1,25 @@ +import os +import sys + +from torch_em.util.debug import check_loader +from torch_em.data.datasets import get_cryonuseg_loader + + +sys.path.append("..") + + +def check_cryonuseg(): + from util import ROOT + + loader = get_cryonuseg_loader( + path=os.path.join(ROOT, "cryonuseg"), + patch_shape=(1, 512, 512), + batch_size=1, + rater="b3", + download=True, + ) + check_loader(loader, 8, rgb=True, instance_labels=True) + + +if __name__ == "__main__": + check_cryonuseg() diff --git a/scripts/datasets/histopathology/check_lizard.py b/scripts/datasets/histopathology/check_lizard.py new file mode 100644 index 00000000..77d571a1 --- /dev/null +++ b/scripts/datasets/histopathology/check_lizard.py @@ -0,0 +1,24 @@ +import os +import sys + +from torch_em.data.datasets import get_lizard_loader +from torch_em.util.debug import check_loader + + +sys.path.append("..") + + +def check_lizard(): + from util import ROOT + + loader = get_lizard_loader( + path=os.path.join(ROOT, "lizard"), + patch_shape=(512, 512), + batch_size=1, + download=True + ) + check_loader(loader, 8, rgb=True, instance_labels=True) + + +if __name__ == "__main__": + check_lizard() diff --git a/scripts/datasets/check_monusac.py b/scripts/datasets/histopathology/check_monusac.py similarity index 100% rename from scripts/datasets/check_monusac.py rename to scripts/datasets/histopathology/check_monusac.py diff --git a/scripts/datasets/check_monuseg.py b/scripts/datasets/histopathology/check_monuseg.py similarity index 100% rename from scripts/datasets/check_monuseg.py rename to scripts/datasets/histopathology/check_monuseg.py diff --git a/scripts/datasets/check_pannuke.py b/scripts/datasets/histopathology/check_pannuke.py similarity index 100% rename from scripts/datasets/check_pannuke.py rename to scripts/datasets/histopathology/check_pannuke.py diff --git a/scripts/datasets/light_microscopy/check_cellpose.py b/scripts/datasets/light_microscopy/check_cellpose.py new file mode 100644 index 00000000..2001cca9 --- /dev/null +++ b/scripts/datasets/light_microscopy/check_cellpose.py @@ -0,0 +1,24 @@ +import os +import sys + +from torch_em.util.debug import check_loader +from torch_em.data.datasets import get_cellpose_loader + +sys.path.append("..") + + +def check_cellpose(): + from util import ROOT + + loader = get_cellpose_loader( + path=os.path.join(ROOT, "cellpose"), + split="train", + patch_shape=(512, 512), + batch_size=1, + choice=None, + ) + check_loader(loader, 8, instance_labels=True) + + +if __name__ == "__main__": + check_cellpose() diff --git a/scripts/datasets/light_microscopy/check_cellseg3d.py b/scripts/datasets/light_microscopy/check_cellseg3d.py new file mode 100644 index 00000000..ea635aa1 --- /dev/null +++ b/scripts/datasets/light_microscopy/check_cellseg3d.py @@ -0,0 +1,18 @@ +import os +import sys + +from torch_em.data.datasets import get_cellseg_3d_loader +from torch_em.util.debug import check_loader + +sys.path.append("..") + + +def check_cellseg_3d(): + from util import ROOT + + loader = get_cellseg_3d_loader(os.path.join(ROOT, "cellseg_3d"), (32, 256, 256), batch_size=1, download=True) + check_loader(loader, 8, instance_labels=True) + + +if __name__ == "__main__": + check_cellseg_3d() diff --git a/scripts/datasets/check_covid_if.py b/scripts/datasets/light_microscopy/check_covid_if.py similarity index 57% rename from scripts/datasets/check_covid_if.py rename to scripts/datasets/light_microscopy/check_covid_if.py index ac439865..4496bfd5 100644 --- a/scripts/datasets/check_covid_if.py +++ b/scripts/datasets/light_microscopy/check_covid_if.py @@ -1,9 +1,16 @@ +import os +import sys + from torch_em.data.datasets import get_covid_if_loader from torch_em.util.debug import check_loader +sys.path.append("..") + def check_covid_if(): - loader = get_covid_if_loader("./data/covid-if", (512, 512), 1, download=True) + from util import ROOT + + loader = get_covid_if_loader(os.path.join(ROOT, "covid-if"), (512, 512), 1, download=True) check_loader(loader, 8, instance_labels=True) diff --git a/scripts/datasets/check_ctc.py b/scripts/datasets/light_microscopy/check_ctc.py similarity index 68% rename from scripts/datasets/check_ctc.py rename to scripts/datasets/light_microscopy/check_ctc.py index 0d320c69..c94a91b9 100644 --- a/scripts/datasets/check_ctc.py +++ b/scripts/datasets/light_microscopy/check_ctc.py @@ -1,8 +1,11 @@ -from torch_em.data.datasets.ctc import get_ctc_segmentation_loader, CTC_CHECKSUMS +import os +import sys + +from torch_em.data.datasets.light_microscopy.ctc import get_ctc_segmentation_loader, CTC_CHECKSUMS from torch_em.util.debug import check_loader from torch_em.data.sampler import MinInstanceSampler -ROOT = "/home/anwai/data/ctc/" +sys.path.append("..") # Some of the datasets have partial sparse labels: @@ -10,11 +13,14 @@ # - Fluo-N2DL-HeLa # Maybe depends on the split?! def check_ctc_segmentation(split): + from util import ROOT, USE_NAPARI + + data_root = os.path.join(ROOT, "ctc") ctc_dataset_names = list(CTC_CHECKSUMS["train"].keys()) for name in ctc_dataset_names: print("Checking dataset", name) loader = get_ctc_segmentation_loader( - path=ROOT, + path=data_root, dataset_name=name, patch_shape=(1, 512, 512), batch_size=1, @@ -22,7 +28,7 @@ def check_ctc_segmentation(split): split=split, sampler=MinInstanceSampler() ) - check_loader(loader, 8, plt=True) + check_loader(loader, 8, plt=not USE_NAPARI, instance_labels=True) if __name__ == "__main__": diff --git a/scripts/datasets/check_deepbacs.py b/scripts/datasets/light_microscopy/check_deepbacs.py similarity index 63% rename from scripts/datasets/check_deepbacs.py rename to scripts/datasets/light_microscopy/check_deepbacs.py index 287b6530..fd5bab8a 100644 --- a/scripts/datasets/check_deepbacs.py +++ b/scripts/datasets/light_microscopy/check_deepbacs.py @@ -1,9 +1,16 @@ +import os +import sys + from torch_em.data.datasets import get_deepbacs_loader from torch_em.util.debug import check_loader +sys.path.append("..") + def check_deepbacs(): - loader = get_deepbacs_loader("./deepbacs", "test", bac_type="mixed", download=True, + from util import ROOT + + loader = get_deepbacs_loader(os.path.join(ROOT, "deepbacs"), "test", bac_type="mixed", download=True, patch_shape=(256, 256), batch_size=1, shuffle=True) check_loader(loader, 15, instance_labels=True) diff --git a/scripts/datasets/light_microscopy/check_dic_hepg2.py b/scripts/datasets/light_microscopy/check_dic_hepg2.py new file mode 100644 index 00000000..a91d1d40 --- /dev/null +++ b/scripts/datasets/light_microscopy/check_dic_hepg2.py @@ -0,0 +1,25 @@ +import os +import sys + +from torch_em.util.debug import check_loader +from torch_em.data.datasets import get_dic_hepg2_loader + + +sys.path.append("..") + + +def check_dic_hepg2(): + from util import ROOT + + loader = get_dic_hepg2_loader( + path=os.path.join(ROOT, "dic_hepg2"), + split="test", + patch_shape=(512, 512), + batch_size=2, + download=True, + ) + check_loader(loader, 8) + + +if __name__ == "__main__": + check_dic_hepg2() diff --git a/scripts/datasets/check_dsb.py b/scripts/datasets/light_microscopy/check_dsb.py similarity index 55% rename from scripts/datasets/check_dsb.py rename to scripts/datasets/light_microscopy/check_dsb.py index 1bac6f93..c0962247 100644 --- a/scripts/datasets/check_dsb.py +++ b/scripts/datasets/light_microscopy/check_dsb.py @@ -1,9 +1,16 @@ +import os +import sys + from torch_em.data.datasets import get_dsb_loader from torch_em.util.debug import check_loader +sys.path.append("..") + def check_dsb(): - loader = get_dsb_loader("./data/dsb", "train", (256, 256), 1, download=True) + from util import ROOT + + loader = get_dsb_loader(os.path.join(ROOT, "dsb"), "train", (256, 256), 1, download=True) check_loader(loader, 8, instance_labels=True) diff --git a/scripts/datasets/check_dynamicnuclearnet.py b/scripts/datasets/light_microscopy/check_dynamicnuclearnet.py similarity index 100% rename from scripts/datasets/check_dynamicnuclearnet.py rename to scripts/datasets/light_microscopy/check_dynamicnuclearnet.py diff --git a/scripts/datasets/light_microscopy/check_embedseg.py b/scripts/datasets/light_microscopy/check_embedseg.py new file mode 100644 index 00000000..b90aeb71 --- /dev/null +++ b/scripts/datasets/light_microscopy/check_embedseg.py @@ -0,0 +1,29 @@ +import os +import sys + +from torch_em.data.datasets import get_embedseg_loader +from torch_em.util.debug import check_loader + +sys.path.append("..") + + +def check_embedseg(): + from util import ROOT + + names = [ + "Mouse-Organoid-Cells-CBG", + "Mouse-Skull-Nuclei-CBG", + "Platynereis-ISH-Nuclei-CBG", + "Platynereis-Nuclei-CBG", + ] + + patch_shape = (32, 384, 384) + for name in names: + loader = get_embedseg_loader( + os.path.join(ROOT, "embedseg"), name=name, patch_shape=patch_shape, batch_size=1, download=True + ) + check_loader(loader, 2, instance_labels=True) + + +if __name__ == "__main__": + check_embedseg() diff --git a/scripts/datasets/light_microscopy/check_go_nuclear.py b/scripts/datasets/light_microscopy/check_go_nuclear.py new file mode 100644 index 00000000..176fdc03 --- /dev/null +++ b/scripts/datasets/light_microscopy/check_go_nuclear.py @@ -0,0 +1,27 @@ +import os +import sys + +from torch_em.data.datasets import get_gonuclear_loader +from torch_em.util.debug import check_loader + +sys.path.append("..") + + +def check_gonuclear(): + from util import ROOT + + patch_shape = (48, 384, 384) + + loader = get_gonuclear_loader( + os.path.join(ROOT, "gonuclear"), patch_shape, segmentation_task="nuclei", batch_size=1, download=True + ) + check_loader(loader, 5, instance_labels=True) + + loader = get_gonuclear_loader( + os.path.join(ROOT, "gonuclear"), patch_shape, segmentation_task="cells", batch_size=1, download=True + ) + check_loader(loader, 5, instance_labels=True) + + +if __name__ == "__main__": + check_gonuclear() diff --git a/scripts/datasets/light_microscopy/check_hpa.py b/scripts/datasets/light_microscopy/check_hpa.py new file mode 100644 index 00000000..20316905 --- /dev/null +++ b/scripts/datasets/light_microscopy/check_hpa.py @@ -0,0 +1,25 @@ +import os +import sys + +from torch_em.data.datasets import get_hpa_segmentation_loader +from torch_em.util.debug import check_loader + +sys.path.append("..") + + +def check_hpa(): + from util import ROOT + + loader = get_hpa_segmentation_loader( + path=os.path.join(ROOT, "hpa"), + split="train", + patch_shape=(1024, 1024), + batch_size=1, + channels=["protein", "er"], + download=True, + ) + check_loader(loader, 8, instance_labels=True) + + +if __name__ == "__main__": + check_hpa() diff --git a/scripts/datasets/check_livecell.py b/scripts/datasets/light_microscopy/check_livecell.py similarity index 51% rename from scripts/datasets/check_livecell.py rename to scripts/datasets/light_microscopy/check_livecell.py index bb2d15c9..63ce6990 100644 --- a/scripts/datasets/check_livecell.py +++ b/scripts/datasets/light_microscopy/check_livecell.py @@ -1,11 +1,17 @@ +import os +import sys + from torch_em.data.datasets import get_livecell_loader from torch_em.util.debug import check_loader -LIVECELL_ROOT = "/home/pape/Work/data/incu_cyte/livecell" +sys.path.append("..") def check_livecell(): - loader = get_livecell_loader(LIVECELL_ROOT, "train", (512, 512), 1) + from util import ROOT + + livecell_root = os.path.join(ROOT, "livecell") + loader = get_livecell_loader(livecell_root, "train", (512, 512), 1, download=True) check_loader(loader, 15, instance_labels=True) diff --git a/scripts/datasets/light_microscopy/check_mouse_embryo.py b/scripts/datasets/light_microscopy/check_mouse_embryo.py new file mode 100644 index 00000000..16148579 --- /dev/null +++ b/scripts/datasets/light_microscopy/check_mouse_embryo.py @@ -0,0 +1,22 @@ +import os +import sys + +from torch_em.data.datasets import get_mouse_embryo_loader +from torch_em.util.debug import check_loader + +sys.path.append("..") + + +def check_mouse_embryo(): + from util import ROOT + + data_root = os.path.join(ROOT, "mouse_embryo") + loader = get_mouse_embryo_loader(data_root, "nuclei", "train", (8, 512, 512), 1, download=True) + check_loader(loader, 8, instance_labels=True) + + loader = get_mouse_embryo_loader(data_root, "membrane", "train", (8, 512, 512), 1, download=True) + check_loader(loader, 8, instance_labels=True) + + +if __name__ == "__main__": + check_mouse_embryo() diff --git a/scripts/datasets/check_neurips_cellseg.py b/scripts/datasets/light_microscopy/check_neurips_cellseg.py similarity index 50% rename from scripts/datasets/check_neurips_cellseg.py rename to scripts/datasets/light_microscopy/check_neurips_cellseg.py index 32ac516c..620b1267 100644 --- a/scripts/datasets/check_neurips_cellseg.py +++ b/scripts/datasets/light_microscopy/check_neurips_cellseg.py @@ -1,16 +1,23 @@ +import os +import sys + from torch_em.data.datasets import ( get_neurips_cellseg_supervised_loader, get_neurips_cellseg_unsupervised_loader ) from torch_em.util.debug import check_loader -NEURIPS_ROOT = "/home/pape/Work/data/neurips-cell-seg" +sys.path.append("..") def check_neurips(): - loader = get_neurips_cellseg_supervised_loader(NEURIPS_ROOT, "train", (512, 512), 1) + from util import ROOT + + neurips_root = os.path.join(ROOT, "neurips-cell-seg") + + loader = get_neurips_cellseg_supervised_loader(neurips_root, "train", (512, 512), 1, download=True) check_loader(loader, 15, instance_labels=True, rgb=True) - loader = get_neurips_cellseg_unsupervised_loader(NEURIPS_ROOT, (512, 512), 1) + loader = get_neurips_cellseg_unsupervised_loader(neurips_root, (512, 512), 1, download=True) check_loader(loader, 15, rgb=True) diff --git a/scripts/datasets/light_microscopy/check_omnipose.py b/scripts/datasets/light_microscopy/check_omnipose.py new file mode 100644 index 00000000..0f9db7fd --- /dev/null +++ b/scripts/datasets/light_microscopy/check_omnipose.py @@ -0,0 +1,28 @@ +import os +import sys + +from torch_em.util.debug import check_loader +from torch_em.data import MinInstanceSampler +from torch_em.data.datasets import get_omnipose_loader + +sys.path.append("..") + + +def check_omnipose(): + from util import ROOT + + loader = get_omnipose_loader( + path=os.path.join(ROOT, "omnipose"), + batch_size=1, + patch_shape=(1024, 1024), + split="train", + data_choice=None, + sampler=MinInstanceSampler(), + shuffle=True, + download=True, + ) + check_loader(loader, 8, instance_labels=True) + + +if __name__ == "__main__": + check_omnipose() diff --git a/scripts/datasets/light_microscopy/check_orgasegment.py b/scripts/datasets/light_microscopy/check_orgasegment.py new file mode 100644 index 00000000..c49d45a3 --- /dev/null +++ b/scripts/datasets/light_microscopy/check_orgasegment.py @@ -0,0 +1,24 @@ +from torch_em.util.debug import check_loader +from torch_em.data.datasets.light_microscopy import get_orgasegment_loader + + +ROOT = "/media/anwai/ANWAI/data/orgasegment" + + +def check_orgasegment(): + loader = get_orgasegment_loader( + path=ROOT, + split="val", + patch_shape=(512, 512), + batch_size=1, + download=True, + ) + check_loader(loader, 8, instance_labels=True) + + +def main(): + check_orgasegment() + + +if __name__ == "__main__": + main() diff --git a/scripts/datasets/light_microscopy/check_plantseg.py b/scripts/datasets/light_microscopy/check_plantseg.py new file mode 100644 index 00000000..e8683e9b --- /dev/null +++ b/scripts/datasets/light_microscopy/check_plantseg.py @@ -0,0 +1,32 @@ +import os +import sys + +from torch_em.data.datasets import get_plantseg_loader +from torch_em.util.debug import check_loader + +sys.path.append("..") + + +def check_plantseg(): + from util import ROOT + + plantseg_root = os.path.join(ROOT, "plantseg") + + loader = get_plantseg_loader( + plantseg_root, name="root", split="train", patch_shape=(8, 512, 512), batch_size=1, download=True + ) + check_loader(loader, 8, instance_labels=True) + + loader = get_plantseg_loader( + plantseg_root, name="ovules", split="train", patch_shape=(8, 512, 512), batch_size=1, download=True + ) + check_loader(loader, 8, instance_labels=True) + + loader = get_plantseg_loader( + plantseg_root, name="nuclei", split="train", patch_shape=(8, 512, 512), batch_size=1, download=True + ) + check_loader(loader, 8, instance_labels=True) + + +if __name__ == "__main__": + check_plantseg() diff --git a/scripts/datasets/check_tissuenet.py b/scripts/datasets/light_microscopy/check_tissuenet.py similarity index 80% rename from scripts/datasets/check_tissuenet.py rename to scripts/datasets/light_microscopy/check_tissuenet.py index bdfaa58f..2f70465b 100644 --- a/scripts/datasets/check_tissuenet.py +++ b/scripts/datasets/light_microscopy/check_tissuenet.py @@ -1,10 +1,13 @@ +import os +import sys + import numpy as np from torch_em.transform.raw import standardize, normalize_percentile from torch_em.data.datasets import get_tissuenet_loader from torch_em.util.debug import check_loader -TISSUENET_ROOT = "/home/pape/Work/data/tissuenet" +sys.path.append("..") def raw_trafo(raw): @@ -17,9 +20,11 @@ def raw_trafo(raw): # NOTE: the tissuenet data cannot be downloaded automatically. # you need to download it yourself from https://datasets.deepcell.org/data def check_tissuenet(): - # set this path to where you have downloaded the tissuenet data + from util import ROOT + + tissuenet_root = os.path.join(ROOT, "tissuenet") loader = get_tissuenet_loader( - TISSUENET_ROOT, "train", raw_channel="rgb", label_channel="cell", + tissuenet_root, "train", raw_channel="rgb", label_channel="cell", patch_shape=(512, 512), batch_size=1, shuffle=True, raw_transform=raw_trafo ) diff --git a/scripts/datasets/light_microscopy/check_vgg_hela.py b/scripts/datasets/light_microscopy/check_vgg_hela.py new file mode 100644 index 00000000..69d7c5c2 --- /dev/null +++ b/scripts/datasets/light_microscopy/check_vgg_hela.py @@ -0,0 +1,18 @@ +import os +import sys + +from torch_em.data.datasets import get_vgg_hela_loader +from torch_em.util.debug import check_loader + +sys.path.append("..") + + +def check_vgg_hela(): + from util import ROOT + + loader = get_vgg_hela_loader(os.path.join(ROOT, "hela-vgg"), "train", (1, 256, 256), 1, download=True) + check_loader(loader, 8, instance_labels=True) + + +if __name__ == "__main__": + check_vgg_hela() diff --git a/scripts/datasets/medical/check_acdc.py b/scripts/datasets/medical/check_acdc.py new file mode 100644 index 00000000..81a073ed --- /dev/null +++ b/scripts/datasets/medical/check_acdc.py @@ -0,0 +1,23 @@ +from torch_em.util.debug import check_loader +from torch_em.data.datasets.medical import get_acdc_loader +from torch_em.data import MinInstanceSampler + + +ROOT = "/media/anwai/ANWAI/data/acdc" + + +def check_acdc(): + loader = get_acdc_loader( + path=ROOT, + patch_shape=(4, 256, 256), + batch_size=2, + split="train", + download=True, + sampler=MinInstanceSampler(min_num_instances=4), + ) + + check_loader(loader, 8) + + +if __name__ == "__main__": + check_acdc() diff --git a/scripts/datasets/medical/check_acouslic_ai.py b/scripts/datasets/medical/check_acouslic_ai.py new file mode 100644 index 00000000..ab634be1 --- /dev/null +++ b/scripts/datasets/medical/check_acouslic_ai.py @@ -0,0 +1,29 @@ +import os +import sys + +from torch_em.data import MinInstanceSampler +from torch_em.util.debug import check_loader +from torch_em.data.datasets.medical import get_acouslic_ai_loader + + +sys.path.append("..") + + +def check_acouslic_ai(): + from util import ROOT + + loader = get_acouslic_ai_loader( + path=os.path.join(ROOT, "acouslic_ai"), + patch_shape=(1, 512, 512), + ndim=2, + batch_size=1, + resize_inputs=False, + download=True, + sampler=MinInstanceSampler(), + ) + + check_loader(loader, 8, plt=True, save_path="./test.png") + + +if __name__ == "__main__": + check_acouslic_ai() diff --git a/scripts/datasets/medical/check_amos.py b/scripts/datasets/medical/check_amos.py new file mode 100644 index 00000000..98357854 --- /dev/null +++ b/scripts/datasets/medical/check_amos.py @@ -0,0 +1,24 @@ +from torch_em.util.debug import check_loader +from torch_em.data import MinInstanceSampler +from torch_em.data.datasets.medical import get_amos_loader + +ROOT = "/media/anwai/ANWAI/data/amos" + + +def check_amos(): + loader = get_amos_loader( + path=ROOT, + split="train", + patch_shape=(1, 512, 512), + modality="mri", + ndim=2, + batch_size=2, + download=True, + sampler=MinInstanceSampler(min_num_instances=3), + resize_inputs=False, + ) + check_loader(loader, 8) + + +if __name__ == "__main__": + check_amos() diff --git a/scripts/datasets/check_autopet.py b/scripts/datasets/medical/check_autopet.py similarity index 100% rename from scripts/datasets/check_autopet.py rename to scripts/datasets/medical/check_autopet.py diff --git a/scripts/datasets/check_btcv.py b/scripts/datasets/medical/check_btcv.py similarity index 100% rename from scripts/datasets/check_btcv.py rename to scripts/datasets/medical/check_btcv.py diff --git a/scripts/datasets/medical/check_camus.py b/scripts/datasets/medical/check_camus.py new file mode 100644 index 00000000..79fb7f87 --- /dev/null +++ b/scripts/datasets/medical/check_camus.py @@ -0,0 +1,21 @@ +from torch_em.util.debug import check_loader +from torch_em.data.datasets.medical import get_camus_loader + + +ROOT = "/media/anwai/ANWAI/data/camus" + + +def check_camus(): + loader = get_camus_loader( + path=ROOT, + patch_shape=(1, 512, 512), + batch_size=2, + chamber=2, + resize_inputs=True, + download=True, + ) + check_loader(loader, 8) + + +if __name__ == "__main__": + check_camus() diff --git a/scripts/datasets/medical/check_cbis_ddsm.py b/scripts/datasets/medical/check_cbis_ddsm.py new file mode 100644 index 00000000..ca305183 --- /dev/null +++ b/scripts/datasets/medical/check_cbis_ddsm.py @@ -0,0 +1,24 @@ +from torch_em.data import MinInstanceSampler +from torch_em.util.debug import check_loader +from torch_em.data.datasets.medical import get_cbis_ddsm_loader + + +ROOT = "/media/anwai/ANWAI/data/cbis_ddsm" + + +def check_cbis_ddsm(): + loader = get_cbis_ddsm_loader( + path=ROOT, + patch_shape=(512, 512), + batch_size=2, + split="Train", + task=None, + tumour_type=None, + resize_inputs=True, + sampler=MinInstanceSampler() + ) + check_loader(loader, 8) + + +if __name__ == "__main__": + check_cbis_ddsm() diff --git a/scripts/datasets/medical/check_cholecseg8k.py b/scripts/datasets/medical/check_cholecseg8k.py new file mode 100644 index 00000000..41571080 --- /dev/null +++ b/scripts/datasets/medical/check_cholecseg8k.py @@ -0,0 +1,21 @@ +from torch_em.util.debug import check_loader +from torch_em.data.datasets.medical import get_cholecseg8k_loader + + +ROOT = "/media/anwai/ANWAI/data/cholecseg8k" + + +def get_cholecseg8k(): + loader = get_cholecseg8k_loader( + path=ROOT, + patch_shape=(512, 512), + batch_size=2, + split="train", + resize_inputs=True, + download=False, + ) + check_loader(loader, 8) + + +if __name__ == "__main__": + get_cholecseg8k() diff --git a/scripts/datasets/medical/check_covid19_seg.py b/scripts/datasets/medical/check_covid19_seg.py new file mode 100644 index 00000000..92111d71 --- /dev/null +++ b/scripts/datasets/medical/check_covid19_seg.py @@ -0,0 +1,23 @@ +from torch_em.util.debug import check_loader +from torch_em.data import MinInstanceSampler +from torch_em.data.datasets.medical import get_covid19_seg_loader + + +ROOT = "/media/anwai/ANWAI/data/covid19_seg" + + +def check_covid19_seg(): + loader = get_covid19_seg_loader( + path=ROOT, + patch_shape=(32, 512, 512), + batch_size=2, + task="lung", + download=True, + sampler=MinInstanceSampler(), + ) + + check_loader(loader, 8) + + +if __name__ == "__main__": + check_covid19_seg() diff --git a/scripts/datasets/medical/check_curvas.py b/scripts/datasets/medical/check_curvas.py new file mode 100644 index 00000000..9dbd1b38 --- /dev/null +++ b/scripts/datasets/medical/check_curvas.py @@ -0,0 +1,30 @@ +import os +import sys + +from torch_em.data import MinInstanceSampler +from torch_em.util.debug import check_loader +from torch_em.data.datasets.medical import get_curvas_loader + + +sys.path.append("..") + + +def check_curvas(): + from util import ROOT + + loader = get_curvas_loader( + path=os.path.join(ROOT, "curvas"), + patch_shape=(1, 512, 512), + batch_size=2, + ndim=2, + rater="1", + resize_inputs=False, + download=True, + sampler=MinInstanceSampler() + ) + + check_loader(loader, 8, plt=True, save_path="./test.png") + + +if __name__ == "__main__": + check_curvas() diff --git a/scripts/datasets/medical/check_dca1.py b/scripts/datasets/medical/check_dca1.py new file mode 100644 index 00000000..428329cc --- /dev/null +++ b/scripts/datasets/medical/check_dca1.py @@ -0,0 +1,21 @@ +from torch_em.util.debug import check_loader +from torch_em.data.datasets.medical import get_dca1_loader + + +ROOT = "/media/anwai/ANWAI/data/dca1" + + +def check_dca1(): + loader = get_dca1_loader( + path=ROOT, + patch_shape=(512, 512), + batch_size=2, + split="test", + resize_inputs=True, + download=False, + ) + check_loader(loader, 8) + + +if __name__ == "__main__": + check_dca1() diff --git a/scripts/datasets/medical/check_duke_liver.py b/scripts/datasets/medical/check_duke_liver.py new file mode 100644 index 00000000..3668aed8 --- /dev/null +++ b/scripts/datasets/medical/check_duke_liver.py @@ -0,0 +1,23 @@ +from torch_em.util.debug import check_loader +from torch_em.data.datasets.medical import get_duke_liver_loader + + +ROOT = "/media/anwai/ANWAI/data/duke_liver" + + +def check_duke_liver(): + from micro_sam.training import identity + loader = get_duke_liver_loader( + path=ROOT, + patch_shape=(32, 512, 512), + batch_size=2, + split="train", + download=False, + raw_transform=identity, + + ) + check_loader(loader, 8) + + +if __name__ == "__main__": + check_duke_liver() diff --git a/scripts/datasets/medical/check_feta24.py b/scripts/datasets/medical/check_feta24.py new file mode 100644 index 00000000..83e70645 --- /dev/null +++ b/scripts/datasets/medical/check_feta24.py @@ -0,0 +1,21 @@ +from torch_em.util.debug import check_loader +from torch_em.data.datasets.medical import get_feta24_loader + + +ROOT = "/media/anwai/ANWAI/data/feta24" + + +def check_feta24(): + loader = get_feta24_loader( + path=ROOT, + patch_shape=(1, 512, 512), + batch_size=2, + resize_inputs=True, + download=False, + ) + + check_loader(loader, 8) + + +if __name__ == "__main__": + check_feta24() diff --git a/scripts/datasets/medical/check_han_seg.py b/scripts/datasets/medical/check_han_seg.py new file mode 100644 index 00000000..dd12fcad --- /dev/null +++ b/scripts/datasets/medical/check_han_seg.py @@ -0,0 +1,20 @@ +from torch_em.util.debug import check_loader +from torch_em.data.datasets.medical import get_han_seg_loader + + +ROOT = "/media/anwai/ANWAI/data/han-seg/" + + +def check_han_seg(): + loader = get_han_seg_loader( + path=ROOT, + patch_shape=(32, 512, 512), + batch_size=2, + download=False, + ) + + check_loader(loader, 8) + + +if __name__ == "__main__": + check_han_seg() diff --git a/scripts/datasets/medical/check_hil_toothseg.py b/scripts/datasets/medical/check_hil_toothseg.py new file mode 100644 index 00000000..ff016b0b --- /dev/null +++ b/scripts/datasets/medical/check_hil_toothseg.py @@ -0,0 +1,25 @@ +import os +import sys + +from torch_em.util.debug import check_loader +from torch_em.data.datasets.medical import get_hil_toothseg_loader + + +sys.path.append("..") + + +def check_hil_toothseg(): + from util import ROOT + + loader = get_hil_toothseg_loader( + path=os.path.join(ROOT, "hil_toothseg"), + patch_shape=(512, 512), + batch_size=2, + split="train", + resize_inputs=False, + ) + check_loader(loader, 8) + + +if __name__ == "__main__": + check_hil_toothseg() diff --git a/scripts/datasets/medical/check_idrid.py b/scripts/datasets/medical/check_idrid.py new file mode 100644 index 00000000..b2e2d360 --- /dev/null +++ b/scripts/datasets/medical/check_idrid.py @@ -0,0 +1,23 @@ +from torch_em.util.debug import check_loader +from torch_em.data.datasets.medical import get_idrid_loader + + +ROOT = "/media/anwai/ANWAI/data/idrid" + + +def check_idrid(): + loader = get_idrid_loader( + path=ROOT, + patch_shape=(512, 512), + batch_size=2, + split="train", + task="optic_disc", + resize_inputs=True, + download=True, + ) + + check_loader(loader, 8) + + +if __name__ == "__main__": + check_idrid() diff --git a/scripts/datasets/medical/check_isic.py b/scripts/datasets/medical/check_isic.py new file mode 100644 index 00000000..c7a77aee --- /dev/null +++ b/scripts/datasets/medical/check_isic.py @@ -0,0 +1,22 @@ +from torch_em.util.debug import check_loader +from torch_em.data.datasets.medical import get_isic_loader + + +ROOT = "/scratch/share/cidas/cca/data/isic" + + +def check_isic(): + loader = get_isic_loader( + path=ROOT, + patch_shape=(700, 700), + batch_size=2, + split="test", + download=True, + resize_inputs=True, + ) + + check_loader(loader, 8, plt=True, save_path="./isic.png") + + +if __name__ == "__main__": + check_isic() diff --git a/scripts/datasets/medical/check_jnuifm.py b/scripts/datasets/medical/check_jnuifm.py new file mode 100644 index 00000000..33f75447 --- /dev/null +++ b/scripts/datasets/medical/check_jnuifm.py @@ -0,0 +1,21 @@ +from torch_em.util.debug import check_loader +from torch_em.data.datasets.medical import get_jnuifm_loader + + +ROOT = "/media/anwai/ANWAI/data/jnu-ifm" + + +def check_jnuifm(): + loader = get_jnuifm_loader( + path=ROOT, + patch_shape=(512, 512), + batch_size=2, + resize_inputs=True, + download=True, + ) + + check_loader(loader, 8) + + +if __name__ == "__main__": + check_jnuifm() diff --git a/scripts/datasets/medical/check_leg_3d_us.py b/scripts/datasets/medical/check_leg_3d_us.py new file mode 100644 index 00000000..aa0ad818 --- /dev/null +++ b/scripts/datasets/medical/check_leg_3d_us.py @@ -0,0 +1,26 @@ +import os +import sys + +from torch_em.util.debug import check_loader +from torch_em.data.datasets import get_leg_3d_us_loader + + +sys.path.append("..") + + +def check_leg_3d_us(): + from util import ROOT + + loader = get_leg_3d_us_loader( + path=os.path.join(ROOT, "leg_3d_us"), + patch_shape=(1, 512, 512), + batch_size=1, + split="train", + ndim=2, + download=True, + ) + check_loader(loader, 8, plt=True, save_path="./test.png") + + +if __name__ == "__main__": + check_leg_3d_us() diff --git a/scripts/datasets/medical/check_lgg_mri.py b/scripts/datasets/medical/check_lgg_mri.py new file mode 100644 index 00000000..4ee64dba --- /dev/null +++ b/scripts/datasets/medical/check_lgg_mri.py @@ -0,0 +1,24 @@ +import os +import sys + +from torch_em.util.debug import check_loader +from torch_em.data.datasets import get_lgg_mri_loader + + +sys.path.append("..") + + +def check_lgg_mri(): + from util import ROOT + + loader = get_lgg_mri_loader( + path=os.path.join(ROOT, "lgg_mri"), + patch_shape=(8, 512, 512), + batch_size=1, + download=True, + ) + check_loader(loader, 8) + + +if __name__ == "__main__": + check_lgg_mri() diff --git a/scripts/datasets/medical/check_m2caiseg.py b/scripts/datasets/medical/check_m2caiseg.py new file mode 100644 index 00000000..9853f773 --- /dev/null +++ b/scripts/datasets/medical/check_m2caiseg.py @@ -0,0 +1,21 @@ +from torch_em.util.debug import check_loader +from torch_em.data.datasets.medical import get_m2caiseg_loader + + +ROOT = "/media/anwai/ANWAI/data/m2caiseg" + + +def check_m2caiseg(): + loader = get_m2caiseg_loader( + path=ROOT, + split="train", + patch_shape=(512, 512), + batch_size=2, + resize_inputs=True, + download=True, + ) + check_loader(loader, 8) + + +if __name__ == "__main__": + check_m2caiseg() diff --git a/scripts/datasets/medical/check_mbh_seg.py b/scripts/datasets/medical/check_mbh_seg.py new file mode 100644 index 00000000..e01aa651 --- /dev/null +++ b/scripts/datasets/medical/check_mbh_seg.py @@ -0,0 +1,29 @@ +import os +import sys + +from torch_em.data import MinInstanceSampler +from torch_em.util.debug import check_loader +from torch_em.data.datasets.medical import get_mbh_seg_loader + + +sys.path.append("..") + + +def check_mbh_seg(): + from util import ROOT + + loader = get_mbh_seg_loader( + path=os.path.join(ROOT, "mbh_seg"), + patch_shape=(1, 512, 512), + ndim=2, + batch_size=2, + resize_inputs=False, + download=True, + sampler=MinInstanceSampler(), + ) + + check_loader(loader, 8) + + +if __name__ == "__main__": + check_mbh_seg() diff --git a/scripts/datasets/medical/check_micro_usp.py b/scripts/datasets/medical/check_micro_usp.py new file mode 100644 index 00000000..9ea3ed56 --- /dev/null +++ b/scripts/datasets/medical/check_micro_usp.py @@ -0,0 +1,24 @@ +from torch_em.data import MinInstanceSampler +from torch_em.util.debug import check_loader +from torch_em.data.datasets.medical import get_micro_usp_loader + + +ROOT = "/media/anwai/ANWAI/data/micro-usp" + + +def check_micro_usp(): + loader = get_micro_usp_loader( + path=ROOT, + patch_shape=(1, 512, 512), + batch_size=2, + split="train", + resize_inputs=True, + download=True, + sampler=MinInstanceSampler(), + ) + + check_loader(loader, 8) + + +if __name__ == "__main__": + check_micro_usp() diff --git a/scripts/datasets/medical/check_montgomery.py b/scripts/datasets/medical/check_montgomery.py new file mode 100644 index 00000000..19c546eb --- /dev/null +++ b/scripts/datasets/medical/check_montgomery.py @@ -0,0 +1,21 @@ +from torch_em.util.debug import check_loader +from torch_em.data.datasets.medical import get_montgomery_loader + + +ROOT = "/media/anwai/ANWAI/data/montgomery" + + +def check_montgomery(): + loader = get_montgomery_loader( + path=ROOT, + patch_shape=(512, 512), + batch_size=2, + resize_inputs=True, + download=True, + ) + + check_loader(loader, 8) + + +if __name__ == "__main__": + check_montgomery() diff --git a/scripts/datasets/medical/check_msd.py b/scripts/datasets/medical/check_msd.py new file mode 100644 index 00000000..dd483e01 --- /dev/null +++ b/scripts/datasets/medical/check_msd.py @@ -0,0 +1,22 @@ +from torch_em.util.debug import check_loader +from torch_em.data.datasets.medical import get_msd_loader + + +MSD_ROOT = "/media/anwai/ANWAI/data/msd" + + +def check_msd(): + loader = get_msd_loader( + path=MSD_ROOT, + patch_shape=(1, 512, 512), + batch_size=2, + ndim=2, + download=True, + task_names="braintumour", + ) + print(f"Length of the loader: {len(loader)}") + check_loader(loader, 8) + + +if __name__ == "__main__": + check_msd() diff --git a/scripts/datasets/medical/check_oasis.py b/scripts/datasets/medical/check_oasis.py new file mode 100644 index 00000000..d568e550 --- /dev/null +++ b/scripts/datasets/medical/check_oasis.py @@ -0,0 +1,24 @@ +import os +import sys + +from torch_em.util.debug import check_loader +from torch_em.data.datasets import get_oasis_loader + + +sys.path.append("..") + + +def check_oasis(): + from util import ROOT + + loader = get_oasis_loader( + path=os.path.join(ROOT, "oasis"), + patch_shape=(8, 512, 512), + batch_size=1, + download=True, + ) + check_loader(loader, 8) + + +if __name__ == "__main__": + check_oasis() diff --git a/scripts/datasets/medical/check_oimhs.py b/scripts/datasets/medical/check_oimhs.py new file mode 100644 index 00000000..feea6d6f --- /dev/null +++ b/scripts/datasets/medical/check_oimhs.py @@ -0,0 +1,22 @@ +from torch_em.util.debug import check_loader +from torch_em.data.datasets.medical import get_oimhs_loader + + +ROOT = "/scratch/share/cidas/cca/data/oimhs" + + +def check_oimhs(): + loader = get_oimhs_loader( + path=ROOT, + patch_shape=(512, 512), + batch_size=2, + split="test", + download=False, + resize_inputs=True, + ) + + check_loader(loader, 8, plt=True, save_path="./oimhs.png") + + +if __name__ == "__main__": + check_oimhs() diff --git a/scripts/datasets/medical/check_osic_pulmofib.py b/scripts/datasets/medical/check_osic_pulmofib.py new file mode 100644 index 00000000..5ae7ff67 --- /dev/null +++ b/scripts/datasets/medical/check_osic_pulmofib.py @@ -0,0 +1,22 @@ +from torch_em.util.debug import check_loader +from torch_em.data.datasets.medical import get_osic_pulmofib_loader + + +ROOT = "/media/anwai/ANWAI/data/osic_pulmofib" + + +def check_osic_pulmofib(): + loader = get_osic_pulmofib_loader( + path=ROOT, + patch_shape=(4, 256, 256), + ndim=3, + batch_size=2, + resize_inputs=True, + download=True, + ) + + check_loader(loader, 8) + + +if __name__ == "__main__": + check_osic_pulmofib() diff --git a/scripts/datasets/medical/check_panorama.py b/scripts/datasets/medical/check_panorama.py new file mode 100644 index 00000000..899f3d43 --- /dev/null +++ b/scripts/datasets/medical/check_panorama.py @@ -0,0 +1,25 @@ +import os +import sys + +from torch_em.util.debug import check_loader +from torch_em.data.datasets import get_panorama_loader + + +sys.path.append("..") + + +def check_panorama(): + from util import ROOT + + loader = get_panorama_loader( + path=os.path.join(ROOT, "panorama"), + patch_shape=(1, 512, 512), + batch_size=1, + ndim=2, + download=True, + ) + check_loader(loader, 8) + + +if __name__ == "__main__": + check_panorama() diff --git a/scripts/datasets/medical/check_papila.py b/scripts/datasets/medical/check_papila.py new file mode 100644 index 00000000..637c6747 --- /dev/null +++ b/scripts/datasets/medical/check_papila.py @@ -0,0 +1,23 @@ +from torch_em.util.debug import check_loader +from torch_em.data.datasets.medical import get_papila_loader + + +ROOT = "/scratch/share/cidas/cca/data/papila" + + +def check_papila(): + loader = get_papila_loader( + path=ROOT, + patch_shape=(256, 256), + batch_size=2, + resize_inputs=True, + task="cup", + expert_choice="exp1", + download=True, + ) + + check_loader(loader, 8, plt=True, save_path="./papila.png") + + +if __name__ == "__main__": + check_papila() diff --git a/scripts/datasets/medical/check_pengwin.py b/scripts/datasets/medical/check_pengwin.py new file mode 100644 index 00000000..6e153cef --- /dev/null +++ b/scripts/datasets/medical/check_pengwin.py @@ -0,0 +1,29 @@ +import os +import sys + +from torch_em.data import MinInstanceSampler +from torch_em.util.debug import check_loader +from torch_em.data.datasets.medical import get_pengwin_loader + + +sys.path.append("..") + + +def check_pengwin(): + from util import ROOT + + loader = get_pengwin_loader( + path=os.path.join(ROOT, "pengwin"), + patch_shape=(1, 512, 512), + batch_size=2, + modality="X-Ray", + resize_inputs=False, + download=True, + sampler=MinInstanceSampler(), + ) + + check_loader(loader, 8) + + +if __name__ == "__main__": + check_pengwin() diff --git a/scripts/datasets/medical/check_piccolo.py b/scripts/datasets/medical/check_piccolo.py new file mode 100644 index 00000000..7d43313f --- /dev/null +++ b/scripts/datasets/medical/check_piccolo.py @@ -0,0 +1,20 @@ +from torch_em.util.debug import check_loader +from torch_em.data.datasets.medical import get_piccolo_loader + + +ROOT = "/media/anwai/ANWAI/data/piccolo" + + +def check_piccolo(): + loader = get_piccolo_loader( + path=ROOT, + patch_shape=(512, 512), + batch_size=2, + split="train", + resize_inputs=True, + ) + check_loader(loader, 8) + + +if __name__ == "__main__": + check_piccolo() diff --git a/scripts/datasets/medical/check_plethora.py b/scripts/datasets/medical/check_plethora.py new file mode 100644 index 00000000..53e4c154 --- /dev/null +++ b/scripts/datasets/medical/check_plethora.py @@ -0,0 +1,24 @@ +from torch_em.data import MinInstanceSampler +from torch_em.util.debug import check_loader +from torch_em.data.datasets.medical import get_plethora_loader + + +ROOT = "/media/anwai/ANWAI/data/plethora" + + +def check_plethora(): + loader = get_plethora_loader( + path=ROOT, + task="thoracic", + patch_shape=(1, 512, 512), + batch_size=2, + resize_inputs=True, + download=True, + sampler=MinInstanceSampler(), + ) + + check_loader(loader, 8) + + +if __name__ == "__main__": + check_plethora() diff --git a/scripts/datasets/medical/check_sa_med2d.py b/scripts/datasets/medical/check_sa_med2d.py new file mode 100644 index 00000000..33f5627a --- /dev/null +++ b/scripts/datasets/medical/check_sa_med2d.py @@ -0,0 +1,25 @@ +from torch_em.util.debug import check_loader +from torch_em.data.datasets.medical import get_sa_med2d_loader + + +ROOT = "/scratch/share/cidas/cca/data/sa-med2d" + + +def check_sa_med2d(): + loader = get_sa_med2d_loader( + path=ROOT, + patch_shape=(512, 512), + split="train", + batch_size=2, + resize_inputs=True, + exclude_dataset=None, + exclude_modality=None, + download=False, + num_workers=32, + ) + + check_loader(loader, 8, plt=True, save_path="./sa-med2d.png") + + +if __name__ == "__main__": + check_sa_med2d() diff --git a/scripts/datasets/medical/check_sega.py b/scripts/datasets/medical/check_sega.py new file mode 100644 index 00000000..2420f2d2 --- /dev/null +++ b/scripts/datasets/medical/check_sega.py @@ -0,0 +1,23 @@ +from torch_em.data import MinInstanceSampler +from torch_em.util.debug import check_loader +from torch_em.data.datasets.medical import get_sega_loader + + +ROOT = "/media/anwai/ANWAI/data/sega" + + +def check_sega(): + loader = get_sega_loader( + path=ROOT, + patch_shape=(32, 512, 512), + batch_size=2, + data_choice="KiTS", + download=True, + sampler=MinInstanceSampler(), + ) + + check_loader(loader, 8) + + +if __name__ == "__main__": + check_sega() diff --git a/scripts/datasets/medical/check_segthy.py b/scripts/datasets/medical/check_segthy.py new file mode 100644 index 00000000..8433fe12 --- /dev/null +++ b/scripts/datasets/medical/check_segthy.py @@ -0,0 +1,26 @@ +import os +import sys + +from torch_em.util.debug import check_loader +from torch_em.data.datasets import get_segthy_loader + + +sys.path.append("..") + + +def check_segthy(): + from util import ROOT + + loader = get_segthy_loader( + path=os.path.join(ROOT, "segthy"), + patch_shape=(1, 512, 512), + batch_size=1, + source="MRI", + ndim=2, + download=True, + ) + check_loader(loader, 8, plt=True, save_path="./test.png") + + +if __name__ == "__main__": + check_segthy() diff --git a/scripts/datasets/medical/check_siim_acr.py b/scripts/datasets/medical/check_siim_acr.py new file mode 100644 index 00000000..1f3df848 --- /dev/null +++ b/scripts/datasets/medical/check_siim_acr.py @@ -0,0 +1,23 @@ +from torch_em.util.debug import check_loader +from torch_em.data import MinInstanceSampler +from torch_em.data.datasets.medical import get_siim_acr_loader + + +ROOT = "/scratch/share/cidas/cca/data/siim_acr" + + +def check_siim_acr(): + loader = get_siim_acr_loader( + path=ROOT, + split="train", + patch_shape=(512, 512), + batch_size=2, + download=True, + resize_inputs=True, + sampler=MinInstanceSampler() + ) + check_loader(loader, 8, plt=True, save_path="./siim_acr.png") + + +if __name__ == "__main__": + check_siim_acr() diff --git a/scripts/datasets/medical/check_spider.py b/scripts/datasets/medical/check_spider.py new file mode 100644 index 00000000..3a4421fc --- /dev/null +++ b/scripts/datasets/medical/check_spider.py @@ -0,0 +1,20 @@ +from torch_em.util.debug import check_loader +from torch_em.data import MinInstanceSampler +from torch_em.data.datasets.medical import get_spider_loader + + +ROOT = "/media/anwai/ANWAI/data/spider" + + +def check_spider(): + loader = get_spider_loader( + path=ROOT, + patch_shape=(1, 512, 512), + batch_size=2, + sampler=MinInstanceSampler() + ) + + check_loader(loader, 8) + + +check_spider() diff --git a/scripts/datasets/medical/check_tcia.py b/scripts/datasets/medical/check_tcia.py new file mode 100644 index 00000000..466773a1 --- /dev/null +++ b/scripts/datasets/medical/check_tcia.py @@ -0,0 +1,77 @@ +import os +import requests +from glob import glob +from natsort import natsorted + +import numpy as np +import pandas as pd +import nibabel as nib +import pydicom as dicom + +from tcia_utils import nbia + + +ROOT = "/media/anwai/ANWAI/data/tmp/" + +TCIA_URL = "https://wiki.cancerimagingarchive.net/download/attachments/68551327/NSCLC-Radiomics-OriginalCTs.tcia" + + +def check_tcia(download): + trg_path = os.path.join(ROOT, os.path.split(TCIA_URL)[-1]) + if download: + # output = nbia.getSeries(collection="LIDC-IDRI") + # nbia.downloadSeries(output, number=3, path=ROOT) + + manifest = requests.get(TCIA_URL) + with open(trg_path, 'wb') as f: + f.write(manifest.content) + + nbia.downloadSeries( + series_data=trg_path, input_type="manifest", number=3, path=ROOT, csv_filename="save" + ) + + df = pd.read_csv("save.csv") + + all_patient_dirs = glob(os.path.join(ROOT, "*")) + for patient_dir in all_patient_dirs: + patient_id = os.path.split(patient_dir)[-1] + if not patient_id.startswith("1.3"): + continue + + subject_id = pd.Series.to_string(df.loc[df["Series UID"] == patient_id]["Subject ID"])[-9:] + seg_path = glob(os.path.join(ROOT, "Thoracic_Cavities", subject_id, "*_primary_reviewer.nii.gz"))[0] + gt = nib.load(seg_path) + gt = gt.get_fdata() + gt = gt.transpose(2, 1, 0) + gt = np.flip(gt, axis=(0, 1)) + + all_dicom_files = natsorted(glob(os.path.join(patient_dir, "*.dcm"))) + samples = [] + for dcm_fpath in all_dicom_files: + file = dicom.dcmread(dcm_fpath) + img = file.pixel_array + samples.append(img) + + samples = np.stack(samples) + + import napari + + v = napari.Viewer() + v.add_image(samples) + v.add_labels(gt.astype("uint64")) + napari.run() + + +def _test_me(): + data = nbia.getSeries(collection="Soft-tissue-Sarcoma") + print(data) + + nbia.downloadSeries(data, number=3) + + seriesUid = "1.3.6.1.4.1.14519.5.2.1.5168.1900.104193299251798317056218297018" + nbia.viewSeries(seriesUid) + + +if __name__ == "__main__": + # _test_me() + check_tcia(download=True) diff --git a/scripts/datasets/medical/check_toothfairy.py b/scripts/datasets/medical/check_toothfairy.py new file mode 100644 index 00000000..9b08fa93 --- /dev/null +++ b/scripts/datasets/medical/check_toothfairy.py @@ -0,0 +1,28 @@ +import os +import sys + +from torch_em.util.debug import check_loader +from torch_em.data import MinInstanceSampler +from torch_em.data.datasets.medical import get_toothfairy_loader + + +sys.path.append("..") + + +def check_toothfairy(): + from util import ROOT + + loader = get_toothfairy_loader( + path=os.path.join(ROOT, "toothfairy"), + patch_shape=(1, 512, 512), + ndim=2, + batch_size=2, + version="v2", + resize_inputs=False, + sampler=MinInstanceSampler(), + ) + + check_loader(loader, 8, plt=True, save_path="./toothfairy.png") + + +check_toothfairy() diff --git a/scripts/datasets/medical/check_uwaterloo_skin.py b/scripts/datasets/medical/check_uwaterloo_skin.py new file mode 100644 index 00000000..9cae6023 --- /dev/null +++ b/scripts/datasets/medical/check_uwaterloo_skin.py @@ -0,0 +1,21 @@ +from torch_em.util.debug import check_loader +from torch_em.data.datasets.medical import get_uwaterloo_skin_loader + + +ROOT = "/media/anwai/ANWAI/data/uwaterloo_skinseg" + + +def check_uwaterloo_skin(): + loader = get_uwaterloo_skin_loader( + path=ROOT, + patch_shape=(512, 512), + batch_size=2, + resize_inputs=True, + download=True, + ) + + check_loader(loader, 8) + + +if __name__ == "__main__": + check_uwaterloo_skin() diff --git a/scripts/datasets/util.py b/scripts/datasets/util.py new file mode 100644 index 00000000..2b2b2ae4 --- /dev/null +++ b/scripts/datasets/util.py @@ -0,0 +1,7 @@ +import os + +# Change this if you want to store the data somewhere else. +ROOT = os.path.join(os.path.split(__file__)[0], "data") + +# TODO try to derive if we can open napari +USE_NAPARI = True diff --git a/scripts/run_multi_gpu_train.py b/scripts/run_multi_gpu_train.py new file mode 100644 index 00000000..7eae7992 --- /dev/null +++ b/scripts/run_multi_gpu_train.py @@ -0,0 +1,37 @@ +from torch_em.model import UNet2d +from torch_em.multi_gpu_training import train_multi_gpu + +from torch_em.data.datasets.light_microscopy.dsb import get_dsb_dataset + + +model_class = UNet2d +model_kwargs = {"in_channels": 1, "out_channels": 1} + +data_root = "./data-dsb" +# Download the data +# get_dsb_data(data_root, "reduced", True) +# quit() + +train_dataset_class = get_dsb_dataset +train_dataset_kwargs = { + "path": data_root, "split": "train", + "patch_shape": (256, 256), "binary": True +} + +val_dataset_class = get_dsb_dataset +val_dataset_kwargs = { + "path": data_root, "split": "test", + "patch_shape": (256, 256), "binary": True +} + +loader_kwargs = {"batch_size": 4, "shuffle": True, "num_workers": 4} + +if __name__ == "__main__": + train_multi_gpu( + model_class, model_kwargs, + train_dataset_class, train_dataset_kwargs, + val_dataset_class, val_dataset_kwargs, + loader_kwargs=loader_kwargs, + iterations=250, name="multi-gpu-test", + compile_model=False, + ) diff --git a/test/self_training/test_fix_match.py b/test/self_training/test_fix_match.py index 6d8a0b12..d375db9e 100644 --- a/test/self_training/test_fix_match.py +++ b/test/self_training/test_fix_match.py @@ -35,7 +35,7 @@ def _test_fix_match( unsupervised_loss_and_metric=None, ): model = UNet2d(in_channels=1, out_channels=1, initial_features=8, depth=3) - optimizer = torch.optim.Adam(model.parameters()) + optimizer = torch.optim.AdamW(model.parameters()) name = "fm-test" trainer = self_training.FixMatchTrainer( diff --git a/test/self_training/test_mean_teacher.py b/test/self_training/test_mean_teacher.py index 110e5d9e..ee907f48 100644 --- a/test/self_training/test_mean_teacher.py +++ b/test/self_training/test_mean_teacher.py @@ -35,7 +35,7 @@ def _test_mean_teacher( unsupervised_loss_and_metric=None, ): model = UNet2d(in_channels=1, out_channels=1, initial_features=8, depth=3) - optimizer = torch.optim.Adam(model.parameters()) + optimizer = torch.optim.AdamW(model.parameters()) name = "mt-test" trainer = self_training.MeanTeacherTrainer( diff --git a/test/trainer/test_default_trainer.py b/test/trainer/test_default_trainer.py index 65251690..52bdbd6f 100644 --- a/test/trainer/test_default_trainer.py +++ b/test/trainer/test_default_trainer.py @@ -47,9 +47,9 @@ def _get_kwargs(self, with_roi=False, compile_model=False): "model": model, "loss": torch_em.loss.DiceLoss(), "metric": torch_em.loss.DiceLoss(), - "optimizer": torch.optim.Adam(model.parameters(), lr=1e-5), + "optimizer": torch.optim.AdamW(model.parameters(), lr=1e-5), "device": torch.device("cpu"), - "mixed_precision": False, + "mixed_precision": True, "compile_model": compile_model, } return kwargs diff --git a/test/trainer/test_spoco_trainer.py b/test/trainer/test_spoco_trainer.py index 195f841f..94c0811d 100644 --- a/test/trainer/test_spoco_trainer.py +++ b/test/trainer/test_spoco_trainer.py @@ -60,7 +60,7 @@ def _get_kwargs(self, with_roi=False): "model": model, "loss": DummySpocoLoss(), "metric": DummySpocoMetric(), - "optimizer": torch.optim.Adam(model.parameters(), lr=1e-5), + "optimizer": torch.optim.AdamW(model.parameters(), lr=1e-5), "device": torch.device("cpu"), "mixed_precision": False, "momentum": 0.95, diff --git a/test/transform/test_generic.py b/test/transform/test_generic.py index 09f2d4b5..8777f73b 100644 --- a/test/transform/test_generic.py +++ b/test/transform/test_generic.py @@ -1,9 +1,11 @@ import unittest +import itertools import numpy as np + import torch -from torch_em.transform import Tile +from torch_em.transform import Tile, generic class TestTile(unittest.TestCase): @@ -34,6 +36,19 @@ def _test_tile_impl(ndim, reps): actual = tile_aug(a) assert actual.shape == expected.shape + def test_resize_longest_inputs(self): + input_shapes = [(520, 704), (256, 384), (1040, 1200)] + target_shapes = [(256, 256), (512, 512), (1024, 1024)] + + for (input_shape, target_shape) in itertools.product(input_shapes, target_shapes): + test_image = np.zeros(input_shape, dtype=np.float32) + + raw_transform = generic.ResizeLongestSideInputs(target_shape=target_shape) + resized_image = raw_transform(inputs=test_image) + + assert resized_image.shape == target_shape + assert resized_image.dtype == test_image.dtype + if __name__ == "__main__": unittest.main() diff --git a/test/util/test_imageio.py b/test/util/test_imageio.py index 1032e5d7..8353df97 100644 --- a/test/util/test_imageio.py +++ b/test/util/test_imageio.py @@ -12,7 +12,7 @@ def test_read_memap(self): with tempfile.TemporaryDirectory() as td: tifffile.imwrite(os.path.join(td, "test.tif"), np.zeros((10, 10, 2))) - self.assert_(supports_memmap(os.path.join(td, "test.tif"))) + self.assertTrue(supports_memmap(os.path.join(td, "test.tif"))) data = load_image(os.path.join(td, "test.tif")) self.assertEqual(data.shape, (10, 10, 2)) diff --git a/test/util/test_modelzoo.py b/test/util/test_modelzoo.py index fa0d5f14..e560c094 100644 --- a/test/util/test_modelzoo.py +++ b/test/util/test_modelzoo.py @@ -63,7 +63,7 @@ def _create_checkpoint(self, n_channels): ) model = UNet2d(in_channels=1, out_channels=n_channels, depth=2, initial_features=4, norm="BatchNorm") - optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) + optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5) trainer = DefaultTrainer( name=self.name, train_loader=loader, val_loader=loader, model=model, loss=DiceLoss(), metric=DiceLoss(), diff --git a/torch_em/__init__.py b/torch_em/__init__.py index 0ec89675..0cb4931d 100644 --- a/torch_em/__init__.py +++ b/torch_em/__init__.py @@ -1,3 +1,7 @@ +""" +.. include:: ../doc/start_page.md +.. include:: ../doc/datasets_and_dataloaders.md +""" from .segmentation import ( default_segmentation_dataset, default_segmentation_loader, diff --git a/torch_em/__version__.py b/torch_em/__version__.py index 49e0fc1e..4910b9ec 100644 --- a/torch_em/__version__.py +++ b/torch_em/__version__.py @@ -1 +1 @@ -__version__ = "0.7.0" +__version__ = "0.7.3" diff --git a/torch_em/data/datasets/__init__.py b/torch_em/data/datasets/__init__.py index ca841e51..fa04485b 100644 --- a/torch_em/data/datasets/__init__.py +++ b/torch_em/data/datasets/__init__.py @@ -1,39 +1,5 @@ -from .asem import get_asem_loader, get_asem_dataset -from .axondeepseg import get_axondeepseg_loader, get_axondeepseg_dataset -from .bcss import get_bcss_loader, get_bcss_dataset -from .cem import get_mitolab_loader -from .covid_if import get_covid_if_loader, get_covid_if_dataset -from .cremi import get_cremi_loader, get_cremi_dataset -from .ctc import get_ctc_segmentation_loader, get_ctc_segmentation_dataset -from .deepbacs import get_deepbacs_loader, get_deepbacs_dataset -from .dsb import get_dsb_loader, get_dsb_dataset -from .dynamicnuclearnet import get_dynamicnuclearnet_loader, get_dynamicnuclearnet_dataset -from .hpa import get_hpa_segmentation_loader, get_hpa_segmentation_dataset -from .isbi2012 import get_isbi_loader, get_isbi_dataset -from .kasthuri import get_kasthuri_loader, get_kasthuri_dataset -from .livecell import get_livecell_loader, get_livecell_dataset -from .lizard import get_lizard_loader, get_lizard_dataset -from .lucchi import get_lucchi_loader, get_lucchi_dataset -from .mitoem import get_mitoem_loader, get_mitoem_dataset -from .monuseg import get_monuseg_loader, get_monuseg_dataset -from .monusac import get_monusac_loader, get_monusac_dataset -from .mouse_embryo import get_mouse_embryo_loader, get_mouse_embryo_dataset -from .neurips_cell_seg import ( - get_neurips_cellseg_supervised_loader, get_neurips_cellseg_supervised_dataset, - get_neurips_cellseg_unsupervised_loader, get_neurips_cellseg_unsupervised_dataset -) -from .nuc_mm import get_nuc_mm_loader, get_nuc_mm_dataset -from .pannuke import get_pannuke_loader, get_pannuke_dataset -from .plantseg import get_plantseg_loader, get_plantseg_dataset -from .platynereis import ( - get_platynereis_cell_loader, get_platynereis_cell_dataset, - get_platynereis_cilia_loader, get_platynereis_cilia_dataset, - get_platynereis_cuticle_loader, get_platynereis_cuticle_dataset, - get_platynereis_nuclei_loader, get_platynereis_nuclei_dataset -) -from .snemi import get_snemi_loader, get_snemi_dataset -from .sponge_em import get_sponge_em_loader, get_sponge_em_dataset -from .tissuenet import get_tissuenet_loader, get_tissuenet_dataset -from .uro_cell import get_uro_cell_loader, get_uro_cell_dataset +from .electron_microscopy import * +from .histopathology import * +from .light_microscopy import * +from .medical import * from .util import get_bioimageio_dataset_id -from .vnc import get_vnc_mito_loader, get_vnc_mito_dataset diff --git a/torch_em/data/datasets/bcss.py b/torch_em/data/datasets/bcss.py deleted file mode 100644 index adc079a4..00000000 --- a/torch_em/data/datasets/bcss.py +++ /dev/null @@ -1,160 +0,0 @@ -import os -import shutil -from glob import glob -from pathlib import Path - -from sklearn.model_selection import train_test_split - -import torch -import torch_em -from torch_em.data.datasets import util -from torch_em.data import ImageCollectionDataset - - -URL = "https://drive.google.com/drive/folders/1zqbdkQF8i5cEmZOGmbdQm-EP8dRYtvss?usp=sharing" - - -# TODO -CHECKSUM = None - - -TEST_LIST = [ - 'TCGA-A2-A0SX-DX1_xmin53791_ymin56683_MPP-0.2500', 'TCGA-BH-A0BG-DX1_xmin64019_ymin24975_MPP-0.2500', - 'TCGA-AR-A1AI-DX1_xmin38671_ymin10616_MPP-0.2500', 'TCGA-E2-A574-DX1_xmin54962_ymin47475_MPP-0.2500', - 'TCGA-GM-A3XL-DX1_xmin29910_ymin15820_MPP-0.2500', 'TCGA-E2-A14X-DX1_xmin88836_ymin66393_MPP-0.2500', - 'TCGA-A2-A04P-DX1_xmin104246_ymin48517_MPP-0.2500', 'TCGA-E2-A14N-DX1_xmin21383_ymin66838_MPP-0.2500', - 'TCGA-EW-A1OV-DX1_xmin126026_ymin65132_MPP-0.2500', 'TCGA-S3-AA15-DX1_xmin55486_ymin28926_MPP-0.2500', - 'TCGA-LL-A5YO-DX1_xmin36631_ymin44396_MPP-0.2500', 'TCGA-GI-A2C9-DX1_xmin20882_ymin11843_MPP-0.2500', - 'TCGA-BH-A0BW-DX1_xmin42346_ymin30843_MPP-0.2500', 'TCGA-E2-A1B6-DX1_xmin16266_ymin50634_MPP-0.2500', - 'TCGA-AO-A0J2-DX1_xmin33561_ymin14515_MPP-0.2500' -] - - -def _download_bcss_dataset(path, download): - """Current recommendation: - - download the folder from URL manually - - use the consortium's git repo to download the dataset (https://github.com/PathologyDataScience/BCSS) - """ - raise NotImplementedError("Please download the dataset using the drive link / git repo directly") - - # FIXME: limitation for the installation below: - # - only downloads first 50 files - due to `gdown`'s download folder function - # - (optional) clone their git repo to download their data - util.download_source_gdrive(path=path, url=URL, download=download, checksum=CHECKSUM, download_type="folder") - - -def _get_image_and_label_paths(path): - # when downloading the files from `URL`, the input images are stored under `rgbs_colorNormalized` - # when getting the files from the git repo's command line feature, the input images are stored under `images` - if os.path.exists(os.path.join(path, "images")): - image_paths = sorted(glob(os.path.join(path, "images", "*"))) - label_paths = sorted(glob(os.path.join(path, "masks", "*"))) - elif os.path.exists(os.path.join(path, "0_Public-data-Amgad2019_0.25MPP", "rgbs_colorNormalized")): - image_paths = sorted(glob(os.path.join(path, "0_Public-data-Amgad2019_0.25MPP", "rgbs_colorNormalized", "*"))) - label_paths = sorted(glob(os.path.join(path, "0_Public-data-Amgad2019_0.25MPP", "masks", "*"))) - else: - raise ValueError("Please check the image directory. If downloaded from gdrive, it's named \"rgbs_colorNormalized\", if from github it's named \"images\"") - - return image_paths, label_paths - - -def _assort_bcss_data(path, download): - if download: - _download_bcss_dataset(path, download) - - if os.path.exists(os.path.join(path, "train")) and os.path.exists(os.path.join(path, "test")): - return - - all_image_paths, all_label_paths = _get_image_and_label_paths(path) - - train_img_dir, train_lab_dir = os.path.join(path, "train", "images"), os.path.join(path, "train", "masks") - test_img_dir, test_lab_dir = os.path.join(path, "test", "images"), os.path.join(path, "test", "masks") - os.makedirs(train_img_dir, exist_ok=True) - os.makedirs(train_lab_dir, exist_ok=True) - os.makedirs(test_img_dir, exist_ok=True) - os.makedirs(test_lab_dir, exist_ok=True) - - for image_path, label_path in zip(all_image_paths, all_label_paths): - img_idx, label_idx = os.path.split(image_path)[-1], os.path.split(label_path)[-1] - if Path(image_path).stem in TEST_LIST: - # move image and label to test - dst_img_path, dst_lab_path = os.path.join(test_img_dir, img_idx), os.path.join(test_lab_dir, label_idx) - shutil.copy(src=image_path, dst=dst_img_path) - shutil.copy(src=label_path, dst=dst_lab_path) - else: - # move image and label to train - dst_img_path, dst_lab_path = os.path.join(train_img_dir, img_idx), os.path.join(train_lab_dir, label_idx) - shutil.copy(src=image_path, dst=dst_img_path) - shutil.copy(src=label_path, dst=dst_lab_path) - - -def get_bcss_dataset(path, patch_shape, split=None, val_fraction=0.2, download=False, label_dtype=torch.int64, **kwargs): - """Dataset for breast cancer tissue segmentation in histopathology. - - This dataset is from https://bcsegmentation.grand-challenge.org/BCSS/. - Please cite this paper (https://doi.org/10.1093/bioinformatics/btz083) if you use this dataset for a publication. - - NOTE: There are multiple semantic instances in tissue labels. Below mentioned are their respective index details: - - 0: outside_roi (~background) - - 1: tumor - - 2: stroma - - 3: lymphocytic_infiltrate - - 4: necrosis_or_debris - - 5: glandular_secretions - - 6: blood - - 7: exclude - - 8: metaplasia_NOS - - 9: fat - - 10: plasma_cells - - 11: other_immune_infiltrate - - 12: mucoid_material - - 13: normal_acinus_or_duct - - 14: lymphatics - - 15: undetermined - - 16: nerve - - 17: skin_adnexa - - 18: blood_vessel - - 19: angioinvasion - - 20: dcis - - 21: other - """ - _assort_bcss_data(path, download) - - if split is None: - image_paths = sorted(glob(os.path.join(path, "*", "images", "*"))) - label_paths = sorted(glob(os.path.join(path, "*", "masks", "*"))) - else: - assert split in ["train", "val", "test"], "Please choose from the available `train` / `val` / `test` splits" - if split == "test": - image_paths = sorted(glob(os.path.join(path, "test", "images", "*"))) - label_paths = sorted(glob(os.path.join(path, "test", "masks", "*"))) - else: - image_paths = sorted(glob(os.path.join(path, "train", "images", "*"))) - label_paths = sorted(glob(os.path.join(path, "train", "masks", "*"))) - - (train_image_paths, val_image_paths, - train_label_paths, val_label_paths) = train_test_split( - image_paths, label_paths, test_size=val_fraction, random_state=42 - ) - - image_paths = train_image_paths if split == "train" else val_image_paths - label_paths = train_label_paths if split == "train" else val_label_paths - - assert len(image_paths) == len(label_paths) - - dataset = ImageCollectionDataset( - image_paths, label_paths, patch_shape=patch_shape, label_dtype=label_dtype, **kwargs - ) - return dataset - - -def get_bcss_loader( - path, patch_shape, batch_size, split=None, val_fraction=0.2, download=False, label_dtype=torch.int64, **kwargs -): - """Dataloader for breast cancer tissue segmentation in histopathology. See `get_bcss_dataset` for details.""" - ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) - dataset = get_bcss_dataset( - path, patch_shape, split, val_fraction, download=download, label_dtype=label_dtype, **ds_kwargs - ) - loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) - return loader diff --git a/torch_em/data/datasets/cem.py b/torch_em/data/datasets/cem.py deleted file mode 100644 index 718ccc54..00000000 --- a/torch_em/data/datasets/cem.py +++ /dev/null @@ -1,237 +0,0 @@ -"""Contains datasets and dataloader for the CEM data: -- CEM-MitoLab: annotated 2d data for training mitochondria segmentation models - - https://www.ebi.ac.uk/empiar/EMPIAR-11037/ -- CEM-Mito-Benchmark: 7 Benchmark datasets for mitochondria segmentation - - https://www.ebi.ac.uk/empiar/EMPIAR-10982/ -- CEM-1.5M: unlabeled EM images for pretraining: (Not yet implemented) - - https://www.ebi.ac.uk/empiar/EMPIAR-11035/ - -The data itself can be downloaded from EMPIAR via aspera. -- You can install aspera via mamba. I recommend to do this in a separate environment - to avoid dependency issues: - - `$ mamba create -c conda-forge -c hcc -n aspera aspera-cli` -- After this you can run `$ mamba activate aspera` to have an environment with aspera installed. -- You can then download the data for one of the three datasets like this: - - ascp -QT -l 200m -P33001 -i /etc/asperaweb_id_dsa.openssh emp_ext2@fasp.ebi.ac.uk:/ - - Where is the path to the mamba environment, the id of one of the three datasets - and where you want to download the data. -- After this you can use the functions in this file if you use as location for the data. - -Note that I have implemented automatic download, but this leads to issues with -mamba for me so I recommend to download the data manually and then run the loaders -with the correct path. -""" - -import json -import os -from glob import glob - -import imageio.v3 as imageio -import numpy as np -import torch_em -from sklearn.model_selection import train_test_split - -from . import util - -BENCHMARK_DATASETS = { - 1: "mito_benchmarks/c_elegans", - 2: "mito_benchmarks/fly_brain", - 3: "mito_benchmarks/glycolytic_muscle", - 4: "mito_benchmarks/hela_cell", - 5: "mito_benchmarks/lucchi_pp", - 6: "mito_benchmarks/salivary_gland", - 7: "tem_benchmark", -} -BENCHMARK_SHAPES = { - 1: (256, 256, 256), - 2: (256, 255, 255), - 3: (302, 383, 765), - 4: (256, 256, 256), - 5: (165, 768, 1024), - 6: (1260, 1081, 1200), - 7: (224, 224), # NOTE: this is the minimal square shape that fits -} - - -def _get_mitolab_data(path, download): - access_id = "11037" - data_path = util.download_source_empiar(path, access_id, download) - - zip_path = os.path.join(data_path, "data/cem_mitolab.zip") - if os.path.exists(zip_path): - util.unzip(zip_path, data_path, remove=True) - - data_root = os.path.join(data_path, "cem_mitolab") - assert os.path.exists(data_root) - - return data_root - - -def _get_all_images(path): - raw_paths, label_paths = [], [] - folders = glob(os.path.join(path, "*")) - assert all(os.path.isdir(folder) for folder in folders) - for folder in folders: - images = sorted(glob(os.path.join(folder, "images", "*.tiff"))) - assert len(images) > 0 - labels = sorted(glob(os.path.join(folder, "masks", "*.tiff"))) - assert len(images) == len(labels) - raw_paths.extend(images) - label_paths.extend(labels) - return raw_paths, label_paths - - -def _get_non_empty_images(path): - save_path = os.path.join(path, "non_empty_images.json") - - if os.path.exists(save_path): - with open(save_path, "r") as f: - saved_images = json.load(f) - raw_paths, label_paths = saved_images["images"], saved_images["labels"] - raw_paths = [os.path.join(path, rp) for rp in raw_paths] - label_paths = [os.path.join(path, lp) for lp in label_paths] - return raw_paths, label_paths - - folders = glob(os.path.join(path, "*")) - assert all(os.path.isdir(folder) for folder in folders) - - raw_paths, label_paths = [], [] - for folder in folders: - images = sorted(glob(os.path.join(folder, "images", "*.tiff"))) - labels = sorted(glob(os.path.join(folder, "masks", "*.tiff"))) - assert len(images) > 0 - assert len(images) == len(labels) - - for im, lab in zip(images, labels): - n_labels = len(np.unique(imageio.imread(lab))) - if n_labels > 1: - raw_paths.append(im) - label_paths.append(lab) - - raw_paths_rel = [os.path.relpath(rp, path) for rp in raw_paths] - label_paths_rel = [os.path.relpath(lp, path) for lp in label_paths] - - with open(save_path, "w") as f: - json.dump({"images": raw_paths_rel, "labels": label_paths_rel}, f) - - return raw_paths, label_paths - - -def _get_mitolab_paths(path, split, val_fraction, download, discard_empty_images): - data_path = _get_mitolab_data(path, download) - if discard_empty_images: - raw_paths, label_paths = _get_non_empty_images(data_path) - else: - raw_paths, label_paths = _get_all_images(data_path) - - if split is not None: - raw_train, raw_val, labels_train, labels_val = train_test_split( - raw_paths, label_paths, test_size=val_fraction, random_state=42, - ) - if split == "train": - raw_paths, label_paths = raw_train, labels_train - else: - raw_paths, label_paths = raw_val, labels_val - - assert len(raw_paths) > 0 - assert len(raw_paths) == len(label_paths) - return raw_paths, label_paths - - -def _get_benchmark_data(path, dataset_id, download): - access_id = "10982" - data_path = util.download_source_empiar(path, access_id, download) - dataset_path = os.path.join(data_path, "data", BENCHMARK_DATASETS[dataset_id]) - - # these are the 3d datasets - if dataset_id in range(1, 7): - dataset_name = os.path.basename(dataset_path) - raw_paths = os.path.join(dataset_path, f"{dataset_name}_em.tif") - label_paths = os.path.join(dataset_path, f"{dataset_name}_mito.tif") - raw_key, label_key = None, None - is_seg_dataset = True - - # this is the 2d dataset - else: - raw_paths = os.path.join(dataset_path, "images") - label_paths = os.path.join(dataset_path, "masks") - raw_key, label_key = "*.tiff", "*.tiff" - is_seg_dataset = False - - return raw_paths, label_paths, raw_key, label_key, is_seg_dataset - - -# -# data sets -# - - -def get_mitolab_dataset( - path, split, patch_shape=(224, 224), val_fraction=0.05, download=False, - discard_empty_images=True, **kwargs -): - assert split in ("train", "val", None) - assert os.path.exists(path) - raw_paths, label_paths = _get_mitolab_paths(path, split, val_fraction, download, discard_empty_images) - return torch_em.default_segmentation_dataset( - raw_paths=raw_paths, raw_key=None, - label_paths=label_paths, label_key=None, - patch_shape=patch_shape, is_seg_dataset=False, ndim=2, **kwargs - ) - - -def get_cem15m_dataset(path): - raise NotImplementedError - - -def get_benchmark_dataset( - path, dataset_id, patch_shape, download=False, **kwargs, -): - """ - ascp -QT -l 200m -P33001 -i /etc/asperaweb_id_dsa.openssh emp_ext2@fasp.ebi.ac.uk:/10982 - """ - if dataset_id not in range(1, 8): - raise ValueError - raw_paths, label_paths, raw_key, label_key, is_seg_dataset = _get_benchmark_data(path, dataset_id, download) - return torch_em.default_segmentation_dataset( - raw_paths=raw_paths, raw_key=raw_key, - label_paths=label_paths, label_key=label_key, - patch_shape=patch_shape, - is_seg_dataset=is_seg_dataset, **kwargs, - ) - - -# -# data loaders -# - - -def get_mitolab_loader( - path, split, batch_size, patch_shape=(224, 224), - discard_empty_images=True, - val_fraction=0.05, download=False, **kwargs -): - ds_kwargs, loader_kwargs = util.split_kwargs( - torch_em.default_segmentation_dataset, **kwargs - ) - dataset = get_mitolab_dataset( - path, split, patch_shape, download=download, discard_empty_images=discard_empty_images, **ds_kwargs - ) - loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) - return loader - - -def get_cem15m_loader(path): - raise NotImplementedError - - -def get_benchmark_loader(path, dataset_id, batch_size, patch_shape, download=False, **kwargs): - ds_kwargs, loader_kwargs = util.split_kwargs( - torch_em.default_segmentation_dataset, **kwargs - ) - dataset = get_benchmark_dataset( - path, dataset_id, - patch_shape=patch_shape, download=download, **ds_kwargs - ) - loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) - return loader diff --git a/torch_em/data/datasets/covid_if.py b/torch_em/data/datasets/covid_if.py deleted file mode 100644 index f9b42e0c..00000000 --- a/torch_em/data/datasets/covid_if.py +++ /dev/null @@ -1,79 +0,0 @@ -import os -from glob import glob - -import torch_em -from . import util - -COVID_IF_URL = "https://zenodo.org/record/5092850/files/covid-if-groundtruth.zip?download=1" -CHECKSUM = "d9cd6c85a19b802c771fb4ff928894b19a8fab0e0af269c49235fdac3f7a60e1" - - -def _download_covid_if(path, download): - url = COVID_IF_URL - checksum = CHECKSUM - - if os.path.exists(path): - return - - os.makedirs(path, exist_ok=True) - zip_path = os.path.join(path, "covid-if.zip") - util.download_source(zip_path, url, download, checksum) - util.unzip(zip_path, path, True) - - -def get_covid_if_dataset( - path, patch_shape, sample_range=None, target="cells", download=False, - offsets=None, boundaries=False, binary=False, **kwargs -): - """Dataset for the cells and nuclei in immunofluorescence. - - This dataset is from the publication https://doi.org/10.1002/bies.202000257. - Please cite it if you use this dataset for a publication. - """ - available_targets = ("cells", "nuclei") - # TODO also support infected_cells - # available_targets = ("cells", "nuclei", "infected_cells") - assert target in available_targets, f"{target} not found in {available_targets}" - - if target == "cells": - raw_key = "raw/serum_IgG/s0" - label_key = "labels/cells/s0" - elif target == "nuclei": - raw_key = "raw/nuclei/s0" - label_key = "labels/nuclei/s0" - - _download_covid_if(path, download) - - file_paths = sorted(glob(os.path.join(path, "*.h5"))) - if sample_range is not None: - start, stop = sample_range - if start is None: - start = 0 - if stop is None: - stop = len(file_paths) - file_paths = [os.path.join(path, f"gt_image_{idx:03}.h5") for idx in range(start, stop)] - assert all(os.path.exists(fp) for fp in file_paths), f"Invalid sample range {sample_range}" - - kwargs, _ = util.add_instance_label_transform( - kwargs, add_binary_target=True, binary=binary, boundaries=boundaries, offsets=offsets - ) - kwargs = util.update_kwargs(kwargs, "ndim", 2) - - return torch_em.default_segmentation_dataset( - file_paths, raw_key, file_paths, label_key, patch_shape, **kwargs - ) - - -def get_covid_if_loader( - path, patch_shape, batch_size, sample_range=None, target="cells", download=False, - offsets=None, boundaries=False, binary=False, **kwargs -): - """Dataloader for the segmentation of cells and nuclei in immunofluoroscence. See 'get_covid_if_loader' for details. - """ - ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) - dataset = get_covid_if_dataset( - path, patch_shape, sample_range=sample_range, target=target, download=download, - offsets=offsets, boundaries=boundaries, binary=binary, **ds_kwargs, - ) - loader = torch_em.get_data_loader(dataset, batch_size=batch_size, **loader_kwargs) - return loader diff --git a/torch_em/data/datasets/cremi.py b/torch_em/data/datasets/cremi.py deleted file mode 100644 index e5e3d24d..00000000 --- a/torch_em/data/datasets/cremi.py +++ /dev/null @@ -1,140 +0,0 @@ -import os -import numpy as np - -import torch_em -from . import util - -CREMI_URLS = { - "original": { - "A": "https://cremi.org/static/data/sample_A_20160501.hdf", - "B": "https://cremi.org/static/data/sample_B_20160501.hdf", - "C": "https://cremi.org/static/data/sample_C_20160501.hdf", - }, - "realigned": {}, - "defects": "https://zenodo.org/record/5767036/files/sample_ABC_padded_defects.h5" -} -CHECKSUMS = { - "original": { - "A": "4c563d1b78acb2bcfb3ea958b6fe1533422f7f4a19f3e05b600bfa11430b510d", - "B": "887e85521e00deead18c94a21ad71f278d88a5214c7edeed943130a1f4bb48b8", - "C": "2874496f224d222ebc29d0e4753e8c458093e1d37bc53acd1b69b19ed1ae7052", - }, - "realigned": {}, - "defects": "7b06ffa34733b2c32956ea5005e0cf345e7d3a27477f42f7c905701cdc947bd0" -} - - -# TODO add support for realigned volumes -def get_cremi_dataset( - path, - patch_shape, - samples=("A", "B", "C"), - use_realigned=False, - download=False, - offsets=None, - boundaries=False, - rois={}, - defect_augmentation_kwargs={ - "p_drop_slice": 0.025, - "p_low_contrast": 0.025, - "p_deform_slice": 0.0, - "deformation_mode": "compress", - }, - **kwargs, -): - """Dataset for the segmentation of neurons in EM. - - This dataset is from the CREMI challenge: https://cremi.org/. - """ - assert len(patch_shape) == 3 - if rois is not None: - assert isinstance(rois, dict) - os.makedirs(path, exist_ok=True) - - if use_realigned: - # we need to sample batches in this case - # sampler = torch_em.data.MinForegroundSampler(min_fraction=0.05, p_reject=.75) - raise NotImplementedError - else: - urls = CREMI_URLS["original"] - checksums = CHECKSUMS["original"] - - data_paths = [] - data_rois = [] - for name in samples: - url = urls[name] - checksum = checksums[name] - data_path = os.path.join(path, f"sample{name}.h5") - # CREMI SSL certificates expired, so we need to disable verification - util.download_source(data_path, url, download, checksum, verify=False) - data_paths.append(data_path) - data_rois.append(rois.get(name, np.s_[:, :, :])) - - if defect_augmentation_kwargs is not None and "artifact_source" not in defect_augmentation_kwargs: - # download the defect volume - url = CREMI_URLS["defects"] - checksum = CHECKSUMS["defects"] - defect_path = os.path.join(path, "cremi_defects.h5") - util.download_source(defect_path, url, download, checksum) - defect_patch_shape = (1,) + tuple(patch_shape[1:]) - artifact_source = torch_em.transform.get_artifact_source(defect_path, defect_patch_shape, - min_mask_fraction=0.75, - raw_key="defect_sections/raw", - mask_key="defect_sections/mask") - defect_augmentation_kwargs.update({"artifact_source": artifact_source}) - - raw_key = "volumes/raw" - label_key = "volumes/labels/neuron_ids" - - # defect augmentations - if defect_augmentation_kwargs is not None: - raw_transform = torch_em.transform.get_raw_transform( - augmentation1=torch_em.transform.EMDefectAugmentation(**defect_augmentation_kwargs) - ) - kwargs = util.update_kwargs(kwargs, "raw_transform", raw_transform) - - kwargs, _ = util.add_instance_label_transform( - kwargs, add_binary_target=False, boundaries=boundaries, offsets=offsets - ) - - return torch_em.default_segmentation_dataset( - data_paths, raw_key, data_paths, label_key, patch_shape, rois=data_rois, **kwargs - ) - - -def get_cremi_loader( - path, - patch_shape, - batch_size, - samples=("A", "B", "C"), - use_realigned=False, - download=False, - offsets=None, - boundaries=False, - rois={}, - defect_augmentation_kwargs={ - "p_drop_slice": 0.025, - "p_low_contrast": 0.025, - "p_deform_slice": 0.0, - "deformation_mode": "compress", - }, - **kwargs, -): - """Dataset for the segmentation of neurons in EM. See 'get_cremi_dataset' for details. - """ - dataset_kwargs, loader_kwargs = util.split_kwargs( - torch_em.default_segmentation_dataset, **kwargs - ) - ds = get_cremi_dataset( - path=path, - patch_shape=patch_shape, - samples=samples, - use_realigned=use_realigned, - download=download, - offsets=offsets, - boundaries=boundaries, - rois=rois, - defect_augmentation_kwargs=defect_augmentation_kwargs, - **dataset_kwargs, - ) - return torch_em.get_data_loader(ds, batch_size=batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/dsb.py b/torch_em/data/datasets/dsb.py deleted file mode 100644 index 559d573c..00000000 --- a/torch_em/data/datasets/dsb.py +++ /dev/null @@ -1,77 +0,0 @@ -import os -from shutil import move - -import torch_em -from . import util - -DSB_URLS = { - "full": "", # TODO - "reduced": "https://github.com/stardist/stardist/releases/download/0.1.0/dsb2018.zip" -} -CHECKSUMS = { - "full": None, - "reduced": "e44921950edce378063aa4457e625581ba35b4c2dbd9a07c19d48900129f386f" -} - - -def _download_dsb(path, source, download): - os.makedirs(path, exist_ok=True) - url = DSB_URLS[source] - checksum = CHECKSUMS[source] - - train_out_path = os.path.join(path, "train") - test_out_path = os.path.join(path, "test") - - if os.path.exists(train_out_path) and os.path.exists(test_out_path): - return - - zip_path = os.path.join(path, "dsb.zip") - util.download_source(zip_path, url, download, checksum) - util.unzip(zip_path, path, True) - - move(os.path.join(path, "dsb2018", "train"), train_out_path) - move(os.path.join(path, "dsb2018", "test"), test_out_path) - - -def get_dsb_dataset( - path, split, patch_shape, download=False, - offsets=None, boundaries=False, binary=False, - source="reduced", **kwargs -): - """Dataset for the segmentation of nuclei in light microscopy. - - This dataset is from the publication https://doi.org/10.1038/s41592-019-0612-7. - Please cite it if you use this dataset for a publication. - """ - assert split in ("test", "train"), split - _download_dsb(path, source, download) - - image_path = os.path.join(path, split, "images") - label_path = os.path.join(path, split, "masks") - - kwargs, _ = util.add_instance_label_transform( - kwargs, add_binary_target=True, binary=binary, boundaries=boundaries, offsets=offsets - ) - kwargs = util.update_kwargs(kwargs, "ndim", 2) - return torch_em.default_segmentation_dataset( - image_path, "*.tif", label_path, "*.tif", patch_shape, **kwargs - ) - - -def get_dsb_loader( - path, split, patch_shape, batch_size, download=False, - offsets=None, boundaries=False, binary=False, - source="reduced", **kwargs -): - """Dataloader for the segmentation of nuclei in light microscopy. See 'get_dsb_dataset' for details. - """ - ds_kwargs, loader_kwargs = util.split_kwargs( - torch_em.default_segmentation_dataset, **kwargs - ) - dataset = get_dsb_dataset( - path, split, patch_shape, download=download, - offsets=offsets, boundaries=boundaries, binary=binary, - source=source, **ds_kwargs, - ) - loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) - return loader diff --git a/torch_em/data/datasets/dynamicnuclearnet.py b/torch_em/data/datasets/dynamicnuclearnet.py deleted file mode 100644 index dbe573c5..00000000 --- a/torch_em/data/datasets/dynamicnuclearnet.py +++ /dev/null @@ -1,93 +0,0 @@ -import os -from tqdm import tqdm -from glob import glob - -import z5py -import numpy as np -import pandas as pd - -import torch_em - -from . import util - - -# Automatic download is currently not possible, because of authentication -URL = None # TODO: here - https://datasets.deepcell.org/data - - -def _create_split(path, split): - split_file = os.path.join(path, "DynamicNuclearNet-segmentation-v1_0", f"{split}.npz") - split_folder = os.path.join(path, split) - os.makedirs(split_folder, exist_ok=True) - data = np.load(split_file, allow_pickle=True) - - x, y = data["X"], data["y"] - metadata = data["meta"] - metadata = pd.DataFrame(metadata[1:], columns=metadata[0]) - - for i, (im, label) in tqdm(enumerate(zip(x, y)), total=len(x), desc=f"Creating files for {split}-split"): - out_path = os.path.join(split_folder, f"image_{i:04}.zarr") - image_channel = im[..., 0] - label_channel = label[..., 0] - chunks = image_channel.shape - with z5py.File(out_path, "a") as f: - f.create_dataset("raw", data=image_channel, compression="gzip", chunks=chunks) - f.create_dataset("labels", data=label_channel, compression="gzip", chunks=chunks) - - os.remove(split_file) - - -def _create_dataset(path, zip_path): - util.unzip(zip_path, path, remove=False) - splits = ["train", "val", "test"] - assert all( - [os.path.exists(os.path.join(path, "DynamicNuclearNet-segmentation-v1_0", f"{split}.npz")) for split in splits] - ) - for split in splits: - _create_split(path, split) - - -def get_dynamicnuclearnet_dataset( - path, split, patch_shape, download=False, **kwargs -): - """Dataset for the segmentation of cell nuclei imaged with fluorescene microscopy. - - This dataset is from the publication https://doi.org/10.1101/803205. - Please cite it if you use this dataset for a publication.""" - splits = ["train", "val", "test"] - assert split in splits - - # check if the dataset exists already - zip_path = os.path.join(path, "DynamicNuclearNet-segmentation-v1_0.zip") - if all([os.path.exists(os.path.join(path, split)) for split in splits]): # yes it does - pass - elif os.path.exists(zip_path): # no it does not, but we have the zip there and can unpack it - _create_dataset(path, zip_path) - else: - raise RuntimeError( - "We do not support automatic download for the dynamic nuclear net dataset yet. " - f"Please download the dataset from https://datasets.deepcell.org/data and put it here: {zip_path}" - ) - - split_folder = os.path.join(path, split) - assert os.path.exists(split_folder) - data_path = glob(os.path.join(split_folder, "*.zarr")) - assert len(data_path) > 0 - - raw_key, label_key = "raw", "labels" - - return torch_em.default_segmentation_dataset( - data_path, raw_key, data_path, label_key, patch_shape, is_seg_dataset=True, ndim=2, **kwargs - ) - - -def get_dynamicnuclearnet_loader( - path, split, patch_shape, batch_size, download=False, **kwargs -): - """Dataloader for the segmentation of cell nuclei for 5 different cell lines in fluorescence microscopes. - See `get_dynamicnuclearnet_dataset` for details. -""" - ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) - dataset = get_dynamicnuclearnet_dataset(path, split, patch_shape, download, **ds_kwargs) - loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) - return loader diff --git a/torch_em/data/datasets/electron_microscopy/__init__.py b/torch_em/data/datasets/electron_microscopy/__init__.py new file mode 100644 index 00000000..459733f0 --- /dev/null +++ b/torch_em/data/datasets/electron_microscopy/__init__.py @@ -0,0 +1,21 @@ +from .asem import get_asem_loader, get_asem_dataset +from .axondeepseg import get_axondeepseg_loader, get_axondeepseg_dataset +from .cem import get_mitolab_loader +from .cremi import get_cremi_loader, get_cremi_dataset +from .deepict import get_deepict_actin_loader, get_deepict_actin_dataset +from .emneuron import get_emneuron_loader, get_emneuron_dataset +from .isbi2012 import get_isbi_loader, get_isbi_dataset +from .kasthuri import get_kasthuri_loader, get_kasthuri_dataset +from .lucchi import get_lucchi_loader, get_lucchi_dataset +from .mitoem import get_mitoem_loader, get_mitoem_dataset +from .nuc_mm import get_nuc_mm_loader, get_nuc_mm_dataset +from .platynereis import ( + get_platynereis_cell_loader, get_platynereis_cell_dataset, + get_platynereis_cilia_loader, get_platynereis_cilia_dataset, + get_platynereis_cuticle_loader, get_platynereis_cuticle_dataset, + get_platynereis_nuclei_loader, get_platynereis_nuclei_dataset +) +from .snemi import get_snemi_loader, get_snemi_dataset +from .sponge_em import get_sponge_em_loader, get_sponge_em_dataset +from .uro_cell import get_uro_cell_loader, get_uro_cell_dataset +from .vnc import get_vnc_mito_loader, get_vnc_mito_dataset diff --git a/torch_em/data/datasets/asem.py b/torch_em/data/datasets/electron_microscopy/asem.py similarity index 54% rename from torch_em/data/datasets/asem.py rename to torch_em/data/datasets/electron_microscopy/asem.py index 11c6e83b..99d6dfe8 100644 --- a/torch_em/data/datasets/asem.py +++ b/torch_em/data/datasets/electron_microscopy/asem.py @@ -1,12 +1,20 @@ +"""ASEM is a dataset for segmentation of cellular structures in FIB-SEM. + +The dataset was publised in https://doi.org/10.1083/jcb.202208005. +Please cite this publication if you use the dataset in your research. +""" + import os +from typing import Union, Tuple, Optional, List + import numpy as np -import zarr +from torch.utils.data import Dataset, DataLoader import torch_em -from . import util -from .. import ConcatDataset +from .. import util +from ... import ConcatDataset try: import quilt3 as q3 @@ -49,40 +57,63 @@ } -def _download_asem_dataset(path, volume_ids, download): - """https://open.quiltdata.com/b/asem-project""" +def get_asem_data(path: Union[os.PathLike, str], volume_ids: List[str], download: bool = False): + """Download the ASEM dataset. + + The dataset is located at https://open.quiltdata.com/b/asem-project. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + volume_ids: List of volumes to download. + download: Whether to download the data if it is not present. + """ if download and not have_quilt: raise ModuleNotFoundError("Please install quilt3: 'pip install quilt3'.") b = q3.Bucket("s3://asem-project") - volume_paths = [] for volume_id in volume_ids: volume_path = os.path.join(path, VOLUMES[volume_id]) - if not os.path.exists(volume_path): - if not download: - raise FileNotFoundError(f"{VOLUMES[volume_id]} is not found, and 'download' is set to False.") - - print(f"The ASEM dataset for sample '{volume_id}' is not available yet and will be downloaded and created.") - print("Note that this dataset is large, so this step can take several hours (depending on your internet).") - b.fetch( - key=f"datasets/{VOLUMES[volume_id]}/volumes/labels/", - path=os.path.join(volume_path, "volumes", "labels/") - ) - b.fetch( - key=f"datasets/{VOLUMES[volume_id]}/volumes/raw/", - path=os.path.join(volume_path, "volumes", "raw/") - ) - # let's get the group metadata keyfiles - b.fetch(key=f"datasets/{VOLUMES[volume_id]}/.zgroup", path=f"{volume_path}/") - b.fetch(key=f"datasets/{VOLUMES[volume_id]}/volumes/.zgroup", path=f"{volume_path}/volumes/") - - volume_paths.append(volume_path) - + if os.path.exists(volume_path): + continue + + if not download: + raise FileNotFoundError(f"{VOLUMES[volume_id]} is not found, and 'download' is set to False.") + + print(f"The ASEM dataset for sample '{volume_id}' is not available yet and will be downloaded and created.") + print("Note that this dataset is large, so this step can take several hours (depending on your internet).") + b.fetch( + key=f"datasets/{VOLUMES[volume_id]}/volumes/labels/", + path=os.path.join(volume_path, "volumes", "labels/") + ) + b.fetch( + key=f"datasets/{VOLUMES[volume_id]}/volumes/raw/", + path=os.path.join(volume_path, "volumes", "raw/") + ) + # let's get the group metadata keyfiles + b.fetch(key=f"datasets/{VOLUMES[volume_id]}/.zgroup", path=f"{volume_path}/") + b.fetch(key=f"datasets/{VOLUMES[volume_id]}/volumes/.zgroup", path=f"{volume_path}/volumes/") + + +def get_asem_paths(path: Union[os.PathLike, str], volume_ids: List[str], download: bool = False) -> List[str]: + """Get paths to the ASEM data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + volume_ids: List of volumes to download. + download: Whether to download the data if it is not present. + + Returns: + List of paths for all volume ids. + """ + get_asem_data(path, volume_ids, download) + volume_paths = [os.path.join(path, VOLUMES[vol_id]) for vol_id in volume_ids] return volume_paths def _make_volumes_consistent(volume_path, organelle): + import zarr + have_inconsistent_volumes = False # we shouldn't load the volumes which are already consistent @@ -141,12 +172,25 @@ def _check_input_args(input_arg, default_values): def get_asem_dataset( - path, patch_shape, ndim, download, organelles=None, volume_ids=None, **kwargs -): - """Dataset for the segmentation of organelles in FIB-SEM cells. - - This dataset provides access to 3d images of organelles (mitochondria, golgi, endoplasmic reticulum) - segmented in cells. If you use this data in your research, please cite: https://doi.org/10.1083/jcb.202208005 + path: Union[os.PathLike, str], + patch_shape: Tuple[int, ...], + download: bool = False, + organelles: Optional[Union[List[str], str]] = None, + volume_ids: Optional[Union[List[str], str]] = None, + **kwargs +) -> Dataset: + """Get dataset for segmentation of organelles in FIB-SEM cells. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + patch_shape: The patch shape to use for training. + download: Whether to download the data if it is not present. + organelles: The choice of organelles. + volume_ids: The choice of volumes. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The segmentation dataset. """ # let's get the choice of organelles sorted organelles = _check_input_args(organelles, ORGANELLES) @@ -164,16 +208,18 @@ def get_asem_dataset( assert volume_id in ORGANELLES[organelle], \ f"The chosen volume and organelle combination does not match: '{volume_id}' & '{organelle}'" - volume_paths = _download_asem_dataset(path, volume_ids, download) + volume_paths = get_asem_paths(path, volume_ids, download) for volume_path in volume_paths: have_volumes_inconsistent = _make_volumes_consistent(volume_path, organelle) - raw_key = f"volumes/raw_{organelle}" if have_volumes_inconsistent else "volumes/raw" dataset = torch_em.default_segmentation_dataset( - volume_path, raw_key, - volume_path, f"volumes/labels/{organelle}", - patch_shape, ndim=ndim, is_seg_dataset=True, + raw_paths=volume_path, + raw_key=f"volumes/raw_{organelle}" if have_volumes_inconsistent else "volumes/raw", + label_paths=volume_path, + label_key=f"volumes/labels/{organelle}", + patch_shape=patch_shape, + is_seg_dataset=True, **kwargs ) dataset.max_sampling_attempts = 5000 @@ -183,10 +229,29 @@ def get_asem_dataset( def get_asem_loader( - path, patch_shape, batch_size, ndim, download=False, organelles=None, volume_ids=None, **kwargs -): - """Dataloader for organelle segmentation in FIB-SEM cells. See `get_asem_dataset` for details.""" + path: Union[os.PathLike, str], + patch_shape: Tuple[int, ...], + batch_size: int, + download: bool = False, + organelles: Optional[Union[List[str], str]] = None, + volume_ids: Optional[Union[List[str], str]] = None, + **kwargs +) -> DataLoader: + """Get dataloader for the segmentation of organelles in FIB-SEM cells. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + patch_shape: The patch shape to use for training. + batch_size: The batch size for training. + download: Whether to download the data if it is not present. + organelles: The choice of organelles. + volume_ids: The choice of volumes. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. + + Returns: + The DataLoader. + """ ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) - ds = get_asem_dataset(path, patch_shape, ndim, download, organelles, volume_ids, **ds_kwargs) + ds = get_asem_dataset(path, patch_shape, download, organelles, volume_ids, **ds_kwargs) loader = torch_em.get_data_loader(ds, batch_size=batch_size, **loader_kwargs) return loader diff --git a/torch_em/data/datasets/axondeepseg.py b/torch_em/data/datasets/electron_microscopy/axondeepseg.py similarity index 58% rename from torch_em/data/datasets/axondeepseg.py rename to torch_em/data/datasets/electron_microscopy/axondeepseg.py index 0886f378..4fb386c0 100644 --- a/torch_em/data/datasets/axondeepseg.py +++ b/torch_em/data/datasets/electron_microscopy/axondeepseg.py @@ -1,13 +1,23 @@ +"""AxonDeepSeg is a dataset for the segmentation of myelinated axons in EM. +It contains two different data types: TEM and SEM. + +The dataset was published in https://doi.org/10.1038/s41598-018-22181-4. +Please cite this publication if you use the dataset in your research. +""" + import os from glob import glob from shutil import rmtree +from typing import Optional, Tuple, Union, Literal, List import imageio -import h5py import numpy as np + +from torch.utils.data import Dataset, DataLoader + import torch_em -from . import util +from .. import util URLS = { "sem": "https://github.com/axondeepseg/data_axondeepseg_sem/archive/refs/heads/master.zip", @@ -20,6 +30,8 @@ def _preprocess_sem_data(out_path): + import h5py + # preprocess the data to get it to a better data format data_root = os.path.join(out_path, "data_axondeepseg_sem-master") assert os.path.exists(data_root) @@ -78,6 +90,8 @@ def _preprocess_sem_data(out_path): def _preprocess_tem_data(out_path): + import h5py + data_root = os.path.join(out_path, "TEM_dataset") folder_names = os.listdir(data_root) folders = [os.path.join(data_root, fname) for fname in folder_names @@ -103,7 +117,17 @@ def _preprocess_tem_data(out_path): rmtree(data_root) -def _require_axondeepseg_data(path, name, download): +def get_axondeepseg_data(path: Union[str, os.PathLike], name: Literal["sem", "tem"], download: bool = False) -> str: + """Download the AxonDeepSeg data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + name: The name of the dataset to download. Can be either 'sem' or 'tem'. + download: Whether to download the data if it is not present. + + Returns: + The filepath for the downloaded data. + """ # download and unzip the data url, checksum = URLS[name], CHECKSUMS[name] @@ -120,25 +144,34 @@ def _require_axondeepseg_data(path, name, download): _preprocess_sem_data(out_path) elif name == "tem": _preprocess_tem_data(out_path) + else: + raise ValueError(f"Invalid dataset name for axondeepseg, expected 'sem' or 'tem', got {name}.") return out_path -def get_axondeepseg_dataset( - path, name, patch_shape, download=False, one_hot_encoding=False, val_fraction=None, split=None, **kwargs -): - """Dataset for the segmentation of myelinated axons in EM. +def get_axondeepseg_paths( + path: Union[str, os.PathLike], + name: Literal["sem", "tem"], + download: bool = False, + val_fraction: Optional[float] = None, + split: Optional[str] = None, +) -> List[str]: + """Get paths to the AxonDeepSeg data. - This dataset is from the publication https://doi.org/10.1038/s41598-018-22181-4. - Please cite it if you use this dataset for a publication. - """ - if isinstance(name, str): - name = [name] - assert isinstance(name, (tuple, list)) + Args: + path: Filepath to a folder where the downloaded data will be saved. + name: The name of the dataset to download. Can be either 'sem' or 'tem'. + download: Whether to download the data if it is not present. + val_fraction: The fraction of the data to use for validation. + split: The data split. Either 'train' or 'val'. + Returns: + List of paths for all the data. + """ all_paths = [] for nn in name: - data_root = _require_axondeepseg_data(path, nn, download) + data_root = get_axondeepseg_data(path, nn, download) paths = glob(os.path.join(data_root, "*.h5")) paths.sort() if val_fraction is not None: @@ -147,6 +180,40 @@ def get_axondeepseg_dataset( paths = paths[:n_samples] if split == "train" else paths[n_samples:] all_paths.extend(paths) + return all_paths + + +def get_axondeepseg_dataset( + path: Union[str, os.PathLike], + name: Literal["sem", "tem"], + patch_shape: Tuple[int, int], + download: bool = False, + one_hot_encoding: bool = False, + val_fraction: Optional[float] = None, + split: Optional[Literal['train', 'val']] = None, + **kwargs, +) -> Dataset: + """Get dataset for segmentation of myelinated axons. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + name: The name of the dataset to download. Can be either 'sem' or 'tem'. + patch_shape: The patch shape to use for training. + download: Whether to download the data if it is not present. + one_hot_encoding: Whether to return the labels as one hot encoding. + val_fraction: The fraction of the data to use for validation. + split: The data split. Either 'train' or 'val'. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The segmentation dataset. + """ + if isinstance(name, str): + name = [name] + assert isinstance(name, (tuple, list)) + + all_paths = get_axondeepseg_paths(path, name, download, val_fraction, split) + if one_hot_encoding: if isinstance(one_hot_encoding, bool): # add transformation to go from [0, 1, 2] to one hot encoding @@ -163,17 +230,42 @@ def get_axondeepseg_dataset( msg = "'one_hot' is set to True, but 'label_transform' is in the kwargs. It will be over-ridden." kwargs = util.update_kwargs(kwargs, "label_transform", label_transform, msg=msg) - raw_key, label_key = "raw", "labels" - return torch_em.default_segmentation_dataset(all_paths, raw_key, all_paths, label_key, patch_shape, **kwargs) + return torch_em.default_segmentation_dataset( + raw_paths=all_paths, + raw_key="raw", + label_paths=all_paths, + label_key="labels", + patch_shape=patch_shape, + **kwargs + ) -# add instance segmentation representations? def get_axondeepseg_loader( - path, name, patch_shape, batch_size, - download=False, one_hot_encoding=False, - val_fraction=None, split=None, **kwargs -): - """Dataloader for the segmentation of myelinated axons. See 'get_axondeepseg_dataset' for details. + path: Union[str, os.PathLike], + name: Literal["sem", "tem"], + patch_shape: Tuple[int, int], + batch_size: int, + download: bool = False, + one_hot_encoding: bool = False, + val_fraction: Optional[float] = None, + split: Optional[Literal["train", "val"]] = None, + **kwargs +) -> DataLoader: + """Get dataloader for the segmentation of myelinated axons. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + name: The name of the dataset to download. Can be either 'sem' or 'tem'. + patch_shape: The patch shape to use for training. + batch_size: The batch size for training. + download: Whether to download the data if it is not present. + one_hot_encoding: Whether to return the labels as one hot encoding. + val_fraction: The fraction of the data to use for validation. + split: The data split. Either 'train' or 'val'. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. + + Returns: + The PyTorch DataLoader. """ ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) dataset = get_axondeepseg_dataset( diff --git a/torch_em/data/datasets/electron_microscopy/cem.py b/torch_em/data/datasets/electron_microscopy/cem.py new file mode 100644 index 00000000..d8342539 --- /dev/null +++ b/torch_em/data/datasets/electron_microscopy/cem.py @@ -0,0 +1,382 @@ +"""The CEM, or MitoLab, dataset is a collection of data for +training mitochondria generalist models. It consists of: +- CEM-MitoLab: annotated 2d data for training mitochondria segmentation models + - https://www.ebi.ac.uk/empiar/EMPIAR-11037/ +- CEM-Mito-Benchmark: 7 Benchmark datasets for mitochondria segmentation + - https://www.ebi.ac.uk/empiar/EMPIAR-10982/ +- CEM-1.5M: unlabeled EM images for pretraining: (Not yet implemented) + - https://www.ebi.ac.uk/empiar/EMPIAR-11035/ + +These datasets are from the publication https://doi.org/10.1016/j.cels.2022.12.006. +Please cite this publication if you use this data in your research. + +The data itself can be downloaded from EMPIAR via aspera. +- You can install aspera via mamba. We recommend to do this in a separate environment + to avoid dependency issues: + - `$ mamba create -c conda-forge -c hcc -n aspera aspera-cli` +- After this you can run `$ mamba activate aspera` to have an environment with aspera installed. +- You can then download the data for one of the three datasets like this: + - ascp -QT -l 200m -P33001 -i /etc/asperaweb_id_dsa.openssh emp_ext2@fasp.ebi.ac.uk:/ + - Where is the path to the mamba environment, the id of one of the three datasets + and where you want to download the data. +- After this you can use the functions in this file if you use as location for the data. + +Note that we have implemented automatic download, but this leads to dependency +issues, so we recommend to download the data manually and then run the loaders with the correct path. +""" + +import os +import json +from glob import glob +from typing import List, Tuple, Union, Literal + +import numpy as np +import imageio.v3 as imageio +from sklearn.model_selection import train_test_split + +from torch.utils.data import Dataset, DataLoader + +import torch_em + +from .. import util + + +BENCHMARK_DATASETS = { + 1: "mito_benchmarks/c_elegans", + 2: "mito_benchmarks/fly_brain", + 3: "mito_benchmarks/glycolytic_muscle", + 4: "mito_benchmarks/hela_cell", + 5: "mito_benchmarks/lucchi_pp", + 6: "mito_benchmarks/salivary_gland", + 7: "tem_benchmark", +} +BENCHMARK_SHAPES = { + 1: (256, 256, 256), + 2: (256, 255, 255), + 3: (302, 383, 765), + 4: (256, 256, 256), + 5: (165, 768, 1024), + 6: (1260, 1081, 1200), + 7: (224, 224), # NOTE: this is the minimal square shape that fits +} + + +def _get_all_images(path): + raw_paths, label_paths = [], [] + folders = glob(os.path.join(path, "*")) + assert all(os.path.isdir(folder) for folder in folders) + for folder in folders: + images = sorted(glob(os.path.join(folder, "images", "*.tiff"))) + assert len(images) > 0 + labels = sorted(glob(os.path.join(folder, "masks", "*.tiff"))) + assert len(images) == len(labels) + raw_paths.extend(images) + label_paths.extend(labels) + return raw_paths, label_paths + + +def _get_non_empty_images(path): + save_path = os.path.join(path, "non_empty_images.json") + + if os.path.exists(save_path): + with open(save_path, "r") as f: + saved_images = json.load(f) + raw_paths, label_paths = saved_images["images"], saved_images["labels"] + raw_paths = [os.path.join(path, rp) for rp in raw_paths] + label_paths = [os.path.join(path, lp) for lp in label_paths] + return raw_paths, label_paths + + folders = glob(os.path.join(path, "*")) + assert all(os.path.isdir(folder) for folder in folders) + + raw_paths, label_paths = [], [] + for folder in folders: + images = sorted(glob(os.path.join(folder, "images", "*.tiff"))) + labels = sorted(glob(os.path.join(folder, "masks", "*.tiff"))) + assert len(images) > 0 + assert len(images) == len(labels) + + for im, lab in zip(images, labels): + n_labels = len(np.unique(imageio.imread(lab))) + if n_labels > 1: + raw_paths.append(im) + label_paths.append(lab) + + raw_paths_rel = [os.path.relpath(rp, path) for rp in raw_paths] + label_paths_rel = [os.path.relpath(lp, path) for lp in label_paths] + + with open(save_path, "w") as f: + json.dump({"images": raw_paths_rel, "labels": label_paths_rel}, f) + + return raw_paths, label_paths + + +def get_mitolab_data(path: Union[os.PathLike, str], download: bool = False) -> str: + """Download the MitoLab training data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + download: Whether to download the data if it is not present. + + Returns: + The filepath for the downloaded data. + """ + access_id = "11037" + data_path = util.download_source_empiar(path, access_id, download) + + zip_path = os.path.join(data_path, "data/cem_mitolab.zip") + if os.path.exists(zip_path): + util.unzip(zip_path, data_path, remove=True) + + data_root = os.path.join(data_path, "cem_mitolab") + assert os.path.exists(data_root) + + return data_root + + +def get_mitolab_paths( + path: Union[os.PathLike, str], + split: Literal['train', 'val'], + val_fraction: float = 0.05, + download: bool = False, + discard_empty_images: bool = True, +) -> Tuple[List[str], List[str]]: + """Get the paths to MitoLab training data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The data split. Either 'train' or 'val'. + val_fraction: The fraction of the data to use for validation. + download: Whether to download the data if it is not present. + discard_empty_images: Whether to discard images without annotations. + + Returns: + List of the image data paths. + List of the label data paths. + """ + data_path = get_mitolab_data(path, download) + + if discard_empty_images: + raw_paths, label_paths = _get_non_empty_images(data_path) + else: + raw_paths, label_paths = _get_all_images(data_path) + + if split is not None: + raw_train, raw_val, labels_train, labels_val = train_test_split( + raw_paths, label_paths, test_size=val_fraction, random_state=42, + ) + if split == "train": + raw_paths, label_paths = raw_train, labels_train + else: + raw_paths, label_paths = raw_val, labels_val + + assert len(raw_paths) > 0 + assert len(raw_paths) == len(label_paths) + return raw_paths, label_paths + + +def get_benchmark_data(path: Union[os.PathLike, str], dataset_id: int, download: bool = False) -> str: + """Download the MitoLab benchmark data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + dataset_id: The id of the benchmark dataset to download. + download: Whether to download the data if it is not present. + + Returns: + The filepath for the stored data. + """ + access_id = "10982" + data_path = util.download_source_empiar(path, access_id, download) + dataset_path = os.path.join(data_path, "data", BENCHMARK_DATASETS[dataset_id]) + return dataset_path + + +def get_benchmark_paths( + path: Union[os.PathLike, str], dataset_id: int, download: bool = False +) -> Tuple[List[str], List[str], str, str, bool]: + """Get paths to the MitoLab benchmark data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + dataset_id: The id of the benchmark dataset to download. + download: Whether to download the data if it is not present. + + Returns: + List of the image data paths. + List of the label data paths. + The image data key. + The label data key. + Whether this is a segmentation dataset. + """ + dataset_path = get_benchmark_data(path, dataset_id, download) + + # these are the 3d datasets + if dataset_id in range(1, 7): + dataset_name = os.path.basename(dataset_path) + raw_paths = os.path.join(dataset_path, f"{dataset_name}_em.tif") + label_paths = os.path.join(dataset_path, f"{dataset_name}_mito.tif") + raw_key, label_key = None, None + is_seg_dataset = True + + # this is the 2d dataset + else: + raw_paths = os.path.join(dataset_path, "images") + label_paths = os.path.join(dataset_path, "masks") + raw_key, label_key = "*.tiff", "*.tiff" + is_seg_dataset = False + + return raw_paths, label_paths, raw_key, label_key, is_seg_dataset + + +# +# Datasets +# + + +def get_mitolab_dataset( + path: Union[os.PathLike, str], + split: Literal['train', 'val'], + patch_shape: Tuple[int, int] = (224, 224), + val_fraction: float = 0.05, + download: bool = False, + discard_empty_images: bool = True, + **kwargs +) -> Dataset: + """Get the dataset for the MitoLab training data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The data split. Either 'train' or 'val'. + patch_shape: The patch shape to use for training. + val_fraction: The fraction of the data to use for validation. + download: Whether to download the data if it is not present. + discard_empty_images: Whether to discard images without annotations. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The segmentation dataset. + """ + assert split in ("train", "val", None) + assert os.path.exists(path) + + raw_paths, label_paths = get_mitolab_paths(path, split, val_fraction, download, discard_empty_images) + + return torch_em.default_segmentation_dataset( + raw_paths=raw_paths, + raw_key=None, + label_paths=label_paths, + label_key=None, + patch_shape=patch_shape, + is_seg_dataset=False, + ndim=2, + **kwargs + ) + + +def get_cem15m_dataset(path): + raise NotImplementedError + + +def get_benchmark_dataset( + path: Union[os.PathLike, str], dataset_id: int, patch_shape: Tuple[int, int], download: bool = False, **kwargs +) -> Dataset: + """Get the dataset for one of the mitolab benchmark datasets. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + dataset_id: The id of the benchmark dataset to download. + patch_shape: The patch shape to use for training. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The segmentation dataset. + """ + if dataset_id not in range(1, 8): + raise ValueError(f"Invalid dataset id {dataset_id}, expected id in range [1, 7].") + + raw_paths, label_paths, raw_key, label_key, is_seg_dataset = get_benchmark_paths(path, dataset_id, download) + + return torch_em.default_segmentation_dataset( + raw_paths=raw_paths, + raw_key=raw_key, + label_paths=label_paths, + label_key=label_key, + patch_shape=patch_shape, + is_seg_dataset=is_seg_dataset, + **kwargs, + ) + + +# +# DataLoaders +# + + +def get_mitolab_loader( + path: Union[os.PathLike, str], + split: str, + batch_size: int, + patch_shape: Tuple[int, int] = (224, 224), + discard_empty_images: bool = True, + val_fraction: float = 0.05, + download: bool = False, + **kwargs +) -> DataLoader: + """Get the dataloader for the MitoLab training data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The data split. Either 'train' or 'val'. + batch_size: The batch size for training. + patch_shape: The patch shape to use for training. + discard_empty_images: Whether to discard images without annotations. + val_fraction: The fraction of the data to use for validation. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. + + Returns: + The PyTorch DataLoader. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_mitolab_dataset( + path=path, + split=split, + patch_shape=patch_shape, + val_fraction=val_fraction, + download=download, + discard_empty_images=discard_empty_images, + **ds_kwargs + ) + return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) + + +def get_cem15m_loader(path): + raise NotImplementedError + + +def get_benchmark_loader( + path: Union[os.PathLike, str], + dataset_id: int, + batch_size: int, + patch_shape: Tuple[int, int], + download: bool = False, + **kwargs +) -> DataLoader: + """Get the dataloader for one of the MitoLab benchmark datasets. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + dataset_id: The id of the benchmark dataset to download. + batch_size: The batch size for training. + patch_shape: The patch shape to use for training. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. + + Returns: + The DataLoader. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_benchmark_dataset(path, dataset_id, patch_shape=patch_shape, download=download, **ds_kwargs) + return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/electron_microscopy/cremi.py b/torch_em/data/datasets/electron_microscopy/cremi.py new file mode 100644 index 00000000..c72f94f9 --- /dev/null +++ b/torch_em/data/datasets/electron_microscopy/cremi.py @@ -0,0 +1,216 @@ +"""CREMI is a dataset for neuron segmentation in EM. + +It contains three annotated volumes from the adult fruit-fly brain. +It was held as a challenge at MICCAI 2016. For details on the dataset check out https://cremi.org/. +Please cite the challenge if you use the dataset in your research. +""" +# TODO add support for realigned volumes + +import os +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np + +from torch.utils.data import Dataset, DataLoader + +import torch_em + +from .. import util + + +CREMI_URLS = { + "original": { + "A": "https://cremi.org/static/data/sample_A_20160501.hdf", + "B": "https://cremi.org/static/data/sample_B_20160501.hdf", + "C": "https://cremi.org/static/data/sample_C_20160501.hdf", + }, + "realigned": {}, + "defects": "https://zenodo.org/record/5767036/files/sample_ABC_padded_defects.h5" +} +CHECKSUMS = { + "original": { + "A": "4c563d1b78acb2bcfb3ea958b6fe1533422f7f4a19f3e05b600bfa11430b510d", + "B": "887e85521e00deead18c94a21ad71f278d88a5214c7edeed943130a1f4bb48b8", + "C": "2874496f224d222ebc29d0e4753e8c458093e1d37bc53acd1b69b19ed1ae7052", + }, + "realigned": {}, + "defects": "7b06ffa34733b2c32956ea5005e0cf345e7d3a27477f42f7c905701cdc947bd0" +} + + +def get_cremi_data(path: Union[os.PathLike, str], samples: Tuple[str], download: bool, use_realigned: bool = False): + """Download the CREMI training data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + samples: The CREMI samples to use. The available samples are 'A', 'B', 'C'. + download: Whether to download the data if it is not present. + use_realigned: Use the realigned instead of the original training data. + """ + if use_realigned: + # we need to sample batches in this case + # sampler = torch_em.data.MinForegroundSampler(min_fraction=0.05, p_reject=.75) + raise NotImplementedError + else: + urls = CREMI_URLS["original"] + checksums = CHECKSUMS["original"] + + os.makedirs(path, exist_ok=True) + for name in samples: + url = urls[name] + checksum = checksums[name] + data_path = os.path.join(path, f"sample{name}.h5") + # CREMI SSL certificates expired, so we need to disable verification + util.download_source(data_path, url, download, checksum, verify=False) + + +def get_cremi_paths( + path: Union[os.PathLike, str], + samples: Tuple[str, ...] = ("A", "B", "C"), + use_realigned: bool = False, + download: bool = False +) -> List[str]: + """Get paths to the CREMI data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + samples: The CREMI samples to use. The available samples are 'A', 'B', 'C'. + use_realigned: Use the realigned instead of the original training data. + download: Whether to download the data if it is not present. + + Returns: + The filepaths to the training data. + """ + get_cremi_data(path, samples, download, use_realigned) + data_paths = [os.path.join(path, f"sample{name}.h5") for name in samples] + return data_paths + + +def get_cremi_dataset( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int, int], + samples: Tuple[str, ...] = ("A", "B", "C"), + use_realigned: bool = False, + download: bool = False, + offsets: Optional[List[List[int]]] = None, + boundaries: bool = False, + rois: Dict[str, Any] = {}, + defect_augmentation_kwargs: Dict[str, Any] = { + "p_drop_slice": 0.025, + "p_low_contrast": 0.025, + "p_deform_slice": 0.0, + "deformation_mode": "compress", + }, + **kwargs, +) -> Dataset: + """Get the CREMI dataset for the segmentation of neurons in EM. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + patch_shape: The patch shape to use for training. + samples: The CREMI samples to use. The available samples are 'A', 'B', 'C'. + use_realigned: Use the realigned instead of the original training data. + download: Whether to download the data if it is not present. + offsets: Offset values for affinity computation used as target. + boundaries: Whether to compute boundaries as the target. + rois: The region of interests to use for the samples. + defect_augmentation_kwargs: Keyword arguments for defect augmentations. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The segmentation dataset. + """ + assert len(patch_shape) == 3 + if rois is not None: + assert isinstance(rois, dict) + + data_paths = get_cremi_paths(path, samples, use_realigned, download) + data_rois = [rois.get(name, np.s_[:, :, :]) for name in samples] + + if defect_augmentation_kwargs is not None and "artifact_source" not in defect_augmentation_kwargs: + # download the defect volume + url = CREMI_URLS["defects"] + checksum = CHECKSUMS["defects"] + defect_path = os.path.join(path, "cremi_defects.h5") + util.download_source(defect_path, url, download, checksum) + defect_patch_shape = (1,) + tuple(patch_shape[1:]) + artifact_source = torch_em.transform.get_artifact_source( + defect_path, defect_patch_shape, + min_mask_fraction=0.75, + raw_key="defect_sections/raw", + mask_key="defect_sections/mask" + ) + defect_augmentation_kwargs.update({"artifact_source": artifact_source}) + + # defect augmentations + if defect_augmentation_kwargs is not None: + raw_transform = torch_em.transform.get_raw_transform( + augmentation1=torch_em.transform.EMDefectAugmentation(**defect_augmentation_kwargs) + ) + kwargs = util.update_kwargs(kwargs, "raw_transform", raw_transform) + + kwargs, _ = util.add_instance_label_transform( + kwargs, add_binary_target=False, boundaries=boundaries, offsets=offsets + ) + + return torch_em.default_segmentation_dataset( + raw_paths=data_paths, + raw_key="volumes/raw", + label_paths=data_paths, + label_key="volumes/labels/neuron_ids", + patch_shape=patch_shape, + rois=data_rois, + **kwargs + ) + + +def get_cremi_loader( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int, int], + batch_size: int, + samples: Tuple[str, ...] = ("A", "B", "C"), + use_realigned: bool = False, + download: bool = False, + offsets: Optional[List[List[int]]] = None, + boundaries: bool = False, + rois: Dict[str, Any] = {}, + defect_augmentation_kwargs: Dict[str, Any] = { + "p_drop_slice": 0.025, + "p_low_contrast": 0.025, + "p_deform_slice": 0.0, + "deformation_mode": "compress", + }, + **kwargs, +) -> DataLoader: + """Get the DataLoader for EM neuron segmentation in the CREMI dataset. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + patch_shape: The patch shape to use for training. + batch_size: The batch size for training. + samples: The CREMI samples to use. The available samples are 'A', 'B', 'C'. + use_realigned: Use the realigned instead of the original training data. + download: Whether to download the data if it is not present. + offsets: Offset values for affinity computation used as target. + boundaries: Whether to compute boundaries as the target. + rois: The region of interests to use for the samples. + defect_augmentation_kwargs: Keyword arguments for defect augmentations. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. + + Returns: + The DataLoader. + """ + dataset_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + ds = get_cremi_dataset( + path=path, + patch_shape=patch_shape, + samples=samples, + use_realigned=use_realigned, + download=download, + offsets=offsets, + boundaries=boundaries, + rois=rois, + defect_augmentation_kwargs=defect_augmentation_kwargs, + **dataset_kwargs, + ) + return torch_em.get_data_loader(ds, batch_size=batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/electron_microscopy/deepict.py b/torch_em/data/datasets/electron_microscopy/deepict.py new file mode 100644 index 00000000..b9b9f023 --- /dev/null +++ b/torch_em/data/datasets/electron_microscopy/deepict.py @@ -0,0 +1,178 @@ +"""Dataset for segmentation of structures in Cryo ET. +The DeePict dataset contains annotations for several structures in CryoET. +The dataset implemented here currently only provides access to the actin annotations. + +The dataset is part of the publication https://doi.org/10.1038/s41592-022-01746-2. +Plase cite it if you use this dataset in your research. +""" + +import os +from glob import glob +from shutil import rmtree +from typing import Tuple, Union, List + +from torch.utils.data import Dataset, DataLoader + +try: + import mrcfile +except ImportError: + mrcfile = None + +import torch_em + +from .. import util + + +ACTIN_ID = 10002 + + +def _process_deepict_actin(input_path, output_path): + from elf.io import open_file + + os.makedirs(output_path, exist_ok=True) + + # datasets = ["00004", "00011", "00012"] + # There are issues with the 00011 dataset + datasets = ["00004", "00012"] + for dataset in datasets: + ds_folder = os.path.join(input_path, dataset) + assert os.path.exists(ds_folder) + ds_out = os.path.join(output_path, f"{dataset}.h5") + if os.path.exists(ds_out): + continue + + assert mrcfile is not None, "Plese install mrcfile" + + tomo_folder = glob(os.path.join(ds_folder, "Tomograms", "VoxelSpacing*")) + assert len(tomo_folder) == 1 + tomo_folder = tomo_folder[0] + + annotation_folder = os.path.join(tomo_folder, "Annotations") + annotion_files = glob(os.path.join(annotation_folder, "*.zarr")) + + tomo_path = os.path.join(tomo_folder, "CanonicalTomogram", f"{dataset}.mrc") + with mrcfile.open(tomo_path, "r") as f: + data = f.data[:] + + annotations = {} + for annotation in annotion_files: + with open_file(annotation, "r") as f: + annotation_data = f["0"][:].astype("uint8") + assert annotation_data.shape == data.shape + annotation_name = os.path.basename(annotation).split("-")[1] + annotations[annotation_name] = annotation_data + + with open_file(ds_out, "a") as f: + f.create_dataset("raw", data=data, compression="gzip") + for name, annotation in annotations.items(): + f.create_dataset(f"labels/original/{name}", data=annotation, compression="gzip") + + # Create combined annotations for actin + actin_seg = annotations["actin_deepict_training_prediction"] + actin_seg2 = annotations["actin_ground_truth"] + actin_seg[actin_seg2 == 1] = 1 + f.create_dataset("labels/actin", data=actin_seg, compression="gzip") + + +def get_deepict_actin_data(path: Union[os.PathLike, str], download: bool) -> str: + """Download the DeePict actin dataset. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + download: Whether to download the data if it is not present. + + Returns: + The path to the downloaded data. + """ + # Check if the processed data is already present. + dataset_path = os.path.join(path, "deepict_actin") + if os.path.exists(dataset_path): + return dataset_path + + # Otherwise download the data. + dl_path = util.download_from_cryo_et_portal(path, ACTIN_ID, download) + + # And then process it. + _process_deepict_actin(dl_path, dataset_path) + + # Clean up the original data after processing. + rmtree(dl_path) + + return dataset_path + + +def get_deepict_actin_paths(path: Union[os.PathLike, str], download: bool = False) -> List[str]: + """Get paths to DeePict actin data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + download: Whether to download the data if it is not present. + + Returns: + The filepaths to the stored data. + """ + get_deepict_actin_data(path, download) + data_paths = sorted(glob(os.path.join(path, "deepict_actin", "*.h5"))) + return data_paths + + +def get_deepict_actin_dataset( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int, int], + label_key: str = "labels/actin", + download: bool = False, + **kwargs +) -> Dataset: + """Get the dataset for actin segmentation in Cryo ET data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + patch_shape: The patch shape to use for training. + label_key: The key for the labels to load. By default this uses 'labels/actin', + which holds the best version of actin ground-truth images. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The segmentation dataset. + """ + assert len(patch_shape) == 3 + + data_paths = get_deepict_actin_paths(path, download) + + return torch_em.default_segmentation_dataset( + raw_paths=data_paths, + raw_key="raw", + label_paths=data_paths, + label_key=label_key, + patch_shape=patch_shape, + is_seg_dataset=True, + **kwargs + ) + + +def get_deepict_actin_loader( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int, int], + batch_size: int, + label_key: str = "labels/actin", + download: bool = False, + **kwargs +) -> DataLoader: + """Get the DataLoader for actin segmentation in CryoET data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + patch_shape: The patch shape to use for training. + batch_size: The batch size for training. + label_key: The key for the labels to load. By default this uses 'labels/actin', + which holds the best version of actin ground-truth images. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. + + Returns: + The DataLoader. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_deepict_actin_dataset(path, patch_shape, label_key=label_key, download=download, **ds_kwargs) + return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/electron_microscopy/emneuron.py b/torch_em/data/datasets/electron_microscopy/emneuron.py new file mode 100644 index 00000000..bb0b67e1 --- /dev/null +++ b/torch_em/data/datasets/electron_microscopy/emneuron.py @@ -0,0 +1,151 @@ +"""EMNeuron is a dataset for neuron segmentation in EM. +It contains multiple annotated volumes from 16 domain sources. + +The dataset is hosted at https://huggingface.co/datasets/yanchaoz/EMNeuron. +The dataset is published in https://papers.miccai.org/miccai-2024/677-Paper0518.html. +Please cite this publication if you use the dataset in your research. +""" + +import os +import shutil +from glob import glob +from natsort import natsorted +from typing import Union, Tuple, List, Literal + +from torch.utils.data import Dataset, DataLoader + +import torch_em + +from .. import util + + +def _clean_redundant_files(path): + # The "InDistribution" directory is redundant. + target_dir = os.path.join(path, "valid", "InDistribution", "InDistribution") + if os.path.exists(target_dir): + shutil.rmtree(target_dir) + + +def get_emneuron_data(path: Union[os.PathLike, str], split: Literal['train', 'val'], download: bool = False): + """Get the EMNeuron data. + + NOTE: The automatic download feature is currently not supported in `get_emneuron_data`. + You must follow the steps mentioned to download the data: + - Go to the official GitHub repository: https://github.com/yanchaoz/SegNeuron. + - Access the dataset link (hosted at HuggingFace): https://huggingface.co/datasets/yanchaoz/EMNeuron. + - Login / create your account to access the "Dataset Card". + - Go to "Files" in the dataset repo and download a) `labeled.rar` and b) `valid.rar`. + - Finally, provide the filepath to the folder where rar files are stored. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The split of the data to be used for training. + download: Whether to download the data if it is not present. + """ + if download: + raise ValueError() + + os.makedirs(path, exist_ok=True) + + if split == "train": + rar_path = os.path.join(path, "labeled.rar") + elif split == "val": + rar_path = os.path.join(path, "valid.rar") + else: + raise ValueError(f"'{split}' is not a valid split. Please choose either 'train' or 'val'.") + + if os.path.exists(os.path.splitext(rar_path)[0]): + return + + util.unzip_rarfile(rar_path=rar_path, dst=path, remove=False, use_rarfile=False) + + _clean_redundant_files(path) + + +def get_emneuron_paths( + path: Union[os.PathLike, str], split: Literal['train', 'val'], download: bool = False +) -> List[str]: + """Get paths to the EMNeuron data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The split of the data to be used for training. + download: Whether to download the data if it is not present. + + Returns: + List of filepaths to the stored data. + """ + get_emneuron_data(path, split, download) + if split == "train": + label_paths = natsorted(glob(os.path.join(path, "labeled", "*", "*_MaskIns.tif"))) + raw_paths = [os.path.join(os.path.dirname(p), os.path.basename(p).replace("_MaskIns", "")) for p in label_paths] + + else: # 'val' split + raw_paths = natsorted(glob(os.path.join(path, "valid", "*", "*", "raw.tif"))) + label_paths = [ + os.path.join(os.path.dirname(p), "label_0.tif") + if os.path.exists(os.path.join(os.path.dirname(p), "label_0.tif")) + else os.path.join(os.path.dirname(p), "label.tif") for p in raw_paths + ] + + assert len(raw_paths) == len(label_paths) + return raw_paths, label_paths + + +def get_emneuron_dataset( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, ...], + split: Literal['train', 'val'], + download: bool = False, + **kwargs +) -> Dataset: + """Get the dataset for EMNeuron dataset. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + batch_size: The batch size for training. + patch_shape: The patch shape to use for training. + split: The split of the data to be used for training. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The segmentation dataset. + """ + raw_paths, label_paths = get_emneuron_paths(path, split, download) + + return torch_em.default_segmentation_dataset( + raw_paths=raw_paths, + raw_key=None, + label_paths=label_paths, + label_key=None, + patch_shape=patch_shape, + is_seg_dataset=True, + **kwargs + ) + + +def get_emneuron_loader( + path: Union[os.PathLike, str], + batch_size: int, + patch_shape: Tuple[int, ...], + split: Literal['train', 'val'], + download: bool = False, + **kwargs +) -> DataLoader: + """Get the dataloader for EMNeuron dataset. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + batch_size: The batch size for training. + patch_shape: The patch shape to use for training. + split: The split of the data to be used for training. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. + + Returns: + The DataLoader. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_emneuron_dataset(path=path, patch_shape=patch_shape, split=split, download=download, **ds_kwargs) + return torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/electron_microscopy/isbi2012.py b/torch_em/data/datasets/electron_microscopy/isbi2012.py new file mode 100644 index 00000000..b117d76f --- /dev/null +++ b/torch_em/data/datasets/electron_microscopy/isbi2012.py @@ -0,0 +1,121 @@ +"""The ISBI2012 dataset was the first neuron segmentation challenge, held at the ISBI 2012 competition. +It contains a small annotated EM volume from the fruit-fly brain. + +If you use this dataset in your research please cite the following publication: +https://doi.org/10.3389/fnana.2015.00142. +""" + +import os +from typing import List, Optional, Tuple, Union + +from torch.utils.data import Dataset, DataLoader + +import torch_em +from .. import util + + +ISBI_URL = "https://oc.embl.de/index.php/s/h0TkwqxU0PJDdMd/download" +CHECKSUM = "0e10fe909a1243084d91773470856993b7d40126a12e85f0f1345a7a9e512f29" + + +def get_isbi_data(path: Union[os.PathLike, str], download: bool = False): + """Download the ISBI2012 dataset. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + download: Whether to download the data if it is not present. + """ + os.makedirs(path, exist_ok=True) + util.download_source(os.path.join(path, "isbi.h5"), ISBI_URL, download, CHECKSUM) + + +def get_isbi_paths(path: Union[os.PathLike, str], download: bool = False) -> str: + """Get path to ISBI data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + download: Whether to download the data if it is not present. + + Returns: + The filepath for the stored data. + """ + get_isbi_data(path, download) + volume_path = os.path.join(path, "isbi.h5") + return volume_path + + +def get_isbi_dataset( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int, int], + download: bool = False, + offsets: Optional[List[List[int]]] = None, + boundaries: bool = False, + use_original_labels: bool = False, + **kwargs +) -> Dataset: + """Get the dataset for EM neuron segmentation in ISBI 2012. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + patch_shape: The patch shape to use for training. + download: Whether to download the data if it is not present. + offsets: Offset values for affinity computation used as target. + boundaries: Whether to compute boundaries as the target. + use_original_labels: Whether to use the original annotations or postprocessed 3d annotations. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The segmentation dataset. + """ + assert len(patch_shape) == 3 + + volume_path = get_isbi_paths(path, download) + + ndim = 2 if patch_shape[0] == 1 else 3 + kwargs = util.update_kwargs(kwargs, "ndim", ndim) + + kwargs, _ = util.add_instance_label_transform( + kwargs, add_binary_target=False, boundaries=boundaries, offsets=offsets + ) + + return torch_em.default_segmentation_dataset( + raw_paths=volume_path, + raw_key="raw", + label_paths=volume_path, + label_key="labels/membranes" if use_original_labels else "labels/gt_segmentation", + patch_shape=patch_shape, + **kwargs + ) + + +def get_isbi_loader( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int, int], + batch_size: int, + download: bool = False, + offsets: Optional[List[List[int]]] = None, + boundaries: bool = False, + use_original_labels: bool = False, + **kwargs +) -> DataLoader: + """Get the DataLoader for EM neuron segmentation in ISBI 2012. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + patch_shape: The patch shape to use for training. + batch_size: The batch size for training. + download: Whether to download the data if it is not present. + offsets: Offset values for affinity computation used as target. + boundaries: Whether to compute boundaries as the target. + use_original_labels: Whether to use the original annotations or postprocessed 3d annotations. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. + + Returns: + The DataLoader. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_isbi_dataset( + path, patch_shape, download=download, offsets=offsets, + boundaries=boundaries, use_original_labels=use_original_labels, **ds_kwargs + ) + return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/electron_microscopy/kasthuri.py b/torch_em/data/datasets/electron_microscopy/kasthuri.py new file mode 100644 index 00000000..2e391e8f --- /dev/null +++ b/torch_em/data/datasets/electron_microscopy/kasthuri.py @@ -0,0 +1,173 @@ +"""The Kasthuri dataset is a segmentation dataset for mitochondrion segmentation in electron microscopy. + +The dataset was published in https://doi.org/10.48550/arXiv.1812.06024. +Please cite this publication if you use the dataset in your research. +We use the version of the dataset from https://sites.google.com/view/connectomics/. +""" + +import os +from glob import glob +from tqdm import tqdm +from shutil import rmtree +from concurrent import futures +from typing import Tuple, Union + +import imageio +import numpy as np + +import torch_em + +from torch.utils.data import Dataset, DataLoader + +from .. import util + + +URL = "http://www.casser.io/files/kasthuri_pp.zip " +CHECKSUM = "bbb78fd205ec9b57feb8f93ebbdf1666261cbc3e0305e7f11583ab5157a3d792" + +# TODO: add sampler for foreground (-1 is empty area) +# TODO: and masking for the empty space + + +def _load_volume(path): + files = glob(os.path.join(path, "*.png")) + files.sort() + nz = len(files) + + im0 = imageio.imread(files[0]) + out = np.zeros((nz,) + im0.shape, dtype=im0.dtype) + out[0] = im0 + + def _loadz(z): + im = imageio.imread(files[z]) + out[z] = im + + n_threads = 8 + with futures.ThreadPoolExecutor(n_threads) as tp: + list(tqdm( + tp.map(_loadz, range(1, nz)), desc="Load volume", total=nz-1 + )) + + return out + + +def _create_data(root, inputs, out_path): + import h5py + + raw = _load_volume(os.path.join(root, inputs[0])) + labels_argb = _load_volume(os.path.join(root, inputs[1])) + assert labels_argb.ndim == 4 + labels = np.zeros(raw.shape, dtype="int8") + + fg_mask = (labels_argb == np.array([255, 255, 255])[None, None, None]).all(axis=-1) + labels[fg_mask] = 1 + bg_mask = (labels_argb == np.array([2, 2, 2])[None, None, None]).all(axis=-1) + labels[bg_mask] = -1 + assert (np.unique(labels) == np.array([-1, 0, 1])).all() + assert raw.shape == labels.shape, f"{raw.shape}, {labels.shape}" + with h5py.File(out_path, "w") as f: + f.create_dataset("raw", data=raw, compression="gzip") + f.create_dataset("labels", data=labels, compression="gzip") + + +def get_kasthuri_data(path: Union[os.PathLike, str], download: bool = False) -> str: + """Download the kasthuri dataset. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + download: Whether to download the data if it is not present. + + Returns: + The filepath for the downloaded data. + """ + if os.path.exists(path): + return path + + os.makedirs(path) + tmp_path = os.path.join(path, "kasthuri.zip") + util.download_source(tmp_path, URL, download, checksum=CHECKSUM) + util.unzip(tmp_path, path, remove=True) + + root = os.path.join(path, "Kasthuri++") + assert os.path.exists(root), root + + inputs = [["Test_In", "Test_Out"], ["Train_In", "Train_Out"]] + outputs = ["kasthuri_train.h5", "kasthuri_test.h5"] + for inp, out in zip(inputs, outputs): + out_path = os.path.join(path, out) + _create_data(root, inp, out_path) + + rmtree(root) + return path + + +def get_kasthuri_paths(path: Union[os.PathLike, str], split: str, download: bool = False) -> str: + """Get paths to the Kasthuri data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The data split. Either 'train' or 'test'. + download: Whether to download the data if it is not present. + + Returns: + The filepath to the stored data. + """ + get_kasthuri_data(path, download) + data_path = os.path.join(path, f"kasthuri_{split}.h5") + assert os.path.exists(data_path), data_path + return data_path + + +def get_kasthuri_dataset( + path: Union[os.PathLike, str], split: str, patch_shape: Tuple[int, int, int], download: bool = False, **kwargs +) -> Dataset: + """Get dataset for EM mitochondrion segmentation in the kasthuri dataset. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The data split. Either 'train' or 'test'. + patch_shape: The patch shape to use for training. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The segmentation dataset. + """ + assert split in ("train", "test") + + data_path = get_kasthuri_paths(path, split, download) + + return torch_em.default_segmentation_dataset( + raw_paths=data_path, + raw_key="raw", + label_paths=data_path, + label_key="labels", + patch_shape=patch_shape, + **kwargs + ) + + +def get_kasthuri_loader( + path: Union[os.PathLike, str], + split: str, + patch_shape: Tuple[int, int, int], + batch_size: int, + download: bool = False, + **kwargs +) -> DataLoader: + """Get dataloader for EM mitochondrion segmentation in the kasthuri dataset. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The data split. Either 'train' or 'test'. + patch_shape: The patch shape to use for training. + batch_size: The batch size for training. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. + + Returns: + The PyTorch DataLoader. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_kasthuri_dataset(path, split, patch_shape, download=download, **ds_kwargs) + return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/electron_microscopy/lucchi.py b/torch_em/data/datasets/electron_microscopy/lucchi.py new file mode 100644 index 00000000..e4d5dec8 --- /dev/null +++ b/torch_em/data/datasets/electron_microscopy/lucchi.py @@ -0,0 +1,174 @@ +"""The Lucchi dataset is a segmentation dataset for mitochondrion segmentation in electron microscopy. + +The dataset was published in https://doi.org/10.48550/arXiv.1812.06024. +Please cite this publication if you use the dataset in your research. +We use the version of the dataset from https://sites.google.com/view/connectomics/. +""" + +import os +from glob import glob +from tqdm import tqdm +from shutil import rmtree +from concurrent import futures +from typing import Tuple, Union, Literal + +import imageio +import numpy as np + +import torch_em + +from torch.utils.data import Dataset, DataLoader + +from .. import util + + +URL = "http://www.casser.io/files/lucchi_pp.zip" +CHECKSUM = "770ce9e98fc6f29c1b1a250c637e6c5125f2b5f1260e5a7687b55a79e2e8844d" + + +def _load_volume(path, pattern): + nz = len(glob(os.path.join(path, "*.png"))) + im0 = imageio.imread(os.path.join(path, pattern % 0)) + out = np.zeros((nz,) + im0.shape, dtype=im0.dtype) + out[0] = im0 + + def _loadz(z): + im = imageio.imread(os.path.join(path, pattern % z)) + out[z] = im + + n_threads = 8 + with futures.ThreadPoolExecutor(n_threads) as tp: + list(tqdm( + tp.map(_loadz, range(1, nz)), desc="Load volume", total=nz-1 + )) + + return out + + +def _create_data(root, inputs, out_path): + import h5py + + raw = _load_volume(os.path.join(root, inputs[0]), pattern="mask%04i.png") + labels_argb = _load_volume(os.path.join(root, inputs[1]), pattern="%i.png") + if labels_argb.ndim == 4: + labels = np.zeros(raw.shape, dtype="uint8") + fg_mask = (labels_argb == np.array([255, 255, 255, 255])[None, None, None]).all(axis=-1) + labels[fg_mask] = 1 + else: + assert labels_argb.ndim == 3 + labels = labels_argb + labels[labels == 255] = 1 + assert (np.unique(labels) == np.array([0, 1])).all() + assert raw.shape == labels.shape, f"{raw.shape}, {labels.shape}" + with h5py.File(out_path, "w") as f: + f.create_dataset("raw", data=raw, compression="gzip") + f.create_dataset("labels", data=labels.astype("uint8"), compression="gzip") + + +def get_lucchi_data(path: Union[os.PathLike, str], split: Literal["train", "test"], download: bool = False) -> str: + """Download the Lucchi dataset. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The split to download, either 'train' or 'test'. + download: Whether to download the data if it is not present. + + Returns: + The filepath for the downloaded data. + """ + data_path = os.path.join(path, f"lucchi_{split}.h5") + if os.path.exists(data_path): + return data_path + + os.makedirs(path) + tmp_path = os.path.join(path, "lucchi.zip") + util.download_source(tmp_path, URL, download, checksum=CHECKSUM) + util.unzip(tmp_path, path, remove=True) + + root = os.path.join(path, "Lucchi++") + assert os.path.exists(root), root + + inputs = [["Test_In", "Test_Out"], ["Train_In", "Train_Out"]] + outputs = ["lucchi_train.h5", "lucchi_test.h5"] + for inp, out in zip(inputs, outputs): + out_path = os.path.join(path, out) + _create_data(root, inp, out_path) + rmtree(root) + + assert os.path.exists(data_path), data_path + return data_path + + +def get_lucchi_paths(path: Union[os.PathLike, str], split: Literal["train", "test"], download: bool = False) -> str: + """Get paths to the Lucchi data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The data split. Either 'train' or 'test'. + download: Whether to download the data if it is not present. + + Returns: + The filepath for the stored data. + """ + get_lucchi_data(path, split, download) + data_path = os.path.join(path, f"lucchi_{split}.h5") + return data_path + + +def get_lucchi_dataset( + path: Union[os.PathLike, str], + split: Literal["train", "test"], + patch_shape: Tuple[int, int, int], + download: bool = False, + **kwargs +) -> Dataset: + """Get dataset for EM mitochondrion segmentation in the Lucchi dataset. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The data split. Either 'train' or 'test'. + patch_shape: The patch shape to use for training. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The segmentation dataset. + """ + assert split in ("train", "test") + + data_path = get_lucchi_paths(path, split, download) + + return torch_em.default_segmentation_dataset( + raw_paths=data_path, + raw_key="raw", + label_paths=data_path, + label_key="labels", + patch_shape=patch_shape, + **kwargs + ) + + +def get_lucchi_loader( + path: Union[os.PathLike, str], + split: Literal["train", "test"], + patch_shape: Tuple[int, int, int], + batch_size: int, + download: bool = False, + **kwargs +) -> DataLoader: + """Get dataloader for EM mitochondrion segmentation in the Lucchi dataset. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The data split. Either 'train' or 'test'. + patch_shape: The patch shape to use for training. + batch_size: The batch size for training. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. + + Returns: + The PyTorch DataLoader. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_lucchi_dataset(path, split, patch_shape, download=download, **ds_kwargs) + return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/mitoem.py b/torch_em/data/datasets/electron_microscopy/mitoem.py similarity index 51% rename from torch_em/data/datasets/mitoem.py rename to torch_em/data/datasets/electron_microscopy/mitoem.py index 262b583e..04716efb 100644 --- a/torch_em/data/datasets/mitoem.py +++ b/torch_em/data/datasets/electron_microscopy/mitoem.py @@ -1,18 +1,26 @@ +"""MitoEM is a dataset for segmenting mitochondria in electron microscopy. +It contains two large annotated volumes, one from rat cortex, the other from human cortex. +This dataset was used for a segmentation challenge at ISBI 2022. + +If you use it in your research then please cite https://doi.org/10.1007/978-3-030-59722-1_7. +""" + import os +from tqdm import tqdm import multiprocessing -from concurrent import futures from shutil import rmtree +from concurrent import futures +from typing import List, Optional, Sequence, Tuple, Union import imageio import numpy as np + import torch_em -import z5py -from tqdm import tqdm -from . import util +from torch.utils.data import Dataset, DataLoader + +from .. import util -# TODO: update the links to the new host location at huggingface. -# - https://mitoem.grand-challenge.org/ (see `Dataset` for the links) URLS = { "raw": { @@ -21,7 +29,7 @@ }, "labels": { "human": "https://www.dropbox.com/s/dhf89bc14kemw4e/EM30-H-mito-train-val-v2.zip?dl=1", - "rat": "https://www.dropbox.com/s/stncdytayhr8ggz/EM30-R-mito-train-val-v2.zip?dl=1" + "rat": "https://huggingface.co/datasets/pytc/MitoEM/resolve/main/EM30-R-mito-train-val-v2.zip" } } CHECKSUMS = { @@ -72,6 +80,8 @@ def load_slice(z, slice_id): def _create_volume(out_path, im_folder, label_folder=None, z_start=None): + import z5py + if label_folder is None: assert z_start is not None n_slices = len(get_slices(im_folder)) @@ -127,30 +137,21 @@ def _require_mitoem_sample(path, sample, download): rmtree(val_folder) -def get_mitoem_dataset( - path, - splits, - patch_shape, - samples=("human", "rat"), - download=False, - offsets=None, - boundaries=False, - binary=False, - **kwargs, -): - """Dataset for the segmentation of mitochondria in EM. +def get_mitoem_data(path: Union[os.PathLike, str], samples: Sequence[str], splits: Sequence[str], download: bool): + """Download the MitoEM training data. - This dataset is from the publication https://doi.org/10.1007/978-3-030-59722-1_7. - Please cite it if you use this dataset for a publication. + Args: + path: Filepath to a folder where the downloaded data will be saved. + samples: The samples to download. The available samples are 'human' and 'rat'. + splits: The data splits to download. The available splits are 'train', 'val' and 'test'. + download: Whether to download the data if it is not present. """ - assert len(patch_shape) == 3 if isinstance(splits, str): splits = [splits] assert len(set(splits) - {"train", "val"}) == 0, f"{splits}" assert len(set(samples) - {"human", "rat"}) == 0, f"{samples}" os.makedirs(path, exist_ok=True) - data_paths = [] for sample in samples: if not _check_data(path, sample): print("The MitoEM data for sample", sample, "is not available yet and will be downloaded and created.") @@ -161,37 +162,107 @@ def get_mitoem_dataset( for split in splits: split_path = os.path.join(path, f"{sample}_{split}.n5") assert os.path.exists(split_path), split_path - data_paths.append(split_path) + + +def get_mitoem_paths( + path: Union[os.PathLike, str], + splits: Sequence[str], + samples: Sequence[str] = ("human", "rat"), + download: bool = False, +) -> List[str]: + """Get paths for MitoEM data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + samples: The samples to download. The available samples are 'human' and 'rat'. + splits: The data splits to download. The available splits are 'train', 'val' and 'test'. + download: Whether to download the data if it is not present. + + Returns: + The filepaths for the stored data. + """ + get_mitoem_data(path, samples, splits, download) + data_paths = [os.path.join(path, f"{sample}_{split}.n5") for split in splits for sample in samples] + return data_paths + + +def get_mitoem_dataset( + path: Union[os.PathLike, str], + splits: Sequence[str], + patch_shape: Tuple[int, int, int], + samples: Sequence[str] = ("human", "rat"), + download: bool = False, + offsets: Optional[List[List[int]]] = None, + boundaries: bool = False, + binary: bool = False, + **kwargs, +) -> Dataset: + """Get the MitoEM dataset for the segmentation of mitochondria in EM. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + splits: The splits to use for the dataset. Available values are 'train', 'val' and 'test'. + patch_shape: The patch shape to use for training. + samples: The samples to use for the dataset. The available samples are 'human' and 'rat'. + download: Whether to download the data if it is not present. + offsets: Offset values for affinity computation used as target. + boundaries: Whether to compute boundaries as the target. + binary: Whether to return a binary segmentation target. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The segmentation dataset. + """ + assert len(patch_shape) == 3 + + data_paths = get_mitoem_paths(path, samples, splits, download) kwargs, _ = util.add_instance_label_transform( kwargs, add_binary_target=True, binary=binary, boundaries=boundaries, offsets=offsets ) - raw_key = "raw" - label_key = "labels" - return torch_em.default_segmentation_dataset(data_paths, raw_key, data_paths, label_key, patch_shape, **kwargs) + + return torch_em.default_segmentation_dataset( + raw_paths=data_paths, + raw_key="raw", + label_paths=data_paths, + label_key="labels", + patch_shape=patch_shape, + **kwargs + ) def get_mitoem_loader( - path, - splits, - patch_shape, - batch_size, - samples=("human", "rat"), - download=False, - offsets=None, - boundaries=False, - binary=False, + path: Union[os.PathLike, str], + splits: Sequence[str], + patch_shape: Tuple[int, int, int], + batch_size: int, + samples: Sequence[str] = ("human", "rat"), + download: bool = False, + offsets: Optional[List[List[int]]] = None, + boundaries: bool = False, + binary: bool = False, **kwargs, -): - """Dataloader for the segmentation of mitochondria in EM. See 'get_mitoem_dataset' for details.""" - ds_kwargs, loader_kwargs = util.split_kwargs( - torch_em.default_segmentation_dataset, **kwargs - ) +) -> DataLoader: + """Get the MitoEM dataloader for the segmentation of mitochondria in EM. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + splits: The splits to use for the dataset. Available values are 'train', 'val' and 'test'. + patch_shape: The patch shape to use for training. + batch_size: The batch size for training. + samples: The samples to use for the dataset. The available samples are 'human' and 'rat'. + download: Whether to download the data if it is not present. + offsets: Offset values for affinity computation used as target. + boundaries: Whether to compute boundaries as the target. + binary: Whether to return a binary segmentation target. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. + + Returns: + The DataLoader. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) dataset = get_mitoem_dataset( - path, splits, patch_shape, - samples=samples, download=download, - offsets=offsets, boundaries=boundaries, binary=binary, - **ds_kwargs + path, splits, patch_shape, samples=samples, download=download, + offsets=offsets, boundaries=boundaries, binary=binary, **ds_kwargs ) - loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) - return loader + return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/electron_microscopy/nuc_mm.py b/torch_em/data/datasets/electron_microscopy/nuc_mm.py new file mode 100644 index 00000000..a9736a3c --- /dev/null +++ b/torch_em/data/datasets/electron_microscopy/nuc_mm.py @@ -0,0 +1,159 @@ +"""NucMM is a dataset for the segmentation of nuclei in EM and X-Ray. + +This dataset is from the publication https://doi.org/10.1007/978-3-030-87193-2_16. +Please cite it if you use this dataset for a publication. +""" + +import os +from glob import glob +from typing import Tuple, Union, Literal, List + +import torch_em + +from torch.utils.data import Dataset, DataLoader + +from .. import util + + +URL = "https://drive.google.com/drive/folders/1_4CrlYvzx0ITnGlJOHdgcTRgeSkm9wT8" + + +def _extract_split(image_folder, label_folder, output_folder): + import h5py + + os.makedirs(output_folder, exist_ok=True) + image_files = sorted(glob(os.path.join(image_folder, "*.h5"))) + label_files = sorted(glob(os.path.join(label_folder, "*.h5"))) + assert len(image_files) == len(label_files) + for image, label in zip(image_files, label_files): + with h5py.File(image, "r") as f: + vol = f["main"][:] + with h5py.File(label, "r") as f: + seg = f["main"][:] + assert vol.shape == seg.shape + out_path = os.path.join(output_folder, os.path.basename(image)) + with h5py.File(out_path, "a") as f: + f.create_dataset("raw", data=vol, compression="gzip") + f.create_dataset("labels", data=seg, compression="gzip") + + +def get_nuc_mm_data(path: Union[os.PathLike, str], sample: Literal['mouse', 'zebrafish'], download: bool) -> str: + """Download the NucMM training data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + sample: The NucMM samples to use. The available samples are 'mouse' and 'zebrafish'. + download: Whether to download the data if it is not present. + + Returns: + The filepath to the training data. + """ + assert sample in ("mouse", "zebrafish") + + sample_folder = os.path.join(path, sample) + if os.path.exists(sample_folder): + return sample_folder + + # Downloading the dataset + util.download_source_gdrive(path, URL, download, download_type="folder") + + if sample == "mouse": + input_folder = os.path.join(path, "Mouse (NucMM-M)") + else: + input_folder = os.path.join(path, "Zebrafish (NucMM-Z)") + assert os.path.exists(input_folder), input_folder + + sample_folder = os.path.join(path, sample) + _extract_split( + os.path.join(input_folder, "Image", "train"), os.path.join(input_folder, "Label", "train"), + os.path.join(sample_folder, "train") + ) + _extract_split( + os.path.join(input_folder, "Image", "val"), os.path.join(input_folder, "Label", "val"), + os.path.join(sample_folder, "val") + ) + return sample_folder + + +def get_nuc_mm_paths( + path: Union[os.PathLike], sample: Literal['mouse', 'zebrafish'], split: str, download: bool = False, +) -> List[str]: + """Get paths to the NucMM data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + sample: The NucMM samples to use. The available samples are 'mouse' and 'zebrafish'. + split: The split for the dataset, either 'train' or 'val'. + download: Whether to download the data if it is not present. + + Returns: + The filepaths to the stored data. + """ + get_nuc_mm_data(path, sample, download) + split_folder = os.path.join(path, sample, split) + paths = sorted(glob(os.path.join(split_folder, "*.h5"))) + return paths + + +def get_nuc_mm_dataset( + path: Union[os.PathLike, str], + sample: Literal['mouse', 'zebrafish'], + split: str, + patch_shape: Tuple[int, int, int], + download: bool = False, + **kwargs +) -> Dataset: + """Get the NucMM dataset for the segmentation of nuclei in X-Ray and EM. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + sample: The NucMM samples to use. The available samples are 'mouse' and 'zebrafish'. + split: The split for the dataset, either 'train' or 'val'. + patch_shape: The patch shape to use for training. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The segmentation dataset. + """ + assert split in ("train", "val") + + paths = get_nuc_mm_paths(path, sample, split, download) + + return torch_em.default_segmentation_dataset( + raw_paths=paths, + raw_key="raw", + label_paths=paths, + label_key="labels", + patch_shape=patch_shape, + is_seg_dataset=True, + **kwargs + ) + + +def get_nuc_mm_loader( + path: Union[os.PathLike, str], + sample: Literal['mouse', 'zebrafish'], + split: str, + patch_shape: Tuple[int, int, int], + batch_size: int, + download: bool = False, + **kwargs +) -> DataLoader: + """Get the NucMM dataset for the segmentation of nuclei in X-Ray and EM. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + sample: The NucMM samples to use. The available samples are 'mouse' and 'zebrafish'. + split: The split for the dataset, either 'train' or 'val'. + patch_shape: The patch shape to use for training. + batch_size: The batch size for training. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. + + Returns: + The segmentation dataset. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + ds = get_nuc_mm_dataset(path, sample, split, patch_shape, download, **ds_kwargs) + return torch_em.get_data_loader(ds, batch_size=batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/electron_microscopy/platynereis.py b/torch_em/data/datasets/electron_microscopy/platynereis.py new file mode 100644 index 00000000..64a17e7c --- /dev/null +++ b/torch_em/data/datasets/electron_microscopy/platynereis.py @@ -0,0 +1,448 @@ +"""Dataset for the segmentation of different structures in EM volume of a +platynereis larve. Contains annotations for the segmentation of: +- Cuticle +- Cilia +- Cells +- Nuclei + +This dataset is from the publication https://doi.org/10.1016/j.cell.2021.07.017. +Please cite it if you use this dataset for a publication. +""" + +import os +from glob import glob +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union + +import numpy as np + +from torch.utils.data import Dataset, DataLoader + +import torch_em + +from .. import util + + +URLS = { + "cells": "https://zenodo.org/record/3675220/files/membrane.zip", + "nuclei": "https://zenodo.org/record/3675220/files/nuclei.zip", + "cilia": "https://zenodo.org/record/3675220/files/cilia.zip", + "cuticle": "https://zenodo.org/record/3675220/files/cuticle.zip" +} + +CHECKSUMS = { + "cells": "30eb50c39e7e9883e1cd96e0df689fac37a56abb11e8ed088907c94a5980d6a3", + "nuclei": "a05033c5fbc6a3069479ac6595b0a430070f83f5281f5b5c8913125743cf5510", + "cilia": "6d2b47f63d39a671789c02d8b66cad5e4cf30eb14cdb073da1a52b7defcc5e24", + "cuticle": "464f75d30133e8864958049647fe3c2216ddf2d4327569738ad72d299c991843" +} + +FILE_TEMPLATES = { + "cells": "train_data_membrane_%02i.n5", + "nuclei": "train_data_nuclei_%02i.h5", + "cilia": "train_data_cilia_%02i.h5", + "cuticle": "train_data_%02i.n5", +} + + +# +# TODO data-loader for more classes: +# - mitos +# + + +def _check_data(path, prefix, extension, n_files): + if not os.path.exists(path): + return False + files = glob(os.path.join(path, f"{prefix}*{extension}")) + return len(files) == n_files + + +def get_platynereis_data(path: Union[os.PathLike, str], name: str, download: bool) -> Tuple[str, int]: + """Download the platynereis dataset. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + name: Name of the segmentation task. Available tasks: 'cuticle', 'cilia', 'cells' or 'nuclei'. + download: Whether to download the data if it is not present. + + Returns: + The path to the folder where the data has been downloaded. + The number of files downloaded. + """ + data_root = os.path.join(path, name) + + if name == "cuticle": + ext, prefix, n_files = ".n5", "train_data_", 5 + elif name == "cilia": + ext, prefix, n_files = ".h5", "train_data_cilia_", 3 + elif name == "cells": + data_root = os.path.join(path, "membrane") + ext, prefix, n_files = ".n5", "train_data_membrane_", 9 + elif name == "nuclei": + ext, prefix, n_files = ".h5", "train_data_nuclei_", 12 + else: + raise ValueError(f"Invalid name {name}. Expect one of 'cuticle', 'cilia', 'cell' or 'nuclei'.") + + data_is_complete = _check_data(data_root, prefix, ext, n_files) + if data_is_complete: + return data_root, n_files + + os.makedirs(path, exist_ok=True) + url = URLS[name] + checksum = CHECKSUMS[name] + + zip_path = os.path.join(path, f"data-{name}.zip") + util.download_source(zip_path, url, download=download, checksum=checksum) + util.unzip(zip_path, path, remove=True) + + return data_root, n_files + + +def get_platynereis_paths(path, sample_ids, name, rois={}, download=False, return_rois=False): + """Get paths to the platynereis data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + sample_ids: The sample ids to use for the dataset + name: Name of the segmentation task. Available tasks: 'cuticle', 'cilia', 'cells' or 'nuclei'. + rois: The region of interests to use for the data blocks. + download: Whether to download the data if it is not present. + return_rois: Whether to return the extracted rois. + + Returns: + The filepaths for the stored data. + """ + root, n_files = get_platynereis_data(path, name, download) + template = os.path.join(root, FILE_TEMPLATES[name]) + + if sample_ids is None: + sample_ids = list(range(1, n_files + 1)) + else: + assert min(sample_ids) >= 1 and max(sample_ids) <= n_files + sample_ids.sort() + paths = [template % sample for sample in sample_ids] + data_rois = [rois.get(sample, np.s_[:, :, :]) for sample in sample_ids] + + if return_rois: + return paths, data_rois + else: + return paths + + +def get_platynereis_cuticle_dataset( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int, int], + sample_ids: Optional[Sequence[int]] = None, + download: bool = False, + rois: Dict[int, Any] = {}, + **kwargs +) -> Dataset: + """Get the dataset for cuticle segmentation in platynereis. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + patch_shape: The patch shape to use for training. + sample_ids: The sample ids to use for the dataset + download: Whether to download the data if it is not present. + rois: The region of interests to use for the data blocks. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The segmentation dataset. + """ + paths, data_rois = get_platynereis_paths( + path=path, sample_ids=sample_ids, name="cuticle", rois=rois, download=download, return_rois=True, + ) + return torch_em.default_segmentation_dataset( + raw_paths=paths, + raw_key="volumes/raw", + label_paths=paths, + label_key="volumes/labels/segmentation", + patch_shape=patch_shape, + rois=data_rois, + **kwargs + ) + + +def get_platynereis_cuticle_loader( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int, int], + batch_size: int, + sample_ids: Optional[Sequence[int]] = None, + download: bool = False, + rois: Dict[int, Any] = {}, + **kwargs +) -> DataLoader: + """Get the dataloader for cuticle segmentation in platynereis. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + patch_shape: The patch shape to use for training. + batch_size: The batch size for training. + sample_ids: The sample ids to use for the dataset + download: Whether to download the data if it is not present. + rois: The region of interests to use for the data blocks. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. + + Returns: + The DataLoader. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + ds = get_platynereis_cuticle_dataset( + path, patch_shape, sample_ids=sample_ids, download=download, rois=rois, **ds_kwargs, + ) + return torch_em.get_data_loader(ds, batch_size=batch_size, **loader_kwargs) + + +def get_platynereis_cilia_dataset( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int, int], + sample_ids: Optional[Sequence[int]] = None, + offsets: Optional[List[List[int]]] = None, + boundaries: bool = False, + binary: bool = False, + rois: Dict[int, Any] = {}, + download: bool = False, + **kwargs +) -> Dataset: + """Get the dataset for cilia segmentation in platynereis. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + patch_shape: The patch shape to use for training. + sample_ids: The sample ids to use for the dataset + offsets: Offset values for affinity computation used as target. + boundaries: Whether to compute boundaries as the target. + binary: Whether to use a binary segmentation target. + rois: The region of interests to use for the data blocks. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The segmentation dataset. + """ + paths, rois = get_platynereis_paths( + path=path, sample_ids=sample_ids, name="cilia", rois=rois, download=download, return_rois=True, + ) + kwargs = util.update_kwargs(kwargs, "rois", rois) + kwargs, _ = util.add_instance_label_transform( + kwargs, add_binary_target=True, boundaries=boundaries, offsets=offsets, binary=binary, + ) + return torch_em.default_segmentation_dataset( + raw_paths=paths, + raw_key="volumes/raw", + label_paths=paths, + label_key="volumes/labels/segmentation", + patch_shape=patch_shape, + **kwargs + ) + + +def get_platynereis_cilia_loader( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int, int], + batch_size: int, + sample_ids: Optional[Sequence[int]] = None, + offsets: Optional[List[List[int]]] = None, + boundaries: bool = False, + binary: bool = False, + rois: Dict[int, Any] = {}, + download: bool = False, + **kwargs +) -> DataLoader: + """Get the dataloader for cilia segmentation in platynereis. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + patch_shape: The patch shape to use for training. + batch_size: The batch size for training. + sample_ids: The sample ids to use for the dataset + offsets: Offset values for affinity computation used as target. + boundaries: Whether to compute boundaries as the target. + binary: Whether to return a binary segmentation target. + rois: The region of interests to use for the data blocks. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. + + Returns: + The DataLoader. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + ds = get_platynereis_cilia_dataset( + path, patch_shape, sample_ids=sample_ids, + offsets=offsets, boundaries=boundaries, binary=binary, + rois=rois, download=download, **ds_kwargs, + ) + return torch_em.get_data_loader(ds, batch_size=batch_size, **loader_kwargs) + + +def get_platynereis_cell_dataset( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int, int], + sample_ids: Optional[Sequence[int]] = None, + offsets: Optional[List[List[int]]] = None, + boundaries: bool = False, + rois: Dict[int, Any] = {}, + download: bool = False, + **kwargs +) -> Dataset: + """Get the dataset for cell segmentation in platynereis. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + patch_shape: The patch shape to use for training. + sample_ids: The sample ids to use for the dataset + offsets: Offset values for affinity computation used as target. + boundaries: Whether to compute boundaries as the target. + rois: The region of interests to use for the data blocks. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The segmentation dataset. + """ + data_paths, data_rois = get_platynereis_paths( + path=path, sample_ids=sample_ids, name="cells", rois=rois, download=download, return_rois=True, + ) + + kwargs = util.update_kwargs(kwargs, "rois", data_rois) + kwargs, _ = util.add_instance_label_transform( + kwargs, add_binary_target=False, boundaries=boundaries, offsets=offsets, + ) + + return torch_em.default_segmentation_dataset( + raw_paths=data_paths, + raw_key="volumes/raw/s1", + label_paths=data_paths, + label_key="volumes/labels/segmentation/s1", + patch_shape=patch_shape, + **kwargs + ) + + +def get_platynereis_cell_loader( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int, int], + batch_size: int, + sample_ids: Optional[Sequence[int]] = None, + offsets: Optional[List[List[int]]] = None, + boundaries: bool = False, + rois: Dict[int, Any] = {}, + download: bool = False, + **kwargs +) -> DataLoader: + """Get the dataloader for cell segmentation in platynereis. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + patch_shape: The patch shape to use for training. + batch_size: The batch size for training. + sample_ids: The sample ids to use for the dataset + offsets: Offset values for affinity computation used as target. + boundaries: Whether to compute boundaries as the target. + rois: The region of interests to use for the data blocks. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. + + Returns: + The DataLoader. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + ds = get_platynereis_cell_dataset( + path, patch_shape, sample_ids, rois=rois, + offsets=offsets, boundaries=boundaries, download=download, + **ds_kwargs, + ) + return torch_em.get_data_loader(ds, batch_size=batch_size, **loader_kwargs) + + +def get_platynereis_nuclei_dataset( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int, int], + sample_ids: Optional[Sequence[int]] = None, + offsets: Optional[List[List[int]]] = None, + boundaries: bool = False, + binary: bool = False, + rois: Dict[int, Any] = {}, + download: bool = False, + **kwargs +) -> Dataset: + """Get the dataset for nucleus segmentation in platynereis. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + patch_shape: The patch shape to use for training. + sample_ids: The sample ids to use for the dataset + offsets: Offset values for affinity computation used as target. + boundaries: Whether to compute boundaries as the target. + binary: Whether to return a binary segmentation target. + rois: The region of interests to use for the data blocks. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The segmentation dataset. + """ + _, n_files = get_platynereis_data(path, "nuclei", download) + + if sample_ids is None: + sample_ids = list(range(1, n_files + 1)) + assert min(sample_ids) >= 1 and max(sample_ids) <= n_files + sample_ids.sort() + + data_paths, data_rois = get_platynereis_paths( + path=path, sample_ids=sample_ids, name="nuclei", rois=rois, download=download, return_rois=True, + ) + + kwargs = util.update_kwargs(kwargs, "is_seg_dataset", True) + kwargs = util.update_kwargs(kwargs, "rois", data_rois) + kwargs, _ = util.add_instance_label_transform( + kwargs, add_binary_target=True, boundaries=boundaries, offsets=offsets, binary=binary, + ) + + return torch_em.default_segmentation_dataset( + raw_paths=data_paths, + raw_key="volumes/raw", + label_paths=data_paths, + label_key="volumes/labels/nucleus_instance_labels", + patch_shape=patch_shape, + **kwargs + ) + + +def get_platynereis_nuclei_loader( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int, int], + batch_size: int, + sample_ids: Optional[Sequence[int]] = None, + offsets: Optional[List[List[int]]] = None, + boundaries: bool = False, + binary: bool = False, + rois: Dict[int, Any] = {}, + download: bool = False, + **kwargs +) -> DataLoader: + """Get the dataloader for nucleus segmentation in platynereis. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + patch_shape: The patch shape to use for training. + batch_size: The batch size for training. + sample_ids: The sample ids to use for the dataset + offsets: Offset values for affinity computation used as target. + boundaries: Whether to compute boundaries as the target. + binary: Whether to return a binary segmentation target. + rois: The region of interests to use for the data blocks. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. + + Returns: + The DataLoader. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + ds = get_platynereis_nuclei_dataset( + path, patch_shape, sample_ids=sample_ids, rois=rois, + offsets=offsets, boundaries=boundaries, binary=binary, download=download, + **ds_kwargs, + ) + return torch_em.get_data_loader(ds, batch_size=batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/electron_microscopy/snemi.py b/torch_em/data/datasets/electron_microscopy/snemi.py new file mode 100644 index 00000000..05f36638 --- /dev/null +++ b/torch_em/data/datasets/electron_microscopy/snemi.py @@ -0,0 +1,135 @@ +"""SNEMI is a dataset for neuron segmentation in EM. +It contains an annotated volumes from the mouse brain. + +The data is part of the publication https://doi.org/10.1016/j.cell.2015.06.054. +Please cite it if you use this dataset for a publication. +""" + +import os +from typing import List, Optional, Union, Tuple + +from torch.utils.data import Dataset, DataLoader + +import torch_em + +from .. import util + + +SNEMI_URLS = { + "train": "https://oc.embl.de/index.php/s/43iMotlXPyAB39z/download", + "test": "https://oc.embl.de/index.php/s/aRhphk35H23De2s/download" +} +CHECKSUMS = { + "train": "5b130a24d9eb23d972fede0f1a403bc05f6808b361cfa22eff23b930b12f0615", + "test": "3df3920a0ddec6897105845f842b2665d37a47c2d1b96d4f4565682e315a59fa" +} + + +def get_snemi_data(path: Union[os.PathLike, str], sample: str, download: bool = False): + """Download the SNEMI training data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + sample: The sample to download, either 'train' or 'test'. + download: Whether to download the data if it is not present. + """ + os.makedirs(path, exist_ok=True) + data_path = os.path.join(path, f"snemi_{sample}.h5") + util.download_source(data_path, SNEMI_URLS[sample], download, CHECKSUMS[sample]) + + +def get_snemi_paths(path: Union[os.PathLike, str], sample: str, download: bool = False) -> str: + """Get path to the SNEMI data. + + Args: + path: Filepath to a folder where the downloaded data is saved. + sample: The sample to download, either 'train' or 'test'. + download: Whether to download the data if it is not present. + + Returns: + The filepath for the stored data. + """ + get_snemi_data(path, sample, download) + data_path = os.path.join(path, f"snemi_{sample}.h5") + assert os.path.exists(data_path), data_path + return data_path + + +def get_snemi_dataset( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int, int], + sample: str = "train", + download: bool = False, + offsets: Optional[List[List[int]]] = None, + boundaries: bool = False, + **kwargs, +) -> Dataset: + """Get the SNEMI dataset for the segmentation of neurons in EM. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + patch_shape: The patch shape to use for training. + sample: The sample to download, either 'train' or 'test'. + download: Whether to download the data if it is not present. + offsets: Offset values for affinity computation used as target. + boundaries: Whether to compute boundaries as the target. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The segmentation dataset. + """ + assert len(patch_shape) == 3 + + data_path = get_snemi_paths(path, sample, download) + + kwargs = util.update_kwargs(kwargs, "is_seg_dataset", True) + kwargs, _ = util.add_instance_label_transform( + kwargs, add_binary_target=False, boundaries=boundaries, offsets=offsets + ) + + return torch_em.default_segmentation_dataset( + raw_paths=data_path, + raw_key="volumes/raw", + label_paths=data_path, + label_key="volumes/labels/neuron_ids", + patch_shape=patch_shape, + **kwargs + ) + + +def get_snemi_loader( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int, int], + batch_size: int, + sample: str = "train", + download: bool = False, + offsets: Optional[List[List[int]]] = None, + boundaries: bool = False, + **kwargs, +) -> DataLoader: + """Get the DataLoader for EM neuron segmentation in the SNEMI dataset. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + patch_shape: The patch shape to use for training. + batch_size: The batch size for training. + sample: The sample to download, either 'train' or 'test'. + download: Whether to download the data if it is not present. + offsets: Offset values for affinity computation used as target. + boundaries: Whether to compute boundaries as the target. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. + + Returns: + The DataLoader. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + ds = get_snemi_dataset( + path=path, + patch_shape=patch_shape, + sample=sample, + download=download, + offsets=offsets, + boundaries=boundaries, + **ds_kwargs, + ) + return torch_em.get_data_loader(ds, batch_size=batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/electron_microscopy/sponge_em.py b/torch_em/data/datasets/electron_microscopy/sponge_em.py new file mode 100644 index 00000000..635b916d --- /dev/null +++ b/torch_em/data/datasets/electron_microscopy/sponge_em.py @@ -0,0 +1,137 @@ +"""This dataset contains volume EM data of a sponge chamber with +segmentation annotations for cells, cilia and microvilli. + +It contains three annotated volumes. The dataset is part of the publication +https://doi.org/10.1126/science.abj2949. Please cite this publication of you use the +dataset in your research. +""" + +import os +from glob import glob +from typing import Optional, Sequence, Tuple, Union, List + +from torch.utils.data import Dataset, DataLoader + +import torch_em + +from .. import util + + +URL = "https://zenodo.org/record/8150818/files/sponge_em_train_data.zip?download=1" +CHECKSUM = "f1df616cd60f81b91d7642933e9edd74dc6c486b2e546186a7c1e54c67dd32a5" + + +def get_sponge_em_data(path: Union[os.PathLike, str], download: bool) -> Tuple[str, int]: + """Download the SpongeEM training data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + download: Whether to download the data if it is not present. + + Returns: + The path to the downloaded data. + The number of downloaded volumes. + """ + n_files = len(glob(os.path.join(path, "*.h5"))) + if n_files == 3: + return path, n_files + elif n_files == 0: + pass + else: + raise RuntimeError( + f"Invalid number of downloaded files in {path}. Please remove this folder and rerun this function." + ) + + os.makedirs(path, exist_ok=True) + zip_path = os.path.join(path, "data.zip") + util.download_source(zip_path, URL, download, CHECKSUM) + util.unzip(zip_path, path) + + n_files = len(glob(os.path.join(path, "*.h5"))) + assert n_files == 3 + return path, n_files + + +def get_sponge_em_paths( + path: Union[os.PathLike, str], sample_ids: Optional[Sequence[int]], download: bool = False +) -> List[str]: + """Get paths to the SpongeEM data. + + Args: + path: Filepath to a folder where the downloaded data will saved. + sample_ids: The sample to download, valid ids are 1, 2 and 3. + download: Whether to download the data if it is not present. + + Returns: + The filepaths to the stored data. + """ + data_folder, n_files = get_sponge_em_data(path, download) + + if sample_ids is None: + sample_ids = range(1, n_files + 1) + + paths = [os.path.join(data_folder, f"train_data_0{i}.h5") for i in sample_ids] + return paths + + +def get_sponge_em_dataset( + path: Union[os.PathLike, str], + mode: str, + patch_shape: Tuple[int, int, int], + sample_ids: Optional[Sequence[int]] = None, + download: bool = False, + **kwargs +) -> Dataset: + """Get the SpongeEM dataset for the segmentation of structures in EM. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + mode: Choose the segmentation task, either 'semantic' or 'instances'. + patch_shape: The patch shape to use for training. + sample_ids: The sample to download, valid ids are 1, 2 and 3. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The segmentation dataset. + """ + assert mode in ("semantic", "instances") + + paths = get_sponge_em_paths(path, sample_ids, download) + + return torch_em.default_segmentation_dataset( + raw_paths=paths, + raw_key="volumes/raw", + label_paths=paths, + label_key=f"volumes/labels/{mode}", + patch_shape=patch_shape, + **kwargs + ) + + +def get_sponge_em_loader( + path: Union[os.PathLike, str], + mode: str, + patch_shape: Tuple[int, int, int], + batch_size: int, + sample_ids: Optional[Sequence[int]] = None, + download: bool = False, + **kwargs +) -> DataLoader: + """Get the SpongeEM dataloader for the segmentation of structures in EM. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + mode: Choose the segmentation task, either 'semantic' or 'instances'. + patch_shape: The patch shape to use for training. + batch_size: The batch size for training. + sample_ids: The sample to download, valid ids are 1, 2 and 3. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. + + Returns: + The DataLoader. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + ds = get_sponge_em_dataset(path, mode, patch_shape, sample_ids=sample_ids, download=download, **ds_kwargs) + return torch_em.get_data_loader(ds, batch_size=batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/electron_microscopy/uro_cell.py b/torch_em/data/datasets/electron_microscopy/uro_cell.py new file mode 100644 index 00000000..bd8e12ad --- /dev/null +++ b/torch_em/data/datasets/electron_microscopy/uro_cell.py @@ -0,0 +1,226 @@ +"""The UroCell dataset contains segmentation annotations for the following organelles: +- Food Vacuoles +- Golgi Apparatus +- Lysosomes +- Mitochondria +It contains several FIB-SEM volumes with annotations. + +This dataset is from the publication https://doi.org/10.1016/j.compbiomed.2020.103693. +Please cite it if you use this dataset for a publication. +""" + +import os +import warnings +from glob import glob +from shutil import rmtree +from typing import List, Optional, Union, Tuple + +from torch.utils.data import Dataset, DataLoader + +import torch_em + +from .. import util + + +URL = "https://github.com/MancaZerovnikMekuc/UroCell/archive/refs/heads/master.zip" +CHECKSUM = "a48cf31b06114d7def642742b4fcbe76103483c069122abe10f377d71a1acabc" + + +def get_uro_cell_data(path: Union[os.PathLike, str], download: bool = False) -> str: + """Download the UroCell training data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + download: Whether to download the data if it is not present. + + Returns: + The path to the downloaded data. + """ + import h5py + + if os.path.exists(path): + return path + + try: + import nibabel as nib + except ImportError: + raise RuntimeError("Please install the nibabel package.") + + # Download and unzip the data. + os.makedirs(path) + tmp_path = os.path.join(path, "uro_cell.zip") + util.download_source(tmp_path, URL, download, checksum=CHECKSUM) + util.unzip(tmp_path, path, remove=True) + + root = os.path.join(path, "UroCell-master") + + files = glob(os.path.join(root, "data", "*.nii.gz")) + files.sort() + for data_path in files: + fname = os.path.basename(data_path) + data = nib.load(data_path).get_fdata() + + out_path = os.path.join(path, fname.replace("nii.gz", "h5")) + with h5py.File(out_path, "w") as f: + f.create_dataset("raw", data=data, compression="gzip") + + # Check if we have any of the organelle labels for this volume + # and also copy them if yes. + fv_path = os.path.join(root, "fv", "instance", fname) + if os.path.exists(fv_path): + fv = nib.load(fv_path).get_fdata().astype("uint32") + assert fv.shape == data.shape + f.create_dataset("labels/fv", data=fv, compression="gzip") + + golgi_path = os.path.join(root, "golgi", "precise", fname) + if os.path.exists(golgi_path): + golgi = nib.load(golgi_path).get_fdata().astype("uint32") + assert golgi.shape == data.shape + f.create_dataset("labels/golgi", data=golgi, compression="gzip") + + lyso_path = os.path.join(root, "lyso", "instance", fname) + if os.path.exists(lyso_path): + lyso = nib.load(lyso_path).get_fdata().astype("uint32") + assert lyso.shape == data.shape + f.create_dataset("labels/lyso", data=lyso, compression="gzip") + + mito_path = os.path.join(root, "mito", "instance", fname) + if os.path.exists(mito_path): + mito = nib.load(mito_path).get_fdata().astype("uint32") + assert mito.shape == data.shape + f.create_dataset("labels/mito", data=mito, compression="gzip") + + # Clean Up. + rmtree(root) + return path + + +def get_uro_cell_paths( + path: Union[os.PathLike], target: str, download: bool = False, return_label_key: bool = False, +) -> List[str]: + """Get paths to the UroCell data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + target: The segmentation target, corresponding to the organelle to segment. + Available organelles are 'fv', 'golgi', 'lyso' and 'mito'. + download: Whether to download the data if it is not present. + return_label_key: Whether to return the label key. + + Returns: + List of filepaths to the stored data. + """ + import h5py + + get_uro_cell_data(path, download) + + label_key = f"labels/{target}" + all_paths = glob(os.path.join(path, "*.h5")) + all_paths.sort() + paths = [path for path in all_paths if label_key in h5py.File(path, "r")] + + if return_label_key: + return paths, label_key + else: + return paths + + +def get_uro_cell_dataset( + path: Union[os.PathLike, str], + target: str, + patch_shape: Tuple[int, int, int], + download: bool = False, + offsets: Optional[List[List[int]]] = None, + boundaries: bool = False, + binary: bool = False, + **kwargs +) -> Dataset: + """Get the UroCell dataset for organelle segmentation in FIB-SEM. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + target: The segmentation target, corresponding to the organelle to segment. + Available organelles are 'fv', 'golgi', 'lyso' and 'mito'. + patch_shape: The patch shape to use for training. + download: Whether to download the data if it is not present. + offsets: Offset values for affinity computation used as target. + boundaries: Whether to compute boundaries as the target. + binary: Whether to return a binary segmentation target. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The segmentation dataset. + """ + assert target in ("fv", "golgi", "lyso", "mito") + + paths, label_key = get_uro_cell_paths(path, target, download, return_label_key=True) + + assert sum((offsets is not None, boundaries, binary)) <= 1, f"{offsets}, {boundaries}, {binary}" + if offsets is not None: + if target in ("lyso", "golgi"): + warnings.warn( + f"{target} does not have instance labels, affinities will be computed based on binary segmentation." + ) + # we add a binary target channel for foreground background segmentation + label_transform = torch_em.transform.label.AffinityTransform( + offsets=offsets, ignore_label=None, add_binary_target=True, add_mask=True + ) + msg = "Offsets are passed, but 'label_transform2' is in the kwargs. It will be over-ridden." + kwargs = util.update_kwargs(kwargs, 'label_transform2', label_transform, msg=msg) + elif boundaries: + if target in ("lyso", "golgi"): + warnings.warn( + f"{target} does not have instance labels, boundaries will be computed based on binary segmentation." + ) + label_transform = torch_em.transform.label.BoundaryTransform(add_binary_target=True) + msg = "Boundaries is set to true, but 'label_transform' is in the kwargs. It will be over-ridden." + kwargs = util.update_kwargs(kwargs, 'label_transform', label_transform, msg=msg) + elif binary: + label_transform = torch_em.transform.label.labels_to_binary + msg = "Binary is set to true, but 'label_transform' is in the kwargs. It will be over-ridden." + kwargs = util.update_kwargs(kwargs, 'label_transform', label_transform, msg=msg) + + return torch_em.default_segmentation_dataset( + raw_paths=paths, + raw_key="raw", + label_paths=paths, + label_key=label_key, + patch_shape=patch_shape, + is_seg_dataset=True, + **kwargs + ) + + +def get_uro_cell_loader( + path: Union[os.PathLike, str], + target: str, + patch_shape: Tuple[int, int, int], + batch_size: int, + download: bool = False, + offsets: Optional[List[List[int]]] = None, + boundaries: bool = False, + binary: bool = False, + **kwargs +) -> DataLoader: + """Get the UroCell dataloader for organelle segmentation in FIB-SEM. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + target: The segmentation target, corresponding to the organelle to segment. + Available organelles are 'fv', 'golgi', 'lyso' and 'mito'. + patch_shape: The patch shape to use for training. + batch_size: The batch size for training. + download: Whether to download the data if it is not present. + offsets: Offset values for affinity computation used as target. + boundaries: Whether to compute boundaries as the target. + binary: Whether to return a binary segmentation target. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. + + Returns: + The DataLoader. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + ds = get_uro_cell_dataset( + path, target, patch_shape, download=download, offsets=offsets, boundaries=boundaries, binary=binary, **ds_kwargs + ) + return torch_em.get_data_loader(ds, batch_size=batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/electron_microscopy/vnc.py b/torch_em/data/datasets/electron_microscopy/vnc.py new file mode 100644 index 00000000..b6fd8568 --- /dev/null +++ b/torch_em/data/datasets/electron_microscopy/vnc.py @@ -0,0 +1,163 @@ +"""The VNC dataset contains segmentation annotations for mitochondria in EM. +It contains two volumes from TEM of the drosophila brain. + +Please cite https://doi.org/10.6084/m9.figshare.856713.v1 if you use this dataset in your publication. +""" + +import os +from glob import glob +from shutil import rmtree +from typing import List, Optional, Union, Tuple + +import imageio +import numpy as np +from skimage.measure import label + +from torch.utils.data import Dataset, DataLoader + +import torch_em + +from .. import util + + +URL = "https://github.com/unidesigner/groundtruth-drosophila-vnc/archive/refs/heads/master.zip" +CHECKSUM = "f7bd0db03c86b64440a16b60360ad60c0a4411f89e2c021c7ee2c8d6af3d7e86" + + +def _create_volume(f, key, pattern, process=None): + images = glob(pattern) + images.sort() + data = np.concatenate([imageio.imread(im)[None] for im in images], axis=0) + if process is not None: + data = process(data) + f.create_dataset(key, data=data, compression="gzip") + + +def get_vnc_data(path: Union[os.PathLike, str], download: bool) -> str: + """Download the VNC training data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + download: Whether to download the data if it is not present. + + Returns: + The path to the downloaded data. + """ + import h5py + + train_path = os.path.join(path, "vnc_train.h5") + test_path = os.path.join(path, "vnc_test.h5") + if os.path.exists(train_path) and os.path.exists(test_path): + return path + + os.makedirs(path, exist_ok=True) + zip_path = os.path.join(path, "vnc.zip") + util.download_source(zip_path, URL, download, CHECKSUM) + util.unzip(zip_path, path, remove=True) + + root = os.path.join(path, "groundtruth-drosophila-vnc-master") + assert os.path.exists(root) + + with h5py.File(train_path, "w") as f: + _create_volume(f, "raw", os.path.join(root, "stack1", "raw", "*.tif")) + _create_volume(f, "labels/mitochondria", os.path.join(root, "stack1", "mitochondria", "*.png"), process=label) + _create_volume(f, "labels/synapses", os.path.join(root, "stack1", "synapses", "*.png"), process=label) + # TODO find the post-processing to go from neuron labels to membrane labels + # _create_volume(f, "labels/neurons", os.path.join(root, "stack1", "membranes", "*.png")) + + with h5py.File(test_path, "w") as f: + _create_volume(f, "raw", os.path.join(root, "stack2", "raw", "*.tif")) + + rmtree(root) + return path + + +def get_vnc_mito_paths(path: Union[os.PathLike, str], download: bool = False) -> str: + """Get path to the VNC data. + + Args: + path: Filepath to a folder where the downloaded data is saved. + download: Whether to download the data if it is not present. + + Returns: + The filepath to the stored data. + """ + get_vnc_data(path, download) + data_path = os.path.join(path, "vnc_train.h5") + return data_path + + +def get_vnc_mito_dataset( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int, int], + offsets: Optional[List[List[int]]] = None, + boundaries: bool = False, + binary: bool = False, + download: bool = False, + **kwargs +) -> Dataset: + """Get the VNC dataset for segmentating mitochondria in EM. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + patch_shape: The patch shape to use for training. + offsets: Offset values for affinity computation used as target. + boundaries: Whether to compute boundaries as the target. + binary: Whether to return a binary segmentation target. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The segmentation dataset. + """ + data_path = get_vnc_mito_paths(path, download) + + kwargs, _ = util.add_instance_label_transform( + kwargs, add_binary_target=True, boundaries=boundaries, offsets=offsets, binary=binary, + ) + + return torch_em.default_segmentation_dataset( + raw_paths=data_path, + raw_key="raw", + label_paths=data_path, + label_key="labels/mitochondria", + patch_shape=patch_shape, + **kwargs + ) + + +def get_vnc_mito_loader( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int, int], + batch_size: int, + offsets: Optional[List[List[int]]] = None, + boundaries: bool = False, + binary: bool = False, + download: bool = False, + **kwargs +) -> DataLoader: + """Get the VNC dataloader for segmentating mitochondria in EM. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + patch_shape: The patch shape to use for training. + batch_size: The batch size for training. + offsets: Offset values for affinity computation used as target. + boundaries: Whether to compute boundaries as the target. + binary: Whether to return a binary segmentation target. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. + + Returns: + The DataLoader. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + ds = get_vnc_mito_dataset( + path, patch_shape, download=download, offsets=offsets, boundaries=boundaries, binary=binary, **ds_kwargs + ) + return torch_em.get_data_loader(ds, batch_size=batch_size, **loader_kwargs) + + +# TODO implement +def get_vnc_neuron_loader(path, patch_shape, download=False, **kwargs): + raise NotImplementedError diff --git a/torch_em/data/datasets/histopathology/__init__.py b/torch_em/data/datasets/histopathology/__init__.py new file mode 100644 index 00000000..5da41827 --- /dev/null +++ b/torch_em/data/datasets/histopathology/__init__.py @@ -0,0 +1,6 @@ +from .bcss import get_bcss_loader, get_bcss_dataset +from .cryonuseg import get_cryonuseg_loader, get_cryonuseg_dataset +from .lizard import get_lizard_loader, get_lizard_dataset +from .monuseg import get_monuseg_loader, get_monuseg_dataset +from .monusac import get_monusac_loader, get_monusac_dataset +from .pannuke import get_pannuke_loader, get_pannuke_dataset diff --git a/torch_em/data/datasets/histopathology/bcss.py b/torch_em/data/datasets/histopathology/bcss.py new file mode 100644 index 00000000..367a3b50 --- /dev/null +++ b/torch_em/data/datasets/histopathology/bcss.py @@ -0,0 +1,242 @@ +"""This dataset contains annotations for tissue region segmentation in +breast cancer histopathology images. + +NOTE: There are multiple semantic instances in tissue labels. Below mentioned are their respective index details: + - 0: outside_roi (~background) + - 1: tumor + - 2: stroma + - 3: lymphocytic_infiltrate + - 4: necrosis_or_debris + - 5: glandular_secretions + - 6: blood + - 7: exclude + - 8: metaplasia_NOS + - 9: fat + - 10: plasma_cells + - 11: other_immune_infiltrate + - 12: mucoid_material + - 13: normal_acinus_or_duct + - 14: lymphatics + - 15: undetermined + - 16: nerve + - 17: skin_adnexa + - 18: blood_vessel + - 19: angioinvasion + - 20: dcis + - 21: other + +This dataset is from https://bcsegmentation.grand-challenge.org/BCSS/. +Please cite this paper (https://doi.org/10.1093/bioinformatics/btz083) if you use this dataset for a publication. +""" + +import os +import shutil +from glob import glob +from pathlib import Path +from typing import Union, Optional, List, Tuple + +from sklearn.model_selection import train_test_split + +import torch +from torch.utils.data import Dataset, DataLoader + +import torch_em + +from .. import util + + +URL = "https://drive.google.com/drive/folders/1zqbdkQF8i5cEmZOGmbdQm-EP8dRYtvss?usp=sharing" + + +# TODO +CHECKSUM = None + + +TEST_LIST = [ + "TCGA-A2-A0SX-DX1_xmin53791_ymin56683_MPP-0.2500", "TCGA-BH-A0BG-DX1_xmin64019_ymin24975_MPP-0.2500", + "TCGA-AR-A1AI-DX1_xmin38671_ymin10616_MPP-0.2500", "TCGA-E2-A574-DX1_xmin54962_ymin47475_MPP-0.2500", + "TCGA-GM-A3XL-DX1_xmin29910_ymin15820_MPP-0.2500", "TCGA-E2-A14X-DX1_xmin88836_ymin66393_MPP-0.2500", + "TCGA-A2-A04P-DX1_xmin104246_ymin48517_MPP-0.2500", "TCGA-E2-A14N-DX1_xmin21383_ymin66838_MPP-0.2500", + "TCGA-EW-A1OV-DX1_xmin126026_ymin65132_MPP-0.2500", "TCGA-S3-AA15-DX1_xmin55486_ymin28926_MPP-0.2500", + "TCGA-LL-A5YO-DX1_xmin36631_ymin44396_MPP-0.2500", "TCGA-GI-A2C9-DX1_xmin20882_ymin11843_MPP-0.2500", + "TCGA-BH-A0BW-DX1_xmin42346_ymin30843_MPP-0.2500", "TCGA-E2-A1B6-DX1_xmin16266_ymin50634_MPP-0.2500", + "TCGA-AO-A0J2-DX1_xmin33561_ymin14515_MPP-0.2500" +] + + +def _download_bcss_dataset(path, download): + """Current recommendation: + - download the folder from URL manually + - use the consortium's git repo to download the dataset (https://github.com/PathologyDataScience/BCSS) + """ + raise NotImplementedError("Please download the dataset using the drive link / git repo directly") + + # FIXME: limitation for the installation below: + # - only downloads first 50 files - due to `gdown`'s download folder function + # - (optional) clone their git repo to download their data + util.download_source_gdrive(path=path, url=URL, download=download, checksum=CHECKSUM, download_type="folder") + + +def _get_image_and_label_paths(path): + # when downloading the files from `URL`, the input images are stored under `rgbs_colorNormalized` + # when getting the files from the git repo's command line feature, the input images are stored under `images` + if os.path.exists(os.path.join(path, "images")): + image_paths = sorted(glob(os.path.join(path, "images", "*"))) + label_paths = sorted(glob(os.path.join(path, "masks", "*"))) + elif os.path.exists(os.path.join(path, "0_Public-data-Amgad2019_0.25MPP", "rgbs_colorNormalized")): + image_paths = sorted(glob(os.path.join(path, "0_Public-data-Amgad2019_0.25MPP", "rgbs_colorNormalized", "*"))) + label_paths = sorted(glob(os.path.join(path, "0_Public-data-Amgad2019_0.25MPP", "masks", "*"))) + else: + raise ValueError( + "Please check the image directory. " + "If downloaded from gdrive, it's named \"rgbs_colorNormalized\", if from github it's named \"images\"" + ) + + return image_paths, label_paths + + +def get_bcss_data(path: Union[os.PathLike, str], download: bool = False): + """Download the BCSS dataset. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + download: Whether to download the data if it is not present. + """ + if download: + _download_bcss_dataset(path, download) + + if os.path.exists(os.path.join(path, "train")) and os.path.exists(os.path.join(path, "test")): + return + + all_image_paths, all_label_paths = _get_image_and_label_paths(path) + + train_img_dir, train_lab_dir = os.path.join(path, "train", "images"), os.path.join(path, "train", "masks") + test_img_dir, test_lab_dir = os.path.join(path, "test", "images"), os.path.join(path, "test", "masks") + os.makedirs(train_img_dir, exist_ok=True) + os.makedirs(train_lab_dir, exist_ok=True) + os.makedirs(test_img_dir, exist_ok=True) + os.makedirs(test_lab_dir, exist_ok=True) + + for image_path, label_path in zip(all_image_paths, all_label_paths): + img_idx, label_idx = os.path.split(image_path)[-1], os.path.split(label_path)[-1] + if Path(image_path).stem in TEST_LIST: + # move image and label to test + dst_img_path, dst_lab_path = os.path.join(test_img_dir, img_idx), os.path.join(test_lab_dir, label_idx) + shutil.copy(src=image_path, dst=dst_img_path) + shutil.copy(src=label_path, dst=dst_lab_path) + else: + # move image and label to train + dst_img_path, dst_lab_path = os.path.join(train_img_dir, img_idx), os.path.join(train_lab_dir, label_idx) + shutil.copy(src=image_path, dst=dst_img_path) + shutil.copy(src=label_path, dst=dst_lab_path) + + +def get_bcsss_paths( + path: Union[os.PathLike, str], split: Optional[str] = None, val_fraction: float = 0.2, download: bool = False +) -> Tuple[List[str], List[str]]: + """Get paths to the BCSS data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The split to use for the dataset. Either 'train', 'val' or 'test'. + val_fraction: The fraction of data to be considered for validation split. + download: Whether to download the data if it is not present. + + Returns: + List of filepaths for the image data. + List of filepaths for the label data. + """ + get_bcss_data(path, download) + + if split is None: + image_paths = sorted(glob(os.path.join(path, "*", "images", "*"))) + label_paths = sorted(glob(os.path.join(path, "*", "masks", "*"))) + else: + assert split in ["train", "val", "test"], "Please choose from the available `train` / `val` / `test` splits" + if split == "test": + image_paths = sorted(glob(os.path.join(path, "test", "images", "*"))) + label_paths = sorted(glob(os.path.join(path, "test", "masks", "*"))) + else: + image_paths = sorted(glob(os.path.join(path, "train", "images", "*"))) + label_paths = sorted(glob(os.path.join(path, "train", "masks", "*"))) + + (train_image_paths, val_image_paths, + train_label_paths, val_label_paths) = train_test_split( + image_paths, label_paths, test_size=val_fraction, random_state=42 + ) + + image_paths = train_image_paths if split == "train" else val_image_paths + label_paths = train_label_paths if split == "train" else val_label_paths + + assert len(image_paths) == len(label_paths) + + return image_paths, label_paths + + +def get_bcss_dataset( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, ...], + split: Optional[str] = None, + val_fraction: float = 0.2, + download: bool = False, + label_dtype: torch.dtype = torch.int64, + **kwargs +) -> Dataset: + """Get the BCSS dataset for breast cancer tissue segmentation in histopathology. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + patch_shape: The patch shape to use for training. + split: The split to use for the dataset. Either 'train', 'val' or 'test'. + val_fraction: The fraction of data to be considered for validation split. + download: Whether to download the data if it is not present. + label_dtype: The datatype of labels. + kwargs: kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The segmentation dataset. + """ + image_paths, label_paths = get_bcsss_paths(path, split, val_fraction, download) + + return torch_em.default_segmentation_dataset( + raw_paths=image_paths, + raw_key=None, + label_paths=label_paths, + label_key=None, + patch_shape=patch_shape, + label_dtype=label_dtype, + is_seg_dataset=False, + **kwargs + ) + + +def get_bcss_loader( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, ...], + batch_size: int, + split: Optional[str] = None, + val_fraction: float = 0.2, + download: bool = False, + label_dtype: torch.dtype = torch.int64, + **kwargs +) -> DataLoader: + """Get the BCSS dataloader for breast cancer tissue segmentation in histopathology. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + patch_shape: The patch shape to use for training. + batch_size: The batch size for training. + split: The split to use for the dataset. Either 'train', 'val' or 'test'. + val_fraction: The fraction of data to be considered for validation split. + download: Whether to download the data if it is not present. + label_dtype: The datatype of labels. + kwargs: kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The DataLoader. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_bcss_dataset( + path, patch_shape, split, val_fraction, download=download, label_dtype=label_dtype, **ds_kwargs + ) + return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/histopathology/cryonuseg.py b/torch_em/data/datasets/histopathology/cryonuseg.py new file mode 100644 index 00000000..7f12dd19 --- /dev/null +++ b/torch_em/data/datasets/histopathology/cryonuseg.py @@ -0,0 +1,128 @@ +"""The CryoNuSeg dataset contains annotations for nucleus segmentation +in cryosectioned H&E stained histological images of 10 different organs. + +This dataset is from the publication https://doi.org/10.1016/j.compbiomed.2021.104349. +Please cite it if you use this dataset for your research. +""" + +import os +from glob import glob +from natsort import natsorted +from typing import Union, Tuple, Literal, List + +from torch.utils.data import Dataset, DataLoader + +import torch_em + +from .. import util + + +def get_cryonuseg_data(path: Union[os.PathLike, str], download: bool = False): + """Download the CryoNuSeg dataset for nucleus segmentation. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + download: Whether to download the data if it is not present. + """ + if os.path.exists(os.path.join(path, r"tissue images")): + return + + os.makedirs(path, exist_ok=True) + util.download_source_kaggle( + path=path, dataset_name="ipateam/segmentation-of-nuclei-in-cryosectioned-he-images", download=download + ) + + zip_path = os.path.join(path, "segmentation-of-nuclei-in-cryosectioned-he-images.zip") + util.unzip(zip_path=zip_path, dst=path) + + +def get_cryonuseg_paths( + path: Union[os.PathLike, str], rater_choice: Literal["b1", "b2", "b3"] = "b1", download: bool = False +) -> Tuple[List[str], List[str]]: + """Get paths to the CryoNuSeg data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + rater: The choice of annotator. + download: Whether to download the data if it is not present. + + Returns: + List of filepaths to the image data. + List of filepaths to the label data. + """ + get_cryonuseg_data(path, download) + + if rater_choice == "b1": + label_dir = r"Annotator 1 (biologist)/" + elif rater_choice == "b2": + label_dir = r"Annotator 1 (biologist second round of manual marks up)/" * 2 + elif rater_choice == "b3": + label_dir = r"Annotator 2 (bioinformatician)/" * 2 + else: + raise ValueError(f"'{rater_choice}' is not a valid rater choice.") + + # Point to the instance labels folder + label_dir += r"label masks" + + label_paths = natsorted(glob(os.path.join(path, label_dir, "*.tif"))) + raw_paths = natsorted(glob(os.path.join(path, r"tissue images", "*.tif"))) + + return raw_paths, label_paths + + +def get_cryonuseg_dataset( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int], + rater: Literal["b1", "b2", "b3"] = "b1", + download: bool = False, + **kwargs +) -> Dataset: + """Get the CryoNuSeg dataset for nucleus segmentation. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + patch_shape: The patch shape to use for training. + rater: The choice of annotator. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The segmentation dataset. + """ + raw_paths, label_paths = get_cryonuseg_paths(path, rater, download) + + return torch_em.default_segmentation_dataset( + raw_paths=raw_paths, + raw_key=None, + label_paths=label_paths, + label_key=None, + is_seg_dataset=False, + patch_shape=patch_shape, + **kwargs + ) + + +def get_cryonuseg_loader( + path: Union[os.PathLike, str], + batch_size: int, + patch_shape: Tuple[int, int], + rater: Literal["b1", "b2", "b3"] = "b1", + download: bool = False, + **kwargs +) -> DataLoader: + """Get the CryoNuSeg dataloader for nucleus segmentation. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + batch_size: The batch size for training. + patch_shape: The patch shape to use for training. + rater: The choice of annotator. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. + + Returns: + The DataLoader. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_cryonuseg_dataset(path, patch_shape, rater, download, **ds_kwargs) + return torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/histopathology/lizard.py b/torch_em/data/datasets/histopathology/lizard.py new file mode 100644 index 00000000..dbfd9e50 --- /dev/null +++ b/torch_em/data/datasets/histopathology/lizard.py @@ -0,0 +1,151 @@ +"""The Lizard dataset contains annotations for nucleus segmentation +in histopathology images in H&E stained colon tissue. + +This dataset is from the publication https://doi.org/10.48550/arXiv.2108.11195. +Please cite it if you use this dataset for your research. +""" + +import os +from glob import glob +from tqdm import tqdm +from shutil import rmtree +from typing import Tuple, Union, List + +import imageio.v3 as imageio + +from scipy.io import loadmat + +from torch.utils.data import Dataset, DataLoader + +import torch_em + +from .. import util + + +def _extract_images(image_folder, label_folder, output_dir): + import h5py + + image_files = glob(os.path.join(image_folder, "*.png")) + for image_file in tqdm(image_files, desc=f"Extract images from {image_folder}"): + fname = os.path.basename(image_file) + label_file = os.path.join(label_folder, fname.replace(".png", ".mat")) + assert os.path.exists(label_file), label_file + + image = imageio.imread(image_file) + assert image.ndim == 3 and image.shape[-1] == 3 + + labels = loadmat(label_file) + segmentation = labels["inst_map"] + assert image.shape[:-1] == segmentation.shape + classes = labels["class"] + + image = image.transpose((2, 0, 1)) + assert image.shape[1:] == segmentation.shape + + output_file = os.path.join(output_dir, fname.replace(".png", ".h5")) + with h5py.File(output_file, "a") as f: + f.create_dataset("image", data=image, compression="gzip") + f.create_dataset("labels/segmentation", data=segmentation, compression="gzip") + f.create_dataset("labels/classes", data=classes, compression="gzip") + + +def get_lizard_data(path, download): + """Download the Lizard dataset for nucleus segmentation. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + download: Whether to download the data if it is not present. + """ + util.download_source_kaggle(path=path, dataset_name="aadimator/lizard-dataset", download=download) + zip_path = os.path.join(path, "lizard-dataset.zip") + util.unzip(zip_path=zip_path, dst=path) + + image_files = glob(os.path.join(path, "*.h5")) + if len(image_files) > 0: + return + + os.makedirs(path, exist_ok=True) + + image_folder1 = os.path.join(path, "lizard_images1", "Lizard_Images1") + image_folder2 = os.path.join(path, "lizard_images2", "Lizard_Images2") + label_folder = os.path.join(path, "lizard_labels", "Lizard_Labels") + + assert os.path.exists(image_folder1), image_folder1 + assert os.path.exists(image_folder2), image_folder2 + assert os.path.exists(label_folder), label_folder + + _extract_images(image_folder1, os.path.join(label_folder, "Labels"), path) + _extract_images(image_folder2, os.path.join(label_folder, "Labels"), path) + + rmtree(os.path.join(path, "lizard_images1")) + rmtree(os.path.join(path, "lizard_images2")) + rmtree(os.path.join(path, "lizard_labels")) + rmtree(os.path.join(path, "overlay")) + + +def get_lizard_paths(path: Union[os.PathLike], download: bool = False) -> List[str]: + """Get paths to the Lizard data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + download: Whether to download the data if it is not present. + + Returns: + List of filepaths for the stored data. + """ + get_lizard_data(path, download) + + data_paths = glob(os.path.join(path, "*.h5")) + data_paths.sort() + return data_paths + + +def get_lizard_dataset( + path: Union[os.PathLike, str], patch_shape: Tuple[int, int], download: bool = False, **kwargs +) -> Dataset: + """Get the Lizard dataset for nucleus segmentation. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + patch_shape: The patch shape to use for training. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The segmentation dataset. + """ + data_paths = get_lizard_paths(path, download) + + return torch_em.default_segmentation_dataset( + raw_paths=data_paths, + raw_key="image", + label_paths=data_paths, + label_key="labels/segmentation", + patch_shape=patch_shape, + ndim=2, + with_channels=True, + **kwargs + ) + + +# TODO implement loading the classification labels +# TODO implement selecting different tissue types +# TODO implement train / val / test split (is pre-defined in a csv) +def get_lizard_loader( + path: Union[os.PathLike, str], patch_shape: Tuple[int, int], batch_size: int, download: bool = False, **kwargs +) -> DataLoader: + """Get the Lizard dataloader for nucleus segmentation. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + patch_shape: The patch shape to use for training. + batch_size: The batch size for training. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. + + Returns: + The DataLoader. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + ds = get_lizard_dataset(path, patch_shape, download=download, **ds_kwargs) + return torch_em.get_data_loader(ds, batch_size=batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/monusac.py b/torch_em/data/datasets/histopathology/monusac.py similarity index 51% rename from torch_em/data/datasets/monusac.py rename to torch_em/data/datasets/histopathology/monusac.py index 6a1f0175..3f3500f9 100644 --- a/torch_em/data/datasets/monusac.py +++ b/torch_em/data/datasets/histopathology/monusac.py @@ -1,14 +1,26 @@ +"""This dataset consists annotations for nucleus segmentation in +H&E stained tissue images derived from four different organs. + +This dataset comes from https://monusac-2020.grand-challenge.org/Data/. + +This dataset is from the publication https://doi.org/10.1109/TMI.2021.3085712. +Please cite it if you use this dataset in your research. +""" + import os import shutil from glob import glob from tqdm import tqdm from pathlib import Path -from typing import Optional, List +from typing import Optional, List, Union, Literal, Tuple import imageio.v3 as imageio +from torch.utils.data import Dataset, DataLoader + import torch_em -from . import util + +from .. import util URL = { @@ -25,45 +37,39 @@ # here's the description: https://drive.google.com/file/d/1kdOl3s6uQBRv0nToSIf1dPuceZunzL4N/view ORGAN_SPLITS = { "train": { - "lung": ["TCGA-55-1594", "TCGA-69-7760", "TCGA-69-A59K", "TCGA-73-4668", "TCGA-78-7220", - "TCGA-86-7713", "TCGA-86-8672", "TCGA-L4-A4E5", "TCGA-MP-A4SY", "TCGA-MP-A4T7"], - "kidney": ["TCGA-5P-A9K0", "TCGA-B9-A44B", "TCGA-B9-A8YI", "TCGA-DW-7841", "TCGA-EV-5903", "TCGA-F9-A97G", - "TCGA-G7-A8LD", "TCGA-MH-A560", "TCGA-P4-AAVK", "TCGA-SX-A7SR", "TCGA-UZ-A9PO", "TCGA-UZ-A9PU"], - "breast": ["TCGA-A2-A0CV", "TCGA-A2-A0ES", "TCGA-B6-A0WZ", "TCGA-BH-A18T", "TCGA-D8-A1X5", - "TCGA-E2-A154", "TCGA-E9-A22B", "TCGA-E9-A22G", "TCGA-EW-A6SD", "TCGA-S3-AA11"], - "prostate": ["TCGA-EJ-5495", "TCGA-EJ-5505", "TCGA-EJ-5517", "TCGA-G9-6342", "TCGA-G9-6499", - "TCGA-J4-A67Q", "TCGA-J4-A67T", "TCGA-KK-A59X", "TCGA-KK-A6E0", "TCGA-KK-A7AW", - "TCGA-V1-A8WL", "TCGA-V1-A9O9", "TCGA-X4-A8KQ", "TCGA-YL-A9WY"] + "lung": [ + "TCGA-55-1594", "TCGA-69-7760", "TCGA-69-A59K", "TCGA-73-4668", "TCGA-78-7220", + "TCGA-86-7713", "TCGA-86-8672", "TCGA-L4-A4E5", "TCGA-MP-A4SY", "TCGA-MP-A4T7" + ], + "kidney": [ + "TCGA-5P-A9K0", "TCGA-B9-A44B", "TCGA-B9-A8YI", "TCGA-DW-7841", "TCGA-EV-5903", "TCGA-F9-A97G", + "TCGA-G7-A8LD", "TCGA-MH-A560", "TCGA-P4-AAVK", "TCGA-SX-A7SR", "TCGA-UZ-A9PO", "TCGA-UZ-A9PU" + ], + "breast": [ + "TCGA-A2-A0CV", "TCGA-A2-A0ES", "TCGA-B6-A0WZ", "TCGA-BH-A18T", "TCGA-D8-A1X5", + "TCGA-E2-A154", "TCGA-E9-A22B", "TCGA-E9-A22G", "TCGA-EW-A6SD", "TCGA-S3-AA11" + ], + "prostate": [ + "TCGA-EJ-5495", "TCGA-EJ-5505", "TCGA-EJ-5517", "TCGA-G9-6342", "TCGA-G9-6499", + "TCGA-J4-A67Q", "TCGA-J4-A67T", "TCGA-KK-A59X", "TCGA-KK-A6E0", "TCGA-KK-A7AW", + "TCGA-V1-A8WL", "TCGA-V1-A9O9", "TCGA-X4-A8KQ", "TCGA-YL-A9WY" + ] }, "test": { - "lung": ["TCGA-49-6743", "TCGA-50-6591", "TCGA-55-7570", "TCGA-55-7573", - "TCGA-73-4662", "TCGA-78-7152", "TCGA-MP-A4T7"], - "kidney": ["TCGA-2Z-A9JG", "TCGA-2Z-A9JN", "TCGA-DW-7838", "TCGA-DW-7963", - "TCGA-F9-A8NY", "TCGA-IZ-A6M9", "TCGA-MH-A55W"], + "lung": [ + "TCGA-49-6743", "TCGA-50-6591", "TCGA-55-7570", "TCGA-55-7573", + "TCGA-73-4662", "TCGA-78-7152", "TCGA-MP-A4T7" + ], + "kidney": [ + "TCGA-2Z-A9JG", "TCGA-2Z-A9JN", "TCGA-DW-7838", "TCGA-DW-7963", + "TCGA-F9-A8NY", "TCGA-IZ-A6M9", "TCGA-MH-A55W" + ], "breast": ["TCGA-A2-A04X", "TCGA-A2-A0ES", "TCGA-D8-A3Z6", "TCGA-E2-A108", "TCGA-EW-A6SB"], "prostate": ["TCGA-G9-6356", "TCGA-G9-6367", "TCGA-VP-A87E", "TCGA-VP-A87H", "TCGA-X4-A8KS", "TCGA-YL-A9WL"] }, } -def _download_monusac(path, download, split): - assert split in ["train", "test"], "Please choose from train/test" - - # check if we have extracted the images and labels already - im_path = os.path.join(path, "images", split) - label_path = os.path.join(path, "labels", split) - if os.path.exists(im_path) and os.path.exists(label_path): - return - - os.makedirs(path, exist_ok=True) - zip_path = os.path.join(path, f"monusac_{split}.zip") - util.download_source_gdrive(zip_path, URL[split], download=download, checksum=CHECKSUM[split]) - - _process_monusac(path, split) - - _check_channel_consistency(path, split) - - def _check_channel_consistency(path, split): "The provided tif images have RGBA channels, check and remove the alpha channel" all_image_path = glob(os.path.join(path, "images", split, "*.tif")) @@ -137,13 +143,50 @@ def get_patient_id(path, split_wrt="-01Z-00-"): return patient_id -def get_monusac_dataset( - path, patch_shape, split, organ_type: Optional[List[str]] = None, download=False, - offsets=None, boundaries=False, binary=False, **kwargs -): - """Dataset from https://monusac-2020.grand-challenge.org/Data/ +def get_monusac_data(path: Union[os.PathLike, str], split: Literal['train', 'test'], download: bool = False): + """Download the MoNuSAC dataset. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The split to use for the dataset. Either 'train' or 'test'. + download: Whether to download the data if it is not present. + """ + assert split in ["train", "test"], "Please choose from train/test" + + # check if we have extracted the images and labels already + im_path = os.path.join(path, "images", split) + label_path = os.path.join(path, "labels", split) + if os.path.exists(im_path) and os.path.exists(label_path): + return + + os.makedirs(path, exist_ok=True) + zip_path = os.path.join(path, f"monusac_{split}.zip") + util.download_source_gdrive(zip_path, URL[split], download=download, checksum=CHECKSUM[split]) + + _process_monusac(path, split) + + _check_channel_consistency(path, split) + + +def get_monusac_paths( + path: Union[os.PathLike, str], + split: Literal['train', 'val'], + organ_type: Optional[List[str]] = None, + download: bool = False +) -> Tuple[List[str], List[str]]: + """Get paths to MoNuSAC data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The split to use for the dataset. Either 'train' or 'test'. + organ_type: The choice of organ type. + download: Whether to download the data if it is not present. + + Returns: + List of filepaths to the image data. + List of filepaths to the label data. """ - _download_monusac(path, download, split) + get_monusac_data(path, split, download) image_paths = sorted(glob(os.path.join(path, "images", split, "*"))) label_paths = sorted(glob(os.path.join(path, "labels", split, "*"))) @@ -157,24 +200,85 @@ def get_monusac_dataset( assert len(image_paths) == len(label_paths) + return image_paths, label_paths + + +def get_monusac_dataset( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, ...], + split: Literal['train', 'test'], + organ_type: Optional[List[str]] = None, + download: bool = False, + offsets: Optional[List[List[int]]] = None, + boundaries: bool = False, + binary: bool = False, + **kwargs +) -> Dataset: + """Get the MoNuSAC dataset for nucleus segmentation in H&E stained tissue images. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + patch_shape: The patch shape to use for training. + split: The split to use for the dataset. Either 'train' or 'test'. + organ_type: The choice of organ type. + download: Whether to download the data if it is not present. + offsets: Offset values for affinity computation used as target. + boundaries: Whether to compute boundaries as the target. + binary: Whether to use a binary segmentation target. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The segmentation dataset. + """ + image_paths, label_paths = get_monusac_paths(path, split, organ_type, download) + kwargs, _ = util.add_instance_label_transform( kwargs, add_binary_target=True, binary=binary, boundaries=boundaries, offsets=offsets ) + return torch_em.default_segmentation_dataset( - image_paths, None, label_paths, None, patch_shape, is_seg_dataset=False, **kwargs + raw_paths=image_paths, + raw_key=None, + label_paths=label_paths, + label_key=None, + patch_shape=patch_shape, + is_seg_dataset=False, + **kwargs ) def get_monusac_loader( - path, patch_shape, split, batch_size, organ_type=None, download=False, - offsets=None, boundaries=False, binary=False, **kwargs -): - ds_kwargs, loader_kwargs = util.split_kwargs( - torch_em.default_segmentation_dataset, **kwargs - ) + path: Union[os.PathLike, str], + patch_shape: Tuple[int, ...], + batch_size: int, + split: Literal['train', 'test'], + organ_type: Optional[List[str]] = None, + download: bool = False, + offsets: Optional[List[List[int]]] = None, + boundaries: bool = False, + binary: bool = False, + **kwargs +) -> DataLoader: + """Get the MoNuSAC dataloader for nucleus segmentation in H&E stained tissue images. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + patch_shape: The patch shape to use for training. + batch_size: The batch size for training. + split: The split to use for the dataset. Either 'train' or 'test'. + organ_type: The choice of organ type. + download: Whether to download the data if it is not present. + offsets: Offset values for affinity computation used as target. + boundaries: Whether to compute boundaries as the target. + binary: Whether to use a binary segmentation target. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. + + Returns: + The DataLoader + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) dataset = get_monusac_dataset( path, patch_shape, split, organ_type=organ_type, download=download, offsets=offsets, boundaries=boundaries, binary=binary, **ds_kwargs ) - loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) - return loader + return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/histopathology/monuseg.py b/torch_em/data/datasets/histopathology/monuseg.py new file mode 100644 index 00000000..7adc7341 --- /dev/null +++ b/torch_em/data/datasets/histopathology/monuseg.py @@ -0,0 +1,237 @@ +"""This dataset contains annotations for nucleus segmentation in +H&E stained tissue images derived from different organs. + +This dataset comes from https://monuseg.grand-challenge.org/Data/. + +Please cite the relevant publications from the challenge +if you use this dataset in your research. +""" + +import os +import shutil +from tqdm import tqdm +from glob import glob +from pathlib import Path +from typing import List, Optional, Union, Tuple, Literal + +import imageio.v3 as imageio + +from torch.utils.data import Dataset, DataLoader + +import torch_em + +from .. import util + + +URL = { + "train": "https://drive.google.com/uc?export=download&id=1ZgqFJomqQGNnsx7w7QBzQQMVA16lbVCA", + "test": "https://drive.google.com/uc?export=download&id=1NKkSQ5T0ZNQ8aUhh0a8Dt2YKYCQXIViw" +} + +CHECKSUM = { + "train": "25d3d3185bb2970b397cafa72eb664c9b4d24294aee382e7e3df9885affce742", + "test": "13e522387ae8b1bcc0530e13ff9c7b4d91ec74959ef6f6e57747368d7ee6f88a" +} + +# Here is the description: https://drive.google.com/file/d/1xYyQ31CHFRnvTCTuuHdconlJCMk2SK7Z/view?usp=sharing +ORGAN_SPLITS = { + "breast": [ + "TCGA-A7-A13E-01Z-00-DX1", "TCGA-A7-A13F-01Z-00-DX1", "TCGA-AR-A1AK-01Z-00-DX1", + "TCGA-AR-A1AS-01Z-00-DX1", "TCGA-E2-A1B5-01Z-00-DX1", "TCGA-E2-A14V-01Z-00-DX1" + ], + "kidney": [ + "TCGA-B0-5711-01Z-00-DX1", "TCGA-HE-7128-01Z-00-DX1", "TCGA-HE-7129-01Z-00-DX1", + "TCGA-HE-7130-01Z-00-DX1", "TCGA-B0-5710-01Z-00-DX1", "TCGA-B0-5698-01Z-00-DX1" + ], + "liver": [ + "TCGA-18-5592-01Z-00-DX1", "TCGA-38-6178-01Z-00-DX1", "TCGA-49-4488-01Z-00-DX1", + "TCGA-50-5931-01Z-00-DX1", "TCGA-21-5784-01Z-00-DX1", "TCGA-21-5786-01Z-00-DX1" + ], + "prostate": [ + "TCGA-G9-6336-01Z-00-DX1", "TCGA-G9-6348-01Z-00-DX1", "TCGA-G9-6356-01Z-00-DX1", + "TCGA-G9-6363-01Z-00-DX1", "TCGA-CH-5767-01Z-00-DX1", "TCGA-G9-6362-01Z-00-DX1" + ], + "bladder": ["TCGA-DK-A2I6-01A-01-TS1", "TCGA-G2-A2EK-01A-02-TSB"], + "colon": ["TCGA-AY-A8YK-01A-01-TS1", "TCGA-NH-A8F7-01A-01-TS1"], + "stomach": ["TCGA-KB-A93J-01A-01-TS1", "TCGA-RD-A8N9-01A-01-TS1"] +} + + +def _process_monuseg(path, split): + util.unzip(os.path.join(path, f"monuseg_{split}.zip"), path) + + # assorting the images into expected dir; + # converting the label xml files to numpy arrays (of same dimension as input images) in the expected dir + root_img_save_dir = os.path.join(path, "images", split) + root_label_save_dir = os.path.join(path, "labels", split) + + os.makedirs(root_img_save_dir, exist_ok=True) + os.makedirs(root_label_save_dir, exist_ok=True) + + if split == "train": + all_img_dir = sorted(glob(os.path.join(path, "*", "Tissue*", "*"))) + all_xml_label_dir = sorted(glob(os.path.join(path, "*", "Annotations", "*"))) + else: + all_img_dir = sorted(glob(os.path.join(path, "MoNuSegTestData", "*.tif"))) + all_xml_label_dir = sorted(glob(os.path.join(path, "MoNuSegTestData", "*.xml"))) + + assert len(all_img_dir) == len(all_xml_label_dir) + + for img_path, xml_label_path in tqdm( + zip(all_img_dir, all_xml_label_dir), + desc=f"Converting {split} split to the expected format", + total=len(all_img_dir) + ): + desired_label_shape = imageio.imread(img_path).shape[:-1] + + img_id = os.path.split(img_path)[-1] + dst = os.path.join(root_img_save_dir, img_id) + shutil.move(src=img_path, dst=dst) + + _label = util.generate_labeled_array_from_xml(shape=desired_label_shape, xml_file=xml_label_path) + _fileid = img_id.split(".")[0] + imageio.imwrite(os.path.join(root_label_save_dir, f"{_fileid}.tif"), _label, compression="zlib") + + shutil.rmtree(glob(os.path.join(path, "MoNuSeg*"))[0]) + if split == "train": + shutil.rmtree(glob(os.path.join(path, "__MACOSX"))[0]) + + +def get_monuseg_data(path: Union[os.PathLike, str], split: Literal['train', 'test'], download: bool = False): + """Download the MoNuSeg dataset. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The split to use for the dataset. Either 'train' or 'test'. + download: Whether to download the data if it is not present. + """ + assert split in ["train", "test"], "The split choices in MoNuSeg datset are train/test, please choose from them" + + # check if we have extracted the images and labels already + im_path = os.path.join(path, "images", split) + label_path = os.path.join(path, "labels", split) + if os.path.exists(im_path) and os.path.exists(label_path): + return + + os.makedirs(path, exist_ok=True) + zip_path = os.path.join(path, f"monuseg_{split}.zip") + util.download_source_gdrive(zip_path, URL[split], download=download, checksum=CHECKSUM[split]) + + _process_monuseg(path, split) + + +def get_monuseg_paths( + path: Union[os.PathLike, str], + split: Literal['train', 'test'], + organ_type: Optional[List[str]] = None, + download: bool = False +) -> Tuple[List[str], List[str]]: + """Get paths to the MoNuSeg data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The split to use for the dataset. Either 'train' or 'test'. + organ_type: The choice of organ type. + download: Whether to download the data if it is not present. + + Returns: + List of filepaths to the image data. + List of filepaths to the label data. + """ + get_monuseg_data(path, split, download) + + image_paths = sorted(glob(os.path.join(path, "images", split, "*"))) + label_paths = sorted(glob(os.path.join(path, "labels", split, "*"))) + + if split == "train" and organ_type is not None: + # get all patients for multiple organ selection + all_organ_splits = sum([ORGAN_SPLITS[_o] for _o in organ_type], []) + + image_paths = [_path for _path in image_paths if Path(_path).stem in all_organ_splits] + label_paths = [_path for _path in label_paths if Path(_path).stem in all_organ_splits] + + elif split == "test" and organ_type is not None: + # we don't have organ splits in the test dataset + raise ValueError("The test split does not have any organ informations, please pass `organ_type=None`") + + return image_paths, label_paths + + +def get_monuseg_dataset( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, ...], + split: Literal['train', 'test'], + organ_type: Optional[List[str]] = None, + download: bool = False, + offsets: Optional[List[List[int]]] = None, + boundaries: bool = False, + binary: bool = False, + **kwargs +) -> Dataset: + """Get the MoNuSeg dataset for nucleus segmentation in H&E stained tissue images. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + patch_shape: The patch shape to use for training. + split: The split to use for the dataset. Either 'train' or 'test'. + organ_type: The choice of organ type. + download: Whether to download the data if it is not present. + offsets: Offset values for affinity computation used as target. + boundaries: Whether to compute boundaries as the target. + binary: Whether to use a binary segmentation target. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The segmentation dataset. + """ + image_paths, label_paths = get_monuseg_paths(path, split, organ_type, download) + + kwargs, _ = util.add_instance_label_transform( + kwargs, add_binary_target=True, binary=binary, boundaries=boundaries, offsets=offsets + ) + return torch_em.default_segmentation_dataset( + raw_paths=image_paths, + raw_key=None, + label_paths=label_paths, + label_key=None, + patch_shape=patch_shape, + is_seg_dataset=False, + **kwargs + ) + + +def get_monuseg_loader( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, ...], + batch_size: int, + split: Literal['train', 'test'], + organ_type: Optional[List[str]] = None, + download: bool = False, + offsets: Optional[List[List[int]]] = None, + boundaries: bool = False, + binary: bool = False, + **kwargs +) -> DataLoader: + """Get the MoNuSeg dataloader for nucleus segmentation in H&E stained tissue images. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + patch_shape: The patch shape to use for training. + batch_size: The batch size for training. + split: The split to use for the dataset. Either 'train' or 'test'. + organ_type: The choice of organ type. + download: Whether to download the data if it is not present. + offsets: Offset values for affinity computation used as target. + boundaries: Whether to compute boundaries as the target. + binary: Whether to use a binary segmentation target. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. + + Returns: + The DataLoader + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_monuseg_dataset( + path, patch_shape, split, organ_type=organ_type, download=download, + offsets=offsets, boundaries=boundaries, binary=binary, **ds_kwargs + ) + return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/pannuke.py b/torch_em/data/datasets/histopathology/pannuke.py similarity index 63% rename from torch_em/data/datasets/pannuke.py rename to torch_em/data/datasets/histopathology/pannuke.py index df46154f..63243dee 100644 --- a/torch_em/data/datasets/pannuke.py +++ b/torch_em/data/datasets/histopathology/pannuke.py @@ -1,13 +1,22 @@ +"""The PanNuke datasets contains annotations for nucleus segmentation +in histopathology images across different tissue types. + +This dataset is from the publication https://doi.org/10.48550/arXiv.2003.10778. +Please cite it if you use this dataset for your research. +""" + import os -import h5py -import vigra import shutil -import numpy as np from glob import glob -from typing import List +from typing import List, Union, Dict, Tuple + +import numpy as np + +from torch.utils.data import Dataset, DataLoader import torch_em -from torch_em.data.datasets import util + +from .. import util # PanNuke Dataset - https://warwick.ac.uk/fac/cross_fac/tia/data/pannuke @@ -17,7 +26,6 @@ "fold_3": "https://warwick.ac.uk/fac/cross_fac/tia/data/pannuke/fold_3.zip" } - CHECKSUM = { "fold_1": "6e19ad380300e8ce9480f9ab6a14cc91fa4b6a511609b40e3d70bdf9c881ed0b", "fold_2": "5bc540cc509f64b5f5a274d6e5a245527dbd3e6d3155d43555115c5d54709b07", @@ -25,16 +33,20 @@ } -def _download_pannuke_dataset(path, download, folds): - os.makedirs(path, exist_ok=True) - - checksum = CHECKSUM +def get_pannuke_data(path, download, folds): + """Download the PanNuke data. + Args: + path: Filepath to a folder where the downloaded data will be saved. + download: Whether to download the data if it is not present. + folds: The data fold(s) of choice to be used. + """ + os.makedirs(path, exist_ok=True) for tmp_fold in folds: if os.path.exists(os.path.join(path, f"pannuke_{tmp_fold}.h5")): return - util.download_source(os.path.join(path, f"{tmp_fold}.zip"), URLS[tmp_fold], download, checksum[tmp_fold]) + util.download_source(os.path.join(path, f"{tmp_fold}.zip"), URLS[tmp_fold], download, CHECKSUM[tmp_fold]) print(f"Unzipping the PanNuke dataset in {tmp_fold} directories...") util.unzip(os.path.join(path, f"{tmp_fold}.zip"), os.path.join(path, f"{tmp_fold}"), True) @@ -52,6 +64,8 @@ def _convert_to_hdf5(path, fold): (0: Background, 1: Neoplastic cells, 2: Inflammatory, 3: Connective/Soft tissue cells, 4: Dead Cells, 5: Epithelial) """ + import h5py + if os.path.exists(os.path.join(path, f"pannuke_{fold}.h5")): return @@ -95,6 +109,8 @@ def _channels_to_instances(labels): Returns: - instance labels of dimensions -> (C x H x W) """ + import vigra + labels = labels.transpose(0, 3, 1, 2) # to access with the shape S x 6 x H x W list_of_instances = [] @@ -142,17 +158,52 @@ def _channels_to_semantics(labels): return f_segmentation +def get_pannuke_paths( + path: Union[os.PathLike, str], folds: List[str] = ["fold_1", "fold_2", "fold_3"], download: bool = False, +) -> List[str]: + """Get paths to the PanNuke data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + folds: The data fold(s) of choice to be used. + download: Whether to download the data if it is not present. + + Returns: + List of filepaths to the stored data. + """ + get_pannuke_data(path, download, folds) + + data_paths = [os.path.join(path, f"pannuke_{fold}.h5") for fold in folds] + return data_paths + + def get_pannuke_dataset( - path, - patch_shape, - folds: List[str] = ["fold_1", "fold_2", "fold_3"], - rois={}, - download=False, - with_channels=True, - with_label_channels=False, - custom_label_choice: str = "instances", - **kwargs -): + path: Union[os.PathLike, str], + patch_shape: Tuple[int, ...], + folds: List[str] = ["fold_1", "fold_2", "fold_3"], + rois: Dict = {}, + download: bool = False, + custom_label_choice: str = "instances", + with_channels: bool = True, + with_label_channels: bool = False, + **kwargs +) -> Dataset: + """Get the PanNuke dataset for nucleus segmentation. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + patch_shape: The patch shape to use for training. + folds: The data fold(s) of choice to be used. + download: Whether to download the data if it is not present. + rois: The choice of rois per fold to create the dataloader for training. + custom_label_choice: The choice of labels to be used for training. + with_channels: Whether the inputs have channels. + with_label_channels: Whether the labels have channels. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The segmentation dataset + """ assert custom_label_choice in [ "masks", "instances", "semantic" ], "Select the type of labels you want from [masks/instances/semantic] (See `_convert_to_hdf5` for details)" @@ -160,34 +211,47 @@ def get_pannuke_dataset( if rois is not None: assert isinstance(rois, dict) - _download_pannuke_dataset(path, download, folds) - - data_paths = [os.path.join(path, f"pannuke_{fold}.h5") for fold in folds] - data_rois = [rois.get(fold, np.s_[:, :, :]) for fold in folds] - - raw_key = "images" - label_key = f"labels/{custom_label_choice}" + data_paths = get_pannuke_paths(path, folds, download) return torch_em.default_segmentation_dataset( - data_paths, raw_key, data_paths, label_key, patch_shape, rois=data_rois, - with_channels=with_channels, with_label_channels=with_label_channels, **kwargs + raw_paths=data_paths, + raw_key="images", + label_paths=data_paths, + label_key=f"labels/{custom_label_choice}", + patch_shape=patch_shape, + rois=[rois.get(fold, np.s_[:, :, :]) for fold in folds], + with_channels=with_channels, + with_label_channels=with_label_channels, + **kwargs ) def get_pannuke_loader( - path, - patch_shape, - batch_size, - folds=["fold_1", "fold_2", "fold_3"], - download=False, - rois={}, - custom_label_choice="instances", - **kwargs -): - """TODO + path: Union[os.PathLike, str], + patch_shape: Tuple[int, ...], + batch_size: str, + folds: List[str] = ["fold_1", "fold_2", "fold_3"], + download: bool = False, + rois: Dict = {}, + custom_label_choice: str = "instances", + **kwargs +) -> DataLoader: + """Get the PanNuke dataloader for nucleus segmentation. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + patch_shape: The patch shape to use for training. + batch_size: The batch size for training. + folds: The data fold(s) of choice to be used. + download: Whether to download the data if it is not present. + rois: The choice of rois per fold to create the dataloader for training. + custom_label_choice: The choice of labels to be used for training. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. + + Returns: + The DataLoader """ dataset_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) - ds = get_pannuke_dataset( path=path, patch_shape=patch_shape, @@ -195,5 +259,6 @@ def get_pannuke_loader( rois=rois, download=download, custom_label_choice=custom_label_choice, - **dataset_kwargs) + **dataset_kwargs + ) return torch_em.get_data_loader(ds, batch_size=batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/isbi2012.py b/torch_em/data/datasets/isbi2012.py deleted file mode 100644 index 365d21ff..00000000 --- a/torch_em/data/datasets/isbi2012.py +++ /dev/null @@ -1,56 +0,0 @@ -import os - -import torch_em -from . import util - -ISBI_URL = "https://oc.embl.de/index.php/s/h0TkwqxU0PJDdMd/download" -CHECKSUM = "0e10fe909a1243084d91773470856993b7d40126a12e85f0f1345a7a9e512f29" - - -def get_isbi_dataset( - path, patch_shape, download=False, offsets=None, boundaries=False, - use_original_labels=False, **kwargs -): - """Dataset for the segmentation of neurons in EM. - - This dataset is from the publication https://doi.org/10.3389/fnana.2015.00142. - Please cite it if you use this dataset for a publication. - """ - if path.endswith(".h5"): - volume_path = path - else: - os.makedirs(path, exist_ok=True) - volume_path = os.path.join(path, "isbi.h5") - - assert len(patch_shape) == 3 - util.download_source(volume_path, ISBI_URL, download, CHECKSUM) - ndim = 2 if patch_shape[0] == 1 else 3 - kwargs = util.update_kwargs(kwargs, "ndim", ndim) - - kwargs, _ = util.add_instance_label_transform( - kwargs, add_binary_target=False, boundaries=boundaries, offsets=offsets - ) - - raw_key = "raw" - label_key = "labels/membranes" if use_original_labels else "labels/gt_segmentation" - - return torch_em.default_segmentation_dataset(volume_path, raw_key, volume_path, label_key, patch_shape, **kwargs) - - -def get_isbi_loader( - path, patch_shape, batch_size, download=False, - offsets=None, boundaries=False, - use_original_labels=False, - **kwargs -): - """Dataloader for the segmentation of neurons in EM. See 'get_isbi_dataset' for details.""" - ds_kwargs, loader_kwargs = util.split_kwargs( - torch_em.default_segmentation_dataset, **kwargs - ) - dataset = get_isbi_dataset( - path, patch_shape, download=download, - offsets=offsets, boundaries=boundaries, use_original_labels=use_original_labels, - **ds_kwargs - ) - loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) - return loader diff --git a/torch_em/data/datasets/kasthuri.py b/torch_em/data/datasets/kasthuri.py deleted file mode 100644 index 76b30869..00000000 --- a/torch_em/data/datasets/kasthuri.py +++ /dev/null @@ -1,104 +0,0 @@ -import os -from concurrent import futures -from glob import glob -from shutil import rmtree - -import imageio -import h5py -import numpy as np -import torch_em - -from tqdm import tqdm -from . import util - -URL = "http://www.casser.io/files/kasthuri_pp.zip " -CHECKSUM = "bbb78fd205ec9b57feb8f93ebbdf1666261cbc3e0305e7f11583ab5157a3d792" - -# data from: https://sites.google.com/view/connectomics/ -# TODO: add sampler for foreground (-1 is empty area) -# TODO and masking for the empty space - - -def _load_volume(path): - files = glob(os.path.join(path, "*.png")) - files.sort() - nz = len(files) - - im0 = imageio.imread(files[0]) - out = np.zeros((nz,) + im0.shape, dtype=im0.dtype) - out[0] = im0 - - def _loadz(z): - im = imageio.imread(files[z]) - out[z] = im - - n_threads = 8 - with futures.ThreadPoolExecutor(n_threads) as tp: - list(tqdm( - tp.map(_loadz, range(1, nz)), desc="Load volume", total=nz-1 - )) - - return out - - -def _create_data(root, inputs, out_path): - raw = _load_volume(os.path.join(root, inputs[0])) - labels_argb = _load_volume(os.path.join(root, inputs[1])) - assert labels_argb.ndim == 4 - labels = np.zeros(raw.shape, dtype="int8") - - fg_mask = (labels_argb == np.array([255, 255, 255])[None, None, None]).all(axis=-1) - labels[fg_mask] = 1 - bg_mask = (labels_argb == np.array([2, 2, 2])[None, None, None]).all(axis=-1) - labels[bg_mask] = -1 - assert (np.unique(labels) == np.array([-1, 0, 1])).all() - assert raw.shape == labels.shape, f"{raw.shape}, {labels.shape}" - with h5py.File(out_path, "w") as f: - f.create_dataset("raw", data=raw, compression="gzip") - f.create_dataset("labels", data=labels, compression="gzip") - - -def _require_kasthuri_data(path, download): - # download and unzip the data - if os.path.exists(path): - return path - - os.makedirs(path) - tmp_path = os.path.join(path, "kasthuri.zip") - util.download_source(tmp_path, URL, download, checksum=CHECKSUM) - util.unzip(tmp_path, path, remove=True) - - root = os.path.join(path, "Kasthuri++") - assert os.path.exists(root), root - - inputs = [["Test_In", "Test_Out"], ["Train_In", "Train_Out"]] - outputs = ["kasthuri_train.h5", "kasthuri_test.h5"] - for inp, out in zip(inputs, outputs): - out_path = os.path.join(path, out) - _create_data(root, inp, out_path) - - rmtree(root) - - -def get_kasthuri_dataset(path, split, patch_shape, download=False, **kwargs): - """Dataset for the segmentation of mitochondria in EM. - - This dataset is from the publication https://doi.org/10.48550/arXiv.1812.06024. - Please cite it if you use this dataset for a publication. - """ - assert split in ("train", "test") - _require_kasthuri_data(path, download) - data_path = os.path.join(path, f"kasthuri_{split}.h5") - assert os.path.exists(data_path), data_path - raw_key, label_key = "raw", "labels" - return torch_em.default_segmentation_dataset(data_path, raw_key, data_path, label_key, patch_shape, **kwargs) - - -def get_kasthuri_loader(path, split, patch_shape, batch_size, download=False, **kwargs): - """Dataloader for the segmentation of mitochondria in EM. See 'get_kasthuri_dataset' for details.""" - ds_kwargs, loader_kwargs = util.split_kwargs( - torch_em.default_segmentation_dataset, **kwargs - ) - dataset = get_kasthuri_dataset(path, split, patch_shape, download=download, **ds_kwargs) - loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) - return loader diff --git a/torch_em/data/datasets/light_microscopy/__init__.py b/torch_em/data/datasets/light_microscopy/__init__.py new file mode 100644 index 00000000..216dfd70 --- /dev/null +++ b/torch_em/data/datasets/light_microscopy/__init__.py @@ -0,0 +1,23 @@ +from .cellpose import get_cellpose_loader, get_cellpose_dataset +from .cellseg_3d import get_cellseg_3d_loader, get_cellseg_3d_dataset +from .covid_if import get_covid_if_loader, get_covid_if_dataset +from .ctc import get_ctc_segmentation_loader, get_ctc_segmentation_dataset +from .deepbacs import get_deepbacs_loader, get_deepbacs_dataset +from .dic_hepg2 import get_dic_hepg2_loader, get_dic_hepg2_dataset +from .dsb import get_dsb_loader, get_dsb_dataset +from .dynamicnuclearnet import get_dynamicnuclearnet_loader, get_dynamicnuclearnet_dataset +from .embedseg_data import get_embedseg_loader, get_embedseg_dataset +from .gonuclear import get_gonuclear_loader, get_gonuclear_dataset +from .hpa import get_hpa_segmentation_loader, get_hpa_segmentation_dataset +from .livecell import get_livecell_loader, get_livecell_dataset +from .mouse_embryo import get_mouse_embryo_loader, get_mouse_embryo_dataset +from .neurips_cell_seg import ( + get_neurips_cellseg_supervised_loader, get_neurips_cellseg_supervised_dataset, + get_neurips_cellseg_unsupervised_loader, get_neurips_cellseg_unsupervised_dataset +) +from .omnipose import get_omnipose_dataset, get_omnipose_loader +from .orgasegment import get_orgasegment_dataset, get_orgasegment_loader +from .organoidnet import get_organoidnet_dataset, get_organoidnet_loader +from .plantseg import get_plantseg_loader, get_plantseg_dataset +from .tissuenet import get_tissuenet_loader, get_tissuenet_dataset +from .vgg_hela import get_vgg_hela_loader, get_vgg_hela_dataset diff --git a/torch_em/data/datasets/light_microscopy/cellpose.py b/torch_em/data/datasets/light_microscopy/cellpose.py new file mode 100644 index 00000000..51cca865 --- /dev/null +++ b/torch_em/data/datasets/light_microscopy/cellpose.py @@ -0,0 +1,177 @@ +"""This dataset contains annotation for cell segmentation in +fluorescene microscently-labeled microscopy images. + +This dataset is from the following publications: +- https://doi.org/10.1038/s41592-020-01018-x +- https://doi.org/10.1038/s41592-022-01663-4 +Please cite it if you use this dataset in your research. +""" + +import os +from glob import glob +from natsort import natsorted +from typing import Union, Tuple, Literal, Optional, List + +import torch_em + +from torch.utils.data import Dataset, DataLoader + +from .. import util +from .neurips_cell_seg import to_rgb + + +AVAILABLE_CHOICES = ["cyto", "cyto2"] + + +def get_cellpose_data( + path: Union[os.PathLike, str], + split: Literal["train", "test"], + choice: Literal["cyto", "cyto2"], + download: bool = False, +) -> str: + """Instruction to download CellPose data. + + NOTE: Please download the dataset from "https://www.cellpose.org/dataset". + + Args: + path: Filepath to a folder where the data is downloaded for further processing. + split: The data split to use. Either 'train', or 'test'. + choice: The choice of dataset. Either 'cyto' or 'cyto2'. + download: Whether to download the data if it is not present. + + Returns: + The filepath to the data. + """ + if download: + assert NotImplementedError( + "The dataset cannot be automatically downloaded. ", + "Please see 'get_cellpose_data' in 'torch_em/data/datasets/cellpose.py' for details." + ) + + per_choice_dir = os.path.join(path, choice) # path where the unzipped files will be stored + if choice == "cyto": + assert split in ["train", "test"], f"'{split}' is not a valid split in '{choice}'." + zip_path = os.path.join(path, f"{split}.zip") + data_dir = os.path.join(per_choice_dir, split) # path where the per split images for 'cyto' exist. + elif choice == "cyto2": + assert split == "train", f"'{split}' is not a valid split in '{choice}'." + zip_path = os.path.join(path, "train_cyto2.zip") + data_dir = os.path.join(per_choice_dir, "train_cyto2") # path where 'train' split images for 'cyto2' exist. + else: + raise ValueError(f"'{choice}' is not a valid dataset choice.") + + if not os.path.exists(data_dir): + util.unzip(zip_path=zip_path, dst=per_choice_dir, remove=False) + + return data_dir + + +def get_cellpose_paths( + path: Union[os.PathLike, str], + split: Literal['train', 'test'], + choice: Literal["cyto", "cyto2"], + download: bool = False, +) -> Tuple[List[str], List[str]]: + """Get paths to the CellPose data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The data split to use. Either 'train', or 'test'. + choice: The choice of dataset. Either 'cyto' or 'cyto2'. + download: Whether to download the data if it is not present. + + Returns: + List of filepaths for the image data. + List of filepaths for the label data. + """ + data_dir = get_cellpose_data(path=path, split=split, choice=choice, download=download) + + image_paths = natsorted(glob(os.path.join(data_dir, "*_img.png"))) + gt_paths = natsorted(glob(os.path.join(data_dir, "*_masks.png"))) + + return image_paths, gt_paths + + +def get_cellpose_dataset( + path: Union[os.PathLike, str], + split: Literal["train", "test"], + patch_shape: Tuple[int, int], + choice: Optional[Literal["cyto", "cyto2"]] = None, + download: bool = False, + **kwargs +) -> Dataset: + """Get the CellPose dataset for cell segmentation. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The data split to use. Either 'train', or 'test'. + patch_shape: The patch shape to use for training. + choice: The choice of dataset. Either 'cyto' or 'cyto2'. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The segmentation dataset. + """ + assert split in ["train", "test"] + + if choice is None: + choice = AVAILABLE_CHOICES + else: + if not isinstance(choice, list): + choice = [choice] + + image_paths, gt_paths = [], [] + for per_choice in choice: + assert per_choice in AVAILABLE_CHOICES + per_image_paths, per_gt_paths = get_cellpose_paths(path, split, choice, download) + image_paths.extend(per_image_paths) + gt_paths.extend(per_gt_paths) + + if "raw_transform" not in kwargs: + raw_transform = torch_em.transform.get_raw_transform(augmentation2=to_rgb) + + if "transform" not in kwargs: + transform = torch_em.transform.get_augmentations(ndim=2) + + return torch_em.default_segmentation_dataset( + raw_paths=image_paths, + raw_key=None, + label_paths=gt_paths, + label_key=None, + is_seg_dataset=False, + patch_shape=patch_shape, + raw_transform=raw_transform, + transform=transform, + **kwargs + ) + + +def get_cellpose_loader( + path: Union[os.PathLike, str], + split: Literal["train", "test"], + patch_shape: Tuple[int, int], + batch_size: int, + choice: Optional[Literal["cyto", "cyto2"]] = None, + download: bool = False, + **kwargs +) -> DataLoader: + """Get the CellPose dataloader for cell segmentation. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The data split to use. Either 'train', or 'test'. + patch_shape: The patch shape to use for training. + batch_size: The batch size for training. + choice: The choice of dataset. Either 'cyto' or 'cyto2'. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. + + Returns: + The DataLoader. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_cellpose_dataset( + path=path, split=split, patch_shape=patch_shape, choice=choice, download=download, **ds_kwargs + ) + return torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/light_microscopy/cellseg_3d.py b/torch_em/data/datasets/light_microscopy/cellseg_3d.py new file mode 100644 index 00000000..69482e89 --- /dev/null +++ b/torch_em/data/datasets/light_microscopy/cellseg_3d.py @@ -0,0 +1,126 @@ +"""This dataset contains annotation for nucleus segmentation in 3d fluorescence microscopy from mesoSPIM microscopy. + +This dataset is from the publication https://doi.org/10.1101/2024.05.17.594691 . +Please cite it if you use this dataset in your research. +""" + +import os +from glob import glob +from typing import Optional, Tuple, Union, List + +from torch.utils.data import Dataset, DataLoader + +import torch_em + +from .. import util + + +URL = "https://zenodo.org/records/11095111/files/DATASET_WITH_GT.zip?download=1" +CHECKSUM = "6d8e8d778e479000161fdfea70201a6ded95b3958a703f69def63e69bbddf9d6" + + +def get_cellseg_3d_data(path: Union[os.PathLike, str], download: bool = False) -> str: + """Download the CellSeg3d training data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + download: Whether to download the data if it is not present. + + Returns: + The filepath to the training data. + """ + url = URL + checksum = CHECKSUM + + data_path = os.path.join(path, "DATASET_WITH_GT") + if os.path.exists(data_path): + return data_path + + os.makedirs(path, exist_ok=True) + zip_path = os.path.join(path, "cellseg3d.zip") + util.download_source(zip_path, url, download, checksum) + util.unzip(zip_path, path, True) + + return data_path + + +def get_cellseg_3d_paths(path: Union[os.PathLike, str], download: bool = False) -> Tuple[List[str], List[str]]: + """Get paths to the CellSeg3d data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + download: Whether to download the data if it is not present. + + Returns: + List of filepaths for the image data. + List of filepaths for the label data. + """ + data_root = get_cellseg_3d_data(path, download) + + raw_paths = sorted(glob(os.path.join(data_root, "*.tif"))) + label_paths = sorted(glob(os.path.join(data_root, "labels", "*.tif"))) + assert len(raw_paths) == len(label_paths) + + return raw_paths, label_paths + + +def get_cellseg_3d_dataset( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int], + sample_ids: Optional[Tuple[int, ...]] = None, + download: bool = False, + **kwargs +) -> Dataset: + """Get the CellSeg3d dataset for segmenting nuclei in 3d fluorescence microscopy. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + patch_shape: The patch shape to use for training. + sample_ids: The volume ids to load. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The segmentation dataset. + """ + raw_paths, label_paths = get_cellseg_3d_paths(path, download) + + if sample_ids is not None: + assert all(sid < len(raw_paths) for sid in sample_ids) + raw_paths = [raw_paths[i] for i in sample_ids] + label_paths = [label_paths[i] for i in sample_ids] + + return torch_em.default_segmentation_dataset( + raw_paths=raw_paths, + raw_key=None, + label_paths=label_paths, + label_key=None, + patch_shape=patch_shape, + **kwargs + ) + + +def get_cellseg_3d_loader( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int], + batch_size: int, + sample_ids: Optional[Tuple[int, ...]] = None, + download: bool = False, + **kwargs +) -> DataLoader: + """Get the CellSeg3d dataloader for segmenting nuclei in 3d fluorescence microscopy. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + patch_shape: The patch shape to use for training. + batch_size: The batch size for training. + sample_ids: The volume ids to load. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. + + Returns: + The DataLoader. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_cellseg_3d_dataset(path, patch_shape, sample_ids=sample_ids, download=download, **ds_kwargs) + return torch_em.get_data_loader(dataset, batch_size=batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/light_microscopy/covid_if.py b/torch_em/data/datasets/light_microscopy/covid_if.py new file mode 100644 index 00000000..59bbb74f --- /dev/null +++ b/torch_em/data/datasets/light_microscopy/covid_if.py @@ -0,0 +1,168 @@ +"""This dataset contains annotation for cell and nucleus segmentation +in immunofluorescence microscopy. + +This dataset is from the publication https://doi.org/10.1002/bies.202000257. +Please cite it if you use this dataset in your research. +""" + +import os +from glob import glob +from typing import List, Optional, Tuple, Union + +from torch.utils.data import Dataset, DataLoader + +import torch_em + +from .. import util + + +COVID_IF_URL = "https://zenodo.org/record/5092850/files/covid-if-groundtruth.zip?download=1" +CHECKSUM = "d9cd6c85a19b802c771fb4ff928894b19a8fab0e0af269c49235fdac3f7a60e1" + + +def get_covid_if_data(path: Union[os.PathLike, str], download: bool = False) -> str: + """Download the Covid-IF training data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + download: Whether to download the data if it is not present. + + Returns: + The filepath to the training data. + """ + url = COVID_IF_URL + checksum = CHECKSUM + + if os.path.exists(path): + return path + + os.makedirs(path, exist_ok=True) + zip_path = os.path.join(path, "covid-if.zip") + util.download_source(zip_path, url, download, checksum) + util.unzip(zip_path, path, True) + + return path + + +def get_covid_if_paths( + path: Union[os.PathLike, str], + sample_range: Optional[Tuple[int, int]] = None, + download: bool = False +) -> List[str]: + """Get paths to the Covid-IF data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + sample_range: Id range of samples to load from the training dataset. + download: Whether to download the data if it is not present. + + Returns: + List of filepaths to the stored data. + """ + get_covid_if_data(path, download) + + file_paths = sorted(glob(os.path.join(path, "*.h5"))) + if sample_range is not None: + start, stop = sample_range + if start is None: + start = 0 + if stop is None: + stop = len(file_paths) + file_paths = [os.path.join(path, f"gt_image_{idx:03}.h5") for idx in range(start, stop)] + assert all(os.path.exists(fp) for fp in file_paths), f"Invalid sample range {sample_range}" + + return file_paths + + +def get_covid_if_dataset( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int], + sample_range: Optional[Tuple[int, int]] = None, + target: str = "cells", + download: bool = False, + offsets: Optional[List[List[int]]] = None, + boundaries: bool = False, + binary: bool = False, + **kwargs +) -> Dataset: + """Get the Covid-IF dataset for segmenting nuclei or cells in immunofluorescence microscopy. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + patch_shape: The patch shape to use for training. + sample_range: Id range of samples to load from the training dataset. + target: The segmentation task. Either 'cells' or 'nuclei'. + download: Whether to download the data if it is not present. + offsets: Offset values for affinity computation used as target. + boundaries: Whether to compute boundaries as the target. + binary: Whether to use a binary segmentation target. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The segmentation dataset. + """ + available_targets = ("cells", "nuclei") + # TODO also support infected_cells + # available_targets = ("cells", "nuclei", "infected_cells") + + if target == "cells": + raw_key = "raw/serum_IgG/s0" + label_key = "labels/cells/s0" + elif target == "nuclei": + raw_key = "raw/nuclei/s0" + label_key = "labels/nuclei/s0" + else: + raise ValueError(f"{target} not found in {available_targets}") + + file_paths = get_covid_if_paths(path, sample_range, download) + + kwargs, _ = util.add_instance_label_transform( + kwargs, add_binary_target=True, binary=binary, boundaries=boundaries, offsets=offsets + ) + kwargs = util.update_kwargs(kwargs, "ndim", 2) + + return torch_em.default_segmentation_dataset( + raw_paths=file_paths, + raw_key=raw_key, + label_paths=file_paths, + label_key=label_key, + patch_shape=patch_shape, + **kwargs + ) + + +def get_covid_if_loader( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int], + batch_size: int, + sample_range: Optional[Tuple[int, int]] = None, + target: str = "cells", + download: bool = False, + offsets: Optional[List[List[int]]] = None, + boundaries: bool = False, + binary: bool = False, + **kwargs +) -> DataLoader: + """Get the Covid-IF dataloder for segmenting nuclei or cells in immunofluorescence microscopy. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + patch_shape: The patch shape to use for training. + batch_size: The batch size for training. + sample_range: Id range of samples to load from the training dataset. + target: The segmentation task. Either 'cells' or 'nuclei'. + download: Whether to download the data if it is not present. + offsets: Offset values for affinity computation used as target. + boundaries: Whether to compute boundaries as the target. + binary: Whether to use a binary segmentation target. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. + + Returns: + The DataLoader. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_covid_if_dataset( + path, patch_shape, sample_range=sample_range, target=target, download=download, + offsets=offsets, boundaries=boundaries, binary=binary, **ds_kwargs, + ) + return torch_em.get_data_loader(dataset, batch_size=batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/ctc.py b/torch_em/data/datasets/light_microscopy/ctc.py similarity index 52% rename from torch_em/data/datasets/ctc.py rename to torch_em/data/datasets/light_microscopy/ctc.py index d5dfe518..d6ed6e16 100644 --- a/torch_em/data/datasets/ctc.py +++ b/torch_em/data/datasets/light_microscopy/ctc.py @@ -1,9 +1,19 @@ +"""The Cell Tracking Challenge contains annotated data for cell segmentation and tracking. +We currently provide the 2d datasets with segmentation annotations. + +If you use this data in your research please cite https://doi.org/10.1038/nmeth.4473. +""" + import os from glob import glob from shutil import copyfile +from typing import Optional, Tuple, Union + +from torch.utils.data import Dataset, DataLoader import torch_em -from . import util + +from .. import util CTC_CHECKSUMS = { @@ -34,7 +44,7 @@ } -def get_ctc_url_and_checksum(dataset_name, split): +def _get_ctc_url_and_checksum(dataset_name, split): if split == "train": _link_to_split = "training-datasets" else: @@ -45,7 +55,21 @@ def get_ctc_url_and_checksum(dataset_name, split): return url, checksum -def _require_ctc_dataset(path, dataset_name, download, split): +def get_ctc_segmentation_data( + path: Union[os.PathLike, str], dataset_name: str, split: str, download: bool = False, +) -> str: + f"""Download training data from the Cell Tracking Challenge. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + dataset_name: Name of the dataset to be downloaded. The available datasets are: + {', '.join(CTC_CHECKSUMS['train'].keys())} + split: The split to download. Either 'train' or 'test'. + download: Whether to download the data if it is not present. + + Returns: + The filepath to the training data. + """ dataset_names = list(CTC_CHECKSUMS["train"].keys()) if dataset_name not in dataset_names: raise ValueError(f"Invalid dataset: {dataset_name}, choose one of {dataset_names}.") @@ -56,7 +80,7 @@ def _require_ctc_dataset(path, dataset_name, download, split): return data_path os.makedirs(data_path) - url, checksum = get_ctc_url_and_checksum(dataset_name, split) + url, checksum = _get_ctc_url_and_checksum(dataset_name, split) zip_path = os.path.join(path, f"{dataset_name}.zip") util.download_source(zip_path, url, download, checksum=checksum) util.unzip(zip_path, os.path.join(path, split), remove=True) @@ -96,24 +120,28 @@ def _require_gt_images(data_path, vol_ids): return image_paths, label_paths -def get_ctc_segmentation_dataset( - path, - dataset_name, - patch_shape, - split="train", - vol_id=None, - download=False, - **kwargs, -): - """Dataset for the cell tracking challenge segmentation data. - - This dataset provides access to the 2d segmentation datsets of the - cell tracking challenge. If you use this data in your research please cite - https://doi.org/10.1038/nmeth.4473 +def get_ctc_segmentation_paths( + path: Union[os.PathLike, str], + dataset_name: str, + split: str = "train", + vol_id: Optional[int] = None, + download: bool = False, +) -> Tuple[str, str]: + f"""Get paths to the Cell Tracking Challenge data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + dataset_name: Name of the dataset to be downloaded. The available datasets are: + {', '.join(CTC_CHECKSUMS['train'].keys())} + split: The split to download. Currently only supports 'train'. + vol_id: The train id to load. + download: Whether to download the data if it is not present. + + Returns: + Filepath to the folder where image data is stored. + Filepath to the folder where label data is stored. """ - assert split in ["train"] - - data_path = _require_ctc_dataset(path, dataset_name, download, split) + data_path = get_ctc_segmentation_data(path, dataset_name, download, split) if vol_id is None: vol_ids = glob(os.path.join(data_path, "*_GT")) @@ -123,32 +151,78 @@ def get_ctc_segmentation_dataset( vol_ids = vol_id image_path, label_path = _require_gt_images(data_path, vol_ids) + return image_path, label_path + + +def get_ctc_segmentation_dataset( + path: Union[os.PathLike, str], + dataset_name: str, + patch_shape: Tuple[int, int, int], + split: str = "train", + vol_id: Optional[int] = None, + download: bool = False, + **kwargs, +) -> Dataset: + f"""Get the CTC dataset for cell segmentation. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + dataset_name: Name of the dataset to be downloaded. The available datasets are: + {', '.join(CTC_CHECKSUMS['train'].keys())} + patch_shape: The patch shape to use for training. + split: The split to download. Currently only supports 'train'. + vol_id: The train id to load. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The segmentation dataset. + """ + assert split in ["train"] + + image_path, label_path = get_ctc_segmentation_paths(path, dataset_name, split, vol_id, download) kwargs = util.update_kwargs(kwargs, "ndim", 2) + return torch_em.default_segmentation_dataset( - image_path, "*.tif", label_path, "*.tif", patch_shape, is_seg_dataset=True, **kwargs + raw_paths=image_path, + raw_key="*.tif", + label_paths=label_path, + label_key="*.tif", + patch_shape=patch_shape, + is_seg_dataset=True, + **kwargs ) def get_ctc_segmentation_loader( - path, - dataset_name, - patch_shape, - batch_size, - split="train", - vol_id=None, - download=False, + path: Union[os.PathLike, str], + dataset_name: str, + patch_shape: Tuple[int, int, int], + batch_size: int, + split: str = "train", + vol_id: Optional[int] = None, + download: bool = False, **kwargs, -): - """Dataloader for cell tracking challenge segmentation data. - See 'get_ctc_segmentation_dataset' for details. +) -> DataLoader: + f"""Get the CTC dataloader for cell segmentation. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + dataset_name: Name of the dataset to be downloaded. The available datasets are: + {', '.join(CTC_CHECKSUMS['train'].keys())} + patch_shape: The patch shape to use for training. + batch_size: The batch size for training. + split: The split to download. Currently only supports 'train'. + vol_id: The train id to load. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. + + Returns: + The DataLoader. """ - ds_kwargs, loader_kwargs = util.split_kwargs( - torch_em.default_segmentation_dataset, **kwargs - ) + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) dataset = get_ctc_segmentation_dataset( path, dataset_name, patch_shape, split=split, vol_id=vol_id, download=download, **ds_kwargs, ) - - loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) - return loader + return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/deepbacs.py b/torch_em/data/datasets/light_microscopy/deepbacs.py similarity index 50% rename from torch_em/data/datasets/deepbacs.py rename to torch_em/data/datasets/light_microscopy/deepbacs.py index 4149128e..24be8468 100644 --- a/torch_em/data/datasets/deepbacs.py +++ b/torch_em/data/datasets/light_microscopy/deepbacs.py @@ -1,14 +1,25 @@ +"""DeepBacs is a dataset for segmenting bacteria in label-free light microscopy. + +This dataset is from the publication https://doi.org/10.1038/s42003-022-03634-z. +Please cite it if you use this dataset in your research. +""" + import os import shutil -import numpy as np from glob import glob +from typing import Tuple, Union + +import numpy as np + +from torch.utils.data import Dataset, DataLoader import torch_em -from . import util + +from .. import util URLS = { - "s_aureus": "https://zenodo.org/record/5550933/files/DeepBacs_Data_Segmentation_Staph_Aureus_dataset.zip?download=1", - "e_coli": "https://zenodo.org/record/5550935/files/DeepBacs_Data_Segmentation_E.coli_Brightfield_dataset.zip?download=1", + "s_aureus": "https://zenodo.org/record/5550933/files/DeepBacs_Data_Segmentation_Staph_Aureus_dataset.zip?download=1", # noqa + "e_coli": "https://zenodo.org/record/5550935/files/DeepBacs_Data_Segmentation_E.coli_Brightfield_dataset.zip?download=1", # noqa "b_subtilis": "https://zenodo.org/record/5639253/files/Multilabel_U-Net_dataset_B.subtilis.zip?download=1", "mixed": "https://zenodo.org/record/5551009/files/DeepBacs_Data_Segmentation_StarDist_MIXED_dataset.zip?download=1", } @@ -20,18 +31,6 @@ } -def _require_deebacs_dataset(path, bac_type, download): - os.makedirs(path, exist_ok=True) - - zip_path = os.path.join(path, f"{bac_type}.zip") - if not os.path.exists(zip_path): - util.download_source(zip_path, URLS[bac_type], download, checksum=CHECKSUMS[bac_type]) - util.unzip(zip_path, os.path.join(path, bac_type)) - - # let's get a val split for the expected bacteria type - _assort_val_set(path, bac_type) - - def _assort_val_set(path, bac_type): image_paths = glob(os.path.join(path, bac_type, "training", "source", "*")) image_paths = [os.path.split(_path)[-1] for _path in image_paths] @@ -71,7 +70,54 @@ def _assort_val_set(path, bac_type): shutil.move(src_val_label_path, dst_val_label_path) -def _get_paths(path, bac_type, split): +def get_deepbacs_data(path: Union[os.PathLike, str], bac_type: str, download: bool) -> str: + f"""Download the DeepBacs training data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + bac_type: The bacteria type. The available types are: + {', '.join(URLS.keys())} + download: Whether to download the data if it is not present. + + Returns: + The filepath to the training data. + """ + bac_types = list(URLS.keys()) + assert bac_type in bac_types, f"{bac_type} is not in expected bacteria types: {bac_types}" + + data_folder = os.path.join(path, bac_type) + if os.path.exists(data_folder): + return data_folder + + os.makedirs(path, exist_ok=True) + zip_path = os.path.join(path, f"{bac_type}.zip") + if not os.path.exists(zip_path): + util.download_source(zip_path, URLS[bac_type], download, checksum=CHECKSUMS[bac_type]) + util.unzip(zip_path, os.path.join(path, bac_type)) + + # Get a val split for the expected bacteria type. + _assort_val_set(path, bac_type) + return data_folder + + +def get_deepbacs_paths( + path: Union[os.PathLike, str], bac_type: str, split: str, download: bool = False +) -> Tuple[str, str]: + f"""Get paths to the DeepBacs data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The split to use for the dataset. Either 'train', 'val' or 'test'. + bac_type: The bacteria type. The available types are: + {', '.join(URLS.keys())} + download: Whether to download the data if it is not present. + + Returns: + Filepath to the folder where image data is stored. + Filepath to the folder where label data is stored. + """ + get_deepbacs_data(path, bac_type, download) + # the bacteria types other than mixed are a bit more complicated so we don't have the dataloaders for them yet # mixed is the combination of all other types if split == "train": @@ -81,39 +127,73 @@ def _get_paths(path, bac_type, split): if bac_type != "mixed": raise NotImplementedError(f"Currently only the bacteria type 'mixed' is supported, not {bac_type}") + image_folder = os.path.join(path, bac_type, dir_choice, "source") label_folder = os.path.join(path, bac_type, dir_choice, "target") + return image_folder, label_folder def get_deepbacs_dataset( - path, split, patch_shape, bac_type="mixed", download=False, **kwargs -): - """Dataset for the segmentation of bacteria in light microscopy. - - This dataset is from the publication https://doi.org/10.1038/s42003-022-03634-z. - Please cite it if you use this dataset for a publication. + path: Union[os.PathLike, str], + split: str, + patch_shape: Tuple[int, int], + bac_type: str = "mixed", + download: bool = False, + **kwargs +) -> Dataset: + f"""Get the DeepBacs dataset for bacteria segmentation. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The split to use for the dataset. Either 'train', 'val' or 'test'. + patch_shape: The patch shape to use for training. + bac_type: The bacteria type. The available types are: + {', '.join(URLS.keys())} + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The segmentation dataset. """ assert split in ("train", "val", "test") - bac_types = list(URLS.keys()) - assert bac_type in bac_types, f"{bac_type} is not in expected bacteria types: {bac_types}" - - data_folder = os.path.join(path, bac_type) - if not os.path.exists(data_folder): - _require_deebacs_dataset(path, bac_type, download) - image_folder, label_folder = _get_paths(path, bac_type, split) + image_folder, label_folder = get_deepbacs_paths(path, bac_type, split, download) - dataset = torch_em.default_segmentation_dataset( - image_folder, "*.tif", label_folder, "*.tif", patch_shape=patch_shape, **kwargs + return torch_em.default_segmentation_dataset( + raw_paths=image_folder, + raw_key="*.tif", + label_paths=label_folder, + label_key="*.tif", + patch_shape=patch_shape, + **kwargs ) - return dataset -def get_deepbacs_loader(path, split, patch_shape, batch_size, bac_type="mixed", download=False, **kwargs): - """Dataloader for the segmentation of bacteria in light microscopy. See 'get_deepbacs_dataset' for details. +def get_deepbacs_loader( + path: Union[os.PathLike, str], + split: str, + patch_shape: Tuple[int, int], + batch_size: int, + bac_type: str = "mixed", + download: bool = False, + **kwargs +) -> DataLoader: + f"""Get the DeepBacs dataset for bacteria segmentation. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The split to use for the dataset. Either 'train', 'val' or 'test'. + patch_shape: The patch shape to use for training. + batch_size: The batch size for training. + bac_type: The bacteria type. The available types are: + {', '.join(URLS.keys())} + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. + + Returns: + The DataLoader. """ ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) dataset = get_deepbacs_dataset(path, split, patch_shape, bac_type=bac_type, download=download, **ds_kwargs) - loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) - return loader + return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/light_microscopy/dic_hepg2.py b/torch_em/data/datasets/light_microscopy/dic_hepg2.py new file mode 100644 index 00000000..30b5160e --- /dev/null +++ b/torch_em/data/datasets/light_microscopy/dic_hepg2.py @@ -0,0 +1,190 @@ +"""This dataset ontains annotation for cell segmentation in +differential interference contrast (DIC) microscopy images. + +This dataset is from the publication https://doi.org/10.1016/j.compbiomed.2024.109151. +Please cite it if you use this dataset in your research. +""" + +import os +from tqdm import tqdm +from glob import glob +from pathlib import Path +from natsort import natsorted +from typing import Union, Literal, Tuple, Optional, List + +import imageio.v3 as imageio + +from torch.utils.data import Dataset, DataLoader + +import torch_em + +try: + from pycocotools.coco import COCO +except ImportError: + COCO = None + +from .. import util +from .livecell import _annotations_to_instances + + +URL = "https://zenodo.org/records/13120679/files/2021-11-15_HepG2_Calcein_AM.zip" +CHECKSUM = "42b939d01c5fc2517dc3ad34bde596ac38dbeba2a96173f37e1b6dfe14cbe3a2" + + +def get_dic_hepg2_data(path: Union[str, os.PathLike], download: bool = False) -> str: + """Download the DIC HepG2 dataset. + + Args: + path: Filepath to a folder where the downloaded data will be stored. + download: Whether to download the data if it is not present. + + Returns: + The path to the folder where data is stored. + """ + if os.path.exists(path): + return path + + os.makedirs(path, exist_ok=True) + zip_path = os.path.join(path, "2021-11-15_HepG2_Calcein_AM.zip") + util.download_source(zip_path, URL, download, CHECKSUM) + util.unzip(zip_path, path, True) + + return path + + +def _create_segmentations_from_coco_annotation(path, split): + assert COCO is not None, "pycocotools is required for processing the LiveCELL ground-truth." + + base_dir = os.path.join(path, "2021-11-15_HepG2_Calcein_AM", "coco_format", split) + image_folder = os.path.join(base_dir, "images") + gt_folder = os.path.join(base_dir, "annotations") + if os.path.exists(gt_folder): + return image_folder, gt_folder + + os.makedirs(gt_folder, exist_ok=True) + + ann_file = os.path.join(base_dir, "annotations.json") + assert os.path.exists(ann_file) + coco = COCO(ann_file) + category_ids = coco.getCatIds(catNms=["cell"]) + image_ids = coco.getImgIds(catIds=category_ids) + + for image_id in tqdm( + image_ids, desc="Creating DIC HepG2 segmentations from coco-style annotations" + ): + image_metadata = coco.loadImgs(image_id)[0] + fname = image_metadata["file_name"] + + gt_path = os.path.join(gt_folder, Path(fname).with_suffix(".tif")) + + gt = _annotations_to_instances(coco, image_metadata, category_ids) + imageio.imwrite(gt_path, gt, compression="zlib") + + return image_folder, gt_folder + + +def get_dic_hepg2_paths( + path: Union[os.PathLike, str], split: str, download: bool = False +) -> Tuple[List[str], List[str]]: + """Get paths to DIC HepG2 data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The data split to use. Either 'train', 'val' or 'test'. + download: Whether to download the data if it is not present. + + Returns: + List of filepaths for the image data. + List of filepaths for the label data. + """ + path = get_dic_hepg2_data(path=path, download=download) + + image_folder, gt_folder = _create_segmentations_from_coco_annotation(path=path, split=split) + gt_paths = natsorted(glob(os.path.join(gt_folder, "*.tif"))) + image_paths = [os.path.join(image_folder, f"{Path(gt_path).stem}.png") for gt_path in gt_paths] + + return image_paths, gt_paths + + +def get_dic_hepg2_dataset( + path: Union[str, os.PathLike], + patch_shape: Tuple[int, int], + split: Literal["train", "val", "test"], + offsets: Optional[List[List[int]]] = None, + boundaries: bool = False, + binary: bool = False, + download: bool = False, + **kwargs +) -> Dataset: + """Get the DIC HepG2 dataset for segmenting cells in differential interference contrast microscopy. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The data split to use. Either 'train', 'val' or 'test'. + patch_shape: The patch shape to use for training. + download: Whether to download the data if it is not present. + offsets: Offset values for affinity computation used as target. + boundaries: Whether to compute boundaries as the target. + binary: Whether to use a binary segmentation target. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The segmentation dataset. + """ + image_paths, gt_paths = get_dic_hepg2_paths(path=path, split=split) + + kwargs = util.ensure_transforms(ndim=2, **kwargs) + kwargs, _ = util.add_instance_label_transform( + kwargs, add_binary_target=True, offsets=offsets, boundaries=boundaries, binary=binary + ) + + return torch_em.default_segmentation_dataset( + raw_paths=image_paths, + raw_key=None, + label_paths=gt_paths, + label_key=None, + patch_shape=patch_shape, + is_seg_dataset=False, + **kwargs + ) + + +def get_dic_hepg2_loader( + path: Union[str, os.PathLike], + split: Literal['train', 'val', 'test'], + patch_shape: Tuple[int, int], + batch_size: int, + offsets: Optional[List[List[int]]] = None, + boundaries: bool = False, + binary: bool = False, + download: bool = False, + **kwargs +) -> DataLoader: + """Get the DIC HepG2 dataloader for segmenting cells in differential interference contrast microscopy. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The data split to use. Either 'train', 'val' or 'test'. + patch_shape: The patch shape to use for training. + batch_size: The batch size for training. + download: Whether to download the data if it is not present. + offsets: Offset values for affinity computation used as target. + boundaries: Whether to compute boundaries as the target. + binary: Whether to use a binary segmentation target. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. + + Returns: + The DataLoader. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_dic_hepg2_dataset( + path=path, + patch_shape=patch_shape, + split=split, + offsets=offsets, + boundaries=boundaries, + binary=binary, + download=download, + **ds_kwargs + ) + return torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/light_microscopy/dsb.py b/torch_em/data/datasets/light_microscopy/dsb.py new file mode 100644 index 00000000..be18fd86 --- /dev/null +++ b/torch_em/data/datasets/light_microscopy/dsb.py @@ -0,0 +1,165 @@ +"""This Dataset was used in a Kaggle Data Science Bowl. It contains light microscopy +images with annotations for nucleus segmentation. + +The dataset is described in the publication https://doi.org/10.1038/s41592-019-0612-7. +Please cite it if you use this dataset in your research. +""" + +import os +from shutil import move +from typing import List, Optional, Tuple, Union + +from torch.utils.data import Dataset, DataLoader + +import torch_em + +from .. import util + + +DSB_URLS = { + "full": "", # TODO + "reduced": "https://github.com/stardist/stardist/releases/download/0.1.0/dsb2018.zip" +} +CHECKSUMS = { + "full": None, + "reduced": "e44921950edce378063aa4457e625581ba35b4c2dbd9a07c19d48900129f386f" +} + + +def get_dsb_data(path: Union[os.PathLike, str], source: str, download: bool) -> str: + """Download the DSB training data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + source: The source of the dataset. Can either be 'full' for the complete dataset, + or 'reduced' for the dataset excluding histopathology images. + download: Whether to download the data if it is not present. + + Returns: + The filepath to the training data. + """ + os.makedirs(path, exist_ok=True) + url = DSB_URLS[source] + checksum = CHECKSUMS[source] + + train_out_path = os.path.join(path, "train") + test_out_path = os.path.join(path, "test") + + if os.path.exists(train_out_path) and os.path.exists(test_out_path): + return path + + zip_path = os.path.join(path, "dsb.zip") + util.download_source(zip_path, url, download, checksum) + util.unzip(zip_path, path, True) + + move(os.path.join(path, "dsb2018", "train"), train_out_path) + move(os.path.join(path, "dsb2018", "test"), test_out_path) + return path + + +def get_dsb_paths(path: Union[os.PathLike, str], split: str, source: str, download: bool = False) -> Tuple[str, str]: + """Get paths to the DSB data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The split to use for the dataset. Either 'train' or 'test'. + source: The source of the dataset. Can either be 'full' for the complete dataset, + or 'reduced' for the dataset excluding histopathology images. + download: Whether to download the data if it is not present. + + Returns: + Filepath for the folder where the images are stored. + Filepath for the folder where the labels are stored. + """ + get_dsb_data(path, source, download) + + image_path = os.path.join(path, split, "images") + label_path = os.path.join(path, split, "masks") + + return image_path, label_path + + +def get_dsb_dataset( + path: Union[os.PathLike, str], + split: str, + patch_shape: Tuple[int, int], + download: bool = False, + offsets: Optional[List[List[int]]] = None, + boundaries: bool = False, + binary: bool = False, + source: str = "reduced", + **kwargs +) -> Dataset: + """Get the DSB dataset for nucleus segmentation. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The split to use for the dataset. Either 'train' or 'test'. + patch_shape: The patch shape to use for training. + download: Whether to download the data if it is not present. + offsets: Offset values for affinity computation used as target. + boundaries: Whether to compute boundaries as the target. + binary: Whether to use a binary segmentation target. + source: The source of the dataset. Can either be 'full' for the complete dataset, + or 'reduced' for the dataset excluding histopathology images. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The segmentation dataset. + """ + assert split in ("test", "train"), split + + image_path, label_path = get_dsb_paths(path, split, source, download) + + kwargs, _ = util.add_instance_label_transform( + kwargs, add_binary_target=True, binary=binary, boundaries=boundaries, offsets=offsets + ) + kwargs = util.update_kwargs(kwargs, "ndim", 2) + + return torch_em.default_segmentation_dataset( + raw_paths=image_path, + raw_key="*.tif", + label_paths=label_path, + label_key="*.tif", + patch_shape=patch_shape, + **kwargs + ) + + +def get_dsb_loader( + path: Union[os.PathLike, str], + split: str, + patch_shape: Tuple[int, int], + batch_size: int, + download: bool = False, + offsets: Optional[List[List[int]]] = None, + boundaries: bool = False, + binary: bool = False, + source: str = "reduced", + **kwargs +) -> DataLoader: + """Get the DSB dataloader for nucleus segmentation. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The split to use for the dataset. Either 'train' or 'test'. + patch_shape: The patch shape to use for training. + batch_size: The batch size for training. + download: Whether to download the data if it is not present. + offsets: Offset values for affinity computation used as target. + boundaries: Whether to compute boundaries as the target. + binary: Whether to use a binary segmentation target. + source: The source of the dataset. Can either be 'full' for the complete dataset, + or 'reduced' for the dataset excluding histopathology images. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. + + Returns: + The DataLoader. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_dsb_dataset( + path, split, patch_shape, download=download, + offsets=offsets, boundaries=boundaries, binary=binary, + source=source, **ds_kwargs, + ) + return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/light_microscopy/dynamicnuclearnet.py b/torch_em/data/datasets/light_microscopy/dynamicnuclearnet.py new file mode 100644 index 00000000..e88fa085 --- /dev/null +++ b/torch_em/data/datasets/light_microscopy/dynamicnuclearnet.py @@ -0,0 +1,172 @@ +"""The DynamicNuclearNet dataset contains annotations for nucleus segmentation +and tracking in fluorescence light microscopy, for five different cell lines. + +This dataset is from the publication https://doi.org/10.1101/803205. +Please cite it if you use this dataset for your research. + +This dataset cannot be downloaded automatically, please visit https://datasets.deepcell.org/data +and download it yourself. +""" + +import os +from tqdm import tqdm +from glob import glob +from typing import Tuple, Union, List + +import numpy as np +import pandas as pd + +from torch.utils.data import Dataset, DataLoader + +import torch_em + +from .. import util + + +def _create_split(path, split): + import z5py + + split_file = os.path.join(path, "DynamicNuclearNet-segmentation-v1_0", f"{split}.npz") + split_folder = os.path.join(path, split) + os.makedirs(split_folder, exist_ok=True) + data = np.load(split_file, allow_pickle=True) + + x, y = data["X"], data["y"] + metadata = data["meta"] + metadata = pd.DataFrame(metadata[1:], columns=metadata[0]) + + for i, (im, label) in tqdm(enumerate(zip(x, y)), total=len(x), desc=f"Creating files for {split}-split"): + out_path = os.path.join(split_folder, f"image_{i:04}.zarr") + image_channel = im[..., 0] + label_channel = label[..., 0] + chunks = image_channel.shape + with z5py.File(out_path, "a") as f: + f.create_dataset("raw", data=image_channel, compression="gzip", chunks=chunks) + f.create_dataset("labels", data=label_channel, compression="gzip", chunks=chunks) + + os.remove(split_file) + + +def _create_dataset(path, zip_path): + util.unzip(zip_path, path, remove=False) + splits = ["train", "val", "test"] + assert all( + [os.path.exists(os.path.join(path, "DynamicNuclearNet-segmentation-v1_0", f"{split}.npz")) for split in splits] + ) + for split in splits: + _create_split(path, split) + + +def get_dynamicnuclearnet_data( + path: Union[os.PathLike, str], + split: str, + download: bool = False, +) -> str: + """Download the DynamicNuclearNet dataset. + + NOTE: Automatic download is not supported for DynamicNuclearnet dataset. + Please download the dataset from https://datasets.deepcell.org/data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The split to use for the dataset. Either 'train', 'val' or 'test'. + download: Whether to download the data if it is not present. + + Returns: + The path where inputs are stored per split. + """ + splits = ["train", "val", "test"] + assert split in splits + + # check if the dataset exists already + zip_path = os.path.join(path, "DynamicNuclearNet-segmentation-v1_0.zip") + if all([os.path.exists(os.path.join(path, split)) for split in splits]): # yes it does + pass + elif os.path.exists(zip_path): # no it does not, but we have the zip there and can unpack it + _create_dataset(path, zip_path) + else: + raise RuntimeError( + "We do not support automatic download for the dynamic nuclear net dataset yet. " + f"Please download the dataset from https://datasets.deepcell.org/data and put it here: {zip_path}" + ) + + split_folder = os.path.join(path, split) + return split_folder + + +def get_dynamicnuclearnet_paths(path: Union[os.PathLike, str], split: str, download: bool = False) -> List[str]: + """Get paths to the DynamicNuclearNet data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The split to use for the dataset. Either 'train', 'val' or 'test'. + download: Whether to download the data if it is not present. + + Returns: + List of filepaths for the stored data. + """ + split_folder = get_dynamicnuclearnet_data(path, split, download) + assert os.path.exists(split_folder) + data_paths = glob(os.path.join(split_folder, "*.zarr")) + assert len(data_paths) > 0 + + return data_paths + + +def get_dynamicnuclearnet_dataset( + path: Union[os.PathLike, str], + split: str, + patch_shape: Tuple[int, int], + download: bool = False, + **kwargs +) -> Dataset: + """Get the DynamicNuclearNet dataset for nucleus segmentation. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The split to use for the dataset. Either 'train', 'val' or 'test'. + patch_shape: The patch shape to use for training. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The segmentation dataset. + """ + data_paths = get_dynamicnuclearnet_paths(path, split, download) + + return torch_em.default_segmentation_dataset( + raw_paths=data_paths, + raw_key="raw", + label_paths=data_paths, + label_key="labels", + patch_shape=patch_shape, + is_seg_dataset=True, + ndim=2, + **kwargs + ) + + +def get_dynamicnuclearnet_loader( + path: Union[os.PathLike, str], + split: str, + patch_shape: Tuple[int, int], + batch_size: int, + download: bool = False, + **kwargs +) -> DataLoader: + """Get the DynamicNuclearNet dataloader for nucleus segmentation. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The split to use for the dataset. Either 'train', 'val' or 'test'. + patch_shape: The patch shape to use for training. + batch_size: The batch size for training. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. + + Returns: + The DataLoader. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_dynamicnuclearnet_dataset(path, split, patch_shape, download, **ds_kwargs) + return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/light_microscopy/embedseg_data.py b/torch_em/data/datasets/light_microscopy/embedseg_data.py new file mode 100644 index 00000000..28268389 --- /dev/null +++ b/torch_em/data/datasets/light_microscopy/embedseg_data.py @@ -0,0 +1,147 @@ +"""This dataset contains annotation for 3d fluorescence microscopy segmentation +that were introduced by the EmbedSeg publication. + +This dataset is from the publication https://proceedings.mlr.press/v143/lalit21a.html. +Please cite it if you use this dataset in your research. +""" + +import os +from glob import glob +from typing import Tuple, Union, List + +from torch.utils.data import Dataset, DataLoader + +import torch_em + +from .. import util + + +URLS = { + "Mouse-Organoid-Cells-CBG": "https://github.com/juglab/EmbedSeg/releases/download/v0.1.0/Mouse-Organoid-Cells-CBG.zip", # noqa + "Mouse-Skull-Nuclei-CBG": "https://github.com/juglab/EmbedSeg/releases/download/v0.1.0/Mouse-Skull-Nuclei-CBG.zip", + "Platynereis-ISH-Nuclei-CBG": "https://github.com/juglab/EmbedSeg/releases/download/v0.1.0/Platynereis-ISH-Nuclei-CBG.zip", # noqa + "Platynereis-Nuclei-CBG": "https://github.com/juglab/EmbedSeg/releases/download/v0.1.0/Platynereis-Nuclei-CBG.zip", +} +CHECKSUMS = { + "Mouse-Organoid-Cells-CBG": "3695ac340473900ace8c37fd7f3ae0d37217de9f2b86c2341f36b1727825e48b", + "Mouse-Skull-Nuclei-CBG": "3600ec261a48bf953820e0536cacd0bb8a5141be6e7435a4cb0fffeb0caf594e", + "Platynereis-ISH-Nuclei-CBG": "bc9284df6f6d691a8e81b47310d95617252cc98ebf7daeab55801b330ba921e0", + "Platynereis-Nuclei-CBG": "448cb7b46f2fe7d472795e05c8d7dfb40f259d94595ad2cfd256bc2aa4ab3be7", +} + + +def get_embedseg_data(path: Union[os.PathLike, str], name: str, download: bool) -> str: + """Download the EmbedSeg training data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + name: Name of the dataset to download. + download: Whether to download the data if it is not present. + + Returns: + The filepath to the training data. + """ + if name not in URLS: + raise ValueError(f"The dataset name must be in {list(URLS.keys())}. You provided {name}.") + + url = URLS[name] + checksum = CHECKSUMS[name] + + data_path = os.path.join(path, name) + if os.path.exists(data_path): + return data_path + + os.makedirs(path, exist_ok=True) + zip_path = os.path.join(path, f"{name}.zip") + util.download_source(zip_path, url, download, checksum) + util.unzip(zip_path, path, True) + + return data_path + + +def get_embedseg_paths( + path: Union[os.PathLike, str], name: str, split: str, download: bool = False +) -> Tuple[List[str], List[str]]: + """Get paths to the EmbedSeg data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + name: Name of the dataset to download. + split: The split to use for the dataset. + download: Whether to download the data if it is not present. + + Returns: + List of filepaths for the mage data. + List of filepaths for the label data. + """ + data_root = get_embedseg_data(path, name, download) + + raw_paths = sorted(glob(os.path.join(data_root, split, "images", "*.tif"))) + label_paths = sorted(glob(os.path.join(data_root, split, "masks", "*.tif"))) + assert len(raw_paths) > 0 + assert len(raw_paths) == len(label_paths) + + return raw_paths, label_paths + + +def get_embedseg_dataset( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int], + name: str, + split: str = "train", + download: bool = False, + **kwargs +) -> Dataset: + """Get the EmbedSeg dataset for 3d fluorescence microscopy segmentation. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + patch_shape: The patch shape to use for training. + name: Name of the dataset to download. + split: The split to use for the dataset. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The segmentation dataset. + """ + raw_paths, label_paths = get_embedseg_paths(path, name, split, download) + + return torch_em.default_segmentation_dataset( + raw_paths=raw_paths, + raw_key=None, + label_paths=label_paths, + label_key=None, + patch_shape=patch_shape, + **kwargs + ) + + +def get_embedseg_loader( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int], + batch_size: int, + name: str, + split: str = "train", + download: bool = False, + **kwargs +) -> DataLoader: + """Get the EmbedSeg dataloader for 3d fluorescence microscopy segmentation. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + patch_shape: The patch shape to use for training. + batch_size: The batch size for training. + name: Name of the dataset to download. + split: The split to use for the dataset. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. + + Returns: + The DataLoader. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_embedseg_dataset( + path, name=name, split=split, patch_shape=patch_shape, download=download, **ds_kwargs, + ) + return torch_em.get_data_loader(dataset, batch_size=batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/light_microscopy/gonuclear.py b/torch_em/data/datasets/light_microscopy/gonuclear.py new file mode 100644 index 00000000..b4378112 --- /dev/null +++ b/torch_em/data/datasets/light_microscopy/gonuclear.py @@ -0,0 +1,245 @@ +"""This dataset contains annotation for nucleus segmentation in 3d fluorescence microscopy. + +This dataset is from the publication https://doi.org/10.1242/dev.202800. +Please cite it if you use this dataset in your research. +""" + +import os +from glob import glob +from shutil import rmtree +from typing import Optional, Tuple, Union, List + +import numpy as np +import imageio.v3 as imageio + +from torch.utils.data import Dataset, DataLoader + +import torch_em + +from .. import util + + +URL = "https://www.ebi.ac.uk/biostudies/files/S-BIAD1026/Nuclei_training_segmentation/Training%20image%20dataset_Tiff%20Files.zip" # noqa +CHECKSUM = "b103388a4aed01c7aadb2d5f49392d2dd08dd7cbeb2357b0c56355384ebb93a9" + + +def _load_tif(path): + vol = None + + path_tif = path + ".tif" + if os.path.exists(path_tif): + vol = imageio.imread(path_tif) + + path_tiff = path + ".tiff" + if os.path.exists(path_tiff): + vol = imageio.imread(path_tiff) + + if vol is None: + raise RuntimeError("Can't find tif or tiff file for {path}.") + + return vol + + +def _clip_shape(raw, labels): + shape = raw.shape + labels = labels[:shape[0], :shape[1], :shape[2]] + + shape = labels.shape + raw = raw[:shape[0], :shape[1], :shape[2]] + + assert labels.shape == raw.shape, f"{labels.shape}, {raw.shape}" + return raw, labels + + +def _process_data(in_folder, out_folder): + import h5py + + os.makedirs(out_folder, exist_ok=True) + + sample_folders = glob(os.path.join(in_folder, "*")) + for folder in sample_folders: + sample = os.path.basename(folder) + out_path = os.path.join(out_folder, f"{sample}.h5") + + cell_raw = _load_tif(os.path.join(folder, f"{sample}_cellwall")) + cell_labels = _load_tif(os.path.join(folder, f"{sample}_cellseg")) + cell_labels = cell_labels[:, ::-1] + cell_raw, cell_labels = _clip_shape(cell_raw, cell_labels) + + nucleus_raw = _load_tif(os.path.join(folder, f"{sample}_n_H2BtdTomato")) + nucleus_labels = _load_tif(os.path.join(folder, f"{sample}_n_stain_StarDist_goldGT")) + nucleus_labels = nucleus_labels[:, ::-1] + nucleus_raw, nucleus_labels = _clip_shape(nucleus_raw, nucleus_labels) + + # Remove last frames with artifacts for two volumes (1137 and 1170). + if sample in ["1137", "1170"]: + nucleus_raw, nucleus_labels = nucleus_raw[:-1], nucleus_labels[:-1] + cell_raw, cell_labels = cell_raw[:-1], cell_labels[:-1] + + # Fixing cell labels for one volume (1136) is misaligned. + if sample == "1136": + cell_labels = np.fliplr(cell_labels) + + with h5py.File(out_path, "w") as f: + f.create_dataset("raw/cells", data=cell_raw, compression="gzip") + f.create_dataset("raw/nuclei", data=nucleus_raw, compression="gzip") + + f.create_dataset("labels/cells", data=cell_labels, compression="gzip") + f.create_dataset("labels/nuclei", data=nucleus_labels, compression="gzip") + + +def get_gonuclear_data(path: Union[os.PathLike, str], download: bool) -> str: + """Download the GoNuclear training data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + download: Whether to download the data if it is not present. + + Returns: + The filepath to the training data. + """ + data_path = os.path.join(path, "gonuclear_datasets") + if os.path.exists(data_path): + return data_path + + os.makedirs(path, exist_ok=True) + zip_path = os.path.join(path, "gonuclear.zip") + util.download_source(zip_path, URL, download, CHECKSUM) + util.unzip(zip_path, path, True) + + extracted_path = os.path.join(path, "Training image dataset_Tiff Files") + assert os.path.exists(extracted_path), extracted_path + _process_data(extracted_path, data_path) + assert os.path.exists(data_path) + + rmtree(extracted_path) + return data_path + + +def get_gonuclear_paths( + path: Union[os.PathLike, str], + sample_ids: Optional[Union[int, Tuple[int, ...]]] = None, + download: bool = False +) -> List[str]: + """Get paths to the GoNuclear data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + sample_ids: The sample ids to load. The valid sample ids are: + 1135, 1136, 1137, 1139, 1170. If none is given all samples will be loaded. + download: Whether to download the data if it is not present. + + Returns: + List of filepaths for the stored data. + """ + data_root = get_gonuclear_data(path, download) + + if sample_ids is None: + paths = sorted(glob(os.path.join(data_root, "*.h5"))) + else: + paths = [] + for sample_id in sample_ids: + sample_path = os.path.join(data_root, f"{sample_id}.h5") + if not os.path.exists(sample_path): + raise ValueError(f"Invalid sample id {sample_id}.") + paths.append(sample_path) + + return paths + + +def get_gonuclear_dataset( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int], + segmentation_task: str = "nuclei", + sample_ids: Optional[Union[int, Tuple[int, ...]]] = None, + offsets: Optional[List[List[int]]] = None, + boundaries: bool = False, + binary: bool = False, + download: bool = False, + **kwargs +) -> Dataset: + """Get the GoNuclear dataset for segmenting nuclei in 3d fluorescence microscopy. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + patch_shape: The patch shape to use for training. + segmentation_task: The segmentation task. Either 'nuclei' or 'cells'. + sample_ids: The sample ids to load. The valid sample ids are: + 1135, 1136, 1137, 1139, 1170. If none is given all samples will be loaded. + offsets: Offset values for affinity computation used as target. + boundaries: Whether to compute boundaries as the target. + binary: Whether to use a binary segmentation target. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The segmentation dataset. + """ + paths = get_gonuclear_paths(path, sample_ids, download) + + if segmentation_task == "nuclei": + raw_key = "raw/nuclei" + label_key = "labels/nuclei" + elif segmentation_task == "cells": + raw_key = "raw/cells" + label_key = "labels/cells" + else: + raise ValueError(f"Invalid segmentation task {segmentation_task}, expect one of 'cells' or 'nuclei'.") + + kwargs, _ = util.add_instance_label_transform( + kwargs, add_binary_target=True, binary=binary, boundaries=boundaries, offsets=offsets, + ) + + return torch_em.default_segmentation_dataset( + raw_paths=paths, + raw_key=raw_key, + label_paths=paths, + label_key=label_key, + patch_shape=patch_shape, + **kwargs + ) + + +def get_gonuclear_loader( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int], + batch_size: int, + segmentation_task: str = "nuclei", + sample_ids: Optional[Union[int, Tuple[int, ...]]] = None, + offsets: Optional[List[List[int]]] = None, + boundaries: bool = False, + binary: bool = False, + download: bool = False, + **kwargs +) -> DataLoader: + """Get the GoNuclear dataloader for segmenting nuclei in 3d fluorescence microscopy. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + patch_shape: The patch shape to use for training. + batch_size: The batch size for training. + segmentation_task: The segmentation task. Either 'nuclei' or 'cells'. + sample_ids: The sample ids to load. The valid sample ids are: + 1135, 1136, 1137, 1139, 1170. If none is given all samples will be loaded. + offsets: Offset values for affinity computation used as target. + boundaries: Whether to compute boundaries as the target. + binary: Whether to use a binary segmentation target. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. + + Returns: + The DataLoader. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_gonuclear_dataset( + path=path, + patch_shape=patch_shape, + segmentation_task=segmentation_task, + sample_ids=sample_ids, + offsets=offsets, + boundaries=boundaries, + binary=binary, + download=download, + **ds_kwargs, + ) + return torch_em.get_data_loader(dataset, batch_size=batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/hpa.py b/torch_em/data/datasets/light_microscopy/hpa.py similarity index 62% rename from torch_em/data/datasets/hpa.py rename to torch_em/data/datasets/light_microscopy/hpa.py index dc54062f..bd713166 100644 --- a/torch_em/data/datasets/hpa.py +++ b/torch_em/data/datasets/light_microscopy/hpa.py @@ -1,20 +1,30 @@ +"""This dataset was part of the HPA Kaggle challenge for protein identification. +It contains confocal microscopy images and annotations for cell segmentation. + +The dataset is described in the publication https://doi.org/10.1038/s41592-019-0658-6. +Please cite it if you use this dataset in your research. +""" + import os import json import shutil +from glob import glob +from tqdm import tqdm from concurrent import futures from functools import partial -from glob import glob +from typing import List, Optional, Sequence, Tuple, Union import imageio -import h5py import numpy as np +from skimage import morphology from PIL import Image, ImageDraw from skimage import draw as skimage_draw -from skimage import morphology -from tqdm import tqdm + +from torch.utils.data import Dataset, DataLoader import torch_em -from . import util + +from .. import util URLS = { @@ -23,6 +33,7 @@ CHECKSUMS = { "segmentation": "dcd6072293d88d49c71376d3d99f3f4f102e4ee83efb0187faa89c95ec49faa9" } +VALID_CHANNELS = ["microtubules", "protein", "nuclei", "er"] def _download_hpa_data(path, name, download): @@ -74,10 +85,10 @@ def _generate_binary_masks(annot_dict, shape, erose_size=5, obj_size_rem=500, sa if save_indiv is True: mask_edge_indiv = np.zeros( - (shape[0], shape[1], len(annot_dict)), dtype=np.bool + (shape[0], shape[1], len(annot_dict)), dtype="bool" ) mask_fill_indiv = np.zeros( - (shape[0], shape[1], len(annot_dict)), dtype=np.bool + (shape[0], shape[1], len(annot_dict)), dtype="bool" ) # Image used to draw lines - for edge mask for freelines @@ -249,7 +260,9 @@ def _get_labels(annotation_file, shape, label="*"): raise RuntimeError -def _process_image(in_folder, out_path, channels, with_labels): +def _process_image(in_folder, out_path, with_labels): + import h5py + # TODO double check the default order and color matching # correspondence to the HPA kaggle data: # microtubules: red @@ -257,14 +270,9 @@ def _process_image(in_folder, out_path, channels, with_labels): # er: yellow # protein: green # default order: rgby = micro, prot, nuclei, er - all_channels = {"microtubules", "protein", "nuclei", "er"} - assert len(list(set(channels) - all_channels)) == 0 - raw = [] - for chan in channels: - im_path = os.path.join(in_folder, f"{chan}.png") - assert os.path.exists(im_path), im_path - raw.append(imageio.imread(im_path)[None]) - raw = np.concatenate(raw, axis=0) + raw = np.concatenate([ + imageio.imread(os.path.join(in_folder, f"{chan}.png"))[None] for chan in VALID_CHANNELS + ], axis=0) if with_labels: annotation_file = os.path.join(in_folder, "annotation.json") @@ -273,28 +281,35 @@ def _process_image(in_folder, out_path, channels, with_labels): assert labels.shape == raw.shape[1:] with h5py.File(out_path, "w") as f: - f.create_dataset("raw", data=raw, compression="gzip") + f.create_dataset("raw/microtubules", data=raw[0], compression="gzip") + f.create_dataset("raw/protein", data=raw[1], compression="gzip") + f.create_dataset("raw/nuclei", data=raw[2], compression="gzip") + f.create_dataset("raw/er", data=raw[3], compression="gzip") if with_labels: f.create_dataset("labels", data=labels, compression="gzip") -def _process_split(root_in, root_out, channels, n_workers, with_labels): +def _process_split(root_in, root_out, n_workers, with_labels): os.makedirs(root_out, exist_ok=True) inputs = glob(os.path.join(root_in, "*")) outputs = [os.path.join(root_out, f"{os.path.split(inp)[1]}.h5") for inp in inputs] - process = partial(_process_image, channels=channels, with_labels=with_labels) + process = partial(_process_image, with_labels=with_labels) with futures.ProcessPoolExecutor(n_workers) as pp: list(tqdm(pp.map(process, inputs, outputs), total=len(inputs), desc=f"Process data in {root_in}")) -# save data as h5 with 4 channel raw data and labels extracted from the geo json -def _process_hpa_data(path, channels, n_workers, remove): +# save data as h5 in 4 separate channel raw data and labels extracted from the geo json +def _process_hpa_data(path, n_workers, remove): in_path = os.path.join(path, "hpa_dataset_v2") assert os.path.exists(in_path), in_path for split in ("train", "test", "valid"): out_split = "val" if split == "valid" else split - _process_split(os.path.join(in_path, split), os.path.join(path, out_split), - channels=channels, n_workers=n_workers, with_labels=split != "test") + _process_split( + root_in=os.path.join(in_path, split), + root_out=os.path.join(path, out_split), + n_workers=n_workers, + with_labels=(split != "test") + ) if remove: shutil.rmtree(in_path) @@ -306,21 +321,77 @@ def _check_data(path): return have_train and have_test and have_val -def get_hpa_segmentation_dataset( - path, split, patch_shape, - offsets=None, boundaries=False, binary=False, - channels=["microtubules", "protein", "nuclei", "er"], - download=False, n_workers_preproc=8, **kwargs -): - """Dataset for the segmentation of cells in light microscopy. - - This dataset is from the publication https://doi.org/10.1038/s41592-019-0658-6. - Please cite it if you use this dataset for a publication. +def get_hpa_segmentation_data(path: Union[os.PathLike, str], download: bool, n_workers_preproc: int = 8) -> str: + """Download the HPA training data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + download: Whether to download the data if it is not present. + n_workers_preproc: The number of workers to use for preprocessing. + + Returns: + The filepath to the training data. """ data_is_complete = _check_data(path) if not data_is_complete: _download_hpa_data(path, "segmentation", download) - _process_hpa_data(path, channels, n_workers_preproc, remove=True) + _process_hpa_data(path, n_workers_preproc, remove=True) + return path + + +def get_hpa_segmentation_paths( + path: Union[os.PathLike, str], split: str, download: bool = False, n_workers_preproc: int = 8, +) -> List[str]: + """Get paths to the HPA data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The split for the dataset. Available splits are 'train', 'val' or 'test'. + download: Whether to download the data if it is not present. + n_workers_preproc: The number of workers to use for preprocessing. + + Returns: + List of filepaths to the stored data. + """ + get_hpa_segmentation_data(path, download, n_workers_preproc) + paths = glob(os.path.join(path, split, "*.h5")) + return paths + + +def get_hpa_segmentation_dataset( + path: Union[os.PathLike, str], + split: str, + patch_shape: Tuple[int, int], + offsets: Optional[List[List[int]]] = None, + boundaries: bool = False, + binary: bool = False, + channels: Sequence[str] = ["microtubules", "protein", "nuclei", "er"], + download: bool = False, + n_workers_preproc: int = 8, + **kwargs +) -> Dataset: + """Get the HPA dataset for segmenting cells in confocal microscopy. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The split for the dataset. Available splits are 'train', 'val' or 'test'. + patch_shape: The patch shape to use for training. + offsets: Offset values for affinity computation used as target. + boundaries: Whether to compute boundaries as the target. + binary: Whether to use a binary segmentation target. + channels: The image channels to extract. Available channels are + 'microtubules', 'protein', 'nuclei' or 'er'. + download: Whether to download the data if it is not present. + n_workers_preproc: The number of workers to use for preprocessing. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The segmentation dataset. + """ + assert isinstance(channels, list), "The 'channels' argument expects the desired channel(s) in a list." + for chan in channels: + if chan not in VALID_CHANNELS: + raise ValueError(f"'{chan}' is not a valid channel for HPA dataset.") kwargs, _ = util.add_instance_label_transform( kwargs, add_binary_target=True, binary=binary, boundaries=boundaries, offsets=offsets @@ -328,29 +399,55 @@ def get_hpa_segmentation_dataset( kwargs = util.update_kwargs(kwargs, "ndim", 2) kwargs = util.update_kwargs(kwargs, "with_channels", True) - paths = glob(os.path.join(path, split, "*.h5")) - raw_key = "raw" - label_key = "labels" + paths = get_hpa_segmentation_paths(path, split, download, n_workers_preproc) - return torch_em.default_segmentation_dataset(paths, raw_key, paths, label_key, patch_shape, **kwargs) + return torch_em.default_segmentation_dataset( + raw_paths=paths, + raw_key=[f"raw/{chan}" for chan in channels], + label_paths=paths, + label_key="labels", + patch_shape=patch_shape, + **kwargs + ) def get_hpa_segmentation_loader( - path, split, patch_shape, batch_size, - offsets=None, boundaries=False, binary=False, - channels=["microtubules", "protein", "nuclei", "er"], - download=False, n_workers_preproc=8, **kwargs -): - """Dataloader for the segmentation of cells in light microscopy. See 'get_hpa_segmentation_dataset' for details. + path: Union[os.PathLike, str], + split: str, + patch_shape: Tuple[int, int], + batch_size: int, + offsets: Optional[List[List[int]]] = None, + boundaries: bool = False, + binary: bool = False, + channels: Sequence[str] = ["microtubules", "protein", "nuclei", "er"], + download: bool = False, + n_workers_preproc: int = 8, + **kwargs +) -> DataLoader: + """Get the HPA dataloader for segmenting cells in confocal microscopy. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The split for the dataset. Available splits are 'train', 'val' or 'test'. + patch_shape: The patch shape to use for training. + batch_size: The batch size for training. + offsets: Offset values for affinity computation used as target. + boundaries: Whether to compute boundaries as the target. + binary: Whether to use a binary segmentation target. + channels: The image channels to extract. Available channels are + 'microtubules', 'protein', 'nuclei' or 'er'. + download: Whether to download the data if it is not present. + n_workers_preproc: The number of workers to use for preprocessing. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. + + Returns: + The DataLoader. """ - ds_kwargs, loader_kwargs = util.split_kwargs( - torch_em.default_segmentation_dataset, **kwargs - ) + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) dataset = get_hpa_segmentation_dataset( path, split, patch_shape, offsets=offsets, boundaries=boundaries, binary=binary, channels=channels, download=download, n_workers_preproc=n_workers_preproc, **ds_kwargs ) - loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) - return loader + return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/livecell.py b/torch_em/data/datasets/light_microscopy/livecell.py similarity index 56% rename from torch_em/data/datasets/livecell.py rename to torch_em/data/datasets/light_microscopy/livecell.py index 86cdd2fd..d0a79449 100644 --- a/torch_em/data/datasets/livecell.py +++ b/torch_em/data/datasets/light_microscopy/livecell.py @@ -1,15 +1,26 @@ +"""The LIVECell dataset contains phase-contrast microscopy images +and annotations for cell segmentations for 8 different cell lines. + +This dataset is desceibed in the publication https://doi.org/10.1038/s41592-021-01249-6. +Please cite it if you use this dataset in your research. +""" + import os +import requests +from tqdm import tqdm from shutil import copyfileobj +from typing import List, Optional, Sequence, Tuple, Union -import imageio import numpy as np -import requests -import vigra -from tqdm import tqdm +import imageio.v3 as imageio + +import torch +from torch.utils.data import Dataset, DataLoader import torch_em -import torch.utils.data -from . import util + +from .. import util +from ... import ImageCollectionDataset try: from pycocotools.coco import COCO @@ -29,20 +40,6 @@ CHECKSUM = None -def _download_livecell_images(path, download): - os.makedirs(path, exist_ok=True) - image_path = os.path.join(path, "images") - - if os.path.exists(image_path): - return - - url = URLS["images"] - checksum = CHECKSUM - zip_path = os.path.join(path, "livecell.zip") - util.download_source(zip_path, url, download, checksum) - util.unzip(zip_path, path, True) - - # TODO use download flag def _download_annotation_file(path, split, download): annotation_file = os.path.join(path, f"{split}.json") @@ -56,6 +53,8 @@ def _download_annotation_file(path, split, download): def _annotations_to_instances(coco, image_metadata, category_ids): + import vigra + # create and save the segmentation annotation_ids = coco.getAnnIds(imgIds=image_metadata["id"], catIds=category_ids) annotations = coco.loadAnns(annotation_ids) @@ -123,7 +122,7 @@ def _create_segmentations_from_annotations(annotation_file, image_folder, seg_fo imageio.imwrite(seg_path, seg) assert len(image_paths) == len(seg_paths) - assert len(image_paths) > 0,\ + assert len(image_paths) > 0, \ f"No matching image paths were found. Did you pass invalid cell type naems ({cell_types})?" return image_paths, seg_paths @@ -144,46 +143,139 @@ def _download_livecell_annotations(path, split, download, cell_types, label_path return _create_segmentations_from_annotations(annotation_file, image_folder, seg_folder, cell_types) +def get_livecell_data(path: Union[os.PathLike], download: bool = False): + """Download the LIVECell dataset. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + download: Whether to download the data if it is not present. + """ + os.makedirs(path, exist_ok=True) + image_path = os.path.join(path, "images") + + if os.path.exists(image_path): + return + + url = URLS["images"] + checksum = CHECKSUM + zip_path = os.path.join(path, "livecell.zip") + util.download_source(zip_path, url, download, checksum) + util.unzip(zip_path, path, True) + + +def get_livecell_paths( + path: Union[os.PathLike, str], + split: str, + download: bool = False, + cell_types: Optional[Sequence[str]] = None, + label_path: Optional[Union[os.PathLike, str]] = None +) -> Tuple[List[str], List[str]]: + """Get paths to the LIVECell data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The data split to use. Either 'train', 'val' or 'test'. + download: Whether to download the data if it is not present. + cell_types: The cell types for which to get the data paths. + label_path: Optional path for loading the label data. + + Returns: + List of filepaths for the image data. + List of filepaths for the label data. + """ + get_livecell_data(path, download) + image_paths, seg_paths = _download_livecell_annotations(path, split, download, cell_types, label_path) + return image_paths, seg_paths + + def get_livecell_dataset( - path, split, patch_shape, download=False, - offsets=None, boundaries=False, binary=False, - cell_types=None, label_path=None, label_dtype=torch.int64, **kwargs -): - """Dataset for the segmentation of cells in phase-contrast microscopy. - - This dataset is from the publication https://doi.org/10.1038/s41592-021-01249-6. - Please cite it if you use this dataset for a publication. + path: Union[os.PathLike, str], + split: str, + patch_shape: Tuple[int, int], + download: bool = False, + offsets: Optional[List[List[int]]] = None, + boundaries: bool = False, + binary: bool = False, + cell_types: Optional[Sequence[str]] = None, + label_path: Optional[Union[os.PathLike, str]] = None, + label_dtype=torch.int64, + **kwargs +) -> Dataset: + """Get the LIVECell dataset for segmenting cells in phase-contrast microscopy. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The data split to use. Either 'train', 'val' or 'test'. + patch_shape: The patch shape to use for training. + download: Whether to download the data if it is not present. + offsets: Offset values for affinity computation used as target. + boundaries: Whether to compute boundaries as the target. + binary: Whether to use a binary segmentation target. + cell_types: The cell types for which to get the data paths. + label_path: Optional path for loading the label data. + label_dtype: The datatype of the label data. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The segmentation dataset. """ assert split in ("train", "val", "test") if cell_types is not None: - assert isinstance(cell_types, (list, tuple)),\ + assert isinstance(cell_types, (list, tuple)), \ f"cell_types must be passed as a list or tuple instead of {cell_types}" - _download_livecell_images(path, download) - image_paths, seg_paths = _download_livecell_annotations(path, split, download, cell_types, label_path) + image_paths, seg_paths = get_livecell_paths(path, split, download, cell_types, label_path) kwargs = util.ensure_transforms(ndim=2, **kwargs) kwargs, label_dtype = util.add_instance_label_transform( - kwargs, add_binary_target=True, label_dtype=label_dtype, - offsets=offsets, boundaries=boundaries, binary=binary + kwargs, add_binary_target=True, label_dtype=label_dtype, offsets=offsets, boundaries=boundaries, binary=binary ) - dataset = torch_em.data.ImageCollectionDataset( - image_paths, seg_paths, patch_shape=patch_shape, label_dtype=label_dtype, **kwargs + return ImageCollectionDataset( + raw_image_paths=image_paths, + label_image_paths=seg_paths, + patch_shape=patch_shape, + label_dtype=label_dtype, + **kwargs ) - return dataset def get_livecell_loader( - path, split, patch_shape, batch_size, download=False, - offsets=None, boundaries=False, binary=False, - cell_types=None, label_path=None, label_dtype=torch.int64, **kwargs -): - """Dataloader for the segmentation of cells in phase-contrast microscopy. See 'get_livecell_dataset' for details.""" + path: Union[os.PathLike, str], + split: str, + patch_shape: Tuple[int, int], + batch_size: int, + download: bool = False, + offsets: Optional[List[List[int]]] = None, + boundaries: bool = False, + binary: bool = False, + cell_types: Optional[Sequence[str]] = None, + label_path: Optional[Union[os.PathLike, str]] = None, + label_dtype=torch.int64, + **kwargs +) -> DataLoader: + """Get the LIVECell dataloader for segmenting cells in phase-contrast microscopy. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The data split to use. Either 'train', 'val' or 'test'. + patch_shape: The patch shape to use for training. + batch_size: The batch size for training. + download: Whether to download the data if it is not present. + offsets: Offset values for affinity computation used as target. + boundaries: Whether to compute boundaries as the target. + binary: Whether to use a binary segmentation target. + cell_types: The cell types for which to get the data paths. + label_path: Optional path for loading the label data. + label_dtype: The datatype of the label data. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. + + Returns: + The DataLoader. + """ ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) dataset = get_livecell_dataset( path, split, patch_shape, download=download, offsets=offsets, boundaries=boundaries, binary=binary, cell_types=cell_types, label_path=label_path, label_dtype=label_dtype, **ds_kwargs ) - loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) - return loader + return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/light_microscopy/mouse_embryo.py b/torch_em/data/datasets/light_microscopy/mouse_embryo.py new file mode 100644 index 00000000..1259c1ad --- /dev/null +++ b/torch_em/data/datasets/light_microscopy/mouse_embryo.py @@ -0,0 +1,153 @@ +"""This dataset contains confocal microscopy stacks of a mouse embryo +with annotations for cell and nucleus segmentation. + +This dataset is part of the publication https://doi.org/10.15252/embj.2022113280. +Please cite it if you use this data in your research. +""" + +import os +from glob import glob +from typing import List, Optional, Tuple, Union + +from torch.utils.data import Dataset, DataLoader + +import torch_em + +from .. import util + + +URL = "https://zenodo.org/record/6546550/files/MouseEmbryos.zip?download=1" +CHECKSUM = "bf24df25e5f919489ce9e674876ff27e06af84445c48cf2900f1ab590a042622" + + +def get_mouse_embryo_data(path: Union[os.PathLike, str], download: bool) -> str: + """Download the mouse embryo dataset. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + download: Whether to download the data if it is not present. + + Returns: + The filepath for the downloaded data. + """ + if os.path.exists(path): + return path + os.makedirs(path, exist_ok=True) + tmp_path = os.path.join(path, "mouse_embryo.zip") + util.download_source(tmp_path, URL, download, CHECKSUM) + util.unzip(tmp_path, path, remove=True) + # Remove empty volume. + os.remove(os.path.join(path, "Membrane", "train", "fused_paral_stack0_chan2_tp00073_raw_crop_bg_noise.h5")) + return path + + +def get_mouse_embryo_paths(path: Union[os.PathLike, str], name: str, split: str, download: bool = False) -> List[str]: + """Get paths to the Mouse Embryo data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + name: The name of the segmentation task. Either 'membrane' or 'nuclei'. + split: The split to use for the dataset. Either 'train' or 'val'. + download: Whether to download the data if it is not present. + + Returns: + List of filepaths for the stored data. + """ + get_mouse_embryo_data(path, download) + + # the naming of the data is inconsistent: membrane has val, nuclei has test; + # we treat nuclei:test as val + split_ = "test" if name == "nuclei" and split == "val" else split + file_paths = glob(os.path.join(path, name.capitalize(), split_, "*.h5")) + file_paths.sort() + + return file_paths + + +def get_mouse_embryo_dataset( + path: Union[os.PathLike, str], + name: str, + split: str, + patch_shape: Tuple[int, int], + download: bool = False, + offsets: Optional[List[List[int]]] = None, + boundaries: bool = False, + binary: bool = False, + **kwargs, +) -> Dataset: + """Get the mouse embryo dataset for cell or nucleus segmentation. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + name: The name of the segmentation task. Either 'membrane' or 'nuclei'. + split: The split to use for the dataset. Either 'train' or 'val'. + patch_shape: The patch shape to use for training. + download: Whether to download the data if it is not present. + offsets: Offset values for affinity computation used as target. + boundaries: Whether to compute boundaries as the target. + binary: Whether to use a binary segmentation target. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The segmentation dataset. + """ + assert name in ("membrane", "nuclei") + assert split in ("train", "val") + assert len(patch_shape) == 3 + + file_paths = get_mouse_embryo_paths(path, name, split, download) + + kwargs, _ = util.add_instance_label_transform( + kwargs, + add_binary_target=binary, + binary=binary, + boundaries=boundaries, + offsets=offsets, + binary_is_exclusive=False + ) + + return torch_em.default_segmentation_dataset( + raw_paths=file_paths, + raw_key="raw", + label_paths=file_paths, + label_key="label", + patch_shape=patch_shape, + **kwargs + ) + + +def get_mouse_embryo_loader( + path: Union[os.PathLike, str], + name: str, + split: str, + patch_shape: Tuple[int, int, int], + batch_size: int, + download: bool = False, + offsets: Optional[List[List[int]]] = None, + boundaries: bool = False, + binary: bool = False, + **kwargs, +) -> DataLoader: + """Get the mouse embryo dataset for cell or nucleus segmentation. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + name: The name of the segmentation task. Either 'membrane' or 'nuclei'. + split: The split to use for the dataset. Either 'train' or 'val'. + patch_shape: The patch shape to use for training. + batch_size: The batch size for training. + download: Whether to download the data if it is not present. + offsets: Offset values for affinity computation used as target. + boundaries: Whether to compute boundaries as the target. + binary: Whether to use a binary segmentation target. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. + + Returns: + The DataLoader. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_mouse_embryo_dataset( + path, name, split, patch_shape, download=download, offsets=offsets, + boundaries=boundaries, binary=binary, **ds_kwargs + ) + return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/neurips_cell_seg.py b/torch_em/data/datasets/light_microscopy/neurips_cell_seg.py similarity index 57% rename from torch_em/data/datasets/neurips_cell_seg.py rename to torch_em/data/datasets/light_microscopy/neurips_cell_seg.py index e3ae7694..89be5e1b 100644 --- a/torch_em/data/datasets/neurips_cell_seg.py +++ b/torch_em/data/datasets/light_microscopy/neurips_cell_seg.py @@ -1,13 +1,26 @@ +"""This dataset comes from the Neurips Cell Segmentation Challenge, +which collects microscopy images and annotations for cell segmentation. + +The dataset contains both images with annotations for cell segmentation +and unlabed images for self-supervised or semi-supervised learning. +See also the challenge website for details: https://neurips22-cellseg.grand-challenge.org/. +The dataset os decribed in the publication https://doi.org/10.1038/s41592-024-02233-6. +Please cite it if you use the dataset in your research. +""" + import os -import numpy as np from glob import glob -from typing import Union, Tuple, Any, Optional +from typing import Union, Tuple, Any, Optional, List + +import numpy as np import torch +from torch.utils.data import Dataset, DataLoader import torch_em -from . import util -from .. import ImageCollectionDataset, RawImageCollectionDataset, ConcatDataset + +from .. import util +from ... import ImageCollectionDataset, RawImageCollectionDataset, ConcatDataset URL = { @@ -50,7 +63,18 @@ def to_rgb(image): return image -def _download_dataset(root, split, download): +def get_neurips_cellseg_data(root: Union[os.PathLike, str], split: str, download: bool) -> str: + f"""Download the Neurips Cell Seg training data. + + Args: + root: Filepath to a folder where the downloaded data will be saved. + split: The data split to download. Available splits are: + {', '.join(URL.keys())} + download: Whether to download the data if it is not present. + + Returns: + The filepath to the training data. + """ os.makedirs(root, exist_ok=True) target_dir = os.path.join(root, DIR_NAMES[split]) @@ -63,8 +87,22 @@ def _download_dataset(root, split, download): return target_dir -def _get_image_and_label_paths(root, split, download): - path = _download_dataset(root, split, download) +def get_neurips_cellseg_paths( + root: Union[os.PathLike, str], split: str, download: bool = False +) -> Tuple[List[str], List[str]]: + f"""Get paths to NeurIPS CellSeg Challenge data. + + Args: + root: Filepath to a folder where the downloaded data will be saved. + split: The data split to download. Available splits are: + {', '.join(URL.keys())} + download: Whether to download the data if it is not present. + + Returns: + List of filepaths for the image data. + List of filepaths for the label data. + """ + path = get_neurips_cellseg_data(root, split, download) image_folder = os.path.join(path, "images") assert os.path.exists(image_folder) @@ -93,21 +131,37 @@ def get_neurips_cellseg_supervised_dataset( n_samples: Optional[int] = None, sampler: Optional[Any] = None, download: bool = False, -): - """Dataset for the segmentation of cells in light microscopy. - - This dataset is part of the NeurIPS Cell Segmentation challenge: https://neurips22-cellseg.grand-challenge.org/. +) -> Dataset: + f"""Get the dataset for cell segmentation from the NeurIPS Cell Seg Challenge. + + Args: + root: Filepath to a folder where the downloaded data will be saved. + split: The data split to download. Available splits are: + {', '.join(URL.keys())} + patch_shape: The patch shape to use for training. + make_rgb: Whether to map all data to RGB or treat it as grayscale. + label_transform: Transformation of labels, applied before data augmentation. + label_transform2: Transformation of labels, applied after data augmentation. + raw_transform: Transformation of the raw data. + label_dtype: The data type of the label data. + n_samples: Number of samples per epoch from this dataset. + sampler: Sampler for rejecting batches. + download: Whether to download the data if it is not present. + + Returns: + The segmentation dataset. """ assert split in ("train", "val", "test"), split - image_paths, label_paths = _get_image_and_label_paths(root, split, download) + image_paths, label_paths = get_neurips_cellseg_paths(root, split, download) if raw_transform is None: trafo = to_rgb if make_rgb else None raw_transform = torch_em.transform.get_raw_transform(augmentation2=trafo) + if transform is None: transform = torch_em.transform.get_augmentations(ndim=2) - ds = ImageCollectionDataset( + return ImageCollectionDataset( raw_image_paths=image_paths, label_image_paths=label_paths, patch_shape=patch_shape, @@ -119,7 +173,6 @@ def get_neurips_cellseg_supervised_dataset( n_samples=n_samples, sampler=sampler ) - return ds def get_neurips_cellseg_supervised_loader( @@ -137,8 +190,29 @@ def get_neurips_cellseg_supervised_loader( sampler: Optional[Any] = None, download: bool = False, **loader_kwargs -): - """Dataloader for the segmentation of cells in light microscopy. See `get_neurips_cellseg_supervised_dataset`.""" +) -> DataLoader: + f"""Get the dataset for cell segmentation from the NeurIPS Cell Seg Challenge. + + Args: + root: Filepath to a folder where the downloaded data will be saved. + split: The data split to download. Available splits are: + {', '.join(URL.keys())} + patch_shape: The patch shape to use for training. + batch_size: The batch size for training. + make_rgb: Whether to map all data to RGB or treat it as grayscale. + label_transform: Transformation of labels, applied before data augmentation. + label_transform2: Transformation of labels, applied after data augmentation. + raw_transform: Transformation of the raw data. + transform: Transformation applied to raw and label data. + label_dtype: The data type of the label data. + n_samples: Number of samples per epoch from this dataset. + sampler: Sampler for rejecting batches. + download: Whether to download the data if it is not present. + loader_kwargs: Keyword arguments for the PyTorch DataLoader. + + Returns: + The DataLoader. + """ ds = get_neurips_cellseg_supervised_dataset( root=root, split=split, @@ -157,14 +231,14 @@ def get_neurips_cellseg_supervised_loader( def _get_image_paths(root, download): - path = _download_dataset(root, "unlabeled", download) + path = get_neurips_cellseg_data(root, "unlabeled", download) image_paths = glob(os.path.join(path, "*")) image_paths.sort() return image_paths def _get_wholeslide_paths(root, patch_shape, download): - path = _download_dataset(root, "unlabeled_wsi", download) + path = get_neurips_cellseg_data(root, "unlabeled_wsi", download) image_paths = glob(os.path.join(path, "*")) image_paths.sort() @@ -193,10 +267,23 @@ def get_neurips_cellseg_unsupervised_dataset( use_images: bool = True, use_wholeslide: bool = True, download: bool = False, -): - """Dataset for the segmentation of cells in light microscopy. - - This dataset is part of the NeurIPS Cell Segmentation challenge: https://neurips22-cellseg.grand-challenge.org/. +) -> Dataset: + """Get the unsupervised dataset from the NeurIPS Cell Seg Challenge. + + Args: + root: Filepath to a folder where the downloaded data will be saved. + patch_shape: The patch shape to use for training. + make_rgb: Whether to map all data to RGB or treat it as grayscale. + raw_transform: Transformation of the raw data. + transform: Transformation applied to raw and label data. + dtype: The data type of the image data. + sampler: Sampler for rejecting batches. + use_images: Whether to use the normal image data. + use_wholeslide: Whether to use the wholeslide image data. + download: Whether to download the data if it is not present. + + Returns: + The segmentation dataset. """ if raw_transform is None: trafo = to_rgb if make_rgb else None @@ -247,8 +334,25 @@ def get_neurips_cellseg_unsupervised_loader( use_wholeslide: bool = True, download: bool = False, **loader_kwargs, -): - """Dataloader for the segmentation of cells in light microscopy. See `get_neurips_cellseg_unsupervised_dataset`. +) -> DataLoader: + """Get the unsupervised dataset from the NeurIPS Cell Seg Challenge. + + Args: + root: Filepath to a folder where the downloaded data will be saved. + patch_shape: The patch shape to use for training. + batch_size: The batch size for training. + make_rgb: Whether to map all data to RGB or treat it as grayscale. + raw_transform: Transformation of the raw data. + transform: Transformation applied to raw and label data. + dtype: The data type of the image data. + sampler: Sampler for rejecting batches. + use_images: Whether to use the normal image data. + use_wholeslide: Whether to use the wholeslide image data. + download: Whether to download the data if it is not present. + loader_kwargs: Keyword arguments for the PyTorch DataLoader. + + Returns: + The DataLoader. """ ds = get_neurips_cellseg_unsupervised_dataset( root=root, patch_shape=patch_shape, make_rgb=make_rgb, raw_transform=raw_transform, transform=transform, diff --git a/torch_em/data/datasets/light_microscopy/omnipose.py b/torch_em/data/datasets/light_microscopy/omnipose.py new file mode 100644 index 00000000..ecb49ac9 --- /dev/null +++ b/torch_em/data/datasets/light_microscopy/omnipose.py @@ -0,0 +1,173 @@ +"""The OmniPose dataset contains phase-contrast and fluorescence microscopy images +and annotations for bacteria segmentation and brightfield microscopy images and +annotations for worm segmentation. + +This dataset is described in the publication https://doi.org/10.1038/s41592-022-01639-4. +Please cite it if you use this dataset in your research. +""" + + +import os +from glob import glob +from natsort import natsorted +from typing import Union, Tuple, Literal, Optional, List + +from torch.utils.data import Dataset, DataLoader + +import torch_em + +from .. import util + + +URL = "https://files.osf.io/v1/resources/xmury/providers/osfstorage/62f56c035775130690f25481/?zip=" + +# NOTE: the checksums are not reliable from the osf project downloads. +# CHECKSUM = "7ae943ff5003b085a4cde7337bd9c69988b034cfe1a6d3f252b5268f1f4c0af7" +CHECKSUM = None + +DATA_CHOICES = ["bact_fluor", "bact_phase", "worm", "worm_high_res"] + + +def get_omnipose_data(path: Union[os.PathLike, str], download: bool = False) -> str: + """Download the OmniPose dataset. + + Args: + path: Filepath to the folder where the downloaded data will be saved. + download: Whether to download the data if it is not present. + + Return: + The filepath where the data is downloaded. + """ + os.makedirs(path, exist_ok=True) + + data_dir = os.path.join(path, "data") + if os.path.exists(data_dir): + return data_dir + + zip_path = os.path.join(path, "datasets.zip") + util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM) + util.unzip(zip_path=zip_path, dst=data_dir) + + return data_dir + + +def get_omnipose_paths( + path: Union[os.PathLike, str], + split: str, + data_choice: Optional[Union[str, List[str]]] = None, + download: bool = False +) -> Tuple[List[str], List[str]]: + """Get paths to the OmniPose data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The data split to use. Either 'train' or 'test'. + data_choice: The choice of specific data. + Either 'bact_fluor', 'bact_phase', 'worm' or 'worm_high_res'. + download: Whether to download the data if it is not present. + + Returns: + List of filepaths for the image data. + List of filepaths for the label data. + """ + data_dir = get_omnipose_data(path=path, download=download) + + if split not in ["train", "test"]: + raise ValueError(f"'{split}' is not a valid split.") + + if data_choice is None: + data_choice = DATA_CHOICES + else: + if not isinstance(data_choice, list): + data_choice = [data_choice] + + all_image_paths, all_gt_paths = [], [] + for _chosen_data in data_choice: + if _chosen_data not in DATA_CHOICES: + raise ValueError(f"'{_chosen_data}' is not a valid choice of data.") + + if _chosen_data.startswith("bact"): + base_dir = os.path.join(data_dir, _chosen_data, f"{split}_sorted", "*") + gt_paths = glob(os.path.join(base_dir, "*_masks.tif")) + image_paths = glob(os.path.join(base_dir, "*.tif")) + + else: + base_dir = os.path.join(data_dir, _chosen_data, split) + gt_paths = glob(os.path.join(base_dir, "*_masks.*")) + image_paths = glob(os.path.join(base_dir, "*")) + + for _path in image_paths.copy(): + # NOTE: Removing the masks and flows from the image paths. + if _path.endswith("_masks.tif") or _path.endswith("_masks.png") or _path.endswith("_flows.tif"): + image_paths.remove(_path) + + all_image_paths.extend(natsorted(image_paths)) + all_gt_paths.extend(natsorted(gt_paths)) + + return all_image_paths, all_gt_paths + + +def get_omnipose_dataset( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int], + split: Literal["train", "test"], + data_choice: Optional[Union[str, List[str]]] = None, + download: bool = False, + **kwargs +) -> Dataset: + """Get the OmniPose dataset for segmenting bacteria and worms in microscopy images. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + patch_shape: The patch shape to use for training. + split: The data split to use. Either 'train' or 'test'. + data_choice: The choice of specific data. + Either 'bact_fluor', 'bact_phase', 'worm' or 'worm_high_res'. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The segmentation dataset. + """ + image_paths, gt_paths = get_omnipose_paths(path, split, data_choice, download) + + return torch_em.default_segmentation_dataset( + raw_paths=image_paths, + raw_key=None, + label_paths=gt_paths, + label_key=None, + is_seg_dataset=False, + patch_shape=patch_shape, + **kwargs + ) + + +def get_omnipose_loader( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int], + batch_size: int, + split: Literal["train", "test"], + data_choice: Optional[Union[str, List[str]]] = None, + download: bool = False, + **kwargs +) -> DataLoader: + """Get the OmniPose dataloader for segmenting bacteria and worms in microscopy images. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + patch_shape: The patch shape to use for training. + batch_size: The batch size for training. + split: The data split to use. Either 'train' or 'test'. + data_choice: The choice of specific data. + Either 'bact_fluor', 'bact_phase', 'worm' or 'worm_high_res'. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. + + Returns: + The DataLoader. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_omnipose_dataset( + path=path, patch_shape=patch_shape, split=split, data_choice=data_choice, download=download, **ds_kwargs + ) + return torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/light_microscopy/organoidnet.py b/torch_em/data/datasets/light_microscopy/organoidnet.py new file mode 100644 index 00000000..a48191fc --- /dev/null +++ b/torch_em/data/datasets/light_microscopy/organoidnet.py @@ -0,0 +1,148 @@ +"""The OrganoIDNet dataset contains annotations of panceratic organoids. + +This dataset is from the publication https://doi.org/10.1007/s13402-024-00958-2. +Please cite it if you use this dataset for a publication. +""" + + +import os +import shutil +import zipfile +from glob import glob +from typing import Tuple, Union, List + +from torch.utils.data import Dataset, DataLoader + +import torch_em + +from .. import util + + +URL = "https://zenodo.org/records/10643410/files/OrganoIDNetData.zip?download=1" +CHECKSUM = "3cd9239bf74bda096ecb5b7bdb95f800c7fa30b9937f9aba6ddf98d754cbfa3d" + + +def get_organoidnet_data(path: Union[os.PathLike, str], split: str, download: bool = False) -> str: + """Download the OrganoIDNet dataset. + + Args: + path: Filepath to the folder where the downloaded data will be saved. + split: The data split to use. + download: Whether to download the data if it is not present. + + Returns: + The filepath where the data is downloaded. + """ + splits = ["Training", "Validation", "Test"] + assert split in splits + + os.makedirs(path, exist_ok=True) + + data_dir = os.path.join(path, split) + if os.path.exists(data_dir): + return data_dir + + # Download and extraction. + zip_path = os.path.join(path, "OrganoIDNetData.zip") + util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM) + + # Only "Training", "Test", "Validation" from the zip are relevant and need to be extracted. + # They are in "/OrganoIDNetData/Dataset/" + prefix = "OrganoIDNetData/Dataset/" + for dl_split in splits: + + dl_prefix = prefix + dl_split + + with zipfile.ZipFile(zip_path) as archive: + for ff in archive.namelist(): + if ff.startswith(dl_prefix): + archive.extract(ff, path) + + for dl_split in splits: + shutil.move( + os.path.join(path, "OrganoIDNetData/Dataset", dl_split), + os.path.join(path, dl_split) + ) + + assert os.path.exists(data_dir) + + os.remove(zip_path) + return data_dir + + +def get_organoidnet_paths( + path: Union[os.PathLike, str], split: str, download: bool = False +) -> Tuple[List[str], List[str]]: + """Get paths to the OrganoIDNet data. + + Args: + path: Filepath to the folder where the downloaded data will be saved. + split: The data split to use. + download: Whether to download the data if it is not present. + + Returns: + List of filepaths for the image data. + List of filepaths for the label data. + """ + data_dir = get_organoidnet_data(path=path, split=split, download=download) + + image_paths = sorted(glob(os.path.join(data_dir, "Images", "*.tif"))) + label_paths = sorted(glob(os.path.join(data_dir, "Masks", "*.tif"))) + + return image_paths, label_paths + + +def get_organoidnet_dataset( + path: Union[os.PathLike, str], split: str, patch_shape: Tuple[int, int], download: bool = False, **kwargs +) -> Dataset: + """Get the OrganoIDNet dataset for organoid segmentation in microscopy images. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The data split to use. + patch_shape: The patch shape to use for training. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The segmentation dataset. + """ + image_paths, label_paths = get_organoidnet_paths(path, split, download) + + return torch_em.default_segmentation_dataset( + raw_paths=image_paths, + raw_key=None, + label_paths=label_paths, + label_key=None, + patch_shape=patch_shape, + is_seg_dataset=False, + **kwargs + ) + + +def get_organoidnet_loader( + path: Union[os.PathLike, str], + split: str, + patch_shape: Tuple[int, int], + batch_size: int, + download: bool = False, + **kwargs +) -> DataLoader: + """Get the OrganoIDNet dataset for organoid segmentation in microscopy images. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The data split to use. + patch_shape: The patch shape to use for training. + batch_size: The batch size for training. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. + + Returns: + The DataLoader. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_organoidnet_dataset( + path=path, split=split, patch_shape=patch_shape, download=download, **ds_kwargs + ) + return torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/light_microscopy/orgasegment.py b/torch_em/data/datasets/light_microscopy/orgasegment.py new file mode 100644 index 00000000..f6b3f2d7 --- /dev/null +++ b/torch_em/data/datasets/light_microscopy/orgasegment.py @@ -0,0 +1,155 @@ +"""The OrgaSegment dataset contains annotations for organoid segmentation +of intestinal patient derived organoids in bright field images. + +This dataset is from the publication https://doi.org/10.1038/s42003-024-05966-4. +Please cite it if you use this dataset for your research. +""" + +import os +import shutil +from glob import glob +from typing import Tuple, Union, Literal, List + +from torch.utils.data import Dataset, DataLoader + +import torch_em + +from .. import util + + +URL = "https://zenodo.org/records/10278229/files/OrganoidBasic_v20211206.zip" +CHECKSUM = "d067124d734108e46e18f65daaf17c89cb0a40bdacc6f6031815a6839e472798" + + +def get_orgasegment_data( + path: Union[os.PathLike, str], + split: Literal["train", "val", "eval"], + download: bool = False +) -> str: + """Download the OrgaSegment dataset for organoid segmentation. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The split to download. Either 'train', 'val or 'eval'. + download: Whether to download the data if it is not present. + + Returns: + The filepath to the training data. + """ + os.makedirs(path, exist_ok=True) + + data_dir = os.path.join(path, split) + if os.path.exists(data_dir): + return data_dir + + zip_path = os.path.join(path, "OrganoidBasic_v20211206.zip") + util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM) + util.unzip(zip_path=zip_path, dst=path, remove=True) + + shutil.move(os.path.join(path, "OrganoidBasic_v20211206", "train"), os.path.join(path, "train")) + shutil.move(os.path.join(path, "OrganoidBasic_v20211206", "val"), os.path.join(path, "val")) + shutil.move(os.path.join(path, "OrganoidBasic_v20211206", "eval"), os.path.join(path, "eval")) + shutil.rmtree(os.path.join(path, "OrganoidBasic_v20211206")) + + return data_dir + + +def get_orgasegment_paths( + path: Union[os.PathLike, str], + split: Literal["train", "val", "eval"], + download: bool = False +) -> Tuple[List[str], List[str]]: + """Get paths for the OrgaSegment data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The split to download. Either 'train', 'val or 'eval'. + download: Whether to download the data if it is not present. + + Returns: + List of filepaths to the image data. + List of filepaths to the label data. + """ + data_dir = get_orgasegment_data(path=path, split=split, download=download) + + image_paths = sorted(glob(os.path.join(data_dir, "*_img.jpg"))) + label_paths = sorted(glob(os.path.join(data_dir, "*_masks_organoid.png"))) + + return image_paths, label_paths + + +def get_orgasegment_dataset( + path: Union[os.PathLike, str], + split: Literal["train", "val", "eval"], + patch_shape: Tuple[int, int], + boundaries: bool = False, + binary: bool = False, + download: bool = False, + **kwargs +) -> Dataset: + """Get the OrgaSegment dataset for organoid segmentation + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The split to download. Either 'train', 'val or 'eval'. + patch_shape: The patch shape to use for training. + boundaries: Whether to compute boundaries as the target. + binary: Whether to use a binary segmentation target. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The segmentation dataset. + """ + assert split in ["train", "val", "eval"] + + image_paths, label_paths = get_orgasegment_paths(path=path, split=split, download=download) + + kwargs, _ = util.add_instance_label_transform(kwargs, add_binary_target=True, binary=binary, boundaries=boundaries) + + return torch_em.default_segmentation_dataset( + raw_paths=image_paths, + raw_key=None, + label_paths=label_paths, + label_key=None, + patch_shape=patch_shape, + is_seg_dataset=False, + **kwargs + ) + + +def get_orgasegment_loader( + path: Union[os.PathLike, str], + split: Literal["train", "val", "eval"], + patch_shape: Tuple[int, int], + batch_size: int, + boundaries: bool = False, + binary: bool = False, + download: bool = False, + **kwargs +) -> DataLoader: + """Get the OrgaSegment dataloader for organoid segmentation + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The split to download. Either 'train', 'val or 'eval'. + patch_shape: The patch shape to use for training. + boundaries: Whether to compute boundaries as the target. + binary: Whether to use a binary segmentation target. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. + + Returns: + The DataLoader. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_orgasegment_dataset( + path=path, + split=split, + patch_shape=patch_shape, + boundaries=boundaries, + binary=binary, + download=download, + **ds_kwargs + ) + return torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/light_microscopy/plantseg.py b/torch_em/data/datasets/light_microscopy/plantseg.py new file mode 100644 index 00000000..2db4f603 --- /dev/null +++ b/torch_em/data/datasets/light_microscopy/plantseg.py @@ -0,0 +1,244 @@ +"""This dataset contains confocal and lightsheet microscopy images of plant cells +with annotations for cell and nucleus segmentation. + +The dataset part of the publication https://doi.org/10.7554/eLife.57613. +Please cite it if you use this dataset in your research. +""" + +import os +from glob import glob +from tqdm import tqdm +from typing import List, Optional, Tuple, Union + +from torch.utils.data import Dataset, DataLoader + +import torch_em + +from .. import util + + +URLS = { + "root": { + "train": "https://files.de-1.osf.io/v1/resources/9x3g2/providers/osfstorage/?zip=", + "val": "https://files.de-1.osf.io/v1/resources/vs6gb/providers/osfstorage/?zip=", + "test": "https://files.de-1.osf.io/v1/resources/tn4xj/providers/osfstorage/?zip=", + }, + "nuclei": { + "train": "https://files.de-1.osf.io/v1/resources/thxzn/providers/osfstorage/?zip=", + }, + "ovules": { + "train": "https://files.de-1.osf.io/v1/resources/x9yns/providers/osfstorage/?zip=", + "val": "https://files.de-1.osf.io/v1/resources/xp5uf/providers/osfstorage/?zip=", + "test": "https://files.de-1.osf.io/v1/resources/8jz7e/providers/osfstorage/?zip=", + } +} + +# FIXME somehow the checksums are not reliably, this is a bit weird. +CHECKSUMS = { + "root": { + "train": None, "val": None, "test": None + # "train": "f72e9525ff716ef14b70ab1318efd4bf303bbf9e0772bf2981a2db6e22a75794", + # "val": "987280d9a56828c840e508422786431dcc3603e0ba4814aa06e7bf4424efcd9e", + # "test": "ad71b8b9d20effba85fb5e1b42594ae35939d1a0cf905f3403789fc9e6afbc58", + }, + "nuclei": { + "train": None + # "train": "9d19ddb61373e2a97effb6cf8bd8baae5f8a50f87024273070903ea8b1160396", + }, + "ovules": { + "train": None, "val": None, "test": None + # "train": "70379673f1ab1866df6eb09d5ce11db7d3166d6d15b53a9c8b47376f04bae413", + # "val": "872f516cb76879c30782d9a76d52df95236770a866f75365902c60c37b14fa36", + # "test": "a7272f6ad1d765af6d121e20f436ac4f3609f1a90b1cb2346aa938d8c52800b9", + } +} + +CROPPING_VOLUMES = { + # root (train) + "Movie2_T00006_crop_gt.h5": slice(4, None), + "Movie2_T00008_crop_gt.h5": slice(None, -18), + "Movie2_T00010_crop_gt.h5": slice(None, -32), + "Movie2_T00012_crop_gt.h5": slice(None, -39), + "Movie2_T00014_crop_gt.h5": slice(None, -40), + "Movie2_T00016_crop_gt.h5": slice(None, -42), + # root (test) + "Movie2_T00020_crop_gt.h5": slice(None, -50), + # ovules (train) + "N_487_ds2x.h5": slice(17, None), + "N_535_ds2x.h5": slice(None, -1), + "N_534_ds2x.h5": slice(None, -1), + "N_451_ds2x.h5": slice(None, -1), + "N_425_ds2x.h5": slice(None, -1), + # ovules (val) + "N_420_ds2x.h5": slice(None, -1), +} + +# The resolution previous used for the resizing +# I have removed this feature since it was not reliable, +# but leaving this here for reference +# (also implementing resizing would be a good idea, +# but more general and not for each dataset individually) +# NATIVE_RESOLUTION = (0.235, 0.075, 0.075) + + +def _fix_inconsistent_volumes(data_path, name, split): + import h5py + + file_paths = glob(os.path.join(data_path, "*.h5")) + if name not in ["root", "ovules"] and split not in ["train", "val"]: + return + + for vol_path in tqdm(file_paths, desc="Fixing inconsistencies in volumes"): + fname = os.path.basename(vol_path) + + # avoid duplicated volumes in 'train' and 'test'. + if fname == "Movie1_t00045_crop_gt.h5" and (name == "root" and split == "train"): + os.remove(vol_path) + continue + + if fname not in CROPPING_VOLUMES: + continue + + with h5py.File(vol_path, "r+") as f: + raw, labels = f["raw"], f["label"] + + crop_slices = CROPPING_VOLUMES[fname] + resized_raw, resized_labels = raw[:][crop_slices], labels[:][crop_slices] + + cropped_shape = resized_raw.shape + raw.resize(cropped_shape) + labels.resize(cropped_shape) + + raw[...] = resized_raw + labels[...] = resized_labels + + +def get_plantseg_data(path: Union[os.PathLike, str], name: str, split: str, download: bool = False) -> str: + """Download the PlantSeg training data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + name: The name of the data to load. Either 'root', 'nuclei' or 'ovules'. + split: The split to download. Either 'train', 'val' or 'test'. + download: Whether to download the data if it is not present. + + Returns: + The filepath to the training data. + """ + url = URLS[name][split] + checksum = CHECKSUMS[name][split] + os.makedirs(path, exist_ok=True) + out_path = os.path.join(path, f"{name}_{split}") + if os.path.exists(out_path): + return out_path + tmp_path = os.path.join(path, f"{name}_{split}.zip") + util.download_source(tmp_path, url, download, checksum) + util.unzip(tmp_path, out_path, remove=True) + _fix_inconsistent_volumes(out_path, name, split) + return out_path + + +def get_plantseg_paths( + path: Union[os.PathLike, str], + name: str, + split: str, + download: bool = False +) -> List[str]: + """Get paths to the PlantSeg data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + name: The name of the data to load. Either 'root', 'nuclei' or 'ovules'. + split: The split to download. Either 'train', 'val' or 'test'. + download: Whether to download the data if it is not present. + + Returns: + List of filepaths for the data. + """ + data_path = get_plantseg_data(path, download, name, split) + file_paths = sorted(glob(os.path.join(data_path, "*.h5"))) + return file_paths + + +def get_plantseg_dataset( + path: Union[os.PathLike, str], + name: str, + split: str, + patch_shape: Tuple[int, int, int], + download: bool = False, + offsets: Optional[List[List[int]]] = None, + boundaries: bool = False, + binary: bool = False, + **kwargs, +) -> Dataset: + """Get the PlantSeg dataset for segmenting nuclei or cells. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + name: The name of the data to load. Either 'root', 'nuclei' or 'ovules'. + split: The split to download. Either 'train', 'val' or 'test'. + patch_shape: The patch shape to use for training. + download: Whether to download the data if it is not present. + offsets: Offset values for affinity computation used as target. + boundaries: Whether to compute boundaries as the target. + binary: Whether to use a binary segmentation target. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The segmentation dataset. + """ + assert len(patch_shape) == 3 + + file_paths = get_plantseg_paths(path, name, split, download) + + kwargs, _ = util.add_instance_label_transform( + kwargs, add_binary_target=binary, binary=binary, boundaries=boundaries, + offsets=offsets, binary_is_exclusive=False + ) + + return torch_em.default_segmentation_dataset( + raw_paths=file_paths, + raw_key="raw", + label_paths=file_paths, + label_key="label", + patch_shape=patch_shape, + **kwargs + ) + + +# TODO add support for ignore label, key: "/label_with_ignore" +def get_plantseg_loader( + path: Union[os.PathLike, str], + name: str, + split: str, + patch_shape: Tuple[int, int, int], + batch_size: int, + download: bool = False, + offsets: Optional[List[List[int]]] = None, + boundaries: bool = False, + binary: bool = False, + **kwargs, +) -> DataLoader: + """Get the PlantSeg dataloader for segmenting nuclei or cells. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + name: The name of the data to load. Either 'root', 'nuclei' or 'ovules'. + split: The split to download. Either 'train', 'val' or 'test'. + patch_shape: The patch shape to use for training. + batch_size: The batch size for training. + download: Whether to download the data if it is not present. + offsets: Offset values for affinity computation used as target. + boundaries: Whether to compute boundaries as the target. + binary: Whether to use a binary segmentation target. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. + + Returns: + The DataLoader. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_plantseg_dataset( + path, name, split, patch_shape, download=download, offsets=offsets, + boundaries=boundaries, binary=binary, **ds_kwargs + ) + return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/pnas_arabidopsis.py b/torch_em/data/datasets/light_microscopy/pnas_arabidopsis.py similarity index 100% rename from torch_em/data/datasets/pnas_arabidopsis.py rename to torch_em/data/datasets/light_microscopy/pnas_arabidopsis.py diff --git a/torch_em/data/datasets/light_microscopy/tissuenet.py b/torch_em/data/datasets/light_microscopy/tissuenet.py new file mode 100644 index 00000000..52e88c88 --- /dev/null +++ b/torch_em/data/datasets/light_microscopy/tissuenet.py @@ -0,0 +1,188 @@ +"""The TissueNet dataset contains annotations for cell segmentation in microscopy images of different tissue types. + +This dataset is from the publication https://doi.org/10.1038/s41587-021-01094-0. +Please cite it if you use this dataset for your research. + +This dataset cannot be downloaded automatically, please visit https://datasets.deepcell.org/data +and download it yourself. +""" + +import os +from glob import glob +from tqdm import tqdm +from typing import Tuple, Union, List + +import numpy as np +import pandas as pd + +from torch.utils.data import Dataset, DataLoader + +import torch_em + +from .. import util + + +def _create_split(path, split): + import z5py + + split_file = os.path.join(path, f"tissuenet_v1.1_{split}.npz") + split_folder = os.path.join(path, split) + os.makedirs(split_folder, exist_ok=True) + data = np.load(split_file, allow_pickle=True) + + x, y = data["X"], data["y"] + metadata = data["meta"] + metadata = pd.DataFrame(metadata[1:], columns=metadata[0]) + + for i, (im, label) in tqdm(enumerate(zip(x, y)), total=len(x), desc=f"Creating files for {split}-split"): + out_path = os.path.join(split_folder, f"image_{i:04}.zarr") + nucleus_channel = im[..., 0] + cell_channel = im[..., 1] + rgb = np.stack([np.zeros_like(nucleus_channel), cell_channel, nucleus_channel]) + chunks = cell_channel.shape + with z5py.File(out_path, "a") as f: + + f.create_dataset("raw/nucleus", data=im[..., 0], compression="gzip", chunks=chunks) + f.create_dataset("raw/cell", data=cell_channel, compression="gzip", chunks=chunks) + f.create_dataset("raw/rgb", data=rgb, compression="gzip", chunks=(3,) + chunks) + + # the switch 0<->1 is intentional, the data format is chaotic... + f.create_dataset("labels/nucleus", data=label[..., 1], compression="gzip", chunks=chunks) + f.create_dataset("labels/cell", data=label[..., 0], compression="gzip", chunks=chunks) + os.remove(split_file) + + +def _create_dataset(path, zip_path): + util.unzip(zip_path, path, remove=False) + splits = ["train", "val", "test"] + assert all([os.path.exists(os.path.join(path, f"tissuenet_v1.1_{split}.npz")) for split in splits]) + for split in splits: + _create_split(path, split) + + +def get_tissuenet_data(path: Union[os.PathLike, str], split: str, download: bool = False) -> str: + """Download the TissueNet dataset. + + NOTE: Automatic download is not supported for TissueNet datset. + Please download the dataset from https://datasets.deepcell.org/data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The data split to use. Either 'train', 'val' or 'test'. + download: Whether to download the data if it is not present. + + Returns: + The path where inputs are stored per split. + """ + splits = ["train", "val", "test"] + assert split in splits + + # check if the dataset exists already + zip_path = os.path.join(path, "tissuenet_v1.1.zip") + if all([os.path.exists(os.path.join(path, split)) for split in splits]): # yes it does + pass + elif os.path.exists(zip_path): # no it does not, but we have the zip there and can unpack it + _create_dataset(path, zip_path) + else: + raise RuntimeError( + "We do not support automatic download for the tissuenet datasets yet." + f"Please download the dataset from https://datasets.deepcell.org/data and put it here: {zip_path}" + ) + + split_folder = os.path.join(path, split) + return split_folder + + +def get_tissuenet_paths(path: Union[os.PathLike, str], split: str, download: bool = False) -> List[str]: + """Get paths to the TissueNet data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The data split to use. Either 'train', 'val' or 'test'. + download: Whether to download the data if it is not present. + + Returns: + List of filepaths for the data. + """ + split_folder = get_tissuenet_data(path, split, download) + assert os.path.exists(split_folder) + data_paths = glob(os.path.join(split_folder, "*.zarr")) + assert len(data_paths) > 0 + + return data_paths + + +def get_tissuenet_dataset( + path: Union[os.PathLike, str], + split: str, + patch_shape: Tuple[int, int], + raw_channel: str, + label_channel: str, + download: bool = False, + **kwargs +) -> Dataset: + """Get the TissueNet dataset for segmenting cells in microscopy tissue images. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The data split to use. Either 'train', 'val' or 'test'. + patch_shape: The patch shape to use for training. + raw_channel: The channel to load for the raw data. Either 'nucleus', 'cell' or 'rgb'. + label_channel: The channel to load for the label data. Either 'nucleus' or 'cell'. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The segmentation dataset. + """ + assert raw_channel in ("nucleus", "cell", "rgb") + assert label_channel in ("nucleus", "cell") + + data_paths = get_tissuenet_paths(path, split, download) + + with_channels = True if raw_channel == "rgb" else False + kwargs = util.update_kwargs(kwargs, "with_channels", with_channels) + kwargs = util.update_kwargs(kwargs, "is_seg_dataset", True) + kwargs = util.update_kwargs(kwargs, "ndim", 2) + + return torch_em.default_segmentation_dataset( + raw_paths=data_paths, + raw_key=f"raw/{raw_channel}", + label_paths=data_paths, + label_key=f"labels/{label_channel}", + patch_shape=patch_shape, + **kwargs + ) + + +# TODO enable loading specific tissue types etc. (from the 'meta' attributes) +def get_tissuenet_loader( + path: Union[os.PathLike, str], + split: str, + patch_shape: Tuple[int, int], + batch_size: int, + raw_channel: str, + label_channel: str, + download: bool = False, + **kwargs +) -> DataLoader: + """Get the TissueNet dataloader for segmenting cells in microscopy tissue images. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The data split to use. Either 'train', 'val' or 'test'. + patch_shape: The patch shape to use for training. + batch_size: The batch size for training. + raw_channel: The channel to load for the raw data. Either 'nucleus', 'cell' or 'rgb'. + label_channel: The channel to load for the label data. Either 'nucleus' or 'cell'. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. + + Returns: + The DataLoader. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_tissuenet_dataset( + path, split, patch_shape, raw_channel, label_channel, download, **ds_kwargs + ) + return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/light_microscopy/vgg_hela.py b/torch_em/data/datasets/light_microscopy/vgg_hela.py new file mode 100644 index 00000000..75500241 --- /dev/null +++ b/torch_em/data/datasets/light_microscopy/vgg_hela.py @@ -0,0 +1,160 @@ +"""This is a dataset for counting HeLA cells in phase-contrast microscopy. + +It is described in the publication https://www.robots.ox.ac.uk/~vgg/publications/2012/Arteta12/. +Please cite it if you use this dataset in your research. +""" + +import os +from glob import glob +from shutil import rmtree +from typing import Tuple, Union + +import numpy as np +import imageio.v3 as imageio +from scipy.io import loadmat + +from torch.utils.data import Dataset, DataLoader + +import torch_em + +from .. import util + + +URL = "https://www.robots.ox.ac.uk/~vgg/software/cell_detection/downloads/CellDetect_v1.0.tar.gz" +CHECKSUM = "09825d6a8e287ddf2c4b1ef3d2f62585ec6876e3bfcd4b9bbcd3dd300e4be282" + + +def get_vgg_hela_data(path: Union[os.PathLike, str], download: bool) -> str: + """Download the HeLA VGG dataset. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + download: Whether to download the data if it is not present. + + Returns: + The filepath to the training data. + """ + os.makedirs(path, exist_ok=True) + url = URL + checksum = CHECKSUM + + train_path = os.path.join(path, "train") + test_path = os.path.join(path, "test") + + if os.path.exists(train_path) and os.path.exists(test_path): + return path + + dl_path = os.path.join(path, "cell_detect.tar.gz") + util.download_source(dl_path, url, download, checksum) + util.unzip_tarfile(dl_path, path, True) + + extracted_path = os.path.join(path, "CellDetect_v1.0") + assert os.path.exists(extracted_path), extracted_path + + splits_in = ["trainPhasecontrast", "testPhasecontrast"] + splits_out = [train_path, test_path] + + for split_in, out_folder in zip(splits_in, splits_out): + out_im_folder = os.path.join(out_folder, "images") + os.makedirs(out_im_folder, exist_ok=True) + + out_label_folder = os.path.join(out_folder, "labels") + os.makedirs(out_label_folder, exist_ok=True) + + split_root = os.path.join(extracted_path, "phasecontrast", split_in) + image_files = sorted(glob(os.path.join(split_root, "*.pgm"))) + mat_files = sorted(glob(os.path.join(split_root, "*.mat"))) + + for ii, (im, mat) in enumerate(zip(image_files, mat_files), 1): + im = imageio.imread(im) + coordinates = loadmat(mat)["gt"] - 1 + coordinates = (coordinates[:, 1], coordinates[:, 0]) + + out_im = os.path.join(out_im_folder, f"im{ii:02}.tif") + imageio.imwrite(out_im, im, compression="zlib") + + labels = np.zeros(im.shape, dtype="uint8") + labels[coordinates] = 1 + out_labels = os.path.join(out_label_folder, f"im{ii:02}.tif") + imageio.imwrite(out_labels, labels, compression="zlib") + + rmtree(extracted_path) + return path + + +def get_vgg_hela_paths(path: Union[os.PathLike, str], split: str, download: bool = False) -> Tuple[str, str]: + """Get paths for HeLA VGG data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The split to use for the dataset. Either 'train' or 'test'. + download: Whether to download the data if it is not present. + + Returns: + Filepath to the folder where image data is stored. + Filepath to the folder where label data is stored. + """ + get_vgg_hela_data(path, download) + + image_path = os.path.join(path, split, "images") + label_path = os.path.join(path, split, "labels") + + return image_path, label_path + + +def get_vgg_hela_dataset( + path: Union[os.PathLike, str], split: str, patch_shape: Tuple[int, int], download: bool = False, **kwargs +) -> Dataset: + """Get the HeLA VGG dataset for cell counting. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The split to use for the dataset. Either 'train' or 'test'. + patch_shape: The patch shape to use for training. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The segmentation dataset. + """ + assert split in ("test", "train"), split + + image_path, label_path = get_vgg_hela_paths(path, split, download) + + kwargs = util.update_kwargs(kwargs, "ndim", 2) + kwargs = util.update_kwargs(kwargs, "is_seg_dataset", True) + + return torch_em.default_segmentation_dataset( + raw_paths=image_path, + raw_key="*.tif", + label_paths=label_path, + label_key="*.tif", + patch_shape=patch_shape, + **kwargs + ) + + +def get_vgg_hela_loader( + path: Union[os.PathLike, str], + split: str, + patch_shape: Tuple[int, int], + batch_size: int, + download: bool = False, + **kwargs +) -> DataLoader: + """Get the HeLA VGG dataloader for cell counting. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + split: The split to use for the dataset. Either 'train' or 'test'. + patch_shape: The patch shape to use for training. + batch_size: The batch size for training. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. + + Returns: + The DataLoader. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_vgg_hela_dataset(path, split, patch_shape, download=download, **ds_kwargs) + return torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/lizard.py b/torch_em/data/datasets/lizard.py deleted file mode 100644 index a0af3785..00000000 --- a/torch_em/data/datasets/lizard.py +++ /dev/null @@ -1,117 +0,0 @@ -import os -import warnings -from glob import glob -from shutil import rmtree - -import h5py -import imageio.v3 as imageio -import torch_em - -from scipy.io import loadmat -from tqdm import tqdm -from .import util - -# TODO: the links don't work anymore (?) -# workaround to still make this work (kaggle still has the dataset in the same structure): -# - download the zip files manually from here - https://www.kaggle.com/datasets/aadimator/lizard-dataset -# - Kaggle API (TODO) - `kaggle datasets download -d aadimator/lizard-dataset` -URL1 = "https://warwick.ac.uk/fac/cross_fac/tia/data/lizard/lizard_images1.zip" -URL2 = "https://warwick.ac.uk/fac/cross_fac/tia/data/lizard/lizard_images2.zip" -LABEL_URL = "https://warwick.ac.uk/fac/cross_fac/tia/data/lizard/lizard_labels.zip" - -CHECKSUM1 = "d2c4e7c83dff634624c9c14d4a1a0b821d4e9ac41e05e3b36303d8f0c510113d" -CHECKSUM2 = "9f529f30d9de66587167991a8bf75aaad07ce1d518b72e825c868ac7c33015ed" -LABEL_CHECKSUM = "79f22ca83ca535682fba340cbc8bb66b74abd1ead4151ffc8593f204fcb97dec" - - -def _extract_images(image_folder, label_folder, output_dir): - image_files = glob(os.path.join(image_folder, "*.png")) - for image_file in tqdm(image_files, desc=f"Extract images from {image_folder}"): - fname = os.path.basename(image_file) - label_file = os.path.join(label_folder, fname.replace(".png", ".mat")) - assert os.path.exists(label_file), label_file - - image = imageio.imread(image_file) - assert image.ndim == 3 and image.shape[-1] == 3 - - labels = loadmat(label_file) - segmentation = labels["inst_map"] - assert image.shape[:-1] == segmentation.shape - classes = labels["class"] - - image = image.transpose((2, 0, 1)) - assert image.shape[1:] == segmentation.shape - - output_file = os.path.join(output_dir, fname.replace(".png", ".h5")) - with h5py.File(output_file, "a") as f: - f.create_dataset("image", data=image, compression="gzip") - f.create_dataset("labels/segmentation", data=segmentation, compression="gzip") - f.create_dataset("labels/classes", data=classes, compression="gzip") - - -def _require_lizard_data(path, download): - image_files = glob(os.path.join(path, "*.h5")) - if len(image_files) > 0: - return - - os.makedirs(path, exist_ok=True) - - zip_path = os.path.join(path, "lizard_images1.zip") - util.download_source(zip_path, URL1, download=download, checksum=CHECKSUM1) - util.unzip(zip_path, path, remove=True) - - zip_path = os.path.join(path, "lizard_images2.zip") - util.download_source(zip_path, URL2, download=download, checksum=CHECKSUM2) - util.unzip(zip_path, path, remove=True) - - zip_path = os.path.join(path, "lizard_labels.zip") - util.download_source(zip_path, LABEL_URL, download=download, checksum=LABEL_CHECKSUM) - util.unzip(zip_path, path, remove=True) - - image_folder1 = os.path.join(path, "Lizard_Images1") - image_folder2 = os.path.join(path, "Lizard_Images2") - label_folder = os.path.join(path, "Lizard_Labels") - - assert os.path.exists(image_folder1), image_folder1 - assert os.path.exists(image_folder2), image_folder2 - assert os.path.exists(label_folder), label_folder - - _extract_images(image_folder1, os.path.join(label_folder, "Labels"), path) - _extract_images(image_folder2, os.path.join(label_folder, "Labels"), path) - - rmtree(image_folder1) - rmtree(image_folder2) - rmtree(label_folder) - - -def get_lizard_dataset(path, patch_shape, download=False, **kwargs): - """Dataset for the segmentation of nuclei in histopathology. - - This dataset is from the publication https://doi.org/10.48550/arXiv.2108.11195. - Please cite it if you use this dataset for a publication. - """ - if download: - warnings.warn(f"The download link does not work right now. Please manually download the zip files to {path} from https://www.kaggle.com/datasets/aadimator/lizard-dataset") - - _require_lizard_data(path, download) - - data_paths = glob(os.path.join(path, "*.h5")) - data_paths.sort() - - raw_key = "image" - label_key = "labels/segmentation" - return torch_em.default_segmentation_dataset( - data_paths, raw_key, data_paths, label_key, patch_shape, ndim=2, with_channels=True, **kwargs - ) - - -# TODO implement loading the classification labels -# TODO implement selecting different tissue types -# TODO implement train / val / test split (is pre-defined in a csv) -def get_lizard_loader(path, patch_shape, batch_size, download=False, **kwargs): - """Dataloader for the segmentation of nuclei in histopathology. See 'get_lizard_dataset' for details.""" - ds_kwargs, loader_kwargs = util.split_kwargs( - torch_em.default_segmentation_dataset, **kwargs - ) - ds = get_lizard_dataset(path, patch_shape, download=download, **ds_kwargs) - return torch_em.get_data_loader(ds, batch_size=batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/lucchi.py b/torch_em/data/datasets/lucchi.py deleted file mode 100644 index 981c4eef..00000000 --- a/torch_em/data/datasets/lucchi.py +++ /dev/null @@ -1,101 +0,0 @@ -import os -from concurrent import futures -from glob import glob -from shutil import rmtree - -import imageio -import h5py -import numpy as np -import torch_em - -from tqdm import tqdm -from . import util - -URL = "http://www.casser.io/files/lucchi_pp.zip" -CHECKSUM = "770ce9e98fc6f29c1b1a250c637e6c5125f2b5f1260e5a7687b55a79e2e8844d" - -# data from: https://sites.google.com/view/connectomics/ -# TODO: add sampler for foreground to avoid empty batches - - -def _load_volume(path, pattern): - nz = len(glob(os.path.join(path, "*.png"))) - im0 = imageio.imread(os.path.join(path, pattern % 0)) - out = np.zeros((nz,) + im0.shape, dtype=im0.dtype) - out[0] = im0 - - def _loadz(z): - im = imageio.imread(os.path.join(path, pattern % z)) - out[z] = im - - n_threads = 8 - with futures.ThreadPoolExecutor(n_threads) as tp: - list(tqdm( - tp.map(_loadz, range(1, nz)), desc="Load volume", total=nz-1 - )) - - return out - - -def _create_data(root, inputs, out_path): - raw = _load_volume(os.path.join(root, inputs[0]), pattern="mask%04i.png") - labels_argb = _load_volume(os.path.join(root, inputs[1]), pattern="%i.png") - if labels_argb.ndim == 4: - labels = np.zeros(raw.shape, dtype="uint8") - fg_mask = (labels_argb == np.array([255, 255, 255, 255])[None, None, None]).all(axis=-1) - labels[fg_mask] = 1 - else: - assert labels_argb.ndim == 3 - labels = labels_argb - labels[labels == 255] = 1 - assert (np.unique(labels) == np.array([0, 1])).all() - assert raw.shape == labels.shape, f"{raw.shape}, {labels.shape}" - with h5py.File(out_path, "w") as f: - f.create_dataset("raw", data=raw, compression="gzip") - f.create_dataset("labels", data=labels.astype("uint8"), compression="gzip") - - -def _require_lucchi_data(path, download): - # download and unzip the data - if os.path.exists(path): - return path - - os.makedirs(path) - tmp_path = os.path.join(path, "lucchi.zip") - util.download_source(tmp_path, URL, download, checksum=CHECKSUM) - util.unzip(tmp_path, path, remove=True) - - root = os.path.join(path, "Lucchi++") - assert os.path.exists(root), root - - inputs = [["Test_In", "Test_Out"], ["Train_In", "Train_Out"]] - outputs = ["lucchi_train.h5", "lucchi_test.h5"] - for inp, out in zip(inputs, outputs): - out_path = os.path.join(path, out) - _create_data(root, inp, out_path) - - rmtree(root) - - -def get_lucchi_dataset(path, split, patch_shape, download=False, **kwargs): - """Dataset for the segmentation of mitochondria in EM. - - This dataset is from the publication https://doi.org/10.48550/arXiv.1812.06024. - Please cite it if you use this dataset for a publication. - """ - assert split in ("train", "test") - _require_lucchi_data(path, download) - data_path = os.path.join(path, f"lucchi_{split}.h5") - assert os.path.exists(data_path), data_path - raw_key, label_key = "raw", "labels" - return torch_em.default_segmentation_dataset(data_path, raw_key, data_path, label_key, patch_shape, **kwargs) - - -def get_lucchi_loader(path, split, patch_shape, batch_size, download=False, **kwargs): - """Dataloader for the segmentation of mitochondria in EM. See 'get_lucchi_dataset' for details""" - ds_kwargs, loader_kwargs = util.split_kwargs( - torch_em.default_segmentation_dataset, **kwargs - ) - dataset = get_lucchi_dataset(path, split, patch_shape, download=download, **ds_kwargs) - loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) - return loader diff --git a/torch_em/data/datasets/medical/__init__.py b/torch_em/data/datasets/medical/__init__.py index fc1ce8ce..0993e076 100644 --- a/torch_em/data/datasets/medical/__init__.py +++ b/torch_em/data/datasets/medical/__init__.py @@ -1,3 +1,43 @@ +from .acdc import get_acdc_dataset, get_acdc_loader +from .acouslic_ai import get_acouslic_ai_dataset, get_acouslic_ai_loader from .autopet import get_autopet_loader +from .amos import get_amos_dataset, get_amos_loader from .btcv import get_btcv_dataset, get_btcv_loader +from .busi import get_busi_dataset, get_busi_loader +from .camus import get_camus_dataset, get_camus_loader +from .cbis_ddsm import get_cbis_ddsm_dataset, get_cbis_ddsm_loader +from .cholecseg8k import get_cholecseg8k_dataset, get_cholecseg8k_loader +from .covid19_seg import get_covid19_seg_dataset, get_covid19_seg_loader +from .curvas import get_curvas_dataset, get_curvas_loader +from .dca1 import get_dca1_dataset, get_dca1_loader +from .drive import get_drive_dataset, get_drive_loader +from .duke_liver import get_duke_liver_dataset, get_duke_liver_loader +from .feta24 import get_feta24_dataset, get_feta24_loader +from .han_seg import get_han_seg_dataset, get_han_seg_loader +from .hil_toothseg import get_hil_toothseg_dataset, get_hil_toothseg_loader +from .idrid import get_idrid_dataset, get_idrid_loader +from .isic import get_isic_dataset, get_isic_loader from .isles import get_isles_dataset, get_isles_loader +from .jnuifm import get_jnuifm_dataset, get_jnuifm_loader +from .leg_3d_us import get_leg_3d_us_dataset, get_leg_3d_us_loader +from .lgg_mri import get_lgg_mri_dataset, get_lgg_mri_loader +from .m2caiseg import get_m2caiseg_dataset, get_m2caiseg_loader +from .mbh_seg import get_mbh_seg_dataset, get_mbh_seg_loader +from .micro_usp import get_micro_usp_dataset, get_micro_usp_loader +from .montgomery import get_montgomery_dataset, get_montgomery_loader +from .msd import get_msd_dataset, get_msd_loader +from .oasis import get_oasis_dataset, get_oasis_loader +from .oimhs import get_oimhs_dataset, get_oimhs_loader +from .osic_pulmofib import get_osic_pulmofib_dataset, get_osic_pulmofib_loader +from .panorama import get_panorama_dataset, get_panorama_loader +from .papila import get_papila_dataset, get_papila_loader +from .pengwin import get_pengwin_dataset, get_pengwin_loader +from .piccolo import get_piccolo_dataset, get_piccolo_loader +from .plethora import get_plethora_dataset, get_plethora_loader +from .sa_med2d import get_sa_med2d_dataset, get_sa_med2d_loader +from .sega import get_sega_dataset, get_sega_loader +from .segthy import get_segthy_dataset, get_segthy_loader +from .siim_acr import get_siim_acr_dataset, get_siim_acr_loader +from .spider import get_spider_dataset, get_spider_loader +from .toothfairy import get_toothfairy_dataset, get_toothfairy_loader +from .uwaterloo_skin import get_uwaterloo_skin_dataset, get_uwaterloo_skin_loader diff --git a/torch_em/data/datasets/medical/acdc.py b/torch_em/data/datasets/medical/acdc.py new file mode 100644 index 00000000..448708bb --- /dev/null +++ b/torch_em/data/datasets/medical/acdc.py @@ -0,0 +1,106 @@ +import os +from glob import glob +from natsort import natsorted +from typing import Union, Tuple + +import torch_em + +from .. import util +from ... import ConcatDataset + + +URL = "https://humanheart-project.creatis.insa-lyon.fr/database/api/v1/collection/637218c173e9f0047faa00fb/download" +CHECKSUM = "2787e08b0d3525cbac710fc3bdf69ee7c5fd7446472e49db8bc78548802f6b5e" + + +def get_acdc_data(path, download): + os.makedirs(path, exist_ok=True) + + zip_path = os.path.join(path, "ACDC.zip") + trg_dir = os.path.join(path, "ACDC") + if os.path.exists(trg_dir): + return trg_dir + + util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM) + util.unzip(zip_path=zip_path, dst=path, remove=False) + return trg_dir + + +def _get_acdc_paths(path, split, download): + root_dir = get_acdc_data(path=path, download=download) + + if split == "train": + input_dir = os.path.join(root_dir, "database", "training") + else: + input_dir = os.path.join(root_dir, "database", "testing") + + all_patient_dirs = natsorted(glob(os.path.join(input_dir, "patient*"))) + + image_paths, gt_paths = [], [] + for per_patient_dir in all_patient_dirs: + # the volumes with frames are for particular time frames (end diastole (ED) and end systole (ES)) + # the "frames" denote - ED and ES phase instances, which have manual segmentations. + all_volumes = glob(os.path.join(per_patient_dir, "*frame*.nii.gz")) + for vol_path in all_volumes: + sres = vol_path.find("gt") + if sres == -1: # this means the search was invalid, hence it's the mri volume + image_paths.append(vol_path) + else: # this means that the search went through, hence it's the ground truth volume + gt_paths.append(vol_path) + + return natsorted(image_paths), natsorted(gt_paths) + + +def get_acdc_dataset( + path: Union[os.PathLike, str], + split: str, + patch_shape: Tuple[int, int], + download: bool = False, + **kwargs +): + """Dataset fir multi-structure segmentation in cardiac MRI. + + The labels have the following mapping: + - 0 (background), 1 (right ventricle cavity),2 (myocardium), 3 (left ventricle cavity) + + The database is located at + https://humanheart-project.creatis.insa-lyon.fr/database/#collection/637218c173e9f0047faa00fb + + The dataset is from the publication https://doi.org/10.1109/TMI.2018.2837502 + + Please cite it if you use this dataset for a publication. + """ + assert split in ["train", "test"], f"{split} is not a valid split." + + image_paths, gt_paths = _get_acdc_paths(path=path, split=split, download=download) + + all_datasets = [] + for image_path, gt_path in zip(image_paths, gt_paths): + per_vol_ds = torch_em.default_segmentation_dataset( + raw_paths=image_path, + raw_key="data", + label_paths=gt_path, + label_key="data", + patch_shape=patch_shape, + is_seg_dataset=True, + **kwargs + ) + all_datasets.append(per_vol_ds) + + return ConcatDataset(*all_datasets) + + +def get_acdc_loader( + path: Union[os.PathLike, str], + split: str, + patch_shape: Tuple[int, int], + batch_size: int, + download: bool = False, + **kwargs +): + """Dataloader for multi-structure segmentation in cardiac MRI, See `get_acdc_dataset` for details. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_acdc_dataset(path=path, split=split, patch_shape=patch_shape, download=download, **ds_kwargs) + loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) + return loader diff --git a/torch_em/data/datasets/medical/acouslic_ai.py b/torch_em/data/datasets/medical/acouslic_ai.py new file mode 100644 index 00000000..26c6db45 --- /dev/null +++ b/torch_em/data/datasets/medical/acouslic_ai.py @@ -0,0 +1,129 @@ +"""The Acouslic AI dataset contains annotations for fetal segmentation +in ultrasound images. + +This dataset is from the challenge: https://acouslic-ai.grand-challenge.org/. +Please cite the challenge if you use this dataset for your publication. +""" + +import os +from glob import glob +from natsort import natsorted +from typing import Tuple, Union, List + +from torch.utils.data import Dataset, DataLoader + +import torch_em + +from .. import util + + +URL = "https://zenodo.org/records/11005384/files/acouslic-ai-train-set.zip" +CHECKSUM = "187602dd243a3a872502b57b8ea56e28c67a9ded547b6e816b00c6d41f8b8767" + + +def get_acouslic_ai_data(path: Union[os.PathLike, str], download: bool = False) -> str: + """Download the Acouslic AI dataset. + + Args: + path: Filepath to a folder where the data is downloaded for further processing. + download: Whether to download the data if it is not present. + + Returns: + Filepath where the data is downlaoded. + """ + os.makedirs(path, exist_ok=True) + + data_dir = os.path.join(path, "data") + if os.path.exists(data_dir): + return data_dir + + zip_path = os.path.join(path, "acouslic-ai-train-set.zip") + util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM) + util.unzip(zip_path=zip_path, dst=data_dir, remove=False) + + return data_dir + + +def get_acouslic_ai_paths(path: Union[os.PathLike, str], download: bool = False) -> Tuple[List[str], List[str]]: + """Get paths to the Acouslic AI data. + + Args: + path: Filepath to a folder where the data is downloaded for further processing. + download: Whether to download the data if it is not present. + + Returns: + List of filepaths for the image data. + List of filepaths for the label data. + """ + data_dir = get_acouslic_ai_data(path=path, download=download) + + image_paths = natsorted(glob(os.path.join(data_dir, "images", "stacked_fetal_ultrasound", "*.mha"))) + gt_paths = natsorted(glob(os.path.join(data_dir, "masks", "stacked_fetal_abdomen", "*.mha"))) + + return image_paths, gt_paths + + +def get_acouslic_ai_dataset( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, ...], + resize_inputs: bool = False, + download: bool = False, + **kwargs +) -> Dataset: + """Get the Acouslic AI dataset for fetal segmentation. + + Args: + path: Filepath to a folder where the data is downloaded for further processing. + patch_shape: The patch shape to use for training. + resize_inputs: Whether to resize inputs to the desired patch shape. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The segmentation dataset. + """ + image_paths, gt_paths = get_acouslic_ai_paths(path=path, download=download) + + if resize_inputs: + resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False} + kwargs, patch_shape = util.update_kwargs_for_resize_trafo( + kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs + ) + + return torch_em.default_segmentation_dataset( + raw_paths=image_paths, + raw_key=None, + label_paths=gt_paths, + label_key=None, + patch_shape=patch_shape, + **kwargs + ) + + +def get_acouslic_ai_loader( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, ...], + batch_size: int, + resize_inputs: bool = False, + download: bool = False, + **kwargs +) -> DataLoader: + """Get the Acouslic AI dataloader for fetal segmentation. + + Args: + path: Filepath to a folder where the data is downloaded for further processing. + patch_shape: The patch shape to use for training. + batch_size: The batch size for training. + resize_inputs: Whether to resize inputs to the desired patch shape. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. + + Returns: + The DataLoader. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_acouslic_ai_dataset( + path=path, patch_shape=patch_shape, resize_inputs=resize_inputs, download=download, **ds_kwargs + ) + loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) + return loader diff --git a/torch_em/data/datasets/medical/amos.py b/torch_em/data/datasets/medical/amos.py new file mode 100644 index 00000000..299617f9 --- /dev/null +++ b/torch_em/data/datasets/medical/amos.py @@ -0,0 +1,132 @@ +import os +from glob import glob +from pathlib import Path +from typing import Union, Tuple, Optional + +import torch_em + +from .. import util + + +URL = "https://zenodo.org/records/7155725/files/amos22.zip" +CHECKSUM = "d2fbf2c31abba9824d183f05741ce187b17905b8cca64d1078eabf1ba96775c2" + + +def get_amos_data(path, download, remove_zip=False): + os.makedirs(path, exist_ok=True) + + data_dir = os.path.join(path, "amos22") + if os.path.exists(data_dir): + return data_dir + + zip_path = os.path.join(path, "amos22.zip") + + util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM) + util.unzip(zip_path=zip_path, dst=path, remove=remove_zip) + + return data_dir + + +def _get_amos_paths(path, split, modality, download): + data_dir = get_amos_data(path=path, download=download) + + if split == "train": + im_dir, gt_dir = "imagesTr", "labelsTr" + elif split == "val": + im_dir, gt_dir = "imagesVa", "labelsVa" + elif split == "test": + im_dir, gt_dir = "imagesTs", "labelsTs" + else: + raise ValueError(f"'{split}' is not a valid split.") + + image_paths = sorted(glob(os.path.join(data_dir, im_dir, "*.nii.gz"))) + gt_paths = sorted(glob(os.path.join(data_dir, gt_dir, "*.nii.gz"))) + + if modality is None: + chosen_image_paths, chosen_gt_paths = image_paths, gt_paths + else: + ct_image_paths, ct_gt_paths = [], [] + mri_image_paths, mri_gt_paths = [], [] + for image_path, gt_path in zip(image_paths, gt_paths): + patient_id = Path(image_path.split(".")[0]).stem + id_value = int(patient_id.split("_")[-1]) + + is_ct = id_value < 500 + + if is_ct: + ct_image_paths.append(image_path) + ct_gt_paths.append(gt_path) + else: + mri_image_paths.append(image_path) + mri_gt_paths.append(gt_path) + + if modality.upper() == "CT": + chosen_image_paths, chosen_gt_paths = ct_image_paths, ct_gt_paths + elif modality.upper() == "MRI": + chosen_image_paths, chosen_gt_paths = mri_image_paths, mri_gt_paths + else: + raise ValueError(f"'{modality}' is not a valid modality.") + + return chosen_image_paths, chosen_gt_paths + + +def get_amos_dataset( + path: Union[os.PathLike, str], + split: str, + patch_shape: Tuple[int, ...], + modality: Optional[str] = None, + resize_inputs: bool = False, + download: bool = False, + **kwargs +): + """Dataset for abdominal multi-organ segmentation in CT and MRI scans. + + The database is located at https://doi.org/10.5281/zenodo.7155725 + The dataset is from AMOS 2022 Challenge - https://doi.org/10.48550/arXiv.2206.08023 + Please cite it if you use this dataset for publication. + """ + image_paths, gt_paths = _get_amos_paths(path=path, split=split, modality=modality, download=download) + + if resize_inputs: + resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False} + kwargs, patch_shape = util.update_kwargs_for_resize_trafo( + kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs + ) + + dataset = torch_em.default_segmentation_dataset( + raw_paths=image_paths, + raw_key="data", + label_paths=gt_paths, + label_key="data", + patch_shape=patch_shape, + is_seg_dataset=True, + **kwargs + ) + + return dataset + + +def get_amos_loader( + path: Union[os.PathLike, str], + split: str, + patch_shape: Tuple[int, ...], + batch_size: int, + modality: Optional[str] = None, + resize_inputs: bool = False, + download: bool = False, + **kwargs +): + """Dataloader for abdominal multi-organ segmentation in CT and MRI scans. See `get_amos_dataset` for details. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_amos_dataset( + path=path, + split=split, + patch_shape=patch_shape, + modality=modality, + resize_inputs=resize_inputs, + download=download, + **ds_kwargs + ) + loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) + return loader diff --git a/torch_em/data/datasets/medical/busi.py b/torch_em/data/datasets/medical/busi.py new file mode 100644 index 00000000..f529df1a --- /dev/null +++ b/torch_em/data/datasets/medical/busi.py @@ -0,0 +1,105 @@ +import os +from glob import glob +from typing import Union, Tuple, Optional + +import torch_em +from torch_em.transform.generic import ResizeInputs + +from .. import util +from ... import ImageCollectionDataset + + +URL = "https://scholar.cu.edu.eg/Dataset_BUSI.zip" +CHECKSUM = "b2ce09f6063a31a73f628b6a6ee1245187cbaec225e93e563735691d68654de7" + + +def get_busi_data(path, download): + os.makedirs(path, exist_ok=True) + + data_dir = os.path.join(path, "Dataset_BUSI_with_GT") + if os.path.exists(data_dir): + return data_dir + + zip_path = os.path.join(path, "Dataset_BUSI.zip") + util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM, verify=False) + util.unzip(zip_path=zip_path, dst=path) + + return data_dir + + +def _get_busi_paths(path, category, download): + data_dir = get_busi_data(path=path, download=download) + + if category is None: + category = "*" + + data_dir = os.path.join(data_dir, category) + + image_paths = sorted(glob(os.path.join(data_dir, r"*).png"))) + gt_paths = sorted(glob(os.path.join(data_dir, r"*)_mask.png"))) + + return image_paths, gt_paths + + +def get_busi_dataset( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int], + category: Optional[str] = None, + resize_inputs: bool = False, + download: bool = False, + **kwargs +): + """"Dataset for segmentation of breast cancer in ultrasound images. + + This database is located at https://scholar.cu.edu.eg/?q=afahmy/pages/dataset + + The dataset is from Al-Dhabyani et al. - https://doi.org/10.1016/j.dib.2019.104863 + Please cite it if you use this dataset for a publication. + """ + if category is not None: + assert category in ["normal", "benign", "malignant"] + + image_paths, gt_paths = _get_busi_paths(path=path, category=category, download=download) + + if resize_inputs: + raw_trafo = ResizeInputs(target_shape=patch_shape, is_rgb=True) + label_trafo = ResizeInputs(target_shape=patch_shape, is_label=True) + patch_shape = None + else: + patch_shape = patch_shape + raw_trafo, label_trafo = None, None + + dataset = ImageCollectionDataset( + raw_image_paths=image_paths, + label_image_paths=gt_paths, + patch_shape=patch_shape, + raw_transform=raw_trafo, + label_transform=label_trafo, + **kwargs + ) + + return dataset + + +def get_busi_loader( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int], + batch_size: int, + category: Optional[str] = None, + resize_inputs: bool = False, + download: bool = False, + **kwargs +): + """Dataloader for segmentation of breast cancer in ultrasound images. See `get_busi_dataset` for details. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_busi_dataset( + path=path, + patch_shape=patch_shape, + category=category, + resize_inputs=resize_inputs, + download=download, + **ds_kwargs + ) + loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) + return loader diff --git a/torch_em/data/datasets/medical/camus.py b/torch_em/data/datasets/medical/camus.py new file mode 100644 index 00000000..ff360a9e --- /dev/null +++ b/torch_em/data/datasets/medical/camus.py @@ -0,0 +1,97 @@ +import os +from glob import glob +from typing import Union, Tuple, Optional + +import torch_em + +from .. import util + + +URL = "https://humanheart-project.creatis.insa-lyon.fr/database/api/v1/folder/63fde55f73e9f004868fb7ac/download" + +# TODO: the checksums are different with each download, not sure why +# CHECKSUM = "43745d640db5d979332bda7f00f4746747a2591b46efc8f1966b573ce8d65655" + + +def get_camus_data(path, download): + os.makedirs(path, exist_ok=True) + + data_dir = os.path.join(path, "database_nifti") + if os.path.exists(data_dir): + return data_dir + + zip_path = os.path.join(path, "CAMUS.zip") + util.download_source(path=zip_path, url=URL, download=download, checksum=None) + util.unzip(zip_path=zip_path, dst=path) + + return data_dir + + +def _get_camus_paths(path, chamber, download): + data_dir = get_camus_data(path=path, download=download) + + if chamber is None: + chamber = "*" # 2CH / 4CH + else: + assert chamber in [2, 4], f"{chamber} is not a valid chamber choice for the acquisitions." + chamber = f"{chamber}CH" + + image_paths = sorted(glob(os.path.join(data_dir, "patient*", f"patient*_{chamber}_half_sequence.nii.gz"))) + gt_paths = sorted(glob(os.path.join(data_dir, "patient*", f"patient*_{chamber}_half_sequence_gt.nii.gz"))) + + return image_paths, gt_paths + + +def get_camus_dataset( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int], + chamber: Optional[int] = None, + resize_inputs: bool = False, + download: bool = False, + **kwargs +): + """Dataset for segmenting cardiac structures in 2d echocardiography images. + + The database is located at: + https://humanheart-project.creatis.insa-lyon.fr/database/#collection/6373703d73e9f0047faa1bc8 + + This dataset is from the CAMUS challenge - https://doi.org/10.1109/TMI.2019.2900516. + Please cite it if you use this dataset for a publication. + """ + image_paths, gt_paths = _get_camus_paths(path=path, chamber=chamber, download=download) + + if resize_inputs: + resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False} + kwargs, patch_shape = util.update_kwargs_for_resize_trafo( + kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs + ) + + dataset = torch_em.default_segmentation_dataset( + raw_paths=image_paths, + raw_key="data", + label_paths=gt_paths, + label_key="data", + patch_shape=patch_shape, + **kwargs + ) + + return dataset + + +def get_camus_loader( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int], + batch_size: int, + chamber: Optional[int] = None, + resize_inputs: bool = False, + download: bool = False, + **kwargs +): + """Dataloader for segmenting cardiac structures in 2d echocardiography images. See `get_camus_dataset` for details + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_camus_dataset( + path=path, patch_shape=patch_shape, chamber=chamber, resize_inputs=resize_inputs, download=download, **ds_kwargs + ) + loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) + return loader diff --git a/torch_em/data/datasets/medical/cbis_ddsm.py b/torch_em/data/datasets/medical/cbis_ddsm.py new file mode 100644 index 00000000..81a5830b --- /dev/null +++ b/torch_em/data/datasets/medical/cbis_ddsm.py @@ -0,0 +1,120 @@ +import os +from glob import glob +from natsort import natsorted +from typing import Union, Tuple, Literal, Optional + +import torch_em + +from .. import util + + +def get_cbis_ddsm_data(path, split, task, tumour_type, download): + os.makedirs(path, exist_ok=True) + + assert split in ["Train", "Test"] + + if task is None: + task = "*" + else: + assert task in ["Calc", "Mass"] + + if tumour_type is None: + tumour_type = "*" + else: + assert tumour_type in ["MALIGNANT", "BENIGN"] + + data_dir = os.path.join(path, "DATA") + if os.path.exists(data_dir): + return os.path.join(path, "DATA", task, split, tumour_type) + + util.download_source_kaggle( + path=path, dataset_name="mohamedbenticha/cbis-ddsm/", download=download, + ) + zip_path = os.path.join(path, "cbis-ddsm.zip") + util.unzip(zip_path=zip_path, dst=path) + return os.path.join(path, "DATA", task, split, tumour_type) + + +def _get_cbis_ddsm_paths(path, split, task, tumour_type, download): + data_dir = get_cbis_ddsm_data( + path=path, + split=split, + task=task, + tumour_type=tumour_type, + download=download + ) + + image_paths = natsorted(glob(os.path.join(data_dir, "*_FULL_*.png"))) + gt_paths = natsorted(glob(os.path.join(data_dir, "*_MASK_*.png"))) + + assert len(image_paths) == len(gt_paths) + + return image_paths, gt_paths + + +def get_cbis_ddsm_dataset( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int], + split: Literal["Train", "Test"], + task: Optional[Literal["Calc", "Mass"]] = None, + tumour_type: Optional[Literal["MALIGNANT", "BENIGN"]] = None, + resize_inputs: bool = False, + download: bool = False, + **kwargs +): + """Dataset for segmentation of calcification and mass in mammography. + + This dataset is a preprocessed version of https://www.cancerimagingarchive.net/collection/cbis-ddsm/ available + at https://www.kaggle.com/datasets/mohamedbenticha/cbis-ddsm/data. The related publication is: + - https://doi.org/10.1038/sdata.2017.177 + + Please cite it if you use this dataset in a publication. + """ + image_paths, gt_paths = _get_cbis_ddsm_paths( + path=path, split=split, task=task, tumour_type=tumour_type, download=download + ) + + if resize_inputs: + resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False} + kwargs, patch_shape = util.update_kwargs_for_resize_trafo( + kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs + ) + + dataset = torch_em.default_segmentation_dataset( + raw_paths=image_paths, + raw_key=None, + label_paths=gt_paths, + label_key=None, + patch_shape=patch_shape, + is_seg_dataset=False, + **kwargs + ) + return dataset + + +def get_cbis_ddsm_loader( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int], + batch_size: int, + split: Literal["Train", "Test"], + task: Optional[Literal["Calc", "Mass"]] = None, + tumour_type: Optional[Literal["MALIGNANT", "BENIGN"]] = None, + resize_inputs: bool = False, + download: bool = False, + **kwargs +): + """Dataloader for segmentation of calcification and mass in mammography. See `get_cbis_ddsm_dataset` for details. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_cbis_ddsm_dataset( + path=path, + patch_shape=patch_shape, + split=split, + task=task, + tumour_type=tumour_type, + resize_inputs=resize_inputs, + download=download, + **ds_kwargs + ) + loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) + return loader diff --git a/torch_em/data/datasets/medical/cholecseg8k.py b/torch_em/data/datasets/medical/cholecseg8k.py new file mode 100644 index 00000000..f0dfaa2e --- /dev/null +++ b/torch_em/data/datasets/medical/cholecseg8k.py @@ -0,0 +1,159 @@ +import os +import shutil +from glob import glob +from tqdm import tqdm +from pathlib import Path +from natsort import natsorted +from typing import Tuple, Union, Literal + +import numpy as np +import imageio.v3 as imageio + +import torch_em + +from .. import util + + +LABEL_MAPS = { + (255, 255, 255): 0, # small white frame around the image + (50, 50, 50): 0, # background + (11, 11, 11): 1, # abdominal wall + (21, 21, 21): 2, # liver + (13, 13, 13): 3, # gastrointestinal tract + (12, 12, 12): 4, # fat + (31, 31, 31): 5, # grasper + (23, 23, 23): 6, # connective tissue + (24, 24, 24): 7, # blood + (25, 25, 25): 8, # cystic dust + (32, 32, 32): 9, # l-hook electrocautery + (22, 22, 22): 10, # gallbladder + (33, 33, 33): 11, # hepatic vein + (5, 5, 5): 12 # liver ligament +} + + +def get_cholecseg8k_data(path, download): + os.makedirs(path, exist_ok=True) + + data_dir = os.path.join(path, "data") + if os.path.exists(data_dir): + return data_dir + + zip_path = os.path.join(path, "cholecseg8k.zip") + util.download_source_kaggle(path=zip_path, dataset_name="newslab/cholecseg8k", download=download) + util.unzip(zip_path=zip_path, dst=data_dir) + return data_dir + + +def _get_cholecseg8k_paths(path, split, download): + data_dir = get_cholecseg8k_data(path=path, download=download) + + video_dirs = natsorted(glob(os.path.join(data_dir, "video*"))) + if split == "train": + video_dirs = video_dirs[2:-2] + elif split == "val": + video_dirs = [video_dirs[1], video_dirs[-2]] + elif split == "test": + video_dirs = [video_dirs[0], video_dirs[-1]] + else: + raise ValueError(f"'{split}' is not a valid split.") + + ppdir = os.path.join(data_dir, "preprocessed", split) + if os.path.exists(ppdir): + _image_paths = natsorted(glob(os.path.join(ppdir, "images", "*"))) + _gt_paths = natsorted(glob(os.path.join(ppdir, "masks", "*"))) + return _image_paths, _gt_paths + + os.makedirs(os.path.join(ppdir, "images"), exist_ok=True) + os.makedirs(os.path.join(ppdir, "masks"), exist_ok=True) + + image_paths, gt_paths = [], [] + for video_dir in tqdm(video_dirs): + org_image_paths = natsorted(glob(os.path.join(video_dir, "video*", "*_endo.png"))) + org_gt_paths = natsorted(glob(os.path.join(video_dir, "video*", "*_endo_watershed_mask.png"))) + + for org_image_path, org_gt_path in zip(org_image_paths, org_gt_paths): + image_id = os.path.split(org_image_path)[-1] + + image_path = os.path.join(ppdir, "images", image_id) + gt_path = os.path.join(ppdir, "masks", Path(image_id).with_suffix(".tif")) + + image_paths.append(image_path) + gt_paths.append(gt_path) + + if os.path.exists(image_path) and os.path.exists(gt_path): + continue + + gt = imageio.imread(org_gt_path) + assert gt.ndim == 3 + if gt.shape[-1] != 3: # some labels have a 4th channel which has all values as 255 + print("Found a label with inconsistent format.") + # let's verify the case + assert np.unique(gt[..., -1]) == 255 + gt = gt[..., :3] + + instances = np.zeros(gt.shape[:2]) + for lmap in LABEL_MAPS: + binary_map = (gt == lmap).all(axis=2) + instances[binary_map > 0] = LABEL_MAPS[lmap] + + shutil.copy(src=org_image_path, dst=image_path) + imageio.imwrite(gt_path, instances, compression="zlib") + + return image_paths, gt_paths + + +def get_cholecseg8k_dataset( + path: Union[str, os.PathLike], + patch_shape: Tuple[int, int], + split: Literal["train", "val", "test"], + resize_inputs: bool = False, + download: bool = False, + **kwargs +): + """Dataset for segmentation of organs and instruments in endoscopy. + + This dataset is from Twinanda et al. - https://doi.org/10.48550/arXiv.1602.03012 + + This dataset is located at https://www.kaggle.com/datasets/newslab/cholecseg8k/data + + Please cite it if you use this data in a publication. + """ + image_paths, gt_paths = _get_cholecseg8k_paths(path=path, split=split, download=download) + + if resize_inputs: + resize_kwargs = {"patch_shape": patch_shape, "is_rgb": True} + kwargs, patch_shape = util.update_kwargs_for_resize_trafo( + kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs + ) + + dataset = torch_em.default_segmentation_dataset( + raw_paths=image_paths, + raw_key=None, + label_paths=gt_paths, + label_key=None, + is_seg_dataset=False, + patch_shape=patch_shape, + **kwargs + ) + + return dataset + + +def get_cholecseg8k_loader( + path: Union[str, os.PathLike], + patch_shape: Tuple[int, int], + batch_size: int, + split: Literal["train", "val", "test"], + resize_inputs: bool = False, + download: bool = False, + **kwargs +): + """Dataloader for segmentation of organs and instruments in endoscopy. See `get_cholecseg_dataset` for details. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_cholecseg8k_dataset( + path=path, patch_shape=patch_shape, split=split, resize_inputs=resize_inputs, download=download, **ds_kwargs + ) + loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) + return loader diff --git a/torch_em/data/datasets/medical/covid19_seg.py b/torch_em/data/datasets/medical/covid19_seg.py new file mode 100644 index 00000000..65ecc354 --- /dev/null +++ b/torch_em/data/datasets/medical/covid19_seg.py @@ -0,0 +1,118 @@ +import os +from glob import glob +from pathlib import Path +from typing import Union, Tuple, Optional + +import torch_em + +from .. import util + + +URL = { + "images": "https://zenodo.org/records/3757476/files/COVID-19-CT-Seg_20cases.zip", + "lung_and_infection": "https://zenodo.org/records/3757476/files/Lung_and_Infection_Mask.zip", + "lung": "https://zenodo.org/records/3757476/files/Lung_Mask.zip", + "infection": "https://zenodo.org/records/3757476/files/Infection_Mask.zip" +} + +CHECKSUM = { + "images": "a5060480eff9315b069b086312dac4872777901fb80d268a5a83edd9f4e7b440", + "lung_and_infection": "34f5a573cb8fb53cb15abe81868395d9addf436854826a6fd6e70c2b294f19c3", + "lung": "f060b0d0299939a6d95ddefdbfa281de1a779c4d230a5adbd32414711d6d8187", + "infection": "87901c73fdd2230260e61d2dbc57bf56026efc28264006b8ea2bf411453c1694" +} + +ZIP_FNAMES = { + "images": "COVID-19-CT-Seg_20cases.zip", + "lung_and_infection": "Lung_and_Infection_Mask.zip", + "lung": "Lung_Mask.zip", + "infection": "Infection_Mask.zip" +} + + +def get_covid19_seg_data(path, task, download): + os.makedirs(path, exist_ok=True) + + im_dir = os.path.join(path, "images", Path(ZIP_FNAMES["images"]).stem) + gt_dir = os.path.join(path, "gt", Path(ZIP_FNAMES[task]).stem) + + if os.path.exists(im_dir) and os.path.exists(gt_dir): + return im_dir, gt_dir + + im_zip_path = os.path.join(path, ZIP_FNAMES["images"]) + gt_zip_path = os.path.join(path, ZIP_FNAMES[task]) + + # download the images + util.download_source( + path=im_zip_path, url=URL["images"], download=download, checksum=CHECKSUM["images"] + ) + util.unzip(zip_path=im_zip_path, dst=im_dir, remove=False) + + # download the gt + util.download_source( + path=gt_zip_path, url=URL[task], download=download, checksum=CHECKSUM[task] + ) + util.unzip(zip_path=gt_zip_path, dst=gt_dir) + + return im_dir, gt_dir + + +def _get_covid19_seg_paths(path, task, download): + image_dir, gt_dir = get_covid19_seg_data(path=path, task=task, download=download) + + image_paths = sorted(glob(os.path.join(image_dir, "*.nii.gz"))) + gt_paths = sorted(glob(os.path.join(gt_dir, "*.nii.gz"))) + + return image_paths, gt_paths + + +def get_covid19_seg_dataset( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int], + task: Optional[str] = None, + download: bool = False, + **kwargs +): + """Dataset for lung and covid infection segmentation in CT scans. + + The database is located at https://doi.org/10.5281/zenodo.3757476. + + This dataset is from Ma et al. - https://doi.org/10.1002/mp.14676. + Please cite it if you use this dataset for a publication. + """ + if task is None: + task = "lung_and_infection" + else: + assert task in ["lung", "infection", "lung_and_infection"], f"{task} is not a valid task." + + image_paths, gt_paths = _get_covid19_seg_paths(path, task, download) + + dataset = torch_em.default_segmentation_dataset( + raw_paths=image_paths, + raw_key="data", + label_paths=gt_paths, + label_key="data", + patch_shape=patch_shape, + is_seg_dataset=True, + **kwargs + ) + + return dataset + + +def get_covid19_seg_loader( + path: Union[os.PathLike, str], + batch_size: int, + patch_shape: Tuple[int, int], + task: Optional[str] = None, + download: bool = False, + **kwargs +): + """Dataloader for lung and covid infection segmentation in CT scans. See `get_covid19_seg_dataset` for details. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_covid19_seg_dataset( + path=path, patch_shape=patch_shape, task=task, download=download, **ds_kwargs + ) + loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) + return loader diff --git a/torch_em/data/datasets/medical/curvas.py b/torch_em/data/datasets/medical/curvas.py new file mode 100644 index 00000000..521e16e1 --- /dev/null +++ b/torch_em/data/datasets/medical/curvas.py @@ -0,0 +1,146 @@ +"""The CURVAS dataset contains annotations for pancreas, kidney and liver +in abdominal CT scans. + +This dataset is from the challenge: https://curvas.grand-challenge.org. +The dataset is located at: https://zenodo.org/records/12687192. +Please cite tem if you use this dataset for your research. +""" + +import os +import subprocess +from glob import glob +from natsort import natsorted +from typing import Tuple, Union, Literal, List + +import torch_em + +from .. import util + + +URL = "https://zenodo.org/records/12687192/files/training_set.zip" +CHECKSUM = "1126a2205553ae1d4fe5fbaee7ea732aacc4f5a92b96504ed521c23e5a0e3f89" + + +def get_curvas_data(path: Union[os.PathLike, str], download: bool = False) -> str: + """Download the CURVAS dataset. + + Args: + path: Filepath to a folder where the data is downloaded for further processing. + download: Whether to download the data if it is not present. + + Returns: + Filepath where the data is downloaded. + """ + data_dir = os.path.join(path, "training_set") + if os.path.exists(data_dir): + return data_dir + + os.makedirs(path, exist_ok=True) + + zip_path = os.path.join(path, "training_set.zip") + util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM) + + # HACK: The zip file is broken. We fix it using the following script. + fixed_zip_path = os.path.join(path, "training_set_fixed.zip") + subprocess.run(["zip", "-FF", zip_path, "--out", fixed_zip_path]) + subprocess.run(["unzip", fixed_zip_path, "-d", path]) + + return data_dir + + +def get_curvas_paths( + path: Union[os.PathLike, str], rater: Literal["1"] = "1", download: bool = False +) -> Tuple[List[str], List[str]]: + """Get paths to the CURVAS data. + + Args: + path: Filepath to a folder where the data is downloaded for further processing. + rater: The choice of rater providing the annotations. + download: Whether to download the data if it is not present. + + Returns: + List of filepaths for the image data. + List of filepaths for the label data. + """ + data_dir = get_curvas_data(path=path, download=download) + + if not isinstance(rater, list): + rater = [rater] + + assert len(rater) == 1, "The segmentations for multiple raters is not supported at the moment." + + image_paths = natsorted(glob(os.path.join(data_dir, "*", "image.nii.gz"))) + gt_paths = [] + for _rater in rater: + gt_paths.extend(natsorted(glob(os.path.join(data_dir, "*", f"annotation_{_rater}.nii.gz")))) + + return image_paths, gt_paths + + +def get_curvas_dataset( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, ...], + rater: Literal["1"] = "1", + resize_inputs: bool = False, + download: bool = False, + **kwargs +): + """Get the CURVAS dataset for pancreas, kidney and liver segmentation. + + Args: + path: Filepath to a folder where the data is downloaded for further processing. + patch_shape: The patch shape to use for training. + rater: The choice of rater providing the annotations. + resize_inputs: Whether to resize inputs to the desired patch shape. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The segmentationd dataset. + """ + image_paths, gt_paths = get_curvas_paths(path, rater, download) + + if resize_inputs: + resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False} + kwargs, patch_shape = util.update_kwargs_for_resize_trafo( + kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs + ) + + dataset = torch_em.default_segmentation_dataset( + raw_paths=image_paths, + raw_key="data", + label_paths=gt_paths, + label_key="data", + patch_shape=patch_shape, + **kwargs + ) + + return dataset + + +def get_curvas_loader( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, ...], + batch_size: int, + rater: Literal["1"] = "1", + resize_inputs: bool = False, + download: bool = False, + **kwargs +): + """Get the CURVAS dataloader for pancreas, kidney and liver segmentation. + + Args: + path: Filepath to a folder where the data is downloaded for further processing. + patch_shape: The patch shape to use for training. + batch_size: The batch size for training. + rater: The choice of rater providing the annotations. + resize_inputs: Whether to resize inputs to the desired patch shape. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch dataloader. + + Returns: + The DataLoader. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_curvas_dataset(path, patch_shape, rater, resize_inputs, download, **ds_kwargs) + return torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/medical/dca1.py b/torch_em/data/datasets/medical/dca1.py new file mode 100644 index 00000000..df06b88c --- /dev/null +++ b/torch_em/data/datasets/medical/dca1.py @@ -0,0 +1,105 @@ +import os +from glob import glob +from natsort import natsorted +from typing import Union, Tuple, Literal + +import torch_em + +from .. import util + + +URL = "http://personal.cimat.mx:8181/~ivan.cruz/DB_Angiograms_files/DB_Angiograms_134.zip" +CHECKSUM = "7161638a6e92c6a6e47a747db039292c8a1a6bad809aac0d1fd16a10a6f22a11" + + +def get_dca1_data(path, download): + os.makedirs(path, exist_ok=True) + + data_dir = os.path.join(path, "Database_134_Angiograms") + if os.path.exists(data_dir): + return data_dir + + zip_path = os.path.join(path, "DB_Angiograms_134.zip") + util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM) + util.unzip(zip_path=zip_path, dst=path) + return data_dir + + +def _get_dca1_paths(path, split, download): + data_dir = get_dca1_data(path=path, download=download) + + image_paths, gt_paths = [], [] + for image_path in natsorted(glob(os.path.join(data_dir, "*.pgm"))): + if image_path.endswith("_gt.pgm"): + gt_paths.append(image_path) + else: + image_paths.append(image_path) + + image_paths, gt_paths = natsorted(image_paths), natsorted(gt_paths) + + if split == "train": # first 85 images + image_paths, gt_paths = image_paths[:-49], gt_paths[:-49] + elif split == "val": # 15 images + image_paths, gt_paths = image_paths[-49:-34], gt_paths[-49:-34] + elif split == "test": # last 34 images + image_paths, gt_paths = image_paths[-34:], gt_paths[-34:] + else: + raise ValueError(f"'{split}' is not a valid split.") + + return image_paths, gt_paths + + +def get_dca1_dataset( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int], + split: Literal["train", "val", "test"], + resize_inputs: bool = False, + download: bool = False, + **kwargs +): + """Dataset for segmentation of coronary ateries in x-ray angiography. + + This dataset is from Cervantes-Sanchez et al. - https://doi.org/10.3390/app9245507. + + The database is located at http://personal.cimat.mx:8181/~ivan.cruz/DB_Angiograms.html. + + Please cite it if you use this dataset in a publication. + """ + image_paths, gt_paths = _get_dca1_paths(path=path, split=split, download=download) + + if resize_inputs: + resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False} + kwargs, patch_shape = util.update_kwargs_for_resize_trafo( + kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs + ) + + dataset = torch_em.default_segmentation_dataset( + raw_paths=image_paths, + raw_key=None, + label_paths=gt_paths, + label_key=None, + patch_shape=patch_shape, + is_seg_dataset=False, + **kwargs + ) + + return dataset + + +def get_dca1_loader( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int], + batch_size: int, + split: Literal["train", "val", "test"], + resize_inputs: bool = False, + download: bool = False, + **kwargs +): + """Dataloader for segmentation of coronary ateries in x-ray angiography. See `get_dca1_dataset` for details. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_dca1_dataset( + path=path, patch_shape=patch_shape, split=split, resize_inputs=resize_inputs, download=download, **ds_kwargs + ) + loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) + return loader diff --git a/torch_em/data/datasets/medical/drive.py b/torch_em/data/datasets/medical/drive.py new file mode 100644 index 00000000..da06a2a1 --- /dev/null +++ b/torch_em/data/datasets/medical/drive.py @@ -0,0 +1,132 @@ +import os +from glob import glob +from pathlib import Path +from typing import Union, Tuple + +import imageio.v3 as imageio + +import torch_em + +from .. import util + + +URL = { + "train": "https://www.dropbox.com/sh/z4hbbzqai0ilqht/AADp_8oefNFs2bjC2kzl2_Fqa/training.zip?dl=1", + "test": "https://www.dropbox.com/sh/z4hbbzqai0ilqht/AABuUJQJ5yG5oCuziYzYu8jWa/test.zip?dl=1" +} + +CHECKSUM = { + "train": "7101e19598e2b7aacdbd5e6e7575057b9154a4aaec043e0f4e28902bf4e2e209", + "test": "d76c95c98a0353487ffb63b3bb2663c00ed1fde7d8fdfd8c3282c6e310a02731" +} + + +def get_drive_data(path, download): + os.makedirs(path, exist_ok=True) + + data_dir = os.path.join(path, "training") + if os.path.exists(data_dir): + return data_dir + + zip_path = os.path.join(path, "training.zip") + util.download_source_gdrive( + path=zip_path, url=URL["train"], download=download, checksum=CHECKSUM["train"], download_type="zip", + ) + util.unzip(zip_path=zip_path, dst=path) + + return data_dir + + +def _get_drive_ground_truth(data_dir): + gt_paths = sorted(glob(os.path.join(data_dir, "1st_manual", "*.gif"))) + + neu_gt_dir = os.path.join(data_dir, "gt") + if os.path.exists(neu_gt_dir): + return sorted(glob(os.path.join(neu_gt_dir, "*.tif"))) + else: + os.makedirs(neu_gt_dir, exist_ok=True) + + neu_gt_paths = [] + for gt_path in gt_paths: + gt = imageio.imread(gt_path).squeeze() + neu_gt_path = os.path.join( + neu_gt_dir, Path(os.path.split(gt_path)[-1]).with_suffix(".tif") + ) + imageio.imwrite(neu_gt_path, (gt > 0).astype("uint8")) + neu_gt_paths.append(neu_gt_path) + + return neu_gt_paths + + +def _get_drive_paths(path, split, download): + data_dir = get_drive_data(path=path, download=download) + + image_paths = sorted(glob(os.path.join(data_dir, "images", "*.tif"))) + gt_paths = _get_drive_ground_truth(data_dir) + + if split == "train": + image_paths, gt_paths = image_paths[:10], gt_paths[:10] + elif split == "val": + image_paths, gt_paths = image_paths[10:14], gt_paths[10:14] + elif split == "test": + image_paths, gt_paths = image_paths[14:], gt_paths[14:] + else: + raise ValueError(f"'{split}' is not a valid split.") + + return image_paths, gt_paths + + +def get_drive_dataset( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int], + split: str, + resize_inputs: bool = False, + download: bool = False, + **kwargs +): + """Dataset for segmentation of retinal blood vessels in fundus images. + + This dataset is from the "DRIVE" challenge: + - https://drive.grand-challenge.org/ + - https://doi.org/10.1109/TMI.2004.825627 + + Please cite it if you use this dataset for a publication. + """ + image_paths, gt_paths = _get_drive_paths(path=path, split=split, download=download) + + if resize_inputs: + resize_kwargs = {"patch_shape": patch_shape, "is_rgb": True} + kwargs, patch_shape = util.update_kwargs_for_resize_trafo( + kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs + ) + + dataset = torch_em.default_segmentation_dataset( + raw_paths=image_paths, + raw_key=None, + label_paths=gt_paths, + label_key=None, + patch_shape=patch_shape, + is_seg_dataset=False, + **kwargs + ) + + return dataset + + +def get_drive_loader( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int], + split: str, + batch_size: int, + resize_inputs: bool = False, + download: bool = False, + **kwargs +): + """Dataloader for segmentation of retinal blood vessels in fundus images. See `get_drive_dataset` for details. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_drive_dataset( + path=path, patch_shape=patch_shape, split=split, resize_inputs=resize_inputs, download=download, **ds_kwargs + ) + loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) + return loader diff --git a/torch_em/data/datasets/medical/duke_liver.py b/torch_em/data/datasets/medical/duke_liver.py new file mode 100644 index 00000000..3cc897c9 --- /dev/null +++ b/torch_em/data/datasets/medical/duke_liver.py @@ -0,0 +1,162 @@ +import os +from glob import glob +from tqdm import tqdm +from natsort import natsorted + +import numpy as np + +import torch_em + +from .. import util + + +def get_duke_liver_data(path, download): + """The dataset is located at https://doi.org/10.5281/zenodo.7774566. + + Follow the instructions below to get access to the dataset. + - Visit the zenodo site attached above. + - Send a request message alongwith some details to get access to the dataset. + - The authors would accept the request, then you can access the dataset. + - Next, download the `Segmentation.zip` file and provide the path where the zip file is stored. + """ + if download: + raise NotImplementedError( + "Automatic download for Duke Liver dataset is not possible. See `get_duke_liver_data` for details." + ) + + data_dir = os.path.join(path, "data", "Segmentation") + if os.path.exists(data_dir): + return data_dir + + zip_path = os.path.join(path, "Segmentation.zip") + util.unzip(zip_path=zip_path, dst=os.path.join(path, "data"), remove=False) + return data_dir + + +def _preprocess_data(path, data_dir): + preprocess_dir = os.path.join(path, "data", "preprocessed") + + if os.path.exists(preprocess_dir): + _image_paths = natsorted(glob(os.path.join(preprocess_dir, "images", "*.nii.gz"))) + _gt_paths = natsorted(glob(os.path.join(preprocess_dir, "masks", "*.nii.gz"))) + return _image_paths, _gt_paths + + os.makedirs(os.path.join(preprocess_dir, "images"), exist_ok=True) + os.makedirs(os.path.join(preprocess_dir, "masks"), exist_ok=True) + + import pydicom as dicom + import nibabel as nib + + image_paths, gt_paths = [], [] + for patient_dir in tqdm(glob(os.path.join(data_dir, "00*"))): + patient_id = os.path.split(patient_dir)[-1] + + for sub_id_dir in glob(os.path.join(patient_dir, "*")): + sub_id = os.path.split(sub_id_dir)[-1] + + image_path = os.path.join(preprocess_dir, "images", f"{patient_id}_{sub_id}.nii.gz") + gt_path = os.path.join(preprocess_dir, "masks", f"{patient_id}_{sub_id}.nii.gz") + + image_paths.append(image_path) + gt_paths.append(gt_path) + + if os.path.exists(image_path) and os.path.exists(gt_path): + continue + + image_slice_paths = natsorted(glob(os.path.join(sub_id_dir, "images", "*.dicom"))) + gt_slice_paths = natsorted(glob(os.path.join(sub_id_dir, "masks", "*.dicom"))) + + images, gts = [], [] + for image_slice_path, gt_slice_path in zip(image_slice_paths, gt_slice_paths): + image_slice = dicom.dcmread(image_slice_path).pixel_array + gt_slice = dicom.dcmread(gt_slice_path).pixel_array + + images.append(image_slice) + gts.append(gt_slice) + + image = np.stack(images).transpose(1, 2, 0) + gt = np.stack(gts).transpose(1, 2, 0) + + assert image.shape == gt.shape + + image = nib.Nifti2Image(image, np.eye(4)) + gt = nib.Nifti2Image(gt, np.eye(4)) + + nib.save(image, image_path) + nib.save(gt, gt_path) + + return natsorted(image_paths), natsorted(gt_paths) + + +def _get_duke_liver_paths(path, split, download): + data_dir = get_duke_liver_data(path=path, download=download) + + image_paths, gt_paths = _preprocess_data(path=path, data_dir=data_dir) + + if split == "train": + image_paths, gt_paths = image_paths[:250], gt_paths[:250] + elif split == "val": + image_paths, gt_paths = image_paths[250:260], gt_paths[250:260] + elif split == "test": + image_paths, gt_paths = image_paths[260:], gt_paths[260:] + else: + raise ValueError(f"'{split}' is not a valid split.") + + return image_paths, gt_paths + + +def get_duke_liver_dataset( + path, + patch_shape, + split, + resize_inputs=False, + download=False, + **kwargs +): + """Dataset for segmentation of liver in MRI. + + This dataset is from Macdonald et al. - https://doi.org/10.1148/ryai.220275. + + The dataset is located at https://doi.org/10.5281/zenodo.7774566 (see `get_duke_liver_dataset` for details). + + Please cite it if you use it in a publication. + """ + + image_paths, gt_paths = _get_duke_liver_paths(path=path, split=split, download=download) + + if resize_inputs: + resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False} + kwargs, patch_shape = util.update_kwargs_for_resize_trafo( + kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs + ) + + dataset = torch_em.default_segmentation_dataset( + raw_paths=image_paths, + raw_key="data", + label_paths=gt_paths, + label_key="data", + is_seg_dataset=True, + patch_shape=patch_shape, + **kwargs + ) + + return dataset + + +def get_duke_liver_loader( + path, + patch_shape, + batch_size, + split, + resize_inputs=False, + download=False, + **kwargs +): + """Dataloader for segmentation of liver in MRI. See `get_duke_liver_dataset` for details. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_duke_liver_dataset( + path=path, patch_shape=patch_shape, split=split, resize_inputs=resize_inputs, download=download, **ds_kwargs + ) + loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) + return loader diff --git a/torch_em/data/datasets/medical/feta24.py b/torch_em/data/datasets/medical/feta24.py new file mode 100644 index 00000000..be76600a --- /dev/null +++ b/torch_em/data/datasets/medical/feta24.py @@ -0,0 +1,109 @@ +import os +from glob import glob +from natsort import natsorted +from typing import Union, Tuple + +import torch_em + +from .. import util + + +def get_feta24_data(path, download): + """This function describes the download fucntionality and ensures your data has been downloaded in expected format. + + The dataset is from the FeTa Challenge 2024 - https://fetachallenge.github.io/ (Task 1: Segmentation). + A detailed description of the dataset is provided here: https://fetachallenge.github.io/pages/Data_description. + To download the dataset, please follow the below mentioned steps: + - Go to the section `1. Request access and download the FeTa 2024 data from the University Children's Hospital + Zurich` at `https://fetachallenge.github.io/pages/Data_download`, which explains the steps to be a registered user + in Synapse platform and expects the user to agree with the mentioned conditions. + - While registration, the users are expected to provide some information + (see https://fetachallenge.github.io/pages/Data_download for details). + - Next, you can proceed with requesting access (by following provided instructions) at + https://www.synapse.org/#!Synapse:syn25649159/wiki/610007. + + Once you have access to the dataset, you can use the synapse client or the platform download option to get + the zipped files. It contains 80 scans paired with their segmentations (more details in the challenge website). + + Finally, you should provide the path to the parent directory where the zipfile is stored. + """ + if download: + print("Download is not supported due to the challenge's setup. See 'get_feta24_data' for details.") + + data_dir = os.path.join(path, "feta_2.3") + if os.path.exists(data_dir): + return data_dir + + zip_path = os.path.join(path, "feta_2.3.zip") + if not os.path.exists(zip_path): + raise FileNotFoundError(f"The downloaded zip file was not found. Please download it and place it at '{path}'.") + + util.unzip(zip_path=zip_path, dst=path) + + return data_dir + + +def _get_feta24_paths(path, download): + data_dir = get_feta24_data(path=path, download=download) + + base_dir = os.path.join(data_dir, "sub-*", "anat") + image_paths = natsorted(glob(os.path.join(base_dir, "sub-*_rec-*_T2w.nii.gz"))) + gt_paths = natsorted(glob(os.path.join(base_dir, "sub-*_rec-*_dseg.nii.gz"))) + + return image_paths, gt_paths + + +def get_feta24_dataset( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, ...], + resize_inputs: bool = False, + download: bool = False, + **kwargs +): + """Dataset for segmentation of fetal brain tissues in MRI. + + The dataset cannot be automatically download. See `get_feta24_data` for details. + + This dataset is from FeTa 2024 Challenge: + - https://doi.org/10.5281/zenodo.11192452 + - Payete et al. - https://doi.org/10.1038/s41597-021-00946-3 + + Please cite it if you use this dataset in your publication. + """ + image_paths, gt_paths = _get_feta24_paths(path=path, download=download) + + if resize_inputs: + resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False} + kwargs, patch_shape = util.update_kwargs_for_resize_trafo( + kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs + ) + + dataset = torch_em.default_segmentation_dataset( + raw_paths=image_paths, + raw_key="data", + label_paths=gt_paths, + label_key="data", + patch_shape=patch_shape, + **kwargs + ) + + return dataset + + +def get_feta24_loader( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, ...], + batch_size: int, + resize_inputs: bool = False, + download: bool = False, + **kwargs +): + """Dataloader for segmentation of fetal brain tissues in MRI. + See `get_feta24_dataset` for details. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_feta24_dataset( + path=path, patch_shape=patch_shape, resize_inputs=resize_inputs, download=download, **ds_kwargs + ) + loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) + return loader diff --git a/torch_em/data/datasets/medical/han_seg.py b/torch_em/data/datasets/medical/han_seg.py new file mode 100644 index 00000000..cc322ab2 --- /dev/null +++ b/torch_em/data/datasets/medical/han_seg.py @@ -0,0 +1,128 @@ +import os +from glob import glob +from tqdm import tqdm +from pathlib import Path +from natsort import natsorted +from typing import Union, Tuple + +import torch_em + +from .. import util + + +URL = "https://zenodo.org/records/7442914/files/HaN-Seg.zip" +CHECKSUM = "20226dd717f334dc1b1afe961b3375f946fa56b64a80bf5349128f90c0bbfa5f" + + +def get_han_seg_data(path, download): + os.makedirs(path, exist_ok=True) + + data_dir = os.path.join(path, "HaN-Seg") + if os.path.exists(data_dir): + return data_dir + + zip_path = os.path.join(path, "HaN-Seg.zip") + util.download_source( + path=zip_path, url=URL, download=download, checksum=CHECKSUM + ) + util.unzip(zip_path=zip_path, dst=path, remove=False) + + return data_dir + + +def _get_han_seg_paths(path, download): + import nrrd + import numpy as np + import nibabel as nib + + data_dir = get_han_seg_data(path=path, download=download) + + image_dir = os.path.join(data_dir, "set_1", "preprocessed", "images") + gt_dir = os.path.join(data_dir, "set_1", "preprocessed", "ground_truth") + os.makedirs(image_dir, exist_ok=True) + os.makedirs(gt_dir, exist_ok=True) + + image_paths, gt_paths = [], [] + all_case_dirs = natsorted(glob(os.path.join(data_dir, "set_1", "case_*"))) + for case_dir in tqdm(all_case_dirs): + image_path = os.path.join(image_dir, f"{os.path.split(case_dir)[-1]}_ct.nii.gz") + gt_path = os.path.join(gt_dir, f"{os.path.split(case_dir)[-1]}.nii.gz") + image_paths.append(image_path) + gt_paths.append(gt_path) + if os.path.exists(image_path) and os.path.exists(gt_path): + continue + + all_nrrd_paths = natsorted(glob(os.path.join(case_dir, "*.nrrd"))) + all_volumes, all_volume_ids = [], [] + for nrrd_path in all_nrrd_paths: + image_id = Path(nrrd_path).stem + + # we skip the MRI volumes + if image_id.endswith("_MR_T1"): + continue + + data, header = nrrd.read(nrrd_path) + all_volumes.append(data) + all_volume_ids.append(image_id) + + raw = all_volumes[0] + raw = nib.Nifti2Image(raw, np.eye(4)) + nib.save(raw, image_path) + + gt = np.zeros(raw.shape) + for idx, per_organ in enumerate(all_volumes[1:], 1): + gt[per_organ > 0] = idx + gt = nib.Nifti2Image(gt, np.eye(4)) + nib.save(gt, gt_path) + + return image_paths, gt_paths + + +def get_han_seg_dataset( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, ...], + resize_inputs: bool = False, + download: bool = False, + **kwargs +): + """Dataset for head and neck organ-at-rish segmentation in CT scans. + + This dataset is from Podobnik et al. - https://doi.org/10.1002/mp.16197 + Please cite it if you use it in a publication. + """ + image_paths, gt_paths = _get_han_seg_paths(path=path, download=download) + + if resize_inputs: + resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False} + kwargs, patch_shape = util.update_kwargs_for_resize_trafo( + kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs, + ) + + dataset = torch_em.default_segmentation_dataset( + raw_paths=image_paths, + raw_key="data", + label_paths=gt_paths, + label_key="data", + patch_shape=patch_shape, + **kwargs + ) + + return dataset + + +def get_han_seg_loader( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, ...], + batch_size: int, + resize_inputs: bool = False, + download: bool = False, + **kwargs +): + """Dataloader for for head and neck organ-at-rish segmentation in CT scans. See `get_han_seg_dataset` for details. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_han_seg_dataset( + path=path, patch_shape=patch_shape, resize_inputs=resize_inputs, download=download, **ds_kwargs + ) + loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) + return loader diff --git a/torch_em/data/datasets/medical/hil_toothseg.py b/torch_em/data/datasets/medical/hil_toothseg.py new file mode 100644 index 00000000..e6eb9d4c --- /dev/null +++ b/torch_em/data/datasets/medical/hil_toothseg.py @@ -0,0 +1,176 @@ +"""The HIL ToothSeg dataset contains annotations for teeth segmentation +in panoramic dental radiographs. + +This dataset is from the publication https://www.mdpi.com/1424-8220/21/9/3110. +Please cite it if you use this dataset for your research. +""" + +import os +from glob import glob +from tqdm import tqdm +from pathlib import Path +from natsort import natsorted +from typing import Union, Literal, Tuple, List + +import numpy as np +import imageio.v3 as imageio + +from torch.utils.data import Dataset, DataLoader + +import torch_em + +from .. import util + + +URL = "https://hitl-public-datasets.s3.eu-central-1.amazonaws.com/Teeth+Segmentation.zip" +CHECKSUM = "3b628165a218a5e8d446d1313e6ecbe7cfc599a3d6418cd60b4fb78745becc2e" + + +def get_hil_toothseg_data(path: Union[os.PathLike, str], download: bool = False): + """Download the HIL ToothSeg dataset. + + Args: + path: Filepath to a folder where the data is downloaded for further processing. + download: Whether to download the data if it is not present. + """ + data_dir = os.path.join(path, r"Teeth Segmentation PNG") + if os.path.exists(data_dir): + return data_dir + + os.makedirs(path, exist_ok=True) + + zip_path = os.path.join(path, "Teeth_Segmentation.zip") + util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM) + util.unzip(zip_path=zip_path, dst=path) + + return data_dir + + +def get_hil_toothseg_paths( + path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False +) -> Tuple[List[str], List[str]]: + """Get paths to the HIL ToothSeg data. + + Args: + path: Filepath to a folder where the data is downloaded for further processing. + split: The data split to use. Either 'train', 'val' or 'test'. + download: Whether to download the data if it is not present. + + Returns: + List of filepaths for the image data. + List of filepaths for the label data. + """ + import cv2 as cv + + data_dir = get_hil_toothseg_data(path=path, download=download) + + image_paths = natsorted(glob(os.path.join(data_dir, "d2", "img", "*"))) + gt_paths = natsorted(glob(os.path.join(data_dir, "d2", "masks_machine", "*"))) + + neu_gt_dir = os.path.join(data_dir, "preprocessed", "gt") + os.makedirs(neu_gt_dir, exist_ok=True) + + neu_gt_paths = [] + for gt_path in tqdm(gt_paths): + neu_gt_path = os.path.join(neu_gt_dir, f"{Path(gt_path).stem}.tif") + neu_gt_paths.append(neu_gt_path) + if os.path.exists(neu_gt_path): + continue + + rgb_gt = cv.imread(gt_path) + rgb_gt = cv.cvtColor(rgb_gt, cv.COLOR_BGR2RGB) + incolors = np.unique(rgb_gt.reshape(-1, rgb_gt.shape[2]), axis=0) + + # the first id is always background, let's remove it + if np.array_equal(incolors[0], np.array([0, 0, 0])): + incolors = incolors[1:] + + instances = np.zeros(rgb_gt.shape[:2]) + + color_to_id = {tuple(cvalue): i for i, cvalue in enumerate(incolors, start=1)} + for cvalue, idx in color_to_id.items(): + binary_map = (rgb_gt == cvalue).all(axis=2) + instances[binary_map] = idx + + imageio.imwrite(neu_gt_path, instances) + + if split == "train": + image_paths, neu_gt_paths = image_paths[:450], neu_gt_paths[:450] + elif split == "val": + image_paths, neu_gt_paths = image_paths[425:475], neu_gt_paths[425:475] + elif split == "test": + image_paths, neu_gt_paths = image_paths[475:], neu_gt_paths[475:] + else: + raise ValueError(f"{split} is not a valid split.") + + return image_paths, neu_gt_paths + + +def get_hil_toothseg_dataset( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int], + split: Literal["train", "val", "test"], + resize_inputs: bool = False, + download: bool = False, + **kwargs +) -> Dataset: + """Get the HIL ToothSeg dataset for teeth segmentation. + + Args: + path: Filepath to a folder where the data is downloaded for further processing. + patch_shape: The patch shape to use for training. + split: The data split to use. Either 'train', 'val' or 'test'. + resize_inpts: Whether to resize the inputs to the patch shape. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The segmentation dataset. + """ + image_paths, gt_paths = get_hil_toothseg_paths(path=path, split=split, download=download) + + if resize_inputs: + resize_kwargs = {"patch_shape": patch_shape, "is_rgb": True} + kwargs, patch_shape = util.update_kwargs_for_resize_trafo( + kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs + ) + + return torch_em.default_segmentation_dataset( + raw_paths=image_paths, + raw_key=None, + label_paths=gt_paths, + label_key=None, + is_seg_dataset=False, + patch_shape=patch_shape, + **kwargs + ) + + +def get_hil_toothseg_loader( + path: Union[os.PathLike, str], + batch_size: int, + patch_shape: Tuple[int, int], + split: Literal["train", "val", "test"], + resize_inputs: bool = False, + download: bool = False, + **kwargs +) -> DataLoader: + """Get the HIL ToothSeg dataloader for teeth segmentation. + + Args: + path: Filepath to a folder where the data is downloaded for further processing. + batch_size: The batch size for training. + patch_shape: The patch shape to use for training. + split: The data split to use. Either 'train', 'val' or 'test'. + resize_inpts: Whether to resize the inputs to the patch shape. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. + + Returns: + The DataLoader. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_hil_toothseg_dataset( + path=path, split=split, patch_shape=patch_shape, resize_inputs=resize_inputs, download=download, **ds_kwargs + ) + return torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/medical/idrid.py b/torch_em/data/datasets/medical/idrid.py new file mode 100644 index 00000000..0f65e91d --- /dev/null +++ b/torch_em/data/datasets/medical/idrid.py @@ -0,0 +1,124 @@ +import os +from glob import glob +from pathlib import Path +from typing import Union, Tuple + +import torch_em + +from .. import util + + +TASKS = { + "microaneurysms": r"1. Microaneurysms", + "haemorrhages": r"2. Haemorrhages", + "hard_exudates": r"3. Hard Exudates", + "soft_exudates": r"4. Soft Exudates", + "optic_disc": r"5. Optic Disc" +} + + +def get_idrid_data(path, download): + os.makedirs(path, exist_ok=True) + + data_dir = os.path.join(path, "data", "A.%20Segmentation") + if os.path.exists(data_dir): + return data_dir + + util.download_source_kaggle( + path=path, dataset_name="aaryapatel98/indian-diabetic-retinopathy-image-dataset", download=download, + ) + zip_path = os.path.join(path, "indian-diabetic-retinopathy-image-dataset.zip") + util.unzip(zip_path=zip_path, dst=os.path.join(path, "data")) + return data_dir + + +def _get_idrid_paths(path, split, task, download): + data_dir = get_idrid_data(path=path, download=download) + + split = r"a. Training Set" if split == "train" else r"b. Testing Set" + + gt_paths = sorted( + glob( + os.path.join(data_dir, r"A. Segmentation", r"2. All Segmentation Groundtruths", split, TASKS[task], "*.tif") + ) + ) + + image_dir = os.path.join(data_dir, r"A. Segmentation", r"1. Original Images", split) + image_paths = [] + for gt_path in gt_paths: + gt_id = Path(gt_path).stem[:-3] + image_path = os.path.join(image_dir, f"{gt_id}.jpg") + image_paths.append(image_path) + + return image_paths, gt_paths + + +def get_idrid_dataset( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int], + split: str, + task: str = "optic_disc", + resize_inputs: bool = False, + download: bool = False, + **kwargs +): + """Dataset for segmentation of retinal lesions and optic disc in fundus images. + + The database is located at https://ieee-dataport.org/open-access/indian-diabetic-retinopathy-image-dataset-idrid + The dataloader makes use of an open-source version of the original dataset hosted on Kaggle. + + The dataset is from the IDRiD challenge: + - https://idrid.grand-challenge.org/ + - Porwal et al. - https://doi.org/10.1016/j.media.2019.101561 + + Please cite it if you use this dataset for a publication. + """ + assert split in ["train", "test"] + assert task in list(TASKS.keys()) + + image_paths, gt_paths = _get_idrid_paths(path=path, split=split, task=task, download=download) + + if resize_inputs: + resize_kwargs = {"patch_shape": patch_shape, "is_rgb": True} + kwargs, patch_shape = util.update_kwargs_for_resize_trafo( + kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs + ) + + dataset = torch_em.default_segmentation_dataset( + raw_paths=image_paths, + raw_key=None, + label_paths=gt_paths, + label_key=None, + patch_shape=patch_shape, + is_seg_dataset=False, + **kwargs + ) + + return dataset + + +def get_idrid_loader( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int], + batch_size: int, + split: str, + task: str = "optic_disc", + resize_inputs: bool = False, + download: bool = False, + **kwargs +): + """ + Dataloader for segmentation of retinal lesions and optic disc in fundus images. See `get_idrid_dataset` for details. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_idrid_dataset( + path=path, + patch_shape=patch_shape, + split=split, + task=task, + resize_inputs=resize_inputs, + download=download, + **ds_kwargs + ) + loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) + return loader diff --git a/torch_em/data/datasets/medical/isic.py b/torch_em/data/datasets/medical/isic.py new file mode 100644 index 00000000..efabba2c --- /dev/null +++ b/torch_em/data/datasets/medical/isic.py @@ -0,0 +1,147 @@ +import os +from glob import glob +from pathlib import Path +from typing import Union, Tuple + +import torch_em + +from .. import util +from ..light_microscopy.neurips_cell_seg import to_rgb + + +URL = { + "images": { + "train": "https://isic-challenge-data.s3.amazonaws.com/2018/ISIC2018_Task1-2_Training_Input.zip", + "val": "https://isic-challenge-data.s3.amazonaws.com/2018/ISIC2018_Task1-2_Validation_Input.zip", + "test": "https://isic-challenge-data.s3.amazonaws.com/2018/ISIC2018_Task1-2_Test_Input.zip", + }, + "gt": { + "train": "https://isic-challenge-data.s3.amazonaws.com/2018/ISIC2018_Task1_Training_GroundTruth.zip", + "val": "https://isic-challenge-data.s3.amazonaws.com/2018/ISIC2018_Task1_Validation_GroundTruth.zip", + "test": "https://isic-challenge-data.s3.amazonaws.com/2018/ISIC2018_Task1_Test_GroundTruth.zip", + }, +} + +CHECKSUM = { + "images": { + "train": "80f98572347a2d7a376227fa9eb2e4f7459d317cb619865b8b9910c81446675f", + "val": "0ea920fcfe512d12a6e620b50b50233c059f67b10146e1479c82be58ff15a797", + "test": "e59ae1f69f4ed16f09db2cb1d76c2a828487b63d28f6ab85997f5616869b127d", + }, + "gt": { + "train": "99f8b2bb3c4d6af483362010715f7e7d5d122d9f6c02cac0e0d15bef77c7604c", + "val": "f6911e9c0a64e6d687dd3ca466ca927dd5e82145cb2163b7a1e5b37d7a716285", + "test": "2e8f6edce454a5bdee52485e39f92bd6eddf357e81f39018d05512175238ef82", + } +} + + +def get_isic_data(path, split, download): + os.makedirs(path, exist_ok=True) + + im_url = URL["images"][split] + im_checksum = CHECKSUM["images"][split] + + gt_url = URL["gt"][split] + gt_checksum = CHECKSUM["gt"][split] + + im_zipfile = os.path.split(im_url)[-1] + gt_zipfile = os.path.split(gt_url)[-1] + + imdir = os.path.join(path, Path(im_zipfile).stem) + gtdir = os.path.join(path, Path(gt_zipfile).stem) + + im_zip_path = os.path.join(path, im_zipfile) + gt_zip_path = os.path.join(path, gt_zipfile) + + if os.path.exists(imdir) and os.path.exists(gtdir): + return imdir, gtdir + + # download the images + util.download_source(path=im_zip_path, url=im_url, download=download, checksum=im_checksum) + util.unzip(zip_path=im_zip_path, dst=path, remove=False) + # download the ground-truth + util.download_source(path=gt_zip_path, url=gt_url, download=download, checksum=gt_checksum) + util.unzip(zip_path=gt_zip_path, dst=path, remove=False) + + return imdir, gtdir + + +def _get_isic_paths(path, split, download): + image_dir, gt_dir = get_isic_data(path=path, split=split, download=download) + + image_paths = sorted(glob(os.path.join(image_dir, "*.jpg"))) + gt_paths = sorted(glob(os.path.join(gt_dir, "*.png"))) + + return image_paths, gt_paths + + +def get_isic_dataset( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int], + split: str, + download: bool = False, + make_rgb: bool = True, + resize_inputs: bool = False, + **kwargs +): + """Dataset for the segmentation of skin lesion in dermoscopy images. + + This dataset is related to the following publication(s): + - https://doi.org/10.1038/sdata.2018.161 + - https://doi.org/10.48550/arXiv.1710.05006 + - https://doi.org/10.48550/arXiv.1902.03368 + + The database is located at https://challenge.isic-archive.com/data/#2018 + + Please cite it if you use this dataset for a publication. + """ + assert split in list(URL["images"].keys()), f"{split} is not a valid split." + + image_paths, gt_paths = _get_isic_paths(path=path, split=split, download=download) + + if make_rgb: + kwargs["raw_transform"] = to_rgb + + if resize_inputs: + resize_kwargs = {"patch_shape": patch_shape, "is_rgb": True} + kwargs, patch_shape = util.update_kwargs_for_resize_trafo( + kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs + ) + + dataset = torch_em.default_segmentation_dataset( + raw_paths=image_paths, + raw_key=None, + label_paths=gt_paths, + label_key=None, + patch_shape=patch_shape, + is_seg_dataset=False, + **kwargs + ) + return dataset + + +def get_isic_loader( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int], + batch_size: int, + split: str, + download: bool = False, + make_rgb: bool = True, + resize_inputs: bool = False, + **kwargs +): + """Dataloader for the segmentation of skin lesion in dermoscopy images. See `get_isic_dataset` for details. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_isic_dataset( + path=path, + patch_shape=patch_shape, + split=split, + download=download, + make_rgb=make_rgb, + resize_inputs=resize_inputs, + **ds_kwargs + ) + loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) + return loader diff --git a/torch_em/data/datasets/medical/jnuifm.py b/torch_em/data/datasets/medical/jnuifm.py new file mode 100644 index 00000000..3de8933d --- /dev/null +++ b/torch_em/data/datasets/medical/jnuifm.py @@ -0,0 +1,96 @@ +import os +from glob import glob +from natsort import natsorted +from typing import Union, Tuple +from urllib.parse import urljoin + +import torch_em + +from .. import util + + +BASE_URL = "https://zenodo.org/records/7851339/files/" +URL = urljoin(BASE_URL, "Pubic%20Symphysis-Fetal%20Head%20Segmentation%20and%20Angle%20of%20Progression.zip") +CHECKSUM = "2b14d1c78e11cfb799d74951b0b985b90777c195f7a456ccd00528bf02802e21" + + +def get_jnuifm_data(path, download): + os.makedirs(path, exist_ok=True) + + data_dir = os.path.join(path, r"Pubic Symphysis-Fetal Head Segmentation and Angle of Progression") + if os.path.exists(data_dir): + return data_dir + + zip_path = os.path.join(path, "JNU-IFM.zip") + util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM) + util.unzip(zip_path=zip_path, dst=path) + + return data_dir + + +def _get_jnuifm_paths(path, download): + data_dir = get_jnuifm_data(path=path, download=download) + + image_paths = natsorted(glob(os.path.join(data_dir, "image_mha", "*.mha"))) + gt_paths = natsorted(glob(os.path.join(data_dir, "label_mha", "*.mha"))) + + return image_paths, gt_paths + + +def get_jnuifm_dataset( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int], + resize_inputs: bool = False, + download: bool = False, + **kwargs +): + """Dataset for segmentation of pubic symphysis and fetal head in ultrasound images. + + The label pixels are - 0: background, 1: pubic symphysis, 2: fetal head. + + The database is located at https://doi.org/10.5281/zenodo.7851339 + + The dataset is from Lu et al. - https://doi.org/10.1016/j.dib.2022.107904 + Please cite it if you use this dataset for a publication. + """ + image_paths, gt_paths = _get_jnuifm_paths(path=path, download=download) + + if resize_inputs: + resize_kwargs = {"patch_shape": patch_shape, "is_rgb": True} + kwargs, patch_shape = util.update_kwargs_for_resize_trafo( + kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs + ) + + dataset = torch_em.default_segmentation_dataset( + raw_paths=image_paths, + raw_key=None, + label_paths=gt_paths, + label_key=None, + patch_shape=patch_shape, + ndim=2, + with_channels=True, + is_seg_dataset=False, + **kwargs + ) + + return dataset + + +def get_jnuifm_loader( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int], + batch_size: int, + resize_inputs: bool = False, + download: bool = False, + **kwargs +): + """ + Dataloader for segmentation of pubic symphysis and fetal head in ultrasound images. + See `get_jnuifm_loader` for details. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_jnuifm_dataset( + path=path, patch_shape=patch_shape, resize_inputs=resize_inputs, download=download, **ds_kwargs + ) + loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) + return loader diff --git a/torch_em/data/datasets/medical/leg_3d_us.py b/torch_em/data/datasets/medical/leg_3d_us.py new file mode 100644 index 00000000..25c419c4 --- /dev/null +++ b/torch_em/data/datasets/medical/leg_3d_us.py @@ -0,0 +1,145 @@ +"""The LEG 3D US dataset contains annotations for leg muscle segmentation +in 3d ultrasound scans. + +NOTE: The label legends are described as follows: +- background: 0 +- soleus (SOL): 100 +- gastrocnemius medialis (GM): 150 +- gastrocnemuis lateralist (GL): 200 + +The dataset is located at https://www.cs.cit.tum.de/camp/publications/leg-3d-us-dataset/. + +This dataset is from the article: https://doi.org/10.1371/journal.pone.0268550. +Please cite it if you use this dataset in your research. +""" + +import os +from glob import glob +from natsort import natsorted +from typing import Union, Tuple, Literal, List + +from torch.utils.data import Dataset, DataLoader + +import torch_em + +from .. import util + + +URLS = { + "train": "https://www.campar.in.tum.de/public_datasets/2024_IPCAI_Vanessa/leg_train_data.zip", + "val": "https://www.campar.in.tum.de/public_datasets/2024_IPCAI_Vanessa/leg_validation_data.zip", + "test": "https://www.campar.in.tum.de/public_datasets/2024_IPCAI_Vanessa/leg_test_data.zip", +} + +CHECKSUMS = { + "train": "747e9ada7135979218d93022ac46d40a3a85119e2ea7aebcda4b13f7dfda70d6", + "val": "c204fa0759dd279de722a423401da60657bc0d1ab5f57d135cd0ad55c32af70f", + "test": "42ad341e8133f827d35f9cb3afde3ffbe5ae97dc2af448b6f9af6d4ea6ac99f0", +} + + +def get_leg_3d_us_data( + path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False +): + """Download the LEG 3D US data. + + Args: + path: Filepath to a folder where the data is downloaded for further processing. + split: The data split to use. Either 'train', 'val' or 'test'. + download: Whether to download the data if it is not present. + """ + data_dir = os.path.join(path, split) + if os.path.exists(data_dir): + return + + os.makedirs(path, exist_ok=True) + + if split not in URLS: + raise ValueError(f"'{split}' is not a valid split choice.") + + zip_name = "validation" if split == "val" else split + zip_path = os.path.join(path, f"leg_{zip_name}_data.zip") + util.download_source(path=zip_path, url=URLS[split], download=download, checksum=CHECKSUMS[split]) + util.unzip(zip_path=zip_path, dst=path) + + +def get_leg_3d_us_paths( + path: Union[os.PathLike, str], split: Literal['train', 'val', 'test'], download: bool = False +) -> Tuple[List[str], List[str]]: + """Get paths to the LEG 3D US data. + + Args: + path: Filepath to a folder where the data is downloaded for further processing. + split: The data split to use. Either 'train', 'val' or 'test'. + download: Whether to download the data if it is not present. + + Returns: + List of filepathgs for the image data. + List of filepaths for the label data. + """ + get_leg_3d_us_data(path, split, download) + + raw_paths = natsorted(glob(os.path.join(path, split, "*", "x*.mha"))) + label_paths = [fpath.replace("x", "masksX") for fpath in raw_paths] + + assert len(raw_paths) == len(label_paths) + + return raw_paths, label_paths + + +def get_leg_3d_us_dataset( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, ...], + split: Literal['train', 'val', 'test'], + download: bool = False, + **kwargs +) -> Dataset: + """Get the LEG 3D US dataset for leg muscle segmentation. + + Args: + path: Filepath to a folder where the data is downloaded for further processing. + patch_shape: The patch shape to use for training. + split: The data split to use. Either 'train', 'val' or 'test'. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The segmentation dataset. + """ + raw_paths, label_paths = get_leg_3d_us_paths(path, split, download) + + return torch_em.default_segmentation_dataset( + raw_paths=raw_paths, + raw_key=None, + label_paths=label_paths, + label_key=None, + patch_shape=patch_shape, + is_seg_dataset=True, + **kwargs + ) + + +def get_leg_3d_us_loader( + path: Union[os.PathLike, str], + batch_size: int, + patch_shape: Tuple[int, ...], + split: Literal['train', 'val', 'test'], + download: bool = False, + **kwargs +) -> DataLoader: + """Get the LEG 3D US dataloader for leg muscle segmentation. + + Args: + path: Filepath to a folder where the data is downloaded for further processing. + batch_size: The batch size for training. + patch_shape: The patch shape to use for training. + split: The data split to use. Either 'train', 'val' or 'test'. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. + + Returns: + The DataLoader. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_leg_3d_us_dataset(path, patch_shape, split, download, **ds_kwargs) + return torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/medical/lgg_mri.py b/torch_em/data/datasets/medical/lgg_mri.py new file mode 100644 index 00000000..06db0d30 --- /dev/null +++ b/torch_em/data/datasets/medical/lgg_mri.py @@ -0,0 +1,141 @@ +"""The LGG MRI datasets contains annotations for low grade glioma segmentation +in FLAIR MRI scans. + +The dataset is located at https://www.kaggle.com/datasets/mateuszbuda/lgg-mri-segmentation. + +This dataset is from the publication https://www.nejm.org/doi/full/10.1056/NEJMoa1402121. +Please cite it if you use this dataset in your research. +""" + +import os +import shutil +from glob import glob +from tqdm import tqdm +from natsort import natsorted +from typing import Union, Tuple, List + +import numpy as np +import imageio.v3 as imageio + +from torch.utils.data import Dataset, DataLoader + +import torch_em + +from .. import util + + +def _merge_slices_to_volumes(path): + import nibabel as nib + + raw_dir = os.path.join(path, "data", "raw") + label_dir = os.path.join(path, "data", "labels") + os.makedirs(raw_dir, exist_ok=True) + os.makedirs(label_dir, exist_ok=True) + + patient_dirs = glob(os.path.join(path, "kaggle_3m", "TCGA_*")) + for patient_dir in tqdm(patient_dirs, desc="Preprocessing inputs"): + label_slice_paths = natsorted(glob(os.path.join(patient_dir, "*_mask.tif"))) + raw_slice_paths = [lpath.replace("_mask.tif", ".tif") for lpath in label_slice_paths] + + raw = [imageio.imread(rpath) for rpath in raw_slice_paths] + labels = [imageio.imread(lpath) for lpath in label_slice_paths] + + raw, labels = np.stack(raw, axis=2), np.stack(labels, axis=2) + + raw_nifti = nib.Nifti2Image(raw, np.eye(4)) + label_nifti = nib.Nifti2Image(labels, np.eye(4)) + + nib.save(raw_nifti, os.path.join(raw_dir, f"{os.path.basename(patient_dir)}.nii.gz")) + nib.save(label_nifti, os.path.join(label_dir, f"{os.path.basename(patient_dir)}.nii.gz")) + + shutil.rmtree(os.path.join(path, "kaggle_3m")) + + +def get_lgg_mri_data(path: Union[os.PathLike, str], download: bool = False): + """Download the LGG MRI data. + + Args: + path: Filepath to a folder where the data is downloaded for further processing. + download: Whether to download the data if it is not present. + """ + data_dir = os.path.join(path, "data") + if os.path.exists(data_dir): + return + + os.makedirs(path, exist_ok=True) + + util.download_source_kaggle(path=path, dataset_name="mateuszbuda/lgg-mri-segmentation", download=download) + zip_path = os.path.join(path, "lgg-mri-segmentation.zip") + util.unzip(zip_path=zip_path, dst=path) + + # Remove redundant volumes + shutil.rmtree(os.path.join(path, "lgg-mri-segmentation")) + + _merge_slices_to_volumes(path) + + +def get_lgg_mri_paths(path: Union[os.PathLike, str], download: bool = False) -> Tuple[List[str], List[str]]: + """Get paths to the LGG MRI data. + + Args: + path: Filepath to a folder where the data is downloaded for further processing. + download: Whether to download the data if it is not present. + + Returns: + List of filepaths for the image data. + List of filepaths for the label data. + """ + get_lgg_mri_data(path, download) + + raw_paths = natsorted(glob(os.path.join(path, "data", "raw", "*.nii.gz"))) + label_paths = natsorted(glob(os.path.join(path, "data", "labels", "*.nii.gz"))) + + return raw_paths, label_paths + + +def get_lgg_mri_dataset( + path: Union[os.PathLike, str], patch_shape: Tuple[int, ...], download: bool, **kwargs +) -> Dataset: + """Get the LGG MRI dataset for glioma segmentation. + + Args: + path: Filepath to a folder where the data is downloaded for further processing. + patch_shape: The patch shape to use for training. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The segmentation dataset. + """ + raw_paths, label_paths = get_lgg_mri_paths(path, download) + + return torch_em.default_segmentation_dataset( + raw_paths=raw_paths, + raw_key="data", + label_paths=label_paths, + label_key="data", + patch_shape=patch_shape, + is_seg_dataset=True, + with_channels=True, + **kwargs + ) + + +def get_lgg_mri_loader( + path: Union[os.PathLike, str], batch_size: int, patch_shape: Tuple[int, ...], download: bool, **kwargs +) -> DataLoader: + """Get the LGG MRI dataloader for glioma segmentation. + + Args: + path: Filepath to a folder where the data is downloaded for further processing. + batch_size: The batch size for training. + patch_shape: The patch shape to use for training. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. + + Returns: + The DataLoader. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_lgg_mri_dataset(path, patch_shape, download, **ds_kwargs) + return torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/medical/m2caiseg.py b/torch_em/data/datasets/medical/m2caiseg.py new file mode 100644 index 00000000..51f92710 --- /dev/null +++ b/torch_em/data/datasets/medical/m2caiseg.py @@ -0,0 +1,172 @@ +import os +from glob import glob +from tqdm import tqdm +from pathlib import Path +from natsort import natsorted +from typing import Union, Tuple + +import numpy as np +import imageio.v3 as imageio + +import torch_em + +from .. import util + + +LABEL_MAPS = { + (0, 0, 0): 0, # out of frame + (0, 85, 170): 1, # grasper + (0, 85, 255): 2, # bipolar + (0, 170, 255): 3, # hook + (0, 255, 85): 4, # scissors + (0, 255, 170): 5, # clipper + (85, 0, 170): 6, # irrigator + (85, 0, 255): 7, # specimen bag + (170, 85, 85): 8, # trocars + (170, 170, 170): 9, # clip + (85, 170, 0): 10, # liver + (85, 170, 255): 11, # gall bladder + (85, 255, 0): 12, # fat + (85, 255, 170): 13, # upper wall + (170, 0, 255): 14, # artery + (255, 0, 255): 15, # intestine + (255, 255, 0): 16, # bile + (255, 0, 0): 17, # blood + (170, 0, 85): 18, # unknown +} + + +def get_m2caiseg_data(path, download): + os.makedirs(path, exist_ok=True) + + data_dir = os.path.join(path, r"m2caiSeg dataset") + if os.path.exists(data_dir): + return data_dir + + util.download_source_kaggle(path=path, dataset_name="salmanmaq/m2caiseg", download=download) + zip_path = os.path.join(path, "m2caiseg.zip") + util.unzip(zip_path=zip_path, dst=path) + + return data_dir + + +def _get_m2caiseg_paths(path, split, download): + data_dir = get_m2caiseg_data(path=path, download=download) + + if split == "val": + impaths = natsorted(glob(os.path.join(data_dir, "train", "images", "*.jpg"))) + gpaths = natsorted(glob(os.path.join(data_dir, "train", "groundtruth", "*.png"))) + + imids = [os.path.split(_p)[-1] for _p in impaths] + gids = [os.path.split(_p)[-1] for _p in gpaths] + + image_paths = [ + _p for _p in natsorted( + glob(os.path.join(data_dir, "trainval", "images", "*.jpg")) + ) if os.path.split(_p)[-1] not in imids + ] + gt_paths = [ + _p for _p in natsorted( + glob(os.path.join(data_dir, "trainval", "groundtruth", "*.png")) + ) if os.path.split(_p)[-1] not in gids + ] + + else: + image_paths = natsorted(glob(os.path.join(data_dir, split, "images", "*.jpg"))) + gt_paths = natsorted(glob(os.path.join(data_dir, split, "groundtruth", "*.png"))) + + images_dir = os.path.join(data_dir, "preprocessed", split, "images") + mask_dir = os.path.join(data_dir, "preprocessed", split, "masks") + if os.path.exists(images_dir) and os.path.exists(mask_dir): + return natsorted(glob(os.path.join(images_dir, "*"))), natsorted(glob(os.path.join(mask_dir, "*"))) + + os.makedirs(images_dir, exist_ok=True) + os.makedirs(mask_dir, exist_ok=True) + + fimage_paths, fgt_paths = [], [] + for image_path, gt_path in tqdm(zip(image_paths, gt_paths), total=len(image_paths)): + image = imageio.imread(image_path) + gt = imageio.imread(gt_path) + + image_id = Path(image_path).stem + gt_id = Path(gt_path).stem + + if image.shape != gt.shape: + print("This pair of image and labels mismatch.") + continue + + dst_image_path = os.path.join(images_dir, f"{image_id}.tif") + dst_gt_path = os.path.join(mask_dir, f"{gt_id}.tif") + + fimage_paths.append(image_path) + fgt_paths.append(dst_gt_path) + if os.path.exists(dst_gt_path) and os.path.exists(dst_image_path): + continue + + instances = np.zeros(gt.shape[:2]) + for lmap in LABEL_MAPS: + binary_map = (gt == lmap).all(axis=2) + instances[binary_map > 0] = LABEL_MAPS[lmap] + + imageio.imwrite(dst_image_path, image, compression="zlib") + imageio.imwrite(dst_gt_path, instances, compression="zlib") + + return fimage_paths, fgt_paths + + +def get_m2caiseg_dataset( + path: Union[os.PathLike, str], + split: str, + patch_shape: Tuple[int, int], + resize_inputs: bool = False, + download: bool = False, + **kwargs +): + """Dataset for segmentation of organs and instruments in endoscopy. + + This data is from Maqbool et al. - https://doi.org/10.48550/arXiv.2008.10134 + + This dataset is located at https://www.kaggle.com/datasets/salmanmaq/m2caiseg + + Please cite it if you use this data in a publication. + """ + assert split in ["train", "val", "test"] + + image_paths, gt_paths = _get_m2caiseg_paths(path=path, split=split, download=download) + + if resize_inputs: + resize_kwargs = {"patch_shape": patch_shape, "is_rgb": True} + kwargs, patch_shape = util.update_kwargs_for_resize_trafo( + kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs + ) + + dataset = torch_em.default_segmentation_dataset( + raw_paths=image_paths, + raw_key=None, + label_paths=gt_paths, + label_key=None, + patch_shape=patch_shape, + is_seg_dataset=False, + **kwargs + ) + + return dataset + + +def get_m2caiseg_loader( + path: Union[os.PathLike, str], + split: str, + patch_shape: Tuple[int, int], + batch_size: int, + resize_inputs: bool = False, + download: bool = False, + **kwargs +): + """Dataloader for segmentation of organs and instruments in endoscopy. See `get_m2caiseg_dataset` for details. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_m2caiseg_dataset( + path=path, split=split, patch_shape=patch_shape, resize_inputs=resize_inputs, download=download, **ds_kwargs + ) + loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) + return loader diff --git a/torch_em/data/datasets/medical/mbh_seg.py b/torch_em/data/datasets/medical/mbh_seg.py new file mode 100644 index 00000000..1b7a4c91 --- /dev/null +++ b/torch_em/data/datasets/medical/mbh_seg.py @@ -0,0 +1,125 @@ +"""The MBH Seg dataset contains annotations for intracranial hemorrhages +in non-contrast CT scans. + +This dataset is from the MBH-Seg challenge: https://mbh-seg.com +- original scans: https://kaggle.com/competitions/rsna-intracranial-hemorrhage-detection +Please cite these if you use this dataset for your publication. +""" + +import os +from glob import glob +from natsort import natsorted +from typing import Union, Tuple, List + +from torch.utils.data import Dataset, DataLoader + +import torch_em + +from .. import util + + +URL = "https://huggingface.co/datasets/WuBiao/BHSD/resolve/main/label_192.zip" +CHECKSUM = "582bf184af993541a4958a4d209a6a44e3bbe702a5daefaf9fb1733a4e7a6e39" + + +def get_mbh_seg_data(path: Union[os.PathLike, str], download: bool = False) -> str: + """Download the MBH Seg dataset. + + Args: + path: Filepath to a folder where the data is downloaded for further processing. + download: Whether to download the data if it is not present. + """ + data_dir = os.path.join(path, "label_192") + if os.path.exists(data_dir): + return data_dir + + os.makedirs(path, exist_ok=True) + + zip_path = os.path.join(path, "label_192.zip") + util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM) + util.unzip(zip_path=zip_path, dst=path) + + return data_dir + + +def get_mbh_seg_paths(path: Union[os.PathLike, str], download: bool = False) -> Tuple[List[str], List[str]]: + """Get paths to the MBH Seg data. + + Args: + path: Filepath to a folder where the data is downloaded for further processing. + download: Whether to download the data if it is not present. + + Returns: + List of filepaths for the image data. + List of filepaths for the label data. + """ + data_dir = get_mbh_seg_data(path=path, download=download) + + image_paths = natsorted(glob(os.path.join(data_dir, "images", "*.nii.gz"))) + gt_paths = natsorted(glob(os.path.join(data_dir, r"ground truths", "*.nii.gz"))) + + return image_paths, gt_paths + + +def get_mbh_seg_dataset( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, ...], + resize_inputs: bool = False, + download: bool = False, + **kwargs +) -> Dataset: + """Get the MBH Seg dataset for intracranial hemorrhage segmentation. + + Args: + path: Filepath to a folder where the data is downloaded for further processing. + patch_shape: The patch shape to use for training. + resize_inputs: Whether to resize inputs to the desired patch shape. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The segmentation dataset. + """ + image_paths, gt_paths = get_mbh_seg_paths(path=path, download=download) + + if resize_inputs: + resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False} + kwargs, patch_shape = util.update_kwargs_for_resize_trafo( + kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs + ) + + return torch_em.default_segmentation_dataset( + raw_paths=image_paths, + raw_key="data", + label_paths=gt_paths, + label_key="data", + patch_shape=patch_shape, + **kwargs + ) + + +def get_mbh_seg_loader( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, ...], + batch_size: int, + resize_inputs: bool = False, + download: bool = False, + **kwargs +) -> DataLoader: + """Get the MBH Seg dataloader for intracranial hemorrhage segmentation. + + Args: + path: Filepath to a folder where the data is downloaded for further processing. + patch_shape: The patch shape to use for training. + resize_inputs: Whether to resize inputs to the desired patch shape. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. + + Returns: + The DataLoader. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_mbh_seg_dataset( + path=path, patch_shape=patch_shape, resize_inputs=resize_inputs, download=download, **ds_kwargs + ) + return torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/medical/micro_usp.py b/torch_em/data/datasets/medical/micro_usp.py new file mode 100644 index 00000000..eb22ca7f --- /dev/null +++ b/torch_em/data/datasets/medical/micro_usp.py @@ -0,0 +1,89 @@ +import os +from glob import glob +from pathlib import Path +from natsort import natsorted +from typing import Union, Tuple + +import torch_em + +from .. import util + + +URL = "https://zenodo.org/records/10475293/files/Micro_Ultrasound_Prostate_Segmentation_Dataset.zip?" +CHECKSUM = "031645dc30948314e379d0a0a7d54bad1cd4e1f3f918b77455d69810aa05dce3" + + +def get_micro_usp_data(path, download): + os.makedirs(path, exist_ok=True) + + fname = Path(URL).stem + data_dir = os.path.join(path, fname) + if os.path.exists(data_dir): + return data_dir + + zip_path = os.path.join(path, f"{fname}.zip") + util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM) + util.unzip(zip_path=zip_path, dst=path) + + return data_dir + + +def _get_micro_usp_paths(path, split, download): + data_dir = get_micro_usp_data(path=path, download=download) + + image_paths = natsorted(glob(os.path.join(data_dir, split, "micro_ultrasound_scans", "*.nii.gz"))) + gt_paths = natsorted(glob(os.path.join(data_dir, split, "expert_annotations", "*.nii.gz"))) + + return image_paths, gt_paths + + +def get_micro_usp_dataset( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, ...], + split: str, + resize_inputs: bool = False, + download: bool = False, + **kwargs +): + """Dataset for segmentation of prostate in micro-ultrasound scans. + + This dataset is from Jiang et al. - https://doi.org/10.1016/j.compmedimag.2024.102326. + Please cite it if you use this dataset for a publication. + """ + image_paths, gt_paths = _get_micro_usp_paths(path=path, split=split, download=download) + + if resize_inputs: + resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False} + kwargs, patch_shape = util.update_kwargs_for_resize_trafo( + kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs + ) + + dataset = torch_em.default_segmentation_dataset( + raw_paths=image_paths, + raw_key="data", + label_paths=gt_paths, + label_key="data", + patch_shape=patch_shape, + **kwargs + ) + + return dataset + + +def get_micro_usp_loader( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, ...], + batch_size: int, + split: str, + resize_inputs: bool = False, + download: bool = False, + **kwargs +): + """Dataloader for segmentation of prostate in micro-ultrasound scans. See `get_micro_usp_dataset` for details. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_micro_usp_dataset( + path=path, patch_shape=patch_shape, split=split, resize_inputs=resize_inputs, download=download, **ds_kwargs + ) + loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) + return loader diff --git a/torch_em/data/datasets/medical/montgomery.py b/torch_em/data/datasets/medical/montgomery.py new file mode 100644 index 00000000..2cdb3919 --- /dev/null +++ b/torch_em/data/datasets/medical/montgomery.py @@ -0,0 +1,116 @@ +import os +from glob import glob +from tqdm import tqdm +from typing import Union, Tuple + +import imageio.v3 as imageio + +import torch_em + +from .. import util + + +URL = "http://openi.nlm.nih.gov/imgs/collections/NLM-MontgomeryCXRSet.zip" +CHECKSUM = "54601e952315d8f67383e9202a6e145997ade429f54f7e0af44b4e158714f424" + + +def get_montgomery_data(path, download): + os.makedirs(path, exist_ok=True) + + data_dir = os.path.join(path, "MontgomerySet") + if os.path.exists(data_dir): + return data_dir + + zip_path = os.path.join(path, "NLM-MontgomeryCXRSet.zip") + util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM) + util.unzip(zip_path=zip_path, dst=path) + + return data_dir + + +def _get_montgomery_paths(path, download): + data_dir = get_montgomery_data(path=path, download=download) + gt_dir = os.path.join(data_dir, "ManualMask", "gt") + + image_paths = sorted(glob(os.path.join(data_dir, "CXR_png", "*.png"))) + + if os.path.exists(gt_dir): + gt_paths = sorted(glob(os.path.join(gt_dir, "*.png"))) + if len(image_paths) == len(gt_paths): + return image_paths, gt_paths + + else: + os.makedirs(gt_dir, exist_ok=True) + + lmask_dir = os.path.join(data_dir, "ManualMask", "leftMask") + rmask_dir = os.path.join(data_dir, "ManualMask", "rightMask") + gt_paths = [] + for image_path in tqdm(image_paths, desc="Merging left and right lung halves"): + image_id = os.path.split(image_path)[-1] + + # merge the left and right lung halves into one gt file + gt = imageio.imread(os.path.join(lmask_dir, image_id)) + gt += imageio.imread(os.path.join(rmask_dir, image_id)) + gt = gt.astype("uint8") + + gt_path = os.path.join(gt_dir, image_id) + + imageio.imwrite(gt_path, gt) + gt_paths.append(gt_path) + + return image_paths, gt_paths + + +def get_montgomery_dataset( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int], + resize_inputs: bool = True, + download: bool = False, + **kwargs +): + """Dataset for the segmentation of lungs in x-ray. + + This dataset is from the publication: + - https://doi.org/10.1109/TMI.2013.2284099 + - https://doi.org/10.1109/tmi.2013.2290491 + + The database is located at + https://data.lhncbc.nlm.nih.gov/public/Tuberculosis-Chest-X-ray-Datasets/Montgomery-County-CXR-Set/MontgomerySet/index.html. + + Please cite it if you use this dataset for a publication. + """ + image_paths, gt_paths = _get_montgomery_paths(path=path, download=download) + + if resize_inputs: + resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False} + kwargs, patch_shape = util.update_kwargs_for_resize_trafo( + kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs + ) + + dataset = torch_em.default_segmentation_dataset( + raw_paths=image_paths, + raw_key=None, + label_paths=gt_paths, + label_key=None, + patch_shape=patch_shape, + **kwargs + ) + return dataset + + +def get_montgomery_loader( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int], + batch_size: int, + resize_inputs: bool = True, + download: bool = False, + **kwargs +): + """Dataloader for the segmentation of lungs in x-ray. See 'get_montgomery_dataset' for details. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_montgomery_dataset( + path=path, patch_shape=patch_shape, resize_inputs=resize_inputs, download=download, **ds_kwargs + ) + loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) + return loader diff --git a/torch_em/data/datasets/medical/msd.py b/torch_em/data/datasets/medical/msd.py new file mode 100644 index 00000000..82a3b2a5 --- /dev/null +++ b/torch_em/data/datasets/medical/msd.py @@ -0,0 +1,148 @@ +import os +from glob import glob +from pathlib import Path +from typing import Tuple, List, Union + +import torch_em + +from .. import util +from ....data import ConcatDataset + + +URL = { + "braintumour": "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task01_BrainTumour.tar", + "heart": "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task02_Heart.tar", + "liver": "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task03_Liver.tar", + "hippocampus": "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task04_Hippocampus.tar", + "prostate": "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task05_Prostate.tar", + "lung": "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task06_Lung.tar", + "pancreas": "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task07_Pancreas.tar", + "hepaticvessel": "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task08_HepaticVessel.tar", + "spleen": "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task09_Spleen.tar", + "colon": "https://msd-for-monai.s3-us-west-2.amazonaws.com/Task10_Colon.tar", +} + +CHECKSUM = { + "braintumour": "d423911308d2ae5396d9c6bf4fad2b68cfde2dd09044269da9c0d639c22753c4", + "heart": "4277dc6dfe100142aa8060e895f6ff0f81c5b733703ea250bd294df8f820bcba", + "liver": "4007d9db1acda850d57a6ceb2b3998b7a0d43f8ad5a3f740dc38bc0cb8b7a2c5", + "hippocampus": "282d808a3e84e5a52f090d9dd4c0b0057b94a6bd51ad41569aef5ff303287771", + "prostate": "8cbbd7147691109b880ff8774eb6ab26704b1be0935482e7996a36a4ed31ec79", + "lung": "f782cd09da9cf7a3128475d4a53650d371db10f0427aa76e166fccfcb2654161", + "pancreas": "e40181a0229ca85c2588d6ebb90fa6674f84eb1e66f0f968cda088d011769732", + "hepaticvessel": "ee880799f12e3b6e1ef2f8645f6626c5b39de77a4f1eae6f496c25fbf306ba04", + "spleen": "dfeba347daae4fb08c38f4d243ab606b28b91b206ffc445ec55c35489fa65e60", + "colon": "a26bfd23faf2de703f5a51a262cd4e2b9774c47e7fb86f0e0a854f8446ec2325", +} + +FILENAMES = { + "braintumour": "Task01_BrainTumour.tar", + "heart": "Task02_Heart.tar", + "liver": "Task03_Liver.tar", + "hippocampus": "Task04_Hippocampus.tar", + "prostate": "Task05_Prostate.tar", + "lung": "Task06_Lung.tar", + "pancreas": "Task07_Pancreas.tar", + "hepaticvessel": "Task08_HepaticVessel.tar", + "spleen": "Task09_Spleen.tar", + "colon": "Task10_Colon.tar", +} + + +def get_msd_data(path, task_name, download): + os.makedirs(path, exist_ok=True) + + data_dir = os.path.join(path, "data", task_name) + if os.path.exists(data_dir): + return data_dir + + fpath = os.path.join(path, FILENAMES[task_name]) + + util.download_source(path=fpath, url=URL[task_name], download=download, checksum=None) + util.unzip_tarfile(tar_path=fpath, dst=data_dir, remove=False) + + return data_dir + + +def get_msd_dataset( + path: str, + patch_shape: Tuple[int, ...], + ndim: int, + task_names: Union[str, List[str]], + download: bool = False, + **kwargs +): + """Dataset for semantic segmentation in 10 medical imaging datasets. + + This dataset is from the Medical Segmentation Decathlon Challenge: + - Antonelli et al. - https://doi.org/10.1038/s41467-022-30695-9 + - Link - http://medicaldecathlon.com/ + + Please cite it if you use this dataset for a publication. + + Args: + path: The path to prepare the dataset. + patch_shape: The patch shape (for 2d or 3d patches) + ndim: The dimensions of inputs (use `2` for getting `2d` patches, and `3` for getting 3d patches) + task_names: The names for the 10 different segmentation tasks (see the challenge website for further details): + 1. tasks with 1 modality inputs are: heart, liver, hippocampus, lung, pancreas, hepaticvessel, spleen, colon + 2. tasks with multi-modality inputs are: + - braintumour: with 4 modality (channel) inputs + - prostate: with 2 modality (channel) inputs + download: Downloads the dataset + + Here's an example for how to pass different tasks: + ```python + # we want to get datasets for one task, eg. "heart" + task_names = ["heart"] + + # we want to get datasets for multiple tasks + # NOTE 1: it's important to note that datasets with similar number of modality (channels) can be paired together. + # to use different datasets together, you need to use "raw_transform" to update inputs per dataset + # to pair as desired patch shapes per batch. + # Example 1: "heart", "liver", "lung" all have one modality inputs + task_names = ["heart", "lung", "liver"] + + # Example 2: "braintumour" and "prostate" have multi-modal inputs, however the no. of modalities are not equal. + # hence, you can use only one at a time. + task_names = ["prostate"] + ``` + """ + if isinstance(task_names, str): + task_names = [task_names] + + _datasets = [] + for task_name in task_names: + data_dir = get_msd_data(path, task_name, download) + image_paths = glob(os.path.join(data_dir, Path(FILENAMES[task_name]).stem, "imagesTr", "*.nii.gz")) + label_paths = glob(os.path.join(data_dir, Path(FILENAMES[task_name]).stem, "labelsTr", "*.nii.gz")) + + if task_name in ["braintumour", "prostate"]: + kwargs["with_channels"] = True + + this_dataset = torch_em.default_segmentation_dataset( + raw_paths=image_paths, + raw_key="data", + label_paths=label_paths, + label_key="data", + patch_shape=patch_shape, + ndim=ndim, + **kwargs + ) + _datasets.append(this_dataset) + + return ConcatDataset(*_datasets) + + +def get_msd_loader( + path, patch_shape, batch_size, ndim, task_names, download=False, **kwargs +): + """Dataloader for semantic segmentation from 10 highly variable medical segmentation tasks. + See `get_msd_dataset` for details. + """ + ds_kwargs, loader_kwargs = util.split_kwargs( + torch_em.default_segmentation_dataset, **kwargs + ) + ds = get_msd_dataset(path, patch_shape, ndim, task_names, download, **ds_kwargs) + loader = torch_em.get_data_loader(ds, batch_size=batch_size, **loader_kwargs) + return loader diff --git a/torch_em/data/datasets/medical/oasis.py b/torch_em/data/datasets/medical/oasis.py new file mode 100644 index 00000000..e18187cf --- /dev/null +++ b/torch_em/data/datasets/medical/oasis.py @@ -0,0 +1,135 @@ +"""The OASIS dataset contains two set of annotations: +one for 4 tissue segmentation and 35 anatomical segmentation in brain T1 MRI. + +The dataset comes from https://github.com/adalca/medical-datasets/blob/master/neurite-oasis.md. + +This dataset is from the following publications: +- https://doi.org/10.59275/j.melba.2022-74f1 +- https://doi.org/10.1162/jocn.2007.19.9.1498 + +Please cite them if you use this dataset for your research. +""" + +import os +from glob import glob +from typing import Union, Tuple, Literal, List + +from torch.utils.data import Dataset, DataLoader + +import torch_em + +from .. import util + + +URL = "https://surfer.nmr.mgh.harvard.edu/ftp/data/neurite/data/neurite-oasis.v1.0.tar" +CHECKSUM = "86dd117dda17f736ade8a4088d7e98e066e1181950fe8b406f1a35f7fb743e78" + + +def get_oasis_data(path: Union[os.PathLike, str], download: bool = False): + """Download the OASIS dataset. + + Args: + path: Filepath to a folder where the data is downloaded for further processing. + download: Whether to download the data if it is not present. + """ + data_path = os.path.join(path, "data") + if os.path.exists(data_path): + return + + os.makedirs(path, exist_ok=True) + util.download_source(path=path, url=URL, download=download, checksum=CHECKSUM) + tar_path = os.path.join(path, "neurite-oasis.v1.0.tar") + util.unzip_tarfile(tar_path=tar_path, dst=data_path, remove=False) + + +def get_oasis_paths( + path: Union[os.PathLike, str], + source: Literal['orig', 'norm'] = "orig", + label_annotations: Literal['4', '35'] = "4", + download: bool = False +) -> Tuple[List[int], List[int]]: + """Get paths to the OASIS data. + + Args: + path: Filepath to a folder where the data is downloaded for further processing. + source: The source of inputs. Either 'orig' (original brain scans) or 'norm' (skull stripped). + label_annotations: The set of annotations. Either '4' (for tissues) or '35' (for anatomy). + download: Whether to download the data if it is not present. + + Returns: + List of filepaths for the image data. + List of filepaths for the label data. + """ + get_oasis_data(path, download) + + patient_dirs = glob(os.path.join(path, "data", "OASIS_*")) + raw_paths, label_paths = [], [] + for pdir in patient_dirs: + raw_paths.append(os.path.join(pdir, f"seg{label_annotations}.nii.gz")) + label_paths.append(os.path.join(pdir, f"{source}.nii.gz")) + + assert len(raw_paths) == len(label_paths) + + return raw_paths, label_paths + + +def get_oasis_dataset( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, ...], + source: Literal['orig', 'norm'] = "orig", + label_annotations: Literal['4', '35'] = "4", + download: bool = False, + **kwargs +) -> Dataset: + """Get the OASIS dataset for tissue / anatomical segmentation. + + Args: + path: Filepath to a folder where the data is downloaded for further processing. + patch_shape: The patch shape to use for training. + source: The source of inputs. Either 'orig' (original brain scans) or 'norm' (skull stripped). + label_annotations: The set of annotations. Either '4' (for tissues) or '35' (for anatomy). + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The segmentation dataset. + """ + raw_paths, label_paths = get_oasis_paths(path, source, label_annotations, download) + + return torch_em.default_segmentation_dataset( + raw_paths=raw_paths, + raw_key="data", + label_paths=label_paths, + label_key="data", + patch_shape=patch_shape, + is_seg_dataset=True, + **kwargs + ) + + +def get_oasis_loader( + path: Union[os.PathLike, str], + batch_size: int, + patch_shape: Tuple[int, ...], + source: Literal['orig', 'norm'] = "orig", + label_annotations: Literal['4', '35'] = "4", + download: bool = False, + **kwargs +) -> DataLoader: + """Get the OASIS dataloader for tissue / anatomical segmentation. + + Args: + path: Filepath to a folder where the data is downloaded for further processing. + batch_size: The batch size for training. + patch_shape: The patch shape to use for training. + source: The source of inputs. Either 'orig' (original brain scans) or 'norm' (skull stripped). + label_annotations: The set of annotations. Either '4' (for tissues) or '35' (for anatomy). + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. + + Returns: + The DataLoader. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_oasis_dataset(path, patch_shape, source, label_annotations, download, **ds_kwargs) + return torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/medical/oimhs.py b/torch_em/data/datasets/medical/oimhs.py new file mode 100644 index 00000000..24faada2 --- /dev/null +++ b/torch_em/data/datasets/medical/oimhs.py @@ -0,0 +1,159 @@ +import os +from glob import glob +from tqdm import tqdm +from pathlib import Path +from natsort import natsorted +from typing import Union, Tuple, Literal + +import json +import numpy as np +import imageio.v3 as imageio +from sklearn.model_selection import train_test_split + +import torch_em + +from .. import util + + +URL = "https://springernature.figshare.com/ndownloader/files/42522673" +CHECKSUM = "d93ba18964614eb9b0ba4b8dfee269efbb94ff27142e4b5ecf7cc86f3a1f9d80" + +LABEL_MAPS = { + (255, 255, 0): 1, # choroid + (0, 255, 0): 2, # retina + (0, 0, 255): 3, # intrarentinal cysts + (255, 0, 0): 4 # macular hole +} + + +def get_oimhs_data(path, download): + os.makedirs(path, exist_ok=True) + + data_dir = os.path.join(path, "data") + if os.path.exists(data_dir): + return data_dir + + zip_path = os.path.join(path, "oimhs_dataset.zip") + util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM) + util.unzip(zip_path=zip_path, dst=data_dir) + + return data_dir + + +def _create_splits(data_dir, split_file, test_fraction=0.2): + eye_dirs = natsorted(glob(os.path.join(data_dir, "Images", "*"))) + + # let's split the data + main_split, test_split = train_test_split(eye_dirs, test_size=test_fraction) + train_split, val_split = train_test_split(main_split, test_size=0.1) + + decided_splits = {"train": train_split, "val": val_split, "test": test_split} + + with open(split_file, "w") as f: + json.dump(decided_splits, f) + + +def _get_per_split_dirs(split_file, split): + with open(split_file, "r") as f: + data = json.load(f) + + return data[split] + + +def _get_oimhs_paths(path, split, download): + data_dir = get_oimhs_data(path=path, download=download) + + image_dir = os.path.join(data_dir, "preprocessed", "images") + gt_dir = os.path.join(data_dir, "preprocessed", "gt") + os.makedirs(image_dir, exist_ok=True) + os.makedirs(gt_dir, exist_ok=True) + + split_file = os.path.join(path, "split_file.json") + if not os.path.exists(split_file): + _create_splits(data_dir, split_file) + + eye_dirs = _get_per_split_dirs(split_file=split_file, split=split) + + image_paths, gt_paths = [], [] + for eye_dir in tqdm(eye_dirs): + eye_id = os.path.split(eye_dir)[-1] + all_oct_scan_paths = natsorted(glob(os.path.join(eye_dir, "*.png"))) + for per_scan_path in all_oct_scan_paths: + scan_id = Path(per_scan_path).stem + + image_path = os.path.join(image_dir, f"{eye_id}_{scan_id}.tif") + gt_path = os.path.join(gt_dir, f"{eye_id}_{scan_id}.tif") + + image_paths.append(image_path) + gt_paths.append(gt_path) + + if os.path.exists(image_path) and os.path.exists(gt_path): + continue + + scan = imageio.imread(per_scan_path) + image, gt = scan[:, :512, :], scan[:, 512:, :] + + instances = np.zeros(image.shape[:2]) + for lmap in LABEL_MAPS: + binary_map = (gt == lmap).all(axis=2) + instances[binary_map > 0] = LABEL_MAPS[lmap] + + imageio.imwrite(image_path, image, compression="zlib") + imageio.imwrite(gt_path, instances, compression="zlib") + + return image_paths, gt_paths + + +def get_oimhs_dataset( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int], + split: Literal["train", "val", "test"], + resize_inputs: bool = False, + download: bool = False, + **kwargs +): + """Dataset for segmentation of macular hole and retinal regions in OCT scans. + + The dataset is from Ye et al. - https://doi.org/10.1038/s41597-023-02675-1. + + Please cite it if you use this dataset for your publication. + """ + image_paths, gt_paths = _get_oimhs_paths(path=path, split=split, download=download) + + if resize_inputs: + resize_kwargs = {"patch_shape": patch_shape, "is_rgb": True} + kwargs, patch_shape = util.update_kwargs_for_resize_trafo( + kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs + ) + + dataset = torch_em.default_segmentation_dataset( + raw_paths=image_paths, + raw_key=None, + label_paths=gt_paths, + label_key=None, + patch_shape=patch_shape, + is_seg_dataset=False, + **kwargs + ) + + return dataset + + +def get_oimhs_loader( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int], + batch_size: int, + split: Literal["train", "val", "test"], + resize_inputs: bool = False, + download: bool = False, + **kwargs +): + """Dataloader for segmentation of macular hole and retinal regions in OCT scans. + See `get_oimhs_dataset` for details. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_oimhs_dataset( + path=path, patch_shape=patch_shape, split=split, resize_inputs=resize_inputs, download=download, **ds_kwargs + ) + loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) + return loader diff --git a/torch_em/data/datasets/medical/osic_pulmofib.py b/torch_em/data/datasets/medical/osic_pulmofib.py new file mode 100644 index 00000000..762f025e --- /dev/null +++ b/torch_em/data/datasets/medical/osic_pulmofib.py @@ -0,0 +1,170 @@ +import os +from glob import glob +from tqdm import tqdm +from pathlib import Path +from natsort import natsorted +from typing import Union, Tuple + +import json +import numpy as np + +import torch_em + +from .. import util + + +ORGAN_IDS = {"heart": 1, "lung": 2, "trachea": 3} + + +def get_osic_pulmofib_data(path, download): + os.makedirs(path, exist_ok=True) + + data_dir = os.path.join(path, "data") + if os.path.exists(data_dir): + return data_dir + + # download the data first + zip_path = os.path.join(path, "osic-pulmonary-fibrosis-progression.zip") + util.download_source_kaggle( + path=path, dataset_name="osic-pulmonary-fibrosis-progression", download=download, competition=True + ) + util.unzip(zip_path=zip_path, dst=data_dir, remove=False) + + # download the ground truth next + zip_path = os.path.join(path, "ct-lung-heart-trachea-segmentation.zip") + util.download_source_kaggle( + path=path, dataset_name="sandorkonya/ct-lung-heart-trachea-segmentation", download=download + ) + util.unzip(zip_path=zip_path, dst=data_dir) + + return data_dir + + +def _get_osic_pulmofib_paths(path, download): + import nrrd + import nibabel as nib + import pydicom as dicom + + data_dir = get_osic_pulmofib_data(path=path, download=download) + + image_dir = os.path.join(data_dir, "preprocessed", "images") + gt_dir = os.path.join(data_dir, "preprocessed", "ground_truth") + + os.makedirs(image_dir, exist_ok=True) + os.makedirs(gt_dir, exist_ok=True) + + cpath = os.path.join(data_dir, "preprocessed", "confirmer.json") + _completed_preproc = os.path.exists(cpath) + + image_paths, gt_paths = [], [] + uid_paths = natsorted(glob(os.path.join(data_dir, "train", "*"))) + for uid_path in tqdm(uid_paths): + uid = uid_path.split("/")[-1] + + image_path = os.path.join(image_dir, f"{uid}.nii.gz") + gt_path = os.path.join(gt_dir, f"{uid}.nii.gz") + + if _completed_preproc: + if os.path.exists(image_path) and os.path.exists(gt_path): + image_paths.append(image_path) + gt_paths.append(gt_path) + + continue + + # creating the volume out of individual dicom slices + all_slices = [] + for slice_path in natsorted(glob(os.path.join(uid_path, "*.dcm"))): + per_slice = dicom.dcmread(slice_path) + per_slice = per_slice.pixel_array + all_slices.append(per_slice) + all_slices = np.stack(all_slices).transpose(1, 2, 0) + + # next, combining the semantic organ annotations into one ground-truth volume with specific semantic labels + all_gt = np.zeros(all_slices.shape, dtype="uint8") + for ann_path in glob(os.path.join(data_dir, "*", "*", f"{uid}_*.nrrd")): + ann_organ = Path(ann_path).stem.split("_")[-1] + if ann_organ == "noisy": + continue + + per_gt, _ = nrrd.read(ann_path) + per_gt = per_gt.transpose(1, 0, 2) + + # some organ anns have weird dimension mismatch, we don't consider them for simplicity + if per_gt.shape == all_slices.shape: + all_gt[per_gt > 0] = ORGAN_IDS[ann_organ] + + # only if the volume has any labels (some volumes do not have segmentations), we save those raw and gt volumes + if len(np.unique(all_gt)) > 1: + all_gt = np.flip(all_gt, axis=2) + + image_nifti = nib.Nifti2Image(all_slices, np.eye(4)) + gt_nifti = nib.Nifti2Image(all_gt, np.eye(4)) + + nib.save(image_nifti, image_path) + nib.save(gt_nifti, gt_path) + + image_paths.append(image_path) + gt_paths.append(gt_path) + + if not _completed_preproc: + # since we do not have segmentation for all volumes, we store a file which reflects aggrement of created dataset + confirm_msg = "The dataset has been preprocessed. " + confirm_msg += f"It has {len(image_paths)} volume and {len(gt_paths)} respective ground-truth." + print(confirm_msg) + + with open(cpath, "w") as f: + json.dump(confirm_msg, f) + + return image_paths, gt_paths + + +def get_osic_pulmofib_dataset( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, ...], + resize_inputs: bool = False, + download: bool = False, + **kwargs +): + """Dataset for segmentation of lung, heart and trachea in CT scans. + + This dataset is from OSIC Pulmonary Fibrosis Progression Challenge: + - https://www.kaggle.com/c/osic-pulmonary-fibrosis-progression/data (dataset source) + - https://www.kaggle.com/datasets/sandorkonya/ct-lung-heart-trachea-segmentation (segmentation source) + Please cite it if you use this dataset for a publication. + """ + image_paths, gt_paths = _get_osic_pulmofib_paths(path=path, download=download) + + if resize_inputs: + resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False} + kwargs, patch_shape = util.update_kwargs_for_resize_trafo( + kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs + ) + + dataset = torch_em.default_segmentation_dataset( + raw_paths=image_paths, + raw_key="data", + label_paths=gt_paths, + label_key="data", + patch_shape=patch_shape, + **kwargs + ) + + return dataset + + +def get_osic_pulmofib_loader( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, ...], + batch_size: int, + resize_inputs: bool = False, + download: bool = False, + **kwargs +): + """Dataloader for segmentation of lung, heart and trachea in CT scans. See `get_osic_pulmofib_dataset` for details. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_osic_pulmofib_dataset( + path=path, patch_shape=patch_shape, resize_inputs=resize_inputs, download=download, **ds_kwargs + ) + loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) + return loader diff --git a/torch_em/data/datasets/medical/panorama.py b/torch_em/data/datasets/medical/panorama.py new file mode 100644 index 00000000..8016885f --- /dev/null +++ b/torch_em/data/datasets/medical/panorama.py @@ -0,0 +1,182 @@ +"""The PANORAMA dataset contains annotation for PDAC lesion, veins, arteries, pancreas parenchyma, +pancreatic duct and common bile duct segmentation in CT scans. + +The dataset is from the PANORAMA challenge: https://panorama.grand-challenge.org/. + +NOTE: The latest information for the label legends are located at: +https://github.com/DIAGNijmegen/panorama_labels#label-legend. +The label legends are described as follows: +- background: 0 +- PDAC lesion: 1 +- veins: 2 +- arteries: 3 +- pancreas parenchyma: 4 +- pancreatic duct: 5 +- common bile duct: 6 + +This dataset is from the article: https://doi.org/10.5281/zenodo.10599559 +Please cite it if you use this dataset in your research. +""" + +import os +import shutil +import subprocess +from glob import glob +from natsort import natsorted +from typing import Union, Tuple, Optional, Literal, List + +from torch.utils.data import Dataset, DataLoader + +import torch_em + +from .. import util + + +URLS = { + "batch_1": "https://zenodo.org/records/13715870/files/batch_1.zip", + "batch_2": "https://zenodo.org/records/13742336/files/batch_2.zip", + "batch_3": "https://zenodo.org/records/11034011/files/batch_3.zip", + "batch_4": "https://zenodo.org/records/10999754/files/batch_4.zip", +} + +CHECKSUMS = { + "batch_1": "aff39b6347650d6c7457adf7a04bfb0a651ab6ecd33676ff109bdab17bc41cff", + "batch_2": "db6353a2c1c565c8bf084bd4fe1512fd6020b7675a1c9ab61b9a13d72a9fe76c", + "batch_3": "c1d71b40948edc36f795a7801cc79000082df8d365c48574af50b36516d64cee", + "batch_4": "3b5341af79c2cc8b8a9fa3ab7a6cfa8fedf694538a3d6be97c18e5c82be4d9d8", +} + + +def get_panorama_data(path: Union[os.PathLike, str], download: bool = False): + """Download the PANORAMA data. + + Args: + path: Filepath to a folder where the data is downloaded for further processing. + download: Whether to download the data if it is not present. + """ + os.makedirs(path, exist_ok=True) + + data_path = os.path.join(path, "volumes") + label_path = os.path.join(path, "labels") + if os.path.exists(data_path) and os.path.exists(label_path): + return + + print("PANORAMA is a large dataset. I might take a while to download the volumes and respective labels.") + + # Download the label volumes. + subprocess.call( + ["git", "clone", "--quiet", "https://github.com/DIAGNijmegen/panorama_labels", label_path] + ) + + def _move_batch_data_to_root(batch): + if batch in ["batch_3", "batch_4"]: + batch_dir = os.path.join(data_path, batch) + + for fpath in glob(os.path.join(batch_dir, "*.nii.gz")): + shutil.move(src=fpath, dst=data_path) + + if os.path.exists(batch_dir): + shutil.rmtree(batch_dir) + + # Download the input volumes. + for batch in URLS.keys(): + zip_path = os.path.join(path, f"{batch}.zip") + util.download_source(path=zip_path, url=URLS[batch], download=download, checksum=CHECKSUMS[batch]) + util.unzip(zip_path=zip_path, dst=data_path) + _move_batch_data_to_root(batch) + + +def get_panorama_paths( + path: Union[os.PathLike, str], + annotation_choice: Optional[Literal["manual", "automatic"]] = None, + download: bool = False +) -> Tuple[List[str], List[str]]: + """Get paths to the PANORAMA data. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + annotation_choice: The source of annotation. + download: Whether to download the data if it is not present. + + Returns: + List of filepaths for the image data. + List of filepaths for the label data. + """ + get_panorama_data(path, download) + + if annotation_choice is None: + annotation_choice = "*" + label_paths = natsorted(glob(os.path.join(path, "labels", f"{annotation_choice}_labels", "*.nii.gz"))) + raw_dir = os.path.join(path, "volumes") + raw_paths = [ + os.path.join(raw_dir, os.path.basename(fpath).replace(".nii.gz", "_0000.nii.gz")) for fpath in label_paths + ] + + # NOTE: the label "100051_00001.nii.gz" returns the error: 'nibabel.filebasedimages.ImageFileError: Empty file' + # We simply do not consider the sample (and correspondign labels) for the dataset. + for rpath, lpath in zip(raw_paths, label_paths): + if rpath.find("100051_00001") != -1: + raw_paths.remove(rpath) + + if lpath.find("100051_00001") != -1: + label_paths.remove(lpath) + + assert len(raw_paths) == len(label_paths) + + return raw_paths, label_paths + + +def get_panorama_dataset( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, ...], + annotation_choice: Optional[Literal["manual", "automatic"]] = None, + download: bool = False, **kwargs +) -> Dataset: + """Get the PANORAMA dataset for pancreatic lesion (and other structures) segmentation. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + patch_shape: The patch shape to use for training. + annotation_choice: The source of annotation. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The segmentation dataset. + """ + raw_paths, label_paths = get_panorama_paths(path, annotation_choice, download) + + return torch_em.default_segmentation_dataset( + raw_paths=raw_paths, + raw_key="data", + label_paths=label_paths, + label_key="data", + is_seg_dataset=True, + patch_shape=patch_shape, + **kwargs + ) + + +def get_panorama_loader( + path: Union[os.PathLike, str], + batch_size: int, + patch_shape: Tuple[int, ...], + annotation_choice: Optional[Literal["manual", "automatic"]] = None, + download: bool = False, **kwargs +) -> DataLoader: + """Get the PANORAMA dataloader for pancreatic lesion (and other structures) segmentation. + + Args: + path: Filepath to a folder where the downloaded data will be saved. + batch_size: The batch size for training. + patch_shape: The patch shape to use for training. + annotation_choice: The source of annotation. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. + + Returns: + The DataLoader. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_panorama_dataset(path, patch_shape, annotation_choice, download, **ds_kwargs) + return torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/medical/papila.py b/torch_em/data/datasets/medical/papila.py new file mode 100644 index 00000000..8a030b8f --- /dev/null +++ b/torch_em/data/datasets/medical/papila.py @@ -0,0 +1,138 @@ +import os +from glob import glob +from tqdm import tqdm +from pathlib import Path +from typing import Union, Tuple, Literal + +import numpy as np +from skimage import draw +import imageio.v3 as imageio + +import torch_em + +from .. import util + + +URL = "https://figshare.com/ndownloader/files/35013982" +CHECKSUM = "15b053dff496bc8e53eb8a8d0707ef73ba3d56c988eea92b65832c9c82852a7d" + + +def get_papila_data(path, download): + os.makedirs(path, exist_ok=True) + + data_dir = os.path.join(path, "PapilaDB-PAPILA-17f8fa7746adb20275b5b6a0d99dc9dfe3007e9f") + if os.path.exists(data_dir): + return data_dir + + zip_path = os.path.join(path, "papila.zip") + util.download_source(path=zip_path, url=URL, download=download, checksum=CHECKSUM) + util.unzip(zip_path=zip_path, dst=path) + + return data_dir + + +# contour_to_mask() functions taken from https://github.com/matterport/Mask_RCNN +def contour_to_mask(cont, img_shape): + """Return mask given a contour and the shape of image + """ + c = np.loadtxt(cont) + mask = np.zeros(img_shape[:-1], dtype=np.uint8) + rr, cc = draw.polygon(c[:, 1], c[:, 0]) + mask[rr, cc] = 1 + return mask + + +def _get_papila_paths(path, task, expert_choice, download): + data_dir = get_papila_data(path=path, download=download) + + image_paths = sorted(glob(os.path.join(data_dir, "FundusImages", "*.jpg"))) + + gt_dir = os.path.join(data_dir, "ground_truth") + os.makedirs(gt_dir, exist_ok=True) + + patient_ids = [Path(image_path).stem for image_path in image_paths] + + input_shape = (1934, 2576, 3) # shape of the input images + gt_paths = [] + for patient_id in tqdm(patient_ids, desc=f"Converting contours to segmentations for '{expert_choice}'"): + gt_contours = sorted( + glob(os.path.join(data_dir, "ExpertsSegmentations", "Contours", f"{patient_id}_{task}_{expert_choice}.txt")) + ) + + for gt_contour in gt_contours: + tmp_task = Path(gt_contour).stem.split("_")[1] + gt_path = os.path.join(gt_dir, f"{patient_id}_{tmp_task}_{expert_choice}.tif") + gt_paths.append(gt_path) + if os.path.exists(gt_path): + continue + + semantic_labels = contour_to_mask(cont=gt_contour, img_shape=input_shape) + imageio.imwrite(gt_path, semantic_labels) + + return image_paths, gt_paths + + +def get_papila_dataset( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int], + task: Literal["cup", "disc"] = "disc", + expert_choice: Literal["exp1", "exp2"] = "exp1", + resize_inputs: bool = False, + download: bool = False, + **kwargs +): + """Dataset for segmentation of optic cup and optic disc in fundus images. + + The database is located at https://figshare.com/articles/dataset/PAPILA/14798004/2 + + The dataset is from Kovalyk et al. - https://doi.org/10.1038/s41597-022-01388-1. + Please cite it if you use this dataset for a publication. + """ + assert expert_choice in ["exp1", "exp2"], f"'{expert_choice}' is not a valid expert choice." + assert task in ["cup", "disc"], f"'{task}' is not a valid task." + + image_paths, gt_paths = _get_papila_paths(path=path, task=task, expert_choice=expert_choice, download=download) + + if resize_inputs: + resize_kwargs = {"patch_shape": patch_shape, "is_rgb": True} + kwargs, patch_shape = util.update_kwargs_for_resize_trafo( + kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs + ) + + dataset = torch_em.default_segmentation_dataset( + raw_paths=image_paths, + raw_key=None, + label_paths=gt_paths, + label_key=None, + patch_shape=patch_shape, + is_seg_dataset=False, + **kwargs + ) + + return dataset + + +def get_papila_loader( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int], + batch_size: int, + task: Literal["cup", "disc"] = "disc", + expert_choice: Literal["exp1", "exp2"] = "exp1", + resize_inputs: bool = False, + download: bool = False, + **kwargs +): + """Dataloader for segmentation of optic cup and optic disc in fundus images. See `get_papila_dataset` for details. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_papila_dataset( + path=path, + patch_shape=patch_shape, + task=task, + expert_choice=expert_choice, + resize_inputs=resize_inputs, + download=download, + **ds_kwargs + ) + loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) + return loader diff --git a/torch_em/data/datasets/medical/pengwin.py b/torch_em/data/datasets/medical/pengwin.py new file mode 100644 index 00000000..c940af63 --- /dev/null +++ b/torch_em/data/datasets/medical/pengwin.py @@ -0,0 +1,162 @@ +"""The PENGWIN dataset contains annotation for pelvic bone fracture and +fragments in CT and X-Ray images. + +This dataset is from the challenge: https://pengwin.grand-challenge.org/pengwin/. +This dataset is related to the publication: https://doi.org/10.1007/978-3-031-43996-4_30. +Please cite them if you use this dataset for your publication. +""" + +import os +from glob import glob +from natsort import natsorted +from typing import Union, Tuple, Literal, List + +from torch.utils.data import Dataset, DataLoader + +import torch_em + +from .. import util + + +URLS = { + "CT": [ + "https://zenodo.org/records/10927452/files/PENGWIN_CT_train_images_part1.zip", # inputs part 1 + "https://zenodo.org/records/10927452/files/PENGWIN_CT_train_images_part2.zip", # inputs part 2 + "https://zenodo.org/records/10927452/files/PENGWIN_CT_train_labels.zip", # labels + ], + "X-Ray": ["https://zenodo.org/records/10913196/files/train.zip"] +} + +CHECKSUMS = { + "CT": [ + "e2e9f99798960607ffced1fbdeee75a626c41bf859eaf4125029a38fac6b7609", # inputs part 1 + "19f3cdc5edd1daf9324c70f8ba683eed054f6ed8f2b1cc59dbd80724f8f0bbb2", # inputs part 2 + "c4d3857e02d3ee5d0df6c8c918dd3cf5a7c9419135f1ec089b78215f37c6665c" # labels + ], + "X-Ray": ["48d107979eb929a3c61da4e75566306a066408954cf132907bda570f2a7de725"] +} + +TARGET_DIRS = { + "CT": ["CT/images", "CT/images", "CT/labels"], + "X-Ray": ["X-Ray"] +} + +MODALITIES = ["CT", "X-Ray"] + + +def get_pengwin_data( + path: Union[os.PathLike, str], modality: Literal["CT", "X-Ray"], download: bool = False +) -> str: + """Download the PENGWIN dataset. + + Args: + path: Filepath to a folder where the data is downloaded for further processing. + modality: The choice of modality for inputs. + download: Whether to download the data if it is not present. + + Returns: + Filepath where the data is downlaoded. + """ + if not isinstance(modality, str) and modality in MODALITIES: + raise ValueError(f"'{modality}' is not a valid modality. Please choose from {MODALITIES}.") + + data_dir = os.path.join(path, "data") + if os.path.exists(os.path.join(data_dir, modality)): + return data_dir + + os.makedirs(path, exist_ok=True) + + for url, checksum, dst_dir in zip(URLS[modality], CHECKSUMS[modality], TARGET_DIRS[modality]): + zip_path = os.path.join(path, os.path.split(url)[-1]) + util.download_source(path=zip_path, url=url, download=download, checksum=checksum) + util.unzip(zip_path=zip_path, dst=os.path.join(data_dir, dst_dir)) + + return data_dir + + +def get_pengwin_paths( + path: Union[os.PathLike, str], modality: Literal["CT", "X-Ray"], download: bool = False +) -> Tuple[List[str], List[str]]: + """Get paths to the PENGWIN data. + + Args: + path: Filepath to a folder where the data is downloaded for further processing. + modality: The choice of modality for inputs. + download: Whether to download the data if it is not present. + + Returns: + List of filepaths for the image data. + List of filepaths for the label data. + """ + data_dir = get_pengwin_data(path=path, modality=modality, download=download) + + if modality == "CT": + image_paths = natsorted(glob(os.path.join(data_dir, modality, "images", "*.mha"))) + gt_paths = natsorted(glob(os.path.join(data_dir, modality, "labels", "*.mha"))) + else: # X-Ray + base_dir = os.path.join(data_dir, modality, "train") + image_paths = natsorted(glob(os.path.join(base_dir, "input", "images", "*.tif"))) + gt_paths = natsorted(glob(os.path.join(base_dir, "output", "images", "*.tif"))) + + return image_paths, gt_paths + + +def get_pengwin_dataset( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, ...], + modality: Literal["CT", "X-Ray"], + resize_inputs: bool = False, + download: bool = False, + **kwargs +) -> Dataset: + """Get the PENGWIN dataset for pelvic fracture segmentation. + + Args: + path: Filepath to a folder where the data is downloaded for further processing. + patch_shape: The patch shape to use for training. + modality: The choice of modality for inputs. + resize_inputs: Whether to resize inputs to the desired patch shape. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The segmentation dataset. + """ + image_paths, gt_paths = get_pengwin_paths(path=path, modality=modality, download=download) + + if resize_inputs: + resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False} + kwargs, patch_shape = util.update_kwargs_for_resize_trafo( + kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs + ) + + return torch_em.default_segmentation_dataset( + raw_paths=image_paths, raw_key=None, label_paths=gt_paths, label_key=None, patch_shape=patch_shape, **kwargs + ) + + +def get_pengwin_loader( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, ...], + batch_size: int, + modality: Literal["CT", "X-Ray"], + resize_inputs: bool = False, + download: bool = False, + **kwargs +) -> DataLoader: + """Get the PENGWIN dataloader for pelvic fracture segmentation. + + Args: + path: Filepath to a folder where the data is downloaded for further processing. + patch_shape: The patch shape to use for training. + modality: The choice of modality for inputs. + resize_inputs: Whether to resize inputs to the desired patch shape. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. + + Returns: + The DataLoader. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_pengwin_dataset(path, patch_shape, modality, resize_inputs, download, **ds_kwargs) + return torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/medical/piccolo.py b/torch_em/data/datasets/medical/piccolo.py new file mode 100644 index 00000000..0b8738f5 --- /dev/null +++ b/torch_em/data/datasets/medical/piccolo.py @@ -0,0 +1,105 @@ +import os +from glob import glob +from natsort import natsorted +from typing import Union, Tuple, Literal + +import torch_em + +from .. import util + + +def get_piccolo_data(path, download): + """The database is located at: + - https://www.biobancovasco.bioef.eus/en/Sample-and-data-e-catalog/Databases/PD178-PICCOLO-EN1.html + + Follow the instructions below to get access to the dataset. + - Visit the attached website above + - Fill up the access request form: https://labur.eus/EzJUN + - Send an email to Basque Biobank at solicitudes.biobancovasco@bioef.eus, requesting access to the dataset. + - The team will request you to follow-up with some formalities. + - Then, you will gain access to the ".rar" file. + - Finally, provide the path where the rar file is stored, and you should be good to go. + """ + if download: + raise NotImplementedError( + "Automatic download is not possible for this dataset. See 'get_piccolo_data' for details." + ) + + data_dir = os.path.join(path, r"piccolo dataset-release0.1") + if os.path.exists(data_dir): + return data_dir + + rar_file = os.path.join(path, r"piccolo dataset_widefield-release0.1.rar") + if not os.path.exists(rar_file): + raise FileNotFoundError( + "You must download the PICCOLO dataset from the Basque Biobank, see 'get_piccolo_data' for details." + ) + + util.unzip_rarfile(rar_path=rar_file, dst=path, remove=False) + return data_dir + + +def _get_piccolo_paths(path, split, download): + data_dir = get_piccolo_data(path=path, download=download) + + split_dir = os.path.join(data_dir, split) + + image_paths = natsorted(glob(os.path.join(split_dir, "polyps", "*"))) + gt_paths = natsorted(glob(os.path.join(split_dir, "masks", "*"))) + + return image_paths, gt_paths + + +def get_piccolo_dataset( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int], + split: Literal["train", "validation", "test"], + resize_inputs: bool = False, + download: bool = False, + **kwargs +): + """Dataset for polyp segmentation in narrow band imaging colonoscopy images. + + This dataset is from Sánchez-Peralta et al. - https://doi.org/10.3390/app10238501. + To access the dataset, see `get_piccolo_data` for details. + + Please cite it if you use this data in a publication. + """ + image_paths, gt_paths = _get_piccolo_paths(path=path, split=split, download=download) + + if resize_inputs: + resize_kwargs = {"patch_shape": patch_shape, "is_rgb": True} + kwargs, patch_shape = util.update_kwargs_for_resize_trafo( + kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs + ) + + dataset = torch_em.default_segmentation_dataset( + raw_paths=image_paths, + raw_key=None, + label_paths=gt_paths, + label_key=None, + patch_shape=patch_shape, + is_seg_dataset=False, + **kwargs + ) + return dataset + + +def get_piccolo_loader( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int], + batch_size: int, + split: Literal["train", "validation", "test"], + resize_inputs: bool = False, + download: bool = False, + **kwargs +): + """Dataloader for polyp segmentation in narrow band imaging colonoscopy images. + See `get_piccolo_dataset` for details. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_piccolo_dataset( + path=path, patch_shape=patch_shape, split=split, resize_inputs=resize_inputs, download=download, **ds_kwargs + ) + loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) + return loader diff --git a/torch_em/data/datasets/medical/plethora.py b/torch_em/data/datasets/medical/plethora.py new file mode 100644 index 00000000..ea8301e1 --- /dev/null +++ b/torch_em/data/datasets/medical/plethora.py @@ -0,0 +1,183 @@ +import os +from glob import glob +from tqdm import tqdm +from pathlib import Path +from natsort import natsorted +from typing import Union, Tuple +from urllib.parse import urljoin + +import numpy as np +import pandas as pd + +import torch_em + +from .. import util + + +BASE_URL = "https://wiki.cancerimagingarchive.net/download/attachments/68551327/" + + +URL = { + "image": urljoin(BASE_URL, "NSCLC-Radiomics-OriginalCTs.tcia"), + "gt": { + "thoracic": urljoin( + BASE_URL, "PleThora%20Thoracic_Cavities%20June%202020.zip?version=1&modificationDate=1593202695428&api=v2" + ), + "pleural_effusion": urljoin( + BASE_URL, "PleThora%20Effusions%20June%202020.zip?version=1&modificationDate=1593202778373&api=v2" + ) + } +} + + +CHECKSUMS = { + "image": None, + "gt": { + "thoracic": "6dfcb60e46c7b0ccf240bc5d13acb1c45c8d2f4922223f7b2fbd5e37acff2be0", + "pleural_effusion": "5dd07c327fb5723c5bbb48f2a02d7f365513d3ad136811fbe4def330ef2d7f6a" + } +} + + +ZIPFILES = { + "thoracic": "thoracic.zip", + "pleural_effusion": "pleural_effusion.zip" +} + + +def get_plethora_data(path, task, download): + os.makedirs(path, exist_ok=True) + + image_dir = os.path.join(path, "data", "images") + gt_dir = os.path.join(path, "data", "gt", "Thoracic_Cavities" if task == "thoracic" else "Effusions") + csv_path = os.path.join(path, "plethora_images") + if os.path.exists(image_dir) and os.path.exists(gt_dir): + return image_dir, gt_dir, Path(csv_path).with_suffix(".csv") + + # let's download dicom files from the tcia manifest + tcia_path = os.path.join(path, "NSCLC-Radiomics-OriginalCTs.tcia") + util.download_source_tcia(path=tcia_path, url=URL["image"], dst=image_dir, csv_filename=csv_path, download=download) + + # let's download the segmentations from zipfiles + zip_path = os.path.join(path, ZIPFILES[task]) + util.download_source( + path=zip_path, url=URL["gt"][task], download=download, checksum=CHECKSUMS["gt"][task] + ) + util.unzip(zip_path=zip_path, dst=os.path.join(path, "data", "gt")) + + return image_dir, gt_dir, Path(csv_path).with_suffix(".csv") + + +def _assort_plethora_inputs(image_dir, gt_dir, task, csv_path): + import nibabel as nib + import pydicom as dicom + + df = pd.read_csv(csv_path) + + task_gt_dir = os.path.join(gt_dir, ) + + os.makedirs(os.path.join(image_dir, "preprocessed"), exist_ok=True) + os.makedirs(os.path.join(task_gt_dir, "preprocessed"), exist_ok=True) + + # let's get all the series uid of the volumes downloaded and spot their allocated subject id + all_series_uid_dirs = glob(os.path.join(image_dir, "1.3*")) + image_paths, gt_paths = [], [] + for series_uid_dir in tqdm(all_series_uid_dirs): + series_uid = os.path.split(series_uid_dir)[-1] + subject_id = pd.Series.to_string(df.loc[df["Series UID"] == series_uid]["Subject ID"])[-9:] + + try: + gt_path = glob(os.path.join(task_gt_dir, subject_id, "*.nii.gz"))[0] + except IndexError: + # - some patients do not have "Thoracic_Cavities" segmentation + print(f"The ground truth is missing for subject '{subject_id}'") + continue + + assert os.path.exists(gt_path) + + vol_path = os.path.join(image_dir, "preprocessed", f"{subject_id}.nii.gz") + neu_gt_path = os.path.join(task_gt_dir, "preprocessed", os.path.split(gt_path)[-1]) + + image_paths.append(vol_path) + gt_paths.append(neu_gt_path) + if os.path.exists(vol_path) and os.path.exists(neu_gt_path): + continue + + # the individual slices for the inputs need to be merged into one volume. + if not os.path.exists(vol_path): + all_dcm_slices = natsorted(glob(os.path.join(series_uid_dir, "*.dcm"))) + all_slices = [] + for dcm_path in all_dcm_slices: + dcmfile = dicom.dcmread(dcm_path) + img = dcmfile.pixel_array + all_slices.append(img) + + volume = np.stack(all_slices) + volume = volume.transpose(1, 2, 0) + nii_vol = nib.Nifti1Image(volume, np.eye(4)) + nii_vol.header.get_xyzt_units() + nii_vol.to_filename(vol_path) + + # the ground truth needs to be aligned as the inputs, let's take care of that. + gt = nib.load(gt_path) + gt = gt.get_fdata() + gt = gt.transpose(2, 1, 0) # aligning w.r.t the inputs + gt = np.flip(gt, axis=(0, 1)) + + gt = gt.transpose(1, 2, 0) + gt_nii_vol = nib.Nifti1Image(gt, np.eye(4)) + gt_nii_vol.header.get_xyzt_units() + gt_nii_vol.to_filename(neu_gt_path) + + return image_paths, gt_paths + + +def _get_plethora_paths(path, task, download): + image_dir, gt_dir, csv_path = get_plethora_data(path=path, task=task, download=download) + image_paths, gt_paths = _assort_plethora_inputs(image_dir=image_dir, gt_dir=gt_dir, task=task, csv_path=csv_path) + return image_paths, gt_paths + + +def get_plethora_dataset( + path: Union[os.PathLike, str], + task: str, + patch_shape: Tuple[int, ...], + resize_inputs: bool = False, + download: bool = False, + **kwargs +): + image_paths, gt_paths = _get_plethora_paths(path=path, task=task, download=download) + + if resize_inputs: + resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False} + kwargs, patch_shape = util.update_kwargs_for_resize_trafo( + kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs + ) + + dataset = torch_em.default_segmentation_dataset( + raw_paths=image_paths, + raw_key="data", + label_paths=gt_paths, + label_key="data", + patch_shape=patch_shape, + **kwargs + ) + + return dataset + + +def get_plethora_loader( + path: Union[os.PathLike, str], + task: str, + patch_shape: Tuple[int, ...], + batch_size: int, + resize_inputs: bool = False, + download: bool = False, + **kwargs +): + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_plethora_dataset( + path=path, task=task, patch_shape=patch_shape, resize_inputs=resize_inputs, download=download, **ds_kwargs + ) + loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) + return loader diff --git a/torch_em/data/datasets/medical/sa_med2d.py b/torch_em/data/datasets/medical/sa_med2d.py new file mode 100644 index 00000000..7e2a6eda --- /dev/null +++ b/torch_em/data/datasets/medical/sa_med2d.py @@ -0,0 +1,436 @@ +import os +import random +from tqdm import tqdm +from pathlib import Path +from typing import Union, Tuple, Optional + +import json +import numpy as np +import imageio.v3 as imageio +from skimage.segmentation import relabel_sequential +from sklearn.model_selection import train_test_split + +import torch_em + +from .. import util +from ..light_microscopy.neurips_cell_seg import to_rgb + + +DATASET_NAMES = [ + "ACDC", # cardiac structures in MRI + "AMOS2022", # multi-organ in CT + "ATM2022", # pulmonary airway in CT + "AbdomenCT1K", # abdominal organ in CT + "ASC18", # left atrium in LGE-MRI + "COSMOS2022", # cartoid vessel wall in MRI + "BTCV", # organs in CT + "BTCV_Cervix", # cervical organs in CT + "BraTS2013", # brain tumour in MRI + "BraTS2015", # brain tumour in MRI + "BraTS2018", # brain tumour in MRI + "BraTS2019", # brain tumour in MRI + "BraTS2020", # brain tumour in MRI + "BraTS2021", # brain tumour in MRI + "Brain_PTM", # white matter tracts in brain MRI + "CAD_PE", # pulmonary embolism in CTPA + "CHAOS_Task_4", # liver, kidney and spleen in T1W-MR + "CMRxMotions", # cardiac structures in CMR + "COVID19CTscans", # lung and covid infection in CT + "COVID-19-20", # covid infection in CT + "covid_19_ct_cxr", # lung in CXR + "crass", # clavicle in CXR + "CTPelvic1k", # pelvic bones in CT + "CTSpine1K_Full", # spinal vertebrae in CT + "cvc_clinicdb", # polyp in colonoscopy + "Chest_Image_Pneum", # pneumonia in CXR + "cranium", # cranial segmentation in CT + "CrossMoDA21", # vestibular schwannoma and cochlea segmentation in T1-CE and TI-HR MRI + "CrossMoDA22", # vestibular schwannoma and cochlea segmentation in T1-CE and TI-HR MRI + "EMIDEC", # cardiac structures in MRI + "endovis15", # polyp in endoscopy + "FLARE21", # abdominal organs in CT + "FLARE22", # abdominal organs in CT + "fusc2021", # skin lesion in dermoscopy + "hvsmr_2016", # blood pool and ventricular myocardium in CMR + "Heart_Seg_MRI", # heart in MRI + "ichallenge_adam_task2", # optic disc in fundus images + "PALM19", # optic disc in fundus images + "gamma", # optic disk, optic cup and ring in fundus images + "gamma3", # optic disk, optic cup and ring in fundus images + "ISLES_SPES", # ischemic stroke lesion in brain MRI + "ISLES_SISS", # ischemic stroke lesion in brain MRI + "ISLES2016", # ischemic stroke lesion in brain MRI + "ISLES2017", # ischemic stroke lesion in brain MRI + "ISLES2018", # ischemic stroke in brain CT + "ISLES2022", # ischemic stroke in brain MRI + "Instance22", # intracranial hemorrhage in nc-ct + "KiTS", # kidney and kidney tumor in CT + "KiTS2021", # kidney and kidney tumor in CT + "LNDb", # lung nodules in thoracic CT + "LUNA16", # lung and trachea in thoracic CT + "LongitudinalMultipleSclerosisLesionSegmentation", # MS lesion in FLAIR-MRI + "mnms2", # cardiac structures in MRI + "MMWHS", # whole heart in CT + "BrainTumour", # brain tumor in MRI + "MSD_Heart", # heart in MRI + "MSD_Liver", # liver in CT + "MSD_Prostate", # prostate in ADC-MRI + "MSD_Lung", # lung tumour in CT + "MSD_Pancreas", # pancreas in CT + "MSD_HepaticVessel", # hepatic vessel in CT + "MSD_Spleen", # spleen in CT + "MSD_Colon", # colon in CT + "CT_ORG", # multiple organ in CT + "picai_baseline", # prostate cancer in MRI + "picai_semi", # prostate cancer in MRI + "Promise09", # prostate in MRI + "PROMISE12", # prostate in MRI + "Parse22", # pulmonary atery in CT + "chest_x_ray_images_with_pneumothorax_masks", # pneumothorax in CXR + "Prostate_MRI_Segmentation_Dataset", # prostate in MRI + "Pulmonary_Chest_X-Ray_Abnormalities_seg", # lung in CXR + "QUBIQ2020", # kidney in CT + "StructSeg2019_subtask1", # OAR in H&N CT + "StructSeg2019_subtask2", # OAR in chest CT + "Totalsegmentator_dataset", # organ in CT + "ultrasound_nerve_segmentation", # nerve in US + "VESSEL2012", # lung in CT + "VerSe20", # vertebrae in CT + "VerSe19", # vertebrae in CT + "WORD", # abdominal organs in CT + "autoPET", # lesions in PET and CT + "braimMRI", # brain lesions in MRI + "breast_ultrasound_images_dataset", # breast cancer in US + "kvasircapsule_seg", # polyp in endoscopy + "sz_cxr", # lungs in CXR + "EndoVis_2017_RIS", # instruments in endoscopy + "kvasir_seg", # polyp in endoscopy + "isic2018_task1", # skin lesions in dermoscopy + "isic2017_task1", # skin lesions in dermoscopy + "isic2016_task1", # skin lesions in dermoscopy +] + +MODALITY_NAMES = [ + # CT modalities + 'ct_00', 'ct_cbf', 'ct_cbv', 'ct_mtt', 'ct_tmax', + # RGB0-image modalities + 'dermoscopy_00', 'endoscopy_00', 'fundus_photography', + # MRI modalities + 'mr_00', 'mr_adc', 'mr_cbf', 'mr_cbv', 'mr_cmr', 'mr_dwi', + 'mr_flair', 'mr_hbv', 'mr_lge', 'mr_mprage', 'mr_mtt', + 'mr_pd', 'mr_rcbf', 'mr_rcbv', 'mr_t1', 'mr_t1c', 'mr_t1ce', + 'mr_t1gd', 'mr_t1w', 'mr_t2', 'mr_t2w', 'mr_tmax', 'mr_ttp', + # mono-channel modalities + 'pet_00', 'ultrasound_00', 'x_ray' +] + + +# datasets under 1000 samples +SMALL_DATASETS = [ + "crass", "covid_19_ct_cxr", "cvc_clinicdb", "cranium", "CrossMoDA21", "EMIDEC", + "endovis15", "fusc2021", "Heart_Seg_MRI", "ichallenge_adam_task2", "gamma", "gamma3", + "Instance22", "LNDb", "MSD_Heart", "MSD_Prostate", "MSD_Spleen", "MSD_Colon", + "picai_baseline", "picai_semi", "Promise09", "PROMISE12", "Pulmonary_Chest_X-Ray_Abnormalities_seg", + "QUBIQ2020", "breast_ultrasound_images_dataset", "kvasircapsule_seg", "sz_cxr", "kvasir_seg" +] + + +def get_sa_med2d_data(path, download): + """This function describes the download functionality and ensures your data has been downloaded in expected format. + + The dataset is located at https://huggingface.co/datasets/OpenGVLab/SA-Med2D-20M. + + There are two ways of downloading the dataset: + 1. wget (Recommended): + - There are 10 `z.*` files and 1 `.zip` file which needs to be installed together. + - Go to `Files` -> download each file individually using `wget `. Below mentioned are the links: + - https://huggingface.co/datasets/OpenGVLab/SA-Med2D-20M/resolve/main/raw/SA-Med2D-16M.z01 + - https://huggingface.co/datasets/OpenGVLab/SA-Med2D-20M/resolve/main/raw/SA-Med2D-16M.z02 + - https://huggingface.co/datasets/OpenGVLab/SA-Med2D-20M/resolve/main/raw/SA-Med2D-16M.z03 + - https://huggingface.co/datasets/OpenGVLab/SA-Med2D-20M/resolve/main/raw/SA-Med2D-16M.z04 + - https://huggingface.co/datasets/OpenGVLab/SA-Med2D-20M/resolve/main/raw/SA-Med2D-16M.z05 + - https://huggingface.co/datasets/OpenGVLab/SA-Med2D-20M/resolve/main/raw/SA-Med2D-16M.z06 + - https://huggingface.co/datasets/OpenGVLab/SA-Med2D-20M/resolve/main/raw/SA-Med2D-16M.z07 + - https://huggingface.co/datasets/OpenGVLab/SA-Med2D-20M/resolve/main/raw/SA-Med2D-16M.z08 + - https://huggingface.co/datasets/OpenGVLab/SA-Med2D-20M/resolve/main/raw/SA-Med2D-16M.z09 + - https://huggingface.co/datasets/OpenGVLab/SA-Med2D-20M/resolve/main/raw/SA-Med2D-16M.z10 + - https://huggingface.co/datasets/OpenGVLab/SA-Med2D-20M/resolve/main/raw/SA-Med2D-16M.zip + + 2. Using Git Large File Storage (lfs): + - `git lfs install` (Make sure you have git-lfs installed (https://git-lfs.com)) + - `git clone https://huggingface.co/datasets/OpenGVLab/SA-Med2D-20M` + - This step takes several hours, make sure you have a consistent internet and sufficient space. + + Once you have downloaded the archives, you need to unzip the splitted-up zip files: + - For Windows: decompress SA-Med2D-16M.zip to automatically extract the other volumes together. + - For Linux: + - `zip SA-Med2D-16M.zip SA-Med2D-16M.z0* SA-Med2D-16M.z10 -s=0 --out {full}.zip` + - NOTE: deflates the entire dataset to ensemble into one zip, make sure you have ~1.5TB free space. + - `unzip {full}.zip` + - NOTE: there are >4M images paired with >19M ground-truth masks. unzipping takes a lot of inodes and time. + """ + if download: + print("Download is not supported, as the data is huge and takes quite a while to download and extract.") + + data_dir = os.path.join(path, "SAMed2Dv1") + + # the first part is to ensure if the data has been unzipped in the expected data directory + msg = "The data directory is not found. " + msg += "Please ensure that you provide the path to the parent directory where the unzip operation took place. " + msg += "For example: `unzip -d /path/to/dir/`. Hence, the argument 'path' expects '/path/to/dir/'." + assert os.path.exists(data_dir), msg + + # next, let's investigate the presence of the json files + json_file = "SAMed2D_v1.json" + assert os.path.exists(os.path.join(data_dir, json_file)), f"The json file '{json_file}' is missing." + + json_file = "SAMed2D_v1_class_mapping_id.json" + assert os.path.exists(os.path.join(data_dir, json_file)), f"The json file '{json_file}' is missing." + + print("Looks like the dataset is ready to use.") + + return data_dir + + +def _assort_sa_med2d_data(data_dir): + with open(os.path.join(data_dir, "SAMed2D_v1.json")) as f: + data = json.load(f) + + image_files = list(data.keys()) + + gt_instances_dir = os.path.join(data_dir, "preprocessed_instances") + os.makedirs(gt_instances_dir, exist_ok=True) + + skipped_files = [] + for ifile in tqdm(image_files): + image_path = os.path.join(data_dir, ifile) + image_id = Path(image_path).stem + + gt_path = os.path.join(gt_instances_dir, f"{image_id}.tif") + if os.path.exists(gt_path): + continue + + # let's split different components + splits = image_id.split("--") + dataset = splits[1] + + # HACK: (SKIP) there are some known images which are pretty weird (binary brain masks as inputs) + if splits[2].find("brain-growth") != -1: + skipped_files.append(ifile) + continue + + # let's get the shape of the image + image = imageio.imread(image_path) + shape = image.shape if image.ndim == 2 else image.shape[:-1] + + # HACK: (SKIP) there are weird images which appear to be whole brain binary masks + if dataset == "Brain_PTM": + if len(np.unique(image)) == 2: # easy check for binary values in the input image + skipped_files.append(ifile) + continue + + # let's create an empty array and merge all segmentations into one + instances = np.zeros(shape, dtype="uint8") + for idx, gfile in enumerate(sorted(data[ifile]), start=1): + # HACK: (SKIP) we remove the segmentation of entire ventricular cavity in ACDC + if dataset == "ACDC": + if gfile.find("0003_000") != -1 and len(data[ifile]) > 1: # to avoid whole ventricular rois + continue + + per_gt = imageio.imread(os.path.join(data_dir, gfile)) + + # HACK: need to see if we can resize this inputs + if per_gt.shape != shape: + print("Skipping these images with mismatching ground-truth shapes.") + continue + + # HACK: (UPDATE) optic disk is mapped as 0, and background as 1 + if dataset == "ichallenge_adam_task2": + per_gt = (per_gt == 0).astype("uint8") # simply reversing the binary optic disc masks + + instances[per_gt > 0] = idx + + instances = relabel_sequential(instances)[0] + imageio.imwrite(gt_path, instances, compression="zlib") + + return skipped_files + + +def _create_splits_per_dataset(data_dir, json_file, skipped_files, val_fraction=0.1): + with open(os.path.join(data_dir, "SAMed2D_v1.json")) as f: + data = json.load(f) + + image_files = list(data.keys()) + + # now, get's group them data-wise and make splits per dataset + data_dict = {} + for image_file in image_files: + if image_file in skipped_files: + print("Skipping this file:", image_file) + continue + + _image_file = os.path.split(image_file)[-1] + splits = _image_file.split("--") + dataset = splits[1] + + if dataset in data_dict: + data_dict[dataset].append(_image_file) + else: + data_dict[dataset] = [_image_file] + + # next, let's make a train-val split out of the dataset and write them in a json file + train_dict, val_dict = {}, {} + for dataset, dfiles in data_dict.items(): + tr_split, val_split = train_test_split(dfiles, test_size=val_fraction) + train_dict[dataset] = tr_split + val_dict[dataset] = val_split + + fdict = {"train": train_dict, "val": val_dict} + with open(json_file, "w") as f: + json.dump(fdict, f) + + +def _get_split_wise_paths(data_dir, json_file, split, exclude_dataset, exclude_modality, n_fraction_per_dataset): + with open(json_file, "r") as f: + data = json.load(f) + + if exclude_dataset is not None and not isinstance(exclude_dataset, list): + exclude_dataset = [exclude_dataset] + + if exclude_modality is not None and not isinstance(exclude_modality, list): + exclude_modality = [exclude_modality] + + image_files = data[split] + image_paths, gt_paths = [], [] + for dfiles in image_files.values(): + splits = dfiles[0].split("--") + modality = splits[0] + dataset = splits[1] + + if exclude_dataset is not None and dataset in exclude_dataset: + continue + + if exclude_modality is not None and modality in exclude_modality: + continue + + if n_fraction_per_dataset is not None and dataset not in SMALL_DATASETS: + dfiles = random.sample(dfiles, k=int(n_fraction_per_dataset * len(dfiles))) + + per_dataset_ipaths = [os.path.join(data_dir, "images", fname) for fname in dfiles] + per_dataset_gpaths = [ + os.path.join(data_dir, "preprocessed_instances", f"{Path(fname).stem}.tif") for fname in dfiles + ] + + image_paths.extend(per_dataset_ipaths) + gt_paths.extend(per_dataset_gpaths) + + return image_paths, gt_paths + + +def _get_sa_med2d_paths(path, split, exclude_dataset, exclude_modality, n_fraction_per_dataset, download): + data_dir = get_sa_med2d_data(path=path, download=download) + + json_file = os.path.join(data_dir, "preprocessed_inputs.json") + if not os.path.exists(json_file): + skipped_files = _assort_sa_med2d_data(data_dir=data_dir) + _create_splits_per_dataset(data_dir=data_dir, json_file=json_file, skipped_files=skipped_files) + + image_paths, gt_paths = _get_split_wise_paths( + data_dir=data_dir, + json_file=json_file, + split=split, + exclude_dataset=exclude_dataset, + exclude_modality=exclude_modality, + n_fraction_per_dataset=n_fraction_per_dataset + ) + + return image_paths, gt_paths + + +def get_sa_med2d_dataset( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int], + split: str, + resize_inputs: bool = False, + exclude_dataset: Optional[Union[str, list]] = None, + exclude_modality: Optional[Union[str, list]] = None, + n_fraction_per_dataset: Optional[float] = None, + download: bool = False, + **kwargs +): + """Dataset for segmentation of various organs and structures in multiple medical imaging modalities. + + You should download the dataset yourself. See `get_sa_med2d_data` for details. + + The dataset is from Ye et al. - https://doi.org/10.48550/arXiv.2311.11969. + The dataset is curated in alignment with Cheng et al. - https://doi.org/10.48550/arXiv.2308.16184. + + Please cite it if you use it in a publication. + """ + image_paths, gt_paths = _get_sa_med2d_paths( + path=path, + split=split, + exclude_dataset=exclude_dataset, + exclude_modality=exclude_modality, + n_fraction_per_dataset=n_fraction_per_dataset, + download=download, + ) + + if resize_inputs: + resize_kwargs = {"patch_shape": patch_shape, "is_rgb": True} + kwargs, patch_shape = util.update_kwargs_for_resize_trafo( + kwargs=kwargs, + patch_shape=patch_shape, + resize_inputs=resize_inputs, + resize_kwargs=resize_kwargs, + ensure_rgb=to_rgb, + ) + + print("Creating the dataset for the SA-Med2D-20M dataset. This takes a bit of time.") + + dataset = torch_em.default_segmentation_dataset( + raw_paths=image_paths, + raw_key=None, + label_paths=gt_paths, + label_key=None, + patch_shape=patch_shape, + ndim=2, + with_channels=True, + is_seg_dataset=False, + verify_paths=False, + **kwargs + ) + + return dataset + + +def get_sa_med2d_loader( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int], + batch_size: int, + split: str, + resize_inputs: bool = False, + exclude_dataset: Optional[Union[str, list]] = None, + exclude_modality: Optional[Union[str, list]] = None, + n_fraction_per_dataset: Optional[float] = None, + download: bool = False, + **kwargs +): + """Dataloader for segmentation of various organs and structures in multiple medical imaging modalities. + See `get_sa_med2d_dataset` for details. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_sa_med2d_dataset( + path=path, + patch_shape=patch_shape, + split=split, + resize_inputs=resize_inputs, + exclude_dataset=exclude_dataset, + exclude_modality=exclude_modality, + n_fraction_per_dataset=n_fraction_per_dataset, + download=download, + **ds_kwargs + ) + print("Creating the dataloader for the SA-Med2D-20M dataset. This takes a bit of time.") + loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) + return loader diff --git a/torch_em/data/datasets/medical/sega.py b/torch_em/data/datasets/medical/sega.py new file mode 100644 index 00000000..3c82763b --- /dev/null +++ b/torch_em/data/datasets/medical/sega.py @@ -0,0 +1,158 @@ +import os +from glob import glob +from pathlib import Path +from natsort import natsorted +from typing import Union, Tuple, Optional, Literal + +import torch_em + +from .. import util + + +URL = { + "kits": "https://figshare.com/ndownloader/files/30950821", + "rider": "https://figshare.com/ndownloader/files/30950914", + "dongyang": "https://figshare.com/ndownloader/files/30950971" +} + +CHECKSUMS = { + "kits": "6c9c2ea31e5998348acf1c4f6683ae07041bd6c8caf309dd049adc7f222de26e", + "rider": "7244038a6a4f70ae70b9288a2ce874d32128181de2177c63a7612d9ab3c4f5fa", + "dongyang": "0187e90038cba0564e6304ef0182969ff57a31b42c5969d2b9188a27219da541" +} + +ZIPFILES = { + "kits": "KiTS.zip", + "rider": "Rider.zip", + "dongyang": "Dongyang.zip" +} + + +def get_sega_data(path, data_choice, download): + os.makedirs(path, exist_ok=True) + + data_choice = data_choice.lower() + + zip_fid = ZIPFILES[data_choice] + + data_dir = os.path.join(path, Path(zip_fid).stem) + if os.path.exists(data_dir): + return data_dir + + zip_path = os.path.join(path, zip_fid) + util.download_source( + path=zip_path, url=URL[data_choice], download=download, checksum=CHECKSUMS[data_choice], + ) + util.unzip(zip_path=zip_path, dst=path) + + return data_dir + + +def _get_sega_paths(path, data_choice, download): + if data_choice is None: + data_choices = URL.keys() + else: + if isinstance(data_choice, str): + data_choices = [data_choice] + + data_dirs = [get_sega_data(path=path, data_choice=data_choice, download=download) for data_choice in data_choices] + + image_paths, gt_paths = [], [] + for data_dir in data_dirs: + all_volumes_paths = glob(os.path.join(data_dir, "*", "*.nrrd")) + for volume_path in all_volumes_paths: + if volume_path.endswith(".seg.nrrd"): + gt_paths.append(volume_path) + else: + image_paths.append(volume_path) + + # now let's wrap the volumes to nifti format + fimage_dir = os.path.join(path, "data", "images") + fgt_dir = os.path.join(path, "data", "labels") + + os.makedirs(fimage_dir, exist_ok=True) + os.makedirs(fgt_dir, exist_ok=True) + + fimage_paths, fgt_paths = [], [] + for image_path, gt_path in zip(natsorted(image_paths), natsorted(gt_paths)): + fimage_path = os.path.join(fimage_dir, f"{Path(image_path).stem}.nii.gz") + fgt_path = os.path.join(fgt_dir, f"{Path(image_path).stem}.nii.gz") + + fimage_paths.append(fimage_path) + fgt_paths.append(fgt_path) + + if os.path.exists(fimage_path) and os.path.exists(fgt_path): + continue + + import nrrd + import numpy as np + import nibabel as nib + + image = nrrd.read(image_path)[0] + gt = nrrd.read(gt_path)[0] + + image_nifti = nib.Nifti2Image(image, np.eye(4)) + gt_nifti = nib.Nifti2Image(gt, np.eye(4)) + + nib.save(image_nifti, fimage_path) + nib.save(gt_nifti, fgt_path) + + return natsorted(fimage_paths), natsorted(fgt_paths) + + +def get_sega_dataset( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, ...], + data_choice: Optional[Literal["KiTS", "Rider", "Dongyang"]] = None, + resize_inputs: bool = False, + download: bool = False, + **kwargs +): + """Dataset for segmentation of aorta in computed tomography angiography (CTA) scans. + + This dataset is from Pepe et al. - https://doi.org/10.1007/978-3-031-53241-2 + Please cite it if you use this dataset for a publication. + """ + image_paths, gt_paths = _get_sega_paths(path=path, data_choice=data_choice, download=download) + + if resize_inputs: + resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False} + kwargs, patch_shape = util.update_kwargs_for_resize_trafo( + kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs, + ) + + dataset = torch_em.default_segmentation_dataset( + raw_paths=image_paths, + raw_key="data", + label_paths=gt_paths, + label_key="data", + patch_shape=patch_shape, + is_seg_dataset=True, + **kwargs + ) + + return dataset + + +def get_sega_loader( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, ...], + batch_size: int, + data_choice: Optional[Literal["KiTS", "Rider", "Dongyang"]] = None, + resize_inputs: bool = False, + download: bool = False, + **kwargs +): + """Dataloader for segmentation of aorta in CTA scans. See `get_sega_dataset` for details. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_sega_dataset( + path=path, + patch_shape=patch_shape, + data_choice=data_choice, + resize_inputs=resize_inputs, + download=download, + **ds_kwargs + ) + loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) + return loader diff --git a/torch_em/data/datasets/medical/segthy.py b/torch_em/data/datasets/medical/segthy.py new file mode 100644 index 00000000..1064ca81 --- /dev/null +++ b/torch_em/data/datasets/medical/segthy.py @@ -0,0 +1,165 @@ +"""The SegThy dataset contains annotations for thyroid segmentation in MRI and US scans, +and additional annotations for vein and artery segmentation in MRI. + +NOTE: The label legends are described as following: +1: For thyroid-only labels: (at 'MRI_thyroid' or 'US_thyroid') +- background: 0 and thyroid: 1 +2: For thyroid, jugular veins and carotid arteries (at 'MRI_thyroid+jugular+carotid_label') +- background: 0, thyroid: 1, jugular vein: 3 and 5, carotid artery: 2 and 4. + +The dataset is located at https://www.cs.cit.tum.de/camp/publications/segthy-dataset/. + +This dataset is from the publication https://doi.org/10.1371/journal.pone.0268550. +Please cite it if you use this dataset in your research. +""" + +import os +from glob import glob +from natsort import natsorted +from typing import Union, Tuple, Literal, List + +import numpy as np + +from torch.utils.data import Dataset, DataLoader + +import torch_em + +from .. import util + + +URLS = { + "MRI": "https://www.campar.in.tum.de/public_datasets/2022_plosone_eilers/MRI_data.zip", + "US": "https://www.campar.in.tum.de/public_datasets/2022_plosone_eilers/US_data.zip", +} + +CHECKSUMS = { + "MRI": "e9d0599b305dfe36795c45282a8495d3bfb4a872851c221b321d59ed0b11e7eb", + "US": "52c59ef4db08adfa0e6ea562c7fe747c612f2064e01f907a78b170b02fb459bb", +} + + +def get_segthy_data(path: Union[os.PathLike, str], source: Literal['MRI', 'US'], download: bool = False): + """Download the SegThy dataset. + + Args: + path: Filepath to a folder where the data is downloaded for further processing. + download: Whether to download the data if it is not present. + """ + data_dir = os.path.join(path, f"{source}_volunteer_dataset") + if os.path.exists(data_dir): + return + + os.makedirs(path, exist_ok=True) + + zip_path = os.path.join(path, f"{source}_data.zip") + util.download_source(path=zip_path, url=URLS[source], download=download, checksum=CHECKSUMS[source]) + util.unzip(zip_path=zip_path, dst=path) + + # NOTE: There is one label with an empty channel. + if source == "MRI": + lpath = os.path.join(data_dir, "MRI_thyroid_label", "005_MRI_thyroid_label.nii.gz") + + import nibabel as nib + # Load the label volume and remove the empty channel. + label = nib.load(lpath).get_fdata() + label = label[..., 0] + + # Store the updated label. + label_nifti = nib.Nifti2Image(label, np.eye(4)) + nib.save(label_nifti, lpath) + + +def get_segthy_paths( + path: Union[os.PathLike, str], + source: Literal['MRI', 'US'], + region: Literal['thyroid', 'thyroid_and_vessels'] = "thyroid", + download: bool = False +) -> Tuple[List[str], List[str]]: + """Get paths to the SegThy data. + + Args: + path: Filepath to a folder where the data is downloaded for further processing. + source: The source of dataset. Either 'MRI' or 'US. + region: The labeled regions for the corresponding volumes. Either 'thyroid' or 'thyroid_and_vessels'. + download: Whether to download the data if it is not present. + + Returns: + List of filepaths for the image data. + List of filepaths for the label data. + """ + get_segthy_data(path, source, download) + + if source == "MRI": + rdir, ldir = "MRI", "MRI_thyroid_label" if region == "thyroid" else "MRI_thyroid+jugular+carotid_label" + fext = "*.nii.gz" + else: # US data + assert region != "thyroid_and_vessels", "US source does not have labels for both thyroid and vessels." + rdir, ldir = "ground_truth_data/US", "ground_truth_data/US_thyroid_label" + fext = "*.nii" + + raw_paths = natsorted(glob(os.path.join(path, f"{source}_volunteer_dataset", rdir, fext))) + label_paths = natsorted(glob(os.path.join(path, f"{source}_volunteer_dataset", ldir, fext))) + + return raw_paths, label_paths + + +def get_segthy_dataset( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, ...], + source: Literal['MRI', 'US'], + region: Literal['thyroid', 'thyroid_and_vessels'] = "thyroid", + download: bool = False, + **kwargs +) -> Dataset: + """Get the SegThy dataset for thyroid (and vessel) segmentation. + + Args: + path: Filepath to a folder where the data is downloaded for further processing. + patch_shape: The patch shape to use for training. + source: The source of dataset. Either 'MRI' or 'US. + region: The labeled regions for the corresponding volumes. Either 'thyroid' or 'thyroid_and_vessels'. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The segmentation dataset. + """ + raw_paths, label_paths = get_segthy_paths(path, source, region, download) + + return torch_em.default_segmentation_dataset( + raw_paths=raw_paths, + raw_key="data", + label_paths=label_paths, + label_key="data", + patch_shape=patch_shape, + is_seg_dataset=True, + **kwargs + ) + + +def get_segthy_loader( + path: Union[os.PathLike, str], + batch_size: int, + patch_shape: Tuple[int, ...], + source: Literal['MRI', 'US'], + region: Literal['thyroid', 'thyroid_and_vessels'] = "thyroid", + download: bool = False, + **kwargs +) -> DataLoader: + """Get the SegThy dataloader for thyroid (and vessel) segmentation. + + Args: + path: Filepath to a folder where the data is downloaded for further processing. + batch_size: The batch size for training. + patch_shape: The patch shape to use for training. + source: The source of dataset. Either 'MRI' or 'US. + region: The labeled regions for the corresponding volumes. Either 'thyroid' or 'thyroid_and_vessels'. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. + + Args: + The DataLoader. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_segthy_dataset(path, patch_shape, source, region, download, **ds_kwargs) + return torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/medical/siim_acr.py b/torch_em/data/datasets/medical/siim_acr.py new file mode 100644 index 00000000..7326c7e1 --- /dev/null +++ b/torch_em/data/datasets/medical/siim_acr.py @@ -0,0 +1,96 @@ +import os +from glob import glob +from typing import Union, Tuple, Literal + +import torch_em + +from .. import util + + +KAGGLE_DATASET_NAME = "vbookshelf/pneumothorax-chest-xray-images-and-masks" +CHECKSUM = "1ade68d31adb996c531bb686fb9d02fe11876ddf6f25594ab725e18c69d81538" + + +def get_siim_acr_data(path, download): + os.makedirs(path, exist_ok=True) + + data_dir = os.path.join(path, "siim-acr-pneumothorax") + if os.path.exists(data_dir): + return data_dir + + util.download_source_kaggle(path=path, dataset_name=KAGGLE_DATASET_NAME, download=download) + + zip_path = os.path.join(path, "pneumothorax-chest-xray-images-and-masks.zip") + util._check_checksum(path=zip_path, checksum=CHECKSUM) + util.unzip(zip_path=zip_path, dst=path) + + return data_dir + + +def _get_siim_acr_paths(path, split, download): + data_dir = get_siim_acr_data(path=path, download=download) + + assert split in ["train", "test"], f"'{split}' is not a valid split." + + image_paths = sorted(glob(os.path.join(data_dir, "png_images", f"*_{split}_*.png"))) + gt_paths = sorted(glob(os.path.join(data_dir, "png_masks", f"*_{split}_*.png"))) + + return image_paths, gt_paths + + +def get_siim_acr_dataset( + path: Union[os.PathLike, str], + split: Literal["train", "test"], + patch_shape: Tuple[int, int], + download: bool = False, + resize_inputs: bool = False, + **kwargs +): + """Dataset for pneumothorax segmentation in CXR. + + The database is located at https://www.kaggle.com/datasets/vbookshelf/pneumothorax-chest-xray-images-and-masks/data + + This dataset is from the "SIIM-ACR Pneumothorax Segmentation" competition: + https://kaggle.com/competitions/siim-acr-pneumothorax-segmentation + + Please cite it if you use this dataset for a publication. + """ + image_paths, gt_paths = _get_siim_acr_paths(path=path, split=split, download=download) + + if resize_inputs: + resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False} + kwargs, patch_shape = util.update_kwargs_for_resize_trafo( + kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs + ) + + dataset = torch_em.default_segmentation_dataset( + raw_paths=image_paths, + raw_key=None, + label_paths=gt_paths, + label_key=None, + patch_shape=patch_shape, + is_seg_dataset=False, + **kwargs + ) + dataset.max_sampling_attempts = 5000 + + return dataset + + +def get_siim_acr_loader( + path: Union[os.PathLike, str], + split: Literal["train", "test"], + patch_shape: Tuple[int, int], + batch_size: int, + download: bool = False, + resize_inputs: bool = False, + **kwargs +): + """Dataloader for pneumothorax segmentation in CXR. See `get_siim_acr_dataset` for details. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_siim_acr_dataset( + path=path, split=split, patch_shape=patch_shape, download=download, resize_inputs=resize_inputs, **ds_kwargs + ) + loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) + return loader diff --git a/torch_em/data/datasets/medical/spider.py b/torch_em/data/datasets/medical/spider.py new file mode 100644 index 00000000..0dddc2dd --- /dev/null +++ b/torch_em/data/datasets/medical/spider.py @@ -0,0 +1,79 @@ +import os +from glob import glob +from natsort import natsorted + +import torch_em + +from .. import util + + +URL = { + "images": "https://zenodo.org/records/10159290/files/images.zip?download=1", + "masks": "https://zenodo.org/records/10159290/files/masks.zip?download=1" +} + +CHECKSUMS = { + "images": "a54cba2905284ff6cc9999f1dd0e4d871c8487187db7cd4b068484eac2f50f17", + "masks": "13a6e25a8c0d74f507e16ebb2edafc277ceeaf2598474f1fed24fdf59cb7f18f" +} + + +def get_spider_data(path, download): + os.makedirs(path, exist_ok=True) + + data_dir = os.path.join(path, "data") + if os.path.exists(data_dir): + return data_dir + + zip_path = os.path.join(path, "images.zip") + util.download_source(path=zip_path, url=URL["images"], download=download, checksum=CHECKSUMS["images"]) + util.unzip(zip_path=zip_path, dst=data_dir) + + zip_path = os.path.join(path, "masks.zip") + util.download_source(path=zip_path, url=URL["masks"], download=download, checksum=CHECKSUMS["masks"]) + util.unzip(zip_path=zip_path, dst=data_dir) + + return data_dir + + +def _get_spider_paths(path, download): + data_dir = get_spider_data(path, download) + + image_paths = natsorted(glob(os.path.join(data_dir, "images", "*.mha"))) + gt_paths = natsorted(glob(os.path.join(data_dir, "masks", "*.mha"))) + + return image_paths, gt_paths + + +def get_spider_dataset(path, patch_shape, download=False, **kwargs): + """Dataset for segmentation of vertebrae, intervertebral discs and spinal canal in T1 and T2 MRI series. + + https://zenodo.org/records/10159290 + https://www.nature.com/articles/s41597-024-03090-w + + Please cite it if you use this data in a publication. + """ + # TODO: expose the choice to choose specific MRI modality, for now this works for our interests. + image_paths, gt_paths = _get_spider_paths(path, download) + + dataset = torch_em.default_segmentation_dataset( + raw_paths=image_paths, + raw_key=None, + label_paths=gt_paths, + label_key=None, + is_seg_dataset=True, + patch_shape=patch_shape, + **kwargs + ) + + return dataset + + +def get_spider_loader(path, patch_shape, batch_size, download=False, **kwargs): + """Dataloader for segmentation of vertebrae, intervertebral discs and spinal canal in T1 and T2 MRI series. + See `get_spider_dataset` for details. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_spider_dataset(path=path, patch_shape=patch_shape, download=download, **ds_kwargs) + loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) + return loader diff --git a/torch_em/data/datasets/medical/toothfairy.py b/torch_em/data/datasets/medical/toothfairy.py new file mode 100644 index 00000000..987647aa --- /dev/null +++ b/torch_em/data/datasets/medical/toothfairy.py @@ -0,0 +1,194 @@ +"""The ToothSeg data contains annotations for mandibular canal (v1) and multiple structures (v2) +segmentation in CBCT scans. + +NOTE: The dataset is located at https://ditto.ing.unimore.it/ +To download the dataset, please follow the mentioned steps: +- Choose either v1 (https://ditto.ing.unimore.it/toothfairy) or v2 (https://ditto.ing.unimore.it/toothfairy2). +- Visit the website, scroll down to the 'Download' section, which expects you to sign up. +- After signing up, use your credentials to login to the dataset home page. +- Click on the blue icon stating: 'Download Dataset' to download the zipped files to the desired path. + +The relevant links for the dataset are: +- ToothFairy Challenge: https://toothfairy.grand-challenge.org/ +- ToothFairy2 Challenge: https://toothfairy2.grand-challenge.org/ +- Publication 1: https://doi.org/10.1109/ACCESS.2022.3144840 +- Publication 2: https://doi.org/10.1109/CVPR52688.2022.02046 + +Please cite them if you use this dataset for your research. +""" + +import os +from glob import glob +from tqdm import tqdm +from natsort import natsorted +from typing import Union, Tuple, Literal, List + +import numpy as np + +from torch.utils.data import Dataset, DataLoader + +import torch_em + +from .. import util + + +def get_toothfairy_data( + path: Union[os.PathLike, str], version: Literal["v1", "v2"] = "v2", download: bool = False +) -> str: + """Obtain the ToothFairy datasets. + + Args: + path: Filepath to a folder where the data is downloaded for further processing. + version: The version of dataset. Either v1 (ToothFairy) or v2 (ToothFairy2). + download: Whether to download the data if it is not present. + + Returns: + Filepath to the already downloaded dataset. + """ + data_dir = os.path.join(path, "ToothFairy_Dataset/Dataset" if version == "v1" else "Dataset112_ToothFairy2") + if os.path.exists(data_dir): + return data_dir + + if download: + msg = "Download is set to True, but 'torch_em' cannot download this dataset. " + msg += "See `get_toothfairy2_data` for details." + raise NotImplementedError(msg) + + if version == "v1": + zip_path = os.path.join(path, "ToothFairy_Dataset.zip") + elif version == "v2": + zip_path = os.path.join(path, "ToothFairy2_Dataset.zip") + else: + raise ValueError(f"'{version}' is not a valid version.") + + if not os.path.exists(zip_path): + raise FileNotFoundError(f"It's expected to place the downloaded toothfairy zipfile at '{path}'.") + + util.unzip(zip_path=zip_path, dst=path, remove=False) + + return data_dir + + +def _preprocess_toothfairy_inputs(path, data_dir): + import nibabel as nib + + images_dir = os.path.join(path, "data", "images") + gt_dir = os.path.join(path, "data", "dense_labels") + if os.path.exists(images_dir) and os.path.exists(gt_dir): + return natsorted(glob(os.path.join(images_dir, "*.nii.gz"))), natsorted(glob(os.path.join(gt_dir, "*.nii.gz"))) + + os.makedirs(images_dir, exist_ok=True) + os.makedirs(gt_dir, exist_ok=True) + + image_paths, gt_paths = [], [] + for patient_dir in tqdm(glob(os.path.join(data_dir, "P*"))): + dense_anns_path = os.path.join(patient_dir, "gt_alpha.npy") + if not os.path.exists(dense_anns_path): + continue + + image_path = os.path.join(patient_dir, "data.npy") + image, gt = np.load(image_path), np.load(dense_anns_path) + image_nifti, gt_nifti = nib.Nifti2Image(image, np.eye(4)), nib.Nifti2Image(gt, np.eye(4)) + + patient_id = os.path.split(patient_dir)[-1] + trg_image_path = os.path.join(images_dir, f"{patient_id}.nii.gz") + trg_gt_path = os.path.join(gt_dir, f"{patient_id}.nii.gz") + + nib.save(image_nifti, trg_image_path) + nib.save(gt_nifti, trg_gt_path) + + image_paths.append(trg_image_path) + gt_paths.append(trg_gt_path) + + return image_paths, gt_paths + + +def get_toothfairy_paths( + path: Union[os.PathLike, str], version: Literal["v1", "v2"] = "v2", download: bool = False +) -> Tuple[List[str], List[str]]: + """Get paths to the ToothFairy data. + + Args: + path: Filepath to a folder where the data is downloaded for further processing. + version: The version of dataset. Either v1 (ToothFairy) or v2 (ToothFairy2). + download: Whether to download the data if it is not present. + + Returns: + List of filepaths for the image data. + List of filepaths for the label data. + """ + data_dir = get_toothfairy_data(path, version, download) + + if version == "v1": + image_paths, gt_paths = _preprocess_toothfairy_inputs(path, data_dir) + else: + image_paths = natsorted(glob(os.path.join(data_dir, "imagesTr", "*.mha"))) + gt_paths = natsorted(glob(os.path.join(data_dir, "labelsTr", "*.mha"))) + + return image_paths, gt_paths + + +def get_toothfairy_dataset( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, ...], + version: Literal["v1", "v2"] = "v2", + resize_inputs: bool = False, + download: bool = False, + **kwargs +) -> Dataset: + """Get the ToothFairy dataset for canal and teeth segmentation. + + Args: + path: Filepath to a folder where the data is downloaded for further processing. + patch_shape: The patch shape to use for training. + resize_inputs: Whether to resize inputs to the desired patch shape. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset`. + + Returns: + The segmentation dataset. + """ + image_paths, gt_paths = get_toothfairy_paths(path, version, download) + + if resize_inputs: + resize_kwargs = {"patch_shape": patch_shape, "is_rgb": False} + kwargs, patch_shape = util.update_kwargs_for_resize_trafo( + kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs + ) + + return torch_em.default_segmentation_dataset( + raw_paths=image_paths, + raw_key="data" if version == "v1" else None, + label_paths=gt_paths, + label_key="data" if version == "v1" else None, + is_seg_dataset=True, + patch_shape=patch_shape, + **kwargs + ) + + +def get_toothfairy_loader( + path: Union[os.PathLike, str], + batch_size: int, + patch_shape: Tuple[int, ...], + version: Literal["v1", "v2"] = "v2", + resize_inputs: bool = False, + download: bool = False, + **kwargs +) -> DataLoader: + """Get the ToothFairy dataloader for canal and teeth segmentation. + + Args: + path: Filepath to a folder where the data is downloaded for further processing. + batch_size: The batch size for training. + patch_shape: The patch shape to use for training. + resize_inputs: Whether to resize inputs to the desired patch shape. + download: Whether to download the data if it is not present. + kwargs: Additional keyword arguments for `torch_em.default_segmentation_dataset` or for the PyTorch DataLoader. + + Returns: + The DataLoader. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_toothfairy_dataset(path, patch_shape, version, resize_inputs, download, **ds_kwargs) + return torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/medical/uwaterloo_skin.py b/torch_em/data/datasets/medical/uwaterloo_skin.py new file mode 100644 index 00000000..12d56021 --- /dev/null +++ b/torch_em/data/datasets/medical/uwaterloo_skin.py @@ -0,0 +1,118 @@ +import os +import shutil +from glob import glob +from typing import Tuple, Union +from urllib.parse import urljoin +from urllib3.exceptions import ProtocolError + +import torch_em + +from .. import util + + +BASE_URL = "https://uwaterloo.ca/vision-image-processing-lab/sites/ca.vision-image-processing-lab/files/uploads/files/" + + +ZIPFILES = { + "set1": "skin_image_data_set-1.zip", # patients with melanoma + "set2": "skin_image_data_set-2.zip" # patients without melanoma +} + +CHECKSUMS = { + "set1": "1788cd3eb7a4744012aad9a154e514fc5b82b9f3b19e31cc1b6ded5fc6bed297", + "set2": "108a818baf20b36ef4544ebda10a8075dad99e335f0535c9533bb14cb02b5c53" +} + + +def get_uwaterloo_skin_data(path, chosen_set, download): + os.makedirs(path, exist_ok=True) + + assert chosen_set in ZIPFILES.keys(), f"'{chosen_set}' is not a valid set." + + data_dir = os.path.join(path, f"{chosen_set}_Data") + if os.path.exists(data_dir): + return data_dir + + zip_path = os.path.join(path, ZIPFILES[chosen_set]) + url = urljoin(BASE_URL, ZIPFILES[chosen_set]) + + try: + util.download_source(path=zip_path, url=url, download=download, checksum=CHECKSUMS[chosen_set]) + except ProtocolError: # the 'uwaterloo.ca' quite randomly times out of connections, pretty weird. + msg = "The server seems to be unreachable at the moment. " + msg += f"We recommend downloading the data manually, from '{url}' at '{path}'. " + print(msg) + quit() + + util.unzip(zip_path=zip_path, dst=path) + + setnum = chosen_set[-1] + tmp_dir = os.path.join(path, fr"Skin Image Data Set-{setnum}") + shutil.move(src=tmp_dir, dst=data_dir) + + return data_dir + + +def _get_uwaterloo_skin_paths(path, download): + data_dir = get_uwaterloo_skin_data(path=path, chosen_set="set1", download=download) + + image_paths = sorted(glob(os.path.join(data_dir, "skin_data", "melanoma", "*", "*_orig.jpg"))) + gt_paths = sorted(glob(os.path.join(data_dir, "skin_data", "melanoma", "*", "*_contour.png"))) + + data_dir = get_uwaterloo_skin_data(path=path, chosen_set="set2", download=download) + + image_paths.extend(sorted(glob(os.path.join(data_dir, "skin_data", "notmelanoma", "*", "*_orig.jpg")))) + gt_paths.extend(sorted(glob(os.path.join(data_dir, "skin_data", "notmelanoma", "*", "*_contour.png")))) + + return image_paths, gt_paths + + +def get_uwaterloo_skin_dataset( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int], + resize_inputs: bool = False, + download: bool = False, + **kwargs +): + """Dataset for skin lesion segmentation in dermoscopy images. + + The database is located at https://uwaterloo.ca/vision-image-processing-lab/research-demos/skin-cancer-detection. + Please cite it if you use this dataset for a publication. + """ + image_paths, gt_paths = _get_uwaterloo_skin_paths(path=path, download=download) + + if resize_inputs: + resize_kwargs = {"patch_shape": patch_shape, "is_rgb": True} + kwargs, patch_shape = util.update_kwargs_for_resize_trafo( + kwargs=kwargs, patch_shape=patch_shape, resize_inputs=resize_inputs, resize_kwargs=resize_kwargs + ) + + dataset = torch_em.default_segmentation_dataset( + raw_paths=image_paths, + raw_key=None, + label_paths=gt_paths, + label_key=None, + is_seg_dataset=False, + patch_shape=patch_shape, + **kwargs + ) + + return dataset + + +def get_uwaterloo_skin_loader( + path: Union[os.PathLike, str], + patch_shape: Tuple[int, int], + batch_size: int, + resize_inputs: bool = False, + download: bool = False, + **kwargs +): + """Dataset for skin lesion segmentation in dermoscopy images. See `get_uwaterloo_skin_dataset` for details. + """ + ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) + dataset = get_uwaterloo_skin_dataset( + path=path, patch_shape=patch_shape, resize_inputs=resize_inputs, download=download, **ds_kwargs + ) + loader = torch_em.get_data_loader(dataset=dataset, batch_size=batch_size, **loader_kwargs) + return loader diff --git a/torch_em/data/datasets/monuseg.py b/torch_em/data/datasets/monuseg.py deleted file mode 100644 index 2f2556d3..00000000 --- a/torch_em/data/datasets/monuseg.py +++ /dev/null @@ -1,136 +0,0 @@ -import os -import shutil -from tqdm import tqdm -from glob import glob -from pathlib import Path -from typing import List, Optional - -import imageio.v3 as imageio - -import torch_em -from torch_em.data.datasets import util - - -URL = { - "train": "https://drive.google.com/uc?export=download&id=1ZgqFJomqQGNnsx7w7QBzQQMVA16lbVCA", - "test": "https://drive.google.com/uc?export=download&id=1NKkSQ5T0ZNQ8aUhh0a8Dt2YKYCQXIViw" -} - -CHECKSUM = { - "train": "25d3d3185bb2970b397cafa72eb664c9b4d24294aee382e7e3df9885affce742", - "test": "13e522387ae8b1bcc0530e13ff9c7b4d91ec74959ef6f6e57747368d7ee6f88a" -} - -# here's the description: https://drive.google.com/file/d/1xYyQ31CHFRnvTCTuuHdconlJCMk2SK7Z/view?usp=sharing -ORGAN_SPLITS = { - "breast": ["TCGA-A7-A13E-01Z-00-DX1", "TCGA-A7-A13F-01Z-00-DX1", "TCGA-AR-A1AK-01Z-00-DX1", - "TCGA-AR-A1AS-01Z-00-DX1", "TCGA-E2-A1B5-01Z-00-DX1", "TCGA-E2-A14V-01Z-00-DX1"], - "kidney": ["TCGA-B0-5711-01Z-00-DX1", "TCGA-HE-7128-01Z-00-DX1", "TCGA-HE-7129-01Z-00-DX1", - "TCGA-HE-7130-01Z-00-DX1", "TCGA-B0-5710-01Z-00-DX1", "TCGA-B0-5698-01Z-00-DX1"], - "liver": ["TCGA-18-5592-01Z-00-DX1", "TCGA-38-6178-01Z-00-DX1", "TCGA-49-4488-01Z-00-DX1", - "TCGA-50-5931-01Z-00-DX1", "TCGA-21-5784-01Z-00-DX1", "TCGA-21-5786-01Z-00-DX1"], - "prostate": ["TCGA-G9-6336-01Z-00-DX1", "TCGA-G9-6348-01Z-00-DX1", "TCGA-G9-6356-01Z-00-DX1", - "TCGA-G9-6363-01Z-00-DX1", "TCGA-CH-5767-01Z-00-DX1", "TCGA-G9-6362-01Z-00-DX1"], - "bladder": ["TCGA-DK-A2I6-01A-01-TS1", "TCGA-G2-A2EK-01A-02-TSB"], - "colon": ["TCGA-AY-A8YK-01A-01-TS1", "TCGA-NH-A8F7-01A-01-TS1"], - "stomach": ["TCGA-KB-A93J-01A-01-TS1", "TCGA-RD-A8N9-01A-01-TS1"] -} - - -def _download_monuseg(path, download, split): - assert split in ["train", "test"], "The split choices in MoNuSeg datset are train/test, please choose from them" - - # check if we have extracted the images and labels already - im_path = os.path.join(path, "images", split) - label_path = os.path.join(path, "labels", split) - if os.path.exists(im_path) and os.path.exists(label_path): - return - - os.makedirs(path, exist_ok=True) - zip_path = os.path.join(path, f"monuseg_{split}.zip") - util.download_source_gdrive(zip_path, URL[split], download=download, checksum=CHECKSUM[split]) - - _process_monuseg(path, split) - - -def _process_monuseg(path, split): - util.unzip(os.path.join(path, f"monuseg_{split}.zip"), path) - - # assorting the images into expected dir; - # converting the label xml files to numpy arrays (of same dimension as input images) in the expected dir - root_img_save_dir = os.path.join(path, "images", split) - root_label_save_dir = os.path.join(path, "labels", split) - - os.makedirs(root_img_save_dir, exist_ok=True) - os.makedirs(root_label_save_dir, exist_ok=True) - - if split == "train": - all_img_dir = sorted(glob(os.path.join(path, "*", "Tissue*", "*"))) - all_xml_label_dir = sorted(glob(os.path.join(path, "*", "Annotations", "*"))) - else: - all_img_dir = sorted(glob(os.path.join(path, "MoNuSegTestData", "*.tif"))) - all_xml_label_dir = sorted(glob(os.path.join(path, "MoNuSegTestData", "*.xml"))) - - assert len(all_img_dir) == len(all_xml_label_dir) - - for img_path, xml_label_path in tqdm(zip(all_img_dir, all_xml_label_dir), - desc=f"Converting {split} split to the expected format", - total=len(all_img_dir)): - desired_label_shape = imageio.imread(img_path).shape[:-1] - - img_id = os.path.split(img_path)[-1] - dst = os.path.join(root_img_save_dir, img_id) - shutil.move(src=img_path, dst=dst) - - _label = util.generate_labeled_array_from_xml(shape=desired_label_shape, xml_file=xml_label_path) - _fileid = img_id.split(".")[0] - imageio.imwrite(os.path.join(root_label_save_dir, f"{_fileid}.tif"), _label) - - shutil.rmtree(glob(os.path.join(path, "MoNuSeg*"))[0]) - if split == "train": - shutil.rmtree(glob(os.path.join(path, "__MACOSX"))[0]) - - -def get_monuseg_dataset( - path, patch_shape, split, organ_type: Optional[List[str]] = None, download=False, - offsets=None, boundaries=False, binary=False, **kwargs -): - """Dataset from https://monuseg.grand-challenge.org/Data/ - """ - _download_monuseg(path, download, split) - - image_paths = sorted(glob(os.path.join(path, "images", split, "*"))) - label_paths = sorted(glob(os.path.join(path, "labels", split, "*"))) - - if split == "train" and organ_type is not None: - # get all patients for multiple organ selection - all_organ_splits = sum([ORGAN_SPLITS[_o] for _o in organ_type], []) - - image_paths = [_path for _path in image_paths if Path(_path).stem in all_organ_splits] - label_paths = [_path for _path in label_paths if Path(_path).stem in all_organ_splits] - - elif split == "test" and organ_type is not None: - # we don't have organ splits in the test dataset - raise ValueError("The test split does not have any organ informations, please pass `organ_type=None`") - - kwargs, _ = util.add_instance_label_transform( - kwargs, add_binary_target=True, binary=binary, boundaries=boundaries, offsets=offsets - ) - return torch_em.default_segmentation_dataset( - image_paths, None, label_paths, None, patch_shape, is_seg_dataset=False, **kwargs - ) - - -def get_monuseg_loader( - path, patch_shape, batch_size, split, organ_type=None, download=False, offsets=None, boundaries=False, binary=False, - **kwargs -): - ds_kwargs, loader_kwargs = util.split_kwargs( - torch_em.default_segmentation_dataset, **kwargs - ) - dataset = get_monuseg_dataset( - path, patch_shape, split, organ_type=organ_type, download=download, - offsets=offsets, boundaries=boundaries, binary=binary, **ds_kwargs - ) - loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) - return loader diff --git a/torch_em/data/datasets/mouse_embryo.py b/torch_em/data/datasets/mouse_embryo.py deleted file mode 100644 index fb6839dd..00000000 --- a/torch_em/data/datasets/mouse_embryo.py +++ /dev/null @@ -1,79 +0,0 @@ -import os -from glob import glob - -import torch_em -from . import util - -URL = "https://zenodo.org/record/6546550/files/MouseEmbryos.zip?download=1" -CHECKSUM = "bf24df25e5f919489ce9e674876ff27e06af84445c48cf2900f1ab590a042622" - - -def _require_embryo_data(path, download): - if os.path.exists(path): - return - os.makedirs(path, exist_ok=True) - tmp_path = os.path.join(path, "mouse_embryo.zip") - util.download_source(tmp_path, URL, download, CHECKSUM) - util.unzip(tmp_path, path, remove=True) - # remove empty volume - os.remove(os.path.join(path, "Membrane", "train", "fused_paral_stack0_chan2_tp00073_raw_crop_bg_noise.h5")) - - -def get_mouse_embryo_dataset( - path, - name, - split, - patch_shape, - download=False, - offsets=None, - boundaries=False, - binary=False, - **kwargs, -): - """Dataset for the segmentation of nuclei in confocal microscopy. - - This dataset is stored on zenodo: https://zenodo.org/record/6546550. - """ - assert name in ("membrane", "nuclei") - assert split in ("train", "val") - assert len(patch_shape) == 3 - _require_embryo_data(path, download) - - # the naming of the data is inconsistent: membrane has val, nuclei has test; - # we treat nuclei:test as val - split_ = "test" if name == "nuclei" and split == "val" else split - file_paths = glob(os.path.join(path, name.capitalize(), split_, "*.h5")) - file_paths.sort() - - kwargs, _ = util.add_instance_label_transform( - kwargs, add_binary_target=binary, binary=binary, boundaries=boundaries, - offsets=offsets, binary_is_exclusive=False - ) - - raw_key, label_key = "raw", "label" - return torch_em.default_segmentation_dataset(file_paths, raw_key, file_paths, label_key, patch_shape, **kwargs) - - -def get_mouse_embryo_loader( - path, - name, - split, - patch_shape, - batch_size, - download=False, - offsets=None, - boundaries=False, - binary=False, - **kwargs, -): - """Dataloader for the segmentation of nuclei in confocal microscopy. See 'get_mouse_embryo_dataset' for details.""" - ds_kwargs, loader_kwargs = util.split_kwargs( - torch_em.default_segmentation_dataset, **kwargs - ) - dataset = get_mouse_embryo_dataset( - path, name, split, patch_shape, - download=download, offsets=offsets, boundaries=boundaries, binary=binary, - **ds_kwargs - ) - loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) - return loader diff --git a/torch_em/data/datasets/nuc_mm.py b/torch_em/data/datasets/nuc_mm.py deleted file mode 100644 index 2e257c51..00000000 --- a/torch_em/data/datasets/nuc_mm.py +++ /dev/null @@ -1,78 +0,0 @@ -import os -from glob import glob - -import h5py -import torch_em - -from . import util - -URL = "https://drive.google.com/drive/folders/1_4CrlYvzx0ITnGlJOHdgcTRgeSkm9wT8" - - -def _extract_split(image_folder, label_folder, output_folder): - os.makedirs(output_folder, exist_ok=True) - image_files = sorted(glob(os.path.join(image_folder, "*.h5"))) - label_files = sorted(glob(os.path.join(label_folder, "*.h5"))) - assert len(image_files) == len(label_files) - for image, label in zip(image_files, label_files): - with h5py.File(image, "r") as f: - vol = f["main"][:] - with h5py.File(label, "r") as f: - seg = f["main"][:] - assert vol.shape == seg.shape - out_path = os.path.join(output_folder, os.path.basename(image)) - with h5py.File(out_path, "a") as f: - f.create_dataset("raw", data=vol, compression="gzip") - f.create_dataset("labels", data=seg, compression="gzip") - - -def _require_dataset(path, sample, download): - # downloading the dataset - util.download_source_gdrive(path, URL, download, download_type="folder") - - if sample == "mouse": - input_folder = os.path.join(path, "Mouse (NucMM-M)") - else: - input_folder = os.path.join(path, "Zebrafish (NucMM-Z)") - assert os.path.exists(input_folder), input_folder - - sample_folder = os.path.join(path, sample) - _extract_split( - os.path.join(input_folder, "Image", "train"), os.path.join(input_folder, "Label", "train"), - os.path.join(sample_folder, "train") - ) - _extract_split( - os.path.join(input_folder, "Image", "val"), os.path.join(input_folder, "Label", "val"), - os.path.join(sample_folder, "val") - ) - - -def get_nuc_mm_dataset(path, sample, split, patch_shape, download=False, **kwargs): - """Dataset for the segmentation of nuclei in EM and X-Ray. - - This dataset is from the publication https://doi.org/10.1007/978-3-030-87193-2_16. - Please cite it if you use this dataset for a publication. - """ - assert sample in ("mouse", "zebrafish") - assert split in ("train", "val") - - sample_folder = os.path.join(path, sample) - if not os.path.exists(sample_folder): - _require_dataset(path, sample, download) - - split_folder = os.path.join(sample_folder, split) - paths = sorted(glob(os.path.join(split_folder, "*.h5"))) - - raw_key, label_key = "raw", "labels" - return torch_em.default_segmentation_dataset( - paths, raw_key, paths, label_key, patch_shape, is_seg_dataset=True, **kwargs - ) - - -def get_nuc_mm_loader(path, sample, split, patch_shape, batch_size, download=False, **kwargs): - """Dataset for the segmentation of nuclei in EM and X-Ray. See 'get_nuc_mm_dataset' for details.""" - ds_kwargs, loader_kwargs = util.split_kwargs( - torch_em.default_segmentation_dataset, **kwargs - ) - ds = get_nuc_mm_dataset(path, sample, split, patch_shape, download, **ds_kwargs) - return torch_em.get_data_loader(ds, batch_size=batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/plantseg.py b/torch_em/data/datasets/plantseg.py deleted file mode 100644 index 1a85f1f2..00000000 --- a/torch_em/data/datasets/plantseg.py +++ /dev/null @@ -1,117 +0,0 @@ -import os -from glob import glob - -import torch_em -from . import util - -URLS = { - "root": { - "train": "https://files.de-1.osf.io/v1/resources/9x3g2/providers/osfstorage/?zip=", - "val": "https://files.de-1.osf.io/v1/resources/vs6gb/providers/osfstorage/?zip=", - "test": "https://files.de-1.osf.io/v1/resources/tn4xj/providers/osfstorage/?zip=", - }, - "nuclei": { - "train": "https://files.de-1.osf.io/v1/resources/thxzn/providers/osfstorage/?zip=", - }, - "ovules": { - "train": "https://files.de-1.osf.io/v1/resources/x9yns/providers/osfstorage/?zip=", - "val": "https://files.de-1.osf.io/v1/resources/xp5uf/providers/osfstorage/?zip=", - "test": "https://files.de-1.osf.io/v1/resources/8jz7e/providers/osfstorage/?zip=", - } -} - -# FIXME somehow the checksums are not reliably, this is quite worrying... -CHECKSUMS = { - "root": { - "train": None, "val": None, "test": None - # "train": "f72e9525ff716ef14b70ab1318efd4bf303bbf9e0772bf2981a2db6e22a75794", - # "val": "987280d9a56828c840e508422786431dcc3603e0ba4814aa06e7bf4424efcd9e", - # "test": "ad71b8b9d20effba85fb5e1b42594ae35939d1a0cf905f3403789fc9e6afbc58", - }, - "nuclei": { - "train": None - # "train": "9d19ddb61373e2a97effb6cf8bd8baae5f8a50f87024273070903ea8b1160396", - }, - "ovules": { - "train": None, "val": None, "test": None - # "train": "70379673f1ab1866df6eb09d5ce11db7d3166d6d15b53a9c8b47376f04bae413", - # "val": "872f516cb76879c30782d9a76d52df95236770a866f75365902c60c37b14fa36", - # "test": "a7272f6ad1d765af6d121e20f436ac4f3609f1a90b1cb2346aa938d8c52800b9", - } -} -# The resolution previous used for the resizing -# I have removed this feature since it was not reliable, -# but leaving this here for reference -# (also implementing resizing would be a good idea, -# but more general and not for each dataset individually) -# NATIVE_RESOLUTION = (0.235, 0.075, 0.075) - - -def _require_plantseg_data(path, download, name, split): - url = URLS[name][split] - checksum = CHECKSUMS[name][split] - os.makedirs(path, exist_ok=True) - out_path = os.path.join(path, f"{name}_{split}") - if os.path.exists(out_path): - return out_path - tmp_path = os.path.join(path, f"{name}_{split}.zip") - util.download_source(tmp_path, url, download, checksum) - util.unzip(tmp_path, out_path, remove=True) - return out_path - - -def get_plantseg_dataset( - path, - name, - split, - patch_shape, - download=False, - offsets=None, - boundaries=False, - binary=False, - **kwargs, -): - """Dataset for the segmentation of plant cells in confocal and light-sheet microscopy. - - This dataset is from the publication https://doi.org/10.7554/eLife.57613. - Please cite it if you use this dataset for a publication. - """ - assert len(patch_shape) == 3 - data_path = _require_plantseg_data(path, download, name, split) - - file_paths = glob(os.path.join(data_path, "*.h5")) - file_paths.sort() - - kwargs, _ = util.add_instance_label_transform( - kwargs, add_binary_target=binary, binary=binary, boundaries=boundaries, - offsets=offsets, binary_is_exclusive=False - ) - - raw_key, label_key = "raw", "label" - return torch_em.default_segmentation_dataset(file_paths, raw_key, file_paths, label_key, patch_shape, **kwargs) - - -# TODO add support for ignore label, key: "/label_with_ignore" -def get_plantseg_loader( - path, - name, - split, - patch_shape, - batch_size, - download=False, - offsets=None, - boundaries=False, - binary=False, - **kwargs, -): - """Dataloader for the segmentation of cells in confocal and light-sheet microscopy. See 'get_plantseg_dataset'.""" - ds_kwargs, loader_kwargs = util.split_kwargs( - torch_em.default_segmentation_dataset, **kwargs - ) - dataset = get_plantseg_dataset( - path, name, split, patch_shape, - download=download, offsets=offsets, boundaries=boundaries, binary=binary, - **ds_kwargs - ) - loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) - return loader diff --git a/torch_em/data/datasets/platynereis.py b/torch_em/data/datasets/platynereis.py deleted file mode 100644 index 87a46e53..00000000 --- a/torch_em/data/datasets/platynereis.py +++ /dev/null @@ -1,238 +0,0 @@ -import os -from glob import glob - -import numpy as np -import torch_em -from . import util - -URLS = { - "cells": "https://zenodo.org/record/3675220/files/membrane.zip", - "nuclei": "https://zenodo.org/record/3675220/files/nuclei.zip", - "cilia": "https://zenodo.org/record/3675220/files/cilia.zip", - "cuticle": "https://zenodo.org/record/3675220/files/cuticle.zip" -} - -CHECKSUMS = { - "cells": "30eb50c39e7e9883e1cd96e0df689fac37a56abb11e8ed088907c94a5980d6a3", - "nuclei": "a05033c5fbc6a3069479ac6595b0a430070f83f5281f5b5c8913125743cf5510", - "cilia": "6d2b47f63d39a671789c02d8b66cad5e4cf30eb14cdb073da1a52b7defcc5e24", - "cuticle": "464f75d30133e8864958049647fe3c2216ddf2d4327569738ad72d299c991843" -} - - -# -# TODO data-loader for more classes: -# - mitos -# - - -def _require_platy_data(path, name, download): - os.makedirs(path, exist_ok=True) - url = URLS[name] - checksum = CHECKSUMS[name] - - zip_path = os.path.join(path, f"data-{name}.zip") - util.download_source(zip_path, url, download=download, checksum=checksum) - util.unzip(zip_path, path, remove=True) - - -def _check_data(path, prefix, extension, n_files): - if not os.path.exists(path): - return False - files = glob(os.path.join(path, f"{prefix}*{extension}")) - return len(files) == n_files - - -def _get_paths_and_rois(sample_ids, n_files, template, rois): - if sample_ids is None: - sample_ids = list(range(1, n_files + 1)) - else: - assert min(sample_ids) >= 1 and max(sample_ids) <= n_files - sample_ids.sort() - paths = [template % sample for sample in sample_ids] - data_rois = [rois.get(sample, np.s_[:, :, :]) for sample in sample_ids] - return paths, data_rois - - -def get_platynereis_cuticle_dataset(path, patch_shape, sample_ids=None, download=False, rois={}, **kwargs): - """Dataset for the segmentation of cuticle in EM. - - This dataset is from the publication https://doi.org/10.1016/j.cell.2021.07.017. - Please cite it if you use this dataset for a publication. - """ - cuticle_root = os.path.join(path, "cuticle") - - ext = ".n5" - prefix, n_files = "train_data_", 5 - data_is_complete = _check_data(cuticle_root, prefix, ext, n_files) - if not data_is_complete: - _require_platy_data(path, "cuticle", download) - - paths, data_rois = _get_paths_and_rois(sample_ids, n_files, os.path.join(cuticle_root, "train_data_%02i.n5"), rois) - raw_key, label_key = "volumes/raw", "volumes/labels/segmentation" - return torch_em.default_segmentation_dataset( - paths, raw_key, paths, label_key, patch_shape, rois=data_rois, **kwargs - ) - - -def get_platynereis_cuticle_loader( - path, patch_shape, batch_size, sample_ids=None, download=False, rois={}, **kwargs -): - """Dataloader for the segmentation of cuticle in EM. See 'get_platynereis_cuticle_loader'.""" - ds_kwargs, loader_kwargs = util.split_kwargs( - torch_em.default_segmentation_dataset, **kwargs - ) - ds = get_platynereis_cuticle_dataset( - path, patch_shape, sample_ids=sample_ids, download=download, rois=rois, **ds_kwargs, - ) - return torch_em.get_data_loader(ds, batch_size=batch_size, **loader_kwargs) - - -def get_platynereis_cilia_dataset( - path, patch_shape, sample_ids=None, - offsets=None, boundaries=False, binary=False, - rois={}, download=False, **kwargs -): - """Dataset for the segmentation of cilia in EM. - - This dataset is from the publication https://doi.org/10.1016/j.cell.2021.07.017. - Please cite it if you use this dataset for a publication. - """ - cilia_root = os.path.join(path, "cilia") - - ext = ".h5" - prefix, n_files = "train_data_cilia_", 3 - data_is_complete = _check_data(cilia_root, prefix, ext, n_files) - if not data_is_complete: - _require_platy_data(path, "cilia", download) - - paths, rois = _get_paths_and_rois(sample_ids, n_files, os.path.join(cilia_root, "train_data_cilia_%02i.h5"), rois) - raw_key = "volumes/raw" - label_key = "volumes/labels/segmentation" - - kwargs, _ = util.add_instance_label_transform( - kwargs, add_binary_target=True, boundaries=boundaries, offsets=offsets, binary=binary, - ) - return torch_em.default_segmentation_dataset(paths, raw_key, paths, label_key, patch_shape, **kwargs) - - -def get_platynereis_cilia_loader( - path, patch_shape, batch_size, sample_ids=None, - offsets=None, boundaries=False, binary=False, - rois={}, download=False, **kwargs -): - """Dataloader for the segmentation of cilia in EM. See 'get_platynereis_cilia_dataset'.""" - ds_kwargs, loader_kwargs = util.split_kwargs( - torch_em.default_segmentation_dataset, **kwargs - ) - ds = get_platynereis_cilia_dataset( - path, patch_shape, sample_ids=sample_ids, - offsets=offsets, boundaries=boundaries, binary=binary, - rois=rois, download=download, **ds_kwargs, - ) - return torch_em.get_data_loader(ds, batch_size=batch_size, **loader_kwargs) - - -def get_platynereis_cell_dataset( - path, patch_shape, - sample_ids=None, rois={}, - offsets=None, boundaries=False, - download=False, **kwargs -): - """Dataset for the segmentation of cells in EM. - - This dataset is from the publication https://doi.org/10.1016/j.cell.2021.07.017. - Please cite it if you use this dataset for a publication. - """ - cell_root = os.path.join(path, "membrane") - - prefix = "train_data_membrane_" - ext = ".n5" - n_files = 9 - data_is_complete = _check_data(cell_root, prefix, ext, n_files) - if not data_is_complete: - _require_platy_data(path, "cells", download) - - template = os.path.join(cell_root, "train_data_membrane_%02i.n5") - data_paths, data_rois = _get_paths_and_rois(sample_ids, n_files, template, rois) - - kwargs = util.update_kwargs(kwargs, "rois", data_rois) - kwargs, _ = util.add_instance_label_transform( - kwargs, add_binary_target=False, boundaries=boundaries, offsets=offsets, - ) - - raw_key = "volumes/raw/s1" - label_key = "volumes/labels/segmentation/s1" - return torch_em.default_segmentation_dataset(data_paths, raw_key, data_paths, label_key, patch_shape, **kwargs) - - -def get_platynereis_cell_loader( - path, patch_shape, batch_size, - sample_ids=None, rois={}, - offsets=None, boundaries=False, - download=False, **kwargs -): - """Dataloader for the segmentation of cells in EM. See 'get_platynereis_cell_dataset'.""" - ds_kwargs, loader_kwargs = util.split_kwargs( - torch_em.default_segmentation_dataset, **kwargs - ) - ds = get_platynereis_cell_dataset( - path, patch_shape, sample_ids, rois=rois, - offsets=offsets, boundaries=boundaries, download=download, - **ds_kwargs, - ) - return torch_em.get_data_loader(ds, batch_size=batch_size, **loader_kwargs) - - -def get_platynereis_nuclei_dataset( - path, patch_shape, sample_ids=None, rois={}, - offsets=None, boundaries=False, binary=False, - download=False, **kwargs, -): - """Dataset for the segmentation of nuclei in EM. - - This dataset is from the publication https://doi.org/10.1016/j.cell.2021.07.017. - Please cite it if you use this dataset for a publication. - """ - nuc_root = os.path.join(path, "nuclei") - prefix, ext = "train_data_nuclei_", ".h5" - n_files = 12 - data_is_complete = _check_data(nuc_root, prefix, ext, n_files) - if not data_is_complete: - _require_platy_data(path, "nuclei", download) - - if sample_ids is None: - sample_ids = list(range(1, n_files + 1)) - assert min(sample_ids) >= 1 and max(sample_ids) <= n_files - sample_ids.sort() - - template = os.path.join(nuc_root, "train_data_nuclei_%02i.h5") - data_paths, data_rois = _get_paths_and_rois(sample_ids, n_files, template, rois) - - kwargs = util.update_kwargs(kwargs, "is_seg_dataset", True) - kwargs = util.update_kwargs(kwargs, "rois", data_rois) - kwargs, _ = util.add_instance_label_transform( - kwargs, add_binary_target=True, boundaries=boundaries, offsets=offsets, binary=binary, - ) - - raw_key = "volumes/raw" - label_key = "volumes/labels/nucleus_instance_labels" - return torch_em.default_segmentation_dataset(data_paths, raw_key, data_paths, label_key, patch_shape, **kwargs) - - -def get_platynereis_nuclei_loader( - path, patch_shape, batch_size, - sample_ids=None, rois={}, - offsets=None, boundaries=False, binary=False, - download=False, **kwargs -): - """Dataloader for the segmentation of nuclei in EM. See 'get_platynereis_nuclei_dataset'.""" - ds_kwargs, loader_kwargs = util.split_kwargs( - torch_em.default_segmentation_dataset, **kwargs - ) - ds = get_platynereis_nuclei_dataset( - path, patch_shape, sample_ids=sample_ids, rois=rois, - offsets=offsets, boundaries=boundaries, binary=binary, download=download, - **ds_kwargs, - ) - return torch_em.get_data_loader(ds, batch_size=batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/snemi.py b/torch_em/data/datasets/snemi.py deleted file mode 100644 index ca1255f2..00000000 --- a/torch_em/data/datasets/snemi.py +++ /dev/null @@ -1,70 +0,0 @@ -import os - -import torch_em -from . import util - -SNEMI_URLS = { - "train": "https://oc.embl.de/index.php/s/43iMotlXPyAB39z/download", - "test": "https://oc.embl.de/index.php/s/aRhphk35H23De2s/download" -} -CHECKSUMS = { - "train": "5b130a24d9eb23d972fede0f1a403bc05f6808b361cfa22eff23b930b12f0615", - "test": "3df3920a0ddec6897105845f842b2665d37a47c2d1b96d4f4565682e315a59fa" -} - - -def get_snemi_dataset( - path, - patch_shape, - sample="train", - download=False, - offsets=None, - boundaries=False, - **kwargs, -): - """Dataset for the segmentation of neurons in EM. - - This dataset is from the publication https://doi.org/10.1016/j.cell.2015.06.054. - Please cite it if you use this dataset for a publication. - """ - assert len(patch_shape) == 3 - os.makedirs(path, exist_ok=True) - - data_path = os.path.join(path, f"snemi_{sample}.h5") - util.download_source(data_path, SNEMI_URLS[sample], download, CHECKSUMS[sample]) - assert os.path.exists(data_path), data_path - - kwargs = util.update_kwargs(kwargs, "is_seg_dataset", True) - kwargs, _ = util.add_instance_label_transform( - kwargs, add_binary_target=False, boundaries=boundaries, offsets=offsets - ) - - raw_key = "volumes/raw" - label_key = "volumes/labels/neuron_ids" - return torch_em.default_segmentation_dataset(data_path, raw_key, data_path, label_key, patch_shape, **kwargs) - - -def get_snemi_loader( - path, - patch_shape, - batch_size, - sample="train", - download=False, - offsets=None, - boundaries=False, - **kwargs, -): - """Dataloader for the segmentation of neurons in EM. See 'get_snemi_dataset'.""" - ds_kwargs, loader_kwargs = util.split_kwargs( - torch_em.default_segmentation_dataset, **kwargs - ) - ds = get_snemi_dataset( - path=path, - patch_shape=patch_shape, - sample=sample, - download=download, - offsets=offsets, - boundaries=boundaries, - **ds_kwargs, - ) - return torch_em.get_data_loader(ds, batch_size=batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/sponge_em.py b/torch_em/data/datasets/sponge_em.py deleted file mode 100644 index f5553584..00000000 --- a/torch_em/data/datasets/sponge_em.py +++ /dev/null @@ -1,47 +0,0 @@ -import os -from glob import glob - -import torch_em -from . import util - -URL = "https://zenodo.org/record/8150818/files/sponge_em_train_data.zip?download=1" -CHECKSUM = "f1df616cd60f81b91d7642933e9edd74dc6c486b2e546186a7c1e54c67dd32a5" - - -def _require_sponge_em_data(path, download): - os.makedirs(path, exist_ok=True) - zip_path = os.path.join(path, "data.zip") - util.download_source(zip_path, URL, download, CHECKSUM) - util.unzip(zip_path, path) - - -def get_sponge_em_dataset(path, mode, patch_shape, sample_ids=None, download=False, **kwargs): - """Dataset for the segmentation of sponge cells and organelles in EM. - - This dataset is from the publication https://doi.org/10.1126/science.abj2949. - Please cite it if you use this dataset for a publication. - """ - assert mode in ("semantic", "instances") - - n_files = len(glob(os.path.join(path, "*.h5"))) - if n_files == 0: - _require_sponge_em_data(path, download) - n_files = len(glob(os.path.join(path, "*.h5"))) - assert n_files == 3 - - if sample_ids is None: - sample_ids = range(1, n_files + 1) - paths = [os.path.join(path, f"train_data_0{i}.h5") for i in sample_ids] - - raw_key = "volumes/raw" - label_key = f"volumes/labels/{mode}" - return torch_em.default_segmentation_dataset(paths, raw_key, paths, label_key, patch_shape, **kwargs) - - -def get_sponge_em_loader(path, mode, patch_shape, batch_size, sample_ids=None, download=False, **kwargs): - """Dataloader for the segmentation of sponge cells and organelles in EM. See 'get_sponge_em_dataset'.""" - ds_kwargs, loader_kwargs = util.split_kwargs( - torch_em.default_segmentation_dataset, **kwargs - ) - ds = get_sponge_em_dataset(path, mode, patch_shape, sample_ids=sample_ids, download=download, **ds_kwargs) - return torch_em.get_data_loader(ds, batch_size=batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/tissuenet.py b/torch_em/data/datasets/tissuenet.py deleted file mode 100644 index f6098eee..00000000 --- a/torch_em/data/datasets/tissuenet.py +++ /dev/null @@ -1,106 +0,0 @@ -import os -from glob import glob - -import numpy as np -import pandas as pd -import torch_em -import z5py - -from tqdm import tqdm -from . import util - - -# Automated download is currently not possible, because of authentication -URL = None # TODO: here - https://datasets.deepcell.org/data - - -def _create_split(path, split): - split_file = os.path.join(path, f"tissuenet_v1.1_{split}.npz") - split_folder = os.path.join(path, split) - os.makedirs(split_folder, exist_ok=True) - data = np.load(split_file, allow_pickle=True) - - x, y = data["X"], data["y"] - metadata = data["meta"] - metadata = pd.DataFrame(metadata[1:], columns=metadata[0]) - - for i, (im, label) in tqdm(enumerate(zip(x, y)), total=len(x), desc=f"Creating files for {split}-split"): - out_path = os.path.join(split_folder, f"image_{i:04}.zarr") - nucleus_channel = im[..., 0] - cell_channel = im[..., 1] - rgb = np.stack([np.zeros_like(nucleus_channel), cell_channel, nucleus_channel]) - chunks = cell_channel.shape - with z5py.File(out_path, "a") as f: - - f.create_dataset("raw/nucleus", data=im[..., 0], compression="gzip", chunks=chunks) - f.create_dataset("raw/cell", data=cell_channel, compression="gzip", chunks=chunks) - f.create_dataset("raw/rgb", data=rgb, compression="gzip", chunks=(3,) + chunks) - - # the switch 0<->1 is intentional, the data format is chaotic... - f.create_dataset("labels/nucleus", data=label[..., 1], compression="gzip", chunks=chunks) - f.create_dataset("labels/cell", data=label[..., 0], compression="gzip", chunks=chunks) - os.remove(split_file) - - -def _create_dataset(path, zip_path): - util.unzip(zip_path, path, remove=False) - splits = ["train", "val", "test"] - assert all([os.path.exists(os.path.join(path, f"tissuenet_v1.1_{split}.npz")) for split in splits]) - for split in splits: - _create_split(path, split) - - -def get_tissuenet_dataset( - path, split, patch_shape, raw_channel, label_channel, download=False, **kwargs -): - """Dataset for the segmentation of cells in tissue imaged with light microscopy. - - This dataset is from the publication https://doi.org/10.1038/s41587-021-01094-0. - Please cite it if you use this dataset for a publication. - """ - assert raw_channel in ("nucleus", "cell", "rgb") - assert label_channel in ("nucleus", "cell") - - splits = ["train", "val", "test"] - assert split in splits - - # check if the dataset exists already - zip_path = os.path.join(path, "tissuenet_v1.1.zip") - if all([os.path.exists(os.path.join(path, split)) for split in splits]): # yes it does - pass - elif os.path.exists(zip_path): # no it does not, but we have the zip there and can unpack it - _create_dataset(path, zip_path) - else: - raise RuntimeError( - "We do not support automatic download for the tissuenet datasets yet." - f"Please download the dataset from https://datasets.deepcell.org/data and put it here: {zip_path}" - ) - - split_folder = os.path.join(path, split) - assert os.path.exists(split_folder) - data_path = glob(os.path.join(split_folder, "*.zarr")) - assert len(data_path) > 0 - - raw_key, label_key = f"raw/{raw_channel}", f"labels/{label_channel}" - - with_channels = True if raw_channel == "rgb" else False - kwargs = util.update_kwargs(kwargs, "with_channels", with_channels) - kwargs = util.update_kwargs(kwargs, "is_seg_dataset", True) - kwargs = util.update_kwargs(kwargs, "ndim", 2) - - return torch_em.default_segmentation_dataset(data_path, raw_key, data_path, label_key, patch_shape, **kwargs) - - -# TODO enable loading specific tissue types etc. (from the 'meta' attributes) -def get_tissuenet_loader( - path, split, patch_shape, batch_size, raw_channel, label_channel, download=False, **kwargs -): - """Dataloader for the segmentation of cells in tissue imaged with light microscopy. - See 'get_tissuenet_dataset' for details. - """ - ds_kwargs, loader_kwargs = util.split_kwargs(torch_em.default_segmentation_dataset, **kwargs) - dataset = get_tissuenet_dataset( - path, split, patch_shape, raw_channel, label_channel, download, **ds_kwargs - ) - loader = torch_em.get_data_loader(dataset, batch_size, **loader_kwargs) - return loader diff --git a/torch_em/data/datasets/uro_cell.py b/torch_em/data/datasets/uro_cell.py deleted file mode 100644 index fb8e92ad..00000000 --- a/torch_em/data/datasets/uro_cell.py +++ /dev/null @@ -1,147 +0,0 @@ -import os -import warnings -from glob import glob -from shutil import rmtree - -import h5py -import torch_em -from . import util - - -URL = "https://github.com/MancaZerovnikMekuc/UroCell/archive/refs/heads/master.zip" -CHECKSUM = "a48cf31b06114d7def642742b4fcbe76103483c069122abe10f377d71a1acabc" - - -def _require_urocell_data(path, download): - if os.path.exists(path): - return path - - # add nifti file format support in elf by wrapping nibabel? - import nibabel as nib - - # download and unzip the data - os.makedirs(path) - tmp_path = os.path.join(path, "uro_cell.zip") - util.download_source(tmp_path, URL, download, checksum=CHECKSUM) - util.unzip(tmp_path, path, remove=True) - - root = os.path.join(path, "UroCell-master") - - files = glob(os.path.join(root, "data", "*.nii.gz")) - files.sort() - for data_path in files: - fname = os.path.basename(data_path) - data = nib.load(data_path).get_fdata() - - out_path = os.path.join(path, fname.replace("nii.gz", "h5")) - with h5py.File(out_path, "w") as f: - f.create_dataset("raw", data=data, compression="gzip") - - # check if we have any of the organelle labels for this volume - # and also copy them if yes - fv_path = os.path.join(root, "fv", "instance", fname) - if os.path.exists(fv_path): - fv = nib.load(fv_path).get_fdata().astype("uint32") - assert fv.shape == data.shape - f.create_dataset("labels/fv", data=fv, compression="gzip") - - golgi_path = os.path.join(root, "golgi", "precise", fname) - if os.path.exists(golgi_path): - golgi = nib.load(golgi_path).get_fdata().astype("uint32") - assert golgi.shape == data.shape - f.create_dataset("labels/golgi", data=golgi, compression="gzip") - - lyso_path = os.path.join(root, "lyso", "instance", fname) - if os.path.exists(lyso_path): - lyso = nib.load(lyso_path).get_fdata().astype("uint32") - assert lyso.shape == data.shape - f.create_dataset("labels/lyso", data=lyso, compression="gzip") - - mito_path = os.path.join(root, "mito", "instance", fname) - if os.path.exists(mito_path): - mito = nib.load(mito_path).get_fdata().astype("uint32") - assert mito.shape == data.shape - f.create_dataset("labels/mito", data=mito, compression="gzip") - - # clean up - rmtree(root) - - -def _get_paths(path, target): - label_key = f"labels/{target}" - all_paths = glob(os.path.join(path, "*.h5")) - all_paths.sort() - paths = [path for path in all_paths if label_key in h5py.File(path, "r")] - return paths, label_key - - -def get_uro_cell_dataset( - path, - target, - patch_shape, - download=False, - offsets=None, - boundaries=False, - binary=False, - **kwargs -): - """Dataset for the segmentation of mitochondria and other organelles in EM. - - This dataset is from the publication https://doi.org/10.1016/j.compbiomed.2020.103693. - Please cite it if you use this dataset for a publication. - """ - assert target in ("fv", "golgi", "lyso", "mito") - _require_urocell_data(path, download) - paths, label_key = _get_paths(path, target) - - assert sum((offsets is not None, boundaries, binary)) <= 1, f"{offsets}, {boundaries}, {binary}" - if offsets is not None: - if target in ("lyso", "golgi"): - warnings.warn( - f"{target} does not have instance labels, affinities will be computed based on binary segmentation." - ) - # we add a binary target channel for foreground background segmentation - label_transform = torch_em.transform.label.AffinityTransform(offsets=offsets, - ignore_label=None, - add_binary_target=True, - add_mask=True) - msg = "Offsets are passed, but 'label_transform2' is in the kwargs. It will be over-ridden." - kwargs = util.update_kwargs(kwargs, 'label_transform2', label_transform, msg=msg) - elif boundaries: - if target in ("lyso", "golgi"): - warnings.warn( - f"{target} does not have instance labels, boundaries will be computed based on binary segmentation." - ) - label_transform = torch_em.transform.label.BoundaryTransform(add_binary_target=True) - msg = "Boundaries is set to true, but 'label_transform' is in the kwargs. It will be over-ridden." - kwargs = util.update_kwargs(kwargs, 'label_transform', label_transform, msg=msg) - elif binary: - label_transform = torch_em.transform.label.labels_to_binary - msg = "Binary is set to true, but 'label_transform' is in the kwargs. It will be over-ridden." - kwargs = util.update_kwargs(kwargs, 'label_transform', label_transform, msg=msg) - - raw_key = "raw" - return torch_em.default_segmentation_dataset( - paths, raw_key, paths, label_key, patch_shape, is_seg_dataset=True, **kwargs - ) - - -def get_uro_cell_loader( - path, - target, - patch_shape, - batch_size, - download=False, - offsets=None, - boundaries=False, - binary=False, - **kwargs -): - """Dataloader for the segmentation of mitochondria and other organelles in EM. See 'get_uro_cell_dataset'.""" - ds_kwargs, loader_kwargs = util.split_kwargs( - torch_em.default_segmentation_dataset, **kwargs - ) - ds = get_uro_cell_dataset( - path, target, patch_shape, download=download, offsets=offsets, boundaries=boundaries, binary=binary, **ds_kwargs - ) - return torch_em.get_data_loader(ds, batch_size=batch_size, **loader_kwargs) diff --git a/torch_em/data/datasets/util.py b/torch_em/data/datasets/util.py index e6a38236..07d53d52 100644 --- a/torch_em/data/datasets/util.py +++ b/torch_em/data/datasets/util.py @@ -1,26 +1,39 @@ -import inspect import os import hashlib -import zipfile -import numpy as np +import inspect +import requests from tqdm import tqdm from warnings import warn -from xml.dom import minidom -from shutil import copyfileobj, which from subprocess import run from packaging import version +from shutil import copyfileobj, which +import zipfile +import numpy as np +from xml.dom import minidom from skimage.draw import polygon import torch + import torch_em -import requests +from torch_em.transform import get_raw_transform +from torch_em.transform.generic import ResizeLongestSideInputs, Compose try: import gdown except ImportError: gdown = None +try: + from tcia_utils import nbia +except ModuleNotFoundError: + nbia = None + +try: + from cryoet_data_portal import Client, Dataset +except ImportError: + Client, Dataset = None, None + BIOIMAGEIO_IDS = { "covid_if": "ilastik/covid_if_training_data", @@ -146,6 +159,44 @@ def download_source_empiar(path, access_id, download): return download_path +def download_source_kaggle(path, dataset_name, download, competition=False): + if not download: + raise RuntimeError(f"Cannot fine the data at {path}, but download was set to False.") + + try: + from kaggle.api.kaggle_api_extended import KaggleApi + except ModuleNotFoundError: + msg = "Please install the Kaggle API. You can do this using 'pip install kaggle'. " + msg += "After you have installed kaggle, you would need an API token. " + msg += "Follow the instructions at https://www.kaggle.com/docs/api." + raise ModuleNotFoundError(msg) + + api = KaggleApi() + api.authenticate() + + if competition: + api.competition_download_files(competition=dataset_name, path=path, quiet=False) + else: + api.dataset_download_files(dataset=dataset_name, path=path, quiet=False) + + +def download_source_tcia(path, url, dst, csv_filename, download): + if not download: + raise RuntimeError(f"Cannot fine the data at {path}, but download was set to False.") + + assert url.endswith(".tcia"), f"{path} is not a TCIA Manifest." + + # downloads the manifest file from the collection page + manifest = requests.get(url=url) + with open(path, "wb") as f: + f.write(manifest.content) + + # this part extracts the UIDs from the manigests and downloads them. + nbia.downloadSeries( + series_data=path, input_type="manifest", path=dst, csv_filename=csv_filename, + ) + + def update_kwargs(kwargs, key, value, msg=None): if key in kwargs: msg = f"{key} will be over-ridden in loader kwargs." if msg is None else msg @@ -154,6 +205,42 @@ def update_kwargs(kwargs, key, value, msg=None): return kwargs +def unzip_tarfile(tar_path, dst, remove=True): + import tarfile + + if tar_path.endswith(".tar.gz"): + access_mode = "r:gz" + elif tar_path.endswith(".tar"): + access_mode = "r:" + else: + raise ValueError( + "The provided file isn't a supported archive to unpack. ", + f"Please check the file: {tar_path}" + ) + + tar = tarfile.open(tar_path, access_mode) + tar.extractall(dst) + tar.close() + + if remove: + os.remove(tar_path) + + +def unzip_rarfile(rar_path, dst, remove=True, use_rarfile=True): + import rarfile + import aspose.zip as az + + if use_rarfile: + with rarfile.RarFile(rar_path) as f: + f.extractall(path=dst) + else: + with az.rar.RarArchive(rar_path) as archive: + archive.extract_to_directory(dst) + + if remove: + os.remove(rar_path) + + def unzip(zip_path, dst, remove=True): with zipfile.ZipFile(zip_path, "r") as f: f.extractall(dst) @@ -208,6 +295,49 @@ def add_instance_label_transform( return kwargs, label_dtype +def update_kwargs_for_resize_trafo(kwargs, patch_shape, resize_inputs, resize_kwargs=None, ensure_rgb=None): + """ + Checks for raw_transform and label_transform incoming values. + If yes, it will automatically merge these two transforms to apply them together. + """ + if resize_inputs: + assert isinstance(resize_kwargs, dict) + + target_shape = resize_kwargs.get("patch_shape") + if len(resize_kwargs["patch_shape"]) == 3: + # we only need the XY dimensions to reshape the inputs along them. + target_shape = target_shape[1:] + # we provide the Z dimension value to return the desired number of slices and not the whole volume + kwargs["z_ext"] = resize_kwargs["patch_shape"][0] + + raw_trafo = ResizeLongestSideInputs(target_shape=target_shape, is_rgb=resize_kwargs["is_rgb"]) + label_trafo = ResizeLongestSideInputs(target_shape=target_shape, is_label=True) + + # The patch shape provided to the dataset. Here, "None" means that the entire volume will be loaded. + patch_shape = None + + if ensure_rgb is None: + raw_trafos = [] + else: + assert not isinstance(ensure_rgb, bool), "'ensure_rgb' is expected to be a function." + raw_trafos = [ensure_rgb] + + if "raw_transform" in kwargs: + raw_trafos.extend([raw_trafo, kwargs["raw_transform"]]) + else: + raw_trafos.extend([raw_trafo, get_raw_transform()]) + + kwargs["raw_transform"] = Compose(*raw_trafos, is_multi_tensor=False) + + if "label_transform" in kwargs: + trafo = Compose(label_trafo, kwargs["label_transform"], is_multi_tensor=False) + kwargs["label_transform"] = trafo + else: + kwargs["label_transform"] = label_trafo + + return kwargs, patch_shape + + def generate_labeled_array_from_xml(shape, xml_file): """Function taken from: https://github.com/rshwndsz/hover-net/blob/master/lightning_hovernet.ipynb @@ -278,3 +408,21 @@ def convert_svs_to_array(path, location=(0, 0), level=0, img_size=None): img_arr = _slide.read_region(location=location, level=level, size=img_size, as_array=True) return img_arr + + +def download_from_cryo_et_portal(path, dataset_id, download): + if Client is None or Dataset is None: + raise RuntimeError("Please install CryoETDataPortal via 'pip install cryoet-data-portal'") + + output_path = os.path.join(path, str(dataset_id)) + if os.path.exists(output_path): + return output_path + + if not download: + raise RuntimeError(f"Cannot find the data at {path}, but download was set to False") + + client = Client() + dataset = Dataset.get_by_id(client, dataset_id) + dataset.download_everything(dest_path=path) + + return output_path diff --git a/torch_em/data/datasets/vnc.py b/torch_em/data/datasets/vnc.py deleted file mode 100644 index 676dc132..00000000 --- a/torch_em/data/datasets/vnc.py +++ /dev/null @@ -1,102 +0,0 @@ -import os -from glob import glob -from shutil import rmtree - -import imageio -import h5py -import numpy as np -import torch_em -from skimage.measure import label - -from . import util - -URL = "https://github.com/unidesigner/groundtruth-drosophila-vnc/archive/refs/heads/master.zip" -CHECKSUM = "f7bd0db03c86b64440a16b60360ad60c0a4411f89e2c021c7ee2c8d6af3d7e86" - - -def _create_volume(f, key, pattern, process=None): - images = glob(pattern) - images.sort() - data = np.concatenate([imageio.imread(im)[None] for im in images], axis=0) - if process is not None: - data = process(data) - f.create_dataset(key, data=data, compression="gzip") - - -def _get_vnc_data(path, download): - train_path = os.path.join(path, "vnc_train.h5") - test_path = os.path.join(path, "vnc_test.h5") - if os.path.exists(train_path) and os.path.exists(test_path): - return - - os.makedirs(path, exist_ok=True) - zip_path = os.path.join(path, "vnc.zip") - util.download_source(zip_path, URL, download, CHECKSUM) - util.unzip(zip_path, path, remove=True) - - root = os.path.join(path, "groundtruth-drosophila-vnc-master") - assert os.path.exists(root) - - with h5py.File(train_path, "w") as f: - _create_volume(f, "raw", os.path.join(root, "stack1", "raw", "*.tif")) - _create_volume(f, "labels/mitochondria", os.path.join(root, "stack1", "mitochondria", "*.png"), process=label) - _create_volume(f, "labels/synapses", os.path.join(root, "stack1", "synapses", "*.png"), process=label) - # TODO find the post-processing to go from neuron labels to membrane labels - # _create_volume(f, "labels/neurons", os.path.join(root, "stack1", "membranes", "*.png")) - - with h5py.File(test_path, "w") as f: - _create_volume(f, "raw", os.path.join(root, "stack2", "raw", "*.tif")) - - rmtree(root) - - -def get_vnc_mito_dataset( - path, - patch_shape, - offsets=None, - boundaries=False, - binary=False, - download=False, - **kwargs -): - """Dataset for the segmentation of mitochondria in EM. - - This dataset is from https://doi.org/10.6084/m9.figshare.856713.v1. - Please cite it if you use this dataset for a publication. - """ - _get_vnc_data(path, download) - data_path = os.path.join(path, "vnc_train.h5") - - kwargs, _ = util.add_instance_label_transform( - kwargs, add_binary_target=True, boundaries=boundaries, offsets=offsets, binary=binary, - ) - - raw_key = "raw" - label_key = "labels/mitochondria" - return torch_em.default_segmentation_dataset(data_path, raw_key, data_path, label_key, patch_shape, **kwargs) - - -def get_vnc_mito_loader( - path, - patch_shape, - batch_size, - offsets=None, - boundaries=False, - binary=False, - download=False, - **kwargs -): - """Dataloader for the segmentation of mitochondria in EM. See 'get_vnc_mito_loader'.""" - ds_kwargs, loader_kwargs = util.split_kwargs( - torch_em.default_segmentation_dataset, **kwargs - ) - ds = get_vnc_mito_dataset( - path, patch_shape, download=download, offsets=offsets, boundaries=boundaries, binary=binary, **kwargs - ) - return torch_em.get_data_loader(ds, batch_size=batch_size, **loader_kwargs) - - -# TODO implement -# TODO extra kwargs for binary / boundaries / affinities -def get_vnc_neuron_loader(path, patch_shape, download=False, **kwargs): - raise NotImplementedError diff --git a/torch_em/data/image_collection_dataset.py b/torch_em/data/image_collection_dataset.py index e005d15a..2b66ac77 100644 --- a/torch_em/data/image_collection_dataset.py +++ b/torch_em/data/image_collection_dataset.py @@ -4,8 +4,9 @@ import torch -from ..util import (ensure_spatial_array, ensure_tensor_with_channels, - load_image, supports_memmap) +from ..util import ( + ensure_spatial_array, ensure_tensor_with_channels, load_image, supports_memmap, ensure_patch_shape +) class ImageCollectionDataset(torch.utils.data.Dataset): @@ -62,13 +63,15 @@ def __init__( n_samples: Optional[int] = None, sampler=None, full_check: bool = False, + with_padding=True, ): self._check_inputs(raw_image_paths, label_image_paths, full_check=full_check) self.raw_images = raw_image_paths self.label_images = label_image_paths self._ndim = 2 - assert len(patch_shape) == self._ndim + if patch_shape is not None: + assert len(patch_shape) == self._ndim self.patch_shape = patch_shape self.raw_transform = raw_transform @@ -76,6 +79,7 @@ def __init__( self.label_transform2 = label_transform2 self.transform = transform self.sampler = sampler + self.with_padding = with_padding self.dtype = dtype self.label_dtype = label_dtype @@ -95,30 +99,17 @@ def ndim(self): return self._ndim def _sample_bounding_box(self, shape): - bb_start = [ - np.random.randint(0, sh - psh) if sh - psh > 0 else 0 - for sh, psh in zip(shape, self.patch_shape) - ] - return tuple(slice(start, start + psh) for start, psh in zip(bb_start, self.patch_shape)) - - def _ensure_patch_shape(self, raw, labels, have_raw_channels, have_label_channels, channel_first): - shape = raw.shape - if have_raw_channels and channel_first: - shape = shape[1:] - if any(sh < psh for sh, psh in zip(shape, self.patch_shape)): - pw = [(0, max(0, psh - sh)) for sh, psh in zip(shape, self.patch_shape)] - - if have_raw_channels and channel_first: - pw_raw = [(0, 0), *pw] - elif have_raw_channels and not channel_first: - pw_raw = [*pw, (0, 0)] - else: - pw_raw = pw - - # TODO: ensure padding for labels with channels, when supported (see `_get_sample` below) + if self.patch_shape is None: + patch_shape_for_bb = shape + bb_start = [0] * len(shape) + else: + patch_shape_for_bb = self.patch_shape + bb_start = [ + np.random.randint(0, sh - psh) if sh - psh > 0 else 0 + for sh, psh in zip(shape, patch_shape_for_bb) + ] - raw, labels = np.pad(raw, pw_raw), np.pad(labels, pw) - return raw, labels + return tuple(slice(start, start + psh) for start, psh in zip(bb_start, patch_shape_for_bb)) def _load_data(self, raw_path, label_path): raw = load_image(raw_path, memmap=False) @@ -137,7 +128,16 @@ def _load_data(self, raw_path, label_path): if have_raw_channels: channel_first = raw.shape[-1] > 16 - raw, label = self._ensure_patch_shape(raw, label, have_raw_channels, have_label_channels, channel_first) + if self.patch_shape is not None and self.with_padding: + raw, label = ensure_patch_shape( + raw=raw, + labels=label, + patch_shape=self.patch_shape, + have_raw_channels=have_raw_channels, + have_label_channels=have_label_channels, + channel_first=channel_first + ) + shape = raw.shape prefix_box = tuple() @@ -173,7 +173,7 @@ def _get_sample(self, index): label_patch = np.array(label[bb]) sample_id += 1 - # We need to avoid sampling from the same image over and over agagin, + # We need to avoid sampling from the same image over and over again, # otherwise this will fail just because of one or a few empty images. # Hence we update the image from which we sample sometimes. if sample_id % self.max_sampling_attempts_image == 0: diff --git a/torch_em/data/raw_dataset.py b/torch_em/data/raw_dataset.py index ec410718..402659d1 100644 --- a/torch_em/data/raw_dataset.py +++ b/torch_em/data/raw_dataset.py @@ -7,7 +7,7 @@ from elf.wrapper import RoiWrapper -from ..util import ensure_tensor_with_channels, load_data +from ..util import ensure_tensor_with_channels, ensure_patch_shape, load_data class RawDataset(torch.utils.data.Dataset): @@ -107,7 +107,13 @@ def _get_sample(self, index): sample_id += 1 if sample_id > self.max_sampling_attempts: raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts") - + if self.patch_shape is not None: + raw = ensure_patch_shape( + raw=raw, + labels=None, + patch_shape=self.patch_shape, + have_raw_channels=self._with_channels + ) # squeeze the singleton spatial axis if we have a spatial shape that is larger by one than self._ndim if len(self.patch_shape) == self._ndim + 1: raw = raw.squeeze(1 if self._with_channels else 0) diff --git a/torch_em/data/sampler.py b/torch_em/data/sampler.py index 7a47fb56..930aad63 100644 --- a/torch_em/data/sampler.py +++ b/torch_em/data/sampler.py @@ -1,5 +1,5 @@ import numpy as np -from typing import List +from typing import List, Optional class MinForegroundSampler: @@ -85,13 +85,18 @@ class MinInstanceSampler: def __init__( self, min_num_instances: int = 2, - p_reject: float = 1.0 + p_reject: float = 1.0, + min_size: Optional[int] = None ): self.min_num_instances = min_num_instances self.p_reject = p_reject + self.min_size = min_size def __call__(self, x, y): - uniques = np.unique(y) + uniques, sizes = np.unique(y, return_counts=True) + if self.min_size is not None: + filter_ids = uniques[sizes >= self.min_size] + uniques = filter_ids if len(uniques) >= self.min_num_instances: return True else: diff --git a/torch_em/data/segmentation_dataset.py b/torch_em/data/segmentation_dataset.py index 275a04ce..821b5829 100644 --- a/torch_em/data/segmentation_dataset.py +++ b/torch_em/data/segmentation_dataset.py @@ -7,7 +7,7 @@ from elf.wrapper import RoiWrapper -from ..util import ensure_spatial_array, ensure_tensor_with_channels, load_data +from ..util import ensure_spatial_array, ensure_tensor_with_channels, load_data, ensure_patch_shape class SegmentationDataset(torch.utils.data.Dataset): @@ -17,8 +17,11 @@ class SegmentationDataset(torch.utils.data.Dataset): @staticmethod def compute_len(shape, patch_shape): - n_samples = int(np.prod([float(sh / csh) for sh, csh in zip(shape, patch_shape)])) - return n_samples + if patch_shape is None: + return 1 + else: + n_samples = int(np.prod([float(sh / csh) for sh, csh in zip(shape, patch_shape)])) + return n_samples def __init__( self, @@ -39,6 +42,8 @@ def __init__( ndim: Optional[int] = None, with_channels: bool = False, with_label_channels: bool = False, + with_padding: bool = True, + z_ext: Optional[int] = None, ): self.raw_path = raw_path self.raw_key = raw_key @@ -67,7 +72,10 @@ def __init__( self._ndim = len(shape_raw) if ndim is None else ndim assert self._ndim in (2, 3, 4), f"Invalid data dimensions: {self._ndim}. Only 2d, 3d or 4d data is supported" - assert len(patch_shape) in (self._ndim, self._ndim + 1), f"{patch_shape}, {self._ndim}" + + if patch_shape is not None: + assert len(patch_shape) in (self._ndim, self._ndim + 1), f"{patch_shape}, {self._ndim}" + self.patch_shape = patch_shape self.raw_transform = raw_transform @@ -75,12 +83,15 @@ def __init__( self.label_transform2 = label_transform2 self.transform = transform self.sampler = sampler + self.with_padding = with_padding self.dtype = dtype self.label_dtype = label_dtype self._len = self.compute_len(self.shape, self.patch_shape) if n_samples is None else n_samples + self.z_ext = z_ext + self.sample_shape = patch_shape self.trafo_halo = None # TODO add support for trafo halo: asking for a bigger bounding box before applying the trafo, @@ -102,33 +113,56 @@ def ndim(self): return self._ndim def _sample_bounding_box(self): - bb_start = [ - np.random.randint(0, sh - psh) if sh - psh > 0 else 0 - for sh, psh in zip(self.shape, self.sample_shape) - ] - return tuple(slice(start, start + psh) for start, psh in zip(bb_start, self.sample_shape)) - - def _get_sample(self, index): - if self.raw is None or self.labels is None: - raise RuntimeError("SegmentationDataset has not been properly deserialized.") + if self.sample_shape is None: + if self.z_ext is None: + bb_start = [0] * len(self.shape) + patch_shape_for_bb = self.shape + else: + z_diff = self.shape[0] - self.z_ext + bb_start = [np.random.randint(0, z_diff) if z_diff > 0 else 0] + [0] * len(self.shape[1:]) + patch_shape_for_bb = (self.z_ext, *self.shape[1:]) + + else: + bb_start = [ + np.random.randint(0, sh - psh) if sh - psh > 0 else 0 + for sh, psh in zip(self.shape, self.sample_shape) + ] + patch_shape_for_bb = self.sample_shape + + return tuple(slice(start, start + psh) for start, psh in zip(bb_start, patch_shape_for_bb)) + + def _get_desired_raw_and_labels(self): bb = self._sample_bounding_box() bb_raw = (slice(None),) + bb if self._with_channels else bb bb_labels = (slice(None),) + bb if self._with_label_channels else bb raw, labels = self.raw[bb_raw], self.labels[bb_labels] + return raw, labels + + def _get_sample(self, index): + if self.raw is None or self.labels is None: + raise RuntimeError("SegmentationDataset has not been properly deserialized.") + + raw, labels = self._get_desired_raw_and_labels() if self.sampler is not None: sample_id = 0 while not self.sampler(raw, labels): - bb = self._sample_bounding_box() - bb_raw = (slice(None),) + bb if self._with_channels else bb - bb_labels = (slice(None),) + bb if self._with_label_channels else bb - raw, labels = self.raw[bb_raw], self.labels[bb_labels] + raw, labels = self._get_desired_raw_and_labels() sample_id += 1 if sample_id > self.max_sampling_attempts: raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts") + # Padding the patch to match the expected input shape. + if self.patch_shape is not None and self.with_padding: + raw, labels = ensure_patch_shape( + raw=raw, + labels=labels, + patch_shape=self.patch_shape, + have_raw_channels=self._with_channels, + have_label_channels=self._with_label_channels, + ) # squeeze the singleton spatial axis if we have a spatial shape that is larger by one than self._ndim - if len(self.patch_shape) == self._ndim + 1: + if self.patch_shape is not None and len(self.patch_shape) == self._ndim + 1: raw = raw.squeeze(1 if self._with_channels else 0) labels = labels.squeeze(1 if self._with_label_channels else 0) @@ -184,7 +218,7 @@ def __setstate__(self, state): except Exception: msg = f"SegmentationDataset could not be deserialized because of missing {raw_path}, {raw_key}.\n" msg += "The dataset is deserialized in order to allow loading trained models from a checkpoint.\n" - msg += "But it cannot be used for further training and wil throw an error." + msg += "But it cannot be used for further training and will throw an error." warnings.warn(msg) state["raw"] = None @@ -197,7 +231,7 @@ def __setstate__(self, state): except Exception: msg = f"SegmentationDataset could not be deserialized because of missing {label_path}, {label_key}.\n" msg += "The dataset is deserialized in order to allow loading trained models from a checkpoint.\n" - msg += "But it cannot be used for further training and wil throw an error." + msg += "But it cannot be used for further training and will throw an error." warnings.warn(msg) state["labels"] = None diff --git a/torch_em/model/unet.py b/torch_em/model/unet.py index fdc45f3b..fc6c255d 100644 --- a/torch_em/model/unet.py +++ b/torch_em/model/unet.py @@ -353,10 +353,10 @@ def get_norm_layer(norm, dim, channels, n_groups=32): if norm is None: return None if norm == "InstanceNorm": + return nn.InstanceNorm2d(channels) if dim == 2 else nn.InstanceNorm3d(channels) + elif norm == "InstanceNormTrackStats": kwargs = {"affine": True, "track_running_stats": True, "momentum": 0.01} return nn.InstanceNorm2d(channels, **kwargs) if dim == 2 else nn.InstanceNorm3d(channels, **kwargs) - elif norm == "OldDefault": - return nn.InstanceNorm2d(channels) if dim == 2 else nn.InstanceNorm3d(channels) elif norm == "GroupNorm": return nn.GroupNorm(min(n_groups, channels), channels) elif norm == "BatchNorm": diff --git a/torch_em/model/unetr.py b/torch_em/model/unetr.py index d86f5357..06c225c2 100644 --- a/torch_em/model/unetr.py +++ b/torch_em/model/unetr.py @@ -127,7 +127,6 @@ def __init__( scale_factors=scale_factors[::-1], conv_block_impl=ConvBlock2d, sampler_impl=_upsampler, - norm="OldDefault", ) else: self.decoder = decoder @@ -143,14 +142,14 @@ def __init__( Deconv2DBlock(features_decoder[0], features_decoder[1]), Deconv2DBlock(features_decoder[1], features_decoder[2]) ) - self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1], norm="OldDefault") + self.deconv4 = ConvBlock2d(in_chans, features_decoder[-1]) else: self.deconv1 = Deconv2DBlock(embed_dim, features_decoder[0]) self.deconv2 = Deconv2DBlock(features_decoder[0], features_decoder[1]) self.deconv3 = Deconv2DBlock(features_decoder[1], features_decoder[2]) self.deconv4 = Deconv2DBlock(features_decoder[2], features_decoder[3]) - self.base = ConvBlock2d(embed_dim, features_decoder[0], norm="OldDefault") + self.base = ConvBlock2d(embed_dim, features_decoder[0]) self.out_conv = nn.Conv2d(features_decoder[-1], out_channels, 1) @@ -158,7 +157,7 @@ def __init__( scale_factor=2, in_channels=features_decoder[-1], out_channels=features_decoder[-1] ) - self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1], norm="OldDefault") + self.decoder_head = ConvBlock2d(2 * features_decoder[-1], features_decoder[-1]) self.final_activation = self._get_activation(final_activation) diff --git a/torch_em/multi_gpu_training.py b/torch_em/multi_gpu_training.py new file mode 100644 index 00000000..ae8765d2 --- /dev/null +++ b/torch_em/multi_gpu_training.py @@ -0,0 +1,141 @@ +import os +from functools import partial + +import torch +import torch_em +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel + + +def setup(rank, world_size): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + dist.init_process_group("nccl", rank=rank, world_size=world_size) + + +def cleanup(): + dist.destroy_process_group() + + +def _create_data_loader(ds_callable, ds_kwargs, loader_kwargs, world_size, rank): + # Create the dataset. + ds = ds_callable(**ds_kwargs) + + # Create the sampler + # Set shuffle on the sampler instead of the loader + shuffle = loader_kwargs.pop("shuffle", False) + sampler = torch.utils.data.distributed.DistributedSampler(ds, num_replicas=world_size, rank=rank, shuffle=shuffle) + + # Create the loader. + loader = torch.utils.data.DataLoader(ds, sampler=sampler, **loader_kwargs) + loader.shuffle = shuffle + + return loader + + +class DDP(DistributedDataParallel): + """Wrapper for the DistributedDataParallel class that overrides the `__getattr__` method + to handle access from the "model" object module wrapped by DDP. + """ + def __getattr__(self, name): + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self.module, name) + + +def _train_impl( + rank, + world_size, + model_callable, + model_kwargs, + train_dataset_callable, + train_dataset_kwargs, + val_dataset_callable, + val_dataset_kwargs, + loader_kwargs, + iterations, + find_unused_parameters=True, + optimizer_callable=None, + optimizer_kwargs=None, + lr_scheduler_callable=None, + lr_scheduler_kwargs=None, + trainer_callable=None, + **kwargs +): + assert "device" not in kwargs + print(f"Running DDP on rank {rank}.") + setup(rank, world_size) + + model = model_callable(**model_kwargs).to(rank) + ddp_model = DDP(model, device_ids=[rank], find_unused_parameters=find_unused_parameters) + + if optimizer_callable is not None: + optimizer = optimizer_callable(model.parameters(), **optimizer_kwargs) + kwargs["optimizer"] = optimizer + if lr_scheduler_callable is not None: + lr_scheduler = lr_scheduler_callable(optimizer, **lr_scheduler_kwargs) + kwargs["lr_scheduler"] = lr_scheduler + + train_loader = _create_data_loader(train_dataset_callable, train_dataset_kwargs, loader_kwargs, world_size, rank) + val_loader = _create_data_loader(val_dataset_callable, val_dataset_kwargs, loader_kwargs, world_size, rank) + + if trainer_callable is None: + trainer_callable = torch_em.default_segmentation_trainer + + trainer = trainer_callable( + model=ddp_model, + train_loader=train_loader, + val_loader=val_loader, + device=rank, + rank=rank, + **kwargs + ) + trainer.fit(iterations=iterations) + + cleanup() + + +def train_multi_gpu( + model_callable, + model_kwargs, + train_dataset_callable, + train_dataset_kwargs, + val_dataset_callable, + val_dataset_kwargs, + loader_kwargs, + iterations, + find_unused_parameters=True, + optimizer_callable=None, + optimizer_kwargs=None, + lr_scheduler_callable=None, + lr_scheduler_kwargs=None, + trainer_callable=None, + **kwargs +) -> None: + """ + + Args: + model: The PyTorch model to be trained. + kwargs: Keyword arguments for `torch_em.segmentation.default_segmentation_trainer`. + """ + world_size = torch.cuda.device_count() + train = partial( + _train_impl, + model_callable=model_callable, + model_kwargs=model_kwargs, + train_dataset_callable=train_dataset_callable, + train_dataset_kwargs=train_dataset_kwargs, + val_dataset_callable=val_dataset_callable, + val_dataset_kwargs=val_dataset_kwargs, + loader_kwargs=loader_kwargs, + iterations=iterations, + find_unused_parameters=find_unused_parameters, + optimizer_callable=optimizer_callable, + optimizer_kwargs=optimizer_kwargs, + lr_scheduler_callable=lr_scheduler_callable, + lr_scheduler_kwargs=lr_scheduler_kwargs, + trainer_callable=trainer_callable, + **kwargs + ) + torch.multiprocessing.spawn(train, args=(world_size,), nprocs=world_size, join=True) diff --git a/torch_em/segmentation.py b/torch_em/segmentation.py index 9ea62efc..dccab67e 100644 --- a/torch_em/segmentation.py +++ b/torch_em/segmentation.py @@ -138,11 +138,12 @@ def _get_paths(rpath, rkey, lpath, lkey, this_roi): return rpath, lpath patch_shape = kwargs.pop("patch_shape") - if len(patch_shape) == 3: - if patch_shape[0] != 1: - raise ValueError(f"Image collection dataset expects 2d patch shape, got {patch_shape}") - patch_shape = patch_shape[1:] - assert len(patch_shape) == 2 + if patch_shape is not None: + if len(patch_shape) == 3: + if patch_shape[0] != 1: + raise ValueError(f"Image collection dataset expects 2d patch shape, got {patch_shape}") + patch_shape = patch_shape[1:] + assert len(patch_shape) == 2 if isinstance(raw_paths, str): raw_paths, label_paths = _get_paths(raw_paths, raw_key, label_paths, label_key, roi) @@ -205,6 +206,7 @@ def default_segmentation_loader( is_seg_dataset=None, with_channels=False, with_label_channels=False, + verify_paths=True, **loader_kwargs, ): ds = default_segmentation_dataset( @@ -226,6 +228,7 @@ def default_segmentation_loader( is_seg_dataset=is_seg_dataset, with_channels=with_channels, with_label_channels=with_label_channels, + verify_paths=verify_paths, ) return get_data_loader(ds, batch_size=batch_size, **loader_kwargs) @@ -249,8 +252,13 @@ def default_segmentation_dataset( is_seg_dataset=None, with_channels=False, with_label_channels=False, + verify_paths=True, + with_padding=True, + z_ext=None, ): - check_paths(raw_paths, label_paths) + if verify_paths: + check_paths(raw_paths, label_paths) + if is_seg_dataset is None: is_seg_dataset = is_segmentation_dataset(raw_paths, raw_key, label_paths, label_key) @@ -283,6 +291,8 @@ def default_segmentation_dataset( label_dtype=label_dtype, with_channels=with_channels, with_label_channels=with_label_channels, + with_padding=with_padding, + z_ext=z_ext, ) else: ds = _load_image_collection_dataset( @@ -300,13 +310,15 @@ def default_segmentation_dataset( sampler=sampler, dtype=dtype, label_dtype=label_dtype, + with_padding=with_padding, ) return ds -def get_data_loader(dataset: torch.utils.data.Dataset, batch_size, **loader_kwargs) -> torch.utils.data.DataLoader: - loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, **loader_kwargs) +def get_data_loader(dataset: torch.utils.data.Dataset, batch_size: int, **loader_kwargs) -> torch.utils.data.DataLoader: + pin_memory = loader_kwargs.pop("pin_memory", True) + loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, pin_memory=pin_memory, **loader_kwargs) # monkey patch shuffle attribute to the loader loader.shuffle = loader_kwargs.get("shuffle", False) return loader @@ -337,8 +349,9 @@ def default_segmentation_trainer( id_=None, save_root=None, compile_model=None, + rank=None, ): - optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, **optimizer_kwargs) + optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, **optimizer_kwargs) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, **scheduler_kwargs) loss = DiceLoss() if loss is None else loss @@ -371,5 +384,6 @@ def default_segmentation_trainer( id_=id_, save_root=save_root, compile_model=compile_model, + rank=rank, ) return trainer diff --git a/torch_em/self_training/fix_match.py b/torch_em/self_training/fix_match.py index c92a7176..03ce7b12 100644 --- a/torch_em/self_training/fix_match.py +++ b/torch_em/self_training/fix_match.py @@ -183,7 +183,7 @@ def _train_epoch_unsupervised(self, progress, forward_context, backprop): # Sample from both the supervised and unsupervised loader. for xu1, xu2 in self.unsupervised_train_loader: - xu1, xu2 = xu1.to(self.device), xu2.to(self.device) + xu1, xu2 = xu1.to(self.device, non_blocking=True), xu2.to(self.device, non_blocking=True) teacher_input, model_input = xu1, xu2 @@ -231,8 +231,8 @@ def _train_epoch_semisupervised(self, progress, forward_context, backprop): # Sample from both the supervised and unsupervised loader. for (xs, ys), (xu1, xu2) in zip(self.supervised_train_loader, self.unsupervised_train_loader): - xs, ys = xs.to(self.device), ys.to(self.device) - xu1, xu2 = xu1.to(self.device), xu2.to(self.device) + xs, ys = xs.to(self.device, non_blocking=True), ys.to(self.device, non_blocking=True) + xu1, xu2 = xu1.to(self.device, non_blocking=True), xu2.to(self.device, non_blocking=True) # Perform supervised training. self.optimizer.zero_grad() @@ -289,7 +289,7 @@ def _validate_supervised(self, forward_context): loss_val = 0.0 for x, y in self.supervised_val_loader: - x, y = x.to(self.device), y.to(self.device) + x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) with forward_context(): loss, metric = self.supervised_loss_and_metric(self.model, x, y) loss_val += loss.item() @@ -310,7 +310,7 @@ def _validate_unsupervised(self, forward_context): loss_val = 0.0 for x1, x2 in self.unsupervised_val_loader: - x1, x2 = x1.to(self.device), x2.to(self.device) + x1, x2 = x1.to(self.device, non_blocking=True), x2.to(self.device, non_blocking=True) teacher_input, model_input = x1, x2 with forward_context(): pseudo_labels, label_filter = self.pseudo_labeler(self.model, teacher_input) diff --git a/torch_em/self_training/mean_teacher.py b/torch_em/self_training/mean_teacher.py index 3820ae2f..03a32562 100644 --- a/torch_em/self_training/mean_teacher.py +++ b/torch_em/self_training/mean_teacher.py @@ -213,7 +213,7 @@ def _train_epoch_unsupervised(self, progress, forward_context, backprop): # Sample from both the supervised and unsupervised loader. for xu1, xu2 in self.unsupervised_train_loader: - xu1, xu2 = xu1.to(self.device), xu2.to(self.device) + xu1, xu2 = xu1.to(self.device, non_blocking=True), xu2.to(self.device, non_blocking=True) teacher_input, model_input = xu1, xu2 @@ -256,8 +256,8 @@ def _train_epoch_semisupervised(self, progress, forward_context, backprop): # Sample from both the supervised and unsupervised loader. for (xs, ys), (xu1, xu2) in zip(self.supervised_train_loader, self.unsupervised_train_loader): - xs, ys = xs.to(self.device), ys.to(self.device) - xu1, xu2 = xu1.to(self.device), xu2.to(self.device) + xs, ys = xs.to(self.device, non_blocking=True), ys.to(self.device, non_blocking=True) + xu1, xu2 = xu1.to(self.device, non_blocking=True), xu2.to(self.device, non_blocking=True) # Perform supervised training. self.optimizer.zero_grad() @@ -310,7 +310,7 @@ def _validate_supervised(self, forward_context): loss_val = 0.0 for x, y in self.supervised_val_loader: - x, y = x.to(self.device), y.to(self.device) + x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) with forward_context(): loss, metric = self.supervised_loss_and_metric(self.model, x, y) loss_val += loss.item() @@ -331,7 +331,7 @@ def _validate_unsupervised(self, forward_context): loss_val = 0.0 for x1, x2 in self.unsupervised_val_loader: - x1, x2 = x1.to(self.device), x2.to(self.device) + x1, x2 = x1.to(self.device, non_blocking=True), x2.to(self.device, non_blocking=True) teacher_input, model_input = x1, x2 with forward_context(): pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input) diff --git a/torch_em/self_training/probabilistic_unet_trainer.py b/torch_em/self_training/probabilistic_unet_trainer.py index 2af1f1e7..e0941dae 100644 --- a/torch_em/self_training/probabilistic_unet_trainer.py +++ b/torch_em/self_training/probabilistic_unet_trainer.py @@ -59,7 +59,7 @@ def _train_epoch_impl(self, progress, forward_context, backprop): t_per_iter = time.time() for x, y in self.train_loader: - x, y = x.to(self.device), y.to(self.device) + x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) self.optimizer.zero_grad() @@ -96,7 +96,7 @@ def _validate_impl(self, forward_context): with torch.no_grad(): for x, y in self.val_loader: - x, y = x.to(self.device), y.to(self.device) + x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) with forward_context(): loss, metric = self.loss_and_metric(self.model, x, y) diff --git a/torch_em/trainer/default_trainer.py b/torch_em/trainer/default_trainer.py index 06f9582b..b07c83d5 100644 --- a/torch_em/trainer/default_trainer.py +++ b/torch_em/trainer/default_trainer.py @@ -6,12 +6,12 @@ import time import warnings from collections import OrderedDict +from functools import partial from importlib import import_module from typing import Any, Callable, Dict, Optional, Union import numpy as np import torch -import torch.cuda.amp as amp from tqdm import tqdm from .tensorboard_logger import TensorboardLogger @@ -41,6 +41,7 @@ def __init__( id_: Optional[str] = None, save_root: Optional[str] = None, compile_model: Optional[Union[bool, str]] = None, + rank: Optional[int] = None, ): if name is None and not issubclass(logger, WandbLogger): raise TypeError("Name cannot be None if not using the WandbLogger") @@ -57,11 +58,12 @@ def __init__( self.loss = loss self.optimizer = optimizer self.metric = metric - self.device = device + self.device = torch.device(device) self.lr_scheduler = lr_scheduler self.log_image_interval = log_image_interval self.save_root = save_root self.compile_model = compile_model + self.rank = rank self._iteration = 0 self._epoch = 0 @@ -71,7 +73,10 @@ def __init__( self.early_stopping = early_stopping self.train_time = 0.0 - self.scaler = amp.GradScaler() if mixed_precision else None + if mixed_precision: + self.scaler = torch.GradScaler("cpu" if self.device.type == "cpu" else "cuda") + else: + self.scaler = None self.logger_class = logger self.logger_kwargs = logger_kwargs @@ -477,7 +482,10 @@ def save_checkpoint(self, name, current_metric, best_metric, train_time=0.0, **e save_dict.update({"scaler_state": self.scaler.state_dict()}) if self.lr_scheduler is not None: save_dict.update({"scheduler_state": self.lr_scheduler.state_dict()}) - torch.save(save_dict, save_path) + + rank = getattr(self, "rank", None) + if rank is None or rank == 0: + torch.save(save_dict, save_path) def load_checkpoint(self, checkpoint="best"): if isinstance(checkpoint, str): @@ -566,9 +574,15 @@ def fit( msg = "Epoch %i: average [s/it]: %f, current metric: %f, best metric: %f" train_epochs = self.max_epoch - self._epoch t_start = time.time() - for _ in range(train_epochs): + for epoch in range(train_epochs): - # run training and validation for this epoch + # Ensure data is shuffled differently at each epoch. + try: + self.train_loader.sampler.set_epoch(epoch) + except AttributeError: + pass + + # Run training and validation for this epoch t_per_iter = train_epoch(progress) current_metric = validate() @@ -633,7 +647,10 @@ def _train_epoch(self, progress): return self._train_epoch_impl(progress, contextlib.nullcontext, self._backprop) def _train_epoch_mixed(self, progress): - return self._train_epoch_impl(progress, amp.autocast, self._backprop_mixed) + return self._train_epoch_impl( + progress, partial(torch.autocast, device_type="cpu" if self.device.type == "cpu" else "cuda"), + self._backprop_mixed + ) def _forward_and_loss(self, x, y): pred = self.model(x) @@ -650,7 +667,7 @@ def _train_epoch_impl(self, progress, forward_context, backprop: Callable[[torch n_iter = 0 t_per_iter = time.time() for x, y in self.train_loader: - x, y = x.to(self.device), y.to(self.device) + x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) self.optimizer.zero_grad() @@ -676,7 +693,9 @@ def _validate(self): return self._validate_impl(contextlib.nullcontext) def _validate_mixed(self): - return self._validate_impl(amp.autocast) + return self._validate_impl( + partial(torch.autocast, device_type="cpu" if self.device.type == "cpu" else "cuda") + ) def _validate_impl(self, forward_context): self.model.eval() @@ -686,7 +705,7 @@ def _validate_impl(self, forward_context): with torch.no_grad(): for x, y in self.val_loader: - x, y = x.to(self.device), y.to(self.device) + x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) with forward_context(): pred, loss = self._forward_and_loss(x, y) metric = self.metric(pred, y) diff --git a/torch_em/trainer/spoco_trainer.py b/torch_em/trainer/spoco_trainer.py index 68c125f1..b2b5fdc8 100644 --- a/torch_em/trainer/spoco_trainer.py +++ b/torch_em/trainer/spoco_trainer.py @@ -56,7 +56,7 @@ def _train_epoch_semisupervised(self, progress, forward_context, backprop): ) for x in self.semisupervised_loader: - x = x.to(self.device) + x = x.to(self.device, non_blocking=True) self.optimizer.zero_grad() with forward_context(): @@ -76,7 +76,7 @@ def _train_epoch_impl(self, progress, forward_context, backprop): n_iter = 0 t_per_iter = time.time() for x, y in self.train_loader: - x, y = x.to(self.device), y.to(self.device) + x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) self.optimizer.zero_grad() @@ -120,7 +120,7 @@ def _validate_impl(self, forward_context): with torch.no_grad(): for x, y in self.val_loader: - x, y = x.to(self.device), y.to(self.device) + x, y = x.to(self.device, non_blocking=True), y.to(self.device, non_blocking=True) with forward_context(): prediction = self.model(x) prediction2 = self.model2(x) diff --git a/torch_em/transform/generic.py b/torch_em/transform/generic.py index b74c5be2..d1b1f49b 100644 --- a/torch_em/transform/generic.py +++ b/torch_em/transform/generic.py @@ -1,9 +1,10 @@ +from math import ceil, floor from typing import Any, Dict, Optional, Sequence, Union import numpy as np -import torch +from skimage.transform import rescale, resize -from skimage.transform import rescale +import torch class Tile(torch.nn.Module): @@ -34,13 +35,18 @@ def forward(self, input: Union[torch.Tensor, np.ndarray], params: Optional[Dict[ # a simple way to compose transforms class Compose: - def __init__(self, *transforms): + def __init__(self, *transforms, is_multi_tensor=True): self.transforms = transforms + self.is_multi_tensor = is_multi_tensor def __call__(self, *inputs): outputs = self.transforms[0](*inputs) for trafo in self.transforms[1:]: - outputs = trafo(*outputs) + if self.is_multi_tensor: + outputs = trafo(*outputs) + else: + outputs = trafo(outputs) + return outputs @@ -72,6 +78,112 @@ def __call__(self, *inputs): return outputs +class ResizeInputs: + def __init__(self, target_shape, is_label=False, is_rgb=False): + self.target_shape = target_shape + self.is_label = is_label + self.is_rgb = is_rgb + + def __call__(self, inputs): + if self.is_label: # kwargs needed for int data + kwargs = {"order": 0, "anti_aliasing": False} + else: # we use the default settings for float data + kwargs = {} + + if self.is_rgb: + assert inputs.ndim == 3 and inputs.shape[0] == 3 + patch_shape = (3, *self.target_shape) + else: + patch_shape = self.target_shape + + inputs = resize( + image=inputs, + output_shape=patch_shape, + preserve_range=True, + **kwargs + ).astype(inputs.dtype) + + return inputs + + +class ResizeLongestSideInputs: + def __init__(self, target_shape, is_label=False, is_rgb=False): + self.target_shape = target_shape + self.is_label = is_label + self.is_rgb = is_rgb + + h, w = self.target_shape[-2], self.target_shape[-1] + if h != w: # We currently support resize feature for square-shaped target shape only. + raise ValueError("'ResizeLongestSideInputs' does not support non-square shaped target shapes.") + + self.target_length = self.target_shape[-1] + + if self.is_label: # kwargs needed for int data + self.kwargs = {"order": 0, "anti_aliasing": False} + else: # we use the default settings for float data + self.kwargs = {} + + def _get_preprocess_shape(self, oldh, oldw): + """Inspired from Segment Anything. + + - https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/utils/transforms.py + """ + scale = self.target_length * 1.0 / max(oldh, oldw) + newh, neww = oldh * scale, oldw * scale + neww = int(neww + 0.5) + newh = int(newh + 0.5) + return (newh, neww) + + def convert_transformed_inputs_to_original_shape(self, resized_inputs): + if not hasattr(self, "pre_pad_shape"): + raise RuntimeError( + "'convert_transformed_inputs_to_original_shape' is only valid after the '__call__' method has run." + ) + + # First step is to remove the padded region + inputs = resized_inputs[tuple(self.pre_pad_shape)] + # Next, we resize the inputs to original shape + inputs = resize( + image=inputs, output_shape=self.original_shape, preserve_range=True, **self.kwargs + ) + return inputs + + def __call__(self, inputs): + # NOTE: We store this in case we would like to transform the inputs back to original shape. + self.original_shape = inputs.shape + + # Let's get the new shape with the longest side equal to the target length. + new_shape = self._get_preprocess_shape(inputs.shape[-2], inputs.shape[-1]) + + if self.is_rgb: # for rgb inputs, we assume channels first + assert inputs.ndim == 3 and inputs.shape[0] == 3 + patch_shape = (3, *new_shape) + elif inputs.ndim == 3: # for 3d inputs, we do not resize along the first (=z) axis + patch_shape = (inputs.shape[0], *new_shape) + else: + patch_shape = new_shape + + # Next, we resize the input image along the longest side. + inputs = resize( + image=inputs, output_shape=patch_shape, preserve_range=True, **self.kwargs + ).astype(inputs.dtype) + + # Finally, we pad the remaining height to match the expected target shape. + pad_width = [(sh - dsh) / 2 for sh, dsh in zip(self.target_shape, new_shape)] + pad_width = ( + (ceil(pad_width[0]), floor(pad_width[0])), (ceil(pad_width[1]), floor(pad_width[1])) + ) + # we do not pad across the first axis (= channel or z-axis) for rgb or 3d inputs + if self.is_rgb or inputs.ndim == 3: + pad_width = ((0, 0), *pad_width) + + # NOTE: We store this in case we would like to unpad the inputs. + self.pre_pad_shape = [slice(pw[0], -pw[1] if pw[1] > 0 else None) for pw in pad_width] + + inputs = np.pad(array=inputs, pad_width=pad_width, mode="constant") + return inputs + + class PadIfNecessary: def __init__(self, shape): self.shape = tuple(shape) diff --git a/torch_em/transform/label.py b/torch_em/transform/label.py index d13233e2..cd73b8ca 100644 --- a/torch_em/transform/label.py +++ b/torch_em/transform/label.py @@ -37,6 +37,22 @@ def label_consecutive(labels, with_background=True): return seg +class MinSizeLabelTransform: + def __init__(self, min_size=None, ndim=None, ensure_zero=False): + self.min_size = min_size + self.ndim = ndim + self.ensure_zero = ensure_zero + + def __call__(self, labels): + components = connected_components(labels, ndim=self.ndim, ensure_zero=self.ensure_zero) + if self.min_size is not None: + ids, sizes = np.unique(components, return_counts=True) + filter_ids = ids[sizes < self.min_size] + components[np.isin(components, filter_ids)] = 0 + components, _, _ = skimage.segmentation.relabel_sequential(components) + return components + + # TODO smoothing class BoundaryTransform: def __init__(self, mode="thick", add_binary_target=False, ndim=None): diff --git a/torch_em/util/__init__.py b/torch_em/util/__init__.py index dc92a521..7b7bbf9e 100644 --- a/torch_em/util/__init__.py +++ b/torch_em/util/__init__.py @@ -1,10 +1,10 @@ from .image import load_data, load_image, supports_memmap from .reporting import get_training_summary from .training import parser_helper -from .util import (auto_compile, ensure_array, ensure_spatial_array, - ensure_tensor, ensure_tensor_with_channels, - get_constructor_arguments, get_trainer, - is_compiled, load_model, model_is_equal) +from .util import ( + auto_compile, ensure_array, ensure_spatial_array, ensure_tensor, ensure_tensor_with_channels, + get_constructor_arguments, get_trainer, is_compiled, load_model, model_is_equal, ensure_patch_shape +) # NOTE: we don't import the modelzoo convenience functions here. # In order to avoid importing bioimageio.core (which is quite massive) when importing torch_em diff --git a/torch_em/util/debug.py b/torch_em/util/debug.py index ca7bfa90..0a2c3a43 100644 --- a/torch_em/util/debug.py +++ b/torch_em/util/debug.py @@ -106,6 +106,7 @@ def _check_napari(loader, n_samples, instance_labels, model=None, device=None, r v.add_image(y) if pred is not None: v.add_image(pred) + napari.run() diff --git a/torch_em/util/image.py b/torch_em/util/image.py index 17e34cfc..5a6efc2b 100644 --- a/torch_em/util/image.py +++ b/torch_em/util/image.py @@ -4,6 +4,7 @@ import numpy as np from elf.io import open_file + try: import imageio.v3 as imageio except ImportError: @@ -14,6 +15,7 @@ except ImportError: tifffile = None + TIF_EXTS = (".tif", ".tiff") @@ -35,6 +37,13 @@ def load_image(image_path, memmap=True): return tifffile.memmap(image_path, mode="r") elif tifffile is not None and os.path.splitext(image_path)[1].lower() in (".tiff", ".tif"): return tifffile.imread(image_path) + elif os.path.splitext(image_path)[1].lower() == ".nrrd": + import nrrd + return nrrd.read(image_path)[0] + elif os.path.splitext(image_path)[1].lower() == ".mha": + import SimpleITK as sitk + image = sitk.ReadImage(image_path) + return sitk.GetArrayFromImage(image) else: return imageio.imread(image_path) @@ -61,11 +70,19 @@ def __getitem__(self, index): def load_data(path, key, mode="r"): have_single_file = isinstance(path, str) - if key is None and have_single_file: - return load_image(path) - elif key is None and not have_single_file: - return np.stack([load_image(p) for p in path]) - elif key is not None and have_single_file: - return open_file(path, mode=mode)[key] - elif key is not None and not have_single_file: - return MultiDatasetWrapper(*[open_file(p, mode=mode)[key] for p in path]) + have_single_key = isinstance(key, str) + + if key is None: + if have_single_file: + return load_image(path) + else: + return np.stack([load_image(p) for p in path]) + else: + if have_single_key and have_single_file: + return open_file(path, mode=mode)[key] + elif have_single_key and not have_single_file: + return MultiDatasetWrapper(*[open_file(p, mode=mode)[key] for p in path]) + elif not have_single_key and have_single_file: + return MultiDatasetWrapper(*[open_file(path, mode=mode)[k] for k in key]) + else: # have multipe keys and multiple files + return MultiDatasetWrapper(*[open_file(p, mode=mode)[k] for k in key for p in path]) diff --git a/torch_em/util/prediction.py b/torch_em/util/prediction.py index e4ae86ed..eb6ba53d 100644 --- a/torch_em/util/prediction.py +++ b/torch_em/util/prediction.py @@ -182,7 +182,7 @@ def predict_block(block_id): if mask is not None: mask_block, _ = _load_block(mask, offset, block_shape, halo, with_channels=False) - mask_block = mask_block[inner_bb] + mask_block = mask_block[inner_bb].astype("bool") if mask_block.sum() == 0: return diff --git a/torch_em/util/util.py b/torch_em/util/util.py index 8b1f8e0b..e4062e56 100644 --- a/torch_em/util/util.py +++ b/torch_em/util/util.py @@ -5,7 +5,6 @@ import numpy as np import torch import torch_em -import matplotlib.pyplot as plt from matplotlib import colors # this is a fairly brittle way to check if a module is compiled. @@ -69,7 +68,14 @@ def ensure_tensor(tensor, dtype=None): if isinstance(tensor, np.ndarray): if np.dtype(tensor.dtype) in DTYPE_MAP: tensor = tensor.astype(DTYPE_MAP[tensor.dtype]) - tensor = torch.from_numpy(tensor) + # Try to convert the tensor, even if it has wrong byte-order + try: + tensor = torch.from_numpy(tensor) + except ValueError: + tensor = tensor.view(tensor.dtype.newbyteorder()) + if np.dtype(tensor.dtype) in DTYPE_MAP: + tensor = tensor.astype(DTYPE_MAP[tensor.dtype]) + tensor = torch.from_numpy(tensor) assert torch.is_tensor(tensor), f"Cannot convert {type(tensor)} to torch" if dtype is not None: @@ -139,6 +145,52 @@ def ensure_spatial_array(array, ndim, dtype=None): return array +def ensure_patch_shape( + raw, labels, patch_shape, have_raw_channels=False, have_label_channels=False, channel_first=True +): + raw_shape = raw.shape + if labels is not None: + labels_shape = labels.shape + + # In case the inputs has channels and they are channels first + # IMPORTANT: for ImageCollectionDataset + if have_raw_channels and channel_first: + raw_shape = raw_shape[1:] + + if have_label_channels and channel_first and labels is not None: + labels_shape = labels_shape[1:] + + # Extract the pad_width and pad the raw inputs + if any(sh < psh for sh, psh in zip(raw_shape, patch_shape)): + pw = [(0, max(0, psh - sh)) for sh, psh in zip(raw_shape, patch_shape)] + + if have_raw_channels and channel_first: + pad_width = [(0, 0), *pw] + elif have_raw_channels and not channel_first: + pad_width = [*pw, (0, 0)] + else: + pad_width = pw + + raw = np.pad(array=raw, pad_width=pad_width) + + # Extract the pad width and pad the label inputs + if labels is not None and any(sh < psh for sh, psh in zip(labels_shape, patch_shape)): + pw = [(0, max(0, psh - sh)) for sh, psh in zip(labels_shape, patch_shape)] + + if have_label_channels and channel_first: + pad_width = [(0, 0), *pw] + elif have_label_channels and not channel_first: + pad_width = [*pw, (0, 0)] + else: + pad_width = pw + + labels = np.pad(array=labels, pad_width=pad_width) + if labels is None: + return raw + else: + return raw, labels + + def get_constructor_arguments(obj): # all relevant torch_em classes have 'init_kwargs' to @@ -169,9 +221,9 @@ def _get_args(obj, param_names): # TODO support common torch losses (e.g. CrossEntropy, BCE) warnings.warn( - f"Constructor arguments for {type(obj)} cannot be deduced." + - "For this object, empty constructor arguments will be used." + - "Hence, the trainer can probably not be correctly deserialized via 'DefaultTrainer.from_checkpoint'." + f"Constructor arguments for {type(obj)} cannot be deduced.\n" + + "For this object, empty constructor arguments will be used.\n" + + "The trainer can probably not be correctly deserialized via 'DefaultTrainer.from_checkpoint'." ) return {} @@ -228,7 +280,11 @@ def load_model(checkpoint, model=None, name="best", state_key="model_state", dev model = get_trainer(checkpoint, name=name, device=device).model else: # load the model state from the checkpoint - ckpt = os.path.join(checkpoint, f"{name}.pt") + if os.path.isdir(checkpoint): + ckpt = os.path.join(checkpoint, f"{name}.pt") + else: + ckpt = checkpoint + state = torch.load(ckpt, map_location=device)[state_key] # to enable loading compiled models compiled_prefix = "_orig_mod." @@ -250,7 +306,6 @@ def model_is_equal(model1, model2): return True - def get_random_colors(labels): """Function to generate a random color map for a label image """