-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathloss.py
69 lines (54 loc) · 2.51 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch
import torch.nn as nn
import torch.nn.functional as F
class LabelDifference(nn.Module):
def __init__(self, distance_type='l1'):
super(LabelDifference, self).__init__()
self.distance_type = distance_type
def forward(self, labels):
# labels: [bs, label_dim]
# output: [bs, bs]
if self.distance_type == 'l1':
return torch.abs(labels[:, None, :] - labels[None, :, :]).sum(dim=-1)
else:
raise ValueError(self.distance_type)
class FeatureSimilarity(nn.Module):
def __init__(self, similarity_type='l2'):
super(FeatureSimilarity, self).__init__()
self.similarity_type = similarity_type
def forward(self, features):
# labels: [bs, feat_dim]
# output: [bs, bs]
if self.similarity_type == 'l2':
return - (features[:, None, :] - features[None, :, :]).norm(2, dim=-1)
else:
raise ValueError(self.similarity_type)
class RnCLoss(nn.Module):
def __init__(self, temperature=2, label_diff='l1', feature_sim='l2'):
super(RnCLoss, self).__init__()
self.t = temperature
self.label_diff_fn = LabelDifference(label_diff)
self.feature_sim_fn = FeatureSimilarity(feature_sim)
def forward(self, features, labels):
# features: [bs, 2, feat_dim]
# labels: [bs, label_dim]
features = torch.cat([features[:, 0], features[:, 1]], dim=0) # [2bs, feat_dim]
labels = labels.repeat(2, 1) # [2bs, label_dim]
label_diffs = self.label_diff_fn(labels)
logits = self.feature_sim_fn(features).div(self.t)
logits_max, _ = torch.max(logits, dim=1, keepdim=True)
logits -= logits_max.detach()
exp_logits = logits.exp()
n = logits.shape[0] # n = 2bs
# remove diagonal
logits = logits.masked_select((1 - torch.eye(n).to(logits.device)).bool()).view(n, n - 1)
exp_logits = exp_logits.masked_select((1 - torch.eye(n).to(logits.device)).bool()).view(n, n - 1)
label_diffs = label_diffs.masked_select((1 - torch.eye(n).to(logits.device)).bool()).view(n, n - 1)
loss = 0.
for k in range(n - 1):
pos_logits = logits[:, k] # 2bs
pos_label_diffs = label_diffs[:, k] # 2bs
neg_mask = (label_diffs >= pos_label_diffs.view(-1, 1)).float() # [2bs, 2bs - 1]
pos_log_probs = pos_logits - torch.log((neg_mask * exp_logits).sum(dim=-1)) # 2bs
loss += - (pos_log_probs / (n * (n - 1))).sum()
return loss