Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vectorize ARProcess() #439

Merged
merged 29 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
3c1f487
Fix docstring and error message typos in distributionalvariable.py
dylanhmorris Sep 4, 2024
3c5926c
Make IID random sequence have arbitrary shape
dylanhmorris Sep 4, 2024
b65b03b
Revert incorrect change
dylanhmorris Sep 4, 2024
3e1ef80
Add test for StandardNormalSequence with configurable shape
dylanhmorris Sep 4, 2024
c0375c3
Draw ar noise within scan
dylanhmorris Sep 10, 2024
618a0d2
Tweak docstrings and function call signature for AR
dylanhmorris Sep 10, 2024
4d42e5a
Multi-dim AR process, still with scalar sd
dylanhmorris Sep 11, 2024
005e383
jnp.flip for init values and appropriate axis dimensions
dylanhmorris Sep 11, 2024
ef16ea5
Update ar sample docstring
dylanhmorris Sep 11, 2024
aa48d6b
Merge branch 'main' into dhm-vectorize-ar
dylanhmorris Sep 11, 2024
0cf6406
tensordot ==> einsum and test improvements
dylanhmorris Sep 11, 2024
def1cd5
Fix test bug, run precommit
dylanhmorris Sep 11, 2024
92005d9
Add switches for order>1
dylanhmorris Sep 11, 2024
987ff0c
Remove switches, remove reshapes, nicer broadcasting
dylanhmorris Sep 12, 2024
fa93215
Style tweak
dylanhmorris Sep 12, 2024
7323e44
restore RNG context accidentally deleted in merge
dylanhmorris Sep 12, 2024
ff9d14e
Restore use of flatten for ks test
dylanhmorris Sep 12, 2024
7cf3e24
Better broadcasting, more tests
dylanhmorris Sep 12, 2024
430396e
Further generalize AR; breaks RtPeriodicDiff
dylanhmorris Sep 12, 2024
00f2dfc
Better failed broadcasting error for inits, and tests
dylanhmorris Sep 12, 2024
ffa597f
Fail to broadcast more clearly, with test that error is raised
dylanhmorris Sep 12, 2024
e63bc76
Squeezes as workaround for rtperiodicdiff ar while we wait for vector…
dylanhmorris Sep 12, 2024
3023a3d
Add clarifying comment on shape coercion
dylanhmorris Sep 12, 2024
f64c839
Remove 'last' output of scan
dylanhmorris Sep 12, 2024
57bd2bc
Update pyrenew/process/rtperiodicdiffar.py
dylanhmorris Sep 12, 2024
461e024
Remove shape assertions
dylanhmorris Sep 12, 2024
d72add5
Document broadcasting
dylanhmorris Sep 12, 2024
932d309
Fix backtick typo
dylanhmorris Sep 12, 2024
4c2561c
Add passing plate test
dylanhmorris Sep 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 125 additions & 52 deletions pyrenew/process/ar.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
# numpydoc ignore=GL08
"""
This file defines a RandomVariable subclass for
autoregressive (AR) processes
"""

from __future__ import annotations

import jax
import jax.numpy as jnp
import numpyro
from jax.typing import ArrayLike
from numpyro.contrib.control_flow import scan
from numpyro.infer.reparam import LocScaleReparam

from pyrenew.metaclass import RandomVariable
from pyrenew.process.iidrandomsequence import StandardNormalSequence


class ARProcess(RandomVariable):
Expand All @@ -16,32 +21,23 @@ class ARProcess(RandomVariable):
an AR(p) process.
"""

def __init__(self, noise_rv_name: str, *args, **kwargs) -> None:
"""
Default constructor.

Parameters
----------
noise_rv_name : str
A name for the internal RandomVariable
holding the process noise.
"""
super().__init__(*args, **kwargs)
self.noise_rv_ = StandardNormalSequence(element_rv_name=noise_rv_name)

def sample(
self,
noise_name: str,
n: int,
autoreg: ArrayLike,
init_vals: ArrayLike,
noise_sd: float | ArrayLike,
**kwargs,
) -> ArrayLike:
"""
Sample from the AR process

Parameters
----------
noise_name: str
A name for the sample site holding the
Normal(`0`, `noise_sd`) noise for the AR process.
Passed to :func:`numpyro.sample`.
n: int
Length of the sequence.
autoreg: ArrayLike
Expand All @@ -52,60 +48,137 @@ def sample(
init_vals : ArrayLike
Array of initial values. Must have the
same first dimension size as the order.
noise_sd : float | ArrayLike
Scalar giving the s.d. of the AR
noise_sd : ArrayLike
Standard deviation of the AR
process Normal noise, which by
definition has mean 0.
**kwargs : dict, optional
Additional keyword arguments passed to
self.noise_rv_.sample()

Returns
-------
ArrayLike
with first dimension of length `n`
and additional dimensions as inferred
from the shapes of `autoreg`,
`init_vals`, and `noise_sd`.

Notes
-----
The first dimension of the return value
with be of length `n` and represents time.
Trailing dimensions follow standard numpy
broadcasting rules and are determined from
the second through `n` th dimensions, if any,
of `autoreg` and `init_vals`, as well as the
all dimensions of `noise_sd` (i.e.
:code:`jax.numpy.shape(autoreg)[1:]`,
:code:`jax.numpy.shape(init_vals)[1:]`
and :code:`jax.numpy.shape(noise_sd)`

Those shapes must be
broadcastable together via
:func:`jax.lax.broadcast_shapes`. This can
be used to produce multiple AR processes of the
same order but with either shared or different initial
values, AR coefficient vectors, and/or
and noise standard deviation values.
"""
noise_sd_arr = jnp.atleast_1d(noise_sd)
if not noise_sd_arr.shape == (1,):
raise ValueError("noise_sd must be a scalar. " f"Got {noise_sd}")
autoreg = jnp.atleast_1d(autoreg)
init_vals = jnp.atleast_1d(init_vals)
noise_sd = jnp.array(noise_sd)
# noise_sd can be a scalar, but
# autoreg and init_vals must have a
# a first dimension (time),
# as the order of the process is
# inferred from that first dimension

if not autoreg.ndim == 1:
raise ValueError(
"Array of autoregressive coefficients "
"must be no more than 1 dimension",
f"Got {autoreg.ndim}",
order = autoreg.shape[0]
n_inits = init_vals.shape[0]

try:
noise_shape = jax.lax.broadcast_shapes(
init_vals.shape[1:],
autoreg.shape[1:],
noise_sd.shape,
)
if not init_vals.ndim == 1:
except Exception as e:
raise ValueError(
"Array of initial values must be " "no more than 1 dimension",
f"Got {init_vals.ndim}",
)
order = autoreg.size
if not init_vals.size == order:
"Could not determine a "
"valid shape for the AR process noise "
"from the shapes of the init_vals, "
"autoreg, and noise_sd arrays. "
"See ARProcess.sample() documentation "
"for details."
) from e

if not n_inits == order:
raise ValueError(
"Array of initial values must be "
"be the same size as the order of "
"the autoregressive process, "
"which is determined by the number "
"of autoregressive coefficients "
"provided. Got {init_vals.size} "
"initial values for a process of "
f"order {order}"
"Initial values array must have the same "
"first dimension length as the order p of "
"the AR process. The order is given by "
"the first dimension length of the array "
"of autoregressive coefficients. Got an initial "
f"value array with first dimension {n_inits} for "
f"a process of order {order}"
)

raw_noise = self.noise_rv_(n=n, **kwargs)
noise = noise_sd_arr * raw_noise
history_shape = (order,) + noise_shape

try:
inits_broadcast = jnp.broadcast_to(init_vals, history_shape)
except Exception as e:
raise ValueError(
"Could not broadcast init_vals "
f"(shape {init_vals.shape}) "
"to the expected shape of the process "
f"history (shape {history_shape}). "
"History shape is determined by the "
"shapes of the init_vals, autoreg, and "
"noise_sd arrays. See ARProcess "
"documentation for details"
) from e

inits_flipped = jnp.flip(inits_broadcast, axis=0)

def transition(recent_vals, _): # numpydoc ignore=GL08
with numpyro.handlers.reparam(
config={noise_name: LocScaleReparam(0)}
):
next_noise = numpyro.sample(
noise_name,
numpyro.distributions.Normal(
loc=jnp.zeros(noise_shape), scale=noise_sd
dylanhmorris marked this conversation as resolved.
Show resolved Hide resolved
),
)

dot_prod = jnp.einsum("i...,i...->...", autoreg, recent_vals)
dylanhmorris marked this conversation as resolved.
Show resolved Hide resolved
new_term = dot_prod + next_noise
new_recent_vals = jnp.concatenate(
[
new_term[jnp.newaxis, ...],
# concatenate as (1 time unit,) + noise_shape
# array
recent_vals,
],
axis=0,
)[:order]

def transition(recent_vals, next_noise): # numpydoc ignore=GL08
new_term = jnp.dot(autoreg, recent_vals) + next_noise
new_recent_vals = jnp.hstack(
[new_term, recent_vals[: (order - 1)]]
)
return new_recent_vals, new_term

last, ts = scan(transition, init_vals, noise)
return jnp.hstack([init_vals, ts])
if n > order:
_, ts = scan(
f=transition,
init=inits_flipped,
xs=None,
length=(n - order),
)

ts_with_inits = jnp.concatenate(
[inits_broadcast, ts],
axis=0,
)
else:
ts_with_inits = inits_broadcast
return ts_with_inits[:n]

@staticmethod
def validate(): # numpydoc ignore=RT01
Expand Down
9 changes: 8 additions & 1 deletion pyrenew/process/iidrandomsequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ class StandardNormalSequence(IIDRandomSequence):
def __init__(
self,
element_rv_name: str,
element_shape: tuple = None,
**kwargs,
):
"""
Expand All @@ -124,13 +125,19 @@ def __init__(
DistributionalVariable encoding a
standard Normal (mean = 0, sd = 1)
distribution.
element_shape : tuple
Shape for each element in the sequence.
If None, elements are scalars. Default
None.

Returns
-------
None
"""
if element_shape is None:
element_shape = ()
super().__init__(
element_rv=DistributionalVariable(
name=element_rv_name, distribution=dist.Normal(0, 1)
),
).expand_by(element_shape)
)
17 changes: 9 additions & 8 deletions pyrenew/process/rtperiodicdiffar.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,10 @@ def __init__(
self.log_rt_rv = log_rt_rv
self.autoreg_rv = autoreg_rv
self.periodic_diff_sd_rv = periodic_diff_sd_rv
self.ar_process_suffix = ar_process_suffix

self.ar_diff = DifferencedProcess(
fundamental_process=ARProcess(
noise_rv_name=f"{name}{ar_process_suffix}"
),
fundamental_process=ARProcess(),
differencing_order=1,
)

Expand Down Expand Up @@ -138,22 +138,23 @@ def sample(
"""

# Initial sample
log_rt_rv = self.log_rt_rv.sample(**kwargs)
b = self.autoreg_rv.sample(**kwargs)
s_r = self.periodic_diff_sd_rv.sample(**kwargs)
log_rt_rv = self.log_rt_rv(**kwargs).squeeze()
b = self.autoreg_rv(**kwargs).squeeze()
s_r = self.periodic_diff_sd_rv(**kwargs).squeeze()

# How many periods to sample?
n_periods = (duration + self.period_size - 1) // self.period_size

# Running the process

log_rt = self.ar_diff(
noise_name=f"{self.name}{self.ar_process_suffix}",
n=n_periods,
init_vals=jnp.array([log_rt_rv[0]]),
init_vals=jnp.array(log_rt_rv[0]),
autoreg=b,
noise_sd=s_r,
fundamental_process_init_vals=jnp.array(
[log_rt_rv[1] - log_rt_rv[0]]
log_rt_rv[1] - log_rt_rv[0]
),
)

Expand Down
4 changes: 2 additions & 2 deletions pyrenew/randomvariable/distributionalvariable.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def sample(
def expand_by(self, sample_shape) -> Self:
"""
Expand the distribution by a given
shape_shape, if possible. Returns a
sample_shape, if possible. Returns a
new DynamicDistributionalVariable whose underlying
distribution will be expanded by the given shape
at sample() time.
Expand Down Expand Up @@ -326,5 +326,5 @@ def DistributionalVariable(
"(for instantiating a static DistributionalVariable) "
"or a callable that returns a "
"numpyro.distributions.Distribution (for "
"a dynamic DistributionalVariable"
"a dynamic DistributionalVariable)."
)
Loading