Deep-learning based semantic and instance segmentation for 3D Electron Microscopy and other bioimage analysis problems based on PyTorch. Any feedback is highly appreciated, just open an issue!
Highlights:
- Functional API with sensible defaults to train a state-of-the-art segmentation model with a few lines of code.
- Differentiable augmentations on GPU and CPU thanks to kornia.
- Off-the-shelf logging with tensorboard or wandb.
- Export trained models to bioimage.io model format with one function call to deploy them in ilastik or deepimageJ.
Design:
- All parameters are specified in code, no configuration files.
- No callback logic; to extend the core functionality inherit from
trainer.DefaultTrainer
instead. - All data-loading is lazy to support training on large data-sets.
# train a 2d U-Net for foreground and boundary segmentation of nuclei
# using data from https://github.com/mpicbg-csbd/stardist/releases/download/0.1.0/dsb2018.zip
import torch
import torch_em
from torch_em.model import UNet2d
from torch_em.data.datasets import get_dsb_loader
model = UNet2d(in_channels=1, out_channels=2)
# transform to go from instance segmentation labels
# to foreground/background and boundary channel
label_transform = torch_em.transform.BoundaryTransform(
add_binary_target=True, ndim=2
)
# training and validation data loader
data_path = "./dsb" # the training data will be downloaded and saved here
train_loader = get_dsb_loader(
data_path,
patch_shape=(1, 256, 256),
batch_size=8,
split="train",
download=True,
label_transform=label_transform
)
val_loader = get_dsb_loader(
data_path,
patch_shape=(1, 256, 256),
batch_size=8,
split="test",
label_transform=label_transform
)
# the trainer object that handles the training details
# the model checkpoints will be saved in "checkpoints/dsb-boundary-model"
# the tensorboard logs will be saved in "logs/dsb-boundary-model"
trainer = torch_em.default_segmentation_trainer(
name="dsb-boundary-model",
model=model,
train_loader=train_loader,
val_loader=val_loader,
learning_rate=1e-4,
device=torch.device("cuda")
)
trainer.fit(iterations=5000)
# export bioimage.io model format
from glob import glob
import imageio
from torch_em.util import export_bioimageio_model
# load one of the images to use as reference image image
# and crop it to a shape that is guaranteed to fit the network
test_im = imageio.imread(glob(f"{data_path}/test/images/*.tif")[0])[:256, :256]
export_bioimageio_model("./checkpoints/dsb-boundary-model", "./bioimageio-model", test_im)
For a more in-depth example, check out one of the example notebooks:
- 2D-UNet: train a 2d UNet for a segmentation task. Available on google colab.
- 3D-UNet: train a 3d UNet for a segmentation task. Available on google colab.
mamba is a drop-in replacement for conda, but much faster. While the steps below may also work with conda
, it's highly recommended using mamba
. You can follow the instructions here to install mamba
.
You can install torch_em
from conda-forge:
mamba install -c conda-forge torch_em
Please check out pytorch.org for more information on how to install a PyTorch version compatible with your system.
It's recommmended to set up a conda environment for using torch_em
.
Two conda environment files are provided: environment_cpu.yaml
for a pure CPU set-up and environment_gpu.yaml
for a GPU set-up.
If you want to use the GPU version, make sure to set the correct CUDA version for your system in the environment file, by modifiying this-line.
You can set up a conda environment using one of these files like this:
mamba env create -f <ENV>.yaml -n <ENV_NAME>
mamba activate <ENV_NAME>
pip install -e .
where <ENV>.yaml
is either environment_cpu.yaml
or environment_gpu.yaml
.
- Training of 2d U-Nets and 3d U-Nets for various segmentation tasks.
- Random forest based domain adaptation from Shallow2Deep
- Training models for embedding prediction with sparse instance labels from SPOCO
- Training of UNETR for various 2d segmentation tasks, with a flexible choice of vision transformer backbone from Segment Anything or Masked Autoencoder.
- Training of ViM-UNet for various 2d segmentation tasks.
A command line interface for training, prediction and conversion to the bioimage.io modelzoo format wll be installed with torch_em
:
torch_em.train_unet_2d
: train a 2D U-Net.torch_em.train_unet_3d
: train a 3D U-Net.torch_em.predict
: run prediction with a trained model.torch_em.predict_with_tiling
: run prediction with tiling.torch_em.export_bioimageio_model
: export a model to the modelzoo format.
For more details run <COMMAND> -h
for any of these commands.
The folder scripts/cli contains some examples for how to use the CLI.
Note: this functionality was recently added and is not fully tested.