From 3c91e3f1763d2a30a85187a3a606dbe4d1b9454d Mon Sep 17 00:00:00 2001 From: yuehuayingxueluo <867460659@qq.com> Date: Thu, 25 Apr 2024 23:11:30 +0800 Subject: [PATCH] [Inference]Adapt to baichuan2 13B (#5614) * adapt to baichuan2 13B * adapt to baichuan2 13B * change BAICHUAN_MODEL_NAME_OR_PATH * fix test_decoding_attn.py * Modifications based on review comments. * change BAICHUAN_MODEL_NAME_OR_PATH * mv attn mask processes to test flash decoding * mv get_alibi_slopes baichuan modeling * fix bugs in test_baichuan.py --- colossalai/inference/flash_decoding_utils.py | 1 + .../inference/kv_cache/kvcache_manager.py | 9 +- .../modeling/models/nopadding_baichuan.py | 208 ++++++++++-- .../modeling/policy/nopadding_baichuan.py | 47 +-- .../kernel/triton/context_attn_unpad.py | 295 +++++++++++++++--- colossalai/kernel/triton/flash_decoding.py | 227 ++++++++++++-- tests/test_infer/test_models/test_baichuan.py | 36 ++- .../test_ops/triton/kernel_utils.py | 4 - .../triton/test_context_attn_unpad.py | 51 ++- .../test_ops/triton/test_decoding_attn.py | 42 ++- 10 files changed, 786 insertions(+), 134 deletions(-) diff --git a/colossalai/inference/flash_decoding_utils.py b/colossalai/inference/flash_decoding_utils.py index 7563d1e4ecb9..8f9534d6adf4 100644 --- a/colossalai/inference/flash_decoding_utils.py +++ b/colossalai/inference/flash_decoding_utils.py @@ -60,4 +60,5 @@ def initialize( self._mid_output_lse = torch.empty( size=(max_batch_size, num_attn_heads, kv_max_split_num), dtype=dtype, device=device ) + self._tensors_initialized = True diff --git a/colossalai/inference/kv_cache/kvcache_manager.py b/colossalai/inference/kv_cache/kvcache_manager.py index 27ceca426b08..8b9605a52e55 100644 --- a/colossalai/inference/kv_cache/kvcache_manager.py +++ b/colossalai/inference/kv_cache/kvcache_manager.py @@ -64,8 +64,15 @@ def __init__(self, config: InferenceConfig, model_config: PretrainedConfig, verb self.elem_size_in_bytes = torch.tensor([], dtype=self.dtype).element_size() self.num_layers = get_model_config_attr(model_config, "num_hidden_layers") self.head_num = get_model_config_attr(model_config, "num_attention_heads") - self.kv_head_num = get_model_config_attr(model_config, "num_key_value_heads") self.head_size = get_model_config_attr(model_config, "hidden_size") // self.head_num + + if hasattr(config, "num_key_value_heads"): + self.kv_head_num = getattr(config, "num_key_value_heads") + elif hasattr(config, "attribute_map") and hasattr(config, config.attribute_map["num_key_value_heads"]): + self.kv_head_num = getattr(config, config.attribute_map["num_key_value_heads"]) + else: + self.kv_head_num = self.head_num + assert ( self.kv_head_num % self.tp_size == 0 ), f"Cannot shard {self.kv_head_num} heads with tp size {self.tp_size}" diff --git a/colossalai/inference/modeling/models/nopadding_baichuan.py b/colossalai/inference/modeling/models/nopadding_baichuan.py index 893d45c1f2c4..8aaa448e4936 100644 --- a/colossalai/inference/modeling/models/nopadding_baichuan.py +++ b/colossalai/inference/modeling/models/nopadding_baichuan.py @@ -1,19 +1,83 @@ # This code is adapted from huggingface baichuan model: hhttps://huggingface.co/baichuan-inc/Baichuan2-13B-Base/blob/main/modeling_baichuan.py +import math from typing import Optional, Tuple import torch import torch.nn as nn from colossalai.inference.flash_decoding_utils import FDIntermTensors -from colossalai.inference.modeling.models.nopadding_llama import NopadLlamaAttention from colossalai.kernel.kernel_loader import InferenceOpsLoader +from colossalai.kernel.triton import ( + context_attention_unpadded, + copy_k_to_blocked_cache, + decoding_fused_rotary_embedding, + flash_decoding_attention, + rms_layernorm, + rotary_embedding, +) from colossalai.logging import get_dist_logger +logger = get_dist_logger(__name__) + +try: + from flash_attn import flash_attn_varlen_func + + use_flash_attn2 = True +except ImportError: + use_flash_attn2 = False + logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.") + inference_ops = InferenceOpsLoader().load() logger = get_dist_logger(__name__) +# alibi slopes calculation adapted from https://github.com/huggingface/transformers/blob/v4.36.0/src/transformers/models/bloom/modeling_bloom.py#L57 +def get_alibi_slopes(num_heads: int, device: torch.device) -> torch.Tensor: + closest_power_of_2 = 2 ** math.floor(math.log2(num_heads)) + base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32, device=device) + powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32, device=device) + slopes = torch.pow(base, powers) + if closest_power_of_2 != num_heads: + extra_base = torch.tensor( + 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32, device=device + ) + num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2) + extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=torch.int32, device=device) + slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) + return slopes + + +def baichuan_rmsnorm_forward( + self, + hidden_states: torch.Tensor, + norm_output: torch.Tensor, + residual: torch.Tensor = None, + use_cuda_kernel: bool = True, +): + # Used to address the issue of inconsistent epsilon variable names in baichuan2 7b and 13b. + if hasattr(self, "variance_epsilon"): + eps = self.variance_epsilon + elif hasattr(self, "epsilon"): + eps = self.epsilon + else: + TypeError( + "Currently, the variable name for the epsilon of baichuan7B/13B should be 'variance_epsilon' or 'epsilon'." + ) + + if use_cuda_kernel: + if residual is not None: + inference_ops.fused_add_rms_layernorm(hidden_states, residual, self.weight.data, eps) + return hidden_states, residual + + if norm_output is None: + norm_output = torch.empty_like(hidden_states) + inference_ops.rms_layernorm(norm_output, hidden_states, self.weight.data, eps) + return norm_output, hidden_states + else: + return rms_layernorm(hidden_states, self.weight.data, eps, norm_output, residual) + + class NopadBaichuanAttention(nn.Module): def __init__( self, @@ -39,9 +103,11 @@ def __init__( self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads - - # Used to adapt llama_base_attn_forward - self.num_key_value_heads = self.num_heads + self.alibi_slopes = None + self.use_alibi_attn = False + if self.hidden_size == 5120: + self.use_alibi_attn = True + self.alibi_slopes = get_alibi_slopes(self.num_heads, device=attn_qproj_w.device) qkv_weight_list = [attn_qproj_w, attn_kproj_w, attn_vproj_w] self.qkv_weight = torch.stack(qkv_weight_list, dim=0) @@ -112,26 +178,124 @@ def forward( high_precision(Optional[bool]): Whether to use float32 for underlying calculations of float16 data to achieve higher precision, defaults to False. """ - return NopadLlamaAttention.forward( - self, - hidden_states=hidden_states, - block_tables=block_tables, - k_cache=k_cache, - v_cache=v_cache, - sequence_lengths=sequence_lengths, - cos_sin=cos_sin, - fd_inter_tensor=fd_inter_tensor, - is_prompts=is_prompts, - is_verifier=is_verifier, - tokens_to_verify=tokens_to_verify, - kv_seq_len=kv_seq_len, - output_tensor=output_tensor, - sm_scale=sm_scale, - use_cuda_kernel=use_cuda_kernel, - cu_seqlens=cu_seqlens, - high_precision=high_precision, + token_nums = hidden_states.size(0) + # fused qkv + hidden_states = hidden_states.expand(3, -1, -1) + query_states, key_states, value_states = ( + torch.bmm(hidden_states, self.qkv_weight).view(3, token_nums, self.num_heads, self.head_dim).unbind(0) ) + block_size = k_cache.size(-2) + + if is_prompts: + if ( + not is_verifier + and use_cuda_kernel + and query_states.dtype != torch.float32 + and use_flash_attn2 + and not self.use_alibi_attn + ): + # flash attn 2 currently only supports FP16/BF16. + inference_ops.rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], high_precision) + inference_ops.context_kv_cache_memcpy( + key_states, value_states, k_cache, v_cache, sequence_lengths, cu_seqlens, block_tables, kv_seq_len + ) + + attn_output = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens, + cu_seqlens_k=cu_seqlens, + max_seqlen_q=kv_seq_len, + max_seqlen_k=kv_seq_len, + dropout_p=0.0, + softmax_scale=sm_scale, + causal=True, + ) + attn_output = attn_output.view(token_nums, -1) + else: + if not self.use_alibi_attn: + rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) + attn_output = context_attention_unpadded( + q=query_states, + k=key_states, + v=value_states, + k_cache=k_cache, + v_cache=v_cache, + context_lengths=sequence_lengths, + block_tables=block_tables, + block_size=block_size, + output=output_tensor, + alibi_slopes=self.alibi_slopes, + max_seq_len=kv_seq_len, + sm_scale=sm_scale, + ) + else: + q_len = tokens_to_verify + 1 if is_verifier else 1 + + if use_cuda_kernel: + if not self.use_alibi_attn: + inference_ops.rotary_embedding_and_cache_copy( + query_states, + key_states, + value_states, + cos_sin[0], + cos_sin[1], + k_cache, + v_cache, + sequence_lengths, + block_tables, + high_precision, + ) + else: + inference_ops.decode_kv_cache_memcpy( + key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables + ) + else: + if not is_verifier and not self.use_alibi_attn: + decoding_fused_rotary_embedding( + query_states, + key_states, + value_states, + cos_sin[0], + cos_sin[1], + k_cache, + v_cache, + block_tables, + sequence_lengths, + ) + else: + if not self.use_alibi_attn: + rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) + copy_k_to_blocked_cache( + key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len + ) + copy_k_to_blocked_cache( + value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len + ) + + attn_output = flash_decoding_attention( + q=query_states, + k_cache=k_cache, + v_cache=v_cache, + kv_seq_len=sequence_lengths, + block_tables=block_tables, + block_size=block_size, + max_seq_len_in_batch=kv_seq_len, + output=output_tensor, + mid_output=fd_inter_tensor.mid_output, + mid_output_lse=fd_inter_tensor.mid_output_lse, + alibi_slopes=self.alibi_slopes, + sm_scale=sm_scale, + q_len=q_len, + ) + + attn_output = attn_output.view(-1, self.hidden_size) + attn_output = torch.mm(attn_output, self.o_proj_weight) + + return attn_output + # NOTE This will cause difference as out length increases. class NopadBaichuanMLP(nn.Module): diff --git a/colossalai/inference/modeling/policy/nopadding_baichuan.py b/colossalai/inference/modeling/policy/nopadding_baichuan.py index 64dc40dbc0b9..12975aceae8a 100644 --- a/colossalai/inference/modeling/policy/nopadding_baichuan.py +++ b/colossalai/inference/modeling/policy/nopadding_baichuan.py @@ -1,12 +1,15 @@ import torch.nn as nn from torch.nn import Parameter -from colossalai.inference.modeling.models.nopadding_baichuan import NopadBaichuanAttention, NopadBaichuanMLP +from colossalai.inference.modeling.models.nopadding_baichuan import ( + NopadBaichuanAttention, + NopadBaichuanMLP, + baichuan_rmsnorm_forward, +) from colossalai.inference.modeling.models.nopadding_llama import ( llama_causal_lm_forward, llama_decoder_layer_forward, llama_model_forward, - llama_rmsnorm_forward, ) from colossalai.inference.utils import init_to_get_rotary from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription @@ -21,26 +24,30 @@ def module_policy(self): policy = super().module_policy() decoder_attribute_replacement = { - "lm_head.weight": Parameter( - nn.functional.normalize(self.model.lm_head.weight).transpose(0, 1), requires_grad=False - ), + "lm_head.weight": Parameter(nn.functional.normalize(self.model.lm_head.weight), requires_grad=False), } policy["BaichuanForCausalLM"] = ModulePolicyDescription( attribute_replacement=decoder_attribute_replacement, ) - policy["DecoderLayer"] = ModulePolicyDescription( - sub_module_replacement=[ - SubModuleReplacementDescription( - suffix="mlp", - target_module=NopadBaichuanMLP, - ), - SubModuleReplacementDescription( - suffix="self_attn", - target_module=NopadBaichuanAttention, - ), - ] - ) + # used for relpacing Baichuan 7B/13B decoder layer + for layer_name in ["DecoderLayer", "BaichuanLayer"]: + policy[layer_name] = ModulePolicyDescription( + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="mlp", + target_module=NopadBaichuanMLP, + ), + SubModuleReplacementDescription( + suffix="self_attn", + target_module=NopadBaichuanAttention, + ), + ] + ) + + self.append_or_create_method_replacement( + description={"forward": llama_decoder_layer_forward}, policy=policy, target_key=layer_name + ) self.append_or_create_method_replacement( description={"forward": llama_causal_lm_forward}, policy=policy, target_key="BaichuanForCausalLM" @@ -48,11 +55,9 @@ def module_policy(self): self.append_or_create_method_replacement( description={"forward": llama_model_forward}, policy=policy, target_key="BaichuanModel" ) + self.append_or_create_method_replacement( - description={"forward": llama_decoder_layer_forward}, policy=policy, target_key="DecoderLayer" - ) - self.append_or_create_method_replacement( - description={"forward": llama_rmsnorm_forward}, policy=policy, target_key="RMSNorm" + description={"forward": baichuan_rmsnorm_forward}, policy=policy, target_key="RMSNorm" ) return policy diff --git a/colossalai/kernel/triton/context_attn_unpad.py b/colossalai/kernel/triton/context_attn_unpad.py index 3f494b97f4ef..a7b5242ff8fd 100644 --- a/colossalai/kernel/triton/context_attn_unpad.py +++ b/colossalai/kernel/triton/context_attn_unpad.py @@ -185,6 +185,192 @@ def _fwd_context_paged_attention_kernel( return +# Triton 2.1.0 +@triton.jit +def _alibi_fwd_context_paged_attention_kernel( + Q, + K, + V, + O, + KCache, + VCache, + BLOCK_TABLES, # [num_seqs, max_blocks_per_sequence] + batch_size, + alibi_slopes, + stride_qt, + stride_qh, + stride_qd, + stride_kt, + stride_kh, + stride_kd, + stride_vt, + stride_vh, + stride_vd, + stride_ot, + stride_oh, + stride_od, + stride_cacheb, + stride_cacheh, + stride_cachebs, + stride_cached, + stride_bts, + stride_btb, + context_lengths, + sm_scale, + KV_GROUPS: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + cur_seq_idx = tl.program_id(0) + if cur_seq_idx >= batch_size: + return + cur_head_idx = tl.program_id(1) + block_start_m = tl.program_id(2) # Br, max_input_len // Block_M + cur_kv_head_idx = cur_head_idx // KV_GROUPS + + global_block_start_offest = block_start_m * BLOCK_M + + # NOTE It requires BLOCK_M, BLOCK_N, and BLOCK_SIZE to be the same + tl.static_assert(BLOCK_M == BLOCK_N) + tl.static_assert(BLOCK_N == BLOCK_SIZE) + + # get the current sequence length from provided context lengths tensor + cur_seq_len = tl.load(context_lengths + cur_seq_idx) + # NOTE when talking to fused QKV and a nopadding context attention, + # we assume that the input Q/K/V is contiguous, and thus here `prev_seq_len_sum` + # could be considered as the start index of the current sequence. + # FIXME might want to explore better way to get the summation of prev seq lengths. + # `tl.sum(tensor[:end])` is invalid as tensor slice is not supported in triton. + prev_seq_len_sum = 0 + for i in range(0, cur_seq_idx): + prev_seq_len_sum += tl.load(context_lengths + i) + + offset_q = prev_seq_len_sum * stride_qt + cur_head_idx * stride_qh + offset_kv = prev_seq_len_sum * stride_kt + cur_kv_head_idx * stride_kh + Q_block_ptr = tl.make_block_ptr( + base=Q + offset_q, + shape=(cur_seq_len, HEAD_DIM), + strides=(stride_qt, stride_qd), + offsets=(global_block_start_offest, 0), + block_shape=(BLOCK_M, HEAD_DIM), + order=(1, 0), + ) + K_block_ptr = tl.make_block_ptr( + base=K + offset_kv, + shape=(HEAD_DIM, cur_seq_len), + strides=(stride_kd, stride_kt), + offsets=(0, 0), + block_shape=(HEAD_DIM, BLOCK_N), + order=(0, 1), + ) + V_block_ptr = tl.make_block_ptr( + base=V + offset_kv, + shape=(cur_seq_len, HEAD_DIM), + strides=(stride_vt, stride_vd), + offsets=(0, 0), + block_shape=(BLOCK_N, HEAD_DIM), + order=(1, 0), + ) + O_block_ptr = tl.make_block_ptr( + base=O + offset_q, + shape=(cur_seq_len, HEAD_DIM), + strides=(stride_ot, stride_od), + offsets=(global_block_start_offest, 0), + block_shape=(BLOCK_M, HEAD_DIM), + order=(1, 0), + ) + + # block table for the current sequence + block_table_ptr = BLOCK_TABLES + cur_seq_idx * stride_bts + # block indexes on block table (i.e. 0, 1, 2, ..., max_blocks_per_seq) + # Consider `block_start_m` as the logical block idx in the current block table, + # as we have BLOCK_M the same size as the block size. + cur_block_table_idx = block_start_m + cur_block_id = tl.load(block_table_ptr + cur_block_table_idx * stride_btb) + offset_kvcache = cur_block_id * stride_cacheb + cur_kv_head_idx * stride_cacheh + + offsets_m = global_block_start_offest + tl.arange(0, BLOCK_M) + offsets_n = tl.arange(0, BLOCK_N) + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + + # load alibi_slope + alibi_slope = tl.load(alibi_slopes + cur_head_idx) + m_alibi_offset = tl.arange(0, BLOCK_M)[:, None] + global_block_start_offest + n_alibi_offset = tl.arange(0, BLOCK_N)[None, :] + + if global_block_start_offest >= cur_seq_len: + return + + Q_i = tl.load(Q_block_ptr, boundary_check=(1, 0)) + + for block_start_n in range(0, (block_start_m + 1) * BLOCK_M, BLOCK_N): + block_start_n = tl.multiple_of(block_start_n, BLOCK_N) + + k = tl.load(K_block_ptr, boundary_check=(0, 1)) + S_ij = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + S_ij += tl.dot(Q_i, k) + S_ij *= sm_scale + S_ij += tl.where(offsets_m[:, None] >= (block_start_n + offsets_n[None, :]), 0, float("-inf")) + + alibi = (n_alibi_offset + block_start_n - m_alibi_offset) * alibi_slope + alibi = tl.where((alibi <= 0) & (m_alibi_offset < cur_seq_len), alibi, float("-inf")) + S_ij += alibi + + m_ij = tl.max(S_ij, 1) # rowmax(Sij) + m_ij = tl.maximum(m_i, m_ij) # m_ij + S_ij -= m_ij[:, None] + p_ij_hat = tl.exp(S_ij) + scale = tl.exp(m_i - m_ij) + l_ij = scale * l_i + tl.sum(p_ij_hat, 1) + acc = acc * scale[:, None] + + v = tl.load(V_block_ptr, boundary_check=(1, 0)) + p_ij_hat = p_ij_hat.to(v.type.element_ty) + + acc += tl.dot(p_ij_hat, v) + l_i = l_ij + m_i = m_ij + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + + acc = acc / l_i[:, None] + tl.store(O_block_ptr, acc.to(O.type.element_ty), boundary_check=(1, 0)) + + if cur_head_idx % KV_GROUPS == 0: + # Copy k to corresponding cache block + offsets_dmodel = tl.arange(0, HEAD_DIM) + offsets_kt = global_block_start_offest + tl.arange(0, BLOCK_M) + offsets_k = K + offset_kv + offsets_dmodel[None, :] * stride_kd + offsets_kt[:, None] * stride_kt + k = tl.load(offsets_k, mask=offsets_kt[:, None] < cur_seq_len, other=0.0) + offsets_kcachebs = tl.arange(0, BLOCK_SIZE) + offsets_kcache = ( + KCache + + offset_kvcache + + offsets_dmodel[None, :] * stride_cached + + offsets_kcachebs[:, None] * stride_cachebs + ) + tl.store(offsets_kcache, k, mask=offsets_kcachebs[:, None] < cur_seq_len - block_start_m * BLOCK_SIZE) + # Copy v to corresponding cache block + offsets_vd = offsets_dmodel + offsets_vt = block_start_m * BLOCK_N + tl.arange(0, BLOCK_N) + offsets_v = V + offset_kv + offsets_vt[None, :] * stride_vt + offsets_vd[:, None] * stride_vd + v = tl.load(offsets_v, mask=offsets_vt[None, :] < cur_seq_len, other=0.0) + offsets_vcachebs = offsets_kcachebs # same block size range, just to notify here + offsets_vcache = ( + VCache + + offset_kvcache + + offsets_vcachebs[None, :] * stride_cachebs + + offsets_dmodel[:, None] * stride_cached + ) + tl.store(offsets_vcache, v, mask=offsets_vcachebs[None, :] < cur_seq_len - block_start_m * BLOCK_SIZE) + + return + + def context_attention_unpadded( q: torch.Tensor, # [num_tokens, num_heads, head_dim] k: torch.Tensor, # [num_tokens, num_kv_heads, head_dim] @@ -195,6 +381,7 @@ def context_attention_unpadded( block_tables: torch.Tensor, # [num_seqs, max_blocks_per_sequence], block_size: int, output: torch.Tensor = None, # [num_tokens, num_heads, head_dim] + alibi_slopes: torch.Tensor = None, # [num_heads] max_seq_len: int = None, sm_scale: int = None, ): @@ -226,40 +413,78 @@ def context_attention_unpadded( # To optimize, revise batching/scheduling to batch 2^n sequences in a batch (preferred) grid = (triton.next_power_of_2(num_seqs), num_heads, triton.cdiv(max_seq_len, BLOCK_M)) - _fwd_context_paged_attention_kernel[grid]( - q, - k, - v, - output, - k_cache, - v_cache, - block_tables, - num_seqs, - q.stride(0), - q.stride(1), - q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), - v.stride(0), - v.stride(1), - v.stride(2), - output.stride(0), - head_dim, - 1, - k_cache.stride(0), - k_cache.stride(1), - k_cache.stride(2), - k_cache.stride(3), - block_tables.stride(0), - block_tables.stride(1), - context_lengths, - sm_scale, - num_kv_group, - block_size, - HEAD_DIM=Lk, - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, - ) + if alibi_slopes is not None: + _alibi_fwd_context_paged_attention_kernel[grid]( + q, + k, + v, + output, + k_cache, + v_cache, + block_tables, + num_seqs, + alibi_slopes, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + output.stride(0), + head_dim, + 1, + k_cache.stride(0), + k_cache.stride(1), + k_cache.stride(2), + k_cache.stride(3), + block_tables.stride(0), + block_tables.stride(1), + context_lengths, + sm_scale, + num_kv_group, + block_size, + HEAD_DIM=Lk, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ) + else: + _fwd_context_paged_attention_kernel[grid]( + q, + k, + v, + output, + k_cache, + v_cache, + block_tables, + num_seqs, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + output.stride(0), + head_dim, + 1, + k_cache.stride(0), + k_cache.stride(1), + k_cache.stride(2), + k_cache.stride(3), + block_tables.stride(0), + block_tables.stride(1), + context_lengths, + sm_scale, + num_kv_group, + block_size, + HEAD_DIM=Lk, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ) return output diff --git a/colossalai/kernel/triton/flash_decoding.py b/colossalai/kernel/triton/flash_decoding.py index dcbad7bc8bd9..200835ec3cba 100644 --- a/colossalai/kernel/triton/flash_decoding.py +++ b/colossalai/kernel/triton/flash_decoding.py @@ -124,6 +124,129 @@ def _flash_decoding_fwd_kernel( tl.store(mid_o_lse + offsets_mid_o_lse, m + tl.log(l)) +# Triton 2.1.0 +@triton.jit +def _alibi_flash_decoding_fwd_kernel( + Q, # [batch_size * q_len, head_num, head_dim] + KCache, # [num_blocks, num_kv_heads, block_size, head_dim] + VCache, # [num_blocks, num_kv_heads, block_size, head_dim] + block_tables, # [batch_size, max_blocks_per_sequence] + mid_o, # [batch_size * q_len, head_num, kv_split_num, head_dim] + mid_o_lse, # [batch_size * q_len, head_num, kv_split_num] + kv_seq_len, # [batch_size] + q_len, + batch_size, + alibi_slopes, + stride_qt, + stride_qh, + stride_qd, + stride_cacheb, + stride_cacheh, + stride_cachebs, + stride_cached, + stride_bts, + stride_btb, + stride_mid_ot, + stride_mid_oh, + stride_mid_ob, + stride_mid_od, + stride_mid_o_lset, + stride_mid_o_lseh, + stride_mid_o_lseb, + sm_scale, + KV_GROUPS: tl.constexpr, + BLOCK_KV: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + HEAD_DIM: tl.constexpr, +): + cur_token_idx = tl.program_id(0) + cur_seq_idx = cur_token_idx // q_len + if cur_seq_idx >= batch_size: + return + cur_token_off = (cur_token_idx % q_len) - q_len + 1 + cur_head_idx = tl.program_id(1) + block_start_kv = tl.program_id(2) # for splitting k/v + + # NOTE It requires BLOCK_KV and BLOCK_SIZE to be the same + # TODO might want to replace with BLOCK_KV % BLOCK_SIZE == 0 (optimize BLOCK_KV as multiple of BLOCK_SIZE) + # and then support calculating multiple kv cache blocks on an instance + tl.static_assert(BLOCK_KV == BLOCK_SIZE) + # get the current (kv) sequence length + # cur_token_off is used as a "mask" here for spec-dec during verification process + cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx) + cur_token_off + if block_start_kv * BLOCK_KV >= cur_kv_seq_len: + return + + offsets_dmodel = tl.arange(0, HEAD_DIM) + offsets_q = cur_token_idx * stride_qt + cur_head_idx * stride_qh + offsets_dmodel * stride_qd + q = tl.load(Q + offsets_q) + # block table for the current sequence + block_table_ptr = block_tables + cur_seq_idx * stride_bts + # cur_bt_start_idx = block_start_kv * (BLOCK_KV // BLOCK_SIZE) + # cur_block_id = tl.load(block_table_ptr + cur_bt_start_idx * stride_btb) + cur_block_id = tl.load(block_table_ptr + block_start_kv * stride_btb) + cur_occupied_size = tl.where( + (block_start_kv + 1) * BLOCK_SIZE <= cur_kv_seq_len, BLOCK_SIZE, cur_kv_seq_len - block_start_kv * BLOCK_SIZE + ) + tl.device_assert(cur_occupied_size >= 0) + + cur_kv_head_idx = cur_head_idx // KV_GROUPS + offset_kvcache = cur_block_id * stride_cacheb + cur_kv_head_idx * stride_cacheh + K_block_ptr = tl.make_block_ptr( + base=KCache + offset_kvcache, + shape=(cur_occupied_size, HEAD_DIM), + strides=(stride_cachebs, stride_cached), + offsets=(0, 0), + block_shape=(BLOCK_SIZE, HEAD_DIM), + order=(0, 1), + ) + V_block_ptr = tl.make_block_ptr( + base=VCache + offset_kvcache, + shape=(cur_occupied_size, HEAD_DIM), + strides=(stride_cachebs, stride_cached), + offsets=(0, 0), + block_shape=(BLOCK_SIZE, HEAD_DIM), + order=(0, 1), + ) + k_cur_block = tl.load(K_block_ptr) + v_cur_block = tl.load(V_block_ptr) + acc = tl.zeros([HEAD_DIM], dtype=tl.float32) + # use block size of the paged/blocked kv cache + S_ij = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + + alibi_slope = tl.load(alibi_slopes + cur_head_idx) + position_k_offset = block_start_kv * BLOCK_KV + tl.arange(0, BLOCK_SIZE) + + # NOTE a trick to come across triton's requirement that values in both first and second input shapes must be >= 16, + # Multiplying two tensors with shapes [1, d] * [d, block_size] will fail. + # Refer to https://github.com/openai/triton/discussions/895 + S_ij += tl.sum(q[None, :] * k_cur_block, 1) + S_ij *= sm_scale + S_ij -= alibi_slope * (cur_kv_seq_len - 1 - position_k_offset) + S_ij = tl.where(cur_kv_seq_len > position_k_offset, S_ij, float("-inf")) + + m = tl.max(S_ij, 0) + S_ij -= m + p_ij_hat = tl.exp(S_ij) + l = tl.sum(p_ij_hat, 0) + p_ij_hat = p_ij_hat.to(v_cur_block.type.element_ty) + acc += tl.sum(v_cur_block * p_ij_hat[:, None], 0) + acc = acc / l + + offsets_mid_o = ( + cur_token_idx * stride_mid_ot + + cur_head_idx * stride_mid_oh + + block_start_kv * stride_mid_ob + + offsets_dmodel * stride_mid_od + ) + tl.store(mid_o + offsets_mid_o, acc) + offsets_mid_o_lse = ( + cur_token_idx * stride_mid_o_lset + cur_head_idx * stride_mid_o_lseh + block_start_kv * stride_mid_o_lseb + ) + # logsumexp L^(j) = m^(j) + log(l^(j)) + tl.store(mid_o_lse + offsets_mid_o_lse, m + tl.log(l)) + + # Triton 2.1.0 @triton.jit def _flash_decoding_fwd_reduce_kernel( @@ -197,9 +320,10 @@ def flash_decoding_attention( output: torch.Tensor = None, mid_output: torch.Tensor = None, mid_output_lse: torch.Tensor = None, + alibi_slopes: torch.Tensor = None, sm_scale: int = None, kv_group_num: int = 1, - q_len: int = 1, + q_len: int = 1, # NOTE alibi flash decoding does not support q_len > 1 at this moment. ): """ Flash decoding implemented with a blocked KV Cache (PagedAttention) during decoding stage. @@ -220,6 +344,7 @@ def flash_decoding_attention( mid_output_lse (torch.Tensor): [max_bsz * q_len, num_heads, kv_max_split_num] Log-sum-exp of intermediate output. `max_bsz` should be greater than or equal to `bsz`. q_len > 1 only for verification process in speculative-decoding. + alibi_slopes (torch.Tensor): [num_heads] alibi slopes used for alibi flash decoding. block_size (int): Size of each block in the blocked key/value cache. num_kv_group (int, optional): Number of key/value groups. Defaults to 1. q_length (int): Query length. Use for speculative decoding when `q_length` > 1 (i.e. the last n tokens). @@ -280,38 +405,74 @@ def flash_decoding_attention( num_heads, triton.cdiv(triton.next_power_of_2(max_seq_len_in_batch), BLOCK_KV), ) - _flash_decoding_fwd_kernel[grid]( - q, - k_cache, - v_cache, - block_tables, - mid_output, - mid_output_lse, - kv_seq_len, - q_len, - bsz, - q.stride(0), - q.stride(1), - q.stride(2), - k_cache.stride(0), - k_cache.stride(1), - k_cache.stride(2), - k_cache.stride(3), - block_tables.stride(0), - block_tables.stride(1), - mid_output.stride(0), - mid_output.stride(1), - mid_output.stride(2), - mid_output.stride(3), - mid_output_lse.stride(0), - mid_output_lse.stride(1), - mid_output_lse.stride(2), - sm_scale, - KV_GROUPS=kv_group_num, - BLOCK_KV=block_size, - BLOCK_SIZE=block_size, - HEAD_DIM=head_dim, - ) + + if alibi_slopes is not None: + _alibi_flash_decoding_fwd_kernel[grid]( + q, + k_cache, + v_cache, + block_tables, + mid_output, + mid_output_lse, + kv_seq_len, + q_len, + bsz, + alibi_slopes, + q.stride(0), + q.stride(1), + q.stride(2), + k_cache.stride(0), + k_cache.stride(1), + k_cache.stride(2), + k_cache.stride(3), + block_tables.stride(0), + block_tables.stride(1), + mid_output.stride(0), + mid_output.stride(1), + mid_output.stride(2), + mid_output.stride(3), + mid_output_lse.stride(0), + mid_output_lse.stride(1), + mid_output_lse.stride(2), + sm_scale, + KV_GROUPS=kv_group_num, + BLOCK_KV=block_size, + BLOCK_SIZE=block_size, + HEAD_DIM=head_dim, + ) + else: + _flash_decoding_fwd_kernel[grid]( + q, + k_cache, + v_cache, + block_tables, + mid_output, + mid_output_lse, + kv_seq_len, + q_len, + bsz, + q.stride(0), + q.stride(1), + q.stride(2), + k_cache.stride(0), + k_cache.stride(1), + k_cache.stride(2), + k_cache.stride(3), + block_tables.stride(0), + block_tables.stride(1), + mid_output.stride(0), + mid_output.stride(1), + mid_output.stride(2), + mid_output.stride(3), + mid_output_lse.stride(0), + mid_output_lse.stride(1), + mid_output_lse.stride(2), + sm_scale, + KV_GROUPS=kv_group_num, + BLOCK_KV=block_size, + BLOCK_SIZE=block_size, + HEAD_DIM=head_dim, + ) grid = (triton.next_power_of_2(bsz * q_len), num_heads) _flash_decoding_fwd_reduce_kernel[grid]( diff --git a/tests/test_infer/test_models/test_baichuan.py b/tests/test_infer/test_models/test_baichuan.py index 5ca67c5be7b4..27b0c86203a7 100644 --- a/tests/test_infer/test_models/test_baichuan.py +++ b/tests/test_infer/test_models/test_baichuan.py @@ -12,7 +12,8 @@ from colossalai.inference.flash_decoding_utils import FDIntermTensors from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn -BAICHUAN_MODEL_NAME_OR_PATH = "baichuan-inc/Baichuan2-7B-Base" +# BAICHUAN_MODEL_NAME_OR_PATH = "baichuan-inc/Baichuan2-7B-Base" +BAICHUAN_MODEL_NAME_OR_PATH = "/home/data/models/Baichuan2-13B-Base" def setup_seed(seed): @@ -22,12 +23,10 @@ def setup_seed(seed): random.seed(seed) -def check_inference_engine(use_engine=False, prompt_template=None): +def check_inference_engine(use_engine=False, do_sample=False, use_cuda_kernel=False, prompt_template=None): setup_seed(20) tokenizer = AutoTokenizer.from_pretrained(BAICHUAN_MODEL_NAME_OR_PATH, use_fast=False, trust_remote_code=True) - model = AutoModelForCausalLM.from_pretrained( - BAICHUAN_MODEL_NAME_OR_PATH, device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True - ).cuda() + model = AutoModelForCausalLM.from_pretrained(BAICHUAN_MODEL_NAME_OR_PATH, trust_remote_code=True).half().cuda() model = model.eval() inputs = [ @@ -35,17 +34,24 @@ def check_inference_engine(use_engine=False, prompt_template=None): ] output_len = 38 - do_sample = False + do_sample = do_sample + + if do_sample: + top_p = 0.5 + top_k = 50 + else: + top_p = None + top_k = None if use_engine: inference_config = InferenceConfig( - max_output_len=output_len, prompt_template=prompt_template, dtype="fp32", use_cuda_kernel=True + max_output_len=output_len, prompt_template=prompt_template, use_cuda_kernel=use_cuda_kernel ) inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True) assert inference_engine.generation_config.max_new_tokens == output_len inference_engine.add_request(prompts=inputs) assert inference_engine.request_handler._has_waiting() - generation_config = GenerationConfig(do_sample=do_sample) + generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k) outputs = inference_engine.generate(generation_config=generation_config) else: if prompt_template: @@ -57,6 +63,8 @@ def check_inference_engine(use_engine=False, prompt_template=None): inputs = inputs.cuda() generation_config = GenerationConfig( do_sample=do_sample, + top_p=top_p, + top_k=top_k, pad_token_id=tokenizer.pad_token_id, max_new_tokens=output_len, ) @@ -67,9 +75,15 @@ def check_inference_engine(use_engine=False, prompt_template=None): @parameterize("prompt_template", [None, "baichuan"]) -def check_output_consistency(prompt_template): - cai_outputs = check_inference_engine(use_engine=True, prompt_template=prompt_template) - transformer_outputs = check_inference_engine(use_engine=False, prompt_template=prompt_template) +@parameterize("do_sample", [True, False]) +@parameterize("use_cuda_kernel", [True, False]) +def check_output_consistency(prompt_template, do_sample, use_cuda_kernel): + cai_outputs = check_inference_engine( + use_engine=True, do_sample=do_sample, use_cuda_kernel=use_cuda_kernel, prompt_template=prompt_template + ) + transformer_outputs = check_inference_engine( + use_engine=False, do_sample=do_sample, use_cuda_kernel=use_cuda_kernel, prompt_template=prompt_template + ) for s1, s2 in zip(cai_outputs, transformer_outputs): assert s1 == s2, f"\nColossalAI Output: {s1}\nTransformers Output: {s2}" diff --git a/tests/test_infer/test_ops/triton/kernel_utils.py b/tests/test_infer/test_ops/triton/kernel_utils.py index 6bb947d00c1e..916691228e7c 100644 --- a/tests/test_infer/test_ops/triton/kernel_utils.py +++ b/tests/test_infer/test_ops/triton/kernel_utils.py @@ -64,10 +64,6 @@ def torch_attn_ref( assert attn_scores.shape == (bsz, num_heads, q_len, kv_len), "Invalid shape of attention scores" if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_len)}, but is {attention_mask.size()}" - ) attn_scores = attn_scores + attention_mask attn_weights = F.softmax(attn_scores.to(dtype=torch.float32), dim=-1).to(dtype=q.dtype) diff --git a/tests/test_infer/test_ops/triton/test_context_attn_unpad.py b/tests/test_infer/test_ops/triton/test_context_attn_unpad.py index 2b758c903c26..70f367c0987e 100644 --- a/tests/test_infer/test_ops/triton/test_context_attn_unpad.py +++ b/tests/test_infer/test_ops/triton/test_context_attn_unpad.py @@ -2,6 +2,7 @@ import torch from packaging import version +from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes from colossalai.kernel.triton import context_attention_unpadded from colossalai.utils import get_current_device from tests.test_infer.test_ops.triton.kernel_utils import generate_caches_and_block_tables_v2, torch_attn_ref @@ -19,8 +20,31 @@ HEAD_DIM = 32 +def _fill_with_neg_inf(t): + return t.float().fill_(float("-inf")).type_as(t) + + +# alibi mask calculation adapted from https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/main/modeling_baichuan.py +def generate_alibi_mask(slopes, num_heads, max_seq_len, device): + token_position = torch.arange(max_seq_len, device=device) - max_seq_len + 1 + token_position = token_position.unsqueeze(0).unsqueeze(0).expand(num_heads, -1, -1) + diag = torch.diag(token_position[0]) + token_position = token_position - diag.unsqueeze(0).unsqueeze(0).transpose(-1, -2) + alibi = slopes.unsqueeze(1).unsqueeze(1) * token_position + alibi = alibi.view(num_heads, 1, max_seq_len) + alibi_mask = torch.triu(_fill_with_neg_inf(torch.zeros([max_seq_len, max_seq_len], device=device)), 1) + alibi_mask = alibi_mask.unsqueeze(0) + alibi + return alibi_mask + + def torch_attn_unpad( - q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context_lengths: torch.Tensor, num_heads: int, num_kv_heads: int + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + context_lengths: torch.Tensor, + num_heads: int, + num_kv_heads: int, + slopes: torch.Tensor = None, ): # Process sequence one by one and concatenate them together. # q,k,v [num_tokens(sum(context_lengths)), num_heads, head_dim] @@ -35,6 +59,10 @@ def torch_attn_unpad( mask = torch.tril(torch.ones(1, 1, seq_len, seq_len), diagonal=0).to(device=q.device) mask[mask == 0.0] = float("-inf") + if slopes != None: + alibi_mask = generate_alibi_mask(slopes, num_heads, seq_len, q.device) + mask = mask + alibi_mask + torch_attn_ref_out = torch_attn_ref( q[start_idx:end_idx].unsqueeze(0).transpose(1, 2), k[start_idx:end_idx].unsqueeze(0).transpose(1, 2), @@ -60,6 +88,7 @@ def torch_attn_unpad( @pytest.mark.parametrize("num_attn_heads", [16]) @pytest.mark.parametrize("kv_group_num", [1, 2, 16]) @pytest.mark.parametrize("same_context_len", [True, False]) +@pytest.mark.parametrize("use_alibi_slopes", [True, False]) def test_context_attention( bsz: int, block_size: int, @@ -67,6 +96,7 @@ def test_context_attention( num_attn_heads: int, kv_group_num: int, same_context_len: bool, + use_alibi_slopes: bool, ): torch.manual_seed(123) # It's necessary to clear cache here. @@ -79,6 +109,10 @@ def test_context_attention( max_seq_len = max_num_blocks_per_seq * block_size dtype = torch.float16 device = get_current_device() + alibi_slopes = None + + if use_alibi_slopes: + alibi_slopes = get_alibi_slopes(num_attn_heads, device) if same_context_len: context_lengths = torch.tensor([max_seq_len for _ in range(bsz)], dtype=torch.int32, device=device) @@ -100,12 +134,19 @@ def test_context_attention( _, num_heads, head_dim = q_unpad.shape out_triton = context_attention_unpadded( - q_unpad, k_unpad, v_unpad, k_cache_triton, v_cache_triton, context_lengths, block_tables, block_size + q_unpad, + k_unpad, + v_unpad, + k_cache_triton, + v_cache_triton, + context_lengths, + block_tables, + block_size, + alibi_slopes=alibi_slopes, ) out_triton = out_triton.view(-1, num_heads, head_dim) - - out_torch = torch_attn_unpad(q_unpad, k_unpad, v_unpad, context_lengths, num_attn_heads, num_kv_heads) + out_torch = torch_attn_unpad(q_unpad, k_unpad, v_unpad, context_lengths, num_attn_heads, num_kv_heads, alibi_slopes) assert out_torch.shape == out_triton.shape assert torch.allclose(out_torch, out_triton, atol=1e-3) @@ -114,4 +155,4 @@ def test_context_attention( if __name__ == "__main__": - test_context_attention(4, 32, 8, 16, 1, True) + test_context_attention(4, 32, 8, 16, 1, True, True) diff --git a/tests/test_infer/test_ops/triton/test_decoding_attn.py b/tests/test_infer/test_ops/triton/test_decoding_attn.py index d52373128dda..5dc3c22c0716 100644 --- a/tests/test_infer/test_ops/triton/test_decoding_attn.py +++ b/tests/test_infer/test_ops/triton/test_decoding_attn.py @@ -1,7 +1,9 @@ +import numpy as np import pytest import torch from packaging import version +from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes from colossalai.kernel.triton import flash_decoding_attention from colossalai.utils import get_current_device from tests.test_infer.test_ops.triton.kernel_utils import ( @@ -10,6 +12,7 @@ generate_caches_and_block_tables_v2, torch_attn_ref, ) +from tests.test_infer.test_ops.triton.test_context_attn_unpad import generate_alibi_mask try: import triton # noqa @@ -24,6 +27,13 @@ HEAD_DIM = 128 +def numpy_allclose(x, y, rtol, atol): + x_numpy = x.detach().cpu().numpy() + y_numpy = y.detach().cpu().numpy() + + np.testing.assert_allclose(x_numpy, y_numpy, rtol=rtol, atol=atol) + + def prepare_data( bsz: int, num_attn_heads: int, @@ -64,6 +74,7 @@ def prepare_data( @pytest.mark.parametrize("kv_group_num", [1, 2, 16]) @pytest.mark.parametrize("same_context_len", [True, False]) @pytest.mark.parametrize("q_len", [1, 5]) +@pytest.mark.parametrize("use_alibi_slopes", [True, False]) def test_flash_decoding( bsz: int, block_size: int, @@ -72,6 +83,7 @@ def test_flash_decoding( kv_group_num: int, same_context_len: bool, q_len: int, + use_alibi_slopes: bool, ): torch.manual_seed(123) torch.cuda.empty_cache() @@ -83,6 +95,14 @@ def test_flash_decoding( max_seq_len = block_size * max_num_blocks_per_seq dtype = torch.float16 device = get_current_device() + + if use_alibi_slopes: + alibi_slopes = get_alibi_slopes(num_attn_heads, device) + # Currently, alibi flash decoding does not support q_len>1. + q_len = 1 + else: + alibi_slopes = None + q, k_unpad, v_unpad, kv_lengths = prepare_data( bsz, num_attn_heads, num_kv_heads, HEAD_DIM, same_context_len, q_len, max_seq_len, dtype, device ) @@ -92,6 +112,17 @@ def test_flash_decoding( k_torch = convert_kv_unpad_to_padded(k_unpad, kv_lengths, bsz, max_kv_len_in_b) v_torch = convert_kv_unpad_to_padded(v_unpad, kv_lengths, bsz, max_kv_len_in_b) attention_mask = create_attention_mask(kv_lengths, bsz, q_len, max_kv_len_in_b, q.device) + + if use_alibi_slopes: + alibi_mask = generate_alibi_mask(alibi_slopes, num_attn_heads, max_kv_len_in_b, q.device) + attention_mask = attention_mask + alibi_mask + + if q_len == 1: + if len(attention_mask.size()) == 4: + attention_mask = attention_mask[:, :, -1:, :] + else: + attention_mask = attention_mask[:, -1:, :] + out_torch = torch_attn_ref( q, k_torch, v_torch, attention_mask, bsz, q_len, max_kv_len_in_b, num_attn_heads, num_kv_heads, HEAD_DIM ) @@ -130,14 +161,21 @@ def test_flash_decoding( output, mid_output, mid_output_lse, + alibi_slopes=alibi_slopes, sm_scale=sm_scale, kv_group_num=kv_group_num, q_len=q_len, ) # [bsz * q_len, num_heads, head_dim] assert out_torch.shape == out_triton.shape - assert torch.allclose(out_torch, out_triton, atol=1e-3, rtol=1e-4) + + rtol = 1e-4 + # After the shape becomes larger, some data elements are too small, leading to excessively large relative errors. + if bsz == 32 and use_alibi_slopes: + rtol = 100 + + numpy_allclose(out_torch, out_triton, atol=1e-3, rtol=rtol) if __name__ == "__main__": - test_flash_decoding(16, 32, 32, 16, 1, True) + test_flash_decoding(16, 32, 32, 16, 1, True, 1, True)