Skip to content

Commit

Permalink
add qk normalization
Browse files Browse the repository at this point in the history
  • Loading branch information
lintangsutawika committed Dec 19, 2023
1 parent 050f560 commit a97a984
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 0 deletions.
15 changes: 15 additions & 0 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,16 @@ def __init__(
neox_args.num_attention_heads, world_size
)
self.pos_emb = neox_args.pos_emb
self.use_qk_layernorm = neox_args.use_qk_layernorm
if self.use_qk_layernorm:
norm, eps = get_norm(neox_args)
self.qk_layernorm = norm(
[
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
],
eps=eps,
)

# Strided linear layer.
self.query_key_value = mpu.ColumnParallelLinear(
Expand Down Expand Up @@ -639,6 +649,11 @@ def forward(self, hidden_states, attention_mask, layer_past=None):
mixed_x_layer, 3
)

# QK Normalization https://arxiv.org/abs/2302.05442
if self.use_qk_layernorm:
query_layer = self.qk_layernorm(query_layer)
key_layer = self.qk_layernorm(key_layer)

if exists(self.rotary_emb):
if exists(self.rotary_ndims):
# partial rotary
Expand Down
5 changes: 5 additions & 0 deletions megatron/neox_arguments/neox_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,11 @@ class NeoXArgsModel(NeoXArgsTemplate):
Normalization layer to use. Choose from "layernorm", "rmsnorm", "scalenorm".
"""

use_qk_layernorm: bool = False
"""
Use QK Normalization
"""

layernorm_epsilon: float = 1.0e-5
"""
Layer norm epsilon.
Expand Down

0 comments on commit a97a984

Please sign in to comment.