Skip to content

Commit

Permalink
Diagnostics: partitioning filtering and naming (#260)
Browse files Browse the repository at this point in the history
* Migrate diagnostics module

* Add confusion_matrix

* Remove duplicate code for layout configuration

* Simplify pre-processing for plot_recovery and plot_z_score_contraction

* Simplify prettify process for z_score and recovery

* Simplify preprocessing for plot_sbc_ecdf and plot_sbc_histograms

* Simplify labeling

* Reformat

* Make plot_distribution_2d more compatible

* Update quickstart notebook with working prior checks

* Improve compatibility for plot_losses and plot_prior_2d

* Update quickstart notebook with loss trajectory

* Minor fix of plot_utils, start adding tests for diagnostics

* Pre-final version WIP

* Minor changes in 2d plots; update and test plot_z_score_contraction

* Update and test plot_sbc_histograms

* Update plot_calibration_curves (WIP)

* Minor refactors: change global color schemes, complete type casts, further simplify plot_losses

* Add detailed callback for loss trajectory

* Generalize preprocessing utilities from samples to variables

* Generalize add_metric

* Remove redundant code segments related to prettify

* Include add_titles and add_titles_and_labels; propagate variables as samples

* Interim cleanup

* Add typing and fix plot_sbc_ecdf

* Minor fix of plot_samples_2d

* Minor fix of plot_prior_2d

* Remove redundant code for axes flattening

* Ensure consistent color scheme; incorporate sequence of labels

* Bug fix for plot_losses

* Cleanup

* Partition filtering and renaming in dicts_to_arrays

* Propagate filter keys and variable names

* Rename all 'names' to 'variable_names'

* Getting rid of test_diagnostics (for now) to make sure that tests are passing

* Minor bugfix to plot_samples_2d and plot_z_score_contraction based on partitioning

---------

Co-authored-by: stefanradev93 <[email protected]>
  • Loading branch information
jerrymhuang and stefanradev93 authored Dec 2, 2024
1 parent fcc78a8 commit 0c1c9bc
Show file tree
Hide file tree
Showing 9 changed files with 455 additions and 420 deletions.
2 changes: 1 addition & 1 deletion bayesflow/diagnostics/plot_calibration_curves.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def plot_calibration_curves(
axes=plot_data["axes"],
num_row=plot_data["num_row"],
num_col=plot_data["num_col"],
title=plot_data["names"],
title=plot_data["variable_names"],
xlabel="Predicted Probability",
ylabel="True Probability",
title_fontsize=title_fontsize,
Expand Down
5 changes: 3 additions & 2 deletions bayesflow/diagnostics/plot_recovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
def plot_recovery(
post_samples: dict[str, np.ndarray] | np.ndarray,
prior_samples: dict[str, np.ndarray] | np.ndarray,
filter_keys: Sequence[str] = None,
variable_names: Sequence[str] = None,
point_agg=np.median,
uncertainty_agg=median_abs_deviation,
Expand Down Expand Up @@ -59,7 +60,7 @@ def plot_recovery(
"""

# Gather plot data and metadata into a dictionary
plot_data = preprocess(post_samples, prior_samples, variable_names, num_col, num_row, figsize)
plot_data = preprocess(post_samples, prior_samples, filter_keys, variable_names, num_col, num_row, figsize)
plot_data["post_samples"] = plot_data.pop("post_variables")
plot_data["prior_samples"] = plot_data.pop("prior_variables")

Expand Down Expand Up @@ -93,7 +94,7 @@ def plot_recovery(
corr = np.corrcoef(plot_data["prior_samples"][:, i], point_estimate[:, i])[0, 1]
add_metric(ax=ax, metric_text="$r$", metric_value=corr, metric_fontsize=metric_fontsize)

ax.set_title(plot_data["names"][i], fontsize=title_fontsize)
ax.set_title(plot_data["variable_names"][i], fontsize=title_fontsize)

# Add custom schmuck
prettify_subplots(plot_data["axes"], num_subplots=plot_data["num_variables"], tick_fontsize=tick_fontsize)
Expand Down
16 changes: 13 additions & 3 deletions bayesflow/diagnostics/plot_samples_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@
import seaborn as sns
import pandas as pd

from typing import Sequence
from bayesflow.utils import logging
from bayesflow.utils.dict_utils import dicts_to_arrays


def plot_samples_2d(
samples: np.ndarray = None,
samples: dict[str, np.ndarray] | np.ndarray = None,
filter_keys: Sequence[str] = None,
context: str = None,
variable_names: list = None,
height: float = 2.5,
Expand Down Expand Up @@ -41,7 +44,11 @@ def plot_samples_2d(
Additional keyword arguments passed to the sns.PairGrid constructor
"""

dim = samples.shape[-1]
plot_data = dicts_to_arrays(
post_variables=samples, filter_keys=filter_keys, variable_names=variable_names, context=context
)

dim = plot_data["post_variables"].shape[-1]
if context is None:
context = "Default"

Expand All @@ -52,7 +59,10 @@ def plot_samples_2d(
titles = [f"{context} {p}" for p in variable_names]

# Convert samples to pd.DataFrame
data_to_plot = pd.DataFrame(samples, columns=titles)
if context == "Posterior":
data_to_plot = pd.DataFrame(plot_data["post_variables"][0], columns=titles)
else:
data_to_plot = pd.DataFrame(plot_data["post_variables"], columns=titles)

# Generate plots
artist = sns.PairGrid(data_to_plot, height=height, **kwargs)
Expand Down
7 changes: 5 additions & 2 deletions bayesflow/diagnostics/plot_sbc_ecdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
def plot_sbc_ecdf(
post_samples: dict[str, np.ndarray] | np.ndarray,
prior_samples: dict[str, np.ndarray] | np.ndarray,
filter_keys: Sequence[str] = None,
variable_names: Sequence[str] = None,
difference: bool = False,
stacked: bool = False,
Expand Down Expand Up @@ -92,7 +93,9 @@ def plot_sbc_ecdf(
"""

# Preprocessing
plot_data = preprocess(post_samples, prior_samples, variable_names, num_col, num_row, figsize, stacked=stacked)
plot_data = preprocess(
post_samples, prior_samples, filter_keys, variable_names, num_col, num_row, figsize, stacked=stacked
)
plot_data["post_samples"] = plot_data.pop("post_variables")
plot_data["prior_samples"] = plot_data.pop("prior_variables")

Expand Down Expand Up @@ -129,7 +132,7 @@ def plot_sbc_ecdf(
ylab = "ECDF"

# Add simultaneous bounds
titles = plot_data["names"] if not stacked else ["Stacked ECDFs"]
titles = plot_data["variable_names"] if not stacked else ["Stacked ECDFs"]
for ax, title in zip(plot_data["axes"].flat, titles):
ax.fill_between(z, L, H, color=fill_color, alpha=0.2, label=rf"{int((1-alpha) * 100)}$\%$ Confidence Bands")
ax.legend(fontsize=legend_fontsize)
Expand Down
7 changes: 4 additions & 3 deletions bayesflow/diagnostics/plot_sbc_histograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
def plot_sbc_histograms(
post_samples: dict[str, np.ndarray] | np.ndarray,
prior_samples: dict[str, np.ndarray] | np.ndarray,
filter_keys: Sequence[str] = None,
variable_names: Sequence[str] = None,
figsize: Sequence[float] = None,
num_bins: int = 10,
Expand Down Expand Up @@ -71,11 +72,11 @@ def plot_sbc_histograms(
"""

# Preprocessing
plot_data = preprocess(post_samples, prior_samples, num_col, num_row, variable_names, figsize)
plot_data = preprocess(post_samples, prior_samples, filter_keys, variable_names, num_col, num_row, figsize=figsize)
plot_data["post_samples"] = plot_data.pop("post_variables")
plot_data["prior_samples"] = plot_data.pop("prior_variables")

# Determine the ratio of simulations to prior draws
# Determine the ratio of simulations to prior draw
# num_params = plot_data['num_variables']
num_sims = plot_data["post_samples"].shape[0]
num_draws = plot_data["post_samples"].shape[1]
Expand Down Expand Up @@ -119,7 +120,7 @@ def plot_sbc_histograms(
axes=plot_data["axes"],
num_row=plot_data["num_row"],
num_col=plot_data["num_col"],
title=plot_data["names"],
title=plot_data["variable_names"],
xlabel="Rank statistic",
ylabel="",
title_fontsize=title_fontsize,
Expand Down
7 changes: 4 additions & 3 deletions bayesflow/diagnostics/plot_z_score_contraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
def plot_z_score_contraction(
post_samples: dict[str, np.ndarray] | np.ndarray,
prior_samples: dict[str, np.ndarray] | np.ndarray,
filter_keys: Sequence[str] = None,
variable_names: Sequence[str] = None,
figsize: Sequence[int] = None,
label_fontsize: int = 16,
Expand Down Expand Up @@ -84,7 +85,7 @@ def plot_z_score_contraction(
"""

# Preprocessing
plot_data = preprocess(post_samples, prior_samples, variable_names, num_col, num_row, figsize)
plot_data = preprocess(post_samples, prior_samples, filter_keys, variable_names, num_col, num_row, figsize)
plot_data["post_samples"] = plot_data.pop("post_variables")
plot_data["prior_samples"] = plot_data.pop("prior_variables")

Expand All @@ -98,7 +99,7 @@ def plot_z_score_contraction(

# Compute contraction and z-score
contraction = 1 - (post_vars / prior_vars)
z_score = (post_means - prior_samples) / post_stds
z_score = (post_means - plot_data["prior_samples"]) / post_stds

# Loop and plot
for i, ax in enumerate(plot_data["axes"].flat):
Expand All @@ -115,7 +116,7 @@ def plot_z_score_contraction(
axes=plot_data["axes"],
num_row=plot_data["num_row"],
num_col=plot_data["num_col"],
title=plot_data["names"],
title=plot_data["variable_names"],
xlabel="Posterior contraction",
ylabel="Posterior z-score",
title_fontsize=title_fontsize,
Expand Down
60 changes: 40 additions & 20 deletions bayesflow/utils/dict_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,36 +107,56 @@ def split_tensors(data: Mapping[any, Tensor], axis: int = -1) -> Mapping[any, Te

def dicts_to_arrays(
post_variables: dict[str, np.ndarray] | np.ndarray,
prior_variables: dict[str, np.ndarray] | np.ndarray,
names: Sequence[str] = None,
prior_variables: dict[str, np.ndarray] | np.ndarray = None,
filter_keys: Sequence[str] | None = None,
variable_names: Sequence[str] = None,
context: str = None,
):
"""Utility to optionally convert dicts as returned from approximators and adapters into arrays."""
"""
# TODO
"""

if type(post_variables) is not type(prior_variables):
raise ValueError("You should either use dicts or tensors, but not separate types for your inputs.")
# Ensure that posterior and prior variables have the same type
if prior_variables is not None:
if type(post_variables) is not type(prior_variables):
raise ValueError("You should either use dicts or tensors, but not separate types for your inputs.")

# Filtering
if isinstance(post_variables, dict):
if post_variables.keys() != prior_variables.keys():
raise ValueError("Keys in your posterior / prior arrays should match.")
# Ensure that the keys of selected posterior and prior variables match
if prior_variables is not None:
if not (set(post_variables) <= set(prior_variables)):
raise ValueError("Keys in your posterior / prior arrays should match.")

# Use user-provided names instead of inferred ones
names = list(post_variables.keys()) if names is None else names
# If they match, users can further select the variables by using filter keys
filter_keys = list(post_variables.keys()) if filter_keys is None else filter_keys

post_variables = np.concatenate([v for k, v in post_variables.items() if k in names], axis=-1)
prior_variables = np.concatenate([v for k, v in prior_variables.items() if k in names], axis=-1)
# The variables will then be overridden with the filtered keys
post_variables = np.concatenate([v for k, v in post_variables.items() if k in filter_keys], axis=-1)
if prior_variables is not None:
prior_variables = np.concatenate([v for k, v in prior_variables.items() if k in filter_keys], axis=-1)

# Naming or Renaming
elif isinstance(post_variables, np.ndarray):
if names is not None:
if post_variables.shape[-1] != len(names) or prior_variables.shape[-1] != len(names):
raise ValueError("The length of the names list should match the number of target variables.")
else:
if context is not None:
names = [f"${context}_{{{i}}}$" for i in range(post_variables.shape[-1])]
# If there are filter_keys, check if their number is the same as that of the variables.
# If it does, check if there are sufficient variable names.
# If there are, then the variable names are adopted.
if variable_names is not None:
if post_variables.shape[-1] != len(variable_names) or prior_variables.shape[-1] != len(variable_names):
raise ValueError("The number of variable names should match the number of target variables.")

else: # Otherwise, we would assume that all variables are used for plotting.
if context is None:
if variable_names is None:
variable_names = [f"$\\theta_{{{i}}}$" for i in range(post_variables.shape[-1])]
else:
names = [f"$\\theta_{{{i}}}$" for i in range(post_variables.shape[-1])]

variable_names = [f"${context}_{{{i}}}$" for i in range(post_variables.shape[-1])]
else:
raise TypeError("Only dicts and tensors are supported as arguments.")

return dict(post_variables=post_variables, prior_variables=prior_variables, names=names, num_variables=len(names))
return dict(
post_variables=post_variables,
prior_variables=prior_variables,
variable_names=variable_names,
num_variables=len(variable_names),
)
15 changes: 11 additions & 4 deletions bayesflow/utils/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@


def preprocess(
post_variables: dict[str, np.ndarray],
prior_variables: dict[str, np.ndarray],
names: Sequence[str] = None,
post_variables: dict[str, np.ndarray] | np.ndarray,
prior_variables: dict[str, np.ndarray] | np.ndarray,
filter_keys: Sequence[str] = None,
variable_names: Sequence[str] = None,
context: str = None,
num_col: int = None,
num_row: int = None,
Expand Down Expand Up @@ -43,7 +44,13 @@ def preprocess(
Whether the plots are stacked horizontally
"""

plot_data = dicts_to_arrays(post_variables, prior_variables, names, context)
plot_data = dicts_to_arrays(
post_variables=post_variables,
prior_variables=prior_variables,
filter_keys=filter_keys,
variable_names=variable_names,
context=context,
)
check_posterior_prior_shapes(plot_data["post_variables"], plot_data["prior_variables"])

# Configure layout
Expand Down
Loading

0 comments on commit 0c1c9bc

Please sign in to comment.