Skip to content

Commit

Permalink
Transformations: Add merge_associates utility
Browse files Browse the repository at this point in the history
  • Loading branch information
mlange05 committed Oct 14, 2024
1 parent 7f3c716 commit 02c5b7a
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 4 deletions.
84 changes: 82 additions & 2 deletions loki/transformations/sanitise.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@

from loki.batch import Transformation
from loki.expression import Array, RangeIndex, LokiIdentityMapper
from loki.ir import nodes as ir, FindNodes, Transformer
from loki.ir import nodes as ir, FindNodes, Transformer, NestedTransformer
from loki.tools import as_tuple, dict_override
from loki.types import BasicType


__all__ = [
'SanitiseTransformation', 'resolve_associates',
'SanitiseTransformation', 'resolve_associates', 'merge_associates',
'ResolveAssociatesTransformer', 'transform_sequence_association',
'transform_sequence_association_append_map'
]
Expand Down Expand Up @@ -194,6 +194,86 @@ def visit_CallStatement(self, o, **kwargs):
return o._rebuild(arguments=arguments, kwarguments=kwarguments)


def merge_associates(routine, max_parents=None):
"""
Moves associate mappings in :any:`Associate` within a
:any:`Subroutine` to the outermost parent scope.
Please see :any:`MergeAssociatesTransformer` for mode details.
Note
----
This method can be combined with :any:`resolve_associates` to
create a more unified look-and-feel for nested ASSOCIATE blocks.
Parameters
----------
routine : :any:`Subroutine`
The subroutine for which to resolve all associate blocks.
max_parents : int, optional
Maximum number of parent symbols for valid selector to have.
"""
transformer = MergeAssociatesTransformer(max_parents=max_parents)
routine.body = transformer.visit(routine.body)


class MergeAssociatesTransformer(NestedTransformer):
"""
:any:`NestedTransformer` that moves associate mappings in
:any:`Associate` to parent nodes.
If a selector expression depends on a symbol from a parent
:any:`Associate` exists, it does not get moved.
Additionally, a maximum parent-depth can be specified for the
selector to prevent overly long symbols to be moved up.
Parameters
----------
routine : :any:`Subroutine`
The subroutine for which to resolve all associate blocks.
max_parents : int, optional
Maximum number of parent symbols for valid selector to have.
"""

def __init__(self, max_parents=None, **kwargs):
self.max_parents = max_parents
super().__init__(**kwargs)

def visit_Associate(self, o, **kwargs):
body = self.visit(o.body, **kwargs)

if not o.parent or not isinstance(o.parent, ir.Associate):
return o._rebuild(body=body)

# Find all associate mapping that can be moved up
to_move = tuple(
(expr, name) for expr, name in o.associations
if not expr.scope == o.parent
)

if self.max_parents:
# Optionally filter by depth of symbol-parentage
to_move = tuple(
(expr, name) for expr, name in to_move
if not len(expr.parents) > self.max_parents
)

# Move up to parent ...
parent_assoc = tuple(
(expr, name) for expr, name in to_move
if (expr, name) not in o.parent.associations
)
o.parent._update(associations=o.parent.associations + parent_assoc)

# ... and remove from this associate node
new_assocs = tuple(
(expr, name) for expr, name in o.associations
if (expr, name) not in to_move
)
return o._rebuild(body=body, associations=new_assocs)


def check_if_scalar_syntax(arg, dummy):
"""
Check if an array argument, arg,
Expand Down
58 changes: 56 additions & 2 deletions loki/transformations/tests/test_sanitise.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
from loki.ir import nodes as ir, FindNodes

from loki.transformations.sanitise import (
resolve_associates, transform_sequence_association,
ResolveAssociatesTransformer, SanitiseTransformation
resolve_associates, merge_associates,
transform_sequence_association, ResolveAssociatesTransformer,
SanitiseTransformation
)


Expand Down Expand Up @@ -302,6 +303,59 @@ def test_transform_associates_start_depth(frontend):
assert assigns[2].rhs == 'b%d(i) + 1.'


@pytest.mark.parametrize('frontend', available_frontends(
xfail=[(OMNI, 'OMNI does not handle missing type definitions')]
))
def test_merge_associates_nested(frontend):
"""
Test association merging for nested mappings.
"""
fcode = """
subroutine merge_associates_simple(base)
use some_module, only: some_type
implicit none
type(some_type), intent(inout) :: base
integer :: i
real :: local_var
associate(a => base%a)
associate(b => base%other%symbol, c => a%more)
associate(d => base%other%symbol%really%deep, &
& a => base%a)
do i=1, 5
call another_routine(i, n=b(c)%n)
d(i) = 42.0
end do
end associate
end associate
end associate
end subroutine merge_associates_simple
"""

routine = Subroutine.from_source(fcode, frontend=frontend)

assocs = FindNodes(ir.Associate).visit(routine.body)
assert len(assocs) == 3
assert len(assocs[0].associations) == 1
assert len(assocs[1].associations) == 2
assert len(assocs[2].associations) == 2

# Move associate mapping around
merge_associates(routine, max_parents=2)

assocs = FindNodes(ir.Associate).visit(routine.body)
assert len(assocs) == 3
assert len(assocs[0].associations) == 2
assert assocs[0].associations[0] == ('base%a', 'a')
assert assocs[0].associations[1] == ('base%other%symbol', 'b')
assert len(assocs[1].associations) == 1
assert assocs[1].associations[0] == ('a%more', 'c')
assert len(assocs[2].associations) == 1
assert assocs[2].associations[0] == ('base%other%symbol%really%deep', 'd')


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_sequence_assocaition_scalar_notation(frontend, tmp_path):
fcode = """
Expand Down

0 comments on commit 02c5b7a

Please sign in to comment.