Skip to content

Commit

Permalink
Fix misclassification as StatementFunction (#326)
Browse files Browse the repository at this point in the history
  • Loading branch information
reuterbal committed Jun 13, 2024
1 parent eb793e2 commit f5b3980
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 15 deletions.
71 changes: 71 additions & 0 deletions loki/expression/tests/test_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -2197,3 +2197,74 @@ def static_func(a):
test_str = 'foo%bar%barbar%barbarbar%val_barbarbar + 1'
parsed = parse_expr(convert_to_case(f'{test_str}', mode=case), evaluate=True, context=context)
assert parsed == 6


@pytest.mark.parametrize('frontend', available_frontends(
skip={OMNI: "OMNI fails on missing module"}
))
def test_stmt_func_heuristic(frontend, tmp_path):
"""
Our Fparser/OFP translation has a heuristic to detect statement function declarations,
but that falsely misinterpreted some assignments as statement functions due to
missing shape information (reported in #326)
"""
fcode = """
SUBROUTINE SOME_ROUTINE(YDFIELDS,YDMODEL,YDCST)
USE FIELDS_MOD , ONLY : FIELDS
USE TYPE_MODEL , ONLY : MODEL
USE VAR_MOD , ONLY : ARR, FNAME
IMPLICIT NONE
TYPE(FIELDS) ,INTENT(INOUT) :: YDFIELDS
TYPE(MODEL) ,INTENT(IN) :: YDMODEL
TYPE(TOMCST) ,INTENT(IN) :: YDCST
CHARACTER(LEN=20) :: CLFILE
REAL :: ZALFA
REAL :: ZALFAG(3)
REAL :: FOEALFA
REAL :: PTARE
FOEALFA(PTARE) = MIN(1.0, PTARE)
#include "fcttre.func.h"
ASSOCIATE(YDSURF=>YDFIELDS%YRSURF,RTT=>YDCST%RTT)
ASSOCIATE(SD_VN=>YDSURF%SD_VN,YSD_VN=>YDSURF%YSD_VN, &
& LEGBRAD=>YDMODEL%YRML_PHY_EC%YREPHY%LEGBRAD)
IF(LEGBRAD)SD_VN(:,YSD_VN%YACCPR5%MP,:)=SD_VN(:,YSD_VN%YACCPR%MP,:)
IF(LEGBRAD)ARR(:,YSD_VN%YACCPR5%MP,:)=SD_VN(:,YSD_VN%YACCPR%MP,:)
CLFILE(1:20)=FNAME
ZALFA=FOEDELTA(RTT)
ZALFAG(1)=FOEDELTA(RTT-1.)
ZALFAG(2)=FOEALFA(RTT)
ZALFAG(3)=FOEALFA(RTT-1.)
END ASSOCIATE
END ASSOCIATE
END SUBROUTINE SOME_ROUTINE
""".strip()
source = Sourcefile.from_source(fcode, frontend=frontend, xmods=[tmp_path])
routine = source['some_routine']

assignments = FindNodes(ir.Assignment).visit(routine.body)

assert [
ass.lhs.name.lower() for ass in assignments
] == [
'sd_vn', 'arr', 'clfile', 'zalfa', 'zalfag', 'zalfag', 'zalfag'
]

sd_vn = assignments[0].lhs
assert isinstance(sd_vn, sym.Array)

arr = assignments[1].lhs
assert isinstance(arr, sym.Array)
assert arr.type.imported

# FOEDELTA cannot be identified as a statement function due to the declarations
# hidden in the external header
assert isinstance(assignments[3].rhs, sym.Array)
assert isinstance(assignments[4].rhs, sym.Array)

# FOEALFA should have been identified as a statement function
stmt_funcs = FindNodes(ir.StatementFunction).visit(routine.ir)
assert len(stmt_funcs) == 1
assert stmt_funcs[0].name.lower() == 'foealfa'
assert isinstance(assignments[5].rhs, sym.InlineCall)
assert isinstance(assignments[6].rhs, sym.InlineCall)
39 changes: 24 additions & 15 deletions loki/frontend/fparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2983,22 +2983,31 @@ def visit_Assignment_Stmt(self, o, **kwargs):

# Special-case: Identify statement functions using our internal symbol table
symbol_attrs = kwargs['scope'].symbol_attrs
if isinstance(lhs, sym.Array) and not lhs.parent and lhs.name in symbol_attrs:

def _create_stmt_func_type(stmt_func):
name = str(stmt_func.variable)
procedure = LazyNodeLookup(
anchor=kwargs['scope'],
query=lambda x: [
f for f in FindNodes(ir.StatementFunction).visit(x.spec) if f.variable == name
][0]
)
proc_type = ProcedureType(is_function=True, procedure=procedure, name=name)
return SymbolAttributes(dtype=proc_type, is_stmt_func=True)
if isinstance(lhs, sym.Array) and symbol_attrs.lookup(lhs.name) is not None:
# If this looks like an array but we have an explicit scalar declaration then
# this might in fact be a statement function.
# To avoid the costly lookup for declarations on each array assignment, we run through
# some sanity checks instead that allow us to bail out early in most cases
lhs_type = lhs.type
could_be_a_statement_func = not (
lhs_type.shape or lhs_type.length # Declaration with length or dimensions
or lhs.parent # Derived type member (we might lack information from enrichment)
or lhs_type.intent or lhs_type.imported # Dummy argument or imported from module
or isinstance(lhs.scope, ir.Associate) # Symbol stems from an associate
)

if could_be_a_statement_func:
def _create_stmt_func_type(stmt_func):
name = str(stmt_func.variable)
procedure = LazyNodeLookup(
anchor=kwargs['scope'],
query=lambda x: [
f for f in FindNodes(ir.StatementFunction).visit(x.spec) if f.variable == name
][0]
)
proc_type = ProcedureType(is_function=True, procedure=procedure, name=name)
return SymbolAttributes(dtype=proc_type, is_stmt_func=True)

if not lhs.type.shape and not lhs.type.intent:
# If the LHS array access is actually declared as a scalar,
# we are actually dealing with a statement function!
f_symbol = sym.ProcedureSymbol(name=lhs.name, scope=kwargs['scope'])
stmt_func = ir.StatementFunction(
variable=f_symbol, arguments=lhs.dimensions,
Expand Down

0 comments on commit f5b3980

Please sign in to comment.