Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Diagnostics module #235

Merged
merged 33 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
adf3f5b
Migrate diagnostics module
jerrymhuang Nov 4, 2024
104f726
Add confusion_matrix
jerrymhuang Nov 4, 2024
9f66871
Merge branch 'dev' into diagnostics
jerrymhuang Nov 8, 2024
ae2a6a3
Remove duplicate code for layout configuration
jerrymhuang Nov 9, 2024
bf9066a
Simplify pre-processing for plot_recovery and plot_z_score_contraction
jerrymhuang Nov 9, 2024
b522a7e
Simplify prettify process for z_score and recovery
jerrymhuang Nov 9, 2024
609c89c
Simplify preprocessing for plot_sbc_ecdf and plot_sbc_histograms
jerrymhuang Nov 9, 2024
1beb585
Simplify labeling
jerrymhuang Nov 9, 2024
fd24587
Reformat
jerrymhuang Nov 9, 2024
8fde9e9
Make plot_distribution_2d more compatible
jerrymhuang Nov 9, 2024
a7fd398
Update quickstart notebook with working prior checks
jerrymhuang Nov 9, 2024
04bd262
Improve compatibility for plot_losses and plot_prior_2d
jerrymhuang Nov 11, 2024
a4d4452
Update quickstart notebook with loss trajectory
jerrymhuang Nov 12, 2024
18f0514
Minor fix of plot_utils, start adding tests for diagnostics
jerrymhuang Nov 13, 2024
05a4978
Pre-final version WIP
stefanradev93 Nov 13, 2024
0e956b5
Merge branch 'dev' into diagnostics
jerrymhuang Nov 13, 2024
3090ea7
Minor changes in 2d plots; update and test plot_z_score_contraction
jerrymhuang Nov 14, 2024
5cb5847
Update and test plot_sbc_histograms
jerrymhuang Nov 14, 2024
5462ecf
Update plot_calibration_curves (WIP)
jerrymhuang Nov 14, 2024
2868f96
Minor refactors: change global color schemes, complete type casts, fu…
jerrymhuang Nov 14, 2024
ad59eae
Add detailed callback for loss trajectory
jerrymhuang Nov 14, 2024
15bc4ab
Generalize preprocessing utilities from samples to variables
jerrymhuang Nov 14, 2024
16436c2
Generalize add_metric
jerrymhuang Nov 15, 2024
c25a67d
Remove redundant code segments related to prettify
jerrymhuang Nov 15, 2024
c063c67
Include add_titles and add_titles_and_labels; propagate variables as …
jerrymhuang Nov 15, 2024
d2a742e
Interim cleanup
jerrymhuang Nov 15, 2024
7875661
Add typing and fix plot_sbc_ecdf
jerrymhuang Nov 15, 2024
05ef11b
Minor fix of plot_samples_2d
jerrymhuang Nov 15, 2024
88cc1c4
Minor fix of plot_prior_2d
jerrymhuang Nov 15, 2024
d9ab082
Remove redundant code for axes flattening
jerrymhuang Nov 15, 2024
6683b25
Ensure consistent color scheme; incorporate sequence of labels
jerrymhuang Nov 15, 2024
9de0479
Bug fix for plot_losses
jerrymhuang Nov 15, 2024
9623061
Cleanup
jerrymhuang Nov 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ First, install one machine learning backend of choice. Note that BayesFlow **wil
- [Install TensorFlow](https://www.tensorflow.org/install)

If you don't know which backend to use, we recommend JAX to get started.
It is the fastest backend and already works pretty reliably with the current
It is the fastest backend and already works pretty reliably with the current
dev version of bayesflow.

Once installed, [set the backend environment variable as required by keras](https://keras.io/getting_started/#configuring-your-backend). For example, inside your Python script write:
Expand Down
9 changes: 9 additions & 0 deletions bayesflow/diagnostics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from .plot_losses import plot_losses
from .plot_recovery import plot_recovery
from .plot_sbc_ecdf import plot_sbc_ecdf
from .plot_sbc_histograms import plot_sbc_histograms
from .plot_samples_2d import plot_samples_2d
from .plot_z_score_contraction import plot_z_score_contraction
from .plot_prior_2d import plot_prior_2d
from .plot_posterior_2d import plot_posterior_2d
from .plot_calibration_curves import plot_calibration_curves
113 changes: 113 additions & 0 deletions bayesflow/diagnostics/plot_calibration_curves.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import numpy as np
import matplotlib.pyplot as plt

from typing import Sequence
from ..utils.comp_utils import expected_calibration_error
from ..utils.plot_utils import preprocess, add_titles_and_labels, add_metric, prettify_subplots


def plot_calibration_curves(
post_model_samples: dict[str, np.ndarray] | np.ndarray,
true_model_samples: dict[str, np.ndarray] | np.ndarray,
names: Sequence[str] = None,
num_bins: int = 10,
label_fontsize: int = 16,
title_fontsize: int = 18,
metric_fontsize: int = 14,
tick_fontsize: int = 12,
epsilon: float = 0.02,
figsize: Sequence[int] = None,
color: str = "#132a70",
num_col: int = None,
num_row: int = None,
) -> plt.Figure:
"""Plots the calibration curves, the ECEs and the marginal histograms of predicted posterior model probabilities
for a model comparison problem. The marginal histograms inform about the fraction of predictions in each bin.
Depends on the ``expected_calibration_error`` function for computing the ECE.

Parameters
----------
true_model_samples : np.ndarray of shape (num_data_sets, num_models)
The one-hot-encoded true model indices per data set.
post_model_samples : np.ndarray of shape (num_data_sets, num_models)
The predicted posterior model probabilities (PMPs) per data set.
names : list or None, optional, default: None
The model names for nice plot titles. Inferred if None.
num_bins : int, optional, default: 10
The number of bins to use for the calibration curves (and marginal histograms).
label_fontsize : int, optional, default: 16
The font size of the y-label and y-label texts
legend_fontsize : int, optional, default: 14
The font size of the legend text (ECE value)
title_fontsize : int, optional, default: 18
The font size of the title text. Only relevant if `stacked=False`
tick_fontsize : int, optional, default: 12
The font size of the axis ticklabels
epsilon : float, optional, default: 0.02
A small amount to pad the [0, 1]-bounded axes from both side.
figsize : tuple or None, optional, default: None
The figure size passed to the ``matplotlib`` constructor. Inferred if ``None``
color : str, optional, default: '#8f2727'
The color of the calibration curves
num_row : int, optional, default: None
The number of rows for the subplots. Dynamically determined if None.
num_col : int, optional, default: None
The number of columns for the subplots. Dynamically determined if None.

Returns
-------
fig : plt.Figure - the figure instance for optional saving
"""

plot_data = preprocess(post_model_samples, true_model_samples, names, num_col, num_row, figsize, context="M")

# Compute calibration
cal_errors, true_probs, pred_probs = expected_calibration_error(
plot_data["prior_samples"], plot_data["post_samples"], num_bins
)

for j, ax in enumerate(plot_data["axes"].flat):
# Plot calibration curve
ax[j].plot(pred_probs[j], true_probs[j], "o-", color=color)

# Plot PMP distribution over bins
uniform_bins = np.linspace(0.0, 1.0, num_bins + 1)
norm_weights = np.ones_like(plot_data["post_samples"]) / len(plot_data["post_samples"])
ax[j].hist(
plot_data["post_samples"][:, j], bins=uniform_bins, weights=norm_weights[:, j], color="grey", alpha=0.3
)

# Plot AB line
ax[j].plot((0, 1), (0, 1), "--", color="black", alpha=0.9)

# Tweak plot
ax[j].set_xlim([0 - epsilon, 1 + epsilon])
ax[j].set_ylim([0 - epsilon, 1 + epsilon])
ax[j].set_xticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])
ax[j].set_yticks([0.0, 0.2, 0.4, 0.6, 0.8, 1.0])

# Add ECE label
add_metric(
ax[j],
metric_text=r"$\widehat{{\mathrm{{ECE}}}}$ = {0:.3f}",
metric_value=cal_errors[j],
metric_fontsize=metric_fontsize,
)

# Prettify
prettify_subplots(axes=plot_data["axes"], num_subplots=plot_data["num_variables"], tick_fontsize=tick_fontsize)

# Only add x-labels to the bottom row
add_titles_and_labels(
axes=plot_data["axes"],
num_row=plot_data["num_row"],
num_col=plot_data["num_col"],
title=plot_data["names"],
xlabel="Predicted Probability",
ylabel="True Probability",
title_fontsize=title_fontsize,
label_fontsize=label_fontsize,
)

plot_data["fig"].tight_layout()
return plot_data["fig"]
121 changes: 121 additions & 0 deletions bayesflow/diagnostics/plot_confusion_matrix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import matplotlib.colors
import matplotlib.pyplot as plt
import numpy as np

from typing import Sequence

from keras import ops
from sklearn.metrics import confusion_matrix
from matplotlib.colors import LinearSegmentedColormap

from bayesflow.utils.plot_utils import make_figure


def plot_confusion_matrix(
true_models: dict[str, np.ndarray] | np.ndarray,
pred_models: dict[str, np.ndarray] | np.ndarray,
model_names: Sequence[str] = None,
fig_size: tuple = (5, 5),
label_fontsize: int = 16,
title_fontsize: int = 18,
value_fontsize: int = 10,
tick_fontsize: int = 12,
xtick_rotation: int = None,
ytick_rotation: int = None,
normalize: bool = True,
cmap: matplotlib.colors.Colormap | str = None,
title: bool = True,
) -> plt.Figure:
"""Plots a confusion matrix for validating a neural network trained for Bayesian model comparison.

Parameters
----------
true_models : np.ndarray of shape (num_data_sets, num_models)
The one-hot-encoded true model indices per data set.
pred_models : np.ndarray of shape (num_data_sets, num_models)
The predicted posterior model probabilities (PMPs) per data set.
model_names : list or None, optional, default: None
The model names for nice plot titles. Inferred if None.
fig_size : tuple or None, optional, default: (5, 5)
The figure size passed to the ``matplotlib`` constructor. Inferred if ``None``
label_fontsize : int, optional, default: 16
The font size of the y-label and y-label texts
title_fontsize : int, optional, default: 18
The font size of the title text.
value_fontsize : int, optional, default: 10
The font size of the text annotations and the colorbar tick labels.
tick_fontsize : int, optional, default: 12
The font size of the axis label and model name texts.
xtick_rotation: int, optional, default: None
Rotation of x-axis tick labels (helps with long model names).
ytick_rotation: int, optional, default: None
Rotation of y-axis tick labels (helps with long model names).
normalize : bool, optional, default: True
A flag for normalization of the confusion matrix.
If True, each row of the confusion matrix is normalized to sum to 1.
cmap : matplotlib.colors.Colormap or str, optional, default: None
Colormap to be used for the cells. If a str, it should be the name of a registered colormap,
e.g., 'viridis'. Default colormap matches the BayesFlow defaults by ranging from white to red.
title : bool, optional, default True
A flag for adding 'Confusion Matrix' above the matrix.

Returns
-------
fig : plt.Figure - the figure instance for optional saving
"""

if model_names is None:
num_models = true_models.shape[-1]
model_names = [rf"$M_{{{m}}}$" for m in range(1, num_models + 1)]

if cmap is None:
cmap = LinearSegmentedColormap.from_list("", ["white", "#132a70"])

# Flatten input
true_models = ops.argmax(true_models, axis=1)
pred_models = ops.argmax(pred_models, axis=1)

# Compute confusion matrix
cm = confusion_matrix(true_models, pred_models)

# if normalize:
# # Sum along rows and keep dimensions for broadcasting
# cm_sum = ops.sum(cm, axis=1, keepdims=True)
#
# # Broadcast division for normalization
# cm_normalized = cm / cm_sum

# Initialize figure
fig, ax = make_figure(1, 1, figsize=fig_size)
im = ax.imshow(cm, interpolation="nearest", cmap=cmap)
cbar = ax.figure.colorbar(im, ax=ax, shrink=0.75)

cbar.ax.tick_params(labelsize=value_fontsize)

ax.set(xticks=ops.arange(cm.shape[1]), yticks=ops.arange(cm.shape[0]))
ax.set_xticklabels(model_names, fontsize=tick_fontsize)
if xtick_rotation:
plt.xticks(rotation=xtick_rotation, ha="right")
ax.set_yticklabels(model_names, fontsize=tick_fontsize)
if ytick_rotation:
plt.yticks(rotation=ytick_rotation)
ax.set_xlabel("Predicted model", fontsize=label_fontsize)
ax.set_ylabel("True model", fontsize=label_fontsize)

# Loop over data dimensions and create text annotations
fmt = ".2f" if normalize else "d"
thresh = cm.max() / 2.0
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
ax.text(
j,
i,
format(cm[i, j], fmt),
fontsize=value_fontsize,
ha="center",
va="center",
color="white" if cm[i, j] > thresh else "black",
)
if title:
ax.set_title("Confusion Matrix", fontsize=title_fontsize)
return fig
142 changes: 142 additions & 0 deletions bayesflow/diagnostics/plot_losses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from typing import Sequence
from ..utils.plot_utils import make_figure, add_titles_and_labels


def plot_losses(
train_losses: pd.DataFrame | np.ndarray,
val_losses: pd.DataFrame | np.ndarray = None,
moving_average: bool = False,
per_training_step: bool = False,
ma_window_fraction: float = 0.01,
figsize: Sequence[float] = None,
train_color: str = "#132a70",
val_color: str = "black",
lw_train: float = 2.0,
lw_val: float = 3.0,
legend_fontsize: int = 14,
label_fontsize: int = 14,
title_fontsize: int = 16,
) -> plt.Figure:
"""
A generic helper function to plot the losses of a series of training epochs
and runs.

Parameters
----------

train_losses : pd.DataFrame
The (plottable) history as returned by a train_[...] method of a
``Trainer`` instance.
Alternatively, you can just pass a data frame of validation losses
instead of train losses, if you only want to plot the validation loss.
val_losses : pd.DataFrame or None, optional, default: None
The (plottable) validation history as returned by a train_[...] method
of a ``Trainer`` instance.
If left ``None``, only train losses are plotted. Should have the same
number of columns as ``train_losses``.
moving_average : bool, optional, default: False
A flag for adding a moving average line of the train_losses.
per_training_step : bool, optional, default: False
A flag for making loss trajectory detailed (to training steps) rather than per epoch.
ma_window_fraction : int, optional, default: 0.01
Window size for the moving average as a fraction of total
training steps.
figsize : tuple or None, optional, default: None
The figure size passed to the ``matplotlib`` constructor.
Inferred if ``None``
train_color : str, optional, default: '#8f2727'
The color for the train loss trajectory
val_color : str, optional, default: black
The color for the optional validation loss trajectory
lw_train : int, optional, default: 2
The linewidth for the training loss curve
lw_val : int, optional, default: 3
The linewidth for the validation loss curve
legend_fontsize : int, optional, default: 14
The font size of the legend text
label_fontsize : int, optional, default: 14
The font size of the y-label text
title_fontsize : int, optional, default: 16
The font size of the title text

Returns
-------
f : plt.Figure - the figure instance for optional saving

Raises
------
AssertionError
If the number of columns in ``train_losses`` does not match the
number of columns in ``val_losses``.
"""
if isinstance(train_losses, np.ndarray):
train_losses = pd.DataFrame(train_losses)

if isinstance(val_losses, np.ndarray):
val_losses = pd.DataFrame(val_losses)

# Determine the number of rows for plot
num_row = len(train_losses.columns)

# Initialize figure
fig, axes = make_figure(num_row=num_row, num_col=1, figsize=(16, int(4 * num_row) if figsize is None else figsize))

# Get the number of steps as an array
train_step_index = np.arange(1, len(train_losses) + 1)
if val_losses is not None:
val_step = int(np.floor(len(train_losses) / len(val_losses)))
val_step_index = train_step_index[(val_step - 1) :: val_step]

# If unequal length due to some reason, attempt a fix
if val_step_index.shape[0] > val_losses.shape[0]:
val_step_index = val_step_index[: val_losses.shape[0]]

# Loop through loss entries and populate plot
looper = [axes] if num_row == 1 else axes.flat
for i, ax in enumerate(looper):
# Plot train curve
ax.plot(train_step_index, train_losses.iloc[:, i], color=train_color, lw=lw_train, alpha=0.9, label="Training")
if moving_average and train_losses.columns[i] == "Loss":
moving_average_window = int(train_losses.shape[0] * ma_window_fraction)
smoothed_loss = train_losses.iloc[:, i].rolling(window=moving_average_window).mean()
ax.plot(train_step_index, smoothed_loss, color="grey", lw=lw_train, label="Training (Moving Average)")

# Plot optional val curve
if val_losses is not None:
if i < val_losses.shape[1]:
ax.plot(
val_step_index,
val_losses.iloc[:, i],
linestyle="--",
marker="o",
color=val_color,
lw=lw_val,
label="Validation",
)

sns.despine(ax=ax)
ax.grid(alpha=0.5)

# Only add legend if there is a validation curve
if val_losses is not None or moving_average:
ax.legend(fontsize=legend_fontsize)

# Schmuck
add_titles_and_labels(
axes=np.atleast_1d(axes),
num_row=num_row,
num_col=1,
title=train_losses.columns if num_row > 1 else ["Training Loss"],
xlabel="Training step #" if per_training_step else "Training epoch #",
ylabel="Value",
title_fontsize=title_fontsize,
label_fontsize=label_fontsize,
)

fig.tight_layout()
return fig
Loading
Loading