Skip to content

Commit

Permalink
Do not count literal factors toward degree (fixing sorting of e.g. gl…
Browse files Browse the repository at this point in the history
…obal intercept).
  • Loading branch information
matthewwardrop committed Sep 25, 2023
1 parent 0f4ea0b commit 5ed5066
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 5 deletions.
2 changes: 1 addition & 1 deletion formulaic/formula.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def _prepare_item(self, key: str, item: FormulaSpec) -> Union[List[Term], Formul
# Order terms appropriately
orderer = None
if self._ordering is OrderingMethod.DEGREE:
orderer = lambda terms: sorted(terms, key=lambda term: len(term.factors))
orderer = lambda terms: sorted(terms, key=lambda term: term.degree)
elif self._ordering is OrderingMethod.SORT:
orderer = lambda terms: sorted(
[Term(factors=sorted(term.factors)) for term in terms]
Expand Down
15 changes: 12 additions & 3 deletions formulaic/parser/types/term.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,23 @@ class Term:
a formula is made up of a sum of terms.
Attributes:
factors: The set of factors to be multipled to form the term.
factors: The set of factors to be multiplied to form the term.
"""

def __init__(self, factors: Iterable["Factor"]):
self.factors = tuple(dict.fromkeys(factors))
self._factor_key = tuple(factor.expr for factor in sorted(self.factors))
self._hash = hash(":".join(self._factor_key))

@property
def degree(self) -> int:
"""
The degree of the `Term`. Literal factors do not contribute to the degree.
"""
return len(
tuple(f for f in self.factors if f.eval_method != f.eval_method.LITERAL)
)

# Transforms and comparisons

def __mul__(self, other: Any) -> Term:
Expand All @@ -41,9 +50,9 @@ def __eq__(self, other: Any) -> bool:

def __lt__(self, other: Any) -> bool:
if isinstance(other, Term):
if len(self.factors) == len(other.factors):
if self.degree == other.degree:
return sorted(self.factors) < sorted(other.factors)
if len(self.factors) < len(other.factors):
if self.degree < other.degree:
return True
return False
return NotImplemented
Expand Down
6 changes: 6 additions & 0 deletions tests/parser/types/test_term.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,9 @@ def test_sort(self, term1, term2, term3):

def test_repr(self, term1):
assert repr(term1) == "c:b"

def test_degree(self, term1, term3):
assert term1.degree == 2
assert term3.degree == 3
assert Term([Factor("1", eval_method="literal")]).degree == 0
assert Term([Factor("1", eval_method="literal"), Factor("x")]).degree == 1
2 changes: 1 addition & 1 deletion tests/test_formula.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ def data(self):
def test_constructor(self):
assert [str(t) for t in Formula(["a", "b", "c"])] == ["a", "b", "c"]
assert [str(t) for t in Formula(["a", "c", "b", "1"])] == [
"1",
"a",
"c",
"b",
"1",
]

f = Formula((["a", "b"], ["c", "d"]))
Expand Down

0 comments on commit 5ed5066

Please sign in to comment.