diff --git a/loki/expression/expr_visitors.py b/loki/expression/expr_visitors.py index a3de74e35..bc9edb58c 100644 --- a/loki/expression/expr_visitors.py +++ b/loki/expression/expr_visitors.py @@ -297,7 +297,7 @@ def __init__(self, expr_map, invalidate_source=True, **kwargs): invalidate_source=invalidate_source) def visit_Expression(self, o, **kwargs): - return self.expr_mapper(o) + return self.expr_mapper(o, parent_node=kwargs.get('current_node')) class AttachScopes(Visitor): @@ -338,13 +338,15 @@ def visit(self, o, *args, **kwargs): Default visitor method that dispatches the node-specific handler """ kwargs.setdefault('scope', None) + if isinstance(o, Node): + kwargs['current_node'] = o return super().visit(o, *args, **kwargs) def visit_Expression(self, o, **kwargs): """ Dispatch :any:`AttachScopesMapper` for :any:`Expression` tree nodes """ - return self.expr_mapper(o, scope=kwargs['scope']) + return self.expr_mapper(o, scope=kwargs['scope'], current_node=kwargs.get('current_node')) def visit_list(self, o, **kwargs): """ diff --git a/loki/expression/mappers.py b/loki/expression/mappers.py index 70ce06325..5c435fb3d 100644 --- a/loki/expression/mappers.py +++ b/loki/expression/mappers.py @@ -23,6 +23,7 @@ except ImportError: _intrinsic_fortran_names = () +from loki.ir import VariableDeclaration, Import from loki.logging import debug from loki.tools import as_tuple, flatten from loki.types import SymbolAttributes, BasicType @@ -520,6 +521,10 @@ def __init__(self, invalidate_source=True): def __call__(self, expr, *args, **kwargs): if expr is None: return None + kwargs.setdefault( + 'recurse_to_declaration_attributes', + 'current_node' not in kwargs or isinstance(kwargs['current_node'], (VariableDeclaration, Import)) + ) new_expr = super().__call__(expr, *args, **kwargs) if getattr(expr, 'source', None): if isinstance(new_expr, tuple): @@ -557,34 +562,22 @@ def map_variable_symbol(self, expr, *args, **kwargs): # it does not affect the outcome of expr.clone expr.scope.symbol_attrs[expr.name] = expr.type.clone(kind=kind) - if expr.scope and expr.type.initial and expr.name == expr.type.initial: - # FIXME: This is a hack to work around situations where a constant - # symbol (from a parent scope) with the same name as the declared - # variable is used as initializer. This hands down the correct scope - # (in case this traversal is part of ``AttachScopesMapper``) and thus - # interrupts an otherwise infinite recursion (see LOKI-52). + if kwargs['recurse_to_declaration_attributes']: _kwargs = kwargs.copy() - _kwargs['scope'] = expr.scope.parent - initial = self.rec(expr.type.initial, *args, **_kwargs) - elif expr.type.initial and expr.name.lower() in str(expr.type.initial).lower(): - # FIXME: This is another hack to work around situations where the - # variable itself appears in the initializer expression (e.g. to - # inquire value limits via ``HUGE``), which would otherwise result in - # infinite recursion: - # 1. Replace occurences in initializer expression by a temporary variable... - retr = ExpressionRetriever(query=lambda e: e == expr.name) - retr(expr.type.initial) - tmp_map = {e: e.clone(name=f'___tmp___{expr.name}', scope=None) for e in retr.exprs} - tmp_initial = SubstituteExpressionsMapper(tmp_map)(expr.type.initial) - # 2. ...do the recursion into the initializer expression... - tmp_initial = self.rec(tmp_initial, *args, **kwargs) - # 3. ...and reverse the replacement by a temporary variable: - retr = ExpressionRetriever(query=lambda e: e == f'___tmp___{expr.name}') - retr(tmp_initial) - tmp_map = {e: e.clone(name=expr.name) for e in retr.exprs} - initial = SubstituteExpressionsMapper(tmp_map)(tmp_initial) + _kwargs['recurse_to_declaration_attributes'] = False + if expr.scope and expr.type.initial and expr.name == expr.type.initial: + # FIXME: This is a hack to work around situations where a constant + # symbol (from a parent scope) with the same name as the declared + # variable is used as initializer. This hands down the correct scope + # (in case this traversal is part of ``AttachScopesMapper``) and thus + # interrupts an otherwise infinite recursion (see LOKI-52). + _kwargs['scope'] = expr.scope.parent + initial = self.rec(expr.type.initial, *args, **_kwargs) + else: + initial = self.rec(expr.type.initial, *args, **_kwargs) else: - initial = self.rec(expr.type.initial, *args, **kwargs) + initial = expr.type.initial + if initial is not expr.type.initial and expr.scope: # Update symbol table entry for initial directly because with a scope attached # it does not affect the outcome of expr.clone @@ -628,7 +621,13 @@ def map_array(self, expr, *args, **kwargs): # and make sure we don't loose the call parameters (aka dimensions) return InlineCall(function=symbol.clone(parent=parent), parameters=dimensions) - shape = self.rec(symbol.type.shape, *args, **kwargs) + if kwargs['recurse_to_declaration_attributes']: + _kwargs = kwargs.copy() + _kwargs['recurse_to_declaration_attributes'] = False + shape = self.rec(symbol.type.shape, *args, **_kwargs) + else: + shape = symbol.type.shape + if (getattr(symbol, 'symbol', symbol) is expr.symbol and all(d is orig_d for d, orig_d in zip_longest(dimensions or (), expr.dimensions or ())) and all(d is orig_d for d, orig_d in zip_longest(shape or (), symbol.type.shape or ()))): diff --git a/loki/transform/transform_associates.py b/loki/transform/transform_associates.py index 7d1d17f6d..987811fd6 100644 --- a/loki/transform/transform_associates.py +++ b/loki/transform/transform_associates.py @@ -40,7 +40,7 @@ class ResolveAssociatesTransformer(Transformer): corresponding `selector` expression defined in ``associations``. """ - def visit_Associate(self, o): + def visit_Associate(self, o, **kwargs): # First head-recurse, so that all associate blocks beneath are resolved body = self.visit(o.body) diff --git a/loki/visitors/transform.py b/loki/visitors/transform.py index f05510c6d..3113cd6e8 100644 --- a/loki/visitors/transform.py +++ b/loki/visitors/transform.py @@ -246,6 +246,8 @@ def visit(self, o, *args, **kwargs): :any:`Node` or tuple The rebuilt control flow tree. """ + if isinstance(o, Node): + kwargs['current_node'] = o obj = super().visit(o, *args, **kwargs) if isinstance(o, Node) and obj is not o: self.rebuilt[o] = obj @@ -446,6 +448,8 @@ def visit(self, o, *args, **kwargs): # to make sure that we don't include any following nodes we clear start self.start.clear() self.active = False + if isinstance(o, Node): + kwargs['current_node'] = o return super().visit(o, *args, **kwargs) def visit_object(self, o, **kwargs): @@ -503,7 +507,6 @@ class NestedMaskedTransformer(MaskedTransformer): # Handler for leaf nodes - def visit_object(self, o, **kwargs): """ Return the object unchanged. diff --git a/tests/test_expression.py b/tests/test_expression.py index eea0d29f0..0299efc1f 100644 --- a/tests/test_expression.py +++ b/tests/test_expression.py @@ -1309,15 +1309,50 @@ def _check(routine_): zexplimit = routine_.variable_map['zexplimit'] assert zexplimit.scope is routine_ # Now let's take a closer look at the initializer expression - # The way we currently work around the infinite recursion, is - # that we don't attach the scope for the variable in the rhs assert 'zexplimit' in str(zexplimit.type.initial).lower() variables = FindVariables().visit(zexplimit.type.initial) assert 'zexplimit' in variables - assert variables[variables.index('zexplimit')].scope is None + assert variables[variables.index('zexplimit')].scope is routine_ routine = Subroutine.from_source(fcode, frontend=frontend) _check(routine) # Make sure that's still true when doing another scope attachment routine.rescope_symbols() _check(routine) + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_variable_in_dimensions(frontend): + """ + Check correct handling of cases where the variable appears in the + dimensions expression of the same variable (i.e. do not cause + infinite recursion) + """ + fcode = """ +module some_mod + implicit none + + type multi_level + real, allocatable :: data(:, :) + end type multi_level +contains + subroutine some_routine(levels, num_levels) + type(multi_level), intent(inout) :: levels(:) + integer, intent(in) :: num_levels + integer jscale + + do jscale = 2,num_levels + allocate(levels(jscale)%data(size(levels(jscale-1)%data,1), size(levels(jscale-1)%data,2))) + end do + end subroutine some_routine +end module some_mod + """.strip() + + module = Module.from_source(fcode, frontend=frontend) + routine = module['some_routine'] + assert 'levels%data' in routine.symbol_attrs + shape = routine.symbol_attrs['levels%data'].shape + assert len(shape) == 2 + for i, dim in enumerate(shape): + assert isinstance(dim, symbols.InlineCall) + assert str(dim).lower() == f'size(levels(jscale - 1)%data, {i+1})'