The primary goal of this codebase is to extend dynamax to a continuous-discrete (CD) state-space-modeling setting:
- that is, to problems where the underlying dynamics are continuous in time and measurements can arise at arbitrary (i.e., non-regular) discrete times.
To address these gaps, cd-dynamax
modifies dynamax
to accept irregularly sampled data and implements classical algorithms for continuous-discrete filtering and smoothing.
In this repository, build an expanded toolkit for learning and predicting dynamical systems that underpin real-world messy time-series data. We move towards this goal by introducing the following flexible mathematical setting.
We assume there exists a (possibly unknown) stochastic dynamical system of form
where
We further assume that data are available at arbitrary times
where
We denote the collection of all parameters as
Note:
-
We assume
$\eta(t)$ i.i.d. w.r.t.$t$ :- This assumption places us in the continuous (dynamics) - discrete (observation) setting.
- If
$\eta(t)$ had temporal correlations, we would likely adopt a mathematical setting that defines the observation process continuously in time via its own SDE.
-
Other extensions of the above paradigm include categorical state-spaces and non-additive observation noise distributions
- These can fit into our code framework (indeed, some are covered in
dynamax
), but have not been our focus.
- These can fit into our code framework (indeed, some are covered in
For a given set of observations
- Filter: estimate
$x(t_K) \ | \ Y_K, \ \theta$ - Smooth: estimate
$\{x(t)\}_t \ | \ Y_K, \ \theta$ - Predict: estimate
$x(t > t_K)\ |\ Y_K, \ \theta$ - Infer parameters: estimate
$\theta \ |\ Y_K$
All of these problems are deeply interconnected.
-
In cd-dynamax, we enable filtering, smoothing, and parameter inference for a single system under multiple trajectory observations (
$[Y^{(1)}, \ \dots \, \ Y^{(N)}]$ .-
In these cases, we assume that each trajectory represents an independent realization of the same dynamics-data model, which we may be interested in learning, filtering, smoothing, or predicting.
- In the future, we would like to have options to perform hierarchical inference, where we assume that each trajectory came from a different, yet similar set of system-defining parameters
$\theta^{(n)}$ .
- In the future, we would like to have options to perform hierarchical inference, where we assume that each trajectory came from a different, yet similar set of system-defining parameters
-
We implement such filtering/smoothing algorithms in a fast, autodifferentiable framework, we enable usage of modern general-purpose tools for parameter inference (e.g., stochastic gradient descent, Hamiltonian Monte Carlo).
-
-
In cd-dynamax, we take onto the parameter inference case by relying on marginalizing out unobserved states
$\{x(t)\}_t$ - this is a design choice of ours, other alternatives are possible.
- This marginalization is performed (approximately, in cases of non-linear dynamics) via filtering/smoothing algorithms.
-
We are leveraging dynamax code
- Currently, based on a local directory with Dynamax release 0.1.5
-
We have implemented continuous-discrete linear and non-linear models, along with filtering and smoothing algorithms.
- If you are simulating data from a non-linear SDE, it is recommended to use
model.sample(..., transition_type="path")
, which runs an SDE solver.- Default behavior is to perform Gaussian approximations to the SDE.
- If you are simulating data from a non-linear SDE, it is recommended to use
-
For comparison purposes, we provide example notebooks for linear continuous-discrete filtering/smoothing under regular and irregular sampling
- Tracking
- Parameter estimation that marginalizes out un-observed dynamics via auto-differentiable filtering (MLE via SGD; uncertainty quantification via HMC)
-
For more interesting continuous-discrete, nonlinear models, see our new tutorials for examples of how to use the codebase.
- We provide a tutorial REAMDE describing each of the tutorials
- Highlights include a notebook for learning neural network based drift functions from partial, noisy, irregularly-spaced observations!
- We provide a working conda environment
- with dependencies installed using the pip-based requirements file
# For CPU
$ conda create --name cd_dynamax python=3.11.4
$ conda activate cd_dynamax
$ conda install pip
$ pip install -r hduq_cd_dynamax_requirements.txt
# For GPU
$ conda create --name cd_dynamax_GPU python=3.11.4
$ conda activate cd_dynamax_GPU
$ conda install pip
$ pip install -r hduq_cd_dynamax_requirements.txt
$ pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
$ pip install jax==0.4.13 jaxlib==0.4.13+cuda12.cudnn89 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html