You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When running a Llama3.1 training job with FSDP and BF16 mixed precision, we noticed a large gap in train/val loss between FP8 autocast enabled vs. disabled which we do not see with Llama2.
For context, the main difference between Llama3.1 and Llama2 is that the rotary embedding initialization is different (main relevant portion is that the RoPE rotary base value is changed from 10k to 500k).
However, this appears to have different behavior than passing both t and freqs explicitly as FP32, which I assume is some loss of precision with the fused rope kernel.
We can verify this with the following test script:
# TE main (after 1.11 at time of writing)
# needed for the tunable `rotary_base` change in `RotaryPositionEmbedding`
from transformer_engine.pytorch.attention import (
apply_rotary_pos_emb,
RotaryPositionEmbedding,
)
import torch
S = 2 # seqlen
B = 1 # bs
H = 1 # num attn heads per partition
D = 8 # hidden size per attn head
device = torch.device("cuda:0")
rope = RotaryPositionEmbedding(dim=D, rotary_base=10000)
# FP32 rope frequencies
rotary_emb = rope(max_seq_len=S)
# BF16 input tensor
t1 = torch.rand((S, B, H, D), dtype=torch.bfloat16, device=device)
# FP32 input tensor
t2 = t1.detach().clone().to(torch.float32)
t1_out = apply_rotary_pos_emb(t=t1, freqs=rotary_emb, fused=True).to(torch.float32)
t2_out = apply_rotary_pos_emb(t=t2, freqs=rotary_emb, fused=True)
# Assertion fails here
assert torch.allclose(t1_out, t2_out)
This precision loss appears to be fine for the smaller rotary_base=10k for Llama2, but is exaggerated when rotary_base=500k and even further exaggerated when FP8 is enabled.
Llama with rotary_base=500k.
blue=FP8 autocast off
orange=FP8 autocast on
Llama with rotary_base=500k with the query/key layers upcast to FP32 before FusedRoPEFunc.apply().
pink=FP8 autocast off
green=FP8 autocast on
What I think is happening, is that even though the RoPE computation itself is not being done in FP8, the precision loss is being amplified and accumulating across layers/iterations when FP8 is enabled.
Thus, I think there should be an explicit upcast to FP32 for the query/key layers and so the logic in MultiheadAttention.forward() should look something like this:
It could also be an upcast_to_fp32 flag in apply_rotary_pos_emb, either way looks good to me.
This should probably be applied to all cases, but if not all, at least when FP8 is enabled to mitigate its impact.
Although slightly different, HuggingFace appeared to also see similar precision issues with RoPE, though this is just showing how impactful RoPE precision is.
Note 1: I did not investigate the unfused rope case as it is not used in MultiheadAttention.
Note 2: The Llama runs were run with TE=v1.10, with a patched RotaryPositionEmbedding to allow the tuning of rotary_base.
What do you guys think?
Thanks!
The text was updated successfully, but these errors were encountered:
and the assertion successes. And I think this is closer to the real case, where we usually pass BF16 q/k/v to the core attention, not FP32.
In fact, the fused RoPE kernel just works as your suggestion.
It will first cast the input tensor to FP32, then do the calculation in FP32, and store the results back to the original dtype.
So I'm surprised that there is such a big difference in your loss curves. @ksivaman Can you comment as I remember you changed the computation of RoPE to FP32 in #645?
Hi everyone,
When running a Llama3.1 training job with FSDP and BF16 mixed precision, we noticed a large gap in train/val loss between FP8 autocast enabled vs. disabled which we do not see with Llama2.
For context, the main difference between Llama3.1 and Llama2 is that the rotary embedding initialization is different (main relevant portion is that the RoPE rotary base value is changed from 10k to 500k).
After some investigation we found that this was because although the RoPE inverse frequencies are created in FP32 (and cast to FP32 during fused rope application), the query/key layers are in BF16 (due to mixed precision in our case), they are passed to
tex.fused_rope_forward(t, freqs)
as is (which I assume is intentional).However, this appears to have different behavior than passing both
t
andfreqs
explicitly as FP32, which I assume is some loss of precision with the fused rope kernel.We can verify this with the following test script:
This precision loss appears to be fine for the smaller
rotary_base=10k
for Llama2, but is exaggerated whenrotary_base=500k
and even further exaggerated when FP8 is enabled.Llama with
rotary_base=500k
.Llama with
rotary_base=500k
with the query/key layers upcast to FP32 beforeFusedRoPEFunc.apply()
.What I think is happening, is that even though the RoPE computation itself is not being done in FP8, the precision loss is being amplified and accumulating across layers/iterations when FP8 is enabled.
Thus, I think there should be an explicit upcast to FP32 for the query/key layers and so the logic in MultiheadAttention.forward() should look something like this:
It could also be an
upcast_to_fp32
flag inapply_rotary_pos_emb
, either way looks good to me.This should probably be applied to all cases, but if not all, at least when FP8 is enabled to mitigate its impact.
Although slightly different, HuggingFace appeared to also see similar precision issues with RoPE, though this is just showing how impactful RoPE precision is.
Note 1: I did not investigate the unfused rope case as it is not used in
MultiheadAttention
.Note 2: The Llama runs were run with TE=v1.10, with a patched
RotaryPositionEmbedding
to allow the tuning ofrotary_base
.What do you guys think?
Thanks!
The text was updated successfully, but these errors were encountered: