diff --git a/comtypes/tools/codegenerator.py b/comtypes/tools/codegenerator.py index 7aaf5b6a..8b04c9fa 100644 --- a/comtypes/tools/codegenerator.py +++ b/comtypes/tools/codegenerator.py @@ -445,7 +445,7 @@ def __init__(self, known_symbols=None, known_interfaces=None) -> None: self.done = set() # type descriptions that have been generated self.names = set() # names that have been generated self.externals = [] # typelibs imported to generated module - self.aliases: Dict[str, str] = {} + self.enum_aliases: Dict[str, str] = {} self.last_item_class = False def generate(self, item): @@ -594,7 +594,14 @@ def generate_wrapper_code( for n, v in self.unnamed_enum_members: print(f"{n} = {v}", file=output) print(file=output) - print(self.enums.to_constants(), file=output) + if self.enums: + print(self.enums.to_constants(), file=output) + print(file=output) + if self.enum_aliases: + print("# aliases for enums", file=output) + for k, v in self.enum_aliases.items(): + print(f"{k} = {v}", file=output) + print(file=output) print(self.stream.getvalue(), file=output) print(self._make_dunder_all_part(), file=output) print(file=output) @@ -619,12 +626,12 @@ def generate_friendly_code(self, modname: str) -> str: print(self._make_friendly_module_import_part(modname), file=output) print(file=output) print(file=output) - print(self.enums.to_intflags(), file=output) - print(file=output) - print(file=output) - enum_aliases = self.enum_aliases - if enum_aliases: - for k, v in enum_aliases.items(): + if self.enums: + print(self.enums.to_intflags(), file=output) + print(file=output) + print(file=output) + if self.enum_aliases: + for k, v in self.enum_aliases.items(): print(f"{k} = {v}", file=output) print(file=output) print(file=output) @@ -662,10 +669,6 @@ def _make_friendly_module_import_part(self, modname: str) -> str: part = f"from {modname} import (\n{joined_names}\n)" return part - @property - def enum_aliases(self) -> Dict[str, str]: - return {k: v for k, v in self.aliases.items() if v in self.enums} - def need_VARIANT_imports(self, value): text = repr(value) if "Decimal(" in text: @@ -690,7 +693,7 @@ def EnumValue(self, tp: typedesc.EnumValue) -> None: if keyword.iskeyword(tp.name): # XXX use logging! if __warn_on_munge__: - print("# Fixing keyword as EnumValue for %s" % tp.name) + print(f"# Fixing keyword as EnumValue for {tp.name}") tp_name = self._to_type_name(tp) if tp.enumeration.name: self.enums.add(tp.enumeration.name, tp_name, value) @@ -716,9 +719,11 @@ def Typedef(self, tp: typedesc.Typedef) -> None: if definition in self.known_symbols: self.declarations.add(tp.name, definition) else: - print("%s = %s" % (tp.name, definition), file=self.stream) - self.aliases[tp.name] = definition - self.last_item_class = False + if isinstance(tp.typ, typedesc.Enumeration): + self.enum_aliases[tp.name] = definition + else: + print(f"{tp.name} = {definition}", file=self.stream) + self.last_item_class = False self.names.add(tp.name) def FundamentalType(self, item: typedesc.FundamentalType) -> None: @@ -1610,7 +1615,9 @@ def add(self, enum_name: str, member_name: str, value: int) -> None: Examples: is necessary for doctest >>> enums = EnumerationNamespaces() + >>> assert not enums >>> enums.add('Foo', 'ham', 1) + >>> assert enums >>> enums.add('Foo', 'spam', 2) >>> enums.add('Bar', 'bacon', 3) >>> assert 'Foo' in enums @@ -1638,6 +1645,9 @@ class Bar(IntFlag): def __contains__(self, item: str) -> bool: return item in self.data + def __bool__(self) -> bool: + return bool(self.data) + def get_symbols(self) -> Set[str]: return set(self.data)