diff --git a/configs/neox_arguments.md b/configs/neox_arguments.md index efb828799..f7d41dc22 100644 --- a/configs/neox_arguments.md +++ b/configs/neox_arguments.md @@ -111,7 +111,7 @@ Logging Arguments - **git_hash**: str - Default = 3f8c63c + Default = cd88435 current git hash of repository diff --git a/megatron/model/positional_embeddings.py b/megatron/model/positional_embeddings.py index 68815075a..29c8a7e65 100644 --- a/megatron/model/positional_embeddings.py +++ b/megatron/model/positional_embeddings.py @@ -36,7 +36,7 @@ def forward(self, x, seq_dim=1): class RotaryEmbedding(torch.nn.Module): - def __init__(self, dim, base=10000, precision=torch.half): + def __init__(self, dim, max_seq_len, base=10000, precision=torch.half): super().__init__() inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq) @@ -44,23 +44,50 @@ def __init__(self, dim, base=10000, precision=torch.half): self.cos_cached = None self.sin_cached = None self.precision = precision + self.max_seq_len = max_seq_len + self.base = base + self.dim = dim + + # precompute cos_cached, sin_cached in fp32 + cos_cached, sin_cached, inv_freq = self._prepare_cache( + max_seq_len, precision, base + ) + + self.register_buffer("inv_freq", inv_freq) + self.cos_cached = cos_cached + self.sin_cached = sin_cached + + def _prepare_cache(self, seq_len, precision, base): + # precompute cos_cached, sin_cached in fp32 + inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float() / self.dim)) + + t = torch.arange(seq_len).type_as(inv_freq) + freqs = torch.einsum("i,j->ij", t, inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + + cos_cached = emb.cos()[:, None, None, :] + sin_cached = emb.sin()[:, None, None, :] - def forward(self, x, seq_dim=1, seq_len=None): + return ( + cos_cached.to(precision), + sin_cached.to(precision), + inv_freq.to(precision), + ) + + def forward(self, x, seq_dim=0, seq_len=None): if seq_len is None: seq_len = x.shape[seq_dim] - if seq_len != self.seq_len_cached: - self.seq_len_cached = seq_len - t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq) - freqs = torch.einsum("i,j->ij", t, self.inv_freq) - emb = torch.cat((freqs, freqs), dim=-1).to(x.device) - if self.precision == torch.bfloat16: - emb = emb.float() - self.cos_cached = emb.cos()[:, None, None, :] - self.sin_cached = emb.sin()[:, None, None, :] - if self.precision == torch.bfloat16: - self.cos_cached = self.cos_cached.bfloat16() - self.sin_cached = self.sin_cached.bfloat16() - return self.cos_cached, self.sin_cached + + assert seq_len <= self.max_seq_len + + if seq_len != self.max_seq_len: + # y, z, _ = self._prepare_cache(seq_len, self.precision, self.base) + return ( + self.cos_cached[:seq_len, ...].to(x.device), + self.sin_cached[:seq_len, ...].to(x.device), + ) + else: + return self.cos_cached.to(x.device), self.sin_cached.to(x.device) # rotary pos emb helpers: diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 09b5c6985..63f4122e2 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -222,22 +222,23 @@ def __init__( skip_bias_add=False, mup_rescale_parameters=is_last_layer, # rescale params only called if neox_args.use_mup = True, despite it not being included here ) -# else: -# print( -# 'ERROR: Output layer parallelism over the hidden dim is currently broken (https://github.com/EleutherAI/gpt-neox/issues/905). Please run with output_layer_parallelism = "column" until this issue is fixed.' -# ) -# exit() -# self.final_linear = mpu.RowParallelLinear( -# neox_args=neox_args, -# input_size=neox_args.hidden_size, -# output_size=neox_args.padded_vocab_size, -# bias=False, -# input_is_parallel=False, -# init_method=init_method, -# parallel_output=parallel_output, -# skip_bias_add=False, -# mup_rescale_parameters=is_last_layer, # only called if neox_args.use_mup = True, despite it not being included here -# ) + + # else: + # print( + # 'ERROR: Output layer parallelism over the hidden dim is currently broken (https://github.com/EleutherAI/gpt-neox/issues/905). Please run with output_layer_parallelism = "column" until this issue is fixed.' + # ) + # exit() + # self.final_linear = mpu.RowParallelLinear( + # neox_args=neox_args, + # input_size=neox_args.hidden_size, + # output_size=neox_args.padded_vocab_size, + # bias=False, + # input_is_parallel=False, + # init_method=init_method, + # parallel_output=parallel_output, + # skip_bias_add=False, + # mup_rescale_parameters=is_last_layer, # only called if neox_args.use_mup = True, despite it not being included here + # ) def forward(self, hidden_states): return self.final_linear(hidden_states) @@ -327,7 +328,10 @@ def __init__( else self.hidden_size_per_attention_head ) self.rotary_emb = RotaryEmbedding( - dim, base=neox_args.rotary_emb_base, precision=neox_args.params_dtype + dim, + base=neox_args.rotary_emb_base, + max_seq_len=neox_args.seq_length, + precision=neox_args.params_dtype, ) else: self.rotary_emb = None @@ -350,7 +354,7 @@ def __init__( # Change of function names going from flash attention 1 -> flash attention 2 flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func, - flash_attn_unpadded_unpacked_func_triton + flash_attn_unpadded_unpacked_func_triton, ) self.flash_triton_fn = flash_attn_unpadded_unpacked_func_triton