Skip to content

Commit

Permalink
Update on "[WIP] Compute forward grads for saved_mean and saved_var w…
Browse files Browse the repository at this point in the history
…hen input requires grad"


We want to avoid having to recompute saved_mean and saved_invstd in batch_norm_backward's decomposition in functorch (see pytorch/functorch#877), but also avoid unnecessarily computing forward grads for saved_mean and saved_invstd when they are not needed.

Tested locally with: `python test/test_ops.py -k test_jvpvjp_nn_functional_batch_norm`

Issues:
- not sure if gradgrad in core is missing something, but it is able to pass while the fwgrad_bwgrad comparison fails in functorch


[ghstack-poisoned]
  • Loading branch information
soulitzer committed Jul 21, 2022
2 parents 8af0a6a + 6ffc0a9 commit aed7c7e
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions torch/autograd/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,16 @@ def save_for_backward(self, *tensors: torch.Tensor):
incorrect gradients and memory leaks, and enable the application of saved
tensor hooks. See :class:`torch.autograd.graph.saved_tensors_hooks`.
Note that if intermediary tensors (i.e., tensors that are neither input
nor output) are saved for backward, your custom Function may not support
`double backward <https://pytorch.org/tutorials/intermediate/custom_function_double_backward_tutorial.html>`_.
Note that if intermediary tensors (tensors that are neither input
nor output of :func:`forward`) are saved for backward, your custom Function
may not support double backward.
Custom Functions that do not support double backward should decorate their
:func:`backward` method with `@once_differentiable` so that performing
double backward raises an error. If you'd like to support double backawrd
:func:`backward` method with ``@once_differentiable`` so that performing
double backward raises an error. If you'd like to support double backward
you can either recompute intermediaries based on the inputs during backward
or return the intermediaries as the outputs of the custom Function. See
the tutorial linked above for more details.
or return the intermediaries as the outputs of the custom Function. See the
`double backward tutorial <https://pytorch.org/tutorials/intermediate/custom_function_double_backward_tutorial.html>`_.
for more details.
In :func:`backward`, saved tensors can be accessed through the :attr:`saved_tensors`
attribute. Before returning them to the user, a check is made to ensure
Expand Down

0 comments on commit aed7c7e

Please sign in to comment.