Skip to content

Commit

Permalink
make sure that plots work with only one parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
Kucharssim committed Feb 21, 2024
1 parent fe63666 commit df9180c
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions bayesflow/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,12 +147,13 @@ def plot_recovery(
if fig_size is None:
fig_size = (int(4 * n_col), int(4 * n_row))
f, axarr = plt.subplots(n_row, n_col, figsize=fig_size)

# turn axarr into 1D list
axarr = np.atleast_1d(axarr)
if n_col > 1 or n_row > 1:
axarr_it = axarr.flat
else:
# for 1x1, axarr is not a list -> turn it into one for use with enumerate
axarr_it = [axarr]
axarr_it = axarr

for i, ax in enumerate(axarr_it):
if i >= n_params:
Expand Down Expand Up @@ -337,12 +338,13 @@ def plot_z_score_contraction(
if fig_size is None:
fig_size = (int(4 * n_col), int(4 * n_row))
f, axarr = plt.subplots(n_row, n_col, figsize=fig_size)

# turn axarr into 1D list
axarr = np.atleast_1d(axarr)
if n_col > 1 or n_row > 1:
axarr_it = axarr.flat
else:
# for 1x1, axarr is not a list -> turn it into one for use with enumerate
axarr_it = [axarr]
axarr_it = axarr

# Loop and plot
for i, ax in enumerate(axarr_it):
Expand Down Expand Up @@ -480,6 +482,7 @@ def plot_sbc_ecdf(

# Initialize figure
f, ax = plt.subplots(n_row, n_col, figsize=fig_size)
ax = np.atleast_1d(ax)

# Plot individual ecdf of parameters
for j in range(ranks.shape[-1]):
Expand Down Expand Up @@ -657,7 +660,8 @@ def plot_sbc_histograms(
if fig_size is None:
fig_size = (int(5 * n_col), int(5 * n_row))
f, axarr = plt.subplots(n_row, n_col, figsize=fig_size)

axarr = np.atleast_1d(axarr)

# Compute ranks (using broadcasting)
ranks = np.sum(post_samples < prior_samples[:, np.newaxis, :], axis=1)

Expand Down

0 comments on commit df9180c

Please sign in to comment.