Skip to content

A probabilistic model to cluster survival data in a variational deep clustering setting

Notifications You must be signed in to change notification settings

i6092467/vadesc

Repository files navigation

A Deep Variational Approach to Clustering Survival Data

This repository holds the official code for the paper "A Deep Variational Approach to Clustering Survival Data", presented at ICLR 2022. A short explanation of the method is provided in this talk, poster can be viewed here.

[left] Cluster-specific Kaplan–Meier curves and corresponding centroid CT images, computed by averaging all samples assigned to each cluster by VaDeSC on the NSCLC data. [right] CT images generated by (i) sampling latent representations from the Gaussian mixture learnt by VaDeSC and (ii) decoding representations using the decoder network.

🦸‍♀ Motivation

In this work, we study the problem of clustering survival data — a challenging and so far under-explored task. The figure on the right schematically depicts this clustering problem: here, the overall patient population consists of three groups characterised by different associations between the covariates and survival, resulting in disparate clinical conditions. The survival distributions do not need to differ between clusters: compare groups 1 and 3. The identification of such patient subpopulations could, for example, facilitate a better understanding of a disease and a more personalised disease management.

💡 Method

We introduce a novel semi-supervised probabilistic approach, the variational deep survival clustering (VaDeSC), to discovers groups of patients characterised by different generative mechanisms of survival outcome. The proposed method employs a deep generative model to uncover the underlying distribution of both the explanatory variables and censored survival times. The figure on the right provides a schematic summary of the proposed approach: the input vector is mapped to a latent representation using a VAE regularised by a Gaussian mixture prior. The survival density function is given by a mixture of Weibull distributions with cluster-specific parameters . The parameters of the Gaussian mixture and the Weibull distributions are then optimised jointly using both the explanatory input variables and survival outcomes.

📝 Requirements

All required libraries are included in the conda environment specified by environment.yml. To install and activate it, follow the instructions below:

conda env create -f environment.yml         # install dependencies
conda activate Survival_Cluster_Analysis    # activate environment

Note: you might need to reboot after creating the environment!

🔨 Usage

File main.py trains and evaluates the VaDeSC model. It accepts following arguments:

  --data {mnist,sim,support,flchain,hgg,hemo,lung1,nsclc,nsclc_features,basel}
                        the dataset
  --num_epochs NUM_EPOCHS
                        the number of training epochs
  --batch_size BATCH_SIZE
                        the mini-batch size
  --lr LR               the learning rate
  --decay DECAY         the exponential decay of the learning rate
  --weibull_shape WEIBULL_SHAPE
                        the Weibull shape parameter (global)
  --no-survival         specifies if the survival model should not be included
  --dsa                 specifies if the deep survival analysis with k-means should be run
  --dsa_k DSA_K         number of clusters in deep survival analysis with k-means
  --eval-cal EVAL_CAL   specifies if the calibration needs to be evaluated
  --runs RUNS           the number of runs, the results will be averaged
  --results_dir RESULTS_DIR
                        the directory where the results will be saved
  --results_fname RESULTS_FNAME
                        the name of the .txt file with the results
  --pretrain PRETRAIN   specifies if the autoencoder should be pretrained
  --epochs_pretrain EPOCHS_PRETRAIN
                        the number of pretraining epochs
  --save_model SAVE_MODEL
                        specifies if the model should be saved
  --ex_name EX_NAME     the experiment name
  --config_override CONFIG_OVERRIDE
                        the override file name for config.yml
  --seed SEED           random number generator seed
  --eager EAGER         specifies if the TF functions should be run eagerly

Folder /configs contains .yml files which further sepcify the configuration of the model for each experiment. Folder /bin contains shell scripts for running clustering and time-to-event prediction experiments:

The VaDeSC model is implemented in /models/model.py. Encoder and decoder architectures are specified in /models/networks.py. Data loaders are provided in the /datasets folder. HGG, Hemodialysis, and NSCLC datasets are not included in the repository due to the medical confidentiality. Code for post hoc explanations of cluster assignments is available in the /post_hoc_explanations folder.

🆘 Feeling lost? Follow this Jupyter notebook to get started!

🤝 Acknowledgements

📭 Maintainers

📚 References

Below are some references helpful for understanding our method:

  • E. Bair and R. Tibshirani. Semi-supervised methods to predict patient survival from gene expression data. PLoS Biology, 2(4):e108, 2004.
  • N. Dilokthanakul, P. A. M. Mediano, M. Garnelo, M. C. H. Lee, H. Salimbeni, K. Arulkumaran, and M. Shanahan. Deep unsupervised clustering with Gaussian mixture variational autoencoders, 2016. arXiv:1611.02648.
  • Z. Jiang, Y. Zheng, H. Tan, B. Tang, and H. Zhou. Variational deep embedding: An unsupervised and generative approach to clustering. In Proceedings of the 26th International Joint Conference on Artificial Intelligence, pp. 1965–1972. AAAI Press, 2017.
  • S. Liverani, L. Leigh, I. L. Hudson, and J. E. Byles. Clustering method for censored and collinear survival data. Computational Statistics, 2020.
  • P. Chapfuwa, C. Li, N. Mehta, L. Carin, and R. Henao. Survival cluster analysis. In Proceedings of the ACM Conference on Health, Inference, and Learning, pp. 60–68. Association for Computing Machinery, 2020.
  • C. Nagpal, X. R. Li, and A. Dubrawski. Deep survival machines: Fully parametric survival regression and representation learning for censored data with competing risks. IEEE Journal of Biomedical and Health Informatics, 2021.

🙂 Citation

@inproceedings{ManduchiMarcinkevics2022,
  title={A Deep Variational Approach to Clustering Survival Data},
  author={Laura Manduchi and Ri{\v{c}}ards Marcinkevi{\v{c}}s and Michela C. Massi 
          and Thomas Weikert and Alexander Sauter and Verena Gotta 
          and Timothy M{\"u}ller and Flavio Vasella and Marian C. Neidert 
          and Marc Pfister and Bram Stieltjes and Julia E Vogt},
  booktitle={International Conference on Learning Representations},
  year={2022},
}