Skip to content

Commit

Permalink
[KVCache][Test] Fix TIR attn kernels for uncommon group size (#17074)
Browse files Browse the repository at this point in the history
This PR fixes the TIR attention kernels in PagedKVCache tests, which
had issues when handling uncommon GQA group size (e.g., 6).
  • Loading branch information
MasterJH5574 authored Jun 9, 2024
1 parent 0db8220 commit 4183229
Showing 1 changed file with 60 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1181,8 +1181,8 @@ def batch_prefill_paged_kv(

if T.tvm_thread_invariant(batch_idx[0] < batch_size):
b_idx: T.int32 = batch_idx[0]
L_start: T.int32 = q_indptr[b_idx] + tile_id[0] * L_per_cta
H_qo_start: T.int32 = by * group_size
LH_start: T.int32 = tile_id[0] * tile_x
q_indptr_val: T.int32 = q_indptr[b_idx]

cur_page_indptr_begin: T.int32 = page_indptr[b_idx]
cur_page_indptr_end: T.int32 = page_indptr[b_idx + 1]
Expand Down Expand Up @@ -1212,8 +1212,8 @@ def batch_prefill_paged_kv(
i, j = T.axis.remap("SS", [li, lj])
T.reads()
T.writes()
cur_L = L_start + i // group_size
cur_H_qo = H_qo_start + i % group_size
cur_L = q_indptr_val + (LH_start + i) // group_size
cur_H_qo = by * group_size + (LH_start + i) % group_size
if cur_L < q_indptr[b_idx + 1]:
Q_smem[i, j] = T.if_then_else(
rotary_mode == 1,
Expand Down Expand Up @@ -1282,9 +1282,10 @@ def batch_prefill_paged_kv(
m_prev[i] = m_smem[row]
m_new[i] = m_smem[row]
# mask out of kv_chunk_len S
row_: T.int32 = (LH_start + row) // group_size
for j in T.serial(tile_z):
if _causal_mask(causal,
row=tile_id[0] * L_per_cta + row // group_size,
row=row_,
col=L_kv_start + j,
kv_len=kv_chunk_len[0],
qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx]):
Expand All @@ -1297,8 +1298,9 @@ def batch_prefill_paged_kv(
for j in T.serial(tile_z):
# this is to avoid sync inside condition branch
if row < tile_x:
row_: T.int32 = (LH_start + row) // group_size
if _causal_mask(causal,
row=tile_id[0] * L_per_cta + row // group_size,
row=row_,
col=L_kv_start + j,
kv_len=kv_chunk_len[0],
qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx]):
Expand Down Expand Up @@ -1330,15 +1332,19 @@ def batch_prefill_paged_kv(
for li, lj in T.grid(tile_x, tile_y):
with T.block("O_store"):
i, j = T.axis.remap("SS", [li, lj])
if L_start + i // group_size < q_indptr[b_idx + 1]:
output[L_start + i // group_size, H_qo_start + i % group_size, j] = O_local[i, j] / d_smem[i]
cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // group_size
cur_H_qo: T.int32 = by * group_size + (LH_start + i) % group_size
if cur_L < q_indptr[b_idx + 1]:
output[cur_L, cur_H_qo, j] = O_local[i, j] / d_smem[i]

# Store LSE to gmem
for li in T.grid(tile_x):
with T.block("lse_store"):
i = T.axis.remap("S", [li])
if L_start + i // group_size < q_indptr[b_idx + 1]:
lse[L_start + i // group_size, H_qo_start + i % group_size] = m_smem[i] + T.log2(d_smem[i])
cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // group_size
cur_H_qo: T.int32 = by * group_size + (LH_start + i) % group_size
if cur_L < q_indptr[b_idx + 1]:
lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i])

# move to next tile
tile_id[0] += NUM_BLKS
Expand Down Expand Up @@ -1688,7 +1694,6 @@ def _attention_prefill_ragged(
bdx = 32
num_warps = 4
tile_x, tile_y, tile_z = 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), d, 16
L_per_cta = tile_x // group_size

# Otherwise we would exceed maxComputeWorkgroupStorageSize
if (
Expand Down Expand Up @@ -1784,8 +1789,8 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-arguments,too-many-branc

if T.tvm_thread_invariant(batch_idx[0] < batch_size):
b_idx: T.int32 = batch_idx[0]
L_start: T.int32 = q_indptr[b_idx] + tile_id[0] * L_per_cta
H_qo_start: T.int32 = by * group_size
LH_start: T.int32 = tile_id[0] * tile_x
q_indptr_val: T.int32 = q_indptr[b_idx]

kv_chunk_len[0] = kv_indptr[b_idx + 1] - kv_indptr[b_idx]
T.tvm_storage_sync("shared")
Expand All @@ -1809,8 +1814,8 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-arguments,too-many-branc
i, j = T.axis.remap("SS", [li, lj])
T.reads()
T.writes()
cur_L = L_start + i // group_size
cur_H_qo = H_qo_start + i % group_size
cur_L = q_indptr_val + (LH_start + i) // group_size
cur_H_qo = by * group_size + (LH_start + i) % group_size
if cur_L < q_indptr[b_idx + 1]:
Q_smem[i, j] = T.if_then_else(
rotary_mode == 1,
Expand Down Expand Up @@ -1874,9 +1879,10 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-arguments,too-many-branc
m_prev[i] = m_smem[row]
m_new[i] = m_smem[row]
# mask out of kv_chunk_len S
row_: T.int32 = (LH_start + row) // group_size
for j in T.serial(tile_z):
if _causal_mask(causal,
row=tile_id[0] * L_per_cta + row // group_size,
row=row_,
col=L_kv_start + j,
kv_len=kv_chunk_len[0],
qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx]):
Expand All @@ -1889,8 +1895,9 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-arguments,too-many-branc
for j in T.serial(tile_z):
# this is to avoid sync inside condition branch
if row < tile_x:
row_: T.int32 = (LH_start + row) // group_size
if _causal_mask(causal,
row=tile_id[0] * L_per_cta + row // group_size,
row=row_,
col=L_kv_start + j,
kv_len=kv_chunk_len[0],
qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx]):
Expand Down Expand Up @@ -1922,15 +1929,19 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-arguments,too-many-branc
for li, lj in T.grid(tile_x, tile_y):
with T.block("O_store"):
i, j = T.axis.remap("SS", [li, lj])
if L_start + i // group_size < q_indptr[b_idx + 1]:
output[L_start + i // group_size, H_qo_start + i % group_size, j] = O_local[i, j] / d_smem[i]
cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // group_size
cur_H_qo: T.int32 = by * group_size + (LH_start + i) % group_size
if cur_L < q_indptr[b_idx + 1]:
output[cur_L, cur_H_qo, j] = O_local[i, j] / d_smem[i]

# Store LSE to gmem
for li in T.grid(tile_x):
with T.block("lse_store"):
i = T.axis.remap("S", [li])
if L_start + i // group_size < q_indptr[b_idx + 1]:
lse[L_start + i // group_size, H_qo_start + i % group_size] = m_smem[i] + T.log2(d_smem[i])
cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // group_size
cur_H_qo: T.int32 = by * group_size + (LH_start + i) % group_size
if cur_L < q_indptr[b_idx + 1]:
lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i])

# move to next tile
tile_id[0] += NUM_BLKS
Expand Down Expand Up @@ -2122,8 +2133,8 @@ def batch_tree_attn( # pylint: disable=too-many-branches

if T.tvm_thread_invariant(batch_idx[0] < batch_size):
b_idx: T.int32 = batch_idx[0]
L_start: T.int32 = q_indptr[b_idx] + tile_id[0] * L_per_cta
H_qo_start: T.int32 = by * group_size
LH_start: T.int32 = tile_id[0] * tile_x
q_indptr_val: T.int32 = q_indptr[b_idx]

kv_chunk_len[0] = kv_indptr[b_idx + 1] - kv_indptr[b_idx]
T.tvm_storage_sync("shared")
Expand All @@ -2147,8 +2158,8 @@ def batch_tree_attn( # pylint: disable=too-many-branches
i, j = T.axis.remap("SS", [li, lj])
T.reads()
T.writes()
cur_L = L_start + i // group_size
cur_H_qo = H_qo_start + i % group_size
cur_L = q_indptr_val + (LH_start + i) // group_size
cur_H_qo = by * group_size + (LH_start + i) % group_size
if cur_L < q_indptr[b_idx + 1]:
Q_smem[i, j] = T.if_then_else(
rotary_mode == 1,
Expand Down Expand Up @@ -2203,13 +2214,15 @@ def batch_tree_attn( # pylint: disable=too-many-branches
m_prev[i] = m_smem[row]
m_new[i] = m_smem[row]
# mask out of kv_chunk_len S
row_: T.int32 = (LH_start + row) // group_size
for j in T.serial(tile_z):
if _tree_mask(row=tile_id[0] * L_per_cta + row // group_size,
col=L_kv_start + j,
mask_ptr=mask,
offset=mn_indptr[b_idx],
stride=q_indptr[b_idx + 1] - q_indptr[b_idx],
kv_len=kv_chunk_len[0]):
if _tree_mask(
row=row_,
col=L_kv_start + j,
mask_ptr=mask,
offset=mn_indptr[b_idx],
stride=q_indptr[b_idx + 1] - q_indptr[b_idx],
kv_len=kv_chunk_len[0]):
m_new[i] = T.max(m_new[i], S_smem[row, j])
d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i])

Expand All @@ -2219,12 +2232,14 @@ def batch_tree_attn( # pylint: disable=too-many-branches
for j in T.serial(tile_z):
# this is to avoid sync inside condition branch
if row < tile_x:
if _tree_mask(row=tile_id[0] * L_per_cta + row // group_size,
col=L_kv_start + j,
mask_ptr=mask,
offset=mn_indptr[b_idx],
stride=q_indptr[b_idx + 1] - q_indptr[b_idx],
kv_len=kv_chunk_len[0]):
row_: T.int32 = (LH_start + row) // group_size
if _tree_mask(
row=row_,
col=L_kv_start + j,
mask_ptr=mask,
offset=mn_indptr[b_idx],
stride=q_indptr[b_idx + 1] - q_indptr[b_idx],
kv_len=kv_chunk_len[0]):
S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i])
else:
S_smem[row, j] = T.exp2(-5e4 - m_new[i])
Expand Down Expand Up @@ -2253,15 +2268,19 @@ def batch_tree_attn( # pylint: disable=too-many-branches
for li, lj in T.grid(tile_x, tile_y):
with T.block("O_store"):
i, j = T.axis.remap("SS", [li, lj])
if L_start + i // group_size < q_indptr[b_idx + 1]:
output[L_start + i // group_size, H_qo_start + i % group_size, j] = O_local[i, j] / d_smem[i]
cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // group_size
cur_H_qo: T.int32 = by * group_size + (LH_start + i) % group_size
if cur_L < q_indptr[b_idx + 1]:
output[cur_L, cur_H_qo, j] = O_local[i, j] / d_smem[i]

# Store LSE to gmem
for li in T.grid(tile_x):
with T.block("lse_store"):
i = T.axis.remap("S", [li])
if L_start + i // group_size < q_indptr[b_idx + 1]:
lse[L_start + i // group_size, H_qo_start + i % group_size] = m_smem[i] + T.log2(d_smem[i])
cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // group_size
cur_H_qo: T.int32 = by * group_size + (LH_start + i) % group_size
if cur_L < q_indptr[b_idx + 1]:
lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i])

# move to next tile
tile_id[0] += NUM_BLKS
Expand Down

0 comments on commit 4183229

Please sign in to comment.