Skip to content

Commit

Permalink
Backend TensorFlow supports forward-mode automatic differentiation (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
ZongrenZou authored Jan 3, 2024
1 parent d6a217a commit aa69952
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 16 deletions.
6 changes: 3 additions & 3 deletions deepxde/data/pde.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,13 @@ def __init__(
self.test()

def losses(self, targets, outputs, loss_fn, inputs, model, aux=None):
if backend_name in ["tensorflow.compat.v1", "tensorflow", "paddle"]:
if backend_name in ["tensorflow.compat.v1", "paddle"]:
outputs_pde = outputs
elif backend_name == "pytorch":
elif backend_name in ["tensorflow", "pytorch"]:
if config.autodiff == "reverse":
outputs_pde = outputs
elif config.autodiff == "forward":
# forward-mode AD in PyTorch requires functions
# forward-mode AD requires functions
outputs_pde = (outputs, aux[0])
elif backend_name == "jax":
# JAX requires pure functions
Expand Down
27 changes: 21 additions & 6 deletions deepxde/gradients/gradients_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
__all__ = ["hessian", "jacobian"]

from .jacobian import Jacobian, Jacobians
from ..backend import backend_name, jax, torch
from ..backend import backend_name, jax, tf, torch


class JacobianForward(Jacobian):
Expand All @@ -22,13 +22,27 @@ def __call__(self, i=None, j=None):
if j not in self.J:
if backend_name in [
"tensorflow.compat.v1",
"tensorflow",
"paddle",
]:
# TODO: Other backends
raise NotImplementedError(
"Backend f{backend_name} doesn't support forward-mode autodiff."
)
elif backend_name == "tensorflow":
# We use tensorflow.autodiff.ForwardAccumulator to compute the jvp of
# a function.
# TODO: create the tangent in a smarter way
tangent = tf.one_hot(self.xs.shape[0] * [j], depth=self.xs.shape[1])

def grad_fn(x):
with tf.autodiff.ForwardAccumulator(
primals=x,
tangents=tangent,
) as acc:
u = self.ys[1](x)
return acc.jvp(u)

self.J[j] = (grad_fn(self.xs), grad_fn)
elif backend_name == "pytorch":
# Here we use torch.func.jvp to compute the gradient of a function.
# The implementation is similiar to backend JAX. Vectorization is not
Expand Down Expand Up @@ -64,10 +78,11 @@ def __call__(self, i=None, j=None):

# Compute J[i, j]
if (i, j) not in self.J:
if backend_name in ["pytorch", "jax"]:
# In backend pytorch/jax, a tuple of a tensor/array and a callable is
# returned, so that it is consistent with the argument, which is also
# a tuple. This is useful for further computation, e.g., Hessian.
if backend_name in ["tensorflow", "pytorch", "jax"]:
# In backend tensorflow/pytorch/jax, a tuple of a tensor/tensor/array
# and a callable is returned, so that it is consistent with the argument,
# which is also a tuple. This is useful for further computation, e.g.,
# Hessian.
self.J[i, j] = (
self.J[j][0][:, i : i + 1],
lambda x: self.J[j][1](x)[i : i + 1],
Expand Down
15 changes: 9 additions & 6 deletions deepxde/gradients/jacobian.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@ def __init__(self, ys, xs):
self.ys = ys
self.xs = xs

if backend_name in ["tensorflow.compat.v1", "tensorflow", "paddle"]:
if backend_name in ["tensorflow.compat.v1", "paddle"]:
self.dim_y = ys.shape[1]
elif backend_name == "pytorch":
elif backend_name in ["tensorflow", "pytorch"]:
if config.autodiff == "reverse":
# For backend pytorch with reverse-mode AD, only a tensor is passed.
# For reverse-mode AD, only a tensor is passed.
self.dim_y = ys.shape[1]
elif config.autodiff == "forward":
# For backend pytorch with forward-mode AD, a tuple of a tensor and
# a callable is passed, similar to backend jax.
# For forward-mode AD, a tuple of a tensor and a callable is passed,
# similar to backend jax.
self.dim_y = ys[0].shape[1]
elif backend_name == "jax":
# For backend jax, a tuple of a jax array and a callable is passed as one of
Expand Down Expand Up @@ -115,7 +115,10 @@ def __call__(self, ys, xs, i=None, j=None):
# x.requires_grad_()
# f(x)
if backend_name in ["tensorflow.compat.v1", "tensorflow"]:
key = (ys.ref(), xs.ref())
if config.autodiff == "reverse":
key = (ys.ref(), xs.ref())
elif config.autodiff == "forward":
key = (ys[0].ref(), xs.ref())
elif backend_name in ["pytorch", "paddle"]:
key = (ys, xs)
elif backend_name == "jax":
Expand Down
6 changes: 5 additions & 1 deletion deepxde/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,9 @@ def outputs_losses(training, inputs, targets, auxiliary_vars, losses_fn):
# gradient of outputs wrt inputs will be lost here.
outputs_ = self.net(inputs, training=training)
# Data losses
losses = losses_fn(targets, outputs_, loss_fn, inputs, self)
# if forward-mode AD is used, then a forward call needs to be passed
aux = [self.net] if config.autodiff == "forward" else None
losses = losses_fn(targets, outputs_, loss_fn, inputs, self, aux=aux)
if not isinstance(losses, list):
losses = [losses]
# Regularization loss
Expand Down Expand Up @@ -907,6 +909,8 @@ def predict(self, x, operator=None, callbacks=None):
@tf.function
def op(inputs):
y = self.net(inputs)
if config.autodiff == "forward":
y = (y, self.net)
return operator(inputs, y)

elif utils.get_num_args(operator) == 3:
Expand Down

0 comments on commit aa69952

Please sign in to comment.