diff --git a/tests/test_mem_eff_attention.py b/tests/test_mem_eff_attention.py index dce31201e1..778d818540 100644 --- a/tests/test_mem_eff_attention.py +++ b/tests/test_mem_eff_attention.py @@ -1649,10 +1649,16 @@ def _test_to_copy(attn_bias: torch.Tensor) -> None: attn_bias = fmha.attn_bias.LowerTriangularMask().to("cpu") _test_to_copy(attn_bias) + with torch.inference_mode(): + _test_to_copy(attn_bias) + tensor_bias = torch.tensor([[1.0, 2.0, 3.0], [3.0, 4.0, 5.0]]) attn_bias = fmha.attn_bias.LowerTriangularMaskWithTensorBias(tensor_bias).to("cpu") _test_to_copy(attn_bias) + with torch.inference_mode(): + _test_to_copy(attn_bias) + def _kv_heads_label(kv_heads: Optional[int]) -> str: if kv_heads is None: diff --git a/xformers/ops/fmha/attn_bias.py b/xformers/ops/fmha/attn_bias.py index d3c5f9487a..d69393ffff 100644 --- a/xformers/ops/fmha/attn_bias.py +++ b/xformers/ops/fmha/attn_bias.py @@ -1568,8 +1568,9 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): torch.ops.aten.clone, torch.ops.aten.detach, torch.ops.aten._to_copy, + torch.ops.aten.to, ]: - return cls(_subtensor=func(args[0]._subtensor, **kwargs)) + return cls(_subtensor=func(args[0]._subtensor, *args[1:], **kwargs)) return NotImplemented def __tensor_flatten__(self): @@ -1669,6 +1670,7 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): torch.ops.aten.clone, torch.ops.aten.detach, torch.ops.aten._to_copy, + torch.ops.aten.to, ]: output = func( *[a._subtensor if isinstance(a, cls) else a for a in args],