Skip to content

Commit

Permalink
Add decomposition for aten.native_batch_norm_backward op
Browse files Browse the repository at this point in the history
This commit adds decomposition for the `aten.native_batch_norm_backward`
op.

Signed-Off-by: Gaurav Shukla <[email protected]>
  • Loading branch information
Shukla-Gaurav committed Apr 8, 2022
1 parent 6a52d17 commit 5d06c10
Showing 1 changed file with 54 additions and 0 deletions.
54 changes: 54 additions & 0 deletions functorch/_src/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,60 @@ def native_layer_norm_backward(grad_out: Tensor, input: Tensor, normalized_shape
return (d_input, d_weight, d_bias)


@register_decomposition(aten.native_batch_norm_backward)
def native_batch_norm_backward(grad_out: Tensor, input: Tensor, weight: Optional[Tensor], running_mean: Optional[Tensor], running_var: Optional[Tensor], save_mean: Optional[Tensor], save_invstd: Optional[Tensor], train: bool, eps: float, output_mask: List[bool]) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
input_shape = input.shape
input_rank = input.dim()
assert input_rank >= 2, "rank of the input must be at least 2"

axis = 1
num_features = prod(input_shape) / input_shape[axis]
mean = save_mean
invstd = save_invstd
if train:
assert save_mean is not None and save_invstd is not None, "when train=True, save_mean and save_invstd are required"
else:
mean = running_mean
invstd = torch.rsqrt(running_var + eps)

broadcast_mask = [1] * input_rank
broadcast_mask[axis] = input_shape[axis]

reduction_axes = []
for i in range(input_rank):
if i != axis:
reduction_axes.append(i)

mean = torch.reshape(mean, broadcast_mask)
norm = 1.0 / num_features
grad_output_sum = torch.sum(grad_out, reduction_axes)
dot_p = torch.sum(grad_out * (input - mean), reduction_axes)

grad_mean = torch.reshape(grad_output_sum * norm, broadcast_mask)
proj_scale = torch.reshape(torch.mul(dot_p * norm, invstd * invstd), broadcast_mask)

grad_scale = None
if weight is None:
grad_scale = torch.reshape(invstd, broadcast_mask) * 1.0
else:
grad_scale = torch.reshape(invstd * weight, broadcast_mask)
grad_input = None
if train:
proj = (input - mean) * proj_scale
grad_input = ((grad_out - proj) - grad_mean) * grad_scale
else:
grad_input = grad_out * grad_scale

grad_weight = None
if output_mask[1]:
grad_weight = dot_p * invstd

grad_bias = None
if output_mask[2]:
grad_bias = grad_output_sum
return (grad_input, grad_weight, grad_bias)


@register_decomposition(aten.clamp_min)
def clamp_min(self: Tensor, min: float):
return torch.clamp(self, min=min)
Expand Down

0 comments on commit 5d06c10

Please sign in to comment.