diff --git a/comtypes/tools/codegenerator.py b/comtypes/tools/codegenerator.py index 4dca21e2..7aaf5b6a 100644 --- a/comtypes/tools/codegenerator.py +++ b/comtypes/tools/codegenerator.py @@ -1052,23 +1052,11 @@ def CoClass(self, coclass: typedesc.CoClass) -> None: print(file=self.stream) print(file=self.stream) - for itf, idlflags in coclass.interfaces: + for itf, _ in coclass.interfaces: self.generate(itf.get_head()) - implemented = [] - sources = [] - for item in coclass.interfaces: - # item is (interface class, impltypeflags) - if item[1] & 2: # IMPLTYPEFLAG_FSOURCE - # source interface - where = sources - else: - # sink interface - where = implemented - if item[1] & 1: # IMPLTYPEFLAG_FDEAULT - # The default interface should be the first item on the list - where.insert(0, self._to_type_name(item[0])) - else: - where.append(self._to_type_name(item[0])) + impl, src = typedesc.groupby_impltypeflags(coclass.interfaces) + implemented = [self._to_type_name(itf) for itf in impl] + sources = [self._to_type_name(itf) for itf in src] if implemented: self.last_item_class = False diff --git a/comtypes/tools/typedesc.py b/comtypes/tools/typedesc.py index 8bf167cc..015e009c 100644 --- a/comtypes/tools/typedesc.py +++ b/comtypes/tools/typedesc.py @@ -2,8 +2,9 @@ # in typedesc_base import ctypes -from typing import Any, List, Optional, Tuple, Union as _UnionT +from typing import Any, List, Optional, Sequence, Tuple, Union as _UnionT +from comtypes import typeinfo from comtypes.typeinfo import ITypeLib, TLIBATTR from comtypes.tools.typedesc_base import * @@ -195,6 +196,10 @@ def get_head(self) -> ComInterfaceHead: return self.itf_head +_ImplTypeFlags = int +_Interface = _UnionT[ComInterface, DispInterface] + + class CoClass(object): def __init__( self, name: str, clsid: str, idlflags: List[str], tlibattr: TLIBATTR @@ -203,7 +208,31 @@ def __init__( self.clsid = clsid self.idlflags = idlflags self.tlibattr = tlibattr - self.interfaces: List[Tuple[Any, int]] = [] + self.interfaces: List[Tuple[_Interface, _ImplTypeFlags]] = [] - def add_interface(self, itf: Any, idlflags: int) -> None: + def add_interface(self, itf: _Interface, idlflags: _ImplTypeFlags) -> None: self.interfaces.append((itf, idlflags)) + + +_ImplementedInterfaces = Sequence[_Interface] +_SourceInterfaces = Sequence[_Interface] + + +def groupby_impltypeflags( + seq: Sequence[Tuple[_Interface, _ImplTypeFlags]] +) -> Tuple[_ImplementedInterfaces, _SourceInterfaces]: + implemented = [] + sources = [] + for itf, impltypeflags in seq: + if impltypeflags & typeinfo.IMPLTYPEFLAG_FSOURCE: + # source interface + where = sources + else: + # sink interface + where = implemented + if impltypeflags & typeinfo.IMPLTYPEFLAG_FDEFAULT: + # The default interface should be the first item on the list + where.insert(0, itf) + else: + where.append(itf) + return implemented, sources