Skip to content

Commit

Permalink
integrated flash attention 2
Browse files Browse the repository at this point in the history
  • Loading branch information
a663E-36z1120 committed Sep 18, 2023
1 parent 70af6e8 commit f386a4b
Show file tree
Hide file tree
Showing 3 changed files with 332 additions and 7 deletions.
324 changes: 323 additions & 1 deletion megatron/model/flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch.nn.functional as F

from flash_attn import flash_attn_triton
import flash_attn_cuda
import flash_attn_2_cuda as flash_attn_cuda # For flash_attn version 2.1.1


def flash_attn_unpadded_unpacked_func_triton(
Expand Down Expand Up @@ -462,3 +462,325 @@ def flash_attn_unpadded_func_cuda(
causal,
return_attn_probs,
)


# For flash-attention 2 integration
def _flash_attn_varlen_forward(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
causal,
return_softmax,
):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
q,
k,
v,
None,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
False,
causal,
return_softmax,
None,
)
# if out.isnan().any() or softmax_lse.isnan().any():
# breakpoint()
return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state


def _flash_attn_varlen_backward(
dout,
q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
causal,
rng_state=None,
):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
# dq, dk, dv are allocated by us so they should already be contiguous
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
dq, dk, dv, softmax_d, = flash_attn_cuda.varlen_bwd(
dout,
q,
k,
v,
out,
softmax_lse,
dq,
dk,
dv,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
False,
causal,
None,
rng_state,
)
# if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any():
# breakpoint()
return dq, dk, dv, softmax_d


class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_softmax):
if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5)
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
qkv[:, 0],
qkv[:, 1],
qkv[:, 2],
cu_seqlens,
cu_seqlens,
max_seqlen,
max_seqlen,
dropout_p,
softmax_scale,
causal=causal,
return_softmax=return_softmax and dropout_p > 0,
)
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
ctx.dropout_p = dropout_p
ctx.max_seqlen = max_seqlen
ctx.softmax_scale = softmax_scale
ctx.causal = causal
return out if not return_softmax else (out, softmax_lse, S_dmask)

@staticmethod
def backward(ctx, dout, *args):
q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors
qkv_shape = q.shape[:-2] + (3, *q.shape[-2:])
dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device)
_flash_attn_varlen_backward(
dout,
q,
k,
v,
out,
softmax_lse,
dqkv[:, 0],
dqkv[:, 1],
dqkv[:, 2],
cu_seqlens,
cu_seqlens,
ctx.max_seqlen,
ctx.max_seqlen,
ctx.dropout_p,
ctx.softmax_scale,
ctx.causal,
rng_state=rng_state,
)
dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
return dqkv, None, None, None, None, None, None


def flash_attn_varlen_qkvpacked_func(
qkv,
cu_seqlens,
max_seqlen,
dropout_p=0.0,
softmax_scale=None,
causal=False,
return_attn_probs=False,
):
"""dropout_p should be set to 0.0 during evaluation
If Q, K, V are already stacked into 1 tensor, this function will be faster than
calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation
of the gradients of Q, K, V.
For multi-query and grouped-query attention (MQA/GQA), please see
flash_attn_varlen_kvpacked_func and flash_attn_varlen_func.
Arguments:
qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch.
cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into qkv.
max_seqlen: int. Maximum sequence length in the batch.
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
Return:
out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
The output of softmax (possibly with different scaling). It also encodes the dropout
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
return FlashAttnVarlenQKVPackedFunc.apply(
qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_attn_probs
)


class FlashAttnVarlenKVPackedFunc(torch.autograd.Function):
@staticmethod
def forward(
ctx,
q,
kv,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
causal,
return_softmax,
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward(
q,
kv[:, 0],
kv[:, 1],
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
causal=causal,
return_softmax=return_softmax and dropout_p > 0,
)
ctx.save_for_backward(
q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
)
ctx.dropout_p = dropout_p
ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_k = max_seqlen_k
ctx.softmax_scale = softmax_scale
ctx.causal = causal
return out if not return_softmax else (out, softmax_lse, S_dmask)

@staticmethod
def backward(ctx, dout, *args):
q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors
dq = torch.empty_like(q)
kv_shape = k.shape[:-2] + (2, *k.shape[-2:])
dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device)
_flash_attn_varlen_backward(
dout,
q,
k,
v,
out,
softmax_lse,
dq,
dkv[:, 0],
dkv[:, 1],
cu_seqlens_q,
cu_seqlens_k,
ctx.max_seqlen_q,
ctx.max_seqlen_k,
ctx.dropout_p,
ctx.softmax_scale,
ctx.causal,
rng_state=rng_state,
)
dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension
dkv = dkv[..., : dout.shape[-1]]
return dq, dkv, None, None, None, None, None, None, None, None


def flash_attn_varlen_kvpacked_func(
q,
kv,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p=0.0,
softmax_scale=None,
causal=False,
return_attn_probs=False,
):
"""dropout_p should be set to 0.0 during evaluation
If K, V are already stacked into 1 tensor, this function will be faster than
calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation
of the gradients of K, V.
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
than Q. Note that the number of heads in Q must be divisible by the number of heads in KV.
For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head
0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V.
If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix.
For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is:
1 1 1 1 0
1 1 1 1 1
If seqlen_q = 5 and seqlen_k = 2, the causal mask is:
0 0
0 0
0 0
1 0
1 1
If the row of the mask is all zero, the output will be zero.
Arguments:
q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch.
kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch.
cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into q.
cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths
of the sequences in the batch, used to index into kv.
max_seqlen_q: int. Maximum query sequence length in the batch.
max_seqlen_k: int. Maximum key sequence length in the batch.
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
Default to 1 / sqrt(headdim).
causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling).
return_attn_probs: bool. Whether to return the attention probabilities. This option is for
testing only. The returned probabilities are not guaranteed to be correct
(they might not have the right scaling).
Return:
out: (total, nheads, headdim).
softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen).
The output of softmax (possibly with different scaling). It also encodes the dropout
pattern (negative means that location was dropped, nonnegative means it was kept).
"""
return FlashAttnVarlenKVPackedFunc.apply(
q,
kv,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p,
softmax_scale,
causal,
return_attn_probs,
)
13 changes: 8 additions & 5 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,14 +345,17 @@ def __init__(
else:
if self.use_flash_attention:
from megatron.model.flash_attention import (
flash_attn_unpadded_qkvpacked_func_cuda,
flash_attn_unpadded_kvpacked_func_cuda,
flash_attn_unpadded_unpacked_func_triton,
# flash_attn_unpadded_qkvpacked_func_cuda,
# flash_attn_unpadded_kvpacked_func_cuda,
# Change of function names going from flash attention 1 -> flash attention 2
flash_attn_varlen_qkvpacked_func,
flash_attn_varlen_kvpacked_func,
flash_attn_unpadded_unpacked_func_triton
)

self.flash_triton_fn = flash_attn_unpadded_unpacked_func_triton
self.flash_qkv_fn = flash_attn_unpadded_qkvpacked_func_cuda
self.flash_kv_fn = flash_attn_unpadded_kvpacked_func_cuda
self.flash_qkv_fn = flash_attn_varlen_qkvpacked_func
self.flash_kv_fn = flash_attn_varlen_kvpacked_func
else:
self.scale_mask_softmax = FusedScaleMaskSoftmax(
input_in_fp16=self.fp16,
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements-flashattention.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
flash-attn==0.2.2
flash-attn==2.2.1

0 comments on commit f386a4b

Please sign in to comment.