-
Notifications
You must be signed in to change notification settings - Fork 17
/
lsoftmax.py
84 lines (68 loc) · 3.17 KB
/
lsoftmax.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import math
import torch
from torch import nn
from scipy.special import binom
class LSoftmaxLinear(nn.Module):
def __init__(self, input_features, output_features, margin, device):
super().__init__()
self.input_dim = input_features # number of input feature i.e. output of the last fc layer
self.output_dim = output_features # number of output = class numbers
self.margin = margin # m
self.beta = 100
self.beta_min = 0
self.scale = 0.99
self.device = device # gpu or cpu
# Initialize L-Softmax parameters
self.weight = nn.Parameter(torch.FloatTensor(input_features, output_features))
self.divisor = math.pi / self.margin # pi/m
self.C_m_2n = torch.Tensor(binom(margin, range(0, margin + 1, 2))).to(device) # C_m{2n}
self.cos_powers = torch.Tensor(range(self.margin, -1, -2)).to(device) # m - 2n
self.sin2_powers = torch.Tensor(range(len(self.cos_powers))).to(device) # n
self.signs = torch.ones(margin // 2 + 1).to(device) # 1, -1, 1, -1, ...
self.signs[1::2] = -1
def calculate_cos_m_theta(self, cos_theta):
sin2_theta = 1 - cos_theta**2
cos_terms = cos_theta.unsqueeze(1) ** self.cos_powers.unsqueeze(0) # cos^{m - 2n}
sin2_terms = (sin2_theta.unsqueeze(1) # sin2^{n}
** self.sin2_powers.unsqueeze(0))
cos_m_theta = (self.signs.unsqueeze(0) * # -1^{n} * C_m{2n} * cos^{m - 2n} * sin2^{n}
self.C_m_2n.unsqueeze(0) *
cos_terms *
sin2_terms).sum(1) # summation of all terms
return cos_m_theta
def reset_parameters(self):
nn.init.kaiming_normal_(self.weight.data.t())
def find_k(self, cos):
# to account for acos numerical errors
eps = 1e-7
cos = torch.clamp(cos, -1 + eps, 1 - eps)
acos = cos.acos()
k = (acos / self.divisor).floor().detach()
return k
def forward(self, input, target=None):
if self.training:
assert target is not None
x, w = input, self.weight
beta = max(self.beta, self.beta_min)
logit = x.mm(w)
indexes = range(logit.size(0))
logit_target = logit[indexes, target]
# cos(theta) = w * x / ||w||*||x||
w_target_norm = w[:, target].norm(p=2, dim=0)
x_norm = x.norm(p=2, dim=1)
cos_theta_target = logit_target / (w_target_norm * x_norm + 1e-10)
# equation 7
cos_m_theta_target = self.calculate_cos_m_theta(cos_theta_target)
# find k in equation 6
k = self.find_k(cos_theta_target)
# f_y_i
logit_target_updated = (w_target_norm *
x_norm *
(((-1) ** k * cos_m_theta_target) - 2 * k))
logit_target_updated_beta = (logit_target_updated + beta * logit[indexes, target]) / (1 + beta)
logit[indexes, target] = logit_target_updated_beta
self.beta *= self.scale
return logit
else:
assert target is None
return input.mm(self.weight)