Skip to content

Commit

Permalink
Multi-dim AR process, still with scalar sd
Browse files Browse the repository at this point in the history
  • Loading branch information
dylanhmorris committed Sep 11, 2024
1 parent 618a0d2 commit 4d42e5a
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 79 deletions.
57 changes: 30 additions & 27 deletions pyrenew/process/ar.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import jax
import jax.numpy as jnp
import numpyro
from jax.typing import ArrayLike
Expand Down Expand Up @@ -40,7 +41,7 @@ def sample(
init_vals : ArrayLike
Array of initial values. Must have the
same first dimension size as the order.
noise_sd : float | ArrayLike
noise_sd : ArrayLike
Standard deviation of the AR
process Normal noise, which by
definition has mean 0.
Expand All @@ -55,30 +56,23 @@ def sample(
if not noise_sd_arr.shape == (1,):
raise ValueError("noise_sd must be a scalar. " f"Got {noise_sd}")
autoreg = jnp.atleast_1d(autoreg)
noise_sd = jnp.atleast_1d(noise_sd)
init_vals = jnp.atleast_1d(init_vals)
order = autoreg.shape[0]

if not autoreg.ndim == 1:
raise ValueError(
"Array of autoregressive coefficients "
"must be no more than 1 dimension",
f"Got {autoreg.ndim}",
)
if not init_vals.ndim == 1:
raise ValueError(
"Array of initial values must be " "no more than 1 dimension",
f"Got {init_vals.ndim}",
)
order = autoreg.size
if not init_vals.size == order:
noise_shape = jax.lax.broadcast_shapes(
autoreg.shape[1:], noise_sd.shape
)

if not init_vals.shape == autoreg.shape:
raise ValueError(
"Array of initial values must be "
"be the same size as the order of "
"the autoregressive process, "
"which is determined by the number "
"of autoregressive coefficients "
"provided. Got {init_vals.size} "
"initial values for a process of "
f"order {order}"
"Initial values array and autoregressive "
"coefficient array must be of the same shape ",
"and must have a first dimension that represents "
"the order of the AR process. Got a shape of "
"{init_vals.shape} for the initial values and "
"a shape of {autoreg.shape} for the autoregressive "
"coefficients",
)

def transition(recent_vals, _): # numpydoc ignore=GL08
Expand All @@ -87,16 +81,25 @@ def transition(recent_vals, _): # numpydoc ignore=GL08
):
next_noise = numpyro.sample(
noise_name,
numpyro.distributions.Normal(loc=0, scale=noise_sd_arr),
numpyro.distributions.Normal(
loc=jnp.zeros(noise_shape), scale=noise_sd
),
)

new_term = jnp.dot(autoreg, recent_vals) + next_noise
new_recent_vals = jnp.concatenate(
[new_term, recent_vals[..., : (order - 1)]]
new_term = (
jnp.tensordot(autoreg, recent_vals, axes=[0, 0]) + next_noise
)
new_recent_vals = jnp.vstack(
[new_term, recent_vals[: (order - 1), ...]]
)
return new_recent_vals, new_term

last, ts = scan(f=transition, init=init_vals, xs=None, length=n)
last, ts = scan(
f=transition,
init=init_vals[..., jnp.newaxis],
xs=None,
length=(n - order),
)
return (
SampledValue(
jnp.squeeze(
Expand Down
113 changes: 61 additions & 52 deletions test/test_ar_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,89 +11,98 @@
def test_ar_can_be_sampled():
"""
Check that an AR process
can be initialized and sampled from
can be initialized and sampled from,
and that output shapes are as expected.
"""
ar1 = ARProcess()
ar = ARProcess()
with numpyro.handlers.seed(rng_seed=62):
# can sample
ar1(
ar(
noise_name="ar1process_noise",
n=3532,
init_vals=jnp.array([50.0]),
autoreg=jnp.array([0.95]),
noise_sd=0.5,
)

ar3 = ARProcess()

with numpyro.handlers.seed(rng_seed=62):
# can sample
ar3(
res1, *_ = ar(
noise_name="ar3process_noise",
n=1230,
init_vals=jnp.array([50.0, 49.9, 48.2]),
autoreg=jnp.array([0.05, 0.025, 0.025]),
noise_sd=0.5,
)
ar3(
res2, *_ = ar(
noise_name="ar3process_noise",
n=1230,
init_vals=jnp.array([50.0, 49.9, 48.2]),
autoreg=jnp.array([0.05, 0.025, 0.025]),
noise_sd=[0.25],
)
ar3(
res3, *_ = ar(
noise_name="ar3process_noise",
n=1230,
init_vals=jnp.array([50.0, 49.9, 48.2]),
autoreg=jnp.array([0.05, 0.025, 0.025]),
noise_sd=jnp.array([0.25]),
)

# vector valued noise raises
# error
with pytest.raises(ValueError, match="must be a scalar"):
ar3(
noise_name="ar3process_noise",
n=1230,
init_vals=jnp.array([50.0, 49.9, 48.2]),
autoreg=jnp.array([0.05, 0.025, 0.025]),
noise_sd=jnp.array([1.0, 2.0]),
)
with pytest.raises(ValueError, match="must be a scalar"):
ar3(
noise_name="ar3process_noise",
n=1230,
init_vals=jnp.array([50.0, 49.9, 48.2]),
autoreg=jnp.array([0.05, 0.025, 0.025]),
noise_sd=[1.0, 2.0],
)
assert jnp.shape(res1.value) == jnp.shape(res2.value)
assert jnp.shape(res2.value) == jnp.shape(res3.value)
assert jnp.shape(res3.value) == (1230,)

# bad dimensionality raises error
with pytest.raises(ValueError, match="Array of autoregressive"):
ar3(
noise_name="ar3process_noise",
n=1230,
init_vals=jnp.array([50.0, 49.9, 48.2]),
autoreg=jnp.array([[0.05, 0.025, 0.025]]),
noise_sd=0.5,
)
with pytest.raises(ValueError, match="Array of initial"):
ar3(
noise_name="ar3process_noise",
n=1230,
init_vals=jnp.array([[50.0, 49.9, 48.2]]),
autoreg=jnp.array([0.05, 0.025, 0.025]),
noise_sd=0.5,
)
with pytest.raises(ValueError, match="same size as the order"):
ar3(
noise_name="ar3process_noise",
n=1230,
init_vals=jnp.array([50.0, 49.9, 1, 1, 1]),
autoreg=jnp.array([0.05, 0.025, 0.025]),
noise_sd=0.5,
)

def test_ar_shape_validation():
"""
Test that AR process sample() method validates
the shapes of its inputs as expected.
"""
# vector valued noise raises
# error
ar = ARProcess()

with pytest.raises(ValueError, match="must be a scalar"):
ar(
noise_name="ar3process_noise",
n=1230,
init_vals=jnp.array([50.0, 49.9, 48.2]),
autoreg=jnp.array([0.05, 0.025, 0.025]),
noise_sd=jnp.array([1.0, 2.0]),
)
with pytest.raises(ValueError, match="must be a scalar"):
ar(
noise_name="ar3process_noise",
n=1230,
init_vals=jnp.array([50.0, 49.9, 48.2]),
autoreg=jnp.array([0.05, 0.025, 0.025]),
noise_sd=[1.0, 2.0],
)
# bad dimensionality raises error
with pytest.raises(ValueError, match="Initial values array"):
ar(
noise_name="ar3process_noise",
n=1230,
init_vals=jnp.array([50.0, 49.9, 48.2]),
autoreg=jnp.array([[0.05, 0.025, 0.025]]),
noise_sd=0.5,
)
with pytest.raises(ValueError, match="Initial values array"):
ar(
noise_name="ar3process_noise",
n=1230,
init_vals=jnp.array([[50.0, 49.9, 48.2]]),
autoreg=jnp.array([0.05, 0.025, 0.025]),
noise_sd=0.5,
)
with pytest.raises(ValueError, match="Initial values array"):
ar(
noise_name="ar3process_noise",
n=1230,
init_vals=jnp.array([50.0, 49.9, 1, 1, 1]),
autoreg=jnp.array([0.05, 0.025, 0.025]),
noise_sd=0.5,
)


def test_ar_samples_correctly_distributed():
Expand Down

0 comments on commit 4d42e5a

Please sign in to comment.