Skip to content

Commit

Permalink
Rebase updates and PR review changes
Browse files Browse the repository at this point in the history
Added Flag for controlling triton vs default flow.
More small changes to dockerfile
  • Loading branch information
jpvillam-amd committed Mar 19, 2024
1 parent c89c0e3 commit d4cb905
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 48 deletions.
2 changes: 1 addition & 1 deletion Dockerfile.rocm
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ RUN if [ "$BUILD_TRITON" = "1" ]; then \
&& pip uninstall -y triton \
&& git clone https://github.com/ROCmSoftwarePlatform/triton.git \
&& cd triton/python \
&& pip3 install -e . \
&& pip3 install . \
&& cd ../..; \
fi

Expand Down
47 changes: 30 additions & 17 deletions vllm/model_executor/layers/attention/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from vllm.logger import init_logger
from vllm.model_executor.input_metadata import InputMetadata
from vllm.utils import is_hip
import os

logger = init_logger(__name__)

Expand All @@ -34,11 +35,12 @@ def __init__(
sliding_window: Optional[int] = None,
) -> None:
super().__init__()
if _use_flash_attn():
if use_triton := _use_flash_attn():
from vllm.model_executor.layers.attention.backends.flash_attn import FlashAttentionBackend # noqa: E501
self.backend = FlashAttentionBackend(num_heads, head_size, scale,
num_kv_heads, alibi_slopes,
sliding_window)
sliding_window,
use_triton == 2)
else:
from vllm.model_executor.layers.attention.backends.xformers import XFormersBackend # noqa: E501
self.backend = XFormersBackend(num_heads, head_size, scale,
Expand All @@ -59,26 +61,37 @@ def forward(


@lru_cache(maxsize=1)
def _use_flash_attn() -> bool:
try:
import flash_attn # noqa: F401
except ImportError:
logger.info("flash_attn is not found. Using xformers backend.")
return False

if is_hip():
# AMD GPUs.
return False
if torch.cuda.get_device_capability()[0] < 8:
def _use_flash_attn() -> int:
"""Returns if and which flash attention to use.
Returns:
int: 0 for none, 1 for default implementation, 2 for triton implementation.
"""
if not (os.environ.get('VLLM_USE_FLASH_ATTN_TRITON') and is_hip()):
# AMD GPUs can use flash_attn package or triton impl.
try:
import flash_attn # noqa: F401
except ImportError:
logger.info("flash_attn is not found. Using xformers backend.")
return 0

if (not is_hip()) and torch.cuda.get_device_capability()[0] < 8:
# Volta and Turing NVIDIA GPUs.
logger.info("flash_attn is not supported on Turing or older GPUs. "
"Using xformers backend.")
return False
return 0

if is_hip() and torch.cuda.get_device_capability()[0] != 9:
# not Instinct series GPUs.
logger.info("flash_atten is not supported on NAVI GPUs. "
"Using xformers backend.")
return 0

if torch.get_default_dtype() not in (torch.float16, torch.bfloat16):
logger.info(
"flash_attn only supports torch.float16 or torch.bfloat16. "
"Using xformers backend.")
return False
return 0

logger.info("Using flash_attn backend.")
return True
logger.info(f"Using {'Triton' if os.environ.get('VLLM_USE_FLASH_ATTN_TRITON') else ''} flash_attn backend.")
return 2 if os.environ.get('VLLM_USE_FLASH_ATTN_TRITON') else 1
36 changes: 24 additions & 12 deletions vllm/model_executor/layers/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.attention.ops.paged_attn import (
PagedAttentionImpl)
from vllm.model_executor.layers.attention.ops.flash_attention_triton import attention
from vllm.model_executor.layers.attention.ops.flash_attention_triton import triton_attention


class FlashAttentionBackend:
Expand All @@ -21,6 +21,7 @@ def __init__(
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
use_triton: Optional[bool] = False,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
Expand All @@ -30,6 +31,7 @@ def __init__(
if alibi_slopes is not None:
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
self.alibi_slopes = alibi_slopes
self.use_triton = use_triton

assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
Expand Down Expand Up @@ -87,8 +89,8 @@ def forward(
query = query.unflatten(0, (batch_size, seq_len))
key = key.unflatten(0, (batch_size, seq_len))
value = value.unflatten(0, (batch_size, seq_len))
if is_hip():
output, _ = attention(
if self.use_triton:
output, _ = triton_attention(
query,
key,
value,
Expand All @@ -98,15 +100,25 @@ def forward(
self.scale,
)
else:
output = flash_attn_func(
query,
key,
value,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
)
if is_hip():
#XXX: window_size and alibi_slopes not supported
output = flash_attn_func(
query,
key,
value,
softmax_scale=self.scale,
causal=True,
)
else:
output = flash_attn_func(
query,
key,
value,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
)
else:
# prefix-enabled attention
output = PagedAttentionImpl.forward_prefix(
Expand Down
33 changes: 15 additions & 18 deletions vllm/model_executor/layers/attention/ops/flash_attention_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,12 +251,12 @@ def attn_fwd(
)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty)
# We still need to write 0s to the result
tl.store(O_block_ptr, acc.to(Out.type.element_ty), boundary_check=(0,1))
l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
#tl.store(O_block_ptr, acc.to(Out.type.element_ty), boundary_check=(0,1))
#l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
# We store inf to LSE, not -inf because in the bwd pass, we subtract this
# from qk which makes it -inf, such that exp(qk - inf) = 0 for these masked blocks.
l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32)
tl.store(l_ptrs, l)
#l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32)
#tl.store(l_ptrs, l)
# TODO: Should dropout and return encoded softmax be handled here too?
return

Expand Down Expand Up @@ -417,17 +417,17 @@ def attn_fwd(
z = 0.0
acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty))
# write back LSE
l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
#l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m
# If seqlen_q not multiple of BLOCK_M, we need to mask out the last few rows.
# This is only true for the last M block. For others, overflow_size will be -ve
overflow_size = end_m_idx - seqlen_q
if overflow_size > 0:
boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32)
# This is a > check because mask being 0 blocks the store.
l_ptrs_mask = boundary > tl.arange(0, BLOCK_M)
tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask)
else:
tl.store(l_ptrs, m_i + tl.math.log2(l_i))
#overflow_size = end_m_idx - seqlen_q
#if overflow_size > 0:
# boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32)
# # This is a > check because mask being 0 blocks the store.
# l_ptrs_mask = boundary > tl.arange(0, BLOCK_M)
# tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask)
#else:
# tl.store(l_ptrs, m_i + tl.math.log2(l_i))

# write back O
o_offset = off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh
Expand Down Expand Up @@ -494,8 +494,6 @@ def forward(ctx, q, k, v, o, metadata, causal=False, sm_scale=1.0, bias=None):

encoded_softmax = None

M = torch.empty((batch, nheads_q, metadata.max_seq_len), device=q.device, dtype=torch.float32)

# Seed the RNG so we get reproducible results for testing.
philox_seed = 0x1BF52
philox_offset = 0x1D4B42
Expand All @@ -507,7 +505,7 @@ def forward(ctx, q, k, v, o, metadata, causal=False, sm_scale=1.0, bias=None):
bias_strides = (0,0,0,0)

attn_fwd[grid](
q, k, v, bias, sm_scale, M, o,
q, k, v, bias, sm_scale, None, o,
*q_strides, *k_strides, *v_strides, *o_strides, *bias_strides,
None, None,
dropout_p=0.0,
Expand All @@ -526,7 +524,6 @@ def forward(ctx, q, k, v, o, metadata, causal=False, sm_scale=1.0, bias=None):
RETURN_ENCODED_SOFTMAX=False
)

ctx.save_for_backward(q, k, v, o, M)
ctx.grid = grid
ctx.sm_scale = sm_scale
ctx.BLOCK_DMODEL = head_size
Expand All @@ -538,4 +535,4 @@ def forward(ctx, q, k, v, o, metadata, causal=False, sm_scale=1.0, bias=None):
ctx.return_encoded_softmax = False
return o, encoded_softmax

attention = _attention.apply
triton_attention = _attention.apply

0 comments on commit d4cb905

Please sign in to comment.