Skip to content

Commit

Permalink
BlockDiagonalCausalLocalAttentionPaddedKeysMask (fairinternal/xformer…
Browse files Browse the repository at this point in the history
…s#1139)

For decoding with local attention

__original_commit__ = fairinternal/xformers@187147d
  • Loading branch information
bottler authored and xFormers Bot committed Sep 4, 2024
1 parent 67c5055 commit 2009638
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 5 deletions.
1 change: 1 addition & 0 deletions tests/test_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ def _generate_op_device_dtype_biasT_B_Mq_Mkv_H_K_Kv(
}:
Mq, Mkv = min(Mkv, Mq), max(Mkv, Mq) + 2
elif bias_type in {
fmha.attn_bias.BlockDiagonalCausalLocalAttentionPaddedKeysMask,
fmha.attn_bias.BlockDiagonalCausalWithOffsetGappyKeysMask,
fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask,
fmha.attn_bias.BlockDiagonalPaddedKeysMask,
Expand Down
19 changes: 14 additions & 5 deletions xformers/attn_bias_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def create_attn_bias(
return block_diag
if bias_type in [
fmha.attn_bias.BlockDiagonalPaddedKeysMask,
fmha.attn_bias.BlockDiagonalCausalLocalAttentionPaddedKeysMask,
fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask,
fmha.attn_bias.PagedBlockDiagonalPaddedKeysMask,
fmha.attn_bias.PagedBlockDiagonalCausalWithOffsetPaddedKeysMask,
Expand All @@ -167,11 +168,19 @@ def create_attn_bias(
if issubclass(bias_type, fmha.attn_bias.PagedBlockDiagonalPaddedKeysMask)
else bias_type
)
g_block_diag = block_diag_type.from_seqlens(
q_seqlen=q,
kv_padding=kv_len,
kv_seqlen=k,
)
if bias_type is fmha.attn_bias.BlockDiagonalCausalLocalAttentionPaddedKeysMask:
g_block_diag = block_diag_type.from_seqlens_local(
q_seqlen=q,
kv_padding=kv_len,
kv_seqlen=k,
window_size=min(window_size, min(k)),
)
else:
g_block_diag = block_diag_type.from_seqlens(
q_seqlen=q,
kv_padding=kv_len,
kv_seqlen=k,
)
if issubclass(bias_type, fmha.attn_bias.PagedBlockDiagonalPaddedKeysMask):
assert page_size is not None
pages_per_row = (kv_len + page_size - 1) // page_size
Expand Down
47 changes: 47 additions & 0 deletions xformers/ops/fmha/attn_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -1087,6 +1087,53 @@ def from_seqlens(
return cls(q_seqinfo=q_seqinfo, k_seqinfo=k_seqinfo)


@dataclass
class BlockDiagonalCausalLocalAttentionPaddedKeysMask(BlockDiagonalPaddedKeysMask):
"""
Like :attr:`xformers.ops.fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask`,
except with a window size.
A query Q in block i cannot attend to a key which is not in block i,
nor one which is not in use (i.e. in the padded area),
nor one which is nearer to the final key in block i
than Q is to the final query in block i, nor one that is more than
window_size further from the final key in block i than Q is
to the final query in block i.
"""

_window_size: int

def _create_block_mask(
self,
shape: Tuple[int, ...],
dtype: torch.dtype = torch.float32,
device: Union[str, torch.device] = "cpu",
) -> torch.Tensor:
return _materialize_causal_mask(
shape=shape,
dtype=dtype,
device=device,
window_size=self._window_size,
from_bottomright=True,
)

@classmethod
def from_seqlens_local(
cls,
q_seqlen: Sequence[int],
kv_padding: int,
kv_seqlen: Sequence[int],
window_size: int,
) -> "BlockDiagonalCausalLocalAttentionPaddedKeysMask":
assert kv_seqlen is None or len(q_seqlen) == len(kv_seqlen), (
q_seqlen,
kv_seqlen,
)
q_seqinfo = _SeqLenInfo.from_seqlens(q_seqlen)
k_seqinfo = _PaddedSeqLenInfo.from_seqlens_padded(kv_seqlen, kv_padding)
return cls(q_seqinfo=q_seqinfo, k_seqinfo=k_seqinfo, _window_size=window_size)


@dataclass
class PagedBlockDiagonalPaddedKeysMask(AttentionBias):
"""
Expand Down
5 changes: 5 additions & 0 deletions xformers/ops/fmha/flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
BlockDiagonalCausalFromBottomRightMask,
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
BlockDiagonalCausalLocalAttentionMask,
BlockDiagonalCausalLocalAttentionPaddedKeysMask,
BlockDiagonalCausalMask,
BlockDiagonalCausalWithOffsetGappyKeysMask,
BlockDiagonalCausalWithOffsetPaddedKeysMask,
Expand Down Expand Up @@ -470,6 +471,7 @@ def _is_causal(attn_bias: Optional[Union[torch.Tensor, AttentionBias]]) -> bool:
BlockDiagonalCausalLocalAttentionMask,
BlockDiagonalCausalFromBottomRightMask,
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
BlockDiagonalCausalLocalAttentionPaddedKeysMask,
BlockDiagonalCausalWithOffsetGappyKeysMask,
BlockDiagonalCausalWithOffsetPaddedKeysMask,
PagedBlockDiagonalCausalWithOffsetPaddedKeysMask,
Expand All @@ -494,6 +496,7 @@ def _window_size(
(
BlockDiagonalCausalLocalAttentionMask,
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
BlockDiagonalCausalLocalAttentionPaddedKeysMask,
LowerTriangularFromBottomRightLocalAttentionMask,
),
):
Expand Down Expand Up @@ -597,6 +600,7 @@ class FwOp(AttentionFwOpBase):
BlockDiagonalCausalMask,
BlockDiagonalCausalLocalAttentionMask,
BlockDiagonalCausalLocalAttentionFromBottomRightMask,
BlockDiagonalCausalLocalAttentionPaddedKeysMask,
BlockDiagonalCausalFromBottomRightMask,
BlockDiagonalCausalWithOffsetGappyKeysMask,
BlockDiagonalCausalWithOffsetPaddedKeysMask,
Expand Down Expand Up @@ -722,6 +726,7 @@ class BwOp(AttentionBwOpBase):
SUPPORTED_ATTN_BIAS_TYPES: Iterable[Any] = tuple(
set(FwOp.SUPPORTED_ATTN_BIAS_TYPES).difference(
{
BlockDiagonalCausalLocalAttentionPaddedKeysMask,
BlockDiagonalCausalWithOffsetGappyKeysMask,
BlockDiagonalCausalWithOffsetPaddedKeysMask,
BlockDiagonalGappyKeysMask,
Expand Down

0 comments on commit 2009638

Please sign in to comment.