From 81bdb667912956ec05dbd588de8ec5a5584d2f8b Mon Sep 17 00:00:00 2001 From: junkmd Date: Fri, 12 Apr 2024 08:11:27 +0900 Subject: [PATCH] fix `ModuleGenerator` and unify `_create_module_...` --- comtypes/client/_generate.py | 133 ++++++++++++++++------------------- 1 file changed, 62 insertions(+), 71 deletions(-) diff --git a/comtypes/client/_generate.py b/comtypes/client/_generate.py index 9468293e..ccf86b9e 100644 --- a/comtypes/client/_generate.py +++ b/comtypes/client/_generate.py @@ -121,7 +121,7 @@ def GetModule(tlib: _UnionT[Any, typeinfo.ITypeLib]) -> types.ModuleType: pathname = None tlib = _load_tlib(tlib) logger.debug("GetModule(%s)", tlib.GetLibAttr()) - return ModuleGenerator().generate(tlib, pathname) + return ModuleGenerator(tlib, pathname).generate() def _load_tlib(obj: Any) -> typeinfo.ITypeLib: @@ -160,90 +160,81 @@ def _load_tlib(obj: Any) -> typeinfo.ITypeLib: raise TypeError("'%r' is not supported type for loading typelib" % obj) -def _create_module_in_file(modulename: str, code: str) -> types.ModuleType: - """create module in file system, and import it""" +def _create_module(modulename: str, code: str) -> types.ModuleType: + """Creates the module, then imports it.""" # `modulename` is 'comtypes.gen.xxx' - filename = "%s.py" % modulename.split(".")[-1] - with open(os.path.join(comtypes.client.gen_dir, filename), "w") as ofi: + stem = modulename.split(".")[-1] + if comtypes.client.gen_dir is None: + # in memory system + import comtypes.gen as g + + mod = types.ModuleType(modulename) + abs_gen_path = os.path.abspath(g.__path__[0]) # type: ignore + mod.__file__ = os.path.join(abs_gen_path, "") + exec(code, mod.__dict__) + sys.modules[modulename] = mod + setattr(g, stem, mod) + return mod + # in file system + with open(os.path.join(comtypes.client.gen_dir, f"{stem}.py"), "w") as ofi: print(code, file=ofi) # clear the import cache to make sure Python sees newly created modules - if hasattr(importlib, "invalidate_caches"): - importlib.invalidate_caches() + importlib.invalidate_caches() return _my_import(modulename) -def _create_module_in_memory(modulename: str, code: str) -> types.ModuleType: - """create module in memory system, and import it""" - # `modulename` is 'comtypes.gen.xxx' - import comtypes.gen as g - - mod = types.ModuleType(modulename) - abs_gen_path = os.path.abspath(g.__path__[0]) # type: ignore - mod.__file__ = os.path.join(abs_gen_path, "") - exec(code, mod.__dict__) - sys.modules[modulename] = mod - setattr(g, modulename.split(".")[-1], mod) - return mod - - class ModuleGenerator(object): - def __init__(self) -> None: - self.codegen = codegenerator.CodeGenerator(_get_known_symbols()) - - def generate( - self, tlib: typeinfo.ITypeLib, pathname: Optional[str] - ) -> types.ModuleType: - # create and import the real typelib wrapper module - mod = self._create_wrapper_module(tlib, pathname) - # try to get the friendly-name, if not, returns the real typelib wrapper module - modulename = codegenerator.name_friendly_module(tlib) - if modulename is None: - return mod - # create and import the friendly-named module - return self._create_friendly_module(tlib, modulename) + def __init__(self, tlib: typeinfo.ITypeLib, pathname: Optional[str]) -> None: + self.wrapper_name = codegenerator.name_wrapper_module(tlib) + self.friendly_name = codegenerator.name_friendly_module(tlib) + if pathname is None: + self.pathname = tlbparser.get_tlib_filename(tlib) + else: + self.pathname = pathname + self.tlib = tlib + + def generate(self) -> types.ModuleType: + # tries to import existing modules + wrapper_module = self._get_existing_wrapper_module() + if wrapper_module is not None: + if self.friendly_name is None: + return wrapper_module + else: + friendly_module = self._get_existing_friendly_module() + if friendly_module is not None: + return friendly_module + # (re)generates wrapper and friendly modules + codegen = codegenerator.CodeGenerator(_get_known_symbols()) + codebases: List[Tuple[str, str]] = [] + logger.info("# Generating %s", self.wrapper_name) + items = list(tlbparser.TypeLibParser(self.tlib).parse().values()) + wrp_code = codegen.generate_wrapper_code(items, filename=self.pathname) + codebases.append((self.wrapper_name, wrp_code)) + if self.friendly_name is not None: + logger.info("# Generating %s", self.friendly_name) + frd_code = codegen.generate_friendly_code(self.wrapper_name) + codebases.append((self.friendly_name, frd_code)) + for ext_tlib in codegen.externals: # generates dependency COM-lib modules + GetModule(ext_tlib) + return [_create_module(name, code) for (name, code) in codebases][-1] - def _create_friendly_module( - self, tlib: typeinfo.ITypeLib, modulename: str - ) -> types.ModuleType: - """helper which creates and imports the friendly-named module.""" + def _get_existing_friendly_module(self) -> Optional[types.ModuleType]: + if self.friendly_name is None: + return try: - mod = _my_import(modulename) + mod = _my_import(self.friendly_name) except Exception as details: - logger.info("Could not import %s: %s", modulename, details) + logger.info("Could not import %s: %s", self.friendly_name, details) else: return mod - # the module is always regenerated if the import fails - logger.info("# Generating %s", modulename) - # determine the Python module name - modname = codegenerator.name_wrapper_module(tlib) - code = self.codegen.generate_friendly_code(modname) - if comtypes.client.gen_dir is None: - return _create_module_in_memory(modulename, code) - return _create_module_in_file(modulename, code) - - def _create_wrapper_module( - self, tlib: typeinfo.ITypeLib, pathname: Optional[str] - ) -> types.ModuleType: - """helper which creates and imports the real typelib wrapper module.""" - modulename = codegenerator.name_wrapper_module(tlib) - if modulename in sys.modules: - return sys.modules[modulename] + + def _get_existing_wrapper_module(self) -> Optional[types.ModuleType]: + if self.wrapper_name in sys.modules: + return sys.modules[self.wrapper_name] try: - return _my_import(modulename) + return _my_import(self.wrapper_name) except Exception as details: - logger.info("Could not import %s: %s", modulename, details) - # generate the module since it doesn't exist or is out of date - logger.info("# Generating %s", modulename) - p = tlbparser.TypeLibParser(tlib) - if pathname is None: - pathname = tlbparser.get_tlib_filename(tlib) - items = list(p.parse().values()) - code = self.codegen.generate_wrapper_code(items, filename=pathname) - for ext_tlib in self.codegen.externals: # generates dependency COM-lib modules - GetModule(ext_tlib) - if comtypes.client.gen_dir is None: - return _create_module_in_memory(modulename, code) - return _create_module_in_file(modulename, code) + logger.info("Could not import %s: %s", self.wrapper_name, details) def _get_known_symbols() -> Dict[str, str]: