-
Notifications
You must be signed in to change notification settings - Fork 1
/
sut_transformer.py
executable file
·167 lines (147 loc) · 7.44 KB
/
sut_transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from fairseq.dataclass.utils import gen_parser_from_dataclass
from fairseq.models import (
register_model,
register_model_architecture,
)
from .transformer_config import (
TransformerConfig,
DEFAULT_MAX_SOURCE_POSITIONS,
DEFAULT_MAX_TARGET_POSITIONS,
DEFAULT_MIN_PARAMS_TO_WRAP,
)
# from fairseq.models.transformer.transformer_base import (
# TransformerModelBase,
# )
from .transformer_base import TransformerModelBase
@register_model("sut_transformer")
class SUTTransformerModel(TransformerModelBase):
def __init__(self, args, encoder, decoder):
cfg = TransformerConfig.from_namespace(args)
super().__init__(cfg, encoder, decoder)
self.args = args
@classmethod
def add_args(cls, parser):
"""Add model-specific arguments to the parser."""
# we want to build the args recursively in this case.
# do not set defaults so that settings defaults from various architectures still works
gen_parser_from_dataclass(
parser, TransformerConfig(), delete_default=True, with_prefix=""
)
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
# make sure all arguments are present in older models
base_ende_architecture(args)
if args.encoder_layers_to_keep:
args.encoder_layers = len(args.encoder_layers_to_keep.split(","))
if args.decoder_layers_to_keep:
args.decoder_layers = len(args.decoder_layers_to_keep.split(","))
if getattr(args, "max_source_positions", None) is None:
args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
if getattr(args, "max_target_positions", None) is None:
args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
if args.share_all_embeddings:
if src_dict != tgt_dict:
raise ValueError("--share-all-embeddings requires a joined dictionary")
if args.encoder_embed_dim != args.decoder_embed_dim:
raise ValueError(
"--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim"
)
if args.decoder_embed_path and (
args.decoder_embed_path != args.encoder_embed_path
):
raise ValueError(
"--share-all-embeddings not compatible with --decoder-embed-path"
)
args.share_decoder_input_output_embed = True
if getattr(args, "offload_activations", False):
args.checkpoint_activations = True # offloading implies checkpointing
if not args.share_all_embeddings:
args.min_params_to_wrap = getattr(
args, "min_params_to_wrap", DEFAULT_MIN_PARAMS_TO_WRAP
)
cfg = TransformerConfig.from_namespace(args)
return super().build_model(cfg, task)
@classmethod
def build_embedding(cls, args, dictionary, embed_dim, path=None):
return super().build_embedding(
TransformerConfig.from_namespace(args), dictionary, embed_dim, path
)
@classmethod
def build_encoder(cls, args, src_dict, embed_tokens):
return super().build_encoder(
TransformerConfig.from_namespace(args), src_dict, embed_tokens
)
@classmethod
def build_decoder(cls, args, tgt_dict, embed_tokens):
return super().build_decoder(
TransformerConfig.from_namespace(args), tgt_dict, embed_tokens
)
# architectures
@register_model_architecture("sut_transformer", "sut_transformer_base")
def base_ende_architecture(args):
args.encoder_embed_path = getattr(args, "encoder_embed_path", None)
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
args.head_dim = getattr(args, "head_dim", 256)
args.encoder_ff_expert_dim = getattr(args, "encoder_ff_expert_dim", 2048)
args.encoder_layers = getattr(args, "encoder_layers", 6)
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 8)
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
args.encoder_learned_pos = getattr(args, "encoder_learned_pos", False)
args.num_expert = getattr(args, "num_expert", 12)
args.switchloss = getattr(args, "switchloss", 0)
args.cvloss = getattr(args, "cvloss", 0)
args.zloss = getattr(args, "zloss", 0)
args.miloss = getattr(args, "miloss", 0)
args.actloss = getattr(args, "actloss", 0)
args.sample_topk = getattr(args, "sample_topk", 0)
args.decoder_embed_path = getattr(args, "decoder_embed_path", None)
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", args.encoder_embed_dim)
args.decoder_ff_expert_dim = getattr(
args, "decoder_ff_expert_dim", args.encoder_ff_expert_dim
)
args.decoder_layers = getattr(args, "decoder_layers", 6)
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 8)
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", False)
args.decoder_learned_pos = getattr(args, "decoder_learned_pos", False)
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
args.activation_dropout = getattr(args, "activation_dropout", 0.1)
args.activation_fn = getattr(args, "activation_fn", "gelu")
args.dropout = getattr(args, "dropout", 0.3)
args.halting_dropout = getattr(args, "halting_dropout", 0)
args.gating_dropout = getattr(args, "gating_dropout", 0.2)
args.adaptive_softmax_cutoff = getattr(args, "adaptive_softmax_cutoff", None)
args.adaptive_softmax_dropout = getattr(args, "adaptive_softmax_dropout", 0.1)
args.share_decoder_input_output_embed = getattr(
args, "share_decoder_input_output_embed", False
)
args.share_all_embeddings = getattr(args, "share_all_embeddings", False)
args.no_token_positional_embeddings = getattr(
args, "no_token_positional_embeddings", False
)
args.adaptive_input = getattr(args, "adaptive_input", False)
args.no_cross_attention = getattr(args, "no_cross_attention", False)
args.cross_self_attention = getattr(args, "cross_self_attention", False)
args.decoder_output_dim = getattr(
args, "decoder_output_dim", args.decoder_embed_dim
)
args.decoder_input_dim = getattr(args, "decoder_input_dim", args.decoder_embed_dim)
args.no_scale_embedding = getattr(args, "no_scale_embedding", False)
args.layernorm_embedding = getattr(args, "layernorm_embedding", False)
args.tie_adaptive_weights = getattr(args, "tie_adaptive_weights", False)
args.checkpoint_activations = getattr(args, "checkpoint_activations", False)
args.offload_activations = getattr(args, "offload_activations", False)
if args.offload_activations:
args.checkpoint_activations = True
args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None)
args.decoder_layers_to_keep = getattr(args, "decoder_layers_to_keep", None)
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0)
args.decoder_layerdrop = getattr(args, "decoder_layerdrop", 0)
args.quant_noise_pq = getattr(args, "quant_noise_pq", 0)
args.quant_noise_pq_block_size = getattr(args, "quant_noise_pq_block_size", 8)
args.quant_noise_scalar = getattr(args, "quant_noise_scalar", 0)