From db84b1db4e4271772fae7758e7aee6f3b63fe73d Mon Sep 17 00:00:00 2001 From: Michael Staneker Date: Mon, 8 Apr 2024 07:41:54 +0000 Subject: [PATCH] Continued 2: Loki string parser based on pymbolic parser --- loki/expression/parser.py | 65 +++++++-- tests/test_expression.py | 299 +++++++++++++++++++++++++++++++------- 2 files changed, 300 insertions(+), 64 deletions(-) diff --git a/loki/expression/parser.py b/loki/expression/parser.py index 218b8d778..bcdf39d56 100644 --- a/loki/expression/parser.py +++ b/loki/expression/parser.py @@ -12,6 +12,7 @@ from pymbolic.mapper import Mapper import pymbolic.primitives as pmbl +from loki.tools.util import CaseInsensitiveDict from loki.expression import symbols as sym __all__ = ['LokiParser', 'loki_parse'] @@ -27,23 +28,24 @@ def __init__(self, scope=None): def map_product(self, expr, *args, **kwargs): return sym.Product(tuple(self.rec(child, *args, **kwargs) for child in expr.children)) - map_Mul = map_product + # map_Mul = map_product def map_sum(self, expr, *args, **kwargs): return sym.Sum(tuple(self.rec(child, *args, **kwargs) for child in expr.children)) - map_Add = map_sum + # map_Add = map_sum def map_power(self, expr, *args, **kwargs): - return sym.Power(base=self.rec(expr.base), - exponent=self.rec(expr.exponent)) + return sym.Power(base=self.rec(expr.base, *args, **kwargs), + exponent=self.rec(expr.exponent, *args, **kwargs)) def map_quotient(self, expr, *args, **kwargs): - return sym.Quotient(numerator=expr.numerator, denominator=expr.denominator) + return sym.Quotient(numerator=self.rec(expr.numerator, *args, **kwargs), + denominator=self.rec(expr.denominator, *args, **kwargs)) def map_comparison(self, expr, *args, **kwargs): - return sym.Comparison(left=self.rec(expr.left), + return sym.Comparison(left=self.rec(expr.left, *args, **kwargs), operator=expr.operator, - right=self.rec(expr.right)) + right=self.rec(expr.right, *args, **kwargs)) def map_logical_and(self, expr, *args, **kwargs): return sym.LogicalAnd(tuple(self.rec(child, *args, **kwargs) for child in expr.children)) @@ -52,9 +54,11 @@ def map_logical_or(self, expr, *args, **kwargs): return sym.LogicalOr(tuple(self.rec(child, *args, **kwargs) for child in expr.children)) def map_logical_not(self, expr, *args, **kwargs): - return sym.LogicalNot(self.rec(expr.child)) + return sym.LogicalNot(self.rec(expr.child, *args, **kwargs)) def map_constant(self, expr, *args, **kwargs): + if isinstance(expr, (sym.FloatLiteral, sym.IntLiteral)): + return expr return sym.Literal(expr) map_logic_literal = map_constant map_string_literal = map_constant @@ -67,7 +71,7 @@ def map_constant(self, expr, *args, **kwargs): def map_meta_symbol(self, expr, *args, **kwargs): return sym.Variable(name=str(expr.name), scope=self.scope) - map_Symbol = map_meta_symbol + # map_Symbol = map_meta_symbol map_scalar = map_meta_symbol map_array = map_meta_symbol @@ -88,17 +92,29 @@ def map_algebraic_leaf(self, expr, *args, **kwargs): if isinstance(expr, pmbl.Call): if self.scope is not None: if expr.function.name in self.scope.symbol_attrs: - return sym.Variable(name=expr.function.name, scope=self.scope, dimensions=self.rec(expr.parameters)) - return expr + return sym.Variable(name=expr.function.name, scope=self.scope, + dimensions=tuple(self.rec(param, *args, **kwargs) for param in expr.parameters)) + return sym.InlineCall(function=sym.Variable(name=expr.function.name), + parameters=tuple(self.rec(param, *args, **kwargs) for param in expr.parameters)) # else: try: - return self.map_variable(expr) + return self.map_variable(expr, *args, **kwargs) except Exception as e: print(f"Exception: {e}") return expr + def map_call_with_kwargs(self, expr, *args, **kwargs): + name = sym.Variable(name=expr.function.name) + parameters = tuple(self.rec(param, *args, **kwargs) for param in expr.parameters) + kw_parameters = {key: self.rec(value, *args, **kwargs) for key, value\ + in CaseInsensitiveDict(expr.kw_parameters).items()} + if expr.function.name.lower() in ('real', 'int'): + return sym.Cast(name, parameters, kind=kw_parameters['kind']) + + return sym.InlineCall(function=name, parameters=parameters, kw_parameters=kw_parameters) + def map_tuple(self, expr, *args, **kwargs): - return tuple(self.rec(elem) for elem in expr) + return tuple(self.rec(elem, *args, **kwargs) for elem in expr) class LokiParser(ParserBase): @@ -112,6 +128,8 @@ class LokiParser(ParserBase): _f_and = intern("and") _f_or = intern("or") _f_not = intern("not") + _f_float = intern("f_float") + _f_int = intern("f_int") lex_table = [ (_f_lessequal, pytools.lex.RE(r"\.le\.", re.IGNORECASE)), @@ -123,6 +141,8 @@ class LokiParser(ParserBase): (_f_and, pytools.lex.RE(r"\.and\.", re.IGNORECASE)), (_f_or, pytools.lex.RE(r"\.or\.", re.IGNORECASE)), (_f_not, pytools.lex.RE(r"\.not\.", re.IGNORECASE)), + (_f_float, ("|", pytools.lex.RE(r"[0-9]+\.[0-9]*([eEdD][+-]?[0-9]+)?(_[a-zA-Z]*)", re.IGNORECASE))), + (_f_int, pytools.lex.RE(r"[0-9]+?(_[a-zA-Z]*)", re.IGNORECASE)), ] + ParserBase.lex_table ParserBase._COMP_TABLE.update({ @@ -134,8 +154,27 @@ class LokiParser(ParserBase): _f_notequal: "!=" }) + def parse_terminal(self, pstate): + # next_tag = pstate.next_tag() + if pstate.is_next(self._f_float): # next_tag is self._f_float: + return self.parse_f_float(pstate.next_str_and_advance()) + if pstate.is_next(self._f_int): + return self.parse_f_int(pstate.next_str_and_advance()) + return super().parse_terminal(pstate) + def __call__(self, expr_str, scope=None, min_precedence=0): result = super().__call__(expr_str, min_precedence) return PymbolicMapper(scope=scope)(result) + def parse_f_float(self, s): + stripped = s.split('_', 1) + if len(stripped) == 2: + return sym.Literal(value=self.parse_float(stripped[0]), kind=stripped[1].lower()) + return self.parse_float(stripped[0]) + + def parse_f_int(self, s): + stripped = s.split('_', 1) + value = int(stripped[0].replace("d", "e").replace("D", "e")) + return sym.IntLiteral(value=value, kind=stripped[1].lower()) + loki_parse = LokiParser() diff --git a/tests/test_expression.py b/tests/test_expression.py index d292280cf..31d5c7a60 100644 --- a/tests/test_expression.py +++ b/tests/test_expression.py @@ -1552,72 +1552,269 @@ def test_expression_c_de_reference(frontend): assert '(&renamed_var_reference)=1' in c_str assert '(*renamed_var_dereference)=2' in c_str +@pytest.mark.parametrize('case', ('upper', 'lower', 'random')) @pytest.mark.parametrize('frontend', available_frontends()) -def test_parser(frontend): +def test_parser(frontend, case): fcode = """ subroutine some_routine() implicit none +! INTEGER, PARAMETER :: JPIM = SELECTED_INT_KIND(9) +! integer, parameter :: jprb = selected_real_kind(13,300) integer :: i1, i2, i3, len1, len2, len3 real :: a, b real :: arr(len1, len2, len3) end subroutine some_routine """.strip() + def convert_to_case(_str, mode='upper'): + if mode == 'upper': + # print(f"{_str.upper()}") + return _str.upper() + if mode == 'lower': + # print(f"{_str.lower()}") + return _str.lower() + if mode == 'random': + # this is obviously not random, but fulfils its purpose ... + result = '' + for i, char in enumerate(_str): + result += char.upper() if i%2==0 and i<3 else char.lower() + # print(f"{result}") + return result + return convert_to_case(_str) + + routine = Subroutine.from_source(fcode, frontend=frontend) - print("") - parsed = loki_parse('a + b') - print(f"{parsed} | type: {type(parsed)} | children type(s): {[type(child) for child in parsed.children]}") - parsed = loki_parse('a + b', scope=routine) - print(f"{parsed} | type: {type(parsed)} | children type(s): {[type(child) for child in parsed.children]}") - parsed = loki_parse('a + b + 2 + 10', scope=routine) - print(f"{parsed} | {type(parsed)} | {[type(child) for child in parsed.children]}") - parsed = loki_parse('a - b', scope=routine) - print(f"{parsed} | {type(parsed)}") - parsed = loki_parse('a * b', scope=routine) - print(f"{parsed} | {type(parsed)}") - parsed = loki_parse('a / b', scope=routine) - print(f"{parsed} | {type(parsed)}") - parsed = loki_parse('a ** b', scope=routine) - print(f"{parsed} | {type(parsed)}") - parsed = loki_parse('a:b', scope=routine) - print(f"{parsed} | {type(parsed)}") - parsed = loki_parse('a>b', scope=routine) - print(f"{parsed} | {type(parsed)}") - parsed = loki_parse('a.gt.b', scope=routine) - print(f"{parsed} | {type(parsed)}") - - parsed = loki_parse('arr(i1, i2, i3)') - print(f"{parsed} | {type(parsed)}") #  | shape: {parsed.shape} | dimensions: {parsed.dimensions}") - parsed = loki_parse('arr(i1, i2, i3)', scope=routine) - print(f"{parsed} | {type(parsed)} | shape: {parsed.shape} | dimensions: {parsed.dimensions}") - - parsed = loki_parse('a') - print(f"{parsed} | {type(parsed)} | scope: {parsed.scope} | type: {parsed.type}") - parsed = loki_parse('a', scope=routine) - print(f"{parsed} | {type(parsed)} | scope: {parsed.scope} | type: {parsed.type}") - parsed = loki_parse('3.1415') - print(f"{parsed} | {type(parsed)}") - - parsed = loki_parse('MODULO(A, B)') - print(f"{parsed} | {type(parsed)}") + # print("") + parsed = loki_parse(convert_to_case('a + b', mode=case)) + assert isinstance(parsed, symbols.Sum) + assert all(isinstance(_parsed, symbols.DeferredTypeSymbol) for _parsed in parsed.children) + # print(f"{parsed} | type: {type(parsed)} | children type(s): {[type(child) for child in parsed.children]}") + + parsed = loki_parse(convert_to_case('a + b', mode=case), scope=routine) + assert isinstance(parsed, symbols.Sum) + assert all(isinstance(_parsed, symbols.Scalar) for _parsed in parsed.children) + assert all(_parsed.scope == routine for _parsed in parsed.children) + # print(f"{parsed} | type: {type(parsed)} | children type(s): {[type(child) for child in parsed.children]}") + + parsed = loki_parse(convert_to_case('a + b + 2 + 10', mode=case), scope=routine) + assert isinstance(parsed, symbols.Sum) + assert all(isinstance(_parsed, (symbols.Scalar, symbols.IntLiteral)) for _parsed in parsed.children) + # print(f"{parsed} | {type(parsed)} | {[type(child) for child in parsed.children]}") + + parsed = loki_parse(convert_to_case('a - b', mode=case), scope=routine) + assert isinstance(parsed, symbols.Sum) + # assert all(isinstance(_parsed, symbols.Scalar) for _parsed in parsed.children) + assert isinstance(parsed.children[0], symbols.Scalar) + assert isinstance(parsed.children[1], symbols.Product) + assert isinstance(parsed.children[1].children[0], symbols.IntLiteral) + assert isinstance(parsed.children[1].children[1], symbols.Scalar) + # print(f"{parsed} | {type(parsed)}") + + parsed = loki_parse(convert_to_case('a * b', mode=case), scope=routine) + assert isinstance(parsed, symbols.Product) + assert all(isinstance(_parsed, symbols.Scalar) for _parsed in parsed.children) + assert all(_parsed.scope == routine for _parsed in parsed.children) + # print(f"{parsed} | {type(parsed)}") + + parsed = loki_parse(convert_to_case('a / b', mode=case), scope=routine) + assert isinstance(parsed, symbols.Quotient) + assert all(isinstance(_parsed, symbols.Scalar) for _parsed in [parsed.numerator, parsed.denominator]) + assert all(_parsed.scope == routine for _parsed in [parsed.numerator, parsed.denominator]) + # print(f"{parsed} | {type(parsed)}") + + parsed = loki_parse(convert_to_case('a ** b', mode=case), scope=routine) + assert isinstance(parsed, symbols.Power) + assert all(isinstance(_parsed, symbols.Scalar) for _parsed in [parsed.base, parsed.exponent]) + assert all(_parsed.scope == routine for _parsed in [parsed.base, parsed.exponent]) + # print(f"{parsed} | {type(parsed)}") + + parsed = loki_parse(convert_to_case('a:b', mode=case), scope=routine) + assert isinstance(parsed, symbols.RangeIndex) + assert all(isinstance(_parsed, symbols.Scalar) for _parsed in [parsed.lower, parsed.upper]) + assert all(_parsed.scope == routine for _parsed in [parsed.lower, parsed.upper]) + # print(f"{parsed} | {type(parsed)}") + + parsed = loki_parse(convert_to_case('a:b:5', mode=case), scope=routine) + assert isinstance(parsed, symbols.RangeIndex) + assert all(isinstance(_parsed, (symbols.Scalar, symbols.IntLiteral)) + for _parsed in [parsed.lower, parsed.upper, parsed.step]) + # print(f"{parsed} | {type(parsed)}") + + parsed = loki_parse(convert_to_case('a == b', mode=case), scope=routine) + assert parsed.operator == '==' + assert isinstance(parsed, symbols.Comparison) + assert all(isinstance(_parsed, symbols.Scalar) for _parsed in [parsed.left, parsed.right]) + assert all(_parsed.scope == routine for _parsed in [parsed.left, parsed.right]) + # print(f"{parsed} | {type(parsed)}") + parsed = loki_parse(convert_to_case('a.eq.b', mode=case), scope=routine) + assert parsed.operator == '==' + assert isinstance(parsed, symbols.Comparison) + assert all(isinstance(_parsed, symbols.Scalar) for _parsed in [parsed.left, parsed.right]) + assert all(_parsed.scope == routine for _parsed in [parsed.left, parsed.right]) + # print(f"{parsed} | {type(parsed)}") + + parsed = loki_parse(convert_to_case('a!=b', mode=case), scope=routine) + assert parsed.operator == '!=' + assert isinstance(parsed, symbols.Comparison) + assert all(isinstance(_parsed, symbols.Scalar) for _parsed in [parsed.left, parsed.right]) + assert all(_parsed.scope == routine for _parsed in [parsed.left, parsed.right]) + # print(f"{parsed} | {type(parsed)}") + parsed = loki_parse(convert_to_case('a.ne.b', mode=case), scope=routine) + assert parsed.operator == '!=' + assert isinstance(parsed, symbols.Comparison) + assert all(isinstance(_parsed, symbols.Scalar) for _parsed in [parsed.left, parsed.right]) + assert all(_parsed.scope == routine for _parsed in [parsed.left, parsed.right]) + # print(f"{parsed} | {type(parsed)}") + + parsed = loki_parse(convert_to_case('a>b', mode=case), scope=routine) + assert parsed.operator == '>' + assert isinstance(parsed, symbols.Comparison) + assert all(isinstance(_parsed, symbols.Scalar) for _parsed in [parsed.left, parsed.right]) + assert all(_parsed.scope == routine for _parsed in [parsed.left, parsed.right]) + # print(f"{parsed} | {type(parsed)}") + parsed = loki_parse(convert_to_case('a.gt.b', mode=case), scope=routine) + assert parsed.operator == '>' + assert isinstance(parsed, symbols.Comparison) + assert all(isinstance(_parsed, symbols.Scalar) for _parsed in [parsed.left, parsed.right]) + assert all(_parsed.scope == routine for _parsed in [parsed.left, parsed.right]) + # print(f"{parsed} | {type(parsed)}") + + parsed = loki_parse(convert_to_case('a>=b', mode=case), scope=routine) + assert parsed.operator == '>=' + assert isinstance(parsed, symbols.Comparison) + assert all(isinstance(_parsed, symbols.Scalar) for _parsed in [parsed.left, parsed.right]) + assert all(_parsed.scope == routine for _parsed in [parsed.left, parsed.right]) + # print(f"{parsed} | {type(parsed)}") + parsed = loki_parse(convert_to_case('a.ge.b', mode=case), scope=routine) + assert parsed.operator == '>=' + assert isinstance(parsed, symbols.Comparison) + assert all(isinstance(_parsed, symbols.Scalar) for _parsed in [parsed.left, parsed.right]) + assert all(_parsed.scope == routine for _parsed in [parsed.left, parsed.right]) + # print(f"{parsed} | {type(parsed)}") + + parsed = loki_parse(convert_to_case('a