diff --git a/formulaic/parser/parser.py b/formulaic/parser/parser.py index 3100c21..c0d12a6 100644 --- a/formulaic/parser/parser.py +++ b/formulaic/parser/parser.py @@ -228,6 +228,23 @@ def power(arg: OrderedSet[Term], power: OrderedSet[Term]) -> OrderedSet[Term]: for term in itertools.product(*[arg] * int(power_term.factors[0].expr)) ) + def sub_formula(lhs, rhs): + def get_terms(terms): + return [ + Term( + factors=[Factor(str(t) + "_hat", eval_method="lookup")], + origin=t, + ) + for t in terms + ] + + if isinstance(lhs, Structured): + lhs_hat = lhs._update(root=get_terms(lhs.root)) + else: + lhs_hat = get_terms(lhs) + + return Structured(lhs_hat, deps=(Structured(lhs=lhs, rhs=rhs),)) + return [ Operator( "~", @@ -248,6 +265,15 @@ def power(arg: OrderedSet[Term], power: OrderedSet[Term]) -> OrderedSet[Term]: accepts_context=lambda context: len(context) == 0, structural=True, ), + Operator( + "~", + arity=2, + precedence=-100, + associativity=None, + to_terms=sub_formula, + accepts_context=lambda context: context and context[-1] == "[", + structural=True, + ), Operator( "|", arity=2, diff --git a/formulaic/parser/types/term.py b/formulaic/parser/types/term.py index 1cb575d..56a388f 100644 --- a/formulaic/parser/types/term.py +++ b/formulaic/parser/types/term.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Iterable, TYPE_CHECKING +from typing import Any, Iterable, Optional, TYPE_CHECKING if TYPE_CHECKING: from .factor import Factor # pragma: no cover @@ -15,10 +15,13 @@ class Term: Attributes: factors: The set of factors to be multiplied to form the term. + origin: If this `Term` has been derived from another `Term`, for example + in subformulae, a reference to the original term. """ - def __init__(self, factors: Iterable["Factor"]): + def __init__(self, factors: Iterable["Factor"], origin: Optional[Term] = None): self.factors = tuple(dict.fromkeys(factors)) + self.origin = origin self._factor_key = tuple(factor.expr for factor in sorted(self.factors)) self._hash = hash(":".join(self._factor_key))