Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bug report] h100 attn_causal kernel #33

Open
xiayuqing0622 opened this issue May 22, 2024 · 3 comments
Open

[bug report] h100 attn_causal kernel #33

xiayuqing0622 opened this issue May 22, 2024 · 3 comments

Comments

@xiayuqing0622
Copy link

Using the same random seed, the result of tk h100 attn_causal kernel vary with each run. In some cases, the max diff between tk and pytorch result can be larger than 2.

@xiayuqing0622
Copy link
Author

It turns out that a wgmma fence needs to be added after wgmma async wait. I have created a PR for your reference. #34

@Aaryan0404
Copy link
Collaborator

Thanks for sharing your PR @xiayuqing0622.

There should be a wgmma async_wait after the wgmma is committed. As you know, the wgmma API launches an async matrix multiply on the H100 tensor cores across the 4 warps in the warpgroup via the commit_group function exposed. In order to ensure this is completed, you need to call and async_wait on it - this was missing from the original code and has now been added to the the relevant .cu file in the main branch.

Re your PR, the wgmma fence/syncthreads() should not be necessary in order to achieve correctness. The fence is needed when on the output register tile before you launch the wgmma async instruction. Furthermore, the syncthreads() will likely unnecessarily slow performance.

Does the latest fix on main fix the randomness you were seeing?

@xiayuqing0622
Copy link
Author

Thanks for sharing your PR @xiayuqing0622.

There should be a wgmma async_wait after the wgmma is committed. As you know, the wgmma API launches an async matrix multiply on the H100 tensor cores across the 4 warps in the warpgroup via the commit_group function exposed. In order to ensure this is completed, you need to call and async_wait on it - this was missing from the original code and has now been added to the the relevant .cu file in the main branch.

Re your PR, the wgmma fence/syncthreads() should not be necessary in order to achieve correctness. The fence is needed when on the output register tile before you launch the wgmma async instruction. Furthermore, the syncthreads() will likely unnecessarily slow performance.

Does the latest fix on main fix the randomness you were seeing?

@Aaryan0404 Thanks for your reply. Actually, after reading the document of cuda, I also believe the wgmma fence/syncthreads() is not necessary. However, just adding async wait does not fix the randomness (I just tested it on the latest main branch). I don't know why. Here is my test script (just add a debug function in h100_fwd_check.py):

import torch 
import sys
import os
import time

current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.abspath(os.path.join(current_dir, "../../"))
sys.path.insert(0, project_root)
from src.common.pyutils.test_build_utils import __eq
sys.path.append('build/lib.linux-x86_64-cpython-312')
import h100_fwd as mod

from collections import defaultdict
import matplotlib.pyplot as plt
from statistics import median
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
torch.manual_seed(0)

def debug(name,expect, actual, atol=1e-3, rtol=1e-3):
    all_close = torch.allclose(expect, actual, atol=atol, rtol=rtol)
    print(name + "  all_close={}".format(all_close))
    if not all_close:
        diff = (expect - actual).abs()
        print("all_close={}, max={}, min={}, mean={}".format(all_close, diff.max().item(), diff.min().item(), diff.mean().item()))
        max_indices  = torch.nonzero(diff == diff.max().item())
        first_index = tuple(max_indices[0].tolist())
        print(f"Index: {first_index}, expect: {expect[first_index]}, actual: {actual[first_index]}") 


def pytorch_test(Q, K, V):
    output = torch.nn.functional.scaled_dot_product_attention(Q, K, V, is_causal=True)
    return output

def h100_fwd_kernel_test(Q, K, V):
    o = torch.zeros_like(Q)
    mod.attention_forward_causal(Q, K, V, o)
    return o

def check_correctness(b, h, n, d):
    print(f"Testing with b={b}, h={h}, n={n}, d={d}")
    
    Q = torch.randn(b, h, n, d, dtype=torch.bfloat16, device='cuda').contiguous()
    K = torch.randn(b, h, n, d, dtype=torch.bfloat16, device='cuda').contiguous()
    V = torch.randn(b, h, n, d, dtype=torch.bfloat16, device='cuda').contiguous()
    
    result_pytorch = pytorch_test(Q, K, V)
    tk_result = h100_fwd_kernel_test(Q, K, V)
    
    diff = result_pytorch - tk_result
    avg_diff_mag = torch.mean(torch.abs(diff)).item()
    avg_diff_per = 100 * avg_diff_mag / torch.mean(torch.abs(result_pytorch)).item()
    
    print(f"Attention output - avg magnitude of diff: {avg_diff_mag:.6f}")
    print("-" * 40)
    debug("Attention output", result_pytorch, tk_result)

print("Correctness Tests: ")
configurations = [
    # (2,  8, 256,   64),
    # (4,  8, 512,   64),
    # (8,  8, 1024,  64),
    # (16, 8, 2048,  64),
    # (16, 8, 4096,  64),
    # (16, 8, 8192,  64),
    # (16, 8, 16384, 64),
    # (2,  8, 256,   128),
    # (4,  8, 512,   128),
    # (8,  8, 1024,  128),
    # (16, 8, 2048,  128),
    # (16, 8, 4096,  128),
    (16, 8, 8192,  128),
    # (16, 8, 16384, 128)
]
for b, h, n, d in configurations:
    check_correctness(b, h, n, d)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants