Skip to content

Commit

Permalink
rename diagnostics plots (#265)
Browse files Browse the repository at this point in the history
* rename diagnostics plots

* Slight name change

---------

Co-authored-by: stefanradev93 <[email protected]>
  • Loading branch information
paul-buerkner and stefanradev93 authored Dec 3, 2024
1 parent 6fb8b83 commit cddbce8
Show file tree
Hide file tree
Showing 18 changed files with 475 additions and 348 deletions.
20 changes: 11 additions & 9 deletions bayesflow/diagnostics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from .plot_losses import plot_losses
from .plot_recovery import plot_recovery
from .plot_sbc_ecdf import plot_sbc_ecdf
from .plot_sbc_histograms import plot_sbc_histograms
from .plot_samples_2d import plot_samples_2d
from .plot_z_score_contraction import plot_z_score_contraction
from .plot_prior_2d import plot_prior_2d
from .plot_posterior_2d import plot_posterior_2d
from .plot_calibration_curves import plot_calibration_curves
from .plots import calibration_ecdf
from .plots import calibration_histogram
from .plots import loss
from .plots import mc_calibration
from .plots import mc_confusion_matrix
from .plots import mmd_hypothesis_test
from .plots import pairs_posterior
from .plots import pairs_prior
from .plots import pairs_samples
from .plots import recovery
from .plots import z_score_contraction
11 changes: 11 additions & 0 deletions bayesflow/diagnostics/plots/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from .calibration_ecdf import calibration_ecdf
from .calibration_histogram import calibration_histogram
from .loss import loss
from .mc_calibration import mc_calibration
from .mc_confusion_matrix import mc_confusion_matrix
from .mmd_hypothesis_test import mmd_hypothesis_test
from .pairs_posterior import pairs_posterior
from .pairs_prior import pairs_prior
from .pairs_samples import pairs_samples
from .recovery import recovery
from .z_score_contraction import z_score_contraction
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
import matplotlib.pyplot as plt

from typing import Sequence
from ..utils.plot_utils import preprocess, add_titles_and_labels, prettify_subplots
from ..utils.ecdf import simultaneous_ecdf_bands
from ..utils.ecdf.ranks import fractional_ranks, distance_ranks
from ...utils.plot_utils import preprocess, add_titles_and_labels, prettify_subplots
from ...utils.ecdf import simultaneous_ecdf_bands
from ...utils.ecdf.ranks import fractional_ranks, distance_ranks


def plot_sbc_ecdf(
def calibration_ecdf(
post_samples: dict[str, np.ndarray] | np.ndarray,
prior_samples: dict[str, np.ndarray] | np.ndarray,
filter_keys: Sequence[str] = None,
Expand Down Expand Up @@ -61,12 +61,15 @@ def plot_sbc_ecdf(
stacked : bool, optional, default: False
If `True`, all ECDFs will be plotted on the same plot.
If `False`, each ECDF will have its own subplot,
similar to the behavior of `plot_sbc_histograms`.
similar to the behavior of `calibration_histogram`.
rank_type : str, optional, default: 'fractional'
If `fractional` (default), the ranks are computed as the fraction of posterior samples that are smaller than
the prior. If `distance`, the ranks are computed as the fraction of posterior samples that are closer to
a reference points (default here is the origin). You can pass a reference array in the same shape as the
`prior_samples` array by setting `references` in the ``ranks_kwargs``. This is motivated by [2].
If `fractional` (default), the ranks are computed as the fraction
of posterior samples that are smaller than the prior.
If `distance`, the ranks are computed as the fraction of posterior
samples that are closer to a reference points (default here is the origin).
You can pass a reference array in the same shape as the
`prior_samples` array by setting `references` in the ``ranks_kwargs``.
This is motivated by [2].
variable_names : list or None, optional, default: None
The parameter names for nice plot titles.
Inferred if None. Only relevant if `stacked=False`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from bayesflow.utils import preprocess, add_titles_and_labels, prettify_subplots


def plot_sbc_histograms(
def calibration_histogram(
post_samples: dict[str, np.ndarray] | np.ndarray,
prior_samples: dict[str, np.ndarray] | np.ndarray,
filter_keys: Sequence[str] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
import matplotlib.pyplot as plt


from ..utils.plot_utils import make_figure, add_titles_and_labels
from ...utils.plot_utils import make_figure, add_titles_and_labels


def plot_losses(
def loss(
train_losses: pd.DataFrame | np.ndarray,
val_losses: pd.DataFrame | np.ndarray = None,
moving_average: bool = False,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from bayesflow.utils import expected_calibration_error, preprocess, add_titles_and_labels, add_metric, prettify_subplots


def plot_calibration_curves(
post_model_samples: dict[str, np.ndarray] | np.ndarray,
true_model_samples: dict[str, np.ndarray] | np.ndarray,
def mc_calibration(
pred_models: dict[str, np.ndarray] | np.ndarray,
true_models: dict[str, np.ndarray] | np.ndarray,
names: Sequence[str] = None,
num_bins: int = 10,
label_fontsize: int = 16,
Expand All @@ -28,11 +28,11 @@ def plot_calibration_curves(
Parameters
----------
true_model_samples : np.ndarray of shape (num_data_sets, num_models)
true_models : np.ndarray of shape (num_data_sets, num_models)
The one-hot-encoded true model indices per data set.
post_model_samples : np.ndarray of shape (num_data_sets, num_models)
pred_models : np.ndarray of shape (num_data_sets, num_models)
The predicted posterior model probabilities (PMPs) per data set.
names : list or None, optional, default: None
names : list or None, optional, default: None
The model names for nice plot titles. Inferred if None.
num_bins : int, optional, default: 10
The number of bins to use for the calibration curves (and marginal histograms).
Expand Down Expand Up @@ -60,7 +60,7 @@ def plot_calibration_curves(
fig : plt.Figure - the figure instance for optional saving
"""

plot_data = preprocess(post_model_samples, true_model_samples, names, num_col, num_row, figsize, context="M")
plot_data = preprocess(pred_models, true_models, names, num_col, num_row, figsize, context="M")

# Compute calibration
cal_errors, true_probs, pred_probs = expected_calibration_error(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from bayesflow.utils.plot_utils import make_figure


def plot_confusion_matrix(
def mc_confusion_matrix(
true_models: dict[str, np.ndarray] | np.ndarray,
pred_models: dict[str, np.ndarray] | np.ndarray,
model_names: Sequence[str] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from keras import ops


def plot_mmd_hypothesis_test(
def mmd_hypothesis_test(
mmd_null: np.ndarray,
mmd_observed: float = None,
alpha_level: float = 0.05,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@

from matplotlib.lines import Line2D

from .plot_samples_2d import plot_samples_2d
from .pairs_samples import pairs_samples


def plot_posterior_2d(
def pairs_posterior(
post_samples: np.ndarray,
prior_samples: np.ndarray = None,
prior=None,
Expand Down Expand Up @@ -70,7 +70,7 @@ def plot_posterior_2d(

# Plot posterior first
context = ""
g = plot_samples_2d(
g = pairs_samples(
post_samples, context=context, variable_names=variable_names, render=False, height=height, **kwargs
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import seaborn as sns

from bayesflow.simulators import Simulator
from .plot_samples_2d import plot_samples_2d
from .pairs_samples import pairs_samples


def plot_prior_2d(
def pairs_prior(
simulator: Simulator,
variable_names: Sequence[str] | str = None,
num_samples: int = 2000,
Expand Down Expand Up @@ -43,6 +43,6 @@ def plot_prior_2d(
if isinstance(samples, dict):
samples = samples["theta"]

return plot_samples_2d(
return pairs_samples(
samples, context="Prior", height=height, color=color, param_names=variable_names, render=True, **kwargs
)
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from bayesflow.utils.dict_utils import dicts_to_arrays


def plot_samples_2d(
def pairs_samples(
samples: dict[str, np.ndarray] | np.ndarray = None,
filter_keys: Sequence[str] = None,
context: str = None,
Expand All @@ -27,7 +27,8 @@ def plot_samples_2d(
samples : dict[str, Tensor], default: None
Sample draws from any dataset
context : str, default: None
The context that the sample represents
The context that the sample represents. If specified,
should usually either be `Prior` or `Posterior`.
height : float, optional, default: 2.5
The height of the pair plot
color : str, optional, default : '#8f2727'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from bayesflow.utils import preprocess, prettify_subplots, make_quadratic, add_titles_and_labels, add_metric


def plot_recovery(
def recovery(
post_samples: dict[str, np.ndarray] | np.ndarray,
prior_samples: dict[str, np.ndarray] | np.ndarray,
filter_keys: Sequence[str] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from bayesflow.utils import preprocess, add_titles_and_labels, prettify_subplots


def plot_z_score_contraction(
def z_score_contraction(
post_samples: dict[str, np.ndarray] | np.ndarray,
prior_samples: dict[str, np.ndarray] | np.ndarray,
filter_keys: Sequence[str] = None,
Expand Down
Loading

0 comments on commit cddbce8

Please sign in to comment.