We provide the following modifications of the dynamax codebase, to accommodate continuous-discrete models, i.e., those where observations are not assumed to be regularly sampled.
- A modified version of dynamax's ssm.py that incorporates non-regular emission time instants: i.e., the t_emissions array
-
t_emissions
is an input argument- We use
t0
andt1
refer to$t_k$ and$t_{k+1}$ , not necessarily regularly sampled
- We use
-
t_emissions
is a matrix of size$[\textrm{num observations} \times 1]$ - it should facilitate batching
- For
lax.scan()
operations, we recast them in vector shape (i.e., remove final dimension)
-
-
We define a ContDiscreteLinearGaussianSSM model
-
We do not currently provide a ContDiscreteLinearGaussianConjugateSSM model implementation, as CD parameter conjugate priors are non-trivial
-
The CD-LGSSM model is based on
- A continuous-time push-forward operation that computes and returns matrices A and Q
-
-
Continuous-Discrete Kalman filtering and smoothing algorithms are implemented
-
Parameter (point)-estimation is possible via stochastic gradient descent based MLE
- where the marginal log-likelihood is computed based on the CD-Kalman filter
-
We define a ContDiscreteNonlinearGaussianSSM model
- The CD-NLGSSM model is based on a continuous-time push-forward operation that solves an SDE forward over the mean
$x$ and covariance$P$ of the latent state- the parameters of the SDE function are provided in the ParamsCDNLGSSM object, which contains
- The initial state's prior parameters in ParamsLGSSMInitial, as defined by dynamax
- The dynamics function in ParamsCDNLGSSMDynamics
- The emissions function in ParamsCDNLGSSMEmissions
- These two latter are learnable functions
- the parameters of the SDE function are provided in the ParamsCDNLGSSM object, which contains
- The CD-NLGSSM model is based on a continuous-time push-forward operation that solves an SDE forward over the mean
-
Different filtering and smoothing algorithms are implemented
-
Parameter (point)-estimation is possible via stochastic gradient descent based MLE
- the marginal log-likelihood can be computed according to different implemented filtering methods (EKF, UKF, EnKF)
-
- implements a diffrax based, autodifferentiable ODEsolver
-
- Debugging in jax can be difficult---pre-compilation speedups cause typical usage of in-line python debuggers to fail. To make debugging easier, we implemented a wrapper for
lax.scan
which, withdebug=True
, runs a (slow, but in-line debuggable!)for
loop instead oflax.scan
. - To use this in a particular piece of code, simply add
from utils.debug_utils import lax_scan
and replace an existinglax.scan
call you wish to debug withlax_scan(..., debug=True)
. - This is an experimental feature, so please report any issues that arise from using this tool---we hope it helps ease the transition into using jax!
- Debugging in jax can be difficult---pre-compilation speedups cause typical usage of in-line python debuggers to fail. To make debugging easier, we implemented a wrapper for
- Establishes functionality of linear and non-linear filters/smoothers, as well as parameter fitting via SGD.
- Checks that non-linear algorithms applied to linear problems return similar results as linear algorithms.
- Example notebooks, with filtering/smoothing of linear and nonlinear continuous-discrete dynamic models.