diff --git a/Dockerfile.rocm b/Dockerfile.rocm index 080e5b04d28bc..e7f52307a6aa2 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -105,7 +105,7 @@ RUN if [ "$BUILD_TRITON" = "1" ]; then \ && pip uninstall -y triton \ && git clone https://github.com/ROCmSoftwarePlatform/triton.git \ && cd triton/python \ - && pip3 install -e . \ + && pip3 install . \ && cd ../..; \ fi diff --git a/vllm/model_executor/layers/attention/attention.py b/vllm/model_executor/layers/attention/attention.py index 4b63b9eaf59a7..89b5816f7a47a 100644 --- a/vllm/model_executor/layers/attention/attention.py +++ b/vllm/model_executor/layers/attention/attention.py @@ -8,6 +8,7 @@ from vllm.logger import init_logger from vllm.model_executor.input_metadata import InputMetadata from vllm.utils import is_hip +import os logger = init_logger(__name__) @@ -34,11 +35,12 @@ def __init__( sliding_window: Optional[int] = None, ) -> None: super().__init__() - if _use_flash_attn(): + if use_triton := _use_flash_attn(): from vllm.model_executor.layers.attention.backends.flash_attn import FlashAttentionBackend # noqa: E501 self.backend = FlashAttentionBackend(num_heads, head_size, scale, num_kv_heads, alibi_slopes, - sliding_window) + sliding_window, + use_triton == 2) else: from vllm.model_executor.layers.attention.backends.xformers import XFormersBackend # noqa: E501 self.backend = XFormersBackend(num_heads, head_size, scale, @@ -59,26 +61,37 @@ def forward( @lru_cache(maxsize=1) -def _use_flash_attn() -> bool: - try: - import flash_attn # noqa: F401 - except ImportError: - logger.info("flash_attn is not found. Using xformers backend.") - return False - - if is_hip(): - # AMD GPUs. - return False - if torch.cuda.get_device_capability()[0] < 8: +def _use_flash_attn() -> int: + """Returns if and which flash attention to use. + + Returns: + int: 0 for none, 1 for default implementation, 2 for triton implementation. + """ + if not (os.environ.get('VLLM_USE_FLASH_ATTN_TRITON') and is_hip()): + # AMD GPUs can use flash_attn package or triton impl. + try: + import flash_attn # noqa: F401 + except ImportError: + logger.info("flash_attn is not found. Using xformers backend.") + return 0 + + if (not is_hip()) and torch.cuda.get_device_capability()[0] < 8: # Volta and Turing NVIDIA GPUs. logger.info("flash_attn is not supported on Turing or older GPUs. " "Using xformers backend.") - return False + return 0 + + if is_hip() and torch.cuda.get_device_capability()[0] != 9: + # not Instinct series GPUs. + logger.info("flash_atten is not supported on NAVI GPUs. " + "Using xformers backend.") + return 0 + if torch.get_default_dtype() not in (torch.float16, torch.bfloat16): logger.info( "flash_attn only supports torch.float16 or torch.bfloat16. " "Using xformers backend.") - return False + return 0 - logger.info("Using flash_attn backend.") - return True + logger.info(f"Using {'Triton' if os.environ.get('VLLM_USE_FLASH_ATTN_TRITON') else ''} flash_attn backend.") + return 2 if os.environ.get('VLLM_USE_FLASH_ATTN_TRITON') else 1 diff --git a/vllm/model_executor/layers/attention/backends/flash_attn.py b/vllm/model_executor/layers/attention/backends/flash_attn.py index c2d7b5acc467e..726b42cad9e3f 100644 --- a/vllm/model_executor/layers/attention/backends/flash_attn.py +++ b/vllm/model_executor/layers/attention/backends/flash_attn.py @@ -8,7 +8,7 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.attention.ops.paged_attn import ( PagedAttentionImpl) -from vllm.model_executor.layers.attention.ops.flash_attention_triton import attention +from vllm.model_executor.layers.attention.ops.flash_attention_triton import triton_attention class FlashAttentionBackend: @@ -21,6 +21,7 @@ def __init__( num_kv_heads: Optional[int] = None, alibi_slopes: Optional[List[float]] = None, sliding_window: Optional[int] = None, + use_triton: Optional[bool] = False, ) -> None: self.num_heads = num_heads self.head_size = head_size @@ -30,6 +31,7 @@ def __init__( if alibi_slopes is not None: alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.alibi_slopes = alibi_slopes + self.use_triton = use_triton assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads @@ -87,8 +89,8 @@ def forward( query = query.unflatten(0, (batch_size, seq_len)) key = key.unflatten(0, (batch_size, seq_len)) value = value.unflatten(0, (batch_size, seq_len)) - if is_hip(): - output, _ = attention( + if self.use_triton: + output, _ = triton_attention( query, key, value, @@ -98,15 +100,25 @@ def forward( self.scale, ) else: - output = flash_attn_func( - query, - key, - value, - softmax_scale=self.scale, - causal=True, - window_size=self.sliding_window, - alibi_slopes=self.alibi_slopes, - ) + if is_hip(): + #XXX: window_size and alibi_slopes not supported + output = flash_attn_func( + query, + key, + value, + softmax_scale=self.scale, + causal=True, + ) + else: + output = flash_attn_func( + query, + key, + value, + softmax_scale=self.scale, + causal=True, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + ) else: # prefix-enabled attention output = PagedAttentionImpl.forward_prefix( diff --git a/vllm/model_executor/layers/attention/ops/flash_attention_triton.py b/vllm/model_executor/layers/attention/ops/flash_attention_triton.py index 37c15e0e6fa36..80962e4cf9d9a 100644 --- a/vllm/model_executor/layers/attention/ops/flash_attention_triton.py +++ b/vllm/model_executor/layers/attention/ops/flash_attention_triton.py @@ -251,12 +251,12 @@ def attn_fwd( ) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty) # We still need to write 0s to the result - tl.store(O_block_ptr, acc.to(Out.type.element_ty), boundary_check=(0,1)) - l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m + #tl.store(O_block_ptr, acc.to(Out.type.element_ty), boundary_check=(0,1)) + #l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m # We store inf to LSE, not -inf because in the bwd pass, we subtract this # from qk which makes it -inf, such that exp(qk - inf) = 0 for these masked blocks. - l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32) - tl.store(l_ptrs, l) + #l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32) + #tl.store(l_ptrs, l) # TODO: Should dropout and return encoded softmax be handled here too? return @@ -417,17 +417,17 @@ def attn_fwd( z = 0.0 acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) # write back LSE - l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m + #l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m # If seqlen_q not multiple of BLOCK_M, we need to mask out the last few rows. # This is only true for the last M block. For others, overflow_size will be -ve - overflow_size = end_m_idx - seqlen_q - if overflow_size > 0: - boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32) - # This is a > check because mask being 0 blocks the store. - l_ptrs_mask = boundary > tl.arange(0, BLOCK_M) - tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask) - else: - tl.store(l_ptrs, m_i + tl.math.log2(l_i)) + #overflow_size = end_m_idx - seqlen_q + #if overflow_size > 0: + # boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32) + # # This is a > check because mask being 0 blocks the store. + # l_ptrs_mask = boundary > tl.arange(0, BLOCK_M) + # tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask) + #else: + # tl.store(l_ptrs, m_i + tl.math.log2(l_i)) # write back O o_offset = off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh @@ -494,8 +494,6 @@ def forward(ctx, q, k, v, o, metadata, causal=False, sm_scale=1.0, bias=None): encoded_softmax = None - M = torch.empty((batch, nheads_q, metadata.max_seq_len), device=q.device, dtype=torch.float32) - # Seed the RNG so we get reproducible results for testing. philox_seed = 0x1BF52 philox_offset = 0x1D4B42 @@ -507,7 +505,7 @@ def forward(ctx, q, k, v, o, metadata, causal=False, sm_scale=1.0, bias=None): bias_strides = (0,0,0,0) attn_fwd[grid]( - q, k, v, bias, sm_scale, M, o, + q, k, v, bias, sm_scale, None, o, *q_strides, *k_strides, *v_strides, *o_strides, *bias_strides, None, None, dropout_p=0.0, @@ -526,7 +524,6 @@ def forward(ctx, q, k, v, o, metadata, causal=False, sm_scale=1.0, bias=None): RETURN_ENCODED_SOFTMAX=False ) - ctx.save_for_backward(q, k, v, o, M) ctx.grid = grid ctx.sm_scale = sm_scale ctx.BLOCK_DMODEL = head_size @@ -538,4 +535,4 @@ def forward(ctx, q, k, v, o, metadata, causal=False, sm_scale=1.0, bias=None): ctx.return_encoded_softmax = False return o, encoded_softmax -attention = _attention.apply +triton_attention = _attention.apply