-
Notifications
You must be signed in to change notification settings - Fork 22
/
Copy pathmodels.py
329 lines (292 loc) · 12.6 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from config import config
D = config.connector_dim
Nh = config.num_heads
Dword = config.glove_dim
Dchar = config.char_dim
batch_size = config.batch_size
dropout = config.dropout
dropout_char = config.dropout_char
Lc = config.para_limit
Lq = config.ques_limit
def mask_logits(inputs, mask):
mask = mask.type(torch.float32)
return inputs + (-1e30) * (1 - mask)
class Initialized_Conv1d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1, relu=False, stride=1, padding=0, groups=1, bias=False):
super().__init__()
self.out = nn.Conv1d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, groups=groups, bias=bias)
if relu is True:
self.relu = True
nn.init.kaiming_normal_(self.out.weight, nonlinearity='relu')
else:
self.relu = False
nn.init.xavier_uniform_(self.out.weight)
def forward(self, x):
if self.relu == True:
return F.relu(self.out(x))
else:
return self.out(x)
def PosEncoder(x, min_timescale=1.0, max_timescale=1.0e4):
x = x.transpose(1,2)
length = x.size()[1]
channels = x.size()[2]
signal = get_timing_signal(length, channels, min_timescale, max_timescale)
return (x + signal.cuda()).transpose(1,2)
def get_timing_signal(length, channels, min_timescale=1.0, max_timescale=1.0e4):
position = torch.arange(length).type(torch.float32)
num_timescales = channels // 2
log_timescale_increment = (math.log(float(max_timescale) / float(min_timescale)) / (float(num_timescales)-1))
inv_timescales = min_timescale * torch.exp(
torch.arange(num_timescales).type(torch.float32) * -log_timescale_increment)
scaled_time = position.unsqueeze(1) * inv_timescales.unsqueeze(0)
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim = 1)
m = nn.ZeroPad2d((0, (channels % 2), 0, 0))
signal = m(signal)
signal = signal.view(1, length, channels)
return signal
class DepthwiseSeparableConv(nn.Module):
def __init__(self, in_ch, out_ch, k, bias=True):
super().__init__()
self.depthwise_conv = nn.Conv1d(in_channels=in_ch, out_channels=in_ch, kernel_size=k, groups=in_ch, padding=k // 2, bias=False)
self.pointwise_conv = nn.Conv1d(in_channels=in_ch, out_channels=out_ch, kernel_size=1, padding=0, bias=bias)
def forward(self, x):
return F.relu(self.pointwise_conv(self.depthwise_conv(x)))
class Highway(nn.Module):
def __init__(self, layer_num: int, size=D):
super().__init__()
self.n = layer_num
self.linear = nn.ModuleList([Initialized_Conv1d(size, size, relu=False, bias=True) for _ in range(self.n)])
self.gate = nn.ModuleList([Initialized_Conv1d(size, size, bias=True) for _ in range(self.n)])
def forward(self, x):
#x: shape [batch_size, hidden_size, length]
for i in range(self.n):
gate = F.sigmoid(self.gate[i](x))
nonlinear = self.linear[i](x)
nonlinear = F.dropout(nonlinear, p=dropout, training=self.training)
x = gate * nonlinear + (1 - gate) * x
#x = F.relu(x)
return x
class SelfAttention(nn.Module):
def __init__(self):
super().__init__()
self.mem_conv = Initialized_Conv1d(D, D*2, kernel_size=1, relu=False, bias=False)
self.query_conv = Initialized_Conv1d(D, D, kernel_size=1, relu=False, bias=False)
bias = torch.empty(1)
nn.init.constant_(bias, 0)
self.bias = nn.Parameter(bias)
def forward(self, queries, mask):
memory = queries
memory = self.mem_conv(memory)
query = self.query_conv(queries)
memory = memory.transpose(1, 2)
query = query.transpose(1, 2)
Q = self.split_last_dim(query, Nh)
K, V = [self.split_last_dim(tensor, Nh) for tensor in torch.split(memory, D, dim=2)]
key_depth_per_head = D // Nh
Q *= key_depth_per_head**-0.5
x = self.dot_product_attention(Q, K, V, mask = mask)
return self.combine_last_two_dim(x.permute(0,2,1,3)).transpose(1, 2)
def dot_product_attention(self, q, k ,v, bias = False, mask = None):
"""dot-product attention.
Args:
q: a Tensor with shape [batch, heads, length_q, depth_k]
k: a Tensor with shape [batch, heads, length_kv, depth_k]
v: a Tensor with shape [batch, heads, length_kv, depth_v]
bias: bias Tensor (see attention_bias())
is_training: a bool of training
scope: an optional string
Returns:
A Tensor.
"""
logits = torch.matmul(q,k.permute(0,1,3,2))
if bias:
logits += self.bias
if mask is not None:
shapes = [x if x != None else -1 for x in list(logits.size())]
mask = mask.view(shapes[0], 1, 1, shapes[-1])
logits = mask_logits(logits, mask)
weights = F.softmax(logits, dim=-1)
# dropping out the attention links for each of the heads
weights = F.dropout(weights, p=dropout, training=self.training)
return torch.matmul(weights, v)
def split_last_dim(self, x, n):
"""Reshape x so that the last dimension becomes two dimensions.
The first of these two dimensions is n.
Args:
x: a Tensor with shape [..., m]
n: an integer.
Returns:
a Tensor with shape [..., n, m/n]
"""
old_shape = list(x.size())
last = old_shape[-1]
new_shape = old_shape[:-1] + [n] + [last // n if last else None]
ret = x.view(new_shape)
return ret.permute(0, 2, 1, 3)
def combine_last_two_dim(self, x):
"""Reshape x so that the last two dimension become one.
Args:
x: a Tensor with shape [..., a, b]
Returns:
a Tensor with shape [..., ab]
"""
old_shape = list(x.size())
a, b = old_shape[-2:]
new_shape = old_shape[:-2] + [a * b if a and b else None]
ret = x.contiguous().view(new_shape)
return ret
class Embedding(nn.Module):
def __init__(self):
super().__init__()
self.conv2d = nn.Conv2d(Dchar, D, kernel_size = (1,5), padding=0, bias=True)
nn.init.kaiming_normal_(self.conv2d.weight, nonlinearity='relu')
self.conv1d = Initialized_Conv1d(Dword+D, D, bias=False)
self.high = Highway(2)
def forward(self, ch_emb, wd_emb, length):
N = ch_emb.size()[0]
ch_emb = ch_emb.permute(0, 3, 1, 2)
ch_emb = F.dropout(ch_emb, p=dropout_char, training=self.training)
ch_emb = self.conv2d(ch_emb)
ch_emb = F.relu(ch_emb)
ch_emb, _ = torch.max(ch_emb, dim=3)
ch_emb = ch_emb.squeeze()
wd_emb = F.dropout(wd_emb, p=dropout, training=self.training)
wd_emb = wd_emb.transpose(1, 2)
emb = torch.cat([ch_emb, wd_emb], dim=1)
emb = self.conv1d(emb)
emb = self.high(emb)
return emb
class EncoderBlock(nn.Module):
def __init__(self, conv_num: int, ch_num: int, k: int):
super().__init__()
self.convs = nn.ModuleList([DepthwiseSeparableConv(ch_num, ch_num, k) for _ in range(conv_num)])
self.self_att = SelfAttention()
self.FFN_1 = Initialized_Conv1d(ch_num, ch_num, relu=True, bias=True)
self.FFN_2 = Initialized_Conv1d(ch_num, ch_num, bias=True)
self.norm_C = nn.ModuleList([nn.LayerNorm(D) for _ in range(conv_num)])
self.norm_1 = nn.LayerNorm(D)
self.norm_2 = nn.LayerNorm(D)
self.conv_num = conv_num
def forward(self, x, mask, l, blks):
total_layers = (self.conv_num+1)*blks
out = PosEncoder(x)
for i, conv in enumerate(self.convs):
res = out
out = self.norm_C[i](out.transpose(1,2)).transpose(1,2)
if (i) % 2 == 0:
out = F.dropout(out, p=dropout, training=self.training)
out = conv(out)
out = self.layer_dropout(out, res, dropout*float(l)/total_layers)
l += 1
res = out
out = self.norm_1(out.transpose(1,2)).transpose(1,2)
out = F.dropout(out, p=dropout, training=self.training)
out = self.self_att(out, mask)
out = self.layer_dropout(out, res, dropout*float(l)/total_layers)
l += 1
res = out
out = self.norm_2(out.transpose(1,2)).transpose(1,2)
out = F.dropout(out, p=dropout, training=self.training)
out = self.FFN_1(out)
out = self.FFN_2(out)
out = self.layer_dropout(out, res, dropout*float(l)/total_layers)
return out
def layer_dropout(self, inputs, residual, dropout):
if self.training == True:
pred = torch.empty(1).uniform_(0,1) < dropout
if pred:
return residual
else:
return F.dropout(inputs, dropout, training=self.training) + residual
else:
return inputs + residual
class CQAttention(nn.Module):
def __init__(self):
super().__init__()
w4C = torch.empty(D, 1)
w4Q = torch.empty(D, 1)
w4mlu = torch.empty(1, 1, D)
nn.init.xavier_uniform_(w4C)
nn.init.xavier_uniform_(w4Q)
nn.init.xavier_uniform_(w4mlu)
self.w4C = nn.Parameter(w4C)
self.w4Q = nn.Parameter(w4Q)
self.w4mlu = nn.Parameter(w4mlu)
bias = torch.empty(1)
nn.init.constant_(bias, 0)
self.bias = nn.Parameter(bias)
def forward(self, C, Q, Cmask, Qmask):
C = C.transpose(1, 2)
Q = Q.transpose(1, 2)
batch_size_c = C.size()[0]
S = self.trilinear_for_attention(C, Q)
Cmask = Cmask.view(batch_size_c, Lc, 1)
Qmask = Qmask.view(batch_size_c, 1, Lq)
S1 = F.softmax(mask_logits(S, Qmask), dim=2)
S2 = F.softmax(mask_logits(S, Cmask), dim=1)
A = torch.bmm(S1, Q)
B = torch.bmm(torch.bmm(S1, S2.transpose(1, 2)), C)
out = torch.cat([C, A, torch.mul(C, A), torch.mul(C, B)], dim=2)
return out.transpose(1, 2)
def trilinear_for_attention(self, C, Q):
C = F.dropout(C, p=dropout, training=self.training)
Q = F.dropout(Q, p=dropout, training=self.training)
subres0 = torch.matmul(C, self.w4C).expand([-1, -1, Lq])
subres1 = torch.matmul(Q, self.w4Q).transpose(1, 2).expand([-1, Lc, -1])
subres2 = torch.matmul(C * self.w4mlu, Q.transpose(1,2))
res = subres0 + subres1 + subres2
res += self.bias
return res
class Pointer(nn.Module):
def __init__(self):
super().__init__()
self.w1 = Initialized_Conv1d(D*2, 1)
self.w2 = Initialized_Conv1d(D*2, 1)
def forward(self, M1, M2, M3, mask):
X1 = torch.cat([M1, M2], dim=1)
X2 = torch.cat([M1, M3], dim=1)
Y1 = mask_logits(self.w1(X1).squeeze(), mask)
Y2 = mask_logits(self.w2(X2).squeeze(), mask)
return Y1, Y2
class QANet(nn.Module):
def __init__(self, word_mat, char_mat):
super().__init__()
if config.pretrained_char:
self.char_emb = nn.Embedding.from_pretrained(torch.Tensor(char_mat))
else:
char_mat = torch.Tensor(char_mat)
self.char_emb = nn.Embedding.from_pretrained(char_mat, freeze=False)
self.word_emb = nn.Embedding.from_pretrained(torch.Tensor(word_mat), freeze=True)
self.emb = Embedding()
self.emb_enc = EncoderBlock(conv_num=4, ch_num=D, k=7)
self.cq_att = CQAttention()
self.cq_resizer = Initialized_Conv1d(D * 4, D)
self.model_enc_blks = nn.ModuleList([EncoderBlock(conv_num=2, ch_num=D, k=5) for _ in range(7)])
self.out = Pointer()
def forward(self, Cwid, Ccid, Qwid, Qcid):
maskC = (torch.zeros_like(Cwid) != Cwid).float()
maskQ = (torch.zeros_like(Qwid) != Qwid).float()
Cw, Cc = self.word_emb(Cwid), self.char_emb(Ccid)
Qw, Qc = self.word_emb(Qwid), self.char_emb(Qcid)
C, Q = self.emb(Cc, Cw, Lc), self.emb(Qc, Qw, Lq)
Ce = self.emb_enc(C, maskC, 1, 1)
Qe = self.emb_enc(Q, maskQ, 1, 1)
X = self.cq_att(Ce, Qe, maskC, maskQ)
M0 = self.cq_resizer(X)
M0 = F.dropout(M0, p=dropout, training=self.training)
for i, blk in enumerate(self.model_enc_blks):
M0 = blk(M0, maskC, i*(2+2)+1, 7)
M1 = M0
for i, blk in enumerate(self.model_enc_blks):
M0 = blk(M0, maskC, i*(2+2)+1, 7)
M2 = M0
M0 = F.dropout(M0, p=dropout, training=self.training)
for i, blk in enumerate(self.model_enc_blks):
M0 = blk(M0, maskC, i*(2+2)+1, 7)
M3 = M0
p1, p2 = self.out(M1, M2, M3, maskC)
return p1, p2