Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add token sanitization to make equivalent Python factors share the same formatting. #167

Merged
merged 3 commits into from
Dec 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions formulaic/parser/algos/sanitize_tokens.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing import Dict, Iterable

from formulaic.utils.code import format_expr, sanitize_variable_names

from ..types.token import Token


def sanitize_tokens(tokens: Iterable[Token]) -> Iterable[Token]:
"""
Sanitize a sequence of tokens. Given that tokens are user contributed code,
we need to be able to do various hygiene checks/transforms in order to
ensure consistent behavior downstream. In particular, we check for:
- `python` tokens should be consistently formatted so that set operators
and stateful transforms recognise when tokens are equivalent.
- possible more in the future
"""
for token in tokens:
if token.kind is Token.Kind.PYTHON:
token.token = sanitize_python_code(token.token)
yield token


def sanitize_python_code(expr: str) -> str:
"""
Ensure than python code is consistently formatted, and that quoted portions
(by backticks) are properly handled.
"""
aliases: Dict[str, str] = {}
expr = format_expr(
sanitize_variable_names(expr, {}, aliases, template="_formulaic_{}")
)
while aliases:
alias, orig = aliases.popitem()
expr = expr.replace(alias, f"`{orig}`")
return expr
6 changes: 3 additions & 3 deletions formulaic/parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from dataclasses import dataclass, field
from typing import List, Iterable, Sequence, Tuple, Union, cast

from formulaic.parser.types.factor import Factor

from .algos.sanitize_tokens import sanitize_tokens
from .algos.tokenize import tokenize
from .types import (
Factor,
FormulaParser,
Operator,
OperatorResolver,
Expand Down Expand Up @@ -71,7 +71,7 @@ def get_tokens(self, formula: str) -> Iterable[Token]:
token_plus = Token("+", kind=Token.Kind.OPERATOR)
token_minus = Token("-", kind=Token.Kind.OPERATOR)

tokens = tokenize(formula)
tokens = sanitize_tokens(tokenize(formula))

# Substitute "0" with "-1"
tokens = replace_tokens(
Expand Down
3 changes: 2 additions & 1 deletion formulaic/parser/types/formula_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@ def get_tokens(self, formula: str) -> Iterable[Token]:
formula: The formula string to be tokenized.
"""
from ..algos.tokenize import tokenize
from ..algos.sanitize_tokens import sanitize_tokens

return tokenize(formula)
return sanitize_tokens(tokenize(formula))

def get_ast(self, formula: str) -> Union[None, Token, ASTNode]:
"""
Expand Down
121 changes: 121 additions & 0 deletions formulaic/utils/code.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import ast
import keyword
import re
import sys
from typing import MutableMapping, Union

import numpy

from .iterators import peekable_iter

# Expression formatting


def format_expr(expr: Union[str, ast.AST]) -> str: # pragma: no cover; branched code
if sys.version_info >= (3, 9):
code = ast.parse(expr, mode="eval") if isinstance(expr, str) else expr
return ast.unparse(code).replace("\n", " ")

import astor # pylint: disable=import-error

# Note: We use `mode="exec"` here because `astor` inserts parentheses around
# expressions that cannot be naively removed. We still require that these
# are `eval`-uable in the `stateful_eval` method.
code = ast.parse(expr, mode="exec") if isinstance(expr, str) else expr
return astor.to_source(code).strip().replace("\n ", "")


# Variable sanitization


UNQUOTED_BACKTICK_MATCHER = re.compile(
r"(\\\"|\"(?:\\\"|[^\"])*\"|\\'|'(?:\\'|[^'])*'|`)"
)


def sanitize_variable_names(
expr: str, env: MutableMapping, aliases: MutableMapping, *, template: str = "{}"
) -> str:
"""
Sanitize any variables names in the expression that are not valid Python
identifiers and are surrounded by backticks (`). This allows use of field
names that are not valid Python names.

This function transforms `expr` into a new expression where identifiers that
would cause `SyntaxError`s are transformed into valid Python identifiers.
E.g. "func(`1a`)" -> "func(_1a)". `env` is updated to reflect the mapping of
the old identifier to the new one, provided that the original variable name
was already present.

Args:
expr: The expression to sanitize.
env: The environment to keep updated with any name substitutions. This
environment mapping will be mutated in place during this evaluation.
aliases: A dictionary/map to update with any variable mappings. Mapping
is from the sanitized variable back to the original variable.
template: A template to use for sanitized names, which is mainly useful
if you need to undo the sanitization by string replacement.

Returns:
The sanitized expression.
"""

expr_parts = peekable_iter(UNQUOTED_BACKTICK_MATCHER.split(expr))

sanitized_expr = []

for expr_part in expr_parts:
if expr_part == "`":
variable_name_parts = []
while expr_parts.peek(None) not in ("`", None):
variable_name_parts.append(next(expr_parts))
variable_name = "".join(variable_name_parts)
if expr_parts.peek(None) is None:
sanitized_expr.append(f"`{variable_name}")
else:
next(expr_parts)
new_name = sanitize_variable_name(variable_name, env, template=template)
aliases[new_name] = variable_name
sanitized_expr.append(f" {new_name} ")
else:
sanitized_expr.append(expr_part)

return "".join(sanitized_expr).strip()


def sanitize_variable_name(
name: str, env: MutableMapping, *, template: str = "{}"
) -> str:
"""
Generate a valid Python variable name for variable identifier `name`.

Args:
name: The variable name to sanitize.
env: The mapping of variable name to values in the evaluation
environment. If `name` is present in this mapping, an alias is
created for the same value for the new variable name.
template: A template to use for sanitized names, which is mainly useful
if you need to undo the sanitization by string replacement.
"""
if name.isidentifier() or keyword.iskeyword(name):
return name

# Compute recognisable basename
base_name = "".join([char if re.match(r"\w", char) else "_" for char in name])
if base_name[0].isdigit():
base_name = "_" + base_name

# Verify new name is not in env already, and if not add a random suffix.
new_name = template.format(base_name)
while new_name in env:
new_name = template.format(
base_name
+ "_"
+ "".join(numpy.random.choice(list("abcefghiklmnopqrstuvwxyz"), 10))
)

# Reuse the value for `name` for `new_name` also.
if name in env:
env[new_name] = env[name]

return new_name
103 changes: 3 additions & 100 deletions formulaic/utils/stateful_transforms.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import ast
import functools
import inspect
import keyword
import re
from typing import (
Any,
Callable,
Expand All @@ -15,10 +13,7 @@
cast,
)

import astor
import numpy

from .iterators import peekable_iter
from .code import format_expr, sanitize_variable_names
from .layered_mapping import LayeredMapping
from .variables import get_expression_variables, Variable

Expand Down Expand Up @@ -151,9 +146,7 @@ def stateful_eval(
stateful_nodes: Dict[str, ast.Call] = {}
for node in ast.walk(code):
if _is_stateful_transform(node, env):
stateful_nodes[astor.to_source(node).strip().replace("\n ", "")] = cast(
ast.Call, node
)
stateful_nodes[format_expr(node)] = cast(ast.Call, node)

# Mutate stateful nodes to pass in state from a shared dictionary.
for name, node in stateful_nodes.items():
Expand Down Expand Up @@ -224,98 +217,8 @@ def _is_stateful_transform(node: ast.AST, env: Mapping) -> bool:

try:
func = eval(
compile(astor.to_source(node.func).strip(), "", "eval"), {}, env
compile(format_expr(node.func), "", "eval"), {}, env
) # nosec; Get function handle (assuming it exists in env)
return getattr(func, "__is_stateful_transform__", False)
except NameError:
return False


# Variable sanitization


UNQUOTED_BACKTICK_MATCHER = re.compile(
r"(\\\"|\"(?:\\\"|[^\"])*\"|\\'|'(?:\\'|[^'])*'|`)"
)


def sanitize_variable_names(
expr: str, env: MutableMapping, aliases: MutableMapping
) -> str:
"""
Sanitize any variables names in the expression that are not valid Python
identifiers and are surrounded by backticks (`). This allows use of field
names that are not valid Python names.

This function transforms `expr` into a new expression where identifiers that
would cause `SyntaxError`s are transformed into valid Python identifiers.
E.g. "func(`1a`)" -> "func(_1a)". `env` is updated to reflect the mapping of
the old identifier to the new one, provided that the original variable name
was already present.

Args:
expr: The expression to sanitize.
env: The environment to keep updated with any name substitutions. This
environment mapping will be mutated in place during this evaluation.
aliases: A dictionary/map to update with any variable mappings. Mapping
is from the sanitized variable back to the original variable.

Returns:
The sanitized expression.
"""

expr_parts = peekable_iter(UNQUOTED_BACKTICK_MATCHER.split(expr))

sanitized_expr = []

for expr_part in expr_parts:
if expr_part == "`":
variable_name_parts = []
while expr_parts.peek(None) not in ("`", None):
variable_name_parts.append(next(expr_parts))
variable_name = "".join(variable_name_parts)
if expr_parts.peek(None) is None:
sanitized_expr.append(f"`{variable_name}")
else:
next(expr_parts)
new_name = sanitize_variable_name(variable_name, env)
aliases[new_name] = variable_name
sanitized_expr.append(f" {new_name} ")
else:
sanitized_expr.append(expr_part)

return "".join(sanitized_expr).strip()


def sanitize_variable_name(name: str, env: MutableMapping) -> str:
"""
Generate a valid Python variable name for variable identifier `name`.

Args:
name: The variable name to sanitize.
env: The mapping of variable name to values in the evaluation
environment. If `name` is present in this mapping, an alias is
created for the same value for the new variable name.
"""
if name.isidentifier() or keyword.iskeyword(name):
return name

# Compute recognisable basename
base_name = "".join([char if re.match(r"\w", char) else "_" for char in name])
if base_name[0].isdigit():
base_name = "_" + base_name

# Verify new name is not in env already, and if not add a random suffix.
new_name = base_name
while new_name in env:
new_name = (
base_name
+ "_"
+ "".join(numpy.random.choice(list("abcefghiklmnopqrstuvwxyz"), 10))
)

# Reuse the value for `name` for `new_name` also.
if name in env:
env[new_name] = env[name]

return new_name
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ classifiers = [
]
requires-python = ">=3.7.2"
dependencies = [
"astor>=0.8",
"astor>=0.8; python_version < \"3.9\"",
"cached-property>=1.3.0; python_version < \"3.8\"",
"graphlib-backport>=1.0.0; python_version < \"3.9\"",
"interface-meta>=1.2.0",
Expand Down Expand Up @@ -92,7 +92,7 @@ dependencies = [
"formulaic[arrow,calculus]",
"pytest==7.2.0",
"pytest-cov==4.0.0",
"astor==0.8",
"astor==0.8; python_version < \"3.9\"",
"cached-property==1.3.0; python_version < \"3.8\"",
"graphlib-backport==1.0.0; python_version < \"3.9\"",
"interface-meta==1.2.0",
Expand Down
Loading