diff --git a/test/convergence/test_mini_models_with_logits.py b/test/convergence/test_mini_models_with_logits.py index e7672c4a4..5ca3e7420 100644 --- a/test/convergence/test_mini_models_with_logits.py +++ b/test/convergence/test_mini_models_with_logits.py @@ -18,7 +18,9 @@ import pytest import torch +import transformers from datasets import load_from_disk +from packaging import version from torch.utils.data import DataLoader from transformers.models.gemma import GemmaConfig, GemmaForCausalLM from transformers.models.gemma2 import Gemma2Config, Gemma2ForCausalLM @@ -538,8 +540,9 @@ def run_mini_model( 5e-3, 1e-5, marks=pytest.mark.skipif( - not QWEN2_VL_AVAILABLE, - reason="Qwen2-VL not available in this version of transformers", + not QWEN2_VL_AVAILABLE + or version.parse(transformers.__version__) >= version.parse("4.47.0"), + reason="Qwen2-VL not available in this version of transformers or transformers version >= 4.47.0", ), ), pytest.param( @@ -558,8 +561,10 @@ def run_mini_model( not supports_bfloat16(), reason="bfloat16 not supported on this GPU" ), pytest.mark.skipif( - not QWEN2_VL_AVAILABLE, - reason="Qwen2-VL not available in this version of transformers", + not QWEN2_VL_AVAILABLE + or version.parse(transformers.__version__) + >= version.parse("4.47.0"), + reason="Qwen2-VL not available in this version of transformers or transformers version >= 4.47.0", ), ], ),