From df9180c7f1f9fd1231bfb87b3f49d39088c130c6 Mon Sep 17 00:00:00 2001 From: Kucharssim Date: Wed, 21 Feb 2024 09:46:41 +0100 Subject: [PATCH] make sure that plots work with only one parameter --- bayesflow/diagnostics.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/bayesflow/diagnostics.py b/bayesflow/diagnostics.py index ea40e90dc..a0dc5a8cd 100644 --- a/bayesflow/diagnostics.py +++ b/bayesflow/diagnostics.py @@ -147,12 +147,13 @@ def plot_recovery( if fig_size is None: fig_size = (int(4 * n_col), int(4 * n_row)) f, axarr = plt.subplots(n_row, n_col, figsize=fig_size) + # turn axarr into 1D list + axarr = np.atleast_1d(axarr) if n_col > 1 or n_row > 1: axarr_it = axarr.flat else: - # for 1x1, axarr is not a list -> turn it into one for use with enumerate - axarr_it = [axarr] + axarr_it = axarr for i, ax in enumerate(axarr_it): if i >= n_params: @@ -337,12 +338,13 @@ def plot_z_score_contraction( if fig_size is None: fig_size = (int(4 * n_col), int(4 * n_row)) f, axarr = plt.subplots(n_row, n_col, figsize=fig_size) + # turn axarr into 1D list + axarr = np.atleast_1d(axarr) if n_col > 1 or n_row > 1: axarr_it = axarr.flat else: - # for 1x1, axarr is not a list -> turn it into one for use with enumerate - axarr_it = [axarr] + axarr_it = axarr # Loop and plot for i, ax in enumerate(axarr_it): @@ -480,6 +482,7 @@ def plot_sbc_ecdf( # Initialize figure f, ax = plt.subplots(n_row, n_col, figsize=fig_size) + ax = np.atleast_1d(ax) # Plot individual ecdf of parameters for j in range(ranks.shape[-1]): @@ -657,7 +660,8 @@ def plot_sbc_histograms( if fig_size is None: fig_size = (int(5 * n_col), int(5 * n_row)) f, axarr = plt.subplots(n_row, n_col, figsize=fig_size) - + axarr = np.atleast_1d(axarr) + # Compute ranks (using broadcasting) ranks = np.sum(post_samples < prior_samples[:, np.newaxis, :], axis=1)