Skip to content

Commit

Permalink
fix ModuleGenerator and unify _create_module_...
Browse files Browse the repository at this point in the history
  • Loading branch information
junkmd committed Apr 11, 2024
1 parent 040152f commit 81bdb66
Showing 1 changed file with 62 additions and 71 deletions.
133 changes: 62 additions & 71 deletions comtypes/client/_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def GetModule(tlib: _UnionT[Any, typeinfo.ITypeLib]) -> types.ModuleType:
pathname = None
tlib = _load_tlib(tlib)
logger.debug("GetModule(%s)", tlib.GetLibAttr())
return ModuleGenerator().generate(tlib, pathname)
return ModuleGenerator(tlib, pathname).generate()


def _load_tlib(obj: Any) -> typeinfo.ITypeLib:
Expand Down Expand Up @@ -160,90 +160,81 @@ def _load_tlib(obj: Any) -> typeinfo.ITypeLib:
raise TypeError("'%r' is not supported type for loading typelib" % obj)


def _create_module_in_file(modulename: str, code: str) -> types.ModuleType:
"""create module in file system, and import it"""
def _create_module(modulename: str, code: str) -> types.ModuleType:
"""Creates the module, then imports it."""
# `modulename` is 'comtypes.gen.xxx'
filename = "%s.py" % modulename.split(".")[-1]
with open(os.path.join(comtypes.client.gen_dir, filename), "w") as ofi:
stem = modulename.split(".")[-1]
if comtypes.client.gen_dir is None:
# in memory system
import comtypes.gen as g

mod = types.ModuleType(modulename)
abs_gen_path = os.path.abspath(g.__path__[0]) # type: ignore
mod.__file__ = os.path.join(abs_gen_path, "<memory>")
exec(code, mod.__dict__)
sys.modules[modulename] = mod
setattr(g, stem, mod)
return mod
# in file system
with open(os.path.join(comtypes.client.gen_dir, f"{stem}.py"), "w") as ofi:
print(code, file=ofi)
# clear the import cache to make sure Python sees newly created modules
if hasattr(importlib, "invalidate_caches"):
importlib.invalidate_caches()
importlib.invalidate_caches()
return _my_import(modulename)


def _create_module_in_memory(modulename: str, code: str) -> types.ModuleType:
"""create module in memory system, and import it"""
# `modulename` is 'comtypes.gen.xxx'
import comtypes.gen as g

mod = types.ModuleType(modulename)
abs_gen_path = os.path.abspath(g.__path__[0]) # type: ignore
mod.__file__ = os.path.join(abs_gen_path, "<memory>")
exec(code, mod.__dict__)
sys.modules[modulename] = mod
setattr(g, modulename.split(".")[-1], mod)
return mod


class ModuleGenerator(object):
def __init__(self) -> None:
self.codegen = codegenerator.CodeGenerator(_get_known_symbols())

def generate(
self, tlib: typeinfo.ITypeLib, pathname: Optional[str]
) -> types.ModuleType:
# create and import the real typelib wrapper module
mod = self._create_wrapper_module(tlib, pathname)
# try to get the friendly-name, if not, returns the real typelib wrapper module
modulename = codegenerator.name_friendly_module(tlib)
if modulename is None:
return mod
# create and import the friendly-named module
return self._create_friendly_module(tlib, modulename)
def __init__(self, tlib: typeinfo.ITypeLib, pathname: Optional[str]) -> None:
self.wrapper_name = codegenerator.name_wrapper_module(tlib)
self.friendly_name = codegenerator.name_friendly_module(tlib)
if pathname is None:
self.pathname = tlbparser.get_tlib_filename(tlib)
else:
self.pathname = pathname
self.tlib = tlib

def generate(self) -> types.ModuleType:
# tries to import existing modules
wrapper_module = self._get_existing_wrapper_module()
if wrapper_module is not None:
if self.friendly_name is None:
return wrapper_module
else:
friendly_module = self._get_existing_friendly_module()
if friendly_module is not None:
return friendly_module
# (re)generates wrapper and friendly modules
codegen = codegenerator.CodeGenerator(_get_known_symbols())
codebases: List[Tuple[str, str]] = []
logger.info("# Generating %s", self.wrapper_name)
items = list(tlbparser.TypeLibParser(self.tlib).parse().values())
wrp_code = codegen.generate_wrapper_code(items, filename=self.pathname)
codebases.append((self.wrapper_name, wrp_code))
if self.friendly_name is not None:
logger.info("# Generating %s", self.friendly_name)
frd_code = codegen.generate_friendly_code(self.wrapper_name)
codebases.append((self.friendly_name, frd_code))
for ext_tlib in codegen.externals: # generates dependency COM-lib modules
GetModule(ext_tlib)
return [_create_module(name, code) for (name, code) in codebases][-1]

def _create_friendly_module(
self, tlib: typeinfo.ITypeLib, modulename: str
) -> types.ModuleType:
"""helper which creates and imports the friendly-named module."""
def _get_existing_friendly_module(self) -> Optional[types.ModuleType]:
if self.friendly_name is None:
return
try:
mod = _my_import(modulename)
mod = _my_import(self.friendly_name)
except Exception as details:
logger.info("Could not import %s: %s", modulename, details)
logger.info("Could not import %s: %s", self.friendly_name, details)
else:
return mod
# the module is always regenerated if the import fails
logger.info("# Generating %s", modulename)
# determine the Python module name
modname = codegenerator.name_wrapper_module(tlib)
code = self.codegen.generate_friendly_code(modname)
if comtypes.client.gen_dir is None:
return _create_module_in_memory(modulename, code)
return _create_module_in_file(modulename, code)

def _create_wrapper_module(
self, tlib: typeinfo.ITypeLib, pathname: Optional[str]
) -> types.ModuleType:
"""helper which creates and imports the real typelib wrapper module."""
modulename = codegenerator.name_wrapper_module(tlib)
if modulename in sys.modules:
return sys.modules[modulename]

def _get_existing_wrapper_module(self) -> Optional[types.ModuleType]:
if self.wrapper_name in sys.modules:
return sys.modules[self.wrapper_name]
try:
return _my_import(modulename)
return _my_import(self.wrapper_name)
except Exception as details:
logger.info("Could not import %s: %s", modulename, details)
# generate the module since it doesn't exist or is out of date
logger.info("# Generating %s", modulename)
p = tlbparser.TypeLibParser(tlib)
if pathname is None:
pathname = tlbparser.get_tlib_filename(tlib)
items = list(p.parse().values())
code = self.codegen.generate_wrapper_code(items, filename=pathname)
for ext_tlib in self.codegen.externals: # generates dependency COM-lib modules
GetModule(ext_tlib)
if comtypes.client.gen_dir is None:
return _create_module_in_memory(modulename, code)
return _create_module_in_file(modulename, code)
logger.info("Could not import %s: %s", self.wrapper_name, details)


def _get_known_symbols() -> Dict[str, str]:
Expand Down

0 comments on commit 81bdb66

Please sign in to comment.