From 11ec97b2340149a653f9f75420663be42dabadb5 Mon Sep 17 00:00:00 2001 From: Byron Hsu Date: Mon, 18 Nov 2024 22:17:24 -0800 Subject: [PATCH] fix qwen2 import failure in test (#394) ## Summary ## Testing Done - Hardware Type: - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --- test/transformers/test_qwen2vl_mrope.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/test/transformers/test_qwen2vl_mrope.py b/test/transformers/test_qwen2vl_mrope.py index f8bcfd2a2..fb3f4b80e 100644 --- a/test/transformers/test_qwen2vl_mrope.py +++ b/test/transformers/test_qwen2vl_mrope.py @@ -2,16 +2,25 @@ import pytest import torch -from transformers.models.qwen2_vl.modeling_qwen2_vl import ( - Qwen2VLRotaryEmbedding, - apply_multimodal_rotary_pos_emb, -) + +try: + from transformers.models.qwen2_vl.modeling_qwen2_vl import ( + Qwen2VLRotaryEmbedding, + apply_multimodal_rotary_pos_emb, + ) + + IS_QWEN_AVAILABLE = True +except Exception: + IS_QWEN_AVAILABLE = False from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction from liger_kernel.transformers.functional import liger_qwen2vl_mrope from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb +@pytest.mark.skipif( + not IS_QWEN_AVAILABLE, reason="Qwen is not available in transformers." +) @pytest.mark.parametrize("bsz", [1, 2]) @pytest.mark.parametrize("seq_len", [128, 131]) @pytest.mark.parametrize("num_q_heads, num_kv_heads", [(64, 8), (28, 4), (12, 2)]) @@ -87,6 +96,9 @@ def test_correctness( torch.testing.assert_close(k1_grad, k2_grad, atol=atol, rtol=rtol) +@pytest.mark.skipif( + not IS_QWEN_AVAILABLE, reason="Qwen is not available in transformers." +) @pytest.mark.parametrize( "bsz, seq_len, num_q_heads, num_kv_heads, head_dim, mrope_section", [