Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
SeminKim committed Oct 30, 2024
0 parents commit 169a190
Show file tree
Hide file tree
Showing 16 changed files with 2,011 additions and 0 deletions.
12 changes: 12 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
__pycache__
*.pyc
.data
.vscode

logs
lightning_logs
ckpts
wandb


UCI_Datasets
102 changes: 102 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# Simulation-Free Training of Neural ODEs on Paired Data
This repository contains the official implementation of the paper **"Simulation-Free Training of Neural ODEs on Paired Data (NeurIPS 2024)"**

Semin Kim*, Jaehoon Yoo*, Jinwoo Kim, Yeonwoo Cha, Saehoon Kim, Seunghoon Hong

[Paper Link (TODO)](TODO)

## Setup
To set up the environment, start by installing dependencies listed in `requirements.txt`. You can also use Docker to streamline the setup process.

1. **Docker Setup:**
```
docker pull pytorch/pytorch:2.1.0-cuda11.8-cudnn8-devel
docker run -it pytorch/pytorch:2.1.0-cuda11.8-cudnn8-devel bash
```

2. **Clone the Repository:**
```
git clone https://github.com/seminkim/simulation-free-node.git
```
3. **Install Requirements:**
```
pip install -r requirements.txt
```
## Datasets
Place all datasets in the `.data` directory. By default, this code automatically downloads the MNIST, CIFAR-10, and SVHN datasets into the `.data` directory.


The UCI dataset, composed of 10 tasks (`bostonHousing`, `concrete`, `energy`, `kin8nm`, `naval-propulsion-plant`, `power-plant`, `protein-tertiary-structure`, `wine-quality-red`, `yacht`, and `YearPredictionMSD`), can be manually downloaded from the **Usage** part of the following repository: [CARD](https://github.com/XzwHan/CARD).

## Training
Scripts for training are available for both classification and regression tasks.

### Classification
To train a model for a classification task, run:
```
python main.py fit --config configs/{dataset_name}.yaml --name {exp_name}
```

### Regression
For regression tasks (only supported with UCI datasets), use the following command:

```
python main.py fit --config configs/uci.yaml --name {exp_name} --data.task {task_name} --data.split_num {split_num}
```
In this command, specify the UCI task name and the data split number accordingly.

## Inference
Use the following commands for model evaluation.
### Classification
```
python main.py validate --config configs/{dataset_name}.yaml --name {exp_name} --ckpt_path {ckpt_path}
```

### Regression
For UCI regression tasks:

```
python main.py validate --config configs/uci.yaml --name {exp_name} --data.task {task_name} --data.split_num {split_num} --ckpt_path {ckpt_path}
```

### Checkpoints
Trained checkpoints can be found at release tab of this repository.

|Dataset |Dopri Acc. |Link |
|:---: |:---: |:---: |
|MNIST |99.30% |[TODO]()|
|SVHN |96.12% |[TODO]()|
|CIFAR10 |88.89% |[TODO]()|


## Additional Notes
### Logging
We use wandb to monitor training progress and inference results.
The wandb run name will match the argument provided for `--name`.
You can also change the project name by modifying `trainer.logger.init_args.project` in the configuration file (default value is `SFNO_exp`).

### Running Your Own Experiment
Our code is implented with `LightningCLI`, so you can simply overwrite the config via command-line arguments to experiment with various settings.

Examples:
```
# Run MNIST experiment with batch size 128
python main.py fit --config configs/mnist.yaml --name mnist_b128 --data.batch_size 128
# Run SVHN experiment with explicit sampling of $t=0$ with probability 0.01
python main.py fit --config configs/svhn.yaml --name svhn_zero_001 --model.init_args.force_zero_prob 0.01
# Run CIFAR10 experiment with 'concave' dynamics
python main.py fit --config configs/cifar10.yaml --name cifar10_concave --model.init_args.dynamics concave
```
Refer to [Lightning Trainer documentation](https://lightning.ai/docs/pytorch/stable/common/trainer.html) for controlling trainer-related configurations (e.g., training steps or logging frequency).

## Acknowledgements
This implementation of this code was based on the following repositories: [NeuralODE](https://github.com/rtqichen/torchdiffeq), [ANODE](https://github.com/EmilienDupont/augmented-neural-odes), and [CARD](https://github.com/XzwHan/CARD).



## Citation
```
(TODO)
```
69 changes: 69 additions & 0 deletions configs/cifar10.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
name: CIFAR10

model:
class_path: models.conv_model.ConvModel
init_args:
data_dim: 3
emb_res:
- 7
- 7
latent_dim: 256
hidden_dim: 256
in_latent_dim: 64
h_add_blocks: 4
f_add_blocks: 4
g_add_blocks: 0
num_classes: 10

method: ours
force_zero_prob: 0.1
metric_type: accuracy
label_scaler: null

scheduler: cos
lr: 3e-4
wd: 0.0
task_criterion: ce
dynamics: linear
adjoint: false
label_ae_noise: 10.0

trainer:
val_check_interval: 1960
check_val_every_n_epoch: null
max_steps: 100000
log_every_n_steps: 1
gradient_clip_val: 0
logger:
class_path: lightning.pytorch.loggers.WandbLogger
init_args:
project: SFNO_exp
log_model: false
save_dir: ./logs
callbacks:
- class_path: lightning.pytorch.callbacks.LearningRateMonitor
init_args:
logging_interval: step
- class_path: lightning.pytorch.callbacks.ModelCheckpoint # log best
init_args:
save_last: true
monitor: 'val/accuracy_dopri'
save_top_k: 1
mode: max
dirpath: null
- class_path: lightning.pytorch.callbacks.ModelCheckpoint # log every 24 hours
init_args:
save_top_k: -1
dirpath: null
train_time_interval: '24:0:0'
- class_path: lightning.pytorch.callbacks.RichModelSummary
init_args:
max_depth: 10
data:
dataset: cifar10
batch_size: 1024
test_batch_size: 768
task_type: classification


seed_everything: 0
60 changes: 60 additions & 0 deletions configs/mnist.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
name: MNIST

model:
class_path: models.mlp_model.MLPModel
init_args:
data_dim: 784
hidden_dim: 2048
f_add_blocks: 1
h_add_blocks: 0
g_add_blocks: 0
in_proj: mlp
out_proj: mlp
proj_norm: bn
output_dim: 10

method: ours
force_zero_prob: 0.1
metric_type: accuracy
label_scaler: none

scheduler: none
lr: 1e-4
wd: 0.0
task_criterion: ce
dynamics: linear
adjoint: false
label_ae_noise: 3.0
total_steps: 100000


trainer:
check_val_every_n_epoch: null
max_steps: 500000
log_every_n_steps: 1
gradient_clip_val: 0
logger:
class_path: lightning.pytorch.loggers.WandbLogger
init_args:
project: SFNO_exp
log_model: false
save_dir: ./logs
callbacks:
- class_path: lightning.pytorch.callbacks.LearningRateMonitor
init_args:
logging_interval: step
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
monitor: 'val/accuracy_dopri'
save_top_k: 1
mode: max
dirpath: null
train_time_interval: null

data:
dataset: mnist
batch_size: 1024
test_batch_size: 768
task_type: classification

seed_everything: 0
69 changes: 69 additions & 0 deletions configs/svhn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
name: SVHN

model:
class_path: models.conv_model.ConvModel
init_args:
data_dim: 3
emb_res:
- 7
- 7
latent_dim: 256
hidden_dim: 256
in_latent_dim: 64
h_add_blocks: 4
f_add_blocks: 4
g_add_blocks: 0
num_classes: 10

method: ours
force_zero_prob: 0.1
metric_type: accuracy
label_scaler: null

scheduler: cos
lr: 3e-4
wd: 0.0
task_criterion: ce
dynamics: linear
adjoint: false
label_ae_noise: 7.0

trainer:
val_check_interval: 1960
check_val_every_n_epoch: null
max_steps: 100000
log_every_n_steps: 1
gradient_clip_val: 0
logger:
class_path: lightning.pytorch.loggers.WandbLogger
init_args:
project: SFNO_exp
log_model: false
save_dir: ./logs
callbacks:
- class_path: lightning.pytorch.callbacks.LearningRateMonitor
init_args:
logging_interval: step
- class_path: lightning.pytorch.callbacks.ModelCheckpoint # log best
init_args:
save_last: true
monitor: 'val/accuracy_dopri'
save_top_k: 1
mode: max
dirpath: null
- class_path: lightning.pytorch.callbacks.ModelCheckpoint # log every 24 hours
init_args:
save_top_k: -1
dirpath: null
train_time_interval: '24:0:0'
- class_path: lightning.pytorch.callbacks.RichModelSummary
init_args:
max_depth: 10
data:
dataset: svhn
batch_size: 1024
test_batch_size: 768
task_type: classification


seed_everything: 0
76 changes: 76 additions & 0 deletions configs/uci.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
name: UCI

model:
class_path: models.mlp_model.MLPModel
init_args:
data_dim: 13
hidden_dim: 64
latent_dim: 64
f_add_blocks: 0
h_add_blocks: 0
g_add_blocks: 0
in_proj: mlp
out_proj: linear

method: ours
force_zero_prob: 0.1
metric_type: rmse
label_scaler: true
scheduler: 'none'

lr: 0.003
wd: 0.0
task_criterion: ce
dynamics: linear
adjoint: false
label_ae_noise: 3.0


trainer:
check_val_every_n_epoch: null
max_steps: 10000
log_every_n_steps: 10
max_steps: 500000
log_every_n_steps: 1
gradient_clip_val: 0
logger:
class_path: lightning.pytorch.loggers.WandbLogger
init_args:
project: SFNO_exp
log_model: false
save_dir: ./logs

callbacks:
- class_path: lightning.pytorch.callbacks.LearningRateMonitor
init_args:
logging_interval: step
- class_path: lightning.pytorch.callbacks.ModelCheckpoint # log best
init_args:
save_last: true
monitor: 'val/rmse_dopri'
save_top_k: 1
mode: min
mode: max
dirpath: null
- class_path: lightning.pytorch.callbacks.ModelCheckpoint # log every 24 hours
init_args:
save_top_k: -1
dirpath: null
train_time_interval: '24:0:0'
- class_path: lightning.pytorch.callbacks.EarlyStopping
init_args:
patience: 100
monitor: 'val/rmse_dopri'
mode: min


data:
dataset: uci
batch_size: 64
test_batch_size: 64
val_perc: 0.001
task_type: regression
task: bostonHousing
split_num: 0

seed_everything: 0
Loading

0 comments on commit 169a190

Please sign in to comment.