From 4b22bb9644b1c4345b40b7a4d303c783ef846158 Mon Sep 17 00:00:00 2001 From: Balthasar Reuter Date: Fri, 24 Mar 2023 21:54:54 +0000 Subject: [PATCH 01/10] Fix infinite recursion for declarations with the symbol in initializer --- loki/expression/mappers.py | 17 +++++++++++++++++ tests/test_expression.py | 36 ++++++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+) diff --git a/loki/expression/mappers.py b/loki/expression/mappers.py index 7119a033c..70ce06325 100644 --- a/loki/expression/mappers.py +++ b/loki/expression/mappers.py @@ -566,6 +566,23 @@ def map_variable_symbol(self, expr, *args, **kwargs): _kwargs = kwargs.copy() _kwargs['scope'] = expr.scope.parent initial = self.rec(expr.type.initial, *args, **_kwargs) + elif expr.type.initial and expr.name.lower() in str(expr.type.initial).lower(): + # FIXME: This is another hack to work around situations where the + # variable itself appears in the initializer expression (e.g. to + # inquire value limits via ``HUGE``), which would otherwise result in + # infinite recursion: + # 1. Replace occurences in initializer expression by a temporary variable... + retr = ExpressionRetriever(query=lambda e: e == expr.name) + retr(expr.type.initial) + tmp_map = {e: e.clone(name=f'___tmp___{expr.name}', scope=None) for e in retr.exprs} + tmp_initial = SubstituteExpressionsMapper(tmp_map)(expr.type.initial) + # 2. ...do the recursion into the initializer expression... + tmp_initial = self.rec(tmp_initial, *args, **kwargs) + # 3. ...and reverse the replacement by a temporary variable: + retr = ExpressionRetriever(query=lambda e: e == f'___tmp___{expr.name}') + retr(tmp_initial) + tmp_map = {e: e.clone(name=expr.name) for e in retr.exprs} + initial = SubstituteExpressionsMapper(tmp_map)(tmp_initial) else: initial = self.rec(expr.type.initial, *args, **kwargs) if initial is not expr.type.initial and expr.scope: diff --git a/tests/test_expression.py b/tests/test_expression.py index 1d9d023a8..b7f8a60b5 100644 --- a/tests/test_expression.py +++ b/tests/test_expression.py @@ -1289,3 +1289,39 @@ def test_recursive_substitution(frontend): routine.body = SubstituteExpressions(expr_map).visit(routine.body) assignment = FindNodes(Assignment).visit(routine.body)[0] assert assignment.lhs == 'var(j + 1)' + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_variable_in_declaration_initializer(frontend): + """ + Check correct handling of cases where the variable appears + in the initializer expression (i.e. no infinite recursion) + """ + fcode = """ +subroutine some_routine(var) +implicit none +INTEGER, PARAMETER :: JPRB = SELECTED_REAL_KIND(13,300) +REAL(KIND=JPRB), PARAMETER :: ZEXPLIMIT = LOG(HUGE(ZEXPLIMIT)) +real(kind=jprb), intent(inout) :: var +var = var + ZEXPLIMIT +end subroutine some_routine + """.strip() + + def _check(routine_): + # A few sanity checks + assert 'zexplimit' in routine_.variable_map + zexplimit = routine_.variable_map['zexplimit'] + assert zexplimit.scope is routine_ + # Now let's take a closer look at the initializer expression + # The way we currently work around the infinite recursion, is + # that we don't attach the scope for the variable in the rhs + assert 'zexplimit' in str(zexplimit.type.initial).lower() + variables = FindVariables().visit(zexplimit.type.initial) + assert 'zexplimit' in variables + assert variables[variables.index('zexplimit')].scope is None + + routine = Subroutine.from_source(fcode, frontend=frontend) + _check(routine) + # Make sure that's still true when doing another scope attachment + routine.rescope_symbols() + _check(routine) From 16cd4424eeb871fba1442513040fc1d4c17439c4 Mon Sep 17 00:00:00 2001 From: Balthasar Reuter Date: Fri, 24 Mar 2023 22:32:13 +0000 Subject: [PATCH 02/10] Do not invalidate source in statement function injection --- loki/frontend/util.py | 8 ++++---- tests/test_subroutine.py | 7 ++++++- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/loki/frontend/util.py b/loki/frontend/util.py index 603e35c58..9415bee71 100644 --- a/loki/frontend/util.py +++ b/loki/frontend/util.py @@ -271,15 +271,15 @@ def create_type(stmt_func): # Apply transformer with the built maps if spec_map: - routine.spec = Transformer(spec_map).visit(routine.spec) + routine.spec = Transformer(spec_map, invalidate_source=False).visit(routine.spec) if body_map: - routine.body = Transformer(body_map).visit(routine.body) + routine.body = Transformer(body_map, invalidate_source=False).visit(routine.body) if spec_appendix: routine.spec.append(spec_appendix) if expr_map_spec: - routine.spec = SubstituteExpressions(expr_map_spec).visit(routine.spec) + routine.spec = SubstituteExpressions(expr_map_spec, invalidate_source=False).visit(routine.spec) if expr_map_body: - routine.body = SubstituteExpressions(expr_map_body).visit(routine.body) + routine.body = SubstituteExpressions(expr_map_body, invalidate_source=False).visit(routine.body) # And make sure all symbols have the right type routine.rescope_symbols() diff --git a/tests/test_subroutine.py b/tests/test_subroutine.py index 0c5dd5c74..726582d6e 100644 --- a/tests/test_subroutine.py +++ b/tests/test_subroutine.py @@ -16,7 +16,8 @@ Section, CallStatement, BasicType, Array, Scalar, Variable, SymbolAttributes, StringLiteral, fgen, fexprgen, VariableDeclaration, Transformer, FindTypedSymbols, ProcedureSymbol, ProcedureType, - StatementFunction, normalize_range_indexing, DeferredTypeSymbol + StatementFunction, normalize_range_indexing, DeferredTypeSymbol, + Assignment ) @@ -1504,6 +1505,10 @@ def test_subroutine_stmt_func(here, frontend): routine = Subroutine.from_source(fcode, frontend=frontend) routine.name += f'_{frontend!s}' + # Make sure the statement function injection doesn't invalidate source + for assignment in FindNodes(Assignment).visit(routine.body): + assert assignment.source is not None + # OMNI inlines statement functions, so we can only check correct representation # for fparser if frontend != OMNI: From 54462c04ad104818b0c2633b85b61abb5f0b26fc Mon Sep 17 00:00:00 2001 From: Balthasar Reuter Date: Fri, 24 Mar 2023 23:18:30 +0000 Subject: [PATCH 03/10] Store source for helper var declaration representing return type --- loki/frontend/fparser.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/loki/frontend/fparser.py b/loki/frontend/fparser.py index a7ea09121..1efc3c8ac 100644 --- a/loki/frontend/fparser.py +++ b/loki/frontend/fparser.py @@ -1782,7 +1782,8 @@ def visit_Subroutine_Subprogram(self, o, **kwargs): if return_type is not None: routine.symbol_attrs[routine.name] = return_type return_var = sym.Variable(name=routine.name, scope=routine) - return_var_decl = ir.VariableDeclaration(symbols=(return_var,)) + decl_source = self.get_source(subroutine_stmt, source=None) + return_var_decl = ir.VariableDeclaration(symbols=(return_var,), source=decl_source) decls = FindNodes((ir.VariableDeclaration, ir.ProcedureDeclaration)).visit(spec) if not decls: From 47ed3a28638956049ff07c44d3904369336b9b3a Mon Sep 17 00:00:00 2001 From: Balthasar Reuter Date: Sun, 26 Mar 2023 00:16:07 +0000 Subject: [PATCH 04/10] Recurse into declaration attributes only from Import and VariableDeclaration --- loki/expression/expr_visitors.py | 6 ++- loki/expression/mappers.py | 53 +++++++++++++------------- loki/transform/transform_associates.py | 2 +- loki/visitors/transform.py | 5 ++- tests/test_expression.py | 41 ++++++++++++++++++-- 5 files changed, 73 insertions(+), 34 deletions(-) diff --git a/loki/expression/expr_visitors.py b/loki/expression/expr_visitors.py index a3de74e35..bc9edb58c 100644 --- a/loki/expression/expr_visitors.py +++ b/loki/expression/expr_visitors.py @@ -297,7 +297,7 @@ def __init__(self, expr_map, invalidate_source=True, **kwargs): invalidate_source=invalidate_source) def visit_Expression(self, o, **kwargs): - return self.expr_mapper(o) + return self.expr_mapper(o, parent_node=kwargs.get('current_node')) class AttachScopes(Visitor): @@ -338,13 +338,15 @@ def visit(self, o, *args, **kwargs): Default visitor method that dispatches the node-specific handler """ kwargs.setdefault('scope', None) + if isinstance(o, Node): + kwargs['current_node'] = o return super().visit(o, *args, **kwargs) def visit_Expression(self, o, **kwargs): """ Dispatch :any:`AttachScopesMapper` for :any:`Expression` tree nodes """ - return self.expr_mapper(o, scope=kwargs['scope']) + return self.expr_mapper(o, scope=kwargs['scope'], current_node=kwargs.get('current_node')) def visit_list(self, o, **kwargs): """ diff --git a/loki/expression/mappers.py b/loki/expression/mappers.py index 70ce06325..5c435fb3d 100644 --- a/loki/expression/mappers.py +++ b/loki/expression/mappers.py @@ -23,6 +23,7 @@ except ImportError: _intrinsic_fortran_names = () +from loki.ir import VariableDeclaration, Import from loki.logging import debug from loki.tools import as_tuple, flatten from loki.types import SymbolAttributes, BasicType @@ -520,6 +521,10 @@ def __init__(self, invalidate_source=True): def __call__(self, expr, *args, **kwargs): if expr is None: return None + kwargs.setdefault( + 'recurse_to_declaration_attributes', + 'current_node' not in kwargs or isinstance(kwargs['current_node'], (VariableDeclaration, Import)) + ) new_expr = super().__call__(expr, *args, **kwargs) if getattr(expr, 'source', None): if isinstance(new_expr, tuple): @@ -557,34 +562,22 @@ def map_variable_symbol(self, expr, *args, **kwargs): # it does not affect the outcome of expr.clone expr.scope.symbol_attrs[expr.name] = expr.type.clone(kind=kind) - if expr.scope and expr.type.initial and expr.name == expr.type.initial: - # FIXME: This is a hack to work around situations where a constant - # symbol (from a parent scope) with the same name as the declared - # variable is used as initializer. This hands down the correct scope - # (in case this traversal is part of ``AttachScopesMapper``) and thus - # interrupts an otherwise infinite recursion (see LOKI-52). + if kwargs['recurse_to_declaration_attributes']: _kwargs = kwargs.copy() - _kwargs['scope'] = expr.scope.parent - initial = self.rec(expr.type.initial, *args, **_kwargs) - elif expr.type.initial and expr.name.lower() in str(expr.type.initial).lower(): - # FIXME: This is another hack to work around situations where the - # variable itself appears in the initializer expression (e.g. to - # inquire value limits via ``HUGE``), which would otherwise result in - # infinite recursion: - # 1. Replace occurences in initializer expression by a temporary variable... - retr = ExpressionRetriever(query=lambda e: e == expr.name) - retr(expr.type.initial) - tmp_map = {e: e.clone(name=f'___tmp___{expr.name}', scope=None) for e in retr.exprs} - tmp_initial = SubstituteExpressionsMapper(tmp_map)(expr.type.initial) - # 2. ...do the recursion into the initializer expression... - tmp_initial = self.rec(tmp_initial, *args, **kwargs) - # 3. ...and reverse the replacement by a temporary variable: - retr = ExpressionRetriever(query=lambda e: e == f'___tmp___{expr.name}') - retr(tmp_initial) - tmp_map = {e: e.clone(name=expr.name) for e in retr.exprs} - initial = SubstituteExpressionsMapper(tmp_map)(tmp_initial) + _kwargs['recurse_to_declaration_attributes'] = False + if expr.scope and expr.type.initial and expr.name == expr.type.initial: + # FIXME: This is a hack to work around situations where a constant + # symbol (from a parent scope) with the same name as the declared + # variable is used as initializer. This hands down the correct scope + # (in case this traversal is part of ``AttachScopesMapper``) and thus + # interrupts an otherwise infinite recursion (see LOKI-52). + _kwargs['scope'] = expr.scope.parent + initial = self.rec(expr.type.initial, *args, **_kwargs) + else: + initial = self.rec(expr.type.initial, *args, **_kwargs) else: - initial = self.rec(expr.type.initial, *args, **kwargs) + initial = expr.type.initial + if initial is not expr.type.initial and expr.scope: # Update symbol table entry for initial directly because with a scope attached # it does not affect the outcome of expr.clone @@ -628,7 +621,13 @@ def map_array(self, expr, *args, **kwargs): # and make sure we don't loose the call parameters (aka dimensions) return InlineCall(function=symbol.clone(parent=parent), parameters=dimensions) - shape = self.rec(symbol.type.shape, *args, **kwargs) + if kwargs['recurse_to_declaration_attributes']: + _kwargs = kwargs.copy() + _kwargs['recurse_to_declaration_attributes'] = False + shape = self.rec(symbol.type.shape, *args, **_kwargs) + else: + shape = symbol.type.shape + if (getattr(symbol, 'symbol', symbol) is expr.symbol and all(d is orig_d for d, orig_d in zip_longest(dimensions or (), expr.dimensions or ())) and all(d is orig_d for d, orig_d in zip_longest(shape or (), symbol.type.shape or ()))): diff --git a/loki/transform/transform_associates.py b/loki/transform/transform_associates.py index 7d1d17f6d..987811fd6 100644 --- a/loki/transform/transform_associates.py +++ b/loki/transform/transform_associates.py @@ -40,7 +40,7 @@ class ResolveAssociatesTransformer(Transformer): corresponding `selector` expression defined in ``associations``. """ - def visit_Associate(self, o): + def visit_Associate(self, o, **kwargs): # First head-recurse, so that all associate blocks beneath are resolved body = self.visit(o.body) diff --git a/loki/visitors/transform.py b/loki/visitors/transform.py index f05510c6d..3113cd6e8 100644 --- a/loki/visitors/transform.py +++ b/loki/visitors/transform.py @@ -246,6 +246,8 @@ def visit(self, o, *args, **kwargs): :any:`Node` or tuple The rebuilt control flow tree. """ + if isinstance(o, Node): + kwargs['current_node'] = o obj = super().visit(o, *args, **kwargs) if isinstance(o, Node) and obj is not o: self.rebuilt[o] = obj @@ -446,6 +448,8 @@ def visit(self, o, *args, **kwargs): # to make sure that we don't include any following nodes we clear start self.start.clear() self.active = False + if isinstance(o, Node): + kwargs['current_node'] = o return super().visit(o, *args, **kwargs) def visit_object(self, o, **kwargs): @@ -503,7 +507,6 @@ class NestedMaskedTransformer(MaskedTransformer): # Handler for leaf nodes - def visit_object(self, o, **kwargs): """ Return the object unchanged. diff --git a/tests/test_expression.py b/tests/test_expression.py index b7f8a60b5..2361b475b 100644 --- a/tests/test_expression.py +++ b/tests/test_expression.py @@ -1313,15 +1313,50 @@ def _check(routine_): zexplimit = routine_.variable_map['zexplimit'] assert zexplimit.scope is routine_ # Now let's take a closer look at the initializer expression - # The way we currently work around the infinite recursion, is - # that we don't attach the scope for the variable in the rhs assert 'zexplimit' in str(zexplimit.type.initial).lower() variables = FindVariables().visit(zexplimit.type.initial) assert 'zexplimit' in variables - assert variables[variables.index('zexplimit')].scope is None + assert variables[variables.index('zexplimit')].scope is routine_ routine = Subroutine.from_source(fcode, frontend=frontend) _check(routine) # Make sure that's still true when doing another scope attachment routine.rescope_symbols() _check(routine) + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_variable_in_dimensions(frontend): + """ + Check correct handling of cases where the variable appears in the + dimensions expression of the same variable (i.e. do not cause + infinite recursion) + """ + fcode = """ +module some_mod + implicit none + + type multi_level + real, allocatable :: data(:, :) + end type multi_level +contains + subroutine some_routine(levels, num_levels) + type(multi_level), intent(inout) :: levels(:) + integer, intent(in) :: num_levels + integer jscale + + do jscale = 2,num_levels + allocate(levels(jscale)%data(size(levels(jscale-1)%data,1), size(levels(jscale-1)%data,2))) + end do + end subroutine some_routine +end module some_mod + """.strip() + + module = Module.from_source(fcode, frontend=frontend) + routine = module['some_routine'] + assert 'levels%data' in routine.symbol_attrs + shape = routine.symbol_attrs['levels%data'].shape + assert len(shape) == 2 + for i, dim in enumerate(shape): + assert isinstance(dim, symbols.InlineCall) + assert str(dim).lower() == f'size(levels(jscale - 1)%data, {i+1})' From 766c201c0a5cd07e2a13b2e9ca119b3e83a55e9f Mon Sep 17 00:00:00 2001 From: Balthasar Reuter Date: Sun, 26 Mar 2023 22:16:27 +0100 Subject: [PATCH 05/10] Retain source object for injected statement functions --- loki/frontend/util.py | 10 +++++++--- tests/test_sourcefile.py | 1 + tests/test_subroutine.py | 1 + 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/loki/frontend/util.py b/loki/frontend/util.py index 9415bee71..76c91ecd4 100644 --- a/loki/frontend/util.py +++ b/loki/frontend/util.py @@ -205,7 +205,7 @@ def inject_statement_functions(routine): def create_stmt_func(assignment): arguments = assignment.lhs.dimensions variable = assignment.lhs.clone(dimensions=None) - return StatementFunction(variable, arguments, assignment.rhs, variable.type) + return StatementFunction(variable, arguments, assignment.rhs, variable.type, source=assignment.source) def create_type(stmt_func): name = str(stmt_func.variable) @@ -253,7 +253,9 @@ def create_type(stmt_func): if variable.name.lower() in stmt_funcs: if isinstance(variable, Array): parameters = variable.dimensions - expr_map_spec[variable] = InlineCall(variable.clone(dimensions=None), parameters=parameters) + expr_map_spec[variable] = InlineCall( + variable.clone(dimensions=None), parameters=parameters, source=variable.source + ) elif not isinstance(variable, ProcedureSymbol): expr_map_spec[variable] = variable.clone() expr_map_body = {} @@ -261,7 +263,9 @@ def create_type(stmt_func): if variable.name.lower() in stmt_funcs: if isinstance(variable, Array): parameters = variable.dimensions - expr_map_body[variable] = InlineCall(variable.clone(dimensions=None), parameters=parameters) + expr_map_body[variable] = InlineCall( + variable.clone(dimensions=None), parameters=parameters, source=variable.source + ) elif not isinstance(variable, ProcedureSymbol): expr_map_body[variable] = variable.clone() diff --git a/tests/test_sourcefile.py b/tests/test_sourcefile.py index 4c4766a97..64714fc19 100644 --- a/tests/test_sourcefile.py +++ b/tests/test_sourcefile.py @@ -216,6 +216,7 @@ def test_sourcefile_cpp_stmt_func(here, frontend): assert isinstance(var, ProcedureSymbol) assert isinstance(var.type.dtype, ProcedureType) assert var.type.dtype.procedure is decl + assert decl.source is not None # Generate code and compile filepath = here/f'{module.name}.f90' diff --git a/tests/test_subroutine.py b/tests/test_subroutine.py index 726582d6e..fb08c049d 100644 --- a/tests/test_subroutine.py +++ b/tests/test_subroutine.py @@ -1520,6 +1520,7 @@ def test_subroutine_stmt_func(here, frontend): assert isinstance(var, ProcedureSymbol) assert isinstance(var.type.dtype, ProcedureType) assert var.type.dtype.procedure is stmt_func_decls[var] + assert stmt_func_decls[var].source is not None # Make sure this produces the correct result filepath = here/f'{routine.name}.f90' From 86be4d52eea527bab5ce5edaf231e5c864f43ec7 Mon Sep 17 00:00:00 2001 From: Balthasar Reuter Date: Sun, 26 Mar 2023 22:41:53 +0100 Subject: [PATCH 06/10] Fix infinite recursion on abstract procedure declaration --- loki/expression/mappers.py | 28 +++++++++++++----- loki/frontend/ofp.py | 2 +- loki/ir.py | 10 ++++++- loki/transform/transform_associates.py | 2 +- tests/test_derived_types.py | 39 ++++++++++++++++++++++++++ 5 files changed, 71 insertions(+), 10 deletions(-) diff --git a/loki/expression/mappers.py b/loki/expression/mappers.py index 5c435fb3d..900bb017e 100644 --- a/loki/expression/mappers.py +++ b/loki/expression/mappers.py @@ -23,7 +23,7 @@ except ImportError: _intrinsic_fortran_names = () -from loki.ir import VariableDeclaration, Import +from loki.ir import DECLARATION_NODES from loki.logging import debug from loki.tools import as_tuple, flatten from loki.types import SymbolAttributes, BasicType @@ -523,7 +523,7 @@ def __call__(self, expr, *args, **kwargs): return None kwargs.setdefault( 'recurse_to_declaration_attributes', - 'current_node' not in kwargs or isinstance(kwargs['current_node'], (VariableDeclaration, Import)) + 'current_node' not in kwargs or isinstance(kwargs['current_node'], DECLARATION_NODES) ) new_expr = super().__call__(expr, *args, **kwargs) if getattr(expr, 'source', None): @@ -583,11 +583,25 @@ def map_variable_symbol(self, expr, *args, **kwargs): # it does not affect the outcome of expr.clone expr.scope.symbol_attrs[expr.name] = expr.type.clone(initial=initial) - bind_names = self.rec(expr.type.bind_names, *args, **kwargs) - if not (bind_names is None or all(new is old for new, old in zip_longest(bind_names, expr.type.bind_names))): - # Update symbol table entry for bind_names directly because with a scope attached - # it does not affect the outcome of expr.clone - expr.scope.symbol_attrs[expr.name] = expr.type.clone(bind_names=as_tuple(bind_names)) + if kwargs['recurse_to_declaration_attributes']: + _kwargs = kwargs.copy() + _kwargs['recurse_to_declaration_attributes'] = False + if (old_bind_names := expr.type.bind_names): + bind_names = () + for bind_name in old_bind_names: + if bind_name == expr.name: + # FIXME: This is a hack to work around situations where an + # explicit interface is used with the same name as the + # type bound procedure. This hands down the correct scope. + __kwargs = _kwargs.copy() + __kwargs['scope'] = expr.scope.parent + bind_names += (self.rec(bind_name, *args, **__kwargs),) + else: + bind_names += (self.rec(bind_name, *args, **_kwargs),) + if bind_names and any(new is not old for new, old in zip_longest(bind_names, expr.type.bind_names)): + # Update symbol table entry for bind_names directly because with a scope attached + # it does not affect the outcome of expr.clone + expr.scope.symbol_attrs[expr.name] = expr.type.clone(bind_names=bind_names) parent = self.rec(expr.parent, *args, **kwargs) if parent is expr.parent and (kind is expr.type.kind or expr.scope): diff --git a/loki/frontend/ofp.py b/loki/frontend/ofp.py index 742b95b16..25d5aa39c 100644 --- a/loki/frontend/ofp.py +++ b/loki/frontend/ofp.py @@ -609,7 +609,7 @@ def visit_specific_binding(self, o, **kwargs): type=SymbolAttributes(ProcedureType(interface)) ) - _type = interface.type + _type = interface.type.clone(bind_names=(interface,)) elif o.attrib['procedureName']: # Binding provided ( => ) diff --git a/loki/ir.py b/loki/ir.py index d1680a2cc..be93a09a6 100644 --- a/loki/ir.py +++ b/loki/ir.py @@ -39,7 +39,9 @@ 'Comment', 'CommentBlock', 'Pragma', 'PreprocessorDirective', 'Import', 'VariableDeclaration', 'ProcedureDeclaration', 'DataDeclaration', 'StatementFunction', 'TypeDef', 'MultiConditional', 'MaskedStatement', - 'Intrinsic', 'Enumeration', 'RawSource' + 'Intrinsic', 'Enumeration', 'RawSource', + # List of nodes with specific properties + 'DECLARATION_NODES' ] # Configuration for validation mechanism via pydantic @@ -1821,3 +1823,9 @@ class RawSource(LeafNode, _RawSourceBase): def __repr__(self): return f'RawSource:: {truncate_string(self.text.strip())}' + + +DECLARATION_NODES = (Import, VariableDeclaration, ProcedureDeclaration) +""" +List of IR nodes that are considered to be the authority on a symbol's attributes +""" diff --git a/loki/transform/transform_associates.py b/loki/transform/transform_associates.py index 987811fd6..706be7b34 100644 --- a/loki/transform/transform_associates.py +++ b/loki/transform/transform_associates.py @@ -40,7 +40,7 @@ class ResolveAssociatesTransformer(Transformer): corresponding `selector` expression defined in ``associations``. """ - def visit_Associate(self, o, **kwargs): + def visit_Associate(self, o, **kwargs): # pylint: disable=unused-argument # First head-recurse, so that all associate blocks beneath are resolved body = self.visit(o.body) diff --git a/tests/test_derived_types.py b/tests/test_derived_types.py index e26beb10d..87b8da3a0 100644 --- a/tests/test_derived_types.py +++ b/tests/test_derived_types.py @@ -1359,3 +1359,42 @@ def test_derived_types_nested_type(frontend): assert assignment.rhs.parent.type.dtype.typedef is module['some_type'] assert assignment.rhs.parent.parent.type.dtype.name == 'other_type' assert assignment.rhs.parent.parent.type.dtype.typedef is module['other_type'] + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_derived_types_abstract_deferred_procedure(frontend): + fcode = """ +module some_mod + implicit none + type, abstract :: abstract_type + contains + procedure (some_proc), deferred :: some_proc + procedure (other_proc), deferred :: other_proc + end type abstract_type + + abstract interface + subroutine some_proc(this) + import abstract_type + class(abstract_type), intent(in) :: this + end subroutine some_proc + end interface + + abstract interface + subroutine other_proc(this) + import abstract_type + class(abstract_type), intent(inout) :: this + end subroutine other_proc + end interface +end module some_mod + """.strip() + module = Module.from_source(fcode, frontend=frontend) + typedef = module['abstract_type'] + assert typedef.abstract is True + assert typedef.variables == ('some_proc', 'other_proc') + for symbol in typedef.variables: + assert isinstance(symbol, ProcedureSymbol) + assert isinstance(symbol.type.dtype, ProcedureType) + assert symbol.type.dtype.name.lower() == symbol.name.lower() + assert symbol.type.bind_names == (symbol,) + assert symbol.scope is typedef + assert symbol.type.bind_names[0].scope is module From 3b6f00dc187010239fd23f67da48b6a42717e64a Mon Sep 17 00:00:00 2001 From: Balthasar Reuter Date: Wed, 12 Apr 2023 18:39:07 +0100 Subject: [PATCH 07/10] Workaround for fparser behaviour --- loki/frontend/fparser.py | 13 +++++++++-- tests/test_control_flow.py | 44 +++++++++++++++++++++++++++++++++++++- 2 files changed, 54 insertions(+), 3 deletions(-) diff --git a/loki/frontend/fparser.py b/loki/frontend/fparser.py index 1efc3c8ac..42c8a42ac 100644 --- a/loki/frontend/fparser.py +++ b/loki/frontend/fparser.py @@ -2100,8 +2100,17 @@ def visit_If_Construct(self, o, **kwargs): else_if_stmt_index, else_if_stmts = zip(*else_if_stmts) else: else_if_stmt_index = () - else_stmt = get_child(o, Fortran2003.Else_Stmt) - else_stmt_index = o.children.index(else_stmt) if else_stmt else end_if_stmt_index + + # Note: we need to use here the same method as for else-if because finding Else_Stmt + # directly and checking its position via o.children.index may give the wrong result. + # This is because Else_Stmt may erronously compare equal to other node types. + # See https://github.com/stfc/fparser/issues/400 + else_stmt = tuple((i, c) for i, c in enumerate(o.children) if isinstance(c, Fortran2003.Else_Stmt)) + if else_stmt: + assert len(else_stmt) == 1 + else_stmt_index, else_stmt = else_stmt[0] + else: + else_stmt_index = end_if_stmt_index conditions = as_tuple(self.visit(c, **kwargs) for c in (if_then_stmt,) + else_if_stmts) bodies = tuple( tuple(flatten(as_tuple(self.visit(c, **kwargs) for c in o.children[start+1:stop]))) diff --git a/tests/test_control_flow.py b/tests/test_control_flow.py index dbccb6b30..5b558f93b 100644 --- a/tests/test_control_flow.py +++ b/tests/test_control_flow.py @@ -10,7 +10,7 @@ import numpy as np from conftest import jit_compile, clean_test, available_frontends -from loki import OMNI, Subroutine, FindNodes, Loop, Conditional, Node +from loki import OMNI, Subroutine, FindNodes, Loop, Conditional, Node, Intrinsic @pytest.fixture(scope='module', name='here') @@ -455,3 +455,45 @@ def test_conditional_bodies(frontend): c.else_body and isinstance(c.else_body, tuple) and all(isinstance(n, Node) for n in c.else_body) for c in conditionals ) + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_conditional_else_body_return(frontend): + fcode = """ +FUNCTION FUNC(PX,KN) +IMPLICIT NONE +INTEGER,INTENT(INOUT) :: KN +REAL,INTENT(IN) :: PX +REAL :: FUNC +INTEGER :: J +REAL :: Z0, Z1, Z2 +Z0= 1.0 +Z1= PX +IF (KN == 0) THEN + FUNC= Z0 + RETURN +ELSEIF (KN == 1) THEN + FUNC= Z1 + RETURN +ELSE + DO J=2,KN + Z2= Z0+Z1 + Z0= Z1 + Z1= Z2 + ENDDO + FUNC= Z2 + RETURN +ENDIF +END FUNCTION FUNC + """.strip() + + routine = Subroutine.from_source(fcode, frontend=frontend) + conditionals = FindNodes(Conditional).visit(routine.body) + assert len(conditionals) == 2 + assert isinstance(conditionals[0].body[-1], Intrinsic) + assert conditionals[0].body[-1].text.upper() == 'RETURN' + assert conditionals[0].else_body == (conditionals[1],) + assert isinstance(conditionals[1].body[-1], Intrinsic) + assert conditionals[1].body[-1].text.upper() == 'RETURN' + assert isinstance(conditionals[1].else_body[-1], Intrinsic) + assert conditionals[1].else_body[-1].text.upper() == 'RETURN' From 64d25b5608efe644b811cdabd2fec3ae63875bf2 Mon Sep 17 00:00:00 2001 From: Balthasar Reuter Date: Wed, 3 May 2023 14:54:31 +0100 Subject: [PATCH 08/10] Set `recurse_to_declaration_attributes` in the control flow visitors before dispatching to expr tree mappers --- loki/expression/expr_visitors.py | 35 ++++++++++++++++++++++++++++++-- loki/expression/mappers.py | 6 +----- loki/ir.py | 8 -------- loki/visitors/transform.py | 4 ---- 4 files changed, 34 insertions(+), 19 deletions(-) diff --git a/loki/expression/expr_visitors.py b/loki/expression/expr_visitors.py index bc9edb58c..bc2673954 100644 --- a/loki/expression/expr_visitors.py +++ b/loki/expression/expr_visitors.py @@ -297,7 +297,24 @@ def __init__(self, expr_map, invalidate_source=True, **kwargs): invalidate_source=invalidate_source) def visit_Expression(self, o, **kwargs): - return self.expr_mapper(o, parent_node=kwargs.get('current_node')) + """ + call :any:`SubstituteExpressionsMapper` for the given expression node + """ + if kwargs.get('recurse_to_declaration_attributes'): + return self.expr_mapper(o, recurse_to_declaration_attributes=True) + return self.expr_mapper(o) + + def visit_Import(self, o, **kwargs): + """ + For :any:`Import` (as well as :any:`VariableDeclaration` and :any:`ProcedureDeclaration`) + we set ``recurse_to_declaration_attributes=True`` to make sure properties in the symbol + table are updated during dispatch to the expression mapper + """ + kwargs['recurse_to_declaration_attributes'] = True + return super().visit_Node(o, **kwargs) + + visit_VariableDeclaration = visit_Import + visit_ProcedureDeclaration = visit_Import class AttachScopes(Visitor): @@ -346,7 +363,9 @@ def visit_Expression(self, o, **kwargs): """ Dispatch :any:`AttachScopesMapper` for :any:`Expression` tree nodes """ - return self.expr_mapper(o, scope=kwargs['scope'], current_node=kwargs.get('current_node')) + if kwargs.get('recurse_to_declaration_attributes'): + return self.expr_mapper(o, scope=kwargs['scope'], recurse_to_declaration_attributes=True) + return self.expr_mapper(o, scope=kwargs['scope']) def visit_list(self, o, **kwargs): """ @@ -365,6 +384,18 @@ def visit_Node(self, o, **kwargs): children = tuple(self.visit(i, **kwargs) for i in o.children) return self._update(o, children) + def visit_Import(self, o, **kwargs): + """ + For :any:`Import` (as well as :any:`VariableDeclaration` and :any:`ProcedureDeclaration`) + we set ``recurse_to_declaration_attributes=True`` to make sure properties in the symbol + table are updated during dispatch to the expression mapper + """ + kwargs['recurse_to_declaration_attributes'] = True + return self.visit_Node(o, **kwargs) + + visit_VariableDeclaration = visit_Import + visit_ProcedureDeclaration = visit_Import + def visit_Scope(self, o, **kwargs): """ Generic handler for :any:`Scope` objects diff --git a/loki/expression/mappers.py b/loki/expression/mappers.py index 900bb017e..fe8663955 100644 --- a/loki/expression/mappers.py +++ b/loki/expression/mappers.py @@ -23,7 +23,6 @@ except ImportError: _intrinsic_fortran_names = () -from loki.ir import DECLARATION_NODES from loki.logging import debug from loki.tools import as_tuple, flatten from loki.types import SymbolAttributes, BasicType @@ -521,10 +520,7 @@ def __init__(self, invalidate_source=True): def __call__(self, expr, *args, **kwargs): if expr is None: return None - kwargs.setdefault( - 'recurse_to_declaration_attributes', - 'current_node' not in kwargs or isinstance(kwargs['current_node'], DECLARATION_NODES) - ) + kwargs.setdefault('recurse_to_declaration_attributes', False) new_expr = super().__call__(expr, *args, **kwargs) if getattr(expr, 'source', None): if isinstance(new_expr, tuple): diff --git a/loki/ir.py b/loki/ir.py index be93a09a6..885715790 100644 --- a/loki/ir.py +++ b/loki/ir.py @@ -40,8 +40,6 @@ 'Import', 'VariableDeclaration', 'ProcedureDeclaration', 'DataDeclaration', 'StatementFunction', 'TypeDef', 'MultiConditional', 'MaskedStatement', 'Intrinsic', 'Enumeration', 'RawSource', - # List of nodes with specific properties - 'DECLARATION_NODES' ] # Configuration for validation mechanism via pydantic @@ -1823,9 +1821,3 @@ class RawSource(LeafNode, _RawSourceBase): def __repr__(self): return f'RawSource:: {truncate_string(self.text.strip())}' - - -DECLARATION_NODES = (Import, VariableDeclaration, ProcedureDeclaration) -""" -List of IR nodes that are considered to be the authority on a symbol's attributes -""" diff --git a/loki/visitors/transform.py b/loki/visitors/transform.py index 3113cd6e8..45af0fe5d 100644 --- a/loki/visitors/transform.py +++ b/loki/visitors/transform.py @@ -246,8 +246,6 @@ def visit(self, o, *args, **kwargs): :any:`Node` or tuple The rebuilt control flow tree. """ - if isinstance(o, Node): - kwargs['current_node'] = o obj = super().visit(o, *args, **kwargs) if isinstance(o, Node) and obj is not o: self.rebuilt[o] = obj @@ -448,8 +446,6 @@ def visit(self, o, *args, **kwargs): # to make sure that we don't include any following nodes we clear start self.start.clear() self.active = False - if isinstance(o, Node): - kwargs['current_node'] = o return super().visit(o, *args, **kwargs) def visit_object(self, o, **kwargs): From 0ee83015774c35250f297c63bd583b2098e08fed Mon Sep 17 00:00:00 2001 From: Balthasar Reuter Date: Wed, 3 May 2023 16:13:43 +0100 Subject: [PATCH 09/10] Make control flow in map_variable_symbol cleaner --- loki/expression/mappers.py | 74 ++++++++++++++++++++------------------ 1 file changed, 40 insertions(+), 34 deletions(-) diff --git a/loki/expression/mappers.py b/loki/expression/mappers.py index fe8663955..23db7a164 100644 --- a/loki/expression/mappers.py +++ b/loki/expression/mappers.py @@ -552,57 +552,63 @@ def map_int_literal(self, expr, *args, **kwargs): map_float_literal = map_int_literal def map_variable_symbol(self, expr, *args, **kwargs): - kind = self.rec(expr.type.kind, *args, **kwargs) - if kind is not expr.type.kind and expr.scope: - # Update symbol table entry for kind directly because with a scope attached - # it does not affect the outcome of expr.clone - expr.scope.symbol_attrs[expr.name] = expr.type.clone(kind=kind) - - if kwargs['recurse_to_declaration_attributes']: - _kwargs = kwargs.copy() - _kwargs['recurse_to_declaration_attributes'] = False - if expr.scope and expr.type.initial and expr.name == expr.type.initial: + # When updating declaration attributes, which are stored in the symbol table, + # we need to disable `recurse_to_declaration_attributes` to avoid infinite + # recursion because of the various ways that Fortran allows to use the declared + # symbol also inside the declaration expression + recurse_to_declaration_attributes = kwargs['recurse_to_declaration_attributes'] or expr.scope is None + kwargs['recurse_to_declaration_attributes'] = False + + if recurse_to_declaration_attributes: + old_type = expr.type + kind = self.rec(old_type.kind, *args, **kwargs) + + if expr.scope and expr.name == old_type.initial: # FIXME: This is a hack to work around situations where a constant # symbol (from a parent scope) with the same name as the declared # variable is used as initializer. This hands down the correct scope # (in case this traversal is part of ``AttachScopesMapper``) and thus # interrupts an otherwise infinite recursion (see LOKI-52). + _kwargs = kwargs.copy() _kwargs['scope'] = expr.scope.parent - initial = self.rec(expr.type.initial, *args, **_kwargs) + initial = self.rec(old_type.initial, *args, **_kwargs) else: - initial = self.rec(expr.type.initial, *args, **_kwargs) - else: - initial = expr.type.initial - - if initial is not expr.type.initial and expr.scope: - # Update symbol table entry for initial directly because with a scope attached - # it does not affect the outcome of expr.clone - expr.scope.symbol_attrs[expr.name] = expr.type.clone(initial=initial) + initial = self.rec(old_type.initial, *args, **kwargs) - if kwargs['recurse_to_declaration_attributes']: - _kwargs = kwargs.copy() - _kwargs['recurse_to_declaration_attributes'] = False - if (old_bind_names := expr.type.bind_names): + if old_type.bind_names: bind_names = () - for bind_name in old_bind_names: + for bind_name in old_type.bind_names: if bind_name == expr.name: # FIXME: This is a hack to work around situations where an # explicit interface is used with the same name as the # type bound procedure. This hands down the correct scope. - __kwargs = _kwargs.copy() - __kwargs['scope'] = expr.scope.parent - bind_names += (self.rec(bind_name, *args, **__kwargs),) - else: + _kwargs = kwargs.copy() + _kwargs['scope'] = expr.scope.parent bind_names += (self.rec(bind_name, *args, **_kwargs),) - if bind_names and any(new is not old for new, old in zip_longest(bind_names, expr.type.bind_names)): - # Update symbol table entry for bind_names directly because with a scope attached - # it does not affect the outcome of expr.clone - expr.scope.symbol_attrs[expr.name] = expr.type.clone(bind_names=bind_names) + else: + bind_names += (self.rec(bind_name, *args, **kwargs),) + else: + bind_names = None + + is_type_changed = ( + kind is not old_type.kind or initial is not old_type.initial or + any(new is not old for new, old in zip_longest(as_tuple(bind_names), as_tuple(old_type.bind_names))) + ) + if is_type_changed: + new_type = old_type.clone(kind=kind, initial=initial, bind_names=bind_names) + if expr.scope: + # Update symbol table entry + expr.scope.symbol_attrs[expr.name] = new_type parent = self.rec(expr.parent, *args, **kwargs) - if parent is expr.parent and (kind is expr.type.kind or expr.scope): + if expr.scope is None: + if parent is expr.parent and not is_type_changed: + return expr + return expr.clone(parent=parent, type=new_type) + + if parent is expr.parent: return expr - return expr.clone(parent=parent, type=expr.type.clone(kind=kind)) + return expr.clone(parent=parent) map_deferred_type_symbol = map_variable_symbol map_procedure_symbol = map_variable_symbol From 9cef8f4aa08adf4211c31ff9c0eb0569cc611581 Mon Sep 17 00:00:00 2001 From: Balthasar Reuter Date: Wed, 3 May 2023 16:25:12 +0100 Subject: [PATCH 10/10] Eliminate another stray `current_node` argument --- loki/expression/expr_visitors.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/loki/expression/expr_visitors.py b/loki/expression/expr_visitors.py index bc2673954..fc3fcfde0 100644 --- a/loki/expression/expr_visitors.py +++ b/loki/expression/expr_visitors.py @@ -355,8 +355,6 @@ def visit(self, o, *args, **kwargs): Default visitor method that dispatches the node-specific handler """ kwargs.setdefault('scope', None) - if isinstance(o, Node): - kwargs['current_node'] = o return super().visit(o, *args, **kwargs) def visit_Expression(self, o, **kwargs):