diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 63f4122e2..195e57925 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -284,6 +284,16 @@ def __init__( neox_args.num_attention_heads, world_size ) self.pos_emb = neox_args.pos_emb + self.use_qk_layernorm = neox_args.use_qk_layernorm + if self.use_qk_layernorm: + norm, eps = get_norm(neox_args) + self.qk_layernorm = norm( + [ + self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head, + ], + eps=eps, + ) # Strided linear layer. self.query_key_value = mpu.ColumnParallelLinear( @@ -639,6 +649,11 @@ def forward(self, hidden_states, attention_mask, layer_past=None): mixed_x_layer, 3 ) + # QK Normalization https://arxiv.org/abs/2302.05442 + if self.use_qk_layernorm: + query_layer = self.qk_layernorm(query_layer) + key_layer = self.qk_layernorm(key_layer) + if exists(self.rotary_emb): if exists(self.rotary_ndims): # partial rotary diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index 324a379d4..2cfed465d 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -125,6 +125,11 @@ class NeoXArgsModel(NeoXArgsTemplate): Normalization layer to use. Choose from "layernorm", "rmsnorm", "scalenorm". """ + use_qk_layernorm: bool = False + """ + Use QK Normalization + """ + layernorm_epsilon: float = 1.0e-5 """ Layer norm epsilon.