Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update on "[WIP] Compute forward grads for saved_mean and saved_var w…
…hen input requires grad" We want to avoid having to recompute saved_mean and saved_invstd in batch_norm_backward's decomposition in functorch (see pytorch/functorch#877), but also avoid unnecessarily computing forward grads for saved_mean and saved_invstd when they are not needed. Tested locally with: `python test/test_ops.py -k test_jvpvjp_nn_functional_batch_norm` Issues: - not sure if gradgrad in core is missing something, but it is able to pass while the fwgrad_bwgrad comparison fails in functorch [ghstack-poisoned]
- Loading branch information