Skip to content

Commit

Permalink
Modify plot_loss_history to handle irregular array sizes (#1571)
Browse files Browse the repository at this point in the history
  • Loading branch information
cozy-hn authored Nov 25, 2023
1 parent 26860b1 commit c55e5d3
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions deepxde/utils/external.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,10 @@ def plot_loss_history(loss_history, fname=None):
fname (string): If `fname` is a string (e.g., 'loss_history.png'), then save the
figure to the file of the file name `fname`.
"""
loss_train = np.sum(loss_history.loss_train, axis=1)
loss_test = np.sum(loss_history.loss_test, axis=1)
# np.sum(loss_history.loss_train, axis=1) is error-prone for arrays of varying lengths.
# Handle irregular array sizes.
loss_train = np.array([np.sum(loss) for loss in loss_history.loss_train])
loss_test = np.array([np.sum(loss) for loss in loss_history.loss_test])

plt.figure()
plt.semilogy(loss_history.steps, loss_train, label="Train loss")
Expand Down

0 comments on commit c55e5d3

Please sign in to comment.