Skip to content

Commit

Permalink
Recurse into declaration attributes only from Import and VariableDecl…
Browse files Browse the repository at this point in the history
…aration
  • Loading branch information
reuterbal committed Mar 26, 2023
1 parent cd13933 commit e6ead63
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 34 deletions.
6 changes: 4 additions & 2 deletions loki/expression/expr_visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def __init__(self, expr_map, invalidate_source=True, **kwargs):
invalidate_source=invalidate_source)

def visit_Expression(self, o, **kwargs):
return self.expr_mapper(o)
return self.expr_mapper(o, parent_node=kwargs.get('current_node'))


class AttachScopes(Visitor):
Expand Down Expand Up @@ -338,13 +338,15 @@ def visit(self, o, *args, **kwargs):
Default visitor method that dispatches the node-specific handler
"""
kwargs.setdefault('scope', None)
if isinstance(o, Node):
kwargs['current_node'] = o
return super().visit(o, *args, **kwargs)

def visit_Expression(self, o, **kwargs):
"""
Dispatch :any:`AttachScopesMapper` for :any:`Expression` tree nodes
"""
return self.expr_mapper(o, scope=kwargs['scope'])
return self.expr_mapper(o, scope=kwargs['scope'], current_node=kwargs.get('current_node'))

def visit_list(self, o, **kwargs):
"""
Expand Down
53 changes: 26 additions & 27 deletions loki/expression/mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
except ImportError:
_intrinsic_fortran_names = ()

from loki.ir import VariableDeclaration, Import
from loki.logging import debug
from loki.tools import as_tuple, flatten
from loki.types import SymbolAttributes, BasicType
Expand Down Expand Up @@ -520,6 +521,10 @@ def __init__(self, invalidate_source=True):
def __call__(self, expr, *args, **kwargs):
if expr is None:
return None
kwargs.setdefault(
'recurse_to_declaration_attributes',
'current_node' not in kwargs or isinstance(kwargs['current_node'], (VariableDeclaration, Import))
)
new_expr = super().__call__(expr, *args, **kwargs)
if getattr(expr, 'source', None):
if isinstance(new_expr, tuple):
Expand Down Expand Up @@ -557,34 +562,22 @@ def map_variable_symbol(self, expr, *args, **kwargs):
# it does not affect the outcome of expr.clone
expr.scope.symbol_attrs[expr.name] = expr.type.clone(kind=kind)

if expr.scope and expr.type.initial and expr.name == expr.type.initial:
# FIXME: This is a hack to work around situations where a constant
# symbol (from a parent scope) with the same name as the declared
# variable is used as initializer. This hands down the correct scope
# (in case this traversal is part of ``AttachScopesMapper``) and thus
# interrupts an otherwise infinite recursion (see LOKI-52).
if kwargs['recurse_to_declaration_attributes']:
_kwargs = kwargs.copy()
_kwargs['scope'] = expr.scope.parent
initial = self.rec(expr.type.initial, *args, **_kwargs)
elif expr.type.initial and expr.name.lower() in str(expr.type.initial).lower():
# FIXME: This is another hack to work around situations where the
# variable itself appears in the initializer expression (e.g. to
# inquire value limits via ``HUGE``), which would otherwise result in
# infinite recursion:
# 1. Replace occurences in initializer expression by a temporary variable...
retr = ExpressionRetriever(query=lambda e: e == expr.name)
retr(expr.type.initial)
tmp_map = {e: e.clone(name=f'___tmp___{expr.name}', scope=None) for e in retr.exprs}
tmp_initial = SubstituteExpressionsMapper(tmp_map)(expr.type.initial)
# 2. ...do the recursion into the initializer expression...
tmp_initial = self.rec(tmp_initial, *args, **kwargs)
# 3. ...and reverse the replacement by a temporary variable:
retr = ExpressionRetriever(query=lambda e: e == f'___tmp___{expr.name}')
retr(tmp_initial)
tmp_map = {e: e.clone(name=expr.name) for e in retr.exprs}
initial = SubstituteExpressionsMapper(tmp_map)(tmp_initial)
_kwargs['recurse_to_declaration_attributes'] = False
if expr.scope and expr.type.initial and expr.name == expr.type.initial:
# FIXME: This is a hack to work around situations where a constant
# symbol (from a parent scope) with the same name as the declared
# variable is used as initializer. This hands down the correct scope
# (in case this traversal is part of ``AttachScopesMapper``) and thus
# interrupts an otherwise infinite recursion (see LOKI-52).
_kwargs['scope'] = expr.scope.parent
initial = self.rec(expr.type.initial, *args, **_kwargs)
else:
initial = self.rec(expr.type.initial, *args, **_kwargs)
else:
initial = self.rec(expr.type.initial, *args, **kwargs)
initial = expr.type.initial

if initial is not expr.type.initial and expr.scope:
# Update symbol table entry for initial directly because with a scope attached
# it does not affect the outcome of expr.clone
Expand Down Expand Up @@ -628,7 +621,13 @@ def map_array(self, expr, *args, **kwargs):
# and make sure we don't loose the call parameters (aka dimensions)
return InlineCall(function=symbol.clone(parent=parent), parameters=dimensions)

shape = self.rec(symbol.type.shape, *args, **kwargs)
if kwargs['recurse_to_declaration_attributes']:
_kwargs = kwargs.copy()
_kwargs['recurse_to_declaration_attributes'] = False
shape = self.rec(symbol.type.shape, *args, **_kwargs)
else:
shape = symbol.type.shape

if (getattr(symbol, 'symbol', symbol) is expr.symbol and
all(d is orig_d for d, orig_d in zip_longest(dimensions or (), expr.dimensions or ())) and
all(d is orig_d for d, orig_d in zip_longest(shape or (), symbol.type.shape or ()))):
Expand Down
2 changes: 1 addition & 1 deletion loki/transform/transform_associates.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class ResolveAssociatesTransformer(Transformer):
corresponding `selector` expression defined in ``associations``.
"""

def visit_Associate(self, o):
def visit_Associate(self, o, **kwargs):
# First head-recurse, so that all associate blocks beneath are resolved
body = self.visit(o.body)

Expand Down
5 changes: 4 additions & 1 deletion loki/visitors/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,8 @@ def visit(self, o, *args, **kwargs):
:any:`Node` or tuple
The rebuilt control flow tree.
"""
if isinstance(o, Node):
kwargs['current_node'] = o
obj = super().visit(o, *args, **kwargs)
if isinstance(o, Node) and obj is not o:
self.rebuilt[o] = obj
Expand Down Expand Up @@ -446,6 +448,8 @@ def visit(self, o, *args, **kwargs):
# to make sure that we don't include any following nodes we clear start
self.start.clear()
self.active = False
if isinstance(o, Node):
kwargs['current_node'] = o
return super().visit(o, *args, **kwargs)

def visit_object(self, o, **kwargs):
Expand Down Expand Up @@ -503,7 +507,6 @@ class NestedMaskedTransformer(MaskedTransformer):

# Handler for leaf nodes


def visit_object(self, o, **kwargs):
"""
Return the object unchanged.
Expand Down
41 changes: 38 additions & 3 deletions tests/test_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -1309,15 +1309,50 @@ def _check(routine_):
zexplimit = routine_.variable_map['zexplimit']
assert zexplimit.scope is routine_
# Now let's take a closer look at the initializer expression
# The way we currently work around the infinite recursion, is
# that we don't attach the scope for the variable in the rhs
assert 'zexplimit' in str(zexplimit.type.initial).lower()
variables = FindVariables().visit(zexplimit.type.initial)
assert 'zexplimit' in variables
assert variables[variables.index('zexplimit')].scope is None
assert variables[variables.index('zexplimit')].scope is routine_

routine = Subroutine.from_source(fcode, frontend=frontend)
_check(routine)
# Make sure that's still true when doing another scope attachment
routine.rescope_symbols()
_check(routine)


@pytest.mark.parametrize('frontend', available_frontends())
def test_variable_in_dimensions(frontend):
"""
Check correct handling of cases where the variable appears in the
dimensions expression of the same variable (i.e. do not cause
infinite recursion)
"""
fcode = """
module some_mod
implicit none
type multi_level
real, allocatable :: data(:, :)
end type multi_level
contains
subroutine some_routine(levels, num_levels)
type(multi_level), intent(inout) :: levels(:)
integer, intent(in) :: num_levels
integer jscale
do jscale = 2,num_levels
allocate(levels(jscale)%data(size(levels(jscale-1)%data,1), size(levels(jscale-1)%data,2)))
end do
end subroutine some_routine
end module some_mod
""".strip()

module = Module.from_source(fcode, frontend=frontend)
routine = module['some_routine']
assert 'levels%data' in routine.symbol_attrs
shape = routine.symbol_attrs['levels%data'].shape
assert len(shape) == 2
for i, dim in enumerate(shape):
assert isinstance(dim, symbols.InlineCall)
assert str(dim).lower() == f'size(levels(jscale - 1)%data, {i+1})'

0 comments on commit e6ead63

Please sign in to comment.