Skip to content

Commit

Permalink
export from fbcode (fairinternal/xformers#1212)
Browse files Browse the repository at this point in the history
__original_commit__ = fairinternal/xformers@a728c49
  • Loading branch information
bottler authored and xFormers Bot committed Aug 29, 2024
1 parent a974032 commit d444815
Show file tree
Hide file tree
Showing 3 changed files with 205 additions and 3 deletions.
2 changes: 1 addition & 1 deletion xformers/benchmarks/benchmark_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
OPS = [
(xformers.ops.fmha.cutlass.FwOp, xformers.ops.fmha.cutlass.BwOp),
(xformers.ops.fmha.flash.FwOp, xformers.ops.fmha.flash.BwOp),
(xformers.ops.fmha.flash3.FwOp, xformers.ops.fmha.flash.BwOp),
(xformers.ops.fmha.flash3.FwOp, xformers.ops.fmha.flash3.BwOp),
(xformers.ops.fmha.ck.FwOp, xformers.ops.fmha.ck.BwOp),
]

Expand Down
3 changes: 3 additions & 0 deletions xformers/ops/fmha/flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@

FLASH_VERSION = "0.0.0"
VARLEN_LSE_PACKED = False
_TRY_PT_FLASH_ATTN = torch.version.hip is None
_USE_PT_FLASH_ATTN = False

try:
Expand Down Expand Up @@ -73,6 +74,8 @@
)
VARLEN_LSE_PACKED = True
except ImportError:
if not _TRY_PT_FLASH_ATTN:
raise
assert is_pt_flash_compatible(force=True)
FLASH_VERSION = torch.nn.attention._get_flash_version() # type: ignore
VARLEN_LSE_PACKED = False
Expand Down
203 changes: 201 additions & 2 deletions xformers/ops/fmha/flash3.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
# LICENSE file in the root directory of this source tree.


from typing import Any, Iterable, List, Optional, Set, Tuple
from typing import Any, Iterable, List, Optional, Sequence, Set, Tuple

import torch

from ..common import get_operator, register_operator
from .attn_bias import (
VARLEN_BIASES,
BlockDiagonalCausalFromBottomRightMask,
BlockDiagonalCausalMask,
BlockDiagonalCausalWithOffsetGappyKeysMask,
Expand All @@ -20,7 +21,14 @@
LowerTriangularFromBottomRightMask,
LowerTriangularMask,
)
from .common import AttentionFwOpBase, Context, Inputs, check_lastdim_alignment_stride1
from .common import (
AttentionBwOpBase,
AttentionFwOpBase,
Context,
Gradients,
Inputs,
check_lastdim_alignment_stride1,
)
from .flash import (
_check_needs_no_topleft,
_convert_input_format,
Expand Down Expand Up @@ -116,6 +124,99 @@ def mha_fwd_fake(
lse = query.new_empty(lse_shape, dtype=torch.float32)
return out, lse

def _create_dq_dk_dv(
grads_share_storage: bool, query, key, value
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# Create dq,dk,dv
# If Q/K/V come from a single QKV tensor, let's put the gradient in the
# right strides, so we can avoid a `cat`
if grads_share_storage:
chunk = torch.empty(
(*query.shape[0:-2], 3, query.shape[-2], query.shape[-1]),
dtype=query.dtype,
device=query.device,
)
return chunk.select(-3, 0), chunk.select(-3, 1), chunk.select(-3, 2)
return torch.empty_like(query), torch.empty_like(key), torch.empty_like(value)

@torch.library.custom_op(
"xformers_flash3::flash_bwd", mutates_args=(), device_types=["cuda"]
)
def mha_bwd(
grads_share_storage: bool,
dout: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
out: torch.Tensor,
softmax_lse: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
softmax_scale: float,
is_causal: bool,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
dq, dk, dv = _create_dq_dk_dv(grads_share_storage, query, key, value)
is_deterministic = False
if cu_seqlens_q is None:
assert cu_seqlens_k is None
dq, dk, dv, softmax_d, *rest = _C_flashattention3.bwd(
dout,
query,
key,
value,
out,
softmax_lse,
dq,
dk,
dv,
softmax_scale,
is_causal,
is_deterministic,
)
else:
dq, dk, dv, softmax_d, *rest = _C_flashattention3.varlen_bwd(
dout,
query,
key,
value,
out,
softmax_lse,
dq,
dk,
dv,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
softmax_scale,
is_causal,
is_deterministic,
)
return dq, dk, dv

@torch.library.register_fake("xformers_flash3::flash_bwd")
def mha_bwd_fake(
grads_share_storage: bool,
dout: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
out: torch.Tensor,
softmax_lse: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
softmax_scale: float,
is_causal: bool,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
dq = torch.empty_like(query)
dk = torch.empty_like(key)
dv = torch.empty_like(value)
return dq, dk, dv


@register_operator
class FwOp(AttentionFwOpBase):
Expand Down Expand Up @@ -218,3 +319,101 @@ def apply(
lse=_post_process_lse(softmax_lse, inp, tuple(original_query_shape)),
)
return (out, ctx)


@register_operator
class BwOp(AttentionBwOpBase):
__doc__ = FwOp.__doc__

OPERATOR = get_operator("xformers_flash3", "flash_bwd")
SUPPORTED_DEVICES = FwOp.SUPPORTED_DEVICES
CUDA_MINIMUM_COMPUTE_CAPABILITY = FwOp.CUDA_MINIMUM_COMPUTE_CAPABILITY
SUPPORTED_DTYPES = FwOp.SUPPORTED_DTYPES
SUPPORTED_MAX_K = FwOp.SUPPORTED_MAX_K
SUPPORTED_MIN_K = FwOp.SUPPORTED_MIN_K
SUPPORTED_ATTN_BIAS_TYPES = (
# Exclude padded or gappy masks, since seqused_k is not supported by the kernel.
type(None),
LowerTriangularMask,
LowerTriangularFromBottomRightMask,
BlockDiagonalMask,
BlockDiagonalCausalMask,
BlockDiagonalCausalFromBottomRightMask,
)

SUPPORTS_DROPOUT = FwOp.SUPPORTS_DROPOUT
SUPPORTS_CUSTOM_SCALE = FwOp.SUPPORTS_CUSTOM_SCALE
SUPPORTS_DIFFERENT_VALUE_EMBED = FwOp.SUPPORTS_DIFFERENT_VALUE_EMBED
IS_DETERMINISTIC = False
SUPPORTS_BMGHK = False
SUPPORTS_LSE_FORMATS: Sequence[str] = ["", "varlen_flat"]
NAME = f"fa3B@{FLASH_VERSION}"
VERSION = FLASH_VERSION

@classmethod
def not_supported_reasons(cls, d: Inputs) -> List[str]:
reasons = super(BwOp, cls).not_supported_reasons(d)
check_lastdim_alignment_stride1(reasons, "query", d.query, 8)
_check_needs_no_topleft(d, reasons)
if d.query.shape[-1] not in [64, 128]:
reasons.append("only head-dim 64 or 128 is supported")

_check_needs_no_topleft(d, reasons)
return reasons

@classmethod
def apply(cls, ctx: Context, inp: Inputs, grad: torch.Tensor) -> Gradients:

dq_shape, dk_shape, dv_shape = inp.query.shape, inp.key.shape, inp.value.shape
(
inp,
cu_seqlens_q,
max_seqlen_q,
cu_seqlens_k,
max_seqlen_k,
_, # seqused_k,
) = _convert_input_format(inp, supports_mqa=False)
ctx_lse = ctx.lse

if isinstance(inp.attn_bias, VARLEN_BIASES):
assert ctx_lse.shape[0] == 1
ctx_lse = ctx_lse[0]
else:
# NOTE: cutlass pads the last dimension, we need to slice it
assert ctx_lse.shape[2] >= max_seqlen_q
ctx_lse = ctx_lse[:, :, :max_seqlen_q].contiguous()

kernel_out_shape = [
*inp.query.shape[:-1],
inp.value.shape[-1],
]
assert grad.dtype in cls.SUPPORTED_DTYPES

if inp.query.numel() and inp.key.numel():
dq, dk, dv = cls.OPERATOR(
ctx.qkv_share_storage,
grad.reshape(kernel_out_shape).contiguous(),
inp.query,
inp.key,
inp.value,
ctx.out.reshape(kernel_out_shape),
ctx.lse,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
softmax_scale=inp.scale_float,
is_causal=_is_causal(inp.attn_bias),
)
grads = Gradients(dq, dk, dv)
else:
grads = Gradients(
dq=torch.zeros_like(inp.query),
dk=torch.zeros_like(inp.key),
dv=torch.zeros_like(inp.value),
)

grads.dq = grads.dq.reshape(dq_shape)
grads.dk = grads.dk.reshape(dk_shape)
grads.dv = grads.dv.reshape(dv_shape)
return grads

0 comments on commit d444815

Please sign in to comment.