diff --git a/tests/common.py b/tests/common.py index 62ff7dd0a..e84c7ac36 100644 --- a/tests/common.py +++ b/tests/common.py @@ -60,11 +60,3 @@ def tmp_rng_seed(device: Device, seed: int = 0) -> Generator[None, None, None]: torch.manual_seed(seed) yield - - -def python_devel_only() -> bool: - """Return ``True`` if fairseq2 is installed for Python development only.""" - import fairseq2 - import fairseq2n - - return fairseq2.__version__ != fairseq2n.__version__ diff --git a/tests/conftest.py b/tests/conftest.py index 8ed979b76..56f225026 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,23 +4,18 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import warnings from argparse import ArgumentTypeError from pathlib import Path from typing import cast import pytest +from packaging.version import Version import tests.common from fairseq2.typing import Device -def parse_device_arg(value: str) -> Device: - try: - return Device(value) - except RuntimeError: - raise ArgumentTypeError(f"'{value}' is not a valid device name.") - - def pytest_addoption(parser: pytest.Parser) -> None: # fmt: off parser.addoption( @@ -34,6 +29,20 @@ def pytest_addoption(parser: pytest.Parser) -> None: # fmt: on +def parse_device_arg(value: str) -> Device: + try: + return Device(value) + except RuntimeError: + raise ArgumentTypeError(f"'{value}' is not a valid device name.") + + +def pytest_configure(config: pytest.Config) -> None: + config.addinivalue_line( + "markers", + "fairseq2n(version): mark test to run only on the specified fairseq2n version or greater", + ) + + def pytest_sessionstart(session: pytest.Session) -> None: tests.common.device = cast(Device, session.config.getoption("device")) @@ -46,3 +55,28 @@ def pytest_ignore_collect( return not cast(bool, config.getoption("integration")) return False + + +def pytest_runtest_setup(item: pytest.Function) -> None: + marker = item.get_closest_marker(name="fairseq2n") + if marker is not None: + skip_if_fairseq2n_newer(*marker.args, **marker.kwargs) + + +def skip_if_fairseq2n_newer(version: str) -> None: + import fairseq2n + + installed_version = Version(fairseq2n.__version__) + annotated_version = Version(version) + + # fmt: off + if installed_version < annotated_version: + pytest.skip(f"The test requires fairseq2n v{annotated_version} or greater.") + elif ( + installed_version.major != annotated_version.major or + installed_version.minor != annotated_version.minor + ): + warnings.warn( + f"The test requires fairseq2n v{annotated_version} which is older than the current version (v{installed_version}). The marker can be safely removed." + ) + # fmt: on diff --git a/tests/unit/data/data_pipeline/test_sample.py b/tests/unit/data/data_pipeline/test_sample.py index fcb71254c..c414a6d23 100644 --- a/tests/unit/data/data_pipeline/test_sample.py +++ b/tests/unit/data/data_pipeline/test_sample.py @@ -10,15 +10,12 @@ from fairseq2.data import DataPipeline, read_sequence from fairseq2.data.text.text_reader import read_text from fairseq2.utils.version import is_pt2_or_greater -from tests.common import python_devel_only, tmp_rng_seed +from tests.common import tmp_rng_seed cpu_device = torch.device("cpu") -@pytest.mark.skipif( - python_devel_only(), - reason="New fairseq2n API in Python-only installation. Skipping till v0.2.", -) +@pytest.mark.fairseq2n("0.2a0") @pytest.mark.skipif( not is_pt2_or_greater(), reason="Different sampling results with versions lower than PyTorch 2.0",