Skip to content

Commit

Permalink
Fix quoting of variables used in Python calls.
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewwardrop committed Sep 20, 2023
1 parent b3d2d92 commit 13e3414
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 10 deletions.
4 changes: 3 additions & 1 deletion formulaic/materializers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,15 +596,17 @@ def _lookup(self, name: str) -> Tuple[Any, Set[Variable]]:
def _evaluate(
self, expr: str, metadata: Any, spec: ModelSpec
) -> Tuple[Any, Set[Variable]]:
variables = set()
return (
stateful_eval(
expr,
self.layered_context,
{expr: metadata},
spec.transform_state,
spec,
variables=variables,
),
get_expression_variables(expr, self.layered_context),
variables,
)

def _is_categorical(self, values: Any) -> bool:
Expand Down
18 changes: 16 additions & 2 deletions formulaic/utils/stateful_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Mapping,
MutableMapping,
Optional,
Set,
TYPE_CHECKING,
cast,
)
Expand All @@ -19,6 +20,7 @@

from .iterators import peekable_iter
from .layered_mapping import LayeredMapping
from .variables import get_expression_variables, Variable

if TYPE_CHECKING:
from formulaic.model_spec import ModelSpec # pragma: no cover
Expand Down Expand Up @@ -97,6 +99,7 @@ def stateful_eval(
metadata: Optional[Mapping],
state: Optional[MutableMapping],
spec: Optional["ModelSpec"],
variables: Optional[Set[Variable]] = None,
) -> Any:
"""
Evaluate an expression in a nominated environment and with a nominated state.
Expand All @@ -116,6 +119,8 @@ def stateful_eval(
stateful transforms).
spec: The current `ModelSpec` instance being evaluated (passed through
to stateful transforms).
variables: A (optional) set of variables to update with the variables
used in this stateful evaluation.
Returns:
The result of the evaluation.
Expand All @@ -133,11 +138,15 @@ def stateful_eval(

# Ensure that variable names in code are valid for Python's interpreter
# If not, create new variable in mutable env layer, and update code.
expr = sanitize_variable_names(expr, env)
aliases = {}
expr = sanitize_variable_names(expr, env, aliases)

# Parse Python code
code = ast.parse(expr, mode="eval")

if variables is not None:
variables.update(get_expression_variables(code, env, aliases))

# Extract the nodes of the graph that correspond to stateful transforms
stateful_nodes: Dict[str, ast.Call] = {}
for node in ast.walk(code):
Expand Down Expand Up @@ -230,7 +239,9 @@ def _is_stateful_transform(node: ast.AST, env: Mapping) -> bool:
)


def sanitize_variable_names(expr: str, env: MutableMapping) -> str:
def sanitize_variable_names(
expr: str, env: MutableMapping, aliases: MutableMapping
) -> str:
"""
Sanitize any variables names in the expression that are not valid Python
identifiers and are surrounded by backticks (`). This allows use of field
Expand All @@ -246,6 +257,8 @@ def sanitize_variable_names(expr: str, env: MutableMapping) -> str:
expr: The expression to sanitize.
env: The environment to keep updated with any name substitutions. This
environment mapping will be mutated in place during this evaluation.
aliases: A dictionary/map to update with any variable mappings. Mapping
is from the sanitized variable back to the original variable.
Returns:
The sanitized expression.
Expand All @@ -266,6 +279,7 @@ def sanitize_variable_names(expr: str, env: MutableMapping) -> str:
else:
next(expr_parts)
new_name = sanitize_variable_name(variable_name, env)
aliases[new_name] = variable_name
sanitized_expr.append(f" {new_name} ")
else:
sanitized_expr.append(expr_part)
Expand Down
27 changes: 20 additions & 7 deletions formulaic/utils/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,21 @@ def union(cls, *variable_sets: Set[Variable]) -> Set[Variable]:


def get_expression_variables(
expr: Union[str, ast.AST], context: Mapping
expr: Union[str, ast.AST], context: Mapping, aliases: Optional[Mapping] = None
) -> Set[Variable]:
"""
Extract the variables that are used in the nominated Python expression.
Args:
expr: The string or AST representing the python expression.
context: The context from which variable values will be looked up.
aliases: A mapping from variable name in the expression to the alias to
assign to the variable (primarily useful when reverting a variable
renaming performed during sanitization).
"""
if isinstance(expr, str):
expr = ast.parse(expr, mode="eval")
variables = _get_ast_node_variables(expr)
variables = _get_ast_node_variables(expr, aliases or {})

if isinstance(context, LayeredMapping):
out = set()
Expand All @@ -61,20 +71,23 @@ def get_expression_variables(
return set(variables)


def _get_ast_node_variables(node: ast.AST) -> List[Variable]:
def _get_ast_node_variables(node: ast.AST, aliases: Mapping) -> List[Variable]:
variables: List[Variable] = []

todo = deque([node])
while todo:
node = todo.popleft()
if not isinstance(node, (ast.Call, ast.Attribute, ast.Name)):
todo.extend(ast.iter_child_nodes(node))
continue
name = _get_ast_node_name(node)
name = aliases.get(name, name)
if isinstance(node, ast.Call):
variables.append(Variable(_get_ast_node_name(node), roles=["callable"]))
variables.append(Variable(name, roles=["callable"]))
todo.extend(node.args)
todo.extend(node.keywords)
elif isinstance(node, (ast.Attribute, ast.Name)):
variables.append(Variable(_get_ast_node_name(node), roles=["value"]))
else:
todo.extend(ast.iter_child_nodes(node))
variables.append(Variable(name, roles=["value"]))

return variables

Expand Down
9 changes: 9 additions & 0 deletions tests/materializers/test_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,3 +436,12 @@ def test_none_values(self, data):
assert mm.model_spec.structure == [
EncodedTermStructure(term="None", scoped_terms=[], columns=[]),
]

def test_quoted_python_args(self):
data = pandas.DataFrame({"exotic!~ -name": [1, 2, 3]})
mm = PandasMaterializer(data, output="pandas").get_model_matrix(
"np.power(`exotic!~ -name`, 2)"
)
assert mm.shape == (3, 2)
assert len(mm.model_spec.structure) == 2
assert numpy.all(mm.values == numpy.array([[1, 1], [1, 4], [1, 9]]))

0 comments on commit 13e3414

Please sign in to comment.