Skip to content

Commit

Permalink
Added rng tracker to lnmlp and placed rope in te_mha init instead of …
Browse files Browse the repository at this point in the history
…forward
  • Loading branch information
aurelion-source committed Oct 1, 2024
1 parent bb76510 commit 43cf4ee
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 13 deletions.
9 changes: 4 additions & 5 deletions megatron/model/positional_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ def forward(self, x, seq_dim=1):

class RotaryEmbedding(torch.nn.Module):
def __init__(
self, dim, max_seq_len, base=10000, precision=torch.half, save_inv_freqs=False, return_embeddings=False
):
self, dim, max_seq_len, base=10000, precision=torch.half, save_inv_freqs=False):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq, persistent=save_inv_freqs)
Expand All @@ -49,7 +48,6 @@ def __init__(
self.max_seq_len = max_seq_len
self.base = base
self.dim = dim
self.return_embeddings = return_embeddings

# precompute cos_cached, sin_cached in fp32
cos_cached, sin_cached, inv_freq = self._prepare_cache(
Expand Down Expand Up @@ -79,9 +77,10 @@ def _prepare_cache(self, seq_len, precision, base):
inv_freq.to(precision),
)

def get_emb(self):
return self.emb.to(self.precision).cuda()

def forward(self, x, seq_dim=0, seq_len=None):
if self.return_embeddings:
return self.emb.to(self.precision).to(x.device)
if seq_len is None:
seq_len = x.shape[seq_dim]

Expand Down
12 changes: 4 additions & 8 deletions megatron/model/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def __init__(
init_method=self.init_method, output_layer_init_method=self.output_layer_init_method,
device=torch.cuda.current_device(), set_parallel_mode=self.set_parallel_mode,
sequence_parallel=self.sequence_parallel, tp_group=self.tp_group, tp_size=self.world_size,
return_bias=True, params_dtype=self.params_dtype, seq_length=self.seq_len,
return_bias=True, params_dtype=self.params_dtype, seq_length=self.seq_len, get_rng_state_tracker=get_cuda_rng_tracker,
micro_batch_size=self.micro_batch_size)


Expand Down Expand Up @@ -538,16 +538,12 @@ def __init__(self,
base=neox_args.rotary_emb_base,
max_seq_len=neox_args.seq_length,
precision=neox_args.params_dtype,
save_inv_freqs=neox_args.rotary_save_freqs_buffer,
return_embeddings=True
save_inv_freqs=neox_args.rotary_save_freqs_buffer
)
self.rope_emb=self.rotary_embeddings.get_emb()

def forward(self, hidden_states, attention_mask, layer_past=None, rope_emb=None, **kwargs):
if self.neox_args.pos_emb == "rotary":
rope_emb=self.rotary_embeddings(hidden_states)

output = super(TEMultiheadAttention, self).forward(hidden_states, attention_mask, rotary_pos_emb=rope_emb, **kwargs)

output = super(TEMultiheadAttention, self).forward(hidden_states, attention_mask, rotary_pos_emb=self.rope_emb, **kwargs)
return output


Expand Down

0 comments on commit 43cf4ee

Please sign in to comment.