diff --git a/deepxde/data/pde_operator.py b/deepxde/data/pde_operator.py index 308b8f088..1bb97181f 100644 --- a/deepxde/data/pde_operator.py +++ b/deepxde/data/pde_operator.py @@ -275,9 +275,13 @@ def forward_call(trunk_input): ) if not isinstance(f, (list, tuple)): f = [f] - error_f = [fi[:, bcs_start[-1] :] for fi in f] + error_f = [fi[:, bcs_start[-1]:] for fi in f] # Each error has the shape (N1, ~N2) - losses = [loss_fn(bkd.zeros_like(error), error) for error in error_f] + for error in error_f: + error_i = [] + for i in range(error.shape[0]): + error_i.append(loss_fn(bkd.zeros_like(error[i]), error[i])) + losses.append(bkd.reduce_mean(bkd.stack(error_i, 0))) # BC loss error_bc = []