This repository contains the official implementation of the paper "Simulation-Free Training of Neural ODEs on Paired Data (NeurIPS 2024)"
Semin Kim*, Jaehoon Yoo*, Jinwoo Kim, Yeonwoo Cha, Saehoon Kim, Seunghoon Hong
To set up the environment, start by installing dependencies listed in requirements.txt
. You can also use Docker to streamline the setup process.
- Docker Setup:
docker pull pytorch/pytorch:2.1.0-cuda11.8-cudnn8-devel
docker run -it pytorch/pytorch:2.1.0-cuda11.8-cudnn8-devel bash
- Clone the Repository:
git clone https://github.com/seminkim/simulation-free-node.git
- Install Requirements:
pip install -r requirements.txt
Place all datasets in the .data
directory. By default, this code automatically downloads the MNIST, CIFAR-10, and SVHN datasets into the .data
directory.
The UCI dataset, composed of 10 tasks (bostonHousing
, concrete
, energy
, kin8nm
, naval-propulsion-plant
, power-plant
, protein-tertiary-structure
, wine-quality-red
, yacht
, and YearPredictionMSD
), can be manually downloaded from the Usage part of the following repository: CARD.
Scripts for training are available for both classification and regression tasks.
To train a model for a classification task, run:
python main.py fit --config configs/{dataset_name}.yaml --name {exp_name}
For regression tasks (only supported with UCI datasets), use the following command:
python main.py fit --config configs/uci.yaml --name {exp_name} --data.task {task_name} --data.split_num {split_num}
In this command, specify the UCI task name and the data split number accordingly.
Use the following commands for model evaluation.
python main.py validate --config configs/{dataset_name}.yaml --name {exp_name} --ckpt_path {ckpt_path}
For UCI regression tasks:
python main.py validate --config configs/uci.yaml --name {exp_name} --data.task {task_name} --data.split_num {split_num} --ckpt_path {ckpt_path}
Trained checkpoints can be found at release tab of this repository.
Dataset | Dopri Acc. | Link |
---|---|---|
MNIST | 99.30% | Download |
SVHN | 96.12% | Download |
CIFAR10 | 88.89% | Download |
We use wandb to monitor training progress and inference results.
The wandb run name will match the argument provided for --name
.
You can also change the project name by modifying trainer.logger.init_args.project
in the configuration file (default value is SFNO_exp
).
Our code is implented with LightningCLI
, so you can simply overwrite the config via command-line arguments to experiment with various settings.
Examples:
# Run MNIST experiment with batch size 128
python main.py fit --config configs/mnist.yaml --name mnist_b128 --data.batch_size 128
# Run SVHN experiment with explicit sampling of $t=0$ with probability 0.01
python main.py fit --config configs/svhn.yaml --name svhn_zero_001 --model.init_args.force_zero_prob 0.01
# Run CIFAR10 experiment with 'concave' dynamics
python main.py fit --config configs/cifar10.yaml --name cifar10_concave --model.init_args.dynamics concave
Refer to Lightning Trainer documentation for controlling trainer-related configurations (e.g., training steps or logging frequency).
This implementation of this code was based on the following repositories: NeuralODE, ANODE, and CARD.
@article{kim2024simfreenode,
title={Simulation-Free Training of Neural ODEs on Paired Data},
author={Semin Kim and
Jaehoon Yoo and
Jinwoo Kim and
Yeonwoo Cha and
Saehoon Kim and
Seunghoon Hong},
journal={arXiv preprint arXiv:2410.22918},
year={2024},
url={https://arxiv.org/abs/2410.22918},
}