diff --git a/loki/transformations/__init__.py b/loki/transformations/__init__.py index 0e274ce1d..11c8bdbf0 100644 --- a/loki/transformations/__init__.py +++ b/loki/transformations/__init__.py @@ -35,3 +35,4 @@ from loki.transformations.block_index_transformations import * # noqa from loki.transformations.split_read_write import * # noqa from loki.transformations.loop_blocking import * # noqa +from loki.transformations.routine_signatures import * # noqa diff --git a/loki/transformations/routine_signatures.py b/loki/transformations/routine_signatures.py new file mode 100644 index 000000000..449385a65 --- /dev/null +++ b/loki/transformations/routine_signatures.py @@ -0,0 +1,200 @@ +# (C) Copyright 2018- ECMWF. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +""" +Collection of utilities and transformations altering routine signatures. +""" + +import os +import itertools as it +from loki.batch import Transformation, ProcedureItem +from loki.ir import ( + VariableDeclaration, FindVariables, + Transformer, FindNodes, CallStatement, + SubstituteExpressions +) +from loki.tools import as_tuple, flatten +from loki.types import BasicType + +__all__ = ['RemoveDuplicateArgs', 'remove_duplicate_args_from_calls', + 'modify_variable_declarations'] + + +class RemoveDuplicateArgs(Transformation): + """ + Transformation to remove duplicate arguments for both caller + and callee. + + .. warning:: + this won't work properly for multiple calls to the same routine + with differing duplicate arguments + + Parameters + ---------- + recurse_to_kernels : bool, optional + Remove duplicate arguments only at the driver level or recurse to + (nested) kernels (Default: `True`). + rename_common : bool, optional + Try to rename dummy arguments in called routines that received the same argument + on the caller side, by finding a common name pattern in those names (Default: `False`). + """ + + # This trafo only operates on procedures + item_filter = (ProcedureItem,) + + def __init__(self, recurse_to_kernels=True, rename_common=False): + self.recurse_to_kernels = recurse_to_kernels + self.rename_common = rename_common + + def transform_subroutine(self, routine, **kwargs): + role = kwargs['role'] + if role == 'driver' or self.recurse_to_kernels: + remove_duplicate_args_from_calls(routine, rename_common=self.rename_common) + +def remove_duplicate_args_from_calls(routine, rename_common=False): + """ + Utility to remove duplicate arguments from calls in :data:`routine` + + This updates the calls as well as the called routines. It requires calls + to be enriched with interprocedural information. + + .. warning:: + this won't work properly for multiple calls to the same routine + with differing duplicate arguments + + Parameters + ---------- + routine : :any:`Subroutine` + The subroutine where calls should be transformed. + rename_common : bool, optional + Try to rename dummy arguments in called routines that received the same argument + on the caller side, by finding a common name pattern in those names (Default: `False`). + """ + + def remove_duplicate_args_call(call): + arg_map = {} + for routine_arg, call_arg in call.arg_iter(): + arg_map.setdefault(call_arg, []).append(routine_arg) + # filter duplicate kwargs (comparing to the other kwarguments) + _new_kwargs = as_tuple(list(kw_vals)[0] for g, kw_vals in it.groupby(call.kwarguments, key=lambda x: x[1])) + # filter duplicate kwargs (comparing to the arguments) + new_kwargs = tuple(kwarg for kwarg in _new_kwargs if kwarg[1] not in call.arguments) + # (filter duplicate arguments and) update call + call._update(arguments=as_tuple(dict.fromkeys(call.arguments)), kwarguments=new_kwargs) + return arg_map + + def modify_callee(callee, callee_arg_map): + + def allowed_rename(routine, rename): + # check whether rename is already "used" in routine + if rename in routine.arguments or rename in routine.variables: + return False + return True + + combine = [routine_args for call_arg, routine_args in callee_arg_map.items() if len(routine_args) > 1] + if rename_common: + matches = [ + os.path.commonprefix([str(elem.name) for elem in args]).rstrip('_') or + os.path.commonprefix([str(elem.name)[::-1] for elem in args]).rstrip('_')[::-1] + for args in combine + ] + rename_common_map = {c[0].name: m for c, m in zip(combine, matches) if m} + # check whether found rename is already "used" in routine + unallowed_renames = () + for name, rename in rename_common_map.items(): + if not allowed_rename(callee, rename): + unallowed_renames += (name,) + # and if already "used", remove and use instead default + for key in unallowed_renames: + del rename_common_map[key] + else: + rename_common_map = {} + redundant = flatten([routine_args[1:] for routine_args in combine]) + combine_map = {routine_args[0]: as_tuple(routine_args[1:]) for routine_args in combine} + arg_map = {arg.name: rename_common_map.get(common_arg.name, common_arg.name) + for common_arg, redundant_args in combine_map.items() for arg in redundant_args} + # remove duplicates from callee.arguments + new_routine_args = tuple(arg for arg in callee.arguments if arg not in redundant) + # rename if common name is possible + new_routine_args = as_tuple(arg.clone(name=rename_common_map[arg.name]) + if arg.name in rename_common_map else arg for arg in new_routine_args) + callee.arguments = new_routine_args + + # rename usage/occurences in callee.body + var_map = {} + variables = FindVariables(unique=False).visit(callee.body) + var_map = {var: var.clone(name=arg_map[var.name]) for var in variables if var.name in arg_map} + var_map.update({var: var.clone(name=rename_common_map[var.name]) for var in variables + if var.name in rename_common_map}) + callee.body = SubstituteExpressions(var_map).visit(callee.body) + # modify the variable declarations, thus remove redundant variable declarations and possibly rename + modify_variable_declarations(callee, remove_symbols=redundant, rename_symbols=rename_common_map) + # store the information for possibly later renaming kwarguments on caller side + return rename_common_map + + def rename_kwarguments(relevant_calls, rename_common_map_routine): + for call in relevant_calls: + kwarguments = call.kwarguments + if kwarguments: + call_name = str(call.routine.name).lower() + new_kwargs = as_tuple((rename_common_map_routine[call_name][kw[0]], kw[1]) + if kw[0] in rename_common_map_routine[call_name] else kw for kw in kwarguments) + call._update(kwarguments=new_kwargs) + + calls = FindNodes(CallStatement).visit(routine.body) + call_arg_map = {} + relevant_calls = [] + # adapt call statements (and remove duplicate args/kwargs) + for call in calls: + if call.routine is BasicType.DEFERRED: + continue + call_arg_map[call.routine] = remove_duplicate_args_call(call) + relevant_calls.append(call) + rename_common_map_routine = {} + # modify/adapt callees + for callee, callee_arg_map in call_arg_map.items(): + rename_common_map_routine[str(callee.name).lower()] = modify_callee(callee, callee_arg_map) + # handle possibly renamed kwarguments on caller side + if rename_common: + rename_kwarguments(relevant_calls, rename_common_map_routine) + + +def modify_variable_declarations(routine, remove_symbols=(), rename_symbols=None): + """ + Utility to modify variable declarations by either removing symbols or renaming + symbols. + + .. note:: + This utility only works on the variable declarations itself and + won't modify variable/symbol usages elsewhere! + + Parameters + ---------- + routine : :any:`Subroutine` + The subroutine to be transformed. + remove_symbols : list, tuple + List of symbols for which their declaration should be removed. + rename_symbols : dict + Dict/Map of symbols for which their declaration should be renamed. + """ + rename_symbols = rename_symbols if rename_symbols is not None else {} + var_decls = FindNodes(VariableDeclaration).visit(routine.spec) + remove_symbol_names = [var.name.lower() for var in remove_symbols] + decl_map = {} + already_declared = () + for decl in var_decls: + symbols = [symbol for symbol in decl.symbols if symbol.name.lower() not in remove_symbol_names] + symbols = [symbol.clone(name=rename_symbols[symbol.name]) + if symbol.name in rename_symbols else symbol for symbol in symbols] + symbols = [symbol for symbol in symbols if not symbol.name.lower() in already_declared] + already_declared += tuple(symbol.name.lower() for symbol in symbols) + if symbols and symbols != decl.symbols: + decl_map[decl] = decl.clone(symbols=as_tuple(symbols)) + else: + if not symbols: + decl_map[decl] = None + routine.spec = Transformer(decl_map).visit(routine.spec) diff --git a/loki/transformations/tests/test_routine_signatures.py b/loki/transformations/tests/test_routine_signatures.py new file mode 100644 index 000000000..efc5e07ff --- /dev/null +++ b/loki/transformations/tests/test_routine_signatures.py @@ -0,0 +1,155 @@ +# (C) Copyright 2018- ECMWF. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import pytest + +from loki import Module, Subroutine +from loki.frontend import available_frontends +from loki.ir import FindNodes, CallStatement +from loki.transformations.routine_signatures import RemoveDuplicateArgs + +@pytest.mark.parametrize('frontend', available_frontends()) +@pytest.mark.parametrize('pass_as_kwarg', (True, False)) +@pytest.mark.parametrize('recurse_to_kernels', (True, False)) +@pytest.mark.parametrize('rename_common', (True, False)) +def test_utilities_remove_duplicate_args(tmp_path, frontend, pass_as_kwarg, recurse_to_kernels, rename_common): + """ + Test lowering constant array indices + """ + fcode_driver = f""" +subroutine driver(nlon,nlev,nb,var) + use kernel_mod, only: kernel + implicit none + integer, intent(in) :: nlon,nlev,nb + real, intent(inout) :: var(nlon,nlev,5,nb) + integer :: ibl + integer :: offset + integer :: some_val + integer :: loop_start, loop_end + loop_start = 2 + loop_end = nb + some_val = 0 + offset = 1 + !$omp test + do ibl=loop_start, loop_end + call kernel(nlon,nlev, & + & {'var1=' if pass_as_kwarg else ''}var(:,:,1,ibl),& + & {'var2=' if pass_as_kwarg else ''}var(:,:,1,ibl),& + & {'another_var=' if pass_as_kwarg else ''}var(:,:,2:5,ibl),& + & {'icend=' if pass_as_kwarg else ''}offset,& + & {'lstart=' if pass_as_kwarg else ''}loop_start,& + & {'lend=' if pass_as_kwarg else ''}loop_end,& + & {'kend=' if pass_as_kwarg else ''}nlev) + call kernel(nlon,nlev, & + & {'var1=' if pass_as_kwarg else ''}var(:,:,1,ibl),& + & {'var2=' if pass_as_kwarg else ''}var(:,:,1,ibl),& + & {'another_var=' if pass_as_kwarg else ''}var(:,:,2:5,ibl),& + & {'icend=' if pass_as_kwarg else ''}offset,& + & {'lstart=' if pass_as_kwarg else ''}loop_start,& + & {'lend=' if pass_as_kwarg else ''}loop_end,& + & {'kend=' if pass_as_kwarg else ''}nlev) + enddo +end subroutine driver +""" + + fcode_kernel = """ +module kernel_mod +implicit none +contains +subroutine kernel(nlon,nlev,var1,var2,another_var,icend,lstart,lend,kend) + use compute_mod, only: compute + implicit none + integer, intent(in) :: nlon,nlev,icend,lstart,lend,kend + real, intent(inout) :: var1(nlon,nlev) + real, intent(inout) :: var2(nlon,nlev) + real, intent(inout) :: another_var(nlon,nlev,4) + integer :: jk, jl, jt + var1(:,:) = 0. + do jk = 1,kend + do jl = 1, nlon + var1(jl, jk) = 0. + var2(jl, jk) = 1.0 + do jt= 1,4 + another_var(jl, jk, jt) = 0.0 + end do + end do + end do + call compute(nlon,nlev,var1, var2) + call compute(nlon,nlev,var1, var2) +end subroutine kernel +end module kernel_mod +""" + + fcode_nested_kernel = """ +module compute_mod +implicit none +contains +subroutine compute(nlon,nlev,b_var,a_var) + implicit none + integer, intent(in) :: nlon,nlev + real, intent(inout) :: b_var(nlon,nlev) + real, intent(inout) :: a_var(nlon,nlev) + real :: VAR ! create name clash on purpose (if rename_common) + b_var(:,:) = 0. + a_var(:,:) = 1.0 +end subroutine compute +end module compute_mod +""" + + nested_kernel_mod = Module.from_source(fcode_nested_kernel, frontend=frontend, xmods=[tmp_path]) + kernel_mod = Module.from_source(fcode_kernel, frontend=frontend, definitions=nested_kernel_mod, xmods=[tmp_path]) + driver = Subroutine.from_source(fcode_driver, frontend=frontend, definitions=kernel_mod, xmods=[tmp_path]) + + transformation = RemoveDuplicateArgs(recurse_to_kernels=recurse_to_kernels, rename_common=rename_common) + transformation.apply(driver, role='driver', targets=('kernel',)) + transformation.apply(kernel_mod['kernel'], role='kernel', targets=('compute',)) + transformation.apply(nested_kernel_mod['compute'], role='kernel') + + # driver + kernel_var_name = 'var' if rename_common else 'var1' + kernel_calls = FindNodes(CallStatement).visit(driver.body) + for kernel_call in kernel_calls: + if pass_as_kwarg: + assert (kernel_var_name, 'var(:, :, 1, ibl)') in kernel_call.kwarguments + assert ('var2', 'var(:, :, 1, ibl)') not in kernel_call.kwarguments + arg1 = kernel_call.kwarguments[0][1] + arg2 = kernel_call.kwarguments[1][1] + else: + assert 'var(:, :, 1, ibl)' in kernel_call.arguments + assert 'var2(:, :, 1, ibl)' not in kernel_call.arguments + arg1 = kernel_call.arguments[2] + arg2 = kernel_call.arguments[3] + assert arg1.dimensions == (':', ':', '1', 'ibl') + assert arg2.dimensions == (':', ':', '2:5', 'ibl') + # kernel + kernel_vars = kernel_mod['kernel'].variable_map + kernel_args = kernel_mod['kernel']._dummies + assert kernel_var_name in kernel_args + assert 'var2' not in kernel_args + assert 'var2' not in kernel_vars + assert kernel_vars[kernel_var_name].shape == ('nlon', 'nlev') + assert kernel_vars['another_var'].dimensions == ('nlon', 'nlev', 4) + compute_calls = FindNodes(CallStatement).visit(kernel_mod['kernel'].body) + for compute_call in compute_calls: + assert kernel_var_name in compute_call.arguments + assert 'var2' not in compute_call.arguments + # nested_kernel + nested_kernel = nested_kernel_mod['compute'] + nested_kernel_vars = nested_kernel.variable_map + nested_kernel_args = [arg.name.lower() for arg in nested_kernel.arguments] + # it's always 'b_var' as a rename would clash with the already "used" variable "var" + nested_kernel_var_name = 'b_var' + if recurse_to_kernels: + assert nested_kernel_var_name in nested_kernel_args + assert 'a_var' not in nested_kernel_args + assert nested_kernel_var_name in nested_kernel_vars + assert 'a_var' not in nested_kernel_vars + else: + assert 'b_var' in nested_kernel_args + assert 'a_var' in nested_kernel_args + assert 'b_var' in nested_kernel_vars + assert 'a_var' in nested_kernel_vars