From 4183229922ad33c2006954140bc5ef368d40df21 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Sun, 9 Jun 2024 11:44:58 -0400 Subject: [PATCH] [KVCache][Test] Fix TIR attn kernels for uncommon group size (#17074) This PR fixes the TIR attention kernels in PagedKVCache tests, which had issues when handling uncommon GQA group size (e.g., 6). --- ...me_builtin_paged_attention_kv_cache_tir.py | 101 +++++++++++------- 1 file changed, 60 insertions(+), 41 deletions(-) diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py index c5c88211ba18..af55b194fb9a 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py @@ -1181,8 +1181,8 @@ def batch_prefill_paged_kv( if T.tvm_thread_invariant(batch_idx[0] < batch_size): b_idx: T.int32 = batch_idx[0] - L_start: T.int32 = q_indptr[b_idx] + tile_id[0] * L_per_cta - H_qo_start: T.int32 = by * group_size + LH_start: T.int32 = tile_id[0] * tile_x + q_indptr_val: T.int32 = q_indptr[b_idx] cur_page_indptr_begin: T.int32 = page_indptr[b_idx] cur_page_indptr_end: T.int32 = page_indptr[b_idx + 1] @@ -1212,8 +1212,8 @@ def batch_prefill_paged_kv( i, j = T.axis.remap("SS", [li, lj]) T.reads() T.writes() - cur_L = L_start + i // group_size - cur_H_qo = H_qo_start + i % group_size + cur_L = q_indptr_val + (LH_start + i) // group_size + cur_H_qo = by * group_size + (LH_start + i) % group_size if cur_L < q_indptr[b_idx + 1]: Q_smem[i, j] = T.if_then_else( rotary_mode == 1, @@ -1282,9 +1282,10 @@ def batch_prefill_paged_kv( m_prev[i] = m_smem[row] m_new[i] = m_smem[row] # mask out of kv_chunk_len S + row_: T.int32 = (LH_start + row) // group_size for j in T.serial(tile_z): if _causal_mask(causal, - row=tile_id[0] * L_per_cta + row // group_size, + row=row_, col=L_kv_start + j, kv_len=kv_chunk_len[0], qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx]): @@ -1297,8 +1298,9 @@ def batch_prefill_paged_kv( for j in T.serial(tile_z): # this is to avoid sync inside condition branch if row < tile_x: + row_: T.int32 = (LH_start + row) // group_size if _causal_mask(causal, - row=tile_id[0] * L_per_cta + row // group_size, + row=row_, col=L_kv_start + j, kv_len=kv_chunk_len[0], qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx]): @@ -1330,15 +1332,19 @@ def batch_prefill_paged_kv( for li, lj in T.grid(tile_x, tile_y): with T.block("O_store"): i, j = T.axis.remap("SS", [li, lj]) - if L_start + i // group_size < q_indptr[b_idx + 1]: - output[L_start + i // group_size, H_qo_start + i % group_size, j] = O_local[i, j] / d_smem[i] + cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // group_size + cur_H_qo: T.int32 = by * group_size + (LH_start + i) % group_size + if cur_L < q_indptr[b_idx + 1]: + output[cur_L, cur_H_qo, j] = O_local[i, j] / d_smem[i] # Store LSE to gmem for li in T.grid(tile_x): with T.block("lse_store"): i = T.axis.remap("S", [li]) - if L_start + i // group_size < q_indptr[b_idx + 1]: - lse[L_start + i // group_size, H_qo_start + i % group_size] = m_smem[i] + T.log2(d_smem[i]) + cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // group_size + cur_H_qo: T.int32 = by * group_size + (LH_start + i) % group_size + if cur_L < q_indptr[b_idx + 1]: + lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i]) # move to next tile tile_id[0] += NUM_BLKS @@ -1688,7 +1694,6 @@ def _attention_prefill_ragged( bdx = 32 num_warps = 4 tile_x, tile_y, tile_z = 64 // ((DataType(dtype).bits + 7) // 8) // max(d // 128, 1), d, 16 - L_per_cta = tile_x // group_size # Otherwise we would exceed maxComputeWorkgroupStorageSize if ( @@ -1784,8 +1789,8 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-arguments,too-many-branc if T.tvm_thread_invariant(batch_idx[0] < batch_size): b_idx: T.int32 = batch_idx[0] - L_start: T.int32 = q_indptr[b_idx] + tile_id[0] * L_per_cta - H_qo_start: T.int32 = by * group_size + LH_start: T.int32 = tile_id[0] * tile_x + q_indptr_val: T.int32 = q_indptr[b_idx] kv_chunk_len[0] = kv_indptr[b_idx + 1] - kv_indptr[b_idx] T.tvm_storage_sync("shared") @@ -1809,8 +1814,8 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-arguments,too-many-branc i, j = T.axis.remap("SS", [li, lj]) T.reads() T.writes() - cur_L = L_start + i // group_size - cur_H_qo = H_qo_start + i % group_size + cur_L = q_indptr_val + (LH_start + i) // group_size + cur_H_qo = by * group_size + (LH_start + i) % group_size if cur_L < q_indptr[b_idx + 1]: Q_smem[i, j] = T.if_then_else( rotary_mode == 1, @@ -1874,9 +1879,10 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-arguments,too-many-branc m_prev[i] = m_smem[row] m_new[i] = m_smem[row] # mask out of kv_chunk_len S + row_: T.int32 = (LH_start + row) // group_size for j in T.serial(tile_z): if _causal_mask(causal, - row=tile_id[0] * L_per_cta + row // group_size, + row=row_, col=L_kv_start + j, kv_len=kv_chunk_len[0], qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx]): @@ -1889,8 +1895,9 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-arguments,too-many-branc for j in T.serial(tile_z): # this is to avoid sync inside condition branch if row < tile_x: + row_: T.int32 = (LH_start + row) // group_size if _causal_mask(causal, - row=tile_id[0] * L_per_cta + row // group_size, + row=row_, col=L_kv_start + j, kv_len=kv_chunk_len[0], qo_len=q_indptr[b_idx + 1] - q_indptr[b_idx]): @@ -1922,15 +1929,19 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-arguments,too-many-branc for li, lj in T.grid(tile_x, tile_y): with T.block("O_store"): i, j = T.axis.remap("SS", [li, lj]) - if L_start + i // group_size < q_indptr[b_idx + 1]: - output[L_start + i // group_size, H_qo_start + i % group_size, j] = O_local[i, j] / d_smem[i] + cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // group_size + cur_H_qo: T.int32 = by * group_size + (LH_start + i) % group_size + if cur_L < q_indptr[b_idx + 1]: + output[cur_L, cur_H_qo, j] = O_local[i, j] / d_smem[i] # Store LSE to gmem for li in T.grid(tile_x): with T.block("lse_store"): i = T.axis.remap("S", [li]) - if L_start + i // group_size < q_indptr[b_idx + 1]: - lse[L_start + i // group_size, H_qo_start + i % group_size] = m_smem[i] + T.log2(d_smem[i]) + cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // group_size + cur_H_qo: T.int32 = by * group_size + (LH_start + i) % group_size + if cur_L < q_indptr[b_idx + 1]: + lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i]) # move to next tile tile_id[0] += NUM_BLKS @@ -2122,8 +2133,8 @@ def batch_tree_attn( # pylint: disable=too-many-branches if T.tvm_thread_invariant(batch_idx[0] < batch_size): b_idx: T.int32 = batch_idx[0] - L_start: T.int32 = q_indptr[b_idx] + tile_id[0] * L_per_cta - H_qo_start: T.int32 = by * group_size + LH_start: T.int32 = tile_id[0] * tile_x + q_indptr_val: T.int32 = q_indptr[b_idx] kv_chunk_len[0] = kv_indptr[b_idx + 1] - kv_indptr[b_idx] T.tvm_storage_sync("shared") @@ -2147,8 +2158,8 @@ def batch_tree_attn( # pylint: disable=too-many-branches i, j = T.axis.remap("SS", [li, lj]) T.reads() T.writes() - cur_L = L_start + i // group_size - cur_H_qo = H_qo_start + i % group_size + cur_L = q_indptr_val + (LH_start + i) // group_size + cur_H_qo = by * group_size + (LH_start + i) % group_size if cur_L < q_indptr[b_idx + 1]: Q_smem[i, j] = T.if_then_else( rotary_mode == 1, @@ -2203,13 +2214,15 @@ def batch_tree_attn( # pylint: disable=too-many-branches m_prev[i] = m_smem[row] m_new[i] = m_smem[row] # mask out of kv_chunk_len S + row_: T.int32 = (LH_start + row) // group_size for j in T.serial(tile_z): - if _tree_mask(row=tile_id[0] * L_per_cta + row // group_size, - col=L_kv_start + j, - mask_ptr=mask, - offset=mn_indptr[b_idx], - stride=q_indptr[b_idx + 1] - q_indptr[b_idx], - kv_len=kv_chunk_len[0]): + if _tree_mask( + row=row_, + col=L_kv_start + j, + mask_ptr=mask, + offset=mn_indptr[b_idx], + stride=q_indptr[b_idx + 1] - q_indptr[b_idx], + kv_len=kv_chunk_len[0]): m_new[i] = T.max(m_new[i], S_smem[row, j]) d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i]) @@ -2219,12 +2232,14 @@ def batch_tree_attn( # pylint: disable=too-many-branches for j in T.serial(tile_z): # this is to avoid sync inside condition branch if row < tile_x: - if _tree_mask(row=tile_id[0] * L_per_cta + row // group_size, - col=L_kv_start + j, - mask_ptr=mask, - offset=mn_indptr[b_idx], - stride=q_indptr[b_idx + 1] - q_indptr[b_idx], - kv_len=kv_chunk_len[0]): + row_: T.int32 = (LH_start + row) // group_size + if _tree_mask( + row=row_, + col=L_kv_start + j, + mask_ptr=mask, + offset=mn_indptr[b_idx], + stride=q_indptr[b_idx + 1] - q_indptr[b_idx], + kv_len=kv_chunk_len[0]): S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i]) else: S_smem[row, j] = T.exp2(-5e4 - m_new[i]) @@ -2253,15 +2268,19 @@ def batch_tree_attn( # pylint: disable=too-many-branches for li, lj in T.grid(tile_x, tile_y): with T.block("O_store"): i, j = T.axis.remap("SS", [li, lj]) - if L_start + i // group_size < q_indptr[b_idx + 1]: - output[L_start + i // group_size, H_qo_start + i % group_size, j] = O_local[i, j] / d_smem[i] + cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // group_size + cur_H_qo: T.int32 = by * group_size + (LH_start + i) % group_size + if cur_L < q_indptr[b_idx + 1]: + output[cur_L, cur_H_qo, j] = O_local[i, j] / d_smem[i] # Store LSE to gmem for li in T.grid(tile_x): with T.block("lse_store"): i = T.axis.remap("S", [li]) - if L_start + i // group_size < q_indptr[b_idx + 1]: - lse[L_start + i // group_size, H_qo_start + i % group_size] = m_smem[i] + T.log2(d_smem[i]) + cur_L: T.int32 = q_indptr[b_idx] + (LH_start + i) // group_size + cur_H_qo: T.int32 = by * group_size + (LH_start + i) % group_size + if cur_L < q_indptr[b_idx + 1]: + lse[cur_L, cur_H_qo] = m_smem[i] + T.log2(d_smem[i]) # move to next tile tile_id[0] += NUM_BLKS