-
Notifications
You must be signed in to change notification settings - Fork 287
/
Copy pathmocov2.py
145 lines (127 loc) · 5.22 KB
/
mocov2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import copy
from typing import List, Tuple
import torch
from pytorch_lightning import LightningModule
from torch import Tensor
from torch.nn import Identity
from torch.optim import SGD
from torchvision.models import resnet50
from lightly.loss import NTXentLoss
from lightly.models.modules import MoCoProjectionHead
from lightly.models.utils import (
batch_shuffle,
batch_unshuffle,
get_weight_decay_parameters,
update_momentum,
)
from lightly.transforms import MoCoV2Transform
from lightly.utils.benchmarking import OnlineLinearClassifier
from lightly.utils.scheduler import CosineWarmupScheduler
class MoCoV2(LightningModule):
def __init__(self, batch_size_per_device: int, num_classes: int) -> None:
super().__init__()
self.save_hyperparameters()
self.batch_size_per_device = batch_size_per_device
resnet = resnet50()
resnet.fc = Identity() # Ignore classification head
self.backbone = resnet
self.projection_head = MoCoProjectionHead()
self.key_backbone = copy.deepcopy(self.backbone)
self.key_projection_head = MoCoProjectionHead()
self.criterion = NTXentLoss(
temperature=0.2,
memory_bank_size=(65536, 128),
gather_distributed=True,
)
self.online_classifier = OnlineLinearClassifier(num_classes=num_classes)
def forward(self, x: Tensor) -> Tensor:
return self.backbone(x)
def forward_query_encoder(self, x: Tensor) -> Tuple[Tensor, Tensor]:
features = self(x).flatten(start_dim=1)
projections = self.projection_head(features)
return features, projections
@torch.no_grad()
def forward_key_encoder(self, x: Tensor) -> Tensor:
x, shuffle = batch_shuffle(batch=x, distributed=self.trainer.num_devices > 1)
features = self.key_backbone(x).flatten(start_dim=1)
projections = self.key_projection_head(features)
features = batch_unshuffle(
batch=features,
shuffle=shuffle,
distributed=self.trainer.num_devices > 1,
)
projections = batch_unshuffle(
batch=projections,
shuffle=shuffle,
distributed=self.trainer.num_devices > 1,
)
return projections
def training_step(
self, batch: Tuple[List[Tensor], Tensor, List[str]], batch_idx: int
) -> Tensor:
views, targets = batch[0], batch[1]
# Encode queries.
query_features, query_projections = self.forward_query_encoder(views[1])
# Momentum update. This happens between query and key encoding, following the
# original implementation from the authors:
# https://github.com/facebookresearch/moco/blob/5a429c00bb6d4efdf511bf31b6f01e064bf929ab/moco/builder.py#L142
update_momentum(self.backbone, self.key_backbone, m=0.999)
update_momentum(self.projection_head, self.key_projection_head, m=0.999)
# Encode keys.
key_projections = self.forward_key_encoder(views[0])
loss = self.criterion(query_projections, key_projections)
self.log(
"train_loss", loss, prog_bar=True, sync_dist=True, batch_size=len(targets)
)
# Online linear evaluation.
cls_loss, cls_log = self.online_classifier.training_step(
(query_features.detach(), targets), batch_idx
)
self.log_dict(cls_log, sync_dist=True, batch_size=len(targets))
return loss + cls_loss
def validation_step(
self, batch: Tuple[Tensor, Tensor, List[str]], batch_idx: int
) -> Tensor:
images, targets = batch[0], batch[1]
features = self.forward(images).flatten(start_dim=1)
cls_loss, cls_log = self.online_classifier.validation_step(
(features.detach(), targets), batch_idx
)
self.log_dict(cls_log, prog_bar=True, sync_dist=True, batch_size=len(targets))
return cls_loss
def configure_optimizers(self):
# Don't use weight decay for batch norm, bias parameters, and classification
# head to improve performance.
# NOTE: The original implementation from the authors uses weight decay for all
# parameters.
params, params_no_weight_decay = get_weight_decay_parameters(
[self.backbone, self.projection_head]
)
optimizer = SGD(
[
{"name": "mocov2", "params": params},
{
"name": "mocov2_no_weight_decay",
"params": params_no_weight_decay,
"weight_decay": 0.0,
},
{
"name": "online_classifier",
"params": self.online_classifier.parameters(),
"weight_decay": 0.0,
},
],
lr=0.03 * self.batch_size_per_device * self.trainer.world_size / 256,
momentum=0.9,
weight_decay=1e-4,
)
scheduler = {
"scheduler": CosineWarmupScheduler(
optimizer=optimizer,
warmup_epochs=0,
max_epochs=int(self.trainer.estimated_stepping_batches),
),
"interval": "step",
}
return [optimizer], [scheduler]
transform = MoCoV2Transform()