🤖FFPA: Yet antother Faster Flash Prefill Attention with O(1)⚡️GPU SRAM complexity for large headdim🐑
🤖[WIP] FFPA: Yet antother Faster Flash Prefill Attention with O(1) SRAM complexity & O(d/4) or O(1) register complexity for large headdim (D > 256), almost 1.8x~3x 🎉 faster than SDPA EA with or without MMA Acc F32 on many devices: 📈L20 ~1.9x↑🎉, 📈A30 ~1.8x↑🎉, 📈3080 ~2.9x↑🎉, 📈4090 ~2.1x↑🎉. FFPA Attention Algo: Fine-grained tiling for large headim, FA-2 Attention Algo: Coarse-grained tiling for small headidm.
💡NOTE: This project is still in its early dev stages and now provides some kernels and benchmarks for reference. More features will be added in the future. (Welcome to 🌟👆🏻star this repo to support me ~)
@misc{ffpa-attn-mma@2025,
title={FFPA: Yet another Faster Flash Prefill Attention for large headdim.},
url={https://github.com/DefTruth/ffpa-attn-mma.git},
note={Open-source software available at https://github.com/DefTruth/ffpa-attn-mma.git},
author={DefTruth etc},
year={2025}
}
- 📖 Installation⚙️
- 📖 Python Testing👇
- 📖 FFPA L1~L3 Design💡
- 📈 FFPA L1: L20 ~1.9x↑🎉
- 📈 FFPA L1: A30 ~1.8x↑🎉
- 📈 FFPA L1: 3080 ~2.9x↑🎉
- 📈 FFPA L1: 4090 ~2.1x↑🎉
We have extended FlashAttention for large headdim (D > 256) by implementing Fine-grained Tiling at the MMA level (GEMM style) for the Q@K^T and P@V matmul. This approach results in a constant SRAM usage of Br * 16 or Bc * 16 (Br = Bc) for Q, K, and V, leading to an overall SRAM complexity of O(2 * Br * 16) ≈ O(1) and a register complexity of O(d/4) or O(1). Consequently, this method allows us to extend headdim beyond 256 and achieve faster performance compared to SDPA with or without MMA Accumulation F32 (1.8x~3x 🎉 faster than SDPA EA).
We have named this new attention tiling technique FFPA: Faster Flash Prefill Attention. We have designed three (L1~L3)
levels of FFPA based on SRAM and register complexity considerations. All levels will not introduce any additional VRAM requirements, ensuring that the HBM memory complexity remains same as FlashAttention. 👇
- 📚L1: level 1, O(2xBrx16)≈O(1) SRAM complexity, ≈O(d/4) register complexity.
- 📚L2: level 2, O(2xBrx16)≈O(1) SRAM complexity, ≈O(1) register complexity + Q@K^T recomputation.
- 📚L3: level 3, O(2xBrx16)≈O(1) SRAM complexity, ≈O(1) register complexity + scaling O via HBM offloading.
By leveraging this approach, we can achieve better performance for large headdim (D > 256) through a balanced utilization of FlashAttention (which is not designed to support D > 256) and SDPA EA. Approximate SRAM and register complexity analysis for L1~L3 is as follows: (d
=headdim, C,Br,Bc
=Constant, Br=Bc
) 👇
📚Complexity | 📚FFPA L1 | 📚FFPA L2 | 📚FFPA L3 | 📚FA-2 |
---|---|---|---|---|
SRAM | O(2xBrx16)≈O(1) | O(2xBrx16)≈O(1) | O(2xBrx16)≈O(1) | ≈O(3xBrxd), d↑ |
Register | ≈O(d/4), d↑ | O((Bc/16)x4+2C)≈O(1) | O((Bc/16)x4+2C)≈O(1) | ≈O(d/2), d↑ |
HBM | ≈FA2≈O(Nd), O | ≈FA2≈O(Nd), O | ≈FA2≈O(Nd), O | ≈O(Nd), O |
Extra HBM | ≈FA2≈O(N), m,l | ≈FA2≈O(N), m,l | ≈FA2≈O(N), m,l | ≈O(N), m,l |
📚👇Core Features🎉🎉: I have implemented FFPA L1~L3 using pure MMA PTX instructions, which supports many features such as Split-Q, SMEM Swizzle/Padding, QKV Multi-Stages(1~4), Tile MMAs/Warps, Mixed MMA F32/F16 Acc (Q@K^T MMA Acc F32 + P@V MMA Acc F16), Fully Shared QKV SMEM, Prefetch QKV g2s, Persist Q s2r/g2s, Fully QKV Fine-grained Tiling(GEMM style), Collective Store, etc.
📚Feature | 📚Feature | 📚Feature | 📚Feature |
---|---|---|---|
✔️Tensor Cores | ✔️Loop over N/D | ✔️Tile Block(Br, Bc) | ✔️MMA(m16n8k16) |
✔️Split Q(FA-2) | ✔️Pack LDST(128 bits) | ✔️SMEM Swizzle/Pad | ✔️Copy Async |
✔️Tile MMA/Warp | ✔️QKV Multi-Stages(1~4) | ✔️Collective Store(Shfl) | ✔️Prefetch QKV g2s |
✔️QKV Fine-grained Tiling | ✔️Shared QKV SMEM | ✔️Mixed MMA Acc | ✔️Persist Q s2r/g2s |
- 📚 case: FFPA
L1
kernel template signature: ffpa_attn_templates_L1.cuh
template<
const int kHeadDim, // Headdim, 32~1024
const int kMmaAtomM, // MMA Atom M, 16
const int kMmaAtomN, // MMA Atom N, 8
const int kMmaAtomK, // MMA Atom K, 16
const int kMmaTileSeqLenQ, // 4, more MMA(warp), M=16*4=64, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)]
const int kMmaTileSeqLenK, // 1, more MMA(warp), N=8*1 =8, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)]
const int kMmaTileSeqLenP, // 4, more MMA(warp), M=16*4=64, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ]
const int kMmaTileHeadDimV, // 1, more MMA(warp), N=8*1 =8, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ]
const int kWarpTileSeqLenQ, // 1, more values, M, Br=64*1=64, matmul M
const int kWarpTileSeqLenK, // 8, more values, N, Bc=8*8 =64, matmul N
const int kWarpTileSeqLenP, // 1, more values, M, Br=64*1=64, matmul M
const int kWarpTileHeadDimV, // 8, more values, N, d=8*(1|2|3|4|...)=8|...|32|64|96|128|...
const int kMmaAccFloat32QK, // 0/1, Q@K^T, 0 MMA Acc with fp16, 1 MMA Acc with fp32.
const int kMmaAccFloat32PV, // 0/1, P@V, 0 MMA Acc with fp16, 1 MMA Acc with fp32.
const int kOStorageAccFloat32, // 0/1, MMA Acc always be f32/f16, but O storage can be fp32 or half.
const int kPrefetchQK, // Prefetch QK at the Appropriate Time Point.
const int kPrefetchPV, // Prefetch V at the Appropriate Time Point.
const int kShareSmemQKV, // QKV share the same shared memory, reuse QK smem for V.
const int kPersistQs2r, // Persist load Q s2r for headdim < 512, more registers, but still keep O(1) SRAM.
const int kPersistQg2s, // Persist load Q g2s for headdim <= 320, more SRAM, but still keep register usage.
const int kStageQK, // <= 4, may apply different multi stages policy for QK and V (<=4)
const int kStagePV, // <= 4, may apply different multi stages policy for QK and V (<=4)
const int kPadQ, // Pad Q/K/V 0,8; 0 -> smem swizzle, > 0 -> padding
const int kPadK, // Pad Q/K/V 0,8; 0 -> smem swizzle, > 0 -> padding
const int kPadV // Pad Q/K/V 0,8; 0 -> smem swizzle, > 0 -> padding
> __global__ void // Q, K, V, O -> [B, H, N, D]
// FFPA Attention Algo: Fine-grained tiling at MMA level for large headdim (d>=256),
// which can achieve 1.8x~3x🎉 faster than SDPA EA with or without MMA Acc F32.
ffpa_mma_stages_split_q_L1_large_d_template(half* Q, half* K, half* V, half* O, ...);
// FA-2 Attention Algo: Coarse-grained tiling at Attention level for small headdim (d<256),
// which can achieve 95%-150%🎉 performance as SDPA FA-2 BE with MMA Acc F32 for N<=4096,
// and achieve almost 1.2x~1.4x🎉 faster than SDPA FA-2 via Mixed MMA Acc(Q@K^T F32 +
// P@V F16) for all range N.
ffpa_mma_stages_split_q_L1_small_d_template(half* Q, half* K, half* V, half* O, ...);
- Python >= 3.10
- PyTorch >= 2.4.0, CUDA >= 12.4
- Recommended: PyTorch 2.5.1, CUDA 12.5
- Docker: nvcr.io/nvidia/pytorch:24.10-py3
The FFPA implemented in this repo can be install as a python library, namely, ffpa-attn
library (optional).
git clone https://github.com/DefTruth/ffpa-attn-mma.git
# clone, then, run bash .dev/install.sh directly or run commands:
python3 setup.py bdist_wheel && cd dist && python3 -m pip install *.whl # pip uninstall ffpa-attn -y
L1: level 1, O(2xBrx16)≈O(1) SRAM complexity, O(d/4) register complexity, the same GPU HBM memory complexity as FlashAttention. B=1, H=48, N=8192, D=320-1024(FA2 not supported 👀). (Notes, *
=MMA Acc F32, ^
=MMA Acc F16, Softmax Acc dtype is always be F32, T=TFLOPS, 👇Benchmark)
- 📚 NVIDIA L20 (
*
=MMA Acc F32,^
=MMA Acc F16,T
=TFLOPS, ~1.8x↑🎉)
Algorithm | 320 | 384 | 448 | 512 | 576 | 640 | 704 | 768 | 832 | 896 | 960 | 1024 |
---|---|---|---|---|---|---|---|---|---|---|---|---|
SDPA EA | 56T | 63T | 58T | 58T | 55T | 56T | 54T | 55T | 54T | 55T | 54T | 56T |
FFPA L1* | 102T | 102T | 103T | 104T | 103T | 95T | 95T | 95T | 95T | 96T | 95T | 94T |
Speedup | 1.82x | 1.62x | 1.78x | 1.79x | 1.87x | 1.7x | 1.76x | 1.73x | 1.76x | 1.75x | 1.76x | 1.68x |
FFPA L1^ | 104T | 103T | 103T | 102T | 104T | 103T | 102T | 94T | 94T | 94T | 100T | 100T |
Speedup | 1.86x | 1.63x | 1.78x | 1.76x | 1.89x | 1.84x | 1.89x | 1.71x | 1.74x | 1.71x | 1.85x | 1.79x |
- 📚 NVIDIA L20 (
*
=MMA Acc: QK F32 + PV F16,^
=MMA Acc F16,T
=TFLOPS, ~1.9x↑🎉)
Algorithm | 320 | 384 | 448 | 512 | 576 | 640 | 704 | 768 | 832 | 896 | 960 | 1024 |
---|---|---|---|---|---|---|---|---|---|---|---|---|
SDPA EA | 56T | 64T | 58T | 58T | 55T | 56T | 54T | 55T | 54T | 55T | 54T | 56T |
FFPA L1* | 105T | 102T | 104T | 103T | 105T | 95T | 95T | 94T | 94T | 94T | 102T | 101T |
Speedup | 1.88x | 1.59x | 1.79x | 1.78x | 1.91x | 1.7x | 1.76x | 1.71x | 1.74x | 1.71x | 1.89x | 1.8x |
FFPA L1^ | 104T | 103T | 103T | 102T | 103T | 103T | 102T | 94T | 94T | 94T | 100T | 100T |
Speedup | 1.86x | 1.61x | 1.78x | 1.76x | 1.87x | 1.84x | 1.89x | 1.71x | 1.74x | 1.71x | 1.85x | 1.79x |
- 📚 NVIDIA A30 (
*
=MMA Acc F32,^
=MMA Acc F16,T
=TFLOPS, ~1.8x↑🎉)
Algorithm | 320 | 384 | 448 | 512 | 576 | 640 | 704 | 768 | 832 | 896 | 960 | 1024 |
---|---|---|---|---|---|---|---|---|---|---|---|---|
SDPA EA | 25T | 25T | 24T | 24T | 24T | 24T | 23T | 22T | 22T | 22T | 22T | 18T |
FFPA L1* | 45T | 44T | 44T | 43T | 43T | 38T | 37T | 37T | 37T | 36T | 33T | 32T |
Speedup | 1.8x | 1.76x | 1.83x | 1.79x | 1.79x | 1.58x | 1.61x | 1.68x | 1.68x | 1.64x | 1.5x | 1.78x |
FFPA L1^ | 48T | 46T | 45T | 43T | 44T | 44T | 44T | 38T | 37T | 36T | 40T | 34T |
Speedup | 1.92x | 1.84x | 1.88x | 1.79x | 1.83x | 1.83x | 1.91x | 1.73x | 1.68x | 1.64x | 1.82x | 1.89x |
- 📚 NVIDIA A30 (
*
=MMA Acc: QK F32 + PV F16,^
=MMA Acc F16,T
=TFLOPS, ~1.9x↑🎉)
Algorithm | 320 | 384 | 448 | 512 | 576 | 640 | 704 | 768 | 832 | 896 | 960 | 1024 |
---|---|---|---|---|---|---|---|---|---|---|---|---|
SDPA EA | 25T | 25T | 24T | 24T | 24T | 24T | 23T | 22T | 22T | 22T | 22T | 18T |
FFPA L1* | 48T | 46T | 46T | 43T | 44T | 38T | 38T | 38T | 37T | 36T | 40T | 34T |
Speedup | 1.92x | 1.84x | 1.92x | 1.79x | 1.83x | 1.58x | 1.65x | 1.73x | 1.68x | 1.64x | 1.82x | 1.89x |
FFPA L1^ | 48T | 46T | 45T | 43T | 44T | 44T | 44T | 38T | 37T | 36T | 39T | 34T |
Speedup | 1.92x | 1.84x | 1.88x | 1.79x | 1.83x | 1.83x | 1.91x | 1.73x | 1.68x | 1.64x | 1.77x | 1.89x |
- 📚 NVIDIA RTX 3080 Laptop (
*
=MMA Acc F32,^
=MMA Acc F16,T
=TFLOPS, ~2.5x↑🎉)
Algorithm | 320 | 384 | 448 | 512 | 576 | 640 | 704 | 768 | 832 | 896 | 960 | 1024 |
---|---|---|---|---|---|---|---|---|---|---|---|---|
SDPA EA | 13T | 16T | 11T | 16T | 15T | 15T | 15T | 15T | 14T | 14T | 14T | 14T |
FFPA L1* | 33T | 31T | 30T | 30T | 30T | 27T | 27T | 26T | 26T | 26T | 26T | 25T |
Speedup | 2.54x | 1.94x | 2.73x | 1.88x | 2.0x | 1.8x | 1.8x | 1.73x | 1.86x | 1.86x | 1.86x | 1.79x |
FFPA L1^ | 43T | 41T | 39T | 39T | 39T | 39T | 39T | 36T | 34T | 33T | 31T | 33T |
Speedup | 3.31x | 2.56x | 3.55x | 2.44x | 2.6x | 2.6x | 2.6x | 2.4x | 2.43x | 2.36x | 2.21x | 2.36x |
- 📚 NVIDIA RTX 3080 Laptop (
*
=MMA Acc: QK F32 + PV F16,^
=MMA Acc F16,T
=TFLOPS, ~2.9x↑🎉)
Algorithm | 320 | 384 | 448 | 512 | 576 | 640 | 704 | 768 | 832 | 896 | 960 | 1024 |
---|---|---|---|---|---|---|---|---|---|---|---|---|
SDPA EA | 13T | 15T | 12T | 15T | 14T | 15T | 14T | 14T | 14T | 14T | 14T | 14T |
FFPA L1* | 38T | 36T | 34T | 35T | 34T | 31T | 32T | 31T | 30T | 28T | 27T | 27T |
Speedup | 2.92x | 2.4x | 2.83x | 2.33x | 2.43x | 2.07x | 2.29x | 2.21x | 2.14x | 2.0x | 1.93x | 1.93x |
FFPA L1^ | 44T | 41T | 39T | 39T | 38T | 39T | 39T | 36T | 34T | 32T | 31T | 33T |
Speedup | 3.38x | 2.73x | 3.25x | 2.6x | 2.71x | 2.6x | 2.79x | 2.57x | 2.43x | 2.29x | 2.21x | 2.36x |
- 📚 NVIDIA RTX 4090 (
*
=MMA Acc F32,^
=MMA Acc F16,T
=TFLOPS, ~1.8x↑🎉)
Algorithm | 320 | 384 | 448 | 512 | 576 | 640 | 704 | 768 | 832 | 896 | 960 | 1024 |
---|---|---|---|---|---|---|---|---|---|---|---|---|
SDPA EA | 81T | 94T | 85T | 85T | 79T | 81T | 79T | 80T | 79T | 80T | 78T | 78T |
FFPA L1* | 149T | 150T | 150T | 150T | 150T | 140T | 140T | 140T | 139T | 139T | 137T | 134T |
Speedup | 1.84x | 1.6x | 1.76x | 1.76x | 1.9x | 1.73x | 1.77x | 1.75x | 1.76x | 1.74x | 1.76x | 1.72x |
FFPA L1^ | 194T | 194T | 189T | 191T | 197T | 188T | 184T | 180T | 177T | 172T | 171T | 171T |
Speedup | 2.4x | 2.06x | 2.22x | 2.25x | 2.49x | 2.32x | 2.33x | 2.25x | 2.24x | 2.15x | 2.19x | 2.19x |
- 📚 NVIDIA RTX 4090 (
*
=MMA Acc: QK F32 + PV F16,^
=MMA Acc F16,T
=TFLOPS, ~2.1x↑🎉)
Algorithm | 320 | 384 | 448 | 512 | 576 | 640 | 704 | 768 | 832 | 896 | 960 | 1024 |
---|---|---|---|---|---|---|---|---|---|---|---|---|
SDPA EA | 82T | 92T | 85T | 84T | 78T | 81T | 79T | 80T | 78T | 79T | 77T | 78T |
FFPA L1* | 176T | 170T | 171T | 171T | 171T | 161T | 160T | 161T | 160T | 158T | 165T | 164T |
Speedup | 2.15x | 1.85x | 2.01x | 2.04x | 2.19x | 1.99x | 2.03x | 2.01x | 2.05x | 2.0x | 2.14x | 2.1x |
FFPA L1^ | 200T | 191T | 189T | 191T | 188T | 188T | 186T | 179T | 175T | 173T | 172T | 170T |
Speedup | 2.44x | 2.08x | 2.22x | 2.27x | 2.41x | 2.32x | 2.35x | 2.24x | 2.24x | 2.19x | 2.23x | 2.18x |
👇You can test many custom FFPA kernels via Python and figure out the difference in their performance. The --gen-bench
and --plot
options help you generate a benchmark table in Markdown style and speedup bar plots on your device. Contributions of your benchmark tables and plots are welcome via a PR 🎉🎉.
- 📚 case: B=1, H=48, N=8192, D=320(
FA2 not supported
)
# You can test on many devices, such as Volta, Ampere, Ada, Hopper, ...
cd tests && python3 test.py --B 1 --H 48 --N 8192 --show-all --D 320
- 📚 case: Generate benchmark table and speedup bar plots on Your device.
cd tests && pip install matplotlib && python3 test.py --gen-bench --show-all --plot
- 📚 case: Compare small headdim (d<256, e.g 64), FFPA-L1 vs SDPA FA-2 BE.
# Enable ffpa-attn small d kernel which using coarse-grained tiling method.
export ENABLE_FFPA_PERSIST_Q_G2S=1 && export ENABLE_FFPA_PERSIST_KV_G2S=1
python3 test.py --B 1 --H 32 --N 1024 --check --show-all --D 64 # NVIDIA L20
---------------------------------------B=1, H=32, N=1024, D=64, Warmup: 1, Iters: 5--------------------
(sdpa): ['-0.02571106'], time:0.154352ms, TFLOPS:56.72 (+0.00 %)(~1.00x)
(ffpa+acc+f32+L1+stage1): ['-0.02572632'], time:0.103998ms, TFLOPS:84.19 (+48.42%)(~1.48x)
(ffpa+acc+f32+L1+stage2): ['-0.02572632'], time:0.101900ms, TFLOPS:85.92 (+2.06 %)(~1.51x)
(ffpa+acc+f16+L1+stage1): ['-0.02568054'], time:0.113105ms, TFLOPS:77.41 (+0.00 %)(~1.36x)
(ffpa+acc+f16+L1+stage2): ['-0.02568054'], time:0.112771ms, TFLOPS:77.64 (+0.00 %)(~1.37x)
(ffpa+acc+f32+L1+stage3): ['-0.02572632'], time:0.101947ms, TFLOPS:85.88 (+0.00 %)(~1.51x)
(ffpa+acc+f32+L1+stage4): ['-0.02572632'], time:0.102043ms, TFLOPS:85.80 (+0.00 %)(~1.51x)
(ffpa+acc+f16+L1+stage3): ['-0.02568054'], time:0.111246ms, TFLOPS:78.70 (+0.00 %)(~1.39x)
(ffpa+acc+f16+L1+stage4): ['-0.02568054'], time:0.108432ms, TFLOPS:80.75 (+0.00 %)(~1.42x)
--------------------------------------------------------------------------------------------------------
💡NOTE: Please check all configurable environment variables in env.py.
GNU General Public License v3.0
How to contribute? Wecome to star⭐️ this repo to support me👆🏻 ~