diff --git a/loki/expression/tests/test_expression.py b/loki/expression/tests/test_expression.py index 1c16f651f..b55cd935a 100644 --- a/loki/expression/tests/test_expression.py +++ b/loki/expression/tests/test_expression.py @@ -2197,3 +2197,74 @@ def static_func(a): test_str = 'foo%bar%barbar%barbarbar%val_barbarbar + 1' parsed = parse_expr(convert_to_case(f'{test_str}', mode=case), evaluate=True, context=context) assert parsed == 6 + + +@pytest.mark.parametrize('frontend', available_frontends( + skip={OMNI: "OMNI fails on missing module"} +)) +def test_stmt_func_heuristic(frontend, tmp_path): + """ + Our Fparser/OFP translation has a heuristic to detect statement function declarations, + but that falsely misinterpreted some assignments as statement functions due to + missing shape information (reported in #326) + """ + fcode = """ +SUBROUTINE SOME_ROUTINE(YDFIELDS,YDMODEL,YDCST) +USE FIELDS_MOD , ONLY : FIELDS +USE TYPE_MODEL , ONLY : MODEL +USE VAR_MOD , ONLY : ARR, FNAME +IMPLICIT NONE +TYPE(FIELDS) ,INTENT(INOUT) :: YDFIELDS +TYPE(MODEL) ,INTENT(IN) :: YDMODEL +TYPE(TOMCST) ,INTENT(IN) :: YDCST +CHARACTER(LEN=20) :: CLFILE +REAL :: ZALFA +REAL :: ZALFAG(3) +REAL :: FOEALFA +REAL :: PTARE +FOEALFA(PTARE) = MIN(1.0, PTARE) +#include "fcttre.func.h" + +ASSOCIATE(YDSURF=>YDFIELDS%YRSURF,RTT=>YDCST%RTT) +ASSOCIATE(SD_VN=>YDSURF%SD_VN,YSD_VN=>YDSURF%YSD_VN, & + & LEGBRAD=>YDMODEL%YRML_PHY_EC%YREPHY%LEGBRAD) +IF(LEGBRAD)SD_VN(:,YSD_VN%YACCPR5%MP,:)=SD_VN(:,YSD_VN%YACCPR%MP,:) +IF(LEGBRAD)ARR(:,YSD_VN%YACCPR5%MP,:)=SD_VN(:,YSD_VN%YACCPR%MP,:) +CLFILE(1:20)=FNAME +ZALFA=FOEDELTA(RTT) +ZALFAG(1)=FOEDELTA(RTT-1.) +ZALFAG(2)=FOEALFA(RTT) +ZALFAG(3)=FOEALFA(RTT-1.) +END ASSOCIATE +END ASSOCIATE +END SUBROUTINE SOME_ROUTINE + """.strip() + source = Sourcefile.from_source(fcode, frontend=frontend, xmods=[tmp_path]) + routine = source['some_routine'] + + assignments = FindNodes(ir.Assignment).visit(routine.body) + + assert [ + ass.lhs.name.lower() for ass in assignments + ] == [ + 'sd_vn', 'arr', 'clfile', 'zalfa', 'zalfag', 'zalfag', 'zalfag' + ] + + sd_vn = assignments[0].lhs + assert isinstance(sd_vn, sym.Array) + + arr = assignments[1].lhs + assert isinstance(arr, sym.Array) + assert arr.type.imported + + # FOEDELTA cannot be identified as a statement function due to the declarations + # hidden in the external header + assert isinstance(assignments[3].rhs, sym.Array) + assert isinstance(assignments[4].rhs, sym.Array) + + # FOEALFA should have been identified as a statement function + stmt_funcs = FindNodes(ir.StatementFunction).visit(routine.ir) + assert len(stmt_funcs) == 1 + assert stmt_funcs[0].name.lower() == 'foealfa' + assert isinstance(assignments[5].rhs, sym.InlineCall) + assert isinstance(assignments[6].rhs, sym.InlineCall) diff --git a/loki/frontend/fparser.py b/loki/frontend/fparser.py index 65729ed00..57a0f2b02 100644 --- a/loki/frontend/fparser.py +++ b/loki/frontend/fparser.py @@ -2983,22 +2983,31 @@ def visit_Assignment_Stmt(self, o, **kwargs): # Special-case: Identify statement functions using our internal symbol table symbol_attrs = kwargs['scope'].symbol_attrs - if isinstance(lhs, sym.Array) and not lhs.parent and lhs.name in symbol_attrs: - - def _create_stmt_func_type(stmt_func): - name = str(stmt_func.variable) - procedure = LazyNodeLookup( - anchor=kwargs['scope'], - query=lambda x: [ - f for f in FindNodes(ir.StatementFunction).visit(x.spec) if f.variable == name - ][0] - ) - proc_type = ProcedureType(is_function=True, procedure=procedure, name=name) - return SymbolAttributes(dtype=proc_type, is_stmt_func=True) + if isinstance(lhs, sym.Array) and symbol_attrs.lookup(lhs.name) is not None: + # If this looks like an array but we have an explicit scalar declaration then + # this might in fact be a statement function. + # To avoid the costly lookup for declarations on each array assignment, we run through + # some sanity checks instead that allow us to bail out early in most cases + lhs_type = lhs.type + could_be_a_statement_func = not ( + lhs_type.shape or lhs_type.length # Declaration with length or dimensions + or lhs.parent # Derived type member (we might lack information from enrichment) + or lhs_type.intent or lhs_type.imported # Dummy argument or imported from module + or isinstance(lhs.scope, ir.Associate) # Symbol stems from an associate + ) + + if could_be_a_statement_func: + def _create_stmt_func_type(stmt_func): + name = str(stmt_func.variable) + procedure = LazyNodeLookup( + anchor=kwargs['scope'], + query=lambda x: [ + f for f in FindNodes(ir.StatementFunction).visit(x.spec) if f.variable == name + ][0] + ) + proc_type = ProcedureType(is_function=True, procedure=procedure, name=name) + return SymbolAttributes(dtype=proc_type, is_stmt_func=True) - if not lhs.type.shape and not lhs.type.intent: - # If the LHS array access is actually declared as a scalar, - # we are actually dealing with a statement function! f_symbol = sym.ProcedureSymbol(name=lhs.name, scope=kwargs['scope']) stmt_func = ir.StatementFunction( variable=f_symbol, arguments=lhs.dimensions,