Skip to content

Commit

Permalink
update code
Browse files Browse the repository at this point in the history
  • Loading branch information
Jerry-Jzy committed Jan 1, 2025
1 parent 5d8ec8b commit 7d36bb4
Showing 1 changed file with 5 additions and 7 deletions.
12 changes: 5 additions & 7 deletions deepxde/data/pde_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,10 +264,8 @@ def _losses(self, outputs, loss_fn, inputs, model, num_func, aux=None):
losses = [bkd.reduce_mean(bkd.stack(loss, 0)) for loss in losses]
elif config.autodiff == "forward": # forward mode AD
if model.net.num_outputs == 1:
is_multi_outputs = False
shape0, shape1 = outputs.shape[0], outputs.shape[1]
else:
is_multi_outputs = True
shape0, shape1, shape2 = (
outputs.shape[0],
outputs.shape[1],
Expand All @@ -276,11 +274,11 @@ def _losses(self, outputs, loss_fn, inputs, model, num_func, aux=None):

def forward_call(trunk_input):
output = aux[0]((inputs[0], trunk_input))
if is_multi_outputs:
return bkd.reshape(output, (shape0 * shape1, shape2))
return bkd.reshape(output, (shape0 * shape1, 1))
if model.net.num_outputs == 1:
return bkd.reshape(output, (shape0 * shape1, 1))
return bkd.reshape(output, (shape0 * shape1, shape2))

if not is_multi_outputs:
if model.net.num_outputs == 1:
outputs = bkd.reshape(outputs, (shape0 * shape1, 1))
auxiliary_vars = bkd.reshape(
model.net.auxiliary_vars, (shape0 * shape1, 1)
Expand All @@ -298,7 +296,7 @@ def forward_call(trunk_input):
if not isinstance(f, (list, tuple)):
f = [f]

if not is_multi_outputs:
if model.net.num_outputs == 1:
outputs = bkd.reshape(outputs, (shape0, shape1))
f = [bkd.reshape(fi, (shape0, shape1)) for fi in f]
else:
Expand Down

0 comments on commit 7d36bb4

Please sign in to comment.