-
Notifications
You must be signed in to change notification settings - Fork 0
/
critic.py
33 lines (28 loc) · 1.41 KB
/
critic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
from torch import nn
class LinearCritic(nn.Module):
def __init__(self, latent_dim, temperature=1.):
super(LinearCritic, self).__init__()
self.temperature = temperature
self.projection_dim = 128
self.w1 = nn.Linear(latent_dim, latent_dim, bias=False)
self.bn1 = nn.BatchNorm1d(latent_dim)
self.relu = nn.ReLU()
self.w2 = nn.Linear(latent_dim, self.projection_dim, bias=False)
self.bn2 = nn.BatchNorm1d(self.projection_dim, affine=False)
self.cossim = nn.CosineSimilarity(dim=-1)
def project(self, h):
return self.bn2(self.w2(self.relu(self.bn1(self.w1(h)))))
def forward(self, h1, h2):
z1, z2 = self.project(h1), self.project(h2)
sim11 = self.cossim(z1.unsqueeze(-2), z1.unsqueeze(-3)) / self.temperature
sim22 = self.cossim(z2.unsqueeze(-2), z2.unsqueeze(-3)) / self.temperature
sim12 = self.cossim(z1.unsqueeze(-2), z2.unsqueeze(-3)) / self.temperature
d = sim12.shape[-1]
sim11[..., range(d), range(d)] = float('-inf')
sim22[..., range(d), range(d)] = float('-inf')
raw_scores1 = torch.cat([sim12, sim11], dim=-1)
raw_scores2 = torch.cat([sim22, sim12.transpose(-1, -2)], dim=-1)
raw_scores = torch.cat([raw_scores1, raw_scores2], dim=-2)
targets = torch.arange(2 * d, dtype=torch.long, device=raw_scores.device)
return raw_scores, targets