Skip to content

Latest commit

 

History

History
129 lines (96 loc) · 7.25 KB

README.md

File metadata and controls

129 lines (96 loc) · 7.25 KB

2D Latent Diffusion Example

This folder contains an example for training and validating a 2D Latent Diffusion Model on Brats axial slices. The example includes support for multi-GPU training with distributed data parallelism.

The workflow of the Latent Diffusion Model is depicted in the figure below. It begins by training an autoencoder in pixel space to encode images into latent features. Following that, it trains a diffusion model in the latent space to denoise the noisy latent features. During inference, it first generates latent features from random noise by applying multiple denoising steps using the trained diffusion model. Finally, it decodes the denoised latent features into images using the trained autoencoder.

latent diffusion scheme

MONAI latent diffusion model implementation is based on the following papers:

Latent Diffusion: Rombach, Robin, et al. "High-resolution image synthesis with latent diffusion models." CVPR 2022.

This network is designed as a demonstration to showcase the training process for this type of network using MONAI. To achieve optimal performance, it is recommended that users have a GPU with memory larger than 32G to accommodate larger networks and attention layers.

1. Data

The dataset we are experimenting with in this example is BraTS 2016 and 2017 data.

BraTS is a public dataset of brain MR images. Using these images, the goal is to generate images that look similar to the images in BraTS 2016 and 2017 dataset.

The data can be downloaded from Medical Decathlon. By running the following command, the Brats data will be downloaded from Medical Decathlon and extracted to $data_base_dir in ./config/environment.json. You will see a subfolder Task01_BrainTumour under $data_base_dir. By default, you will see ./dataset/Task01_BrainTumour.

python download_brats_data.py -e ./config/environment.json

Disclaimer: We are not the host of the data. Please make sure to read the requirements and usage policies of the data and give credit to the authors of the dataset!

2. Installation

Please refer to the Installation of MONAI Generative Model

3. Run the example

The network configuration files are located in ./config/config_train_32g.json for 32G GPU and ./config/config_train_16g.json for 16G GPU. You can modify the hyperparameters in these files to suit your requirements.

The training script uses the batch size and patch size defined in the configuration files. If you have a different GPU memory size, you should adjust the "batch_size" and "patch_size" parameters in the "autoencoder_train" to match your GPU. Note that the "patch_size" needs to be divisible by 4.

Before you start training, please set the path in ./config/environment.json.

  • "model_dir": where it saves the trained models
  • "tfevent_path": where it saves the tensorboard events
  • "output_dir": where you store the generated images during inference.
  • "resume_ckpt": whether to resume training from existing checkpoints.
  • "data_base_dir": where you store the Brats dataset.

Below is the the training command for single GPU.

python train_autoencoder.py -c ./config/config_train_32g.json -e ./config/environment.json -g 1

The training script also enables multi-GPU training. For instance, if you are using eight 32G GPUs, you can run the training script with the following command:

export NUM_GPUS_PER_NODE=8
torchrun \
    --nproc_per_node=${NUM_GPUS_PER_NODE} \
    --nnodes=1 \
    --master_addr=localhost --master_port=1234 \
    train_autoencoder.py -c ./config/config_train_32g.json -e ./config/environment.json -g ${NUM_GPUS_PER_NODE}

autoencoder train curve         autoencoder validation curve

With eight DGX1V 32G GPUs, it took around 34 hours to train 1000 epochs.

An example reconstruction result is shown below:

Autoencoder reconstruction result

The training script uses the batch size and patch size defined in the configuration files. If you have a different GPU memory size, you should adjust the "batch_size" and "patch_size" parameters in the "diffusion_train" to match your GPU. Note that the "patch_size" needs to be divisible by 16 and no larger than 256.

To train with single 32G GPU, please run:

python train_diffusion.py -c ./config/config_train_32g.json -e ./config/environment.json -g 1

The training script also enables multi-GPU training. For instance, if you are using eight 32G GPUs, you can run the training script with the following command:

export NUM_GPUS_PER_NODE=8
torchrun \
    --nproc_per_node=${NUM_GPUS_PER_NODE} \
    --nnodes=1 \
    --master_addr=localhost --master_port=1234 \
    train_diffusion.py -c ./config/config_train_32g.json -e ./config/environment.json -g ${NUM_GPUS_PER_NODE}

latent diffusion train curve         latent diffusion validation curve

To generate one image during inference, please run the following command:

python inference.py -c ./config/config_train_32g.json -e ./config/environment.json --num 1

--num defines how many images it would generate.

An example output is shown below.

                       

4. Questions and bugs

  • For questions relating to the use of MONAI, please use our Discussions tab on the main repository of MONAI.
  • For bugs relating to MONAI functionality, please create an issue on the main repository.
  • For bugs relating to the running of a tutorial, please create an issue in this repository.

Reference

[1] Rombach, Robin, et al. "High-resolution image synthesis with latent diffusion models." CVPR 2022.

[2] Menze, Bjoern H., et al. "The multimodal brain tumor image segmentation benchmark (BRATS)." IEEE transactions on medical imaging 34.10 (2014): 1993-2024.

[3] Pinaya et al. "Brain imaging generation with latent diffusion models"