Please note: this is research code that is not actively maintained, if you find any issues please let us know.
Public code for the TMLR 2023 paper Faster Training of Neural ODEs Using Gauß--Legendre Quadrature ( Alexander Norcliffe, Marc Deisenroth )
We recommend reading through this documentation fully before using the code.
Neural ODEs demonstrate strong performance in generative and time-series modelling. However, training them via the adjoint method is slow compared to discrete models due to the requirement of numerically solving ODEs. To speed neural ODEs up, a common approach is to regularise the solutions. However, this approach may affect the expressivity of the model; when the trajectory itself matters, this is particularly important. In this paper, we propose an alternative way to speed up the training of neural ODEs. The key idea is to speed up the adjoint method by using Gauß--Legendre quadrature to solve integrals faster than ODE-based methods while remaining memory efficient. We also extend the idea to training SDEs using the Wong--Zakai theorem, by training a corresponding ODE and transferring the parameters. Our approach leads to faster training of neural ODEs, especially for large models. It also presents a new way to train SDE-based models.
To run all the code in the library you will need the majority of the standard packages for machine learning in python. Including but not limited to:
- numpy
- torch
- matplotlib
- scipy
- sklearn
Importantly, the library also requires the use of these non standard libraries:
- torchdiffeq
- sdeint
- TorchDiffEqPack
- torchsde
We recommend using a virtual environment to install the packages.
To create the data in the paper run the following commands:
$ python -m datasets.make_nested_spheres_data
$ python -m datasets.make_sine_data
$ python -m datasets.make_ou_data
Experiments are run from the command line from the home directory. Each experiment
from the paper has its own file in the experiments
directory. For example to run the
nested spheres experiment run the following command:
$ python -m experiments.nested_spheres --adjoint_option adjoint_gq --experiment_no 1 --width 2000 --lr 0.0001
In the above, the adjoint_option
parameter is used to select the method of backprop to use,
experiment_no
is used to create the seed for consistency across methods, increase this
to run the experiment multiple times. width
is used to set the width of the model,
and lr
is used to set the learning rate. The model hyperparameters need to match those listed in
the appendix of the paper.
Each experiment has its own specific arguments to run from the command line. Therefore we recommend looking at the code for each experiment to see the arguments that are required. Another example is the sine experiment, requiring this command:
$ python -m experiments.sines --adjoint_option adjoint_ode --experiment_no 1 --width 1000 --lr 0.00003 --regular_times regular
The difference is that the adjoint option has changed from adjoint_gq
to adjoint_ode
,
so now we use the standard adjoint method rather than the GQ method. The regular_times
parameter is used to select whether the time-series is regularly spaced or irregularly spaced.
After running the experiments, the results are saved in the results
directory, the plotting code can be used to create the figures in the paper (NOTE below about plotting code).
For the users benefit we have included the plotting code we used to generate the plots in the paper. However, this was not written with the intention of being used by others, so it is not well documented and may be difficult to use. We have not tested it extensively since writing the paper.
Running each experiment individually is not efficient. We recommend writing bash scripts to run the experiments if running all of them with many repeats.
The models in this repository are designed to be used in the experiments and have been
built as such. They are not designed to be used as a general Neural ODE. We recommend
using just the torch_gq_adjoint
part of the library, since it has been designed to
be used in the same way the torchdiffeq
library is used.
All nn modules used as the function in the Neural ODE must have a nfe
attribute.
This is used to count the number of function evaluations, to calculate the number of
terms in the quadrature calculation. To do this in the __init__
of the module include the
line:
self.nfe = 0
And then in the forward
method include:
self.nfe += 1
The gtol
parameter is used to determine the number of terms in the quadrature calculation.
It takes the place of rtol
and atol
parameters, these are used to determine the tolerance of the ODE solver, the lower these
numbers, the more steps are used in the solve. The gtol
parameter is used to determine
the number of terms in the quadrature calculation, the higher this number, the more terms are used.
Please note the relationship is not inverse like atol
and rtol
.
Below we include an example of using the adjoint and the GQ methods to show the difference:
from torch_gq_adjoint import odeint_adjoint, odeint_adjoint_gq
output_adj = odeint_adjoint(model, x, t, rtol=1e-4, atol=1e-6) # adjoint method
output_gq = odeint_adjoint_gq(model, x, t, gtol=0.1, rtol=1e-4, atol=1e-6) # gq method
The difference is that the GQ method requries a gtol
parameter (the default is nfe
attribute, as mentioned above.
We also include a more comprehensive example in the g1d_example.ipynb
notebook.
As described in the paper, we recommend using the GQ method when there is a small state size (for example, a small batchsize or small vector size) and when the model has many parameters.
As described below, we also found that the speedups are more extreme on CPU than GPU, so when using limited resources we recommend using the GQ method as well.
To test the accuracy of the gradients produced by the GQ method we repeat the tests from the torchdiffeq library. To run the tests run the following command:
$ python -m gradient_tests.gradient_tests
The following fixed methods are included in the tests:
'euler', 'midpoint', 'rk4', 'explicit_adams', 'implicit_adams'
The following adaptive methods are included in the tests:
'dopri5', 'bosh3', 'adaptive_heun', 'dopri8'
Newly released torch differentiable ODE solvers are not included in the tests. Since this repo is not actively maintained.
Please note: The paper only tests the gradient speed up of the dopri5 solver, the other solvers are included for completeness. We believe these speedups would also be seen in the other solvers, but we have not tested this.
We found that the speed up of the GQ method becomes more apparent on less advanced hardware. On a CPU the GQ method was significantly faster than the standard adjoint method, and often faster than directly backpropagating. On the GPU that we used we found for large models the GQ method scaled better than the adjoint method. So we recommend using the appropriate method for the hardware you are using, as mentioned, the GQ method is more effective compared to the adjoint method on less advanced hardware.
If you find this code or our paper useful in your own research, please cite our paper:
@article{norcliffe2023gq,
title={{F}aster {T}raining of {N}eural {ODE}s {U}sing {G}au{\ss}{\textendash}{L}egendre {Q}uadrature},
author={Norcliffe, Alexander and Deisenroth, Marc},
journal={{T}ransactions on {M}achine {L}earning {R}esearch},
year={2023}
}
At the time of publishing this work, Alexander Norcliffe is supported by a GlaxoSmithKline grant.
We thank So Takao for his help with understanding the Wong--Zakai theorem for the SDE
adaptation of the GQ method.
We would like to thank the anonymous reviewers and action editor
Kevin Swersky for their time and efforts to review and constructively critique the paper.
A very large amount of this work was heavily dependent on the torchdiffeq
library,
we thank the authors Ricky Chen, Yulia Rubanova, Jesse Bettencourt
and David Duvenaud for their work on this library.