-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcdnlgssm_utils.py
249 lines (199 loc) · 9.67 KB
/
cdnlgssm_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
from typing import NamedTuple, Tuple, Optional, Union
from jaxtyping import Array, Float, PyTree
import jax.numpy as jnp
from dynamax.parameters import ParameterProperties, ParameterSet
import abc
# To avoid unnecessary redefinitions of code,
# We import parameters that can be reused from LGSSM first
# And define the rest later
from dynamax.linear_gaussian_ssm.inference import ParamsLGSSMInitial
# TODO: do we need @dataclass(frozen=True)?
class LearnableFunction(NamedTuple):
''' All Learnable functions should have
params propertie
a definiton of a function that takes as input x, u and t
'''
# Parameters as properties of the class
params: ParameterSet
'''
def __init__(
self,
params,
):
self.params = params
'''
# A function definition
@abc.abstractmethod
def f(self, x, u=None, t=None):
''' A function to be defined by specific classes
With inputs
x: state
u: inputs
t: time
'''
class LearnableVector(NamedTuple):
params: Union[Float[Array, "dim"], ParameterProperties]
def f(self, x=None, u=None, t=None):
return self.params
class LearnableMatrix(NamedTuple):
params: Union[Float[Array, "row_dim col_dim"], ParameterProperties]
def f(self, x=None, u=None, t=None):
return self.params
class LearnableLinear(NamedTuple):
'''Linear function with learnable parameters
weights: weights of the linear function
bias: bias of the linear function
f(x) = weights @ x + bias
'''
weights: Union[Float[Array, "output_dim input_dim"], ParameterProperties]
bias: Union[Float[Array, "output_dim"], ParameterProperties]
def f(self, x, u=None, t=None):
return self.weights @ x + self.bias
class LearnableLorenz63(NamedTuple):
'''Lorenz63 model with learnable parameters
sigma: sigma parameter
rho: rho parameter
beta: beta parameter
f(x) = sigma * (y - x)\n
f(y) = x * (rho - z) - y\n
f(z) = x * y - beta * z\n
'''
sigma: Union[Float, ParameterProperties]
rho: Union[Float, ParameterProperties]
beta: Union[Float, ParameterProperties]
def f(self, x, u=None, t=None):
return jnp.array([
self.sigma * (x[1] - x[0]),
x[0] * (self.rho - x[2]) - x[1],
x[0] * x[1] - self.beta * x[2]
])
# Continuous non-linear Gaussian dynamics
# TODO: function definitions within parameter classes breaks fit_sgd: where should they be placed?
class ParamsCDNLGSSMDynamics(NamedTuple):
r"""Parameters of the state dynamics of a CDNLGSSM model.
This model does not obey an SDE as in Sarkaa's equation (3.151):
the solution to 3.151 is not necessarily a Gaussian Process
(note there are cases where that is indeed the case)
We instead assume an approximation to the model of zero-th, first or second order
The resulting transition and emission distributions are
$$p(z_1) = N(z_1 | m, S)$$
$$p(z_t | z_{t-1}, u_t) = N(z_t | z_t, P_t)$$
$$p(y_t | z_t) = N(y_t | h(z_t, u_t), R_t)$$
If you have no inputs, the dynamics and emission functions do not to take $u_t$ as an argument.
The tuple doubles as a container for the ParameterProperties.
:param drift_function: $f$
:param drift_parameters: parameters $\theta$ of the drift_function
:param diffusion_coefficient: $L$
:param diffusion_cov: $Q$
:param dynamics_approx: 'zeroth', 'first' or 'second'
"""
'''
# the deterministic drift $f$ of the nonlinear RHS of the state
drift_function: Union[FnStateToState, FnStateAndInputToState]
# TODO: How to define learnable parameters for emission function?
#drift_parameters: Union[Float[Array], ParameterProperties]
# the coefficient matrix L of the state's diffusion process
diffusion_coefficient: Union[Float[Array, "state_dim state_dim"], Float[Array, "ntime state_dim state_dim"], ParameterProperties]
# The covariance matrix Q of the state noise process
diffusion_cov: Union[Float[Array, "state_dim state_dim"], Float[Array, "ntime state_dim state_dim"], Float[Array, "state_dim_triu"], ParameterProperties]
'''
# These are all learnable functions to be initialized
drift: LearnableFunction
diffusion_coefficient: LearnableFunction
diffusion_cov: LearnableFunction
# Dynamics SDE approximation order, defined as a Float
approx_order: Union[Float, ParameterProperties]
'''
# Continuous non-linear dynamic parameters
class ParamsCDNLSSMDynamics(NamedTuple):
r"""Parameters of the state dynamics of a CDNLGSSM model.
This model does obey the SDE as in Sarkaa's equation (3.151):
the solution to 3.151 is not necessarily a Gaussian Process
(note there are cases where that is indeed the case)
If you have no inputs, the dynamics and emission functions do not to take $u_t$ as an argument.
The tuple doubles as a container for the ParameterProperties.
:param drift_function: $f$
:param drift_parameters: parameters $\theta$ of the drift_function
:param diffusion_coefficient: $L$
:param diffusion_cov: $Q$
"""
# the deterministic drift $f$ of the nonlinear RHS of the state
drift_function: Union[FnStateToState, FnStateAndInputToState]
# TODO: How to define learnable parameters for dynamics drift function?
#drift_parameters: Union[Float[Array], ParameterProperties]
# the coefficient matrix L of the state's diffusion process
diffusion_coefficient: Union[Float[Array, "state_dim state_dim"], Float[Array, "ntime state_dim state_dim"], ParameterProperties]
# The covariance matrix Q of the state noise process
diffusion_cov: Union[Float[Array, "state_dim state_dim"], Float[Array, "ntime state_dim state_dim"], Float[Array, "state_dim_triu"], ParameterProperties]
'''
# Discrete non-linear emission parameters
# TODO: function definitions within parameter classes breaks fit_sgd: where should they be placed?
class ParamsCDNLGSSMEmissions(NamedTuple):
r"""Parameters of the state dynamics
$$p(z_{t+1} \mid z_t, u_t) = \mathcal{N}(z_{t+1} \mid A z_t + B u_t + b, Q)$$
The tuple doubles as a container for the ParameterProperties.
:param drift_function: $f$
:param drift_parameters: parameters $\theta$ of the drift_function
:param diffusion_coefficient: $L$
:param diffusion_cov: $Q$
:param dynamics_approx: 'zeroth', 'first' or 'second'
"""
# These are all learnable functions to be initialized
emission_function: LearnableFunction
emission_cov: LearnableFunction
'''
# Emission distribution h
emission_function: Union[FnStateToEmission, FnStateAndInputToEmission]
# TODO: How to define learnable parameters for emission function?
# emission_parameters: Union[Float[Array], ParameterProperties]
# The covariance matrix R of the observation noise process
emission_cov: Union[Float[Array, "emission_dim emission_dim"], ParameterProperties]
'''
# CDNLGSSM parameters are different to CDLGSSM due to nonlinearities
class ParamsCDNLGSSM(NamedTuple):
r"""Parameters of a linear Gaussian SSM.
:param initial: initial distribution parameters
:param dynamics: dynamics distribution parameters
:param emissions: emission distribution parameters
The assumed transition and emission distributions are
$$p(z_1) = N(z_1 | m, S)$$
$$p(z_t | z_{t-1}, u_t) = N(z_t | m_t, P_t)$$
$$p(y_t | z_t) = N(y_t | h(z_t, u_t), R_t)$$
"""
initial: ParamsLGSSMInitial
dynamics: ParamsCDNLGSSMDynamics
emissions: ParamsCDNLGSSMEmissions
'''
# CDNLSSM parameters are different to CDNLGSSM due to non-gaussian transitions
class ParamsCDNLGSSM(NamedTuple):
r"""Parameters of a linear Gaussian SSM.
:param initial: initial distribution parameters
:param dynamics: dynamics distribution parameters
:param emissions: emission distribution parameters
The assumed transition and emission distributions are
$$p(z_1) = N(z_1 | m, S)$$
"""
initial: ParamsLGSSMInitial
dynamics: ParamsCDNLSSMDynamics
emissions: ParamsCDNLGSSMEmissions
'''
# TODO: Move this for linear Gaussian SSM
class GSSMForecast(NamedTuple):
r"""Object definition used when forecasting.
# If we forecast Gaussian distributions, based on filtering methods
:param forecasted_state_means: array of forecasted state means $\mathbb{E}[z_{t+1:t+t_f} \mid y_{1:t}, u_{1:t}, u_{t+1:t+f}]$
:param filtered_covariances: array of forecasted state covariances $\mathrm{Cov}[z_{t+1:t+t_f} \mid y_{1:t}, u_{1:t}, u_{t+1:t+f}]$
:param forecasted_emission_means: array of forecasted emission means $\mathbb{E}[y_{t+1:t+t_f} \mid y_{1:t}, u_{1:t}, u_{t+1:t+f}]$
:param forecasted_emission_covariances: array of forecasted emission covariances $\mathrm{Cov}[y_{t+1:t+t_f} \mid y_{1:t}, u_{1:t}, u_{t+1:t+f}]$
# If we forecast paths, based on solving the SDE
:param forecasted_state_path: array of forecasted state path $z_{t+1:t+t_f}$
:param forecasted_emission_path: array of forecasted emission path $y_{t+1:t+t_f}$
"""
# If we forecast Gaussian distributions, based on filtering methods
forecasted_state_means: Optional[Float[Array, "ntime state_dim"]] = None
forecasted_state_covariances: Optional[Float[Array, "ntime state_dim"]] = None
forecasted_emission_means: Optional[Float[Array, "ntime state_dim"]] = None
forecasted_emission_covariances: Optional[Float[Array, "ntime state_dim"]] = None
# If we forecast paths, based on solving the SDE
forecasted_state_path: Optional[Float[Array, "ntime state_dim"]] = None
forecasted_emission_path: Optional[Float[Array, "ntime state_dim"]] = None