diff --git a/.gitignore b/.gitignore index 4f13394..58f2d90 100644 --- a/.gitignore +++ b/.gitignore @@ -49,3 +49,5 @@ env/ pyvenv.cfg share/* + +venv diff --git a/thriftpy2/hook.py b/thriftpy2/hook.py index fd35386..2ae383b 100644 --- a/thriftpy2/hook.py +++ b/thriftpy2/hook.py @@ -3,25 +3,49 @@ from __future__ import absolute_import import sys +import importlib.abc +import importlib.util +import types from .parser import load_module -class ThriftImporter(object): - def __init__(self, extension="_thrift"): - self.extension = extension - def __eq__(self, other): - return self.__class__.__module__ == other.__class__.__module__ and \ - self.__class__.__name__ == other.__class__.__name__ and \ - self.extension == other.extension +# TODO: The load process does not compatible with Python standard, e.g., if the +# specified thrift file does not exists, it raises FileNotFoundError, and skiped +# the other meta finders in the sys.meta_path. +if sys.version_info >= (3, 4): + class ThriftImporter(importlib.abc.MetaPathFinder): + def __init__(self, extension="_thrift"): + self.extension = extension - def find_module(self, fullname, path=None): - if fullname.endswith(self.extension): - return self + def find_spec(self, fullname, path, target=None): + if not fullname.endswith(self.extension): + return None + return importlib.util.spec_from_loader(fullname, + ThriftLoader(fullname)) - def load_module(self, fullname): - return load_module(fullname) + + class ThriftLoader(importlib.abc.Loader): + def __init__(self, fullname): + self.fullname = fullname + + def create_module(self, spec): + return load_module(self.fullname) + + def exec_module(self, module): + pass +else: + class ThriftImporter(object): + def __init__(self, extension="_thrift"): + self.extension = extension + + def find_module(self, fullname, path=None): + if fullname.endswith(self.extension): + return self + + def load_module(self, fullname): + return load_module(fullname) _imp = ThriftImporter() @@ -29,9 +53,9 @@ def load_module(self, fullname): def install_import_hook(): global _imp - sys.meta_path[:] = [x for x in sys.meta_path if _imp != x] + [_imp] + sys.meta_path[:] = [x for x in sys.meta_path if _imp is not x] + [_imp] def remove_import_hook(): global _imp - sys.meta_path[:] = [x for x in sys.meta_path if _imp != x] + sys.meta_path[:] = [x for x in sys.meta_path if _imp is not x]