Skip to content

Commit

Permalink
IR: Improved implementation of detach_pragma_region via Transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
mlange05 committed Apr 12, 2024
1 parent a008854 commit 024018b
Showing 1 changed file with 30 additions and 4 deletions.
34 changes: 30 additions & 4 deletions loki/pragma_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
VariableDeclaration, Pragma, PragmaRegion, FindNodes, Visitor,
Transformer
)
from loki.tools.util import as_tuple
from loki.tools.util import as_tuple, replace_windowed
from loki.types import BasicType
from loki.logging import debug, warning

Expand Down Expand Up @@ -497,6 +497,33 @@ def attach_pragma_regions(ir):
return PragmaRegionAttacher(pragma_pairs=pragma_pairs, inplace=True).visit(ir)


class PragmaRegionDetacher(Transformer):
"""
Remove any :any:`PragmaRegion` node objects and insert the tuple
of ``(r.pragma, r.body, r.pragma_post)`` in the enclosing tuple.
"""

def visit_tuple(self, o, **kwargs):
""" Unpack :any:`PragmaRegion` objects and insert in current tuple """

# We unpack regions here to avoid creating nested tuples, or
# forcing general tuple-flattening, which can affect other
# nodes types.
regions = tuple(n for n in o if isinstance(n, PragmaRegion))
for r in regions:
handle = (r.pragma,) + self.visit(r.body, **kwargs) + (r.pragma_post,)
o = replace_windowed(o, r, subs=handle)

# First recurse over the new nodes
visited = tuple(self.visit(i, **kwargs) for i in o)

# Strip empty sublists/subtuples or None entries
return tuple(i for i in visited if i is not None and as_tuple(i))

visit_list = visit_tuple


@Timer(logger=debug, text=lambda s: f'[Loki::IR] Executed detach_pragma_regions in {s:.2f}s')
def detach_pragma_regions(ir):
"""
Remove any :any:`PragmaRegion` node objects and replace each with a
Expand All @@ -506,9 +533,8 @@ def detach_pragma_regions(ir):
All replacements are performed in-place, without rebuilding any IR
nodes.
"""
mapper = {region: as_tuple(region.pragma) + region.body + as_tuple(region.pragma_post)
for region in FindNodes(PragmaRegion).visit(ir)}
return Transformer(mapper, inplace=True).visit(ir)

return PragmaRegionDetacher(inplace=True).visit(ir)


@contextmanager
Expand Down

0 comments on commit 024018b

Please sign in to comment.