Skip to content

Commit

Permalink
Better test
Browse files Browse the repository at this point in the history
  • Loading branch information
vshampor committed Aug 22, 2023
1 parent 5d288bd commit 992c4fa
Showing 1 changed file with 21 additions and 15 deletions.
36 changes: 21 additions & 15 deletions tests/common/test_framework_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import logging
import re
import sys
import unittest
from importlib import import_module
from importlib.machinery import ModuleSpec
from typing import List
Expand All @@ -22,31 +24,35 @@
import nncf

SUPPORTED_FRAMEWORKS = nncf._SUPPORTED_FRAMEWORKS # pylint:disable=protected-access
_REAL_FIND_SPEC = importlib._bootstrap._find_spec # pylint:disable=protected-access


@pytest.mark.parametrize("ref_available_frameworks", [["torch"], ["torch", "tensorflow"], ["onnx", "openvino"], []])
def test_frameworks_detected(ref_available_frameworks: List[str], nncf_caplog):
with mock.patch.dict(sys.modules):
for supp_fw in SUPPORTED_FRAMEWORKS:
if supp_fw in sys.modules:
del sys.modules[supp_fw]
del sys.modules["nncf"]

for fw in ref_available_frameworks:
mock_spec = ModuleSpec(fw, loader=MagicMock(), origin="foo/bar")
module = MagicMock()
module.__spec__ = mock_spec
sys.modules[fw] = module
class FailForModules:
def __init__(self, mocked_modules: List[str], hidden_modules: List[str]):
self._mocked_modules = mocked_modules
self._hidden_modules = hidden_modules

def __call__(self, fullname, path=None, target=None):
if fullname in self._hidden_modules:
return None
elif fullname in self._mocked_modules:
return ModuleSpec(fullname, loader=MagicMock(), origin="foo/bar")
return _REAL_FIND_SPEC(fullname, path, target)


@pytest.mark.parametrize("ref_available_frameworks", [["torch"], ["torch", "tensorflow"], ["onnx", "openvino"], []])
def test_frameworks_detected(ref_available_frameworks: List[str], nncf_caplog, mocker):
unavailable_frameworks = [fw for fw in SUPPORTED_FRAMEWORKS if fw not in ref_available_frameworks]
failer = FailForModules(ref_available_frameworks, unavailable_frameworks)
with unittest.mock.patch("importlib.util.find_spec", wraps=failer) as mocked_fs:
with nncf_caplog.at_level(logging.INFO):
import_module("nncf")
importlib.reload(nncf)
matches = re.search(r"Supported frameworks detected: (.*)", nncf_caplog.text)
if ref_available_frameworks:
assert matches is not None
match_text = matches[0]
for fw in ref_available_frameworks:
assert fw in match_text
unavailable_frameworks = [fw for fw in SUPPORTED_FRAMEWORKS if fw not in ref_available_frameworks]
for fw in unavailable_frameworks:
assert fw not in match_text
else:
Expand Down

0 comments on commit 992c4fa

Please sign in to comment.