-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathCOCOLoss.py
30 lines (23 loc) · 953 Bytes
/
COCOLoss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# encoding:utf-8
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.nn import Parameter
class COCOLoss(nn.Module):
def __init__(self, num_classes, feat_dim=2048, alpha=50, use_gpu=True):
super(COCOLoss, self).__init__()
self.feat_dim = feat_dim
self.num_classes = num_classes
self.alpha = alpha
self.centers = nn.Parameter(torch.randn(num_classes, feat_dim))
self.use_gpu = use_gpu
if self.use_gpu:
self.centers = nn.Parameter(torch.randn(num_classes, feat_dim).cuda())
def forward(self, feat):
norms = torch.norm(feat, p=2, dim=-1, keepdim=True)
nfeat = torch.div(feat, norms)
snfeat = self.alpha*nfeat
norms_c = torch.norm(self.centers, p=2, dim=-1, keepdim=True)
ncenters = torch.div(self.centers, norms_c)
logits = torch.matmul(snfeat, torch.transpose(ncenters, 0, 1))
return logits