Skip to content

Commit

Permalink
IR: Improved implementation of attach_region_pragma via Transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
mlange05 committed Apr 12, 2024
1 parent a594b44 commit a008854
Showing 1 changed file with 56 additions and 29 deletions.
85 changes: 56 additions & 29 deletions loki/pragma_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,21 @@
import re
from collections import defaultdict
from contextlib import contextmanager
from codetiming import Timer

from loki.expression import symbols as sym
from loki.ir import (
VariableDeclaration, Pragma, PragmaRegion,
FindNodes, Visitor, Transformer, MaskedTransformer
VariableDeclaration, Pragma, PragmaRegion, FindNodes, Visitor,
Transformer
)
from loki.tools.util import as_tuple, flatten
from loki.tools.util import as_tuple
from loki.types import BasicType
from loki.logging import debug, warning


__all__ = [
'is_loki_pragma', 'get_pragma_parameters', 'process_dimension_pragmas',
'attach_pragmas', 'detach_pragmas', 'extract_pragma_region',
'attach_pragmas', 'detach_pragmas',
'pragmas_attached', 'attach_pragma_regions', 'detach_pragma_regions',
'pragma_regions_attached', 'PragmaAttacher', 'PragmaDetacher'
]
Expand Down Expand Up @@ -425,34 +427,58 @@ def _matches_starting_pragma(start, p):
return matches


def extract_pragma_region(ir, start, end):
class PragmaRegionAttacher(Transformer):
"""
Create a :any:`PragmaRegion` object defined by two :any:`Pragma` node
objects :data:`start` and :data:`end`.
Utility transformer that inserts :any:`PragmaRegion` objects to
mark code section between matching :any:`Pragma` pairs.
The resulting :any:`PragmaRegion` object will be inserted into the
:data:`ir` tree without rebuilding any IR nodes via ``Transformer(...,
inplace=True)``.
Matching pragma pairs are assumed to be of the form
``!$<keyword> <marker>`` and ``!$<keyword> end <marker>``.
The matching of pragma pairs only happens if the matching pragmas
are stored within the same tuple, or in other words at the same
depth of the IR tree. Ending a pragma region in a different
nesting depth, eg. inside a loop body, will result in a warning
and no region object being inserted into the IR tree.
Parameters
----------
pragma_pairs : tuple of tuple of :any:`Pragma`
Tuple of 2-tuples of matching pragma pairs
"""
assert isinstance(start, Pragma)
assert isinstance(end, Pragma)

# Pick out the marked code block for the PragmaRegion
block = MaskedTransformer(start=start, stop=end, inplace=True).visit(ir)
block = as_tuple(flatten(block))[1:] # Drop the initial pragma node
region = PragmaRegion(body=block, pragma=start, pragma_post=end)

# Remove the content of the code region and replace
# starting pragma with new PragmaRegion node.
mapper = {}
for node in block:
mapper[node] = None
mapper[start] = region
mapper[end] = None

return Transformer(mapper, inplace=True).visit(ir)
def __init__(self, pragma_pairs=None, **kwargs):
self.pragma_pairs = pragma_pairs

super().__init__(**kwargs)

def visit_tuple(self, o, **kwargs):
""" Replace pragma-body-end in tuples """
for start, stop in self.pragma_pairs:
if start in o:
# If a pair does not live in the same tuple we have a problem.
if stop not in o:
warning('[Loki::IR] Cannot find matching end for pragma {start} at same IR level!')
continue

# Create the PragmaRegion node and replace in tuple
idx_start = o.index(start)
idx_stop = o.index(stop)
region = PragmaRegion(
body=o[idx_start+1:idx_stop], pragma=start, pragma_post=stop
)
o = o[:idx_start] + (region,) + o[idx_stop+1:]

# Then recurse over the new nodes
visited = tuple(self.visit(i, **kwargs) for i in o)

# Strip empty sublists/subtuples or None entries
return tuple(i for i in visited if i is not None and as_tuple(i))

visit_list = visit_tuple


@Timer(logger=debug, text=lambda s: f'[Loki::IR] Executed attach_pragma_regions in {s:.2f}s')
def attach_pragma_regions(ir):
"""
Create :any:`PragmaRegion` node objects for all matching pairs of
Expand All @@ -466,9 +492,10 @@ def attach_pragma_regions(ir):
is performed in-place, without rebuilding any IR nodes.
"""
pragmas = FindNodes(Pragma).visit(ir)
for start, end in get_matching_region_pragmas(pragmas):
ir = extract_pragma_region(ir, start=start, end=end)
return ir
pragma_pairs = get_matching_region_pragmas(pragmas)

return PragmaRegionAttacher(pragma_pairs=pragma_pairs, inplace=True).visit(ir)


def detach_pragma_regions(ir):
"""
Expand Down

0 comments on commit a008854

Please sign in to comment.