Skip to content

Commit

Permalink
Replace file based mocking by replacing __import__.
Browse files Browse the repository at this point in the history
  • Loading branch information
terrorfisch committed Sep 28, 2018
1 parent 9ab256a commit f4982a7
Showing 1 changed file with 28 additions and 25 deletions.
53 changes: 28 additions & 25 deletions tests/utils/time_type_tests.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import sys
import unittest
import os
import builtins
import contextlib
import importlib
import fractions
Expand All @@ -14,23 +14,29 @@
import qupulse.utils.types as qutypes


def mock_missing_gmpy2(exit_stack: contextlib.ExitStack):
if os.path.exists('gmpy2.py'):
raise RuntimeError('Cannot mock missing gmpy2 due to existing file')
@contextlib.contextmanager
def mock_missing_module(module_name: str):
exit_stack = contextlib.ExitStack()

with open('gmpy2.py', 'w') as gmpy_file:
exit_stack.callback(os.remove, 'gmpy2.py')
if module_name in sys.modules:
# temporarily remove gmpy2 from the imported modules

gmpy_file.write("raise ImportError()")
temp_modules = sys.modules.copy()
del temp_modules[module_name]
exit_stack.enter_context(mock.patch.dict(sys.modules, temp_modules))

if 'gmpy2' in sys.modules:
modules_patcher = mock.patch.dict(sys.modules,
values=((name, module)
for name, module in sys.modules.items()
if name != 'gmpy2'),
clear=True)
modules_patcher.__enter__()
exit_stack.push(modules_patcher)
original_import = builtins.__import__

def mock_import(name, *args, **kwargs):
if name == module_name:
raise ImportError(name)
else:
return original_import(name, *args, **kwargs)

exit_stack.enter_context(mock.patch('builtins.__import__', mock_import))

with exit_stack:
yield


class TestTimeType(unittest.TestCase):
Expand All @@ -39,25 +45,22 @@ class TestTimeType(unittest.TestCase):
@property
def fallback_qutypes(self):
if not self._fallback_qutypes:
exit_stack = contextlib.ExitStack()

with exit_stack:

if gmpy2:
# create a local file that raises ImportError on import
mock_missing_gmpy2(exit_stack)

if gmpy2:
with mock_missing_module('gmpy2'):
self._fallback_qutypes = importlib.reload(qutypes)

else:
self._fallback_qutypes = qutypes
else:
self._fallback_qutypes = qutypes
return self._fallback_qutypes

def test_fraction_fallback(self):
self.assertIs(fractions.Fraction, self.fallback_qutypes.TimeType)

@unittest.skipIf(gmpy2 is None, "gmpy2 not available.")
def test_default_time_from_float(self):
# assert mocking did no permanent damage
self.assertIs(gmpy2.mpq, qutypes.TimeType)

self.assertEqual(qutypes.time_from_float(123/931), gmpy2.mpq(123, 931))

self.assertEqual(qutypes.time_from_float(1000000/1000001, 1e-5), gmpy2.mpq(1))
Expand Down

0 comments on commit f4982a7

Please sign in to comment.