-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 169a190
Showing
16 changed files
with
2,011 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
__pycache__ | ||
*.pyc | ||
.data | ||
.vscode | ||
|
||
logs | ||
lightning_logs | ||
ckpts | ||
wandb | ||
|
||
|
||
UCI_Datasets |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
# Simulation-Free Training of Neural ODEs on Paired Data | ||
This repository contains the official implementation of the paper **"Simulation-Free Training of Neural ODEs on Paired Data (NeurIPS 2024)"** | ||
|
||
Semin Kim*, Jaehoon Yoo*, Jinwoo Kim, Yeonwoo Cha, Saehoon Kim, Seunghoon Hong | ||
|
||
[Paper Link (TODO)](TODO) | ||
|
||
## Setup | ||
To set up the environment, start by installing dependencies listed in `requirements.txt`. You can also use Docker to streamline the setup process. | ||
|
||
1. **Docker Setup:** | ||
``` | ||
docker pull pytorch/pytorch:2.1.0-cuda11.8-cudnn8-devel | ||
docker run -it pytorch/pytorch:2.1.0-cuda11.8-cudnn8-devel bash | ||
``` | ||
|
||
2. **Clone the Repository:** | ||
``` | ||
git clone https://github.com/seminkim/simulation-free-node.git | ||
``` | ||
3. **Install Requirements:** | ||
``` | ||
pip install -r requirements.txt | ||
``` | ||
## Datasets | ||
Place all datasets in the `.data` directory. By default, this code automatically downloads the MNIST, CIFAR-10, and SVHN datasets into the `.data` directory. | ||
|
||
|
||
The UCI dataset, composed of 10 tasks (`bostonHousing`, `concrete`, `energy`, `kin8nm`, `naval-propulsion-plant`, `power-plant`, `protein-tertiary-structure`, `wine-quality-red`, `yacht`, and `YearPredictionMSD`), can be manually downloaded from the **Usage** part of the following repository: [CARD](https://github.com/XzwHan/CARD). | ||
|
||
## Training | ||
Scripts for training are available for both classification and regression tasks. | ||
|
||
### Classification | ||
To train a model for a classification task, run: | ||
``` | ||
python main.py fit --config configs/{dataset_name}.yaml --name {exp_name} | ||
``` | ||
|
||
### Regression | ||
For regression tasks (only supported with UCI datasets), use the following command: | ||
|
||
``` | ||
python main.py fit --config configs/uci.yaml --name {exp_name} --data.task {task_name} --data.split_num {split_num} | ||
``` | ||
In this command, specify the UCI task name and the data split number accordingly. | ||
|
||
## Inference | ||
Use the following commands for model evaluation. | ||
### Classification | ||
``` | ||
python main.py validate --config configs/{dataset_name}.yaml --name {exp_name} --ckpt_path {ckpt_path} | ||
``` | ||
|
||
### Regression | ||
For UCI regression tasks: | ||
|
||
``` | ||
python main.py validate --config configs/uci.yaml --name {exp_name} --data.task {task_name} --data.split_num {split_num} --ckpt_path {ckpt_path} | ||
``` | ||
|
||
### Checkpoints | ||
Trained checkpoints can be found at release tab of this repository. | ||
|
||
|Dataset |Dopri Acc. |Link | | ||
|:---: |:---: |:---: | | ||
|MNIST |99.30% |[TODO]()| | ||
|SVHN |96.12% |[TODO]()| | ||
|CIFAR10 |88.89% |[TODO]()| | ||
|
||
|
||
## Additional Notes | ||
### Logging | ||
We use wandb to monitor training progress and inference results. | ||
The wandb run name will match the argument provided for `--name`. | ||
You can also change the project name by modifying `trainer.logger.init_args.project` in the configuration file (default value is `SFNO_exp`). | ||
|
||
### Running Your Own Experiment | ||
Our code is implented with `LightningCLI`, so you can simply overwrite the config via command-line arguments to experiment with various settings. | ||
|
||
Examples: | ||
``` | ||
# Run MNIST experiment with batch size 128 | ||
python main.py fit --config configs/mnist.yaml --name mnist_b128 --data.batch_size 128 | ||
# Run SVHN experiment with explicit sampling of $t=0$ with probability 0.01 | ||
python main.py fit --config configs/svhn.yaml --name svhn_zero_001 --model.init_args.force_zero_prob 0.01 | ||
# Run CIFAR10 experiment with 'concave' dynamics | ||
python main.py fit --config configs/cifar10.yaml --name cifar10_concave --model.init_args.dynamics concave | ||
``` | ||
Refer to [Lightning Trainer documentation](https://lightning.ai/docs/pytorch/stable/common/trainer.html) for controlling trainer-related configurations (e.g., training steps or logging frequency). | ||
|
||
## Acknowledgements | ||
This implementation of this code was based on the following repositories: [NeuralODE](https://github.com/rtqichen/torchdiffeq), [ANODE](https://github.com/EmilienDupont/augmented-neural-odes), and [CARD](https://github.com/XzwHan/CARD). | ||
|
||
|
||
|
||
## Citation | ||
``` | ||
(TODO) | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
name: CIFAR10 | ||
|
||
model: | ||
class_path: models.conv_model.ConvModel | ||
init_args: | ||
data_dim: 3 | ||
emb_res: | ||
- 7 | ||
- 7 | ||
latent_dim: 256 | ||
hidden_dim: 256 | ||
in_latent_dim: 64 | ||
h_add_blocks: 4 | ||
f_add_blocks: 4 | ||
g_add_blocks: 0 | ||
num_classes: 10 | ||
|
||
method: ours | ||
force_zero_prob: 0.1 | ||
metric_type: accuracy | ||
label_scaler: null | ||
|
||
scheduler: cos | ||
lr: 3e-4 | ||
wd: 0.0 | ||
task_criterion: ce | ||
dynamics: linear | ||
adjoint: false | ||
label_ae_noise: 10.0 | ||
|
||
trainer: | ||
val_check_interval: 1960 | ||
check_val_every_n_epoch: null | ||
max_steps: 100000 | ||
log_every_n_steps: 1 | ||
gradient_clip_val: 0 | ||
logger: | ||
class_path: lightning.pytorch.loggers.WandbLogger | ||
init_args: | ||
project: SFNO_exp | ||
log_model: false | ||
save_dir: ./logs | ||
callbacks: | ||
- class_path: lightning.pytorch.callbacks.LearningRateMonitor | ||
init_args: | ||
logging_interval: step | ||
- class_path: lightning.pytorch.callbacks.ModelCheckpoint # log best | ||
init_args: | ||
save_last: true | ||
monitor: 'val/accuracy_dopri' | ||
save_top_k: 1 | ||
mode: max | ||
dirpath: null | ||
- class_path: lightning.pytorch.callbacks.ModelCheckpoint # log every 24 hours | ||
init_args: | ||
save_top_k: -1 | ||
dirpath: null | ||
train_time_interval: '24:0:0' | ||
- class_path: lightning.pytorch.callbacks.RichModelSummary | ||
init_args: | ||
max_depth: 10 | ||
data: | ||
dataset: cifar10 | ||
batch_size: 1024 | ||
test_batch_size: 768 | ||
task_type: classification | ||
|
||
|
||
seed_everything: 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
name: MNIST | ||
|
||
model: | ||
class_path: models.mlp_model.MLPModel | ||
init_args: | ||
data_dim: 784 | ||
hidden_dim: 2048 | ||
f_add_blocks: 1 | ||
h_add_blocks: 0 | ||
g_add_blocks: 0 | ||
in_proj: mlp | ||
out_proj: mlp | ||
proj_norm: bn | ||
output_dim: 10 | ||
|
||
method: ours | ||
force_zero_prob: 0.1 | ||
metric_type: accuracy | ||
label_scaler: none | ||
|
||
scheduler: none | ||
lr: 1e-4 | ||
wd: 0.0 | ||
task_criterion: ce | ||
dynamics: linear | ||
adjoint: false | ||
label_ae_noise: 3.0 | ||
total_steps: 100000 | ||
|
||
|
||
trainer: | ||
check_val_every_n_epoch: null | ||
max_steps: 500000 | ||
log_every_n_steps: 1 | ||
gradient_clip_val: 0 | ||
logger: | ||
class_path: lightning.pytorch.loggers.WandbLogger | ||
init_args: | ||
project: SFNO_exp | ||
log_model: false | ||
save_dir: ./logs | ||
callbacks: | ||
- class_path: lightning.pytorch.callbacks.LearningRateMonitor | ||
init_args: | ||
logging_interval: step | ||
- class_path: lightning.pytorch.callbacks.ModelCheckpoint | ||
init_args: | ||
monitor: 'val/accuracy_dopri' | ||
save_top_k: 1 | ||
mode: max | ||
dirpath: null | ||
train_time_interval: null | ||
|
||
data: | ||
dataset: mnist | ||
batch_size: 1024 | ||
test_batch_size: 768 | ||
task_type: classification | ||
|
||
seed_everything: 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
name: SVHN | ||
|
||
model: | ||
class_path: models.conv_model.ConvModel | ||
init_args: | ||
data_dim: 3 | ||
emb_res: | ||
- 7 | ||
- 7 | ||
latent_dim: 256 | ||
hidden_dim: 256 | ||
in_latent_dim: 64 | ||
h_add_blocks: 4 | ||
f_add_blocks: 4 | ||
g_add_blocks: 0 | ||
num_classes: 10 | ||
|
||
method: ours | ||
force_zero_prob: 0.1 | ||
metric_type: accuracy | ||
label_scaler: null | ||
|
||
scheduler: cos | ||
lr: 3e-4 | ||
wd: 0.0 | ||
task_criterion: ce | ||
dynamics: linear | ||
adjoint: false | ||
label_ae_noise: 7.0 | ||
|
||
trainer: | ||
val_check_interval: 1960 | ||
check_val_every_n_epoch: null | ||
max_steps: 100000 | ||
log_every_n_steps: 1 | ||
gradient_clip_val: 0 | ||
logger: | ||
class_path: lightning.pytorch.loggers.WandbLogger | ||
init_args: | ||
project: SFNO_exp | ||
log_model: false | ||
save_dir: ./logs | ||
callbacks: | ||
- class_path: lightning.pytorch.callbacks.LearningRateMonitor | ||
init_args: | ||
logging_interval: step | ||
- class_path: lightning.pytorch.callbacks.ModelCheckpoint # log best | ||
init_args: | ||
save_last: true | ||
monitor: 'val/accuracy_dopri' | ||
save_top_k: 1 | ||
mode: max | ||
dirpath: null | ||
- class_path: lightning.pytorch.callbacks.ModelCheckpoint # log every 24 hours | ||
init_args: | ||
save_top_k: -1 | ||
dirpath: null | ||
train_time_interval: '24:0:0' | ||
- class_path: lightning.pytorch.callbacks.RichModelSummary | ||
init_args: | ||
max_depth: 10 | ||
data: | ||
dataset: svhn | ||
batch_size: 1024 | ||
test_batch_size: 768 | ||
task_type: classification | ||
|
||
|
||
seed_everything: 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
name: UCI | ||
|
||
model: | ||
class_path: models.mlp_model.MLPModel | ||
init_args: | ||
data_dim: 13 | ||
hidden_dim: 64 | ||
latent_dim: 64 | ||
f_add_blocks: 0 | ||
h_add_blocks: 0 | ||
g_add_blocks: 0 | ||
in_proj: mlp | ||
out_proj: linear | ||
|
||
method: ours | ||
force_zero_prob: 0.1 | ||
metric_type: rmse | ||
label_scaler: true | ||
scheduler: 'none' | ||
|
||
lr: 0.003 | ||
wd: 0.0 | ||
task_criterion: ce | ||
dynamics: linear | ||
adjoint: false | ||
label_ae_noise: 3.0 | ||
|
||
|
||
trainer: | ||
check_val_every_n_epoch: null | ||
max_steps: 10000 | ||
log_every_n_steps: 10 | ||
max_steps: 500000 | ||
log_every_n_steps: 1 | ||
gradient_clip_val: 0 | ||
logger: | ||
class_path: lightning.pytorch.loggers.WandbLogger | ||
init_args: | ||
project: SFNO_exp | ||
log_model: false | ||
save_dir: ./logs | ||
|
||
callbacks: | ||
- class_path: lightning.pytorch.callbacks.LearningRateMonitor | ||
init_args: | ||
logging_interval: step | ||
- class_path: lightning.pytorch.callbacks.ModelCheckpoint # log best | ||
init_args: | ||
save_last: true | ||
monitor: 'val/rmse_dopri' | ||
save_top_k: 1 | ||
mode: min | ||
mode: max | ||
dirpath: null | ||
- class_path: lightning.pytorch.callbacks.ModelCheckpoint # log every 24 hours | ||
init_args: | ||
save_top_k: -1 | ||
dirpath: null | ||
train_time_interval: '24:0:0' | ||
- class_path: lightning.pytorch.callbacks.EarlyStopping | ||
init_args: | ||
patience: 100 | ||
monitor: 'val/rmse_dopri' | ||
mode: min | ||
|
||
|
||
data: | ||
dataset: uci | ||
batch_size: 64 | ||
test_batch_size: 64 | ||
val_perc: 0.001 | ||
task_type: regression | ||
task: bostonHousing | ||
split_num: 0 | ||
|
||
seed_everything: 0 |
Oops, something went wrong.