Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SageAttention kernel and Test codes #16

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from . import blocksparse
from .cross_entropy import _cross_entropy, cross_entropy
from .flash_attention import attention
from .sage_attention import sageattn
from .matmul import _matmul, get_higher_dtype, matmul

__all__ = [
Expand All @@ -11,5 +12,6 @@
"_matmul",
"matmul",
"attention",
"sageattn",
"get_higher_dtype",
]
287 changes: 287 additions & 0 deletions kernels/sage_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,287 @@
import torch, math
import triton
import triton.language as tl

@triton.jit
def quant_per_block_int8_kernel(Input, Output, Scale, L,
stride_iz, stride_ih, stride_in,
stride_oz, stride_oh, stride_on,
stride_sz, stride_sh,
sm_scale,
C: tl.constexpr, BLK: tl.constexpr):
off_blk = tl.program_id(0)
off_h = tl.program_id(1)
off_b = tl.program_id(2)
offs_n = off_blk * BLK + tl.arange(0, BLK)
offs_k = tl.arange(0, C)
input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :]
output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :]
scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk
x = tl.load(input_ptrs, mask=offs_n[:, None] < L)
x = x.to(tl.float32)
x *= sm_scale
scale = tl.max(tl.abs(x)) / 127.
x_int8 = x / scale
x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1)
x_int8 = x_int8.to(tl.int8)
tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L)
tl.store(scale_ptrs, scale)

def per_block_int8(q, k, BLKQ=128, BLKK=64, sm_scale=None, tensor_layout="HND"):
q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device)
k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device)
if tensor_layout == "HND":
b, h_qo, qo_len, head_dim = q.shape
_, h_kv, kv_len, _ = k.shape
stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(1), q.stride(2)
stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(1), q_int8.stride(2)
stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(1), k.stride(2)
stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(1), k_int8.stride(2)
elif tensor_layout == "NHD":
b, qo_len, h_qo, head_dim = q.shape
_, kv_len, h_kv, _ = k.shape
stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(2), q.stride(1)
stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(2), q_int8.stride(1)
stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(2), k.stride(1)
stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(2), k_int8.stride(1)
else:
raise ValueError(f"Unknown tensor layout: {tensor_layout}")

q_scale = torch.empty((b, h_qo, (qo_len + BLKQ - 1) // BLKQ, 1), device=q.device, dtype=torch.float32)
k_scale = torch.empty((b, h_kv, (kv_len + BLKK - 1) // BLKK, 1), device=q.device, dtype=torch.float32)

if sm_scale is None:
sm_scale = head_dim**-0.5

grid = ((qo_len + BLKQ - 1) // BLKQ, h_qo, b)
quant_per_block_int8_kernel[grid](
q, q_int8, q_scale, qo_len,
stride_bz_q, stride_h_q, stride_seq_q,
stride_bz_qo, stride_h_qo, stride_seq_qo,
q_scale.stride(0), q_scale.stride(1),
sm_scale=(sm_scale * 1.44269504),
C=head_dim, BLK=BLKQ
)

grid = ((kv_len + BLKK - 1) // BLKK, h_kv, b)
quant_per_block_int8_kernel[grid](
k, k_int8, k_scale, kv_len,
stride_bz_k, stride_h_k, stride_seq_k,
stride_bz_ko, stride_h_ko, stride_seq_ko,
k_scale.stride(0), k_scale.stride(1),
sm_scale=1.0,
C=head_dim, BLK=BLKK
)

return q_int8, q_scale, k_int8, k_scale


@triton.jit
def _attn_fwd_inner(acc, l_i, m_i, q, q_scale, kv_len,
K_ptrs, K_scale_ptr, V_ptrs, stride_kn, stride_vn,
start_m,
BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr,
STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr,
):
if STAGE == 1:
lo, hi = 0, start_m * BLOCK_M
elif STAGE == 2:
lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
lo = tl.multiple_of(lo, BLOCK_M)
K_scale_ptr += lo // BLOCK_N
K_ptrs += stride_kn * lo
V_ptrs += stride_vn * lo
elif STAGE == 3:
lo, hi = 0, kv_len
for start_n in range(lo, hi, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
k_mask = offs_n[None, :] < (kv_len - start_n)
k = tl.load(K_ptrs, mask = k_mask)
k_scale = tl.load(K_scale_ptr)
qk = tl.dot(q, k).to(tl.float32) * q_scale * k_scale

if STAGE == 2:
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
qk = qk + tl.where(mask, 0, -1.0e6)
m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk -= m_ij[:, None]
else:
m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk = qk - m_ij[:, None]
p = tl.math.exp2(qk)
l_ij = tl.sum(p, 1)
alpha = tl.math.exp2(m_i - m_ij)
l_i = l_i * alpha + l_ij
acc = acc * alpha[:, None]
v = tl.load(V_ptrs, mask = offs_n[:, None] < (kv_len - start_n))
p = p.to(tl.float16)
acc += tl.dot(p, v, out_dtype=tl.float16)
m_i = m_ij
K_ptrs += BLOCK_N * stride_kn
K_scale_ptr += 1
V_ptrs += BLOCK_N * stride_vn
return acc, l_i, m_i

@triton.jit
def _attn_fwd(Q, K, V, Q_scale, K_scale, Out,
stride_qz, stride_qh, stride_qn,
stride_kz, stride_kh, stride_kn,
stride_vz, stride_vh, stride_vn,
stride_oz, stride_oh, stride_on,
qo_len, kv_len, H:tl.constexpr, num_kv_groups:tl.constexpr,
HEAD_DIM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
STAGE: tl.constexpr
):
start_m = tl.program_id(0)
off_z = tl.program_id(2).to(tl.int64)
off_h = tl.program_id(1).to(tl.int64)
q_scale_offset = (off_z * H + off_h) * tl.cdiv(qo_len, BLOCK_M)
k_scale_offset = (off_z * (H // num_kv_groups) + off_h // num_kv_groups) * tl.cdiv(kv_len, BLOCK_N)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, HEAD_DIM)
Q_ptrs = Q + (off_z * stride_qz + off_h * stride_qh) + offs_m[:, None] * stride_qn + offs_k[None, :]
Q_scale_ptr = Q_scale + q_scale_offset + start_m
K_ptrs = K + (off_z * stride_kz + (off_h // num_kv_groups) * stride_kh) + offs_n[None, :] * stride_kn + offs_k[:, None]
K_scale_ptr = K_scale + k_scale_offset
V_ptrs = V + (off_z * stride_vz + (off_h // num_kv_groups) * stride_vh) + offs_n[:, None] * stride_vn + offs_k[None, :]
O_block_ptr = Out + (off_z * stride_oz + off_h * stride_oh) + offs_m[:, None] * stride_on + offs_k[None, :]
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
q = tl.load(Q_ptrs, mask = offs_m[:, None] < qo_len)
q_scale = tl.load(Q_scale_ptr)
acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, q_scale, kv_len, K_ptrs, K_scale_ptr, V_ptrs, stride_kn, stride_vn,
start_m,
BLOCK_M, HEAD_DIM, BLOCK_N,
4 - STAGE, offs_m, offs_n
)
if STAGE != 1:
acc, l_i, _ = _attn_fwd_inner(acc, l_i, m_i, q, q_scale, kv_len, K_ptrs, K_scale_ptr, V_ptrs, stride_kn, stride_vn,
start_m,
BLOCK_M, HEAD_DIM, BLOCK_N,
2, offs_m, offs_n
)
acc = acc / l_i[:, None]
tl.store(O_block_ptr, acc.to(Out.type.element_ty), mask = (offs_m[:, None] < qo_len))


def forward(q, k, v, q_scale, k_scale, is_casual, tensor_layout="HND", output_dtype=torch.float16):
BLOCK_M = 128
BLOCK_N = 64
stage = 3 if is_casual else 1

o = torch.empty(q.shape, dtype=output_dtype, device=q.device)

if tensor_layout == "HND":
b, h_qo, qo_len, head_dim = q.shape
_, h_kv, kv_len, _ = k.shape
stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(1), q.stride(2)
stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(1), k.stride(2)
stride_bz_v, stride_h_v, stride_seq_v = v.stride(0), v.stride(1), v.stride(2)
stride_bz_o, stride_h_o, stride_seq_o = o.stride(0), o.stride(1), o.stride(2)
elif tensor_layout == "NHD":
b, qo_len, h_qo, head_dim = q.shape
_, kv_len, h_kv, _ = k.shape
stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(2), q.stride(1)
stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(2), k.stride(1)
stride_bz_v, stride_h_v, stride_seq_v = v.stride(0), v.stride(2), v.stride(1)
stride_bz_o, stride_h_o, stride_seq_o = o.stride(0), o.stride(2), o.stride(1)
else:
raise ValueError(f"tensor_layout {tensor_layout} not supported")

assert qo_len == kv_len, "qo_len and kv_len must be equal for causal attention"

HEAD_DIM_K = head_dim
num_kv_groups = h_qo // h_kv

grid = (triton.cdiv(qo_len, BLOCK_M), h_qo, b )
_attn_fwd[grid](
q, k, v, q_scale, k_scale, o,
stride_bz_q, stride_h_q, stride_seq_q,
stride_bz_k, stride_h_k, stride_seq_k,
stride_bz_v, stride_h_v, stride_seq_v,
stride_bz_o, stride_h_o, stride_seq_o,
qo_len, kv_len,
h_qo, num_kv_groups,
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, HEAD_DIM=HEAD_DIM_K,
STAGE=stage,
num_warps=4 if head_dim == 64 else 8,
num_stages=4)
return o


from typing import Any, List, Literal, Optional, Tuple, Union

def sageattn(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
tensor_layout: str ="HND",
is_causal=False,
sm_scale: Optional[float] = None,
smooth_k: bool =True,
**kwargs: Any,
) -> torch.Tensor:
"""
Parameters
----------
q : torch.Tensor
The query tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
k : torch.Tensor
The key tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
v : torch.Tensor
The value tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
tensor_layout : str
The tensor layout, either "HND" or "NHD".
Default: "HND".
is_causal : bool
Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len.
Default: False.
sm_scale : Optional[float]
The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
smooth_k : bool
Whether to smooth the key tensor by subtracting the mean along the sequence dimension.
Default: True.
Returns
-------
torch.Tensor
The output tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
Note
----
- ``num_qo_heads`` must be divisible by ``num_kv_heads``.
- The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
- All tensors must be on the same cuda device.
"""

dtype = q.dtype
headdim = q.size(-1)
assert headdim in [64, 128], "headdim should be in [64, 128]."

# assert last dim is contiguous
assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, "Last dim of qkv must be contiguous."
seq_dim = 1 if tensor_layout == "NHD" else 2

if smooth_k:
km = k.mean(dim=seq_dim, keepdim=True)
k -= km
else:
km = None
if dtype == torch.bfloat16:
v = v.to(torch.float16)

q_int8, q_scale, k_int8, k_scale = per_block_int8(q, k, sm_scale=sm_scale, tensor_layout=tensor_layout)
o = forward(q_int8, k_int8, v, q_scale, k_scale, is_causal = is_causal, tensor_layout=tensor_layout, output_dtype=dtype)

return o

Loading