Skip to content

Commit

Permalink
adding a viz for prior distributions
Browse files Browse the repository at this point in the history
  • Loading branch information
arik-shurygin committed Oct 17, 2024
1 parent a7e57b5 commit ddcdd27
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 9 deletions.
29 changes: 26 additions & 3 deletions shiny_visualizers/azure_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@
# this will reduce the time it takes to load the azure connection, but only shows
# one experiment worth of data, which may be what you want...
# leave empty ("") to explore all experiments
PRE_FILTER_EXPERIMENTS = (
"example_azure_experiment" # fifty_state_season2_5strain_2202_2404
)
PRE_FILTER_EXPERIMENTS = ""
# when loading the overview timelines csv for each run, columns
# are expected to have names corresponding to the type of plot they create
# vaccination_0_17 specifies the vaccination_ plot type, multiple columns may share
Expand Down Expand Up @@ -170,6 +168,12 @@
"Sample Violin Plots",
output_widget("plot_sample_violins"),
),
ui.nav_panel(
"Config Visualizer",
ui.output_plot(
"plot_prior_distributions", width=1600, height=1600
),
),
),
),
)
Expand Down Expand Up @@ -369,6 +373,25 @@ def plot_sample_correlations():
print("displaying correlations plot")
return fig

@output(id="plot_prior_distributions")
@render.plot
@reactive.event(input.action_button)
def plot_prior_distributions():
exp = input.experiment()
job_id = input.job_id()
states = input.states()
scenario = input.scenario()
theme = input.dark_mode()
theme = sutils.shiny_to_matplotlib_theme(theme)
cache_paths = sutils.get_azure_files(
exp, job_id, states, scenario, azure_client, SHINY_CACHE_PATH
)
# we have the figure, now update the light/dark mode depending on the switch
fig = sutils.load_prior_distributions_plot(cache_paths[0], theme)
# we have the figure, now update the light/dark mode depending on the switch
print("displaying prior distributions")
return fig

@output(id="plot_sample_violins")
@render_widget
@reactive.event(input.action_button)
Expand Down
34 changes: 34 additions & 0 deletions shiny_visualizers/shiny_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from tqdm import tqdm

from mechanistic_azure.azure_utilities import download_directory_from_azure
from resp_ode import Config, vis_utils
from resp_ode.utils import drop_keys_with_substring, flatten_list_parameters


Expand Down Expand Up @@ -309,6 +310,22 @@ def load_checkpoint_inference_chains(
return fig


def load_prior_distributions_plot(cache_path, matplotlib_theme):
path = os.path.join(cache_path, "config_inferer_used.json")
if os.path.exists(path):
config = Config(open(path).read())
styles = ["seaborn-v0_8-colorblind", matplotlib_theme]
fig = vis_utils.plot_prior_distributions(
config.asdict(), matplotlib_style=styles
)
else:
raise FileNotFoundError(
"%s does not exist, either the experiment did "
"not save a config used or loading files failed" % path
)
return fig


def load_checkpoint_inference_correlations(
cache_path,
overview_subplot_size: int,
Expand Down Expand Up @@ -855,3 +872,20 @@ def shiny_to_plotly_theme(shiny_theme: str):
plotly theme as str, used in `fig.update_layout(template=theme)`
"""
return "plotly_%s" % (shiny_theme if shiny_theme == "dark" else "white")


def shiny_to_matplotlib_theme(shiny_theme: str):
"""shiny themes are "dark" and "light", plotly themes are
"plotly_dark" and "plotly_white", this function converts from shiny to plotly theme names
Parameters
----------
shiny_theme : str
shiny theme as str
Returns
-------
str
plotly theme as str, used in `fig.update_layout(template=theme)`
"""
return "dark_background" if shiny_theme == "dark" else "ggplot"
39 changes: 33 additions & 6 deletions src/resp_ode/vis_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""A series of utility functions for generating visualizations for the model"""

import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
Expand All @@ -15,6 +16,10 @@
)


class VisualizationError(Exception):
pass


def _cleanup_and_normalize_timelines(
all_state_timelines: pd.DataFrame,
plot_types: np.ndarray[str],
Expand Down Expand Up @@ -474,6 +479,7 @@ def plot_prior_distributions(
"seaborn-v0_8-colorblind",
],
num_samples=50000,
hist_kwargs={"bins": 50, "density": True},
) -> plt.Figure:
"""Given a dictionary of parameter keys and possibly values of
numpyro.distribution objects, samples them a number of times
Expand All @@ -488,6 +494,11 @@ def plot_prior_distributions(
key will be included in the plot
matplotlib_style : list[str] | str, optional
matplotlib style to plot in by default ["seaborn-v0_8-colorblind"]
num_samples: int, optional
the number of times to sample each distribution, mild impact on
figure performance. By default 50000
hist_kwargs: dict[str: Any]
additional kwargs passed to plt.hist(), by default {"bins": 50}
Returns
-------
Expand All @@ -497,7 +508,6 @@ def plot_prior_distributions(
"""
dist_only = {}
d = identify_distribution_indexes(priors)
print(d)
# filter down to just the distribution objects
for dist_name, locator_dct in d.items():
parameter_name = locator_dct["sample_name"]
Expand All @@ -512,9 +522,12 @@ def plot_prior_distributions(
for i in parameter_idx:
temp = temp[i]
dist_only[dist_name] = temp
print(dist_only)
param_names = list(dist_only.keys())
num_params = len(param_names)
if num_params == 0:
raise VisualizationError(
"Attempted to visualize a config without any distributions"
)
# Calculate the number of rows and columns for a square-ish layout
num_cols = int(np.ceil(np.sqrt(num_params)))
num_rows = int(np.ceil(num_params / num_cols))
Expand All @@ -533,10 +546,24 @@ def plot_prior_distributions(
ax.set_title(param_name)
dist = dist_only[param_name]
samples = dist.sample(PRNGKey(0), sample_shape=(num_samples,))
ax.hist(samples, bins=50)
# Hide x-axis labels except for bottom plots to reduce clutter
# if i < (num_params - num_cols):
# ax.set_xticklabels([])
ax.hist(samples, **hist_kwargs)
ax.axvline(
samples.mean(),
linestyle="dashed",
linewidth=1,
label="mean",
)
ax.axvline(
jnp.median(samples),
linestyle="dotted",
linewidth=3,
label="median",
)
# Turn off any unused subplots
for j in range(i + 1, len(axs_flat)):
axs_flat[j].axis("off")
handles, labels = ax.get_legend_handles_labels()
fig.legend(handles, labels, loc="outside upper right")
fig.suptitle("Prior Distributions Visualized, n=%s" % num_samples)
plt.tight_layout()
return fig

0 comments on commit ddcdd27

Please sign in to comment.