diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index d14359472ddd..6d4ed4539bc0 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -5526,6 +5526,8 @@ def test_fn(attn_mask=None, is_causal=False): # Test with explicit attn_mask attn_mask = torch.ones((L, S), dtype=torch.bool).tril(diagonal=0) + if torch.cuda.is_available(): + attn_mask = attn_mask.cuda() verify_model(test_fn(attn_mask=attn_mask), [query_4d, key_4d, value_4d]) verify_model(test_fn(attn_mask=attn_mask), [query_4d, key_4d, value_3d]) verify_model(test_fn(attn_mask=attn_mask), [query_4d, key_3d, value_4d])