diff --git a/loki/expression/expr_visitors.py b/loki/expression/expr_visitors.py index a3de74e35..fc3fcfde0 100644 --- a/loki/expression/expr_visitors.py +++ b/loki/expression/expr_visitors.py @@ -297,8 +297,25 @@ def __init__(self, expr_map, invalidate_source=True, **kwargs): invalidate_source=invalidate_source) def visit_Expression(self, o, **kwargs): + """ + call :any:`SubstituteExpressionsMapper` for the given expression node + """ + if kwargs.get('recurse_to_declaration_attributes'): + return self.expr_mapper(o, recurse_to_declaration_attributes=True) return self.expr_mapper(o) + def visit_Import(self, o, **kwargs): + """ + For :any:`Import` (as well as :any:`VariableDeclaration` and :any:`ProcedureDeclaration`) + we set ``recurse_to_declaration_attributes=True`` to make sure properties in the symbol + table are updated during dispatch to the expression mapper + """ + kwargs['recurse_to_declaration_attributes'] = True + return super().visit_Node(o, **kwargs) + + visit_VariableDeclaration = visit_Import + visit_ProcedureDeclaration = visit_Import + class AttachScopes(Visitor): """ @@ -344,6 +361,8 @@ def visit_Expression(self, o, **kwargs): """ Dispatch :any:`AttachScopesMapper` for :any:`Expression` tree nodes """ + if kwargs.get('recurse_to_declaration_attributes'): + return self.expr_mapper(o, scope=kwargs['scope'], recurse_to_declaration_attributes=True) return self.expr_mapper(o, scope=kwargs['scope']) def visit_list(self, o, **kwargs): @@ -363,6 +382,18 @@ def visit_Node(self, o, **kwargs): children = tuple(self.visit(i, **kwargs) for i in o.children) return self._update(o, children) + def visit_Import(self, o, **kwargs): + """ + For :any:`Import` (as well as :any:`VariableDeclaration` and :any:`ProcedureDeclaration`) + we set ``recurse_to_declaration_attributes=True`` to make sure properties in the symbol + table are updated during dispatch to the expression mapper + """ + kwargs['recurse_to_declaration_attributes'] = True + return self.visit_Node(o, **kwargs) + + visit_VariableDeclaration = visit_Import + visit_ProcedureDeclaration = visit_Import + def visit_Scope(self, o, **kwargs): """ Generic handler for :any:`Scope` objects diff --git a/loki/expression/mappers.py b/loki/expression/mappers.py index 7119a033c..23db7a164 100644 --- a/loki/expression/mappers.py +++ b/loki/expression/mappers.py @@ -520,6 +520,7 @@ def __init__(self, invalidate_source=True): def __call__(self, expr, *args, **kwargs): if expr is None: return None + kwargs.setdefault('recurse_to_declaration_attributes', False) new_expr = super().__call__(expr, *args, **kwargs) if getattr(expr, 'source', None): if isinstance(new_expr, tuple): @@ -551,38 +552,63 @@ def map_int_literal(self, expr, *args, **kwargs): map_float_literal = map_int_literal def map_variable_symbol(self, expr, *args, **kwargs): - kind = self.rec(expr.type.kind, *args, **kwargs) - if kind is not expr.type.kind and expr.scope: - # Update symbol table entry for kind directly because with a scope attached - # it does not affect the outcome of expr.clone - expr.scope.symbol_attrs[expr.name] = expr.type.clone(kind=kind) - - if expr.scope and expr.type.initial and expr.name == expr.type.initial: - # FIXME: This is a hack to work around situations where a constant - # symbol (from a parent scope) with the same name as the declared - # variable is used as initializer. This hands down the correct scope - # (in case this traversal is part of ``AttachScopesMapper``) and thus - # interrupts an otherwise infinite recursion (see LOKI-52). - _kwargs = kwargs.copy() - _kwargs['scope'] = expr.scope.parent - initial = self.rec(expr.type.initial, *args, **_kwargs) - else: - initial = self.rec(expr.type.initial, *args, **kwargs) - if initial is not expr.type.initial and expr.scope: - # Update symbol table entry for initial directly because with a scope attached - # it does not affect the outcome of expr.clone - expr.scope.symbol_attrs[expr.name] = expr.type.clone(initial=initial) - - bind_names = self.rec(expr.type.bind_names, *args, **kwargs) - if not (bind_names is None or all(new is old for new, old in zip_longest(bind_names, expr.type.bind_names))): - # Update symbol table entry for bind_names directly because with a scope attached - # it does not affect the outcome of expr.clone - expr.scope.symbol_attrs[expr.name] = expr.type.clone(bind_names=as_tuple(bind_names)) + # When updating declaration attributes, which are stored in the symbol table, + # we need to disable `recurse_to_declaration_attributes` to avoid infinite + # recursion because of the various ways that Fortran allows to use the declared + # symbol also inside the declaration expression + recurse_to_declaration_attributes = kwargs['recurse_to_declaration_attributes'] or expr.scope is None + kwargs['recurse_to_declaration_attributes'] = False + + if recurse_to_declaration_attributes: + old_type = expr.type + kind = self.rec(old_type.kind, *args, **kwargs) + + if expr.scope and expr.name == old_type.initial: + # FIXME: This is a hack to work around situations where a constant + # symbol (from a parent scope) with the same name as the declared + # variable is used as initializer. This hands down the correct scope + # (in case this traversal is part of ``AttachScopesMapper``) and thus + # interrupts an otherwise infinite recursion (see LOKI-52). + _kwargs = kwargs.copy() + _kwargs['scope'] = expr.scope.parent + initial = self.rec(old_type.initial, *args, **_kwargs) + else: + initial = self.rec(old_type.initial, *args, **kwargs) + + if old_type.bind_names: + bind_names = () + for bind_name in old_type.bind_names: + if bind_name == expr.name: + # FIXME: This is a hack to work around situations where an + # explicit interface is used with the same name as the + # type bound procedure. This hands down the correct scope. + _kwargs = kwargs.copy() + _kwargs['scope'] = expr.scope.parent + bind_names += (self.rec(bind_name, *args, **_kwargs),) + else: + bind_names += (self.rec(bind_name, *args, **kwargs),) + else: + bind_names = None + + is_type_changed = ( + kind is not old_type.kind or initial is not old_type.initial or + any(new is not old for new, old in zip_longest(as_tuple(bind_names), as_tuple(old_type.bind_names))) + ) + if is_type_changed: + new_type = old_type.clone(kind=kind, initial=initial, bind_names=bind_names) + if expr.scope: + # Update symbol table entry + expr.scope.symbol_attrs[expr.name] = new_type parent = self.rec(expr.parent, *args, **kwargs) - if parent is expr.parent and (kind is expr.type.kind or expr.scope): + if expr.scope is None: + if parent is expr.parent and not is_type_changed: + return expr + return expr.clone(parent=parent, type=new_type) + + if parent is expr.parent: return expr - return expr.clone(parent=parent, type=expr.type.clone(kind=kind)) + return expr.clone(parent=parent) map_deferred_type_symbol = map_variable_symbol map_procedure_symbol = map_variable_symbol @@ -611,7 +637,13 @@ def map_array(self, expr, *args, **kwargs): # and make sure we don't loose the call parameters (aka dimensions) return InlineCall(function=symbol.clone(parent=parent), parameters=dimensions) - shape = self.rec(symbol.type.shape, *args, **kwargs) + if kwargs['recurse_to_declaration_attributes']: + _kwargs = kwargs.copy() + _kwargs['recurse_to_declaration_attributes'] = False + shape = self.rec(symbol.type.shape, *args, **_kwargs) + else: + shape = symbol.type.shape + if (getattr(symbol, 'symbol', symbol) is expr.symbol and all(d is orig_d for d, orig_d in zip_longest(dimensions or (), expr.dimensions or ())) and all(d is orig_d for d, orig_d in zip_longest(shape or (), symbol.type.shape or ()))): diff --git a/loki/frontend/fparser.py b/loki/frontend/fparser.py index a7ea09121..42c8a42ac 100644 --- a/loki/frontend/fparser.py +++ b/loki/frontend/fparser.py @@ -1782,7 +1782,8 @@ def visit_Subroutine_Subprogram(self, o, **kwargs): if return_type is not None: routine.symbol_attrs[routine.name] = return_type return_var = sym.Variable(name=routine.name, scope=routine) - return_var_decl = ir.VariableDeclaration(symbols=(return_var,)) + decl_source = self.get_source(subroutine_stmt, source=None) + return_var_decl = ir.VariableDeclaration(symbols=(return_var,), source=decl_source) decls = FindNodes((ir.VariableDeclaration, ir.ProcedureDeclaration)).visit(spec) if not decls: @@ -2099,8 +2100,17 @@ def visit_If_Construct(self, o, **kwargs): else_if_stmt_index, else_if_stmts = zip(*else_if_stmts) else: else_if_stmt_index = () - else_stmt = get_child(o, Fortran2003.Else_Stmt) - else_stmt_index = o.children.index(else_stmt) if else_stmt else end_if_stmt_index + + # Note: we need to use here the same method as for else-if because finding Else_Stmt + # directly and checking its position via o.children.index may give the wrong result. + # This is because Else_Stmt may erronously compare equal to other node types. + # See https://github.com/stfc/fparser/issues/400 + else_stmt = tuple((i, c) for i, c in enumerate(o.children) if isinstance(c, Fortran2003.Else_Stmt)) + if else_stmt: + assert len(else_stmt) == 1 + else_stmt_index, else_stmt = else_stmt[0] + else: + else_stmt_index = end_if_stmt_index conditions = as_tuple(self.visit(c, **kwargs) for c in (if_then_stmt,) + else_if_stmts) bodies = tuple( tuple(flatten(as_tuple(self.visit(c, **kwargs) for c in o.children[start+1:stop]))) diff --git a/loki/frontend/ofp.py b/loki/frontend/ofp.py index 742b95b16..25d5aa39c 100644 --- a/loki/frontend/ofp.py +++ b/loki/frontend/ofp.py @@ -609,7 +609,7 @@ def visit_specific_binding(self, o, **kwargs): type=SymbolAttributes(ProcedureType(interface)) ) - _type = interface.type + _type = interface.type.clone(bind_names=(interface,)) elif o.attrib['procedureName']: # Binding provided ( => ) diff --git a/loki/frontend/util.py b/loki/frontend/util.py index 603e35c58..76c91ecd4 100644 --- a/loki/frontend/util.py +++ b/loki/frontend/util.py @@ -205,7 +205,7 @@ def inject_statement_functions(routine): def create_stmt_func(assignment): arguments = assignment.lhs.dimensions variable = assignment.lhs.clone(dimensions=None) - return StatementFunction(variable, arguments, assignment.rhs, variable.type) + return StatementFunction(variable, arguments, assignment.rhs, variable.type, source=assignment.source) def create_type(stmt_func): name = str(stmt_func.variable) @@ -253,7 +253,9 @@ def create_type(stmt_func): if variable.name.lower() in stmt_funcs: if isinstance(variable, Array): parameters = variable.dimensions - expr_map_spec[variable] = InlineCall(variable.clone(dimensions=None), parameters=parameters) + expr_map_spec[variable] = InlineCall( + variable.clone(dimensions=None), parameters=parameters, source=variable.source + ) elif not isinstance(variable, ProcedureSymbol): expr_map_spec[variable] = variable.clone() expr_map_body = {} @@ -261,7 +263,9 @@ def create_type(stmt_func): if variable.name.lower() in stmt_funcs: if isinstance(variable, Array): parameters = variable.dimensions - expr_map_body[variable] = InlineCall(variable.clone(dimensions=None), parameters=parameters) + expr_map_body[variable] = InlineCall( + variable.clone(dimensions=None), parameters=parameters, source=variable.source + ) elif not isinstance(variable, ProcedureSymbol): expr_map_body[variable] = variable.clone() @@ -271,15 +275,15 @@ def create_type(stmt_func): # Apply transformer with the built maps if spec_map: - routine.spec = Transformer(spec_map).visit(routine.spec) + routine.spec = Transformer(spec_map, invalidate_source=False).visit(routine.spec) if body_map: - routine.body = Transformer(body_map).visit(routine.body) + routine.body = Transformer(body_map, invalidate_source=False).visit(routine.body) if spec_appendix: routine.spec.append(spec_appendix) if expr_map_spec: - routine.spec = SubstituteExpressions(expr_map_spec).visit(routine.spec) + routine.spec = SubstituteExpressions(expr_map_spec, invalidate_source=False).visit(routine.spec) if expr_map_body: - routine.body = SubstituteExpressions(expr_map_body).visit(routine.body) + routine.body = SubstituteExpressions(expr_map_body, invalidate_source=False).visit(routine.body) # And make sure all symbols have the right type routine.rescope_symbols() diff --git a/loki/ir.py b/loki/ir.py index d1680a2cc..885715790 100644 --- a/loki/ir.py +++ b/loki/ir.py @@ -39,7 +39,7 @@ 'Comment', 'CommentBlock', 'Pragma', 'PreprocessorDirective', 'Import', 'VariableDeclaration', 'ProcedureDeclaration', 'DataDeclaration', 'StatementFunction', 'TypeDef', 'MultiConditional', 'MaskedStatement', - 'Intrinsic', 'Enumeration', 'RawSource' + 'Intrinsic', 'Enumeration', 'RawSource', ] # Configuration for validation mechanism via pydantic diff --git a/loki/transform/transform_associates.py b/loki/transform/transform_associates.py index 7d1d17f6d..706be7b34 100644 --- a/loki/transform/transform_associates.py +++ b/loki/transform/transform_associates.py @@ -40,7 +40,7 @@ class ResolveAssociatesTransformer(Transformer): corresponding `selector` expression defined in ``associations``. """ - def visit_Associate(self, o): + def visit_Associate(self, o, **kwargs): # pylint: disable=unused-argument # First head-recurse, so that all associate blocks beneath are resolved body = self.visit(o.body) diff --git a/loki/visitors/transform.py b/loki/visitors/transform.py index f05510c6d..45af0fe5d 100644 --- a/loki/visitors/transform.py +++ b/loki/visitors/transform.py @@ -503,7 +503,6 @@ class NestedMaskedTransformer(MaskedTransformer): # Handler for leaf nodes - def visit_object(self, o, **kwargs): """ Return the object unchanged. diff --git a/tests/test_control_flow.py b/tests/test_control_flow.py index dbccb6b30..5b558f93b 100644 --- a/tests/test_control_flow.py +++ b/tests/test_control_flow.py @@ -10,7 +10,7 @@ import numpy as np from conftest import jit_compile, clean_test, available_frontends -from loki import OMNI, Subroutine, FindNodes, Loop, Conditional, Node +from loki import OMNI, Subroutine, FindNodes, Loop, Conditional, Node, Intrinsic @pytest.fixture(scope='module', name='here') @@ -455,3 +455,45 @@ def test_conditional_bodies(frontend): c.else_body and isinstance(c.else_body, tuple) and all(isinstance(n, Node) for n in c.else_body) for c in conditionals ) + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_conditional_else_body_return(frontend): + fcode = """ +FUNCTION FUNC(PX,KN) +IMPLICIT NONE +INTEGER,INTENT(INOUT) :: KN +REAL,INTENT(IN) :: PX +REAL :: FUNC +INTEGER :: J +REAL :: Z0, Z1, Z2 +Z0= 1.0 +Z1= PX +IF (KN == 0) THEN + FUNC= Z0 + RETURN +ELSEIF (KN == 1) THEN + FUNC= Z1 + RETURN +ELSE + DO J=2,KN + Z2= Z0+Z1 + Z0= Z1 + Z1= Z2 + ENDDO + FUNC= Z2 + RETURN +ENDIF +END FUNCTION FUNC + """.strip() + + routine = Subroutine.from_source(fcode, frontend=frontend) + conditionals = FindNodes(Conditional).visit(routine.body) + assert len(conditionals) == 2 + assert isinstance(conditionals[0].body[-1], Intrinsic) + assert conditionals[0].body[-1].text.upper() == 'RETURN' + assert conditionals[0].else_body == (conditionals[1],) + assert isinstance(conditionals[1].body[-1], Intrinsic) + assert conditionals[1].body[-1].text.upper() == 'RETURN' + assert isinstance(conditionals[1].else_body[-1], Intrinsic) + assert conditionals[1].else_body[-1].text.upper() == 'RETURN' diff --git a/tests/test_derived_types.py b/tests/test_derived_types.py index e26beb10d..87b8da3a0 100644 --- a/tests/test_derived_types.py +++ b/tests/test_derived_types.py @@ -1359,3 +1359,42 @@ def test_derived_types_nested_type(frontend): assert assignment.rhs.parent.type.dtype.typedef is module['some_type'] assert assignment.rhs.parent.parent.type.dtype.name == 'other_type' assert assignment.rhs.parent.parent.type.dtype.typedef is module['other_type'] + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_derived_types_abstract_deferred_procedure(frontend): + fcode = """ +module some_mod + implicit none + type, abstract :: abstract_type + contains + procedure (some_proc), deferred :: some_proc + procedure (other_proc), deferred :: other_proc + end type abstract_type + + abstract interface + subroutine some_proc(this) + import abstract_type + class(abstract_type), intent(in) :: this + end subroutine some_proc + end interface + + abstract interface + subroutine other_proc(this) + import abstract_type + class(abstract_type), intent(inout) :: this + end subroutine other_proc + end interface +end module some_mod + """.strip() + module = Module.from_source(fcode, frontend=frontend) + typedef = module['abstract_type'] + assert typedef.abstract is True + assert typedef.variables == ('some_proc', 'other_proc') + for symbol in typedef.variables: + assert isinstance(symbol, ProcedureSymbol) + assert isinstance(symbol.type.dtype, ProcedureType) + assert symbol.type.dtype.name.lower() == symbol.name.lower() + assert symbol.type.bind_names == (symbol,) + assert symbol.scope is typedef + assert symbol.type.bind_names[0].scope is module diff --git a/tests/test_expression.py b/tests/test_expression.py index 1d9d023a8..2361b475b 100644 --- a/tests/test_expression.py +++ b/tests/test_expression.py @@ -1289,3 +1289,74 @@ def test_recursive_substitution(frontend): routine.body = SubstituteExpressions(expr_map).visit(routine.body) assignment = FindNodes(Assignment).visit(routine.body)[0] assert assignment.lhs == 'var(j + 1)' + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_variable_in_declaration_initializer(frontend): + """ + Check correct handling of cases where the variable appears + in the initializer expression (i.e. no infinite recursion) + """ + fcode = """ +subroutine some_routine(var) +implicit none +INTEGER, PARAMETER :: JPRB = SELECTED_REAL_KIND(13,300) +REAL(KIND=JPRB), PARAMETER :: ZEXPLIMIT = LOG(HUGE(ZEXPLIMIT)) +real(kind=jprb), intent(inout) :: var +var = var + ZEXPLIMIT +end subroutine some_routine + """.strip() + + def _check(routine_): + # A few sanity checks + assert 'zexplimit' in routine_.variable_map + zexplimit = routine_.variable_map['zexplimit'] + assert zexplimit.scope is routine_ + # Now let's take a closer look at the initializer expression + assert 'zexplimit' in str(zexplimit.type.initial).lower() + variables = FindVariables().visit(zexplimit.type.initial) + assert 'zexplimit' in variables + assert variables[variables.index('zexplimit')].scope is routine_ + + routine = Subroutine.from_source(fcode, frontend=frontend) + _check(routine) + # Make sure that's still true when doing another scope attachment + routine.rescope_symbols() + _check(routine) + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_variable_in_dimensions(frontend): + """ + Check correct handling of cases where the variable appears in the + dimensions expression of the same variable (i.e. do not cause + infinite recursion) + """ + fcode = """ +module some_mod + implicit none + + type multi_level + real, allocatable :: data(:, :) + end type multi_level +contains + subroutine some_routine(levels, num_levels) + type(multi_level), intent(inout) :: levels(:) + integer, intent(in) :: num_levels + integer jscale + + do jscale = 2,num_levels + allocate(levels(jscale)%data(size(levels(jscale-1)%data,1), size(levels(jscale-1)%data,2))) + end do + end subroutine some_routine +end module some_mod + """.strip() + + module = Module.from_source(fcode, frontend=frontend) + routine = module['some_routine'] + assert 'levels%data' in routine.symbol_attrs + shape = routine.symbol_attrs['levels%data'].shape + assert len(shape) == 2 + for i, dim in enumerate(shape): + assert isinstance(dim, symbols.InlineCall) + assert str(dim).lower() == f'size(levels(jscale - 1)%data, {i+1})' diff --git a/tests/test_sourcefile.py b/tests/test_sourcefile.py index 4c4766a97..64714fc19 100644 --- a/tests/test_sourcefile.py +++ b/tests/test_sourcefile.py @@ -216,6 +216,7 @@ def test_sourcefile_cpp_stmt_func(here, frontend): assert isinstance(var, ProcedureSymbol) assert isinstance(var.type.dtype, ProcedureType) assert var.type.dtype.procedure is decl + assert decl.source is not None # Generate code and compile filepath = here/f'{module.name}.f90' diff --git a/tests/test_subroutine.py b/tests/test_subroutine.py index 0c5dd5c74..fb08c049d 100644 --- a/tests/test_subroutine.py +++ b/tests/test_subroutine.py @@ -16,7 +16,8 @@ Section, CallStatement, BasicType, Array, Scalar, Variable, SymbolAttributes, StringLiteral, fgen, fexprgen, VariableDeclaration, Transformer, FindTypedSymbols, ProcedureSymbol, ProcedureType, - StatementFunction, normalize_range_indexing, DeferredTypeSymbol + StatementFunction, normalize_range_indexing, DeferredTypeSymbol, + Assignment ) @@ -1504,6 +1505,10 @@ def test_subroutine_stmt_func(here, frontend): routine = Subroutine.from_source(fcode, frontend=frontend) routine.name += f'_{frontend!s}' + # Make sure the statement function injection doesn't invalidate source + for assignment in FindNodes(Assignment).visit(routine.body): + assert assignment.source is not None + # OMNI inlines statement functions, so we can only check correct representation # for fparser if frontend != OMNI: @@ -1515,6 +1520,7 @@ def test_subroutine_stmt_func(here, frontend): assert isinstance(var, ProcedureSymbol) assert isinstance(var.type.dtype, ProcedureType) assert var.type.dtype.procedure is stmt_func_decls[var] + assert stmt_func_decls[var].source is not None # Make sure this produces the correct result filepath = here/f'{routine.name}.f90'