Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

create randomvariable module #412

Merged
merged 24 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
65b248e
testing convolve mode
sbidari Aug 20, 2024
8d66ef8
Merge branch 'main' of https://github.com/CDCgov/multisignal-epi-infe…
sbidari Aug 20, 2024
17107c2
Merge branch 'main' into 385-incorrect-convolve-mode-in-hospitaladmis…
sbidari Aug 21, 2024
3a54855
update tutorial to work with convolve mode valid
sbidari Aug 21, 2024
399250f
Merge branch 'main' into 385-incorrect-convolve-mode-in-hospitaladmis…
sbidari Aug 21, 2024
9a7cbb3
update latent admissions test
sbidari Aug 21, 2024
cbff93c
update DOW tutorial for convolve mode valid
sbidari Aug 21, 2024
41b070f
update hosp model tests
sbidari Aug 21, 2024
e9130ce
create helper function for convolve and add tests
sbidari Aug 21, 2024
4737288
forgot to run precommit earlier
sbidari Aug 21, 2024
eb9e168
Merge branch 'main' into 385-incorrect-convolve-mode-in-hospitaladmis…
sbidari Aug 21, 2024
e87d742
update test for model with DOW effect
sbidari Aug 21, 2024
999d124
Merge branch 'main' of https://github.com/CDCgov/PyRenew into 385-inc…
sbidari Aug 22, 2024
b4c5ca2
renaming helper function, add n_initialization_point
sbidari Aug 22, 2024
6840243
Merge branch 'main' of https://github.com/CDCgov/PyRenew into 385-inc…
sbidari Aug 22, 2024
ebc9dd2
create randomvariable module
sbidari Aug 22, 2024
22509cf
make suffixes across variables unifrom
sbidari Aug 22, 2024
da7f428
modify import statements
sbidari Aug 22, 2024
3e4e87d
missed few imports
sbidari Aug 22, 2024
120296f
Merge branch 'main' into sb_create_randomvariable_module
sbidari Aug 22, 2024
547cd7b
Merge branch 'main' of https://github.com/CDCgov/PyRenew into sb_crea…
sbidari Aug 26, 2024
6591dc4
pre-commit changes
sbidari Aug 26, 2024
e1b888a
update metaclass.py
sbidari Aug 26, 2024
b6ea3d5
add randomvariable.rst
sbidari Aug 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 7 additions & 10 deletions docs/source/tutorials/basic_renewal_model.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,8 @@ from pyrenew.latent import (
from pyrenew.observation import PoissonObservation
from pyrenew.deterministic import DeterministicPMF
from pyrenew.model import RtInfectionsRenewalModel
from pyrenew.metaclass import (
RandomVariable,
DistributionalRV,
TransformedRandomVariable,
)
from pyrenew.metaclass import RandomVariable
from pyrenew.randomvariable import DistributionalVariable, TransformedVariable
import pyrenew.transformation as t
from numpyro.infer.reparam import LocScaleReparam
```
Expand Down Expand Up @@ -64,7 +61,7 @@ flowchart LR

subgraph latent[Latent module]
inf["latent_infections_rv\n(Infections)"]
i0["I0_rv\n(DistributionalRV)"]
i0["I0_rv\n(DistributionalVariable)"]
end

subgraph process[Process module]
Expand Down Expand Up @@ -126,7 +123,7 @@ gen_int = DeterministicPMF(name="gen_int", value=pmf_array)
# (2) Initial infections (inferred with a prior)
I0 = InfectionInitializationProcess(
"I0_initialization",
DistributionalRV(name="I0", distribution=dist.LogNormal(2.5, 1)),
DistributionalVariable(name="I0", distribution=dist.LogNormal(2.5, 1)),
InitializeInfectionsZeroPad(pmf_array.size),
t_unit=1,
)
Expand All @@ -142,17 +139,17 @@ class MyRt(RandomVariable):
def sample(self, n: int, **kwargs) -> tuple:
sd_rt = numpyro.sample("Rt_random_walk_sd", dist.HalfNormal(0.025))

rt_rv = TransformedRandomVariable(
rt_rv = TransformedVariable(
name="log_rt_random_walk",
base_rv=RandomWalk(
name="log_rt",
step_rv=DistributionalRV(
step_rv=DistributionalVariable(
name="rw_step_rv", distribution=dist.Normal(0, 0.025)
),
),
transforms=t.ExpTransform(),
)
rt_init_rv = DistributionalRV(
rt_init_rv = DistributionalVariable(
name="init_log_rt", distribution=dist.Normal(0, 0.2)
)
init_rt, *_ = rt_init_rv.sample()
Expand Down
24 changes: 12 additions & 12 deletions docs/source/tutorials/day_of_the_week.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,15 @@ inf_hosp_int_array = inf_hosp_int["probability_mass"].to_numpy()
```{python}
# | label: latent-hosp
# | code-fold: true
from pyrenew import latent, deterministic, metaclass
from pyrenew import latent, deterministic, randomvariable
import jax.numpy as jnp
import numpyro.distributions as dist

inf_hosp_int = deterministic.DeterministicPMF(
name="inf_hosp_int", value=inf_hosp_int_array
)

hosp_rate = metaclass.DistributionalRV(
hosp_rate = randomvariable.DistributionalVariable(
name="IHR", distribution=dist.LogNormal(jnp.log(0.05), jnp.log(1.1))
)

Expand All @@ -81,7 +81,7 @@ n_initialization_points = max(gen_int_array.size, inf_hosp_int_array.size) - 1

I0 = InfectionInitializationProcess(
"I0_initialization",
metaclass.DistributionalRV(
randomvariable.DistributionalVariable(
name="I0",
distribution=dist.LogNormal(loc=jnp.log(100), scale=jnp.log(1.75)),
),
Expand Down Expand Up @@ -113,11 +113,11 @@ class MyRt(metaclass.RandomVariable):
sd_rt, *_ = self.sd_rv()

# Random walk step
step_rv = metaclass.DistributionalRV(
step_rv = randomvariable.DistributionalVariable(
name="rw_step_rv", distribution=dist.Normal(0, sd_rt.value)
)

rt_init_rv = metaclass.DistributionalRV(
rt_init_rv = randomvariable.DistributionalVariable(
name="init_log_rt", distribution=dist.Normal(0, 0.2)
)

Expand All @@ -128,7 +128,7 @@ class MyRt(metaclass.RandomVariable):
)

# Transforming the random walk to the Rt scale
rt_rv = metaclass.TransformedRandomVariable(
rt_rv = randomvariable.TransformedVariable(
name="Rt_rv",
base_rv=base_rv,
transforms=transformation.ExpTransform(),
Expand All @@ -139,7 +139,7 @@ class MyRt(metaclass.RandomVariable):


rtproc = MyRt(
metaclass.DistributionalRV(
randomvariable.DistributionalVariable(
name="Rt_random_walk_sd", distribution=dist.HalfNormal(0.025)
)
)
Expand All @@ -152,9 +152,9 @@ rtproc = MyRt(
# | code-fold: true
# we place a log-Normal prior on the concentration
# parameter of the negative binomial.
nb_conc_rv = metaclass.TransformedRandomVariable(
nb_conc_rv = randomvariable.TransformedVariable(
"concentration",
metaclass.DistributionalRV(
randomvariable.DistributionalVariable(
name="concentration_raw",
distribution=dist.TruncatedNormal(loc=0, scale=1, low=0.01),
),
Expand Down Expand Up @@ -212,16 +212,16 @@ out = hosp_model.plot_posterior(

We will re-use the infection to admission interval and infection to hospitalization rate from the previous model. But we will also add a day-of-the-week effect. To do this, we will add two additional arguments to the latent hospital admissions random variable: `day_of_the_week_rv` (a `RandomVariable`) and `obs_data_first_day_of_the_week` (an `int` mapping days of the week from 0:6, zero being Monday). The `day_of_the_week_rv`'s sample method should return a vector of length seven; those values are then broadcasted to match the length of the dataset. Moreover, since the observed data may start in a weekday other than Monday, the `obs_data_first_day_of_the_week` argument is used to offset the day-of-the-week effect.

For this example, the effect will be passed as a scaled Dirichlet distribution. It will consist of a `TransformedRandomVariable` that samples an array of length seven from numpyro's `distributions.Dirichlet` and applies a `transformation.AffineTransform` to scale it by seven. [^note-other-examples]:
For this example, the effect will be passed as a scaled Dirichlet distribution. It will consist of a `TransformedVariable` that samples an array of length seven from numpyro's `distributions.Dirichlet` and applies a `transformation.AffineTransform` to scale it by seven. [^note-other-examples]:

[^note-other-examples]: A similar weekday effect is implemented in its own module, with example code [here](periodic_effects.html).

```{python}
# | label: weekly-effect
# Instantiating the day-of-the-week effect
dayofweek_effect = metaclass.TransformedRandomVariable(
dayofweek_effect = randomvariable.TransformedVariable(
name="dayofweek_effect",
base_rv=metaclass.DistributionalRV(
base_rv=randomvariable.DistributionalVariable(
name="dayofweek_effect_raw",
distribution=dist.Dirichlet(jnp.ones(7)),
),
Expand Down
15 changes: 6 additions & 9 deletions docs/source/tutorials/extending_pyrenew.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,8 @@ from pyrenew.deterministic import DeterministicPMF, DeterministicVariable
from pyrenew.latent import InfectionsWithFeedback
from pyrenew.model import RtInfectionsRenewalModel
from pyrenew.process import RandomWalk
from pyrenew.metaclass import (
RandomVariable,
DistributionalRV,
TransformedRandomVariable,
)
from pyrenew.metaclass import RandomVariable
from pyrenew.randomvariable import DistributionalVariable, TransformedVariable
from pyrenew.latent import (
InfectionInitializationProcess,
InitializeInfectionsExponentialGrowth,
Expand All @@ -53,7 +50,7 @@ feedback_strength = DeterministicVariable(name="feedback_strength", value=0.01)

I0 = InfectionInitializationProcess(
"I0_initialization",
DistributionalRV(name="I0", distribution=dist.LogNormal(0, 1)),
DistributionalVariable(name="I0", distribution=dist.LogNormal(0, 1)),
InitializeInfectionsExponentialGrowth(
gen_int_array.size,
DeterministicVariable(name="rate", value=0.05),
Expand All @@ -75,17 +72,17 @@ class MyRt(RandomVariable):
def sample(self, n: int, **kwargs) -> tuple:
sd_rt = numpyro.sample("Rt_random_walk_sd", dist.HalfNormal(0.025))

rt_rv = TransformedRandomVariable(
rt_rv = TransformedVariable(
name="log_rt_random_walk",
base_rv=RandomWalk(
name="log_rt",
step_rv=DistributionalRV(
step_rv=DistributionalVariable(
name="rw_step_rv", distribution=dist.Normal(0, 0.025)
),
),
transforms=t.ExpTransform(),
)
rt_init_rv = DistributionalRV(
rt_init_rv = DistributionalVariable(
name="init_log_rt", distribution=dist.Normal(0, 0.2)
)
init_rt, *_ = rt_init_rv.sample()
Expand Down
18 changes: 9 additions & 9 deletions docs/source/tutorials/hospital_admissions_model.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -137,15 +137,15 @@ With these two in hand, we can start building the model. First, we will define t

```{python}
# | label: latent-hosp
from pyrenew import latent, deterministic, metaclass
from pyrenew import latent, deterministic, metaclass, randomvariable
import jax.numpy as jnp
import numpyro.distributions as dist

inf_hosp_int = deterministic.DeterministicPMF(
name="inf_hosp_int", value=inf_hosp_int_array
)

hosp_rate = metaclass.DistributionalRV(
hosp_rate = randomvariable.DistributionalVariable(
name="IHR", distribution=dist.LogNormal(jnp.log(0.05), jnp.log(1.1))
)

Expand All @@ -155,7 +155,7 @@ latent_hosp = latent.HospitalAdmissions(
)
```

The `inf_hosp_int` is a `DeterministicPMF` object that takes the infection to hospital admission interval as input. The `hosp_rate` is a `DistributionalRV` object that takes a numpyro distribution to represent the infection to hospital admission rate. The `HospitalAdmissions` class is a `RandomVariable` that takes two distributions as inputs: the infection to admission interval and the infection to hospital admission rate. Now, we can define the rest of the other components:
The `inf_hosp_int` is a `DeterministicPMF` object that takes the infection to hospital admission interval as input. The `hosp_rate` is a `DistributionalVariable` object that takes a numpyro distribution to represent the infection to hospital admission rate. The `HospitalAdmissions` class is a `RandomVariable` that takes two distributions as inputs: the infection to admission interval and the infection to hospital admission rate. Now, we can define the rest of the other components:

```{python}
# | label: initializing-rest-of-model
Expand All @@ -171,7 +171,7 @@ latent_inf = latent.Infections()
n_initialization_points = max(gen_int_array.size, inf_hosp_int_array.size) - 1
I0 = InfectionInitializationProcess(
"I0_initialization",
metaclass.DistributionalRV(
randomvariable.DistributionalVariable(
name="I0",
distribution=dist.LogNormal(loc=jnp.log(100), scale=jnp.log(1.75)),
),
Expand All @@ -194,17 +194,17 @@ class MyRt(metaclass.RandomVariable):
def sample(self, n: int, **kwargs) -> tuple:
sd_rt = numpyro.sample("Rt_random_walk_sd", dist.HalfNormal(0.025))

rt_rv = metaclass.TransformedRandomVariable(
rt_rv = randomvariable.TransformedVariable(
name="log_rt_random_walk",
base_rv=process.RandomWalk(
name="log_rt",
step_rv=metaclass.DistributionalRV(
step_rv=randomvariable.DistributionalVariable(
name="rw_step_rv", distribution=dist.Normal(0, 0.025)
),
),
transforms=transformation.ExpTransform(),
)
rt_init_rv = metaclass.DistributionalRV(
rt_init_rv = randomvariable.DistributionalVariable(
name="init_log_rt", distribution=dist.Normal(0, 0.2)
)
init_rt, *_ = rt_init_rv.sample()
Expand All @@ -218,9 +218,9 @@ rtproc = MyRt()

# we place a log-Normal prior on the concentration
# parameter of the negative binomial.
nb_conc_rv = metaclass.TransformedRandomVariable(
nb_conc_rv = randomvariable.TransformedVariable(
"concentration",
metaclass.DistributionalRV(
randomvariable.DistributionalVariable(
name="concentration_raw",
distribution=dist.TruncatedNormal(loc=0, scale=1, low=0.01),
),
Expand Down
4 changes: 2 additions & 2 deletions docs/source/tutorials/periodic_effects.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ The `PeriodicBroadcaster` class can also be used to repeat a sequence as a whole

```{python}
import numpyro.distributions as dist
from pyrenew import transformation, metaclass
from pyrenew import transformation, randomvariable

# Building the transformed prior: Dirichlet * 7
mysimplex = dist.TransformedDistribution(
Expand All @@ -76,7 +76,7 @@ mysimplex = dist.TransformedDistribution(
# Constructing the day of week effect
dayofweek = process.DayOfWeekEffect(
offset=0,
quantity_to_broadcast=metaclass.DistributionalRV(
quantity_to_broadcast=randomvariable.DistributionalVariable(
name="simp", distribution=mysimplex
),
t_start=0,
Expand Down
Loading