-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
167 lines (144 loc) · 5.96 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import math
import torch
import torch.nn as nn
from colossalai.context import ParallelMode
from colossalai.nn.layer import VanillaPatchEmbedding, VanillaClassifier, \
WrappedDropout as Dropout, WrappedDropPath as DropPath
from colossalai.nn.layer.moe import Experts, MoeLayer, Top2Router, NormalNoiseGenerator
from .util import moe_sa_args, moe_mlp_args
from colossalai.global_variables import moe_env
from colossalai.utils import get_current_device
class MultiHeadAttention(nn.Module):
"""Standard ViT self attention.
"""
def __init__(self,
d_model: int,
n_heads: int,
d_kv: int,
attention_drop: float = 0,
drop_rate: float = 0,
bias: bool = True,
dropout1=None,
dropout2=None):
super().__init__()
self.n_heads = n_heads
self.d_kv = d_kv
self.scale = 1.0 / math.sqrt(self.d_kv)
self.dense1 = nn.Linear(d_model, 3 * n_heads * d_kv, bias, device=get_current_device())
self.softmax = nn.Softmax(dim=-1)
self.atten_drop = nn.Dropout(attention_drop) if dropout1 is None else dropout1
self.dense2 = nn.Linear(n_heads * d_kv, d_model, device=get_current_device())
self.dropout = nn.Dropout(drop_rate) if dropout2 is None else dropout2
def forward(self, x):
qkv = self.dense1(x)
new_shape = qkv.shape[:2] + (3, self.n_heads, self.d_kv)
qkv = qkv.view(*new_shape)
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv[:]
x = torch.matmul(q, k.transpose(-2, -1)) * self.scale
x = self.atten_drop(self.softmax(x))
x = torch.matmul(x, v)
x = x.transpose(1, 2)
new_shape = x.shape[:2] + (self.n_heads * self.d_kv,)
x = x.reshape(*new_shape)
x = self.dense2(x)
x = self.dropout(x)
return x
# Multihayer Perceptron (two layers of FFNs)
class mlp(nn.Module):
def __init__(self,
d_model: int,
d_ff: int,
activation=None,
drop_rate: float = 0,
bias: bool = True,
dropout1=None,
dropout2=None):
super().__init__()
dense1 = nn.Linear(d_model, d_ff, bias, device=get_current_device())
act = nn.GELU() if activation is None else activation
dense2 = nn.Linear(d_ff, d_model, bias, device=get_current_device())
drop1 = nn.Dropout(drop_rate) if dropout1 is None else dropout1
drop2 = nn.Dropout(drop_rate) if dropout2 is None else dropout2
self.ffns = nn.Sequential(dense1, act, drop1, dense2, drop2)
def forward(self, x):
return self.ffns(x)
class TransformerLayer(nn.Module):
"""Transformer layer builder.
"""
def __init__(self,
att: nn.Module,
ffn: nn.Module,
norm1: nn.Module,
norm2: nn.Module,
droppath=None,
droppath_rate: float = 0):
super().__init__()
self.att = att
self.ffn = ffn
self.norm1 = norm1
self.norm2 = norm2
self.droppath = DropPath(droppath_rate) if droppath is None else droppath
def forward(self, x):
x = x + self.droppath(self.att(self.norm1(x)))
x = x + self.droppath(self.ffn(self.norm2(x)))
return x
class Widenet(nn.Module):
def __init__(self,
num_experts: int,
capacity_factor: float,
img_size: int = 224,
patch_size: int = 16,
in_chans: int = 3,
num_classes: int = 1000,
depth: int = 12,
d_model: int = 768,
num_heads: int = 12,
d_kv: int = 64,
d_ff: int = 4096,
attention_drop: float = 0.,
drop_rate: float = 0.1,
drop_path: float = 0.):
super().__init__()
embedding = VanillaPatchEmbedding(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_size=d_model)
embed_dropout = Dropout(p=drop_rate, mode=ParallelMode.TENSOR)
shared_sa = MultiHeadAttention(
d_model=d_model, n_heads=num_heads, d_kv=d_kv,
attention_drop=attention_drop, drop_rate=drop_rate)
noisy_func = NormalNoiseGenerator(num_experts)
shared_router = Top2Router(capacity_factor, noisy_func=noisy_func)
shared_experts = Experts(expert=mlp,
num_experts=num_experts,
d_model=d_model,
d_ff=d_ff,
drop_rate=drop_rate
)
# stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, drop_path, depth)]
blocks = [
TransformerLayer(
att=shared_sa,
ffn=MoeLayer(dim_model=d_model, num_experts=num_experts,
router=shared_router, experts=shared_experts),
norm1=nn.LayerNorm(d_model, eps=1e-6),
norm2=nn.LayerNorm(d_model, eps=1e-6),
droppath=DropPath(p=dpr[i], mode=ParallelMode.TENSOR)
)
for i in range(depth)
]
norm = nn.LayerNorm(d_model, eps=1e-6)
self.linear = VanillaClassifier(in_features=d_model,
num_classes=num_classes)
nn.init.zeros_(self.linear.weight)
nn.init.zeros_(self.linear.bias)
self.widenet = nn.Sequential(embedding, embed_dropout, *blocks, norm)
def forward(self, x):
moe_env.reset_loss()
x = self.widenet(x)
x = torch.mean(x, dim=1)
x = self.linear(x)
return x