diff --git a/formulaic/parser/parser.py b/formulaic/parser/parser.py index e030594..2871926 100644 --- a/formulaic/parser/parser.py +++ b/formulaic/parser/parser.py @@ -5,6 +5,8 @@ from dataclasses import dataclass, field from typing import List, Iterable, Sequence, Tuple, Union, cast +from formulaic.parser.types.factor import Factor + from .algos.tokenize import tokenize from .types import ( FormulaParser, @@ -114,6 +116,52 @@ def get_tokens(self, formula: str) -> Iterable[Token]: return tokens + def get_terms(self, formula: str) -> Structured[List[Term]]: + """ + Assemble the `Term` instances for a formula string. Depending on the + operators involved, this may be an iterable of `Term` instances, or + an iterable of iterables of `Term`s, etc. + + This implementation also verifies that the formula is well-formed, in + that it does not have any literals apart from 1 or numeric scaling of + other terms. + + Args: + formula: The formula for which an AST should be generated. + """ + terms = super().get_terms(formula) + + def check_terms(terms: Iterable[Term]) -> None: + for term in terms: + if len(term.factors) == 1: + factor = term.factors[0] + if ( + factor.eval_method is Factor.EvalMethod.LITERAL + and factor.expr != "1" + ): + raise exc_for_token( + factor.token or Token(), + "Numeric literals other than `1` can only be used " + "to scale other terms. (tip: Use `:` rather than " + "`*` when scaling terms)" + if factor.expr.replace(".", "", 1).isnumeric() + else "String literals are not valid in formulae.", + ) + else: + for factor in term.factors: + if ( + factor.eval_method is Factor.EvalMethod.LITERAL + and not factor.expr.replace(".", "", 1).isnumeric() + ): + raise exc_for_token( + factor.token or Token(), + "String literals are not valid in formulae.", + ) + + terms._map(check_terms) + + return terms + class DefaultOperatorResolver(OperatorResolver): """ diff --git a/tests/parser/test_parser.py b/tests/parser/test_parser.py index 315bd2f..44e1ad5 100644 --- a/tests/parser/test_parser.py +++ b/tests/parser/test_parser.py @@ -73,6 +73,8 @@ "(a+b)**2": ["1", "a", "a:b", "b"], "(a+b)^2": ["1", "a", "a:b", "b"], "(a+b)**3": ["1", "a", "a:b", "b"], + "50:a": ["1", "50:a"], + "1 * a": ["1", "a", "1:a"], # Nested products "a/b": ["1", "a", "a:b"], "(b+a)/c": ["1", "b", "a", "b:a:c"], @@ -162,6 +164,32 @@ def test_long_formula(self): terms = PARSER_NO_INTERCEPT.get_terms(expr) assert {str(term) for term in terms} == names + def test_invalid_literals(self): + with pytest.raises( + FormulaSyntaxError, + match=re.escape( + "Numeric literals other than `1` can only be used to scale other terms." + ), + ): + PARSER.get_terms("50") + with pytest.raises( + FormulaSyntaxError, + match=re.escape("String literals are not valid in formulae."), + ): + PARSER.get_terms("'asd'") + with pytest.raises( + FormulaSyntaxError, + match=re.escape( + "Numeric literals other than `1` can only be used to scale other terms." + ), + ): + PARSER.get_terms("50*a") + with pytest.raises( + FormulaSyntaxError, + match=re.escape("String literals are not valid in formulae."), + ): + PARSER.get_terms("'asd':a") + class TestDefaultOperatorResolver: @pytest.fixture