Skip to content

Commit

Permalink
update code
Browse files Browse the repository at this point in the history
  • Loading branch information
Jerry-Jzy committed Jan 5, 2025
1 parent 7034cbe commit de89c12
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion deepxde/gradients/gradients.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def jacobian(ys, xs, i=None, j=None):
computation.
Args:
ys: Output Tensor of shape (batch_size, dim_y).
ys: Output Tensor of shape (batch_size, dim_y) or (batch_size1, batch_size2, dim_y).
xs: Input Tensor of shape (batch_size, dim_x).
i (int or None): `i`th row. If `i` is ``None``, returns the `j`th column
J[:, `j`].
Expand Down
4 changes: 2 additions & 2 deletions deepxde/gradients/gradients_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,14 @@ def grad_fn(x):
# Compute J[i, j]
if (i, j) not in self.J:
if backend_name == "tensorflow.compat.v1":
self.J[i, j] = self.J[j][:, i : i + 1]
self.J[i, j] = self.J[j][..., i : i + 1]
elif backend_name in ["tensorflow", "pytorch", "jax"]:
# In backend tensorflow/pytorch/jax, a tuple of a tensor/tensor/array
# and a callable is returned, so that it is consistent with the argument,
# which is also a tuple. This is useful for further computation, e.g.,
# Hessian.
self.J[i, j] = (
self.J[j][0][:, i : i + 1],
self.J[j][0][..., i : i + 1],
lambda x: self.J[j][1](x)[i : i + 1],
)
return self.J[i, j]
Expand Down
8 changes: 4 additions & 4 deletions deepxde/gradients/jacobian.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,22 @@ def __init__(self, ys, xs):
self.xs = xs

if backend_name in ["tensorflow.compat.v1", "paddle"]:
self.dim_y = ys.shape[1]
self.dim_y = ys.shape[-1]
elif backend_name in ["tensorflow", "pytorch"]:
if config.autodiff == "reverse":
# For reverse-mode AD, only a tensor is passed.
self.dim_y = ys.shape[1]
self.dim_y = ys.shape[-1]
elif config.autodiff == "forward":
# For forward-mode AD, a tuple of a tensor and a callable is passed,
# similar to backend jax.
self.dim_y = ys[0].shape[1]
self.dim_y = ys[0].shape[-1]
elif backend_name == "jax":
# For backend jax, a tuple of a jax array and a callable is passed as one of
# the arguments, since jax does not support computational graph explicitly.
# The array is used to control the dimensions and the callable is used to
# obtain the derivative function, which can be used to compute the
# derivatives.
self.dim_y = ys[0].shape[1]
self.dim_y = ys[0].shape[-1]
self.dim_x = xs.shape[1]

self.J = {}
Expand Down

0 comments on commit de89c12

Please sign in to comment.