-
Notifications
You must be signed in to change notification settings - Fork 53
/
model_torch.py
304 lines (240 loc) · 9.41 KB
/
model_torch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
from functools import partial
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.batchnorm import _BatchNorm
from bn_lib.nn.modules import SynchronizedBatchNorm2d
import settings
norm_layer = partial(SynchronizedBatchNorm2d, momentum=settings.BN_MOM)
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, dilation=1,
downsample=None, previous_dilation=1):
super().__init__()
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
self.bn1 = norm_layer(planes)
self.conv2 = nn.Conv2d(planes, planes, 3, stride, dilation, dilation,
bias=False)
self.bn2 = norm_layer(planes)
self.conv3 = nn.Conv2d(planes, planes * 4, 1, bias=False)
self.bn3 = norm_layer(planes * 4)
self.relu = nn.ReLU()
self.downsample = downsample
self.dilation = dilation
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=1000, stride=8):
self.inplanes = 128
super().__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False),
norm_layer(64),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False),
norm_layer(64),
nn.ReLU(),
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=False))
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
if stride == 16:
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(
block, 512, layers[3], stride=1, dilation=2, grids=[1,2,4])
elif stride == 8:
self.layer3 = self._make_layer(
block, 256, layers[2], stride=1, dilation=2)
self.layer4 = self._make_layer(
block, 512, layers[3], stride=1, dilation=4, grids=[1,2,4])
self.avgpool = nn.AvgPool2d(7, stride=1)
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, _BatchNorm):
m.weight.data.fill_(1)
if m.bias is not None:
m.bias.data.zero_()
def _make_layer(self, block, planes, blocks, stride=1, dilation=1,
grids=None):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
norm_layer(planes * block.expansion))
layers = []
if grids is None:
grids = [1] * blocks
if dilation == 1 or dilation == 2:
layers.append(block(self.inplanes, planes, stride, dilation=1,
downsample=downsample,
previous_dilation=dilation))
elif dilation == 4:
layers.append(block(self.inplanes, planes, stride, dilation=2,
downsample=downsample,
previous_dilation=dilation))
else:
raise RuntimeError('=> unknown dilation size: {}'.format(dilation))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes,
dilation=dilation*grids[i],
previous_dilation=dilation))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
def resnet(n_layers, stride):
layers = {
50: [3, 4, 6, 3],
101: [3, 4, 23, 3],
152: [3, 8, 36, 3],
}[n_layers]
pretrained_path = {
50: './models/resnet50-ebb6acbb.pth',
101: './models/resnet101-2a57e44d.pth',
152: './models/resnet152-0d43d698.pth',
}[n_layers]
net = ResNet(Bottleneck, layers=layers, stride=stride)
state_dict = torch.load(pretrained_path)
net.load_state_dict(state_dict, strict=False)
return net
class ConvBNReLU(nn.Module):
'''Module for the Conv-BN-ReLU tuple.'''
def __init__(self, c_in, c_out, kernel_size, stride, padding, dilation):
super(ConvBNReLU, self).__init__()
self.conv = nn.Conv2d(
c_in, c_out, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=False)
self.bn = norm_layer(c_out)
self.relu = nn.ReLU()
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
class External_attention(nn.Module):
'''
Arguments:
c (int): The input and output channel number.
'''
def __init__(self, c):
super(External_attention, self).__init__()
self.conv1 = nn.Conv2d(c, c, 1)
self.k = 64
self.linear_0 = nn.Conv1d(c, self.k, 1, bias=False)
self.linear_1 = nn.Conv1d(self.k, c, 1, bias=False)
self.linear_1.weight.data = self.linear_0.weight.data.permute(1, 0, 2)
self.conv2 = nn.Sequential(
nn.Conv2d(c, c, 1, bias=False),
norm_layer(c))
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.Conv1d):
n = m.kernel_size[0] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, _BatchNorm):
m.weight.data.fill_(1)
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x):
idn = x
x = self.conv1(x)
b, c, h, w = x.size()
n = h*w
x = x.view(b, c, h*w) # b * c * n
attn = self.linear_0(x) # b, k, n
attn = F.softmax(attn, dim=-1) # b, k, n
attn = attn / (1e-9 + attn.sum(dim=1, keepdim=True)) # # b, k, n
x = self.linear_1(attn) # b, c, n
x = x.view(b, c, h, w)
x = self.conv2(x)
x = x + idn
x = F.relu(x)
return x
class EANet(nn.Module):
def __init__(self, n_classes, n_layers):
super().__init__()
backbone = resnet(n_layers, settings.STRIDE)
self.extractor = nn.Sequential(
backbone.conv1,
backbone.bn1,
backbone.relu,
backbone.maxpool,
backbone.layer1,
backbone.layer2,
backbone.layer3,
backbone.layer4)
self.fc0 = ConvBNReLU(2048, 512, 3, 1, 1, 1)
self.linu = External_attention(512)
self.fc1 = nn.Sequential(
ConvBNReLU(512, 256, 3, 1, 1, 1),
nn.Dropout2d(p=0.1))
self.fc2 = nn.Conv2d(256, n_classes, 1)
self.crit = CrossEntropyLoss2d(ignore_index=settings.IGNORE_LABEL,
reduction='none')
def forward(self, img, lbl=None, size=None):
x = self.extractor(img)
x = self.fc0(x)
x = self.linu(x)
x = self.fc1(x)
x = self.fc2(x)
if size is None:
size = img.size()[-2:]
pred = F.interpolate(x, size=size, mode='bilinear', align_corners=True)
if self.training and lbl is not None:
loss = self.crit(pred, lbl)
return loss
else:
return pred
class CrossEntropyLoss2d(nn.Module):
def __init__(self, weight=None, reduction='none', ignore_index=-1):
super(CrossEntropyLoss2d, self).__init__()
self.nll_loss = nn.NLLLoss(weight, reduction=reduction,
ignore_index=ignore_index)
def forward(self, inputs, targets):
loss = self.nll_loss(F.log_softmax(inputs, dim=1), targets)
return loss.mean(dim=2).mean(dim=1)
def test_net():
model = EANet(n_classes=21, n_layers=50)
model.eval()
print(list(model.named_children()))
image = torch.randn(1, 3, 513, 513)
label = torch.zeros(1, 513, 513).long()
pred = model(image, label)
print(pred.size())
if __name__ == '__main__':
test_net()