From 6dfc9d8de690aa52d29e8510bb7a39db5582138c Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Thu, 12 Dec 2024 18:21:13 -0800 Subject: [PATCH] bugfix: fix unittests not yield error in AOT mode (#657) The unittests in AOT mode failed since https://github.com/flashinfer-ai/flashinfer/pull/629 because we didn't use return instead of yield in warmup functions, this PR fixes the issue. --- tests/test_alibi.py | 2 +- tests/test_batch_decode_kernels.py | 2 +- tests/test_batch_prefill_kernels.py | 2 +- tests/test_block_sparse.py | 2 +- tests/test_logits_cap.py | 2 +- tests/test_non_contiguous_decode.py | 2 +- tests/test_non_contiguous_prefill.py | 2 +- tests/test_shared_prefix_kernels.py | 2 +- tests/test_sliding_window.py | 2 +- tests/test_tensor_cores_decode.py | 2 +- 10 files changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/test_alibi.py b/tests/test_alibi.py index f01811ec..2b15106b 100644 --- a/tests/test_alibi.py +++ b/tests/test_alibi.py @@ -26,7 +26,7 @@ @pytest.fixture(autouse=True, scope="module") def warmup_jit(): if flashinfer.jit.has_prebuilt_ops: - return + yield try: flashinfer.jit.parallel_load_modules( jit_decode_attention_func_args( diff --git a/tests/test_batch_decode_kernels.py b/tests/test_batch_decode_kernels.py index 4d2d67c6..834d8ef3 100644 --- a/tests/test_batch_decode_kernels.py +++ b/tests/test_batch_decode_kernels.py @@ -24,7 +24,7 @@ @pytest.fixture(autouse=True, scope="module") def warmup_jit(): if flashinfer.jit.has_prebuilt_ops: - return + yield try: flashinfer.jit.parallel_load_modules( jit_decode_attention_func_args( diff --git a/tests/test_batch_prefill_kernels.py b/tests/test_batch_prefill_kernels.py index f9ceadee..11ba55f2 100644 --- a/tests/test_batch_prefill_kernels.py +++ b/tests/test_batch_prefill_kernels.py @@ -24,7 +24,7 @@ @pytest.fixture(autouse=True, scope="module") def warmup_jit(): if flashinfer.jit.has_prebuilt_ops: - return + yield try: flashinfer.jit.parallel_load_modules( jit_prefill_attention_func_args( diff --git a/tests/test_block_sparse.py b/tests/test_block_sparse.py index 8672dbb0..682a4ada 100644 --- a/tests/test_block_sparse.py +++ b/tests/test_block_sparse.py @@ -26,7 +26,7 @@ @pytest.fixture(autouse=True, scope="module") def warmup_jit(): if flashinfer.jit.has_prebuilt_ops: - return + yield try: flashinfer.jit.parallel_load_modules( jit_decode_attention_func_args( diff --git a/tests/test_logits_cap.py b/tests/test_logits_cap.py index c42278aa..9bcf882a 100644 --- a/tests/test_logits_cap.py +++ b/tests/test_logits_cap.py @@ -26,7 +26,7 @@ @pytest.fixture(autouse=True, scope="module") def warmup_jit(): if flashinfer.jit.has_prebuilt_ops: - return + yield try: flashinfer.jit.parallel_load_modules( jit_decode_attention_func_args( diff --git a/tests/test_non_contiguous_decode.py b/tests/test_non_contiguous_decode.py index 22db5f87..f83449dc 100644 --- a/tests/test_non_contiguous_decode.py +++ b/tests/test_non_contiguous_decode.py @@ -8,7 +8,7 @@ @pytest.fixture(autouse=True, scope="module") def warmup_jit(): if flashinfer.jit.has_prebuilt_ops: - return + yield try: flashinfer.jit.parallel_load_modules( jit_decode_attention_func_args( diff --git a/tests/test_non_contiguous_prefill.py b/tests/test_non_contiguous_prefill.py index a45c09ad..601d1caa 100644 --- a/tests/test_non_contiguous_prefill.py +++ b/tests/test_non_contiguous_prefill.py @@ -24,7 +24,7 @@ @pytest.fixture(autouse=True, scope="module") def warmup_jit(): if flashinfer.jit.has_prebuilt_ops: - return + yield try: flashinfer.jit.parallel_load_modules( jit_prefill_attention_func_args( diff --git a/tests/test_shared_prefix_kernels.py b/tests/test_shared_prefix_kernels.py index 5a8bbf2c..77338840 100644 --- a/tests/test_shared_prefix_kernels.py +++ b/tests/test_shared_prefix_kernels.py @@ -24,7 +24,7 @@ @pytest.fixture(autouse=True, scope="module") def warmup_jit(): if flashinfer.jit.has_prebuilt_ops: - return + yield try: flashinfer.jit.parallel_load_modules( jit_decode_attention_func_args( diff --git a/tests/test_sliding_window.py b/tests/test_sliding_window.py index c552f73b..0b4f6fda 100644 --- a/tests/test_sliding_window.py +++ b/tests/test_sliding_window.py @@ -24,7 +24,7 @@ @pytest.fixture(autouse=True, scope="module") def warmup_jit(): if flashinfer.jit.has_prebuilt_ops: - return + yield try: flashinfer.jit.parallel_load_modules( jit_decode_attention_func_args( diff --git a/tests/test_tensor_cores_decode.py b/tests/test_tensor_cores_decode.py index bf312fb8..66309f45 100644 --- a/tests/test_tensor_cores_decode.py +++ b/tests/test_tensor_cores_decode.py @@ -24,7 +24,7 @@ @pytest.fixture(autouse=True, scope="module") def warmup_jit(): if flashinfer.jit.has_prebuilt_ops: - return + yield try: flashinfer.jit.parallel_load_modules( jit_decode_attention_func_args(