From b0d239b1419bf407fe1c23468a470b5f93b65c4e Mon Sep 17 00:00:00 2001 From: Jerry-Jzy <66828815+Jerry-Jzy@users.noreply.github.com> Date: Mon, 23 Dec 2024 19:04:01 -0500 Subject: [PATCH] PDEOperatorCartesianProd supports forward-mode AD (#1903) --- deepxde/data/pde_operator.py | 70 +++++++++++++++++++++++++++--------- 1 file changed, 54 insertions(+), 16 deletions(-) diff --git a/deepxde/data/pde_operator.py b/deepxde/data/pde_operator.py index ce573274d..3b001cf9a 100644 --- a/deepxde/data/pde_operator.py +++ b/deepxde/data/pde_operator.py @@ -237,23 +237,59 @@ def __init__( self.train_next_batch() self.test() - def _losses(self, outputs, loss_fn, inputs, model, num_func): + def _losses(self, outputs, loss_fn, inputs, model, num_func, aux=None): bcs_start = np.cumsum([0] + self.pde.num_bcs) losses = [] - for i in range(num_func): - out = outputs[i] - # Single output - if bkd.ndim(out) == 1: - out = out[:, None] + # PDE loss + if config.autodiff == "reverse": # reverse mode AD + for i in range(num_func): + out = outputs[i] + # Single output + if bkd.ndim(out) == 1: + out = out[:, None] + f = [] + if self.pde.pde is not None: + f = self.pde.pde( + inputs[1], out, model.net.auxiliary_vars[i][:, None] + ) + if not isinstance(f, (list, tuple)): + f = [f] + error_f = [fi[bcs_start[-1] :] for fi in f] + losses_i = [loss_fn(bkd.zeros_like(error), error) for error in error_f] + losses.append(losses_i) + + losses = zip(*losses) + # Use stack instead of as_tensor to keep the gradients. + losses = [bkd.reduce_mean(bkd.stack(loss, 0)) for loss in losses] + elif config.autodiff == "forward": # forward mode AD + + def forward_call(trunk_input): + return aux[0]((inputs[0], trunk_input)) + f = [] if self.pde.pde is not None: - f = self.pde.pde(inputs[1], out, model.net.auxiliary_vars[i][:, None]) + # Each f has the shape (N1, N2) + f = self.pde.pde( + inputs[1], (outputs, forward_call), model.net.auxiliary_vars + ) if not isinstance(f, (list, tuple)): f = [f] - error_f = [fi[bcs_start[-1] :] for fi in f] - losses_i = [loss_fn(bkd.zeros_like(error), error) for error in error_f] - + # Each error has the shape (N1, ~N2) + error_f = [fi[:, bcs_start[-1] :] for fi in f] + for error in error_f: + error_i = [] + for i in range(num_func): + error_i.append(loss_fn(bkd.zeros_like(error[i]), error[i])) + losses.append(bkd.reduce_mean(bkd.stack(error_i, 0))) + + # BC loss + losses_bc = [] + for i in range(num_func): + losses_i = [] + out = outputs[i] + if bkd.ndim(out) == 1: + out = out[:, None] for j, bc in enumerate(self.pde.bcs): beg, end = bcs_start[j], bcs_start[j + 1] # The same BC points are used for training and testing. @@ -267,19 +303,21 @@ def _losses(self, outputs, loss_fn, inputs, model, num_func): ) losses_i.append(loss_fn(bkd.zeros_like(error), error)) - losses.append(losses_i) + losses_bc.append(losses_i) - losses = zip(*losses) - # Use stack instead of as_tensor to keep the gradients. - losses = [bkd.reduce_mean(bkd.stack(loss, 0)) for loss in losses] + losses_bc = zip(*losses_bc) + losses_bc = [bkd.reduce_mean(bkd.stack(loss, 0)) for loss in losses_bc] + losses.append(losses_bc) return losses def losses_train(self, targets, outputs, loss_fn, inputs, model, aux=None): num_func = self.num_func if self.batch_size is None else self.batch_size - return self._losses(outputs, loss_fn, inputs, model, num_func) + return self._losses(outputs, loss_fn, inputs, model, num_func, aux=aux) def losses_test(self, targets, outputs, loss_fn, inputs, model, aux=None): - return self._losses(outputs, loss_fn, inputs, model, len(self.test_x[0])) + return self._losses( + outputs, loss_fn, inputs, model, len(self.test_x[0]), aux=aux + ) def train_next_batch(self, batch_size=None): if self.train_x is None: