Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lower precision RoPE computation leads to training instability #1245

Open
viclzhu opened this issue Oct 12, 2024 · 2 comments
Open

Lower precision RoPE computation leads to training instability #1245

viclzhu opened this issue Oct 12, 2024 · 2 comments

Comments

@viclzhu
Copy link

viclzhu commented Oct 12, 2024

Hi everyone,

When running a Llama3.1 training job with FSDP and BF16 mixed precision, we noticed a large gap in train/val loss between FP8 autocast enabled vs. disabled which we do not see with Llama2.

For context, the main difference between Llama3.1 and Llama2 is that the rotary embedding initialization is different (main relevant portion is that the RoPE rotary base value is changed from 10k to 500k).

After some investigation we found that this was because although the RoPE inverse frequencies are created in FP32 (and cast to FP32 during fused rope application), the query/key layers are in BF16 (due to mixed precision in our case), they are passed to tex.fused_rope_forward(t, freqs) as is (which I assume is intentional).

However, this appears to have different behavior than passing both t and freqs explicitly as FP32, which I assume is some loss of precision with the fused rope kernel.

We can verify this with the following test script:

# TE main (after 1.11 at time of writing)
# needed for the tunable `rotary_base` change in `RotaryPositionEmbedding`
from transformer_engine.pytorch.attention import (
    apply_rotary_pos_emb,
    RotaryPositionEmbedding,
)
import torch

S = 2  # seqlen
B = 1  # bs
H = 1  # num attn heads per partition
D = 8  # hidden size per attn head

device = torch.device("cuda:0")
rope = RotaryPositionEmbedding(dim=D, rotary_base=10000)

# FP32 rope frequencies
rotary_emb = rope(max_seq_len=S)

# BF16 input tensor
t1 = torch.rand((S, B, H, D), dtype=torch.bfloat16, device=device)
# FP32 input tensor
t2 = t1.detach().clone().to(torch.float32)

t1_out = apply_rotary_pos_emb(t=t1, freqs=rotary_emb, fused=True).to(torch.float32)
t2_out = apply_rotary_pos_emb(t=t2, freqs=rotary_emb, fused=True)

# Assertion fails here
assert torch.allclose(t1_out, t2_out)

This precision loss appears to be fine for the smaller rotary_base=10k for Llama2, but is exaggerated when rotary_base=500k and even further exaggerated when FP8 is enabled.

Llama with rotary_base=500k.

  • blue=FP8 autocast off
  • orange=FP8 autocast on
    Image

Llama with rotary_base=500k with the query/key layers upcast to FP32 before FusedRoPEFunc.apply().

  • pink=FP8 autocast off
  • green=FP8 autocast on
    Image

What I think is happening, is that even though the RoPE computation itself is not being done in FP8, the precision loss is being amplified and accumulating across layers/iterations when FP8 is enabled.

Thus, I think there should be an explicit upcast to FP32 for the query/key layers and so the logic in MultiheadAttention.forward() should look something like this:

orig_q_dtype = query_layer.dtype
orig_k_dtype = key_layer.dtype

# Upcast qk to FP32.
query_layer = query_layer.to(torch.float32)
key_layer = key_layer.to(torch.float32)

query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb, self.qkv_format, fused=True)
key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb, self.qkv_format, fused=True)

# Cast qk back to orig dtype.
query_layer = query_layer.to(orig_q_dtype)
key_layer = key_layer.to(orig_k_dtype)

It could also be an upcast_to_fp32 flag in apply_rotary_pos_emb, either way looks good to me.

This should probably be applied to all cases, but if not all, at least when FP8 is enabled to mitigate its impact.

Although slightly different, HuggingFace appeared to also see similar precision issues with RoPE, though this is just showing how impactful RoPE precision is.

Note 1: I did not investigate the unfused rope case as it is not used in MultiheadAttention.
Note 2: The Llama runs were run with TE=v1.10, with a patched RotaryPositionEmbedding to allow the tuning of rotary_base.

What do you guys think?

Thanks!

@yaox12
Copy link
Collaborator

yaox12 commented Oct 12, 2024

Hi @viclzhu, thanks for your report.

In your test script, change t1_out and t2_out to

t1_out = apply_rotary_pos_emb(t=t1, freqs=rotary_emb, fused=True)
t2_out = apply_rotary_pos_emb(t=t2, freqs=rotary_emb, fused=True).to(torch.bfloat16)

and the assertion successes. And I think this is closer to the real case, where we usually pass BF16 q/k/v to the core attention, not FP32.

In fact, the fused RoPE kernel just works as your suggestion.
It will first cast the input tensor to FP32, then do the calculation in FP32, and store the results back to the original dtype.

float v_src = src[offset_src];
float v_src_rotate = (d_id + d2 / 2 < d2)
? -static_cast<float>(src[offset_src + (d2 / 2) * stride_d])
: static_cast<float>(src[offset_src + (d2 / 2 - d2) * stride_d]);

So I'm surprised that there is such a big difference in your loss curves.
@ksivaman Can you comment as I remember you changed the computation of RoPE to FP32 in #645?

@viclzhu
Copy link
Author

viclzhu commented Oct 12, 2024

Thanks @yaox12!

Hm, you're right, it should be testing the outputs cast to BF16, and since the outputs are close enough, there shouldn't be any impact from this.

Not sure why we're seeing this behavior then, maybe there's something else going on in our setup.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants