Skip to content

Commit

Permalink
Modernize codegenerator.namespaces. (enthought#622)
Browse files Browse the repository at this point in the history
* Update `EnumerationNamespaces`.

* Update `DeclaredNamespaces`.

* Update `ImportedNamespaces`.
  • Loading branch information
junkmd authored Sep 23, 2024
1 parent ffaf502 commit 945800b
Showing 1 changed file with 22 additions and 15 deletions.
37 changes: 22 additions & 15 deletions comtypes/tools/codegenerator/namespaces.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
import textwrap
from typing import Optional, Union as _UnionT
from typing import Dict, List, Set, Tuple
from typing import Mapping, Sequence


class ImportedNamespaces(object):
def __init__(self):
self.data = {}

def add(self, name1, name2=None, symbols=None):
def __init__(self) -> None:
self.data: Dict[str, Optional[str]] = {}

def add(
self,
name1: str,
name2: Optional[str] = None,
symbols: Optional[Mapping[str, str]] = None,
) -> None:
"""Adds a namespace will be imported.
Examples:
Expand Down Expand Up @@ -43,7 +50,7 @@ def add(self, name1, name2=None, symbols=None):
from_, import_ = name1, name2
self.data[import_] = from_

def __contains__(self, item):
def __contains__(self, item: _UnionT[str, Tuple[str, str]]) -> bool:
"""Returns item has already added.
Examples:
Expand Down Expand Up @@ -77,7 +84,7 @@ def get_symbols(self) -> Set[str]:
names.add(key)
return names

def _make_line(self, from_, imports):
def _make_line(self, from_: str, imports: Sequence[str]) -> str:
import_ = ", ".join(imports)
code = "from %s import %s" % (from_, import_)
if len(code) <= 80:
Expand All @@ -89,16 +96,16 @@ def _make_line(self, from_, imports):
code = "from %s import (\n%s\n)" % (from_, import_)
return code

def getvalue(self):
ns = {}
lines = []
def getvalue(self) -> str:
ns: Dict[str, Optional[Set[str]]] = {}
lines: List[str] = []
for key, val in self.data.items():
if val is None:
ns[key] = val
elif key == "*":
lines.append("from %s import *" % val)
else:
ns.setdefault(val, set()).add(key)
ns.setdefault(val, set()).add(key) # type: ignore
for key, val in ns.items():
if val is None:
lines.append("import %s" % key)
Expand All @@ -109,10 +116,10 @@ def getvalue(self):


class DeclaredNamespaces(object):
def __init__(self):
self.data = {}
def __init__(self) -> None:
self.data: Dict[Tuple[str, str], Optional[str]] = {}

def add(self, alias, definition, comment=None):
def add(self, alias: str, definition: str, comment: Optional[str] = None) -> None:
"""Adds a namespace will be declared.
Examples:
Expand All @@ -134,7 +141,7 @@ def get_symbols(self) -> Set[str]:
names.add(alias)
return names

def getvalue(self):
def getvalue(self) -> str:
lines = []
for (alias, definition), comment in self.data.items():
code = "%s = %s" % (alias, definition)
Expand All @@ -145,7 +152,7 @@ def getvalue(self):


class EnumerationNamespaces(object):
def __init__(self):
def __init__(self) -> None:
self.data: Dict[str, List[Tuple[str, int]]] = {}

def add(self, enum_name: str, member_name: str, value: int) -> None:
Expand Down

0 comments on commit 945800b

Please sign in to comment.