-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathclassifier.py
81 lines (74 loc) · 3.23 KB
/
classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
from torch import nn
from torch.nn import init
from torch.nn import Module
from config import config as cfg
from torch.nn.parameter import Parameter
import torch.nn.functional as F
import math
import torch.distributed as dist
class classifier(Module):
@torch.no_grad()
def __init__(self, in_features, out_features, sample_rate = 1.0):
super(classifier, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = torch.empty([out_features, in_features],device=cfg.local_rank)
self.momentum = torch.zeros_like(self.weight)
self.sample_rate = sample_rate
self.sub_num = int(sample_rate*out_features)
self.stream = torch.cuda.Stream(cfg.local_rank)
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if sample_rate == 1.0:
self.update= lambda : 0
self.sub_weight = Parameter(self.weight)
self.sub_momentum = self.momentum
else:
self.sub_weight = Parameter(torch.empty([0,0]).cuda(cfg.local_rank))
self.perm = torch.LongTensor(cfg.num).cuda(cfg.local_rank)
@torch.no_grad()
def sample(self, global_label):
P = (cfg.s <= global_label) & (global_label < cfg.s+cfg.num)
global_label[~P] = -1
global_label[P] -= cfg.s
if self.sample_rate!=1.0:
positive = torch.unique(global_label[P], sorted=False)
if self.sub_num-positive.size(0) > 0:
torch.randperm(cfg.num, out=self.perm)
start = cfg.num-self.sub_num
index = torch.cat((positive, self.perm[start:]))
index = torch.unique(index, sorted=False)
start = index.size()[0]-self.sub_num
index = index[start:]
else:
index = positive
index = torch.sort(index)[0].long()
self.index = index
global_label[P] = torch.searchsorted(index, global_label[P])
self.sub_weight = Parameter(self.weight[index])
self.sub_momentum = self.momentum[index]
def forward(self, x_gather, norm_weight):
torch.cuda.current_stream().wait_stream(self.stream)
logits = F.linear(x_gather, norm_weight)
return logits
@torch.no_grad()
def update(self,):
pass
self.momentum[self.index]=self.sub_momentum
self.weight[self.index]=self.sub_weight
def prepare(self, label, optimizer):
with torch.cuda.stream(self.stream):
lable_gather = torch.zeros(label.size()[0]*cfg.world_size, device=cfg.local_rank, dtype=torch.long)
dist.all_gather(list(lable_gather.chunk(cfg.world_size, dim=0)), label)
self.sample(lable_gather)
optimizer.state.pop(optimizer.param_groups[-1]['params'][0], None)
optimizer.param_groups[-1]['params'][0] = self.sub_weight
optimizer.state[self.sub_weight]['momentum_buffer'] = self.sub_momentum
norm_weight = F.normalize(self.sub_weight)
return lable_gather, norm_weight
if __name__ == "__main__":
cfg.local_rank=0
cfg.s=0
cfg.num=100
clss=classifier(5,25,sample_rate=0.1)
print(list(clss.parameters()))