Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
  • Loading branch information
youkaichao committed Dec 30, 2024
1 parent cbd030c commit 487b06a
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions tests/kernels/test_attention_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,34 @@
import torch

from tests.kernels.utils import override_backend_env_variable
from vllm.attention.selector import which_attn_to_use
from tests.utils import fork_new_process_for_each_test
from vllm.utils import STR_FLASH_ATTN_VAL, STR_INVALID_VAL


@pytest.mark.parametrize(
"name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER", "OPENVINO"])
@pytest.mark.parametrize("device", ["cpu", "openvino", "rocm", "cuda"])
@fork_new_process_for_each_test
def test_env(name: str, device: str, monkeypatch):
"""Test that the attention selector can be set via environment variable.
Note that we do not test FlashAttn because it is the default backend.
"""

override_backend_env_variable(monkeypatch, name)

monkeypatch.setenv("VLLM_TEST_FORCE_PLATFORM", device)
from vllm.attention.selector import which_attn_to_use

backend = which_attn_to_use(16, torch.float16, torch.float16, 16, False)
assert backend.name == name


@fork_new_process_for_each_test
def test_flash_attn(monkeypatch):
"""Test FlashAttn validation."""
# TODO: When testing for v1, pipe in `use_v1` as an argument to
# which_attn_to_use

override_backend_env_variable(monkeypatch, STR_FLASH_ATTN_VAL)
from vllm.attention.selector import which_attn_to_use

# Unsupported CUDA arch
with patch("torch.cuda.get_device_capability", return_value=(7, 5)):
Expand Down Expand Up @@ -61,8 +64,11 @@ def test_flash_attn(monkeypatch):
assert backend.name != STR_FLASH_ATTN_VAL


@fork_new_process_for_each_test
def test_invalid_env(monkeypatch):
"""Throw an exception if the backend name is invalid."""
override_backend_env_variable(monkeypatch, STR_INVALID_VAL)
from vllm.attention.selector import which_attn_to_use

with pytest.raises(ValueError):
which_attn_to_use(16, torch.float16, None, 16, False)

0 comments on commit 487b06a

Please sign in to comment.