Skip to content

Commit

Permalink
Pre-compute RoPE embeddings in fp32 (#1041)
Browse files Browse the repository at this point in the history
* Pre-commit

Signed-off-by: Dashiell Stander <[email protected]>

* Sequence dimension is 0

Signed-off-by: Dashiell Stander <[email protected]>

* Update NeoXArgs docs automatically

* Update NeoXArgs docs automatically

---------

Signed-off-by: Dashiell Stander <[email protected]>
Co-authored-by: github-actions <[email protected]>
Co-authored-by: Quentin Anthony <[email protected]>
  • Loading branch information
3 people authored Sep 25, 2023
1 parent 2ab05be commit 3bfedf4
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 34 deletions.
2 changes: 1 addition & 1 deletion configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ Logging Arguments

- **git_hash**: str

Default = 3f8c63c
Default = cd88435

current git hash of repository

Expand Down
57 changes: 42 additions & 15 deletions megatron/model/positional_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,31 +36,58 @@ def forward(self, x, seq_dim=1):


class RotaryEmbedding(torch.nn.Module):
def __init__(self, dim, base=10000, precision=torch.half):
def __init__(self, dim, max_seq_len, base=10000, precision=torch.half):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
self.seq_len_cached = None
self.cos_cached = None
self.sin_cached = None
self.precision = precision
self.max_seq_len = max_seq_len
self.base = base
self.dim = dim

# precompute cos_cached, sin_cached in fp32
cos_cached, sin_cached, inv_freq = self._prepare_cache(
max_seq_len, precision, base
)

self.register_buffer("inv_freq", inv_freq)
self.cos_cached = cos_cached
self.sin_cached = sin_cached

def _prepare_cache(self, seq_len, precision, base):
# precompute cos_cached, sin_cached in fp32
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float() / self.dim))

t = torch.arange(seq_len).type_as(inv_freq)
freqs = torch.einsum("i,j->ij", t, inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)

cos_cached = emb.cos()[:, None, None, :]
sin_cached = emb.sin()[:, None, None, :]

def forward(self, x, seq_dim=1, seq_len=None):
return (
cos_cached.to(precision),
sin_cached.to(precision),
inv_freq.to(precision),
)

def forward(self, x, seq_dim=0, seq_len=None):
if seq_len is None:
seq_len = x.shape[seq_dim]
if seq_len != self.seq_len_cached:
self.seq_len_cached = seq_len
t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
if self.precision == torch.bfloat16:
emb = emb.float()
self.cos_cached = emb.cos()[:, None, None, :]
self.sin_cached = emb.sin()[:, None, None, :]
if self.precision == torch.bfloat16:
self.cos_cached = self.cos_cached.bfloat16()
self.sin_cached = self.sin_cached.bfloat16()
return self.cos_cached, self.sin_cached

assert seq_len <= self.max_seq_len

if seq_len != self.max_seq_len:
# y, z, _ = self._prepare_cache(seq_len, self.precision, self.base)
return (
self.cos_cached[:seq_len, ...].to(x.device),
self.sin_cached[:seq_len, ...].to(x.device),
)
else:
return self.cos_cached.to(x.device), self.sin_cached.to(x.device)


# rotary pos emb helpers:
Expand Down
40 changes: 22 additions & 18 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,22 +222,23 @@ def __init__(
skip_bias_add=False,
mup_rescale_parameters=is_last_layer, # rescale params only called if neox_args.use_mup = True, despite it not being included here
)
# else:
# print(
# 'ERROR: Output layer parallelism over the hidden dim is currently broken (https://github.com/EleutherAI/gpt-neox/issues/905). Please run with output_layer_parallelism = "column" until this issue is fixed.'
# )
# exit()
# self.final_linear = mpu.RowParallelLinear(
# neox_args=neox_args,
# input_size=neox_args.hidden_size,
# output_size=neox_args.padded_vocab_size,
# bias=False,
# input_is_parallel=False,
# init_method=init_method,
# parallel_output=parallel_output,
# skip_bias_add=False,
# mup_rescale_parameters=is_last_layer, # only called if neox_args.use_mup = True, despite it not being included here
# )

# else:
# print(
# 'ERROR: Output layer parallelism over the hidden dim is currently broken (https://github.com/EleutherAI/gpt-neox/issues/905). Please run with output_layer_parallelism = "column" until this issue is fixed.'
# )
# exit()
# self.final_linear = mpu.RowParallelLinear(
# neox_args=neox_args,
# input_size=neox_args.hidden_size,
# output_size=neox_args.padded_vocab_size,
# bias=False,
# input_is_parallel=False,
# init_method=init_method,
# parallel_output=parallel_output,
# skip_bias_add=False,
# mup_rescale_parameters=is_last_layer, # only called if neox_args.use_mup = True, despite it not being included here
# )

def forward(self, hidden_states):
return self.final_linear(hidden_states)
Expand Down Expand Up @@ -327,7 +328,10 @@ def __init__(
else self.hidden_size_per_attention_head
)
self.rotary_emb = RotaryEmbedding(
dim, base=neox_args.rotary_emb_base, precision=neox_args.params_dtype
dim,
base=neox_args.rotary_emb_base,
max_seq_len=neox_args.seq_length,
precision=neox_args.params_dtype,
)
else:
self.rotary_emb = None
Expand All @@ -350,7 +354,7 @@ def __init__(
# Change of function names going from flash attention 1 -> flash attention 2
flash_attn_varlen_qkvpacked_func,
flash_attn_varlen_kvpacked_func,
flash_attn_unpadded_unpacked_func_triton
flash_attn_unpadded_unpacked_func_triton,
)

self.flash_triton_fn = flash_attn_unpadded_unpacked_func_triton
Expand Down

0 comments on commit 3bfedf4

Please sign in to comment.