-
Notifications
You must be signed in to change notification settings - Fork 66
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[bug report] h100 attn_causal kernel #33
Comments
It turns out that a wgmma fence needs to be added after wgmma async wait. I have created a PR for your reference. #34 |
Thanks for sharing your PR @xiayuqing0622. There should be a wgmma async_wait after the wgmma is committed. As you know, the wgmma API launches an async matrix multiply on the H100 tensor cores across the 4 warps in the warpgroup via the commit_group function exposed. In order to ensure this is completed, you need to call and async_wait on it - this was missing from the original code and has now been added to the the relevant .cu file in the main branch. Re your PR, the wgmma fence/syncthreads() should not be necessary in order to achieve correctness. The fence is needed when on the output register tile before you launch the wgmma async instruction. Furthermore, the syncthreads() will likely unnecessarily slow performance. Does the latest fix on main fix the randomness you were seeing? |
@Aaryan0404 Thanks for your reply. Actually, after reading the document of cuda, I also believe the wgmma fence/syncthreads() is not necessary. However, just adding async wait does not fix the randomness (I just tested it on the latest main branch). I don't know why. Here is my test script (just add a debug function in h100_fwd_check.py): import torch
import sys
import os
import time
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.abspath(os.path.join(current_dir, "../../"))
sys.path.insert(0, project_root)
from src.common.pyutils.test_build_utils import __eq
sys.path.append('build/lib.linux-x86_64-cpython-312')
import h100_fwd as mod
from collections import defaultdict
import matplotlib.pyplot as plt
from statistics import median
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
torch.manual_seed(0)
def debug(name,expect, actual, atol=1e-3, rtol=1e-3):
all_close = torch.allclose(expect, actual, atol=atol, rtol=rtol)
print(name + " all_close={}".format(all_close))
if not all_close:
diff = (expect - actual).abs()
print("all_close={}, max={}, min={}, mean={}".format(all_close, diff.max().item(), diff.min().item(), diff.mean().item()))
max_indices = torch.nonzero(diff == diff.max().item())
first_index = tuple(max_indices[0].tolist())
print(f"Index: {first_index}, expect: {expect[first_index]}, actual: {actual[first_index]}")
def pytorch_test(Q, K, V):
output = torch.nn.functional.scaled_dot_product_attention(Q, K, V, is_causal=True)
return output
def h100_fwd_kernel_test(Q, K, V):
o = torch.zeros_like(Q)
mod.attention_forward_causal(Q, K, V, o)
return o
def check_correctness(b, h, n, d):
print(f"Testing with b={b}, h={h}, n={n}, d={d}")
Q = torch.randn(b, h, n, d, dtype=torch.bfloat16, device='cuda').contiguous()
K = torch.randn(b, h, n, d, dtype=torch.bfloat16, device='cuda').contiguous()
V = torch.randn(b, h, n, d, dtype=torch.bfloat16, device='cuda').contiguous()
result_pytorch = pytorch_test(Q, K, V)
tk_result = h100_fwd_kernel_test(Q, K, V)
diff = result_pytorch - tk_result
avg_diff_mag = torch.mean(torch.abs(diff)).item()
avg_diff_per = 100 * avg_diff_mag / torch.mean(torch.abs(result_pytorch)).item()
print(f"Attention output - avg magnitude of diff: {avg_diff_mag:.6f}")
print("-" * 40)
debug("Attention output", result_pytorch, tk_result)
print("Correctness Tests: ")
configurations = [
# (2, 8, 256, 64),
# (4, 8, 512, 64),
# (8, 8, 1024, 64),
# (16, 8, 2048, 64),
# (16, 8, 4096, 64),
# (16, 8, 8192, 64),
# (16, 8, 16384, 64),
# (2, 8, 256, 128),
# (4, 8, 512, 128),
# (8, 8, 1024, 128),
# (16, 8, 2048, 128),
# (16, 8, 4096, 128),
(16, 8, 8192, 128),
# (16, 8, 16384, 128)
]
for b, h, n, d in configurations:
check_correctness(b, h, n, d) |
Using the same random seed, the result of tk h100 attn_causal kernel vary with each run. In some cases, the max diff between tk and pytorch result can be larger than 2.
The text was updated successfully, but these errors were encountered: