-
Notifications
You must be signed in to change notification settings - Fork 2
/
model.py
156 lines (121 loc) · 5.01 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import torch
from torch import nn
from transformers import BertConfig
from transformers.models.bert.modeling_bert import BertLayer as _BertLayer
from utils import pos_encoding
class SafeEmbedding(nn.Embedding):
"Handle unseen id"
def forward(self, input):
output = torch.empty((*input.size(), self.embedding_dim),
device=input.device,
dtype=self.weight.dtype)
seen = input < self.num_embeddings
unseen = seen.logical_not()
output[seen] = super().forward(input[seen])
output[unseen] = torch.zeros_like(
self.weight[0]).expand(unseen.sum(), -1)
return output
class FlattenBatchNorm1d(nn.BatchNorm1d):
"BatchNorm1d that treats (N, C, L) as (N*C, L)"
def forward(self, input):
sz = input.size()
return super().forward(input.view(-1, sz[-1])).view(*sz)
class BertLayer(_BertLayer):
def forward(self, *args, **kwargs):
return super().forward(*args, **kwargs)[0]
class MemBertLayer(BertLayer):
def __init__(self, *args, n_mem=1, **kwargs):
super().__init__(*args, **kwargs)
self.n_mem = n_mem
def forward(self, hidden, mem=None, **kwargs):
# T x L x D, T is time dimension, L is length of investment per time
assert hidden.dim() == 3, hidden.size()
if mem is None:
mem = pos_encoding(self.n_mem, hidden.size(2),
device=hidden.device,
dtype=hidden.dtype).unsqueeze(0)
# Just for easy to understand, if T = 1 then
# hidden = torch.cat([mem, hidden], dim=1)
# hidden = super().forward(hidden, **kwargs)
# mem, hidden = hidden[:, :mem.size(1)], hidden[:, mem.size(1):]
# return hidden, mem
out_hidden = torch.Tensor().type_as(hidden)
for cur_hidden in hidden:
cur_hidden = torch.cat([mem, cur_hidden.unsqueeze(0)], dim=1)
cur_hidden = super().forward(cur_hidden, **kwargs)
mem = cur_hidden[:, :mem.size(1)]
out_hidden = torch.cat([out_hidden, cur_hidden[:, mem.size(1):]])
return out_hidden, mem
class BasicLayer(nn.Module):
def __init__(self, args, in_sz, out_sz, mha=False):
super().__init__()
self.args = args
layers = [
nn.Linear(in_sz, out_sz),
FlattenBatchNorm1d(out_sz),
nn.SiLU(),
]
if args.dropout > 0.0:
layers.append(nn.Dropout(p=args.dropout))
self.layers = nn.Sequential(*layers)
self.mha = self._maybe_get_mha(args, out_sz, mha)
def _maybe_get_mha(self, args, size, mha):
if not mha:
return None
bert_layer, kwargs = BertLayer, {}
if args.n_mem:
bert_layer = MemBertLayer
kwargs['n_mem'] = args.n_mem
return bert_layer(BertConfig(num_attention_heads=8,
hidden_size=size,
intermediate_size=size),
**kwargs)
def forward(self, input, mem=None):
output = self.layers(input)
if self.mha is not None:
args = [] if mem is None else [mem]
return self.mha(output, *args)
return output
class Net(nn.Module):
"""return (output, mem) if use_memory else output"""
def __init__(self, args, n_embed, n_feature):
super().__init__()
self.emb = SafeEmbedding(n_embed, args.emb_dim)
in_size = args.emb_dim + n_feature
szs = [in_size] + args.szs
self.mem_placeholder = None
self.basic_layers = self._get_layers(args, szs)
self.fc = nn.Linear(szs[-1], 1)
self._post_init()
def _get_layers(self, args, szs):
layers = nn.ModuleList([
BasicLayer(args, in_sz, out_sz, layer_i in args.mhas)
for layer_i, (in_sz, out_sz) in enumerate(zip(szs[:-1], szs[1:]))
])
assert sum(
isinstance(layer.mha, MemBertLayer)
for layer in layers
) <= 1, 'Support at most one MemBertLayer'
return layers
def forward(self, x_id, x_feat, mem=None):
x_emb = self.emb(x_id)
output = torch.cat((x_emb, x_feat), dim=-1)
for layer in self.basic_layers:
if isinstance(layer.mha, MemBertLayer):
output, mem = layer(output, mem=mem)
else:
output = layer(output)
output = self.fc(output).squeeze(-1)
if mem is not None:
return output, mem
return output
def _post_init(self):
for m in self.modules():
if isinstance(m, (nn.Linear, SafeEmbedding)):
nn.init.kaiming_normal_(
m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, (nn.BatchNorm1d, nn.LayerNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Embedding):
m.weight.data.normal_(mean=0.0, std=0.02)