From e407d5227d1322014f924d29fc4dff2714ef7e05 Mon Sep 17 00:00:00 2001 From: Michael Lange Date: Thu, 3 Oct 2024 05:25:37 +0000 Subject: [PATCH 1/5] IR: Add inverse_map to Associate and make assoc maps case-insensitive --- loki/ir/nodes.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/loki/ir/nodes.py b/loki/ir/nodes.py index 78a24e01f..da85c24ff 100644 --- a/loki/ir/nodes.py +++ b/loki/ir/nodes.py @@ -501,7 +501,14 @@ def association_map(self): """ An :any:`collections.OrderedDict` of associated expressions. """ - return OrderedDict(self.associations) + return CaseInsensitiveDict((str(k), v) for k, v in self.associations) + + @property + def inverse_map(self): + """ + An :any:`collections.OrderedDict` of associated expressions. + """ + return CaseInsensitiveDict((str(v), k) for k, v in self.associations) @property def variables(self): From c1790ca9573402b45380209625fc2cb255027f3a Mon Sep 17 00:00:00 2001 From: Michael Lange Date: Wed, 2 Oct 2024 07:03:19 +0000 Subject: [PATCH 2/5] Transformations: Rewrite Associate resolver as an in-place mapper Instead of finding all symbols and substituting them with their inverse-association, we now simply run over each symbol, find its inverse from the `.scope` and apply in-place. This requires the substitution to be written as a symbol mapper. As an addition feature, we can now do partial resolution in associate bodies, where only a select sub-region of the code has its associated symbols resolved. I've added a test to this extend. --- loki/transformations/sanitise.py | 99 +++++++++++++-------- loki/transformations/tests/test_sanitise.py | 54 ++++++++++- 2 files changed, 115 insertions(+), 38 deletions(-) diff --git a/loki/transformations/sanitise.py b/loki/transformations/sanitise.py index 839178be6..4c43fcae2 100644 --- a/loki/transformations/sanitise.py +++ b/loki/transformations/sanitise.py @@ -13,16 +13,11 @@ """ from loki.batch import Transformation -from loki.expression import Array, RangeIndex -from loki.ir import ( - CallStatement, FindNodes, Transformer, NestedTransformer, - FindVariables, SubstituteExpressions -) -from loki.tools import as_tuple, CaseInsensitiveDict +from loki.expression import Array, RangeIndex, LokiIdentityMapper +from loki.ir import nodes as ir, FindNodes, Transformer +from loki.tools import as_tuple from loki.types import BasicType -from loki.transformations.utilities import recursive_expression_map_update - __all__ = [ 'SanitiseTransformation', 'resolve_associates', @@ -80,42 +75,74 @@ def resolve_associates(routine): routine.rescope_symbols() -class ResolveAssociatesTransformer(NestedTransformer): +class ResolveAssociateMapper(LokiIdentityMapper): + """ + Exppression mapper that will resolve symbol associations due + :any:`Associate` scopes. + + The mapper will inspect the associated scope of each symbol + and replace it with the inverse of the associate mapping. """ - :any:`Transformer` class to resolve :any:`Associate` nodes in IR trees + + def map_scalar(self, expr, *args, **kwargs): + # Skip unscoped expressions + if not hasattr(expr, 'scope'): + return self.rec(expr, *args, **kwargs) + + # Stop if scope is not an associate + if not isinstance(expr.scope, ir.Associate): + return expr + + scope = expr.scope + + # Recurse on parent first and propagate scope changes + parent = self.rec(expr.parent, *args, **kwargs) + if parent != expr.parent: + expr = expr.clone(parent=parent, scope=parent.scope) + + # Find a match in the given inverse map + if expr.basename in scope.inverse_map: + expr = scope.inverse_map[expr.basename] + return self.rec(expr, *args, **kwargs) + + return expr + + def map_array(self, expr, *args, **kwargs): + """ Special case for arrys: we need to preserve the dimensions """ + new = self.map_variable_symbol(expr, *args, **kwargs) + return new.clone(dimensions=expr.dimensions) + + map_variable_symbol = map_scalar + map_deferred_type_symbol = map_scalar + map_procedure_symbol = map_scalar + + +class ResolveAssociatesTransformer(Transformer): + """ + :any:`Transformer` class to resolve :any:`Associate` nodes in IR trees. This will replace each :any:`Associate` node with its own body, where all `identifier` symbols have been replaced with the corresponding `selector` expression defined in ``associations``. - """ - def visit_Associate(self, o, **kwargs): - # First head-recurse, so that all associate blocks beneath are resolved - body = self.visit(o.body, **kwargs) - - # Create an inverse association map to look up replacements - invert_assoc = CaseInsensitiveDict({v.name: k for k, v in o.associations}) - - # Build the expression substitution map - vmap = {} - for v in FindVariables().visit(body): - if v.name in invert_assoc: - # Clone the expression to update its parentage and scoping - inv = invert_assoc[v.name] - if hasattr(v, 'dimensions'): - vmap[v] = inv.clone(dimensions=v.dimensions) - else: - vmap[v] = inv + Importantly, this :any:`Transformer` can also be applied over partial + bodies of :any:`Associate` bodies. + """ + # pylint: disable=unused-argument - # Apply the expression substitution map to itself to handle nested expressions - vmap = recursive_expression_map_update(vmap) + def visit_Expression(self, o, **kwargs): + return ResolveAssociateMapper()(o) - # Mark the associate block for replacement with its body, with all expressions replaced - self.mapper[o] = SubstituteExpressions(vmap).visit(body) + def visit_Associate(self, o, **kwargs): + """ + Replaces an :any:`Associate` node with its transformed body + """ + return self.visit(o.body, **kwargs) - # Return the original object unchanged and let the tuple injection mechanism take care - # of replacing it by its body - otherwise we would end up with nested tuples - return o + def visit_CallStatement(self, o, **kwargs): + arguments = self.visit(o.arguments, **kwargs) + kwarguments = tuple((k, self.visit(v, **kwargs)) for k, v in o.kwarguments) + return o._rebuild(arguments=arguments, kwarguments=kwarguments) def check_if_scalar_syntax(arg, dummy): @@ -169,7 +196,7 @@ def transform_sequence_association(routine): """ #List calls in routine, but make sure we have the called routine definition - calls = (c for c in FindNodes(CallStatement).visit(routine.body) if not c.procedure_type is BasicType.DEFERRED) + calls = (c for c in FindNodes(ir.CallStatement).visit(routine.body) if not c.procedure_type is BasicType.DEFERRED) call_map = {} # Check all calls and record changes to `call_map` if necessary. diff --git a/loki/transformations/tests/test_sanitise.py b/loki/transformations/tests/test_sanitise.py index 811a3d7e1..be71ffded 100644 --- a/loki/transformations/tests/test_sanitise.py +++ b/loki/transformations/tests/test_sanitise.py @@ -11,11 +11,14 @@ BasicType, FindNodes, Subroutine, Module, fgen ) from loki.frontend import available_frontends, OMNI -from loki.ir import Assignment, Associate, CallStatement, Conditional +from loki.ir import ( + nodes as ir, FindNodes, Assignment, Associate, CallStatement, + Conditional +) from loki.transformations.sanitise import ( resolve_associates, transform_sequence_association, - SanitiseTransformation + ResolveAssociatesTransformer, SanitiseTransformation ) @@ -200,6 +203,53 @@ def test_transform_associates_nested_conditional(frontend): assert assign.rhs.parent.scope == routine +@pytest.mark.parametrize('frontend', available_frontends( + xfail=[(OMNI, 'OMNI does not handle missing type definitions')] +)) +def test_transform_associates_partial_body(frontend): + """ + Test resolving associated symbols, but only for a part of an + associate's body. + """ + fcode = """ +subroutine transform_associates_partial + use some_module, only: some_obj + implicit none + + integer :: i + real :: local_var + + associate (a=>some_obj%a, b=>some_obj%b) + local_var = a(1) + + do i=1, some_obj%n + a(i) = a(i) + 1. + b(i) = b(i) + 1. + end do + end associate +end subroutine transform_associates_partial +""" + routine = Subroutine.from_source(fcode, frontend=frontend) + + assert len(FindNodes(ir.Assignment).visit(routine.body)) == 3 + loops = FindNodes(ir.Loop).visit(routine.body) + assert len(loops) == 1 + + transformer = ResolveAssociatesTransformer(inplace=True) + transformer.visit(loops[0]) + + # Check that associated symbols have been resolved in loop body only + assert len(FindNodes(ir.Loop).visit(routine.body)) == 1 + assigns = FindNodes(ir.Assignment).visit(routine.body) + assert len(assigns) == 3 + assert assigns[0].lhs == 'local_var' + assert assigns[0].rhs == 'a(1)' + assert assigns[1].lhs == 'some_obj%a(i)' + assert assigns[1].rhs == 'some_obj%a(i) + 1.' + assert assigns[2].lhs == 'some_obj%b(i)' + assert assigns[2].rhs == 'some_obj%b(i) + 1.' + + @pytest.mark.parametrize('frontend', available_frontends()) def test_transform_sequence_assocaition_scalar_notation(frontend, tmp_path): fcode = """ From 90af901fc417038372272fed315fb6599be0a46e Mon Sep 17 00:00:00 2001 From: Michael Lange Date: Wed, 2 Oct 2024 07:12:42 +0000 Subject: [PATCH 3/5] Transformations: Small test cleanup for test_sanitise. --- loki/transformations/tests/test_sanitise.py | 67 ++++++++++----------- 1 file changed, 32 insertions(+), 35 deletions(-) diff --git a/loki/transformations/tests/test_sanitise.py b/loki/transformations/tests/test_sanitise.py index be71ffded..97426f677 100644 --- a/loki/transformations/tests/test_sanitise.py +++ b/loki/transformations/tests/test_sanitise.py @@ -11,10 +11,7 @@ BasicType, FindNodes, Subroutine, Module, fgen ) from loki.frontend import available_frontends, OMNI -from loki.ir import ( - nodes as ir, FindNodes, Assignment, Associate, CallStatement, - Conditional -) +from loki.ir import nodes as ir, FindNodes from loki.transformations.sanitise import ( resolve_associates, transform_sequence_association, @@ -43,18 +40,18 @@ def test_transform_associates_simple(frontend): """ routine = Subroutine.from_source(fcode, frontend=frontend) - assert len(FindNodes(Associate).visit(routine.body)) == 1 - assert len(FindNodes(Assignment).visit(routine.body)) == 1 - assign = FindNodes(Assignment).visit(routine.body)[0] + assert len(FindNodes(ir.Associate).visit(routine.body)) == 1 + assert len(FindNodes(ir.Assignment).visit(routine.body)) == 1 + assign = FindNodes(ir.Assignment).visit(routine.body)[0] assert assign.rhs == 'a' and 'some_obj' not in assign.rhs assert assign.rhs.type.dtype == BasicType.DEFERRED # Now apply the association resolver resolve_associates(routine) - assert len(FindNodes(Associate).visit(routine.body)) == 0 - assert len(FindNodes(Assignment).visit(routine.body)) == 1 - assign = FindNodes(Assignment).visit(routine.body)[0] + assert len(FindNodes(ir.Associate).visit(routine.body)) == 0 + assert len(FindNodes(ir.Assignment).visit(routine.body)) == 1 + assign = FindNodes(ir.Assignment).visit(routine.body)[0] assert assign.rhs == 'some_obj%a' assert assign.rhs.parent == 'some_obj' assert assign.rhs.type.dtype == BasicType.DEFERRED @@ -87,18 +84,18 @@ def test_transform_associates_nested(frontend): """ routine = Subroutine.from_source(fcode, frontend=frontend) - assert len(FindNodes(Associate).visit(routine.body)) == 3 - assert len(FindNodes(Assignment).visit(routine.body)) == 1 - assign = FindNodes(Assignment).visit(routine.body)[0] + assert len(FindNodes(ir.Associate).visit(routine.body)) == 3 + assert len(FindNodes(ir.Assignment).visit(routine.body)) == 1 + assign = FindNodes(ir.Assignment).visit(routine.body)[0] assert assign.lhs == 'rick' and assign.rhs == 'a' assert assign.rhs.type.dtype == BasicType.DEFERRED # Now apply the association resolver resolve_associates(routine) - assert len(FindNodes(Associate).visit(routine.body)) == 0 - assert len(FindNodes(Assignment).visit(routine.body)) == 1 - assign = FindNodes(Assignment).visit(routine.body)[0] + assert len(FindNodes(ir.Associate).visit(routine.body)) == 0 + assert len(FindNodes(ir.Assignment).visit(routine.body)) == 1 + assign = FindNodes(ir.Assignment).visit(routine.body)[0] assert assign.rhs == 'some_obj%never%gonna%give%you%up' @@ -129,18 +126,18 @@ def test_transform_associates_array_call(frontend): routine = Subroutine.from_source(fcode, frontend=frontend) - assert len(FindNodes(Associate).visit(routine.body)) == 1 - assert len(FindNodes(CallStatement).visit(routine.body)) == 1 - call = FindNodes(CallStatement).visit(routine.body)[0] + assert len(FindNodes(ir.Associate).visit(routine.body)) == 1 + assert len(FindNodes(ir.CallStatement).visit(routine.body)) == 1 + call = FindNodes(ir.CallStatement).visit(routine.body)[0] assert call.kwarguments[0][1] == 'some_array(i)%n' assert call.kwarguments[0][1].type.dtype == BasicType.DEFERRED # Now apply the association resolver resolve_associates(routine) - assert len(FindNodes(Associate).visit(routine.body)) == 0 - assert len(FindNodes(CallStatement).visit(routine.body)) == 1 - call = FindNodes(CallStatement).visit(routine.body)[0] + assert len(FindNodes(ir.Associate).visit(routine.body)) == 0 + assert len(FindNodes(ir.CallStatement).visit(routine.body)) == 1 + call = FindNodes(ir.CallStatement).visit(routine.body)[0] assert call.kwarguments[0][1] == 'some_obj%some_array(i)%n' assert call.kwarguments[0][1].scope == routine assert call.kwarguments[0][1].type.dtype == BasicType.DEFERRED @@ -182,20 +179,20 @@ def test_transform_associates_nested_conditional(frontend): """ routine = Subroutine.from_source(fcode, frontend=frontend) - assert len(FindNodes(Conditional).visit(routine.body)) == 2 - assert len(FindNodes(Associate).visit(routine.body)) == 1 - assert len(FindNodes(Assignment).visit(routine.body)) == 3 - assign = FindNodes(Assignment).visit(routine.body)[1] + assert len(FindNodes(ir.Conditional).visit(routine.body)) == 2 + assert len(FindNodes(ir.Associate).visit(routine.body)) == 1 + assert len(FindNodes(ir.Assignment).visit(routine.body)) == 3 + assign = FindNodes(ir.Assignment).visit(routine.body)[1] assert assign.rhs == 'a' and 'some_obj' not in assign.rhs assert assign.rhs.type.dtype == BasicType.DEFERRED # Now apply the association resolver resolve_associates(routine) - assert len(FindNodes(Conditional).visit(routine.body)) == 2 - assert len(FindNodes(Associate).visit(routine.body)) == 0 - assert len(FindNodes(Assignment).visit(routine.body)) == 3 - assign = FindNodes(Assignment).visit(routine.body)[1] + assert len(FindNodes(ir.Conditional).visit(routine.body)) == 2 + assert len(FindNodes(ir.Associate).visit(routine.body)) == 0 + assert len(FindNodes(ir.Assignment).visit(routine.body)) == 3 + assign = FindNodes(ir.Assignment).visit(routine.body)[1] assert assign.rhs == 'some_obj%a' assert assign.rhs.parent == 'some_obj' assert assign.rhs.type.dtype == BasicType.DEFERRED @@ -299,7 +296,7 @@ def test_transform_sequence_assocaition_scalar_notation(frontend, tmp_path): transform_sequence_association(routine) - calls = FindNodes(CallStatement).visit(routine.body) + calls = FindNodes(ir.CallStatement).visit(routine.body) assert fgen(calls[0]).lower() == 'call sub_x(array(1:10, 1), 1)' assert fgen(calls[1]).lower() == 'call sub_x(array(2:10, 2), 2)' @@ -349,9 +346,9 @@ def test_transformation_sanitise(frontend, resolve_associate, resolve_sequence, module = Module.from_source(fcode, frontend=frontend, xmods=[tmp_path]) routine = module['test_transformation_sanitise'] - assoc = FindNodes(Associate).visit(routine.body) + assoc = FindNodes(ir.Associate).visit(routine.body) assert len(assoc) == 1 - calls = FindNodes(CallStatement).visit(routine.body) + calls = FindNodes(ir.CallStatement).visit(routine.body) assert len(calls) == 1 assert calls[0].arguments[0] == 'a(1)' @@ -361,9 +358,9 @@ def test_transformation_sanitise(frontend, resolve_associate, resolve_sequence, ) trafo.apply(routine) - assoc = FindNodes(Associate).visit(routine.body) + assoc = FindNodes(ir.Associate).visit(routine.body) assert len(assoc) == 0 if resolve_associate else 1 - calls = FindNodes(CallStatement).visit(routine.body) + calls = FindNodes(ir.CallStatement).visit(routine.body) assert len(calls) == 1 assert calls[0].arguments[0] == 'a(1:3)' if resolve_sequence else 'a(1)' From 1c576e45774db42767dcf2b0b7860ba2b70283ed Mon Sep 17 00:00:00 2001 From: Michael Lange Date: Sat, 5 Oct 2024 04:08:19 +0000 Subject: [PATCH 4/5] Transformations: Recurse to array dimes and types in ResolveAssoc --- loki/transformations/sanitise.py | 11 ++++++++++- loki/transformations/tests/test_sanitise.py | 8 +++++++- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/loki/transformations/sanitise.py b/loki/transformations/sanitise.py index 4c43fcae2..f66f5fb14 100644 --- a/loki/transformations/sanitise.py +++ b/loki/transformations/sanitise.py @@ -110,7 +110,16 @@ def map_scalar(self, expr, *args, **kwargs): def map_array(self, expr, *args, **kwargs): """ Special case for arrys: we need to preserve the dimensions """ new = self.map_variable_symbol(expr, *args, **kwargs) - return new.clone(dimensions=expr.dimensions) + + # Recurse over the type's shape + _type = expr.type + if expr.type.shape: + new_shape = self.rec(expr.type.shape, *args, **kwargs) + _type = expr.type.clone(shape=new_shape) + + # Recurse over array dimensions + new_dims = self.rec(expr.dimensions, *args, **kwargs) + return new.clone(dimensions=new_dims, type=_type) map_variable_symbol = map_scalar map_deferred_type_symbol = map_scalar diff --git a/loki/transformations/tests/test_sanitise.py b/loki/transformations/tests/test_sanitise.py index 97426f677..f7881f9da 100644 --- a/loki/transformations/tests/test_sanitise.py +++ b/loki/transformations/tests/test_sanitise.py @@ -114,8 +114,10 @@ def test_transform_associates_array_call(frontend): integer :: i real :: local_var + real, allocatable :: local_arr(:) - associate (some_array => some_obj%some_array) + associate (some_array => some_obj%some_array, a => some_obj%a) + allocate(local_arr(a%n)) do i=1, 5 call another_routine(i, n=some_array(i)%n) @@ -131,6 +133,7 @@ def test_transform_associates_array_call(frontend): call = FindNodes(ir.CallStatement).visit(routine.body)[0] assert call.kwarguments[0][1] == 'some_array(i)%n' assert call.kwarguments[0][1].type.dtype == BasicType.DEFERRED + assert routine.variable_map['local_arr'].type.shape == ('a%n',) # Now apply the association resolver resolve_associates(routine) @@ -142,6 +145,9 @@ def test_transform_associates_array_call(frontend): assert call.kwarguments[0][1].scope == routine assert call.kwarguments[0][1].type.dtype == BasicType.DEFERRED + # Test the special case of shapes derived from allocations + assert routine.variable_map['local_arr'].type.shape == ('some_obj%a%n',) + @pytest.mark.parametrize('frontend', available_frontends( xfail=[(OMNI, 'OMNI does not handle missing type definitions')] From 4445dd0f89ecd6e8f36f0fc6c5acd5f120702145 Mon Sep 17 00:00:00 2001 From: Michael Lange Date: Mon, 7 Oct 2024 04:49:35 +0000 Subject: [PATCH 5/5] IR: Add constructor test and pre-init hooks for CallStatements This adds "auto-fixer" behaviour for call.arguments and call.kwarguments, which makes it safe to always iterate them. --- loki/ir/nodes.py | 17 ++++++++++++ loki/ir/tests/test_ir_nodes.py | 48 ++++++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+) diff --git a/loki/ir/nodes.py b/loki/ir/nodes.py index da85c24ff..e9f60121b 100644 --- a/loki/ir/nodes.py +++ b/loki/ir/nodes.py @@ -955,6 +955,23 @@ class CallStatement(LeafNode, _CallStatementBase): _traversable = ['name', 'arguments', 'kwarguments'] + @model_validator(mode='before') + @classmethod + def pre_init(cls, values): + # Ensure non-nested tuples for arguments + if 'arguments' in values.kwargs: + values.kwargs['arguments'] = _sanitize_tuple(values.kwargs['arguments']) + else: + values.kwargs['arguments'] = () + # Ensure two-level nested tuples for kwarguments + if 'kwarguments' in values.kwargs: + kwarguments = as_tuple(values.kwargs['kwarguments']) + kwarguments = tuple(_sanitize_tuple(pair) for pair in kwarguments) + values.kwargs['kwarguments'] = kwarguments + else: + values.kwargs['kwarguments'] = () + return values + def __post_init__(self): super().__post_init__() assert isinstance(self.arguments, tuple) diff --git a/loki/ir/tests/test_ir_nodes.py b/loki/ir/tests/test_ir_nodes.py index cb17ea943..97a410c6b 100644 --- a/loki/ir/tests/test_ir_nodes.py +++ b/loki/ir/tests/test_ir_nodes.py @@ -196,3 +196,51 @@ def test_section(scope, one, i, n, a_n, a_i): assert sec.body == (assign, func, assign, assign) sec.insert(pos=3, node=func) assert sec.body == (assign, func, assign, func, assign) + + +def test_callstatement(scope, one, i, n, a_i): + """ Test constructor of :any:`CallStatement` nodes. """ + + cname = sym.ProcedureSymbol(name='test', scope=scope) + call = ir.CallStatement( + name=cname, arguments=(n, a_i), kwarguments=(('i', i), ('j', one)) + ) + assert isinstance(call.name, Expression) + assert isinstance(call.arguments, tuple) + assert all(isinstance(e, Expression) for e in call.arguments) + assert isinstance(call.kwarguments, tuple) + assert all(isinstance(e, tuple) for e in call.kwarguments) + assert all( + isinstance(k, str) and isinstance(v, Expression) + for k, v in call.kwarguments + ) + + # Ensure "frozen" status of node objects + with pytest.raises(FrozenInstanceError) as error: + call.name = sym.ProcedureSymbol('dave', scope=scope) + with pytest.raises(FrozenInstanceError) as error: + call.arguments = (a_i, n, one) + with pytest.raises(FrozenInstanceError) as error: + call.kwarguments = (('i', one), ('j', i)) + + # Test auto-casting of the body to tuple + call = ir.CallStatement(name=cname, arguments=[a_i, one]) + assert call.arguments == (a_i, one) and call.kwarguments == () + call = ir.CallStatement(name=cname, arguments=None) + assert call.arguments == () and call.kwarguments == () + call = ir.CallStatement(name=cname, kwarguments=[('i', i), ('j', one)]) + assert call.arguments == () and call.kwarguments == (('i', i), ('j', one)) + call = ir.CallStatement(name=cname, kwarguments=None) + assert call.arguments == () and call.kwarguments == () + + # Test errors for wrong contructor usage + with pytest.raises(ValidationError) as error: + ir.CallStatement(name='a', arguments=(sym.Literal(42.0),)) + with pytest.raises(ValidationError) as error: + ir.CallStatement(name=cname, arguments=('a',)) + with pytest.raises(ValidationError) as error: + ir.Assignment( + name=cname, arguments=(sym.Literal(42.0),), kwarguments=('i', 'i') + ) + + # TODO: Test pragmas, active and chevron