diff --git a/loki/transformations/inline/functions.py b/loki/transformations/inline/functions.py index f162a1de3..a39b81e5f 100644 --- a/loki/transformations/inline/functions.py +++ b/loki/transformations/inline/functions.py @@ -7,7 +7,8 @@ from collections import ChainMap -from loki.expression import symbols as sym, ExpressionRetriever +from loki.logging import warning +from loki.expression import symbols as sym, ExpressionRetriever, ExpressionDimensionsMapper from loki.ir import ( Transformer, FindNodes, FindVariables, Import, StatementFunction, FindInlineCalls, ExpressionFinder, SubstituteExpressions, @@ -87,6 +88,13 @@ def _inline_functions(routine, inline_elementals_only=False, functions=None): inlined in the next call to this function. """ + def is_array(expr): + """ + Check whether expr evaluates to an array. + E.g., for arr(:, :) return True, for arr(1, 1) or arr(jl, jk) return False. + """ + return any(d != '1' for d in ExpressionDimensionsMapper()(expr)) + class ExpressionRetrieverSkipInlineCallParameters(ExpressionRetriever): """ Expression retriever skipping parameters of inline calls. @@ -102,6 +110,11 @@ def __init__(self, query, recurse_query=None, inline_elementals_only=False, def map_inline_call(self, expr, *args, **kwargs): if not self.visit(expr, *args, **kwargs): return + if not expr.procedure_type is BasicType.DEFERRED and expr.procedure_type.is_elemental: + if any(is_array(val) for val in expr.arg_map.values() if isinstance(val, sym.Array)): + warning(f"Call to elemental function '{expr.routine.name}' with array arguments." + f' There is currently no support to inline those calls!') + return self.rec(expr.function, *args, **kwargs) # SKIP parameters/args/kwargs on purpose # under certain circumstances @@ -142,9 +155,10 @@ class FindInlineCallsSkipInlineCallParameters(ExpressionFinder): for call in calls: if call.procedure_type is BasicType.DEFERRED or isinstance(call.routine, StatementFunction): continue - if inline_elementals_only: - if not (call.procedure_type.is_function and call.procedure_type.is_elemental): - continue + if not call.procedure_type.is_function: + continue + if inline_elementals_only and not call.procedure_type.is_elemental: + continue if functions: if call.routine not in functions: continue diff --git a/loki/transformations/inline/tests/test_functions.py b/loki/transformations/inline/tests/test_functions.py index 4dc815b9b..73ec81d15 100644 --- a/loki/transformations/inline/tests/test_functions.py +++ b/loki/transformations/inline/tests/test_functions.py @@ -6,6 +6,7 @@ # nor does it submit to any jurisdiction. import pytest +import numpy as np from loki import Module, Subroutine from loki.build import jit_compile_lib, Builder, Obj @@ -100,12 +101,8 @@ def test_transform_inline_elemental_functions(tmp_path, builder, frontend): builder.clean() (tmp_path/f'{routine.name}.f90').unlink() - -@pytest.mark.parametrize('frontend', available_frontends()) -def test_transform_inline_elemental_functions_extended(tmp_path, builder, frontend): - """ - Test correct inlining of elemental functions. - """ +@pytest.fixture(name='multiply_extended_mod', params=available_frontends()) +def fixture_multiply_extended_mod(request, tmp_path): fcode_module = """ module multiply_extended_mod use iso_fortran_env, only: real64 @@ -144,8 +141,15 @@ def test_transform_inline_elemental_functions_extended(tmp_path, builder, fronte end module multiply_extended_mod """ + frontend = request.param + module = Module.from_source(fcode_module, frontend=frontend, xmods=[tmp_path]) + return module, frontend + +def test_transform_inline_elemental_functions_extended_scalar(multiply_extended_mod, builder, tmp_path): + module, frontend = multiply_extended_mod + fcode = """ -subroutine transform_inline_elemental_functions_extended(v1, v2, v3) +subroutine transform_inline_elemental_functions_extended_scalar(v1, v2, v3) use iso_fortran_env, only: real64 use multiply_extended_mod, only: multiply, multiply_single_line, add real(kind=real64), intent(in) :: v1 @@ -154,44 +158,88 @@ def test_transform_inline_elemental_functions_extended(tmp_path, builder, fronte v2 = multiply(v1, 6._real64) + multiply_single_line(v1, 3._real64) v3 = add(param1, 200._real64) + add(150._real64, 150._real64) + multiply(6._real64, 11._real64) -end subroutine transform_inline_elemental_functions_extended +end subroutine transform_inline_elemental_functions_extended_scalar """ - # Generate reference code, compile run and verify - module = Module.from_source(fcode_module, frontend=frontend, xmods=[tmp_path]) - routine = Subroutine.from_source(fcode, frontend=frontend, xmods=[tmp_path]) - + routine = Subroutine.from_source(fcode, frontend=frontend, definitions=[module], xmods=[tmp_path]) refname = f'ref_{routine.name}_{frontend}' reference = jit_compile_lib([module, routine], path=tmp_path, name=refname, builder=builder) - - v2, v3 = reference.transform_inline_elemental_functions_extended(11.) + v2, v3 = reference.transform_inline_elemental_functions_extended_scalar(11.) assert v2 == 99. assert v3 == 666. - (tmp_path/f'{module.name}.f90').unlink() (tmp_path/f'{routine.name}.f90').unlink() # Now inline elemental functions routine = Subroutine.from_source(fcode, definitions=module, frontend=frontend, xmods=[tmp_path]) inline_elemental_functions(routine) - - # Make sure there are no more inline calls in the routine body assert not FindInlineCalls().visit(routine.body) - # Verify correct scope of inlined elements assert all(v.scope is routine for v in FindVariables().visit(routine.body)) - # Hack: rename routine to use a different filename in the build routine.name = f'{routine.name}_' - kernel = jit_compile_lib([routine], path=tmp_path, name=routine.name, builder=builder) - - v2, v3 = kernel.transform_inline_elemental_functions_extended_(11.) + kernel = jit_compile_lib([routine, module], path=tmp_path, name=routine.name, builder=builder) + v2, v3 = kernel.transform_inline_elemental_functions_extended_scalar_(11.) assert v2 == 99. assert v3 == 666. builder.clean() (tmp_path/f'{routine.name}.f90').unlink() + (tmp_path/f'{module.name}.f90').unlink() + +def test_transform_inline_elemental_functions_extended_arr(multiply_extended_mod, builder, tmp_path): + module, frontend = multiply_extended_mod + + fcode_arr = """ +subroutine transform_inline_elemental_functions_extended_array(v1, v2, v3, len) + use iso_fortran_env, only: real64 + use multiply_extended_mod, only: multiply, multiply_single_line, add + integer, intent(in) :: len + real(kind=real64), intent(in) :: v1(len) + real(kind=real64), intent(inout) :: v2(len), v3(len) + real(kind=real64), parameter :: param1 = 100. + integer, parameter :: arr_index = 1 + + v2 = multiply(v1(:), 6._real64) + multiply_single_line(v1(:), 3._real64) + v3 = add(param1, 200._real64) + add(v1, 150._real64) + multiply(v1(arr_index), v2(1)) +end subroutine transform_inline_elemental_functions_extended_array +""" + + routine = Subroutine.from_source(fcode_arr, frontend=frontend, definitions=[module], xmods=[tmp_path]) + refname = f'ref_{routine.name}_frontend' + reference = jit_compile_lib([module, routine], path=tmp_path, name=refname, builder=builder) + arr_len = 5 + v1 = np.array([1.0, 2.0, 3.0, 5.0, 3.0], dtype=np.float64, order='F') + v2 = np.zeros((arr_len,), dtype=np.float64, order='F') + v3 = np.zeros((arr_len,), dtype=np.float64, order='F') + reference.transform_inline_elemental_functions_extended_array(v1, v2, v3, arr_len) + assert (v2 == np.array([9., 18., 27., 45., 27.], dtype=np.float64, order='F')).all() + assert (v3 == np.array([460., 461., 462., 464., 462.], dtype=np.float64, order='F')).all() + + (tmp_path/f'{routine.name}.f90').unlink() + + routine = Subroutine.from_source(fcode_arr, definitions=module, frontend=frontend, xmods=[tmp_path]) + inline_elemental_functions(routine) + # TODO: Make sure there are no more inline calls in the routine body + # assert not FindInlineCalls().visit(routine.body) + # this is currently not achievable as calls to elemental functions with array arguments + # can't be properly inlined and therefore are skipped + # Verify correct scope of inlined elements + assert all(v.scope is routine for v in FindVariables().visit(routine.body)) + # Hack: rename routine to use a different filename in the build + routine.name = f'{routine.name}_' + kernel = jit_compile_lib([routine, module], path=tmp_path, name=routine.name, builder=builder) + v1 = np.array([1.0, 2.0, 3.0, 5.0, 3.0], dtype=np.float64, order='F') + v2 = np.zeros((arr_len,), dtype=np.float64, order='F') + v3 = np.zeros((arr_len,), dtype=np.float64, order='F') + kernel.transform_inline_elemental_functions_extended_array_(v1, v2, v3, arr_len) + assert (v2 == np.array([9., 18., 27., 45., 27.], dtype=np.float64, order='F')).all() + assert (v3 == np.array([460., 461., 462., 464., 462.], dtype=np.float64, order='F')).all() + + builder.clean() + (tmp_path/f'{routine.name}.f90').unlink() + (tmp_path/f'{module.name}.f90').unlink() @pytest.mark.parametrize('frontend', available_frontends(