Skip to content

Commit

Permalink
refactor the code for computation of dim_y
Browse files Browse the repository at this point in the history
  • Loading branch information
Jerry-Jzy committed Dec 29, 2024
1 parent 101bc31 commit af99377
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 8 deletions.
36 changes: 32 additions & 4 deletions deepxde/data/pde_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,18 +263,46 @@ def _losses(self, outputs, loss_fn, inputs, model, num_func, aux=None):
# Use stack instead of as_tensor to keep the gradients.
losses = [bkd.reduce_mean(bkd.stack(loss, 0)) for loss in losses]
elif config.autodiff == "forward": # forward mode AD
if bkd.ndim(outputs) == 2:
is_multi_outputs = False
shape0, shape1 = outputs.shape[0], outputs.shape[1]
elif bkd.ndim(outputs) == 3:
is_multi_outputs = True
shape0, shape1, shape2 = (
outputs.shape[0],
outputs.shape[1],
outputs.shape[2],
)

def forward_call(trunk_input):
return aux[0]((inputs[0], trunk_input))
output = aux[0]((inputs[0], trunk_input))
if not is_multi_outputs:
return output.reshape(shape0 * shape1, 1)
elif is_multi_outputs:
return output.reshape(shape0 * shape1, shape2)

if not is_multi_outputs:
outputs = outputs.reshape(shape0 * shape1, 1)
auxiliary_vars = model.net.auxiliary_vars.reshape(shape0 * shape1, 1)
elif is_multi_outputs:
outputs = outputs.reshape(shape0 * shape1, shape2)
auxiliary_vars = model.net.auxiliary_vars.reshape(
shape0 * shape1, shape2
)

f = []
if self.pde.pde is not None:
# Each f has the shape (N1, N2)
f = self.pde.pde(
inputs[1], (outputs, forward_call), model.net.auxiliary_vars
)
f = self.pde.pde(inputs[1], (outputs, forward_call), auxiliary_vars)
if not isinstance(f, (list, tuple)):
f = [f]

if not is_multi_outputs:
outputs = outputs.reshape(shape0, shape1)
f = [fi.reshape(shape0, shape1) for fi in f]
elif is_multi_outputs:
outputs = outputs.reshape(shape0, shape1, shape2)
f = [fi.reshape(shape0, shape1, shape2) for fi in f]
# Each error has the shape (N1, ~N2)
error_f = [fi[:, bcs_start[-1] :] for fi in f]
for error in error_f:
Expand Down
5 changes: 1 addition & 4 deletions deepxde/gradients/jacobian.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,7 @@ def __init__(self, ys, xs):
elif config.autodiff == "forward":
# For forward-mode AD, a tuple of a tensor and a callable is passed,
# similar to backend jax.
if bkd.ndim(ys[0]) == 2:
self.dim_y = 1
elif bkd.ndim(ys[0]) == 3:
self.dim_y = ys[0].shape[2]
self.dim_y = ys[0].shape[-1]
elif backend_name == "jax":
# For backend jax, a tuple of a jax array and a callable is passed as one of
# the arguments, since jax does not support computational graph explicitly.
Expand Down

0 comments on commit af99377

Please sign in to comment.