Skip to content

Commit

Permalink
update code
Browse files Browse the repository at this point in the history
  • Loading branch information
Jerry-Jzy committed Jan 1, 2025
1 parent 38b81b4 commit d261b57
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions deepxde/data/pde_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,28 +263,34 @@ def _losses(self, outputs, loss_fn, inputs, model, num_func, aux=None):
# Use stack instead of as_tensor to keep the gradients.
losses = [bkd.reduce_mean(bkd.stack(loss, 0)) for loss in losses]
elif config.autodiff == "forward": # forward mode AD
outputs_shape = bkd.shape(outputs)
shape0, shape1 = outputs_shape[0], outputs_shape[1]
shape2 = model.net.num_outputs
shape0, shape1 = bkd.shape(outputs)[:2]

def forward_call(trunk_input):
output = aux[0]((inputs[0], trunk_input))
return bkd.reshape(output, (shape0 * shape1, shape2))
return bkd.reshape(output, (shape0 * shape1, model.net.num_outputs))

f = []
if self.pde.pde is not None:
# Each f has the shape (N1, N2)
f = self.pde.pde(
inputs[1],
(bkd.reshape(outputs, (shape0 * shape1, shape2)), forward_call),
bkd.reshape(model.net.auxiliary_vars, (shape0 * shape1, shape2)),
(
bkd.reshape(outputs, (shape0 * shape1, model.net.num_outputs)),
forward_call,
),
bkd.reshape(
model.net.auxiliary_vars,
(shape0 * shape1, model.net.num_outputs),
),
)
if not isinstance(f, (list, tuple)):
f = [f]
f = (
[bkd.reshape(fi, (shape0, shape1)) for fi in f]
if model.net.num_outputs == 1
else [bkd.reshape(fi, (shape0, shape1, shape2)) for fi in f]
else [
bkd.reshape(fi, (shape0, shape1, model.net.num_outputs)) for fi in f
]
)

# Each error has the shape (N1, ~N2)
Expand Down

0 comments on commit d261b57

Please sign in to comment.