Skip to content

Commit

Permalink
LowerTriangularMask.to inference_mode fix
Browse files Browse the repository at this point in the history
ghstack-source-id: e6b798c2743aaed8b02b6a55a3d3e85deb45cf38
Pull Request resolved: fairinternal/xformers#1165

__original_commit__ = fairinternal/xformers@083e3a4
  • Loading branch information
bottler authored and xFormers Bot committed Jul 26, 2024
1 parent 0b9cb70 commit 3610a54
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
6 changes: 6 additions & 0 deletions tests/test_mem_eff_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -1649,10 +1649,16 @@ def _test_to_copy(attn_bias: torch.Tensor) -> None:
attn_bias = fmha.attn_bias.LowerTriangularMask().to("cpu")
_test_to_copy(attn_bias)

with torch.inference_mode():
_test_to_copy(attn_bias)

tensor_bias = torch.tensor([[1.0, 2.0, 3.0], [3.0, 4.0, 5.0]])
attn_bias = fmha.attn_bias.LowerTriangularMaskWithTensorBias(tensor_bias).to("cpu")
_test_to_copy(attn_bias)

with torch.inference_mode():
_test_to_copy(attn_bias)


def _kv_heads_label(kv_heads: Optional[int]) -> str:
if kv_heads is None:
Expand Down
4 changes: 3 additions & 1 deletion xformers/ops/fmha/attn_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -1568,8 +1568,9 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
torch.ops.aten.clone,
torch.ops.aten.detach,
torch.ops.aten._to_copy,
torch.ops.aten.to,
]:
return cls(_subtensor=func(args[0]._subtensor, **kwargs))
return cls(_subtensor=func(args[0]._subtensor, *args[1:], **kwargs))
return NotImplemented

def __tensor_flatten__(self):
Expand Down Expand Up @@ -1669,6 +1670,7 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
torch.ops.aten.clone,
torch.ops.aten.detach,
torch.ops.aten._to_copy,
torch.ops.aten.to,
]:
output = func(
*[a._subtensor if isinstance(a, cls) else a for a in args],
Expand Down

0 comments on commit 3610a54

Please sign in to comment.