diff --git a/loki/transformations/inline.py b/loki/transformations/inline.py index 196095398..d960fddab 100644 --- a/loki/transformations/inline.py +++ b/loki/transformations/inline.py @@ -15,7 +15,8 @@ from loki.batch import Transformation from loki.ir import ( Import, Comment, Assignment, VariableDeclaration, CallStatement, - Transformer, FindNodes, pragmas_attached, is_loki_pragma, Interface + Transformer, FindNodes, pragmas_attached, is_loki_pragma, Interface, + StatementFunction ) from loki.expression import ( symbols as sym, FindVariables, FindInlineCalls, FindLiterals, @@ -36,7 +37,8 @@ __all__ = [ 'inline_constant_parameters', 'inline_elemental_functions', 'inline_internal_procedures', 'inline_member_procedures', - 'inline_marked_subroutines', 'InlineTransformation' + 'inline_marked_subroutines', 'InlineTransformation', + 'inline_statement_functions' ] @@ -54,6 +56,10 @@ class InlineTransformation(Transformation): Replaces :any:`InlineCall` expression to elemental functions with the called function's body (see :any:`inline_elemental_functions`); default: True. + inline_stmt_funcs: bool + Replaces :any:`InlineCall` expression to statement functions + with the corresponding rhs of the statement function if + the statement function declaration is available; default: False. inline_internals : bool Inline internal procedure (see :any:`inline_internal_procedures`); default: False. @@ -84,13 +90,14 @@ class InlineTransformation(Transformation): def __init__( self, inline_constants=False, inline_elementals=True, - inline_internals=False, inline_marked=True, - remove_dead_code=True, allowed_aliases=None, - adjust_imports=True, external_only=True, - resolve_sequence_association=False + inline_stmt_funcs=False, inline_internals=False, + inline_marked=True, remove_dead_code=True, + allowed_aliases=None, adjust_imports=True, + external_only=True, resolve_sequence_association=False ): self.inline_constants = inline_constants self.inline_elementals = inline_elementals + self.inline_stmt_funcs = inline_stmt_funcs self.inline_internals = inline_internals self.inline_marked = inline_marked self.remove_dead_code = remove_dead_code @@ -121,6 +128,10 @@ def transform_subroutine(self, routine, **kwargs): if self.inline_elementals: inline_elemental_functions(routine) + # Inline Statement Functions + if self.inline_stmt_funcs: + inline_statement_functions(routine) + # Inline internal (contained) procedures if self.inline_internals: inline_internal_procedures(routine, allowed_aliases=self.allowed_aliases) @@ -184,11 +195,19 @@ def map_inline_call(self, expr, *args, **kwargs): # We still need to recurse and ensure re-scoping return super().map_inline_call(expr, *args, **kwargs) + # if it is an inline call to a Statement Function + if isinstance(expr.routine, StatementFunction): + function = expr.routine + # Substitute all arguments through the elemental body + arg_map = dict(expr.arg_iter()) + fbody = SubstituteExpressions(arg_map).visit(function.rhs) + return fbody + function = expr.procedure_type.procedure v_result = [v for v in function.variables if v == function.name][0] # Substitute all arguments through the elemental body - arg_map = dict(zip(function.arguments, expr.parameters)) + arg_map = dict(expr.arg_iter()) fbody = SubstituteExpressions(arg_map).visit(function.body) # Extract the RHS of the final result variable assignment @@ -338,6 +357,40 @@ def inline_elemental_functions(routine): import_map[im] = None routine.spec = Transformer(import_map).visit(routine.spec) +def inline_statement_functions(routine): + """ + Replaces :any:`InlineCall` expression to statement functions with the + called statement functions rhs. + """ + # Keep track of removed symbols + removed_functions = set() + + stmt_func_decls = FindNodes(StatementFunction).visit(routine.spec) + exprmap = {} + for call in FindInlineCalls().visit(routine.body): + proc_type = call.procedure_type + if proc_type is BasicType.DEFERRED: + continue + if proc_type.is_function and isinstance(call.routine, StatementFunction): + exprmap[call] = InlineSubstitutionMapper()(call, scope=routine) + removed_functions.add(call.routine) + # Apply the map to itself to handle nested statement function calls + exprmap = recursive_expression_map_update(exprmap, max_iterations=10, mapper_cls=InlineSubstitutionMapper) + # Apply expression-level substitution to routine + routine.body = SubstituteExpressions(exprmap).visit(routine.body) + + # remove statement function declarations as well as statement function argument(s) declarations + vars_to_remove = {stmt_func.variable.name.lower() for stmt_func in stmt_func_decls} + vars_to_remove |= {arg.name.lower() for stmt_func in stmt_func_decls for arg in stmt_func.arguments} + spec_map = {stmt_func: None for stmt_func in stmt_func_decls} + for decl in routine.declarations: + if any(var in vars_to_remove for var in decl.symbols): + symbols = tuple(var for var in decl.symbols if var not in vars_to_remove) + if symbols: + decl._update(symbols=symbols) + else: + spec_map[decl] = None + routine.spec = Transformer(spec_map).visit(routine.spec) def map_call_to_procedure_body(call, caller): """ diff --git a/loki/transformations/tests/test_inline.py b/loki/transformations/tests/test_inline.py index 3e9f29ee9..99d2ea2c2 100644 --- a/loki/transformations/tests/test_inline.py +++ b/loki/transformations/tests/test_inline.py @@ -21,11 +21,13 @@ from loki.transformations.inline import ( inline_elemental_functions, inline_constant_parameters, inline_member_procedures, inline_marked_subroutines, - InlineTransformation, + inline_statement_functions, InlineTransformation, ) from loki.transformations.sanitise import ResolveAssociatesTransformer from loki.transformations.utilities import replace_selected_kind +# pylint: disable=too-many-lines + @pytest.fixture(name='builder') def fixture_builder(tmp_path): @@ -700,6 +702,53 @@ def test_inline_member_routines_with_associate(frontend): assert len(assocs) == 2 +@pytest.mark.parametrize('frontend', available_frontends( + skip={OFP: "OFP apparently has problems dealing with those Statement Functions", + OMNI: "OMNI automatically inlines Statement Functions"} +)) +@pytest.mark.parametrize('stmt_decls', (True, False)) +def test_inline_statement_functions(frontend, stmt_decls): + stmt_decls_code = """ + real :: PTARE + real :: FOEDELTA + FOEDELTA ( PTARE ) = PTARE + 1.0 + real :: FOEEW + FOEEW ( PTARE ) = PTARE + FOEDELTA(PTARE) + """.strip() + + fcode = f""" +subroutine stmt_func(arr, ret) + implicit none + real, intent(in) :: arr(:) + real, intent(inout) :: ret(:) + real :: ret2 + real, parameter :: rtt = 1.0 + {stmt_decls_code if stmt_decls else '#include "fcttre.func.h"'} + + ret = foeew(arr) + ret2 = foedelta(3.0) +end subroutine stmt_func + """.strip() + + routine = Subroutine.from_source(fcode, frontend=frontend) + if stmt_decls: + assert FindNodes(ir.StatementFunction).visit(routine.spec) + else: + assert not FindNodes(ir.StatementFunction).visit(routine.spec) + assert FindInlineCalls().visit(routine.body) + inline_statement_functions(routine) + + assert not FindNodes(ir.StatementFunction).visit(routine.spec) + if stmt_decls: + assert not FindInlineCalls().visit(routine.body) + assignments = FindNodes(ir.Assignment).visit(routine.body) + assert assignments[0].lhs == 'ret' + assert assignments[0].rhs == "arr + arr + 1.0" + assert assignments[1].lhs == 'ret2' + assert assignments[1].rhs == "3.0 + 1.0" + else: + assert FindInlineCalls().visit(routine.body) + @pytest.mark.parametrize('frontend', available_frontends()) @pytest.mark.parametrize('adjust_imports', [True, False]) def test_inline_marked_subroutines(frontend, adjust_imports, tmp_path): @@ -1095,7 +1144,8 @@ def test_inline_marked_subroutines_declarations(frontend, tmp_path): @pytest.mark.parametrize('frontend', available_frontends( (OFP, 'Prefix/elemental support not implemented')) ) -def test_inline_transformation(frontend, tmp_path): +@pytest.mark.parametrize('pass_as_kwarg', (False, True)) +def test_inline_transformation(tmp_path, frontend, pass_as_kwarg): """Test combining recursive inlining via :any:`InliningTransformation`.""" fcode_module = """ @@ -1125,25 +1175,30 @@ def test_inline_transformation(frontend, tmp_path): end subroutine add_one_and_two """ - fcode = """ + fcode = f""" subroutine test_inline_pragma(a, b) implicit none real(kind=8), intent(inout) :: a(3), b(3) integer, parameter :: n = 3 integer :: i + real :: stmt_arg + real :: some_stmt_func + some_stmt_func ( stmt_arg ) = stmt_arg + 3.1415 #include "add_one_and_two.intfb.h" do i=1, n !$loki inline - call add_one_and_two(a(i)) + call add_one_and_two({'a=' if pass_as_kwarg else ''}a(i)) end do do i=1, n !$loki inline - call add_one_and_two(b(i)) + call add_one_and_two({'a=' if pass_as_kwarg else ''}b(i)) end do + a(1) = some_stmt_func({'stmt_arg=' if pass_as_kwarg else ''}a(2)) + end subroutine test_inline_pragma """ module = Module.from_source(fcode_module, frontend=frontend, xmods=[tmp_path]) @@ -1152,7 +1207,8 @@ def test_inline_transformation(frontend, tmp_path): routine.enrich(inner) trafo = InlineTransformation( - inline_constants=True, external_only=True, inline_elementals=True + inline_constants=True, external_only=True, inline_elementals=True, + inline_stmt_funcs=True ) calls = FindNodes(ir.CallStatement).visit(routine.body) @@ -1173,11 +1229,12 @@ def test_inline_transformation(frontend, tmp_path): calls = FindNodes(ir.CallStatement).visit(routine.body) assert len(calls) == 0 assigns = FindNodes(ir.Assignment).visit(routine.body) - assert len(assigns) == 4 + assert len(assigns) == 5 assert assigns[0].lhs == 'a(i)' and assigns[0].rhs == 'a(i) + 1.0' assert assigns[1].lhs == 'a(i)' and assigns[1].rhs == 'a(i) + 2.0' assert assigns[2].lhs == 'b(i)' and assigns[2].rhs == 'b(i) + 1.0' assert assigns[3].lhs == 'b(i)' and assigns[3].rhs == 'b(i) + 2.0' + assert assigns[4].lhs == 'a(1)' and assigns[4].rhs == 'a(2) + 3.1415' @pytest.mark.parametrize('frontend', available_frontends()) diff --git a/loki/transformations/utilities.py b/loki/transformations/utilities.py index e85178055..c237d6d66 100644 --- a/loki/transformations/utilities.py +++ b/loki/transformations/utilities.py @@ -468,7 +468,7 @@ def replace_selected_kind(routine): routine.spec.prepend(imprt) -def recursive_expression_map_update(expr_map, max_iterations=10): +def recursive_expression_map_update(expr_map, max_iterations=10, mapper_cls=SubstituteExpressionsMapper): """ Utility function to apply a substitution map for expressions to itself @@ -490,6 +490,8 @@ def recursive_expression_map_update(expr_map, max_iterations=10): max_iterations : int Maximum number of iterations, corresponds to the maximum level of nesting that can be replaced. + mapper_cls: :any:`SubstituteExpressionsMapper` + The underlying mapper to be used (default: :any:`SubstituteExpressionsMapper`). """ def apply_to_init_arg(name, arg, expr, mapper): # Helper utility to apply the mapper only to expression arguments and @@ -504,7 +506,7 @@ def apply_to_init_arg(name, arg, expr, mapper): # We update the expression map by applying it to the children of each replacement # node, thus making sure node replacements are also applied to nested attributes, # e.g. call arguments or array subscripts etc. - mapper = SubstituteExpressionsMapper(expr_map) + mapper = mapper_cls(expr_map) prev_map, expr_map = expr_map, { expr: type(replacement)(**{ name: apply_to_init_arg(name, arg, expr, mapper)