diff --git a/xformers/ops/fmha/__init__.py b/xformers/ops/fmha/__init__.py index ff07f3bd0..c4391752a 100644 --- a/xformers/ops/fmha/__init__.py +++ b/xformers/ops/fmha/__init__.py @@ -17,7 +17,12 @@ flash3, triton_splitk, ) -from .attn_bias import VARLEN_BIASES, AttentionBias, LowerTriangularMask +from .attn_bias import ( + VARLEN_BIASES, + AttentionBias, + LowerTriangularMask, + BlockDiagonalMask, +) from .common import ( AttentionBwOpBase, AttentionFwOpBase,