Skip to content

Commit

Permalink
Further Prune PostProcessing Code, Specifically plot_posterior And sp…
Browse files Browse the repository at this point in the history
…read_draws (#431)

* remove mcmcutils file; remove spread from tutorials

* remove instances of spread draws

* adding first tutorial edited w/ arviz, still need to get proper plot_ppc

* further plot updates; difficult to imagine plot_ppc will get equivalent of plot_plosterior

* another edit of a less arviz-y plot

* post DHM meet edit, switch to plot_ts

* day of week effect tutorial ready

* modify remaining tutorials

* remove mcmcutils

* update pyprojec.toml

* chain together dow_effect_raw lines

* remove float casting for idata

* fix rt extraction across chains; fix coloring of plots

* final coloring fix

---------

Co-authored-by: Dylan H. Morris <[email protected]>
  • Loading branch information
AFg6K7h4fhy2 and dylanhmorris authored Oct 3, 2024
1 parent 20df5b3 commit 9a996a3
Show file tree
Hide file tree
Showing 8 changed files with 199 additions and 311 deletions.
33 changes: 32 additions & 1 deletion docs/source/tutorials/basic_renewal_model.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,38 @@ Now, let's investigate the output, particularly the posterior distribution of th
```{python}
# | label: fig-output-rt
# | fig-cap: Rt posterior distribution
out = model1.plot_posterior(var="Rt")
import arviz as az
# Create arviz inference data object
idata = az.from_numpyro(
posterior=model1.mcmc,
)
# Extract Rt signal samples across chains
rt = az.extract(idata.posterior["Rt"], num_samples=100)["Rt"].values
# Plot Rt signal
fig, ax = plt.subplots(1, 1, figsize=(8, 6))
ax.plot(
np.arange(rt.shape[0]),
rt,
color="skyblue",
alpha=0.10,
)
ax.plot([], [], color="skyblue", alpha=0.05, label="Rt Posterior Samples")
ax.plot(
np.arange(rt.shape[0]),
rt.mean(axis=1),
color="black",
linewidth=2.0,
linestyle="--",
label="Sample Mean",
)
ax.legend(loc="best")
ax.set_ylabel(r"$\mathscr{R}_t$ Signal", fontsize=20)
ax.set_xlabel("Days", fontsize=20)
plt.show()
```
We can use the `get_samples` method to extract samples from the model
```{python}
Expand Down
149 changes: 131 additions & 18 deletions docs/source/tutorials/day_of_the_week.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ obs = observation.NegativeBinomialObservation(

```{python}
# | label: init-model
# | code-fold: true
hosp_model = model.HospitalAdmissionsModel(
latent_infections_rv=latent_inf,
latent_hosp_admissions_rv=latent_hosp,
Expand All @@ -186,6 +187,7 @@ Here is what the model looks like without the day-of-the-week effect:
```{python}
# | label: fig-output-admissions-padding-and-weekday
# | fig-cap: Hospital Admissions posterior distribution without weekday effect
# | code-fold: true
import jax
import numpy as np
Expand All @@ -197,16 +199,56 @@ hosp_model.run(
rng_key=jax.random.key(54),
mcmc_args=dict(progress_bar=False),
)
```

```{python}
# | code-fold: true
import arviz as az
import matplotlib.pyplot as plt
# Retrieve the posterior samples from the model
ppc_samples = hosp_model.posterior_predictive(
n_datapoints=daily_hosp_admits.size
)
# Create an InferenceData object from model
idata = az.from_numpyro(
posterior=hosp_model.mcmc,
posterior_predictive=ppc_samples,
)
# Plotting the posterior
out = hosp_model.plot_posterior(
var="latent_hospital_admissions",
ylab="Hospital Admissions",
obs_signal=daily_hosp_admits.astype(float),
# Use a time series plot (plot_ts) from arviz for plotting
axes = az.plot_ts(
idata,
y="negbinom_rv",
y_hat="negbinom_rv",
num_samples=200,
y_kwargs={
"color": "blue",
"linewidth": 1.0,
"marker": "o",
"linestyle": "solid",
},
y_hat_plot_kwargs={"color": "skyblue", "alpha": 0.05},
y_mean_plot_kwargs={"color": "black", "linestyle": "--", "linewidth": 2.5},
backend_kwargs={"figsize": (8, 6)},
textsize=15.0,
)
ax = axes[0][0]
ax.set_xlabel("Time", fontsize=20)
ax.set_ylabel("Hospital Admissions", fontsize=20)
handles, labels = ax.get_legend_handles_labels()
ax.legend(
handles, ["Observed", "Sample Mean", "Posterior Samples"], loc="best"
)
plt.show()
```





## Round 2: Incorporating day-of-the-week effects

We will re-use the infection to admission interval and infection to hospitalization rate from the previous model. But we will also add a day-of-the-week effect. To do this, we will add two additional arguments to the latent hospital admissions random variable: `day_of_the_week_rv` (a `RandomVariable`) and `obs_data_first_day_of_the_week` (an `int` mapping days of the week from 0:6, zero being Monday). The `day_of_the_week_rv`'s sample method should return a vector of length seven; those values are then broadcasted to match the length of the dataset. Moreover, since the observed data may start in a weekday other than Monday, the `obs_data_first_day_of_the_week` argument is used to offset the day-of-the-week effect.
Expand Down Expand Up @@ -280,34 +322,105 @@ As a result, we can see the posterior distribution of our novel day-of-the-week
```{python}
# | label: fig-output-day-of-week
# | fig-cap: Day of the week effect
out = hosp_model_dow.plot_posterior(
var="dayofweek_effect_raw", ylab="Day of the Week Effect", samples=500
# Create an InferenceData object from hosp_model_dow
dow_idata = az.from_numpyro(
posterior=hosp_model_dow.mcmc,
)
sp = hosp_model_dow.spread_draws(["dayofweek_effect_raw"])
# dayofweek_effect is not recorded
# Extract day of week effect (DOW)
dow_effect_raw = dow_idata.posterior["dayofweek_effect_raw"].squeeze().T
indices = np.random.choice(dow_effect_raw.shape[1], size=200, replace=False)
dow_plot_samples = dow_effect_raw[:, indices]
fig, ax = plt.subplots(1, 1, figsize=(8, 6))
ax.plot(
np.arange(dow_effect_raw.shape[0]),
dow_plot_samples,
color="skyblue",
alpha=0.10,
)
ax.plot([], [], color="skyblue", alpha=0.10, label="DOW Posterior Samples")
ax.plot(
np.arange(dow_effect_raw.shape[0]),
dow_plot_samples.mean(dim="draw"),
color="black",
linewidth=2.0,
linestyle="--",
label="Sample Mean",
)
ax.legend(loc="best")
ax.set_ylabel("Effect", fontsize=20)
ax.set_xlabel("Day Of Week", fontsize=20)
plt.show()
```

The new model with the day-of-the-week effect can be compared to the previous model without the effect. Finally, let's reproduce the figure without the day-of-the-week effect, and then plot the new model with the effect:

```{python}
# | label: fig-output-admissions-original
# | fig-cap: Hospital Admissions posterior distribution without weekday effect
# Figure without weekday effect
out = hosp_model.plot_posterior(
var="latent_hospital_admissions",
ylab="Hospital Admissions",
obs_signal=daily_hosp_admits.astype(float),
# Without weekday effect (from earlier)
axes = az.plot_ts(
idata,
y="negbinom_rv",
y_hat="negbinom_rv",
num_samples=200,
y_kwargs={
"color": "blue",
"linewidth": 1.0,
"marker": "o",
"linestyle": "solid",
},
y_hat_plot_kwargs={"color": "skyblue", "alpha": 0.05},
y_mean_plot_kwargs={"color": "black", "linestyle": "--", "linewidth": 2.5},
backend_kwargs={"figsize": (8, 6)},
textsize=15.0,
)
ax = axes[0][0]
ax.set_xlabel("Time", fontsize=20)
ax.set_ylabel("Hospital Admissions", fontsize=20)
handles, labels = ax.get_legend_handles_labels()
ax.legend(
handles,
["Observed", "Posterior Predictive", "Samples wo/ WDE"],
loc="best",
)
plt.show()
```

```{python}
# | label: fig-output-admissions-wof
# | fig-cap: Hospital Admissions posterior distribution with weekday effect
# Figure with weekday effect
out_dow = hosp_model_dow.plot_posterior(
var="latent_hospital_admissions",
ylab="Hospital Admissions",
obs_signal=daily_hosp_admits.astype(float),
ppc_samples = hosp_model_dow.posterior_predictive(
n_datapoints=daily_hosp_admits.size
)
idata = az.from_numpyro(
posterior=hosp_model_dow.mcmc,
posterior_predictive=ppc_samples,
)
axes = az.plot_ts(
idata,
y="negbinom_rv",
y_hat="negbinom_rv",
num_samples=200,
y_kwargs={
"color": "blue",
"linewidth": 1.0,
"marker": "o",
"linestyle": "solid",
},
y_hat_plot_kwargs={"color": "skyblue", "alpha": 0.05},
y_mean_plot_kwargs={"color": "black", "linestyle": "--", "linewidth": 2.5},
backend_kwargs={"figsize": (8, 6)},
textsize=15.0,
)
ax = axes[0][0]
ax.set_xlabel("Time", fontsize=20)
ax.set_ylabel("Hospital Admissions", fontsize=20)
handles, labels = ax.get_legend_handles_labels()
ax.legend(
handles, ["Observed", "Posterior Predictive", "Samples w/ WDE"], loc="best"
)
plt.show()
```
43 changes: 35 additions & 8 deletions docs/source/tutorials/hospital_admissions_model.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -297,18 +297,47 @@ hosp_model.run(
)
```

We can use the `Model` object's `plot_posterior` method to visualize the model fit. Here, we plot the observed values against the inferred latent values (i.e. the mean of the negative binomial observation process)[^capture]:
We can use `arviz` to visualize the model fit. Here, we plot the observed values against the inferred latent values (i.e. the mean of the negative binomial observation process)[^capture]:

[^capture]: The output is captured to avoid `quarto` from displaying the output twice.

```{python}
# | label: fig-output-hospital-admissions
# | fig-cap: Latent hospital admissions posterior samples (blue) and observed admissions timeseries (black).
out = hosp_model.plot_posterior(
var="latent_hospital_admissions",
ylab="Hospital Admissions",
obs_signal=daily_hosp_admits.astype(float),
# | fig-cap: Latent hospital admissions posterior samples (gray) and observed admissions timeseries (red).
import arviz as az
ppc_samples = hosp_model.posterior_predictive(
n_datapoints=daily_hosp_admits.size
)
idata = az.from_numpyro(
posterior=hosp_model.mcmc,
posterior_predictive=ppc_samples,
)
axes = az.plot_ts(
idata,
y="negbinom_rv",
y_hat="negbinom_rv",
num_samples=200,
y_kwargs={
"color": "blue",
"linewidth": 1.0,
"marker": "o",
"linestyle": "solid",
},
y_hat_plot_kwargs={"color": "skyblue", "alpha": 0.05},
y_mean_plot_kwargs={"color": "black", "linestyle": "--", "linewidth": 2.5},
backend_kwargs={"figsize": (8, 6)},
textsize=15.0,
)
ax = axes[0][0]
ax.set_xlabel("Time", fontsize=20)
ax.set_ylabel("Hospital Admissions", fontsize=20)
handles, labels = ax.get_legend_handles_labels()
ax.legend(
handles, ["Observed", "Sample Mean", "Posterior Samples"], loc="best"
)
plt.show()
```

## Results exploration and MCMC diagnostics
Expand All @@ -317,7 +346,6 @@ To explore further, We can use [ArviZ](https://www.arviz.org/) to visualize the

```{python}
# | label: convert-inferenceData
import arviz as az
idata = az.from_numpyro(hosp_model.mcmc)
```
Expand Down Expand Up @@ -419,7 +447,6 @@ We can use the `Model`'s `posterior_predictive` and `prior_predictive` methods t

```{python}
# | label: demonstrate-use-of-predictive-methods
import arviz as az
idata = az.from_numpyro(
hosp_model.mcmc,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ python = "^3.12"
jax = ">=0.4.30"
numpy = "^2.0.0"
polars = "^1.2.1"
matplotlib = "^3.8.3"
numpyro = ">=0.15.3"

[tool.poetry.group.dev]
Expand All @@ -30,6 +29,7 @@ deptry = "^0.17.0"
optional = true

[tool.poetry.group.docs.dependencies]
matplotlib = "^3.8.3"
ipykernel = "^6.29.3"
pyyaml = "^6.0.0"
nbclient = "^0.10.0"
Expand Down
Loading

0 comments on commit 9a996a3

Please sign in to comment.