-
Notifications
You must be signed in to change notification settings - Fork 1
/
fp16_optimizer.py
381 lines (314 loc) · 17.3 KB
/
fp16_optimizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
import torch
from torch import nn
from torch.autograd import Variable
from torch.nn.parameter import Parameter
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from loss_scaler import DynamicLossScaler, LossScaler
FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)
def conversion_helper(val, conversion):
"""Apply conversion to val. Recursively apply conversion if `val` is a nested tuple/list structure."""
if not isinstance(val, (tuple, list)):
return conversion(val)
rtn = [conversion_helper(v, conversion) for v in val]
if isinstance(val, tuple):
rtn = tuple(rtn)
return rtn
def fp32_to_fp16(val):
"""Convert fp32 `val` to fp16"""
def half_conversion(val):
val_typecheck = val
if isinstance(val_typecheck, (Parameter, Variable)):
val_typecheck = val.data
if isinstance(val_typecheck, FLOAT_TYPES):
val = val.half()
return val
return conversion_helper(val, half_conversion)
def fp16_to_fp32(val):
"""Convert fp16 `val` to fp32"""
def float_conversion(val):
val_typecheck = val
if isinstance(val_typecheck, (Parameter, Variable)):
val_typecheck = val.data
if isinstance(val_typecheck, HALF_TYPES):
val = val.float()
return val
return conversion_helper(val, float_conversion)
class FP16_Module(nn.Module):
def __init__(self, module):
super(FP16_Module, self).__init__()
self.add_module('module', module.half())
def forward(self, *inputs, **kwargs):
return fp16_to_fp32(self.module(*(fp32_to_fp16(inputs)), **kwargs))
class FP16_Optimizer(object):
"""
FP16_Optimizer is designed to wrap an existing PyTorch optimizer,
and enable an fp16 model to be trained using a master copy of fp32 weights.
Args:
optimizer (torch.optim.optimizer): Existing optimizer containing initialized fp16 parameters. Internally, FP16_Optimizer replaces the passed optimizer's fp16 parameters with new fp32 parameters copied from the original ones. FP16_Optimizer also stores references to the original fp16 parameters, and updates these fp16 parameters from the master fp32 copy after each step.
static_loss_scale (float, optional, default=1.0): Loss scale used internally to scale fp16 gradients computed by the model. Scaled gradients will be copied to fp32, then downscaled before being applied to the fp32 master params, so static_loss_scale should not affect learning rate.
dynamic_loss_scale (bool, optional, default=False): Use dynamic loss scaling. If True, this will override any static_loss_scale option.
"""
def __init__(self, optimizer, static_loss_scale=1.0, dynamic_loss_scale=False):
if not torch.cuda.is_available:
raise SystemError('Cannot use fp16 without CUDA')
self.fp16_param_groups = []
self.fp32_param_groups = []
self.fp32_flattened_groups = []
for i, param_group in enumerate(optimizer.param_groups):
print("FP16_Optimizer processing param group {}:".format(i))
fp16_params_this_group = []
fp32_params_this_group = []
for param in param_group['params']:
if param.requires_grad:
if param.type() == 'torch.cuda.HalfTensor':
print("FP16_Optimizer received torch.cuda.HalfTensor with {}"
.format(param.size()))
fp16_params_this_group.append(param)
elif param.type() == 'torch.cuda.FloatTensor':
print("FP16_Optimizer received torch.cuda.FloatTensor with {}"
.format(param.size()))
fp32_params_this_group.append(param)
else:
raise TypeError("Wrapped parameters must be either "
"torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
"Received {}".format(param.type()))
fp32_flattened_this_group = None
if len(fp16_params_this_group) > 0:
fp32_flattened_this_group = _flatten_dense_tensors(
[param.detach().data.clone().float() for param in fp16_params_this_group])
fp32_flattened_this_group = Variable(fp32_flattened_this_group, requires_grad = True)
fp32_flattened_this_group.grad = fp32_flattened_this_group.new(
*fp32_flattened_this_group.size())
# python's lovely list concatenation via +
if fp32_flattened_this_group is not None:
param_group['params'] = [fp32_flattened_this_group] + fp32_params_this_group
else:
param_group['params'] = fp32_params_this_group
self.fp16_param_groups.append(fp16_params_this_group)
self.fp32_param_groups.append(fp32_params_this_group)
self.fp32_flattened_groups.append(fp32_flattened_this_group)
# print("self.fp32_flattened_groups = ", self.fp32_flattened_groups)
# print("self.fp16_param_groups = ", self.fp16_param_groups)
self.optimizer = optimizer.__class__(optimizer.param_groups)
# self.optimizer.load_state_dict(optimizer.state_dict())
self.param_groups = self.optimizer.param_groups
if dynamic_loss_scale:
self.dynamic_loss_scale = True
self.loss_scaler = DynamicLossScaler()
else:
self.dynamic_loss_scale = False
self.loss_scaler = LossScaler(static_loss_scale)
self.overflow = False
self.first_closure_call_this_step = True
def zero_grad(self):
"""
Zero fp32 and fp16 parameter grads.
"""
self.optimizer.zero_grad()
for fp16_group in self.fp16_param_groups:
for param in fp16_group:
if param.grad is not None:
param.grad.detach_() # This does appear in torch.optim.optimizer.zero_grad(),
# but I'm not sure why it's needed.
param.grad.zero_()
def _check_overflow(self):
params = []
for group in self.fp16_param_groups:
for param in group:
params.append(param)
for group in self.fp32_param_groups:
for param in group:
params.append(param)
self.overflow = self.loss_scaler.has_overflow(params)
def _update_scale(self, has_overflow=False):
self.loss_scaler.update_scale(has_overflow)
def _copy_grads_fp16_to_fp32(self):
for fp32_group, fp16_group in zip(self.fp32_flattened_groups, self.fp16_param_groups):
if len(fp16_group) > 0:
# This might incur one more deep copy than is necessary.
fp32_group.grad.data.copy_(
_flatten_dense_tensors([fp16_param.grad.data for fp16_param in fp16_group]))
def _downscale_fp32(self):
if self.loss_scale != 1.0:
for param_group in self.optimizer.param_groups:
for param in param_group['params']:
param.grad.data.mul_(1./self.loss_scale)
def clip_fp32_grads(self, clip=-1):
if not self.overflow:
fp32_params = []
for param_group in self.optimizer.param_groups:
for param in param_group['params']:
fp32_params.append(param)
if clip > 0:
return torch.nn.utils.clip_grad_norm(fp32_params, clip)
def _copy_params_fp32_to_fp16(self):
for fp16_group, fp32_group in zip(self.fp16_param_groups, self.fp32_flattened_groups):
if len(fp16_group) > 0:
for fp16_param, fp32_data in zip(fp16_group,
_unflatten_dense_tensors(fp32_group.data, fp16_group)):
fp16_param.data.copy_(fp32_data)
def state_dict(self):
"""
Returns a dict containing the current state of this FP16_Optimizer instance.
This dict contains attributes of FP16_Optimizer, as well as the state_dict
of the contained Pytorch optimizer.
Untested.
"""
state_dict = {}
state_dict['loss_scaler'] = self.loss_scaler
state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
state_dict['overflow'] = self.overflow
state_dict['first_closure_call_this_step'] = self.first_closure_call_this_step
state_dict['optimizer_state_dict'] = self.optimizer.state_dict()
return state_dict
def load_state_dict(self, state_dict):
"""
Loads a state_dict created by an earlier call to state_dict.
Untested.
"""
self.loss_scaler = state_dict['loss_scaler']
self.dynamic_loss_scale = state_dict['dynamic_loss_scale']
self.overflow = state_dict['overflow']
self.first_closure_call_this_step = state_dict['first_closure_call_this_step']
self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])
def step(self, closure=None): # could add clip option.
"""
If no closure is supplied, step should be called after fp16_optimizer_obj.backward(loss).
step updates the fp32 master copy of parameters using the optimizer supplied to
FP16_Optimizer's constructor, then copies the updated fp32 params into the fp16 params
originally referenced by Fp16_Optimizer's constructor, so the user may immediately run
another forward pass using their model.
If a closure is supplied, step may be called without a prior call to self.backward(loss).
However, the user should take care that any loss.backward() call within the closure
has been replaced by fp16_optimizer_obj.backward(loss).
Args:
closure (optional): Closure that will be supplied to the underlying optimizer originally passed to FP16_Optimizer's constructor. closure should call zero_grad on the FP16_Optimizer object, compute the loss, call .backward(loss), and return the loss.
Closure example::
# optimizer is assumed to be an FP16_Optimizer object, previously constructed from an
# existing pytorch optimizer.
for input, target in dataset:
def closure():
optimizer.zero_grad()
output = model(input)
loss = loss_fn(output, target)
optimizer.backward(loss)
return loss
optimizer.step(closure)
.. note::
The only changes that need to be made compared to
`ordinary optimizer closures`_ are that "optimizer" itself should be an instance of
FP16_Optimizer, and that the call to loss.backward should be replaced by
optimizer.backward(loss).
.. warning::
Currently, calling step with a closure is not compatible with dynamic loss scaling.
.. _`ordinary optimizer closures`:
http://pytorch.org/docs/master/optim.html#optimizer-step-closure
"""
if closure is not None and isinstance(self.loss_scaler, DynamicLossScaler):
raise TypeError("Using step with a closure is currently not "
"compatible with dynamic loss scaling.")
scale = self.loss_scaler.loss_scale
self._update_scale(self.overflow)
if self.overflow:
print("OVERFLOW! Skipping step. Attempted loss scale: {}".format(scale))
return
if closure is not None:
self._step_with_closure(closure)
else:
self.optimizer.step()
self._copy_params_fp32_to_fp16()
return
def _step_with_closure(self, closure):
def wrapped_closure():
if self.first_closure_call_this_step:
"""
We expect that the fp16 params are initially fresh on entering self.step(),
so _copy_params_fp32_to_fp16() is unnecessary the first time wrapped_closure()
is called within self.optimizer.step().
"""
self.first_closure_call_this_step = False
else:
"""
If self.optimizer.step() internally calls wrapped_closure more than once,
it may update the fp32 params after each call. However, self.optimizer
doesn't know about the fp16 params at all. If the fp32 params get updated,
we can't rely on self.optimizer to refresh the fp16 params. We need
to handle that manually:
"""
self._copy_params_fp32_to_fp16()
"""
Our API expects the user to give us ownership of the backward() call by
replacing all calls to loss.backward() with optimizer.backward(loss).
This requirement holds whether or not the call to backward() is made within
a closure.
If the user is properly calling optimizer.backward(loss) within "closure,"
calling closure() here will give the fp32 master params fresh gradients
for the optimizer to play with,
so all wrapped_closure needs to do is call closure() and return the loss.
"""
temp_loss = closure()
return temp_loss
self.optimizer.step(wrapped_closure)
self.first_closure_call_this_step = True
def backward(self, loss, update_fp32_grads=True):
"""
fp16_optimizer_obj.backward performs the following conceptual operations:
fp32_loss = loss.float() (see first Note below)
scaled_loss = fp32_loss*loss_scale
scaled_loss.backward(), which accumulates scaled gradients into the .grad attributes of the
fp16 model's leaves.
fp16 grads are then copied to the stored fp32 params' .grad attributes (see second Note).
Finally, fp32 grads are divided by loss_scale.
In this way, after fp16_optimizer_obj.backward, the fp32 parameters have fresh gradients,
and fp16_optimizer_obj.step may be called.
.. note::
Converting the loss to fp32 before applying the loss scale provides some
additional safety against overflow if the user has supplied an fp16 value.
However, for maximum overflow safety, the user should
compute the loss criterion (MSE, cross entropy, etc) in fp32 before supplying it to
fp16_optimizer_obj.backward.
.. note::
The gradients found in an fp16 model's leaves after a call to
fp16_optimizer_obj.backward should not be regarded as valid in general,
because it's possible
they have been scaled (and in the case of dynamic loss scaling,
the scale factor may silently change over time).
If the user wants to inspect gradients after a call to fp16_optimizer_obj.backward,
he/she should query the .grad attribute of FP16_Optimizer's stored fp32 parameters.
Args:
loss: The loss output by the user's model. loss may be either float or half (but see first Note above).
update_fp32_grads (bool, optional, default=True): Option to copy fp16 grads to fp32 grads on this call. By setting this to False, the user can delay this copy, which is useful to eliminate redundant fp16->fp32 grad copies if fp16_optimizer_obj.backward is being called on multiple losses in one iteration. If set to False, the user becomes responsible for calling fp16_optimizer_obj.update_fp32_grads before calling fp16_optimizer_obj.step.
Example::
# Ordinary operation:
optimizer.backward(loss)
# Naive operation with multiple losses (technically valid, but less efficient):
# fp32 grads will be correct after the second call, but
# the first call incurs an unnecessary fp16->fp32 grad copy.
optimizer.backward(loss1)
optimizer.backward(loss2)
# More efficient way to handle multiple losses:
# The fp16->fp32 grad copy is delayed until fp16 grads from all
# losses have been accumulated.
optimizer.backward(loss1, update_fp32_grads=False)
optimizer.backward(loss2, update_fp32_grads=False)
optimizer.update_fp32_grads()
"""
self.loss_scaler.backward(loss.float())
if update_fp32_grads:
self.update_fp32_grads()
def update_fp32_grads(self):
"""
Copy the .grad attribute from stored references to fp16 parameters to
the .grad attribute of the master fp32 parameters that are directly
updated by the optimizer. :attr:`update_fp32_grads` only needs to be called if
fp16_optimizer_obj.backward was called with update_fp32_grads=False.
"""
if self.dynamic_loss_scale:
self._check_overflow()
if self.overflow: return
self._copy_grads_fp16_to_fp32()
self._downscale_fp32()
@property
def loss_scale(self):
return self.loss_scaler.loss_scale