Skip to content

Commit

Permalink
fMHA: Remove 'AttentionOpDispatch' class
Browse files Browse the repository at this point in the history
ghstack-source-id: ea17ae1745957b123f8ebf8c6c7b61dae03ac504
Pull Request resolved: https://github.com/fairinternal/xformers/pull/1183

__original_commit__ = fairinternal/xformers@60a85db468d212335d3abfbf29014ac627c37ac3
  • Loading branch information
danthe3rd authored and xFormers Bot committed Aug 9, 2024
1 parent e61c6f7 commit fae0ceb
Show file tree
Hide file tree
Showing 3 changed files with 0 additions and 40 deletions.
2 changes: 0 additions & 2 deletions xformers/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
AttentionBias,
AttentionOp,
AttentionOpBase,
AttentionOpDispatch,
LowerTriangularMask,
MemoryEfficientAttentionCkOp,
MemoryEfficientAttentionCutlassFwdFlashBwOp,
Expand Down Expand Up @@ -78,7 +77,6 @@ def masked_matmul(a, b, mask=None):
"AttentionMask",
"AttentionOp",
"AttentionOpBase",
"AttentionOpDispatch",
"LowerTriangularMask",
"MemoryEfficientAttentionCutlassFwdFlashBwOp",
"MemoryEfficientAttentionCutlassOp",
Expand Down
2 changes: 0 additions & 2 deletions xformers/ops/fmha/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
AttentionFwOpBase,
AttentionOp,
AttentionOpBase,
AttentionOpDispatch,
Context,
Gradients,
Inputs,
Expand Down Expand Up @@ -829,7 +828,6 @@ def backward(
"AttentionBias",
"AttentionOp",
"AttentionOpBase",
"AttentionOpDispatch",
"LowerTriangularMask",
"MemoryEfficientAttentionCutlassFwdFlashBwOp",
"MemoryEfficientAttentionCutlassOp",
Expand Down
36 changes: 0 additions & 36 deletions xformers/ops/fmha/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,42 +547,6 @@ def attn_operator_flop(
]


@dataclass
class AttentionOpDispatch:
"""Dispatcher to automatically select
the best operator to run memory-efficient attention.
:Deprecated:
This class is deprecated and will be removed in a later version
"""

op: AttentionOp

@classmethod
def from_arguments(
cls,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_bias: Optional[Union[torch.Tensor, AttentionBias]] = None,
p: float = 0.0,
scale: Optional[float] = None,
) -> "AttentionOpDispatch":
"""Here for backward compatibility"""
from .dispatch import _dispatch_bw, _dispatch_fw

inp = Inputs(
query=query,
key=key,
value=value,
attn_bias=attn_bias,
p=p,
scale=scale,
)
return AttentionOpDispatch(op=(_dispatch_fw(inp, True), _dispatch_bw(inp)))


def bmk2bmhk(tensor, num_heads: int) -> torch.Tensor:
if tensor.ndim == 4:
return tensor
Expand Down

0 comments on commit fae0ceb

Please sign in to comment.