diff --git a/defusedxml/common.py b/defusedxml/common.py index 668b609..6c0b79e 100644 --- a/defusedxml/common.py +++ b/defusedxml/common.py @@ -75,13 +75,25 @@ class NotSupportedError(DefusedXmlException): def _apply_defusing(defused_mod): assert defused_mod is sys.modules[defused_mod.__name__] stdlib_name = defused_mod.__origin__ - __import__(stdlib_name, {}, {}, ["*"]) - stdlib_mod = sys.modules[stdlib_name] - stdlib_names = set(dir(stdlib_mod)) - for name, obj in vars(defused_mod).items(): - if name.startswith("_") or name not in stdlib_names: - continue - setattr(stdlib_mod, name, obj) + if PY3: + from unittest.mock import patch + for name, obj in vars(defused_mod).items(): + if name.startswith("_"): + continue + try: + patcher = patch(stdlib_name + '.' + name, obj) + patcher.start() + except AttributeError: + pass + stdlib_mod = patcher + else: + __import__(stdlib_name, {}, {}, ["*"]) + stdlib_mod = sys.modules[stdlib_name] + stdlib_names = set(dir(stdlib_mod)) + for name, obj in vars(defused_mod).items(): + if name.startswith("_") or name not in stdlib_names: + continue + setattr(stdlib_mod, name, obj) return stdlib_mod diff --git a/tests.py b/tests.py index 145172c..f461b9d 100644 --- a/tests.py +++ b/tests.py @@ -482,6 +482,26 @@ def test_defused_gzip_response(self): self.decode_response(response, 4095, 8192) +def get_std_module(defused_module): + name = defused_module.__origin__ + obj = __import__(name, globals(), locals(), [], 0) + for part in name.split('.')[1:]: + obj = getattr(obj, part) + return obj + + +class TestStdElementTree(TestDefusedElementTree): + module = get_std_module(ElementTree) + + +class TestStdMinidom(TestDefusedMinidom): + module = get_std_module(minidom) + + +class TestStdPulldom(TestDefusedPulldom): + module = get_std_module(pulldom) + + def test_main(): suite = unittest.TestSuite() suite.addTests(unittest.makeSuite(TestDefusedcElementTree)) @@ -497,9 +517,18 @@ def test_main(): return suite +def test_origin(): + suite = unittest.TestSuite() + suite.addTests(unittest.makeSuite(TestStdElementTree)) + suite.addTests(unittest.makeSuite(TestStdMinidom)) + return suite + + if __name__ == "__main__": suite = test_main() result = unittest.TextTestRunner(verbosity=1).run(suite) - # TODO: test that it actually works defuse_stdlib() - sys.exit(not result.wasSuccessful()) + suite = test_origin() + result_std = unittest.TextTestRunner(verbosity=1).run(suite) + success = result.wasSuccessful() and result_std.wasSuccessful() + sys.exit(not success)