Skip to content

Commit

Permalink
Use new custom_op for FlashAttention ops
Browse files Browse the repository at this point in the history
ghstack-source-id: b9470d2381bdc6d653ad7ddfa777d4ba3efa3ac4
Pull Request resolved: fairinternal/xformers#1200

__original_commit__ = fairinternal/xformers@9227a37
  • Loading branch information
lw authored and xFormers Bot committed Aug 22, 2024
1 parent e639746 commit 1f6242d
Showing 1 changed file with 44 additions and 57 deletions.
101 changes: 44 additions & 57 deletions xformers/ops/fmha/flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,45 +78,28 @@
VARLEN_LSE_PACKED = False
_USE_PT_FLASH_ATTN = True

# create library so that flash-attn goes through the PyTorch Dispatcher
torch.library.define(
@torch.library.custom_op(
"xformers_flash::flash_fwd",
"(Tensor query, Tensor key, Tensor value, "
"Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, Tensor? seqused_k, "
"int max_seqlen_q, int max_seqlen_k, "
"float p, float softmax_scale, "
"bool is_causal, int window_left, "
"int window_right, bool return_softmax, Tensor? block_tables) -> (Tensor, Tensor, Tensor)",
mutates_args=(),
device_types=["cuda"],
)

torch.library.define(
"xformers_flash::flash_bwd",
"(bool grads_share_storage, Tensor dout, Tensor query, Tensor key, Tensor value, "
"Tensor out, Tensor softmax_lse_, "
"Tensor cu_seqlens_q, Tensor cu_seqlens_k, "
"int max_seqlen_q, int max_seqlen_k, "
"float p, float softmax_scale, bool is_causal, "
"int window_left, int window_right, Tensor rng_state) -> (Tensor dq, Tensor dk, Tensor dv)",
)

@torch.library.impl("xformers_flash::flash_fwd", "default")
def _flash_fwd(
query,
key,
value,
cu_seq_lens_q,
cu_seq_lens_k,
seqused_k,
max_seq_len_q,
max_seq_len_k,
p,
softmax_scale,
is_causal,
window_left,
window_right,
return_softmax,
block_tables,
):
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
cu_seq_lens_q: Optional[torch.Tensor],
cu_seq_lens_k: Optional[torch.Tensor],
seqused_k: Optional[torch.Tensor],
max_seq_len_q: int,
max_seq_len_k: int,
p: float,
softmax_scale: float,
is_causal: bool,
window_left: int,
window_right: int,
return_softmax: bool,
block_tables: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
softcap = 0.0
if _USE_PT_FLASH_ATTN:
(
Expand Down Expand Up @@ -207,7 +190,7 @@ def _flash_fwd(
)
return out, softmax_lse, rng_state

@torch.library.impl_abstract("xformers_flash::flash_fwd")
@torch.library.register_fake("xformers_flash::flash_fwd")
def _flash_fwd_abstract(
query,
key,
Expand Down Expand Up @@ -240,26 +223,30 @@ def _flash_fwd_abstract(
rng_state = torch.empty([2], device=query.device, dtype=torch.int64)
return out, softmax_lse, rng_state

@torch.library.impl("xformers_flash::flash_bwd", "default")
@torch.library.custom_op(
"xformers_flash::flash_bwd",
mutates_args=(),
device_types=["cuda"],
)
def _flash_bwd(
grads_share_storage,
grad,
query,
key,
value,
out,
lse,
cu_seq_lens_q,
cu_seq_lens_k,
max_seq_len_q,
max_seq_len_k,
p,
softmax_scale,
is_causal,
window_left,
window_right,
rng_state,
):
grads_share_storage: bool,
grad: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
out: torch.Tensor,
lse: torch.Tensor,
cu_seq_lens_q: torch.Tensor,
cu_seq_lens_k: torch.Tensor,
max_seq_len_q: int,
max_seq_len_k: int,
p: float,
softmax_scale: float,
is_causal: bool,
window_left: int,
window_right: int,
rng_state: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
softcap = 0.0
if _USE_PT_FLASH_ATTN:
assert softcap == 0.0
Expand Down Expand Up @@ -341,7 +328,7 @@ def _flash_bwd(
)
return dq, dk, dv

@torch.library.impl_abstract("xformers_flash::flash_bwd")
@torch.library.register_fake("xformers_flash::flash_bwd")
def _flash_bwd_abstract(
grads_share_storage,
grad,
Expand Down

0 comments on commit 1f6242d

Please sign in to comment.