From f598a6714b6eb4763b00120711abc9857df1436c Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 30 Dec 2024 17:14:53 +0800 Subject: [PATCH] fix tests Signed-off-by: youkaichao --- tests/kernels/test_attention_selector.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/kernels/test_attention_selector.py b/tests/kernels/test_attention_selector.py index 66f1ea5708997..b321ed7b7453f 100644 --- a/tests/kernels/test_attention_selector.py +++ b/tests/kernels/test_attention_selector.py @@ -3,7 +3,6 @@ import pytest import torch -from tests.kernels.utils import override_backend_env_variable from tests.utils import fork_new_process_for_each_test from vllm.utils import STR_FLASH_ATTN_VAL, STR_INVALID_VAL @@ -16,9 +15,10 @@ def test_env(name: str, device: str, monkeypatch): """Test that the attention selector can be set via environment variable. Note that we do not test FlashAttn because it is the default backend. """ + monkeypatch.setenv("VLLM_TEST_FORCE_PLATFORM", device) + from tests.kernels.utils import override_backend_env_variable override_backend_env_variable(monkeypatch, name) - monkeypatch.setenv("VLLM_TEST_FORCE_PLATFORM", device) from vllm.attention.selector import which_attn_to_use backend = which_attn_to_use(16, torch.float16, torch.float16, 16, False) @@ -30,6 +30,7 @@ def test_flash_attn(monkeypatch): """Test FlashAttn validation.""" # TODO: When testing for v1, pipe in `use_v1` as an argument to # which_attn_to_use + from tests.kernels.utils import override_backend_env_variable override_backend_env_variable(monkeypatch, STR_FLASH_ATTN_VAL) from vllm.attention.selector import which_attn_to_use @@ -67,6 +68,7 @@ def test_flash_attn(monkeypatch): @fork_new_process_for_each_test def test_invalid_env(monkeypatch): """Throw an exception if the backend name is invalid.""" + from tests.kernels.utils import override_backend_env_variable override_backend_env_variable(monkeypatch, STR_INVALID_VAL) from vllm.attention.selector import which_attn_to_use