From 3d06ab9aa285024b61f62046b996d12c8d764155 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Thu, 12 Sep 2024 16:31:01 -0400 Subject: [PATCH] Vectorize `ARProcess()` (#439) --- pyrenew/process/ar.py | 177 +++++++--- pyrenew/process/iidrandomsequence.py | 9 +- pyrenew/process/rtperiodicdiffar.py | 17 +- .../randomvariable/distributionalvariable.py | 4 +- test/test_ar_process.py | 304 +++++++++++++----- test/test_differenced_process.py | 2 +- test/test_iid_random_sequence.py | 36 ++- test/test_scan_rv_plate_compatibility.py | 44 +++ 8 files changed, 446 insertions(+), 147 deletions(-) create mode 100644 test/test_scan_rv_plate_compatibility.py diff --git a/pyrenew/process/ar.py b/pyrenew/process/ar.py index c7473395..ec8b88bb 100644 --- a/pyrenew/process/ar.py +++ b/pyrenew/process/ar.py @@ -1,13 +1,18 @@ -# numpydoc ignore=GL08 +""" +This file defines a RandomVariable subclass for +autoregressive (AR) processes +""" from __future__ import annotations +import jax import jax.numpy as jnp +import numpyro from jax.typing import ArrayLike from numpyro.contrib.control_flow import scan +from numpyro.infer.reparam import LocScaleReparam from pyrenew.metaclass import RandomVariable -from pyrenew.process.iidrandomsequence import StandardNormalSequence class ARProcess(RandomVariable): @@ -16,32 +21,23 @@ class ARProcess(RandomVariable): an AR(p) process. """ - def __init__(self, noise_rv_name: str, *args, **kwargs) -> None: - """ - Default constructor. - - Parameters - ---------- - noise_rv_name : str - A name for the internal RandomVariable - holding the process noise. - """ - super().__init__(*args, **kwargs) - self.noise_rv_ = StandardNormalSequence(element_rv_name=noise_rv_name) - def sample( self, + noise_name: str, n: int, autoreg: ArrayLike, init_vals: ArrayLike, noise_sd: float | ArrayLike, - **kwargs, ) -> ArrayLike: """ Sample from the AR process Parameters ---------- + noise_name: str + A name for the sample site holding the + Normal(`0`, `noise_sd`) noise for the AR process. + Passed to :func:`numpyro.sample`. n: int Length of the sequence. autoreg: ArrayLike @@ -52,60 +48,137 @@ def sample( init_vals : ArrayLike Array of initial values. Must have the same first dimension size as the order. - noise_sd : float | ArrayLike - Scalar giving the s.d. of the AR + noise_sd : ArrayLike + Standard deviation of the AR process Normal noise, which by definition has mean 0. - **kwargs : dict, optional - Additional keyword arguments passed to - self.noise_rv_.sample() Returns ------- ArrayLike + with first dimension of length `n` + and additional dimensions as inferred + from the shapes of `autoreg`, + `init_vals`, and `noise_sd`. + + Notes + ----- + The first dimension of the return value + with be of length `n` and represents time. + Trailing dimensions follow standard numpy + broadcasting rules and are determined from + the second through `n` th dimensions, if any, + of `autoreg` and `init_vals`, as well as the + all dimensions of `noise_sd` (i.e. + :code:`jax.numpy.shape(autoreg)[1:]`, + :code:`jax.numpy.shape(init_vals)[1:]` + and :code:`jax.numpy.shape(noise_sd)` + + Those shapes must be + broadcastable together via + :func:`jax.lax.broadcast_shapes`. This can + be used to produce multiple AR processes of the + same order but with either shared or different initial + values, AR coefficient vectors, and/or + and noise standard deviation values. """ - noise_sd_arr = jnp.atleast_1d(noise_sd) - if not noise_sd_arr.shape == (1,): - raise ValueError("noise_sd must be a scalar. " f"Got {noise_sd}") autoreg = jnp.atleast_1d(autoreg) init_vals = jnp.atleast_1d(init_vals) + noise_sd = jnp.array(noise_sd) + # noise_sd can be a scalar, but + # autoreg and init_vals must have a + # a first dimension (time), + # as the order of the process is + # inferred from that first dimension - if not autoreg.ndim == 1: - raise ValueError( - "Array of autoregressive coefficients " - "must be no more than 1 dimension", - f"Got {autoreg.ndim}", + order = autoreg.shape[0] + n_inits = init_vals.shape[0] + + try: + noise_shape = jax.lax.broadcast_shapes( + init_vals.shape[1:], + autoreg.shape[1:], + noise_sd.shape, ) - if not init_vals.ndim == 1: + except Exception as e: raise ValueError( - "Array of initial values must be " "no more than 1 dimension", - f"Got {init_vals.ndim}", - ) - order = autoreg.size - if not init_vals.size == order: + "Could not determine a " + "valid shape for the AR process noise " + "from the shapes of the init_vals, " + "autoreg, and noise_sd arrays. " + "See ARProcess.sample() documentation " + "for details." + ) from e + + if not n_inits == order: raise ValueError( - "Array of initial values must be " - "be the same size as the order of " - "the autoregressive process, " - "which is determined by the number " - "of autoregressive coefficients " - "provided. Got {init_vals.size} " - "initial values for a process of " - f"order {order}" + "Initial values array must have the same " + "first dimension length as the order p of " + "the AR process. The order is given by " + "the first dimension length of the array " + "of autoregressive coefficients. Got an initial " + f"value array with first dimension {n_inits} for " + f"a process of order {order}" ) - raw_noise = self.noise_rv_(n=n, **kwargs) - noise = noise_sd_arr * raw_noise + history_shape = (order,) + noise_shape + + try: + inits_broadcast = jnp.broadcast_to(init_vals, history_shape) + except Exception as e: + raise ValueError( + "Could not broadcast init_vals " + f"(shape {init_vals.shape}) " + "to the expected shape of the process " + f"history (shape {history_shape}). " + "History shape is determined by the " + "shapes of the init_vals, autoreg, and " + "noise_sd arrays. See ARProcess " + "documentation for details" + ) from e + + inits_flipped = jnp.flip(inits_broadcast, axis=0) + + def transition(recent_vals, _): # numpydoc ignore=GL08 + with numpyro.handlers.reparam( + config={noise_name: LocScaleReparam(0)} + ): + next_noise = numpyro.sample( + noise_name, + numpyro.distributions.Normal( + loc=jnp.zeros(noise_shape), scale=noise_sd + ), + ) + + dot_prod = jnp.einsum("i...,i...->...", autoreg, recent_vals) + new_term = dot_prod + next_noise + new_recent_vals = jnp.concatenate( + [ + new_term[jnp.newaxis, ...], + # concatenate as (1 time unit,) + noise_shape + # array + recent_vals, + ], + axis=0, + )[:order] - def transition(recent_vals, next_noise): # numpydoc ignore=GL08 - new_term = jnp.dot(autoreg, recent_vals) + next_noise - new_recent_vals = jnp.hstack( - [new_term, recent_vals[: (order - 1)]] - ) return new_recent_vals, new_term - last, ts = scan(transition, init_vals, noise) - return jnp.hstack([init_vals, ts]) + if n > order: + _, ts = scan( + f=transition, + init=inits_flipped, + xs=None, + length=(n - order), + ) + + ts_with_inits = jnp.concatenate( + [inits_broadcast, ts], + axis=0, + ) + else: + ts_with_inits = inits_broadcast + return ts_with_inits[:n] @staticmethod def validate(): # numpydoc ignore=RT01 diff --git a/pyrenew/process/iidrandomsequence.py b/pyrenew/process/iidrandomsequence.py index cdc93fae..34df097a 100644 --- a/pyrenew/process/iidrandomsequence.py +++ b/pyrenew/process/iidrandomsequence.py @@ -110,6 +110,7 @@ class StandardNormalSequence(IIDRandomSequence): def __init__( self, element_rv_name: str, + element_shape: tuple = None, **kwargs, ): """ @@ -124,13 +125,19 @@ def __init__( DistributionalVariable encoding a standard Normal (mean = 0, sd = 1) distribution. + element_shape : tuple + Shape for each element in the sequence. + If None, elements are scalars. Default + None. Returns ------- None """ + if element_shape is None: + element_shape = () super().__init__( element_rv=DistributionalVariable( name=element_rv_name, distribution=dist.Normal(0, 1) - ), + ).expand_by(element_shape) ) diff --git a/pyrenew/process/rtperiodicdiffar.py b/pyrenew/process/rtperiodicdiffar.py index 25fb1d9a..8a6fc6f0 100644 --- a/pyrenew/process/rtperiodicdiffar.py +++ b/pyrenew/process/rtperiodicdiffar.py @@ -77,10 +77,10 @@ def __init__( self.log_rt_rv = log_rt_rv self.autoreg_rv = autoreg_rv self.periodic_diff_sd_rv = periodic_diff_sd_rv + self.ar_process_suffix = ar_process_suffix + self.ar_diff = DifferencedProcess( - fundamental_process=ARProcess( - noise_rv_name=f"{name}{ar_process_suffix}" - ), + fundamental_process=ARProcess(), differencing_order=1, ) @@ -138,9 +138,9 @@ def sample( """ # Initial sample - log_rt_rv = self.log_rt_rv.sample(**kwargs) - b = self.autoreg_rv.sample(**kwargs) - s_r = self.periodic_diff_sd_rv.sample(**kwargs) + log_rt_rv = self.log_rt_rv(**kwargs).squeeze() + b = self.autoreg_rv(**kwargs).squeeze() + s_r = self.periodic_diff_sd_rv(**kwargs).squeeze() # How many periods to sample? n_periods = (duration + self.period_size - 1) // self.period_size @@ -148,12 +148,13 @@ def sample( # Running the process log_rt = self.ar_diff( + noise_name=f"{self.name}{self.ar_process_suffix}", n=n_periods, - init_vals=jnp.array([log_rt_rv[0]]), + init_vals=jnp.array(log_rt_rv[0]), autoreg=b, noise_sd=s_r, fundamental_process_init_vals=jnp.array( - [log_rt_rv[1] - log_rt_rv[0]] + log_rt_rv[1] - log_rt_rv[0] ), ) diff --git a/pyrenew/randomvariable/distributionalvariable.py b/pyrenew/randomvariable/distributionalvariable.py index 20ec94a2..b612aa93 100644 --- a/pyrenew/randomvariable/distributionalvariable.py +++ b/pyrenew/randomvariable/distributionalvariable.py @@ -130,7 +130,7 @@ def sample( def expand_by(self, sample_shape) -> Self: """ Expand the distribution by a given - shape_shape, if possible. Returns a + sample_shape, if possible. Returns a new DynamicDistributionalVariable whose underlying distribution will be expanded by the given shape at sample() time. @@ -326,5 +326,5 @@ def DistributionalVariable( "(for instantiating a static DistributionalVariable) " "or a callable that returns a " "numpyro.distributions.Distribution (for " - "a dynamic DistributionalVariable" + "a dynamic DistributionalVariable)." ) diff --git a/test/test_ar_process.py b/test/test_ar_process.py index b16ebbc4..45ec4c0e 100755 --- a/test/test_ar_process.py +++ b/test/test_ar_process.py @@ -3,106 +3,256 @@ import jax.numpy as jnp import numpyro import pytest -from numpy.testing import assert_almost_equal +from numpy.testing import assert_array_almost_equal from pyrenew.process import ARProcess -def test_ar_can_be_sampled(): +@pytest.mark.parametrize( + ["init_vals", "autoreg", "noise_sd", "n"], + [ + # AR1, 1D + [jnp.array([50.0]), jnp.array([0.95]), 0.5, 1353], + # AR1, 1D, length 1 + [jnp.array([50.0]), jnp.array([0.95]), 0.5, 1], + # AR1, multi-dim + [ + jnp.array([[43.1, -32.5, 3.2, -0.5]]).reshape((1, 4)), + jnp.array([[0.50, 0.205, 0.232, 0.25]]).reshape((1, 4)), + jnp.array([0.73]), + 5322, + ], + # AR3, one dim + [ + jnp.array([43.1, -32.5, 0.52]), + jnp.array([0.50, 0.205, 0.25]), + jnp.array(0.802), + 6432, + ], + # AR3, one dim but unsqueezed + [ + jnp.array([[43.1, -32.5, 0.52]]).reshape((3, -1)), + jnp.array([[0.50, 0.205, 0.25]]).reshape((3, -1)), + 0.802, + 6432, + ], + # AR3, two sets of inits + # one set of AR coefficients + [ + jnp.array( + [ + [43.1, -32.5], + [0.52, 50.35], + [40.0, 0.3], + ] + ), + jnp.array([0.50, 0.205, 0.25]), + 0.802, + 533, + ], + # AR3, one set of inits and two + # sets of coefficients + [ + jnp.array([50.0, 49.9, 48.2]).reshape((3, -1)), + jnp.array([[0.05, 0.025], [0.25, 0.25], [0.1, 0.1]]), + 0.5, + 1230, + ], + # AR3, twos set of (identical) inits, two + # sets of coefficients, two s.ds + [ + jnp.array( + [ + [50.0, 49.9, 48.2], + [50.0, 49.9, 48.2], + ] + ).reshape((3, -1)), + jnp.array([[0.05, 0.025], [0.25, 0.25], [0.1, 0.1]]), + jnp.array([1, 0.25]), + 1230, + ], + # AR3, twos set of (identical) inits, two + # sets of coefficients, two s.ds, + # n shorter than the order + [ + jnp.array( + [ + [50.0, 49.9, 48.2], + [50.0, 49.9, 48.2], + ] + ).reshape((3, -1)), + jnp.array([[0.05, 0.025], [0.25, 0.25], [0.1, 0.1]]), + jnp.array([1, 0.25]), + 1, + ], + ], +) +def test_ar_can_be_sampled(init_vals, autoreg, noise_sd, n): """ Check that an AR process - can be initialized and sampled from + can be initialized and sampled from, + and that output shapes are as expected. """ - ar1 = ARProcess(noise_rv_name="ar1process_noise") + ar = ARProcess() with numpyro.handlers.seed(rng_seed=62): # can sample - ar1( - n=3532, - init_vals=jnp.array([50.0]), - autoreg=jnp.array([0.95]), - noise_sd=0.5, - ) - - ar3 = ARProcess(noise_rv_name="ar3process_noise") - with numpyro.handlers.seed(rng_seed=62): - # can sample - ar3( - n=1230, - init_vals=jnp.array([50.0, 49.9, 48.2]), - autoreg=jnp.array([0.05, 0.025, 0.025]), - noise_sd=0.5, - ) - ar3( - n=1230, - init_vals=jnp.array([50.0, 49.9, 48.2]), - autoreg=jnp.array([0.05, 0.025, 0.025]), - noise_sd=[0.25], + res = ar( + noise_name="ar3process_noise", + n=n, + init_vals=init_vals, + autoreg=autoreg, + noise_sd=noise_sd, ) - ar3( - n=1230, - init_vals=jnp.array([50.0, 49.9, 48.2]), - autoreg=jnp.array([0.05, 0.025, 0.025]), - noise_sd=jnp.array([0.25]), + order = jnp.shape(autoreg)[0] + non_time_dims = jnp.broadcast_shapes( + jnp.atleast_1d(autoreg).shape[1:], + jnp.atleast_1d(init_vals).shape[1:], + jnp.shape(noise_sd), ) - # vector valued noise raises - # error - with pytest.raises(ValueError, match="must be a scalar"): - ar3( - n=1230, - init_vals=jnp.array([50.0, 49.9, 48.2]), - autoreg=jnp.array([0.05, 0.025, 0.025]), - noise_sd=jnp.array([1.0, 2.0]), - ) - with pytest.raises(ValueError, match="must be a scalar"): - ar3( - n=1230, - init_vals=jnp.array([50.0, 49.9, 48.2]), - autoreg=jnp.array([0.05, 0.025, 0.025]), - noise_sd=[1.0, 2.0], - ) + expected_shape = (n,) + non_time_dims + first_entries_broadcast_shape = (order,) + non_time_dims - # bad dimensionality raises error - with pytest.raises(ValueError, match="Array of autoregressive"): - ar3( - n=1230, - init_vals=jnp.array([50.0, 49.9, 48.2]), - autoreg=jnp.array([[0.05, 0.025, 0.025]]), - noise_sd=0.5, - ) - with pytest.raises(ValueError, match="Array of initial"): - ar3( - n=1230, - init_vals=jnp.array([[50.0, 49.9, 48.2]]), - autoreg=jnp.array([0.05, 0.025, 0.025]), - noise_sd=0.5, - ) - with pytest.raises(ValueError, match="same size as the order"): - ar3( - n=1230, - init_vals=jnp.array([50.0, 49.9, 1, 1, 1]), - autoreg=jnp.array([0.05, 0.025, 0.025]), - noise_sd=0.5, + expected_first_entries = jnp.broadcast_to( + init_vals, first_entries_broadcast_shape + )[:n] + + assert jnp.shape(res) == expected_shape + assert_array_almost_equal(res[:order, ...], expected_first_entries) + + +@pytest.mark.parametrize( + ["init_vals", "autoreg", "noise_sd", "n", "error_match"], + [ + # autoreg higher dim than init vals + # and not reshaped appropriately + [ + jnp.array([50.0, 49.9, 48.2]), + jnp.array([[0.05, 0.025, 0.025]]), + 0.5, + 1230, + "Initial values array", + ], + # initial vals higher dim than autoreg + # and not reshaped appropriately + [ + jnp.array([[50.0, 49.9, 48.2]]), + jnp.array([0.05, 0.025, 0.025]), + 0.5, + 1230, + "Initial values array", + ], + # not enough initial values + [ + jnp.array([50.0, 49.9, 48.2]), + jnp.array([0.05, 0.025, 0.025, 0.25]), + 0.5, + 1230, + "Initial values array", + ], + # too many initial values + [ + jnp.array([50.0, 49.9, 48.2, 0.035, 0.523]), + jnp.array([0.05, 0.025, 0.025, 0.25]), + 0.5, + 1230, + "Initial values array", + ], + # unbroadcastable shapes + [ + jnp.array([[50.0, 49.9], [48.2, 0.035]]), + jnp.array([[0.05, 0.025], [0.025, 0.25]]), + jnp.array([0.5, 0.25, 0.3]), + 1230, + "Could not determine a valid shape", + ], + [ + jnp.array([50.0, 49.9]), + jnp.array([0.05, 0.025]), + jnp.array([0.5]), + 1230, + "Could not broadcast init_vals", + ], + # unbroadcastable shapes: + # sd versus AR mismatch + [ + jnp.array([50.0, 49.9, 0.25]), + jnp.array([[0.05, 0.025], [0.025, 0.25], [0.01, 0.1]]), + jnp.array([0.25, 0.25]), + 1230, + "Could not broadcast init_vals", + ], + ], +) +def test_ar_shape_validation(init_vals, autoreg, noise_sd, n, error_match): + """ + Test that AR process sample() method validates + the shapes of its inputs as expected. + """ + # vector valued noise raises + # error + ar = ARProcess() + + # bad dimensionality raises error + with pytest.raises(ValueError, match=error_match): + with numpyro.handlers.seed(rng_seed=5): + ar( + noise_name="test_ar_noise", + n=n, + init_vals=init_vals, + autoreg=autoreg, + noise_sd=noise_sd, ) -def test_ar_samples_correctly_distributed(): +@pytest.mark.parametrize( + ["ar_inits", "autoreg", "noise_sd", "n"], + [ + [ + jnp.array([25.0]), + jnp.array([0.75]), + jnp.array([0.5]), + 10000, + ], + [ + jnp.array([-500, -499.0]), + jnp.array([0.5, 0.45]), + jnp.array(1.25), + 10001, + ], + ], +) +def test_ar_process_asymptotics(ar_inits, autoreg, noise_sd, n): """ - Check that AR processes have correctly- - distributed steps. + Check that AR processes can + start away from the stationary + distribution and converge to it. """ - noise_sd = jnp.array([0.5]) - ar_inits = jnp.array([25.0]) - ar = ARProcess("arprocess") + ar = ARProcess() + order = jnp.shape(ar_inits)[0] + non_time_dims = jnp.broadcast_shapes( + jnp.atleast_1d(autoreg).shape[1:], + jnp.atleast_1d(ar_inits).shape[1:], + jnp.shape(noise_sd), + ) + + first_entries_broadcast_shape = (order,) + non_time_dims + + expected_first_entries = jnp.broadcast_to( + ar_inits, first_entries_broadcast_shape + )[:n] + with numpyro.handlers.seed(rng_seed=62): # check it regresses to mean # when started away from it long_ts = ar( - n=10000, + noise_name="arprocess_noise", + n=n, init_vals=ar_inits, - autoreg=jnp.array([0.75]), + autoreg=autoreg, noise_sd=noise_sd, ) - assert_almost_equal(long_ts[0], ar_inits) - assert jnp.abs(long_ts[-1]) < 4 * noise_sd + assert_array_almost_equal(long_ts[:order], expected_first_entries) + + assert jnp.abs(long_ts[-1]) < 3 * noise_sd diff --git a/test/test_differenced_process.py b/test/test_differenced_process.py index 07ad1854..90620ba5 100644 --- a/test/test_differenced_process.py +++ b/test/test_differenced_process.py @@ -122,7 +122,7 @@ def test_integrator_correctness(order, n_diffs): ) result_proc1 = proc.integrate(inits, diffs) assert result_proc1.shape == (n_diffs + order,) - assert_array_almost_equal(result_manual, result_proc1, decimal=5) + assert_array_almost_equal(result_manual, result_proc1, decimal=4) assert result_proc1[0] == inits[0] diff --git a/test/test_iid_random_sequence.py b/test/test_iid_random_sequence.py index 9630a642..07a8e7d1 100755 --- a/test/test_iid_random_sequence.py +++ b/test/test_iid_random_sequence.py @@ -54,24 +54,48 @@ def test_iidrandomsequence_with_dist_rv(distribution, n): assert kstest_out.pvalue > 0.01 -def test_standard_normal_sequence(): +@pytest.mark.parametrize( + ["shape", "n"], + [[None, 352], [(), 72352], [(5,), 5432], [(3, 23, 2), 10352]], +) +def test_standard_normal_sequence(shape, n): """ Test the StandardNormalSequence RandomVariable class. """ - norm_seq = StandardNormalSequence("test_norm_elements") + norm_seq = StandardNormalSequence( + "test_norm_elements", element_shape=shape + ) # should be implemented with a DistributionalVariable # that is a standard normal assert isinstance(norm_seq.element_rv, StaticDistributionalVariable) - assert isinstance(norm_seq.element_rv.distribution, dist.Normal) - assert norm_seq.element_rv.distribution.loc == 0.0 - assert norm_seq.element_rv.distribution.scale == 1.0 + if shape is None or shape == (): + assert isinstance(norm_seq.element_rv.distribution, dist.Normal) + el_dist = norm_seq.element_rv.distribution + else: + assert isinstance( + norm_seq.element_rv.distribution, dist.ExpandedDistribution + ) + assert isinstance( + norm_seq.element_rv.distribution.base_dist, dist.Normal + ) + el_dist = norm_seq.element_rv.distribution.base_dist + assert el_dist.loc == 0.0 + assert el_dist.scale == 1.0 # should be sampleable with numpyro.handlers.seed(rng_seed=67): + ans = norm_seq(n=n) + + # samples should have shape (n,) + the element_rv sample shape + expected_sample_shape = (n,) + shape if shape is not None else (n,) + assert jnp.shape(ans) == expected_sample_shape + + with numpyro.handlers.seed(rng_seed=35): ans = norm_seq.sample(n=50000) # samples should be approximately standard normal - kstest_out = kstest(ans, "norm", (0, 1)) + kstest_out = kstest(ans.flatten(), "norm", (0, 1)) + assert kstest_out.pvalue > 0.01 diff --git a/test/test_scan_rv_plate_compatibility.py b/test/test_scan_rv_plate_compatibility.py new file mode 100644 index 00000000..4b60b4bf --- /dev/null +++ b/test/test_scan_rv_plate_compatibility.py @@ -0,0 +1,44 @@ +""" +Test that key :class:`RandomVariable` +classes behave as expected in a +:func:`numpyro.plate` context. +""" +import jax.numpy as jnp +import numpyro +import numpyro.distributions as dist +import pytest + +from pyrenew.process import ARProcess +from pyrenew.randomvariable import DistributionalVariable + + +@pytest.mark.parametrize( + ["random_variable", "constructor_args", "sample_args"], + [ + [ + ARProcess, + dict(), + dict( + noise_name="ar_noise", + n=100, + autoreg=jnp.array([0.25, 0.1]), + init_vals=jnp.array([15.0, 50.2]), + noise_sd=jnp.array([0.5, 1.5]), + ), + ] + ], +) +def test_single_plate_sampling(random_variable, constructor_args, sample_args): + """ + Test that the output of vectorized + scans can be sent into plate contexts + successfully + """ + with numpyro.handlers.seed(rng_seed=5): + scanned_rv = random_variable(**constructor_args) + scanned_output = scanned_rv(**sample_args) + with numpyro.plate("test_plate", jnp.shape(scanned_output)[-1]): + plated_rv = DistributionalVariable("test", dist.Normal(0, 1)) + plated_samp = plated_rv() + output = scanned_output + plated_samp + assert output.shape == scanned_output.shape