-
Notifications
You must be signed in to change notification settings - Fork 4
/
norm.patch
65 lines (59 loc) · 2.72 KB
/
norm.patch
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
diff --git a/torch/autograd/_functions/reduce.py b/torch/autograd/_functions/reduce.py
index 0a6f37af8f..b5e5425f72 100644
--- a/torch/autograd/_functions/reduce.py
+++ b/torch/autograd/_functions/reduce.py
@@ -226,43 +226,34 @@ def forward(ctx, input, p=2, dim=None, keepdim=None):
ctx.keepdim = False if keepdim is None else keepdim
if dim is None:
- ctx.norm = input.norm(p)
- ctx.save_for_backward(input)
- return input.new((ctx.norm,))
+ norm = input.norm(p)
+ output = input.new((norm,))
else:
if keepdim is not None:
output = input.norm(p, dim, keepdim=keepdim)
else:
output = input.norm(p, dim)
- ctx.save_for_backward(input, output)
- return output
+ ctx.save_for_backward(input, output)
+ return output
@staticmethod
def backward(ctx, grad_output):
- if ctx.dim is None:
- input, = ctx.saved_variables
- if ctx.p == 2:
- scale_v = (grad_output / ctx.norm).expand_as(input)
- return input.mul(scale_v), None, None, None
- else:
- pow = input.abs().pow(ctx.p - 2)
- scale_v = (grad_output / ctx.norm ** (ctx.p - 1)).expand_as(input)
- return input.mul(pow).mul(scale_v), None, None, None
+ input, output = ctx.saved_variables
+ if ctx.dim is not None and ctx.keepdim is False and input.dim() != 1:
+ grad_output = grad_output.unsqueeze(ctx.dim)
+ output = output.unsqueeze(ctx.dim)
+
+ if ctx.p == 2:
+ grad_input = input.mul(grad_output).div(output)
else:
- input, output = ctx.saved_variables
+ input_pow = input.abs().pow(ctx.p - 2)
+ output_pow = output.pow(ctx.p - 1)
+ grad_input = input.mul(input_pow).mul(grad_output).div(output_pow)
- if ctx.keepdim is False and input.dim() != 1:
- grad_output = grad_output.unsqueeze(ctx.dim)
- output = output.unsqueeze(ctx.dim)
+ # Special case at 0 where we return a subgradient containing 0
+ grad_input.masked_fill_(output == 0, 0)
- big_grad_output = grad_output.expand_as(input)
- if ctx.p == 2:
- big_output = output.expand_as(input)
- return input.mul(big_grad_output).div(big_output), None, None, None
- else:
- pow = input.abs().pow(ctx.p - 2)
- big_output = output.pow(ctx.p - 1).expand_as(input)
- return input.mul(pow).mul(big_grad_output).div(big_output), None, None, None
+ return grad_input, None, None, None
# TODO: renorm