Skip to content

Commit

Permalink
Couple of FA optimizations (#608)
Browse files Browse the repository at this point in the history
Couple of FA optimizations

Set SM scale multiplication to a constexpr. Minor asm improvement.

Changed acc scaling to adjust for softmax division to
multiplication with reciprocal. ~10% perf improvement.

---------

Co-authored-by: Michael Melesse <[email protected]>
  • Loading branch information
vgokhale and micmelesse authored Jul 19, 2024
1 parent 23ba546 commit df4c4d3
Showing 1 changed file with 18 additions and 21 deletions.
39 changes: 18 additions & 21 deletions python/perf-kernels/flash-attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,35 +301,28 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri

@triton.autotune(
configs=[
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=4),
triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': True}, num_stages=1,
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=4),
# triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': False}, num_stages=1,
# num_warps=4),
triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=8),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=4),
triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=8),
# TODO: This config fails with head_size not pow2 with data mismatches. Check why.
# triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=4),
# Fall-back config.
triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1,
num_warps=4),
],
key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'],
key=['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK'],
use_cuda_graph=True,
)
@triton.jit
def attn_fwd(Q, K, V, bias, sm_scale, L, Out, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh,
stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, stride_oz, stride_oh, stride_om,
stride_on, stride_bz, stride_bh, stride_bm, stride_bn, stride_az, stride_ah, cu_seqlens_q, cu_seqlens_k,
dropout_p, philox_seed, philox_offset_base, encoded_softmax, alibi_slopes, HQ: tl.constexpr,
def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, L, Out, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz,
stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, stride_oz, stride_oh,
stride_om, stride_on, stride_bz, stride_bh, stride_bm, stride_bn, stride_az, stride_ah, cu_seqlens_q,
cu_seqlens_k, dropout_p, philox_seed, philox_offset_base, encoded_softmax, alibi_slopes, HQ: tl.constexpr,
HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr,
MAX_SEQLENS_K: tl.constexpr, VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr,
Expand Down Expand Up @@ -446,13 +439,13 @@ def attn_fwd(Q, K, V, bias, sm_scale, L, Out, stride_qz, stride_qh, stride_qm, s
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# scale sm_scale by log_2(e) and use 2^x in the loop as we do not
# have native e^x support in HW.
qk_scale = sm_scale * 1.44269504089
QK_SCALE: tl.constexpr = SM_SCALE * 1.44269504089
# Q is loaded once at the beginning and shared by all N blocks.
q_ptrs_mask = offs_m[:, None] < seqlen_q
if PADDED_HEAD:
q_ptrs_mask = q_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL)
q = tl.load(q_ptrs, mask=q_ptrs_mask, other=0.0)
q = (q * qk_scale).to(q.type.element_ty)
q = (q * QK_SCALE).to(q.type.element_ty)

# Here we compute how many full and masked blocks we have.
padded_block_k = n_extra_tokens != 0
Expand Down Expand Up @@ -509,7 +502,10 @@ def attn_fwd(Q, K, V, bias, sm_scale, L, Out, stride_qz, stride_qh, stride_qm, s
PRE_LOAD_V, True, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, PADDED_HEAD,
ACTUAL_BLOCK_DMODEL)
# epilogue
acc = acc / l_i[:, None]
# This helps the compiler do Newton Raphson on l_i vs on acc which is much larger.
l_recip = 1 / l_i[:, None]
acc = acc * l_recip

if ENABLE_DROPOUT:
acc = acc / (1 - dropout_p)
# If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M,
Expand Down Expand Up @@ -1198,6 +1194,7 @@ def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype):
# this by converting the NaNs to 0s, which is what they should be out of the softmax.
nan_mask = torch.isnan(p)
p[nan_mask == 1] = 0

ref_out = torch.einsum('bhqk,bhkd->bhqd', p.to(dtype), v)
# compare
torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2)
Expand Down

0 comments on commit df4c4d3

Please sign in to comment.