Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Diagnostics: partitioning filtering and naming #260

Merged
merged 43 commits into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
adf3f5b
Migrate diagnostics module
jerrymhuang Nov 4, 2024
104f726
Add confusion_matrix
jerrymhuang Nov 4, 2024
9f66871
Merge branch 'dev' into diagnostics
jerrymhuang Nov 8, 2024
ae2a6a3
Remove duplicate code for layout configuration
jerrymhuang Nov 9, 2024
bf9066a
Simplify pre-processing for plot_recovery and plot_z_score_contraction
jerrymhuang Nov 9, 2024
b522a7e
Simplify prettify process for z_score and recovery
jerrymhuang Nov 9, 2024
609c89c
Simplify preprocessing for plot_sbc_ecdf and plot_sbc_histograms
jerrymhuang Nov 9, 2024
1beb585
Simplify labeling
jerrymhuang Nov 9, 2024
fd24587
Reformat
jerrymhuang Nov 9, 2024
8fde9e9
Make plot_distribution_2d more compatible
jerrymhuang Nov 9, 2024
a7fd398
Update quickstart notebook with working prior checks
jerrymhuang Nov 9, 2024
04bd262
Improve compatibility for plot_losses and plot_prior_2d
jerrymhuang Nov 11, 2024
a4d4452
Update quickstart notebook with loss trajectory
jerrymhuang Nov 12, 2024
18f0514
Minor fix of plot_utils, start adding tests for diagnostics
jerrymhuang Nov 13, 2024
05a4978
Pre-final version WIP
stefanradev93 Nov 13, 2024
0e956b5
Merge branch 'dev' into diagnostics
jerrymhuang Nov 13, 2024
3090ea7
Minor changes in 2d plots; update and test plot_z_score_contraction
jerrymhuang Nov 14, 2024
5cb5847
Update and test plot_sbc_histograms
jerrymhuang Nov 14, 2024
5462ecf
Update plot_calibration_curves (WIP)
jerrymhuang Nov 14, 2024
2868f96
Minor refactors: change global color schemes, complete type casts, fu…
jerrymhuang Nov 14, 2024
ad59eae
Add detailed callback for loss trajectory
jerrymhuang Nov 14, 2024
15bc4ab
Generalize preprocessing utilities from samples to variables
jerrymhuang Nov 14, 2024
16436c2
Generalize add_metric
jerrymhuang Nov 15, 2024
c25a67d
Remove redundant code segments related to prettify
jerrymhuang Nov 15, 2024
c063c67
Include add_titles and add_titles_and_labels; propagate variables as …
jerrymhuang Nov 15, 2024
d2a742e
Interim cleanup
jerrymhuang Nov 15, 2024
7875661
Add typing and fix plot_sbc_ecdf
jerrymhuang Nov 15, 2024
05ef11b
Minor fix of plot_samples_2d
jerrymhuang Nov 15, 2024
88cc1c4
Minor fix of plot_prior_2d
jerrymhuang Nov 15, 2024
d9ab082
Remove redundant code for axes flattening
jerrymhuang Nov 15, 2024
6683b25
Ensure consistent color scheme; incorporate sequence of labels
jerrymhuang Nov 15, 2024
9de0479
Bug fix for plot_losses
jerrymhuang Nov 15, 2024
9623061
Cleanup
jerrymhuang Nov 15, 2024
896caad
Merge branch 'dev' into diagnostics
jerrymhuang Nov 23, 2024
dec4190
Merge branch 'dev' into diagnostics
jerrymhuang Nov 25, 2024
bb3cd83
Partition filtering and renaming in dicts_to_arrays
jerrymhuang Nov 26, 2024
4fdc1d7
Merge branch 'dev' into diagnostics
jerrymhuang Nov 26, 2024
2e3b0a3
Merge branch 'dev' into diagnostics
jerrymhuang Nov 27, 2024
8dd669f
Propagate filter keys and variable names
jerrymhuang Nov 29, 2024
dca0ba0
Rename all 'names' to 'variable_names'
jerrymhuang Nov 29, 2024
a0e0ed0
Getting rid of test_diagnostics (for now) to make sure that tests are…
jerrymhuang Nov 30, 2024
78a84b7
Merge branch 'dev' into diagnostics
jerrymhuang Nov 30, 2024
9a0e9ff
Minor bugfix to plot_samples_2d and plot_z_score_contraction based on…
jerrymhuang Dec 1, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bayesflow/diagnostics/plot_calibration_curves.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def plot_calibration_curves(
axes=plot_data["axes"],
num_row=plot_data["num_row"],
num_col=plot_data["num_col"],
title=plot_data["names"],
title=plot_data["variable_names"],
xlabel="Predicted Probability",
ylabel="True Probability",
title_fontsize=title_fontsize,
Expand Down
5 changes: 3 additions & 2 deletions bayesflow/diagnostics/plot_recovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
def plot_recovery(
post_samples: dict[str, np.ndarray] | np.ndarray,
prior_samples: dict[str, np.ndarray] | np.ndarray,
filter_keys: Sequence[str] = None,
variable_names: Sequence[str] = None,
point_agg=np.median,
uncertainty_agg=median_abs_deviation,
Expand Down Expand Up @@ -59,7 +60,7 @@ def plot_recovery(
"""

# Gather plot data and metadata into a dictionary
plot_data = preprocess(post_samples, prior_samples, variable_names, num_col, num_row, figsize)
plot_data = preprocess(post_samples, prior_samples, filter_keys, variable_names, num_col, num_row, figsize)
plot_data["post_samples"] = plot_data.pop("post_variables")
plot_data["prior_samples"] = plot_data.pop("prior_variables")

Expand Down Expand Up @@ -93,7 +94,7 @@ def plot_recovery(
corr = np.corrcoef(plot_data["prior_samples"][:, i], point_estimate[:, i])[0, 1]
add_metric(ax=ax, metric_text="$r$", metric_value=corr, metric_fontsize=metric_fontsize)

ax.set_title(plot_data["names"][i], fontsize=title_fontsize)
ax.set_title(plot_data["variable_names"][i], fontsize=title_fontsize)

# Add custom schmuck
prettify_subplots(plot_data["axes"], num_subplots=plot_data["num_variables"], tick_fontsize=tick_fontsize)
Expand Down
16 changes: 13 additions & 3 deletions bayesflow/diagnostics/plot_samples_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@
import seaborn as sns
import pandas as pd

from typing import Sequence
from bayesflow.utils import logging
from bayesflow.utils.dict_utils import dicts_to_arrays


def plot_samples_2d(
samples: np.ndarray = None,
samples: dict[str, np.ndarray] | np.ndarray = None,
filter_keys: Sequence[str] = None,
context: str = None,
variable_names: list = None,
height: float = 2.5,
Expand Down Expand Up @@ -41,7 +44,11 @@ def plot_samples_2d(
Additional keyword arguments passed to the sns.PairGrid constructor
"""

dim = samples.shape[-1]
plot_data = dicts_to_arrays(
post_variables=samples, filter_keys=filter_keys, variable_names=variable_names, context=context
)

dim = plot_data["post_variables"].shape[-1]
if context is None:
context = "Default"

Expand All @@ -52,7 +59,10 @@ def plot_samples_2d(
titles = [f"{context} {p}" for p in variable_names]

# Convert samples to pd.DataFrame
data_to_plot = pd.DataFrame(samples, columns=titles)
if context == "Posterior":
data_to_plot = pd.DataFrame(plot_data["post_variables"][0], columns=titles)
else:
data_to_plot = pd.DataFrame(plot_data["post_variables"], columns=titles)

# Generate plots
artist = sns.PairGrid(data_to_plot, height=height, **kwargs)
Expand Down
7 changes: 5 additions & 2 deletions bayesflow/diagnostics/plot_sbc_ecdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
def plot_sbc_ecdf(
post_samples: dict[str, np.ndarray] | np.ndarray,
prior_samples: dict[str, np.ndarray] | np.ndarray,
filter_keys: Sequence[str] = None,
variable_names: Sequence[str] = None,
difference: bool = False,
stacked: bool = False,
Expand Down Expand Up @@ -92,7 +93,9 @@ def plot_sbc_ecdf(
"""

# Preprocessing
plot_data = preprocess(post_samples, prior_samples, variable_names, num_col, num_row, figsize, stacked=stacked)
plot_data = preprocess(
post_samples, prior_samples, filter_keys, variable_names, num_col, num_row, figsize, stacked=stacked
)
plot_data["post_samples"] = plot_data.pop("post_variables")
plot_data["prior_samples"] = plot_data.pop("prior_variables")

Expand Down Expand Up @@ -129,7 +132,7 @@ def plot_sbc_ecdf(
ylab = "ECDF"

# Add simultaneous bounds
titles = plot_data["names"] if not stacked else ["Stacked ECDFs"]
titles = plot_data["variable_names"] if not stacked else ["Stacked ECDFs"]
for ax, title in zip(plot_data["axes"].flat, titles):
ax.fill_between(z, L, H, color=fill_color, alpha=0.2, label=rf"{int((1-alpha) * 100)}$\%$ Confidence Bands")
ax.legend(fontsize=legend_fontsize)
Expand Down
7 changes: 4 additions & 3 deletions bayesflow/diagnostics/plot_sbc_histograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
def plot_sbc_histograms(
post_samples: dict[str, np.ndarray] | np.ndarray,
prior_samples: dict[str, np.ndarray] | np.ndarray,
filter_keys: Sequence[str] = None,
variable_names: Sequence[str] = None,
figsize: Sequence[float] = None,
num_bins: int = 10,
Expand Down Expand Up @@ -71,11 +72,11 @@ def plot_sbc_histograms(
"""

# Preprocessing
plot_data = preprocess(post_samples, prior_samples, num_col, num_row, variable_names, figsize)
plot_data = preprocess(post_samples, prior_samples, filter_keys, variable_names, num_col, num_row, figsize=figsize)
plot_data["post_samples"] = plot_data.pop("post_variables")
plot_data["prior_samples"] = plot_data.pop("prior_variables")

# Determine the ratio of simulations to prior draws
# Determine the ratio of simulations to prior draw
# num_params = plot_data['num_variables']
num_sims = plot_data["post_samples"].shape[0]
num_draws = plot_data["post_samples"].shape[1]
Expand Down Expand Up @@ -119,7 +120,7 @@ def plot_sbc_histograms(
axes=plot_data["axes"],
num_row=plot_data["num_row"],
num_col=plot_data["num_col"],
title=plot_data["names"],
title=plot_data["variable_names"],
xlabel="Rank statistic",
ylabel="",
title_fontsize=title_fontsize,
Expand Down
7 changes: 4 additions & 3 deletions bayesflow/diagnostics/plot_z_score_contraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
def plot_z_score_contraction(
post_samples: dict[str, np.ndarray] | np.ndarray,
prior_samples: dict[str, np.ndarray] | np.ndarray,
filter_keys: Sequence[str] = None,
variable_names: Sequence[str] = None,
figsize: Sequence[int] = None,
label_fontsize: int = 16,
Expand Down Expand Up @@ -84,7 +85,7 @@ def plot_z_score_contraction(
"""

# Preprocessing
plot_data = preprocess(post_samples, prior_samples, variable_names, num_col, num_row, figsize)
plot_data = preprocess(post_samples, prior_samples, filter_keys, variable_names, num_col, num_row, figsize)
plot_data["post_samples"] = plot_data.pop("post_variables")
plot_data["prior_samples"] = plot_data.pop("prior_variables")

Expand All @@ -98,7 +99,7 @@ def plot_z_score_contraction(

# Compute contraction and z-score
contraction = 1 - (post_vars / prior_vars)
z_score = (post_means - prior_samples) / post_stds
z_score = (post_means - plot_data["prior_samples"]) / post_stds

# Loop and plot
for i, ax in enumerate(plot_data["axes"].flat):
Expand All @@ -115,7 +116,7 @@ def plot_z_score_contraction(
axes=plot_data["axes"],
num_row=plot_data["num_row"],
num_col=plot_data["num_col"],
title=plot_data["names"],
title=plot_data["variable_names"],
xlabel="Posterior contraction",
ylabel="Posterior z-score",
title_fontsize=title_fontsize,
Expand Down
60 changes: 40 additions & 20 deletions bayesflow/utils/dict_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,36 +107,56 @@ def split_tensors(data: Mapping[any, Tensor], axis: int = -1) -> Mapping[any, Te

def dicts_to_arrays(
post_variables: dict[str, np.ndarray] | np.ndarray,
prior_variables: dict[str, np.ndarray] | np.ndarray,
names: Sequence[str] = None,
prior_variables: dict[str, np.ndarray] | np.ndarray = None,
filter_keys: Sequence[str] | None = None,
variable_names: Sequence[str] = None,
context: str = None,
):
"""Utility to optionally convert dicts as returned from approximators and adapters into arrays."""
"""
# TODO
"""

if type(post_variables) is not type(prior_variables):
raise ValueError("You should either use dicts or tensors, but not separate types for your inputs.")
# Ensure that posterior and prior variables have the same type
if prior_variables is not None:
if type(post_variables) is not type(prior_variables):
raise ValueError("You should either use dicts or tensors, but not separate types for your inputs.")

# Filtering
if isinstance(post_variables, dict):
if post_variables.keys() != prior_variables.keys():
raise ValueError("Keys in your posterior / prior arrays should match.")
# Ensure that the keys of selected posterior and prior variables match
if prior_variables is not None:
if not (set(post_variables) <= set(prior_variables)):
raise ValueError("Keys in your posterior / prior arrays should match.")

# Use user-provided names instead of inferred ones
names = list(post_variables.keys()) if names is None else names
# If they match, users can further select the variables by using filter keys
filter_keys = list(post_variables.keys()) if filter_keys is None else filter_keys

post_variables = np.concatenate([v for k, v in post_variables.items() if k in names], axis=-1)
prior_variables = np.concatenate([v for k, v in prior_variables.items() if k in names], axis=-1)
# The variables will then be overridden with the filtered keys
post_variables = np.concatenate([v for k, v in post_variables.items() if k in filter_keys], axis=-1)
if prior_variables is not None:
prior_variables = np.concatenate([v for k, v in prior_variables.items() if k in filter_keys], axis=-1)

# Naming or Renaming
elif isinstance(post_variables, np.ndarray):
if names is not None:
if post_variables.shape[-1] != len(names) or prior_variables.shape[-1] != len(names):
raise ValueError("The length of the names list should match the number of target variables.")
else:
if context is not None:
names = [f"${context}_{{{i}}}$" for i in range(post_variables.shape[-1])]
# If there are filter_keys, check if their number is the same as that of the variables.
# If it does, check if there are sufficient variable names.
# If there are, then the variable names are adopted.
if variable_names is not None:
if post_variables.shape[-1] != len(variable_names) or prior_variables.shape[-1] != len(variable_names):
raise ValueError("The number of variable names should match the number of target variables.")

else: # Otherwise, we would assume that all variables are used for plotting.
if context is None:
if variable_names is None:
variable_names = [f"$\\theta_{{{i}}}$" for i in range(post_variables.shape[-1])]
else:
names = [f"$\\theta_{{{i}}}$" for i in range(post_variables.shape[-1])]

variable_names = [f"${context}_{{{i}}}$" for i in range(post_variables.shape[-1])]
else:
raise TypeError("Only dicts and tensors are supported as arguments.")

return dict(post_variables=post_variables, prior_variables=prior_variables, names=names, num_variables=len(names))
return dict(
post_variables=post_variables,
prior_variables=prior_variables,
variable_names=variable_names,
num_variables=len(variable_names),
)
15 changes: 11 additions & 4 deletions bayesflow/utils/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@


def preprocess(
post_variables: dict[str, np.ndarray],
prior_variables: dict[str, np.ndarray],
names: Sequence[str] = None,
post_variables: dict[str, np.ndarray] | np.ndarray,
prior_variables: dict[str, np.ndarray] | np.ndarray,
filter_keys: Sequence[str] = None,
variable_names: Sequence[str] = None,
context: str = None,
num_col: int = None,
num_row: int = None,
Expand Down Expand Up @@ -43,7 +44,13 @@ def preprocess(
Whether the plots are stacked horizontally
"""

plot_data = dicts_to_arrays(post_variables, prior_variables, names, context)
plot_data = dicts_to_arrays(
post_variables=post_variables,
prior_variables=prior_variables,
filter_keys=filter_keys,
variable_names=variable_names,
context=context,
)
check_posterior_prior_shapes(plot_data["post_variables"], plot_data["prior_variables"])

# Configure layout
Expand Down
Loading
Loading