-
Notifications
You must be signed in to change notification settings - Fork 70
/
model.py
252 lines (208 loc) · 8.93 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
import torch
import torch.nn as nn
import torch.nn.functional as F
def get_model_parameters(model):
total_parameters = 0
for layer in list(model.parameters()):
layer_parameter = 1
for l in list(layer.size()):
layer_parameter *= l
total_parameters += layer_parameter
return total_parameters
def _weights_init(m):
if isinstance(m, nn.Conv2d):
torch.nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
torch.nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
n = m.weight.size(1)
m.weight.data.normal_(0, 0.01)
m.bias.data.zero_()
class h_sigmoid(nn.Module):
def __init__(self, inplace=True):
super(h_sigmoid, self).__init__()
self.inplace = inplace
def forward(self, x):
return F.relu6(x + 3., inplace=self.inplace) / 6.
class h_swish(nn.Module):
def __init__(self, inplace=True):
super(h_swish, self).__init__()
self.inplace = inplace
def forward(self, x):
out = F.relu6(x + 3., self.inplace) / 6.
return out * x
def _make_divisible(v, divisor=8, min_value=None):
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
class SqueezeBlock(nn.Module):
def __init__(self, exp_size, divide=4):
super(SqueezeBlock, self).__init__()
self.dense = nn.Sequential(
nn.Linear(exp_size, exp_size // divide),
nn.ReLU(inplace=True),
nn.Linear(exp_size // divide, exp_size),
h_sigmoid()
)
def forward(self, x):
batch, channels, height, width = x.size()
out = F.avg_pool2d(x, kernel_size=[height, width]).view(batch, -1)
out = self.dense(out)
out = out.view(batch, channels, 1, 1)
# out = hard_sigmoid(out)
return out * x
class MobileBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernal_size, stride, nonLinear, SE, exp_size):
super(MobileBlock, self).__init__()
self.out_channels = out_channels
self.nonLinear = nonLinear
self.SE = SE
padding = (kernal_size - 1) // 2
self.use_connect = stride == 1 and in_channels == out_channels
if self.nonLinear == "RE":
activation = nn.ReLU
else:
activation = h_swish
self.conv = nn.Sequential(
nn.Conv2d(in_channels, exp_size, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(exp_size),
activation(inplace=True)
)
self.depth_conv = nn.Sequential(
nn.Conv2d(exp_size, exp_size, kernel_size=kernal_size, stride=stride, padding=padding, groups=exp_size),
nn.BatchNorm2d(exp_size),
)
if self.SE:
self.squeeze_block = SqueezeBlock(exp_size)
self.point_conv = nn.Sequential(
nn.Conv2d(exp_size, out_channels, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(out_channels),
activation(inplace=True)
)
def forward(self, x):
# MobileNetV2
out = self.conv(x)
out = self.depth_conv(out)
# Squeeze and Excite
if self.SE:
out = self.squeeze_block(out)
# point-wise conv
out = self.point_conv(out)
# connection
if self.use_connect:
return x + out
else:
return out
class MobileNetV3(nn.Module):
def __init__(self, model_mode="LARGE", num_classes=1000, multiplier=1.0, dropout_rate=0.0):
super(MobileNetV3, self).__init__()
self.num_classes = num_classes
if model_mode == "LARGE":
layers = [
[16, 16, 3, 1, "RE", False, 16],
[16, 24, 3, 2, "RE", False, 64],
[24, 24, 3, 1, "RE", False, 72],
[24, 40, 5, 2, "RE", True, 72],
[40, 40, 5, 1, "RE", True, 120],
[40, 40, 5, 1, "RE", True, 120],
[40, 80, 3, 2, "HS", False, 240],
[80, 80, 3, 1, "HS", False, 200],
[80, 80, 3, 1, "HS", False, 184],
[80, 80, 3, 1, "HS", False, 184],
[80, 112, 3, 1, "HS", True, 480],
[112, 112, 3, 1, "HS", True, 672],
[112, 160, 5, 1, "HS", True, 672],
[160, 160, 5, 2, "HS", True, 672],
[160, 160, 5, 1, "HS", True, 960],
]
init_conv_out = _make_divisible(16 * multiplier)
self.init_conv = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=init_conv_out, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(init_conv_out),
h_swish(inplace=True),
)
self.block = []
for in_channels, out_channels, kernal_size, stride, nonlinear, se, exp_size in layers:
in_channels = _make_divisible(in_channels * multiplier)
out_channels = _make_divisible(out_channels * multiplier)
exp_size = _make_divisible(exp_size * multiplier)
self.block.append(MobileBlock(in_channels, out_channels, kernal_size, stride, nonlinear, se, exp_size))
self.block = nn.Sequential(*self.block)
out_conv1_in = _make_divisible(160 * multiplier)
out_conv1_out = _make_divisible(960 * multiplier)
self.out_conv1 = nn.Sequential(
nn.Conv2d(out_conv1_in, out_conv1_out, kernel_size=1, stride=1),
nn.BatchNorm2d(out_conv1_out),
h_swish(inplace=True),
)
out_conv2_in = _make_divisible(960 * multiplier)
out_conv2_out = _make_divisible(1280 * multiplier)
self.out_conv2 = nn.Sequential(
nn.Conv2d(out_conv2_in, out_conv2_out, kernel_size=1, stride=1),
h_swish(inplace=True),
nn.Dropout(dropout_rate),
nn.Conv2d(out_conv2_out, self.num_classes, kernel_size=1, stride=1),
)
elif model_mode == "SMALL":
layers = [
[16, 16, 3, 2, "RE", True, 16],
[16, 24, 3, 2, "RE", False, 72],
[24, 24, 3, 1, "RE", False, 88],
[24, 40, 5, 2, "RE", True, 96],
[40, 40, 5, 1, "RE", True, 240],
[40, 40, 5, 1, "RE", True, 240],
[40, 48, 5, 1, "HS", True, 120],
[48, 48, 5, 1, "HS", True, 144],
[48, 96, 5, 2, "HS", True, 288],
[96, 96, 5, 1, "HS", True, 576],
[96, 96, 5, 1, "HS", True, 576],
]
init_conv_out = _make_divisible(16 * multiplier)
self.init_conv = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=init_conv_out, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(init_conv_out),
h_swish(inplace=True),
)
self.block = []
for in_channels, out_channels, kernal_size, stride, nonlinear, se, exp_size in layers:
in_channels = _make_divisible(in_channels * multiplier)
out_channels = _make_divisible(out_channels * multiplier)
exp_size = _make_divisible(exp_size * multiplier)
self.block.append(MobileBlock(in_channels, out_channels, kernal_size, stride, nonlinear, se, exp_size))
self.block = nn.Sequential(*self.block)
out_conv1_in = _make_divisible(96 * multiplier)
out_conv1_out = _make_divisible(576 * multiplier)
self.out_conv1 = nn.Sequential(
nn.Conv2d(out_conv1_in, out_conv1_out, kernel_size=1, stride=1),
SqueezeBlock(out_conv1_out),
nn.BatchNorm2d(out_conv1_out),
h_swish(inplace=True),
)
out_conv2_in = _make_divisible(576 * multiplier)
out_conv2_out = _make_divisible(1280 * multiplier)
self.out_conv2 = nn.Sequential(
nn.Conv2d(out_conv2_in, out_conv2_out, kernel_size=1, stride=1),
h_swish(inplace=True),
nn.Dropout(dropout_rate),
nn.Conv2d(out_conv2_out, self.num_classes, kernel_size=1, stride=1),
)
self.apply(_weights_init)
def forward(self, x):
out = self.init_conv(x)
out = self.block(out)
out = self.out_conv1(out)
batch, channels, height, width = out.size()
out = F.avg_pool2d(out, kernel_size=[height, width])
out = self.out_conv2(out).view(batch, -1)
return out
# temp = torch.zeros((1, 3, 224, 224))
# model = MobileNetV3(model_mode="LARGE", num_classes=1000, multiplier=1.0)
# print(model(temp).shape)
# print(get_model_parameters(model))