From 169a19046d5df5b9b01483008d671cc4c9dec64a Mon Sep 17 00:00:00 2001 From: seminkim Date: Wed, 30 Oct 2024 19:47:41 +0900 Subject: [PATCH] initial commit --- .gitignore | 12 ++ README.md | 102 +++++++++ configs/cifar10.yaml | 69 ++++++ configs/mnist.yaml | 60 ++++++ configs/svhn.yaml | 69 ++++++ configs/uci.yaml | 76 +++++++ data/classification.py | 151 +++++++++++++ data/load_data.py | 50 +++++ data/regression.py | 178 ++++++++++++++++ main.py | 87 ++++++++ models/base.py | 474 +++++++++++++++++++++++++++++++++++++++++ models/conv_model.py | 152 +++++++++++++ models/dynamics.py | 86 ++++++++ models/mlp_model.py | 216 +++++++++++++++++++ requirements.txt | 207 ++++++++++++++++++ utils.py | 22 ++ 16 files changed, 2011 insertions(+) create mode 100644 .gitignore create mode 100644 README.md create mode 100644 configs/cifar10.yaml create mode 100644 configs/mnist.yaml create mode 100644 configs/svhn.yaml create mode 100644 configs/uci.yaml create mode 100644 data/classification.py create mode 100644 data/load_data.py create mode 100644 data/regression.py create mode 100644 main.py create mode 100644 models/base.py create mode 100644 models/conv_model.py create mode 100644 models/dynamics.py create mode 100644 models/mlp_model.py create mode 100644 requirements.txt create mode 100644 utils.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c64ec7c --- /dev/null +++ b/.gitignore @@ -0,0 +1,12 @@ +__pycache__ +*.pyc +.data +.vscode + +logs +lightning_logs +ckpts +wandb + + +UCI_Datasets \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..8536d46 --- /dev/null +++ b/README.md @@ -0,0 +1,102 @@ +# Simulation-Free Training of Neural ODEs on Paired Data +This repository contains the official implementation of the paper **"Simulation-Free Training of Neural ODEs on Paired Data (NeurIPS 2024)"** + +Semin Kim*, Jaehoon Yoo*, Jinwoo Kim, Yeonwoo Cha, Saehoon Kim, Seunghoon Hong + +[Paper Link (TODO)](TODO) + +## Setup +To set up the environment, start by installing dependencies listed in `requirements.txt`. You can also use Docker to streamline the setup process. + +1. **Docker Setup:** +``` +docker pull pytorch/pytorch:2.1.0-cuda11.8-cudnn8-devel +docker run -it pytorch/pytorch:2.1.0-cuda11.8-cudnn8-devel bash +``` + +2. **Clone the Repository:** +``` +git clone https://github.com/seminkim/simulation-free-node.git +``` +3. **Install Requirements:** +``` +pip install -r requirements.txt +``` +## Datasets +Place all datasets in the `.data` directory. By default, this code automatically downloads the MNIST, CIFAR-10, and SVHN datasets into the `.data` directory. + + +The UCI dataset, composed of 10 tasks (`bostonHousing`, `concrete`, `energy`, `kin8nm`, `naval-propulsion-plant`, `power-plant`, `protein-tertiary-structure`, `wine-quality-red`, `yacht`, and `YearPredictionMSD`), can be manually downloaded from the **Usage** part of the following repository: [CARD](https://github.com/XzwHan/CARD). + +## Training +Scripts for training are available for both classification and regression tasks. + +### Classification +To train a model for a classification task, run: +``` +python main.py fit --config configs/{dataset_name}.yaml --name {exp_name} +``` + +### Regression +For regression tasks (only supported with UCI datasets), use the following command: + +``` +python main.py fit --config configs/uci.yaml --name {exp_name} --data.task {task_name} --data.split_num {split_num} +``` +In this command, specify the UCI task name and the data split number accordingly. + +## Inference +Use the following commands for model evaluation. +### Classification +``` +python main.py validate --config configs/{dataset_name}.yaml --name {exp_name} --ckpt_path {ckpt_path} +``` + +### Regression +For UCI regression tasks: + +``` +python main.py validate --config configs/uci.yaml --name {exp_name} --data.task {task_name} --data.split_num {split_num} --ckpt_path {ckpt_path} +``` + +### Checkpoints +Trained checkpoints can be found at release tab of this repository. + +|Dataset |Dopri Acc. |Link | +|:---: |:---: |:---: | +|MNIST |99.30% |[TODO]()| +|SVHN |96.12% |[TODO]()| +|CIFAR10 |88.89% |[TODO]()| + + +## Additional Notes +### Logging +We use wandb to monitor training progress and inference results. +The wandb run name will match the argument provided for `--name`. +You can also change the project name by modifying `trainer.logger.init_args.project` in the configuration file (default value is `SFNO_exp`). + +### Running Your Own Experiment +Our code is implented with `LightningCLI`, so you can simply overwrite the config via command-line arguments to experiment with various settings. + +Examples: +``` +# Run MNIST experiment with batch size 128 +python main.py fit --config configs/mnist.yaml --name mnist_b128 --data.batch_size 128 + +# Run SVHN experiment with explicit sampling of $t=0$ with probability 0.01 +python main.py fit --config configs/svhn.yaml --name svhn_zero_001 --model.init_args.force_zero_prob 0.01 + +# Run CIFAR10 experiment with 'concave' dynamics +python main.py fit --config configs/cifar10.yaml --name cifar10_concave --model.init_args.dynamics concave +``` +Refer to [Lightning Trainer documentation](https://lightning.ai/docs/pytorch/stable/common/trainer.html) for controlling trainer-related configurations (e.g., training steps or logging frequency). + +## Acknowledgements +This implementation of this code was based on the following repositories: [NeuralODE](https://github.com/rtqichen/torchdiffeq), [ANODE](https://github.com/EmilienDupont/augmented-neural-odes), and [CARD](https://github.com/XzwHan/CARD). + + + +## Citation +``` +(TODO) +``` \ No newline at end of file diff --git a/configs/cifar10.yaml b/configs/cifar10.yaml new file mode 100644 index 0000000..3186dc2 --- /dev/null +++ b/configs/cifar10.yaml @@ -0,0 +1,69 @@ +name: CIFAR10 + +model: + class_path: models.conv_model.ConvModel + init_args: + data_dim: 3 + emb_res: + - 7 + - 7 + latent_dim: 256 + hidden_dim: 256 + in_latent_dim: 64 + h_add_blocks: 4 + f_add_blocks: 4 + g_add_blocks: 0 + num_classes: 10 + + method: ours + force_zero_prob: 0.1 + metric_type: accuracy + label_scaler: null + + scheduler: cos + lr: 3e-4 + wd: 0.0 + task_criterion: ce + dynamics: linear + adjoint: false + label_ae_noise: 10.0 + +trainer: + val_check_interval: 1960 + check_val_every_n_epoch: null + max_steps: 100000 + log_every_n_steps: 1 + gradient_clip_val: 0 + logger: + class_path: lightning.pytorch.loggers.WandbLogger + init_args: + project: SFNO_exp + log_model: false + save_dir: ./logs + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.ModelCheckpoint # log best + init_args: + save_last: true + monitor: 'val/accuracy_dopri' + save_top_k: 1 + mode: max + dirpath: null + - class_path: lightning.pytorch.callbacks.ModelCheckpoint # log every 24 hours + init_args: + save_top_k: -1 + dirpath: null + train_time_interval: '24:0:0' + - class_path: lightning.pytorch.callbacks.RichModelSummary + init_args: + max_depth: 10 +data: + dataset: cifar10 + batch_size: 1024 + test_batch_size: 768 + task_type: classification + + +seed_everything: 0 diff --git a/configs/mnist.yaml b/configs/mnist.yaml new file mode 100644 index 0000000..2b1492b --- /dev/null +++ b/configs/mnist.yaml @@ -0,0 +1,60 @@ +name: MNIST + +model: + class_path: models.mlp_model.MLPModel + init_args: + data_dim: 784 + hidden_dim: 2048 + f_add_blocks: 1 + h_add_blocks: 0 + g_add_blocks: 0 + in_proj: mlp + out_proj: mlp + proj_norm: bn + output_dim: 10 + + method: ours + force_zero_prob: 0.1 + metric_type: accuracy + label_scaler: none + + scheduler: none + lr: 1e-4 + wd: 0.0 + task_criterion: ce + dynamics: linear + adjoint: false + label_ae_noise: 3.0 + total_steps: 100000 + + +trainer: + check_val_every_n_epoch: null + max_steps: 500000 + log_every_n_steps: 1 + gradient_clip_val: 0 + logger: + class_path: lightning.pytorch.loggers.WandbLogger + init_args: + project: SFNO_exp + log_model: false + save_dir: ./logs + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.ModelCheckpoint + init_args: + monitor: 'val/accuracy_dopri' + save_top_k: 1 + mode: max + dirpath: null + train_time_interval: null + +data: + dataset: mnist + batch_size: 1024 + test_batch_size: 768 + task_type: classification + +seed_everything: 0 \ No newline at end of file diff --git a/configs/svhn.yaml b/configs/svhn.yaml new file mode 100644 index 0000000..c9657e6 --- /dev/null +++ b/configs/svhn.yaml @@ -0,0 +1,69 @@ +name: SVHN + +model: + class_path: models.conv_model.ConvModel + init_args: + data_dim: 3 + emb_res: + - 7 + - 7 + latent_dim: 256 + hidden_dim: 256 + in_latent_dim: 64 + h_add_blocks: 4 + f_add_blocks: 4 + g_add_blocks: 0 + num_classes: 10 + + method: ours + force_zero_prob: 0.1 + metric_type: accuracy + label_scaler: null + + scheduler: cos + lr: 3e-4 + wd: 0.0 + task_criterion: ce + dynamics: linear + adjoint: false + label_ae_noise: 7.0 + +trainer: + val_check_interval: 1960 + check_val_every_n_epoch: null + max_steps: 100000 + log_every_n_steps: 1 + gradient_clip_val: 0 + logger: + class_path: lightning.pytorch.loggers.WandbLogger + init_args: + project: SFNO_exp + log_model: false + save_dir: ./logs + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.ModelCheckpoint # log best + init_args: + save_last: true + monitor: 'val/accuracy_dopri' + save_top_k: 1 + mode: max + dirpath: null + - class_path: lightning.pytorch.callbacks.ModelCheckpoint # log every 24 hours + init_args: + save_top_k: -1 + dirpath: null + train_time_interval: '24:0:0' + - class_path: lightning.pytorch.callbacks.RichModelSummary + init_args: + max_depth: 10 +data: + dataset: svhn + batch_size: 1024 + test_batch_size: 768 + task_type: classification + + +seed_everything: 0 diff --git a/configs/uci.yaml b/configs/uci.yaml new file mode 100644 index 0000000..0fc1ee7 --- /dev/null +++ b/configs/uci.yaml @@ -0,0 +1,76 @@ +name: UCI + +model: + class_path: models.mlp_model.MLPModel + init_args: + data_dim: 13 + hidden_dim: 64 + latent_dim: 64 + f_add_blocks: 0 + h_add_blocks: 0 + g_add_blocks: 0 + in_proj: mlp + out_proj: linear + + method: ours + force_zero_prob: 0.1 + metric_type: rmse + label_scaler: true + scheduler: 'none' + + lr: 0.003 + wd: 0.0 + task_criterion: ce + dynamics: linear + adjoint: false + label_ae_noise: 3.0 + + +trainer: + check_val_every_n_epoch: null + max_steps: 10000 + log_every_n_steps: 10 + max_steps: 500000 + log_every_n_steps: 1 + gradient_clip_val: 0 + logger: + class_path: lightning.pytorch.loggers.WandbLogger + init_args: + project: SFNO_exp + log_model: false + save_dir: ./logs + + callbacks: + - class_path: lightning.pytorch.callbacks.LearningRateMonitor + init_args: + logging_interval: step + - class_path: lightning.pytorch.callbacks.ModelCheckpoint # log best + init_args: + save_last: true + monitor: 'val/rmse_dopri' + save_top_k: 1 + mode: min + mode: max + dirpath: null + - class_path: lightning.pytorch.callbacks.ModelCheckpoint # log every 24 hours + init_args: + save_top_k: -1 + dirpath: null + train_time_interval: '24:0:0' + - class_path: lightning.pytorch.callbacks.EarlyStopping + init_args: + patience: 100 + monitor: 'val/rmse_dopri' + mode: min + + +data: + dataset: uci + batch_size: 64 + test_batch_size: 64 + val_perc: 0.001 + task_type: regression + task: bostonHousing + split_num: 0 + +seed_everything: 0 \ No newline at end of file diff --git a/data/classification.py b/data/classification.py new file mode 100644 index 0000000..8851bf8 --- /dev/null +++ b/data/classification.py @@ -0,0 +1,151 @@ +import os +import torch +from torch.utils.data import DataLoader +import torchvision.datasets as datasets +import torchvision.transforms as transforms + + +class SVHNWrapper(datasets.SVHN): + ''' + Simple wrapper to handle split argument + ''' + + def __init__(self, *args, train=True, **kwargs): + split = 'train' if train else 'test' + super().__init__(*args, split=split, **kwargs) + + +# Constants +DATASET_CLASSES = { + 'mnist': datasets.MNIST, + 'cifar10': datasets.CIFAR10, + 'cifar100': datasets.CIFAR100, + 'svhn': SVHNWrapper, +} + +DATASET_NUM_CLASSES = { + 'mnist': 10, + 'cifar10': 10, + 'cifar100': 100, + 'svhn': 10, +} + +DATASET_IMG_SIZE = { + 'mnist': 28, + 'cifar10': 32, + 'cifar100': 32, + 'svhn': 32, +} + +DATASET_PATHNAME = { + 'mnist': 'mnist', + 'cifar10': 'cifar', + 'cifar100': 'cifar', + 'svhn': 'svhn', +} + + +def get_transform(data_aug=True, dataset='mnist', aug_type='basic'): + ''' + Get the data augmentation and normalization transformations for the dataset + ''' + if dataset == 'mnist': + transform_test = transforms.Compose([ + transforms.ToTensor(), + ]) + if data_aug: + transform_train = transforms.Compose([ + transforms.RandomCrop(28, padding=4), + transforms.ToTensor(), + ]) + else: + transform_train = transform_test + + elif dataset in ['cifar10', 'cifar100']: + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + if data_aug: + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ]) + else: + transform_train = transform_test + + elif dataset == 'svhn': + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4377, 0.4438, 0.4728), (0.1980, 0.2010, 0.1970)), + ]) + if data_aug: + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.ToTensor(), + transforms.Normalize((0.4377, 0.4438, 0.4728), (0.1980, 0.2010, 0.1970)), + ]) + else: + transform_train = transform_test + + return transform_train, transform_test + + +def to_one_hot(num_classes): + ''' + Convert labels to one-hot encoding + ''' + def to_one_hot_fn(x): + return torch.nn.functional.one_hot(torch.tensor(x), num_classes).float() + return to_one_hot_fn + + +def get_datasets(root='.data', name='mnist', data_aug=True, perc=None, stride=None, onehot=True, verbose=True, aug_type='basic'): + # download flag + base_path = os.path.join(root, DATASET_PATHNAME[name]) + download = not os.path.exists(base_path) + # stride for train subset evaluation + assert perc is None or stride is None, 'Only one of perc or stride can be set' + if perc is not None: + stride = int(1 / perc) + if stride is None: + stride = 1 + # get transforms + transform_train, transform_test = get_transform(data_aug, name, aug_type) + num_classes = DATASET_NUM_CLASSES[name] + target_transform = to_one_hot(num_classes) if onehot else None + + if name in ['mnist', 'cifar10', 'cifar100', 'svhn']: + dataset_class = DATASET_CLASSES[name] + train_dataset = dataset_class(root=base_path, train=True, download=download, transform=transform_train, + target_transform=target_transform) + eval_dataset = dataset_class(root=base_path, train=True, download=download, transform=transform_test, + target_transform=target_transform) + eval_dataset = torch.utils.data.Subset(eval_dataset, list(range(0, len(eval_dataset), stride))) + test_dataset = dataset_class(root=base_path, train=False, download=download, transform=transform_test, + target_transform=target_transform) + + if verbose: + print(f'Initialized {name} dataset with {len(train_dataset)} training samples, ' + f'{len(eval_dataset)} evaluation samples, and {len(test_dataset)} test samples') + + return train_dataset, eval_dataset, test_dataset + + +def get_dataloaders(train_set, val_set, test_set, batch_size=128, test_batch_size=1000, nw=4, **kwargs): + train_loader = DataLoader( + train_set, batch_size=batch_size, shuffle=True, num_workers=nw, drop_last=True, pin_memory=True, + persistent_workers=True, **kwargs, + ) + val_loader = DataLoader( + val_set, batch_size=test_batch_size, shuffle=False, num_workers=nw, drop_last=False, pin_memory=True, + persistent_workers=True, **kwargs, + ) + test_loader = DataLoader( + test_set, batch_size=test_batch_size, shuffle=False, num_workers=nw, drop_last=False, pin_memory=True, + persistent_workers=True, **kwargs, + ) + + return train_loader, val_loader, test_loader diff --git a/data/load_data.py b/data/load_data.py new file mode 100644 index 0000000..b0fd9e4 --- /dev/null +++ b/data/load_data.py @@ -0,0 +1,50 @@ +from torch.utils.data import DataLoader +import torch +from . import classification +from . import regression +import lightning as L + + +class Data(L.LightningDataModule): + def __init__(self, dataset: str, batch_size: int, test_batch_size: int, task=None, split_num=0., task_type=None, + val_perc=0.01, nw=8, data_aug=True, aug_type='basic', root='.data', **kwargs): + super().__init__() + self.dataset_name = dataset + self.batch_size = batch_size + self.test_batch_size = test_batch_size + self.task = task + self.val_perc = val_perc + self.split_num = split_num + self.task_type = task_type + self.data_aug = data_aug + self.aug_type = aug_type + + if self.dataset_name in ['cifar10', 'cifar100', 'mnist', 'svhn']: + self.train_dataset, self.eval_dataset, self.test_dataset = classification.get_datasets( + root=root, data_aug=self.data_aug, name=self.dataset_name, perc=self.val_perc, aug_type=aug_type) + self.train_loader, self.val_loader, self.test_loader = classification.get_dataloaders( + self.train_dataset, self.eval_dataset, self.test_dataset, batch_size=self.batch_size, test_batch_size=self.test_batch_size, nw=nw) + + elif self.dataset_name == 'uci': + self.train_dataset, self.val_dataset, self.test_dataset = regression.get_UCI_datasets( + root+'/UCI_Datasets', self.task, self.split_num, keys=('train', 'val', 'test') + ) + + self.train_loader = DataLoader(self.train_dataset, batch_size=self.batch_size, + shuffle=True, num_workers=0, pin_memory=False, drop_last=True) + self.val_loader = DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=0, + pin_memory=False) + self.test_loader = DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=0, + pin_memory=False) + + def setup(self, stage: str = None): + pass + + def train_dataloader(self) -> DataLoader: + return self.train_loader + + def val_dataloader(self) -> DataLoader: + return [self.test_loader, self.val_loader] + + def test_dataloader(self) -> DataLoader: + return self.test_loader diff --git a/data/regression.py b/data/regression.py new file mode 100644 index 0000000..58194db --- /dev/null +++ b/data/regression.py @@ -0,0 +1,178 @@ +import os +import torch +import torch.nn.functional as F +import numpy as np +from torch.utils.data import Dataset +from sklearn.preprocessing import StandardScaler +import pandas as pd + + +def onehot_encode_cat_feature(X, cat_var_idx_list): + """ + Apply one-hot encoding to the categorical variable(s) in the feature set, + specified by the index list. + """ + # select numerical features + X_num = np.delete(arr=X, obj=cat_var_idx_list, axis=1) + # select categorical features + X_cat = X[:, cat_var_idx_list] + X_onehot_cat = [] + for col in range(X_cat.shape[1]): + X_onehot_cat.append(pd.get_dummies(X_cat[:, col], drop_first=True)) + X_onehot_cat = np.concatenate(X_onehot_cat, axis=1).astype(np.float32) + dim_cat = X_onehot_cat.shape[1] # number of categorical feature(s) + X = np.concatenate([X_num, X_onehot_cat], axis=1) + return X, dim_cat + + +def _get_index_train_test_path(data_directory_path, split_num, train=True): + """ + Method to generate the path containing the training/test split for the given + split number (generally from 1 to 20). + @param split_num Split number for which the data has to be generated + @param train Is true if the data is training data. Else false. + @return path Path of the file containing the requried data + """ + if train: + return os.path.join(data_directory_path, "index_train_" + str(split_num) + ".txt") + else: + return os.path.join(data_directory_path, "index_test_" + str(split_num) + ".txt") + + +def preprocess_uci_feature_set(X, data_path): + """ + Obtain preprocessed UCI feature set X (one-hot encoding applied for categorical variable) + and dimension of one-hot encoded categorical variables. + """ + dim_cat = 0 + task_name = data_path.split('/')[-1] + if task_name == 'bostonHousing': + X, dim_cat = onehot_encode_cat_feature(X, [3]) + elif task_name == 'energy': + X, dim_cat = onehot_encode_cat_feature(X, [4, 6, 7]) + elif task_name == 'naval-propulsion-plant': + X, dim_cat = onehot_encode_cat_feature(X, [0, 1, 8, 11]) + return X, dim_cat + + +class UCI(Dataset): + def __init__(self, data_path, task, split=0, train_split='train', normalize_x=True, normalize_y=True, train_ratio=0.6): + data_dir = os.path.join(data_path, task, 'data') + data_file = os.path.join(data_dir, 'data.txt') + index_feature_file = os.path.join(data_dir, 'index_features.txt') + index_target_file = os.path.join(data_dir, 'index_target.txt') + n_splits_file = os.path.join(data_dir, 'n_splits.txt') + + data = np.loadtxt(data_file) + index_features = np.loadtxt(index_feature_file) + index_target = np.loadtxt(index_target_file) + + X = data[:, [int(i) for i in index_features.tolist()]].astype(np.float32) + y = data[:, int(index_target.tolist())].astype(np.float32) + + X, dim_cat = preprocess_uci_feature_set(X=X, data_path=data_path) + self.dim_cat = dim_cat + + # load the indices of the train and test sets + index_train = np.loadtxt(_get_index_train_test_path(data_dir, split, train=True)) + index_test = np.loadtxt(_get_index_train_test_path(data_dir, split, train=False)) + + # read in data files with indices + x_train = X[[int(i) for i in index_train.tolist()]] + y_train = y[[int(i) for i in index_train.tolist()]].reshape(-1, 1) + x_test = X[[int(i) for i in index_test.tolist()]] + y_test = y[[int(i) for i in index_test.tolist()]].reshape(-1, 1) + + # split train set further into train and validation set for hyperparameter tuning + num_training_examples = int(train_ratio * x_train.shape[0]) + x_val = x_train[num_training_examples:, :] + y_val = y_train[num_training_examples:] + x_train = x_train[0:num_training_examples, :] + y_train = y_train[0:num_training_examples] + + self.x_train = x_train if type(x_train) is torch.Tensor else torch.from_numpy(x_train) + self.y_train = y_train if type(y_train) is torch.Tensor else torch.from_numpy(y_train) + self.x_test = x_test if type(x_test) is torch.Tensor else torch.from_numpy(x_test) + self.y_test = y_test if type(y_test) is torch.Tensor else torch.from_numpy(y_test) + self.x_val = x_val if type(x_val) is torch.Tensor else torch.from_numpy(x_val) + self.y_val = y_val if type(y_val) is torch.Tensor else torch.from_numpy(y_val) + + self.train_n_samples = x_train.shape[0] + self.train_dim_x = self.x_train.shape[1] # dimension of training data input + self.train_dim_y = self.y_train.shape[1] # dimension of training regression output + + self.test_n_samples = x_test.shape[0] + self.test_dim_x = self.x_test.shape[1] # dimension of testing data input + self.test_dim_y = self.y_test.shape[1] # dimension of testing regression output + + self.normalize_x = normalize_x + self.normalize_y = normalize_y + self.scaler_x, self.scaler_y = None, None + + if self.normalize_x: + self.normalize_train_test_x() + if self.normalize_y: + self.normalize_train_test_y() + + self.return_dataset(train_split) + + def normalize_train_test_x(self): + """ + When self.dim_cat > 0, we have one-hot encoded number of categorical variables, + on which we don't conduct standardization. They are arranged as the last + columns of the feature set. + """ + self.scaler_x = StandardScaler(with_mean=True, with_std=True) + if self.dim_cat == 0: + self.x_train = torch.from_numpy( + self.scaler_x.fit_transform(self.x_train).astype(np.float32)) + self.x_test = torch.from_numpy( + self.scaler_x.transform(self.x_test).astype(np.float32)) + self.x_val = torch.from_numpy( + self.scaler_x.transform(self.x_val).astype(np.float32)) + else: # self.dim_cat > 0 + x_train_num, x_train_cat = self.x_train[:, :-self.dim_cat], self.x_train[:, -self.dim_cat:] + x_test_num, x_test_cat = self.x_test[:, :-self.dim_cat], self.x_test[:, -self.dim_cat:] + x_val_num, x_val_cat = self.x_val[:, :-self.dim_cat], self.x_val[:, -self.dim_cat:] + x_train_num = torch.from_numpy( + self.scaler_x.fit_transform(x_train_num).astype(np.float32)) + x_test_num = torch.from_numpy( + self.scaler_x.transform(x_test_num).astype(np.float32)) + x_val_num = torch.from_numpy( + self.scaler_x.transform(x_val_num).astype(np.float32)) + self.x_train = torch.from_numpy(np.concatenate([x_train_num, x_train_cat], axis=1)) + self.x_test = torch.from_numpy(np.concatenate([x_test_num, x_test_cat], axis=1)) + self.x_val = torch.from_numpy(np.concatenate([x_val_num, x_val_cat], axis=1)) + + def normalize_train_test_y(self): + self.scaler_y = StandardScaler(with_mean=True, with_std=True) + self.y_train = torch.from_numpy( + self.scaler_y.fit_transform(self.y_train).astype(np.float32) + ) + self.y_test = torch.from_numpy( + self.scaler_y.transform(self.y_test).astype(np.float32) + ) + self.y_val = torch.from_numpy( + self.scaler_y.transform(self.y_val).astype(np.float32) + ) + + def return_dataset(self, split="train"): + if split == "train": + self.data = self.x_train.cuda() + self.target = self.y_train.cuda() + elif split == "val": + self.data = self.x_val.cuda() + self.target = self.y_val.cuda() + else: + self.data = self.x_test.cuda() + self.target = self.y_test.cuda() + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + return self.data[idx], self.target[idx] + + +def get_UCI_datasets(data_path, task, split, keys=('train', 'val', 'test')): + return (UCI(data_path, task, split, train_split=key) for key in keys) diff --git a/main.py b/main.py new file mode 100644 index 0000000..29bd435 --- /dev/null +++ b/main.py @@ -0,0 +1,87 @@ +import os +import torch +from data.load_data import Data +from models.base import BaseModel +from lightning.pytorch.cli import LightningCLI +from utils import * + +# The flag below controls whether to allow TF32 on matmul. This flag defaults to False in PyTorch 1.12 and later. +torch.backends.cuda.matmul.allow_tf32 = True +# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True. +torch.backends.cudnn.allow_tf32 = True + + +class CustomLightningCLI(LightningCLI): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def before_fit(self): + self.trainer.logger.log_hyperparams(self.config.fit.as_dict()) + + def add_arguments_to_parser(self, parser): + parser.add_argument("--name", type=str, required=True) # experiment name should be provided + parser.link_arguments("trainer.max_steps", "model.init_args.total_steps") + # automatic linking of arguments + parser.link_arguments("name", "trainer.logger.init_args.name") + parser.link_arguments("name", "trainer.default_root_dir", compute_fn=lambda x: os.path.join("logs", x)) + parser.link_arguments("data.task_type", "model.init_args.task_criterion", compute_fn=compute_task_criterion) + parser.link_arguments("data.task_type", "model.init_args.metric_type", compute_fn=compute_metric_type) + + def before_instantiate_classes(self): + mode = getattr(self.config, 'subcommand', None) + if mode is None: + return # not running subcommand + + self.config[mode]['trainer']['gradient_clip_val'] = 1.0 if self.config[mode]['model'][ + 'init_args']['method'] == 'node' or self.config[mode]['data']['dataset'] == 'uci' else 0. + + # set hidden_dim = latent_dim for uci dataset + if self.config[mode]['data']['task_type'] == 'regression': + assert self.config[mode]['data']['dataset'] == 'uci' + if self.config[mode]['data']['dataset'] == 'uci': + self.config[mode]['model']['init_args']['hidden_dim'] = self.config[mode]['model']['init_args']['latent_dim'] + + # set save_dir of logger + name = self.config[mode]['trainer']['logger']['init_args']['name'] + self.config[mode]['trainer']['logger']['init_args']['save_dir'] = os.path.join( + self.config[mode]['trainer']['logger']['init_args']['save_dir'], name) + os.makedirs(self.config[mode]['trainer']['logger']['init_args']['save_dir'], exist_ok=True) + + def instantiate_classes(self): + ''' + Hacks to set the data_dim, output_dim and label scaler for UCI dataset. + Overrides the instantiate_classes method in LightningCLI. + ''' + mode = getattr(self.config, 'subcommand', None) + if mode is None: # not running subcommand + assert self.config['data']['task_type'] != 'regression', 'debug code not implemented for uci' + return super().instantiate_classes() + + self.config_init = self.parser.instantiate_classes(self.config) + self.datamodule = self._get(self.config_init, "data") + + if self.config[mode]['data']['dataset'] == 'uci': + self.config[mode]['model']['init_args']['data_dim'] = self.datamodule.train_dataset.train_dim_x + self.config[mode]['model']['init_args']['output_dim'] = self.datamodule.train_dataset.train_dim_y + self.config_init = self.parser.instantiate_classes(self.config) + self.model = self._get(self.config_init, "model") + self._add_configure_optimizers_method_to_model(self.subcommand) + self.trainer = self.instantiate_trainer() + if self.config[mode]['model']['init_args']['label_scaler'] is True: + self.datamodule.normalize_y = True + self.model.label_scaler = self.datamodule.train_dataset.scaler_y + + # print total batch size + per_gpu = self.config[mode]['data']['batch_size'] + num_gpus = self.trainer.world_size + total = per_gpu * num_gpus + print(f"Using total batch size {total} = {num_gpus} x {per_gpu}") + + +if __name__ == '__main__': + cli = CustomLightningCLI(model_class=BaseModel, + subclass_mode_model=True, + datamodule_class=Data, + save_config_kwargs={"overwrite": True}, + run=True,) + print(f'Done.') diff --git a/models/base.py b/models/base.py new file mode 100644 index 0000000..d8fd609 --- /dev/null +++ b/models/base.py @@ -0,0 +1,474 @@ +import os +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import lightning as L +from torchdiffeq import odeint, odeint_adjoint +from utils import append_dims +from typing import Dict, List +from .dynamics import get_dynamics + +MAX_NUM_STEPS = 1000 # Maximum number of steps for ODE solver + + +class AppendRepeat(nn.Module): + ''' + append and apply repeat {rep_dims} + e.g. rep_dims=(H,W) for (B, C) -> (B, C, H, W) + ''' + + def __init__(self, rep_dims): + super(AppendRepeat, self).__init__() + self.rep_dims = rep_dims + + def forward(self, x): + ori_dim = x.ndim + for _ in range(len(self.rep_dims)): + x = x.unsqueeze(-1) + return x.repeat(*[1 for _ in range(ori_dim)], *self.rep_dims) + + +class ODEBlock(nn.Module): + """Solves ODE defined by odefunc. + + Parameters + ---------- + device : torch.device + + odefunc : ODEFunc instance or anode.conv_models.ConvODEFunc instance + Function defining dynamics of system. + + is_conv : bool + If True, treats odefunc as a convolutional model. + + tol : float + Error tolerance. + + adjoint : bool + If True calculates gradient with adjoint method, otherwise + backpropagates directly through operations of ODE solver. + """ + + def __init__(self, device, odefunc, is_conv=False, tol=1e-3, adjoint=False): + super(ODEBlock, self).__init__() + self.adjoint = adjoint + self.device = device + self.is_conv = is_conv + self.odefunc = odefunc + self.tol = tol + + def forward(self, x, eval_times=None, method='dopri5'): + """Solves ODE starting from x. + + Parameters + ---------- + x : torch.Tensor + Shape (batch_size, self.odefunc.data_dim) + + eval_times : None or torch.Tensor + If None, returns solution of ODE at final time t=1. If torch.Tensor + then returns full ODE trajectory evaluated at points in eval_times. + """ + self.odefunc.nfe = 0 + + if eval_times is None: + integration_time = torch.tensor([0, 1.0]).float().type_as(x) + else: + integration_time = eval_times.type_as(x) + + x_aug = x + + options = None if method == 'euler' else {'max_num_steps': MAX_NUM_STEPS} + odeint_fn = odeint_adjoint if self.adjoint else odeint + + out = odeint_fn(self.odefunc, x_aug, integration_time, + rtol=self.tol, atol=self.tol, method=method, + options=options) + + if eval_times is None: + return out[1] # Return only final time + else: + return out + + def trajectory(self, x, timesteps, method='dopri5'): + """Returns ODE trajectory. + + Parameters + ---------- + x : torch.Tensor + Shape (batch_size, self.odefunc.data_dim) + + timesteps : int + Number of timesteps in trajectory. + """ + if isinstance(timesteps, int): + integration_time = torch.linspace(0., 1., timesteps) + elif isinstance(timesteps, torch.Tensor): + integration_time = timesteps + else: + raise ValueError('timesteps should be int or torch.Tensor') + return self.forward(x, eval_times=integration_time, method=method) + + +class BaseModel(L.LightningModule): + ''' + A base model for training and inference. + ''' + + def __init__(self, method='ours', force_zero_prob=0.1, metric_type='accuracy', label_scaler=None, scheduler='none', + lr=1e-4, wd=0., task_criterion='ce', dynamics=None, adjoint=False, label_ae_noise=0.0, + total_steps=None, resume_path=None, resume_keys=(), freeze_keys=(), + **kwargs): + + super().__init__() + # unused param warning + if kwargs: + print(f'[!] unused kwargs: {kwargs}') + + self.method = method + self.force_zero_prob = force_zero_prob + self.metric_type = metric_type + self.label_scaler = label_scaler + self.scheduler = scheduler + self.lr = lr + self.wd = wd + self.adjoint = adjoint + self.label_ae_noise = label_ae_noise + self.total_steps = total_steps + self.resume_path = resume_path + self.resume_keys = resume_keys + self.freeze_keys = freeze_keys + + if task_criterion == 'mse': + self.task_criterion = torch.nn.MSELoss() + elif task_criterion == 'ce': + self.task_criterion = torch.nn.CrossEntropyLoss() + else: + raise ValueError(f'Unknown task_criterion: {task_criterion}') + + self.label_ae_criterion = torch.nn.MSELoss() + + self.dynamics = dynamics + if isinstance(dynamics, str): + self.dynamics = get_dynamics(dynamics, return_v=False) + + if self.metric_type == 'accuracy': + self.best_metric = { + f'val/best_acc': 0., + f'test/best_acc': 0., + } + elif self.metric_type == 'rmse': + self.best_metric = { + f'val/best_rmse': np.inf, + f'test/best_rmse': np.inf, + } + else: + raise ValueError(f'Unknown metric type: {self.metric_type}') + + # should be defined in the child class + self.in_projection = None + self.out_projection = None + self.label_projection = None + + def setup(self, stage): + ''' + Delete the label_projection if not needed. + This is to avoid unused param error with DDP. + Additionally load from ckpt and freeze it. + ''' + if self.method != 'ours': + self.label_projection = nn.Identity() + self.load_ckpt() + if stage == 'fit': + self.freeze_weights() + super().setup(stage) + + def load_ckpt(self): + ''' + Load checkpoint at `self.resume_path`, filtered by key prefixes in `self.resume_keys`. + ''' + if self.resume_path is not None: + ckpt = torch.load(self.resume_path) + ckpt = ckpt.get('state_dict', ckpt) # unwrap if needed + ckpt = {k: v for k, v in ckpt.items() if any(k.startswith(prefix) for prefix in self.resume_keys)} + missing, unexpected = self.load_state_dict(ckpt, strict=False) + assert len(unexpected) == 0, f'Unexpected keys: {unexpected}' + print(f'[!] Loaded checkpoint from {self.resume_path}. Loaded keys: {len(ckpt)}') + + def freeze_weights(self): + ''' + Freeze weights based on the keys in `self.freeze_keys`. + ''' + assert all(k in ['in_projection', 'pos_embed', 'out_projection', 'label_projection', 'odeblock'] for k in self.freeze_keys), \ + f'Unknown keys in freeze_keys: {self.freeze_keys}' + param_count = 0 + for key in self.freeze_keys: + target = getattr(self, key) + if isinstance(target, nn.Module): + for param in target.parameters(): + param.requires_grad = False + param_count += param.numel() + elif isinstance(target, nn.Parameter): # e.g. pos_embed + target.requires_grad = False + param_count += target.numel() + else: + raise ValueError(f'Unknown target type: {type(target)}') + print(f'[!] {param_count} parameters are freezed.') + + def forward(self, x, return_features=False, method='dopri5'): + x = self.in_projection(x) + eval_t = torch.tensor([0., 1.], device=x.device) + features = self.odeblock(x, method=method, eval_times=eval_t)[-1] + + pred = self.out_projection(features) + + if return_features: + return features, pred + + return pred + + @torch.inference_mode() + def inference(self, X, method='dopri5', num_timesteps=1+1, return_feat=False): + ''' + Do inference and return the prediction (optionally also return the prediction in the latent space). + If one need a whole trajectory in the latent space, use get_traj instead. + ''' + if self.method == 'onestep': + z0 = self.in_projection(X) + feat = self.odeblock.odefunc(0, z0) + pred = self.out_projection(feat) + elif method == 'dopri5': + feat, pred = self(X, return_features=True, method='dopri5') + else: + traj, pred = self.get_traj(X, method=method, timesteps=num_timesteps) # use get_traj to enforce n-step euler. + feat = traj[-1] + if return_feat: + return feat, pred + return pred + + def evaluate(self, batch, method='dopri5', num_timesteps=1+1, dataloader_type='val', save_mse=False): + ''' + Do inference and save metric (with data count) + Also save nfe for dopri. + ''' + X, Y = batch + bs = X.size(0) + try: + pred = self.inference(X, method=method, num_timesteps=num_timesteps, return_feat=False) + except AssertionError: + # prevent shutdown for dopri error + print(f'Error in inference with method {method}') + pred = Y * torch.nan + if method == 'dopri5': + self.metrics[f'{dataloader_type}/dopri_nfe'] += self.odeblock.odefunc.nfe * bs + + if self.metric_type == 'accuracy': + if method == 'dopri5': + self.metrics[f'{dataloader_type}/accuracy_dopri'] += (pred.argmax(dim=-1) == Y.argmax(dim=-1)).float().sum().item() + else: + self.metrics[f'{dataloader_type}/accuracy_{num_timesteps-1}'] += (pred.argmax(dim=-1) == Y.argmax(dim=-1)).float().sum().item() + elif self.metric_type == 'rmse': + if self.label_scaler is not None: + Y_unnorm = self.label_scaler.inverse_transform(Y.cpu().numpy()) + pred_unnorm = self.label_scaler.inverse_transform(pred.cpu().numpy()) + rmse = np.mean((Y_unnorm - pred_unnorm)**2) + else: + rmse = F.mse_loss(pred, Y).item() + if method == 'dopri5': + self.metrics[f'{dataloader_type}/rmse_dopri'] += rmse * bs + else: + self.metrics[f'{dataloader_type}/rmse_{num_timesteps-1}'] += rmse * bs + else: + raise ValueError(f'Unknown metric type: {self.metric_type}') + + def get_traj(self, x, timesteps=100+1, method='dopri5'): + ''' + note: should +1 to timesteps since it is both start & end inclusive. + ''' + x = self.in_projection(x) + out = self.odeblock.trajectory(x, timesteps, method=method) + return out, self.out_projection(out[-1]) + + def pred_v(self, z, t): + self.odeblock.odefunc.nfe = 0 + return self.odeblock.odefunc(t, z) + + def sample_timestep(self, z0): + t = torch.rand(z0.size(0), device=self.device) + t = append_dims(t, z0.ndim) + # make some portion of sampled t to zero + if self.force_zero_prob > 0.: + mask = (torch.rand_like(t) < self.force_zero_prob).float() + t = t * (1. - mask) + return t + + def configure_optimizers(self): + optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=self.wd) + + # lr scheduler + if self.scheduler == 'none': + return optimizer + elif self.scheduler == 'cos': + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.total_steps, eta_min=0) + else: + raise ValueError(f'Unknown scheduler: {self.scheduler}') + + return { + 'optimizer': optimizer, + 'lr_scheduler': { + "scheduler": scheduler, + "interval": "step", + "frequency": 1, + } + } + + def on_after_backward(self): + ''' + This function is called after the backward pass and before the optimizer step. + ''' + backward_nfe = int(self.odeblock.odefunc.nfe) + self.log('train/backward_nfe', backward_nfe) + self.odeblock.odefunc.nfe = 0 + + # Training + def node_training_step(self, batch, batch_idx): + X, Y = batch + + pred = self(X) + loss = self.task_criterion(pred, Y) + forward_nfe = int(self.odeblock.odefunc.nfe) + self.log_dict({ + 'train/loss': loss.item(), + 'train/nfe': forward_nfe, + }) + self.odeblock.odefunc.nfe = 0 + return loss + + def ours_training_step(self, batch, batch_idx): + X, Y = batch + z0 = self.in_projection(X) + z1 = self.label_projection(Y) + + # sample timestep and construct zt, vt + t = self.sample_timestep(z0) + zt = self.dynamics.get_zt(z0, z1, t) + v_target = self.dynamics.get_vt(z0, z1, t).squeeze() + + # flow loss + v_pred = self.pred_v(zt, t) + flow_loss = F.mse_loss(v_pred, v_target) + + # label autoencoding loss + z1_noised = z1 + if self.label_ae_noise > 0.: + z1_noised = z1 + self.label_ae_noise * torch.randn_like(z1) + y_pred = self.out_projection(z1_noised) + label_ae_loss = self.label_ae_criterion(y_pred, Y) + + # total loss + loss = flow_loss + label_ae_loss + + # logging + self.log_dict({ + 'train/loss': loss.item(), + 'train/flow_loss': flow_loss.item(), + 'train/label_ae_loss': label_ae_loss.item(), + }) + + return loss + + def training_step(self, batch, batch_idx): + if self.method == 'node': + return self.node_training_step(batch, batch_idx) + elif self.method == 'ours': + return self.ours_training_step(batch, batch_idx) + else: + raise ValueError(f'Method {self.method} not supported') + + def on_validation_epoch_start(self): + self.metrics = {} + for dataloader_type in ['val', 'test']: + for num_euler_steps in [1, 2, 10, 20]: + self.metrics[f'{dataloader_type}/{self.metric_type}_{num_euler_steps}'] = torch.tensor(0.) + self.metrics[f'{dataloader_type}/{self.metric_type}_dopri'] = 0. + self.metrics[f'{dataloader_type}/dopri_nfe'] = 0. + + self.data_count = { + 'val': 0., + 'test': 0., + } + # for label autoencoding accuracy + self.num_classes = None + + def on_validation_model_zero_grad(self): + ''' + Small hack to avoid validation step on resume. + This will NOT work if the gradient accumulation step should be performed at this point. + We raise StopIteration Exception to make training_epoch_loop.run() stop, just before val_loop.run(). + See training_epoch_loop.run(), and ~.on_advance_end(). + ''' + super().on_validation_model_zero_grad() + if self.trainer.ckpt_path is not None and getattr(self, '_restarting_skip_val_flag', True): + self._restarting_skip_val_flag = False + raise StopIteration + + def validation_step(self, batch, batch_idx, dataloader_idx=0): + if batch_idx == 0: + self.num_classes = batch[1].size(-1) + + if dataloader_idx == 0: # test set + dataloader_type = 'test' + elif dataloader_idx == 1: # optional train-subset or held-out validation set. + dataloader_type = 'val' + else: + raise ValueError(f'Unknown dataloader_idx: {dataloader_idx}') + + X, Y = batch + bs = X.size(0) + self.data_count[dataloader_type] += bs + # final metric and data/latent mse + self.evaluate(batch, method='dopri5', num_timesteps=1+1, dataloader_type=dataloader_type, save_mse=True) + for num_euler_steps in [1, 2, 10, 20]: + self.evaluate(batch, method='euler', num_timesteps=num_euler_steps+1, dataloader_type=dataloader_type, save_mse=False) + + def on_validation_epoch_end(self): + gathered_metrics: Dict[str, List[torch.Tensor]] = self.all_gather(self.metrics) + summed_metrics = {k: v.sum().item() for k, v in gathered_metrics.items()} + gathered_data_count: Dict[str, List[torch.Tensor]] = self.all_gather(self.data_count) + total_data_count = {k: v.sum().item() for k, v in gathered_data_count.items()} + self.metrics = summed_metrics + self.data_count = total_data_count + + # calculate average + for k, v in self.metrics.items(): + dataloader_type = k.split('/')[0] + self.metrics[k] /= max(self.data_count[dataloader_type], 1) + + if self.metric_type == 'rmse': + for k, v in self.metrics.items(): + if 'rmse' in k: + self.metrics[k] = v ** 0.5 + + # update best metric + if self.metric_type == 'accuracy': + if self.metrics['val/accuracy_dopri'] > self.best_metric['val/best_acc']: + self.best_metric['val/best_acc'] = self.metrics['val/accuracy_dopri'] + + if self.metrics['test/accuracy_dopri'] > self.best_metric['test/best_acc']: + self.best_metric['test/best_acc'] = self.metrics['test/accuracy_dopri'] + elif self.metric_type == 'rmse': + if self.metrics['val/rmse_dopri'] < self.best_metric['val/best_rmse']: + self.best_metric['val/best_rmse'] = self.metrics['val/rmse_dopri'] + + if self.metrics['test/rmse_dopri'] < self.best_metric['test/best_rmse']: + self.best_metric['test/best_rmse'] = self.metrics['test/rmse_dopri'] + + # log all metrics + self.log_dict(self.metrics, sync_dist=True) + self.log_dict(self.best_metric, sync_dist=True, prog_bar=True) + + def on_fit_end(self): + save_path = os.path.join(self.logger.save_dir, f'last_step={self.trainer.global_step}.ckpt') + self.trainer.save_checkpoint(save_path) + print(f'[!] Saved last checkpoint at {save_path}') diff --git a/models/conv_model.py b/models/conv_model.py new file mode 100644 index 0000000..b450570 --- /dev/null +++ b/models/conv_model.py @@ -0,0 +1,152 @@ +import torch +import torch.nn as nn +from utils import * +from .base import * + + +class ConcatConv2d(nn.Module): + def __init__(self, dim_in, dim_out, ksize=3, stride=1, padding=0, dilation=1, groups=1, bias=True, transpose=False): + super(ConcatConv2d, self).__init__() + module = nn.ConvTranspose2d if transpose else nn.Conv2d + self._layer = module( + dim_in + 1, dim_out, kernel_size=ksize, stride=stride, padding=padding, dilation=dilation, groups=groups, + bias=bias + ) + + def forward(self, t, x): + tt = torch.ones_like(x[:, :1, :, :]) * t + ttx = torch.cat([tt, x], 1) + return self._layer(ttx) + + +class UnitBlock(nn.Module): + def __init__(self, dim, conv_t_dependent=True): + super(UnitBlock, self).__init__() + + self.act_func = nn.ReLU(inplace=False) + self.conv_t_dependent = conv_t_dependent + if conv_t_dependent: + self.conv1 = ConcatConv2d(dim, dim, 3, 1, 1) + else: + self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, bias=False) + + def forward(self, t, x): + t_vec = torch.ones(x.shape[0], 1, device=x.device) * t.squeeze().reshape(-1, 1) + out = self.act_func(x) + if self.conv_t_dependent: + out = self.conv1(t, out) + else: + out = self.conv1(out) + return out + + +class ConvBlock(nn.Module): + def __init__(self, dim): + super().__init__() + self.conv1 = nn.Conv2d(dim, dim, 3, 1, 1) + self.act_func = nn.ReLU(inplace=False) + + def forward(self, x): + out = self.conv1(x) + out = self.act_func(out) + return out + + +class ConvODEfunc(nn.Module): + + def __init__(self, dim, hidden_dim=0, add_blocks=0, conv_t_dependent=True): + super(ConvODEfunc, self).__init__() + + self.act_func = nn.ReLU(inplace=False) + if hidden_dim == 0: + hidden_dim = dim + + self.conv_t_dependent = conv_t_dependent + conv_cls = ConcatConv2d if conv_t_dependent else nn.Conv2d + self.conv1 = conv_cls(dim, hidden_dim, 3, 1, 1) + self.conv2 = conv_cls(hidden_dim, dim, 3, 1, 1) + self.nfe = 0 + + blocks = [UnitBlock( + hidden_dim, conv_t_dependent=conv_t_dependent) for _ in range(add_blocks)] + self.blocks = nn.ModuleList(blocks) + + def forward(self, t, x): + self.nfe += 1 + if t.ndim == 0: + t = append_dims(torch.ones(x.shape[0], device=x.device), x.ndim) * t + assert t.ndim == x.ndim, (t.ndim, x.ndim) + + out = self.act_func(x) + if self.conv_t_dependent: + out = self.conv1(t, out) + else: + out = self.conv1(out) + # add_blocks + for block in self.blocks: + out = block(t, out) + + out = self.act_func(out) + + if self.conv_t_dependent: + out = self.conv2(t, out) + else: + out = self.conv2(out) + + return out + + +class ConvModel(BaseModel): + def __init__(self, *args, data_dim=3, emb_res=(7, 7), latent_dim=64, + in_latent_dim=64, hidden_dim=0, h_add_blocks=0, f_add_blocks=0, + g_add_blocks=0, num_classes=10, **kwargs): + ''' + params: + - data_dim: input dimension + - emb_res: spatial resolution at embedding space + ''' + super().__init__(*args, **kwargs) + in_proj_layer = [ + nn.Conv2d(data_dim, in_latent_dim, 3, 1), + nn.ReLU(inplace=False), + ] + + for _ in range(f_add_blocks): + in_proj_layer += [ + nn.Conv2d(in_latent_dim, in_latent_dim, 3, 1, 1), + nn.ReLU(inplace=False), + ] + + in_proj_layer += [ + nn.Conv2d(in_latent_dim, in_latent_dim, 4, 2, 1), + nn.ReLU(inplace=False), + nn.Conv2d(in_latent_dim, latent_dim, 4, 2, 1), + ] + self.in_projection = nn.Sequential( + *in_proj_layer + ) + + # out projection + out_proj_layer = [ + ConvBlock(latent_dim) for _ in range(g_add_blocks) + ] + out_proj_layer += [ + nn.ReLU(inplace=False) if g_add_blocks == 0 else nn.Identity(), + nn.AdaptiveAvgPool2d((1, 1)), + nn.Flatten(), + nn.Linear(latent_dim, num_classes), + ] + self.out_projection = nn.Sequential( + *out_proj_layer + ) + # label projection + label_proj_layer = [ + nn.Linear(num_classes, latent_dim), + AppendRepeat(emb_res), + ] + self.label_projection = nn.Sequential( + *label_proj_layer + ) + + odefunc = ConvODEfunc(latent_dim, hidden_dim=hidden_dim, add_blocks=h_add_blocks) + self.odeblock = ODEBlock(self.device, odefunc, is_conv=True, adjoint=self.adjoint) diff --git a/models/dynamics.py b/models/dynamics.py new file mode 100644 index 0000000..ba6a4b2 --- /dev/null +++ b/models/dynamics.py @@ -0,0 +1,86 @@ +import torch +import numpy as np +from torch.func import vmap, jacfwd +from utils import append_dims + + +class BaseDynamics: + def __init__(self, auto_init=True): + if auto_init: + self._auto_init() + + def _auto_init(self): + self._alpha_dot = vmap(jacfwd(self.alpha, argnums=0)) + self._beta_dot = vmap(jacfwd(self.beta, argnums=0)) + + def get_zt(self, z0, z1, t): + t = t.squeeze().reshape(z0.shape[0]) + return append_dims(self.alpha(t), z0.ndim) * z0 + append_dims(self.beta(t), z1.ndim) * z1 + + def get_vt(self, z0, z1, t): + t = t.squeeze().reshape(z0.shape[0]) + return append_dims(self.alpha_dot(t), z0.ndim) * z0 + append_dims(self.beta_dot(t), z1.ndim) * z1 + + def alpha(self, t): + raise NotImplementedError + + def beta(self, t): + raise NotImplementedError + + def _alpha_dot(self, t): + raise NotImplementedError + + def _beta_dot(self, t): + raise NotImplementedError + + def alpha_dot(self, t): + # can be auto computed + return self._alpha_dot(t).to(t.dtype) + + def beta_dot(self, t): + # can be auto computed + return self._beta_dot(t).to(t.dtype) + + +class LinearDynamics(BaseDynamics): + def alpha(self, t): + return 1 - t + + def beta(self, t): + return t + + +class ConcaveDynamics(BaseDynamics): + def alpha(self, t): + return torch.cos(t * np.pi / 2) + + def beta(self, t): + return torch.sin(t * np.pi / 2) + + +class ConvexDynamics(BaseDynamics): + def alpha(self, t): + return 1 - torch.sin(t * np.pi / 2) + + def beta(self, t): + return 1 - torch.cos(t * np.pi / 2) + + +def get_dynamics(name, return_v=True): + ''' + just for backward compatibility + ''' + assert name in [ + 'linear', + 'concave', + 'covnex', + ] + assert not return_v + if name == 'linear': + return LinearDynamics() + elif name == 'concave': + return ConcaveDynamics() + elif name == 'convex': + return ConvexDynamics() + else: + raise ValueError(f'Unknown dynamics name: {name}') diff --git a/models/mlp_model.py b/models/mlp_model.py new file mode 100644 index 0000000..25c8615 --- /dev/null +++ b/models/mlp_model.py @@ -0,0 +1,216 @@ + +import torch +import torch.nn as nn +from .base import * + + +class tMLPBlock(nn.Module): + def __init__(self, t_dim, hidden_dim): + super().__init__() + + self.fc1 = nn.Linear(hidden_dim+t_dim, hidden_dim) + + def forward(self, x, t): + out = self.fc1(x) + return out + + +class MLPBlock(nn.Module): + def __init__(self, in_dim, out_dim): + super().__init__() + self.fc1 = nn.Linear(in_dim, out_dim) + self.relu = nn.ReLU() + + def forward(self, x): + out = self.fc1(x) + out = self.relu(out) + return out + + +class MLPODEFunc(nn.Module): + """MLP modeling the derivative of ODE system. + + Parameters + ---------- + device : torch.device + + data_dim : int + Dimension of data. + + hidden_dim : int + Dimension of hidden layers. + + time_dependent : bool + If True adds time as input, making ODE time dependent. + """ + + def __init__(self, device, data_dim, hidden_dim, time_dependent=True, h_add_blocks=0): + super(MLPODEFunc, self).__init__() + self.device = device + self.data_dim = data_dim + self.input_dim = data_dim + self.hidden_dim = hidden_dim + self.nfe = 0 # Number of function evaluations + t_dim = 0 + + self.time_dependent = time_dependent + + self.h_add_blocks = h_add_blocks + if self.h_add_blocks == -1: + print('[!] Using identity ODEFunc..') + return + + if time_dependent: + t_dim = 1 + + self.fc1 = nn.Linear(self.input_dim + t_dim, hidden_dim) + self.fc2 = nn.Linear(hidden_dim + t_dim, hidden_dim) + self.fc3 = nn.Linear(hidden_dim + t_dim, self.input_dim) + + self.non_linearity = nn.ReLU(inplace=False) + + self.h_add_blocks = h_add_blocks + if h_add_blocks > 0: + tmlp_layers = [tMLPBlock(t_dim, hidden_dim) for _ in range(h_add_blocks)] + self.tmlp_blocks = nn.ModuleList(tmlp_layers) + + def forward(self, t, x): + """ + Parameters + ---------- + t : torch.Tensor + Current time. Shape (1,). + + x : torch.Tensor + Shape (batch_size, input_dim) + """ + # Forward pass of model corresponds to one function evaluation, so + # increment counter + self.nfe += 1 + if self.h_add_blocks == -1: + return x # identity + + if self.time_dependent: + t_vec = torch.ones(x.shape[0], 1, device=x.device) * t + out = self.fc1(torch.cat([t_vec, x], 1)) + out = self.non_linearity(out) + out = self.fc2(torch.cat([t_vec, out], 1)) + out = self.non_linearity(out) + + if self.h_add_blocks > 0: + for i in range(self.h_add_blocks): + out = self.tmlp_blocks[i](torch.cat([t_vec, out], 1), t_vec) + out = self.non_linearity(out) + + out = self.fc3(torch.cat([t_vec, out], 1)) + + else: + out = self.fc1(x) + out = self.non_linearity(out) + out = self.fc2(out) + out = self.non_linearity(out) + + if self.h_add_blocks > 0: + for i in range(self.h_add_blocks): + out = self.tmlp_blocks[i](out, t=None) + out = self.non_linearity(out) + + out = self.fc3(out) + + return out + + +class MLPModel(BaseModel): + """An ODEBlock followed by a Linear layer. + + Parameters + ---------- + device : torch.device + + data_dim : int + Dimension of data. + + hidden_dim : int + Dimension of hidden layers. + + output_dim : int + Dimension of output after hidden layer. Should be 1 for regression or + num_classes for classification. + + time_dependent : bool + If True adds time as input, making ODE time dependent. + + tol : float + Error tolerance. + + adjoint : bool + If True calculates gradient with adjoint method, otherwise + backpropagates directly through operations of ODE solver. + """ + + def __init__(self, *args, data_dim, hidden_dim, latent_dim=None, output_dim=1, + time_dependent=True, tol=1e-3, in_proj=False, out_proj=False, proj_norm='none', + f_add_blocks=0, h_add_blocks=0, g_add_blocks=0, **kwargs): + super().__init__(*args, **kwargs) + + self.data_dim = data_dim + self.hidden_dim = hidden_dim + self.output_dim = output_dim + self.time_dependent = time_dependent + self.tol = tol + + if latent_dim is None: + latent_dim = data_dim + self.latent_dim = latent_dim + + odefunc = MLPODEFunc(self.device, self.latent_dim, hidden_dim, + time_dependent, h_add_blocks=h_add_blocks) + + self.odeblock = ODEBlock(self.device, odefunc, tol=tol, adjoint=self.adjoint) + + if in_proj == 'identity' or in_proj is False: + self.in_projection = nn.Flatten() + elif in_proj == 'mlp' or in_proj is True: + latent_dim = self.odeblock.odefunc.input_dim + in_projection = [ + nn.Flatten(), + nn.Linear(data_dim, latent_dim), + nn.BatchNorm1d(latent_dim) if proj_norm == 'bn' else nn.Identity(), + nn.ReLU(), + nn.Linear(latent_dim, latent_dim), + nn.BatchNorm1d(latent_dim) if proj_norm == 'bn' else nn.Identity(), + ] + for _ in range(f_add_blocks): + in_projection.extend([ + nn.ReLU(), + nn.Linear(latent_dim, latent_dim), + ]) + self.in_projection = nn.Sequential( + *in_projection + ) + else: + raise ValueError(f'Invalid in_proj {type(in_proj)} {in_proj}') + + if out_proj == 'identity' or out_proj is False: + self.out_projection = nn.Identity() + elif out_proj == 'linear' or out_proj is True: + self.out_projection = nn.Linear(self.odeblock.odefunc.input_dim, + self.output_dim,) + elif out_proj == 'mlp': + latent_dim = self.odeblock.odefunc.input_dim + self.out_projection = nn.Sequential( + nn.Linear(latent_dim, latent_dim), + nn.BatchNorm1d(latent_dim) if proj_norm == 'bn' else nn.Identity(), + nn.ReLU(), + nn.Linear(latent_dim, self.output_dim) + ) + else: + raise ValueError(f'Invalid out_proj {type(out_proj)} {out_proj}') + + if g_add_blocks > 0: + out_proj = [MLPBlock(latent_dim, latent_dim) for _ in range(g_add_blocks)] + out_proj += [nn.Linear(latent_dim, self.output_dim)] + self.out_projection = nn.Sequential(*out_proj) + + latent_dim = self.odeblock.odefunc.input_dim + self.label_projection = nn.Linear(self.output_dim, latent_dim) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..6c0c1b1 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,207 @@ +absl-py==2.0.0 +accelerate==0.17.0 +aiohttp==3.9.1 +aiosignal==1.3.1 +antlr4-python3-runtime==4.9.3 +anyio==4.1.0 +appdirs==1.4.4 +argon2-cffi==23.1.0 +argon2-cffi-bindings==21.2.0 +arrow==1.3.0 +astroid==3.0.1 +astunparse==1.6.3 +async-lru==2.0.4 +async-timeout==4.0.3 +attrs==23.1.0 +autopep8==2.0.4 +Babel==2.13.1 +bitsandbytes==0.42.0 +bleach==6.1.0 +blessed==1.20.0 +brotlipy==0.7.0 +cachetools==5.3.2 +comm==0.2.0 +contextlib2==21.6.0 +contourpy==1.2.0 +cycler==0.12.1 +debugpy==1.8.0 +defusedxml==0.7.1 +diffusers==0.18.0 +dill==0.3.7 +dlib==19.24.2 +dnspython==2.4.2 +docker-pycreds==0.4.0 +docstring-parser==0.15 +docstring-to-markdown==0.13 +easydict==1.11 +easyocr==1.7.1 +einops==0.7.0 +expecttest==0.1.6 +fastjsonschema==2.19.0 +flake8==6.1.0 +fonttools==4.46.0 +fqdn==1.5.1 +frozenlist==1.4.1 +fsspec==2023.9.2 +ftfy==6.1.3 +gitdb==4.0.11 +GitPython==3.1.40 +google-auth==2.25.1 +google-auth-oauthlib==1.1.0 +gpustat==1.1.1 +grpcio==1.59.3 +huggingface-hub==0.19.4 +hydra-core==1.3.2 +hypothesis==6.87.1 +imageio==2.33.0 +importlib-metadata==7.0.0 +importlib-resources==6.1.1 +inflect==6.0.4 +ipdb==0.13.13 +ipykernel==6.27.1 +ipywidgets==8.1.1 +isoduration==20.11.0 +isort==5.12.0 +jedi==0.17.2 +joblib==1.3.2 +json5==0.9.14 +jsonargparse==4.27.3 +jsonpointer==2.1 +jsonschema==4.20.0 +jsonschema-specifications==2023.11.2 +jupyter==1.0.0 +jupyter-console==6.6.3 +jupyter-events==0.9.0 +jupyter-lsp==2.2.1 +jupyter_client==8.6.0 +jupyter_core==5.5.0 +jupyter_server==2.12.1 +jupyter_server_terminals==0.4.4 +jupyterlab==4.0.9 +jupyterlab-lsp==5.0.1 +jupyterlab-widgets==3.0.9 +jupyterlab_pygments==0.3.0 +jupyterlab_server==2.25.2 +kaleido==0.2.1 +kiwisolver==1.4.5 +lazy_loader==0.3 +lightning==2.1.3 +lightning-utilities==0.10.1 +lmdb==1.4.1 +lpips==0.1.4 +Markdown==3.5.1 +markdown-it-py==3.0.0 +matplotlib==3.8.2 +mccabe==0.7.0 +mdurl==0.1.2 +mistune==3.0.2 +mkl-service==2.4.0 +ml-collections==0.1.1 +multidict==6.0.4 +nbclient==0.9.0 +nbconvert==7.12.0 +nbformat==5.9.2 +nest-asyncio==1.5.8 +ninja==1.11.1.1 +notebook==7.0.6 +notebook_shim==0.2.3 +nvidia-ml-py==12.535.133 +oauthlib==3.2.2 +omegaconf==2.3.0 +opencv-python==4.8.1.78 +opencv-python-headless==4.8.1.78 +overrides==7.4.0 +pandas==2.1.3 +pandocfilters==1.5.0 +parso==0.7.1 +platformdirs==4.1.0 +plotly==5.16.1 +prometheus-client==0.19.0 +protobuf==4.23.4 +pyarrow==15.0.0 +pyasn1==0.5.1 +pyasn1-modules==0.3.0 +pyclipper==1.3.0.post5 +pycodestyle==2.11.1 +pydantic==1.10.9 +pydocstyle==6.3.0 +pyflakes==3.1.0 +pylint==3.0.2 +pyparsing==3.1.1 +python-bidi==0.4.2 +python-dateutil==2.8.2 +python-etcd==0.4.5 +python-json-logger==2.0.7 +python-jsonrpc-server==0.4.0 +python-lsp-jsonrpc==1.1.2 +python-lsp-server==1.9.0 +pytoolconfig==1.2.6 +pytorch-fid==0.3.0 +pytorch-lightning==2.1.3 +pyzmq==25.1.2 +qtconsole==5.5.1 +QtPy==2.4.1 +referencing==0.31.1 +regex==2023.10.3 +requests-oauthlib==1.3.1 +rfc3339-validator==0.1.4 +rfc3986-validator==0.1.1 +rich==13.7.0 +rope==1.11.0 +rpds-py==0.13.2 +rsa==4.9 +safetensors==0.4.1 +scikit-image==0.22.0 +scikit-learn==1.3.2 +scipy==1.11.4 +seaborn==0.13.0 +Send2Trash==1.8.2 +sentry-sdk==1.38.0 +setproctitle==1.3.3 +shapely==2.0.2 +smmap==5.0.1 +sniffio==1.3.0 +snowballstemmer==2.2.0 +sortedcontainers==2.4.0 +tenacity==8.2.3 +tensorboard==2.15.1 +tensorboard-data-server==0.7.2 +tensorboardX==2.6.2.2 +terminado==0.18.0 +theme-darcula==4.0.0 +threadpoolctl==3.2.0 +tifffile==2023.9.26 +timm==0.9.12 +tinycss2==1.2.1 +tokenizers==0.13.3 +tomlkit==0.12.3 +torch==2.1.0 +torchaudio==2.1.0 +torchdiffeq==0.2.3 +torchelastic==0.2.2 +torchmetrics==1.3.0.post0 +torchode==0.2.0 +torchsummary==1.5.1 +torchtyping==0.1.4 +torchvision==0.16.0 +tornado==6.4 +transformers==4.30.2 +triton==2.1.0 +typeguard==4.1.5 +types-dataclasses==0.6.6 +types-python-dateutil==2.8.19.14 +typeshed-client==2.4.0 +tzdata==2023.3 +ujson==5.8.0 +uri-template==1.3.0 +wandb==0.16.1 +wcwidth==0.2.12 +webcolors==1.13 +webencodings==0.5.1 +websocket-client==1.7.0 +Werkzeug==3.0.1 +whatthepatch==1.0.5 +widgetsnbextension==4.0.9 +yapf==0.40.2 +yarl==1.9.4 +zipp==3.17.0 diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..dd9831f --- /dev/null +++ b/utils.py @@ -0,0 +1,22 @@ +def append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError( + f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" + ) + return x[(...,) + (None,) * dims_to_append] + + +def compute_task_criterion(task_type): + if task_type == 'classification': + return 'ce' + else: + return 'mse' + + +def compute_metric_type(task_type): + if task_type == 'classification': + return 'accuracy' + else: + return 'rmse'