diff --git a/megatron/model/flash_attention.py b/megatron/model/flash_attention.py index be3ebb14e..9d3f212c1 100644 --- a/megatron/model/flash_attention.py +++ b/megatron/model/flash_attention.py @@ -5,7 +5,7 @@ import torch.nn.functional as F from flash_attn import flash_attn_triton -import flash_attn_cuda +import flash_attn_2_cuda as flash_attn_cuda # For flash_attn version 2.1.1 def flash_attn_unpadded_unpacked_func_triton( @@ -462,3 +462,325 @@ def flash_attn_unpadded_func_cuda( causal, return_attn_probs, ) + + +# For flash-attention 2 integration +def _flash_attn_varlen_forward( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + return_softmax, +): + maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x + q, k, v = [maybe_contiguous(x) for x in (q, k, v)] + out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd( + q, + k, + v, + None, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + False, + causal, + return_softmax, + None, + ) + # if out.isnan().any() or softmax_lse.isnan().any(): + # breakpoint() + return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state + + +def _flash_attn_varlen_backward( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + rng_state=None, +): + maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x + # dq, dk, dv are allocated by us so they should already be contiguous + dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] + dq, dk, dv, softmax_d, = flash_attn_cuda.varlen_bwd( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + False, + causal, + None, + rng_state, + ) + # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any(): + # breakpoint() + return dq, dk, dv, softmax_d + + +class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): + @staticmethod + def forward(ctx, qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_softmax): + if softmax_scale is None: + softmax_scale = qkv.shape[-1] ** (-0.5) + out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward( + qkv[:, 0], + qkv[:, 1], + qkv[:, 2], + cu_seqlens, + cu_seqlens, + max_seqlen, + max_seqlen, + dropout_p, + softmax_scale, + causal=causal, + return_softmax=return_softmax and dropout_p > 0, + ) + ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state) + ctx.dropout_p = dropout_p + ctx.max_seqlen = max_seqlen + ctx.softmax_scale = softmax_scale + ctx.causal = causal + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors + qkv_shape = q.shape[:-2] + (3, *q.shape[-2:]) + dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device) + _flash_attn_varlen_backward( + dout, + q, + k, + v, + out, + softmax_lse, + dqkv[:, 0], + dqkv[:, 1], + dqkv[:, 2], + cu_seqlens, + cu_seqlens, + ctx.max_seqlen, + ctx.max_seqlen, + ctx.dropout_p, + ctx.softmax_scale, + ctx.causal, + rng_state=rng_state, + ) + dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension + return dqkv, None, None, None, None, None, None + + +def flash_attn_varlen_qkvpacked_func( + qkv, + cu_seqlens, + max_seqlen, + dropout_p=0.0, + softmax_scale=None, + causal=False, + return_attn_probs=False, +): + """dropout_p should be set to 0.0 during evaluation + If Q, K, V are already stacked into 1 tensor, this function will be faster than + calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation + of the gradients of Q, K, V. + For multi-query and grouped-query attention (MQA/GQA), please see + flash_attn_varlen_kvpacked_func and flash_attn_varlen_func. + + Arguments: + qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch. + cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into qkv. + max_seqlen: int. Maximum sequence length in the batch. + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + return_attn_probs: bool. Whether to return the attention probabilities. This option is for + testing only. The returned probabilities are not guaranteed to be correct + (they might not have the right scaling). + Return: + out: (total, nheads, headdim). + softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). + The output of softmax (possibly with different scaling). It also encodes the dropout + pattern (negative means that location was dropped, nonnegative means it was kept). + """ + return FlashAttnVarlenQKVPackedFunc.apply( + qkv, cu_seqlens, max_seqlen, dropout_p, softmax_scale, causal, return_attn_probs + ) + + +class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + kv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + return_softmax, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward( + q, + kv[:, 0], + kv[:, 1], + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal=causal, + return_softmax=return_softmax and dropout_p > 0, + ) + ctx.save_for_backward( + q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state + ) + ctx.dropout_p = dropout_p + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.softmax_scale = softmax_scale + ctx.causal = causal + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors + dq = torch.empty_like(q) + kv_shape = k.shape[:-2] + (2, *k.shape[-2:]) + dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device) + _flash_attn_varlen_backward( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dkv[:, 0], + dkv[:, 1], + cu_seqlens_q, + cu_seqlens_k, + ctx.max_seqlen_q, + ctx.max_seqlen_k, + ctx.dropout_p, + ctx.softmax_scale, + ctx.causal, + rng_state=rng_state, + ) + dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension + dkv = dkv[..., : dout.shape[-1]] + return dq, dkv, None, None, None, None, None, None, None, None + + +def flash_attn_varlen_kvpacked_func( + q, + kv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p=0.0, + softmax_scale=None, + causal=False, + return_attn_probs=False, +): + """dropout_p should be set to 0.0 during evaluation + If K, V are already stacked into 1 tensor, this function will be faster than + calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation + of the gradients of K, V. + Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads + than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + + If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. + For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: + 1 1 1 1 0 + 1 1 1 1 1 + If seqlen_q = 5 and seqlen_k = 2, the causal mask is: + 0 0 + 0 0 + 0 0 + 1 0 + 1 1 + If the row of the mask is all zero, the output will be zero. + + Arguments: + q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. + kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch. + cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into q. + cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into kv. + max_seqlen_q: int. Maximum query sequence length in the batch. + max_seqlen_k: int. Maximum key sequence length in the batch. + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + return_attn_probs: bool. Whether to return the attention probabilities. This option is for + testing only. The returned probabilities are not guaranteed to be correct + (they might not have the right scaling). + Return: + out: (total, nheads, headdim). + softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). + The output of softmax (possibly with different scaling). It also encodes the dropout + pattern (negative means that location was dropped, nonnegative means it was kept). + """ + return FlashAttnVarlenKVPackedFunc.apply( + q, + kv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + return_attn_probs, + ) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 4e81b70b6..09b5c6985 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -345,14 +345,17 @@ def __init__( else: if self.use_flash_attention: from megatron.model.flash_attention import ( - flash_attn_unpadded_qkvpacked_func_cuda, - flash_attn_unpadded_kvpacked_func_cuda, - flash_attn_unpadded_unpacked_func_triton, + # flash_attn_unpadded_qkvpacked_func_cuda, + # flash_attn_unpadded_kvpacked_func_cuda, + # Change of function names going from flash attention 1 -> flash attention 2 + flash_attn_varlen_qkvpacked_func, + flash_attn_varlen_kvpacked_func, + flash_attn_unpadded_unpacked_func_triton ) self.flash_triton_fn = flash_attn_unpadded_unpacked_func_triton - self.flash_qkv_fn = flash_attn_unpadded_qkvpacked_func_cuda - self.flash_kv_fn = flash_attn_unpadded_kvpacked_func_cuda + self.flash_qkv_fn = flash_attn_varlen_qkvpacked_func + self.flash_kv_fn = flash_attn_varlen_kvpacked_func else: self.scale_mask_softmax = FusedScaleMaskSoftmax( input_in_fp16=self.fp16, diff --git a/requirements/requirements-flashattention.txt b/requirements/requirements-flashattention.txt index 0c7d41e59..8cebdaa50 100644 --- a/requirements/requirements-flashattention.txt +++ b/requirements/requirements-flashattention.txt @@ -1 +1 @@ -flash-attn==0.2.2 +flash-attn==2.2.1