Skip to content

Commit

Permalink
fix rt extraction across chains; fix coloring of plots
Browse files Browse the repository at this point in the history
  • Loading branch information
AFg6K7h4fhy2 committed Oct 3, 2024
1 parent 2073c45 commit b902c13
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 34 deletions.
30 changes: 12 additions & 18 deletions docs/source/tutorials/basic_renewal_model.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -254,34 +254,28 @@ Now, let's investigate the output, particularly the posterior distribution of th
# | fig-cap: Rt posterior distribution
import arviz as az
az.style.use("arviz-doc")
# Create arviz inference data object
idata = az.from_numpyro(
posterior=model1.mcmc,
)
idata = az.extract(idata, num_samples=100)
# Extract Rt signal
# Extract Rt signal samples across chains
rt = az.extract(idata.posterior["Rt"], num_samples=100)["Rt"].values
rt = (
az.extract(idata.posterior["Rt"], num_samples=100)
.to_dataarray()
.squeeze()
.T
)
# Plot Rt signal
fig, ax = plt.subplots(1, 1, figsize=(8, 6))
print(rt)
ax.plot(
np.arange(100),
np.arange(rt.shape[0]),
rt,
color="gray",
alpha=0.25,
color="skyblue",
alpha=0.10,
)
ax.plot([], [], color="gray", alpha=0.25, label="Rt Posterior Samples")
ax.plot([], [], color="skyblue", alpha=0.05, label="Rt Posterior Samples")
ax.plot(
np.arange(100),
rt.mean(dim="Rt_dim_0"),
color="red",
np.arange(rt.shape[0]),
rt.mean(axis=1),
color="black",
linewidth=2.0,
linestyle="--",
label="Sample Mean",
Expand Down
36 changes: 24 additions & 12 deletions docs/source/tutorials/day_of_the_week.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -218,17 +218,19 @@ idata = az.from_numpyro(
posterior_predictive=ppc_samples,
)
# Convert hospital admissions from discrete to continuous
# Use a time series plot (plot_ts) from arviz for plotting
axes = az.plot_ts(
idata,
y="negbinom_rv",
y_hat="negbinom_rv",
num_samples=150,
y_kwargs={"color": "red", "linewidth": 2.5},
y_hat_plot_kwargs={"color": "gray"},
num_samples=200,
y_kwargs={
"color": "blue",
"linewidth": 1.0,
"marker": "o",
"linestyle": "solid",
},
y_hat_plot_kwargs={"color": "skyblue", "alpha": 0.05},
y_mean_plot_kwargs={"color": "black", "linestyle": "--", "linewidth": 2.5},
backend_kwargs={"figsize": (8, 6)},
textsize=15.0,
Expand Down Expand Up @@ -361,9 +363,14 @@ axes = az.plot_ts(
idata,
y="negbinom_rv",
y_hat="negbinom_rv",
num_samples=150,
y_kwargs={"color": "red", "linewidth": 2.5},
y_hat_plot_kwargs={"color": "gray"},
num_samples=200,
y_kwargs={
"color": "blue",
"linewidth": 1.0,
"marker": "o",
"linestyle": "solid",
},
y_hat_plot_kwargs={"color": "skyblue", "alpha": 0.05},
y_mean_plot_kwargs={"color": "black", "linestyle": "--", "linewidth": 2.5},
backend_kwargs={"figsize": (8, 6)},
textsize=15.0,
Expand Down Expand Up @@ -396,9 +403,14 @@ axes = az.plot_ts(
idata,
y="negbinom_rv",
y_hat="negbinom_rv",
num_samples=150,
y_kwargs={"color": "red", "linewidth": 2.5},
y_hat_plot_kwargs={"color": "gray"},
num_samples=200,
y_kwargs={
"color": "blue",
"linewidth": 1.0,
"marker": "o",
"linestyle": "solid",
},
y_hat_plot_kwargs={"color": "skyblue", "alpha": 0.05},
y_mean_plot_kwargs={"color": "black", "linestyle": "--", "linewidth": 2.5},
backend_kwargs={"figsize": (8, 6)},
textsize=15.0,
Expand Down
11 changes: 7 additions & 4 deletions docs/source/tutorials/hospital_admissions_model.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -319,8 +319,13 @@ axes = az.plot_ts(
y="negbinom_rv",
y_hat="negbinom_rv",
num_samples=200,
y_kwargs={"color": "red", "linewidth": 2.5},
y_hat_plot_kwargs={"color": "gray"},
y_kwargs={
"color": "blue",
"linewidth": 1.0,
"marker": "o",
"linestyle": "solid",
},
y_hat_plot_kwargs={"color": "skyblue", "alpha": 0.05},
y_mean_plot_kwargs={"color": "black", "linestyle": "--", "linewidth": 2.5},
backend_kwargs={"figsize": (8, 6)},
textsize=15.0,
Expand All @@ -341,7 +346,6 @@ To explore further, We can use [ArviZ](https://www.arviz.org/) to visualize the

```{python}
# | label: convert-inferenceData
import arviz as az
idata = az.from_numpyro(hosp_model.mcmc)
```
Expand Down Expand Up @@ -443,7 +447,6 @@ We can use the `Model`'s `posterior_predictive` and `prior_predictive` methods t

```{python}
# | label: demonstrate-use-of-predictive-methods
import arviz as az
idata = az.from_numpyro(
hosp_model.mcmc,
Expand Down

0 comments on commit b902c13

Please sign in to comment.