diff --git a/test/convergence/test_mini_models_multimodal.py b/test/convergence/test_mini_models_multimodal.py index f67e96c50..07ddd9493 100644 --- a/test/convergence/test_mini_models_multimodal.py +++ b/test/convergence/test_mini_models_multimodal.py @@ -16,9 +16,7 @@ import pytest import torch -import transformers from datasets import load_dataset -from packaging import version from torch.utils.data import DataLoader from transformers import PreTrainedTokenizerFast @@ -380,9 +378,8 @@ def run_mini_model_multimodal( 5e-3, 1e-5, marks=pytest.mark.skipif( - not QWEN2_VL_AVAILABLE - or version.parse(transformers.__version__) >= version.parse("4.47.0"), - reason="Qwen2-VL not available in this version of transformers or transformers version >= 4.47.0", + not QWEN2_VL_AVAILABLE, + reason="Qwen2-VL not available in this version of transformers", ), ), pytest.param( @@ -401,10 +398,8 @@ def run_mini_model_multimodal( not supports_bfloat16(), reason="bfloat16 not supported on this GPU" ), pytest.mark.skipif( - not QWEN2_VL_AVAILABLE - or version.parse(transformers.__version__) - >= version.parse("4.47.0"), - reason="Qwen2-VL not available in this version of transformers or transformers version >= 4.47.0", + not QWEN2_VL_AVAILABLE, + reason="Qwen2-VL not available in this version of transformers", ), ], ), diff --git a/test/convergence/test_mini_models_with_logits.py b/test/convergence/test_mini_models_with_logits.py index 5ca3e7420..e7672c4a4 100644 --- a/test/convergence/test_mini_models_with_logits.py +++ b/test/convergence/test_mini_models_with_logits.py @@ -18,9 +18,7 @@ import pytest import torch -import transformers from datasets import load_from_disk -from packaging import version from torch.utils.data import DataLoader from transformers.models.gemma import GemmaConfig, GemmaForCausalLM from transformers.models.gemma2 import Gemma2Config, Gemma2ForCausalLM @@ -540,9 +538,8 @@ def run_mini_model( 5e-3, 1e-5, marks=pytest.mark.skipif( - not QWEN2_VL_AVAILABLE - or version.parse(transformers.__version__) >= version.parse("4.47.0"), - reason="Qwen2-VL not available in this version of transformers or transformers version >= 4.47.0", + not QWEN2_VL_AVAILABLE, + reason="Qwen2-VL not available in this version of transformers", ), ), pytest.param( @@ -561,10 +558,8 @@ def run_mini_model( not supports_bfloat16(), reason="bfloat16 not supported on this GPU" ), pytest.mark.skipif( - not QWEN2_VL_AVAILABLE - or version.parse(transformers.__version__) - >= version.parse("4.47.0"), - reason="Qwen2-VL not available in this version of transformers or transformers version >= 4.47.0", + not QWEN2_VL_AVAILABLE, + reason="Qwen2-VL not available in this version of transformers", ), ], ),