Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transformations: ResolveAssociateTransformer re-write to in-place substitution #387

Merged
merged 5 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion loki/ir/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,14 @@ def association_map(self):
"""
An :any:`collections.OrderedDict` of associated expressions.
"""
return OrderedDict(self.associations)
return CaseInsensitiveDict((str(k), v) for k, v in self.associations)

@property
def inverse_map(self):
"""
An :any:`collections.OrderedDict` of associated expressions.
"""
return CaseInsensitiveDict((str(v), k) for k, v in self.associations)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we simply "stringify" v here, considering that an association can in theory be a rather complex expression?


@property
def variables(self):
Expand Down Expand Up @@ -948,6 +955,23 @@ class CallStatement(LeafNode, _CallStatementBase):

_traversable = ['name', 'arguments', 'kwarguments']

@model_validator(mode='before')
@classmethod
def pre_init(cls, values):
# Ensure non-nested tuples for arguments
if 'arguments' in values.kwargs:
values.kwargs['arguments'] = _sanitize_tuple(values.kwargs['arguments'])
else:
values.kwargs['arguments'] = ()
# Ensure two-level nested tuples for kwarguments
if 'kwarguments' in values.kwargs:
kwarguments = as_tuple(values.kwargs['kwarguments'])
kwarguments = tuple(_sanitize_tuple(pair) for pair in kwarguments)
values.kwargs['kwarguments'] = kwarguments
else:
values.kwargs['kwarguments'] = ()
return values

def __post_init__(self):
super().__post_init__()
assert isinstance(self.arguments, tuple)
Expand Down
48 changes: 48 additions & 0 deletions loki/ir/tests/test_ir_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,3 +196,51 @@ def test_section(scope, one, i, n, a_n, a_i):
assert sec.body == (assign, func, assign, assign)
sec.insert(pos=3, node=func)
assert sec.body == (assign, func, assign, func, assign)


def test_callstatement(scope, one, i, n, a_i):
""" Test constructor of :any:`CallStatement` nodes. """

cname = sym.ProcedureSymbol(name='test', scope=scope)
call = ir.CallStatement(
name=cname, arguments=(n, a_i), kwarguments=(('i', i), ('j', one))
)
assert isinstance(call.name, Expression)
assert isinstance(call.arguments, tuple)
assert all(isinstance(e, Expression) for e in call.arguments)
assert isinstance(call.kwarguments, tuple)
assert all(isinstance(e, tuple) for e in call.kwarguments)
assert all(
isinstance(k, str) and isinstance(v, Expression)
for k, v in call.kwarguments
)

# Ensure "frozen" status of node objects
with pytest.raises(FrozenInstanceError) as error:
call.name = sym.ProcedureSymbol('dave', scope=scope)
with pytest.raises(FrozenInstanceError) as error:
call.arguments = (a_i, n, one)
with pytest.raises(FrozenInstanceError) as error:
call.kwarguments = (('i', one), ('j', i))

# Test auto-casting of the body to tuple
call = ir.CallStatement(name=cname, arguments=[a_i, one])
assert call.arguments == (a_i, one) and call.kwarguments == ()
call = ir.CallStatement(name=cname, arguments=None)
assert call.arguments == () and call.kwarguments == ()
call = ir.CallStatement(name=cname, kwarguments=[('i', i), ('j', one)])
assert call.arguments == () and call.kwarguments == (('i', i), ('j', one))
call = ir.CallStatement(name=cname, kwarguments=None)
assert call.arguments == () and call.kwarguments == ()

# Test errors for wrong contructor usage
with pytest.raises(ValidationError) as error:
ir.CallStatement(name='a', arguments=(sym.Literal(42.0),))
with pytest.raises(ValidationError) as error:
ir.CallStatement(name=cname, arguments=('a',))
with pytest.raises(ValidationError) as error:
ir.Assignment(
name=cname, arguments=(sym.Literal(42.0),), kwarguments=('i', 'i')
)

# TODO: Test pragmas, active and chevron
108 changes: 72 additions & 36 deletions loki/transformations/sanitise.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,11 @@
"""

from loki.batch import Transformation
from loki.expression import Array, RangeIndex
from loki.ir import (
CallStatement, FindNodes, Transformer, NestedTransformer,
FindVariables, SubstituteExpressions
)
from loki.tools import as_tuple, CaseInsensitiveDict
from loki.expression import Array, RangeIndex, LokiIdentityMapper
from loki.ir import nodes as ir, FindNodes, Transformer
from loki.tools import as_tuple
from loki.types import BasicType

from loki.transformations.utilities import recursive_expression_map_update


__all__ = [
'SanitiseTransformation', 'resolve_associates',
Expand Down Expand Up @@ -80,42 +75,83 @@ def resolve_associates(routine):
routine.rescope_symbols()


class ResolveAssociatesTransformer(NestedTransformer):
class ResolveAssociateMapper(LokiIdentityMapper):
"""
Exppression mapper that will resolve symbol associations due
:any:`Associate` scopes.

The mapper will inspect the associated scope of each symbol
and replace it with the inverse of the associate mapping.
"""
:any:`Transformer` class to resolve :any:`Associate` nodes in IR trees

def map_scalar(self, expr, *args, **kwargs):
# Skip unscoped expressions
if not hasattr(expr, 'scope'):
return self.rec(expr, *args, **kwargs)

# Stop if scope is not an associate
if not isinstance(expr.scope, ir.Associate):
return expr

scope = expr.scope

# Recurse on parent first and propagate scope changes
parent = self.rec(expr.parent, *args, **kwargs)
if parent != expr.parent:
expr = expr.clone(parent=parent, scope=parent.scope)

# Find a match in the given inverse map
if expr.basename in scope.inverse_map:
expr = scope.inverse_map[expr.basename]
return self.rec(expr, *args, **kwargs)

return expr

def map_array(self, expr, *args, **kwargs):
""" Special case for arrys: we need to preserve the dimensions """
new = self.map_variable_symbol(expr, *args, **kwargs)

# Recurse over the type's shape
_type = expr.type
if expr.type.shape:
new_shape = self.rec(expr.type.shape, *args, **kwargs)
_type = expr.type.clone(shape=new_shape)

# Recurse over array dimensions
new_dims = self.rec(expr.dimensions, *args, **kwargs)
return new.clone(dimensions=new_dims, type=_type)

map_variable_symbol = map_scalar
map_deferred_type_symbol = map_scalar
map_procedure_symbol = map_scalar


class ResolveAssociatesTransformer(Transformer):
"""
:any:`Transformer` class to resolve :any:`Associate` nodes in IR trees.

This will replace each :any:`Associate` node with its own body,
where all `identifier` symbols have been replaced with the
corresponding `selector` expression defined in ``associations``.
"""

def visit_Associate(self, o, **kwargs):
# First head-recurse, so that all associate blocks beneath are resolved
body = self.visit(o.body, **kwargs)

# Create an inverse association map to look up replacements
invert_assoc = CaseInsensitiveDict({v.name: k for k, v in o.associations})

# Build the expression substitution map
vmap = {}
for v in FindVariables().visit(body):
if v.name in invert_assoc:
# Clone the expression to update its parentage and scoping
inv = invert_assoc[v.name]
if hasattr(v, 'dimensions'):
vmap[v] = inv.clone(dimensions=v.dimensions)
else:
vmap[v] = inv
Importantly, this :any:`Transformer` can also be applied over partial
bodies of :any:`Associate` bodies.
"""
# pylint: disable=unused-argument

# Apply the expression substitution map to itself to handle nested expressions
vmap = recursive_expression_map_update(vmap)
def visit_Expression(self, o, **kwargs):
return ResolveAssociateMapper()(o)

# Mark the associate block for replacement with its body, with all expressions replaced
self.mapper[o] = SubstituteExpressions(vmap).visit(body)
def visit_Associate(self, o, **kwargs):
"""
Replaces an :any:`Associate` node with its transformed body
"""
return self.visit(o.body, **kwargs)

# Return the original object unchanged and let the tuple injection mechanism take care
# of replacing it by its body - otherwise we would end up with nested tuples
return o
def visit_CallStatement(self, o, **kwargs):
arguments = self.visit(o.arguments, **kwargs)
kwarguments = tuple((k, self.visit(v, **kwargs)) for k, v in o.kwarguments)
return o._rebuild(arguments=arguments, kwarguments=kwarguments)


def check_if_scalar_syntax(arg, dummy):
Expand Down Expand Up @@ -169,7 +205,7 @@ def transform_sequence_association(routine):
"""

#List calls in routine, but make sure we have the called routine definition
calls = (c for c in FindNodes(CallStatement).visit(routine.body) if not c.procedure_type is BasicType.DEFERRED)
calls = (c for c in FindNodes(ir.CallStatement).visit(routine.body) if not c.procedure_type is BasicType.DEFERRED)
call_map = {}

# Check all calls and record changes to `call_map` if necessary.
Expand Down
Loading
Loading