diff --git a/nncf/__init__.py b/nncf/__init__.py index a429a0f290d..6140a16e383 100644 --- a/nncf/__init__.py +++ b/nncf/__init__.py @@ -33,7 +33,13 @@ from importlib.util import find_spec as _find_spec # pylint:disable=wrong-import-position -_AVAILABLE_FRAMEWORKS = {name: _find_spec(name) is not None for name in _SUPPORTED_FRAMEWORKS} +_AVAILABLE_FRAMEWORKS = {} + +for fw_name in _SUPPORTED_FRAMEWORKS: + spec = _find_spec(fw_name) + # if the framework is not present, spec may still be not None because it found our nncf.*backend_name* subpackage + framework_present = spec is not None and spec.origin is not None and "nncf" not in spec.origin + _AVAILABLE_FRAMEWORKS[fw_name] = framework_present if not any(_AVAILABLE_FRAMEWORKS.values()): nncf_logger.error( diff --git a/tests/common/conftest.py b/tests/common/conftest.py new file mode 100644 index 00000000000..b9217f8077e --- /dev/null +++ b/tests/common/conftest.py @@ -0,0 +1,12 @@ +# Copyright (c) 2023 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tests.shared.logging import nncf_caplog # pylint:disable=unused-import diff --git a/tests/common/test_framework_detection.py b/tests/common/test_framework_detection.py new file mode 100644 index 00000000000..3a4d637e9cf --- /dev/null +++ b/tests/common/test_framework_detection.py @@ -0,0 +1,56 @@ +# Copyright (c) 2023 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import importlib +import logging +import re +import unittest +from importlib.machinery import ModuleSpec +from typing import List +from unittest.mock import MagicMock + +import pytest + +import nncf + +SUPPORTED_FRAMEWORKS = nncf._SUPPORTED_FRAMEWORKS # pylint:disable=protected-access +_REAL_FIND_SPEC = importlib._bootstrap._find_spec # pylint:disable=protected-access + + +class FailForModules: + def __init__(self, mocked_modules: List[str], hidden_modules: List[str]): + self._mocked_modules = mocked_modules + self._hidden_modules = hidden_modules + + def __call__(self, fullname, path=None, target=None): + if fullname in self._hidden_modules: + return None + if fullname in self._mocked_modules: + return ModuleSpec(fullname, loader=MagicMock(), origin="foo/bar") + return _REAL_FIND_SPEC(fullname, path, target) + + +@pytest.mark.parametrize("ref_available_frameworks", [["torch"], ["torch", "tensorflow"], ["onnx", "openvino"], []]) +def test_frameworks_detected(ref_available_frameworks: List[str], nncf_caplog, mocker): + unavailable_frameworks = [fw for fw in SUPPORTED_FRAMEWORKS if fw not in ref_available_frameworks] + failer = FailForModules(ref_available_frameworks, unavailable_frameworks) + with unittest.mock.patch("importlib.util.find_spec", wraps=failer): + with nncf_caplog.at_level(logging.INFO): + importlib.reload(nncf) + matches = re.search(r"Supported frameworks detected: (.*)", nncf_caplog.text) + if ref_available_frameworks: + assert matches is not None + match_text = matches[0] + for fw in ref_available_frameworks: + assert fw in match_text + for fw in unavailable_frameworks: + assert fw not in match_text + else: + assert matches is None