This example implements training of a discrete flow matching model on text data. This repository provides the necessary tools and scripts to train and evaluate these models.
Note: this example was tested only using PyTorch 2.5 and on a single node of H100 (8 gpus). With this setup, we achieved approximately 380k training steps in 24 hours.
To get started with this project, follow these steps to set up your environment:
conda env create -f environment.yml
conda activate discrete_flow_matching
Specify the data cache and checkpoint directories. Data will automatically be downloaded into the cache directory.
CACHE_DIR=...
HYDRA_RUN_DIR=...
To train a discrete flow matching model on fine-web-edu, run:
python run_train.py data.cache_dir=${CACHE_DIR}
To use slurm
, modify the slurm
config according to the cluster you are working on, and run:
python run_train.py data.cache_dir=${CACHE_DIR} hydra_dir=${HYDRA_RUN_DIR} -m &
We trained models with linear scheduler (PolynomialConvexScheduler(n=1.0)
) for one million steps on FineWeb-EDU.
PYTHONPATH="." python scripts/run_eval.py --work_dir "/path/to/exp/folder" --ngpus 8 --eval_elbo --eval_perplexity
Scheduler | Source distribution | Loss | Generative perplexity | ELBO |
---|---|---|---|---|
Linear | Mask | Cross-entropy | 128.9 | 53.2 |
Generalized KL | 132.2 | 47.9 | ||
Uniform | Cross-entropy | 90.9 | 71.7 | |
Generalized KL | 82.1 | 71.3 |
.
├── configs # Train configs
│ └── ...
├── data # Data loading and preprocessing
│ └── ...
├── logic # Logic components, such as flow related classes
│ └── ...
├── model # Transformer implementation
│ └── ...
├── scripts # Evaluation script
│ └── ...
├── utils # Utility functions
│ └── ...
├── README.md
├── environment.yml
├── train.py
└── run_train.py # Run training script
This repository implements the following papers:
- Discrete Flow Matching
- Flow Matching with General Discrete Paths: A Kinetic-Optimal Perspective
- Generative Flows on Discrete State-Spaces: Enabling Multimodal Flows with Applications to Protein Co-Design
- Simplified and Generalized Masked Diffusion for Discrete Data
This example partially use code from:
- Flash attention
- Discrete Diffusion Modeling by Estimating the Ratios of the Data Distribution
- GLIDE: Towards Photorealistic Image Generation and Editing with Text-Guided Diffusion Models
- TorchData
The majority of the code in this example is licensed under CC-BY-NC, however portions of the project are available under separate license terms:
- flash attention and TorchData are under BSD 3 license.
- Discrete Diffusion Modeling by Estimating the Ratios of the Data Distribution and GLIDE: Towards Photorealistic Image Generation and Editing with Text-Guided Diffusion Models are under MIT license.