From 41dc41b57c70a704278afa5961b83b1c926a4d8e Mon Sep 17 00:00:00 2001 From: George Vega Yon Date: Wed, 10 Jul 2024 09:56:27 -0600 Subject: [PATCH 01/41] Adding information about future features for retrieving timeinfo --- model/docs/time.qmd | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/model/docs/time.qmd b/model/docs/time.qmd index c7961519..86c13651 100644 --- a/model/docs/time.qmd +++ b/model/docs/time.qmd @@ -10,7 +10,7 @@ The fundamental time unit should represent a period of fixed (or approximately f For many infectious disease renewal models of interest, the fundamental time unit will be days, and we will proceed with this tutorial treating days as our fundamental unit. - `pyrenew` deals with time having `RandomVariable`s carry information about (i) their own time unit expressed relative to the fundamental unit (`t_unit`) and (ii) the starting time, `t_start`, measured relative to `t = 0` in model time in fundamental time units. + `pyrenew` deals with time having `RandomVariable`s carry information about (i) their own time unit expressed relative to the fundamental unit (`t_unit`) and (ii) the starting time, `t_start`, measured relative to `t = 0` in model time in fundamental time units. Moreover, return values from `RandomVariable.sample()` are namedtuples with `TimeArray` objects that carry the same information. The tuple `(t_unit, t_start)` can encode different types of time series data. For example: @@ -31,10 +31,6 @@ The `PeriodicBroadcaster()` class provides a way of tiling and repeating data ac The following section describes some preliminary design principles that may be included in future versions of `pyrenew`. -### Validation - -With random variables possibly spanning different time scales, *e.g.*, weekly, daily, hourly, the metaclass `Model` should ensure random variables within the model share the same time unit. - ### Array alignment Using `t_unit` and `t_start`, random variables should be able to align input and output data. For example, in the case of the `RtInfectionsRenewalModel.sample()`, the computed values of `Rt` and `infections` are padded left with `nan` values to account for the seeding process. Instead, we expect to either pre-process the padding leveraging the `t_start` information of the involved variables or simplify the process via a function call that aligns the arrays. A possible implementation could be a method `align()` that takes a list of random variables and aligns them based on the `t_unit` and `t_start` information, e.g.: @@ -42,3 +38,7 @@ Using `t_unit` and `t_start`, random variables should be able to align input and ```python Rt_aligned, infections_aligned = align([Rt, infections]) ``` + +### Retrieving time information from sites + +Since numpyro only stores Jax arrays, we cannot store the time information in the arrays themselves. Next iterations of `pyrenew` should include a way to retrieve the time information from the sites of the model after running them. \ No newline at end of file From c4eef2003f8f160ea2811d4f8e4287374dd5f006 Mon Sep 17 00:00:00 2001 From: George Vega Yon Date: Wed, 10 Jul 2024 17:05:29 -0600 Subject: [PATCH 02/41] Updating to return TimeArray (WIP) --- .../pyrenew/deterministic/deterministic.py | 18 +++++++-- .../pyrenew/deterministic/deterministicpmf.py | 18 +++++++-- model/src/pyrenew/deterministic/nullrv.py | 13 ++++--- model/src/pyrenew/deterministic/process.py | 22 ++++++++--- .../src/pyrenew/latent/hospitaladmissions.py | 26 +++++++------ .../latent/infection_seeding_method.py | 3 ++ .../latent/infection_seeding_process.py | 12 ++++-- model/src/pyrenew/latent/infections.py | 8 ++-- .../pyrenew/latent/infectionswithfeedback.py | 11 ++++-- model/src/pyrenew/metaclass.py | 38 +++++++++++++++++-- model/src/pyrenew/model/admissionsmodel.py | 10 ++--- .../pyrenew/model/rtinfectionsrenewalmodel.py | 27 +++++++------ .../pyrenew/observation/negativebinomial.py | 6 +-- model/src/pyrenew/observation/poisson.py | 6 +-- model/src/pyrenew/process/ar.py | 6 ++- .../src/pyrenew/process/firstdifferencear.py | 4 +- model/src/pyrenew/process/rtperiodicdiff.py | 4 +- model/src/pyrenew/process/rtrandomwalk.py | 6 +-- model/src/pyrenew/process/simplerandomwalk.py | 4 +- model/src/test/test_ar_process.py | 5 ++- model/src/test/test_deterministic.py | 14 ++++--- model/src/test/test_first_difference_ar.py | 4 +- .../src/test/test_infection_seeding_method.py | 11 ++++-- model/src/test/test_infectionsrtfeedback.py | 12 +++--- model/src/test/test_latent_admissions.py | 8 ++-- model/src/test/test_latent_infections.py | 2 +- model/src/test/test_model_basic_renewal.py | 10 ++--- .../test/test_observation_negativebinom.py | 8 ++-- model/src/test/test_random_walk.py | 5 ++- model/src/test/test_rtperiodicdiff.py | 2 +- 30 files changed, 211 insertions(+), 112 deletions(-) diff --git a/model/src/pyrenew/deterministic/deterministic.py b/model/src/pyrenew/deterministic/deterministic.py index cd602212..092c793d 100644 --- a/model/src/pyrenew/deterministic/deterministic.py +++ b/model/src/pyrenew/deterministic/deterministic.py @@ -6,7 +6,7 @@ import jax.numpy as jnp import numpyro as npro from jax.typing import ArrayLike -from pyrenew.metaclass import RandomVariable +from pyrenew.metaclass import RandomVariable, TimeArray class DeterministicVariable(RandomVariable): @@ -19,6 +19,8 @@ def __init__( self, vars: ArrayLike, name: str, + t_start: int | None = None, + t_unit: int | None = None, ) -> None: """Default constructor @@ -28,6 +30,10 @@ def __init__( A tuple with arraylike objects. name : str, optional A name to assign to the process. + t_start : int, optional + The start time of the process. + t_unit : int, optional + The unit of time relative to the model's fundamental (smallest) time unit. Returns ------- @@ -35,6 +41,7 @@ def __init__( """ self.validate(vars) + self.set_timeseries(t_start, t_unit) self.vars = jnp.atleast_1d(vars) self.name = name @@ -83,8 +90,13 @@ def sample( Returns ------- tuple - Containing the stored values during construction. + Containing the stored values during construction wrapped in a TimeArrayß. """ if record: npro.deterministic(self.name, self.vars) - return (self.vars,) + return ( + TimeArray( + array = self.vars, + t_start=self.t_start, + t_unit=self.t_unit, + ),) diff --git a/model/src/pyrenew/deterministic/deterministicpmf.py b/model/src/pyrenew/deterministic/deterministicpmf.py index f0f2aa0c..00bf4e1b 100644 --- a/model/src/pyrenew/deterministic/deterministicpmf.py +++ b/model/src/pyrenew/deterministic/deterministicpmf.py @@ -18,6 +18,8 @@ def __init__( vars: ArrayLike, name: str, tol: float = 1e-5, + t_start: int | None = None, + t_unit: int | None = None, ) -> None: """ Default constructor @@ -36,7 +38,12 @@ def __init__( tol : float, optional Passed to pyrenew.distutil.validate_discrete_dist_vector. Defaults to 1e-5. - + t_start : int, optional + The start time of the process. + t_unit : int, optional + The unit of time relative to the model's fundamental (smallest) + time unit. + Returns ------- None @@ -46,7 +53,12 @@ def __init__( tol=tol, ) - self.basevar = DeterministicVariable(vars, name) + self.basevar = DeterministicVariable( + vars=vars, + name=name, + t_start=t_start, + t_unit=t_unit, + ) return None @@ -82,7 +94,7 @@ def sample( Returns ------- tuple - Containing the stored values during construction. + Containing the stored values during construction wrapped in a TimeArray. """ return self.basevar.sample(**kwargs) diff --git a/model/src/pyrenew/deterministic/nullrv.py b/model/src/pyrenew/deterministic/nullrv.py index 435c68b6..03e06232 100644 --- a/model/src/pyrenew/deterministic/nullrv.py +++ b/model/src/pyrenew/deterministic/nullrv.py @@ -3,6 +3,7 @@ from __future__ import annotations from jax.typing import ArrayLike +from pyrenew.metaclass import TimeArray from pyrenew.deterministic.deterministic import DeterministicVariable @@ -46,10 +47,10 @@ def sample( Returns ------- tuple - Containing None. + Containing a TimeArray with None. """ - return (None,) + return (TimeArray(None),) class NullProcess(NullVariable): @@ -95,10 +96,10 @@ def sample( Returns ------- tuple - Containing None. + Containing a TimeArray with None. """ - return (None,) + return (TimeArray(None),) class NullObservation(NullVariable): @@ -151,7 +152,7 @@ def sample( Returns ------- tuple - Containing None. + Containing a TimeArray with None. """ - return (None,) + return (TimeArray(None),) diff --git a/model/src/pyrenew/deterministic/process.py b/model/src/pyrenew/deterministic/process.py index 64f5a514..9ac2cb7b 100644 --- a/model/src/pyrenew/deterministic/process.py +++ b/model/src/pyrenew/deterministic/process.py @@ -1,6 +1,7 @@ # numpydoc ignore=GL08 import jax.numpy as jnp +from pyrenew.metaclass import TimeArray from pyrenew.deterministic.deterministic import DeterministicVariable @@ -29,14 +30,25 @@ def sample( Returns ------- tuple - Containing the stored values during construction. + Containing the stored values during construction wrapped in a TimeArray. """ res, *_ = super().sample(**kwargs) - dif = duration - res.shape[0] + dif = duration - res.array.shape[0] if dif > 0: - return (jnp.hstack([res, jnp.repeat(res[-1], dif)]),) - - return (res[:duration],) + return ( + TimeArray( + jnp.hstack([res.array, jnp.repeat(res.array[-1], dif)]), t_start=self.t_start, + t_unit=self.t_unit, + ), + ) + + return ( + TimeArray( + array=res.array[:duration], + t_start=self.t_start, + t_unit=self.t_unit + ), + ) diff --git a/model/src/pyrenew/latent/hospitaladmissions.py b/model/src/pyrenew/latent/hospitaladmissions.py index 3c46c47f..95b7c8e0 100644 --- a/model/src/pyrenew/latent/hospitaladmissions.py +++ b/model/src/pyrenew/latent/hospitaladmissions.py @@ -9,7 +9,7 @@ import numpyro as npro from jax.typing import ArrayLike from pyrenew.deterministic import DeterministicVariable -from pyrenew.metaclass import RandomVariable +from pyrenew.metaclass import RandomVariable, TimeArray class HospitalAdmissionsSample(NamedTuple): @@ -20,12 +20,12 @@ class HospitalAdmissionsSample(NamedTuple): ---------- infection_hosp_rate : float, optional The infection-to-hospitalization rate. Defaults to None. - latent_hospital_admissions : ArrayLike or None + latent_hospital_admissions : TimeArray or None The computed number of hospital admissions. Defaults to None. """ infection_hosp_rate: float | None = None - latent_hospital_admissions: ArrayLike | None = None + latent_hospital_admissions: TimeArray | None = None def __repr__(self): return f"HospitalAdmissionsSample(infection_hosp_rate={self.infection_hosp_rate}, latent_hospital_admissions={self.latent_hospital_admissions})" @@ -162,7 +162,7 @@ def sample( Parameters ---------- - latent : ArrayLike + latent : ArrayLike or TimeArray Latent infections. **kwargs : dict, optional Additional keyword arguments passed through to internal `sample()` @@ -171,11 +171,10 @@ def sample( Returns ------- HospitalAdmissionsSample - """ - + """ infection_hosp_rate, *_ = self.infect_hosp_rate_rv.sample(**kwargs) - infection_hosp_rate_t = infection_hosp_rate * latent_infections + infection_hosp_rate_t = infection_hosp_rate.array * latent_infections ( infection_to_admission_interval, @@ -184,20 +183,20 @@ def sample( latent_hospital_admissions = jnp.convolve( infection_hosp_rate_t, - infection_to_admission_interval, + infection_to_admission_interval.array, mode="full", )[: infection_hosp_rate_t.shape[0]] # Applying the day of the week effect latent_hospital_admissions = ( latent_hospital_admissions - * self.day_of_week_effect_rv.sample(**kwargs)[0] + * self.day_of_week_effect_rv.sample(**kwargs)[0].array ) # Applying probability of hospitalization effect latent_hospital_admissions = ( latent_hospital_admissions - * self.hosp_report_prob_rv.sample(**kwargs)[0] + * self.hosp_report_prob_rv.sample(**kwargs)[0].array ) npro.deterministic( @@ -205,5 +204,10 @@ def sample( ) return HospitalAdmissionsSample( - infection_hosp_rate, latent_hospital_admissions + infection_hosp_rate=infection_hosp_rate, + latent_hospital_admissions=TimeArray( + array=latent_hospital_admissions, + t_start=self.infection_to_admission_interval_rv.t_start, + t_unit=self.infection_to_admission_interval_rv.t_unit, + ) ) diff --git a/model/src/pyrenew/latent/infection_seeding_method.py b/model/src/pyrenew/latent/infection_seeding_method.py index ce1cfcfc..f79fcf1b 100644 --- a/model/src/pyrenew/latent/infection_seeding_method.py +++ b/model/src/pyrenew/latent/infection_seeding_method.py @@ -176,7 +176,10 @@ def seed_infections(self, I_pre_seed: ArrayLike): raise ValueError( f"I_pre_seed must be an array of size 1. Got size {I_pre_seed.size}." ) + (rate,) = self.rate.sample() + rate = rate.array + if rate.size != 1: raise ValueError( f"rate must be an array of size 1. Got size {rate.size}." diff --git a/model/src/pyrenew/latent/infection_seeding_process.py b/model/src/pyrenew/latent/infection_seeding_process.py index 67110a30..4ef403b6 100644 --- a/model/src/pyrenew/latent/infection_seeding_process.py +++ b/model/src/pyrenew/latent/infection_seeding_process.py @@ -2,7 +2,7 @@ # numpydoc ignore=GL08 import numpyro as npro from pyrenew.latent.infection_seeding_method import InfectionSeedMethod -from pyrenew.metaclass import RandomVariable +from pyrenew.metaclass import RandomVariable, TimeArray class InfectionSeedingProcess(RandomVariable): @@ -90,7 +90,13 @@ def sample(self) -> tuple: a tuple where the only element is an array with the number of seeded infections at each time point. """ (I_pre_seed,) = self.I_pre_seed_rv.sample() - infection_seeding = self.infection_seed_method(I_pre_seed) + infection_seeding = self.infection_seed_method(I_pre_seed.array) npro.deterministic(self.name, infection_seeding) - return (infection_seeding,) + return ( + TimeArray( + array=infection_seeding, + t_start=self.t_start, + t_unit=self.t_unit, + ), + ) diff --git a/model/src/pyrenew/latent/infections.py b/model/src/pyrenew/latent/infections.py index 400e6886..e81c0873 100644 --- a/model/src/pyrenew/latent/infections.py +++ b/model/src/pyrenew/latent/infections.py @@ -8,7 +8,7 @@ import jax.numpy as jnp import pyrenew.latent.infection_functions as inf from jax.typing import ArrayLike -from pyrenew.metaclass import RandomVariable +from pyrenew.metaclass import RandomVariable, TimeArray class InfectionsSample(NamedTuple): @@ -17,11 +17,11 @@ class InfectionsSample(NamedTuple): Attributes ---------- - post_seed_infections : ArrayLike | None, optional + post_seed_infections : TimeArray | None, optional The estimated latent infections. Defaults to None. """ - post_seed_infections: ArrayLike | None = None + post_seed_infections: TimeArray | None = None def __repr__(self): return f"InfectionsSample(post_seed_infections={self.post_seed_infections})" @@ -97,4 +97,4 @@ def sample( reversed_generation_interval_pmf=gen_int_rev, ) - return InfectionsSample(post_seed_infections) + return InfectionsSample(TimeArray(post_seed_infections)) diff --git a/model/src/pyrenew/latent/infectionswithfeedback.py b/model/src/pyrenew/latent/infectionswithfeedback.py index c6be2ec3..90f1c92d 100644 --- a/model/src/pyrenew/latent/infectionswithfeedback.py +++ b/model/src/pyrenew/latent/infectionswithfeedback.py @@ -8,7 +8,9 @@ import pyrenew.arrayutils as au import pyrenew.latent.infection_functions as inf from numpy.typing import ArrayLike -from pyrenew.metaclass import RandomVariable, _assert_sample_and_rtype +from pyrenew.metaclass import ( + RandomVariable, _assert_sample_and_rtype, TimeArray +) class InfectionsRtFeedbackSample(NamedTuple): @@ -159,6 +161,7 @@ def sample( inf_feedback_strength, *_ = self.infection_feedback_strength.sample( **kwargs, ) + inf_feedback_strength = inf_feedback_strength.array # Making sure inf_feedback_strength spans the Rt length if inf_feedback_strength.size == 1: @@ -177,7 +180,7 @@ def sample( # Sampling inf feedback pmf inf_feedback_pmf, *_ = self.infection_feedback_pmf.sample(**kwargs) - inf_fb_pmf_rev = jnp.flip(inf_feedback_pmf) + inf_fb_pmf_rev = jnp.flip(inf_feedback_pmf.array) ( post_seed_infections, @@ -195,6 +198,6 @@ def sample( npro.deterministic("Rt_adjusted", Rt_adj) return InfectionsRtFeedbackSample( - post_seed_infections=post_seed_infections, - rt=Rt_adj, + post_seed_infections=TimeArray(post_seed_infections), + rt=TimeArray(Rt_adj), ) diff --git a/model/src/pyrenew/metaclass.py b/model/src/pyrenew/metaclass.py index 08ff8360..b07d31d4 100644 --- a/model/src/pyrenew/metaclass.py +++ b/model/src/pyrenew/metaclass.py @@ -92,6 +92,29 @@ def _assert_sample_and_rtype( return None +class TimeArray(NamedTuple): + """ + A container for a time-aware array. + + Attributes + ---------- + array: ArrayLike + The data array. + t_start: int, optional + The start time of the data.. + t_unit: int, optional + The unit of time relative to the model's fundamental (smallest) time unit. + """ + array: ArrayLike + t_start: int | None = None + t_unit: int | None = None + + @staticmethod + def to_array(array: ArrayLike | "TimeArray") -> ArrayLike: + if isinstance(array, TimeArray): + return array.array + return array + class RandomVariable(metaclass=ABCMeta): """ Abstract base class for latent and observed random variables. @@ -152,6 +175,16 @@ def set_timeseries( ------- None """ + + # Either both values are None or both are not None + assert \ + (t_unit is not None and t_start is not None) or \ + (t_unit is None and t_start is None), \ + "Both t_start and t_unit should be None or not None." + + if t_unit is None and t_start is None: + return None + # Timeseries unit should be a positive integer assert isinstance( t_unit, int @@ -288,13 +321,12 @@ def sample( DistributionalRVSample """ return DistributionalRVSample( - value=jnp.atleast_1d( + value=TimeArray(jnp.atleast_1d( npro.sample( name=self.name, fn=self.dist, obs=obs, - ) - ), + ))), ) diff --git a/model/src/pyrenew/model/admissionsmodel.py b/model/src/pyrenew/model/admissionsmodel.py index a190bcfc..c5ab0b0b 100644 --- a/model/src/pyrenew/model/admissionsmodel.py +++ b/model/src/pyrenew/model/admissionsmodel.py @@ -196,7 +196,7 @@ def sample( latent_hosp_admissions, *_, ) = self.latent_hosp_admissions_rv.sample( - latent_infections=basic_model.latent_infections, + latent_infections=basic_model.latent_infections.array, **kwargs, ) i0_size = len(latent_hosp_admissions) - n_timepoints @@ -208,14 +208,14 @@ def sample( observed_hosp_admissions, *_, ) = self.hosp_admission_obs_process_rv.sample( - mu=latent_hosp_admissions[i0_size + padding :], + mu=latent_hosp_admissions.array[i0_size + padding :], obs=data_observed_hosp_admissions, **kwargs, ) else: data_observed_hosp_admissions = au.pad_x_to_match_y( - data_observed_hosp_admissions, - latent_hosp_admissions, + data_observed_hosp_admissions.array, + latent_hosp_admissions.array, jnp.nan, pad_direction="start", ) @@ -224,7 +224,7 @@ def sample( observed_hosp_admissions, *_, ) = self.hosp_admission_obs_process_rv.sample( - mu=latent_hosp_admissions[i0_size + padding :], + mu=latent_hosp_admissions.array[i0_size + padding :], obs=data_observed_hosp_admissions[i0_size + padding :], **kwargs, ) diff --git a/model/src/pyrenew/model/rtinfectionsrenewalmodel.py b/model/src/pyrenew/model/rtinfectionsrenewalmodel.py index 2b5498f8..8289556e 100644 --- a/model/src/pyrenew/model/rtinfectionsrenewalmodel.py +++ b/model/src/pyrenew/model/rtinfectionsrenewalmodel.py @@ -10,7 +10,9 @@ import pyrenew.arrayutils as au from numpy.typing import ArrayLike from pyrenew.deterministic import NullObservation -from pyrenew.metaclass import Model, RandomVariable, _assert_sample_and_rtype +from pyrenew.metaclass import ( + Model, RandomVariable, _assert_sample_and_rtype, TimeArray +) # Output class of the RtInfectionsRenewalModel @@ -204,9 +206,9 @@ def sample( I0, *_ = self.I0_rv.sample(**kwargs) # Sampling from the latent process post_seed_latent_infections, *_ = self.latent_infections_rv.sample( - Rt=Rt, - gen_int=gen_int, - I0=I0, + Rt=Rt.array, + gen_int=gen_int.array, + I0=I0.array, **kwargs, ) @@ -214,29 +216,32 @@ def sample( data_observed_infections = data_observed_infections[padding:] observed_infections, *_ = self.infection_obs_process_rv.sample( - mu=post_seed_latent_infections[padding:], + mu=post_seed_latent_infections.array[padding:], obs=data_observed_infections, **kwargs, ) - all_latent_infections = jnp.hstack([I0, post_seed_latent_infections]) + all_latent_infections = jnp.hstack( + [I0.array, post_seed_latent_infections.array], + ) + npro.deterministic("all_latent_infections", all_latent_infections) observed_infections = au.pad_x_to_match_y( - observed_infections, + observed_infections.array, all_latent_infections, jnp.nan, pad_direction="start", ) Rt = au.pad_x_to_match_y( - Rt, + Rt.array, all_latent_infections, jnp.nan, pad_direction="start", ) return RtInfectionsRenewalSample( - Rt=Rt, - latent_infections=all_latent_infections, - observed_infections=observed_infections, + Rt=TimeArray(Rt), + latent_infections=TimeArray(all_latent_infections), + observed_infections=TimeArray(observed_infections), ) diff --git a/model/src/pyrenew/observation/negativebinomial.py b/model/src/pyrenew/observation/negativebinomial.py index 48710592..064dfa22 100644 --- a/model/src/pyrenew/observation/negativebinomial.py +++ b/model/src/pyrenew/observation/negativebinomial.py @@ -8,7 +8,7 @@ import numpyro import numpyro.distributions as dist from jax.typing import ArrayLike -from pyrenew.metaclass import RandomVariable +from pyrenew.metaclass import RandomVariable, TimeArray class NegativeBinomialObservation(RandomVariable): @@ -92,14 +92,14 @@ def sample( name = self.parameter_name return ( - numpyro.sample( + TimeArray(numpyro.sample( name=name, fn=dist.NegativeBinomial2( mean=mu + self.eps, concentration=concentration, ), obs=obs, - ), + )), ) @staticmethod diff --git a/model/src/pyrenew/observation/poisson.py b/model/src/pyrenew/observation/poisson.py index c641cf76..cc4b9ae0 100644 --- a/model/src/pyrenew/observation/poisson.py +++ b/model/src/pyrenew/observation/poisson.py @@ -6,7 +6,7 @@ import numpyro import numpyro.distributions as dist from jax.typing import ArrayLike -from pyrenew.metaclass import RandomVariable +from pyrenew.metaclass import RandomVariable, TimeArray class PoissonObservation(RandomVariable): @@ -70,11 +70,11 @@ def sample( name = self.parameter_name return ( - numpyro.sample( + TimeArray(numpyro.sample( name=name, fn=dist.Poisson(rate=mu + self.eps), obs=obs, - ), + )), ) @staticmethod diff --git a/model/src/pyrenew/process/ar.py b/model/src/pyrenew/process/ar.py index 9b7e7e31..8e108c75 100644 --- a/model/src/pyrenew/process/ar.py +++ b/model/src/pyrenew/process/ar.py @@ -8,7 +8,7 @@ import numpyro.distributions as dist from jax import lax from jax.typing import ArrayLike -from pyrenew.metaclass import RandomVariable +from pyrenew.metaclass import RandomVariable, TimeArray class ARProcess(RandomVariable): @@ -91,7 +91,9 @@ def _ar_scanner(carry, next): # numpydoc ignore=GL08 ) last, ts = lax.scan(_ar_scanner, inits - self.mean, noise) - return (jnp.hstack([inits, self.mean + ts.flatten()]),) + return ( + TimeArray(jnp.hstack([inits, self.mean + ts.flatten()])), + ) @staticmethod def validate(): # numpydoc ignore=RT01 diff --git a/model/src/pyrenew/process/firstdifferencear.py b/model/src/pyrenew/process/firstdifferencear.py index ce8e9d0e..f8bace09 100644 --- a/model/src/pyrenew/process/firstdifferencear.py +++ b/model/src/pyrenew/process/firstdifferencear.py @@ -5,7 +5,7 @@ import jax.numpy as jnp from jax.typing import ArrayLike -from pyrenew.metaclass import RandomVariable +from pyrenew.metaclass import RandomVariable, TimeArray from pyrenew.process import ARProcess @@ -72,7 +72,7 @@ def sample( inits=jnp.atleast_1d(init_rate_of_change), name=name + "_rate_of_change", ) - return (init_val + jnp.cumsum(rates_of_change.flatten()),) + return (TimeArray(init_val + jnp.cumsum(rates_of_change.array.flatten())),) @staticmethod def validate(): diff --git a/model/src/pyrenew/process/rtperiodicdiff.py b/model/src/pyrenew/process/rtperiodicdiff.py index f67bd60b..40b2709f 100644 --- a/model/src/pyrenew/process/rtperiodicdiff.py +++ b/model/src/pyrenew/process/rtperiodicdiff.py @@ -4,7 +4,7 @@ import jax.numpy as jnp from jax.typing import ArrayLike from pyrenew.arrayutils import PeriodicBroadcaster -from pyrenew.metaclass import RandomVariable, _assert_sample_and_rtype +from pyrenew.metaclass import RandomVariable, _assert_sample_and_rtype, TimeArray from pyrenew.process.firstdifferencear import FirstDifferenceARProcess @@ -188,7 +188,7 @@ def sample( )[0] return RtPeriodicDiffProcessSample( - rt=self.broadcaster(jnp.exp(log_rt.flatten()), duration), + rt=TimeArray(self.broadcaster(jnp.exp(log_rt.flatten()), duration)), ) diff --git a/model/src/pyrenew/process/rtrandomwalk.py b/model/src/pyrenew/process/rtrandomwalk.py index 17d48209..44fb5e87 100644 --- a/model/src/pyrenew/process/rtrandomwalk.py +++ b/model/src/pyrenew/process/rtrandomwalk.py @@ -4,7 +4,7 @@ import numpyro as npro import numpyro.distributions as dist import pyrenew.transformation as t -from pyrenew.metaclass import RandomVariable +from pyrenew.metaclass import RandomVariable, TimeArray from pyrenew.process.simplerandomwalk import SimpleRandomWalkProcess @@ -120,6 +120,6 @@ def sample( init=Rt0_trans, ) - Rt = npro.deterministic("Rt", self.Rt_transform.inv(Rt_trans_ts)) + Rt = npro.deterministic("Rt", self.Rt_transform.inv(Rt_trans_ts.array)) - return (Rt,) + return (TimeArray(Rt),) diff --git a/model/src/pyrenew/process/simplerandomwalk.py b/model/src/pyrenew/process/simplerandomwalk.py index d2c233b3..fb9b404f 100644 --- a/model/src/pyrenew/process/simplerandomwalk.py +++ b/model/src/pyrenew/process/simplerandomwalk.py @@ -4,7 +4,7 @@ import jax.numpy as jnp import numpyro as npro import numpyro.distributions as dist -from pyrenew.metaclass import RandomVariable +from pyrenew.metaclass import RandomVariable, TimeArray class SimpleRandomWalkProcess(RandomVariable): @@ -67,7 +67,7 @@ def sample( self.error_distribution.expand((n_timepoints - 1,)), ) - return (init + jnp.cumsum(jnp.pad(diffs, [1, 0], constant_values=0)),) + return (TimeArray(init + jnp.cumsum(jnp.pad(diffs, [1, 0], constant_values=0))),) @staticmethod def validate(): diff --git a/model/src/test/test_ar_process.py b/model/src/test/test_ar_process.py index 3910b1f7..5b10cc8b 100755 --- a/model/src/test/test_ar_process.py +++ b/model/src/test/test_ar_process.py @@ -37,5 +37,6 @@ def test_ar_samples_correctly_distributed(): # check it regresses to mean # when started away from it long_ts, *_ = ar1.sample(10000, inits=ar_inits) - assert_almost_equal(long_ts[0], ar_inits) - assert jnp.abs(long_ts[-1] - ar_mean) < 4 * noise_sd + assert_almost_equal(long_ts.array[0], ar_inits) + assert jnp.abs(long_ts.array[-1] - ar_mean) < 4 * noise_sd + diff --git a/model/src/test/test_deterministic.py b/model/src/test/test_deterministic.py index ec458c8a..d51e5f52 100644 --- a/model/src/test/test_deterministic.py +++ b/model/src/test/test_deterministic.py @@ -31,7 +31,7 @@ def test_deterministic(): var5 = NullProcess() testing.assert_array_equal( - var1.sample()[0], + var1.sample()[0].array, jnp.array( [ 1, @@ -39,16 +39,16 @@ def test_deterministic(): ), ) testing.assert_array_equal( - var2.sample()[0], + var2.sample()[0].array, jnp.array([0.25, 0.25, 0.2, 0.3]), ) testing.assert_array_equal( - var3.sample(duration=5)[0], + var3.sample(duration=5)[0].array, jnp.array([1, 2, 3, 4, 4]), ) testing.assert_array_equal( - var3.sample(duration=3)[0], + var3.sample(duration=3)[0].array, jnp.array( [ 1, @@ -58,5 +58,7 @@ def test_deterministic(): ), ) - testing.assert_equal(var4.sample()[0], None) - testing.assert_equal(var5.sample(duration=1)[0], None) + testing.assert_equal(var4.sample()[0].array, None) + testing.assert_equal(var5.sample(duration=1)[0].array, None) + +test_deterministic() \ No newline at end of file diff --git a/model/src/test/test_first_difference_ar.py b/model/src/test/test_first_difference_ar.py index b6378f87..c2e6671d 100755 --- a/model/src/test/test_first_difference_ar.py +++ b/model/src/test/test_first_difference_ar.py @@ -26,5 +26,5 @@ def test_fd_ar_can_be_sampled(): ) # Checking proper shape - assert ans0[0].shape == (3532,) - assert ans1[0].shape == (3532,) + assert ans0[0].array.shape == (3532,) + assert ans1[0].array.shape == (3532,) diff --git a/model/src/test/test_infection_seeding_method.py b/model/src/test/test_infection_seeding_method.py index 527ac96f..e0697fe4 100644 --- a/model/src/test/test_infection_seeding_method.py +++ b/model/src/test/test_infection_seeding_method.py @@ -20,6 +20,9 @@ def test_seed_infections_exponential(): (I_pre_seed,) = I_pre_seed_RV.sample() (rate,) = rate_RV.sample() + I_pre_seed = I_pre_seed.array + rate = rate.array + infections_default_t_pre_seed = SeedInfectionsExponentialGrowth( n_timepoints, rate=rate_RV ).seed_infections(I_pre_seed) @@ -49,7 +52,7 @@ def test_seed_infections_exponential(): with pytest.raises(ValueError): SeedInfectionsExponentialGrowth( n_timepoints, rate=rate_RV - ).seed_infections(I_pre_seed_2) + ).seed_infections(I_pre_seed_2.array) # test non-default t_pre_seed t_pre_seed = 6 @@ -88,16 +91,16 @@ def test_seed_infections_zero_pad(): (I_pre_seed_2,) = I_pre_seed_RV_2.sample() infections_2 = SeedInfectionsZeroPad(n_timepoints).seed_infections( - I_pre_seed_2 + I_pre_seed_2.array ) testing.assert_array_equal( infections_2, - np.pad(I_pre_seed_2, (n_timepoints - I_pre_seed_2.size, 0)), + np.pad(I_pre_seed_2.array, (n_timepoints - I_pre_seed_2.array.size, 0)), ) # Check that the SeedInfectionsZeroPad class raises an error when the length of I_pre_seed is greater than n_timepoints. with pytest.raises(ValueError): - SeedInfectionsZeroPad(1).seed_infections(I_pre_seed_2) + SeedInfectionsZeroPad(1).seed_infections(I_pre_seed_2.array) def test_seed_infections_from_vec(): diff --git a/model/src/test/test_infectionsrtfeedback.py b/model/src/test/test_infectionsrtfeedback.py index 9467c116..742ced9c 100644 --- a/model/src/test/test_infectionsrtfeedback.py +++ b/model/src/test/test_infectionsrtfeedback.py @@ -95,7 +95,7 @@ def test_infectionsrtfeedback(): ) assert_array_equal(samp1.post_seed_infections, samp2.post_seed_infections) - assert_array_equal(samp1.rt, Rt) + assert_array_equal(samp1.rt.array, Rt) return None @@ -139,16 +139,16 @@ def test_infectionsrtfeedback_feedback(): gen_int=gen_int, Rt=Rt, I0=I0, - inf_feedback_strength=inf_feed_strength.sample()[0], - inf_feedback_pmf=inf_feedback_pmf.sample()[0], + inf_feedback_strength=inf_feed_strength.sample()[0].array, + inf_feedback_pmf=inf_feedback_pmf.sample()[0].array, ) assert not jnp.array_equal( - samp1.post_seed_infections, samp2.post_seed_infections + samp1.post_seed_infections.array, samp2.post_seed_infections.array ) assert_array_almost_equal( - samp1.post_seed_infections, res["post_seed_infections"] + samp1.post_seed_infections.array, res["post_seed_infections"] ) - assert_array_almost_equal(samp1.rt, res["rt"]) + assert_array_almost_equal(samp1.rt.array, res["rt"]) return None diff --git a/model/src/test/test_latent_admissions.py b/model/src/test/test_latent_admissions.py index e591e083..fb3ff571 100644 --- a/model/src/test/test_latent_admissions.py +++ b/model/src/test/test_latent_admissions.py @@ -36,7 +36,7 @@ def test_admissions_sample(): inf1 = Infections() with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): - inf_sampled1 = inf1.sample(Rt=sim_rt, gen_int=gen_int, I0=i0) + inf_sampled1 = inf1.sample(Rt=sim_rt.array, gen_int=gen_int, I0=i0) # Testing the hospital admissions inf_hosp = DeterministicPMF( @@ -73,9 +73,9 @@ def test_admissions_sample(): ) with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): - sim_hosp_1 = hosp1.sample(latent_infections=inf_sampled1[0]) + sim_hosp_1 = hosp1.sample(latent_infections=inf_sampled1[0].array) testing.assert_array_less( - sim_hosp_1.latent_hospital_admissions, - inf_sampled1[0], + sim_hosp_1.latent_hospital_admissions.array, + inf_sampled1[0].array, ) diff --git a/model/src/test/test_latent_infections.py b/model/src/test/test_latent_infections.py index 1df985d9..cdce7fef 100755 --- a/model/src/test/test_latent_infections.py +++ b/model/src/test/test_latent_infections.py @@ -41,7 +41,7 @@ def test_infections_as_deterministic(): inf_sampled2 = inf1.sample(**obs) testing.assert_array_equal( - inf_sampled1.post_seed_infections, inf_sampled2.post_seed_infections + inf_sampled1.post_seed_infections.array, inf_sampled2.post_seed_infections.array ) # Check that Initial infections vector must be at least as long as the generation interval. diff --git a/model/src/test/test_model_basic_renewal.py b/model/src/test/test_model_basic_renewal.py index c8c53ec3..cb7fb15b 100644 --- a/model/src/test/test_model_basic_renewal.py +++ b/model/src/test/test_model_basic_renewal.py @@ -148,13 +148,13 @@ def test_model_basicrenewal_no_obs_model(): with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): model1_samp = model0.sample(n_timepoints_to_simulate=30) - np.testing.assert_array_equal(model0_samp.Rt, model1_samp.Rt) + np.testing.assert_array_equal(model0_samp.Rt.array, model1_samp.Rt.array) np.testing.assert_array_equal( - model0_samp.latent_infections, model1_samp.latent_infections + model0_samp.latent_infections.array, model1_samp.latent_infections.array ) np.testing.assert_array_equal( - model0_samp.observed_infections, - model1_samp.observed_infections, + model0_samp.observed_infections.array, + model1_samp.observed_infections.array, ) model0.run( @@ -271,7 +271,7 @@ def test_model_basicrenewal_padding() -> None: # numpydoc ignore=GL08 model1_samp = model1.sample(n_timepoints_to_simulate=30) new_obs = jnp.hstack( - [jnp.repeat(jnp.nan, 5), model1_samp.observed_infections[5:]], + [jnp.repeat(jnp.nan, 5), model1_samp.observed_infections.array[5:]], ) model1.run( diff --git a/model/src/test/test_observation_negativebinom.py b/model/src/test/test_observation_negativebinom.py index a90408b5..e6862823 100644 --- a/model/src/test/test_observation_negativebinom.py +++ b/model/src/test/test_observation_negativebinom.py @@ -21,8 +21,8 @@ def test_negativebinom_deterministic_obs(): sim_pois2 = negb.sample(mu=rates, obs=rates) testing.assert_array_equal( - sim_pois1, - sim_pois2, + sim_pois1.array, + sim_pois2.array, ) @@ -40,7 +40,7 @@ def test_negativebinom_random_obs(): sim_pois2 = negb.sample(mu=rates) testing.assert_array_almost_equal( - np.mean(sim_pois1), - np.mean(sim_pois2), + np.mean(sim_pois1.array), + np.mean(sim_pois2.array), decimal=1, ) diff --git a/model/src/test/test_random_walk.py b/model/src/test/test_random_walk.py index 9f1335e1..d8f94dce 100755 --- a/model/src/test/test_random_walk.py +++ b/model/src/test/test_random_walk.py @@ -20,8 +20,8 @@ def test_rw_can_be_sampled(): ans1 = rw_normal.sample(5023) # check that the samples are of the right shape - assert ans0[0].shape == (3532,) - assert ans1[0].shape == (5023,) + assert ans0[0].array.shape == (3532,) + assert ans1[0].array.shape == (5023,) def test_rw_samples_correctly_distributed(): @@ -38,6 +38,7 @@ def test_rw_samples_correctly_distributed(): init_arr = jnp.array([532.0]) with numpyro.handlers.seed(rng_seed=62): samples, *_ = rw_normal.sample(n_samples, init=init_arr) + samples = samples.array # Checking the shape assert samples.shape == (n_samples,) diff --git a/model/src/test/test_rtperiodicdiff.py b/model/src/test/test_rtperiodicdiff.py index 1b4fedf9..886f8320 100644 --- a/model/src/test/test_rtperiodicdiff.py +++ b/model/src/test/test_rtperiodicdiff.py @@ -66,7 +66,7 @@ def test_rtweeklydiff() -> None: np.random.seed(223) with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): - rt = rtwd.sample(duration=duration).rt + rt = rtwd.sample(duration=duration).rt.array # Checking that the shape of the sampled Rt is correct assert rt.shape == (duration,) From 997812c8ed78f6788dedeeddf8d21cdd48eecd20 Mon Sep 17 00:00:00 2001 From: George Vega Yon Date: Wed, 10 Jul 2024 17:20:03 -0600 Subject: [PATCH 03/41] Working on tests (still 16 fails) [skip ci] expected to fail --- .../src/test/test_infection_seeding_method.py | 8 +++++--- model/src/test/test_infectionsrtfeedback.py | 2 +- model/src/test/test_latent_infections.py | 2 +- model/src/test/test_model_basic_renewal.py | 2 +- model/src/test/test_model_hospitalizations.py | 18 +++++++++--------- 5 files changed, 17 insertions(+), 15 deletions(-) diff --git a/model/src/test/test_infection_seeding_method.py b/model/src/test/test_infection_seeding_method.py index e0697fe4..aaad4f5e 100644 --- a/model/src/test/test_infection_seeding_method.py +++ b/model/src/test/test_infection_seeding_method.py @@ -77,6 +77,7 @@ def test_seed_infections_zero_pad(): n_timepoints = 10 I_pre_seed_RV = DeterministicVariable(10.0, name="I_pre_seed_RV") (I_pre_seed,) = I_pre_seed_RV.sample() + I_pre_seed = I_pre_seed.array infections = SeedInfectionsZeroPad(n_timepoints).seed_infections( I_pre_seed @@ -89,18 +90,19 @@ def test_seed_infections_zero_pad(): np.array([10.0, 10.0]), name="I_pre_seed_RV" ) (I_pre_seed_2,) = I_pre_seed_RV_2.sample() + I_pre_seed_2 = I_pre_seed_2.array infections_2 = SeedInfectionsZeroPad(n_timepoints).seed_infections( - I_pre_seed_2.array + I_pre_seed_2 ) testing.assert_array_equal( infections_2, - np.pad(I_pre_seed_2.array, (n_timepoints - I_pre_seed_2.array.size, 0)), + np.pad(I_pre_seed_2, (n_timepoints - I_pre_seed_2.size, 0)), ) # Check that the SeedInfectionsZeroPad class raises an error when the length of I_pre_seed is greater than n_timepoints. with pytest.raises(ValueError): - SeedInfectionsZeroPad(1).seed_infections(I_pre_seed_2.array) + SeedInfectionsZeroPad(1).seed_infections(I_pre_seed_2) def test_seed_infections_from_vec(): diff --git a/model/src/test/test_infectionsrtfeedback.py b/model/src/test/test_infectionsrtfeedback.py index 742ced9c..d1e5acc8 100644 --- a/model/src/test/test_infectionsrtfeedback.py +++ b/model/src/test/test_infectionsrtfeedback.py @@ -94,7 +94,7 @@ def test_infectionsrtfeedback(): I0=I0, ) - assert_array_equal(samp1.post_seed_infections, samp2.post_seed_infections) + assert_array_equal(samp1.post_seed_infections.array, samp2.post_seed_infections.array) assert_array_equal(samp1.rt.array, Rt) return None diff --git a/model/src/test/test_latent_infections.py b/model/src/test/test_latent_infections.py index cdce7fef..322c8932 100755 --- a/model/src/test/test_latent_infections.py +++ b/model/src/test/test_latent_infections.py @@ -32,7 +32,7 @@ def test_infections_as_deterministic(): inf1 = Infections() obs = dict( - Rt=sim_rt, + Rt=sim_rt.array, I0=jnp.zeros(gen_int.size), gen_int=gen_int, ) diff --git a/model/src/test/test_model_basic_renewal.py b/model/src/test/test_model_basic_renewal.py index cb7fb15b..4fd2e787 100644 --- a/model/src/test/test_model_basic_renewal.py +++ b/model/src/test/test_model_basic_renewal.py @@ -220,7 +220,7 @@ def test_model_basicrenewal_with_obs_model(): num_warmup=500, num_samples=500, rng_key=jr.key(22), - data_observed_infections=model1_samp.observed_infections, + data_observed_infections=model1_samp.observed_infections.array, ) inf = model1.spread_draws(["all_latent_infections"]) diff --git a/model/src/test/test_model_hospitalizations.py b/model/src/test/test_model_hospitalizations.py index d6d5c023..74a333ed 100644 --- a/model/src/test/test_model_hospitalizations.py +++ b/model/src/test/test_model_hospitalizations.py @@ -259,26 +259,26 @@ def test_model_hosp_no_obs_model(): with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): model1_samp = model0.sample(n_timepoints_to_simulate=30) - np.testing.assert_array_equal(model0_samp.Rt, model1_samp.Rt) + np.testing.assert_array_equal(model0_samp.Rt.array, model1_samp.Rt.array) np.testing.assert_array_equal( - model0_samp.latent_infections, model1_samp.latent_infections + model0_samp.latent_infections.array, model1_samp.latent_infections.array ) np.testing.assert_array_equal( - model0_samp.infection_hosp_rate, model1_samp.infection_hosp_rate + model0_samp.infection_hosp_rate.array, model1_samp.infection_hosp_rate.array ) np.testing.assert_array_equal( - model0_samp.latent_hosp_admissions, model1_samp.latent_hosp_admissions + model0_samp.latent_hosp_admissions.array, model1_samp.latent_hosp_admissions.array ) np.testing.assert_array_equal( - model0_samp.observed_hosp_admissions, - model1_samp.observed_hosp_admissions, + model0_samp.observed_hosp_admissions.array, + model1_samp.observed_hosp_admissions.array, ) model0.run( num_warmup=500, num_samples=500, rng_key=jr.key(272), - data_observed_hosp_admissions=model0_samp.latent_hosp_admissions, + data_observed_hosp_admissions=model0_samp.latent_hosp_admissions.array, ) inf = model0.spread_draws(["latent_hospital_admissions"]) @@ -368,7 +368,7 @@ def test_model_hosp_with_obs_model(): num_warmup=500, num_samples=500, rng_key=jr.key(272), - data_observed_hosp_admissions=model1_samp.observed_hosp_admissions, + data_observed_hosp_admissions=model1_samp.observed_hosp_admissions.array, ) inf = model1.spread_draws(["latent_hospital_admissions"]) @@ -469,7 +469,7 @@ def test_model_hosp_with_obs_model_weekday_phosp_2(): num_warmup=500, num_samples=500, rng_key=jr.key(272), - data_observed_hosp_admissions=model1_samp.observed_hosp_admissions, + data_observed_hosp_admissions=model1_samp.observed_hosp_admissions.array, ) inf = model1.spread_draws(["latent_hospital_admissions"]) From 0ddd03ea0c669f6969716663425c7074131beafa Mon Sep 17 00:00:00 2001 From: George Vega Yon Date: Fri, 12 Jul 2024 09:48:05 -0600 Subject: [PATCH 04/41] Down to 5 errors --- model/src/pyrenew/metaclass.py | 4 ++-- model/src/pyrenew/model/admissionsmodel.py | 2 +- model/src/pyrenew/process/rtperiodicdiff.py | 8 ++++---- model/src/test/test_model_hospitalizations.py | 14 +++++++------- model/src/test/test_observation_negativebinom.py | 10 +++++----- model/src/test/test_rtperiodicdiff.py | 10 +++++----- 6 files changed, 24 insertions(+), 24 deletions(-) diff --git a/model/src/pyrenew/metaclass.py b/model/src/pyrenew/metaclass.py index b07d31d4..0fbb561c 100644 --- a/model/src/pyrenew/metaclass.py +++ b/model/src/pyrenew/metaclass.py @@ -242,11 +242,11 @@ class DistributionalRVSample(NamedTuple): Attributes ---------- - value : ArrayLike + value : TimeArray Sampled value from the distribution. """ - value: ArrayLike | None = None + value: TimeArray | None = None def __repr__(self) -> str: """ diff --git a/model/src/pyrenew/model/admissionsmodel.py b/model/src/pyrenew/model/admissionsmodel.py index c5ab0b0b..45d465e8 100644 --- a/model/src/pyrenew/model/admissionsmodel.py +++ b/model/src/pyrenew/model/admissionsmodel.py @@ -214,7 +214,7 @@ def sample( ) else: data_observed_hosp_admissions = au.pad_x_to_match_y( - data_observed_hosp_admissions.array, + data_observed_hosp_admissions, latent_hosp_admissions.array, jnp.nan, pad_direction="start", diff --git a/model/src/pyrenew/process/rtperiodicdiff.py b/model/src/pyrenew/process/rtperiodicdiff.py index 40b2709f..bc684119 100644 --- a/model/src/pyrenew/process/rtperiodicdiff.py +++ b/model/src/pyrenew/process/rtperiodicdiff.py @@ -172,9 +172,9 @@ def sample( """ # Initial sample - log_rt_prior = self.log_rt_prior.sample(**kwargs)[0] - b = self.autoreg.sample(**kwargs)[0] - s_r = self.periodic_diff_sd.sample(**kwargs)[0] + log_rt_prior = self.log_rt_prior.sample(**kwargs)[0].array + b = self.autoreg.sample(**kwargs)[0].array + s_r = self.periodic_diff_sd.sample(**kwargs)[0].array # How many periods to sample? n_periods = int(jnp.ceil(duration / self.period_size)) @@ -188,7 +188,7 @@ def sample( )[0] return RtPeriodicDiffProcessSample( - rt=TimeArray(self.broadcaster(jnp.exp(log_rt.flatten()), duration)), + rt=TimeArray(self.broadcaster(jnp.exp(log_rt.array.flatten()), duration)), ) diff --git a/model/src/test/test_model_hospitalizations.py b/model/src/test/test_model_hospitalizations.py index 74a333ed..abccf230 100644 --- a/model/src/test/test_model_hospitalizations.py +++ b/model/src/test/test_model_hospitalizations.py @@ -21,7 +21,7 @@ InfectionSeedingProcess, SeedInfectionsZeroPad, ) -from pyrenew.metaclass import DistributionalRV, RandomVariable +from pyrenew.metaclass import DistributionalRV, RandomVariable, TimeArray from pyrenew.model import HospitalAdmissionsModel from pyrenew.observation import PoissonObservation from pyrenew.process import RtRandomWalkProcess @@ -39,7 +39,7 @@ def validate(self): # numpydoc ignore=GL08 def sample(self, **kwargs): # numpydoc ignore=GL08 return ( - npro.sample(name=self.name, fn=dist.Uniform(high=0.99, low=0.01)), + TimeArray(npro.sample(name=self.name, fn=dist.Uniform(high=0.99, low=0.01))), ) @@ -269,10 +269,10 @@ def test_model_hosp_no_obs_model(): np.testing.assert_array_equal( model0_samp.latent_hosp_admissions.array, model1_samp.latent_hosp_admissions.array ) - np.testing.assert_array_equal( - model0_samp.observed_hosp_admissions.array, - model1_samp.observed_hosp_admissions.array, - ) + + # These are supposed to be none, both + assert model0_samp.observed_hosp_admissions is None + assert model1_samp.observed_hosp_admissions is None model0.run( num_warmup=500, @@ -383,6 +383,7 @@ def test_model_hosp_with_obs_model(): assert inf_mean.to_numpy().shape[0] == 500 + def test_model_hosp_with_obs_model_weekday_phosp_2(): """ Checks that the random Hospitalization model runs @@ -483,7 +484,6 @@ def test_model_hosp_with_obs_model_weekday_phosp_2(): # It should be about the MCMC inference. assert inf_mean.to_numpy().shape[0] == 500 - def test_model_hosp_with_obs_model_weekday_phosp(): """ Checks that the random Hospitalization model runs diff --git a/model/src/test/test_observation_negativebinom.py b/model/src/test/test_observation_negativebinom.py index e6862823..0fbc94f9 100644 --- a/model/src/test/test_observation_negativebinom.py +++ b/model/src/test/test_observation_negativebinom.py @@ -17,8 +17,8 @@ def test_negativebinom_deterministic_obs(): np.random.seed(223) rates = np.random.randint(1, 5, size=10) with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): - sim_pois1 = negb.sample(mu=rates, obs=rates) - sim_pois2 = negb.sample(mu=rates, obs=rates) + sim_pois1, *_ = negb.sample(mu=rates, obs=rates) + sim_pois2, *_ = negb.sample(mu=rates, obs=rates) testing.assert_array_equal( sim_pois1.array, @@ -36,11 +36,11 @@ def test_negativebinom_random_obs(): np.random.seed(223) rates = np.repeat(5, 20000) with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): - sim_pois1 = negb.sample(mu=rates) - sim_pois2 = negb.sample(mu=rates) + sim_pois1, *_ = negb.sample(mu=rates) + sim_pois2, *_ = negb.sample(mu=rates) testing.assert_array_almost_equal( np.mean(sim_pois1.array), np.mean(sim_pois2.array), decimal=1, - ) + ) \ No newline at end of file diff --git a/model/src/test/test_rtperiodicdiff.py b/model/src/test/test_rtperiodicdiff.py index 886f8320..a530498d 100644 --- a/model/src/test/test_rtperiodicdiff.py +++ b/model/src/test/test_rtperiodicdiff.py @@ -81,7 +81,7 @@ def test_rtweeklydiff() -> None: params["offset"] = 5 rtwd = RtWeeklyDiffProcess(**params) with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): - rt2 = rtwd.sample(duration=duration).rt + rt2 = rtwd.sample(duration=duration).rt.array # Checking that the shape of the sampled Rt is correct assert rt2.shape == (duration,) @@ -114,7 +114,7 @@ def test_rtweeklydiff_no_autoregressive() -> None: np.random.seed(223) duration = 1000 with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): - rt = rtwd.sample(duration=duration).rt + rt = rtwd.sample(duration=duration).rt.array # Checking that the shape of the sampled Rt is correct assert rt.shape == (duration,) @@ -153,12 +153,12 @@ def test_rtweeklydiff_manual_reconstruction() -> None: _, ans0 = lax.scan( f=rtwd.autoreg_process, - init=np.hstack([params["log_rt_prior"].sample()[0], b]), + init=np.hstack([params["log_rt_prior"].sample()[0].array, b]), xs=noise, ) ans1 = _manual_rt_weekly_diff( - log_seed=params["log_rt_prior"].sample()[0], sd=noise, b=b + log_seed=params["log_rt_prior"].sample()[0].array, sd=noise, b=b ) assert_array_equal(ans0, ans1) @@ -185,7 +185,7 @@ def test_rtperiodicdiff_smallsample(): np.random.seed(223) with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): - rt = rtwd.sample(duration=6).rt + rt = rtwd.sample(duration=6).rt.array # Checking that the shape of the sampled Rt is correct assert rt.shape == (6,) From 015e38196ece96e46604eae0a735caea376a90d5 Mon Sep 17 00:00:00 2001 From: George Vega Yon Date: Fri, 12 Jul 2024 10:04:27 -0600 Subject: [PATCH 05/41] Down to 1 error --- model/src/pyrenew/process/periodiceffect.py | 10 +++++----- model/src/test/test_observation_poisson.py | 2 +- model/src/test/test_periodiceffect.py | 11 +++++++---- model/src/test/test_random_key.py | 2 +- 4 files changed, 14 insertions(+), 11 deletions(-) diff --git a/model/src/pyrenew/process/periodiceffect.py b/model/src/pyrenew/process/periodiceffect.py index ae3a570e..e12065cc 100644 --- a/model/src/pyrenew/process/periodiceffect.py +++ b/model/src/pyrenew/process/periodiceffect.py @@ -4,7 +4,7 @@ import jax.numpy as jnp import pyrenew.arrayutils as au -from pyrenew.metaclass import RandomVariable, _assert_sample_and_rtype +from pyrenew.metaclass import RandomVariable, _assert_sample_and_rtype, TimeArray class PeriodicEffectSample(NamedTuple): @@ -14,7 +14,7 @@ class PeriodicEffectSample(NamedTuple): Attributes ---------- - value: jnp.ndarray + value: TimeArray The sampled value. """ @@ -110,10 +110,10 @@ def sample(self, duration: int, **kwargs): """ return PeriodicEffectSample( - value=self.broadcaster( - data=self.quantity_to_broadcast.sample(**kwargs)[0], + value=TimeArray(self.broadcaster( + data=self.quantity_to_broadcast.sample(**kwargs)[0].array, n_timepoints=duration, - ) + )) ) diff --git a/model/src/test/test_observation_poisson.py b/model/src/test/test_observation_poisson.py index fee5bbac..28f4fbbc 100644 --- a/model/src/test/test_observation_poisson.py +++ b/model/src/test/test_observation_poisson.py @@ -20,4 +20,4 @@ def test_poisson_obs(): with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): sim_pois, *_ = pois.sample(mu=rates) - testing.assert_array_equal(sim_pois, jnp.ceil(sim_pois)) + testing.assert_array_equal(sim_pois.array, jnp.ceil(sim_pois.array)) diff --git a/model/src/test/test_periodiceffect.py b/model/src/test/test_periodiceffect.py index e78e1c43..688dd01b 100644 --- a/model/src/test/test_periodiceffect.py +++ b/model/src/test/test_periodiceffect.py @@ -30,7 +30,7 @@ def test_periodiceffect() -> None: np.random.seed(223) with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): - ans = pe.sample(duration=duration).value + ans = pe.sample(duration=duration).value.array # Checking that the shape of the sampled Rt is correct assert ans.shape == (duration,) @@ -44,7 +44,7 @@ def test_periodiceffect() -> None: params["offset"] = 5 pe = PeriodicEffect(**params) with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): - ans2 = pe.sample(duration=duration).value + ans2 = pe.sample(duration=duration).value.array # Checking that the shape of the sampled Rt is correct assert ans2.shape == (duration,) @@ -81,9 +81,12 @@ def test_weeklyeffect() -> None: pe = PeriodicEffect(**params) pe2 = DayOfWeekEffect(**params2) - ans1 = pe.sample(duration=duration).value - ans2 = pe2.sample(duration=duration).value + ans1 = pe.sample(duration=duration).value.array + ans2 = pe2.sample(duration=duration).value.array assert_array_equal(ans1, ans2) return None + +test_periodiceffect() +test_weeklyeffect() \ No newline at end of file diff --git a/model/src/test/test_random_key.py b/model/src/test/test_random_key.py index f173012c..bc23ef79 100644 --- a/model/src/test/test_random_key.py +++ b/model/src/test/test_random_key.py @@ -100,7 +100,7 @@ def test_rng_keys_produce_correct_samples(): model_sample = models[0].sample( n_timepoints_to_simulate=n_timepoints_to_simulate[0] ) - obs_infections = [model_sample.observed_infections] * len(models) + obs_infections = [model_sample.observed_infections.array] * len(models) rng_keys = [jr.key(54), jr.key(54), None, None, jr.key(74)] # run test models with the different keys From e07329420ae17b873df57494ec49b5fb17602d45 Mon Sep 17 00:00:00 2001 From: George Vega Yon Date: Wed, 17 Jul 2024 09:54:06 -0600 Subject: [PATCH 06/41] Fixing test. Next: merge conflicts --- model/src/pyrenew/model/admissionsmodel.py | 2 +- model/src/test/test_model_hospitalizations.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/model/src/pyrenew/model/admissionsmodel.py b/model/src/pyrenew/model/admissionsmodel.py index 45d465e8..7c01ebe4 100644 --- a/model/src/pyrenew/model/admissionsmodel.py +++ b/model/src/pyrenew/model/admissionsmodel.py @@ -199,7 +199,7 @@ def sample( latent_infections=basic_model.latent_infections.array, **kwargs, ) - i0_size = len(latent_hosp_admissions) - n_timepoints + i0_size = len(latent_hosp_admissions.array) - n_timepoints if self.hosp_admission_obs_process_rv is None: observed_hosp_admissions = None else: diff --git a/model/src/test/test_model_hospitalizations.py b/model/src/test/test_model_hospitalizations.py index abccf230..5a4772a9 100644 --- a/model/src/test/test_model_hospitalizations.py +++ b/model/src/test/test_model_hospitalizations.py @@ -580,7 +580,7 @@ def test_model_hosp_with_obs_model_weekday_phosp(): obs = jnp.hstack( [ jnp.repeat(jnp.nan, pad_size), - model1_samp.observed_hosp_admissions[pad_size:], + model1_samp.observed_hosp_admissions.array[pad_size:], ] ) # Running with padding From fb8588591a324c55324c0e52791aeb96a9772877 Mon Sep 17 00:00:00 2001 From: "George G. Vega Yon" Date: Wed, 17 Jul 2024 10:40:18 -0600 Subject: [PATCH 07/41] Typo [skip ci] --- model/src/pyrenew/deterministic/deterministic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model/src/pyrenew/deterministic/deterministic.py b/model/src/pyrenew/deterministic/deterministic.py index 092c793d..956104d1 100644 --- a/model/src/pyrenew/deterministic/deterministic.py +++ b/model/src/pyrenew/deterministic/deterministic.py @@ -90,7 +90,7 @@ def sample( Returns ------- tuple - Containing the stored values during construction wrapped in a TimeArrayß. + Containing the stored values during construction wrapped in a TimeArray. """ if record: npro.deterministic(self.name, self.vars) From 68f086704e268439bf618b3bf756d45a451cfeae Mon Sep 17 00:00:00 2001 From: George Vega Yon Date: Wed, 17 Jul 2024 12:11:31 -0600 Subject: [PATCH 08/41] Fixing tutorials --- docs/source/tutorials/basic_renewal_model.qmd | 6 ++--- docs/source/tutorials/extending_pyrenew.qmd | 22 ++++++++++--------- .../tutorials/hospital_admissions_model.qmd | 4 ++-- docs/source/tutorials/periodic_effects.qmd | 4 ++-- docs/source/tutorials/pyrenew_demo.qmd | 12 +++++----- 5 files changed, 25 insertions(+), 23 deletions(-) diff --git a/docs/source/tutorials/basic_renewal_model.qmd b/docs/source/tutorials/basic_renewal_model.qmd index 2e324a08..2ac6c128 100644 --- a/docs/source/tutorials/basic_renewal_model.qmd +++ b/docs/source/tutorials/basic_renewal_model.qmd @@ -190,11 +190,11 @@ import matplotlib.pyplot as plt fig, axs = plt.subplots(1, 2) # Rt plot -axs[0].plot(sim_data.Rt) +axs[0].plot(sim_data.Rt.array) axs[0].set_ylabel("Rt") # Infections plot -axs[1].plot(sim_data.observed_infections) +axs[1].plot(sim_data.observed_infections.array) axs[1].set_ylabel("Infections") fig.suptitle("Basic renewal model") @@ -212,7 +212,7 @@ import jax model1.run( num_warmup=2000, num_samples=1000, - data_observed_infections=sim_data.observed_infections, + data_observed_infections=sim_data.observed_infections.array, rng_key=jax.random.PRNGKey(54), mcmc_args=dict( progress_bar=False, num_chains=2, chain_method="sequential" diff --git a/docs/source/tutorials/extending_pyrenew.qmd b/docs/source/tutorials/extending_pyrenew.qmd index 468bd1b9..db887b3b 100644 --- a/docs/source/tutorials/extending_pyrenew.qmd +++ b/docs/source/tutorials/extending_pyrenew.qmd @@ -43,14 +43,14 @@ The following code-chunk defines the model components. Notice that for both the # | label: model-components gen_int_array = jnp.array([0.25, 0.5, 0.15, 0.1]) gen_int = DeterministicPMF(gen_int_array, name="gen_int") -feedback_strength = DeterministicVariable(0.05, name="feedback_strength") +feedback_strength = DeterministicVariable(0.01, name="feedback_strength") I0 = InfectionInitializationProcess( "I0_initialization", DistributionalRV(dist=dist.LogNormal(0, 1), name="I0"), InitializeInfectionsExponentialGrowth( gen_int_array.size, - DeterministicVariable(0.5, name="rate"), + DeterministicVariable(0.05, name="rate"), ), t_unit=1, ) @@ -96,7 +96,7 @@ with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): import matplotlib.pyplot as plt fig, ax = plt.subplots() -ax.plot(model0_samp.latent_infections) +ax.plot(model0_samp.latent_infections.array) ax.set_xlabel("Time") ax.set_ylabel("Infections") plt.show() @@ -153,7 +153,7 @@ The next step is to create the actual class. The bulk of its implementation lies # | label: new-model-def # | code-line-numbers: true # Creating the class -from pyrenew.metaclass import RandomVariable +from pyrenew.metaclass import RandomVariable, TimeArray from pyrenew.latent import compute_infections_from_rt_with_feedback from pyrenew import arrayutils as au from jax.typing import ArrayLike @@ -201,12 +201,14 @@ class InfFeedback(RandomVariable): **kwargs, ) inf_feedback_strength = au.pad_x_to_match_y( - x=inf_feedback_strength, y=Rt, fill_value=inf_feedback_strength[0] + x=inf_feedback_strength.array, + y=Rt, + fill_value=inf_feedback_strength.array[0] ) # Sampling inf feedback and adjusting the shape inf_feedback_pmf, *_ = self.infection_feedback_pmf(**kwargs) - inf_fb_pmf_rev = jnp.flip(inf_feedback_pmf) + inf_fb_pmf_rev = jnp.flip(inf_feedback_pmf.array) # Generating the infections with feedback all_infections, Rt_adj = compute_infections_from_rt_with_feedback( @@ -223,8 +225,8 @@ class InfFeedback(RandomVariable): # Preparing theoutput return InfFeedbackSample( - infections=all_infections, - rt=Rt_adj, + infections=TimeArray(all_infections), + rt=TimeArray(Rt_adj), ) ``` @@ -267,8 +269,8 @@ Comparing `model0` with `model1`, these two should match: import matplotlib.pyplot as plt fig, ax = plt.subplots(ncols=2) -ax[0].plot(model0_samp.latent_infections) -ax[1].plot(model1_samp.latent_infections) +ax[0].plot(model0_samp.latent_infections.array) +ax[1].plot(model1_samp.latent_infections.array) ax[0].set_xlabel("Time (model 0)") ax[1].set_xlabel("Time (model 1)") ax[0].set_ylabel("Infections") diff --git a/docs/source/tutorials/hospital_admissions_model.qmd b/docs/source/tutorials/hospital_admissions_model.qmd index f550ea98..6d6094ee 100644 --- a/docs/source/tutorials/hospital_admissions_model.qmd +++ b/docs/source/tutorials/hospital_admissions_model.qmd @@ -223,11 +223,11 @@ import matplotlib.pyplot as plt fig, axs = plt.subplots(1, 2) # Rt plot -axs[0].plot(sim_data.Rt) +axs[0].plot(sim_data.Rt.array) axs[0].set_ylabel("Rt") # Admissions plot -axs[1].plot(sim_data.observed_hosp_admissions) +axs[1].plot(sim_data.observed_hosp_admissions.array) axs[1].set_ylabel("Admissions") axs[1].set_yscale("log") diff --git a/docs/source/tutorials/periodic_effects.qmd b/docs/source/tutorials/periodic_effects.qmd index 116bc75f..e21d7ec9 100644 --- a/docs/source/tutorials/periodic_effects.qmd +++ b/docs/source/tutorials/periodic_effects.qmd @@ -45,7 +45,7 @@ with npro.handlers.seed(rng_seed=20): # Plotting the Rt values import matplotlib.pyplot as plt -plt.step(np.arange(len(sim_data.rt)), sim_data.rt, where="post") +plt.step(np.arange(len(sim_data.rt.array)), sim_data.rt.array, where="post") plt.xlabel("Time") plt.ylabel("Rt") plt.title("Simulated Rt values") @@ -89,7 +89,7 @@ with npro.handlers.seed(rng_seed=20): # Plotting the effect values import matplotlib.pyplot as plt -plt.step(np.arange(len(sim_data.value)), sim_data.value, where="post") +plt.step(np.arange(len(sim_data.value.array)), sim_data.value.array, where="post") plt.xlabel("Time") plt.ylabel("Effect size") plt.title("Simulated Day of Week Effect values") diff --git a/docs/source/tutorials/pyrenew_demo.qmd b/docs/source/tutorials/pyrenew_demo.qmd index 858b41a6..93c85750 100644 --- a/docs/source/tutorials/pyrenew_demo.qmd +++ b/docs/source/tutorials/pyrenew_demo.qmd @@ -47,7 +47,7 @@ q = SimpleRandomWalkProcess(dist.Normal(0, 0.001)) with seed(rng_seed=np.random.randint(0, 1000)): q_samp = q(n_timepoints=100) -plt.plot(np.exp(q_samp[0])) +plt.plot(np.exp(q_samp[0].array)) ``` Next, import several additional functions from the `latent` module of the `pyrenew` package to model infections and hospital admissions. @@ -163,10 +163,10 @@ Visualizations of the single model output show (top) infections over the 30 time # | label: fig-hosp # | fig-cap: Infections fig, ax = plt.subplots(nrows=3, sharex=True) -ax[0].plot(x.latent_infections) +ax[0].plot(x.latent_infections.array) ax[0].set_ylim([1 / 5, 5]) -ax[1].plot(x.latent_hosp_admissions) -ax[2].plot(x.observed_hosp_admissions, "o") +ax[1].plot(x.latent_hosp_admissions.array) +ax[2].plot(x.observed_hosp_admissions.array, "o") for axis in ax[:-1]: axis.set_yscale("log") ``` @@ -178,7 +178,7 @@ To fit the `hospmodel` to the simulated data, we call `hospmodel.run()`, an MCMC hospmodel.run( num_warmup=1000, num_samples=1000, - data_observed_hosp_admissions=x.observed_hosp_admissions, + data_observed_hosp_admissions=x.observed_hosp_admissions.array, rng_key=jax.random.PRNGKey(54), mcmc_args=dict(progress_bar=False), ) @@ -208,7 +208,7 @@ import polars as pl fig, ax = plt.subplots(figsize=[4, 5]) -ax.plot(x[0]) +ax.plot(x[0].array) samp_ids = np.random.randint(size=25, low=0, high=999) for samp_id in samp_ids: sub_samps = samps.filter(pl.col("draw") == samp_id).sort(pl.col("time")) From eba63e9ed26d5a84808a89acc90e1bf443b9a644 Mon Sep 17 00:00:00 2001 From: George Vega Yon Date: Wed, 17 Jul 2024 15:11:11 -0600 Subject: [PATCH 09/41] Working on hosp admin tutorial (expected to fail) --- docs/source/tutorials/hospital_admissions_model.qmd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/tutorials/hospital_admissions_model.qmd b/docs/source/tutorials/hospital_admissions_model.qmd index 6d6094ee..b98b0733 100644 --- a/docs/source/tutorials/hospital_admissions_model.qmd +++ b/docs/source/tutorials/hospital_admissions_model.qmd @@ -499,7 +499,7 @@ class DayOfWeekEffect(metaclass.RandomVariable): sample_shape=(7,), ) - return jnp.tile(ans, self.nweeks)[: self.len] + return (metaclass.TimeArray(jnp.tile(ans, self.nweeks)[: self.len]),) # Initializing the RV From b6c0ff07ca0b4b27b866ee729efee08b0150cf6f Mon Sep 17 00:00:00 2001 From: George Vega Yon Date: Wed, 17 Jul 2024 18:01:06 -0600 Subject: [PATCH 10/41] Patching tutorial --- .../tutorials/hospital_admissions_model.qmd | 23 +++++++------------ .../src/pyrenew/latent/hospitaladmissions.py | 5 +++- 2 files changed, 12 insertions(+), 16 deletions(-) diff --git a/docs/source/tutorials/hospital_admissions_model.qmd b/docs/source/tutorials/hospital_admissions_model.qmd index b98b0733..9fa81d82 100644 --- a/docs/source/tutorials/hospital_admissions_model.qmd +++ b/docs/source/tutorials/hospital_admissions_model.qmd @@ -476,21 +476,12 @@ import numpyro as npro class DayOfWeekEffect(metaclass.RandomVariable): """Day of the week effect""" - def __init__(self, len: int): - """Initialize the day of the week effect distribution - Parameters - ---------- - len : int - The number of observations - """ - self.nweeks = int(jnp.ceil(len / 7)) - self.len = len - @staticmethod def validate(): return None - def sample(self, **kwargs): + def sample(self, n_timepoints: int, **kwargs): + ans = npro.sample( name="dayofweek_effect", fn=npro.distributions.TruncatedNormal( @@ -499,14 +490,16 @@ class DayOfWeekEffect(metaclass.RandomVariable): sample_shape=(7,), ) - return (metaclass.TimeArray(jnp.tile(ans, self.nweeks)[: self.len]),) + return (metaclass.TimeArray( + jnp.tile(ans, 100)[: n_timepoints]), + ) -# Initializing the RV -dayofweek_effect = DayOfWeekEffect(dat.shape[0]) +# Initializing the RV. +dayofweek_effect = DayOfWeekEffect() ``` -Notice that the instance's `nweeks` and `len` members are passed during construction. Trying to compute the number of weeks and the length of the dataset in the `validate` method will raise a `jit` error in `jax` as the shape and size of elements are not known during the validation step, which happens before the model is run. With the new effect, we can rebuild the latent hospitalization model: +Notice that the instance's `nweeks` and `len` members are defined during construction. Trying to compute the number of weeks and the length of the dataset in the `validate` method will raise a `jit` error in `jax` as the shape and size of elements are not known during the validation step, which happens before the model is run. With the new effect, we can rebuild the latent hospitalization model: ```{python} # | label: latent-hosp-weekday diff --git a/model/src/pyrenew/latent/hospitaladmissions.py b/model/src/pyrenew/latent/hospitaladmissions.py index 1b301479..906332cd 100644 --- a/model/src/pyrenew/latent/hospitaladmissions.py +++ b/model/src/pyrenew/latent/hospitaladmissions.py @@ -191,7 +191,10 @@ def sample( # Applying the day of the week effect latent_hospital_admissions = ( latent_hospital_admissions - * self.day_of_week_effect_rv(**kwargs)[0].array + * self.day_of_week_effect_rv( + n_timepoints = latent_hospital_admissions.size, + **kwargs + )[0].array ) # Applying probability of hospitalization effect From 69d56c144993752fd3830f8f9bf0a551d84834e8 Mon Sep 17 00:00:00 2001 From: George Vega Yon Date: Thu, 18 Jul 2024 15:53:07 -0600 Subject: [PATCH 11/41] Making pre-commit happy --- docs/source/tutorials/extending_pyrenew.qmd | 2 +- .../tutorials/hospital_admissions_model.qmd | 4 +- docs/source/tutorials/periodic_effects.qmd | 4 +- docs/source/tutorials/time.qmd | 2 +- .../pyrenew/deterministic/deterministic.py | 9 +++-- .../pyrenew/deterministic/deterministicpmf.py | 6 +-- model/src/pyrenew/deterministic/nullrv.py | 2 +- model/src/pyrenew/deterministic/process.py | 11 +++--- .../src/pyrenew/latent/hospitaladmissions.py | 11 +++--- .../infection_initialization_process.py | 2 +- .../pyrenew/latent/infectionswithfeedback.py | 8 +++- model/src/pyrenew/metaclass.py | 37 +++++++++---------- .../pyrenew/model/rtinfectionsrenewalmodel.py | 5 ++- .../pyrenew/observation/negativebinomial.py | 19 +++++----- model/src/pyrenew/process/ar.py | 4 +- .../src/pyrenew/process/firstdifferencear.py | 4 +- model/src/pyrenew/process/periodiceffect.py | 16 +++++--- model/src/pyrenew/process/rtperiodicdiff.py | 10 ++++- model/src/test/test_model_basic_renewal.py | 3 +- model/src/test/test_model_hospitalizations.py | 21 ++++++++--- model/src/test/test_periodiceffect.py | 3 +- 21 files changed, 106 insertions(+), 77 deletions(-) diff --git a/docs/source/tutorials/extending_pyrenew.qmd b/docs/source/tutorials/extending_pyrenew.qmd index db887b3b..ff7a6ab9 100644 --- a/docs/source/tutorials/extending_pyrenew.qmd +++ b/docs/source/tutorials/extending_pyrenew.qmd @@ -203,7 +203,7 @@ class InfFeedback(RandomVariable): inf_feedback_strength = au.pad_x_to_match_y( x=inf_feedback_strength.array, y=Rt, - fill_value=inf_feedback_strength.array[0] + fill_value=inf_feedback_strength.array[0], ) # Sampling inf feedback and adjusting the shape diff --git a/docs/source/tutorials/hospital_admissions_model.qmd b/docs/source/tutorials/hospital_admissions_model.qmd index 9fa81d82..4aaf8690 100644 --- a/docs/source/tutorials/hospital_admissions_model.qmd +++ b/docs/source/tutorials/hospital_admissions_model.qmd @@ -490,9 +490,7 @@ class DayOfWeekEffect(metaclass.RandomVariable): sample_shape=(7,), ) - return (metaclass.TimeArray( - jnp.tile(ans, 100)[: n_timepoints]), - ) + return (metaclass.TimeArray(jnp.tile(ans, 100)[:n_timepoints]),) # Initializing the RV. diff --git a/docs/source/tutorials/periodic_effects.qmd b/docs/source/tutorials/periodic_effects.qmd index e21d7ec9..30b747bf 100644 --- a/docs/source/tutorials/periodic_effects.qmd +++ b/docs/source/tutorials/periodic_effects.qmd @@ -89,7 +89,9 @@ with npro.handlers.seed(rng_seed=20): # Plotting the effect values import matplotlib.pyplot as plt -plt.step(np.arange(len(sim_data.value.array)), sim_data.value.array, where="post") +plt.step( + np.arange(len(sim_data.value.array)), sim_data.value.array, where="post" +) plt.xlabel("Time") plt.ylabel("Effect size") plt.title("Simulated Day of Week Effect values") diff --git a/docs/source/tutorials/time.qmd b/docs/source/tutorials/time.qmd index 457c53e0..c76334c4 100644 --- a/docs/source/tutorials/time.qmd +++ b/docs/source/tutorials/time.qmd @@ -41,4 +41,4 @@ Rt_aligned, infections_aligned = align([Rt, infections]) ### Retrieving time information from sites -Since numpyro only stores Jax arrays, we cannot store the time information in the arrays themselves. Next iterations of `pyrenew` should include a way to retrieve the time information from the sites of the model after running them. \ No newline at end of file +Since numpyro only stores Jax arrays, we cannot store the time information in the arrays themselves. Next iterations of `pyrenew` should include a way to retrieve the time information from the sites of the model after running them. diff --git a/model/src/pyrenew/deterministic/deterministic.py b/model/src/pyrenew/deterministic/deterministic.py index 956104d1..d673870e 100644 --- a/model/src/pyrenew/deterministic/deterministic.py +++ b/model/src/pyrenew/deterministic/deterministic.py @@ -96,7 +96,8 @@ def sample( npro.deterministic(self.name, self.vars) return ( TimeArray( - array = self.vars, - t_start=self.t_start, - t_unit=self.t_unit, - ),) + array=self.vars, + t_start=self.t_start, + t_unit=self.t_unit, + ), + ) diff --git a/model/src/pyrenew/deterministic/deterministicpmf.py b/model/src/pyrenew/deterministic/deterministicpmf.py index 00bf4e1b..834c5b7f 100644 --- a/model/src/pyrenew/deterministic/deterministicpmf.py +++ b/model/src/pyrenew/deterministic/deterministicpmf.py @@ -42,8 +42,8 @@ def __init__( The start time of the process. t_unit : int, optional The unit of time relative to the model's fundamental (smallest) - time unit. - + time unit. + Returns ------- None @@ -58,7 +58,7 @@ def __init__( name=name, t_start=t_start, t_unit=t_unit, - ) + ) return None diff --git a/model/src/pyrenew/deterministic/nullrv.py b/model/src/pyrenew/deterministic/nullrv.py index 03e06232..97e11a2d 100644 --- a/model/src/pyrenew/deterministic/nullrv.py +++ b/model/src/pyrenew/deterministic/nullrv.py @@ -3,8 +3,8 @@ from __future__ import annotations from jax.typing import ArrayLike -from pyrenew.metaclass import TimeArray from pyrenew.deterministic.deterministic import DeterministicVariable +from pyrenew.metaclass import TimeArray class NullVariable(DeterministicVariable): diff --git a/model/src/pyrenew/deterministic/process.py b/model/src/pyrenew/deterministic/process.py index 9ac2cb7b..98cc4300 100644 --- a/model/src/pyrenew/deterministic/process.py +++ b/model/src/pyrenew/deterministic/process.py @@ -1,8 +1,8 @@ # numpydoc ignore=GL08 import jax.numpy as jnp -from pyrenew.metaclass import TimeArray from pyrenew.deterministic.deterministic import DeterministicVariable +from pyrenew.metaclass import TimeArray class DeterministicProcess(DeterministicVariable): @@ -40,15 +40,16 @@ def sample( if dif > 0: return ( TimeArray( - jnp.hstack([res.array, jnp.repeat(res.array[-1], dif)]), t_start=self.t_start, + jnp.hstack([res.array, jnp.repeat(res.array[-1], dif)]), + t_start=self.t_start, t_unit=self.t_unit, - ), - ) + ), + ) return ( TimeArray( array=res.array[:duration], t_start=self.t_start, - t_unit=self.t_unit + t_unit=self.t_unit, ), ) diff --git a/model/src/pyrenew/latent/hospitaladmissions.py b/model/src/pyrenew/latent/hospitaladmissions.py index 906332cd..cf061538 100644 --- a/model/src/pyrenew/latent/hospitaladmissions.py +++ b/model/src/pyrenew/latent/hospitaladmissions.py @@ -192,15 +192,14 @@ def sample( latent_hospital_admissions = ( latent_hospital_admissions * self.day_of_week_effect_rv( - n_timepoints = latent_hospital_admissions.size, - **kwargs - )[0].array + n_timepoints=latent_hospital_admissions.size, **kwargs + )[0].array ) # Applying probability of hospitalization effect latent_hospital_admissions = ( - latent_hospital_admissions * - self.hosp_report_prob_rv(**kwargs)[0].array + latent_hospital_admissions + * self.hosp_report_prob_rv(**kwargs)[0].array ) npro.deterministic( @@ -213,5 +212,5 @@ def sample( array=latent_hospital_admissions, t_start=self.infection_to_admission_interval_rv.t_start, t_unit=self.infection_to_admission_interval_rv.t_unit, - ) + ), ) diff --git a/model/src/pyrenew/latent/infection_initialization_process.py b/model/src/pyrenew/latent/infection_initialization_process.py index 7d832307..836d28e6 100644 --- a/model/src/pyrenew/latent/infection_initialization_process.py +++ b/model/src/pyrenew/latent/infection_initialization_process.py @@ -105,4 +105,4 @@ def sample(self) -> tuple: t_start=self.t_start, t_unit=self.t_unit, ), - ) + ) diff --git a/model/src/pyrenew/latent/infectionswithfeedback.py b/model/src/pyrenew/latent/infectionswithfeedback.py index 52bf9a0f..001834ca 100644 --- a/model/src/pyrenew/latent/infectionswithfeedback.py +++ b/model/src/pyrenew/latent/infectionswithfeedback.py @@ -9,7 +9,9 @@ import pyrenew.latent.infection_functions as inf from numpy.typing import ArrayLike from pyrenew.metaclass import ( - RandomVariable, _assert_sample_and_rtype, TimeArray + RandomVariable, + TimeArray, + _assert_sample_and_rtype, ) @@ -198,6 +200,8 @@ def sample( npro.deterministic("Rt_adjusted", Rt_adj) return InfectionsRtFeedbackSample( - post_initialization_infections=TimeArray(post_initialization_infections), + post_initialization_infections=TimeArray( + post_initialization_infections + ), rt=TimeArray(Rt_adj), ) diff --git a/model/src/pyrenew/metaclass.py b/model/src/pyrenew/metaclass.py index 9b5c95d2..71abbcfb 100644 --- a/model/src/pyrenew/metaclass.py +++ b/model/src/pyrenew/metaclass.py @@ -97,15 +97,9 @@ class TimeArray: A container for a time-aware array. """ - @staticmethod - def to_array(array: ArrayLike | "TimeArray") -> ArrayLike: - if isinstance(array, TimeArray): - return array.array - return array - def __init__( self, - array : ArrayLike | None, + array: ArrayLike | None, t_start: int | None = None, t_unit: int | None = None, ) -> None: @@ -127,7 +121,9 @@ def __init__( """ if array is not None: - assert isinstance(array, ArrayLike), "array should be an array-like object." + assert isinstance( + array, ArrayLike + ), "array should be an array-like object." self.array = array self.t_start = t_start @@ -135,6 +131,7 @@ def __init__( return None + class RandomVariable(metaclass=ABCMeta): """ Abstract base class for latent and observed random variables. @@ -197,14 +194,13 @@ def set_timeseries( """ # Either both values are None or both are not None - assert \ - (t_unit is not None and t_start is not None) or \ - (t_unit is None and t_start is None), \ - "Both t_start and t_unit should be None or not None." + assert (t_unit is not None and t_start is not None) or ( + t_unit is None and t_start is None + ), "Both t_start and t_unit should be None or not None." if t_unit is None and t_start is None: return None - + # Timeseries unit should be a positive integer assert isinstance( t_unit, int @@ -347,12 +343,15 @@ def sample( DistributionalRVSample """ return DistributionalRVSample( - value=TimeArray(jnp.atleast_1d( - npro.sample( - name=self.name, - fn=self.dist, - obs=obs, - ))), + value=TimeArray( + jnp.atleast_1d( + npro.sample( + name=self.name, + fn=self.dist, + obs=obs, + ) + ) + ), ) diff --git a/model/src/pyrenew/model/rtinfectionsrenewalmodel.py b/model/src/pyrenew/model/rtinfectionsrenewalmodel.py index da4770ac..72c20ca4 100644 --- a/model/src/pyrenew/model/rtinfectionsrenewalmodel.py +++ b/model/src/pyrenew/model/rtinfectionsrenewalmodel.py @@ -11,7 +11,10 @@ from numpy.typing import ArrayLike from pyrenew.deterministic import NullObservation from pyrenew.metaclass import ( - Model, RandomVariable, _assert_sample_and_rtype, TimeArray + Model, + RandomVariable, + TimeArray, + _assert_sample_and_rtype, ) diff --git a/model/src/pyrenew/observation/negativebinomial.py b/model/src/pyrenew/observation/negativebinomial.py index 83bd1163..40f4e7e7 100644 --- a/model/src/pyrenew/observation/negativebinomial.py +++ b/model/src/pyrenew/observation/negativebinomial.py @@ -89,14 +89,13 @@ def sample( """ concentration, *_ = self.concentration_rv.sample() - negative_binomial_sample = ( - numpyro.sample( - name=self.name, - fn=dist.NegativeBinomial2( - mean=mu + self.eps, - concentration=concentration.array, - ), - obs=obs, - )) - + negative_binomial_sample = numpyro.sample( + name=self.name, + fn=dist.NegativeBinomial2( + mean=mu + self.eps, + concentration=concentration.array, + ), + obs=obs, + ) + return (TimeArray(negative_binomial_sample),) diff --git a/model/src/pyrenew/process/ar.py b/model/src/pyrenew/process/ar.py index 8e108c75..98e04e14 100644 --- a/model/src/pyrenew/process/ar.py +++ b/model/src/pyrenew/process/ar.py @@ -91,9 +91,7 @@ def _ar_scanner(carry, next): # numpydoc ignore=GL08 ) last, ts = lax.scan(_ar_scanner, inits - self.mean, noise) - return ( - TimeArray(jnp.hstack([inits, self.mean + ts.flatten()])), - ) + return (TimeArray(jnp.hstack([inits, self.mean + ts.flatten()])),) @staticmethod def validate(): # numpydoc ignore=RT01 diff --git a/model/src/pyrenew/process/firstdifferencear.py b/model/src/pyrenew/process/firstdifferencear.py index 3680ed61..c6cde5f3 100644 --- a/model/src/pyrenew/process/firstdifferencear.py +++ b/model/src/pyrenew/process/firstdifferencear.py @@ -72,7 +72,9 @@ def sample( inits=jnp.atleast_1d(init_rate_of_change), name=name + "_rate_of_change", ) - return (TimeArray(init_val + jnp.cumsum(rates_of_change.array.flatten())),) + return ( + TimeArray(init_val + jnp.cumsum(rates_of_change.array.flatten())), + ) @staticmethod def validate(): diff --git a/model/src/pyrenew/process/periodiceffect.py b/model/src/pyrenew/process/periodiceffect.py index 01739a00..b380478c 100644 --- a/model/src/pyrenew/process/periodiceffect.py +++ b/model/src/pyrenew/process/periodiceffect.py @@ -4,7 +4,11 @@ import jax.numpy as jnp import pyrenew.arrayutils as au -from pyrenew.metaclass import RandomVariable, _assert_sample_and_rtype, TimeArray +from pyrenew.metaclass import ( + RandomVariable, + TimeArray, + _assert_sample_and_rtype, +) class PeriodicEffectSample(NamedTuple): @@ -110,10 +114,12 @@ def sample(self, duration: int, **kwargs): """ return PeriodicEffectSample( - value=TimeArray(self.broadcaster( - data=self.quantity_to_broadcast.sample(**kwargs)[0].array, - n_timepoints=duration, - )) + value=TimeArray( + self.broadcaster( + data=self.quantity_to_broadcast.sample(**kwargs)[0].array, + n_timepoints=duration, + ) + ) ) diff --git a/model/src/pyrenew/process/rtperiodicdiff.py b/model/src/pyrenew/process/rtperiodicdiff.py index 2848c662..f70d08a7 100644 --- a/model/src/pyrenew/process/rtperiodicdiff.py +++ b/model/src/pyrenew/process/rtperiodicdiff.py @@ -4,7 +4,11 @@ import jax.numpy as jnp from jax.typing import ArrayLike from pyrenew.arrayutils import PeriodicBroadcaster -from pyrenew.metaclass import RandomVariable, _assert_sample_and_rtype, TimeArray +from pyrenew.metaclass import ( + RandomVariable, + TimeArray, + _assert_sample_and_rtype, +) from pyrenew.process.firstdifferencear import FirstDifferenceARProcess @@ -188,7 +192,9 @@ def sample( )[0] return RtPeriodicDiffProcessSample( - rt=TimeArray(self.broadcaster(jnp.exp(log_rt.array.flatten()), duration)), + rt=TimeArray( + self.broadcaster(jnp.exp(log_rt.array.flatten()), duration) + ), ) diff --git a/model/src/test/test_model_basic_renewal.py b/model/src/test/test_model_basic_renewal.py index d1fa7546..c5089232 100644 --- a/model/src/test/test_model_basic_renewal.py +++ b/model/src/test/test_model_basic_renewal.py @@ -150,7 +150,8 @@ def test_model_basicrenewal_no_obs_model(): np.testing.assert_array_equal(model0_samp.Rt.array, model1_samp.Rt.array) np.testing.assert_array_equal( - model0_samp.latent_infections.array, model1_samp.latent_infections.array + model0_samp.latent_infections.array, + model1_samp.latent_infections.array, ) np.testing.assert_array_equal( model0_samp.observed_infections.array, diff --git a/model/src/test/test_model_hospitalizations.py b/model/src/test/test_model_hospitalizations.py index 3d0570fa..5888fd63 100644 --- a/model/src/test/test_model_hospitalizations.py +++ b/model/src/test/test_model_hospitalizations.py @@ -39,7 +39,11 @@ def validate(self): # numpydoc ignore=GL08 def sample(self, **kwargs): # numpydoc ignore=GL08 return ( - TimeArray(npro.sample(name=self.name, fn=dist.Uniform(high=0.99, low=0.01))), + TimeArray( + npro.sample( + name=self.name, fn=dist.Uniform(high=0.99, low=0.01) + ) + ), ) @@ -259,15 +263,20 @@ def test_model_hosp_no_obs_model(): with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): model1_samp = model0.sample(n_timepoints_to_simulate=30) - np.testing.assert_array_almost_equal(model0_samp.Rt.array, model1_samp.Rt.array) + np.testing.assert_array_almost_equal( + model0_samp.Rt.array, model1_samp.Rt.array + ) np.testing.assert_array_equal( - model0_samp.latent_infections.array, model1_samp.latent_infections.array + model0_samp.latent_infections.array, + model1_samp.latent_infections.array, ) np.testing.assert_array_equal( - model0_samp.infection_hosp_rate.array, model1_samp.infection_hosp_rate.array + model0_samp.infection_hosp_rate.array, + model1_samp.infection_hosp_rate.array, ) np.testing.assert_array_equal( - model0_samp.latent_hosp_admissions.array, model1_samp.latent_hosp_admissions.array + model0_samp.latent_hosp_admissions.array, + model1_samp.latent_hosp_admissions.array, ) # These are supposed to be none, both @@ -383,7 +392,6 @@ def test_model_hosp_with_obs_model(): assert inf_mean.to_numpy().shape[0] == 500 - def test_model_hosp_with_obs_model_weekday_phosp_2(): """ Checks that the random Hospitalization model runs @@ -484,6 +492,7 @@ def test_model_hosp_with_obs_model_weekday_phosp_2(): # It should be about the MCMC inference. assert inf_mean.to_numpy().shape[0] == 500 + def test_model_hosp_with_obs_model_weekday_phosp(): """ Checks that the random Hospitalization model runs diff --git a/model/src/test/test_periodiceffect.py b/model/src/test/test_periodiceffect.py index 4e3835eb..481ecad3 100644 --- a/model/src/test/test_periodiceffect.py +++ b/model/src/test/test_periodiceffect.py @@ -88,5 +88,6 @@ def test_weeklyeffect() -> None: return None + test_periodiceffect() -test_weeklyeffect() \ No newline at end of file +test_weeklyeffect() From 0fad7e484507614da545b41a14f3cfb88b63aaf0 Mon Sep 17 00:00:00 2001 From: George Vega Yon Date: Mon, 22 Jul 2024 13:12:19 -0600 Subject: [PATCH 12/41] Addressing points by @damonbayer and @dylanhmorris --- docs/source/tutorials/hospital_admissions_model.qmd | 4 +++- model/src/pyrenew/metaclass.py | 3 ++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/docs/source/tutorials/hospital_admissions_model.qmd b/docs/source/tutorials/hospital_admissions_model.qmd index 4aaf8690..5a8443ae 100644 --- a/docs/source/tutorials/hospital_admissions_model.qmd +++ b/docs/source/tutorials/hospital_admissions_model.qmd @@ -482,6 +482,8 @@ class DayOfWeekEffect(metaclass.RandomVariable): def sample(self, n_timepoints: int, **kwargs): + nweeks = (n_timepoints // 7) + 1 + ans = npro.sample( name="dayofweek_effect", fn=npro.distributions.TruncatedNormal( @@ -490,7 +492,7 @@ class DayOfWeekEffect(metaclass.RandomVariable): sample_shape=(7,), ) - return (metaclass.TimeArray(jnp.tile(ans, 100)[:n_timepoints]),) + return (metaclass.TimeArray(jnp.tile(ans, nweeks)[:n_timepoints]),) # Initializing the RV. diff --git a/model/src/pyrenew/metaclass.py b/model/src/pyrenew/metaclass.py index 71abbcfb..55720fce 100644 --- a/model/src/pyrenew/metaclass.py +++ b/model/src/pyrenew/metaclass.py @@ -196,7 +196,8 @@ def set_timeseries( # Either both values are None or both are not None assert (t_unit is not None and t_start is not None) or ( t_unit is None and t_start is None - ), "Both t_start and t_unit should be None or not None." + ), "Both t_start and t_unit should be None or not None. " \ + "Currently, t_start is {t_start} and t_unit is {t_unit}." if t_unit is None and t_start is None: return None From 6952683a5439fd9a0de9cac92acee0f77ab50bd9 Mon Sep 17 00:00:00 2001 From: George Vega Yon Date: Tue, 23 Jul 2024 10:21:12 -0600 Subject: [PATCH 13/41] Renaming TimeArray to SampledValue --- docs/source/tutorials/extending_pyrenew.qmd | 6 +++--- .../source/tutorials/hospital_admissions_model.qmd | 2 +- docs/source/tutorials/time.qmd | 2 +- model/src/pyrenew/deterministic/deterministic.py | 6 +++--- .../src/pyrenew/deterministic/deterministicpmf.py | 2 +- model/src/pyrenew/deterministic/nullrv.py | 14 +++++++------- model/src/pyrenew/deterministic/process.py | 8 ++++---- model/src/pyrenew/latent/hospitaladmissions.py | 10 +++++----- .../latent/infection_initialization_process.py | 4 ++-- model/src/pyrenew/latent/infections.py | 6 +++--- model/src/pyrenew/latent/infectionswithfeedback.py | 6 +++--- model/src/pyrenew/metaclass.py | 10 +++++----- .../src/pyrenew/model/rtinfectionsrenewalmodel.py | 8 ++++---- model/src/pyrenew/observation/negativebinomial.py | 4 ++-- model/src/pyrenew/observation/poisson.py | 4 ++-- model/src/pyrenew/process/ar.py | 4 ++-- model/src/pyrenew/process/firstdifferencear.py | 4 ++-- model/src/pyrenew/process/periodiceffect.py | 6 +++--- model/src/pyrenew/process/rtperiodicdiff.py | 4 ++-- model/src/pyrenew/process/rtrandomwalk.py | 4 ++-- model/src/pyrenew/process/simplerandomwalk.py | 4 ++-- model/src/test/test_model_hospitalizations.py | 4 ++-- 22 files changed, 61 insertions(+), 61 deletions(-) diff --git a/docs/source/tutorials/extending_pyrenew.qmd b/docs/source/tutorials/extending_pyrenew.qmd index ff7a6ab9..126d9a7c 100644 --- a/docs/source/tutorials/extending_pyrenew.qmd +++ b/docs/source/tutorials/extending_pyrenew.qmd @@ -153,7 +153,7 @@ The next step is to create the actual class. The bulk of its implementation lies # | label: new-model-def # | code-line-numbers: true # Creating the class -from pyrenew.metaclass import RandomVariable, TimeArray +from pyrenew.metaclass import RandomVariable, SampledValue from pyrenew.latent import compute_infections_from_rt_with_feedback from pyrenew import arrayutils as au from jax.typing import ArrayLike @@ -225,8 +225,8 @@ class InfFeedback(RandomVariable): # Preparing theoutput return InfFeedbackSample( - infections=TimeArray(all_infections), - rt=TimeArray(Rt_adj), + infections=SampledValue(all_infections), + rt=SampledValue(Rt_adj), ) ``` diff --git a/docs/source/tutorials/hospital_admissions_model.qmd b/docs/source/tutorials/hospital_admissions_model.qmd index 48f13b5e..0ce74c2e 100644 --- a/docs/source/tutorials/hospital_admissions_model.qmd +++ b/docs/source/tutorials/hospital_admissions_model.qmd @@ -496,7 +496,7 @@ class DayOfWeekEffect(metaclass.RandomVariable): sample_shape=(7,), ) - return (metaclass.TimeArray(jnp.tile(ans, nweeks)[:n_timepoints]),) + return (metaclass.SampledValue(jnp.tile(ans, nweeks)[:n_timepoints]),) # Initializing the RV. diff --git a/docs/source/tutorials/time.qmd b/docs/source/tutorials/time.qmd index c76334c4..1a7c1de8 100644 --- a/docs/source/tutorials/time.qmd +++ b/docs/source/tutorials/time.qmd @@ -10,7 +10,7 @@ The fundamental time unit should represent a period of fixed (or approximately f For many infectious disease renewal models of interest, the fundamental time unit will be days, and we will proceed with this tutorial treating days as our fundamental unit. - `pyrenew` deals with time having `RandomVariable`s carry information about (i) their own time unit expressed relative to the fundamental unit (`t_unit`) and (ii) the starting time, `t_start`, measured relative to `t = 0` in model time in fundamental time units. Moreover, return values from `RandomVariable.sample()` are namedtuples with `TimeArray` objects that carry the same information. + `pyrenew` deals with time having `RandomVariable`s carry information about (i) their own time unit expressed relative to the fundamental unit (`t_unit`) and (ii) the starting time, `t_start`, measured relative to `t = 0` in model time in fundamental time units. Moreover, return values from `RandomVariable.sample()` are namedtuples with `SampledValue` objects that carry the same information. The tuple `(t_unit, t_start)` can encode different types of time series data. For example: diff --git a/model/src/pyrenew/deterministic/deterministic.py b/model/src/pyrenew/deterministic/deterministic.py index d673870e..32da4714 100644 --- a/model/src/pyrenew/deterministic/deterministic.py +++ b/model/src/pyrenew/deterministic/deterministic.py @@ -6,7 +6,7 @@ import jax.numpy as jnp import numpyro as npro from jax.typing import ArrayLike -from pyrenew.metaclass import RandomVariable, TimeArray +from pyrenew.metaclass import RandomVariable, SampledValue class DeterministicVariable(RandomVariable): @@ -90,12 +90,12 @@ def sample( Returns ------- tuple - Containing the stored values during construction wrapped in a TimeArray. + Containing the stored values during construction wrapped in a SampledValue. """ if record: npro.deterministic(self.name, self.vars) return ( - TimeArray( + SampledValue( array=self.vars, t_start=self.t_start, t_unit=self.t_unit, diff --git a/model/src/pyrenew/deterministic/deterministicpmf.py b/model/src/pyrenew/deterministic/deterministicpmf.py index 834c5b7f..a44224ae 100644 --- a/model/src/pyrenew/deterministic/deterministicpmf.py +++ b/model/src/pyrenew/deterministic/deterministicpmf.py @@ -94,7 +94,7 @@ def sample( Returns ------- tuple - Containing the stored values during construction wrapped in a TimeArray. + Containing the stored values during construction wrapped in a SampledValue. """ return self.basevar.sample(**kwargs) diff --git a/model/src/pyrenew/deterministic/nullrv.py b/model/src/pyrenew/deterministic/nullrv.py index 97e11a2d..119531d0 100644 --- a/model/src/pyrenew/deterministic/nullrv.py +++ b/model/src/pyrenew/deterministic/nullrv.py @@ -4,7 +4,7 @@ from jax.typing import ArrayLike from pyrenew.deterministic.deterministic import DeterministicVariable -from pyrenew.metaclass import TimeArray +from pyrenew.metaclass import SampledValue class NullVariable(DeterministicVariable): @@ -47,10 +47,10 @@ def sample( Returns ------- tuple - Containing a TimeArray with None. + Containing a SampledValue with None. """ - return (TimeArray(None),) + return (SampledValue(None),) class NullProcess(NullVariable): @@ -96,10 +96,10 @@ def sample( Returns ------- tuple - Containing a TimeArray with None. + Containing a SampledValue with None. """ - return (TimeArray(None),) + return (SampledValue(None),) class NullObservation(NullVariable): @@ -152,7 +152,7 @@ def sample( Returns ------- tuple - Containing a TimeArray with None. + Containing a SampledValue with None. """ - return (TimeArray(None),) + return (SampledValue(None),) diff --git a/model/src/pyrenew/deterministic/process.py b/model/src/pyrenew/deterministic/process.py index 98cc4300..29e04658 100644 --- a/model/src/pyrenew/deterministic/process.py +++ b/model/src/pyrenew/deterministic/process.py @@ -2,7 +2,7 @@ import jax.numpy as jnp from pyrenew.deterministic.deterministic import DeterministicVariable -from pyrenew.metaclass import TimeArray +from pyrenew.metaclass import SampledValue class DeterministicProcess(DeterministicVariable): @@ -30,7 +30,7 @@ def sample( Returns ------- tuple - Containing the stored values during construction wrapped in a TimeArray. + Containing the stored values during construction wrapped in a SampledValue. """ res, *_ = super().sample(**kwargs) @@ -39,7 +39,7 @@ def sample( if dif > 0: return ( - TimeArray( + SampledValue( jnp.hstack([res.array, jnp.repeat(res.array[-1], dif)]), t_start=self.t_start, t_unit=self.t_unit, @@ -47,7 +47,7 @@ def sample( ) return ( - TimeArray( + SampledValue( array=res.array[:duration], t_start=self.t_start, t_unit=self.t_unit, diff --git a/model/src/pyrenew/latent/hospitaladmissions.py b/model/src/pyrenew/latent/hospitaladmissions.py index cf061538..1d6a2776 100644 --- a/model/src/pyrenew/latent/hospitaladmissions.py +++ b/model/src/pyrenew/latent/hospitaladmissions.py @@ -9,7 +9,7 @@ import numpyro as npro from jax.typing import ArrayLike from pyrenew.deterministic import DeterministicVariable -from pyrenew.metaclass import RandomVariable, TimeArray +from pyrenew.metaclass import RandomVariable, SampledValue class HospitalAdmissionsSample(NamedTuple): @@ -20,12 +20,12 @@ class HospitalAdmissionsSample(NamedTuple): ---------- infection_hosp_rate : float, optional The infection-to-hospitalization rate. Defaults to None. - latent_hospital_admissions : TimeArray or None + latent_hospital_admissions : SampledValue or None The computed number of hospital admissions. Defaults to None. """ infection_hosp_rate: float | None = None - latent_hospital_admissions: TimeArray | None = None + latent_hospital_admissions: SampledValue | None = None def __repr__(self): return f"HospitalAdmissionsSample(infection_hosp_rate={self.infection_hosp_rate}, latent_hospital_admissions={self.latent_hospital_admissions})" @@ -162,7 +162,7 @@ def sample( Parameters ---------- - latent : ArrayLike or TimeArray + latent : ArrayLike or SampledValue Latent infections. **kwargs : dict, optional Additional keyword arguments passed through to internal `sample()` @@ -208,7 +208,7 @@ def sample( return HospitalAdmissionsSample( infection_hosp_rate=infection_hosp_rate, - latent_hospital_admissions=TimeArray( + latent_hospital_admissions=SampledValue( array=latent_hospital_admissions, t_start=self.infection_to_admission_interval_rv.t_start, t_unit=self.infection_to_admission_interval_rv.t_unit, diff --git a/model/src/pyrenew/latent/infection_initialization_process.py b/model/src/pyrenew/latent/infection_initialization_process.py index 836d28e6..2247334b 100644 --- a/model/src/pyrenew/latent/infection_initialization_process.py +++ b/model/src/pyrenew/latent/infection_initialization_process.py @@ -4,7 +4,7 @@ from pyrenew.latent.infection_initialization_method import ( InfectionInitializationMethod, ) -from pyrenew.metaclass import RandomVariable, TimeArray +from pyrenew.metaclass import RandomVariable, SampledValue class InfectionInitializationProcess(RandomVariable): @@ -100,7 +100,7 @@ def sample(self) -> tuple: npro.deterministic(self.name, infection_seeding) return ( - TimeArray( + SampledValue( array=infection_seeding, t_start=self.t_start, t_unit=self.t_unit, diff --git a/model/src/pyrenew/latent/infections.py b/model/src/pyrenew/latent/infections.py index 6dc2c7a8..142b25bc 100644 --- a/model/src/pyrenew/latent/infections.py +++ b/model/src/pyrenew/latent/infections.py @@ -8,7 +8,7 @@ import jax.numpy as jnp import pyrenew.latent.infection_functions as inf from jax.typing import ArrayLike -from pyrenew.metaclass import RandomVariable, TimeArray +from pyrenew.metaclass import RandomVariable, SampledValue class InfectionsSample(NamedTuple): @@ -17,7 +17,7 @@ class InfectionsSample(NamedTuple): Attributes ---------- - post_initialization_infections : TimeArray | None, optional + post_initialization_infections : SampledValue | None, optional The estimated latent infections. Defaults to None. """ @@ -97,4 +97,4 @@ def sample( reversed_generation_interval_pmf=gen_int_rev, ) - return InfectionsSample(TimeArray(post_initialization_infections)) + return InfectionsSample(SampledValue(post_initialization_infections)) diff --git a/model/src/pyrenew/latent/infectionswithfeedback.py b/model/src/pyrenew/latent/infectionswithfeedback.py index 001834ca..a8dccdc8 100644 --- a/model/src/pyrenew/latent/infectionswithfeedback.py +++ b/model/src/pyrenew/latent/infectionswithfeedback.py @@ -10,7 +10,7 @@ from numpy.typing import ArrayLike from pyrenew.metaclass import ( RandomVariable, - TimeArray, + SampledValue, _assert_sample_and_rtype, ) @@ -200,8 +200,8 @@ def sample( npro.deterministic("Rt_adjusted", Rt_adj) return InfectionsRtFeedbackSample( - post_initialization_infections=TimeArray( + post_initialization_infections=SampledValue( post_initialization_infections ), - rt=TimeArray(Rt_adj), + rt=SampledValue(Rt_adj), ) diff --git a/model/src/pyrenew/metaclass.py b/model/src/pyrenew/metaclass.py index c61dcdb0..c4a33d74 100644 --- a/model/src/pyrenew/metaclass.py +++ b/model/src/pyrenew/metaclass.py @@ -92,7 +92,7 @@ def _assert_sample_and_rtype( return None -class TimeArray: +class SampledValue: """ A container for a time-aware array. """ @@ -104,7 +104,7 @@ def __init__( t_unit: int | None = None, ) -> None: """ - Default constructor for TimeArray + Default constructor for SampledValue Parameters ---------- @@ -265,11 +265,11 @@ class DistributionalRVSample(NamedTuple): Attributes ---------- - value : TimeArray + value : SampledValue Sampled value from the distribution. """ - value: TimeArray | None = None + value: SampledValue | None = None def __repr__(self) -> str: """ @@ -344,7 +344,7 @@ def sample( DistributionalRVSample """ return DistributionalRVSample( - value=TimeArray( + value=SampledValue( jnp.atleast_1d( npro.sample( name=self.name, diff --git a/model/src/pyrenew/model/rtinfectionsrenewalmodel.py b/model/src/pyrenew/model/rtinfectionsrenewalmodel.py index 72c20ca4..0b8247b8 100644 --- a/model/src/pyrenew/model/rtinfectionsrenewalmodel.py +++ b/model/src/pyrenew/model/rtinfectionsrenewalmodel.py @@ -13,7 +13,7 @@ from pyrenew.metaclass import ( Model, RandomVariable, - TimeArray, + SampledValue, _assert_sample_and_rtype, ) @@ -245,7 +245,7 @@ def sample( ) return RtInfectionsRenewalSample( - Rt=TimeArray(Rt), - latent_infections=TimeArray(all_latent_infections), - observed_infections=TimeArray(observed_infections), + Rt=SampledValue(Rt), + latent_infections=SampledValue(all_latent_infections), + observed_infections=SampledValue(observed_infections), ) diff --git a/model/src/pyrenew/observation/negativebinomial.py b/model/src/pyrenew/observation/negativebinomial.py index 40f4e7e7..de614f3a 100644 --- a/model/src/pyrenew/observation/negativebinomial.py +++ b/model/src/pyrenew/observation/negativebinomial.py @@ -6,7 +6,7 @@ import numpyro import numpyro.distributions as dist from jax.typing import ArrayLike -from pyrenew.metaclass import RandomVariable, TimeArray +from pyrenew.metaclass import RandomVariable, SampledValue class NegativeBinomialObservation(RandomVariable): @@ -98,4 +98,4 @@ def sample( obs=obs, ) - return (TimeArray(negative_binomial_sample),) + return (SampledValue(negative_binomial_sample),) diff --git a/model/src/pyrenew/observation/poisson.py b/model/src/pyrenew/observation/poisson.py index 2c180fed..28d92c50 100644 --- a/model/src/pyrenew/observation/poisson.py +++ b/model/src/pyrenew/observation/poisson.py @@ -6,7 +6,7 @@ import numpyro import numpyro.distributions as dist from jax.typing import ArrayLike -from pyrenew.metaclass import RandomVariable, TimeArray +from pyrenew.metaclass import RandomVariable, SampledValue class PoissonObservation(RandomVariable): @@ -72,4 +72,4 @@ def sample( fn=dist.Poisson(rate=mu + self.eps), obs=obs, ) - return (TimeArray(poisson_sample),) + return (SampledValue(poisson_sample),) diff --git a/model/src/pyrenew/process/ar.py b/model/src/pyrenew/process/ar.py index 98e04e14..c393c54d 100644 --- a/model/src/pyrenew/process/ar.py +++ b/model/src/pyrenew/process/ar.py @@ -8,7 +8,7 @@ import numpyro.distributions as dist from jax import lax from jax.typing import ArrayLike -from pyrenew.metaclass import RandomVariable, TimeArray +from pyrenew.metaclass import RandomVariable, SampledValue class ARProcess(RandomVariable): @@ -91,7 +91,7 @@ def _ar_scanner(carry, next): # numpydoc ignore=GL08 ) last, ts = lax.scan(_ar_scanner, inits - self.mean, noise) - return (TimeArray(jnp.hstack([inits, self.mean + ts.flatten()])),) + return (SampledValue(jnp.hstack([inits, self.mean + ts.flatten()])),) @staticmethod def validate(): # numpydoc ignore=RT01 diff --git a/model/src/pyrenew/process/firstdifferencear.py b/model/src/pyrenew/process/firstdifferencear.py index c6cde5f3..cfc7b16e 100644 --- a/model/src/pyrenew/process/firstdifferencear.py +++ b/model/src/pyrenew/process/firstdifferencear.py @@ -5,7 +5,7 @@ import jax.numpy as jnp from jax.typing import ArrayLike -from pyrenew.metaclass import RandomVariable, TimeArray +from pyrenew.metaclass import RandomVariable, SampledValue from pyrenew.process import ARProcess @@ -73,7 +73,7 @@ def sample( name=name + "_rate_of_change", ) return ( - TimeArray(init_val + jnp.cumsum(rates_of_change.array.flatten())), + SampledValue(init_val + jnp.cumsum(rates_of_change.array.flatten())), ) @staticmethod diff --git a/model/src/pyrenew/process/periodiceffect.py b/model/src/pyrenew/process/periodiceffect.py index b380478c..1dcffec6 100644 --- a/model/src/pyrenew/process/periodiceffect.py +++ b/model/src/pyrenew/process/periodiceffect.py @@ -6,7 +6,7 @@ import pyrenew.arrayutils as au from pyrenew.metaclass import ( RandomVariable, - TimeArray, + SampledValue, _assert_sample_and_rtype, ) @@ -18,7 +18,7 @@ class PeriodicEffectSample(NamedTuple): Attributes ---------- - value: TimeArray + value: SampledValue The sampled value. """ @@ -114,7 +114,7 @@ def sample(self, duration: int, **kwargs): """ return PeriodicEffectSample( - value=TimeArray( + value=SampledValue( self.broadcaster( data=self.quantity_to_broadcast.sample(**kwargs)[0].array, n_timepoints=duration, diff --git a/model/src/pyrenew/process/rtperiodicdiff.py b/model/src/pyrenew/process/rtperiodicdiff.py index f70d08a7..ec41ab82 100644 --- a/model/src/pyrenew/process/rtperiodicdiff.py +++ b/model/src/pyrenew/process/rtperiodicdiff.py @@ -6,7 +6,7 @@ from pyrenew.arrayutils import PeriodicBroadcaster from pyrenew.metaclass import ( RandomVariable, - TimeArray, + SampledValue, _assert_sample_and_rtype, ) from pyrenew.process.firstdifferencear import FirstDifferenceARProcess @@ -192,7 +192,7 @@ def sample( )[0] return RtPeriodicDiffProcessSample( - rt=TimeArray( + rt=SampledValue( self.broadcaster(jnp.exp(log_rt.array.flatten()), duration) ), ) diff --git a/model/src/pyrenew/process/rtrandomwalk.py b/model/src/pyrenew/process/rtrandomwalk.py index 9dede211..c9fab7b0 100644 --- a/model/src/pyrenew/process/rtrandomwalk.py +++ b/model/src/pyrenew/process/rtrandomwalk.py @@ -4,7 +4,7 @@ import numpyro as npro import numpyro.distributions as dist import pyrenew.transformation as t -from pyrenew.metaclass import RandomVariable, TimeArray +from pyrenew.metaclass import RandomVariable, SampledValue from pyrenew.process.simplerandomwalk import SimpleRandomWalkProcess @@ -122,4 +122,4 @@ def sample( Rt = npro.deterministic("Rt", self.Rt_transform.inv(Rt_trans_ts.array)) - return (TimeArray(Rt),) + return (SampledValue(Rt),) diff --git a/model/src/pyrenew/process/simplerandomwalk.py b/model/src/pyrenew/process/simplerandomwalk.py index 6e5bfa65..2a12e377 100644 --- a/model/src/pyrenew/process/simplerandomwalk.py +++ b/model/src/pyrenew/process/simplerandomwalk.py @@ -5,7 +5,7 @@ import numpyro as npro import numpyro.distributions as dist from numpyro.contrib.control_flow import scan -from pyrenew.metaclass import RandomVariable, TimeArray +from pyrenew.metaclass import RandomVariable, SampledValue class SimpleRandomWalkProcess(RandomVariable): @@ -76,7 +76,7 @@ def transition(x_prev, _): xs=jnp.arange(n_timepoints - 1), ) - return (TimeArray(jnp.hstack([init, x])),) + return (SampledValue(jnp.hstack([init, x])),) @staticmethod def validate(): diff --git a/model/src/test/test_model_hospitalizations.py b/model/src/test/test_model_hospitalizations.py index 5888fd63..af9c9a5b 100644 --- a/model/src/test/test_model_hospitalizations.py +++ b/model/src/test/test_model_hospitalizations.py @@ -21,7 +21,7 @@ Infections, InitializeInfectionsZeroPad, ) -from pyrenew.metaclass import DistributionalRV, RandomVariable, TimeArray +from pyrenew.metaclass import DistributionalRV, RandomVariable, SampledValue from pyrenew.model import HospitalAdmissionsModel from pyrenew.observation import PoissonObservation from pyrenew.process import RtRandomWalkProcess @@ -39,7 +39,7 @@ def validate(self): # numpydoc ignore=GL08 def sample(self, **kwargs): # numpydoc ignore=GL08 return ( - TimeArray( + SampledValue( npro.sample( name=self.name, fn=dist.Uniform(high=0.99, low=0.01) ) From e4a13cefd1c75284126c94fab415fa3b5dd93868 Mon Sep 17 00:00:00 2001 From: George Vega Yon Date: Tue, 23 Jul 2024 10:43:52 -0600 Subject: [PATCH 14/41] Accessing values instead of arrays in SampledValue --- docs/source/tutorials/basic_renewal_model.qmd | 6 ++-- docs/source/tutorials/extending_pyrenew.qmd | 12 ++++---- .../tutorials/hospital_admissions_model.qmd | 4 +-- docs/source/tutorials/periodic_effects.qmd | 4 +-- docs/source/tutorials/pyrenew_demo.qmd | 12 ++++---- .../pyrenew/deterministic/deterministic.py | 2 +- model/src/pyrenew/deterministic/process.py | 6 ++-- .../src/pyrenew/latent/hospitaladmissions.py | 10 +++---- .../latent/infection_initialization_method.py | 2 +- .../infection_initialization_process.py | 4 +-- .../pyrenew/latent/infectionswithfeedback.py | 4 +-- model/src/pyrenew/metaclass.py | 16 +++++------ model/src/pyrenew/model/admissionsmodel.py | 4 +-- .../pyrenew/model/rtinfectionsrenewalmodel.py | 14 +++++----- .../pyrenew/observation/negativebinomial.py | 2 +- .../src/pyrenew/process/firstdifferencear.py | 2 +- model/src/pyrenew/process/periodiceffect.py | 2 +- model/src/pyrenew/process/rtperiodicdiff.py | 8 +++--- model/src/pyrenew/process/rtrandomwalk.py | 2 +- model/src/test/test_ar_process.py | 4 +-- model/src/test/test_deterministic.py | 12 ++++---- model/src/test/test_first_difference_ar.py | 4 +-- model/src/test/test_forecast.py | 2 +- .../src/test/test_infection_seeding_method.py | 10 +++---- model/src/test/test_infectionsrtfeedback.py | 18 ++++++------ model/src/test/test_latent_admissions.py | 8 +++--- model/src/test/test_latent_infections.py | 6 ++-- model/src/test/test_model_basic_renewal.py | 16 +++++------ model/src/test/test_model_hospitalizations.py | 28 +++++++++---------- .../test/test_observation_negativebinom.py | 16 +++++------ model/src/test/test_observation_poisson.py | 2 +- model/src/test/test_periodiceffect.py | 8 +++--- model/src/test/test_random_key.py | 2 +- model/src/test/test_random_walk.py | 6 ++-- model/src/test/test_rtperiodicdiff.py | 12 ++++---- 35 files changed, 135 insertions(+), 135 deletions(-) diff --git a/docs/source/tutorials/basic_renewal_model.qmd b/docs/source/tutorials/basic_renewal_model.qmd index 2c0b6676..24ba5f19 100644 --- a/docs/source/tutorials/basic_renewal_model.qmd +++ b/docs/source/tutorials/basic_renewal_model.qmd @@ -192,11 +192,11 @@ import matplotlib.pyplot as plt fig, axs = plt.subplots(1, 2) # Rt plot -axs[0].plot(sim_data.Rt.array) +axs[0].plot(sim_data.Rt.value) axs[0].set_ylabel("Rt") # Infections plot -axs[1].plot(sim_data.observed_infections.array) +axs[1].plot(sim_data.observed_infections.value) axs[1].set_ylabel("Infections") fig.suptitle("Basic renewal model") @@ -214,7 +214,7 @@ import jax model1.run( num_warmup=2000, num_samples=1000, - data_observed_infections=sim_data.observed_infections.array, + data_observed_infections=sim_data.observed_infections.value, rng_key=jax.random.PRNGKey(54), mcmc_args=dict(progress_bar=False, num_chains=2), ) diff --git a/docs/source/tutorials/extending_pyrenew.qmd b/docs/source/tutorials/extending_pyrenew.qmd index 126d9a7c..f914c36a 100644 --- a/docs/source/tutorials/extending_pyrenew.qmd +++ b/docs/source/tutorials/extending_pyrenew.qmd @@ -96,7 +96,7 @@ with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): import matplotlib.pyplot as plt fig, ax = plt.subplots() -ax.plot(model0_samp.latent_infections.array) +ax.plot(model0_samp.latent_infections.value) ax.set_xlabel("Time") ax.set_ylabel("Infections") plt.show() @@ -201,14 +201,14 @@ class InfFeedback(RandomVariable): **kwargs, ) inf_feedback_strength = au.pad_x_to_match_y( - x=inf_feedback_strength.array, + x=inf_feedback_strength.value, y=Rt, - fill_value=inf_feedback_strength.array[0], + fill_value=inf_feedback_strength.value[0], ) # Sampling inf feedback and adjusting the shape inf_feedback_pmf, *_ = self.infection_feedback_pmf(**kwargs) - inf_fb_pmf_rev = jnp.flip(inf_feedback_pmf.array) + inf_fb_pmf_rev = jnp.flip(inf_feedback_pmf.value) # Generating the infections with feedback all_infections, Rt_adj = compute_infections_from_rt_with_feedback( @@ -269,8 +269,8 @@ Comparing `model0` with `model1`, these two should match: import matplotlib.pyplot as plt fig, ax = plt.subplots(ncols=2) -ax[0].plot(model0_samp.latent_infections.array) -ax[1].plot(model1_samp.latent_infections.array) +ax[0].plot(model0_samp.latent_infections.value) +ax[1].plot(model1_samp.latent_infections.value) ax[0].set_xlabel("Time (model 0)") ax[1].set_xlabel("Time (model 1)") ax[0].set_ylabel("Infections") diff --git a/docs/source/tutorials/hospital_admissions_model.qmd b/docs/source/tutorials/hospital_admissions_model.qmd index 0ce74c2e..362af91d 100644 --- a/docs/source/tutorials/hospital_admissions_model.qmd +++ b/docs/source/tutorials/hospital_admissions_model.qmd @@ -230,11 +230,11 @@ import matplotlib.pyplot as plt fig, axs = plt.subplots(1, 2) # Rt plot -axs[0].plot(sim_data.Rt.array) +axs[0].plot(sim_data.Rt.value) axs[0].set_ylabel("Rt") # Admissions plot -axs[1].plot(sim_data.observed_hosp_admissions.array) +axs[1].plot(sim_data.observed_hosp_admissions.value) axs[1].set_ylabel("Admissions") axs[1].set_yscale("log") diff --git a/docs/source/tutorials/periodic_effects.qmd b/docs/source/tutorials/periodic_effects.qmd index 30b747bf..0dabf79d 100644 --- a/docs/source/tutorials/periodic_effects.qmd +++ b/docs/source/tutorials/periodic_effects.qmd @@ -45,7 +45,7 @@ with npro.handlers.seed(rng_seed=20): # Plotting the Rt values import matplotlib.pyplot as plt -plt.step(np.arange(len(sim_data.rt.array)), sim_data.rt.array, where="post") +plt.step(np.arange(len(sim_data.rt.value)), sim_data.rt.value, where="post") plt.xlabel("Time") plt.ylabel("Rt") plt.title("Simulated Rt values") @@ -90,7 +90,7 @@ with npro.handlers.seed(rng_seed=20): import matplotlib.pyplot as plt plt.step( - np.arange(len(sim_data.value.array)), sim_data.value.array, where="post" + np.arange(len(sim_data.value.value)), sim_data.value.value, where="post" ) plt.xlabel("Time") plt.ylabel("Effect size") diff --git a/docs/source/tutorials/pyrenew_demo.qmd b/docs/source/tutorials/pyrenew_demo.qmd index 0de767be..b863022a 100644 --- a/docs/source/tutorials/pyrenew_demo.qmd +++ b/docs/source/tutorials/pyrenew_demo.qmd @@ -50,7 +50,7 @@ q = SimpleRandomWalkProcess(dist.Normal(0, 0.001)) with seed(rng_seed=np.random.randint(0, 1000)): q_samp = q(n_timepoints=100) -plt.plot(np.exp(q_samp[0].array)) +plt.plot(np.exp(q_samp[0].value)) ``` Next, import several additional functions from the `latent` module of the `pyrenew` package to model infections and hospital admissions. @@ -166,10 +166,10 @@ Visualizations of the single model output show (top) infections over the 30 time # | label: fig-hosp # | fig-cap: Infections fig, ax = plt.subplots(nrows=3, sharex=True) -ax[0].plot(x.latent_infections.array) +ax[0].plot(x.latent_infections.value) ax[0].set_ylim([1 / 5, 5]) -ax[1].plot(x.latent_hosp_admissions.array) -ax[2].plot(x.observed_hosp_admissions.array, "o") +ax[1].plot(x.latent_hosp_admissions.value) +ax[2].plot(x.observed_hosp_admissions.value, "o") for axis in ax[:-1]: axis.set_yscale("log") ``` @@ -181,7 +181,7 @@ To fit the `hospmodel` to the simulated data, we call `hospmodel.run()`, an MCMC hospmodel.run( num_warmup=1000, num_samples=1000, - data_observed_hosp_admissions=x.observed_hosp_admissions.array, + data_observed_hosp_admissions=x.observed_hosp_admissions.value, rng_key=jax.random.PRNGKey(54), mcmc_args=dict(progress_bar=False, num_chains=2), ) @@ -211,7 +211,7 @@ import polars as pl fig, ax = plt.subplots(figsize=[4, 5]) -ax.plot(x[0].array) +ax.plot(x[0].value) samp_ids = np.random.randint(size=25, low=0, high=999) for samp_id in samp_ids: sub_samps = samps.filter(pl.col("draw") == samp_id).sort(pl.col("time")) diff --git a/model/src/pyrenew/deterministic/deterministic.py b/model/src/pyrenew/deterministic/deterministic.py index 32da4714..e61e8ebb 100644 --- a/model/src/pyrenew/deterministic/deterministic.py +++ b/model/src/pyrenew/deterministic/deterministic.py @@ -96,7 +96,7 @@ def sample( npro.deterministic(self.name, self.vars) return ( SampledValue( - array=self.vars, + value=self.vars, t_start=self.t_start, t_unit=self.t_unit, ), diff --git a/model/src/pyrenew/deterministic/process.py b/model/src/pyrenew/deterministic/process.py index 29e04658..7210356b 100644 --- a/model/src/pyrenew/deterministic/process.py +++ b/model/src/pyrenew/deterministic/process.py @@ -35,12 +35,12 @@ def sample( res, *_ = super().sample(**kwargs) - dif = duration - res.array.shape[0] + dif = duration - res.value.shape[0] if dif > 0: return ( SampledValue( - jnp.hstack([res.array, jnp.repeat(res.array[-1], dif)]), + jnp.hstack([res.value, jnp.repeat(res.value[-1], dif)]), t_start=self.t_start, t_unit=self.t_unit, ), @@ -48,7 +48,7 @@ def sample( return ( SampledValue( - array=res.array[:duration], + value=res.value[:duration], t_start=self.t_start, t_unit=self.t_unit, ), diff --git a/model/src/pyrenew/latent/hospitaladmissions.py b/model/src/pyrenew/latent/hospitaladmissions.py index 1d6a2776..5f467ce4 100644 --- a/model/src/pyrenew/latent/hospitaladmissions.py +++ b/model/src/pyrenew/latent/hospitaladmissions.py @@ -175,7 +175,7 @@ def sample( infection_hosp_rate, *_ = self.infect_hosp_rate_rv(**kwargs) - infection_hosp_rate_t = infection_hosp_rate.array * latent_infections + infection_hosp_rate_t = infection_hosp_rate.value * latent_infections ( infection_to_admission_interval, @@ -184,7 +184,7 @@ def sample( latent_hospital_admissions = jnp.convolve( infection_hosp_rate_t, - infection_to_admission_interval.array, + infection_to_admission_interval.value, mode="full", )[: infection_hosp_rate_t.shape[0]] @@ -193,13 +193,13 @@ def sample( latent_hospital_admissions * self.day_of_week_effect_rv( n_timepoints=latent_hospital_admissions.size, **kwargs - )[0].array + )[0].value ) # Applying probability of hospitalization effect latent_hospital_admissions = ( latent_hospital_admissions - * self.hosp_report_prob_rv(**kwargs)[0].array + * self.hosp_report_prob_rv(**kwargs)[0].value ) npro.deterministic( @@ -209,7 +209,7 @@ def sample( return HospitalAdmissionsSample( infection_hosp_rate=infection_hosp_rate, latent_hospital_admissions=SampledValue( - array=latent_hospital_admissions, + value=latent_hospital_admissions, t_start=self.infection_to_admission_interval_rv.t_start, t_unit=self.infection_to_admission_interval_rv.t_unit, ), diff --git a/model/src/pyrenew/latent/infection_initialization_method.py b/model/src/pyrenew/latent/infection_initialization_method.py index ced8b4c9..1a723f1e 100644 --- a/model/src/pyrenew/latent/infection_initialization_method.py +++ b/model/src/pyrenew/latent/infection_initialization_method.py @@ -177,7 +177,7 @@ def seed_infections(self, I_pre_seed: ArrayLike): f"I_pre_seed must be an array of size 1. Got size {I_pre_seed.size}." ) (rate,) = self.rate() - rate = rate.array + rate = rate.value if rate.size != 1: raise ValueError( f"rate must be an array of size 1. Got size {rate.size}." diff --git a/model/src/pyrenew/latent/infection_initialization_process.py b/model/src/pyrenew/latent/infection_initialization_process.py index 2247334b..3e72a7d0 100644 --- a/model/src/pyrenew/latent/infection_initialization_process.py +++ b/model/src/pyrenew/latent/infection_initialization_process.py @@ -96,12 +96,12 @@ def sample(self) -> tuple: a tuple where the only element is an array with the number of seeded infections at each time point. """ (I_pre_seed,) = self.I_pre_seed_rv() - infection_seeding = self.infection_seed_method(I_pre_seed.array) + infection_seeding = self.infection_seed_method(I_pre_seed.value) npro.deterministic(self.name, infection_seeding) return ( SampledValue( - array=infection_seeding, + value=infection_seeding, t_start=self.t_start, t_unit=self.t_unit, ), diff --git a/model/src/pyrenew/latent/infectionswithfeedback.py b/model/src/pyrenew/latent/infectionswithfeedback.py index a8dccdc8..7da74778 100644 --- a/model/src/pyrenew/latent/infectionswithfeedback.py +++ b/model/src/pyrenew/latent/infectionswithfeedback.py @@ -163,7 +163,7 @@ def sample( inf_feedback_strength, *_ = self.infection_feedback_strength( **kwargs, ) - inf_feedback_strength = inf_feedback_strength.array + inf_feedback_strength = inf_feedback_strength.value # Making sure inf_feedback_strength spans the Rt length if inf_feedback_strength.size == 1: @@ -182,7 +182,7 @@ def sample( # Sampling inf feedback pmf inf_feedback_pmf, *_ = self.infection_feedback_pmf(**kwargs) - inf_fb_pmf_rev = jnp.flip(inf_feedback_pmf.array) + inf_fb_pmf_rev = jnp.flip(inf_feedback_pmf.value) ( post_initialization_infections, diff --git a/model/src/pyrenew/metaclass.py b/model/src/pyrenew/metaclass.py index c4a33d74..5b1fbe21 100644 --- a/model/src/pyrenew/metaclass.py +++ b/model/src/pyrenew/metaclass.py @@ -94,12 +94,12 @@ def _assert_sample_and_rtype( class SampledValue: """ - A container for a time-aware array. + A container for a sampled value from a RandomVariable. """ def __init__( self, - array: ArrayLike | None, + value: ArrayLike | None, t_start: int | None = None, t_unit: int | None = None, ) -> None: @@ -108,8 +108,8 @@ def __init__( Parameters ---------- - array : ArrayLike - The data array. + value : ArrayLike + The sampled value. t_start : int, optional The start time of the data. t_unit : int, optional @@ -120,12 +120,12 @@ def __init__( None """ - if array is not None: + if value is not None: assert isinstance( - array, ArrayLike - ), "array should be an array-like object." + value, ArrayLike + ), "value should be an array-like object." - self.array = array + self.value = value self.t_start = t_start self.t_unit = t_unit diff --git a/model/src/pyrenew/model/admissionsmodel.py b/model/src/pyrenew/model/admissionsmodel.py index 8fa32780..480c5922 100644 --- a/model/src/pyrenew/model/admissionsmodel.py +++ b/model/src/pyrenew/model/admissionsmodel.py @@ -198,7 +198,7 @@ def sample( latent_hosp_admissions, *_, ) = self.latent_hosp_admissions_rv( - latent_infections=basic_model.latent_infections.array, + latent_infections=basic_model.latent_infections.value, **kwargs, ) @@ -206,7 +206,7 @@ def sample( observed_hosp_admissions, *_, ) = self.hosp_admission_obs_process_rv( - mu=latent_hosp_admissions.array[-n_datapoints:], + mu=latent_hosp_admissions.value[-n_datapoints:], obs=data_observed_hosp_admissions, **kwargs, ) diff --git a/model/src/pyrenew/model/rtinfectionsrenewalmodel.py b/model/src/pyrenew/model/rtinfectionsrenewalmodel.py index 0b8247b8..e395ffdc 100644 --- a/model/src/pyrenew/model/rtinfectionsrenewalmodel.py +++ b/model/src/pyrenew/model/rtinfectionsrenewalmodel.py @@ -212,33 +212,33 @@ def sample( post_initialization_latent_infections, *_, ) = self.latent_infections_rv( - Rt=Rt.array, - gen_int=gen_int.array, - I0=I0.array, + Rt=Rt.value, + gen_int=gen_int.value, + I0=I0.value, **kwargs, ) observed_infections, *_ = self.infection_obs_process_rv( - mu=post_initialization_latent_infections.array[padding:], + mu=post_initialization_latent_infections.value[padding:], obs=data_observed_infections, **kwargs, ) all_latent_infections = jnp.hstack( - [I0.array, post_initialization_latent_infections.array] + [I0.value, post_initialization_latent_infections.value] ) npro.deterministic("all_latent_infections", all_latent_infections) if observed_infections is not None: observed_infections = au.pad_x_to_match_y( - observed_infections.array, + observed_infections.value, all_latent_infections, jnp.nan, pad_direction="start", ) Rt = au.pad_x_to_match_y( - Rt.array, + Rt.value, all_latent_infections, jnp.nan, pad_direction="start", diff --git a/model/src/pyrenew/observation/negativebinomial.py b/model/src/pyrenew/observation/negativebinomial.py index de614f3a..bd06d45a 100644 --- a/model/src/pyrenew/observation/negativebinomial.py +++ b/model/src/pyrenew/observation/negativebinomial.py @@ -93,7 +93,7 @@ def sample( name=self.name, fn=dist.NegativeBinomial2( mean=mu + self.eps, - concentration=concentration.array, + concentration=concentration.value, ), obs=obs, ) diff --git a/model/src/pyrenew/process/firstdifferencear.py b/model/src/pyrenew/process/firstdifferencear.py index cfc7b16e..15c27e13 100644 --- a/model/src/pyrenew/process/firstdifferencear.py +++ b/model/src/pyrenew/process/firstdifferencear.py @@ -73,7 +73,7 @@ def sample( name=name + "_rate_of_change", ) return ( - SampledValue(init_val + jnp.cumsum(rates_of_change.array.flatten())), + SampledValue(init_val + jnp.cumsum(rates_of_change.value.flatten())), ) @staticmethod diff --git a/model/src/pyrenew/process/periodiceffect.py b/model/src/pyrenew/process/periodiceffect.py index 1dcffec6..e0b81dad 100644 --- a/model/src/pyrenew/process/periodiceffect.py +++ b/model/src/pyrenew/process/periodiceffect.py @@ -116,7 +116,7 @@ def sample(self, duration: int, **kwargs): return PeriodicEffectSample( value=SampledValue( self.broadcaster( - data=self.quantity_to_broadcast.sample(**kwargs)[0].array, + data=self.quantity_to_broadcast.sample(**kwargs)[0].value, n_timepoints=duration, ) ) diff --git a/model/src/pyrenew/process/rtperiodicdiff.py b/model/src/pyrenew/process/rtperiodicdiff.py index ec41ab82..c6b7306b 100644 --- a/model/src/pyrenew/process/rtperiodicdiff.py +++ b/model/src/pyrenew/process/rtperiodicdiff.py @@ -176,9 +176,9 @@ def sample( """ # Initial sample - log_rt_prior = self.log_rt_prior.sample(**kwargs)[0].array - b = self.autoreg.sample(**kwargs)[0].array - s_r = self.periodic_diff_sd.sample(**kwargs)[0].array + log_rt_prior = self.log_rt_prior.sample(**kwargs)[0].value + b = self.autoreg.sample(**kwargs)[0].value + s_r = self.periodic_diff_sd.sample(**kwargs)[0].value # How many periods to sample? n_periods = int(jnp.ceil(duration / self.period_size)) @@ -193,7 +193,7 @@ def sample( return RtPeriodicDiffProcessSample( rt=SampledValue( - self.broadcaster(jnp.exp(log_rt.array.flatten()), duration) + self.broadcaster(jnp.exp(log_rt.value.flatten()), duration) ), ) diff --git a/model/src/pyrenew/process/rtrandomwalk.py b/model/src/pyrenew/process/rtrandomwalk.py index c9fab7b0..a9432562 100644 --- a/model/src/pyrenew/process/rtrandomwalk.py +++ b/model/src/pyrenew/process/rtrandomwalk.py @@ -120,6 +120,6 @@ def sample( init=Rt0_trans, ) - Rt = npro.deterministic("Rt", self.Rt_transform.inv(Rt_trans_ts.array)) + Rt = npro.deterministic("Rt", self.Rt_transform.inv(Rt_trans_ts.value)) return (SampledValue(Rt),) diff --git a/model/src/test/test_ar_process.py b/model/src/test/test_ar_process.py index 72f5bbcc..552a0777 100755 --- a/model/src/test/test_ar_process.py +++ b/model/src/test/test_ar_process.py @@ -37,5 +37,5 @@ def test_ar_samples_correctly_distributed(): # check it regresses to mean # when started away from it long_ts, *_ = ar1(duration=10000, inits=ar_inits) - assert_almost_equal(long_ts.array[0], ar_inits) - assert jnp.abs(long_ts.array[-1] - ar_mean) < 4 * noise_sd + assert_almost_equal(long_ts.value[0], ar_inits) + assert jnp.abs(long_ts.value[-1] - ar_mean) < 4 * noise_sd diff --git a/model/src/test/test_deterministic.py b/model/src/test/test_deterministic.py index 9a22e784..37c3072f 100644 --- a/model/src/test/test_deterministic.py +++ b/model/src/test/test_deterministic.py @@ -31,7 +31,7 @@ def test_deterministic(): var5 = NullProcess() testing.assert_array_equal( - var1()[0].array, + var1()[0].value, jnp.array( [ 1, @@ -39,16 +39,16 @@ def test_deterministic(): ), ) testing.assert_array_equal( - var2()[0].array, + var2()[0].value, jnp.array([0.25, 0.25, 0.2, 0.3]), ) testing.assert_array_equal( - var3(duration=5)[0].array, + var3(duration=5)[0].value, jnp.array([1, 2, 3, 4, 4]), ) testing.assert_array_equal( - var3(duration=3)[0].array, + var3(duration=3)[0].value, jnp.array( [ 1, @@ -58,5 +58,5 @@ def test_deterministic(): ), ) - testing.assert_equal(var4()[0].array, None) - testing.assert_equal(var5(duration=1)[0].array, None) + testing.assert_equal(var4()[0].value, None) + testing.assert_equal(var5(duration=1)[0].value, None) diff --git a/model/src/test/test_first_difference_ar.py b/model/src/test/test_first_difference_ar.py index 6fb67b78..605b336f 100755 --- a/model/src/test/test_first_difference_ar.py +++ b/model/src/test/test_first_difference_ar.py @@ -26,5 +26,5 @@ def test_fd_ar_can_be_sampled(): ) # Checking proper shape - assert ans0[0].array.shape == (3532,) - assert ans1[0].array.shape == (3532,) + assert ans0[0].value.shape == (3532,) + assert ans1[0].value.shape == (3532,) diff --git a/model/src/test/test_forecast.py b/model/src/test/test_forecast.py index 57b01c9a..4c2a3fb2 100644 --- a/model/src/test/test_forecast.py +++ b/model/src/test/test_forecast.py @@ -54,7 +54,7 @@ def test_forecast(): model.run( num_warmup=5, num_samples=5, - data_observed_infections=model_sample.observed_infections.array, + data_observed_infections=model_sample.observed_infections.value, rng_key=jr.key(54), ) diff --git a/model/src/test/test_infection_seeding_method.py b/model/src/test/test_infection_seeding_method.py index 21c502cc..e5e37cc9 100644 --- a/model/src/test/test_infection_seeding_method.py +++ b/model/src/test/test_infection_seeding_method.py @@ -20,8 +20,8 @@ def test_seed_infections_exponential(): (I_pre_seed,) = I_pre_seed_RV() (rate,) = rate_RV() - I_pre_seed = I_pre_seed.array - rate = rate.array + I_pre_seed = I_pre_seed.value + rate = rate.value infections_default_t_pre_seed = InitializeInfectionsExponentialGrowth( n_timepoints, rate=rate_RV ).seed_infections(I_pre_seed) @@ -51,7 +51,7 @@ def test_seed_infections_exponential(): with pytest.raises(ValueError): InitializeInfectionsExponentialGrowth( n_timepoints, rate=rate_RV - ).seed_infections(I_pre_seed_2.array) + ).seed_infections(I_pre_seed_2.value) # test non-default t_pre_seed t_pre_seed = 6 @@ -76,7 +76,7 @@ def test_seed_infections_zero_pad(): n_timepoints = 10 I_pre_seed_RV = DeterministicVariable(10.0, name="I_pre_seed_RV") (I_pre_seed,) = I_pre_seed_RV() - I_pre_seed = I_pre_seed.array + I_pre_seed = I_pre_seed.value infections = InitializeInfectionsZeroPad(n_timepoints).seed_infections( I_pre_seed @@ -89,7 +89,7 @@ def test_seed_infections_zero_pad(): np.array([10.0, 10.0]), name="I_pre_seed_RV" ) (I_pre_seed_2,) = I_pre_seed_RV_2() - I_pre_seed_2 = I_pre_seed_2.array + I_pre_seed_2 = I_pre_seed_2.value infections_2 = InitializeInfectionsZeroPad(n_timepoints).seed_infections( I_pre_seed_2 diff --git a/model/src/test/test_infectionsrtfeedback.py b/model/src/test/test_infectionsrtfeedback.py index 5e07aa25..ef1b564f 100644 --- a/model/src/test/test_infectionsrtfeedback.py +++ b/model/src/test/test_infectionsrtfeedback.py @@ -95,10 +95,10 @@ def test_infectionsrtfeedback(): ) assert_array_equal( - samp1.post_initialization_infections.array, - samp2.post_initialization_infections.array, + samp1.post_initialization_infections.value, + samp2.post_initialization_infections.value, ) - assert_array_equal(samp1.rt.array, Rt) + assert_array_equal(samp1.rt.value, Rt) return None @@ -142,18 +142,18 @@ def test_infectionsrtfeedback_feedback(): gen_int=gen_int, Rt=Rt, I0=I0, - inf_feedback_strength=inf_feed_strength()[0].array, - inf_feedback_pmf=inf_feedback_pmf()[0].array, + inf_feedback_strength=inf_feed_strength()[0].value, + inf_feedback_pmf=inf_feedback_pmf()[0].value, ) assert not jnp.array_equal( - samp1.post_initialization_infections.array, - samp2.post_initialization_infections.array, + samp1.post_initialization_infections.value, + samp2.post_initialization_infections.value, ) assert_array_almost_equal( - samp1.post_initialization_infections.array, + samp1.post_initialization_infections.value, res["post_initialization_infections"], ) - assert_array_almost_equal(samp1.rt.array, res["rt"]) + assert_array_almost_equal(samp1.rt.value, res["rt"]) return None diff --git a/model/src/test/test_latent_admissions.py b/model/src/test/test_latent_admissions.py index 908d867e..283b8b7f 100644 --- a/model/src/test/test_latent_admissions.py +++ b/model/src/test/test_latent_admissions.py @@ -36,7 +36,7 @@ def test_admissions_sample(): inf1 = Infections() with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): - inf_sampled1 = inf1(Rt=sim_rt.array, gen_int=gen_int, I0=i0) + inf_sampled1 = inf1(Rt=sim_rt.value, gen_int=gen_int, I0=i0) # Testing the hospital admissions inf_hosp = DeterministicPMF( @@ -73,9 +73,9 @@ def test_admissions_sample(): ) with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): - sim_hosp_1 = hosp1(latent_infections=inf_sampled1[0].array) + sim_hosp_1 = hosp1(latent_infections=inf_sampled1[0].value) testing.assert_array_less( - sim_hosp_1.latent_hospital_admissions.array, - inf_sampled1[0].array, + sim_hosp_1.latent_hospital_admissions.value, + inf_sampled1[0].value, ) diff --git a/model/src/test/test_latent_infections.py b/model/src/test/test_latent_infections.py index c730e6b9..28972d0f 100755 --- a/model/src/test/test_latent_infections.py +++ b/model/src/test/test_latent_infections.py @@ -32,7 +32,7 @@ def test_infections_as_deterministic(): inf1 = Infections() obs = dict( - Rt=sim_rt.array, + Rt=sim_rt.value, I0=jnp.zeros(gen_int.size), gen_int=gen_int, ) @@ -41,8 +41,8 @@ def test_infections_as_deterministic(): inf_sampled2 = inf1(**obs) testing.assert_array_equal( - inf_sampled1.post_initialization_infections.array, - inf_sampled2.post_initialization_infections.array, + inf_sampled1.post_initialization_infections.value, + inf_sampled2.post_initialization_infections.value, ) # Check that Initial infections vector must be at least as long as the generation interval. diff --git a/model/src/test/test_model_basic_renewal.py b/model/src/test/test_model_basic_renewal.py index c5089232..a16f9295 100644 --- a/model/src/test/test_model_basic_renewal.py +++ b/model/src/test/test_model_basic_renewal.py @@ -148,21 +148,21 @@ def test_model_basicrenewal_no_obs_model(): with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): model1_samp = model0.sample(n_timepoints_to_simulate=30) - np.testing.assert_array_equal(model0_samp.Rt.array, model1_samp.Rt.array) + np.testing.assert_array_equal(model0_samp.Rt.value, model1_samp.Rt.value) np.testing.assert_array_equal( - model0_samp.latent_infections.array, - model1_samp.latent_infections.array, + model0_samp.latent_infections.value, + model1_samp.latent_infections.value, ) np.testing.assert_array_equal( - model0_samp.observed_infections.array, - model1_samp.observed_infections.array, + model0_samp.observed_infections.value, + model1_samp.observed_infections.value, ) model0.run( num_warmup=500, num_samples=500, rng_key=jr.key(272), - data_observed_infections=model0_samp.latent_infections.array, + data_observed_infections=model0_samp.latent_infections.value, ) inf = model0.spread_draws(["all_latent_infections"]) @@ -221,7 +221,7 @@ def test_model_basicrenewal_with_obs_model(): num_warmup=500, num_samples=500, rng_key=jr.key(22), - data_observed_infections=model1_samp.observed_infections.array, + data_observed_infections=model1_samp.observed_infections.value, ) inf = model1.spread_draws(["all_latent_infections"]) @@ -278,7 +278,7 @@ def test_model_basicrenewal_padding() -> None: # numpydoc ignore=GL08 num_warmup=500, num_samples=500, rng_key=jr.key(22), - data_observed_infections=model1_samp.observed_infections.array, + data_observed_infections=model1_samp.observed_infections.value, padding=5, ) diff --git a/model/src/test/test_model_hospitalizations.py b/model/src/test/test_model_hospitalizations.py index af9c9a5b..ff7ea64c 100644 --- a/model/src/test/test_model_hospitalizations.py +++ b/model/src/test/test_model_hospitalizations.py @@ -264,30 +264,30 @@ def test_model_hosp_no_obs_model(): model1_samp = model0.sample(n_timepoints_to_simulate=30) np.testing.assert_array_almost_equal( - model0_samp.Rt.array, model1_samp.Rt.array + model0_samp.Rt.value, model1_samp.Rt.value ) np.testing.assert_array_equal( - model0_samp.latent_infections.array, - model1_samp.latent_infections.array, + model0_samp.latent_infections.value, + model1_samp.latent_infections.value, ) np.testing.assert_array_equal( - model0_samp.infection_hosp_rate.array, - model1_samp.infection_hosp_rate.array, + model0_samp.infection_hosp_rate.value, + model1_samp.infection_hosp_rate.value, ) np.testing.assert_array_equal( - model0_samp.latent_hosp_admissions.array, - model1_samp.latent_hosp_admissions.array, + model0_samp.latent_hosp_admissions.value, + model1_samp.latent_hosp_admissions.value, ) # These are supposed to be none, both - assert model0_samp.observed_hosp_admissions.array is None - assert model1_samp.observed_hosp_admissions.array is None + assert model0_samp.observed_hosp_admissions.value is None + assert model1_samp.observed_hosp_admissions.value is None model0.run( num_warmup=500, num_samples=500, rng_key=jr.key(272), - data_observed_hosp_admissions=model0_samp.latent_hosp_admissions.array, + data_observed_hosp_admissions=model0_samp.latent_hosp_admissions.value, ) inf = model0.spread_draws(["latent_hospital_admissions"]) @@ -377,7 +377,7 @@ def test_model_hosp_with_obs_model(): num_warmup=500, num_samples=500, rng_key=jr.key(272), - data_observed_hosp_admissions=model1_samp.observed_hosp_admissions.array, + data_observed_hosp_admissions=model1_samp.observed_hosp_admissions.value, ) inf = model1.spread_draws(["latent_hospital_admissions"]) @@ -478,7 +478,7 @@ def test_model_hosp_with_obs_model_weekday_phosp_2(): num_warmup=500, num_samples=500, rng_key=jr.key(272), - data_observed_hosp_admissions=model1_samp.observed_hosp_admissions.array, + data_observed_hosp_admissions=model1_samp.observed_hosp_admissions.value, ) inf = model1.spread_draws(["latent_hospital_admissions"]) @@ -594,7 +594,7 @@ def test_model_hosp_with_obs_model_weekday_phosp(): # obs = jnp.hstack( # [ # jnp.repeat(jnp.nan, pad_size), - # model1_samp.observed_hosp_admissions.array[pad_size:], + # model1_samp.observed_hosp_admissions.value[pad_size:], # ] # ) # Running with padding @@ -602,7 +602,7 @@ def test_model_hosp_with_obs_model_weekday_phosp(): num_warmup=500, num_samples=500, rng_key=jr.key(272), - data_observed_hosp_admissions=model1_samp.observed_hosp_admissions.array, + data_observed_hosp_admissions=model1_samp.observed_hosp_admissions.value, padding=pad_size, ) diff --git a/model/src/test/test_observation_negativebinom.py b/model/src/test/test_observation_negativebinom.py index b3686d41..f2bfd894 100644 --- a/model/src/test/test_observation_negativebinom.py +++ b/model/src/test/test_observation_negativebinom.py @@ -27,12 +27,12 @@ def test_negativebinom_deterministic_obs(): assert isinstance(sim_nb1, tuple) assert isinstance(sim_nb2, tuple) - assert isinstance(sim_nb1[0].array, ArrayLike) - assert isinstance(sim_nb2[0].array, ArrayLike) + assert isinstance(sim_nb1[0].value, ArrayLike) + assert isinstance(sim_nb2[0].value, ArrayLike) testing.assert_array_equal( - sim_nb1[0].array, - sim_nb2[0].array, + sim_nb1[0].value, + sim_nb2[0].value, ) @@ -53,11 +53,11 @@ def test_negativebinom_random_obs(): sim_nb2 = negb(mu=rates) assert isinstance(sim_nb1, tuple) assert isinstance(sim_nb2, tuple) - assert isinstance(sim_nb1[0].array, ArrayLike) - assert isinstance(sim_nb2[0].array, ArrayLike) + assert isinstance(sim_nb1[0].value, ArrayLike) + assert isinstance(sim_nb2[0].value, ArrayLike) testing.assert_array_almost_equal( - np.mean(sim_nb1[0].array), - np.mean(sim_nb2[0].array), + np.mean(sim_nb1[0].value), + np.mean(sim_nb2[0].value), decimal=1, ) diff --git a/model/src/test/test_observation_poisson.py b/model/src/test/test_observation_poisson.py index ea622e8d..5ca9c3bc 100644 --- a/model/src/test/test_observation_poisson.py +++ b/model/src/test/test_observation_poisson.py @@ -20,4 +20,4 @@ def test_poisson_obs(): with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): sim_pois, *_ = pois(mu=rates) - testing.assert_array_equal(sim_pois.array, jnp.ceil(sim_pois.array)) + testing.assert_array_equal(sim_pois.value, jnp.ceil(sim_pois.value)) diff --git a/model/src/test/test_periodiceffect.py b/model/src/test/test_periodiceffect.py index 481ecad3..d5014a9e 100644 --- a/model/src/test/test_periodiceffect.py +++ b/model/src/test/test_periodiceffect.py @@ -30,7 +30,7 @@ def test_periodiceffect() -> None: np.random.seed(223) with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): - ans = pe(duration=duration).value.array + ans = pe(duration=duration).value.value # Checking that the shape of the sampled Rt is correct assert ans.shape == (duration,) @@ -44,7 +44,7 @@ def test_periodiceffect() -> None: params["offset"] = 5 pe = PeriodicEffect(**params) with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): - ans2 = pe(duration=duration).value.array + ans2 = pe(duration=duration).value.value # Checking that the shape of the sampled Rt is correct assert ans2.shape == (duration,) @@ -81,8 +81,8 @@ def test_weeklyeffect() -> None: pe = PeriodicEffect(**params) pe2 = DayOfWeekEffect(**params2) - ans1 = pe(duration=duration).value.array - ans2 = pe2(duration=duration).value.array + ans1 = pe(duration=duration).value.value + ans2 = pe2(duration=duration).value.value assert_array_equal(ans1, ans2) diff --git a/model/src/test/test_random_key.py b/model/src/test/test_random_key.py index 00068fcb..9d7fcacb 100644 --- a/model/src/test/test_random_key.py +++ b/model/src/test/test_random_key.py @@ -103,7 +103,7 @@ def test_rng_keys_produce_correct_samples(): model_sample = models[0].sample( n_timepoints_to_simulate=n_timepoints_to_simulate[0] ) - obs_infections = [model_sample.observed_infections.array] * len(models) + obs_infections = [model_sample.observed_infections.value] * len(models) rng_keys = [jr.key(54), jr.key(54), None, None, jr.key(74)] # run test models with the different keys diff --git a/model/src/test/test_random_walk.py b/model/src/test/test_random_walk.py index d564342f..19baafcf 100755 --- a/model/src/test/test_random_walk.py +++ b/model/src/test/test_random_walk.py @@ -20,8 +20,8 @@ def test_rw_can_be_sampled(): ans1 = rw_normal(n_timepoints=5023) # check that the samples are of the right shape - assert ans0[0].array.shape == (3532,) - assert ans1[0].array.shape == (5023,) + assert ans0[0].value.shape == (3532,) + assert ans1[0].value.shape == (5023,) def test_rw_samples_correctly_distributed(): @@ -38,7 +38,7 @@ def test_rw_samples_correctly_distributed(): rw_init = 532.0 with numpyro.handlers.seed(rng_seed=62): samples, *_ = rw_normal(n_timepoints=n_samples, init=rw_init) - samples = samples.array + samples = samples.value # Checking the shape assert samples.shape == (n_samples,) diff --git a/model/src/test/test_rtperiodicdiff.py b/model/src/test/test_rtperiodicdiff.py index 750c493d..ca8c13ce 100644 --- a/model/src/test/test_rtperiodicdiff.py +++ b/model/src/test/test_rtperiodicdiff.py @@ -66,7 +66,7 @@ def test_rtweeklydiff() -> None: np.random.seed(223) with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): - rt = rtwd(duration=duration).rt.array + rt = rtwd(duration=duration).rt.value # Checking that the shape of the sampled Rt is correct assert rt.shape == (duration,) @@ -81,7 +81,7 @@ def test_rtweeklydiff() -> None: params["offset"] = 5 rtwd = RtWeeklyDiffProcess(**params) with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): - rt2 = rtwd(duration=duration).rt.array + rt2 = rtwd(duration=duration).rt.value # Checking that the shape of the sampled Rt is correct assert rt2.shape == (duration,) @@ -114,7 +114,7 @@ def test_rtweeklydiff_no_autoregressive() -> None: np.random.seed(223) duration = 1000 with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): - rt = rtwd(duration=duration).rt.array + rt = rtwd(duration=duration).rt.value # Checking that the shape of the sampled Rt is correct assert rt.shape == (duration,) @@ -153,12 +153,12 @@ def test_rtweeklydiff_manual_reconstruction() -> None: _, ans0 = lax.scan( f=rtwd.autoreg_process, - init=np.hstack([params["log_rt_prior"]()[0].array, b]), + init=np.hstack([params["log_rt_prior"]()[0].value, b]), xs=noise, ) ans1 = _manual_rt_weekly_diff( - log_seed=params["log_rt_prior"]()[0].array, sd=noise, b=b + log_seed=params["log_rt_prior"]()[0].value, sd=noise, b=b ) assert_array_equal(ans0, ans1) @@ -185,7 +185,7 @@ def test_rtperiodicdiff_smallsample(): np.random.seed(223) with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): - rt = rtwd(duration=6).rt.array + rt = rtwd(duration=6).rt.value # Checking that the shape of the sampled Rt is correct assert rt.shape == (6,) From 6d44a700d626159b22aa78196ca193b52ca85ce8 Mon Sep 17 00:00:00 2001 From: George Vega Yon Date: Tue, 23 Jul 2024 11:22:55 -0600 Subject: [PATCH 15/41] Making pre-commit happy --- model/src/pyrenew/metaclass.py | 10 ++++++---- model/src/pyrenew/process/firstdifferencear.py | 4 +++- model/src/test/test_model_hosp_admissions.py | 2 +- model/src/test/test_random_walk.py | 1 + model/src/test/test_transformed_rv_class.py | 7 +++++-- 5 files changed, 16 insertions(+), 8 deletions(-) diff --git a/model/src/pyrenew/metaclass.py b/model/src/pyrenew/metaclass.py index 08903d27..1a2aa010 100644 --- a/model/src/pyrenew/metaclass.py +++ b/model/src/pyrenew/metaclass.py @@ -198,8 +198,10 @@ def set_timeseries( # Either both values are None or both are not None assert (t_unit is not None and t_start is not None) or ( t_unit is None and t_start is None - ), "Both t_start and t_unit should be None or not None. " \ + ), ( + "Both t_start and t_unit should be None or not None. " "Currently, t_start is {t_start} and t_unit is {t_unit}." + ) if t_unit is None and t_start is None: return None @@ -699,9 +701,9 @@ def sample(self, **kwargs) -> tuple: untransformed_values = self.base_rv.sample(**kwargs) return tuple( - SampledValue(t(uv.value)) for t, uv in zip(self.transforms, untransformed_values) - ) - + SampledValue(t(uv.value)) + for t, uv in zip(self.transforms, untransformed_values) + ) def sample_length(self): """ diff --git a/model/src/pyrenew/process/firstdifferencear.py b/model/src/pyrenew/process/firstdifferencear.py index 15c27e13..ed9ac575 100644 --- a/model/src/pyrenew/process/firstdifferencear.py +++ b/model/src/pyrenew/process/firstdifferencear.py @@ -73,7 +73,9 @@ def sample( name=name + "_rate_of_change", ) return ( - SampledValue(init_val + jnp.cumsum(rates_of_change.value.flatten())), + SampledValue( + init_val + jnp.cumsum(rates_of_change.value.flatten()) + ), ) @staticmethod diff --git a/model/src/test/test_model_hosp_admissions.py b/model/src/test/test_model_hosp_admissions.py index 30a8a228..5f3f3694 100644 --- a/model/src/test/test_model_hosp_admissions.py +++ b/model/src/test/test_model_hosp_admissions.py @@ -24,8 +24,8 @@ from pyrenew.metaclass import ( DistributionalRV, RandomVariable, - TransformedRandomVariable, SampledValue, + TransformedRandomVariable, ) from pyrenew.model import HospitalAdmissionsModel from pyrenew.observation import PoissonObservation diff --git a/model/src/test/test_random_walk.py b/model/src/test/test_random_walk.py index 0d395e1c..7a93a094 100755 --- a/model/src/test/test_random_walk.py +++ b/model/src/test/test_random_walk.py @@ -41,6 +41,7 @@ def test_rw_can_be_sampled(): assert_almost_equal(ans_fixed[0].value[0], init_rv_fixed.vars) assert ans_rand[0].value[0] != init_rv_fixed.vars + def test_rw_samples_correctly_distributed(): """ Check that a simple random walk has steps diff --git a/model/src/test/test_transformed_rv_class.py b/model/src/test/test_transformed_rv_class.py index e7837377..30fa2821 100644 --- a/model/src/test/test_transformed_rv_class.py +++ b/model/src/test/test_transformed_rv_class.py @@ -12,8 +12,8 @@ from pyrenew.metaclass import ( DistributionalRV, RandomVariable, - TransformedRandomVariable, SampledValue, + TransformedRandomVariable, ) @@ -134,6 +134,9 @@ def test_transforms_applied_at_sampling(): tr(norm_base_sample[0].value), norm_transformed_sample[0].value ) assert_almost_equal( - (tr(l2_base_sample[0].value), t.ExpTransform()(l2_base_sample[1].value)), + ( + tr(l2_base_sample[0].value), + t.ExpTransform()(l2_base_sample[1].value), + ), (l2_transformed_sample[0].value, l2_transformed_sample[1].value), ) From 2bba85a9b30f7e311b599fedf1e7425ed94b120c Mon Sep 17 00:00:00 2001 From: George Vega Yon Date: Tue, 23 Jul 2024 11:47:14 -0600 Subject: [PATCH 16/41] Fixing tutorial --- docs/source/tutorials/hospital_admissions_model.qmd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/tutorials/hospital_admissions_model.qmd b/docs/source/tutorials/hospital_admissions_model.qmd index 29923c20..9b534f90 100644 --- a/docs/source/tutorials/hospital_admissions_model.qmd +++ b/docs/source/tutorials/hospital_admissions_model.qmd @@ -197,7 +197,7 @@ class MyRt(metaclass.RandomVariable): base_rv=process.SimpleRandomWalkProcess( name="log_rt", step_rv=metaclass.DistributionalRV( - dist.Normal(0, sd_rt), "rw_step_rv" + dist.Normal(0, sd_rt.value), "rw_step_rv" ), init_rv=metaclass.DistributionalRV( dist.Normal(0, 0.2), "init_log_Rt_rv" From 2220fb0cb2ba657e7d062ee28801492e066f9083 Mon Sep 17 00:00:00 2001 From: George Vega Yon Date: Wed, 24 Jul 2024 08:45:24 -0600 Subject: [PATCH 17/41] Fixing last merge --- .../latent/infection_initialization_process.py | 13 ++++++++++--- model/src/test/test_infection_seeding_method.py | 8 ++++---- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/model/src/pyrenew/latent/infection_initialization_process.py b/model/src/pyrenew/latent/infection_initialization_process.py index 699b9ec6..ba9651e8 100644 --- a/model/src/pyrenew/latent/infection_initialization_process.py +++ b/model/src/pyrenew/latent/infection_initialization_process.py @@ -97,8 +97,15 @@ def sample(self) -> tuple: """ (I_pre_init,) = self.I_pre_init_rv() - infection_initialization = self.infection_init_method(I_pre_init) + infection_initialization = self.infection_init_method( + I_pre_init.value, + ) npro.deterministic(self.name, infection_initialization) - return (SampledValue(infection_initialization, t_start=self.t_start, t_unit=self.t_unit),) - + return ( + SampledValue( + infection_initialization, + t_start=self.t_start, + t_unit=self.t_unit, + ), + ) diff --git a/model/src/test/test_infection_seeding_method.py b/model/src/test/test_infection_seeding_method.py index 5d46668c..4c1d711b 100644 --- a/model/src/test/test_infection_seeding_method.py +++ b/model/src/test/test_infection_seeding_method.py @@ -20,7 +20,7 @@ def test_initialize_infections_exponential(): (I_pre_init,) = I_pre_init_RV() (rate,) = rate_RV() - I_pre_seed = I_pre_seed.value + I_pre_init = I_pre_init.value rate = rate.value infections_default_t_pre_init = InitializeInfectionsExponentialGrowth( n_timepoints, rate=rate_RV @@ -51,7 +51,7 @@ def test_initialize_infections_exponential(): with pytest.raises(ValueError): InitializeInfectionsExponentialGrowth( n_timepoints, rate=rate_RV - ).initialize_infections(I_pre_init_2) + ).initialize_infections(I_pre_init_2.value) # test non-default t_pre_init t_pre_init = 6 @@ -76,7 +76,7 @@ def test_initialize_infections_zero_pad(): n_timepoints = 10 I_pre_init_RV = DeterministicVariable(10.0, name="I_pre_init_RV") (I_pre_init,) = I_pre_init_RV() - I_pre_seed = I_pre_seed.value + I_pre_init = I_pre_init.value infections = InitializeInfectionsZeroPad( n_timepoints @@ -90,7 +90,7 @@ def test_initialize_infections_zero_pad(): ) (I_pre_init_2,) = I_pre_init_RV_2() - I_pre_seed_2 = I_pre_seed_2.value + I_pre_init_2 = I_pre_init_2.value infections_2 = InitializeInfectionsZeroPad( n_timepoints From a2098534f6c4c93ca8c028adef339d2f10131e18 Mon Sep 17 00:00:00 2001 From: George Vega Yon Date: Wed, 24 Jul 2024 15:08:15 -0600 Subject: [PATCH 18/41] Addressing comments by @damonbayer. --- docs/source/tutorials/time.qmd | 2 +- .../pyrenew/latent/infectionswithfeedback.py | 8 +++---- model/src/pyrenew/metaclass.py | 2 +- model/src/pyrenew/model/admissionsmodel.py | 22 +++++++++---------- .../pyrenew/model/rtinfectionsrenewalmodel.py | 12 +++++----- model/src/pyrenew/process/periodiceffect.py | 2 +- model/src/pyrenew/process/rtperiodicdiff.py | 4 ++-- 7 files changed, 26 insertions(+), 26 deletions(-) diff --git a/docs/source/tutorials/time.qmd b/docs/source/tutorials/time.qmd index e80b3e23..6ebc4608 100644 --- a/docs/source/tutorials/time.qmd +++ b/docs/source/tutorials/time.qmd @@ -10,7 +10,7 @@ The fundamental time unit should represent a period of fixed (or approximately f For many infectious disease renewal models of interest, the fundamental time unit will be days, and we will proceed with this tutorial treating days as our fundamental unit. - `pyrenew` deals with time having `RandomVariable`s carry information about (i) their own time unit expressed relative to the fundamental unit (`t_unit`) and (ii) the starting time, `t_start`, measured relative to `t = 0` in model time in fundamental time units. Moreover, return values from `RandomVariable.sample()` are namedtuples with `SampledValue` objects that carry the same information. + `pyrenew` deals with time having `RandomVariable`s carry information about (i) their own time unit expressed relative to the fundamental unit (`t_unit`) and (ii) the starting time, `t_start`, measured relative to `t = 0` in model time in fundamental time units. Moreover, return values from `RandomVariable.sample()` are either tuples or namedtuples with `SampledValue` objects that carry the same information. The tuple `(t_unit, t_start)` can encode different types of time series data. For example: diff --git a/model/src/pyrenew/latent/infectionswithfeedback.py b/model/src/pyrenew/latent/infectionswithfeedback.py index 7da74778..090c4b98 100644 --- a/model/src/pyrenew/latent/infectionswithfeedback.py +++ b/model/src/pyrenew/latent/infectionswithfeedback.py @@ -21,14 +21,14 @@ class InfectionsRtFeedbackSample(NamedTuple): Attributes ---------- - post_initialization_infections : ArrayLike | None, optional + post_initialization_infections : SampledValue | None, optional The estimated latent infections. Defaults to None. - rt : ArrayLike | None, optional + rt : SampledValue | None, optional The adjusted reproduction number. Defaults to None. """ - post_initialization_infections: ArrayLike | None = None - rt: ArrayLike | None = None + post_initialization_infections: SampledValue | None = None + rt: SampledValue | None = None def __repr__(self): return f"InfectionsSample(post_initialization_infections={self.post_initialization_infections}, rt={self.rt})" diff --git a/model/src/pyrenew/metaclass.py b/model/src/pyrenew/metaclass.py index 1a2aa010..12158ec3 100644 --- a/model/src/pyrenew/metaclass.py +++ b/model/src/pyrenew/metaclass.py @@ -113,7 +113,7 @@ def __init__( value : ArrayLike The sampled value. t_start : int, optional - The start time of the data. + The start time of the value. t_unit : int, optional The unit of time relative to the model's fundamental (smallest) time unit. diff --git a/model/src/pyrenew/model/admissionsmodel.py b/model/src/pyrenew/model/admissionsmodel.py index 09938cc6..1a90112c 100644 --- a/model/src/pyrenew/model/admissionsmodel.py +++ b/model/src/pyrenew/model/admissionsmodel.py @@ -7,7 +7,7 @@ from jax.typing import ArrayLike from pyrenew.deterministic import NullObservation -from pyrenew.metaclass import Model, RandomVariable, _assert_sample_and_rtype +from pyrenew.metaclass import Model, RandomVariable, _assert_sample_and_rtype, SampledValue from pyrenew.model.rtinfectionsrenewalmodel import RtInfectionsRenewalModel @@ -17,23 +17,23 @@ class HospModelSample(NamedTuple): Attributes ---------- - Rt : float | None, optional + Rt : SampledValue | None, optional The reproduction number over time. Defaults to None. - latent_infections : ArrayLike | None, optional + latent_infections : SampledValue | None, optional The estimated number of new infections over time. Defaults to None. - infection_hosp_rate : float | None, optional + infection_hosp_rate : SampledValue | None, optional The infected hospitalization rate. Defaults to None. - latent_hosp_admissions : ArrayLike | None, optional + latent_hosp_admissions : SampledValue | None, optional The estimated latent hospitalizations. Defaults to None. - observed_hosp_admissions : ArrayLike | None, optional + observed_hosp_admissions : SampledValue | None, optional The sampled or observed hospital admissions. Defaults to None. """ - Rt: float | None = None - latent_infections: ArrayLike | None = None - infection_hosp_rate: float | None = None - latent_hosp_admissions: ArrayLike | None = None - observed_hosp_admissions: ArrayLike | None = None + Rt: SampledValue | None = None + latent_infections: SampledValue | None = None + infection_hosp_rate: SampledValue | None = None + latent_hosp_admissions: SampledValue | None = None + observed_hosp_admissions: SampledValue | None = None def __repr__(self): return ( diff --git a/model/src/pyrenew/model/rtinfectionsrenewalmodel.py b/model/src/pyrenew/model/rtinfectionsrenewalmodel.py index e563a1cb..e8b46ae5 100644 --- a/model/src/pyrenew/model/rtinfectionsrenewalmodel.py +++ b/model/src/pyrenew/model/rtinfectionsrenewalmodel.py @@ -25,17 +25,17 @@ class RtInfectionsRenewalSample(NamedTuple): Attributes ---------- - Rt : ArrayLike | None, optional + Rt : SampledValue | None, optional The reproduction number over time. Defaults to None. - latent_infections : ArrayLike | None, optional + latent_infections : SampledValue | None, optional The estimated latent infections. Defaults to None. - observed_infections : ArrayLike | None, optional + observed_infections : SampledValue | None, optional The sampled infections. Defaults to None. """ - Rt: ArrayLike | None = None - latent_infections: ArrayLike | None = None - observed_infections: ArrayLike | None = None + Rt: SampledValue | None = None + latent_infections: SampledValue | None = None + observed_infections: SampledValue | None = None def __repr__(self): return ( diff --git a/model/src/pyrenew/process/periodiceffect.py b/model/src/pyrenew/process/periodiceffect.py index e0b81dad..8d7b30b2 100644 --- a/model/src/pyrenew/process/periodiceffect.py +++ b/model/src/pyrenew/process/periodiceffect.py @@ -22,7 +22,7 @@ class PeriodicEffectSample(NamedTuple): The sampled value. """ - value: jnp.ndarray + value: SampledValue def __repr__(self): return f"PeriodicEffectSample(value={self.value})" diff --git a/model/src/pyrenew/process/rtperiodicdiff.py b/model/src/pyrenew/process/rtperiodicdiff.py index c6b7306b..7955d511 100644 --- a/model/src/pyrenew/process/rtperiodicdiff.py +++ b/model/src/pyrenew/process/rtperiodicdiff.py @@ -18,11 +18,11 @@ class RtPeriodicDiffProcessSample(NamedTuple): Attributes ---------- - rt : ArrayLike + rt : SampledValue, optional The sampled Rt. """ - rt: ArrayLike | None = None + rt: SampledValue | None = None def __repr__(self): return f"RtPeriodicDiffProcessSample(rt={self.rt})" From 68f1d22de78394368a15ca070ff74f6691d9070f Mon Sep 17 00:00:00 2001 From: George Vega Yon Date: Wed, 24 Jul 2024 15:10:23 -0600 Subject: [PATCH 19/41] Making pre-commit happy --- model/src/pyrenew/model/admissionsmodel.py | 7 ++++++- model/src/pyrenew/process/periodiceffect.py | 1 - 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/model/src/pyrenew/model/admissionsmodel.py b/model/src/pyrenew/model/admissionsmodel.py index 1a90112c..3bf3aa50 100644 --- a/model/src/pyrenew/model/admissionsmodel.py +++ b/model/src/pyrenew/model/admissionsmodel.py @@ -7,7 +7,12 @@ from jax.typing import ArrayLike from pyrenew.deterministic import NullObservation -from pyrenew.metaclass import Model, RandomVariable, _assert_sample_and_rtype, SampledValue +from pyrenew.metaclass import ( + Model, + RandomVariable, + SampledValue, + _assert_sample_and_rtype, +) from pyrenew.model.rtinfectionsrenewalmodel import RtInfectionsRenewalModel diff --git a/model/src/pyrenew/process/periodiceffect.py b/model/src/pyrenew/process/periodiceffect.py index 8d7b30b2..5fa8a539 100644 --- a/model/src/pyrenew/process/periodiceffect.py +++ b/model/src/pyrenew/process/periodiceffect.py @@ -2,7 +2,6 @@ from typing import NamedTuple -import jax.numpy as jnp import pyrenew.arrayutils as au from pyrenew.metaclass import ( RandomVariable, From b5fc9d6a74d773e18e2a9d6af5366d0e0127852b Mon Sep 17 00:00:00 2001 From: George Vega Yon Date: Wed, 24 Jul 2024 15:59:23 -0600 Subject: [PATCH 20/41] SampledValues are an instance of NamedTuple --- model/src/pyrenew/metaclass.py | 50 +++++++++++----------------------- 1 file changed, 16 insertions(+), 34 deletions(-) diff --git a/model/src/pyrenew/metaclass.py b/model/src/pyrenew/metaclass.py index 12158ec3..d2705af1 100644 --- a/model/src/pyrenew/metaclass.py +++ b/model/src/pyrenew/metaclass.py @@ -5,7 +5,7 @@ """ from abc import ABCMeta, abstractmethod -from typing import get_type_hints +from typing import NamedTuple, get_type_hints import jax import jax.numpy as jnp @@ -94,44 +94,26 @@ def _assert_sample_and_rtype( return None -class SampledValue: +class SampledValue(NamedTuple): """ A container for a sampled value from a RandomVariable. - """ - - def __init__( - self, - value: ArrayLike | None, - t_start: int | None = None, - t_unit: int | None = None, - ) -> None: - """ - Default constructor for SampledValue - - Parameters - ---------- - value : ArrayLike - The sampled value. - t_start : int, optional - The start time of the value. - t_unit : int, optional - The unit of time relative to the model's fundamental (smallest) time unit. - - Returns - ------- - None - """ - if value is not None: - assert isinstance( - value, ArrayLike - ), "value should be an array-like object." + Attributes + ---------- + value : ArrayLike, optional + The sampled value. + t_start : int, optional + The start time of the value. + t_unit : int, optional + The unit of time relative to the model's fundamental (smallest) time unit. + """ - self.value = value - self.t_start = t_start - self.t_unit = t_unit + value: ArrayLike | None = None + t_start: int | None = None + t_unit: int | None = None - return None + def __repr__(self): + return f"SampledValue(value={self.value}, t_start={self.t_start}, t_unit={self.t_unit})" class RandomVariable(metaclass=ABCMeta): From ff258d3275dbd0b22c74d22169bf62840b6a2286 Mon Sep 17 00:00:00 2001 From: George Vega Yon Date: Wed, 24 Jul 2024 16:30:49 -0600 Subject: [PATCH 21/41] Adding suggestions by @dylanhmorris --- docs/source/tutorials/time.qmd | 4 ++-- model/src/pyrenew/deterministic/deterministic.py | 5 +++-- model/src/pyrenew/latent/hospitaladmissions.py | 6 +++--- model/src/pyrenew/latent/infection_initialization_method.py | 3 +-- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/docs/source/tutorials/time.qmd b/docs/source/tutorials/time.qmd index 6ebc4608..f171b896 100644 --- a/docs/source/tutorials/time.qmd +++ b/docs/source/tutorials/time.qmd @@ -12,7 +12,7 @@ For many infectious disease renewal models of interest, the fundamental time uni `pyrenew` deals with time having `RandomVariable`s carry information about (i) their own time unit expressed relative to the fundamental unit (`t_unit`) and (ii) the starting time, `t_start`, measured relative to `t = 0` in model time in fundamental time units. Moreover, return values from `RandomVariable.sample()` are either tuples or namedtuples with `SampledValue` objects that carry the same information. -The tuple `(t_unit, t_start)` can encode different types of time series data. For example: +The `t_unit, t_start` pair can encode different types of time series data. For example: | Description | `t_unit` | `t_start` | |:-----------------|----------------:|-----------------:| @@ -41,4 +41,4 @@ Rt_aligned, infections_aligned = align([Rt, infections]) ### Retrieving time information from sites -Since numpyro only stores Jax arrays, we cannot store the time information in the arrays themselves. Next iterations of `pyrenew` should include a way to retrieve the time information from the sites of the model after running them. +Future versions of `pyrenew` could include a way to retrieve the time information for sites keyed by site name the model. diff --git a/model/src/pyrenew/deterministic/deterministic.py b/model/src/pyrenew/deterministic/deterministic.py index e61e8ebb..9cecbf45 100644 --- a/model/src/pyrenew/deterministic/deterministic.py +++ b/model/src/pyrenew/deterministic/deterministic.py @@ -89,8 +89,9 @@ def sample( Returns ------- - tuple - Containing the stored values during construction wrapped in a SampledValue. + tuple[SampledValue] + A length-one tuple whose single entry is a :class:`SampledValue` + instance with `value=self.vars`, `t_start=self.t_start`, and `t_unit=self.t_unit`. """ if record: npro.deterministic(self.name, self.vars) diff --git a/model/src/pyrenew/latent/hospitaladmissions.py b/model/src/pyrenew/latent/hospitaladmissions.py index 5f467ce4..a37d6e5a 100644 --- a/model/src/pyrenew/latent/hospitaladmissions.py +++ b/model/src/pyrenew/latent/hospitaladmissions.py @@ -162,7 +162,7 @@ def sample( Parameters ---------- - latent : ArrayLike or SampledValue + latent_infections : ArrayLike Latent infections. **kwargs : dict, optional Additional keyword arguments passed through to internal `sample()` @@ -210,7 +210,7 @@ def sample( infection_hosp_rate=infection_hosp_rate, latent_hospital_admissions=SampledValue( value=latent_hospital_admissions, - t_start=self.infection_to_admission_interval_rv.t_start, - t_unit=self.infection_to_admission_interval_rv.t_unit, + t_start=self.t_start, + t_unit=self.t_unit, ), ) diff --git a/model/src/pyrenew/latent/infection_initialization_method.py b/model/src/pyrenew/latent/infection_initialization_method.py index 6bd8c2f2..3f58d93e 100644 --- a/model/src/pyrenew/latent/infection_initialization_method.py +++ b/model/src/pyrenew/latent/infection_initialization_method.py @@ -176,8 +176,7 @@ def initialize_infections(self, I_pre_init: ArrayLike): raise ValueError( f"I_pre_init must be an array of size 1. Got size {I_pre_init.size}." ) - (rate,) = self.rate() - rate = rate.value + rate = self.rate()[0].value if rate.size != 1: raise ValueError( f"rate must be an array of size 1. Got size {rate.size}." From 771f6b5adf530d5cbc3d75660df0ec4a908aafcc Mon Sep 17 00:00:00 2001 From: George Vega Yon Date: Wed, 24 Jul 2024 16:34:00 -0600 Subject: [PATCH 22/41] Missing comment --- model/src/pyrenew/latent/infectionswithfeedback.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/model/src/pyrenew/latent/infectionswithfeedback.py b/model/src/pyrenew/latent/infectionswithfeedback.py index 090c4b98..55a752d9 100644 --- a/model/src/pyrenew/latent/infectionswithfeedback.py +++ b/model/src/pyrenew/latent/infectionswithfeedback.py @@ -162,8 +162,7 @@ def sample( # Sampling inf feedback strength inf_feedback_strength, *_ = self.infection_feedback_strength( **kwargs, - ) - inf_feedback_strength = inf_feedback_strength.value + )[0].value # Making sure inf_feedback_strength spans the Rt length if inf_feedback_strength.size == 1: From 7a8bf3b7dc767c6bb2706f31281328f0119dd1ad Mon Sep 17 00:00:00 2001 From: George Vega Yon Date: Wed, 24 Jul 2024 16:34:47 -0600 Subject: [PATCH 23/41] Making pre-commit happy --- model/src/pyrenew/deterministic/deterministic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model/src/pyrenew/deterministic/deterministic.py b/model/src/pyrenew/deterministic/deterministic.py index 9cecbf45..34f4c229 100644 --- a/model/src/pyrenew/deterministic/deterministic.py +++ b/model/src/pyrenew/deterministic/deterministic.py @@ -90,7 +90,7 @@ def sample( Returns ------- tuple[SampledValue] - A length-one tuple whose single entry is a :class:`SampledValue` + A length-one tuple whose single entry is a :class:`SampledValue` instance with `value=self.vars`, `t_start=self.t_start`, and `t_unit=self.t_unit`. """ if record: From 460993a0752220ca4b7237412f4befa058ca43f4 Mon Sep 17 00:00:00 2001 From: George Vega Yon Date: Wed, 24 Jul 2024 16:46:49 -0600 Subject: [PATCH 24/41] Fixing infections with feedback --- docs/source/tutorials/time.qmd | 2 +- model/src/pyrenew/latent/infectionswithfeedback.py | 2 +- model/src/test/test_infectionsrtfeedback.py | 2 ++ 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/docs/source/tutorials/time.qmd b/docs/source/tutorials/time.qmd index f171b896..5698faf2 100644 --- a/docs/source/tutorials/time.qmd +++ b/docs/source/tutorials/time.qmd @@ -10,7 +10,7 @@ The fundamental time unit should represent a period of fixed (or approximately f For many infectious disease renewal models of interest, the fundamental time unit will be days, and we will proceed with this tutorial treating days as our fundamental unit. - `pyrenew` deals with time having `RandomVariable`s carry information about (i) their own time unit expressed relative to the fundamental unit (`t_unit`) and (ii) the starting time, `t_start`, measured relative to `t = 0` in model time in fundamental time units. Moreover, return values from `RandomVariable.sample()` are either tuples or namedtuples with `SampledValue` objects that carry the same information. +`pyrenew` deals with time having `RandomVariable`s carry information about (i) their own time unit expressed relative to the fundamental unit (`t_unit`) and (ii) the starting time, `t_start`, measured relative to `t = 0` in model time in fundamental time units. Return values from `RandomVariable.sample()` are `tuples` or `namedtuple`s of `SampledValue` objects. Each such `SampledValue` is optionally time-aware with specifiable `t_start` and `t_unit` attributes. The `t_unit, t_start` pair can encode different types of time series data. For example: diff --git a/model/src/pyrenew/latent/infectionswithfeedback.py b/model/src/pyrenew/latent/infectionswithfeedback.py index 55a752d9..d8d0477b 100644 --- a/model/src/pyrenew/latent/infectionswithfeedback.py +++ b/model/src/pyrenew/latent/infectionswithfeedback.py @@ -160,7 +160,7 @@ def sample( I0 = I0[-gen_int_rev.size :] # Sampling inf feedback strength - inf_feedback_strength, *_ = self.infection_feedback_strength( + inf_feedback_strength = self.infection_feedback_strength( **kwargs, )[0].value diff --git a/model/src/test/test_infectionsrtfeedback.py b/model/src/test/test_infectionsrtfeedback.py index ef1b564f..f3e17a3d 100644 --- a/model/src/test/test_infectionsrtfeedback.py +++ b/model/src/test/test_infectionsrtfeedback.py @@ -157,3 +157,5 @@ def test_infectionsrtfeedback_feedback(): assert_array_almost_equal(samp1.rt.value, res["rt"]) return None + +test_infectionsrtfeedback() \ No newline at end of file From 8f807dc7b628eea9ff350dad5923944d21f44202 Mon Sep 17 00:00:00 2001 From: George Vega Yon Date: Wed, 24 Jul 2024 16:49:29 -0600 Subject: [PATCH 25/41] Removing explicit test call --- model/src/test/test_infectionsrtfeedback.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/model/src/test/test_infectionsrtfeedback.py b/model/src/test/test_infectionsrtfeedback.py index f3e17a3d..ef1b564f 100644 --- a/model/src/test/test_infectionsrtfeedback.py +++ b/model/src/test/test_infectionsrtfeedback.py @@ -157,5 +157,3 @@ def test_infectionsrtfeedback_feedback(): assert_array_almost_equal(samp1.rt.value, res["rt"]) return None - -test_infectionsrtfeedback() \ No newline at end of file From fc174887103889490573ad904482cb7ad45e2a83 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Thu, 25 Jul 2024 08:09:19 -0700 Subject: [PATCH 26/41] Fix pre-commit issues and remaining numpyro/npro conflicts --- model/src/pyrenew/deterministic/deterministic.py | 12 ++++++++---- .../latent/infection_initialization_process.py | 6 +++--- model/src/test/test_latent_admissions.py | 1 - 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/model/src/pyrenew/deterministic/deterministic.py b/model/src/pyrenew/deterministic/deterministic.py index c66ea88a..c59c3ae1 100644 --- a/model/src/pyrenew/deterministic/deterministic.py +++ b/model/src/pyrenew/deterministic/deterministic.py @@ -82,7 +82,8 @@ def sample( Parameters ---------- record : bool, optional - Whether to record the value of the deterministic RandomVariable. Defaults to True. + Whether to record the value of the deterministic + RandomVariable. Defaults to True. **kwargs : dict, optional Additional keyword arguments passed through to internal sample calls, should there be any. @@ -90,11 +91,14 @@ def sample( Returns ------- tuple[SampledValue] - A length-one tuple whose single entry is a :class:`SampledValue` - instance with `value=self.vars`, `t_start=self.t_start`, and `t_unit=self.t_unit`. + A length-one tuple whose single entry is a + :class:`SampledValue` + instance with `value=self.vars`, + `t_start=self.t_start`, and + `t_unit=self.t_unit`. """ if record: - npro.deterministic(self.name, self.vars) + numpyro.deterministic(self.name, self.vars) return ( SampledValue( value=self.vars, diff --git a/model/src/pyrenew/latent/infection_initialization_process.py b/model/src/pyrenew/latent/infection_initialization_process.py index e3e129d7..8e8c62e2 100644 --- a/model/src/pyrenew/latent/infection_initialization_process.py +++ b/model/src/pyrenew/latent/infection_initialization_process.py @@ -93,7 +93,8 @@ def sample(self) -> tuple: Returns ------- tuple - a tuple where the only element is an array with the number of initialized infections at each time point. + a tuple where the only element is an array with + the number of initialized infections at each time point. """ (I_pre_init,) = self.I_pre_init_rv() @@ -101,8 +102,7 @@ def sample(self) -> tuple: infection_initialization = self.infection_init_method( I_pre_init.value, ) - npro.deterministic(self.name, infection_initialization) - + numpyro.deterministic(self.name, infection_initialization) return ( SampledValue( diff --git a/model/src/test/test_latent_admissions.py b/model/src/test/test_latent_admissions.py index c61ba915..dbf5ab23 100644 --- a/model/src/test/test_latent_admissions.py +++ b/model/src/test/test_latent_admissions.py @@ -40,7 +40,6 @@ def test_admissions_sample(): inf1 = Infections() - with numpyro.handlers.seed(rng_seed=np.random.randint(1, 600)): inf_sampled1 = inf1(Rt=sim_rt, gen_int=gen_int, I0=i0) From 353fbc202245c6b64a49ea16b67b4abb0b353ca5 Mon Sep 17 00:00:00 2001 From: George Vega Yon Date: Thu, 25 Jul 2024 09:23:01 -0600 Subject: [PATCH 27/41] Addressing @dylanhmorris' comments on defaults for t_start and t_unit --- model/src/pyrenew/deterministic/nullrv.py | 6 +++--- model/src/pyrenew/deterministic/process.py | 18 ++++++++++-------- model/src/pyrenew/latent/hospitaladmissions.py | 2 +- model/src/pyrenew/latent/infections.py | 8 +++++++- .../pyrenew/latent/infectionswithfeedback.py | 2 +- model/src/pyrenew/metaclass.py | 14 ++++++++++++-- .../pyrenew/observation/negativebinomial.py | 8 +++++++- model/src/pyrenew/observation/poisson.py | 8 +++++++- model/src/pyrenew/process/ar.py | 8 +++++++- model/src/pyrenew/process/firstdifferencear.py | 4 +++- model/src/pyrenew/process/periodiceffect.py | 4 +++- model/src/pyrenew/process/rtperiodicdiff.py | 4 +++- model/src/pyrenew/process/simplerandomwalk.py | 8 +++++++- model/src/test/test_periodiceffect.py | 12 ++++-------- model/src/test/test_transformed_rv_class.py | 7 +++++-- 15 files changed, 80 insertions(+), 33 deletions(-) diff --git a/model/src/pyrenew/deterministic/nullrv.py b/model/src/pyrenew/deterministic/nullrv.py index 119531d0..ffd78307 100644 --- a/model/src/pyrenew/deterministic/nullrv.py +++ b/model/src/pyrenew/deterministic/nullrv.py @@ -50,7 +50,7 @@ def sample( Containing a SampledValue with None. """ - return (SampledValue(None),) + return (SampledValue(None, t_start=self.t_start, t_unit=self.t_unit),) class NullProcess(NullVariable): @@ -99,7 +99,7 @@ def sample( Containing a SampledValue with None. """ - return (SampledValue(None),) + return (SampledValue(None, t_start=self.t_start, t_unit=self.t_unit),) class NullObservation(NullVariable): @@ -155,4 +155,4 @@ def sample( Containing a SampledValue with None. """ - return (SampledValue(None),) + return (SampledValue(None, t_start=self.t_start, t_unit=self.t_unit),) diff --git a/model/src/pyrenew/deterministic/process.py b/model/src/pyrenew/deterministic/process.py index 7210356b..93ff69f2 100644 --- a/model/src/pyrenew/deterministic/process.py +++ b/model/src/pyrenew/deterministic/process.py @@ -38,18 +38,20 @@ def sample( dif = duration - res.value.shape[0] if dif > 0: - return ( + res = ( SampledValue( jnp.hstack([res.value, jnp.repeat(res.value[-1], dif)]), t_start=self.t_start, t_unit=self.t_unit, ), ) + else: + res = ( + SampledValue( + value=res.value[:duration], + t_start=self.t_start, + t_unit=self.t_unit, + ), + ) - return ( - SampledValue( - value=res.value[:duration], - t_start=self.t_start, - t_unit=self.t_unit, - ), - ) + return res diff --git a/model/src/pyrenew/latent/hospitaladmissions.py b/model/src/pyrenew/latent/hospitaladmissions.py index a37d6e5a..65723410 100644 --- a/model/src/pyrenew/latent/hospitaladmissions.py +++ b/model/src/pyrenew/latent/hospitaladmissions.py @@ -196,7 +196,7 @@ def sample( )[0].value ) - # Applying probability of hospitalization effect + # Applying reporting probability latent_hospital_admissions = ( latent_hospital_admissions * self.hosp_report_prob_rv(**kwargs)[0].value diff --git a/model/src/pyrenew/latent/infections.py b/model/src/pyrenew/latent/infections.py index 142b25bc..e5da11d6 100644 --- a/model/src/pyrenew/latent/infections.py +++ b/model/src/pyrenew/latent/infections.py @@ -97,4 +97,10 @@ def sample( reversed_generation_interval_pmf=gen_int_rev, ) - return InfectionsSample(SampledValue(post_initialization_infections)) + return InfectionsSample( + SampledValue( + post_initialization_infections, + t_start=self.t_start, + t_unit=self.t_unit, + ) + ) diff --git a/model/src/pyrenew/latent/infectionswithfeedback.py b/model/src/pyrenew/latent/infectionswithfeedback.py index d8d0477b..04c83974 100644 --- a/model/src/pyrenew/latent/infectionswithfeedback.py +++ b/model/src/pyrenew/latent/infectionswithfeedback.py @@ -202,5 +202,5 @@ def sample( post_initialization_infections=SampledValue( post_initialization_infections ), - rt=SampledValue(Rt_adj), + rt=SampledValue(Rt_adj, t_start=self.t_start, t_unit=self.t_unit), ) diff --git a/model/src/pyrenew/metaclass.py b/model/src/pyrenew/metaclass.py index d2705af1..fc913441 100644 --- a/model/src/pyrenew/metaclass.py +++ b/model/src/pyrenew/metaclass.py @@ -329,7 +329,13 @@ def sample( fn=self.dist, obs=obs, ) - return (SampledValue(jnp.atleast_1d(sample)),) + return ( + SampledValue( + jnp.atleast_1d(sample), + t_start=self.t_start, + t_unit=self.t_unit, + ), + ) class Model(metaclass=ABCMeta): @@ -683,7 +689,11 @@ def sample(self, **kwargs) -> tuple: untransformed_values = self.base_rv.sample(**kwargs) return tuple( - SampledValue(t(uv.value)) + SampledValue( + t(uv.value), + t_start=self.t_start, + t_unit=self.t_unit, + ) for t, uv in zip(self.transforms, untransformed_values) ) diff --git a/model/src/pyrenew/observation/negativebinomial.py b/model/src/pyrenew/observation/negativebinomial.py index bd06d45a..cf583021 100644 --- a/model/src/pyrenew/observation/negativebinomial.py +++ b/model/src/pyrenew/observation/negativebinomial.py @@ -98,4 +98,10 @@ def sample( obs=obs, ) - return (SampledValue(negative_binomial_sample),) + return ( + SampledValue( + negative_binomial_sample, + t_start=self.t_start, + t_unit=self.t_unit, + ), + ) diff --git a/model/src/pyrenew/observation/poisson.py b/model/src/pyrenew/observation/poisson.py index 28d92c50..12e38c57 100644 --- a/model/src/pyrenew/observation/poisson.py +++ b/model/src/pyrenew/observation/poisson.py @@ -72,4 +72,10 @@ def sample( fn=dist.Poisson(rate=mu + self.eps), obs=obs, ) - return (SampledValue(poisson_sample),) + return ( + SampledValue( + poisson_sample, + t_start=self.t_start, + t_unit=self.t_unit, + ), + ) diff --git a/model/src/pyrenew/process/ar.py b/model/src/pyrenew/process/ar.py index c393c54d..0764ff8a 100644 --- a/model/src/pyrenew/process/ar.py +++ b/model/src/pyrenew/process/ar.py @@ -91,7 +91,13 @@ def _ar_scanner(carry, next): # numpydoc ignore=GL08 ) last, ts = lax.scan(_ar_scanner, inits - self.mean, noise) - return (SampledValue(jnp.hstack([inits, self.mean + ts.flatten()])),) + return ( + SampledValue( + jnp.hstack([inits, self.mean + ts.flatten()]), + t_start=self.t_start, + t_unit=self.t_unit, + ), + ) @staticmethod def validate(): # numpydoc ignore=RT01 diff --git a/model/src/pyrenew/process/firstdifferencear.py b/model/src/pyrenew/process/firstdifferencear.py index ed9ac575..d08eec2e 100644 --- a/model/src/pyrenew/process/firstdifferencear.py +++ b/model/src/pyrenew/process/firstdifferencear.py @@ -74,7 +74,9 @@ def sample( ) return ( SampledValue( - init_val + jnp.cumsum(rates_of_change.value.flatten()) + init_val + jnp.cumsum(rates_of_change.value.flatten()), + t_start=self.t_start, + t_unit=self.t_unit, ), ) diff --git a/model/src/pyrenew/process/periodiceffect.py b/model/src/pyrenew/process/periodiceffect.py index 5fa8a539..ffb2f183 100644 --- a/model/src/pyrenew/process/periodiceffect.py +++ b/model/src/pyrenew/process/periodiceffect.py @@ -117,7 +117,9 @@ def sample(self, duration: int, **kwargs): self.broadcaster( data=self.quantity_to_broadcast.sample(**kwargs)[0].value, n_timepoints=duration, - ) + ), + t_start=self.t_start, + t_unit=self.t_unit, ) ) diff --git a/model/src/pyrenew/process/rtperiodicdiff.py b/model/src/pyrenew/process/rtperiodicdiff.py index 7955d511..52a57b48 100644 --- a/model/src/pyrenew/process/rtperiodicdiff.py +++ b/model/src/pyrenew/process/rtperiodicdiff.py @@ -193,7 +193,9 @@ def sample( return RtPeriodicDiffProcessSample( rt=SampledValue( - self.broadcaster(jnp.exp(log_rt.value.flatten()), duration) + self.broadcaster(jnp.exp(log_rt.value.flatten()), duration), + t_start=self.t_start, + t_unit=self.t_unit, ), ) diff --git a/model/src/pyrenew/process/simplerandomwalk.py b/model/src/pyrenew/process/simplerandomwalk.py index dd5655f4..a88ea0d7 100644 --- a/model/src/pyrenew/process/simplerandomwalk.py +++ b/model/src/pyrenew/process/simplerandomwalk.py @@ -86,7 +86,13 @@ def transition(x_prev, _): xs=jnp.arange(n_steps - 1), ) - return (SampledValue(jnp.hstack([init.value, x.flatten()])),) + return ( + SampledValue( + jnp.hstack([init.value, x.flatten()]), + t_start=self.t_start, + t_unit=self.t_unit, + ), + ) @staticmethod def validate(): diff --git a/model/src/test/test_periodiceffect.py b/model/src/test/test_periodiceffect.py index d5014a9e..1fd736f7 100644 --- a/model/src/test/test_periodiceffect.py +++ b/model/src/test/test_periodiceffect.py @@ -44,9 +44,9 @@ def test_periodiceffect() -> None: params["offset"] = 5 pe = PeriodicEffect(**params) with npro.handlers.seed(rng_seed=np.random.randint(1, 600)): - ans2 = pe(duration=duration).value.value + ans2 = pe(duration=duration)[0].value - # Checking that the shape of the sampled Rt is correct + ans2 = pe(duration=duration)[0].value assert ans2.shape == (duration,) # This time series should be the same as the previous one, but shifted by @@ -81,13 +81,9 @@ def test_weeklyeffect() -> None: pe = PeriodicEffect(**params) pe2 = DayOfWeekEffect(**params2) - ans1 = pe(duration=duration).value.value - ans2 = pe2(duration=duration).value.value + ans1 = pe(duration=duration)[0].value + ans2 = pe2(duration=duration)[0].value assert_array_equal(ans1, ans2) return None - - -test_periodiceffect() -test_weeklyeffect() diff --git a/model/src/test/test_transformed_rv_class.py b/model/src/test/test_transformed_rv_class.py index 30fa2821..1d1aa879 100644 --- a/model/src/test/test_transformed_rv_class.py +++ b/model/src/test/test_transformed_rv_class.py @@ -32,9 +32,12 @@ def sample(self, **kwargs): Returns ------- tuple - (SampledValue(1), SampledValue(5)) + (SampledValue(1, t_start=self.t_start, t_unit=self.t_unit), SampledValue(5, t_start=self.t_start, t_unit=self.t_unit)) """ - return (SampledValue(1), SampledValue(5)) + return ( + SampledValue(1, t_start=self.t_start, t_unit=self.t_unit), + SampledValue(5, t_start=self.t_start, t_unit=self.t_unit), + ) def sample_length(self): """ From b9a87f8156823be5eb3769da6acbc2cf83f919ec Mon Sep 17 00:00:00 2001 From: George Vega Yon Date: Thu, 25 Jul 2024 09:42:09 -0600 Subject: [PATCH 28/41] Forgot to remove a test call --- model/src/test/test_latent_admissions.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/model/src/test/test_latent_admissions.py b/model/src/test/test_latent_admissions.py index 96134611..d020f637 100644 --- a/model/src/test/test_latent_admissions.py +++ b/model/src/test/test_latent_admissions.py @@ -84,5 +84,3 @@ def test_admissions_sample(): sim_hosp_1.latent_hospital_admissions.value, inf_sampled1[0].value, ) - -test_admissions_sample() \ No newline at end of file From 6026b19e857d05a4ec2b09e16678ac84887187c7 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Thu, 25 Jul 2024 16:06:13 -0700 Subject: [PATCH 29/41] Autoformat files, fix typo caught by typos hook --- model/src/pyrenew/deterministic/deterministic.py | 3 +-- model/src/test/test_periodiceffect.py | 1 - 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/model/src/pyrenew/deterministic/deterministic.py b/model/src/pyrenew/deterministic/deterministic.py index 422c646d..5e2f8f9f 100644 --- a/model/src/pyrenew/deterministic/deterministic.py +++ b/model/src/pyrenew/deterministic/deterministic.py @@ -97,7 +97,7 @@ def sample( `t_unit=self.t_unit`. """ if record: - numpyro.deterministic(self.name, self.vallue) + numpyro.deterministic(self.name, self.value) return ( SampledValue( value=self.value, @@ -105,4 +105,3 @@ def sample( t_unit=self.t_unit, ), ) - \ No newline at end of file diff --git a/model/src/test/test_periodiceffect.py b/model/src/test/test_periodiceffect.py index 3f3ebdcc..7173efa4 100644 --- a/model/src/test/test_periodiceffect.py +++ b/model/src/test/test_periodiceffect.py @@ -44,7 +44,6 @@ def test_periodiceffect() -> None: with numpyro.handlers.seed(rng_seed=223): ans2 = pe(duration=duration)[0].value - ans2 = pe(duration=duration)[0].value assert ans2.shape == (duration,) From ee817e596a961830c8e3cf08d204d26196ba34d7 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Thu, 25 Jul 2024 19:06:33 -0400 Subject: [PATCH 30/41] Update model/src/pyrenew/latent/infectionswithfeedback.py --- model/src/pyrenew/latent/infectionswithfeedback.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/model/src/pyrenew/latent/infectionswithfeedback.py b/model/src/pyrenew/latent/infectionswithfeedback.py index 9169214c..f3934d4b 100644 --- a/model/src/pyrenew/latent/infectionswithfeedback.py +++ b/model/src/pyrenew/latent/infectionswithfeedback.py @@ -200,7 +200,9 @@ def sample( return InfectionsRtFeedbackSample( post_initialization_infections=SampledValue( - post_initialization_infections + value=post_initialization_infections, + t_start=self.t_start, + t_unit=self.t_unit) ), rt=SampledValue(Rt_adj, t_start=self.t_start, t_unit=self.t_unit), ) From 302ca2edb29c5e7f9b9d597d79bcc889b2e593fd Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Thu, 25 Jul 2024 16:12:27 -0700 Subject: [PATCH 31/41] Fix typo in infectionswithfeedback.py that caused ill-formed code and consequent test failures --- model/src/pyrenew/latent/infectionswithfeedback.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model/src/pyrenew/latent/infectionswithfeedback.py b/model/src/pyrenew/latent/infectionswithfeedback.py index f3934d4b..fffa7307 100644 --- a/model/src/pyrenew/latent/infectionswithfeedback.py +++ b/model/src/pyrenew/latent/infectionswithfeedback.py @@ -202,7 +202,7 @@ def sample( post_initialization_infections=SampledValue( value=post_initialization_infections, t_start=self.t_start, - t_unit=self.t_unit) + t_unit=self.t_unit, ), rt=SampledValue(Rt_adj, t_start=self.t_start, t_unit=self.t_unit), ) From 32cb24d9dcf0cab1fcbaef9263b7cffee34f8c71 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Thu, 25 Jul 2024 16:19:08 -0700 Subject: [PATCH 32/41] Fix tutorial bug introduced in merge conflict resolution --- docs/source/tutorials/hospital_admissions_model.qmd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/tutorials/hospital_admissions_model.qmd b/docs/source/tutorials/hospital_admissions_model.qmd index d1e79888..32bab32d 100644 --- a/docs/source/tutorials/hospital_admissions_model.qmd +++ b/docs/source/tutorials/hospital_admissions_model.qmd @@ -199,7 +199,7 @@ class MyRt(metaclass.RandomVariable): base_rv=process.SimpleRandomWalkProcess( name="log_rt", step_rv=metaclass.DistributionalRV( - name="rw_step_rv", dist=dist.Normal(0, sd_rt) + name="rw_step_rv", dist=dist.Normal(0, sd_rt.value) ), init_rv=metaclass.DistributionalRV( name="init_log_Rt_rv", dist=dist.Normal(0, 0.2) From 43012ad0690ce13da7166b3e3904aed7c7e4c7a6 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Thu, 25 Jul 2024 16:20:40 -0700 Subject: [PATCH 33/41] Equality assertion => almost equality assertion --- model/src/test/test_rtperiodicdiff.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/model/src/test/test_rtperiodicdiff.py b/model/src/test/test_rtperiodicdiff.py index 00ae6e94..ca34dfa4 100644 --- a/model/src/test/test_rtperiodicdiff.py +++ b/model/src/test/test_rtperiodicdiff.py @@ -87,9 +87,9 @@ def test_rtweeklydiff() -> None: # Checking that the shape of the sampled Rt is correct assert rt2.shape == (duration,) - # This time series should be the same as the previous one, but shifted by - # 5 days - assert_array_equal(rt[5:], rt2[:-5]) + # This time series should be the same as the previous one, + # but shifted by 5 days + assert_array_almost_equal(rt[5:], rt2[:-5]) return None From 4173b75b53914de530293de5671648603d0fe524 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Thu, 25 Jul 2024 16:25:45 -0700 Subject: [PATCH 34/41] Set identical seeds --- model/src/test/test_rtperiodicdiff.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/model/src/test/test_rtperiodicdiff.py b/model/src/test/test_rtperiodicdiff.py index ca34dfa4..4fb2cbfb 100644 --- a/model/src/test/test_rtperiodicdiff.py +++ b/model/src/test/test_rtperiodicdiff.py @@ -81,7 +81,7 @@ def test_rtweeklydiff() -> None: params["offset"] = 5 rtwd = RtWeeklyDiffProcess(**params) - with numpyro.handlers.seed(rng_seed=121): + with numpyro.handlers.seed(rng_seed=223): rt2 = rtwd(duration=duration).rt.value # Checking that the shape of the sampled Rt is correct @@ -89,7 +89,7 @@ def test_rtweeklydiff() -> None: # This time series should be the same as the previous one, # but shifted by 5 days - assert_array_almost_equal(rt[5:], rt2[:-5]) + assert_array_equal(rt[5:], rt2[:-5]) return None From f635c0c3a31a6ce0f9380355b70080a4b7af2fa0 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Thu, 25 Jul 2024 16:37:21 -0700 Subject: [PATCH 35/41] Fix another tutorial typo --- docs/source/tutorials/extending_pyrenew.qmd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/tutorials/extending_pyrenew.qmd b/docs/source/tutorials/extending_pyrenew.qmd index c9dc0da6..664cd3bd 100644 --- a/docs/source/tutorials/extending_pyrenew.qmd +++ b/docs/source/tutorials/extending_pyrenew.qmd @@ -52,7 +52,7 @@ I0 = InfectionInitializationProcess( DistributionalRV(name="I0", dist=dist.LogNormal(0, 1)), InitializeInfectionsExponentialGrowth( gen_int_array.size, - DeterministicVariable(name="rate", value=0.05, name="rate"), + DeterministicVariable(name="rate", value=0.05), ), t_unit=1, ) From ee9dbe220550363142ec307d7d5164f4fa17c8a2 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Fri, 26 Jul 2024 15:28:33 -0400 Subject: [PATCH 36/41] Update model/src/pyrenew/deterministic/process.py Co-authored-by: Damon Bayer --- model/src/pyrenew/deterministic/process.py | 26 +++++++++------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/model/src/pyrenew/deterministic/process.py b/model/src/pyrenew/deterministic/process.py index 93ff69f2..7ff45b35 100644 --- a/model/src/pyrenew/deterministic/process.py +++ b/model/src/pyrenew/deterministic/process.py @@ -38,20 +38,14 @@ def sample( dif = duration - res.value.shape[0] if dif > 0: - res = ( - SampledValue( - jnp.hstack([res.value, jnp.repeat(res.value[-1], dif)]), - t_start=self.t_start, - t_unit=self.t_unit, - ), - ) + value = jnp.hstack([res.value, jnp.repeat(res.value[-1], dif)]) else: - res = ( - SampledValue( - value=res.value[:duration], - t_start=self.t_start, - t_unit=self.t_unit, - ), - ) - - return res + value = res.value[:duration] + + res = SampledValue( + value, + t_start=self.t_start, + t_unit=self.t_unit, + ) + + return (res,) From 2177215733e501b4492b591fffee0884464bdf9e Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Fri, 26 Jul 2024 12:34:15 -0700 Subject: [PATCH 37/41] Fix type hinting for HospitalAdmissionsSample --- model/src/pyrenew/latent/hospitaladmissions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/model/src/pyrenew/latent/hospitaladmissions.py b/model/src/pyrenew/latent/hospitaladmissions.py index 59e030f1..a6ad5cb1 100644 --- a/model/src/pyrenew/latent/hospitaladmissions.py +++ b/model/src/pyrenew/latent/hospitaladmissions.py @@ -18,13 +18,13 @@ class HospitalAdmissionsSample(NamedTuple): Attributes ---------- - infection_hosp_rate : float, optional + infection_hosp_rate : SampledValue, optional The infection-to-hospitalization rate. Defaults to None. latent_hospital_admissions : SampledValue or None The computed number of hospital admissions. Defaults to None. """ - infection_hosp_rate: float | None = None + infection_hosp_rate: SampledValue | None = None latent_hospital_admissions: SampledValue | None = None def __repr__(self): From 9ff9deaf3a4c5f1fac94c7bf8fa9da8f9f202197 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Fri, 26 Jul 2024 12:37:32 -0700 Subject: [PATCH 38/41] Update determinsiticprocess docstring --- model/src/pyrenew/deterministic/process.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/model/src/pyrenew/deterministic/process.py b/model/src/pyrenew/deterministic/process.py index 7ff45b35..1f9bff53 100644 --- a/model/src/pyrenew/deterministic/process.py +++ b/model/src/pyrenew/deterministic/process.py @@ -29,8 +29,9 @@ def sample( Returns ------- - tuple - Containing the stored values during construction wrapped in a SampledValue. + tuple[SampledValue] + containing the deterministic value(s) provided + at construction as a series of length `duration`. """ res, *_ = super().sample(**kwargs) From b298bc287308ea7edf53682a8924ca35aaaf7877 Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Fri, 26 Jul 2024 12:39:54 -0700 Subject: [PATCH 39/41] Update vars => value in docstring for DeterministicVariable --- model/src/pyrenew/deterministic/deterministic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model/src/pyrenew/deterministic/deterministic.py b/model/src/pyrenew/deterministic/deterministic.py index 5e2f8f9f..2bb03333 100644 --- a/model/src/pyrenew/deterministic/deterministic.py +++ b/model/src/pyrenew/deterministic/deterministic.py @@ -92,7 +92,7 @@ def sample( tuple[SampledValue] A length-one tuple whose single entry is a :class:`SampledValue` - instance with `value=self.vars`, + instance with `value=self.value`, `t_start=self.t_start`, and `t_unit=self.t_unit`. """ From 859dbc96712a887bf2a5daa98156bebbedbc8dab Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Fri, 26 Jul 2024 12:46:40 -0700 Subject: [PATCH 40/41] Clarify relationship between t_start/t_unit of a RandomVariable and of a corresponding SampledValue --- docs/source/tutorials/time.qmd | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/docs/source/tutorials/time.qmd b/docs/source/tutorials/time.qmd index 5698faf2..ac551af4 100644 --- a/docs/source/tutorials/time.qmd +++ b/docs/source/tutorials/time.qmd @@ -10,7 +10,9 @@ The fundamental time unit should represent a period of fixed (or approximately f For many infectious disease renewal models of interest, the fundamental time unit will be days, and we will proceed with this tutorial treating days as our fundamental unit. -`pyrenew` deals with time having `RandomVariable`s carry information about (i) their own time unit expressed relative to the fundamental unit (`t_unit`) and (ii) the starting time, `t_start`, measured relative to `t = 0` in model time in fundamental time units. Return values from `RandomVariable.sample()` are `tuples` or `namedtuple`s of `SampledValue` objects. Each such `SampledValue` is optionally time-aware with specifiable `t_start` and `t_unit` attributes. +`pyrenew` deals with time having `RandomVariable`s carry information about (i) their own time unit expressed relative to the fundamental unit (`t_unit`) and (ii) the starting time, `t_start`, measured relative to `t = 0` in model time in fundamental time units. Return values from `RandomVariable.sample()` are `tuples` or `namedtuple`s of `SampledValue` objects. `SampledValue` objects can have `t_start` and `t_unit` attributes. + +By default, `SampledValue` objects carry the `t_start` and `t_unit` of the `RandomVariable` from which they are `sample()`-d. One might override this default to allow a `RandomVariable.sample()` call to produce multiple `SampledValue`s with different time-units, or with different start-points relative to the `RandomVariable`'s own `t_start`. The `t_unit, t_start` pair can encode different types of time series data. For example: From c1111feaf5a30d33c266231104f44e253f35920f Mon Sep 17 00:00:00 2001 From: "Dylan H. Morris" Date: Fri, 26 Jul 2024 16:03:16 -0400 Subject: [PATCH 41/41] Update docs/source/tutorials/time.qmd Co-authored-by: Damon Bayer --- docs/source/tutorials/time.qmd | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/docs/source/tutorials/time.qmd b/docs/source/tutorials/time.qmd index ac551af4..9a2263fb 100644 --- a/docs/source/tutorials/time.qmd +++ b/docs/source/tutorials/time.qmd @@ -10,7 +10,12 @@ The fundamental time unit should represent a period of fixed (or approximately f For many infectious disease renewal models of interest, the fundamental time unit will be days, and we will proceed with this tutorial treating days as our fundamental unit. -`pyrenew` deals with time having `RandomVariable`s carry information about (i) their own time unit expressed relative to the fundamental unit (`t_unit`) and (ii) the starting time, `t_start`, measured relative to `t = 0` in model time in fundamental time units. Return values from `RandomVariable.sample()` are `tuples` or `namedtuple`s of `SampledValue` objects. `SampledValue` objects can have `t_start` and `t_unit` attributes. +`pyrenew` deals with time by having `RandomVariable`s carry information about + +1. their own time unit expressed relative to the fundamental unit (`t_unit`) and +2. the starting time, `t_start`, measured relative to `t = 0` in model time in fundamental time units. + +Return values from `RandomVariable.sample()` are `tuples` or `namedtuple`s of `SampledValue` objects. `SampledValue` objects can have `t_start` and `t_unit` attributes. By default, `SampledValue` objects carry the `t_start` and `t_unit` of the `RandomVariable` from which they are `sample()`-d. One might override this default to allow a `RandomVariable.sample()` call to produce multiple `SampledValue`s with different time-units, or with different start-points relative to the `RandomVariable`'s own `t_start`.