From d444815c02227b2eb9ab27ec09112a98e482f19c Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein <669761+bottler@users.noreply.github.com> Date: Thu, 29 Aug 2024 21:57:35 +0000 Subject: [PATCH] export from fbcode (fairinternal/xformers#1212) __original_commit__ = fairinternal/xformers@a728c49c87d3dbb00dc306cfd5d8e0ba0569d692 --- .../benchmarks/benchmark_mem_eff_attention.py | 2 +- xformers/ops/fmha/flash.py | 3 + xformers/ops/fmha/flash3.py | 203 +++++++++++++++++- 3 files changed, 205 insertions(+), 3 deletions(-) diff --git a/xformers/benchmarks/benchmark_mem_eff_attention.py b/xformers/benchmarks/benchmark_mem_eff_attention.py index a6e848703f..2c7418e9b7 100644 --- a/xformers/benchmarks/benchmark_mem_eff_attention.py +++ b/xformers/benchmarks/benchmark_mem_eff_attention.py @@ -72,7 +72,7 @@ OPS = [ (xformers.ops.fmha.cutlass.FwOp, xformers.ops.fmha.cutlass.BwOp), (xformers.ops.fmha.flash.FwOp, xformers.ops.fmha.flash.BwOp), - (xformers.ops.fmha.flash3.FwOp, xformers.ops.fmha.flash.BwOp), + (xformers.ops.fmha.flash3.FwOp, xformers.ops.fmha.flash3.BwOp), (xformers.ops.fmha.ck.FwOp, xformers.ops.fmha.ck.BwOp), ] diff --git a/xformers/ops/fmha/flash.py b/xformers/ops/fmha/flash.py index 79f398b57b..0df27a432a 100644 --- a/xformers/ops/fmha/flash.py +++ b/xformers/ops/fmha/flash.py @@ -42,6 +42,7 @@ FLASH_VERSION = "0.0.0" VARLEN_LSE_PACKED = False +_TRY_PT_FLASH_ATTN = torch.version.hip is None _USE_PT_FLASH_ATTN = False try: @@ -73,6 +74,8 @@ ) VARLEN_LSE_PACKED = True except ImportError: + if not _TRY_PT_FLASH_ATTN: + raise assert is_pt_flash_compatible(force=True) FLASH_VERSION = torch.nn.attention._get_flash_version() # type: ignore VARLEN_LSE_PACKED = False diff --git a/xformers/ops/fmha/flash3.py b/xformers/ops/fmha/flash3.py index 82db6b9876..e03a9fcea8 100644 --- a/xformers/ops/fmha/flash3.py +++ b/xformers/ops/fmha/flash3.py @@ -4,12 +4,13 @@ # LICENSE file in the root directory of this source tree. -from typing import Any, Iterable, List, Optional, Set, Tuple +from typing import Any, Iterable, List, Optional, Sequence, Set, Tuple import torch from ..common import get_operator, register_operator from .attn_bias import ( + VARLEN_BIASES, BlockDiagonalCausalFromBottomRightMask, BlockDiagonalCausalMask, BlockDiagonalCausalWithOffsetGappyKeysMask, @@ -20,7 +21,14 @@ LowerTriangularFromBottomRightMask, LowerTriangularMask, ) -from .common import AttentionFwOpBase, Context, Inputs, check_lastdim_alignment_stride1 +from .common import ( + AttentionBwOpBase, + AttentionFwOpBase, + Context, + Gradients, + Inputs, + check_lastdim_alignment_stride1, +) from .flash import ( _check_needs_no_topleft, _convert_input_format, @@ -116,6 +124,99 @@ def mha_fwd_fake( lse = query.new_empty(lse_shape, dtype=torch.float32) return out, lse + def _create_dq_dk_dv( + grads_share_storage: bool, query, key, value + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # Create dq,dk,dv + # If Q/K/V come from a single QKV tensor, let's put the gradient in the + # right strides, so we can avoid a `cat` + if grads_share_storage: + chunk = torch.empty( + (*query.shape[0:-2], 3, query.shape[-2], query.shape[-1]), + dtype=query.dtype, + device=query.device, + ) + return chunk.select(-3, 0), chunk.select(-3, 1), chunk.select(-3, 2) + return torch.empty_like(query), torch.empty_like(key), torch.empty_like(value) + + @torch.library.custom_op( + "xformers_flash3::flash_bwd", mutates_args=(), device_types=["cuda"] + ) + def mha_bwd( + grads_share_storage: bool, + dout: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + out: torch.Tensor, + softmax_lse: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + softmax_scale: float, + is_causal: bool, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + dq, dk, dv = _create_dq_dk_dv(grads_share_storage, query, key, value) + is_deterministic = False + if cu_seqlens_q is None: + assert cu_seqlens_k is None + dq, dk, dv, softmax_d, *rest = _C_flashattention3.bwd( + dout, + query, + key, + value, + out, + softmax_lse, + dq, + dk, + dv, + softmax_scale, + is_causal, + is_deterministic, + ) + else: + dq, dk, dv, softmax_d, *rest = _C_flashattention3.varlen_bwd( + dout, + query, + key, + value, + out, + softmax_lse, + dq, + dk, + dv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale, + is_causal, + is_deterministic, + ) + return dq, dk, dv + + @torch.library.register_fake("xformers_flash3::flash_bwd") + def mha_bwd_fake( + grads_share_storage: bool, + dout: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + out: torch.Tensor, + softmax_lse: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + softmax_scale: float, + is_causal: bool, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + dq = torch.empty_like(query) + dk = torch.empty_like(key) + dv = torch.empty_like(value) + return dq, dk, dv + @register_operator class FwOp(AttentionFwOpBase): @@ -218,3 +319,101 @@ def apply( lse=_post_process_lse(softmax_lse, inp, tuple(original_query_shape)), ) return (out, ctx) + + +@register_operator +class BwOp(AttentionBwOpBase): + __doc__ = FwOp.__doc__ + + OPERATOR = get_operator("xformers_flash3", "flash_bwd") + SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES + CUDA_MINIMUM_COMPUTE_CAPABILITY = FwOp.CUDA_MINIMUM_COMPUTE_CAPABILITY + SUPPORTED_DTYPES = FwOp.SUPPORTED_DTYPES + SUPPORTED_MAX_K = FwOp.SUPPORTED_MAX_K + SUPPORTED_MIN_K = FwOp.SUPPORTED_MIN_K + SUPPORTED_ATTN_BIAS_TYPES = ( + # Exclude padded or gappy masks, since seqused_k is not supported by the kernel. + type(None), + LowerTriangularMask, + LowerTriangularFromBottomRightMask, + BlockDiagonalMask, + BlockDiagonalCausalMask, + BlockDiagonalCausalFromBottomRightMask, + ) + + SUPPORTS_DROPOUT = FwOp.SUPPORTS_DROPOUT + SUPPORTS_CUSTOM_SCALE = FwOp.SUPPORTS_CUSTOM_SCALE + SUPPORTS_DIFFERENT_VALUE_EMBED = FwOp.SUPPORTS_DIFFERENT_VALUE_EMBED + IS_DETERMINISTIC = False + SUPPORTS_BMGHK = False + SUPPORTS_LSE_FORMATS: Sequence[str] = ["", "varlen_flat"] + NAME = f"fa3B@{FLASH_VERSION}" + VERSION = FLASH_VERSION + + @classmethod + def not_supported_reasons(cls, d: Inputs) -> List[str]: + reasons = super(BwOp, cls).not_supported_reasons(d) + check_lastdim_alignment_stride1(reasons, "query", d.query, 8) + _check_needs_no_topleft(d, reasons) + if d.query.shape[-1] not in [64, 128]: + reasons.append("only head-dim 64 or 128 is supported") + + _check_needs_no_topleft(d, reasons) + return reasons + + @classmethod + def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients: + + dq_shape, dk_shape, dv_shape = inp.query.shape, inp.key.shape, inp.value.shape + ( + inp, + cu_seqlens_q, + max_seqlen_q, + cu_seqlens_k, + max_seqlen_k, + _, # seqused_k, + ) = _convert_input_format(inp, supports_mqa=False) + ctx_lse = ctx.lse + + if isinstance(inp.attn_bias, VARLEN_BIASES): + assert ctx_lse.shape[0] == 1 + ctx_lse = ctx_lse[0] + else: + # NOTE: cutlass pads the last dimension, we need to slice it + assert ctx_lse.shape[2] >= max_seqlen_q + ctx_lse = ctx_lse[:, :, :max_seqlen_q].contiguous() + + kernel_out_shape = [ + *inp.query.shape[:-1], + inp.value.shape[-1], + ] + assert grad.dtype in cls.SUPPORTED_DTYPES + + if inp.query.numel() and inp.key.numel(): + dq, dk, dv = cls.OPERATOR( + ctx.qkv_share_storage, + grad.reshape(kernel_out_shape).contiguous(), + inp.query, + inp.key, + inp.value, + ctx.out.reshape(kernel_out_shape), + ctx.lse, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale=inp.scale_float, + is_causal=_is_causal(inp.attn_bias), + ) + grads = Gradients(dq, dk, dv) + else: + grads = Gradients( + dq=torch.zeros_like(inp.query), + dk=torch.zeros_like(inp.key), + dv=torch.zeros_like(inp.value), + ) + + grads.dq = grads.dq.reshape(dq_shape) + grads.dk = grads.dk.reshape(dk_shape) + grads.dv = grads.dv.reshape(dv_shape) + return grads