diff --git a/loki/frontend/fparser.py b/loki/frontend/fparser.py index 613adc568..9d85ccde7 100644 --- a/loki/frontend/fparser.py +++ b/loki/frontend/fparser.py @@ -394,6 +394,13 @@ def visit_Part_Ref(self, o, **kwargs): # This should go away once fparser has a basic symbol table, see # https://github.com/stfc/fparser/issues/201 for some details _type = kwargs['scope'].symbol_attrs.lookup(name.name) + if _type is None and (definition := self.definitions.get(name.name)): + # We don't have any type information for this, which means it has + # not been declared locally. Check the definitions for enriched + # type information: + if isinstance(dtype := definition.procedure_type, ProcedureType): + _type = name.type.clone(dtype=dtype) + name = name.clone(type=_type) if _type and isinstance(_type.dtype, ProcedureType): name = name.clone(dimensions=None) call = sym.InlineCall(name, parameters=dimensions, kw_parameters=()) diff --git a/loki/logging.py b/loki/logging.py index e81bd8784..93f4f9433 100644 --- a/loki/logging.py +++ b/loki/logging.py @@ -13,7 +13,7 @@ __all__ = ['logger', 'log_levels', 'set_log_level', 'FileLogger', - 'debug', 'info', 'warning', 'error', 'log'] + 'debug', 'detail', 'perf', 'info', 'warning', 'error', 'log'] def FileLogger(name, filename, level=None, file_level=None, fmt=None, diff --git a/loki/transformations/inline/mapper.py b/loki/transformations/inline/mapper.py index 80439fc6c..232e38ee6 100644 --- a/loki/transformations/inline/mapper.py +++ b/loki/transformations/inline/mapper.py @@ -5,10 +5,11 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +from loki.expression import LokiIdentityMapper from loki.ir import ( FindNodes, Assignment, StatementFunction, SubstituteExpressions ) -from loki.expression import LokiIdentityMapper +from loki.logging import detail from loki.types import BasicType @@ -73,6 +74,13 @@ def map_inline_call(self, expr, *args, **kwargs): function = expr.procedure_type.procedure v_result = [v for v in function.variables if v == function.name][0] + scope = kwargs.get('scope') or expr.function.scope + if scope and function.name in scope.interface_map: + # Inline call to a function that is provided via an interface + # We don't have the function body available for inlining + detail(f'Cannot inline {expr.function.name} into {scope.name}. Only interface available.') + return super().map_inline_call(expr, *args, **kwargs) + # Substitute all arguments through the elemental body arg_map = dict(expr.arg_iter()) fbody = SubstituteExpressions(arg_map).visit(function.body) diff --git a/loki/transformations/inline/tests/test_functions.py b/loki/transformations/inline/tests/test_functions.py index 64b160407..4dc815b9b 100644 --- a/loki/transformations/inline/tests/test_functions.py +++ b/loki/transformations/inline/tests/test_functions.py @@ -13,6 +13,7 @@ from loki.ir import ( nodes as ir, FindNodes, FindVariables, FindInlineCalls ) +from loki.types import ProcedureType from loki.transformations.inline import ( inline_elemental_functions, inline_statement_functions @@ -216,7 +217,7 @@ def test_inline_statement_functions(frontend, stmt_decls): real, parameter :: rtt = 1.0 {stmt_decls_code if stmt_decls else '#include "fcttre.func.h"'} - ret = foeew(arr) + ret = foeew(arr) ret2 = foedelta(3.0) end subroutine stmt_func """.strip() @@ -239,3 +240,122 @@ def test_inline_statement_functions(frontend, stmt_decls): assert assignments[1].rhs == "3.0 + 1.0" else: assert FindInlineCalls().visit(routine.body) + +@pytest.mark.parametrize('frontend', available_frontends( + skip={OFP: "OFP apparently has problems dealing with those Statement Functions", + OMNI: "OMNI automatically inlines Statement Functions"} +)) +@pytest.mark.parametrize('provide_myfunc', ('import', 'module', 'interface', 'intfb', 'routine')) +def test_inline_statement_functions_inline_call(frontend, provide_myfunc, tmp_path): + fcode_myfunc = """ +elemental function myfunc(a) + real, intent(in) :: a + real :: myfunc + myfunc = a * 2.0 +end function myfunc + """.strip() + + if provide_myfunc == 'module': + fcode_myfunc = f""" +module my_mod +implicit none +contains +{fcode_myfunc} +end module my_mod + """.strip() + + if provide_myfunc in ('import', 'module'): + module_import = 'use my_mod, only: myfunc' + else: + module_import = '' + + if provide_myfunc == 'interface': + intf = """ + interface + elemental function myfunc(a) + implicit none + real a + real myfunc + end function myfunc + end interface + """ + elif provide_myfunc in ('intfb', 'routine'): + intf = '#include "myfunc.intfb.h"' + else: + intf = '' + + fcode = f""" +subroutine stmt_func(arr, val, ret) + {module_import} + implicit none + real, intent(in) :: arr(:) + real, intent(in) :: val + real, intent(inout) :: ret(:) + real :: ret2 + real, parameter :: rtt = 1.0 + real :: PTARE + real :: FOEDELTA + FOEDELTA ( PTARE ) = PTARE + 1.0 + MYFUNC(PTARE) + real :: FOEEW + FOEEW ( PTARE ) = PTARE + FOEDELTA(PTARE) + MYFUNC(PTARE) + {intf} + + ret = foeew(arr) + ret2 = foedelta(3.0) + foedelta(val) +end subroutine stmt_func + """.strip() + + if provide_myfunc == 'module': + definitions = (Module.from_source(fcode_myfunc, xmods=[tmp_path]),) + elif provide_myfunc == 'routine': + definitions = (Subroutine.from_source(fcode_myfunc, xmods=[tmp_path]),) + else: + definitions = None + routine = Subroutine.from_source(fcode, frontend=frontend, definitions=definitions, xmods=[tmp_path]) + + # Check the spec + statement_funcs = FindNodes(ir.StatementFunction).visit(routine.spec) + assert len(statement_funcs) == 2 + + inline_calls = FindInlineCalls(unique=False).visit(routine.spec) + if provide_myfunc in ('module', 'interface', 'routine'): + # Enough information available that MYFUNC is recognized as a procedure call + assert len(inline_calls) == 3 + assert all(isinstance(call.function.type.dtype, ProcedureType) for call in inline_calls) + else: + # No information available about MYFUNC, so fparser treats it as an ArraySubscript + assert len(inline_calls) == 1 + assert inline_calls[0].function == 'foedelta' + assert isinstance(inline_calls[0].function.type.dtype, ProcedureType) + + # Check the body + inline_calls = FindInlineCalls().visit(routine.body) + assert len(inline_calls) == 3 + + # Apply the transformation + inline_statement_functions(routine) + + # Check the outcome + assert not FindNodes(ir.StatementFunction).visit(routine.spec) + inline_calls = FindInlineCalls(unique=False).visit(routine.body) + assignments = FindNodes(ir.Assignment).visit(routine.body) + + if provide_myfunc in ('import', 'intfb'): + # MYFUNC(arr) is misclassified as array subscript + assert len(inline_calls) == 0 + elif provide_myfunc in ('module', 'routine'): + # MYFUNC(arr) is eliminated due to inlining + assert len(inline_calls) == 0 + else: + assert len(inline_calls) == 4 + + assert assignments[0].lhs == 'ret' + assert assignments[1].lhs == 'ret2' + if provide_myfunc in ('module', 'routine'): + # Fully inlined due to definition of myfunc available + assert assignments[0].rhs == "arr + arr + 1.0 + arr*2.0 + arr*2.0" + assert assignments[1].rhs == "3.0 + 1.0 + 3.0*2.0 + val + 1.0 + val*2.0" + else: + # myfunc not inlined + assert assignments[0].rhs == "arr + arr + 1.0 + myfunc(arr) + myfunc(arr)" + assert assignments[1].rhs == "3.0 + 1.0 + myfunc(3.0) + val + 1.0 + myfunc(val)"