-
Notifications
You must be signed in to change notification settings - Fork 34
/
norms.py
61 lines (52 loc) · 1.78 KB
/
norms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
import torch.nn as nn
class ChannelWiseLayerNorm(nn.LayerNorm):
"""
Channel wise layer normalization
"""
def __init__(self, *args, **kwargs):
super(ChannelWiseLayerNorm, self).__init__(*args, **kwargs)
def forward(self, x):
"""
x: N x C x T
"""
x = torch.transpose(x, 1, 2)
x = super().forward(x)
x = torch.transpose(x, 1, 2)
return x
class GlobalChannelLayerNorm(nn.Module):
'''
Global Layer Normalization
'''
def __init__(self, channel_size):
super(GlobalChannelLayerNorm, self).__init__()
self.channel_size = channel_size
self.gamma = nn.Parameter(torch.ones(channel_size),
requires_grad=True)
self.beta = nn.Parameter(torch.zeros(channel_size),
requires_grad=True)
def apply_gain_and_bias(self, normed_x):
""" Assumes input of size `[batch, chanel, *]`. """
return (self.gamma * normed_x.transpose(1, -1) +
self.beta).transpose(1, -1)
def forward(self, x):
"""
x: N x C x T
"""
dims = list(range(1, len(x.shape)))
mean = x.mean(dim=dims, keepdim=True)
var = torch.pow(x - mean, 2).mean(dim=dims, keepdim=True)
return self.apply_gain_and_bias((x - mean) / (var + 1e-8).sqrt())
def select_norm(norm, dim):
"""
Build normalize layer
LN cost more memory than BN
"""
if norm not in ["cLN", "gLN", "BN"]:
raise RuntimeError("Unsupported normalize layer: {}".format(norm))
if norm == "cLN":
return ChannelWiseLayerNorm(dim, elementwise_affine=True)
elif norm == "BN":
return nn.BatchNorm1d(dim)
else:
return GlobalChannelLayerNorm(dim)