Skip to content

Commit

Permalink
Continued 2: Loki string parser based on pymbolic parser
Browse files Browse the repository at this point in the history
  • Loading branch information
MichaelSt98 committed Apr 8, 2024
1 parent c48db05 commit db84b1d
Show file tree
Hide file tree
Showing 2 changed files with 300 additions and 64 deletions.
65 changes: 52 additions & 13 deletions loki/expression/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from pymbolic.mapper import Mapper
import pymbolic.primitives as pmbl

from loki.tools.util import CaseInsensitiveDict
from loki.expression import symbols as sym

__all__ = ['LokiParser', 'loki_parse']
Expand All @@ -27,23 +28,24 @@ def __init__(self, scope=None):

def map_product(self, expr, *args, **kwargs):
return sym.Product(tuple(self.rec(child, *args, **kwargs) for child in expr.children))
map_Mul = map_product
map_Mul = map_product

def map_sum(self, expr, *args, **kwargs):
return sym.Sum(tuple(self.rec(child, *args, **kwargs) for child in expr.children))
map_Add = map_sum
# map_Add = map_sum

def map_power(self, expr, *args, **kwargs):
return sym.Power(base=self.rec(expr.base),
exponent=self.rec(expr.exponent))
return sym.Power(base=self.rec(expr.base, *args, **kwargs),
exponent=self.rec(expr.exponent, *args, **kwargs))

def map_quotient(self, expr, *args, **kwargs):
return sym.Quotient(numerator=expr.numerator, denominator=expr.denominator)
return sym.Quotient(numerator=self.rec(expr.numerator, *args, **kwargs),
denominator=self.rec(expr.denominator, *args, **kwargs))

def map_comparison(self, expr, *args, **kwargs):
return sym.Comparison(left=self.rec(expr.left),
return sym.Comparison(left=self.rec(expr.left, *args, **kwargs),
operator=expr.operator,
right=self.rec(expr.right))
right=self.rec(expr.right, *args, **kwargs))

def map_logical_and(self, expr, *args, **kwargs):
return sym.LogicalAnd(tuple(self.rec(child, *args, **kwargs) for child in expr.children))
Expand All @@ -52,9 +54,11 @@ def map_logical_or(self, expr, *args, **kwargs):
return sym.LogicalOr(tuple(self.rec(child, *args, **kwargs) for child in expr.children))

def map_logical_not(self, expr, *args, **kwargs):
return sym.LogicalNot(self.rec(expr.child))
return sym.LogicalNot(self.rec(expr.child, *args, **kwargs))

def map_constant(self, expr, *args, **kwargs):
if isinstance(expr, (sym.FloatLiteral, sym.IntLiteral)):
return expr
return sym.Literal(expr)
map_logic_literal = map_constant
map_string_literal = map_constant
Expand All @@ -67,7 +71,7 @@ def map_constant(self, expr, *args, **kwargs):

def map_meta_symbol(self, expr, *args, **kwargs):
return sym.Variable(name=str(expr.name), scope=self.scope)
map_Symbol = map_meta_symbol
map_Symbol = map_meta_symbol
map_scalar = map_meta_symbol
map_array = map_meta_symbol

Expand All @@ -88,17 +92,29 @@ def map_algebraic_leaf(self, expr, *args, **kwargs):
if isinstance(expr, pmbl.Call):
if self.scope is not None:
if expr.function.name in self.scope.symbol_attrs:
return sym.Variable(name=expr.function.name, scope=self.scope, dimensions=self.rec(expr.parameters))
return expr
return sym.Variable(name=expr.function.name, scope=self.scope,
dimensions=tuple(self.rec(param, *args, **kwargs) for param in expr.parameters))
return sym.InlineCall(function=sym.Variable(name=expr.function.name),
parameters=tuple(self.rec(param, *args, **kwargs) for param in expr.parameters))
# else:
try:
return self.map_variable(expr)
return self.map_variable(expr, *args, **kwargs)
except Exception as e:
print(f"Exception: {e}")
return expr

def map_call_with_kwargs(self, expr, *args, **kwargs):
name = sym.Variable(name=expr.function.name)
parameters = tuple(self.rec(param, *args, **kwargs) for param in expr.parameters)
kw_parameters = {key: self.rec(value, *args, **kwargs) for key, value\
in CaseInsensitiveDict(expr.kw_parameters).items()}
if expr.function.name.lower() in ('real', 'int'):
return sym.Cast(name, parameters, kind=kw_parameters['kind'])

return sym.InlineCall(function=name, parameters=parameters, kw_parameters=kw_parameters)

def map_tuple(self, expr, *args, **kwargs):
return tuple(self.rec(elem) for elem in expr)
return tuple(self.rec(elem, *args, **kwargs) for elem in expr)


class LokiParser(ParserBase):
Expand All @@ -112,6 +128,8 @@ class LokiParser(ParserBase):
_f_and = intern("and")
_f_or = intern("or")
_f_not = intern("not")
_f_float = intern("f_float")
_f_int = intern("f_int")

lex_table = [
(_f_lessequal, pytools.lex.RE(r"\.le\.", re.IGNORECASE)),
Expand All @@ -123,6 +141,8 @@ class LokiParser(ParserBase):
(_f_and, pytools.lex.RE(r"\.and\.", re.IGNORECASE)),
(_f_or, pytools.lex.RE(r"\.or\.", re.IGNORECASE)),
(_f_not, pytools.lex.RE(r"\.not\.", re.IGNORECASE)),
(_f_float, ("|", pytools.lex.RE(r"[0-9]+\.[0-9]*([eEdD][+-]?[0-9]+)?(_[a-zA-Z]*)", re.IGNORECASE))),
(_f_int, pytools.lex.RE(r"[0-9]+?(_[a-zA-Z]*)", re.IGNORECASE)),
] + ParserBase.lex_table

ParserBase._COMP_TABLE.update({
Expand All @@ -134,8 +154,27 @@ class LokiParser(ParserBase):
_f_notequal: "!="
})

def parse_terminal(self, pstate):
# next_tag = pstate.next_tag()
if pstate.is_next(self._f_float): # next_tag is self._f_float:
return self.parse_f_float(pstate.next_str_and_advance())
if pstate.is_next(self._f_int):
return self.parse_f_int(pstate.next_str_and_advance())
return super().parse_terminal(pstate)

def __call__(self, expr_str, scope=None, min_precedence=0):
result = super().__call__(expr_str, min_precedence)
return PymbolicMapper(scope=scope)(result)

def parse_f_float(self, s):
stripped = s.split('_', 1)
if len(stripped) == 2:
return sym.Literal(value=self.parse_float(stripped[0]), kind=stripped[1].lower())
return self.parse_float(stripped[0])

def parse_f_int(self, s):
stripped = s.split('_', 1)
value = int(stripped[0].replace("d", "e").replace("D", "e"))
return sym.IntLiteral(value=value, kind=stripped[1].lower())

loki_parse = LokiParser()
Loading

0 comments on commit db84b1d

Please sign in to comment.