Skip to content

Latest commit

 

History

History
65 lines (43 loc) · 4.55 KB

README.md

File metadata and controls

65 lines (43 loc) · 4.55 KB

cd-dynamax source code description

We provide the following modifications of the dynamax codebase, to accommodate continuous-discrete models, i.e., those where observations are not assumed to be regularly sampled.

  • A modified version of dynamax's ssm.py that incorporates non-regular emission time instants: i.e., the t_emissions array
    • t_emissions is an input argument
      • We use t0 and t1 refer to $t_k$ and $t_{k+1}$, not necessarily regularly sampled
    • t_emissions is a matrix of size $[\textrm{num observations} \times 1]$
      • it should facilitate batching
      • For lax.scan() operations, we recast them in vector shape (i.e., remove final dimension)
  • diffrax_utils.py

    • implements a diffrax based, autodifferentiable ODEsolver
  • test_utils.py

  • plotting_utils.py

  • debug_utils.py

    • Debugging in jax can be difficult---pre-compilation speedups cause typical usage of in-line python debuggers to fail. To make debugging easier, we implemented a wrapper for lax.scan which, with debug=True, runs a (slow, but in-line debuggable!) for loop instead of lax.scan.
    • To use this in a particular piece of code, simply add from utils.debug_utils import lax_scan and replace an existing lax.scan call you wish to debug with lax_scan(..., debug=True).
    • This is an experimental feature, so please report any issues that arise from using this tool---we hope it helps ease the transition into using jax!
  • Establishes functionality of linear and non-linear filters/smoothers, as well as parameter fitting via SGD.
  • Checks that non-linear algorithms applied to linear problems return similar results as linear algorithms.
  • Example notebooks, with filtering/smoothing of linear and nonlinear continuous-discrete dynamic models.