Phyloformer: towards fast, accurate and versatile phylogenetic reconstruction with deep neural networks
- Luca Nesterenko
- Luc Blassel
- Philippe Veber
- Bastien Boussau
- Laurent Jacob
This repository contains the scripts for the paper:
@article{Nesterenko2024phyloformer,
author={Nesterenko Luca, Luc Blassel, Philippe Veber, Boussau Bastien, Jacob Laurent},
title={Phyloformer: Fast, accurate and versatile phylogenetic reconstruction with deep neural networks},
doi={10.1101/2024.06.17.599404},
url={https://www.biorxiv.org/content/10.1101/2024.06.17.599404v1},
year={2024},
journal={bioRxiv}
}
Phyloformer is a fast deep neural network-based method to infer evolutionary distance from a multiple sequence alignment. It can be used to infer alignments under a selection of evolutionary models: LG+GC, LG+GC with indels, CherryML co-evolution model and SelReg with selection.
You can read below for some example usage and explanations but if you just want the CLI reference of the available scripts you
can see them in the cli_reference.md
file.
The easiest way to install the software is by creating a virtual environment using conda/mamba and then installing dependencies in it:
# Install mamba if you want to use it instead of conda
conda install -n base -c conda-forge mamba
# Clone the phyloformer repo
git clone https://github.com/lucanest/Phyloformer.git && cd Phyloformer
# Create the virtual env and install the phyloformer package inside
conda create -n phylo python=3.9 -c defaults && conda activate phylo
pip install -r requirements.txt
Some pre-built binaries are included in this repo both for linux AMD64 and macos ARM64, these include:
IQTree
: for inferring maximum likelihood (ML) trees and simulating alignments (For the alignment simulation to work you should use IQTree v2.0.0)FastTree
: for inferring ML-like treesFastME
: for inferring trees from distance matrices (such as the ones produced by phyloformer)goalign
: for manipulating alignmentsphylotree
: for manipulating newick formatted phylogenetic treesphylocompare
: for batch comparison of newick formatted phylogenetic trees
If any of these executables do not run on your platform you can find more information as well as builds and buil-instruction in the links to each tool's repository.
All the named phyloformer models in the manuscript are given in the models
directory:
PF_Base
trained with an MAE loss on LG+GC dataPF
fine-tuned from PF_Base with an MRE loss on LG+GC dataPF_Indel
fine-tuned from PF_Base with an MAE loss on LG+GC+Indels dataPF_Cherry
fine-tuned from PF_Base with an MAE loss on CherryML dataPF_SelReg
fine-tuned from PF_Base with an MAE loss on SelReg data
Use the infer_alns.py
script to infer some distance matrices from alignments using a trained Phyloformer model
Let's use the small test set given along with this repo to test out Phyloformer (If you're on a macos ARM chip replace bin_linux
with bin_macos
).
# First make sure you are in the repo and have the correct conda env
cd Phyloformer && conda activate phylo
# Infer distance matrices using the LG+GC PF model
# (This will automatically use a CUDA GPU if available, otherwise it will use the CPU)
python infer_alns.py -o data/testdata/pf_matrices models/pf.ckpt data/testdata/msas
# Infer trees with FastME
mkdir data/testdata/pf_trees
for file in data/testdata/pf_matrices/*; do
base="${file##*/}"
stem="${base%%.*}"
./bin/bin_linux/fastme -i "${file}" -o "data/testdata/pf_trees/${stem}.nwk" --nni --spr
done
# Compare trees
./bin/bin_linux/phylocompare -t -n -o data/cmp data/testdata/trees data/testdata/pf_trees
# Compute the average KF distance
# It should output '0.333'
cat data/cmp_topo.csv | awk 'BEGIN {FS=","} NR>1{sum += $5; n+=1} END {printf "%.3f\n", sum/n}'
Simulate trees with simulate_trees.py
, if you want to simulate LG+GC alignments use alisim.py
.
If you want to use Cherry to simulate alignments use TODO
, for SelReg use TODO
.
Let us simulate a small testing set with different tree sizes:
# Create output directory
mkdir data/test_set
# Simulate 20 trees for each number of tips from 10 to 80 with a step size of 10
for i in $(seq 10 10 80); do
python simulate_trees.py --ntips "$i" --ntrees 20 --output data/test_set/trees --type birth-death
done
# Simulate 250-AA long alignments using LG+GC from the simulated trees
# here we specify the iqtree binary given in this repo and allow duplicate sequences
# in the MSAs we get as output
python alisim.py \
--outdir data/test_set/alignments \
--substitution LG \
--gamma GC \
--iqtree ./bin/bin_linux/iqtree_2.2.0 \
--length 250 \
--allow-duplicates \
--max-attempts 1 \
data/test_set/trees
Use the train_distributed
script to train or fine-tune a PF model on some data.
Training is done using a lightning
wrapper, so make sure that is installed before attempting it.
This script is made to run on a CUDA GPU within a SLURM environment but it should run fine on a personlal computer.
If you wish to train on modern ARM architecture Apple machines, the script will run on the CPU instead f the GPU because of a bug in the MPS implementation of Elu()
.
For this I will assume that you have your training data organized as follows:
data/
├── train/
│ ├── msas/
│ └── trees/
└── val/
├── msas/
└── trees/
Note that the training script supports auto-splitting the data into training and validation sets, however for reproducibility purposes it is not recommended, especially if you intend to use checkpoints and resume training at a further point.
Important: so that the data-loader knows which simulated alignment comes from which simulated tree, corresponding tree and alignment file pairs must have the same file name with differing extensions (.nwk
/.newick
for trees and .fa
/.fasta
for alignments). This means that if you have a data/train/trees/0_20_tips.nwk
tree file you must have a corresponding data/train/msas/0_20_tips.fa
alignment file.
Ok now that we have the correct data layout we can train a Phyloformer instance on our data. We wand to run this for 20 epochs, with 300 warmup steps on our learning-rate schedule at the end of which we reach the target starting learning-rate of
python train_distributed.py \
--train-trees data/train/trees \
--train-alignments data/train/msas \
--val-trees data/val/trees \
--val-alignments data/val/msas \
--warmup-steps 300 \
--learning-rate 1e-4 \
--nb-epochs 20 \
--batch-size 4 \
--check-val-every 1000
We also specify that we want to check the validation loss (on the whole validation set) every 1000 steps. Every time that the validation loss is estimated a model checkpoint is saved. Checkpoints are saved in a directory named checkpoints_LR_...
that encodes the specifics of the current training run. In parallel to this, training logs are handled by wandb
and saved to the wandb
directory. Since this script was designed to run on an offline SLURM cluster node, wandb
logs are not automatically synced to the online platform and users must run wandb sync
on the desired logs to upload them from a machine that has an internet connection (e.g. the cluster's head node).
wandb
parameters like the project name and the run name can be user-specified with the --project-name
and --run-name
flags respectively.
You can also specify the directory in which to save the checkpoints and the logs by using the --output-dir
flag (by default it is the current directory).
In the previous command we trained a phyloformer instance with a default architecture (6 attention blocks, with 64 embedding dimensions, 4 attention heads and 0.0 dropout). Each of these parametrs can be user-specified, e.g:
python train_distributed.py \
--nb-blocks 3 \
--embed-dim 128 \
--nb-heads 2 \
--dropout 0.5 \
...
Since training a phyloformer instance can be costly it is usually a good idea to set up sonme guardrails before launching training on several GPUs in parallel for several hours.
The first thing to do is to set up the training schedule by defining the number of epochs, target learning rate and number of warmup steps (see above). Once that is done you can control when the model stops training either:
- After a set number of steps with the
--max-steps
flag. This is useful if you want to stop training early but have a LR schedule that is defined over more steps. By default this isNone
meaning that the training will continue until the specified number of epochs is reached. - With early stopping conditions
- the
--no-improvement-stop
/-n
flag can be used to set the number of validation steps for which there have been no improvement before stopping training. (e.g. with-n 5
if after measuring validaiton loss 5 times without improving performance at least once, then traininng stops.). - the
--hard-loss-ceiling
/-L
flag sets a value that if the training or validation loss is ever above training stops immediately.
- the
Resuming training from a saved checkpoint is meant to be intuitive, you just need to specify the data paths, checkpoint path and validation check interval. All model architecture, LR scheduling and early stopping parameters are contained in the saved checkpoint and will be applied automatically:
python train_distributed.py \
--train-trees data/train/trees \
--train-alignments data/train/msas \
--val-trees data/val/trees \
--val-alignments data/val/msas \
--check-val-every 1000 \
--load-checkpoint checkpoints_.../last.ckpt
the ...
in the checkpoints directory are to be replaced with you checkpoint directory depending on the previous run's trainig parameters.
Fine-tuning a previous model is much like training a model from scratch excepts we do not need to specify architecture parameters and we need to specify the path to the pre-trained base model that we want to fine tune:
python train_distributed.py \
--train-trees data/train/trees \
--train-alignments data/train/msas \
--val-trees data/val/trees \
--val-alignments data/val/msas \
--warmup-steps 300 \
--learning-rate 1e-4 \
--nb-epochs 20 \
--batch-size 4 \
--check-val-every 1000 \
--base-model path/to/pretrained/model.ckpt
This should cover all use cases for training, please refer to the cli_reference.md
file for all the flags or open an issue if you believe that something is missing.
Use the make_plots
script to reproduce all paper figures.
# Download the results (This might take a little time since the file is quite large)
curl 'https://zenodo.org/records/11930296/files/results.tar.gz?download=1' -o results.tar.gz
# Extract results file (make sure you are in the repo root)
tar xzvf results.tar.gz
# Run figure producing script (this should take 5 to 10 minutes)
python make_plots.py