Skip to content

Commit

Permalink
Work around find_spec finding local subpackages (#2075)
Browse files Browse the repository at this point in the history
### Changes
Fixed a bug with backend detection where NNCF would report every package
detected.

### Reason for changes
`importlib.utils.find_spec` turns out to find the local subpackages
instead of global packages if the global package is not present. Since
we have every backend as a local subpackage (e.g. at `nncf.torch`,
`nncf.tensorflow` etc.), the previous code would report every backend as
available, regardless of actual `torch` or `tensorflow` packages being
installed or not.

### Related tickets
N/A

### Tests
test_framework_detection
  • Loading branch information
vshampor authored Aug 23, 2023
1 parent 5b8d438 commit c027c8b
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 1 deletion.
8 changes: 7 additions & 1 deletion nncf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,13 @@

from importlib.util import find_spec as _find_spec # pylint:disable=wrong-import-position

_AVAILABLE_FRAMEWORKS = {name: _find_spec(name) is not None for name in _SUPPORTED_FRAMEWORKS}
_AVAILABLE_FRAMEWORKS = {}

for fw_name in _SUPPORTED_FRAMEWORKS:
spec = _find_spec(fw_name)
# if the framework is not present, spec may still be not None because it found our nncf.*backend_name* subpackage
framework_present = spec is not None and spec.origin is not None and "nncf" not in spec.origin
_AVAILABLE_FRAMEWORKS[fw_name] = framework_present

if not any(_AVAILABLE_FRAMEWORKS.values()):
nncf_logger.error(
Expand Down
12 changes: 12 additions & 0 deletions tests/common/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright (c) 2023 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from tests.shared.logging import nncf_caplog # pylint:disable=unused-import
56 changes: 56 additions & 0 deletions tests/common/test_framework_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright (c) 2023 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import logging
import re
import unittest
from importlib.machinery import ModuleSpec
from typing import List
from unittest.mock import MagicMock

import pytest

import nncf

SUPPORTED_FRAMEWORKS = nncf._SUPPORTED_FRAMEWORKS # pylint:disable=protected-access
_REAL_FIND_SPEC = importlib._bootstrap._find_spec # pylint:disable=protected-access


class FailForModules:
def __init__(self, mocked_modules: List[str], hidden_modules: List[str]):
self._mocked_modules = mocked_modules
self._hidden_modules = hidden_modules

def __call__(self, fullname, path=None, target=None):
if fullname in self._hidden_modules:
return None
if fullname in self._mocked_modules:
return ModuleSpec(fullname, loader=MagicMock(), origin="foo/bar")
return _REAL_FIND_SPEC(fullname, path, target)


@pytest.mark.parametrize("ref_available_frameworks", [["torch"], ["torch", "tensorflow"], ["onnx", "openvino"], []])
def test_frameworks_detected(ref_available_frameworks: List[str], nncf_caplog, mocker):
unavailable_frameworks = [fw for fw in SUPPORTED_FRAMEWORKS if fw not in ref_available_frameworks]
failer = FailForModules(ref_available_frameworks, unavailable_frameworks)
with unittest.mock.patch("importlib.util.find_spec", wraps=failer):
with nncf_caplog.at_level(logging.INFO):
importlib.reload(nncf)
matches = re.search(r"Supported frameworks detected: (.*)", nncf_caplog.text)
if ref_available_frameworks:
assert matches is not None
match_text = matches[0]
for fw in ref_available_frameworks:
assert fw in match_text
for fw in unavailable_frameworks:
assert fw not in match_text
else:
assert matches is None

0 comments on commit c027c8b

Please sign in to comment.