Skip to content

Commit

Permalink
Merge pull request #418 from Daverball/minor-type-improvements
Browse files Browse the repository at this point in the history
Improves mypy config. Adds explicit re-exports to `__init__.py`.
  • Loading branch information
malthe authored Apr 4, 2024
2 parents 0822ac9 + 0d52645 commit b757a4d
Show file tree
Hide file tree
Showing 8 changed files with 91 additions and 48 deletions.
8 changes: 7 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ mypy_path = "$MYPY_CONFIG_FILE_DIR/src"
# we may want to include tests eventually
exclude = "/tests/"
follow_imports = "silent"
warn_redundant_casts = true
warn_unused_configs = true
warn_unused_ignores = true
warn_return_any = true

[[tool.mypy.overrides]]
# strict config for fully typed modules and public API
Expand All @@ -23,7 +27,9 @@ disallow_untyped_calls = true
disallow_untyped_defs = true
disallow_incomplete_defs = true
disallow_untyped_decorators = true
warn_unused_ignores = true
no_implicit_reexport = true
strict_equality = true
extra_checks = true

[[tool.mypy.overrides]]
module = ["zope.*"]
Expand Down
3 changes: 0 additions & 3 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@ universal = 0
[flake8]
doctests = 1
extend-select = TC1
# F401 imported but unused
per-file-ignores =
src/chameleon/__init__.py: F401

[check-manifest]
ignore =
Expand Down
10 changes: 10 additions & 0 deletions src/chameleon/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,13 @@
from chameleon.zpt.template import PageTemplateFile
from chameleon.zpt.template import PageTextTemplate
from chameleon.zpt.template import PageTextTemplateFile


__all__ = (
'TemplateError',
'PageTemplateLoader',
'PageTemplate',
'PageTemplateFile',
'PageTextTemplate',
'PageTextTemplateFile',
)
58 changes: 38 additions & 20 deletions src/chameleon/astutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,35 +5,44 @@
import ast
from copy import deepcopy
from typing import TYPE_CHECKING
from typing import Any
from typing import ClassVar


if TYPE_CHECKING:
from collections.abc import Callable
from typing import Union
_NodeTransform = Callable[[ast.AST], Union[ast.AST, None]]
from collections.abc import Hashable
from typing import Optional

from chameleon.tokenize import Token

_NodeTransform = Callable[[ast.AST], Optional[ast.AST]]


__docformat__ = 'restructuredtext en'


def parse(source, mode='eval'):
def parse(source, mode: str = 'eval') -> ast.AST:
return compile(source, '', mode, ast.PyCF_ONLY_AST)


def load(name):
def load(name: str) -> ast.Name:
return ast.Name(id=name, ctx=ast.Load())


def store(name):
def store(name: str) -> ast.Name:
return ast.Name(id=name, ctx=ast.Store())


def param(name):
def param(name: str) -> ast.Name:
return ast.Name(id=name, ctx=ast.Param())


def subscript(name, value, ctx):
def subscript(
name: str,
value: ast.expr,
ctx: ast.expr_context
) -> ast.Subscript:
return ast.Subscript(
value=value,
slice=ast.Index(value=ast.Str(s=name)),
Expand All @@ -47,7 +56,7 @@ class Node(ast.AST):

_fields: ClassVar[tuple[str, ...]] = ()

def __init__(self, *args, **kwargs) -> None:
def __init__(self, *args: Any, **kwargs: Any) -> None:
assert isinstance(self._fields, tuple)
self.__dict__.update(kwargs)
for name, value in zip(self._fields, args):
Expand Down Expand Up @@ -77,6 +86,7 @@ class Builtin(Node):

_fields = "id", "ctx"

id: str
ctx = ast.Load()


Expand All @@ -85,24 +95,32 @@ class Symbol(Node):

_fields = "value",

# Apart from a few builtins this should be type[Any]
value: type[Any] | Hashable


class Static(Node):
"""Represents a static value."""

_fields = "value", "name"

name = None
value: ast.expr
name: str | None = None


class Comment(Node):
_fields = "text",

text: str


class TokenRef(Node):
"""Represents a source-code token reference."""

_fields = "token",

token: Token


class NodeTransformerBase(ast.NodeTransformer):
def __init__(self, transform: _NodeTransform):
Expand All @@ -120,34 +138,34 @@ def __init__(self, transform: _NodeTransform):
self.scopes: list[set[str]] = [set()]
super().__init__(transform)

def __call__(self, node) -> ast.AST:
def __call__(self, node: ast.AST) -> ast.AST:
clone = deepcopy(node)
return self.visit(clone)
return self.visit(clone) # type: ignore[no-any-return]

def visit_arg(self, node) -> ast.AST:
def visit_arg(self, node: ast.arg) -> ast.AST:
scope = self.scopes[-1]
scope.add(node.arg)
return node

def visit_Name(self, node) -> ast.AST:
def visit_Name(self, node: ast.Name) -> ast.AST:
scope = self.scopes[-1]
if isinstance(node.ctx, ast.Param):
scope.add(node.id)
return node
if node.id not in scope:
node = self.apply_transform(node)
return self.apply_transform(node)
return node

def visit_FunctionDef(self, node) -> ast.AST:
def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.AST:
self.scopes[-1].add(node.name)
return super().generic_visit(node)

def visit_alias(self, node) -> ast.AST:
def visit_alias(self, node: ast.alias) -> ast.AST:
name = node.asname if node.asname is not None else node.name
self.scopes[-1].add(name)
return super().generic_visit(node)

def visit_Lambda(self, node) -> ast.AST:
def visit_Lambda(self, node: ast.Lambda) -> ast.AST:
self.scopes.append(set())
try:
return super().generic_visit(node)
Expand All @@ -156,6 +174,6 @@ def visit_Lambda(self, node) -> ast.AST:


class ItemLookupOnAttributeErrorVisitor(NodeTransformerBase):
def visit_Attribute(self, node) -> ast.AST:
node = self.apply_transform(node)
return self.generic_visit(node)
def visit_Attribute(self, node: ast.Attribute) -> ast.AST:
transformed = self.apply_transform(node)
return self.generic_visit(transformed)
36 changes: 24 additions & 12 deletions src/chameleon/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from ast import NodeTransformer
from ast import alias
from ast import unparse
from typing import TYPE_CHECKING
from typing import Any

from chameleon.astutil import Builtin
from chameleon.astutil import Symbol
Expand All @@ -24,7 +26,15 @@
from chameleon.exc import CompilationError


reverse_builtin_map = {}
if TYPE_CHECKING:
import ast
from collections.abc import Hashable

from chameleon.astutil import Comment
from chameleon.astutil import Static


reverse_builtin_map: dict[type[Any] | Hashable, str] = {}
for name, value in builtins.__dict__.items():
try:
hash(value)
Expand Down Expand Up @@ -60,7 +70,7 @@ def visit_FunctionDef(self, node) -> AST:
lineno=None,
)

def visit_Name(self, node) -> AST:
def visit_Name(self, node: ast.Name) -> AST:
value = symbols.get(node.id, self)
if value is self:
if node.id == 'None' or \
Expand All @@ -77,7 +87,7 @@ def visit_Name(self, node) -> AST:
if isinstance(value, str):
value = load(value)

return value
return value # type: ignore[no-any-return]

expr = parse(textwrap.dedent(source), mode=mode)

Expand All @@ -103,6 +113,8 @@ class TemplateCodeGenerator(NodeTransformer):

names = ()

imports: dict[type[Any] | Hashable, ast.Name]

def __init__(self, tree):
self.comments = []
self.defines = {}
Expand Down Expand Up @@ -142,7 +154,7 @@ def define(self, name, node):

return load(name)

def require(self, value):
def require(self, value: type[Any] | Hashable) -> ast.Name:
node = self.imports.get(value)
if node is None:
# we come up with a unique symbol based on the class name
Expand All @@ -155,9 +167,9 @@ def require(self, value):

return node

def visit_Module(self, module) -> AST:
def visit_Module(self, module: Module) -> AST:
assert isinstance(module, Module)
module = super().generic_visit(module)
module = super().generic_visit(module) # type: ignore[assignment]
preamble: list[AST] = []

for name, node in self.defines.items():
Expand Down Expand Up @@ -188,21 +200,21 @@ def visit_Module(self, module) -> AST:

return Module(imports + preamble + module.body, ())

def visit_Comment(self, node) -> AST:
def visit_Comment(self, node: Comment) -> AST:
self.comments.append(node.text)
return Expr(Constant(...))

def visit_Builtin(self, node) -> AST:
def visit_Builtin(self, node: Builtin) -> AST:
name = load(node.id)
return self.visit(name)
return self.visit(name) # type: ignore[no-any-return]

def visit_Symbol(self, node) -> AST:
def visit_Symbol(self, node: Symbol) -> AST:
return self.require(node.value)

def visit_Static(self, node) -> AST:
def visit_Static(self, node: Static) -> AST:
if node.name is None:
name = "_static_%s" % str(id(node.value)).replace('-', '_')
else:
name = node.name
node = self.define(name, node.value)
return self.visit(node)
return self.visit(node) # type: ignore[no-any-return]
18 changes: 9 additions & 9 deletions src/chameleon/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,24 +62,24 @@
RE_NAME = re.compile('^%s$' % NAME)


def identifier(prefix, suffix=None) -> str:
def identifier(prefix: str, suffix: str | None = None) -> str:
return "__{}_{}".format(prefix, mangle(suffix or id(prefix)))


def mangle(string):
def mangle(string: int | str) -> str:
return RE_MANGLE.sub('_', str(string)).replace('\n', '').replace('-', '_')


def load_econtext(name):
return template("getname(KEY)", KEY=ast.Str(s=name), mode="eval")


def store_econtext(name):
def store_econtext(name: object) -> ast.Subscript:
name = str(name)
return subscript(name, load("econtext"), ast.Store())


def store_rcontext(name):
def store_rcontext(name: object) -> ast.Subscript:
name = str(name)
return subscript(name, load("rcontext"), ast.Store())

Expand All @@ -101,7 +101,7 @@ def eval_token(token):
)


def indent(s):
def indent(s: str | None) -> str:
return textwrap.indent(s, " ") if s else ""


Expand Down Expand Up @@ -1024,16 +1024,16 @@ class Generator(TemplateCodeGenerator):

def visit_EmitText(self, node) -> ast.AST:
append = load(self.scopes[-1].append or "__append")
node = ast.Expr(ast.Call(
expr = ast.Expr(ast.Call(
func=append,
args=[ast.Str(s=node.s)],
keywords=[],
starargs=None,
kwargs=None
))
return self.visit(node)
return self.visit(expr) # type: ignore[no-any-return]

def visit_Name(self, node) -> ast.AST:
def visit_Name(self, node: ast.Name) -> ast.AST:
if isinstance(node.ctx, ast.Load):
scope = self.scopes[-1]
for name in ("append", "stream"):
Expand All @@ -1049,7 +1049,7 @@ def visit_TranslationContext(self, node) -> list[ast.AST]:
self.scopes.pop()
return stmts

def visit_TokenRef(self, node) -> ast.AST:
def visit_TokenRef(self, node: TokenRef) -> ast.AST:
self.tokens.append((node.token.pos, len(node.token)))
return ast.Assign(
[store("__token")],
Expand Down
4 changes: 2 additions & 2 deletions src/chameleon/i18n.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,11 @@ def fast_translate(
return None

if target_language is not None or context is not None:
result = translate(
result: str = translate(
msgid, domain=domain, mapping=mapping, context=context,
target_language=target_language, default=default)
if result != msgid:
return result # type: ignore[no-any-return]
return result

if isinstance(msgid, Message):
default = msgid.default
Expand Down
2 changes: 1 addition & 1 deletion src/chameleon/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ def _compile(self, body: str, builtins: Collection[str]) -> str:
builtins=builtins,
strict=self.strict
)
return compiler.code
return compiler.code # type: ignore[no-any-return]


class BaseTemplateFile(BaseTemplate):
Expand Down

0 comments on commit b757a4d

Please sign in to comment.