-
Notifications
You must be signed in to change notification settings - Fork 26
/
swa_hook.py
149 lines (128 loc) · 5.69 KB
/
swa_hook.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import os.path as osp
from copy import deepcopy
import torch
from mmcv.runner import HOOKS, Hook
from mmcv.runner.checkpoint import save_checkpoint
from mmcv.runner.log_buffer import LogBuffer
from mmdet.core import DistEvalHook, EvalHook
@HOOKS.register_module()
class SWAHook(Hook):
r"""SWA Object Detection Hook.
This hook works together with SWA training config files to train
SWA object detectors <https://arxiv.org/abs/2012.12645>.
Args:
swa_eval (bool): Whether to evaluate the swa model.
Defaults to True.
eval_hook (Hook): Hook class that contains evaluation functions.
Defaults to None.
swa_interval (int): The epoch interval to perform swa
"""
def __init__(self, swa_eval=True, eval_hook=None, swa_interval=1):
if not isinstance(swa_eval, bool):
raise TypeError('swa_eval must be a bool, but got'
f'{type(swa_eval)}')
if swa_eval:
if not isinstance(eval_hook, EvalHook) and \
not isinstance(eval_hook, DistEvalHook):
raise TypeError('eval_hook must be either a EvalHook or a '
'DistEvalHook when swa_eval = True, but got'
f'{type(eval_hook)}')
self.swa_eval = swa_eval
self.eval_hook = eval_hook
self.swa_interval = swa_interval
def before_run(self, runner):
"""Construct the averaged model which will keep track of the running
averages of the parameters of the model."""
model = runner.model
self.model = AveragedModel(model)
self.meta = runner.meta
if self.meta is None:
self.meta = dict()
self.meta.setdefault('hook_msgs', dict())
if isinstance(self.meta, dict) and 'hook_msgs' not in self.meta:
self.meta.setdefault('hook_msgs', dict())
self.log_buffer = LogBuffer()
def after_train_epoch(self, runner):
"""Update the parameters of the averaged model, save and evaluate the
updated averaged model."""
model = runner.model
# Whether to perform swa
if (runner.epoch + 1) % self.swa_interval == 0:
swa_flag = True
else:
swa_flag = False
# update the parameters of the averaged model
if swa_flag:
self.model.update_parameters(model)
# save the swa model
runner.logger.info(
f'Saving swa model at swa-training {runner.epoch + 1} epoch')
filename = 'swa_model_{}.pth'.format(runner.epoch + 1)
filepath = osp.join(runner.work_dir, filename)
optimizer = runner.optimizer
self.meta['hook_msgs']['last_ckpt'] = filepath
save_checkpoint(
self.model.module,
filepath,
optimizer=optimizer,
meta=self.meta)
# evaluate the swa model
if self.swa_eval and swa_flag:
self.work_dir = runner.work_dir
self.rank = runner.rank
self.epoch = runner.epoch
self.logger = runner.logger
self.meta['hook_msgs']['last_ckpt'] = filename
self.eval_hook.after_train_epoch(self)
for name, val in self.log_buffer.output.items():
name = 'swa_' + name
runner.log_buffer.output[name] = val
runner.log_buffer.ready = True
self.log_buffer.clear()
def after_run(self, runner):
# since BN layers in the backbone are frozen,
# we do not need to update the BN for the swa model
pass
def before_epoch(self, runner):
pass
class AveragedModel(torch.nn.Module):
r"""Implements averaged model for Stochastic Weight Averaging (SWA).
AveragedModel class creates a copy of the provided model on the device
and allows to compute running averages of the parameters of the model.
Args:
model (torch.nn.Module): model to use with SWA
device (torch.device, optional): if provided, the averaged model
will be stored on the device. Defaults to None.
avg_fn (function, optional): the averaging function used to update
parameters; the function must take in the current value of the
AveragedModel parameter, the current value of model
parameter and the number of models already averaged; if None,
equally weighted average is used. Defaults to None.
"""
def __init__(self, model, device=None, avg_fn=None):
super(AveragedModel, self).__init__()
self.module = deepcopy(model)
if device is not None:
self.module = self.module.to(device)
self.register_buffer('n_averaged',
torch.tensor(0, dtype=torch.long, device=device))
if avg_fn is None:
def avg_fn(averaged_model_parameter, model_parameter,
num_averaged):
return averaged_model_parameter + (
model_parameter - averaged_model_parameter) / (
num_averaged + 1)
self.avg_fn = avg_fn
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)
def update_parameters(self, model):
for p_swa, p_model in zip(self.parameters(), model.parameters()):
device = p_swa.device
p_model_ = p_model.detach().to(device)
if self.n_averaged == 0:
p_swa.detach().copy_(p_model_)
else:
p_swa.detach().copy_(
self.avg_fn(p_swa.detach(), p_model_,
self.n_averaged.to(device)))
self.n_averaged += 1