Skip to content

Latest commit

 

History

History
100 lines (79 loc) · 5.73 KB

pretrain.md

File metadata and controls

100 lines (79 loc) · 5.73 KB

Pre-training scPRINT

scPRINT is a large model that can be pre-trained on a large dataset of single cell data.

This pre-training is quite efficient for scPRINT and smaller models can be pretrained on any hardware with a 20GB NVIDIA GPU.

Setup of the database

To perform pretraining you will need a large dataset. We recommend using the laminDB to assemble such a large database of dataset and to use our scdataloader package to perform the data loading to the model.

In addition, you will need to preprocess your datasets. To make sure that the fields are all here, the genes are in the right format, the raw counts are used, etc... We recommend using the Preprocessor class of the scdataloader package.

Moreover scdataloader works with a set of ontologies. To install these, use the function populate_my_ontologies from the scdataloader package.

If you do not have your own database of anndatas, we recommend the cellxgene database and our associated helper function to download and preprocess all of cellxgene in a single command with scdataloader.

Finally you might want to generate gene embeddings to use with scPRINT instead of learning these tokens from scratch. For this you can use the gene_embedders module of scPRINT, which usage is detailed in the notebooks/generate_gene_embeddings.ipynb notebook.

Pre-training

to pretrain scPRINT we strongly recommend using command line as it can take multiple days (and using some HPC plateform like slurm or others). If on your own machine, use something like screen at least 😉.

Most of the pre-training usage follows from pytorch lightning with scprint fit you will launch a training run. It will populate both the datamodule (see scdataloader), the model (see model.py), the trainer (see pytorch lightning) and the various callbacks.

But you might want to use additional parameters. For this, you can use the config folder and the yaml files in it. These files are used to store the main hyperparameters of the model and the training scheme.

More hyperparameters are given to the scPRINT model via a Trainer callback I created (see trainer/trainer.py). This is used to specify parameters to scPRINT that are used solely during training and are not part of the model definition itself, like lr, schedulers, optimizers, etc.. I use a callback as it is how pytorch lightning requires us to send training parameters to the model.

Thus a full command line to train scPRINT on a slurm cluster might look like this:

conda activate scprint
### slurm level stuff
module load cuda/11.7
srun 
  -p gpu #gpu partition
  -q gpu #gpu queue
  --gres=gpu:A40:4,gmem:40G #gpu type (4 A40 with 40GB of GPU mem)
  --cpus-per-task 16
  --mem-per-gpu 90G #RAM per GPU
  --ntasks-per-node=1 
####
  # actuall scprint command
  scprint fit 
    --config config/base.yml #base config file (see below)
    --config config/pretrain_large.yml #the differences when training a large model
    --model.nhead 8 # changing this parameter from the large model directly in command line (cannot do 4 heads of 128dim with A40 GPUs...)
    --scprint_training.name o2uniqsx #an id for the model (not needed but useful)
    --trainer.strategy auto #here the strategy selected will be "ddp_find_unused_parameters_true"

with the base yaml file containing:

# general params
project: scprint_scale #project name for saving data and wandb
seed_everything: 42
ckpt_path: null #we don't have a checkpoint weights as we train from scratch
set_float32_matmul_precision: True
wandblog: all #we use wandb here
log_freq: 200
log_graph: True
trainer: #training level params
  precision: 16-mixed #we use mixed precision 16bit for training
  gradient_clip_val: 100 #needed
  log_every_n_steps: 100
  ....
  logger: #we can add multiple loggers (see below)
    - class_path: lightning.pytorch.loggers.WandbLogger
  callbacks: #you can create your own callback and add it here or use lightning's callbacks
    - class_path: lightning.pytorch.callbacks.StochasticWeightAveraging
      init_args:
        swa_lrs: 0.03
    ...
model: # model params
  dropout: 0.1
  transformer: flash #flashattention is used
  mvc_decoder: inner product
  residual_in_fp32: True
  ...
data: #datamodule params
  organisms: #we will use these 2 species
    - NCBITaxon:9606
    - NCBITaxon:10090
  gene_position_tolerance: 10_000 #gene location: if genes are closer than 10kb, they are considered as the same location
  gene_embeddings: ./data/main/gene_embeddings.parquet #the embeddings of genes (see above  )
  collection_name: all no zhang13M # the name of the laminDB collection we will use
  how: random expr # how we collate the expression data (here random expressed genes)
  max_len: 2200 #how many genes we use in the model context during training
  weight_scaler: 50 #how do we scale the weighted random sampling procedure (see our manuscript)
  ...

We use wanDB in our case and our previous wandb training runs are available here, however scPRINT and pytorch lightning support a breadth of logging tools: see loggers.

We use slurm in our usecase here but scPRINT and pytorch lightning has been made to work in a breadth of environments e.g..

Fine-tuning

For now scPRINT doesn't have a fine-tuning script. But PRs are very welcome on using LoRA and its alternatives to fine-tune scPRINT on novel tasks!