forked from PaddlePaddle/PaddleDetection
-
Notifications
You must be signed in to change notification settings - Fork 0
/
resnet.py
executable file
·609 lines (542 loc) · 19.4 KB
/
resnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from numbers import Integral
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from ppdet.core.workspace import register, serializable
from paddle.regularizer import L2Decay
from paddle.nn.initializer import Uniform
from paddle import ParamAttr
from paddle.nn.initializer import Constant
from paddle.vision.ops import DeformConv2D
from .name_adapter import NameAdapter
from ..shape_spec import ShapeSpec
__all__ = ['ResNet', 'Res5Head', 'Blocks', 'BasicBlock', 'BottleNeck']
ResNet_cfg = {
18: [2, 2, 2, 2],
34: [3, 4, 6, 3],
50: [3, 4, 6, 3],
101: [3, 4, 23, 3],
152: [3, 8, 36, 3],
}
class ConvNormLayer(nn.Layer):
def __init__(self,
ch_in,
ch_out,
filter_size,
stride,
groups=1,
act=None,
norm_type='bn',
norm_decay=0.,
freeze_norm=True,
lr=1.0,
dcn_v2=False):
super(ConvNormLayer, self).__init__()
assert norm_type in ['bn', 'sync_bn']
self.norm_type = norm_type
self.act = act
self.dcn_v2 = dcn_v2
if not self.dcn_v2:
self.conv = nn.Conv2D(
in_channels=ch_in,
out_channels=ch_out,
kernel_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
groups=groups,
weight_attr=ParamAttr(learning_rate=lr),
bias_attr=False)
else:
self.offset_channel = 2 * filter_size**2
self.mask_channel = filter_size**2
self.conv_offset = nn.Conv2D(
in_channels=ch_in,
out_channels=3 * filter_size**2,
kernel_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
weight_attr=ParamAttr(initializer=Constant(0.)),
bias_attr=ParamAttr(initializer=Constant(0.)))
self.conv = DeformConv2D(
in_channels=ch_in,
out_channels=ch_out,
kernel_size=filter_size,
stride=stride,
padding=(filter_size - 1) // 2,
dilation=1,
groups=groups,
weight_attr=ParamAttr(learning_rate=lr),
bias_attr=False)
norm_lr = 0. if freeze_norm else lr
param_attr = ParamAttr(
learning_rate=norm_lr,
regularizer=L2Decay(norm_decay),
trainable=False if freeze_norm else True)
bias_attr = ParamAttr(
learning_rate=norm_lr,
regularizer=L2Decay(norm_decay),
trainable=False if freeze_norm else True)
global_stats = True if freeze_norm else None
if norm_type in ['sync_bn', 'bn']:
self.norm = nn.BatchNorm2D(
ch_out,
weight_attr=param_attr,
bias_attr=bias_attr,
use_global_stats=global_stats)
norm_params = self.norm.parameters()
if freeze_norm:
for param in norm_params:
param.stop_gradient = True
def forward(self, inputs):
if not self.dcn_v2:
out = self.conv(inputs)
else:
offset_mask = self.conv_offset(inputs)
offset, mask = paddle.split(
offset_mask,
num_or_sections=[self.offset_channel, self.mask_channel],
axis=1)
mask = F.sigmoid(mask)
out = self.conv(inputs, offset, mask=mask)
if self.norm_type in ['bn', 'sync_bn']:
out = self.norm(out)
if self.act:
out = getattr(F, self.act)(out)
return out
class SELayer(nn.Layer):
def __init__(self, ch, reduction_ratio=16):
super(SELayer, self).__init__()
self.pool = nn.AdaptiveAvgPool2D(1)
stdv = 1.0 / math.sqrt(ch)
c_ = ch // reduction_ratio
self.squeeze = nn.Linear(
ch,
c_,
weight_attr=paddle.ParamAttr(initializer=Uniform(-stdv, stdv)),
bias_attr=True)
stdv = 1.0 / math.sqrt(c_)
self.extract = nn.Linear(
c_,
ch,
weight_attr=paddle.ParamAttr(initializer=Uniform(-stdv, stdv)),
bias_attr=True)
def forward(self, inputs):
out = self.pool(inputs)
out = paddle.squeeze(out, axis=[2, 3])
out = self.squeeze(out)
out = F.relu(out)
out = self.extract(out)
out = F.sigmoid(out)
out = paddle.unsqueeze(out, axis=[2, 3])
scale = out * inputs
return scale
class BasicBlock(nn.Layer):
expansion = 1
def __init__(self,
ch_in,
ch_out,
stride,
shortcut,
variant='b',
groups=1,
base_width=64,
lr=1.0,
norm_type='bn',
norm_decay=0.,
freeze_norm=True,
dcn_v2=False,
std_senet=False):
super(BasicBlock, self).__init__()
assert groups == 1 and base_width == 64, 'BasicBlock only supports groups=1 and base_width=64'
self.shortcut = shortcut
if not shortcut:
if variant == 'd' and stride == 2:
self.short = nn.Sequential()
self.short.add_sublayer(
'pool',
nn.AvgPool2D(
kernel_size=2, stride=2, padding=0, ceil_mode=True))
self.short.add_sublayer(
'conv',
ConvNormLayer(
ch_in=ch_in,
ch_out=ch_out,
filter_size=1,
stride=1,
norm_type=norm_type,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
lr=lr))
else:
self.short = ConvNormLayer(
ch_in=ch_in,
ch_out=ch_out,
filter_size=1,
stride=stride,
norm_type=norm_type,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
lr=lr)
self.branch2a = ConvNormLayer(
ch_in=ch_in,
ch_out=ch_out,
filter_size=3,
stride=stride,
act='relu',
norm_type=norm_type,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
lr=lr)
self.branch2b = ConvNormLayer(
ch_in=ch_out,
ch_out=ch_out,
filter_size=3,
stride=1,
act=None,
norm_type=norm_type,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
lr=lr,
dcn_v2=dcn_v2)
self.std_senet = std_senet
if self.std_senet:
self.se = SELayer(ch_out)
def forward(self, inputs):
out = self.branch2a(inputs)
out = self.branch2b(out)
if self.std_senet:
out = self.se(out)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
out = paddle.add(x=out, y=short)
out = F.relu(out)
return out
class BottleNeck(nn.Layer):
expansion = 4
def __init__(self,
ch_in,
ch_out,
stride,
shortcut,
variant='b',
groups=1,
base_width=4,
lr=1.0,
norm_type='bn',
norm_decay=0.,
freeze_norm=True,
dcn_v2=False,
std_senet=False):
super(BottleNeck, self).__init__()
if variant == 'a':
stride1, stride2 = stride, 1
else:
stride1, stride2 = 1, stride
# ResNeXt
width = int(ch_out * (base_width / 64.)) * groups
self.shortcut = shortcut
if not shortcut:
if variant == 'd' and stride == 2:
self.short = nn.Sequential()
self.short.add_sublayer(
'pool',
nn.AvgPool2D(
kernel_size=2, stride=2, padding=0, ceil_mode=True))
self.short.add_sublayer(
'conv',
ConvNormLayer(
ch_in=ch_in,
ch_out=ch_out * self.expansion,
filter_size=1,
stride=1,
norm_type=norm_type,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
lr=lr))
else:
self.short = ConvNormLayer(
ch_in=ch_in,
ch_out=ch_out * self.expansion,
filter_size=1,
stride=stride,
norm_type=norm_type,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
lr=lr)
self.branch2a = ConvNormLayer(
ch_in=ch_in,
ch_out=width,
filter_size=1,
stride=stride1,
groups=1,
act='relu',
norm_type=norm_type,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
lr=lr)
self.branch2b = ConvNormLayer(
ch_in=width,
ch_out=width,
filter_size=3,
stride=stride2,
groups=groups,
act='relu',
norm_type=norm_type,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
lr=lr,
dcn_v2=dcn_v2)
self.branch2c = ConvNormLayer(
ch_in=width,
ch_out=ch_out * self.expansion,
filter_size=1,
stride=1,
groups=1,
norm_type=norm_type,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
lr=lr)
self.std_senet = std_senet
if self.std_senet:
self.se = SELayer(ch_out * self.expansion)
def forward(self, inputs):
out = self.branch2a(inputs)
out = self.branch2b(out)
out = self.branch2c(out)
if self.std_senet:
out = self.se(out)
if self.shortcut:
short = inputs
else:
short = self.short(inputs)
out = paddle.add(x=out, y=short)
out = F.relu(out)
return out
class Blocks(nn.Layer):
def __init__(self,
block,
ch_in,
ch_out,
count,
name_adapter,
stage_num,
variant='b',
groups=1,
base_width=64,
lr=1.0,
norm_type='bn',
norm_decay=0.,
freeze_norm=True,
dcn_v2=False,
std_senet=False):
super(Blocks, self).__init__()
self.blocks = []
for i in range(count):
conv_name = name_adapter.fix_layer_warp_name(stage_num, count, i)
layer = self.add_sublayer(
conv_name,
block(
ch_in=ch_in,
ch_out=ch_out,
stride=2 if i == 0 and stage_num != 2 else 1,
shortcut=False if i == 0 else True,
variant=variant,
groups=groups,
base_width=base_width,
lr=lr,
norm_type=norm_type,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
dcn_v2=dcn_v2,
std_senet=std_senet))
self.blocks.append(layer)
if i == 0:
ch_in = ch_out * block.expansion
def forward(self, inputs):
block_out = inputs
for block in self.blocks:
block_out = block(block_out)
return block_out
@register
@serializable
class ResNet(nn.Layer):
__shared__ = ['norm_type']
def __init__(self,
depth=50,
ch_in=64,
variant='b',
lr_mult_list=[1.0, 1.0, 1.0, 1.0],
groups=1,
base_width=64,
norm_type='bn',
norm_decay=0,
freeze_norm=True,
freeze_at=0,
return_idx=[0, 1, 2, 3],
dcn_v2_stages=[-1],
num_stages=4,
std_senet=False):
"""
Residual Network, see https://arxiv.org/abs/1512.03385
Args:
depth (int): ResNet depth, should be 18, 34, 50, 101, 152.
ch_in (int): output channel of first stage, default 64
variant (str): ResNet variant, supports 'a', 'b', 'c', 'd' currently
lr_mult_list (list): learning rate ratio of different resnet stages(2,3,4,5),
lower learning rate ratio is need for pretrained model
got using distillation(default as [1.0, 1.0, 1.0, 1.0]).
groups (int): group convolution cardinality
base_width (int): base width of each group convolution
norm_type (str): normalization type, 'bn', 'sync_bn' or 'affine_channel'
norm_decay (float): weight decay for normalization layer weights
freeze_norm (bool): freeze normalization layers
freeze_at (int): freeze the backbone at which stage
return_idx (list): index of the stages whose feature maps are returned
dcn_v2_stages (list): index of stages who select deformable conv v2
num_stages (int): total num of stages
std_senet (bool): whether use senet, default True
"""
super(ResNet, self).__init__()
self._model_type = 'ResNet' if groups == 1 else 'ResNeXt'
assert num_stages >= 1 and num_stages <= 4
self.depth = depth
self.variant = variant
self.groups = groups
self.base_width = base_width
self.norm_type = norm_type
self.norm_decay = norm_decay
self.freeze_norm = freeze_norm
self.freeze_at = freeze_at
if isinstance(return_idx, Integral):
return_idx = [return_idx]
assert max(return_idx) < num_stages, \
'the maximum return index must smaller than num_stages, ' \
'but received maximum return index is {} and num_stages ' \
'is {}'.format(max(return_idx), num_stages)
self.return_idx = return_idx
self.num_stages = num_stages
assert len(lr_mult_list) == 4, \
"lr_mult_list length must be 4 but got {}".format(len(lr_mult_list))
if isinstance(dcn_v2_stages, Integral):
dcn_v2_stages = [dcn_v2_stages]
assert max(dcn_v2_stages) < num_stages
if isinstance(dcn_v2_stages, Integral):
dcn_v2_stages = [dcn_v2_stages]
assert max(dcn_v2_stages) < num_stages
self.dcn_v2_stages = dcn_v2_stages
block_nums = ResNet_cfg[depth]
na = NameAdapter(self)
conv1_name = na.fix_c1_stage_name()
if variant in ['c', 'd']:
conv_def = [
[3, ch_in // 2, 3, 2, "conv1_1"],
[ch_in // 2, ch_in // 2, 3, 1, "conv1_2"],
[ch_in // 2, ch_in, 3, 1, "conv1_3"],
]
else:
conv_def = [[3, ch_in, 7, 2, conv1_name]]
self.conv1 = nn.Sequential()
for (c_in, c_out, k, s, _name) in conv_def:
self.conv1.add_sublayer(
_name,
ConvNormLayer(
ch_in=c_in,
ch_out=c_out,
filter_size=k,
stride=s,
groups=1,
act='relu',
norm_type=norm_type,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
lr=1.0))
self.ch_in = ch_in
ch_out_list = [64, 128, 256, 512]
block = BottleNeck if depth >= 50 else BasicBlock
self._out_channels = [block.expansion * v for v in ch_out_list]
self._out_strides = [4, 8, 16, 32]
self.res_layers = []
for i in range(num_stages):
lr_mult = lr_mult_list[i]
stage_num = i + 2
res_name = "res{}".format(stage_num)
res_layer = self.add_sublayer(
res_name,
Blocks(
block,
self.ch_in,
ch_out_list[i],
count=block_nums[i],
name_adapter=na,
stage_num=stage_num,
variant=variant,
groups=groups,
base_width=base_width,
lr=lr_mult,
norm_type=norm_type,
norm_decay=norm_decay,
freeze_norm=freeze_norm,
dcn_v2=(i in self.dcn_v2_stages),
std_senet=std_senet))
self.res_layers.append(res_layer)
self.ch_in = self._out_channels[i]
if freeze_at >= 0:
self._freeze_parameters(self.conv1)
for i in range(min(freeze_at + 1, num_stages)):
self._freeze_parameters(self.res_layers[i])
def _freeze_parameters(self, m):
for p in m.parameters():
p.stop_gradient = True
@property
def out_shape(self):
return [
ShapeSpec(
channels=self._out_channels[i], stride=self._out_strides[i])
for i in self.return_idx
]
def forward(self, inputs):
x = inputs['image']
conv1 = self.conv1(x)
x = F.max_pool2d(conv1, kernel_size=3, stride=2, padding=1)
outs = []
for idx, stage in enumerate(self.res_layers):
x = stage(x)
if idx in self.return_idx:
outs.append(x)
return outs
@register
class Res5Head(nn.Layer):
def __init__(self, depth=50):
super(Res5Head, self).__init__()
feat_in, feat_out = [1024, 512]
if depth < 50:
feat_in = 256
na = NameAdapter(self)
block = BottleNeck if depth >= 50 else BasicBlock
self.res5 = Blocks(
block, feat_in, feat_out, count=3, name_adapter=na, stage_num=5)
self.feat_out = feat_out if depth < 50 else feat_out * 4
@property
def out_shape(self):
return [ShapeSpec(
channels=self.feat_out,
stride=16, )]
def forward(self, roi_feat, stage=0):
y = self.res5(roi_feat)
return y