From 65b248eb4e9d77b00b6784180236f63a6c480cff Mon Sep 17 00:00:00 2001 From: sbidari Date: Tue, 20 Aug 2024 14:31:40 -0400 Subject: [PATCH 01/16] testing convolve mode --- model/src/pyrenew/latent/hospitaladmissions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model/src/pyrenew/latent/hospitaladmissions.py b/model/src/pyrenew/latent/hospitaladmissions.py index 101ede18..959854f0 100644 --- a/model/src/pyrenew/latent/hospitaladmissions.py +++ b/model/src/pyrenew/latent/hospitaladmissions.py @@ -189,7 +189,7 @@ def sample( latent_hospital_admissions = jnp.convolve( latent_hospital_admissions_raw, infection_to_admission_interval.value, - mode="full", + mode="valid", )[: latent_hospital_admissions_raw.shape[0]] # Applying the day of the week effect From 3a548550ced935144dedf905ede363481aa90c3d Mon Sep 17 00:00:00 2001 From: sbidari Date: Wed, 21 Aug 2024 11:41:57 -0400 Subject: [PATCH 02/16] update tutorial to work with convolve mode valid --- .../tutorials/hospital_admissions_model.qmd | 26 ++++++++----------- src/pyrenew/latent/hospitaladmissions.py | 2 +- 2 files changed, 12 insertions(+), 16 deletions(-) diff --git a/docs/source/tutorials/hospital_admissions_model.qmd b/docs/source/tutorials/hospital_admissions_model.qmd index cbd73e39..9bbfeade 100644 --- a/docs/source/tutorials/hospital_admissions_model.qmd +++ b/docs/source/tutorials/hospital_admissions_model.qmd @@ -118,17 +118,17 @@ inf_hosp_int = datasets.load_infection_admission_interval() # We only need the probability_mass column of each dataset gen_int_array = gen_int["probability_mass"].to_numpy() gen_int = gen_int_array -inf_hosp_int = inf_hosp_int["probability_mass"].to_numpy() +inf_hosp_int_array = inf_hosp_int["probability_mass"].to_numpy() # Taking a peek at the first 5 elements of each -gen_int[:5], inf_hosp_int[:5] +gen_int[:5], inf_hosp_int_array[:5] # Visualizing both quantities side by side fig, axs = plt.subplots(1, 2) axs[0].plot(gen_int) axs[0].set_title("Generation interval") -axs[1].plot(inf_hosp_int) +axs[1].plot(inf_hosp_int_array) axs[1].set_title("Infection to hospital admission interval") plt.show() ``` @@ -142,7 +142,7 @@ import jax.numpy as jnp import numpyro.distributions as dist inf_hosp_int = deterministic.DeterministicPMF( - name="inf_hosp_int", value=inf_hosp_int + name="inf_hosp_int", value=inf_hosp_int_array ) hosp_rate = metaclass.DistributionalRV( @@ -175,7 +175,7 @@ I0 = InfectionInitializationProcess( distribution=dist.LogNormal(loc=jnp.log(100), scale=jnp.log(1.75)), ), InitializeInfectionsExponentialGrowth( - gen_int_array.size, + inf_hosp_int_array.size, deterministic.DeterministicVariable(name="rate", value=0.05), ), t_unit=1, @@ -313,11 +313,7 @@ We can use the `Model` object's `plot_posterior` method to visualize the model f out = hosp_model.plot_posterior( var="latent_hospital_admissions", ylab="Hospital Admissions", - obs_signal=np.pad( - daily_hosp_admits.astype(float), - (gen_int_array.size, 0), - constant_values=np.nan, - ), + obs_signal=daily_hosp_admits.astype(float), ) ``` @@ -504,7 +500,7 @@ def compute_eti(dataset, eti_prob): fig, axes = plt.subplots(figsize=(6, 5)) az.plot_hdi( - idata.prior_predictive["negbinom_rv_dim_0"] + gen_int.size(), + idata.prior_predictive["negbinom_rv_dim_0"], hdi_data=compute_eti(idata.prior_predictive["negbinom_rv"], 0.9), color="C0", smooth=False, @@ -513,7 +509,7 @@ az.plot_hdi( ) az.plot_hdi( - idata.prior_predictive["negbinom_rv_dim_0"] + gen_int.size(), + idata.prior_predictive["negbinom_rv_dim_0"], hdi_data=compute_eti(idata.prior_predictive["negbinom_rv"], 0.5), color="C0", smooth=False, @@ -522,7 +518,7 @@ az.plot_hdi( ) plt.scatter( - idata.observed_data["negbinom_rv_dim_0"] + gen_int.size(), + idata.observed_data["negbinom_rv_dim_0"], idata.observed_data["negbinom_rv"], color="black", ) @@ -538,7 +534,7 @@ And now we plot the posterior predictive distributions with a `{python} n_foreca ```{python} # | label: fig-output-posterior-predictive-forecast # | fig-cap: Posterior predictive admissions, including a forecast. -x_data = idata.posterior_predictive["negbinom_rv_dim_0"] + gen_int.size() +x_data = idata.posterior_predictive["negbinom_rv_dim_0"] y_data = idata.posterior_predictive["negbinom_rv"] fig, axes = plt.subplots(figsize=(6, 5)) az.plot_hdi( @@ -569,7 +565,7 @@ plt.plot( label="Median", ) plt.scatter( - idata.observed_data["negbinom_rv_dim_0"] + gen_int.size(), + idata.observed_data["negbinom_rv_dim_0"], idata.observed_data["negbinom_rv"], color="black", ) diff --git a/src/pyrenew/latent/hospitaladmissions.py b/src/pyrenew/latent/hospitaladmissions.py index 8e01c4d2..11472a13 100644 --- a/src/pyrenew/latent/hospitaladmissions.py +++ b/src/pyrenew/latent/hospitaladmissions.py @@ -191,7 +191,7 @@ def sample( latent_hospital_admissions_raw, infection_to_admission_interval.value, mode="valid", - )[: latent_hospital_admissions_raw.shape[0]] + ) # Applying the day of the week effect latent_hospital_admissions = ( From 9a7cbb3f921e4d080f737059687868f0628605a3 Mon Sep 17 00:00:00 2001 From: sbidari Date: Wed, 21 Aug 2024 12:44:29 -0400 Subject: [PATCH 03/16] update latent admissions test --- src/test/test_latent_admissions.py | 85 +++++++++++++----------------- 1 file changed, 38 insertions(+), 47 deletions(-) diff --git a/src/test/test_latent_admissions.py b/src/test/test_latent_admissions.py index 92acec2b..14333fc9 100644 --- a/src/test/test_latent_admissions.py +++ b/src/test/test_latent_admissions.py @@ -1,20 +1,16 @@ # -*- coding: utf-8 -*- # numpydoc ignore=GL08 +from test.utils import simple_rt + import jax.numpy as jnp import numpy.testing as testing import numpyro import numpyro.distributions as dist -from pyrenew import transformation as t from pyrenew.deterministic import DeterministicPMF, DeterministicVariable from pyrenew.latent import HospitalAdmissions, Infections -from pyrenew.metaclass import ( - DistributionalRV, - SampledValue, - TransformedRandomVariable, -) -from pyrenew.process import SimpleRandomWalkProcess +from pyrenew.metaclass import DistributionalRV, SampledValue def test_admissions_sample(): @@ -25,26 +21,36 @@ def test_admissions_sample(): # Generating Rt and Infections to compute the hospital admissions - rt = TransformedRandomVariable( - name="Rt_rv", - base_rv=SimpleRandomWalkProcess( - name="log_rt", - step_rv=DistributionalRV( - name="rw_step_rv", distribution=dist.Normal(0, 0.025) - ), - init_rv=DistributionalRV( - name="init_log_rt", distribution=dist.Normal(0, 0.2) - ), - ), - transforms=t.ExpTransform(), - ) + rt = simple_rt() + n_steps = 30 with numpyro.handlers.seed(rng_seed=223): - sim_rt = rt(n_steps=30)[0].value + sim_rt = rt(n_steps=n_steps)[0].value gen_int = jnp.array([0.5, 0.1, 0.1, 0.2, 0.1]) - i0 = 10 * jnp.ones_like(gen_int) - + inf_hosp_int_array = jnp.array( + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0.25, + 0.5, + 0.1, + 0.1, + 0.05, + ] + ) + i0 = 10 * jnp.ones_like(inf_hosp_int_array) inf1 = Infections() with numpyro.handlers.seed(rng_seed=223): @@ -53,28 +59,7 @@ def test_admissions_sample(): # Testing the hospital admissions inf_hosp = DeterministicPMF( name="inf_hosp", - value=jnp.array( - [ - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0.25, - 0.5, - 0.1, - 0.1, - 0.05, - ] - ), + value=inf_hosp_int_array, ) hosp1 = HospitalAdmissions( @@ -85,10 +70,16 @@ def test_admissions_sample(): ) with numpyro.handlers.seed(rng_seed=223): - sim_hosp_1 = hosp1(latent_infections=inf_sampled1[0]) + sim_hosp_1 = hosp1( + latent_infections=SampledValue( + value=jnp.hstack( + [i0, inf_sampled1.post_initialization_infections.value] + ) + ) + ) testing.assert_array_less( - sim_hosp_1.latent_hospital_admissions.value, + sim_hosp_1.latent_hospital_admissions.value[-n_steps:], inf_sampled1[0].value, ) From cbff93c060b7265a7a5d315f20ca4882fc3a6c54 Mon Sep 17 00:00:00 2001 From: sbidari Date: Wed, 21 Aug 2024 13:03:34 -0400 Subject: [PATCH 04/16] update DOW tutorial for convolve mode valid --- docs/source/tutorials/day_of_the_week.qmd | 24 ++++++----------------- 1 file changed, 6 insertions(+), 18 deletions(-) diff --git a/docs/source/tutorials/day_of_the_week.qmd b/docs/source/tutorials/day_of_the_week.qmd index ceadc985..976955a9 100644 --- a/docs/source/tutorials/day_of_the_week.qmd +++ b/docs/source/tutorials/day_of_the_week.qmd @@ -43,7 +43,7 @@ inf_hosp_int = datasets.load_infection_admission_interval() # We only need the probability_mass column of each dataset gen_int_array = gen_int["probability_mass"].to_numpy() gen_int = gen_int_array -inf_hosp_int = inf_hosp_int["probability_mass"].to_numpy() +inf_hosp_int_array = inf_hosp_int["probability_mass"].to_numpy() ``` 2. Next, we defined the model's components: @@ -56,7 +56,7 @@ import jax.numpy as jnp import numpyro.distributions as dist inf_hosp_int = deterministic.DeterministicPMF( - name="inf_hosp_int", value=inf_hosp_int + name="inf_hosp_int", value=inf_hosp_int_array ) hosp_rate = metaclass.DistributionalRV( @@ -84,7 +84,7 @@ I0 = InfectionInitializationProcess( distribution=dist.LogNormal(loc=jnp.log(100), scale=jnp.log(1.75)), ), InitializeInfectionsExponentialGrowth( - gen_int_array.size, + inf_hosp_int_array.size, deterministic.DeterministicVariable(name="rate", value=0.05), ), t_unit=1, @@ -201,11 +201,7 @@ hosp_model.run( out = hosp_model.plot_posterior( var="latent_hospital_admissions", ylab="Hospital Admissions", - obs_signal=np.pad( - daily_hosp_admits.astype(float), - (gen_int_array.size, 0), - constant_values=np.nan, - ), + obs_signal=daily_hosp_admits.astype(float), ) ``` @@ -329,11 +325,7 @@ The new model with the day-of-the-week effect can be compared to the previous mo out = hosp_model.plot_posterior( var="latent_hospital_admissions", ylab="Hospital Admissions", - obs_signal=np.pad( - daily_hosp_admits.astype(float), - (gen_int_array.size, 0), - constant_values=np.nan, - ), + obs_signal=daily_hosp_admits.astype(float), ) ``` @@ -344,10 +336,6 @@ out = hosp_model.plot_posterior( out_dow = hosp_model_dow.plot_posterior( var="latent_hospital_admissions", ylab="Hospital Admissions", - obs_signal=np.pad( - daily_hosp_admits.astype(float), - (gen_int_array.size, 0), - constant_values=np.nan, - ), + obs_signal=daily_hosp_admits.astype(float), ) ``` From 41b070f02adb34ae9fd22db38b5d84f03b618ce4 Mon Sep 17 00:00:00 2001 From: sbidari Date: Wed, 21 Aug 2024 13:55:51 -0400 Subject: [PATCH 05/16] update hosp model tests --- src/test/test_model_hosp_admissions.py | 94 +++++++++++++------------- 1 file changed, 47 insertions(+), 47 deletions(-) diff --git a/src/test/test_model_hosp_admissions.py b/src/test/test_model_hosp_admissions.py index fc2e4f57..e752f2c3 100644 --- a/src/test/test_model_hosp_admissions.py +++ b/src/test/test_model_hosp_admissions.py @@ -197,16 +197,6 @@ def test_model_hosp_no_obs_model(): value=jnp.array([0.25, 0.25, 0.25, 0.25]), ) - I0 = InfectionInitializationProcess( - "I0_initialization", - DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), - InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), - t_unit=1, - ) - - latent_infections = Infections() - Rt_process = simple_rt() - inf_hosp = DeterministicPMF( name="inf_hosp", value=jnp.array( @@ -233,6 +223,16 @@ def test_model_hosp_no_obs_model(): ), ) + I0 = InfectionInitializationProcess( + "I0_initialization", + DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), + InitializeInfectionsZeroPad(n_timepoints=inf_hosp.size()), + t_unit=1, + ) + + latent_infections = Infections() + Rt_process = simple_rt() + latent_admissions = HospitalAdmissions( infection_to_admission_interval_rv=inf_hosp, infection_hospitalization_ratio_rv=DistributionalRV( @@ -307,17 +307,6 @@ def test_model_hosp_with_obs_model(): name="gen_int", value=jnp.array([0.25, 0.25, 0.25, 0.25]) ) - I0 = InfectionInitializationProcess( - "I0_initialization", - DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), - InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), - t_unit=1, - ) - - latent_infections = Infections() - Rt_process = simple_rt() - observed_admissions = PoissonObservation("poisson_rv") - inf_hosp = DeterministicPMF( name="inf_hosp", value=jnp.array( @@ -340,10 +329,21 @@ def test_model_hosp_with_obs_model(): 0.1, 0.1, 0.05, - ], + ] ), ) + I0 = InfectionInitializationProcess( + "I0_initialization", + DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), + InitializeInfectionsZeroPad(n_timepoints=inf_hosp.size()), + t_unit=1, + ) + + latent_infections = Infections() + Rt_process = simple_rt() + observed_admissions = PoissonObservation("poisson_rv") + latent_admissions = HospitalAdmissions( infection_to_admission_interval_rv=inf_hosp, infection_hospitalization_ratio_rv=DistributionalRV( @@ -394,17 +394,6 @@ def test_model_hosp_with_obs_model_weekday_phosp_2(): value=jnp.array([0.25, 0.25, 0.25, 0.25]), ) - I0 = InfectionInitializationProcess( - "I0_initialization", - DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), - InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), - t_unit=1, - ) - - latent_infections = Infections() - Rt_process = simple_rt() - observed_admissions = PoissonObservation("poisson_rv") - inf_hosp = DeterministicPMF( name="inf_hosp", value=jnp.array( @@ -427,10 +416,21 @@ def test_model_hosp_with_obs_model_weekday_phosp_2(): 0.1, 0.1, 0.05, - ], + ] ), ) + I0 = InfectionInitializationProcess( + "I0_initialization", + DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), + InitializeInfectionsZeroPad(n_timepoints=inf_hosp.size()), + t_unit=1, + ) + + latent_infections = Infections() + Rt_process = simple_rt() + observed_admissions = PoissonObservation("poisson_rv") + hosp_report_prob_dist = UniformProbForTest(1, "hosp_report_prob_dist") weekday = UniformProbForTest(7, "weekday") @@ -487,18 +487,6 @@ def test_model_hosp_with_obs_model_weekday_phosp(): n_obs_to_generate = 30 pad_size = 5 - I0 = InfectionInitializationProcess( - "I0_initialization", - DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), - InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), - t_unit=1, - ) - - latent_infections = Infections() - Rt_process = simple_rt() - - observed_admissions = PoissonObservation("poisson_rv") - inf_hosp = DeterministicPMF( name="inf_hosp", value=jnp.array( @@ -521,10 +509,22 @@ def test_model_hosp_with_obs_model_weekday_phosp(): 0.1, 0.1, 0.05, - ], + ] ), ) + I0 = InfectionInitializationProcess( + "I0_initialization", + DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), + InitializeInfectionsZeroPad(n_timepoints=inf_hosp.size()), + t_unit=1, + ) + + latent_infections = Infections() + Rt_process = simple_rt() + + observed_admissions = PoissonObservation("poisson_rv") + # Other random components total_length = n_obs_to_generate + pad_size + gen_int.size() weekday = jnp.array([1, 1, 1, 1, 2, 2, 2]) From e9130cecb69e4e2daf59990d48a5bc1c3f80f046 Mon Sep 17 00:00:00 2001 From: sbidari Date: Wed, 21 Aug 2024 15:28:52 -0400 Subject: [PATCH 06/16] create helper function for convolve and add tests --- src/pyrenew/latent/hospitaladmissions.py | 13 ++++--- src/pyrenew/metaclass.py | 32 +++++++++++++++++ .../test_incidence_observed_with_delay.py | 34 +++++++++++++++++++ 3 files changed, 75 insertions(+), 4 deletions(-) create mode 100644 src/test/test_incidence_observed_with_delay.py diff --git a/src/pyrenew/latent/hospitaladmissions.py b/src/pyrenew/latent/hospitaladmissions.py index 947c4fd5..f8d9a2a8 100644 --- a/src/pyrenew/latent/hospitaladmissions.py +++ b/src/pyrenew/latent/hospitaladmissions.py @@ -10,7 +10,11 @@ import pyrenew.arrayutils as au from pyrenew.deterministic import DeterministicVariable -from pyrenew.metaclass import RandomVariable, SampledValue +from pyrenew.metaclass import ( + RandomVariable, + SampledValue, + compute_incidence_observed_with_delay, +) class HospitalAdmissionsSample(NamedTuple): @@ -210,11 +214,12 @@ def sample( *_, ) = self.infection_to_admission_interval_rv(**kwargs) - latent_hospital_admissions = jnp.convolve( - infection_hosp_rate.value * latent_infections.value, + latent_hospital_admissions = compute_incidence_observed_with_delay( + infection_hosp_rate.value, + latent_infections.value, infection_to_admission_interval.value, - mode="valid", ) + # Applying the day of the week effect. For this we need to: # 1. Get the day of the week effect # 2. Identify the offset of the latent_infections diff --git a/src/pyrenew/metaclass.py b/src/pyrenew/metaclass.py index b5be0a54..aad82ba4 100644 --- a/src/pyrenew/metaclass.py +++ b/src/pyrenew/metaclass.py @@ -126,6 +126,38 @@ def _assert_sample_and_rtype( return None +def compute_incidence_observed_with_delay( + incidence_to_observation_rate: float, + latent_incidence: ArrayLike, + incidence_to_observation_delay_interval: ArrayLike, +) -> ArrayLike: + """ + Computes incidences observed according + to a given observation rate and based + on a delay interval. + + Parameters + ---------- + incidence_to_observation_rate: float + The rate at which latent incidences are observed. + latent_incidence: ArrayLike + Incidence values based on the true underlying process. + incidence_to_observation_delay_interval: ArrayLike + Pmf of delay interval between incidence to observation. + + Returns + -------- + ArrayLike + The incidence after the observation delay. + """ + delay_obs_incidence = jnp.convolve( + incidence_to_observation_rate * latent_incidence, + incidence_to_observation_delay_interval, + mode="valid", + ) + return delay_obs_incidence + + class SampledValue(NamedTuple): """ A container for a value sampled from a RandomVariable. diff --git a/src/test/test_incidence_observed_with_delay.py b/src/test/test_incidence_observed_with_delay.py new file mode 100644 index 00000000..48d67b0c --- /dev/null +++ b/src/test/test_incidence_observed_with_delay.py @@ -0,0 +1,34 @@ +# numpydoc ignore=GL08 + +import jax +import jax.numpy as jnp +import numpy as np +from numpy.testing import assert_array_equal +import pytest + +from pyrenew.metaclass import compute_incidence_observed_with_delay + + +@pytest.mark.parametrize( + ["obs_rate", "latent_incidence", "delay_interval", "expected_output"], + [ + [ + jnp.array([1.0]), + jnp.array([1.0, 2.0, 3.0]), + jnp.array([1.0]), + jnp.array([1.0, 2.0, 3.0]), + ], + ], +) +def test(obs_rate, latent_incidence, delay_interval, expected_output): + """ + Tests for helper function to compute + incidence observed with a delay + """ + result = compute_incidence_observed_with_delay( + obs_rate, + latent_incidence, + delay_interval, + ) + + assert_array_equal(result, expected_output) From 47372886abbecc73b849d5cbffafb837b68e6ded Mon Sep 17 00:00:00 2001 From: sbidari Date: Wed, 21 Aug 2024 15:58:37 -0400 Subject: [PATCH 07/16] forgot to run precommit earlier --- src/pyrenew/metaclass.py | 4 ++-- .../test_incidence_observed_with_delay.py | 23 +++++++++++++++---- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/src/pyrenew/metaclass.py b/src/pyrenew/metaclass.py index aad82ba4..e6a19c1c 100644 --- a/src/pyrenew/metaclass.py +++ b/src/pyrenew/metaclass.py @@ -127,7 +127,7 @@ def _assert_sample_and_rtype( def compute_incidence_observed_with_delay( - incidence_to_observation_rate: float, + incidence_to_observation_rate: ArrayLike, latent_incidence: ArrayLike, incidence_to_observation_delay_interval: ArrayLike, ) -> ArrayLike: @@ -138,7 +138,7 @@ def compute_incidence_observed_with_delay( Parameters ---------- - incidence_to_observation_rate: float + incidence_to_observation_rate: ArrayLike The rate at which latent incidences are observed. latent_incidence: ArrayLike Incidence values based on the true underlying process. diff --git a/src/test/test_incidence_observed_with_delay.py b/src/test/test_incidence_observed_with_delay.py index 48d67b0c..180144b9 100644 --- a/src/test/test_incidence_observed_with_delay.py +++ b/src/test/test_incidence_observed_with_delay.py @@ -1,10 +1,8 @@ # numpydoc ignore=GL08 -import jax import jax.numpy as jnp -import numpy as np -from numpy.testing import assert_array_equal import pytest +from numpy.testing import assert_array_equal from pyrenew.metaclass import compute_incidence_observed_with_delay @@ -18,6 +16,24 @@ jnp.array([1.0]), jnp.array([1.0, 2.0, 3.0]), ], + [ + jnp.array([1.0, 0.1, 1.0]), + jnp.array([1.0, 2.0, 3.0]), + jnp.array([1.0]), + jnp.array([1.0, 0.2, 3.0]), + ], + [ + jnp.array([1.0]), + jnp.array([1.0, 2.0, 3.0]), + jnp.array([0.5, 0.5]), + jnp.array([1.5, 2.5]), + ], + [ + jnp.array([1.0]), + jnp.array([0, 2.0, 4.0]), + jnp.array([0.25, 0.5, 0.25]), + jnp.array([2]), + ], ], ) def test(obs_rate, latent_incidence, delay_interval, expected_output): @@ -30,5 +46,4 @@ def test(obs_rate, latent_incidence, delay_interval, expected_output): latent_incidence, delay_interval, ) - assert_array_equal(result, expected_output) From e87d742f8078bd6a2afe8bc5b16df6522fc814cb Mon Sep 17 00:00:00 2001 From: sbidari Date: Wed, 21 Aug 2024 19:31:05 -0400 Subject: [PATCH 08/16] update test for model with DOW effect --- src/test/test_model_hosp_admissions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/test/test_model_hosp_admissions.py b/src/test/test_model_hosp_admissions.py index e752f2c3..da606e97 100644 --- a/src/test/test_model_hosp_admissions.py +++ b/src/test/test_model_hosp_admissions.py @@ -526,7 +526,7 @@ def test_model_hosp_with_obs_model_weekday_phosp(): observed_admissions = PoissonObservation("poisson_rv") # Other random components - total_length = n_obs_to_generate + pad_size + gen_int.size() + total_length = n_obs_to_generate + pad_size + 1 # gen_int.size() weekday = jnp.array([1, 1, 1, 1, 2, 2, 2]) weekday = weekday / weekday.sum() From b4c5ca28b389c418e38eb10431c83b62143586a7 Mon Sep 17 00:00:00 2001 From: sbidari Date: Thu, 22 Aug 2024 15:46:24 -0400 Subject: [PATCH 09/16] renaming helper function, add n_initialization_point --- docs/source/tutorials/day_of_the_week.qmd | 4 +++- .../tutorials/hospital_admissions_model.qmd | 3 ++- src/pyrenew/latent/hospitaladmissions.py | 4 ++-- src/pyrenew/metaclass.py | 3 ++- src/test/test_latent_admissions.py | 3 ++- src/test/test_model_hosp_admissions.py | 23 ++++++++++++++----- 6 files changed, 28 insertions(+), 12 deletions(-) diff --git a/docs/source/tutorials/day_of_the_week.qmd b/docs/source/tutorials/day_of_the_week.qmd index 6cf411f9..d1c47e09 100644 --- a/docs/source/tutorials/day_of_the_week.qmd +++ b/docs/source/tutorials/day_of_the_week.qmd @@ -77,6 +77,8 @@ from pyrenew.latent import ( # Infection process latent_inf = latent.Infections() +n_initialization_points = max(gen_int_array.size, inf_hosp_int_array.size) + I0 = InfectionInitializationProcess( "I0_initialization", metaclass.DistributionalRV( @@ -84,7 +86,7 @@ I0 = InfectionInitializationProcess( distribution=dist.LogNormal(loc=jnp.log(100), scale=jnp.log(1.75)), ), InitializeInfectionsExponentialGrowth( - inf_hosp_int_array.size, + n_initialization_points, deterministic.DeterministicVariable(name="rate", value=0.05), ), t_unit=1, diff --git a/docs/source/tutorials/hospital_admissions_model.qmd b/docs/source/tutorials/hospital_admissions_model.qmd index 259eeb30..d7fb2c6f 100644 --- a/docs/source/tutorials/hospital_admissions_model.qmd +++ b/docs/source/tutorials/hospital_admissions_model.qmd @@ -168,6 +168,7 @@ from pyrenew.latent import ( # Infection process latent_inf = latent.Infections() +n_initialization_points = max(gen_int_array.size, inf_hosp_int_array.size) I0 = InfectionInitializationProcess( "I0_initialization", metaclass.DistributionalRV( @@ -175,7 +176,7 @@ I0 = InfectionInitializationProcess( distribution=dist.LogNormal(loc=jnp.log(100), scale=jnp.log(1.75)), ), InitializeInfectionsExponentialGrowth( - inf_hosp_int_array.size, + n_initialization_points, deterministic.DeterministicVariable(name="rate", value=0.05), ), t_unit=1, diff --git a/src/pyrenew/latent/hospitaladmissions.py b/src/pyrenew/latent/hospitaladmissions.py index 9d900146..382a8522 100644 --- a/src/pyrenew/latent/hospitaladmissions.py +++ b/src/pyrenew/latent/hospitaladmissions.py @@ -13,7 +13,7 @@ from pyrenew.metaclass import ( RandomVariable, SampledValue, - compute_incidence_observed_with_delay, + compute_delay_ascertained_incidence, ) @@ -214,7 +214,7 @@ def sample( *_, ) = self.infection_to_admission_interval_rv(**kwargs) - latent_hospital_admissions = compute_incidence_observed_with_delay( + latent_hospital_admissions = compute_delay_ascertained_incidence( infection_hosp_rate.value, latent_infections.value, infection_to_admission_interval.value, diff --git a/src/pyrenew/metaclass.py b/src/pyrenew/metaclass.py index 98422276..44d8278a 100644 --- a/src/pyrenew/metaclass.py +++ b/src/pyrenew/metaclass.py @@ -8,6 +8,7 @@ from typing import Callable, NamedTuple, Self, get_type_hints import jax +import jax.numpy as jnp import jax.random as jr import matplotlib.pyplot as plt import numpy as np @@ -125,7 +126,7 @@ def _assert_sample_and_rtype( return None -def compute_incidence_observed_with_delay( +def compute_delay_ascertained_incidence( incidence_to_observation_rate: ArrayLike, latent_incidence: ArrayLike, incidence_to_observation_delay_interval: ArrayLike, diff --git a/src/test/test_latent_admissions.py b/src/test/test_latent_admissions.py index c7e0a64b..526fbc31 100644 --- a/src/test/test_latent_admissions.py +++ b/src/test/test_latent_admissions.py @@ -22,9 +22,10 @@ def test_admissions_sample(): # Generating Rt and Infections to compute the hospital admissions rt = SimpleRt() + n_steps = 30 with numpyro.handlers.seed(rng_seed=223): - sim_rt = rt(n_steps=30)[0].value + sim_rt = rt(n=n_steps)[0].value gen_int = jnp.array([0.5, 0.1, 0.1, 0.2, 0.1]) inf_hosp_int_array = jnp.array( diff --git a/src/test/test_model_hosp_admissions.py b/src/test/test_model_hosp_admissions.py index 977e15a7..293c8ea5 100644 --- a/src/test/test_model_hosp_admissions.py +++ b/src/test/test_model_hosp_admissions.py @@ -222,11 +222,12 @@ def test_model_hosp_no_obs_model(): ] ), ) + n_initialization_points = max(gen_int.size(), inf_hosp.size()) I0 = InfectionInitializationProcess( "I0_initialization", DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), - InitializeInfectionsZeroPad(n_timepoints=inf_hosp.size()), + InitializeInfectionsZeroPad(n_timepoints=n_initialization_points), t_unit=1, ) @@ -259,7 +260,9 @@ def test_model_hosp_no_obs_model(): with numpyro.handlers.seed(rng_seed=223): model1_samp = model0.sample(n_datapoints=30) - np.testing.assert_array_almost_equal(model0_samp.Rt.value, model1_samp.Rt.value) + np.testing.assert_array_almost_equal( + model0_samp.Rt.value, model1_samp.Rt.value + ) np.testing.assert_array_equal( model0_samp.latent_infections.value, model1_samp.latent_infections.value, @@ -331,10 +334,12 @@ def test_model_hosp_with_obs_model(): ), ) + n_initialization_points = max(gen_int.size(), inf_hosp.size()) + I0 = InfectionInitializationProcess( "I0_initialization", DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), - InitializeInfectionsZeroPad(n_timepoints=inf_hosp.size()), + InitializeInfectionsZeroPad(n_timepoints=n_initialization_points), t_unit=1, ) @@ -418,10 +423,12 @@ def test_model_hosp_with_obs_model_weekday_phosp_2(): ), ) + n_initialization_points = max(gen_int.size(), inf_hosp.size()) + I0 = InfectionInitializationProcess( "I0_initialization", DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), - InitializeInfectionsZeroPad(n_timepoints=inf_hosp.size()), + InitializeInfectionsZeroPad(n_timepoints=n_initialization_points), t_unit=1, ) @@ -511,10 +518,12 @@ def test_model_hosp_with_obs_model_weekday_phosp(): ), ) + n_initialization_points = max(gen_int.size(), inf_hosp.size()) + I0 = InfectionInitializationProcess( "I0_initialization", DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), - InitializeInfectionsZeroPad(n_timepoints=inf_hosp.size()), + InitializeInfectionsZeroPad(n_timepoints=n_initialization_points), t_unit=1, ) @@ -561,7 +570,9 @@ def test_model_hosp_with_obs_model_weekday_phosp(): # Sampling and fitting model 0 (with no obs for infections) with numpyro.handlers.seed(rng_seed=223): - model1_samp = model1.sample(n_datapoints=n_obs_to_generate, padding=pad_size) + model1_samp = model1.sample( + n_datapoints=n_obs_to_generate, padding=pad_size + ) # Showed during merge conflict, but unsure if it will be needed # pad_size = 5 From ebc9dd2c5f6c875fc6958facb88714c755be9da7 Mon Sep 17 00:00:00 2001 From: sbidari Date: Thu, 22 Aug 2024 17:16:59 -0400 Subject: [PATCH 10/16] create randomvariable module --- pyrenew/metaclass.py | 472 +----------------- pyrenew/randomvariable/__init__.py | 17 + .../randomvariable/distributionalvariable.py | 342 +++++++++++++ pyrenew/randomvariable/transformedvariable.py | 140 ++++++ 4 files changed, 500 insertions(+), 471 deletions(-) create mode 100644 pyrenew/randomvariable/__init__.py create mode 100644 pyrenew/randomvariable/distributionalvariable.py create mode 100644 pyrenew/randomvariable/transformedvariable.py diff --git a/pyrenew/metaclass.py b/pyrenew/metaclass.py index 44d8278a..344cf894 100644 --- a/pyrenew/metaclass.py +++ b/pyrenew/metaclass.py @@ -5,22 +5,18 @@ """ from abc import ABCMeta, abstractmethod -from typing import Callable, NamedTuple, Self, get_type_hints +from typing import NamedTuple, get_type_hints import jax import jax.numpy as jnp import jax.random as jr import matplotlib.pyplot as plt import numpy as np -import numpyro -import numpyro.distributions as dist import polars as pl from jax.typing import ArrayLike from numpyro.infer import MCMC, NUTS, Predictive -from numpyro.infer.reparam import Reparam from pyrenew.mcmcutils import plot_posterior, spread_draws -from pyrenew.transformation import Transform def _assert_type(arg_name: str, value, expected_type) -> None: @@ -309,338 +305,6 @@ def __call__(self, **kwargs): return self.sample(**kwargs) -class DynamicDistributionalRV(RandomVariable): - """ - Wrapper class for random variables that sample - from a single :class:`numpyro.distributions.Distribution` - that is parameterized / instantiated at `sample()` time - (rather than at RandomVariable instantiation time). - """ - - def __init__( - self, - name: str, - distribution_constructor: Callable, - reparam: Reparam = None, - expand_by_shape: tuple = None, - ) -> None: - """ - Default constructor for DynamicDistributionalRV. - - Parameters - ---------- - name : str - Name of the random variable. - distribution_constructor : Callable - Callable that returns a concrete parametrized - numpyro.Distributions.distribution instance. - reparam : numpyro.infer.reparam.Reparam - If not None, reparameterize sampling - from the distribution according to the - given numpyro reparameterizer - expand_by_shape : tuple, optional - If not None, call :meth:`expand_by()` on the - underlying distribution once it is instianted - with the given `expand_by_shape`. - Default None. - - Returns - ------- - None - """ - - self.name = name - self.validate(distribution_constructor) - self.distribution_constructor = distribution_constructor - if reparam is not None: - self.reparam_dict = {self.name: reparam} - else: - self.reparam_dict = {} - if not (expand_by_shape is None or isinstance(expand_by_shape, tuple)): - raise ValueError( - "expand_by_shape must be a tuple or be None ", - f"Got {type(expand_by_shape)}", - ) - self.expand_by_shape = expand_by_shape - - return None - - @staticmethod - def validate(distribution_constructor: any) -> None: - """ - Confirm that the distribution_constructor is - callable. - - Parameters - ---------- - distribution_constructor : any - Putative distribution_constructor to validate. - - Returns - ------- - None or raises a ValueError - """ - if not callable(distribution_constructor): - raise ValueError( - "To instantiate a DynamicDistributionalRV, ", - "one must provide a Callable that returns a " - "numpyro.distributions.Distribution as the " - "distribution_constructor argument. " - f"Got {type(distribution_constructor)}, which " - "does not appear to be callable", - ) - return None - - def sample( - self, - *args, - obs: ArrayLike = None, - **kwargs, - ) -> tuple: - """ - Sample from the distributional rv. - - Parameters - ---------- - *args : - Positional arguments passed to self.distribution_constructor - obs : ArrayLike, optional - Observations passed as the `obs` argument to - :meth:`numpyro.sample()`. Default `None`. - **kwargs : dict, optional - Keyword arguments passed to self.distribution_constructor - - Returns - ------- - SampledValue - Containing a sample from the distribution. - """ - distribution = self.distribution_constructor(*args, **kwargs) - if self.expand_by_shape is not None: - distribution = distribution.expand_by(self.expand_by_shape) - with numpyro.handlers.reparam(config=self.reparam_dict): - sample = numpyro.sample( - name=self.name, - fn=distribution, - obs=obs, - ) - return ( - SampledValue( - sample, - t_start=self.t_start, - t_unit=self.t_unit, - ), - ) - - def expand_by(self, sample_shape) -> Self: - """ - Expand the distribution by a given - shape_shape, if possible. Returns a - new DynamicDistributionalRV whose underlying - distribution will be expanded by the given shape - at sample() time. - - Parameters - ---------- - sample_shape : tuple - Sample shape by which to expand the distribution. - Passed to the expand_by() method of - :class:`numpyro.distributions.Distribution` - after the distribution is instantiated. - - Returns - ------- - DynamicDistributionalRV - Whose underlying distribution will be expanded by - the given sample shape at sampling time. - """ - return DynamicDistributionalRV( - name=self.name, - distribution_constructor=self.distribution_constructor, - reparam=self.reparam_dict.get(self.name, None), - expand_by_shape=sample_shape, - ) - - -class StaticDistributionalRV(RandomVariable): - """ - Wrapper class for random variables that sample - from a single :class:`numpyro.distributions.Distribution` - that is parameterized / instantiated at RandomVariable - instantiation time (rather than at `sample()`-ing time). - """ - - def __init__( - self, - name: str, - distribution: numpyro.distributions.Distribution, - reparam: Reparam = None, - ) -> None: - """ - Default constructor for DistributionalRV. - - Parameters - ---------- - name : str - Name of the random variable. - distribution : numpyro.distributions.Distribution - Distribution of the random variable. - reparam : numpyro.infer.reparam.Reparam - If not None, reparameterize sampling - from the distribution according to the - given numpyro reparameterizer - - Returns - ------- - None - """ - - self.name = name - self.validate(distribution) - self.distribution = distribution - if reparam is not None: - self.reparam_dict = {self.name: reparam} - else: - self.reparam_dict = {} - - return None - - @staticmethod - def validate(distribution: any) -> None: - """ - Validation of the distribution. - """ - if not isinstance(distribution, numpyro.distributions.Distribution): - raise ValueError( - "distribution should be an instance of " - "numpyro.distributions.Distribution, got " - "{type(distribution)}" - ) - - return None - - def sample( - self, - obs: ArrayLike | None = None, - **kwargs, - ) -> tuple: - """ - Sample from the distribution. - - Parameters - ---------- - obs : ArrayLike, optional - Observations passed as the `obs` argument to - :meth:`numpyro.sample()`. Default `None`. - **kwargs : dict, optional - Additional keyword arguments passed through - to internal sample calls, should there be any. - - Returns - ------- - SampledValue - Containing a sample from the distribution. - """ - with numpyro.handlers.reparam(config=self.reparam_dict): - sample = numpyro.sample( - name=self.name, - fn=self.distribution, - obs=obs, - ) - return ( - SampledValue( - sample, - t_start=self.t_start, - t_unit=self.t_unit, - ), - ) - - def expand_by(self, sample_shape) -> Self: - """ - Expand the distribution by the given sample_shape, - if possible. Returns a new StaticDistributionalRV - whose underlying distribution has been expanded by - the given sample_shape via - :meth:`~numpyro.distributions.Distribution.expand_by()` - - Parameters - ---------- - sample_shape : tuple - Sample shape for the expansion. Passed to the - :meth:`expand_by()` method of - :class:`numpyro.distributions.Distribution`. - - Returns - ------- - StaticDistributionalRV - Whose underlying distribution has been expanded by - the given sample shape. - """ - if not isinstance(sample_shape, tuple): - raise ValueError( - "sample_shape for expand()-ing " - "a DistributionalRV must be a " - f"tuple. Got {type(sample_shape)}" - ) - return StaticDistributionalRV( - name=self.name, - distribution=self.distribution.expand_by(sample_shape), - reparam=self.reparam_dict.get(self.name, None), - ) - - -def DistributionalRV( - name: str, - distribution: numpyro.distributions.Distribution | Callable, - reparam: Reparam = None, -) -> RandomVariable: - """ - Factory function to generate Distributional RandomVariables, - either static or dynamic. - - Parameters - ---------- - name : str - Name of the random variable. - - distribution: numpyro.distributions.Distribution | Callable - Either numpyro.distributions.Distribution instance - given the static distribution of the random variable or - a callable that returns a parameterized - numpyro.distributions.Distribution when called, which - allows for dynamically-parameterized DistributionalRVs, - e.g. a Normal distribution with an inferred location and - scale. - - reparam : numpyro.infer.reparam.Reparam - If not None, reparameterize sampling - from the distribution according to the - given numpyro reparameterizer - - Returns - ------- - DynamicDistributionalRV | StaticDistributionalRV or - raises a ValueError if a distribution cannot be constructed. - """ - if isinstance(distribution, dist.Distribution): - return StaticDistributionalRV( - name=name, distribution=distribution, reparam=reparam - ) - elif callable(distribution): - return DynamicDistributionalRV( - name=name, distribution_constructor=distribution, reparam=reparam - ) - else: - raise ValueError( - "distribution argument to DistributionalRV " - "must be either a numpyro.distributions.Distribution " - "(for instantiating a static DistributionalRV) " - "or a callable that returns a " - "numpyro.distributions.Distribution (for " - "a dynamic DistributionalRV" - ) - - class Model(metaclass=ABCMeta): """Abstract base class for models""" @@ -924,137 +588,3 @@ def prior_predictive( ) return predictive(rng_key, **kwargs) - - -class TransformedRandomVariable(RandomVariable): - """ - Class to represent RandomVariables defined - by taking the output of another RV's - :meth:`RandomVariable.sample()` method - and transforming it by a given transformation - (typically a :class:`Transform`) - """ - - def __init__( - self, - name: str, - base_rv: RandomVariable, - transforms: Transform | tuple[Transform], - ): - """ - Default constructor - - Parameters - ---------- - name : str - A name for the random variable instance. - base_rv : RandomVariable - The underlying (untransformed) RandomVariable. - transforms : Transform - Transformation or tuple of transformations - to apply to the output of - `base_rv.sample()`; single values will be coerced to - a length-one tuple. If a tuple, should be the same - length as the tuple returned by `base_rv.sample()`. - - Returns - ------- - None - """ - self.name = name - self.base_rv = base_rv - - if not isinstance(transforms, tuple): - transforms = (transforms,) - self.transforms = transforms - self.validate() - - def sample(self, record=False, **kwargs) -> tuple: - """ - Sample method. Call self.base_rv.sample() - and then apply the transforms specified - in self.transforms. - - Parameters - ---------- - record : bool, optional - Whether to record the value of the deterministic - RandomVariable. Defaults to False. - **kwargs : - Keyword arguments passed to self.base_rv.sample() - - Returns - ------- - tuple of the same length as the tuple returned by - self.base_rv.sample() - """ - - untransformed_values = self.base_rv.sample(**kwargs) - transformed_values = tuple( - SampledValue( - t(uv.value), - t_start=self.t_start, - t_unit=self.t_unit, - ) - for t, uv in zip(self.transforms, untransformed_values) - ) - - if record: - if len(untransformed_values) == 1: - numpyro.deterministic(self.name, transformed_values[0].value) - else: - suffixes = ( - untransformed_values._fields - if hasattr(untransformed_values, "_fields") - else range(len(transformed_values)) - ) - for suffix, tv in zip(suffixes, transformed_values): - numpyro.deterministic(f"{self.name}_{suffix}", tv.value) - - return transformed_values - - def sample_length(self): - """ - Sample length for a transformed - random variable must be equal to the - length of self.transforms or - validation will fail. - - Returns - ------- - int - Equal to the length self.transforms - """ - return len(self.transforms) - - def validate(self): - """ - Perform validation checks on a - TransformedRandomVariable instance, - confirming that all transformations - are callable and that the number of - transformations is equal to the sample - length of the base random variable. - - Returns - ------- - None - on successful validation, or raise a ValueError - """ - for t in self.transforms: - if not callable(t): - raise ValueError( - "All entries in self.transforms " "must be callable" - ) - if hasattr(self.base_rv, "sample_length"): - n_transforms = len(self.transforms) - n_entries = self.base_rv.sample_length() - if not n_transforms == n_entries: - raise ValueError( - "There must be exactly as many transformations " - "specified as entries self.transforms as there are " - "entries in the tuple returned by " - "self.base_rv.sample()." - f"Got {n_transforms} transforms and {n_entries} " - "entries" - ) diff --git a/pyrenew/randomvariable/__init__.py b/pyrenew/randomvariable/__init__.py new file mode 100644 index 00000000..4f154b2d --- /dev/null +++ b/pyrenew/randomvariable/__init__.py @@ -0,0 +1,17 @@ +# -*- coding: utf-8 -*- + +# numpydoc ignore=GL08 + +from pyrenew.randomvariable.distributionalvariable import ( + DistributionalVariable, + DynamicDistributionalVariable, + StaticDistributionalVariable, +) +from pyrenew.randomvariable.transformedvariable import TransformedVariable + +__all__ = [ + "DistributionalVariable", + "StaticDistributionalVariable", + "DynamicDistributionalVariable", + "TransformedVariable", +] diff --git a/pyrenew/randomvariable/distributionalvariable.py b/pyrenew/randomvariable/distributionalvariable.py new file mode 100644 index 00000000..671dde08 --- /dev/null +++ b/pyrenew/randomvariable/distributionalvariable.py @@ -0,0 +1,342 @@ +# numpydoc ignore=GL08 + +from typing import Callable, Self + +import numpyro +import numpyro.distributions as dist +from jax.typing import ArrayLike +from numpyro.infer.reparam import Reparam + +from pyrenew.metaclass import RandomVariable, SampledValue + + +class DynamicDistributionalVariable(RandomVariable): + """ + Wrapper class for random variables that sample + from a single :class:`numpyro.distributions.Distribution` + that is parameterized / instantiated at `sample()` time + (rather than at RandomVariable instantiation time). + """ + + def __init__( + self, + name: str, + distribution_constructor: Callable, + reparam: Reparam = None, + expand_by_shape: tuple = None, + ) -> None: + """ + Default constructor for DynamicDistributionalVariable. + + Parameters + ---------- + name : str + Name of the random variable. + distribution_constructor : Callable + Callable that returns a concrete parametrized + numpyro.Distributions.distribution instance. + reparam : numpyro.infer.reparam.Reparam + If not None, reparameterize sampling + from the distribution according to the + given numpyro reparameterizer + expand_by_shape : tuple, optional + If not None, call :meth:`expand_by()` on the + underlying distribution once it is instianted + with the given `expand_by_shape`. + Default None. + + Returns + ------- + None + """ + + self.name = name + self.validate(distribution_constructor) + self.distribution_constructor = distribution_constructor + if reparam is not None: + self.reparam_dict = {self.name: reparam} + else: + self.reparam_dict = {} + if not (expand_by_shape is None or isinstance(expand_by_shape, tuple)): + raise ValueError( + "expand_by_shape must be a tuple or be None ", + f"Got {type(expand_by_shape)}", + ) + self.expand_by_shape = expand_by_shape + + return None + + @staticmethod + def validate(distribution_constructor: any) -> None: + """ + Confirm that the distribution_constructor is + callable. + + Parameters + ---------- + distribution_constructor : any + Putative distribution_constructor to validate. + + Returns + ------- + None or raises a ValueError + """ + if not callable(distribution_constructor): + raise ValueError( + "To instantiate a DynamicDistributionalVariable, ", + "one must provide a Callable that returns a " + "numpyro.distributions.Distribution as the " + "distribution_constructor argument. " + f"Got {type(distribution_constructor)}, which " + "does not appear to be callable", + ) + return None + + def sample( + self, + *args, + obs: ArrayLike = None, + **kwargs, + ) -> tuple: + """ + Sample from the distributional rv. + + Parameters + ---------- + *args : + Positional arguments passed to self.distribution_constructor + obs : ArrayLike, optional + Observations passed as the `obs` argument to + :meth:`numpyro.sample()`. Default `None`. + **kwargs : dict, optional + Keyword arguments passed to self.distribution_constructor + + Returns + ------- + SampledValue + Containing a sample from the distribution. + """ + distribution = self.distribution_constructor(*args, **kwargs) + if self.expand_by_shape is not None: + distribution = distribution.expand_by(self.expand_by_shape) + with numpyro.handlers.reparam(config=self.reparam_dict): + sample = numpyro.sample( + name=self.name, + fn=distribution, + obs=obs, + ) + return ( + SampledValue( + sample, + t_start=self.t_start, + t_unit=self.t_unit, + ), + ) + + def expand_by(self, sample_shape) -> Self: + """ + Expand the distribution by a given + shape_shape, if possible. Returns a + new DynamicDistributionalVariable whose underlying + distribution will be expanded by the given shape + at sample() time. + + Parameters + ---------- + sample_shape : tuple + Sample shape by which to expand the distribution. + Passed to the expand_by() method of + :class:`numpyro.distributions.Distribution` + after the distribution is instantiated. + + Returns + ------- + DynamicDistributionalVariable + Whose underlying distribution will be expanded by + the given sample shape at sampling time. + """ + return DynamicDistributionalVariable( + name=self.name, + distribution_constructor=self.distribution_constructor, + reparam=self.reparam_dict.get(self.name, None), + expand_by_shape=sample_shape, + ) + + +class StaticDistributionalVariable(RandomVariable): + """ + Wrapper class for random variables that sample + from a single :class:`numpyro.distributions.Distribution` + that is parameterized / instantiated at RandomVariable + instantiation time (rather than at `sample()`-ing time). + """ + + def __init__( + self, + name: str, + distribution: numpyro.distributions.Distribution, + reparam: Reparam = None, + ) -> None: + """ + Default constructor for DistributionalVariable. + + Parameters + ---------- + name : str + Name of the random variable. + distribution : numpyro.distributions.Distribution + Distribution of the random variable. + reparam : numpyro.infer.reparam.Reparam + If not None, reparameterize sampling + from the distribution according to the + given numpyro reparameterizer + + Returns + ------- + None + """ + + self.name = name + self.validate(distribution) + self.distribution = distribution + if reparam is not None: + self.reparam_dict = {self.name: reparam} + else: + self.reparam_dict = {} + + return None + + @staticmethod + def validate(distribution: any) -> None: + """ + Validation of the distribution. + """ + if not isinstance(distribution, numpyro.distributions.Distribution): + raise ValueError( + "distribution should be an instance of " + "numpyro.distributions.Distribution, got " + "{type(distribution)}" + ) + + return None + + def sample( + self, + obs: ArrayLike | None = None, + **kwargs, + ) -> tuple: + """ + Sample from the distribution. + + Parameters + ---------- + obs : ArrayLike, optional + Observations passed as the `obs` argument to + :meth:`numpyro.sample()`. Default `None`. + **kwargs : dict, optional + Additional keyword arguments passed through + to internal sample calls, should there be any. + + Returns + ------- + SampledValue + Containing a sample from the distribution. + """ + with numpyro.handlers.reparam(config=self.reparam_dict): + sample = numpyro.sample( + name=self.name, + fn=self.distribution, + obs=obs, + ) + return ( + SampledValue( + sample, + t_start=self.t_start, + t_unit=self.t_unit, + ), + ) + + def expand_by(self, sample_shape) -> Self: + """ + Expand the distribution by the given sample_shape, + if possible. Returns a new StaticDistributionalVariable + whose underlying distribution has been expanded by + the given sample_shape via + :meth:`~numpyro.distributions.Distribution.expand_by()` + + Parameters + ---------- + sample_shape : tuple + Sample shape for the expansion. Passed to the + :meth:`expand_by()` method of + :class:`numpyro.distributions.Distribution`. + + Returns + ------- + StaticDistributionalVariable + Whose underlying distribution has been expanded by + the given sample shape. + """ + if not isinstance(sample_shape, tuple): + raise ValueError( + "sample_shape for expand()-ing " + "a DistributionalVariable must be a " + f"tuple. Got {type(sample_shape)}" + ) + return StaticDistributionalVariable( + name=self.name, + distribution=self.distribution.expand_by(sample_shape), + reparam=self.reparam_dict.get(self.name, None), + ) + + +def DistributionalVariable( + name: str, + distribution: numpyro.distributions.Distribution | Callable, + reparam: Reparam = None, +) -> RandomVariable: + """ + Factory function to generate Distributional RandomVariables, + either static or dynamic. + + Parameters + ---------- + name : str + Name of the random variable. + + distribution: numpyro.distributions.Distribution | Callable + Either numpyro.distributions.Distribution instance + given the static distribution of the random variable or + a callable that returns a parameterized + numpyro.distributions.Distribution when called, which + allows for dynamically-parameterized DistributionalVariables, + e.g. a Normal distribution with an inferred location and + scale. + + reparam : numpyro.infer.reparam.Reparam + If not None, reparameterize sampling + from the distribution according to the + given numpyro reparameterizer + + Returns + ------- + DynamicDistributionalVariable | StaticDistributionalVariable or + raises a ValueError if a distribution cannot be constructed. + """ + if isinstance(distribution, dist.Distribution): + return StaticDistributionalVariable( + name=name, distribution=distribution, reparam=reparam + ) + elif callable(distribution): + return DynamicDistributionalVariable( + name=name, distribution_constructor=distribution, reparam=reparam + ) + else: + raise ValueError( + "distribution argument to DistributionalVariable " + "must be either a numpyro.distributions.Distribution " + "(for instantiating a static DistributionalVariable) " + "or a callable that returns a " + "numpyro.distributions.Distribution (for " + "a dynamic DistributionalVariable" + ) diff --git a/pyrenew/randomvariable/transformedvariable.py b/pyrenew/randomvariable/transformedvariable.py new file mode 100644 index 00000000..36519a24 --- /dev/null +++ b/pyrenew/randomvariable/transformedvariable.py @@ -0,0 +1,140 @@ +# numpydoc ignore=GL08 + +import numpyro + +from pyrenew.metaclass import RandomVariable, SampledValue +from pyrenew.transformation import Transform + + +class TransformedVariable(RandomVariable): + """ + Class to represent RandomVariables defined + by taking the output of another RV's + :meth:`RandomVariable.sample()` method + and transforming it by a given transformation + (typically a :class:`Transform`) + """ + + def __init__( + self, + name: str, + base_rv: RandomVariable, + transforms: Transform | tuple[Transform], + ): + """ + Default constructor + + Parameters + ---------- + name : str + A name for the random variable instance. + base_rv : RandomVariable + The underlying (untransformed) RandomVariable. + transforms : Transform + Transformation or tuple of transformations + to apply to the output of + `base_rv.sample()`; single values will be coerced to + a length-one tuple. If a tuple, should be the same + length as the tuple returned by `base_rv.sample()`. + + Returns + ------- + None + """ + self.name = name + self.base_rv = base_rv + + if not isinstance(transforms, tuple): + transforms = (transforms,) + self.transforms = transforms + self.validate() + + def sample(self, record=False, **kwargs) -> tuple: + """ + Sample method. Call self.base_rv.sample() + and then apply the transforms specified + in self.transforms. + + Parameters + ---------- + record : bool, optional + Whether to record the value of the deterministic + RandomVariable. Defaults to False. + **kwargs : + Keyword arguments passed to self.base_rv.sample() + + Returns + ------- + tuple of the same length as the tuple returned by + self.base_rv.sample() + """ + + untransformed_values = self.base_rv.sample(**kwargs) + transformed_values = tuple( + SampledValue( + t(uv.value), + t_start=self.t_start, + t_unit=self.t_unit, + ) + for t, uv in zip(self.transforms, untransformed_values) + ) + + if record: + if len(untransformed_values) == 1: + numpyro.deterministic(self.name, transformed_values[0].value) + else: + suffixes = ( + untransformed_values._fields + if hasattr(untransformed_values, "_fields") + else range(len(transformed_values)) + ) + for suffix, tv in zip(suffixes, transformed_values): + numpyro.deterministic(f"{self.name}_{suffix}", tv.value) + + return transformed_values + + def sample_length(self): + """ + Sample length for a transformed + random variable must be equal to the + length of self.transforms or + validation will fail. + + Returns + ------- + int + Equal to the length self.transforms + """ + return len(self.transforms) + + def validate(self): + """ + Perform validation checks on a + TransformedVariable instance, + confirming that all transformations + are callable and that the number of + transformations is equal to the sample + length of the base random variable. + + Returns + ------- + None + on successful validation, or raise a ValueError + """ + for t in self.transforms: + if not callable(t): + raise ValueError( + "All entries in self.transforms " "must be callable" + ) + if hasattr(self.base_rv, "sample_length"): + n_transforms = len(self.transforms) + n_entries = self.base_rv.sample_length() + if not n_transforms == n_entries: + raise ValueError( + "There must be exactly as many transformations " + "specified as entries self.transforms as there are " + "entries in the tuple returned by " + "self.base_rv.sample()." + f"Got {n_transforms} transforms and {n_entries} " + "entries" + ) From 22509cfa147f3d41bf61fefebb6244ededbca9ed Mon Sep 17 00:00:00 2001 From: sbidari Date: Thu, 22 Aug 2024 17:26:10 -0400 Subject: [PATCH 11/16] make suffixes across variables unifrom --- docs/source/tutorials/basic_renewal_model.qmd | 14 ++--- docs/source/tutorials/day_of_the_week.qmd | 22 ++++---- docs/source/tutorials/extending_pyrenew.qmd | 12 ++--- .../tutorials/hospital_admissions_model.qmd | 16 +++--- docs/source/tutorials/periodic_effects.qmd | 2 +- pyrenew/process/iidrandomsequence.py | 10 ++-- pyrenew/process/randomwalk.py | 6 +-- test/test_assert_sample_and_rtype.py | 4 +- test/test_assert_type.py | 8 ++- test/test_differenced_process.py | 4 +- test/test_distributional_rv.py | 53 ++++++++++--------- test/test_forecast.py | 4 +- test/test_iid_random_sequence.py | 10 ++-- test/test_infection_initialization_process.py | 6 +-- test/test_latent_admissions.py | 4 +- test/test_model_basic_renewal.py | 14 ++--- test/test_model_hosp_admissions.py | 30 ++++++----- test/test_predictive.py | 4 +- test/test_random_key.py | 4 +- test/test_random_walk.py | 8 +-- test/test_transformed_rv_class.py | 42 ++++++++------- test/utils.py | 10 ++-- 22 files changed, 152 insertions(+), 135 deletions(-) diff --git a/docs/source/tutorials/basic_renewal_model.qmd b/docs/source/tutorials/basic_renewal_model.qmd index e9a3fcba..7d40ed93 100644 --- a/docs/source/tutorials/basic_renewal_model.qmd +++ b/docs/source/tutorials/basic_renewal_model.qmd @@ -27,8 +27,8 @@ from pyrenew.deterministic import DeterministicPMF from pyrenew.model import RtInfectionsRenewalModel from pyrenew.metaclass import ( RandomVariable, - DistributionalRV, - TransformedRandomVariable, + DistributionalVariable, + TransformedVariable, ) import pyrenew.transformation as t from numpyro.infer.reparam import LocScaleReparam @@ -64,7 +64,7 @@ flowchart LR subgraph latent[Latent module] inf["latent_infections_rv\n(Infections)"] - i0["I0_rv\n(DistributionalRV)"] + i0["I0_rv\n(DistributionalVariable)"] end subgraph process[Process module] @@ -126,7 +126,7 @@ gen_int = DeterministicPMF(name="gen_int", value=pmf_array) # (2) Initial infections (inferred with a prior) I0 = InfectionInitializationProcess( "I0_initialization", - DistributionalRV(name="I0", distribution=dist.LogNormal(2.5, 1)), + DistributionalVariable(name="I0", distribution=dist.LogNormal(2.5, 1)), InitializeInfectionsZeroPad(pmf_array.size), t_unit=1, ) @@ -142,17 +142,17 @@ class MyRt(RandomVariable): def sample(self, n: int, **kwargs) -> tuple: sd_rt = numpyro.sample("Rt_random_walk_sd", dist.HalfNormal(0.025)) - rt_rv = TransformedRandomVariable( + rt_rv = TransformedVariable( name="log_rt_random_walk", base_rv=RandomWalk( name="log_rt", - step_rv=DistributionalRV( + step_rv=DistributionalVariable( name="rw_step_rv", distribution=dist.Normal(0, 0.025) ), ), transforms=t.ExpTransform(), ) - rt_init_rv = DistributionalRV( + rt_init_rv = DistributionalVariable( name="init_log_rt", distribution=dist.Normal(0, 0.2) ) init_rt, *_ = rt_init_rv.sample() diff --git a/docs/source/tutorials/day_of_the_week.qmd b/docs/source/tutorials/day_of_the_week.qmd index d1c47e09..4deb88dd 100644 --- a/docs/source/tutorials/day_of_the_week.qmd +++ b/docs/source/tutorials/day_of_the_week.qmd @@ -59,7 +59,7 @@ inf_hosp_int = deterministic.DeterministicPMF( name="inf_hosp_int", value=inf_hosp_int_array ) -hosp_rate = metaclass.DistributionalRV( +hosp_rate = metaclass.DistributionalVariable( name="IHR", distribution=dist.LogNormal(jnp.log(0.05), jnp.log(1.1)) ) @@ -81,7 +81,7 @@ n_initialization_points = max(gen_int_array.size, inf_hosp_int_array.size) I0 = InfectionInitializationProcess( "I0_initialization", - metaclass.DistributionalRV( + metaclass.DistributionalVariable( name="I0", distribution=dist.LogNormal(loc=jnp.log(100), scale=jnp.log(1.75)), ), @@ -113,11 +113,11 @@ class MyRt(metaclass.RandomVariable): sd_rt, *_ = self.sd_rv() # Random walk step - step_rv = metaclass.DistributionalRV( + step_rv = metaclass.DistributionalVariable( name="rw_step_rv", distribution=dist.Normal(0, sd_rt.value) ) - rt_init_rv = metaclass.DistributionalRV( + rt_init_rv = metaclass.DistributionalVariable( name="init_log_rt", distribution=dist.Normal(0, 0.2) ) @@ -128,7 +128,7 @@ class MyRt(metaclass.RandomVariable): ) # Transforming the random walk to the Rt scale - rt_rv = metaclass.TransformedRandomVariable( + rt_rv = metaclass.TransformedVariable( name="Rt_rv", base_rv=base_rv, transforms=transformation.ExpTransform(), @@ -139,7 +139,7 @@ class MyRt(metaclass.RandomVariable): rtproc = MyRt( - metaclass.DistributionalRV( + metaclass.DistributionalVariable( name="Rt_random_walk_sd", distribution=dist.HalfNormal(0.025) ) ) @@ -152,9 +152,9 @@ rtproc = MyRt( # | code-fold: true # we place a log-Normal prior on the concentration # parameter of the negative binomial. -nb_conc_rv = metaclass.TransformedRandomVariable( +nb_conc_rv = metaclass.TransformedVariable( "concentration", - metaclass.DistributionalRV( + metaclass.DistributionalVariable( name="concentration_raw", distribution=dist.TruncatedNormal(loc=0, scale=1, low=0.01), ), @@ -212,16 +212,16 @@ out = hosp_model.plot_posterior( We will re-use the infection to admission interval and infection to hospitalization rate from the previous model. But we will also add a day-of-the-week effect. To do this, we will add two additional arguments to the latent hospital admissions random variable: `day_of_the_week_rv` (a `RandomVariable`) and `obs_data_first_day_of_the_week` (an `int` mapping days of the week from 0:6, zero being Monday). The `day_of_the_week_rv`'s sample method should return a vector of length seven; those values are then broadcasted to match the length of the dataset. Moreover, since the observed data may start in a weekday other than Monday, the `obs_data_first_day_of_the_week` argument is used to offset the day-of-the-week effect. -For this example, the effect will be passed as a scaled Dirichlet distribution. It will consist of a `TransformedRandomVariable` that samples an array of length seven from numpyro's `distributions.Dirichlet` and applies a `transformation.AffineTransform` to scale it by seven. [^note-other-examples]: +For this example, the effect will be passed as a scaled Dirichlet distribution. It will consist of a `TransformedVariable` that samples an array of length seven from numpyro's `distributions.Dirichlet` and applies a `transformation.AffineTransform` to scale it by seven. [^note-other-examples]: [^note-other-examples]: A similar weekday effect is implemented in its own module, with example code [here](periodic_effects.html). ```{python} # | label: weekly-effect # Instantiating the day-of-the-week effect -dayofweek_effect = metaclass.TransformedRandomVariable( +dayofweek_effect = metaclass.TransformedVariable( name="dayofweek_effect", - base_rv=metaclass.DistributionalRV( + base_rv=metaclass.DistributionalVariable( name="dayofweek_effect_raw", distribution=dist.Dirichlet(jnp.ones(7)), ), diff --git a/docs/source/tutorials/extending_pyrenew.qmd b/docs/source/tutorials/extending_pyrenew.qmd index 14615485..18639c55 100644 --- a/docs/source/tutorials/extending_pyrenew.qmd +++ b/docs/source/tutorials/extending_pyrenew.qmd @@ -31,8 +31,8 @@ from pyrenew.model import RtInfectionsRenewalModel from pyrenew.process import RandomWalk from pyrenew.metaclass import ( RandomVariable, - DistributionalRV, - TransformedRandomVariable, + DistributionalVariable, + TransformedVariable, ) from pyrenew.latent import ( InfectionInitializationProcess, @@ -53,7 +53,7 @@ feedback_strength = DeterministicVariable(name="feedback_strength", value=0.01) I0 = InfectionInitializationProcess( "I0_initialization", - DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), + DistributionalVariable(name="I0", distribution=dist.LogNormal(0, 1)), InitializeInfectionsExponentialGrowth( gen_int_array.size, DeterministicVariable(name="rate", value=0.05), @@ -75,17 +75,17 @@ class MyRt(RandomVariable): def sample(self, n: int, **kwargs) -> tuple: sd_rt = numpyro.sample("Rt_random_walk_sd", dist.HalfNormal(0.025)) - rt_rv = TransformedRandomVariable( + rt_rv = TransformedVariable( name="log_rt_random_walk", base_rv=RandomWalk( name="log_rt", - step_rv=DistributionalRV( + step_rv=DistributionalVariable( name="rw_step_rv", distribution=dist.Normal(0, 0.025) ), ), transforms=t.ExpTransform(), ) - rt_init_rv = DistributionalRV( + rt_init_rv = DistributionalVariable( name="init_log_rt", distribution=dist.Normal(0, 0.2) ) init_rt, *_ = rt_init_rv.sample() diff --git a/docs/source/tutorials/hospital_admissions_model.qmd b/docs/source/tutorials/hospital_admissions_model.qmd index d7fb2c6f..33999f5c 100644 --- a/docs/source/tutorials/hospital_admissions_model.qmd +++ b/docs/source/tutorials/hospital_admissions_model.qmd @@ -145,7 +145,7 @@ inf_hosp_int = deterministic.DeterministicPMF( name="inf_hosp_int", value=inf_hosp_int_array ) -hosp_rate = metaclass.DistributionalRV( +hosp_rate = metaclass.DistributionalVariable( name="IHR", distribution=dist.LogNormal(jnp.log(0.05), jnp.log(1.1)) ) @@ -155,7 +155,7 @@ latent_hosp = latent.HospitalAdmissions( ) ``` -The `inf_hosp_int` is a `DeterministicPMF` object that takes the infection to hospital admission interval as input. The `hosp_rate` is a `DistributionalRV` object that takes a numpyro distribution to represent the infection to hospital admission rate. The `HospitalAdmissions` class is a `RandomVariable` that takes two distributions as inputs: the infection to admission interval and the infection to hospital admission rate. Now, we can define the rest of the other components: +The `inf_hosp_int` is a `DeterministicPMF` object that takes the infection to hospital admission interval as input. The `hosp_rate` is a `DistributionalVariable` object that takes a numpyro distribution to represent the infection to hospital admission rate. The `HospitalAdmissions` class is a `RandomVariable` that takes two distributions as inputs: the infection to admission interval and the infection to hospital admission rate. Now, we can define the rest of the other components: ```{python} # | label: initializing-rest-of-model @@ -171,7 +171,7 @@ latent_inf = latent.Infections() n_initialization_points = max(gen_int_array.size, inf_hosp_int_array.size) I0 = InfectionInitializationProcess( "I0_initialization", - metaclass.DistributionalRV( + metaclass.DistributionalVariable( name="I0", distribution=dist.LogNormal(loc=jnp.log(100), scale=jnp.log(1.75)), ), @@ -194,17 +194,17 @@ class MyRt(metaclass.RandomVariable): def sample(self, n: int, **kwargs) -> tuple: sd_rt = numpyro.sample("Rt_random_walk_sd", dist.HalfNormal(0.025)) - rt_rv = metaclass.TransformedRandomVariable( + rt_rv = metaclass.TransformedVariable( name="log_rt_random_walk", base_rv=process.RandomWalk( name="log_rt", - step_rv=metaclass.DistributionalRV( + step_rv=metaclass.DistributionalVariable( name="rw_step_rv", distribution=dist.Normal(0, 0.025) ), ), transforms=transformation.ExpTransform(), ) - rt_init_rv = metaclass.DistributionalRV( + rt_init_rv = metaclass.DistributionalVariable( name="init_log_rt", distribution=dist.Normal(0, 0.2) ) init_rt, *_ = rt_init_rv.sample() @@ -218,9 +218,9 @@ rtproc = MyRt() # we place a log-Normal prior on the concentration # parameter of the negative binomial. -nb_conc_rv = metaclass.TransformedRandomVariable( +nb_conc_rv = metaclass.TransformedVariable( "concentration", - metaclass.DistributionalRV( + metaclass.DistributionalVariable( name="concentration_raw", distribution=dist.TruncatedNormal(loc=0, scale=1, low=0.01), ), diff --git a/docs/source/tutorials/periodic_effects.qmd b/docs/source/tutorials/periodic_effects.qmd index bfe3e30d..e0a16847 100644 --- a/docs/source/tutorials/periodic_effects.qmd +++ b/docs/source/tutorials/periodic_effects.qmd @@ -76,7 +76,7 @@ mysimplex = dist.TransformedDistribution( # Constructing the day of week effect dayofweek = process.DayOfWeekEffect( offset=0, - quantity_to_broadcast=metaclass.DistributionalRV( + quantity_to_broadcast=metaclass.DistributionalVariable( name="simp", distribution=mysimplex ), t_start=0, diff --git a/pyrenew/process/iidrandomsequence.py b/pyrenew/process/iidrandomsequence.py index 2f868ada..01c0c8ff 100644 --- a/pyrenew/process/iidrandomsequence.py +++ b/pyrenew/process/iidrandomsequence.py @@ -4,7 +4,11 @@ import numpyro.distributions as dist from numpyro.contrib.control_flow import scan -from pyrenew.metaclass import DistributionalRV, RandomVariable, SampledValue +from pyrenew.metaclass import ( + DistributionalVariable, + RandomVariable, + SampledValue, +) class IIDRandomSequence(RandomVariable): @@ -130,7 +134,7 @@ def __init__( see :class:`IIDRandomSequence`. element_rv_name: str Name for the internal element_rv, here a - DistributionalRV encoding a + DistributionalVariable encoding a standard Normal (mean = 0, sd = 1) distribution. @@ -139,7 +143,7 @@ def __init__( None """ super().__init__( - element_rv=DistributionalRV( + element_rv=DistributionalVariable( name=element_rv_name, distribution=dist.Normal(0, 1) ), ) diff --git a/pyrenew/process/randomwalk.py b/pyrenew/process/randomwalk.py index a9fa472e..035b7cf4 100644 --- a/pyrenew/process/randomwalk.py +++ b/pyrenew/process/randomwalk.py @@ -3,7 +3,7 @@ import numpyro.distributions as dist -from pyrenew.metaclass import DistributionalRV, RandomVariable +from pyrenew.metaclass import DistributionalVariable, RandomVariable from pyrenew.process.differencedprocess import DifferencedProcess from pyrenew.process.iidrandomsequence import IIDRandomSequence @@ -69,7 +69,7 @@ def __init__( Parameters ---------- step_rv_name : - Name for the DistributionalRV + Name for the DistributionalVariable from which the Normal(0, 1) steps are sampled. **kwargs: @@ -80,7 +80,7 @@ def __init__( None """ super().__init__( - step_rv=DistributionalRV( + step_rv=DistributionalVariable( name=step_rv_name, distribution=dist.Normal(0.0, 1.0) ), **kwargs, diff --git a/test/test_assert_sample_and_rtype.py b/test/test_assert_sample_and_rtype.py index 69a59f0f..c0c7b4ec 100644 --- a/test/test_assert_sample_and_rtype.py +++ b/test/test_assert_sample_and_rtype.py @@ -9,7 +9,7 @@ from pyrenew.deterministic import DeterministicVariable, NullObservation from pyrenew.metaclass import ( - DistributionalRV, + DistributionalVariable, RandomVariable, SampledValue, _assert_sample_and_rtype, @@ -93,7 +93,7 @@ def test_input_rv(): # numpydoc ignore=GL08 valid_rv = [ NullObservation(), DeterministicVariable(name="rv1", value=jnp.array([1, 2, 3, 4])), - DistributionalRV(name="rv2", distribution=dist.Normal(0, 1)), + DistributionalVariable(name="rv2", distribution=dist.Normal(0, 1)), ] not_rv = jnp.array([1]) diff --git a/test/test_assert_type.py b/test/test_assert_type.py index 7a41cdc8..0a5cf67b 100644 --- a/test/test_assert_type.py +++ b/test/test_assert_type.py @@ -3,7 +3,11 @@ import numpyro.distributions as dist import pytest -from pyrenew.metaclass import DistributionalRV, RandomVariable, _assert_type +from pyrenew.metaclass import ( + DistributionalVariable, + RandomVariable, + _assert_type, +) def test_valid_assertion_types(): @@ -15,7 +19,7 @@ def test_valid_assertion_types(): 5, "Hello", (1,), - DistributionalRV(name="rv", distribution=dist.Beta(1, 1)), + DistributionalVariable(name="rv", distribution=dist.Beta(1, 1)), ] arg_names = ["input_int", "input_string", "input_tuple", "input_rv"] input_types = [int, str, tuple, RandomVariable] diff --git a/test/test_differenced_process.py b/test/test_differenced_process.py index 63c28073..d4e710eb 100644 --- a/test/test_differenced_process.py +++ b/test/test_differenced_process.py @@ -10,7 +10,7 @@ from numpy.testing import assert_array_almost_equal from pyrenew.deterministic import DeterministicVariable, NullVariable -from pyrenew.metaclass import DistributionalRV +from pyrenew.metaclass import DistributionalVariable from pyrenew.process import ( DifferencedProcess, IIDRandomSequence, @@ -155,7 +155,7 @@ def test_manual_integrator_correctness(diffs, inits, expected_solution): [ [ IIDRandomSequence( - DistributionalRV("element_dist", dist.Cauchy(0.02, 0.3)), + DistributionalVariable("element_dist", dist.Cauchy(0.02, 0.3)), ), 3, jnp.array([0.25, 0.67, 5]), diff --git a/test/test_distributional_rv.py b/test/test_distributional_rv.py index 0a0b4d2c..7ddc3ad8 100644 --- a/test/test_distributional_rv.py +++ b/test/test_distributional_rv.py @@ -1,6 +1,7 @@ """ Tests for the distributional RV classes """ + import jax.numpy as jnp import numpyro import numpyro.distributions as dist @@ -9,16 +10,16 @@ from numpyro.distributions import ExpandedDistribution from pyrenew.metaclass import ( - DistributionalRV, - DynamicDistributionalRV, - StaticDistributionalRV, + DistributionalVariable, + DynamicDistributionalVariable, + StaticDistributionalVariable, ) class NonCallableTestClass: """ Generic non-callable object to test - callable checking for DynamicDistributionalRV. + callable checking for DynamicDistributionalVariable. """ def __init__(self): @@ -37,9 +38,11 @@ def test_invalid_constructor_args(not_a_dist): """ with pytest.raises( - ValueError, match="distribution argument to DistributionalRV" + ValueError, match="distribution argument to DistributionalVariable" ): - DistributionalRV(name="this should fail", distribution=not_a_dist) + DistributionalVariable( + name="this should fail", distribution=not_a_dist + ) with pytest.raises( ValueError, match=( @@ -47,9 +50,9 @@ def test_invalid_constructor_args(not_a_dist): "numpyro.distributions.Distribution" ), ): - StaticDistributionalRV.validate(not_a_dist) + StaticDistributionalVariable.validate(not_a_dist) with pytest.raises(ValueError, match="must provide a Callable"): - DynamicDistributionalRV.validate(not_a_dist) + DynamicDistributionalVariable.validate(not_a_dist) @pytest.mark.parametrize( @@ -63,18 +66,18 @@ def test_invalid_constructor_args(not_a_dist): def test_factory_triage(valid_static_dist_arg, valid_dynamic_dist_arg): """ Test that passing a numpyro.distributions.Distribution - instance to the DistributionalRV factory instaniates - a StaticDistributionalRV, while passing a callable - instaniates a DynamicDistributionalRV + instance to the DistributionalVariable factory instaniates + a StaticDistributionalVariable, while passing a callable + instaniates a DynamicDistributionalVariable """ - static = DistributionalRV( + static = DistributionalVariable( name="test static", distribution=valid_static_dist_arg ) - assert isinstance(static, StaticDistributionalRV) - dynamic = DistributionalRV( + assert isinstance(static, StaticDistributionalVariable) + dynamic = DistributionalVariable( name="test dynamic", distribution=valid_dynamic_dist_arg ) - assert isinstance(dynamic, DynamicDistributionalRV) + assert isinstance(dynamic, DynamicDistributionalVariable) @pytest.mark.parametrize( @@ -97,12 +100,12 @@ def test_expand_by(dist, params, expand_by_shape): Test the expand_by method for static distributional RVs. """ - static = DistributionalRV(name="static", distribution=dist(**params)) - dynamic = DistributionalRV(name="dynamic", distribution=dist) + static = DistributionalVariable(name="static", distribution=dist(**params)) + dynamic = DistributionalVariable(name="dynamic", distribution=dist) expanded_static = static.expand_by(expand_by_shape) expanded_dynamic = dynamic.expand_by(expand_by_shape) - assert isinstance(expanded_dynamic, DynamicDistributionalRV) + assert isinstance(expanded_dynamic, DynamicDistributionalVariable) assert dynamic.expand_by_shape is None assert isinstance(expanded_dynamic.expand_by_shape, tuple) assert expanded_dynamic.expand_by_shape == expand_by_shape @@ -112,7 +115,7 @@ def test_expand_by(dist, params, expand_by_shape): == expanded_dynamic.distribution_constructor ) - assert isinstance(expanded_static, StaticDistributionalRV) + assert isinstance(expanded_static, StaticDistributionalVariable) assert isinstance(expanded_static.distribution, ExpandedDistribution) assert expanded_static.distribution.batch_shape == ( expand_by_shape + static.distribution.batch_shape @@ -140,15 +143,15 @@ def test_expand_by(dist, params, expand_by_shape): ) def test_sampling_equivalent(dist, params): """ - Test that sampling a DynamicDistributionalRV + Test that sampling a DynamicDistributionalVariable with a given parameterization is equivalent to - sampling a StaticDistributionalRV with the + sampling a StaticDistributionalVariable with the same parameterization and the same random seed """ - static = DistributionalRV(name="static", distribution=dist(**params)) - dynamic = DistributionalRV(name="dynamic", distribution=dist) - assert isinstance(static, StaticDistributionalRV) - assert isinstance(dynamic, DynamicDistributionalRV) + static = DistributionalVariable(name="static", distribution=dist(**params)) + dynamic = DistributionalVariable(name="dynamic", distribution=dist) + assert isinstance(static, StaticDistributionalVariable) + assert isinstance(dynamic, DynamicDistributionalVariable) with numpyro.handlers.seed(rng_seed=5): static_samp, *_ = static() with numpyro.handlers.seed(rng_seed=5): diff --git a/test/test_forecast.py b/test/test_forecast.py index beef0273..2191f092 100644 --- a/test/test_forecast.py +++ b/test/test_forecast.py @@ -14,7 +14,7 @@ Infections, InitializeInfectionsZeroPad, ) -from pyrenew.metaclass import DistributionalRV +from pyrenew.metaclass import DistributionalVariable from pyrenew.model import RtInfectionsRenewalModel from pyrenew.observation import PoissonObservation @@ -28,7 +28,7 @@ def test_forecast(): gen_int = DeterministicPMF(name="gen_int", value=pmf_array) I0 = InfectionInitializationProcess( "I0_initialization", - DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), + DistributionalVariable(name="I0", distribution=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), t_unit=1, ) diff --git a/test/test_iid_random_sequence.py b/test/test_iid_random_sequence.py index eb6d943c..320d9603 100755 --- a/test/test_iid_random_sequence.py +++ b/test/test_iid_random_sequence.py @@ -7,9 +7,9 @@ from scipy.stats import kstest from pyrenew.metaclass import ( - DistributionalRV, + DistributionalVariable, SampledValue, - StaticDistributionalRV, + StaticDistributionalVariable, ) from pyrenew.process import IIDRandomSequence, StandardNormalSequence @@ -29,7 +29,7 @@ def test_iidrandomsequence_with_dist_rv(distribution, n): a distributional RV, including with array-valued distributions """ - element_rv = DistributionalRV("el_rv", distribution=distribution) + element_rv = DistributionalVariable("el_rv", distribution=distribution) rseq = IIDRandomSequence(element_rv=element_rv) if distribution.batch_shape == () or distribution.batch_shape == (1,): expected_shape = (n,) @@ -63,9 +63,9 @@ def test_standard_normal_sequence(): """ norm_seq = StandardNormalSequence("test_norm_elements") - # should be implemented with a DistributionalRV + # should be implemented with a DistributionalVariable # that is a standard normal - assert isinstance(norm_seq.element_rv, StaticDistributionalRV) + assert isinstance(norm_seq.element_rv, StaticDistributionalVariable) assert isinstance(norm_seq.element_rv.distribution, dist.Normal) assert norm_seq.element_rv.distribution.loc == 0.0 assert norm_seq.element_rv.distribution.scale == 1.0 diff --git a/test/test_infection_initialization_process.py b/test/test_infection_initialization_process.py index afe91ef6..235f6966 100644 --- a/test/test_infection_initialization_process.py +++ b/test/test_infection_initialization_process.py @@ -11,7 +11,7 @@ InitializeInfectionsFromVec, InitializeInfectionsZeroPad, ) -from pyrenew.metaclass import DistributionalRV +from pyrenew.metaclass import DistributionalVariable def test_infection_initialization_process(): @@ -20,14 +20,14 @@ def test_infection_initialization_process(): zero_pad_model = InfectionInitializationProcess( "zero_pad_model", - DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), + DistributionalVariable(name="I0", distribution=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints), t_unit=1, ) exp_model = InfectionInitializationProcess( "exp_model", - DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), + DistributionalVariable(name="I0", distribution=dist.LogNormal(0, 1)), InitializeInfectionsExponentialGrowth( n_timepoints, DeterministicVariable(name="rate", value=0.5) ), diff --git a/test/test_latent_admissions.py b/test/test_latent_admissions.py index 526fbc31..4eabfbd2 100644 --- a/test/test_latent_admissions.py +++ b/test/test_latent_admissions.py @@ -10,7 +10,7 @@ from pyrenew.deterministic import DeterministicPMF, DeterministicVariable from pyrenew.latent import HospitalAdmissions, Infections -from pyrenew.metaclass import DistributionalRV, SampledValue +from pyrenew.metaclass import DistributionalVariable, SampledValue def test_admissions_sample(): @@ -64,7 +64,7 @@ def test_admissions_sample(): hosp1 = HospitalAdmissions( infection_to_admission_interval_rv=inf_hosp, - infection_hospitalization_ratio_rv=DistributionalRV( + infection_hospitalization_ratio_rv=DistributionalVariable( name="IHR", distribution=dist.LogNormal(jnp.log(0.05), 0.05) ), ) diff --git a/test/test_model_basic_renewal.py b/test/test_model_basic_renewal.py index ffe09cd4..210c1972 100644 --- a/test/test_model_basic_renewal.py +++ b/test/test_model_basic_renewal.py @@ -18,7 +18,7 @@ Infections, InitializeInfectionsZeroPad, ) -from pyrenew.metaclass import DistributionalRV +from pyrenew.metaclass import DistributionalVariable from pyrenew.model import RtInfectionsRenewalModel from pyrenew.observation import PoissonObservation @@ -36,7 +36,7 @@ def test_model_basicrenewal_no_timepoints_or_observations(): I0_init_rv = InfectionInitializationProcess( "I0_initialization", - DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), + DistributionalVariable(name="I0", distribution=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), t_unit=1, ) @@ -72,7 +72,7 @@ def test_model_basicrenewal_both_timepoints_and_observations(): I0_init_rv = InfectionInitializationProcess( "I0_initialization", - DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), + DistributionalVariable(name="I0", distribution=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), t_unit=1, ) @@ -111,11 +111,11 @@ def test_model_basicrenewal_no_obs_model(): ) with pytest.raises(ValueError): - _ = DistributionalRV(name="I0", distribution=1) + _ = DistributionalVariable(name="I0", distribution=1) I0_init_rv = InfectionInitializationProcess( "I0_initialization", - DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), + DistributionalVariable(name="I0", distribution=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), t_unit=1, ) @@ -186,7 +186,7 @@ def test_model_basicrenewal_with_obs_model(): I0_init_rv = InfectionInitializationProcess( "I0_initialization", - DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), + DistributionalVariable(name="I0", distribution=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), t_unit=1, ) @@ -240,7 +240,7 @@ def test_model_basicrenewal_padding() -> None: # numpydoc ignore=GL08 I0_init_rv = InfectionInitializationProcess( "I0_initialization", - DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), + DistributionalVariable(name="I0", distribution=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), t_unit=1, ) diff --git a/test/test_model_hosp_admissions.py b/test/test_model_hosp_admissions.py index 293c8ea5..7c99ca13 100644 --- a/test/test_model_hosp_admissions.py +++ b/test/test_model_hosp_admissions.py @@ -23,7 +23,11 @@ Infections, InitializeInfectionsZeroPad, ) -from pyrenew.metaclass import DistributionalRV, RandomVariable, SampledValue +from pyrenew.metaclass import ( + DistributionalVariable, + RandomVariable, + SampledValue, +) from pyrenew.model import HospitalAdmissionsModel from pyrenew.observation import PoissonObservation @@ -91,7 +95,7 @@ def test_model_hosp_no_timepoints_or_observations(): ), ) - I0 = DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)) + I0 = DistributionalVariable(name="I0", distribution=dist.LogNormal(0, 1)) latent_infections = Infections() Rt_process = SimpleRt() @@ -100,7 +104,7 @@ def test_model_hosp_no_timepoints_or_observations(): latent_admissions = HospitalAdmissions( infection_to_admission_interval_rv=inf_hosp, - infection_hospitalization_ratio_rv=DistributionalRV( + infection_hospitalization_ratio_rv=DistributionalVariable( name="IHR", distribution=dist.LogNormal(jnp.log(0.05), 0.05) ), ) @@ -156,7 +160,7 @@ def test_model_hosp_both_timepoints_and_observations(): ), ) - I0 = DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)) + I0 = DistributionalVariable(name="I0", distribution=dist.LogNormal(0, 1)) latent_infections = Infections() Rt_process = SimpleRt() @@ -164,7 +168,7 @@ def test_model_hosp_both_timepoints_and_observations(): latent_admissions = HospitalAdmissions( infection_to_admission_interval_rv=inf_hosp, - infection_hospitalization_ratio_rv=DistributionalRV( + infection_hospitalization_ratio_rv=DistributionalVariable( name="IHR", distribution=dist.LogNormal(jnp.log(0.05), 0.05) ), ) @@ -226,7 +230,7 @@ def test_model_hosp_no_obs_model(): I0 = InfectionInitializationProcess( "I0_initialization", - DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), + DistributionalVariable(name="I0", distribution=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints=n_initialization_points), t_unit=1, ) @@ -236,7 +240,7 @@ def test_model_hosp_no_obs_model(): latent_admissions = HospitalAdmissions( infection_to_admission_interval_rv=inf_hosp, - infection_hospitalization_ratio_rv=DistributionalRV( + infection_hospitalization_ratio_rv=DistributionalVariable( name="IHR", distribution=dist.LogNormal(jnp.log(0.05), 0.05), ), @@ -338,7 +342,7 @@ def test_model_hosp_with_obs_model(): I0 = InfectionInitializationProcess( "I0_initialization", - DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), + DistributionalVariable(name="I0", distribution=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints=n_initialization_points), t_unit=1, ) @@ -349,7 +353,7 @@ def test_model_hosp_with_obs_model(): latent_admissions = HospitalAdmissions( infection_to_admission_interval_rv=inf_hosp, - infection_hospitalization_ratio_rv=DistributionalRV( + infection_hospitalization_ratio_rv=DistributionalVariable( name="IHR", distribution=dist.LogNormal(jnp.log(0.05), 0.05), ), @@ -427,7 +431,7 @@ def test_model_hosp_with_obs_model_weekday_phosp_2(): I0 = InfectionInitializationProcess( "I0_initialization", - DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), + DistributionalVariable(name="I0", distribution=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints=n_initialization_points), t_unit=1, ) @@ -443,7 +447,7 @@ def test_model_hosp_with_obs_model_weekday_phosp_2(): infection_to_admission_interval_rv=inf_hosp, day_of_week_effect_rv=weekday, hospitalization_reporting_ratio_rv=hosp_report_prob_dist, - infection_hospitalization_ratio_rv=DistributionalRV( + infection_hospitalization_ratio_rv=DistributionalVariable( name="IHR", distribution=dist.LogNormal(jnp.log(0.05), 0.05) ), ) @@ -522,7 +526,7 @@ def test_model_hosp_with_obs_model_weekday_phosp(): I0 = InfectionInitializationProcess( "I0_initialization", - DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), + DistributionalVariable(name="I0", distribution=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints=n_initialization_points), t_unit=1, ) @@ -553,7 +557,7 @@ def test_model_hosp_with_obs_model_weekday_phosp(): infection_to_admission_interval_rv=inf_hosp, day_of_week_effect_rv=weekday, hospitalization_reporting_ratio_rv=hosp_report_prob_dist, - infection_hospitalization_ratio_rv=DistributionalRV( + infection_hospitalization_ratio_rv=DistributionalVariable( name="IHR", distribution=dist.LogNormal(jnp.log(0.05), 0.05), ), diff --git a/test/test_predictive.py b/test/test_predictive.py index 5c76b98b..8e472f2b 100644 --- a/test/test_predictive.py +++ b/test/test_predictive.py @@ -17,7 +17,7 @@ Infections, InitializeInfectionsZeroPad, ) -from pyrenew.metaclass import DistributionalRV +from pyrenew.metaclass import DistributionalVariable from pyrenew.model import RtInfectionsRenewalModel from pyrenew.observation import PoissonObservation @@ -25,7 +25,7 @@ gen_int = DeterministicPMF(name="gen_int", value=pmf_array) I0 = InfectionInitializationProcess( "I0_initialization", - DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), + DistributionalVariable(name="I0", distribution=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), t_unit=1, ) diff --git a/test/test_random_key.py b/test/test_random_key.py index 6d6cfd43..f8cf90e2 100644 --- a/test/test_random_key.py +++ b/test/test_random_key.py @@ -19,7 +19,7 @@ Infections, InitializeInfectionsZeroPad, ) -from pyrenew.metaclass import DistributionalRV +from pyrenew.metaclass import DistributionalVariable from pyrenew.model import RtInfectionsRenewalModel from pyrenew.observation import PoissonObservation @@ -29,7 +29,7 @@ def create_test_model(): # numpydoc ignore=GL08 gen_int = DeterministicPMF(name="gen_int", value=pmf_array) I0 = InfectionInitializationProcess( "I0_initialization", - DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)), + DistributionalVariable(name="I0", distribution=dist.LogNormal(0, 1)), InitializeInfectionsZeroPad(n_timepoints=gen_int.size()), t_unit=1, ) diff --git a/test/test_random_walk.py b/test/test_random_walk.py index d7e2cabd..f761fd7e 100755 --- a/test/test_random_walk.py +++ b/test/test_random_walk.py @@ -7,15 +7,15 @@ from numpy.testing import assert_almost_equal, assert_array_almost_equal from pyrenew.deterministic import DeterministicVariable -from pyrenew.metaclass import DistributionalRV, RandomVariable +from pyrenew.metaclass import DistributionalVariable, RandomVariable from pyrenew.process import RandomWalk, StandardNormalRandomWalk @pytest.mark.parametrize( ["element_rv", "init_value"], [ - [DistributionalRV("test_normal", dist.Normal(0.5, 1)), 50.0], - [DistributionalRV("test_cauchy", dist.Cauchy(0.25, 0.25)), -3], + [DistributionalVariable("test_normal", dist.Normal(0.5, 1)), 50.0], + [DistributionalVariable("test_cauchy", dist.Cauchy(0.25, 0.25)), -3], ["test standard normal", jnp.array(3)], ], ) @@ -81,7 +81,7 @@ def test_normal_rw_samples_correctly_distributed(step_mean, step_sd): rw_normal = StandardNormalRandomWalk("test standard normal") else: rw_normal = RandomWalk( - step_rv=DistributionalRV( + step_rv=DistributionalVariable( name="rw_step_dist", distribution=dist.Normal(loc=step_mean, scale=step_sd), ), diff --git a/test/test_transformed_rv_class.py b/test/test_transformed_rv_class.py index 353d59e0..f7910567 100644 --- a/test/test_transformed_rv_class.py +++ b/test/test_transformed_rv_class.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- """ -Tests for TransformedRandomVariable class +Tests for TransformedVariable class """ from typing import NamedTuple @@ -14,11 +14,11 @@ import pyrenew.transformation as t from pyrenew.metaclass import ( - DistributionalRV, + DistributionalVariable, Model, RandomVariable, SampledValue, - TransformedRandomVariable, + TransformedVariable, ) @@ -129,11 +129,11 @@ def sample(self, **kwargs): # numpydoc ignore=GL08 def test_transform_rv_validation(): """ - Test that a TransformedRandomVariable validation + Test that a TransformedVariable validation works as expected. """ - base_rv = DistributionalRV( + base_rv = DistributionalVariable( name="test_normal", distribution=dist.Normal(0, 1) ) base_rv.sample_length = lambda: 1 # numpydoc ignore=GL08 @@ -143,41 +143,41 @@ def test_transform_rv_validation(): test_transforms = [t.IdentityTransform(), t.ExpTransform()] for tr in test_transforms: - my_rv = TransformedRandomVariable("test_transformed_rv", base_rv, tr) + my_rv = TransformedVariable("test_transformed_rv", base_rv, tr) assert isinstance(my_rv.transforms, tuple) assert len(my_rv.transforms) == 1 assert my_rv.sample_length() == 1 not_callable_err = "All entries in self.transforms " "must be callable" sample_length_err = "There must be exactly as many transformations" with pytest.raises(ValueError, match=sample_length_err): - _ = TransformedRandomVariable( + _ = TransformedVariable( "should_error_due_to_too_many_transforms", base_rv, (tr, tr) ) with pytest.raises(ValueError, match=sample_length_err): - _ = TransformedRandomVariable( + _ = TransformedVariable( "should_error_due_to_too_few_transforms", l2_rv, tr ) with pytest.raises(ValueError, match=sample_length_err): - _ = TransformedRandomVariable( + _ = TransformedVariable( "should_also_error_due_to_too_few_transforms", l2_rv, (tr,) ) with pytest.raises(ValueError, match=not_callable_err): - _ = TransformedRandomVariable( + _ = TransformedVariable( "should_error_due_to_not_callable", l2_rv, (1,) ) with pytest.raises(ValueError, match=not_callable_err): - _ = TransformedRandomVariable( + _ = TransformedVariable( "should_error_due_to_not_callable", base_rv, (1,) ) def test_transforms_applied_at_sampling(): """ - Test that TransformedRandomVariable + Test that TransformedVariable instances correctly apply their specified transformations at sampling """ - norm_rv = DistributionalRV( + norm_rv = DistributionalVariable( name="test_normal", distribution=dist.Normal(0, 1) ) norm_rv.sample_length = lambda: 1 @@ -190,9 +190,9 @@ def test_transforms_applied_at_sampling(): t.ExpTransform().inv, t.ScaledLogitTransform(5), ]: - tr_norm = TransformedRandomVariable("transformed_normal", norm_rv, tr) + tr_norm = TransformedVariable("transformed_normal", norm_rv, tr) - tr_l2 = TransformedRandomVariable( + tr_l2 = TransformedVariable( "transformed_length_2", l2_rv, (tr, t.ExpTransform()) ) @@ -217,22 +217,24 @@ def test_transforms_applied_at_sampling(): def test_transforms_variable_naming(): """ - Tests TransformedRandomVariable name + Tests TransformedVariable name recording is as expected. """ - transformed_dist_named_base_rv = TransformedRandomVariable( + transformed_dist_named_base_rv = TransformedVariable( "transformed_rv", NamedBaseRV(), (t.ExpTransform(), t.IdentityTransform()), ) - transformed_dist_unnamed_base_rv = TransformedRandomVariable( + transformed_dist_unnamed_base_rv = TransformedVariable( "transformed_rv", - DistributionalRV(name="my_normal", distribution=dist.Normal(0, 1)), + DistributionalVariable( + name="my_normal", distribution=dist.Normal(0, 1) + ), (t.ExpTransform(), t.IdentityTransform()), ) - transformed_dist_unnamed_base_l2_rv = TransformedRandomVariable( + transformed_dist_unnamed_base_l2_rv = TransformedVariable( "transformed_rv", LengthTwoRV(), (t.ExpTransform(), t.IdentityTransform()), diff --git a/test/utils.py b/test/utils.py index be551dfe..fc596754 100644 --- a/test/utils.py +++ b/test/utils.py @@ -8,10 +8,10 @@ import pyrenew.transformation as t from pyrenew.metaclass import ( - DistributionalRV, + DistributionalVariable, RandomVariable, SampledValue, - TransformedRandomVariable, + TransformedVariable, ) from pyrenew.process import RandomWalk @@ -37,17 +37,17 @@ def __init__(self, name: str = "Rt_rv"): None """ self.name = name - self.rt_rv_ = TransformedRandomVariable( + self.rt_rv_ = TransformedVariable( name=f"{name}_log_rt_random_walk", base_rv=RandomWalk( name="log_rt", - step_rv=DistributionalRV( + step_rv=DistributionalVariable( name="rw_step_rv", distribution=dist.Normal(0, 0.025) ), ), transforms=t.ExpTransform(), ) - self.rt_init_rv_ = DistributionalRV( + self.rt_init_rv_ = DistributionalVariable( name=f"{name}_init_log_rt", distribution=dist.Normal(0, 0.2) ) From da7f4289e7fe1bbd27faaa275674b516b5949215 Mon Sep 17 00:00:00 2001 From: sbidari Date: Thu, 22 Aug 2024 18:13:08 -0400 Subject: [PATCH 12/16] modify import statements --- docs/source/tutorials/basic_renewal_model.qmd | 7 ++----- docs/source/tutorials/extending_pyrenew.qmd | 7 ++----- pyrenew/process/iidrandomsequence.py | 7 ++----- pyrenew/process/randomwalk.py | 3 ++- test/test_assert_sample_and_rtype.py | 2 +- test/test_assert_type.py | 7 ++----- test/test_differenced_process.py | 2 +- test/test_distributional_rv.py | 2 +- test/test_forecast.py | 2 +- test/test_iid_random_sequence.py | 6 +++--- test/test_infection_initialization_process.py | 2 +- test/test_latent_admissions.py | 3 ++- test/test_model_basic_renewal.py | 2 +- test/test_model_hosp_admissions.py | 7 ++----- test/test_predictive.py | 2 +- test/test_random_key.py | 2 +- test/test_random_walk.py | 3 ++- test/test_transformed_rv_class.py | 9 ++------- test/utils.py | 8 ++------ 19 files changed, 31 insertions(+), 52 deletions(-) diff --git a/docs/source/tutorials/basic_renewal_model.qmd b/docs/source/tutorials/basic_renewal_model.qmd index 7d40ed93..bf262094 100644 --- a/docs/source/tutorials/basic_renewal_model.qmd +++ b/docs/source/tutorials/basic_renewal_model.qmd @@ -25,11 +25,8 @@ from pyrenew.latent import ( from pyrenew.observation import PoissonObservation from pyrenew.deterministic import DeterministicPMF from pyrenew.model import RtInfectionsRenewalModel -from pyrenew.metaclass import ( - RandomVariable, - DistributionalVariable, - TransformedVariable, -) +from pyrenew.metaclass import RandomVariable +from pyrenew.randomvariable import DistributionalVariable, TransformedVariable import pyrenew.transformation as t from numpyro.infer.reparam import LocScaleReparam ``` diff --git a/docs/source/tutorials/extending_pyrenew.qmd b/docs/source/tutorials/extending_pyrenew.qmd index 18639c55..f81de653 100644 --- a/docs/source/tutorials/extending_pyrenew.qmd +++ b/docs/source/tutorials/extending_pyrenew.qmd @@ -29,11 +29,8 @@ from pyrenew.deterministic import DeterministicPMF, DeterministicVariable from pyrenew.latent import InfectionsWithFeedback from pyrenew.model import RtInfectionsRenewalModel from pyrenew.process import RandomWalk -from pyrenew.metaclass import ( - RandomVariable, - DistributionalVariable, - TransformedVariable, -) +from pyrenew.metaclass import RandomVariable +from pyrenew.randomvariable import DistributionalVariable, TransformedVariable from pyrenew.latent import ( InfectionInitializationProcess, InitializeInfectionsExponentialGrowth, diff --git a/pyrenew/process/iidrandomsequence.py b/pyrenew/process/iidrandomsequence.py index 01c0c8ff..10adfa9c 100644 --- a/pyrenew/process/iidrandomsequence.py +++ b/pyrenew/process/iidrandomsequence.py @@ -4,11 +4,8 @@ import numpyro.distributions as dist from numpyro.contrib.control_flow import scan -from pyrenew.metaclass import ( - DistributionalVariable, - RandomVariable, - SampledValue, -) +from pyrenew.metaclass import RandomVariable, SampledValue +from pyrenew.randomvariable import DistributionalVariable class IIDRandomSequence(RandomVariable): diff --git a/pyrenew/process/randomwalk.py b/pyrenew/process/randomwalk.py index 035b7cf4..6b0a763d 100644 --- a/pyrenew/process/randomwalk.py +++ b/pyrenew/process/randomwalk.py @@ -3,9 +3,10 @@ import numpyro.distributions as dist -from pyrenew.metaclass import DistributionalVariable, RandomVariable +from pyrenew.metaclass import RandomVariable from pyrenew.process.differencedprocess import DifferencedProcess from pyrenew.process.iidrandomsequence import IIDRandomSequence +from pyrenew.randomvariable import DistributionalVariable class RandomWalk(DifferencedProcess): diff --git a/test/test_assert_sample_and_rtype.py b/test/test_assert_sample_and_rtype.py index c0c7b4ec..d0f9ee8a 100644 --- a/test/test_assert_sample_and_rtype.py +++ b/test/test_assert_sample_and_rtype.py @@ -9,11 +9,11 @@ from pyrenew.deterministic import DeterministicVariable, NullObservation from pyrenew.metaclass import ( - DistributionalVariable, RandomVariable, SampledValue, _assert_sample_and_rtype, ) +from pyrenew.randomvariable import DistributionalVariable class RVreturnsTuple(RandomVariable): diff --git a/test/test_assert_type.py b/test/test_assert_type.py index 0a5cf67b..a885cef3 100644 --- a/test/test_assert_type.py +++ b/test/test_assert_type.py @@ -3,11 +3,8 @@ import numpyro.distributions as dist import pytest -from pyrenew.metaclass import ( - DistributionalVariable, - RandomVariable, - _assert_type, -) +from pyrenew.metaclass import RandomVariable, _assert_type +from pyrenew.randomvariable import DistributionalVariable def test_valid_assertion_types(): diff --git a/test/test_differenced_process.py b/test/test_differenced_process.py index d4e710eb..ba4e95c9 100644 --- a/test/test_differenced_process.py +++ b/test/test_differenced_process.py @@ -10,12 +10,12 @@ from numpy.testing import assert_array_almost_equal from pyrenew.deterministic import DeterministicVariable, NullVariable -from pyrenew.metaclass import DistributionalVariable from pyrenew.process import ( DifferencedProcess, IIDRandomSequence, StandardNormalSequence, ) +from pyrenew.randomvariable import DistributionalVariable @pytest.mark.parametrize( diff --git a/test/test_distributional_rv.py b/test/test_distributional_rv.py index 7ddc3ad8..cebe6f8e 100644 --- a/test/test_distributional_rv.py +++ b/test/test_distributional_rv.py @@ -9,7 +9,7 @@ from numpy.testing import assert_array_equal from numpyro.distributions import ExpandedDistribution -from pyrenew.metaclass import ( +from pyrenew.randomvariable import ( DistributionalVariable, DynamicDistributionalVariable, StaticDistributionalVariable, diff --git a/test/test_forecast.py b/test/test_forecast.py index 2191f092..d8d1d55c 100644 --- a/test/test_forecast.py +++ b/test/test_forecast.py @@ -14,9 +14,9 @@ Infections, InitializeInfectionsZeroPad, ) -from pyrenew.metaclass import DistributionalVariable from pyrenew.model import RtInfectionsRenewalModel from pyrenew.observation import PoissonObservation +from pyrenew.randomvariable import DistributionalVariable def test_forecast(): diff --git a/test/test_iid_random_sequence.py b/test/test_iid_random_sequence.py index 320d9603..73b683aa 100755 --- a/test/test_iid_random_sequence.py +++ b/test/test_iid_random_sequence.py @@ -6,12 +6,12 @@ import pytest from scipy.stats import kstest -from pyrenew.metaclass import ( +from pyrenew.metaclass import SampledValue +from pyrenew.process import IIDRandomSequence, StandardNormalSequence +from pyrenew.randomvariable import ( DistributionalVariable, - SampledValue, StaticDistributionalVariable, ) -from pyrenew.process import IIDRandomSequence, StandardNormalSequence @pytest.mark.parametrize( diff --git a/test/test_infection_initialization_process.py b/test/test_infection_initialization_process.py index 235f6966..069299cd 100644 --- a/test/test_infection_initialization_process.py +++ b/test/test_infection_initialization_process.py @@ -11,7 +11,7 @@ InitializeInfectionsFromVec, InitializeInfectionsZeroPad, ) -from pyrenew.metaclass import DistributionalVariable +from pyrenew.randomvariable import DistributionalVariable def test_infection_initialization_process(): diff --git a/test/test_latent_admissions.py b/test/test_latent_admissions.py index 4eabfbd2..1e82db89 100644 --- a/test/test_latent_admissions.py +++ b/test/test_latent_admissions.py @@ -10,7 +10,8 @@ from pyrenew.deterministic import DeterministicPMF, DeterministicVariable from pyrenew.latent import HospitalAdmissions, Infections -from pyrenew.metaclass import DistributionalVariable, SampledValue +from pyrenew.metaclass import SampledValue +from pyrenew.randomvariable import DistributionalVariable def test_admissions_sample(): diff --git a/test/test_model_basic_renewal.py b/test/test_model_basic_renewal.py index 210c1972..1b0314f8 100644 --- a/test/test_model_basic_renewal.py +++ b/test/test_model_basic_renewal.py @@ -18,9 +18,9 @@ Infections, InitializeInfectionsZeroPad, ) -from pyrenew.metaclass import DistributionalVariable from pyrenew.model import RtInfectionsRenewalModel from pyrenew.observation import PoissonObservation +from pyrenew.randomvariable import DistributionalVariable def test_model_basicrenewal_no_timepoints_or_observations(): diff --git a/test/test_model_hosp_admissions.py b/test/test_model_hosp_admissions.py index 7c99ca13..75e79962 100644 --- a/test/test_model_hosp_admissions.py +++ b/test/test_model_hosp_admissions.py @@ -23,13 +23,10 @@ Infections, InitializeInfectionsZeroPad, ) -from pyrenew.metaclass import ( - DistributionalVariable, - RandomVariable, - SampledValue, -) +from pyrenew.metaclass import RandomVariable, SampledValue from pyrenew.model import HospitalAdmissionsModel from pyrenew.observation import PoissonObservation +from pyrenew.randomvariable import DistributionalVariable class UniformProbForTest(RandomVariable): # numpydoc ignore=GL08 diff --git a/test/test_predictive.py b/test/test_predictive.py index 8e472f2b..636578bb 100644 --- a/test/test_predictive.py +++ b/test/test_predictive.py @@ -17,9 +17,9 @@ Infections, InitializeInfectionsZeroPad, ) -from pyrenew.metaclass import DistributionalVariable from pyrenew.model import RtInfectionsRenewalModel from pyrenew.observation import PoissonObservation +from pyrenew.randomvariable import DistributionalVariable pmf_array = jnp.array([0.25, 0.1, 0.2, 0.45]) gen_int = DeterministicPMF(name="gen_int", value=pmf_array) diff --git a/test/test_random_key.py b/test/test_random_key.py index f8cf90e2..0b99816f 100644 --- a/test/test_random_key.py +++ b/test/test_random_key.py @@ -19,9 +19,9 @@ Infections, InitializeInfectionsZeroPad, ) -from pyrenew.metaclass import DistributionalVariable from pyrenew.model import RtInfectionsRenewalModel from pyrenew.observation import PoissonObservation +from pyrenew.randomvariable import DistributionalVariable def create_test_model(): # numpydoc ignore=GL08 diff --git a/test/test_random_walk.py b/test/test_random_walk.py index f761fd7e..6997d679 100755 --- a/test/test_random_walk.py +++ b/test/test_random_walk.py @@ -7,8 +7,9 @@ from numpy.testing import assert_almost_equal, assert_array_almost_equal from pyrenew.deterministic import DeterministicVariable -from pyrenew.metaclass import DistributionalVariable, RandomVariable +from pyrenew.metaclass import RandomVariable from pyrenew.process import RandomWalk, StandardNormalRandomWalk +from pyrenew.randomvariable import DistributionalVariable @pytest.mark.parametrize( diff --git a/test/test_transformed_rv_class.py b/test/test_transformed_rv_class.py index f7910567..22dd1c2c 100644 --- a/test/test_transformed_rv_class.py +++ b/test/test_transformed_rv_class.py @@ -13,13 +13,8 @@ from numpy.testing import assert_almost_equal import pyrenew.transformation as t -from pyrenew.metaclass import ( - DistributionalVariable, - Model, - RandomVariable, - SampledValue, - TransformedVariable, -) +from pyrenew.metaclass import Model, RandomVariable, SampledValue +from pyrenew.randomvariable import DistributionalVariable, TransformedVariable class LengthTwoRV(RandomVariable): diff --git a/test/utils.py b/test/utils.py index fc596754..ac345b41 100644 --- a/test/utils.py +++ b/test/utils.py @@ -7,13 +7,9 @@ import numpyro.distributions as dist import pyrenew.transformation as t -from pyrenew.metaclass import ( - DistributionalVariable, - RandomVariable, - SampledValue, - TransformedVariable, -) +from pyrenew.metaclass import RandomVariable, SampledValue from pyrenew.process import RandomWalk +from pyrenew.randomvariable import DistributionalVariable, TransformedVariable class SimpleRt(RandomVariable): From 3e4e87dfd2337254acae91ab74a9836f2f1766c7 Mon Sep 17 00:00:00 2001 From: sbidari Date: Thu, 22 Aug 2024 19:18:13 -0400 Subject: [PATCH 13/16] missed few imports --- docs/source/tutorials/day_of_the_week.qmd | 22 +++++++++---------- .../tutorials/hospital_admissions_model.qmd | 16 +++++++------- docs/source/tutorials/periodic_effects.qmd | 4 ++-- 3 files changed, 21 insertions(+), 21 deletions(-) diff --git a/docs/source/tutorials/day_of_the_week.qmd b/docs/source/tutorials/day_of_the_week.qmd index 4deb88dd..0d72a415 100644 --- a/docs/source/tutorials/day_of_the_week.qmd +++ b/docs/source/tutorials/day_of_the_week.qmd @@ -51,7 +51,7 @@ inf_hosp_int_array = inf_hosp_int["probability_mass"].to_numpy() ```{python} # | label: latent-hosp # | code-fold: true -from pyrenew import latent, deterministic, metaclass +from pyrenew import latent, deterministic, randomvariable import jax.numpy as jnp import numpyro.distributions as dist @@ -59,7 +59,7 @@ inf_hosp_int = deterministic.DeterministicPMF( name="inf_hosp_int", value=inf_hosp_int_array ) -hosp_rate = metaclass.DistributionalVariable( +hosp_rate = randomvariable.DistributionalVariable( name="IHR", distribution=dist.LogNormal(jnp.log(0.05), jnp.log(1.1)) ) @@ -81,7 +81,7 @@ n_initialization_points = max(gen_int_array.size, inf_hosp_int_array.size) I0 = InfectionInitializationProcess( "I0_initialization", - metaclass.DistributionalVariable( + randomvariable.DistributionalVariable( name="I0", distribution=dist.LogNormal(loc=jnp.log(100), scale=jnp.log(1.75)), ), @@ -113,11 +113,11 @@ class MyRt(metaclass.RandomVariable): sd_rt, *_ = self.sd_rv() # Random walk step - step_rv = metaclass.DistributionalVariable( + step_rv = randomvariable.DistributionalVariable( name="rw_step_rv", distribution=dist.Normal(0, sd_rt.value) ) - rt_init_rv = metaclass.DistributionalVariable( + rt_init_rv = randomvariable.DistributionalVariable( name="init_log_rt", distribution=dist.Normal(0, 0.2) ) @@ -128,7 +128,7 @@ class MyRt(metaclass.RandomVariable): ) # Transforming the random walk to the Rt scale - rt_rv = metaclass.TransformedVariable( + rt_rv = randomvariable.TransformedVariable( name="Rt_rv", base_rv=base_rv, transforms=transformation.ExpTransform(), @@ -139,7 +139,7 @@ class MyRt(metaclass.RandomVariable): rtproc = MyRt( - metaclass.DistributionalVariable( + randomvariable.DistributionalVariable( name="Rt_random_walk_sd", distribution=dist.HalfNormal(0.025) ) ) @@ -152,9 +152,9 @@ rtproc = MyRt( # | code-fold: true # we place a log-Normal prior on the concentration # parameter of the negative binomial. -nb_conc_rv = metaclass.TransformedVariable( +nb_conc_rv = randomvariable.TransformedVariable( "concentration", - metaclass.DistributionalVariable( + randomvariable.DistributionalVariable( name="concentration_raw", distribution=dist.TruncatedNormal(loc=0, scale=1, low=0.01), ), @@ -219,9 +219,9 @@ For this example, the effect will be passed as a scaled Dirichlet distribution. ```{python} # | label: weekly-effect # Instantiating the day-of-the-week effect -dayofweek_effect = metaclass.TransformedVariable( +dayofweek_effect = randomvariable.TransformedVariable( name="dayofweek_effect", - base_rv=metaclass.DistributionalVariable( + base_rv=randomvariable.DistributionalVariable( name="dayofweek_effect_raw", distribution=dist.Dirichlet(jnp.ones(7)), ), diff --git a/docs/source/tutorials/hospital_admissions_model.qmd b/docs/source/tutorials/hospital_admissions_model.qmd index 33999f5c..9456f828 100644 --- a/docs/source/tutorials/hospital_admissions_model.qmd +++ b/docs/source/tutorials/hospital_admissions_model.qmd @@ -137,7 +137,7 @@ With these two in hand, we can start building the model. First, we will define t ```{python} # | label: latent-hosp -from pyrenew import latent, deterministic, metaclass +from pyrenew import latent, deterministic, metaclass, randomvariable import jax.numpy as jnp import numpyro.distributions as dist @@ -145,7 +145,7 @@ inf_hosp_int = deterministic.DeterministicPMF( name="inf_hosp_int", value=inf_hosp_int_array ) -hosp_rate = metaclass.DistributionalVariable( +hosp_rate = randomvariable.DistributionalVariable( name="IHR", distribution=dist.LogNormal(jnp.log(0.05), jnp.log(1.1)) ) @@ -171,7 +171,7 @@ latent_inf = latent.Infections() n_initialization_points = max(gen_int_array.size, inf_hosp_int_array.size) I0 = InfectionInitializationProcess( "I0_initialization", - metaclass.DistributionalVariable( + randomvariable.DistributionalVariable( name="I0", distribution=dist.LogNormal(loc=jnp.log(100), scale=jnp.log(1.75)), ), @@ -194,17 +194,17 @@ class MyRt(metaclass.RandomVariable): def sample(self, n: int, **kwargs) -> tuple: sd_rt = numpyro.sample("Rt_random_walk_sd", dist.HalfNormal(0.025)) - rt_rv = metaclass.TransformedVariable( + rt_rv = randomvariable.TransformedVariable( name="log_rt_random_walk", base_rv=process.RandomWalk( name="log_rt", - step_rv=metaclass.DistributionalVariable( + step_rv=randomvariable.DistributionalVariable( name="rw_step_rv", distribution=dist.Normal(0, 0.025) ), ), transforms=transformation.ExpTransform(), ) - rt_init_rv = metaclass.DistributionalVariable( + rt_init_rv = randomvariable.DistributionalVariable( name="init_log_rt", distribution=dist.Normal(0, 0.2) ) init_rt, *_ = rt_init_rv.sample() @@ -218,9 +218,9 @@ rtproc = MyRt() # we place a log-Normal prior on the concentration # parameter of the negative binomial. -nb_conc_rv = metaclass.TransformedVariable( +nb_conc_rv = randomvariable.TransformedVariable( "concentration", - metaclass.DistributionalVariable( + randomvariable.DistributionalVariable( name="concentration_raw", distribution=dist.TruncatedNormal(loc=0, scale=1, low=0.01), ), diff --git a/docs/source/tutorials/periodic_effects.qmd b/docs/source/tutorials/periodic_effects.qmd index e0a16847..1603ed59 100644 --- a/docs/source/tutorials/periodic_effects.qmd +++ b/docs/source/tutorials/periodic_effects.qmd @@ -65,7 +65,7 @@ The `PeriodicBroadcaster` class can also be used to repeat a sequence as a whole ```{python} import numpyro.distributions as dist -from pyrenew import transformation, metaclass +from pyrenew import transformation, randomvariable # Building the transformed prior: Dirichlet * 7 mysimplex = dist.TransformedDistribution( @@ -76,7 +76,7 @@ mysimplex = dist.TransformedDistribution( # Constructing the day of week effect dayofweek = process.DayOfWeekEffect( offset=0, - quantity_to_broadcast=metaclass.DistributionalVariable( + quantity_to_broadcast=randomvariable.DistributionalVariable( name="simp", distribution=mysimplex ), t_start=0, From 6591dc4f69e2604a89cd15758582e9834222e40c Mon Sep 17 00:00:00 2001 From: sbidari Date: Mon, 26 Aug 2024 12:42:03 -0400 Subject: [PATCH 14/16] pre-commit changes --- pyrenew/latent/hospitaladmissions.py | 6 +----- test/test_model_hosp_admissions.py | 8 ++++++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/pyrenew/latent/hospitaladmissions.py b/pyrenew/latent/hospitaladmissions.py index 2ba6c0f0..55689b8b 100644 --- a/pyrenew/latent/hospitaladmissions.py +++ b/pyrenew/latent/hospitaladmissions.py @@ -11,11 +11,7 @@ import pyrenew.arrayutils as au from pyrenew.convolve import compute_delay_ascertained_incidence from pyrenew.deterministic import DeterministicVariable -from pyrenew.metaclass import ( - RandomVariable, - SampledValue, - compute_delay_ascertained_incidence, -) +from pyrenew.metaclass import RandomVariable, SampledValue class HospitalAdmissionsSample(NamedTuple): diff --git a/test/test_model_hosp_admissions.py b/test/test_model_hosp_admissions.py index 050cbceb..f6d3d3a2 100644 --- a/test/test_model_hosp_admissions.py +++ b/test/test_model_hosp_admissions.py @@ -261,7 +261,9 @@ def test_model_hosp_no_obs_model(): with numpyro.handlers.seed(rng_seed=223): model1_samp = model0.sample(n_datapoints=30) - np.testing.assert_array_almost_equal(model0_samp.Rt.value, model1_samp.Rt.value) + np.testing.assert_array_almost_equal( + model0_samp.Rt.value, model1_samp.Rt.value + ) np.testing.assert_array_equal( model0_samp.latent_infections.value, model1_samp.latent_infections.value, @@ -570,7 +572,9 @@ def test_model_hosp_with_obs_model_weekday_phosp(): # Sampling and fitting model 0 (with no obs for infections) with numpyro.handlers.seed(rng_seed=223): - model1_samp = model1.sample(n_datapoints=n_obs_to_generate, padding=pad_size) + model1_samp = model1.sample( + n_datapoints=n_obs_to_generate, padding=pad_size + ) # Showed during merge conflict, but unsure if it will be needed # pad_size = 5 From e1b888af1989c3d605a79507522ecaa583640958 Mon Sep 17 00:00:00 2001 From: sbidari Date: Mon, 26 Aug 2024 13:37:52 -0400 Subject: [PATCH 15/16] update metaclass.py --- pyrenew/metaclass.py | 33 --------------------------------- 1 file changed, 33 deletions(-) diff --git a/pyrenew/metaclass.py b/pyrenew/metaclass.py index 344cf894..63d72de1 100644 --- a/pyrenew/metaclass.py +++ b/pyrenew/metaclass.py @@ -8,7 +8,6 @@ from typing import NamedTuple, get_type_hints import jax -import jax.numpy as jnp import jax.random as jr import matplotlib.pyplot as plt import numpy as np @@ -122,38 +121,6 @@ def _assert_sample_and_rtype( return None -def compute_delay_ascertained_incidence( - incidence_to_observation_rate: ArrayLike, - latent_incidence: ArrayLike, - incidence_to_observation_delay_interval: ArrayLike, -) -> ArrayLike: - """ - Computes incidences observed according - to a given observation rate and based - on a delay interval. - - Parameters - ---------- - incidence_to_observation_rate: ArrayLike - The rate at which latent incidences are observed. - latent_incidence: ArrayLike - Incidence values based on the true underlying process. - incidence_to_observation_delay_interval: ArrayLike - Pmf of delay interval between incidence to observation. - - Returns - -------- - ArrayLike - The incidence after the observation delay. - """ - delay_obs_incidence = jnp.convolve( - incidence_to_observation_rate * latent_incidence, - incidence_to_observation_delay_interval, - mode="valid", - ) - return delay_obs_incidence - - class SampledValue(NamedTuple): """ A container for a value sampled from a RandomVariable. From b6ea3d50fee425263838cfb643d15eeb4f3e0cc0 Mon Sep 17 00:00:00 2001 From: sbidari Date: Mon, 26 Aug 2024 14:44:09 -0400 Subject: [PATCH 16/16] add randomvariable.rst --- docs/source/msei_reference/index.rst | 1 + docs/source/msei_reference/randomvariable.rst | 7 +++++++ 2 files changed, 8 insertions(+) create mode 100644 docs/source/msei_reference/randomvariable.rst diff --git a/docs/source/msei_reference/index.rst b/docs/source/msei_reference/index.rst index f7fae05d..323874ca 100644 --- a/docs/source/msei_reference/index.rst +++ b/docs/source/msei_reference/index.rst @@ -7,6 +7,7 @@ Reference model latent process + randomvariable observation datasets msei diff --git a/docs/source/msei_reference/randomvariable.rst b/docs/source/msei_reference/randomvariable.rst new file mode 100644 index 00000000..3ffe44d0 --- /dev/null +++ b/docs/source/msei_reference/randomvariable.rst @@ -0,0 +1,7 @@ +Random Variables +=========== + +.. automodule:: pyrenew.randomvariable + :members: + :undoc-members: + :show-inheritance: