Skip to content

Commit

Permalink
Add format_expr utility function, and only use astor for Python <…
Browse files Browse the repository at this point in the history
…3.9.
  • Loading branch information
matthewwardrop committed Dec 24, 2023
1 parent 56202d9 commit 84ebde7
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 13 deletions.
17 changes: 13 additions & 4 deletions formulaic/utils/code.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,27 @@
import ast
import keyword
import re
from typing import MutableMapping
import sys
from typing import MutableMapping, Union

import astor
import numpy

from .iterators import peekable_iter

# Expression formatting


def format_expr(expr: str) -> str:
code = ast.parse(expr, mode="eval")
def format_expr(expr: Union[str, ast.AST]) -> str: # pragma: no cover; branched code
if sys.version_info >= (3, 9):
code = ast.parse(expr, mode="eval") if isinstance(expr, str) else expr
return ast.unparse(code).replace("\n", " ")

import astor # pylint: disable=import-error

# Note: We use `mode="exec"` here because `astor` inserts parentheses around
# expressions that cannot be naively removed. We still require that these
# are `eval`-uable in the `stateful_eval` method.
code = ast.parse(expr, mode="exec") if isinstance(expr, str) else expr
return astor.to_source(code).strip().replace("\n ", "")


Expand Down
10 changes: 3 additions & 7 deletions formulaic/utils/stateful_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@
cast,
)

import astor

from .code import sanitize_variable_names
from .code import format_expr, sanitize_variable_names
from .layered_mapping import LayeredMapping
from .variables import get_expression_variables, Variable

Expand Down Expand Up @@ -148,9 +146,7 @@ def stateful_eval(
stateful_nodes: Dict[str, ast.Call] = {}
for node in ast.walk(code):
if _is_stateful_transform(node, env):
stateful_nodes[astor.to_source(node).strip().replace("\n ", "")] = cast(
ast.Call, node
)
stateful_nodes[format_expr(node)] = cast(ast.Call, node)

# Mutate stateful nodes to pass in state from a shared dictionary.
for name, node in stateful_nodes.items():
Expand Down Expand Up @@ -221,7 +217,7 @@ def _is_stateful_transform(node: ast.AST, env: Mapping) -> bool:

try:
func = eval(
compile(astor.to_source(node.func).strip(), "", "eval"), {}, env
compile(format_expr(node.func), "", "eval"), {}, env
) # nosec; Get function handle (assuming it exists in env)
return getattr(func, "__is_stateful_transform__", False)
except NameError:
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ classifiers = [
]
requires-python = ">=3.7.2"
dependencies = [
"astor>=0.8",
"astor>=0.8; python_version < \"3.9\"",
"cached-property>=1.3.0; python_version < \"3.8\"",
"graphlib-backport>=1.0.0; python_version < \"3.9\"",
"interface-meta>=1.2.0",
Expand Down Expand Up @@ -92,7 +92,7 @@ dependencies = [
"formulaic[arrow,calculus]",
"pytest==7.2.0",
"pytest-cov==4.0.0",
"astor==0.8",
"astor==0.8; python_version < \"3.9\"",
"cached-property==1.3.0; python_version < \"3.8\"",
"graphlib-backport==1.0.0; python_version < \"3.9\"",
"interface-meta==1.2.0",
Expand Down

0 comments on commit 84ebde7

Please sign in to comment.