From ddcdd27d9da0c9fb5f10f2d9aac2cc34b778cb35 Mon Sep 17 00:00:00 2001 From: Ariel Shurygin Date: Thu, 17 Oct 2024 17:55:27 +0000 Subject: [PATCH] adding a viz for prior distributions --- shiny_visualizers/azure_visualizer.py | 29 +++++++++++++++++--- shiny_visualizers/shiny_utils.py | 34 +++++++++++++++++++++++ src/resp_ode/vis_utils.py | 39 ++++++++++++++++++++++----- 3 files changed, 93 insertions(+), 9 deletions(-) diff --git a/shiny_visualizers/azure_visualizer.py b/shiny_visualizers/azure_visualizer.py index c2ddca0..9324ec0 100644 --- a/shiny_visualizers/azure_visualizer.py +++ b/shiny_visualizers/azure_visualizer.py @@ -22,9 +22,7 @@ # this will reduce the time it takes to load the azure connection, but only shows # one experiment worth of data, which may be what you want... # leave empty ("") to explore all experiments -PRE_FILTER_EXPERIMENTS = ( - "example_azure_experiment" # fifty_state_season2_5strain_2202_2404 -) +PRE_FILTER_EXPERIMENTS = "" # when loading the overview timelines csv for each run, columns # are expected to have names corresponding to the type of plot they create # vaccination_0_17 specifies the vaccination_ plot type, multiple columns may share @@ -170,6 +168,12 @@ "Sample Violin Plots", output_widget("plot_sample_violins"), ), + ui.nav_panel( + "Config Visualizer", + ui.output_plot( + "plot_prior_distributions", width=1600, height=1600 + ), + ), ), ), ) @@ -369,6 +373,25 @@ def plot_sample_correlations(): print("displaying correlations plot") return fig + @output(id="plot_prior_distributions") + @render.plot + @reactive.event(input.action_button) + def plot_prior_distributions(): + exp = input.experiment() + job_id = input.job_id() + states = input.states() + scenario = input.scenario() + theme = input.dark_mode() + theme = sutils.shiny_to_matplotlib_theme(theme) + cache_paths = sutils.get_azure_files( + exp, job_id, states, scenario, azure_client, SHINY_CACHE_PATH + ) + # we have the figure, now update the light/dark mode depending on the switch + fig = sutils.load_prior_distributions_plot(cache_paths[0], theme) + # we have the figure, now update the light/dark mode depending on the switch + print("displaying prior distributions") + return fig + @output(id="plot_sample_violins") @render_widget @reactive.event(input.action_button) diff --git a/shiny_visualizers/shiny_utils.py b/shiny_visualizers/shiny_utils.py index 872cba6..717c2fc 100644 --- a/shiny_visualizers/shiny_utils.py +++ b/shiny_visualizers/shiny_utils.py @@ -19,6 +19,7 @@ from tqdm import tqdm from mechanistic_azure.azure_utilities import download_directory_from_azure +from resp_ode import Config, vis_utils from resp_ode.utils import drop_keys_with_substring, flatten_list_parameters @@ -309,6 +310,22 @@ def load_checkpoint_inference_chains( return fig +def load_prior_distributions_plot(cache_path, matplotlib_theme): + path = os.path.join(cache_path, "config_inferer_used.json") + if os.path.exists(path): + config = Config(open(path).read()) + styles = ["seaborn-v0_8-colorblind", matplotlib_theme] + fig = vis_utils.plot_prior_distributions( + config.asdict(), matplotlib_style=styles + ) + else: + raise FileNotFoundError( + "%s does not exist, either the experiment did " + "not save a config used or loading files failed" % path + ) + return fig + + def load_checkpoint_inference_correlations( cache_path, overview_subplot_size: int, @@ -855,3 +872,20 @@ def shiny_to_plotly_theme(shiny_theme: str): plotly theme as str, used in `fig.update_layout(template=theme)` """ return "plotly_%s" % (shiny_theme if shiny_theme == "dark" else "white") + + +def shiny_to_matplotlib_theme(shiny_theme: str): + """shiny themes are "dark" and "light", plotly themes are + "plotly_dark" and "plotly_white", this function converts from shiny to plotly theme names + + Parameters + ---------- + shiny_theme : str + shiny theme as str + + Returns + ------- + str + plotly theme as str, used in `fig.update_layout(template=theme)` + """ + return "dark_background" if shiny_theme == "dark" else "ggplot" diff --git a/src/resp_ode/vis_utils.py b/src/resp_ode/vis_utils.py index 4f004ab..0a3a10f 100644 --- a/src/resp_ode/vis_utils.py +++ b/src/resp_ode/vis_utils.py @@ -1,5 +1,6 @@ """A series of utility functions for generating visualizations for the model""" +import jax.numpy as jnp import matplotlib.pyplot as plt import numpy as np import pandas as pd @@ -15,6 +16,10 @@ ) +class VisualizationError(Exception): + pass + + def _cleanup_and_normalize_timelines( all_state_timelines: pd.DataFrame, plot_types: np.ndarray[str], @@ -474,6 +479,7 @@ def plot_prior_distributions( "seaborn-v0_8-colorblind", ], num_samples=50000, + hist_kwargs={"bins": 50, "density": True}, ) -> plt.Figure: """Given a dictionary of parameter keys and possibly values of numpyro.distribution objects, samples them a number of times @@ -488,6 +494,11 @@ def plot_prior_distributions( key will be included in the plot matplotlib_style : list[str] | str, optional matplotlib style to plot in by default ["seaborn-v0_8-colorblind"] + num_samples: int, optional + the number of times to sample each distribution, mild impact on + figure performance. By default 50000 + hist_kwargs: dict[str: Any] + additional kwargs passed to plt.hist(), by default {"bins": 50} Returns ------- @@ -497,7 +508,6 @@ def plot_prior_distributions( """ dist_only = {} d = identify_distribution_indexes(priors) - print(d) # filter down to just the distribution objects for dist_name, locator_dct in d.items(): parameter_name = locator_dct["sample_name"] @@ -512,9 +522,12 @@ def plot_prior_distributions( for i in parameter_idx: temp = temp[i] dist_only[dist_name] = temp - print(dist_only) param_names = list(dist_only.keys()) num_params = len(param_names) + if num_params == 0: + raise VisualizationError( + "Attempted to visualize a config without any distributions" + ) # Calculate the number of rows and columns for a square-ish layout num_cols = int(np.ceil(np.sqrt(num_params))) num_rows = int(np.ceil(num_params / num_cols)) @@ -533,10 +546,24 @@ def plot_prior_distributions( ax.set_title(param_name) dist = dist_only[param_name] samples = dist.sample(PRNGKey(0), sample_shape=(num_samples,)) - ax.hist(samples, bins=50) - # Hide x-axis labels except for bottom plots to reduce clutter - # if i < (num_params - num_cols): - # ax.set_xticklabels([]) + ax.hist(samples, **hist_kwargs) + ax.axvline( + samples.mean(), + linestyle="dashed", + linewidth=1, + label="mean", + ) + ax.axvline( + jnp.median(samples), + linestyle="dotted", + linewidth=3, + label="median", + ) + # Turn off any unused subplots + for j in range(i + 1, len(axs_flat)): + axs_flat[j].axis("off") + handles, labels = ax.get_legend_handles_labels() + fig.legend(handles, labels, loc="outside upper right") fig.suptitle("Prior Distributions Visualized, n=%s" % num_samples) plt.tight_layout() return fig