From 39f5af438e2924c38813faa0eba3762d28ebb242 Mon Sep 17 00:00:00 2001 From: Jerry-Jzy Date: Fri, 3 Jan 2025 11:12:52 -0500 Subject: [PATCH] rollback --- deepxde/data/pde_operator.py | 22 +++------------------- deepxde/gradients/gradients.py | 2 +- deepxde/gradients/gradients_forward.py | 4 ++-- deepxde/gradients/jacobian.py | 2 +- 4 files changed, 7 insertions(+), 23 deletions(-) diff --git a/deepxde/data/pde_operator.py b/deepxde/data/pde_operator.py index 4b975fe54..f51888e6a 100644 --- a/deepxde/data/pde_operator.py +++ b/deepxde/data/pde_operator.py @@ -3,7 +3,6 @@ from .data import Data from .sampler import BatchSampler from .. import backend as bkd -from ..backend import backend_name from .. import config from ..utils import run_if_all_none @@ -264,33 +263,18 @@ def _losses(self, outputs, loss_fn, inputs, model, num_func, aux=None): # Use stack instead of as_tensor to keep the gradients. losses = [bkd.reduce_mean(bkd.stack(loss, 0)) for loss in losses] elif config.autodiff == "forward": # forward mode AD - batchsize1, batchsize2 = bkd.shape(outputs)[:2] - shape_3d = (batchsize1, batchsize2, model.net.num_outputs) - # Uniformly reshape the output into the shape (N1, N2, num_outputs), def forward_call(trunk_input): - output = aux[0]((inputs[0], trunk_input)) - return bkd.reshape(output, shape_3d) + return aux[0]((inputs[0], trunk_input)) f = [] if self.pde.pde is not None: - if backend_name in ["tensorflow.compat.v1"]: - outputs_pde = bkd.reshape(outputs, shape_3d) - elif backend_name in ["tensorflow", "pytorch"]: - outputs_pde = (bkd.reshape(outputs, shape_3d), forward_call) # Each f has the shape (N1, N2) f = self.pde.pde( - inputs[1], - outputs_pde, - bkd.reshape( - model.net.auxiliary_vars, - shape_3d, - ), + inputs[1], (outputs, forward_call), model.net.auxiliary_vars ) if not isinstance(f, (list, tuple)): f = [f] - f = [bkd.reshape(fi, (batchsize1, batchsize2)) for fi in f] - # Each error has the shape (N1, ~N2) error_f = [fi[:, bcs_start[-1] :] for fi in f] for error in error_f: @@ -365,4 +349,4 @@ def test(self): ) self.test_x = (func_vals, self.pde.test_x) self.test_aux_vars = vx - return self.test_x, self.test_y, self.test_aux_vars + return self.test_x, self.test_y, self.test_aux_vars \ No newline at end of file diff --git a/deepxde/gradients/gradients.py b/deepxde/gradients/gradients.py index b819405a5..742796b50 100644 --- a/deepxde/gradients/gradients.py +++ b/deepxde/gradients/gradients.py @@ -19,7 +19,7 @@ def jacobian(ys, xs, i=None, j=None): computation. Args: - ys: Output Tensor of shape (batch_size, dim_y) or (batch_size1, batch_size2, dim_y). + ys: Output Tensor of shape (batch_size, dim_y). xs: Input Tensor of shape (batch_size, dim_x). i (int or None): `i`th row. If `i` is ``None``, returns the `j`th column J[:, `j`]. diff --git a/deepxde/gradients/gradients_forward.py b/deepxde/gradients/gradients_forward.py index 8fc385912..b58aa7621 100644 --- a/deepxde/gradients/gradients_forward.py +++ b/deepxde/gradients/gradients_forward.py @@ -87,14 +87,14 @@ def grad_fn(x): # Compute J[i, j] if (i, j) not in self.J: if backend_name == "tensorflow.compat.v1": - self.J[i, j] = self.J[j][..., i : i + 1] + self.J[i, j] = self.J[j][:, i : i + 1] elif backend_name in ["tensorflow", "pytorch", "jax"]: # In backend tensorflow/pytorch/jax, a tuple of a tensor/tensor/array # and a callable is returned, so that it is consistent with the argument, # which is also a tuple. This is useful for further computation, e.g., # Hessian. self.J[i, j] = ( - self.J[j][0][..., i : i + 1], + self.J[j][0][:, i : i + 1], lambda x: self.J[j][1](x)[i : i + 1], ) return self.J[i, j] diff --git a/deepxde/gradients/jacobian.py b/deepxde/gradients/jacobian.py index 5a2f441c2..5b0af016d 100644 --- a/deepxde/gradients/jacobian.py +++ b/deepxde/gradients/jacobian.py @@ -28,7 +28,7 @@ def __init__(self, ys, xs): elif config.autodiff == "forward": # For forward-mode AD, a tuple of a tensor and a callable is passed, # similar to backend jax. - self.dim_y = ys[0].shape[-1] + self.dim_y = ys[0].shape[1] elif backend_name == "jax": # For backend jax, a tuple of a jax array and a callable is passed as one of # the arguments, since jax does not support computational graph explicitly.