From 9a996a34492483e88c8792be7956146265138152 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?upx3=E2=80=94TM=20=28CFA=29?= <127630341+AFg6K7h4fhy2@users.noreply.github.com> Date: Thu, 3 Oct 2024 15:28:31 -0400 Subject: [PATCH] Further Prune PostProcessing Code, Specifically plot_posterior And spread_draws (#431) * remove mcmcutils file; remove spread from tutorials * remove instances of spread draws * adding first tutorial edited w/ arviz, still need to get proper plot_ppc * further plot updates; difficult to imagine plot_ppc will get equivalent of plot_plosterior * another edit of a less arviz-y plot * post DHM meet edit, switch to plot_ts * day of week effect tutorial ready * modify remaining tutorials * remove mcmcutils * update pyprojec.toml * chain together dow_effect_raw lines * remove float casting for idata * fix rt extraction across chains; fix coloring of plots * final coloring fix --------- Co-authored-by: Dylan H. Morris --- docs/source/tutorials/basic_renewal_model.qmd | 33 +++- docs/source/tutorials/day_of_the_week.qmd | 149 +++++++++++++++-- .../tutorials/hospital_admissions_model.qmd | 43 ++++- pyproject.toml | 2 +- pyrenew/mcmcutils.py | 158 ------------------ pyrenew/metaclass.py | 45 ----- test/test_model_basic_renewal.py | 35 ---- test/test_model_hosp_admissions.py | 45 ----- 8 files changed, 199 insertions(+), 311 deletions(-) delete mode 100644 pyrenew/mcmcutils.py diff --git a/docs/source/tutorials/basic_renewal_model.qmd b/docs/source/tutorials/basic_renewal_model.qmd index f9100da6..b5fff380 100644 --- a/docs/source/tutorials/basic_renewal_model.qmd +++ b/docs/source/tutorials/basic_renewal_model.qmd @@ -252,7 +252,38 @@ Now, let's investigate the output, particularly the posterior distribution of th ```{python} # | label: fig-output-rt # | fig-cap: Rt posterior distribution -out = model1.plot_posterior(var="Rt") +import arviz as az + +# Create arviz inference data object +idata = az.from_numpyro( + posterior=model1.mcmc, +) + +# Extract Rt signal samples across chains +rt = az.extract(idata.posterior["Rt"], num_samples=100)["Rt"].values + + +# Plot Rt signal +fig, ax = plt.subplots(1, 1, figsize=(8, 6)) +ax.plot( + np.arange(rt.shape[0]), + rt, + color="skyblue", + alpha=0.10, +) +ax.plot([], [], color="skyblue", alpha=0.05, label="Rt Posterior Samples") +ax.plot( + np.arange(rt.shape[0]), + rt.mean(axis=1), + color="black", + linewidth=2.0, + linestyle="--", + label="Sample Mean", +) +ax.legend(loc="best") +ax.set_ylabel(r"$\mathscr{R}_t$ Signal", fontsize=20) +ax.set_xlabel("Days", fontsize=20) +plt.show() ``` We can use the `get_samples` method to extract samples from the model ```{python} diff --git a/docs/source/tutorials/day_of_the_week.qmd b/docs/source/tutorials/day_of_the_week.qmd index 9d789b98..f95fd9a2 100644 --- a/docs/source/tutorials/day_of_the_week.qmd +++ b/docs/source/tutorials/day_of_the_week.qmd @@ -171,6 +171,7 @@ obs = observation.NegativeBinomialObservation( ```{python} # | label: init-model +# | code-fold: true hosp_model = model.HospitalAdmissionsModel( latent_infections_rv=latent_inf, latent_hosp_admissions_rv=latent_hosp, @@ -186,6 +187,7 @@ Here is what the model looks like without the day-of-the-week effect: ```{python} # | label: fig-output-admissions-padding-and-weekday # | fig-cap: Hospital Admissions posterior distribution without weekday effect +# | code-fold: true import jax import numpy as np @@ -197,16 +199,56 @@ hosp_model.run( rng_key=jax.random.key(54), mcmc_args=dict(progress_bar=False), ) +``` + +```{python} +# | code-fold: true +import arviz as az +import matplotlib.pyplot as plt + + +# Retrieve the posterior samples from the model +ppc_samples = hosp_model.posterior_predictive( + n_datapoints=daily_hosp_admits.size +) + +# Create an InferenceData object from model +idata = az.from_numpyro( + posterior=hosp_model.mcmc, + posterior_predictive=ppc_samples, +) -# Plotting the posterior -out = hosp_model.plot_posterior( - var="latent_hospital_admissions", - ylab="Hospital Admissions", - obs_signal=daily_hosp_admits.astype(float), +# Use a time series plot (plot_ts) from arviz for plotting +axes = az.plot_ts( + idata, + y="negbinom_rv", + y_hat="negbinom_rv", + num_samples=200, + y_kwargs={ + "color": "blue", + "linewidth": 1.0, + "marker": "o", + "linestyle": "solid", + }, + y_hat_plot_kwargs={"color": "skyblue", "alpha": 0.05}, + y_mean_plot_kwargs={"color": "black", "linestyle": "--", "linewidth": 2.5}, + backend_kwargs={"figsize": (8, 6)}, + textsize=15.0, ) +ax = axes[0][0] +ax.set_xlabel("Time", fontsize=20) +ax.set_ylabel("Hospital Admissions", fontsize=20) +handles, labels = ax.get_legend_handles_labels() +ax.legend( + handles, ["Observed", "Sample Mean", "Posterior Samples"], loc="best" +) +plt.show() ``` + + + ## Round 2: Incorporating day-of-the-week effects We will re-use the infection to admission interval and infection to hospitalization rate from the previous model. But we will also add a day-of-the-week effect. To do this, we will add two additional arguments to the latent hospital admissions random variable: `day_of_the_week_rv` (a `RandomVariable`) and `obs_data_first_day_of_the_week` (an `int` mapping days of the week from 0:6, zero being Monday). The `day_of_the_week_rv`'s sample method should return a vector of length seven; those values are then broadcasted to match the length of the dataset. Moreover, since the observed data may start in a weekday other than Monday, the `obs_data_first_day_of_the_week` argument is used to offset the day-of-the-week effect. @@ -280,12 +322,35 @@ As a result, we can see the posterior distribution of our novel day-of-the-week ```{python} # | label: fig-output-day-of-week # | fig-cap: Day of the week effect -out = hosp_model_dow.plot_posterior( - var="dayofweek_effect_raw", ylab="Day of the Week Effect", samples=500 +# Create an InferenceData object from hosp_model_dow +dow_idata = az.from_numpyro( + posterior=hosp_model_dow.mcmc, ) -sp = hosp_model_dow.spread_draws(["dayofweek_effect_raw"]) -# dayofweek_effect is not recorded +# Extract day of week effect (DOW) +dow_effect_raw = dow_idata.posterior["dayofweek_effect_raw"].squeeze().T +indices = np.random.choice(dow_effect_raw.shape[1], size=200, replace=False) +dow_plot_samples = dow_effect_raw[:, indices] +fig, ax = plt.subplots(1, 1, figsize=(8, 6)) +ax.plot( + np.arange(dow_effect_raw.shape[0]), + dow_plot_samples, + color="skyblue", + alpha=0.10, +) +ax.plot([], [], color="skyblue", alpha=0.10, label="DOW Posterior Samples") +ax.plot( + np.arange(dow_effect_raw.shape[0]), + dow_plot_samples.mean(dim="draw"), + color="black", + linewidth=2.0, + linestyle="--", + label="Sample Mean", +) +ax.legend(loc="best") +ax.set_ylabel("Effect", fontsize=20) +ax.set_xlabel("Day Of Week", fontsize=20) +plt.show() ``` The new model with the day-of-the-week effect can be compared to the previous model without the effect. Finally, let's reproduce the figure without the day-of-the-week effect, and then plot the new model with the effect: @@ -293,21 +358,69 @@ The new model with the day-of-the-week effect can be compared to the previous mo ```{python} # | label: fig-output-admissions-original # | fig-cap: Hospital Admissions posterior distribution without weekday effect -# Figure without weekday effect -out = hosp_model.plot_posterior( - var="latent_hospital_admissions", - ylab="Hospital Admissions", - obs_signal=daily_hosp_admits.astype(float), +# Without weekday effect (from earlier) +axes = az.plot_ts( + idata, + y="negbinom_rv", + y_hat="negbinom_rv", + num_samples=200, + y_kwargs={ + "color": "blue", + "linewidth": 1.0, + "marker": "o", + "linestyle": "solid", + }, + y_hat_plot_kwargs={"color": "skyblue", "alpha": 0.05}, + y_mean_plot_kwargs={"color": "black", "linestyle": "--", "linewidth": 2.5}, + backend_kwargs={"figsize": (8, 6)}, + textsize=15.0, ) +ax = axes[0][0] +ax.set_xlabel("Time", fontsize=20) +ax.set_ylabel("Hospital Admissions", fontsize=20) +handles, labels = ax.get_legend_handles_labels() +ax.legend( + handles, + ["Observed", "Posterior Predictive", "Samples wo/ WDE"], + loc="best", +) +plt.show() ``` ```{python} # | label: fig-output-admissions-wof # | fig-cap: Hospital Admissions posterior distribution with weekday effect # Figure with weekday effect -out_dow = hosp_model_dow.plot_posterior( - var="latent_hospital_admissions", - ylab="Hospital Admissions", - obs_signal=daily_hosp_admits.astype(float), +ppc_samples = hosp_model_dow.posterior_predictive( + n_datapoints=daily_hosp_admits.size +) +idata = az.from_numpyro( + posterior=hosp_model_dow.mcmc, + posterior_predictive=ppc_samples, +) + +axes = az.plot_ts( + idata, + y="negbinom_rv", + y_hat="negbinom_rv", + num_samples=200, + y_kwargs={ + "color": "blue", + "linewidth": 1.0, + "marker": "o", + "linestyle": "solid", + }, + y_hat_plot_kwargs={"color": "skyblue", "alpha": 0.05}, + y_mean_plot_kwargs={"color": "black", "linestyle": "--", "linewidth": 2.5}, + backend_kwargs={"figsize": (8, 6)}, + textsize=15.0, +) +ax = axes[0][0] +ax.set_xlabel("Time", fontsize=20) +ax.set_ylabel("Hospital Admissions", fontsize=20) +handles, labels = ax.get_legend_handles_labels() +ax.legend( + handles, ["Observed", "Posterior Predictive", "Samples w/ WDE"], loc="best" ) +plt.show() ``` diff --git a/docs/source/tutorials/hospital_admissions_model.qmd b/docs/source/tutorials/hospital_admissions_model.qmd index bf11bf3a..0acebc6b 100644 --- a/docs/source/tutorials/hospital_admissions_model.qmd +++ b/docs/source/tutorials/hospital_admissions_model.qmd @@ -297,18 +297,47 @@ hosp_model.run( ) ``` -We can use the `Model` object's `plot_posterior` method to visualize the model fit. Here, we plot the observed values against the inferred latent values (i.e. the mean of the negative binomial observation process)[^capture]: +We can use `arviz` to visualize the model fit. Here, we plot the observed values against the inferred latent values (i.e. the mean of the negative binomial observation process)[^capture]: [^capture]: The output is captured to avoid `quarto` from displaying the output twice. ```{python} # | label: fig-output-hospital-admissions -# | fig-cap: Latent hospital admissions posterior samples (blue) and observed admissions timeseries (black). -out = hosp_model.plot_posterior( - var="latent_hospital_admissions", - ylab="Hospital Admissions", - obs_signal=daily_hosp_admits.astype(float), +# | fig-cap: Latent hospital admissions posterior samples (gray) and observed admissions timeseries (red). +import arviz as az + +ppc_samples = hosp_model.posterior_predictive( + n_datapoints=daily_hosp_admits.size +) +idata = az.from_numpyro( + posterior=hosp_model.mcmc, + posterior_predictive=ppc_samples, +) + +axes = az.plot_ts( + idata, + y="negbinom_rv", + y_hat="negbinom_rv", + num_samples=200, + y_kwargs={ + "color": "blue", + "linewidth": 1.0, + "marker": "o", + "linestyle": "solid", + }, + y_hat_plot_kwargs={"color": "skyblue", "alpha": 0.05}, + y_mean_plot_kwargs={"color": "black", "linestyle": "--", "linewidth": 2.5}, + backend_kwargs={"figsize": (8, 6)}, + textsize=15.0, +) +ax = axes[0][0] +ax.set_xlabel("Time", fontsize=20) +ax.set_ylabel("Hospital Admissions", fontsize=20) +handles, labels = ax.get_legend_handles_labels() +ax.legend( + handles, ["Observed", "Sample Mean", "Posterior Samples"], loc="best" ) +plt.show() ``` ## Results exploration and MCMC diagnostics @@ -317,7 +346,6 @@ To explore further, We can use [ArviZ](https://www.arviz.org/) to visualize the ```{python} # | label: convert-inferenceData -import arviz as az idata = az.from_numpyro(hosp_model.mcmc) ``` @@ -419,7 +447,6 @@ We can use the `Model`'s `posterior_predictive` and `prior_predictive` methods t ```{python} # | label: demonstrate-use-of-predictive-methods -import arviz as az idata = az.from_numpyro( hosp_model.mcmc, diff --git a/pyproject.toml b/pyproject.toml index 41e22f9b..23afcd89 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,6 @@ python = "^3.12" jax = ">=0.4.30" numpy = "^2.0.0" polars = "^1.2.1" -matplotlib = "^3.8.3" numpyro = ">=0.15.3" [tool.poetry.group.dev] @@ -30,6 +29,7 @@ deptry = "^0.17.0" optional = true [tool.poetry.group.docs.dependencies] +matplotlib = "^3.8.3" ipykernel = "^6.29.3" pyyaml = "^6.0.0" nbclient = "^0.10.0" diff --git a/pyrenew/mcmcutils.py b/pyrenew/mcmcutils.py deleted file mode 100644 index ee631871..00000000 --- a/pyrenew/mcmcutils.py +++ /dev/null @@ -1,158 +0,0 @@ -""" -Utilities to deal with MCMC outputs -""" - -from __future__ import annotations - -import matplotlib.pyplot as plt -import numpy as np -import polars as pl -from jax.typing import ArrayLike - - -def spread_draws( - posteriors: dict, - variables_names: list[str] | list[tuple], -) -> pl.DataFrame: - """ - Get nicely shaped draws from the posterior - - Given a dictionary of posteriors, return a long-form polars dataframe - indexed by draw, with variable values (equivalent of tidybayes - spread_draws() function). - - Parameters - ---------- - posteriors: dict - A dictionary of posteriors with variable names - as keys and numpy ndarrays as values (with the - first axis corresponding to the posterior - draw number). - variables_names: list[str] | list[tuple] - list of strings or of tuples identifying which - variables to retrieve. - - Returns - ------- - pl.DataFrame - A polars dataframe of draw-indexed posterior samples. - """ - - for i_var, v in enumerate(variables_names): - if isinstance(v, str): - v_dims = None - else: - v_dims = v[1:] - v = v[0] - - post = posteriors.get(v) - long_post = post.flatten()[..., np.newaxis] - - indices = np.array(list(np.ndindex(post.shape))) - n_dims = indices.shape[1] - 1 - if v_dims is None: - dim_names = [ - ("{}_dim_{}_index".format(v, k), pl.Int64) - for k in range(n_dims) - ] - elif len(v_dims) != n_dims: - raise ValueError( - "incorrect number of " - "dimension names " - "provided for variable " - "{}".format(v) - ) - else: - dim_names = [(v_dim, pl.Int64) for v_dim in v_dims] - - p_df = pl.DataFrame( - np.concatenate([indices, long_post], axis=1), - schema=([("draw", pl.Int64)] + dim_names + [(v, pl.Float64)]), - ) - - if i_var == 0: - df = p_df - else: - df = df.join( - p_df, on=[col for col in df.columns if col in p_df.columns] - ) - pass - - return df - - -def plot_posterior( - var: str, - draws: pl.DataFrame, - obs_signal: ArrayLike = None, - ylab: str = None, - xlab: str = "Time", - samples: int = 50, - figsize: list = [4, 5], - draws_col: str = "darkblue", - obs_col: str = "black", -) -> plt.Figure: - """ - Plot the posterior distribution of a variable - - Parameters - ---------- - var : str - Name of the variable to plot - model : Model - Model object - obs_signal : ArrayLike, optional - Observed signal to plot as reference - ylab : str, optional - Label for the y-axis - xlab : str, optional - Label for the x-axis - samples : int, optional - Number of samples to plot - figsize : list, optional - Size of the figure - draws_col : str, optional - Color of the draws - obs_col : str, optional - Color of observations column. - - Returns - ------- - plt.Figure - """ - - if ylab is None: - ylab = var - - fig, ax = plt.subplots(figsize=figsize) - - # Reference signal (if any) - if obs_signal is not None: - ax.plot(obs_signal, color=obs_col) - - samp_ids = np.random.randint(size=samples, low=0, high=999) - - for samp_id in samp_ids: - sub_samps = draws.filter(pl.col("draw") == samp_id).sort( - pl.col("time") - ) - ax.plot( - sub_samps.select("time").to_numpy(), - sub_samps.select(var).to_numpy(), - color=draws_col, - alpha=0.1, - ) - - # Some labels - ax.set_xlabel(xlab) - ax.set_ylabel(ylab) - - # Adding a legend - ax.plot([], [], color=draws_col, alpha=0.9, label="Posterior samples") - - if obs_signal is not None: - ax.plot([], [], color=obs_col, label="Observed signal") - - ax.legend() - - return fig diff --git a/pyrenew/metaclass.py b/pyrenew/metaclass.py index 26e68cb2..d804e63f 100644 --- a/pyrenew/metaclass.py +++ b/pyrenew/metaclass.py @@ -4,16 +4,11 @@ from abc import ABCMeta, abstractmethod -import jax import jax.random as jr -import matplotlib.pyplot as plt import numpy as np -import polars as pl from jax.typing import ArrayLike from numpyro.infer import MCMC, NUTS, Predictive, init_to_sample -from pyrenew.mcmcutils import plot_posterior, spread_draws - def _assert_type(arg_name: str, value, expected_type) -> None: """ @@ -266,46 +261,6 @@ def print_summary( """ return self.mcmc.print_summary(prob, exclude_deterministic) - def spread_draws(self, variables_names: list) -> pl.DataFrame: - """ - A wrapper of :func:`pyrenew.mcmcutils.spread_draws` - - Parameters - ---------- - variables_names : list - A list of variable names to create a table of samples. - - Returns - ------- - pl.DataFrame - """ - - return spread_draws(self.mcmc.get_samples(), variables_names) - - def plot_posterior( - self, - var: list, - obs_signal: jax.typing.ArrayLike = None, - xlab: str = None, - ylab: str = "Signal", - samples: int = 50, - figsize: list = [4, 5], - draws_col: str = "darkblue", - obs_col: str = "black", - ) -> plt.Figure: # numpydoc ignore=RT01 - """A wrapper of pyrenew.mcmcutils.plot_posterior""" - return plot_posterior( - var=var, - draws=self.spread_draws([(var, "time")]), - xlab=xlab, - ylab=ylab, - samples=samples, - obs_signal=obs_signal, - figsize=figsize, - draws_col=draws_col, - obs_col=obs_col, - ) - def posterior_predictive( self, rng_key: ArrayLike | None = None, diff --git a/test/test_model_basic_renewal.py b/test/test_model_basic_renewal.py index 8b1f99f7..973acf61 100644 --- a/test/test_model_basic_renewal.py +++ b/test/test_model_basic_renewal.py @@ -5,7 +5,6 @@ import numpy as np import numpyro import numpyro.distributions as dist -import polars as pl import pytest from pyrenew.deterministic import DeterministicPMF, NullObservation @@ -156,17 +155,6 @@ def test_model_basicrenewal_no_obs_model(): data_observed_infections=model0_samp.latent_infections, ) - inf = model0.spread_draws(["all_latent_infections"]) - inf_mean = ( - inf.group_by("draw") - .agg(pl.col("all_latent_infections").mean()) - .sort(pl.col("draw")) - ) - - # For now the assertion is only about the expected number of rows - # It should be about the MCMC inference. - assert inf_mean.to_numpy().shape[0] == 500 - def test_model_basicrenewal_with_obs_model(): """ @@ -214,17 +202,6 @@ def test_model_basicrenewal_with_obs_model(): data_observed_infections=model1_samp.observed_infections, ) - inf = model1.spread_draws(["all_latent_infections"]) - inf_mean = ( - inf.group_by("draw") - .agg(pl.col("all_latent_infections").mean()) - .sort(pl.col("draw")) - ) - - # For now the assertion is only about the expected number of rows - # It should be about the MCMC inference. - assert inf_mean.to_numpy().shape[0] == 500 - def test_model_basicrenewal_padding() -> None: # numpydoc ignore=GL08 gen_int = DeterministicPMF( @@ -263,15 +240,3 @@ def test_model_basicrenewal_padding() -> None: # numpydoc ignore=GL08 data_observed_infections=model1_samp.observed_infections, padding=pad_size, ) - - inf = model1.spread_draws(["all_latent_infections"]) - - inf_mean = ( - inf.group_by("draw") - .agg(pl.col("all_latent_infections").mean()) - .sort(pl.col("draw")) - ) - - # For now the assertion is only about the expected number of rows - # It should be about the MCMC inference. - assert inf_mean.to_numpy().shape[0] == 500 diff --git a/test/test_model_hosp_admissions.py b/test/test_model_hosp_admissions.py index 68b9a86b..cd36efe7 100644 --- a/test/test_model_hosp_admissions.py +++ b/test/test_model_hosp_admissions.py @@ -5,7 +5,6 @@ import numpy as np import numpyro import numpyro.distributions as dist -import polars as pl import pytest from pyrenew.deterministic import ( @@ -254,17 +253,6 @@ def test_model_hosp_no_obs_model(): data_observed_hosp_admissions=model0_samp.latent_hosp_admissions, ) - inf = model0.spread_draws(["latent_hospital_admissions"]) - inf_mean = ( - inf.group_by("draw") - .agg(pl.col("latent_hospital_admissions").mean()) - .sort(pl.col("draw")) - ) - - # For now the assertion is only about the expected number of rows - # It should be about the MCMC inference. - assert inf_mean.to_numpy().shape[0] == 500 - def test_model_hosp_with_obs_model(): """ @@ -341,17 +329,6 @@ def test_model_hosp_with_obs_model(): data_observed_hosp_admissions=model1_samp.observed_hosp_admissions, ) - inf = model1.spread_draws(["latent_hospital_admissions"]) - inf_mean = ( - inf.group_by("draw") - .agg(pl.col("latent_hospital_admissions").mean()) - .sort(pl.col("draw")) - ) - - # For now the assertion is only about the expected number of rows - # It should be about the MCMC inference. - assert inf_mean.to_numpy().shape[0] == 500 - def test_model_hosp_with_obs_model_weekday_phosp_2(): """ @@ -435,17 +412,6 @@ def test_model_hosp_with_obs_model_weekday_phosp_2(): data_observed_hosp_admissions=model1_samp.observed_hosp_admissions, ) - inf = model1.spread_draws(["latent_hospital_admissions"]) - inf_mean = ( - inf.group_by("draw") - .agg(pl.col("latent_hospital_admissions").mean()) - .sort(pl.col("draw")) - ) - - # For now the assertion is only about the expected number of rows - # It should be about the MCMC inference. - assert inf_mean.to_numpy().shape[0] == 500 - def test_model_hosp_with_obs_model_weekday_phosp(): """ @@ -557,14 +523,3 @@ def test_model_hosp_with_obs_model_weekday_phosp(): data_observed_hosp_admissions=model1_samp.observed_hosp_admissions, padding=pad_size, ) - - inf = model1.spread_draws(["latent_hospital_admissions"]) - inf_mean = ( - inf.group_by("draw") - .agg(pl.col("latent_hospital_admissions").mean()) - .sort(pl.col("draw")) - ) - - # For now the assertion is only about the expected number of rows - # It should be about the MCMC inference. - assert inf_mean.to_numpy().shape[0] == 500