diff --git a/formulaic/materializers/base.py b/formulaic/materializers/base.py index deab32b..e295544 100644 --- a/formulaic/materializers/base.py +++ b/formulaic/materializers/base.py @@ -596,6 +596,7 @@ def _lookup(self, name: str) -> Tuple[Any, Set[Variable]]: def _evaluate( self, expr: str, metadata: Any, spec: ModelSpec ) -> Tuple[Any, Set[Variable]]: + variables = set() return ( stateful_eval( expr, @@ -603,8 +604,9 @@ def _evaluate( {expr: metadata}, spec.transform_state, spec, + variables=variables, ), - get_expression_variables(expr, self.layered_context), + variables, ) def _is_categorical(self, values: Any) -> bool: diff --git a/formulaic/utils/stateful_transforms.py b/formulaic/utils/stateful_transforms.py index b0bbb0d..2958a59 100644 --- a/formulaic/utils/stateful_transforms.py +++ b/formulaic/utils/stateful_transforms.py @@ -10,6 +10,7 @@ Mapping, MutableMapping, Optional, + Set, TYPE_CHECKING, cast, ) @@ -19,6 +20,7 @@ from .iterators import peekable_iter from .layered_mapping import LayeredMapping +from .variables import get_expression_variables, Variable if TYPE_CHECKING: from formulaic.model_spec import ModelSpec # pragma: no cover @@ -97,6 +99,7 @@ def stateful_eval( metadata: Optional[Mapping], state: Optional[MutableMapping], spec: Optional["ModelSpec"], + variables: Optional[Set[Variable]] = None, ) -> Any: """ Evaluate an expression in a nominated environment and with a nominated state. @@ -116,6 +119,8 @@ def stateful_eval( stateful transforms). spec: The current `ModelSpec` instance being evaluated (passed through to stateful transforms). + variables: A (optional) set of variables to update with the variables + used in this stateful evaluation. Returns: The result of the evaluation. @@ -133,11 +138,15 @@ def stateful_eval( # Ensure that variable names in code are valid for Python's interpreter # If not, create new variable in mutable env layer, and update code. - expr = sanitize_variable_names(expr, env) + aliases = {} + expr = sanitize_variable_names(expr, env, aliases) # Parse Python code code = ast.parse(expr, mode="eval") + if variables is not None: + variables.update(get_expression_variables(code, env, aliases)) + # Extract the nodes of the graph that correspond to stateful transforms stateful_nodes: Dict[str, ast.Call] = {} for node in ast.walk(code): @@ -230,7 +239,9 @@ def _is_stateful_transform(node: ast.AST, env: Mapping) -> bool: ) -def sanitize_variable_names(expr: str, env: MutableMapping) -> str: +def sanitize_variable_names( + expr: str, env: MutableMapping, aliases: MutableMapping +) -> str: """ Sanitize any variables names in the expression that are not valid Python identifiers and are surrounded by backticks (`). This allows use of field @@ -246,6 +257,8 @@ def sanitize_variable_names(expr: str, env: MutableMapping) -> str: expr: The expression to sanitize. env: The environment to keep updated with any name substitutions. This environment mapping will be mutated in place during this evaluation. + aliases: A dictionary/map to update with any variable mappings. Mapping + is from the sanitized variable back to the original variable. Returns: The sanitized expression. @@ -266,6 +279,7 @@ def sanitize_variable_names(expr: str, env: MutableMapping) -> str: else: next(expr_parts) new_name = sanitize_variable_name(variable_name, env) + aliases[new_name] = variable_name sanitized_expr.append(f" {new_name} ") else: sanitized_expr.append(expr_part) diff --git a/formulaic/utils/variables.py b/formulaic/utils/variables.py index 14f5cc1..f1f8dac 100644 --- a/formulaic/utils/variables.py +++ b/formulaic/utils/variables.py @@ -46,11 +46,21 @@ def union(cls, *variable_sets: Set[Variable]) -> Set[Variable]: def get_expression_variables( - expr: Union[str, ast.AST], context: Mapping + expr: Union[str, ast.AST], context: Mapping, aliases: Optional[Mapping] = None ) -> Set[Variable]: + """ + Extract the variables that are used in the nominated Python expression. + + Args: + expr: The string or AST representing the python expression. + context: The context from which variable values will be looked up. + aliases: A mapping from variable name in the expression to the alias to + assign to the variable (primarily useful when reverting a variable + renaming performed during sanitization). + """ if isinstance(expr, str): expr = ast.parse(expr, mode="eval") - variables = _get_ast_node_variables(expr) + variables = _get_ast_node_variables(expr, aliases or {}) if isinstance(context, LayeredMapping): out = set() @@ -61,20 +71,23 @@ def get_expression_variables( return set(variables) -def _get_ast_node_variables(node: ast.AST) -> List[Variable]: +def _get_ast_node_variables(node: ast.AST, aliases: Mapping) -> List[Variable]: variables: List[Variable] = [] todo = deque([node]) while todo: node = todo.popleft() + if not isinstance(node, (ast.Call, ast.Attribute, ast.Name)): + todo.extend(ast.iter_child_nodes(node)) + continue + name = _get_ast_node_name(node) + name = aliases.get(name, name) if isinstance(node, ast.Call): - variables.append(Variable(_get_ast_node_name(node), roles=["callable"])) + variables.append(Variable(name, roles=["callable"])) todo.extend(node.args) todo.extend(node.keywords) - elif isinstance(node, (ast.Attribute, ast.Name)): - variables.append(Variable(_get_ast_node_name(node), roles=["value"])) else: - todo.extend(ast.iter_child_nodes(node)) + variables.append(Variable(name, roles=["value"])) return variables diff --git a/tests/materializers/test_pandas.py b/tests/materializers/test_pandas.py index 0ffce03..66c2c8b 100644 --- a/tests/materializers/test_pandas.py +++ b/tests/materializers/test_pandas.py @@ -436,3 +436,12 @@ def test_none_values(self, data): assert mm.model_spec.structure == [ EncodedTermStructure(term="None", scoped_terms=[], columns=[]), ] + + def test_quoted_python_args(self): + data = pandas.DataFrame({"exotic!~ -name": [1, 2, 3]}) + mm = PandasMaterializer(data, output="pandas").get_model_matrix( + "np.power(`exotic!~ -name`, 2)" + ) + assert mm.shape == (3, 2) + assert len(mm.model_spec.structure) == 2 + assert numpy.all(mm.values == numpy.array([[1, 1], [1, 4], [1, 9]]))