From b3575c8449d4b8e5b20ff5201abb000e1398c272 Mon Sep 17 00:00:00 2001 From: Matthew Wardrop Date: Mon, 26 Sep 2022 20:59:26 -0700 Subject: [PATCH] Keep track of original term for reverse lookup during model evaluation. --- formulaic/parser/parser.py | 2 +- formulaic/parser/types/term.py | 9 +++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/formulaic/parser/parser.py b/formulaic/parser/parser.py index 8b97a25..6c1b1e1 100644 --- a/formulaic/parser/parser.py +++ b/formulaic/parser/parser.py @@ -166,7 +166,7 @@ def power(arg: Set[Term], power: Set[Term]) -> Set[Term]: def sub_formula(lhs, rhs): def get_terms(terms): return [ - Term(factors=[Factor(str(t) + "_hat", eval_method="lookup")]) + Term(factors=[Factor(str(t) + "_hat", eval_method="lookup")], origin=t) for t in terms ] diff --git a/formulaic/parser/types/term.py b/formulaic/parser/types/term.py index 8a45853..5dc31f1 100644 --- a/formulaic/parser/types/term.py +++ b/formulaic/parser/types/term.py @@ -1,4 +1,6 @@ -from typing import Iterable, TYPE_CHECKING +from __future__ import annotations + +from typing import Iterable, Optional, TYPE_CHECKING if TYPE_CHECKING: from .factor import Factor # pragma: no cover @@ -13,10 +15,13 @@ class Term: Attributes: factors: The set of factors to be multipled to form the term. + origin: If this `Term` has been derived from another `Term`, for example + in subformulae, a reference to the original term. """ - def __init__(self, factors: Iterable["Factor"]): + def __init__(self, factors: Iterable["Factor"], origin: Optional[Term] = None): self.factors = tuple(sorted(set(factors))) + self.origin = origin self._factor_exprs = tuple(factor.expr for factor in self.factors) self._hash = hash(repr(self))