diff --git a/loki/transformations/__init__.py b/loki/transformations/__init__.py index 3f4136103..5bffd0bf7 100644 --- a/loki/transformations/__init__.py +++ b/loki/transformations/__init__.py @@ -34,3 +34,4 @@ from loki.transformations.utilities import * # noqa from loki.transformations.block_index_transformations import * # noqa from loki.transformations.split_read_write import * # noqa +from loki.transformations.routine_signatures import * # noqa diff --git a/loki/transformations/routine_signatures.py b/loki/transformations/routine_signatures.py new file mode 100644 index 000000000..69da32560 --- /dev/null +++ b/loki/transformations/routine_signatures.py @@ -0,0 +1,203 @@ +# (C) Copyright 2018- ECMWF. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +""" +Collection of utilities and transformations altering routine signatures. +""" + +import os +import itertools as it +from loki.batch import Transformation, ProcedureItem +from loki.expression import ( + FindVariables, + SubstituteExpressions +) +from loki.ir import ( + VariableDeclaration, + Transformer, FindNodes, CallStatement +) +from loki.tools import as_tuple, flatten +from loki.types import BasicType + +__all__ = ['RemoveDuplicateArgs', 'remove_duplicate_args_from_calls', + 'modify_variable_declarations'] + + +class RemoveDuplicateArgs(Transformation): + """ + Transformation to remove duplicate arguments for both caller + and callee. + + .. warning:: + this won't work properly for multiple calls to the same routine + with differing duplicate arguments + + Parameters + ---------- + recurse_to_kernels : bool, optional + Remove duplicate arguments only at the driver level or recurse to + (nested) kernels (Default: `True`). + rename_common : bool, optional + Try to rename dummy arguments in called routines that received the same argument + on the caller side, by finding a common name pattern in those names (Default: `False`). + """ + + # This trafo only operates on procedures + item_filter = (ProcedureItem,) + + def __init__(self, recurse_to_kernels=True, rename_common=False): + self.recurse_to_kernels = recurse_to_kernels + self.rename_common = rename_common + + def transform_subroutine(self, routine, **kwargs): + role = kwargs['role'] + if role == 'driver' or self.recurse_to_kernels: + remove_duplicate_args_from_calls(routine, rename_common=self.rename_common) + +def remove_duplicate_args_from_calls(routine, rename_common=False): + """ + Utility to remove duplicate arguments from calls in :data:`routine` + + This updates the calls as well as the called routines. It requires calls + to be enriched with interprocedural information. + + .. warning:: + this won't work properly for multiple calls to the same routine + with differing duplicate arguments + + Parameters + ---------- + routine : :any:`Subroutine` + The subroutine where calls should be transformed. + rename_common : bool, optional + Try to rename dummy arguments in called routines that received the same argument + on the caller side, by finding a common name pattern in those names (Default: `False`). + """ + + def remove_duplicate_args_call(call): + arg_map = {} + for routine_arg, call_arg in call.arg_iter(): + arg_map.setdefault(call_arg, []).append(routine_arg) + # filter duplicate kwargs (comparing to the other kwarguments) + _new_kwargs = as_tuple(list(kw_vals)[0] for g, kw_vals in it.groupby(call.kwarguments, key=lambda x: x[1])) + # filter duplicate kwargs (comparing to the arguments) + new_kwargs = tuple(kwarg for kwarg in _new_kwargs if kwarg[1] not in call.arguments) + # (filter duplicate arguments and) update call + call._update(arguments=as_tuple(dict.fromkeys(call.arguments)), kwarguments=new_kwargs) + return arg_map + + def modify_callee(callee, callee_arg_map): + + def allowed_rename(routine, rename): + # check whether rename is already "used" in routine + if rename in routine.arguments or rename in routine.variables: + return False + return True + + combine = [routine_args for call_arg, routine_args in callee_arg_map.items() if len(routine_args) > 1] + if rename_common: + matches = [ + os.path.commonprefix([str(elem.name) for elem in args]).rstrip('_') or + os.path.commonprefix([str(elem.name)[::-1] for elem in args]).rstrip('_')[::-1] + for args in combine + ] + rename_common_map = {c[0].name: m for c, m in zip(combine, matches) if m} + # check whether found rename is already "used" in routine + unallowed_renames = () + for name, rename in rename_common_map.items(): + if not allowed_rename(callee, rename): + unallowed_renames += (name,) + # and if already "used", remove and use instead default + for key in unallowed_renames: + del rename_common_map[key] + else: + rename_common_map = {} + redundant = flatten([routine_args[1:] for routine_args in combine]) + combine_map = {routine_args[0]: as_tuple(routine_args[1:]) for routine_args in combine} + arg_map = {arg.name: rename_common_map.get(common_arg.name, common_arg.name) + for common_arg, redundant_args in combine_map.items() for arg in redundant_args} + # remove duplicates from callee.arguments + new_routine_args = tuple(arg for arg in callee.arguments if arg not in redundant) + # rename if common name is possible + new_routine_args = as_tuple(arg.clone(name=rename_common_map[arg.name]) + if arg.name in rename_common_map else arg for arg in new_routine_args) + callee.arguments = new_routine_args + + # rename usage/occurences in callee.body + var_map = {} + variables = FindVariables(unique=False).visit(callee.body) + var_map = {var: var.clone(name=arg_map[var.name]) for var in variables if var.name in arg_map} + var_map.update({var: var.clone(name=rename_common_map[var.name]) for var in variables + if var.name in rename_common_map}) + callee.body = SubstituteExpressions(var_map).visit(callee.body) + # modify the variable declarations, thus remove redundant variable declarations and possibly rename + modify_variable_declarations(callee, remove_symbols=redundant, rename_symbols=rename_common_map) + # store the information for possibly later renaming kwarguments on caller side + return rename_common_map + + def rename_kwarguments(relevant_calls, rename_common_map_routine): + for call in relevant_calls: + kwarguments = call.kwarguments + if kwarguments: + call_name = str(call.routine.name).lower() + new_kwargs = as_tuple((rename_common_map_routine[call_name][kw[0]], kw[1]) + if kw[0] in rename_common_map_routine[call_name] else kw for kw in kwarguments) + call._update(kwarguments=new_kwargs) + + calls = FindNodes(CallStatement).visit(routine.body) + call_arg_map = {} + relevant_calls = [] + # adapt call statements (and remove duplicate args/kwargs) + for call in calls: + if call.routine is BasicType.DEFERRED: + continue + call_arg_map[call.routine] = remove_duplicate_args_call(call) + relevant_calls.append(call) + rename_common_map_routine = {} + # modify/adapt callees + for callee, callee_arg_map in call_arg_map.items(): + rename_common_map_routine[str(callee.name).lower()] = modify_callee(callee, callee_arg_map) + # handle possibly renamed kwarguments on caller side + if rename_common: + rename_kwarguments(relevant_calls, rename_common_map_routine) + + +def modify_variable_declarations(routine, remove_symbols=(), rename_symbols=None): + """ + Utility to modify variable declarations by either removing symbols or renaming + symbols. + + .. note:: + This utility only works on the variable declarations itself and + won't modify variable/symbol usages elsewhere! + + Parameters + ---------- + routine : :any:`Subroutine` + The subroutine to be transformed. + remove_symbols : list, tuple + List of symbols for which their declaration should be removed. + rename_symbols : dict + Dict/Map of symbols for which their declaration should be renamed. + """ + rename_symbols = rename_symbols if rename_symbols is not None else {} + var_decls = FindNodes(VariableDeclaration).visit(routine.spec) + remove_symbol_names = [var.name.lower() for var in remove_symbols] + decl_map = {} + already_declared = () + for decl in var_decls: + symbols = [symbol for symbol in decl.symbols if symbol.name.lower() not in remove_symbol_names] + symbols = [symbol.clone(name=rename_symbols[symbol.name]) + if symbol.name in rename_symbols else symbol for symbol in symbols] + symbols = [symbol for symbol in symbols if not symbol.name.lower() in already_declared] + already_declared += tuple(symbol.name.lower() for symbol in symbols) + if symbols and symbols != decl.symbols: + decl_map[decl] = decl.clone(symbols=as_tuple(symbols)) + else: + if not symbols: + decl_map[decl] = None + routine.spec = Transformer(decl_map).visit(routine.spec) diff --git a/loki/transformations/tests/test_routine_signatures.py b/loki/transformations/tests/test_routine_signatures.py new file mode 100644 index 000000000..efc5e07ff --- /dev/null +++ b/loki/transformations/tests/test_routine_signatures.py @@ -0,0 +1,155 @@ +# (C) Copyright 2018- ECMWF. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import pytest + +from loki import Module, Subroutine +from loki.frontend import available_frontends +from loki.ir import FindNodes, CallStatement +from loki.transformations.routine_signatures import RemoveDuplicateArgs + +@pytest.mark.parametrize('frontend', available_frontends()) +@pytest.mark.parametrize('pass_as_kwarg', (True, False)) +@pytest.mark.parametrize('recurse_to_kernels', (True, False)) +@pytest.mark.parametrize('rename_common', (True, False)) +def test_utilities_remove_duplicate_args(tmp_path, frontend, pass_as_kwarg, recurse_to_kernels, rename_common): + """ + Test lowering constant array indices + """ + fcode_driver = f""" +subroutine driver(nlon,nlev,nb,var) + use kernel_mod, only: kernel + implicit none + integer, intent(in) :: nlon,nlev,nb + real, intent(inout) :: var(nlon,nlev,5,nb) + integer :: ibl + integer :: offset + integer :: some_val + integer :: loop_start, loop_end + loop_start = 2 + loop_end = nb + some_val = 0 + offset = 1 + !$omp test + do ibl=loop_start, loop_end + call kernel(nlon,nlev, & + & {'var1=' if pass_as_kwarg else ''}var(:,:,1,ibl),& + & {'var2=' if pass_as_kwarg else ''}var(:,:,1,ibl),& + & {'another_var=' if pass_as_kwarg else ''}var(:,:,2:5,ibl),& + & {'icend=' if pass_as_kwarg else ''}offset,& + & {'lstart=' if pass_as_kwarg else ''}loop_start,& + & {'lend=' if pass_as_kwarg else ''}loop_end,& + & {'kend=' if pass_as_kwarg else ''}nlev) + call kernel(nlon,nlev, & + & {'var1=' if pass_as_kwarg else ''}var(:,:,1,ibl),& + & {'var2=' if pass_as_kwarg else ''}var(:,:,1,ibl),& + & {'another_var=' if pass_as_kwarg else ''}var(:,:,2:5,ibl),& + & {'icend=' if pass_as_kwarg else ''}offset,& + & {'lstart=' if pass_as_kwarg else ''}loop_start,& + & {'lend=' if pass_as_kwarg else ''}loop_end,& + & {'kend=' if pass_as_kwarg else ''}nlev) + enddo +end subroutine driver +""" + + fcode_kernel = """ +module kernel_mod +implicit none +contains +subroutine kernel(nlon,nlev,var1,var2,another_var,icend,lstart,lend,kend) + use compute_mod, only: compute + implicit none + integer, intent(in) :: nlon,nlev,icend,lstart,lend,kend + real, intent(inout) :: var1(nlon,nlev) + real, intent(inout) :: var2(nlon,nlev) + real, intent(inout) :: another_var(nlon,nlev,4) + integer :: jk, jl, jt + var1(:,:) = 0. + do jk = 1,kend + do jl = 1, nlon + var1(jl, jk) = 0. + var2(jl, jk) = 1.0 + do jt= 1,4 + another_var(jl, jk, jt) = 0.0 + end do + end do + end do + call compute(nlon,nlev,var1, var2) + call compute(nlon,nlev,var1, var2) +end subroutine kernel +end module kernel_mod +""" + + fcode_nested_kernel = """ +module compute_mod +implicit none +contains +subroutine compute(nlon,nlev,b_var,a_var) + implicit none + integer, intent(in) :: nlon,nlev + real, intent(inout) :: b_var(nlon,nlev) + real, intent(inout) :: a_var(nlon,nlev) + real :: VAR ! create name clash on purpose (if rename_common) + b_var(:,:) = 0. + a_var(:,:) = 1.0 +end subroutine compute +end module compute_mod +""" + + nested_kernel_mod = Module.from_source(fcode_nested_kernel, frontend=frontend, xmods=[tmp_path]) + kernel_mod = Module.from_source(fcode_kernel, frontend=frontend, definitions=nested_kernel_mod, xmods=[tmp_path]) + driver = Subroutine.from_source(fcode_driver, frontend=frontend, definitions=kernel_mod, xmods=[tmp_path]) + + transformation = RemoveDuplicateArgs(recurse_to_kernels=recurse_to_kernels, rename_common=rename_common) + transformation.apply(driver, role='driver', targets=('kernel',)) + transformation.apply(kernel_mod['kernel'], role='kernel', targets=('compute',)) + transformation.apply(nested_kernel_mod['compute'], role='kernel') + + # driver + kernel_var_name = 'var' if rename_common else 'var1' + kernel_calls = FindNodes(CallStatement).visit(driver.body) + for kernel_call in kernel_calls: + if pass_as_kwarg: + assert (kernel_var_name, 'var(:, :, 1, ibl)') in kernel_call.kwarguments + assert ('var2', 'var(:, :, 1, ibl)') not in kernel_call.kwarguments + arg1 = kernel_call.kwarguments[0][1] + arg2 = kernel_call.kwarguments[1][1] + else: + assert 'var(:, :, 1, ibl)' in kernel_call.arguments + assert 'var2(:, :, 1, ibl)' not in kernel_call.arguments + arg1 = kernel_call.arguments[2] + arg2 = kernel_call.arguments[3] + assert arg1.dimensions == (':', ':', '1', 'ibl') + assert arg2.dimensions == (':', ':', '2:5', 'ibl') + # kernel + kernel_vars = kernel_mod['kernel'].variable_map + kernel_args = kernel_mod['kernel']._dummies + assert kernel_var_name in kernel_args + assert 'var2' not in kernel_args + assert 'var2' not in kernel_vars + assert kernel_vars[kernel_var_name].shape == ('nlon', 'nlev') + assert kernel_vars['another_var'].dimensions == ('nlon', 'nlev', 4) + compute_calls = FindNodes(CallStatement).visit(kernel_mod['kernel'].body) + for compute_call in compute_calls: + assert kernel_var_name in compute_call.arguments + assert 'var2' not in compute_call.arguments + # nested_kernel + nested_kernel = nested_kernel_mod['compute'] + nested_kernel_vars = nested_kernel.variable_map + nested_kernel_args = [arg.name.lower() for arg in nested_kernel.arguments] + # it's always 'b_var' as a rename would clash with the already "used" variable "var" + nested_kernel_var_name = 'b_var' + if recurse_to_kernels: + assert nested_kernel_var_name in nested_kernel_args + assert 'a_var' not in nested_kernel_args + assert nested_kernel_var_name in nested_kernel_vars + assert 'a_var' not in nested_kernel_vars + else: + assert 'b_var' in nested_kernel_args + assert 'a_var' in nested_kernel_args + assert 'b_var' in nested_kernel_vars + assert 'a_var' in nested_kernel_vars diff --git a/loki/transformations/tests/test_utilities.py b/loki/transformations/tests/test_utilities.py index 77b5a56d1..88564bcd5 100644 --- a/loki/transformations/tests/test_utilities.py +++ b/loki/transformations/tests/test_utilities.py @@ -12,7 +12,7 @@ symbols as sym, FindVariables, FindInlineCalls, SubstituteExpressions ) from loki.frontend import available_frontends, OMNI -from loki.ir import nodes as ir, FindNodes, pragmas_attached, CallStatement +from loki.ir import nodes as ir, FindNodes, pragmas_attached from loki.types import BasicType from loki.transformations.utilities import ( @@ -20,7 +20,6 @@ convert_to_lower_case, replace_intrinsics, rename_variables, get_integer_variable, get_loop_bounds, is_driver_loop, find_driver_loops, get_local_arrays, check_routine_pragmas, - RemoveDuplicateArgs ) @@ -547,145 +546,3 @@ def test_transform_utilites_check_routine_pragmas(frontend, tmp_path): assert check_routine_pragmas(module['test_acc_seq'], directive=None) assert check_routine_pragmas(module['test_loki_seq'], directive=None) assert check_routine_pragmas(module['test_acc_vec'], directive='openacc') - -@pytest.mark.parametrize('frontend', available_frontends()) -@pytest.mark.parametrize('pass_as_kwarg', (True, False)) -@pytest.mark.parametrize('recurse_to_kernels', (True, False)) -@pytest.mark.parametrize('rename_common', (True, False)) -def test_utilities_remove_duplicate_args(tmp_path, frontend, pass_as_kwarg, recurse_to_kernels, rename_common): - """ - Test lowering constant array indices - """ - fcode_driver = f""" -subroutine driver(nlon,nlev,nb,var) - use kernel_mod, only: kernel - implicit none - integer, intent(in) :: nlon,nlev,nb - real, intent(inout) :: var(nlon,nlev,5,nb) - integer :: ibl - integer :: offset - integer :: some_val - integer :: loop_start, loop_end - loop_start = 2 - loop_end = nb - some_val = 0 - offset = 1 - !$omp test - do ibl=loop_start, loop_end - call kernel(nlon,nlev, & - & {'var1=' if pass_as_kwarg else ''}var(:,:,1,ibl),& - & {'var2=' if pass_as_kwarg else ''}var(:,:,1,ibl),& - & {'another_var=' if pass_as_kwarg else ''}var(:,:,2:5,ibl),& - & {'icend=' if pass_as_kwarg else ''}offset,& - & {'lstart=' if pass_as_kwarg else ''}loop_start,& - & {'lend=' if pass_as_kwarg else ''}loop_end,& - & {'kend=' if pass_as_kwarg else ''}nlev) - call kernel(nlon,nlev, & - & {'var1=' if pass_as_kwarg else ''}var(:,:,1,ibl),& - & {'var2=' if pass_as_kwarg else ''}var(:,:,1,ibl),& - & {'another_var=' if pass_as_kwarg else ''}var(:,:,2:5,ibl),& - & {'icend=' if pass_as_kwarg else ''}offset,& - & {'lstart=' if pass_as_kwarg else ''}loop_start,& - & {'lend=' if pass_as_kwarg else ''}loop_end,& - & {'kend=' if pass_as_kwarg else ''}nlev) - enddo -end subroutine driver -""" - - fcode_kernel = """ -module kernel_mod -implicit none -contains -subroutine kernel(nlon,nlev,var1,var2,another_var,icend,lstart,lend,kend) - use compute_mod, only: compute - implicit none - integer, intent(in) :: nlon,nlev,icend,lstart,lend,kend - real, intent(inout) :: var1(nlon,nlev) - real, intent(inout) :: var2(nlon,nlev) - real, intent(inout) :: another_var(nlon,nlev,4) - integer :: jk, jl, jt - var1(:,:) = 0. - do jk = 1,kend - do jl = 1, nlon - var1(jl, jk) = 0. - var2(jl, jk) = 1.0 - do jt= 1,4 - another_var(jl, jk, jt) = 0.0 - end do - end do - end do - call compute(nlon,nlev,var1, var2) - call compute(nlon,nlev,var1, var2) -end subroutine kernel -end module kernel_mod -""" - - fcode_nested_kernel = """ -module compute_mod -implicit none -contains -subroutine compute(nlon,nlev,b_var,a_var) - implicit none - integer, intent(in) :: nlon,nlev - real, intent(inout) :: b_var(nlon,nlev) - real, intent(inout) :: a_var(nlon,nlev) - real :: VAR ! create name clash on purpose (if rename_common) - b_var(:,:) = 0. - a_var(:,:) = 1.0 -end subroutine compute -end module compute_mod -""" - - nested_kernel_mod = Module.from_source(fcode_nested_kernel, frontend=frontend, xmods=[tmp_path]) - kernel_mod = Module.from_source(fcode_kernel, frontend=frontend, definitions=nested_kernel_mod, xmods=[tmp_path]) - driver = Subroutine.from_source(fcode_driver, frontend=frontend, definitions=kernel_mod, xmods=[tmp_path]) - - transformation = RemoveDuplicateArgs(recurse_to_kernels=recurse_to_kernels, rename_common=rename_common) - transformation.apply(driver, role='driver', targets=('kernel',)) - transformation.apply(kernel_mod['kernel'], role='kernel', targets=('compute',)) - transformation.apply(nested_kernel_mod['compute'], role='kernel') - - # driver - kernel_var_name = 'var' if rename_common else 'var1' - kernel_calls = FindNodes(CallStatement).visit(driver.body) - for kernel_call in kernel_calls: - if pass_as_kwarg: - assert (kernel_var_name, 'var(:, :, 1, ibl)') in kernel_call.kwarguments - assert ('var2', 'var(:, :, 1, ibl)') not in kernel_call.kwarguments - arg1 = kernel_call.kwarguments[0][1] - arg2 = kernel_call.kwarguments[1][1] - else: - assert 'var(:, :, 1, ibl)' in kernel_call.arguments - assert 'var2(:, :, 1, ibl)' not in kernel_call.arguments - arg1 = kernel_call.arguments[2] - arg2 = kernel_call.arguments[3] - assert arg1.dimensions == (':', ':', '1', 'ibl') - assert arg2.dimensions == (':', ':', '2:5', 'ibl') - # kernel - kernel_vars = kernel_mod['kernel'].variable_map - kernel_args = kernel_mod['kernel']._dummies - assert kernel_var_name in kernel_args - assert 'var2' not in kernel_args - assert 'var2' not in kernel_vars - assert kernel_vars[kernel_var_name].shape == ('nlon', 'nlev') - assert kernel_vars['another_var'].dimensions == ('nlon', 'nlev', 4) - compute_calls = FindNodes(CallStatement).visit(kernel_mod['kernel'].body) - for compute_call in compute_calls: - assert kernel_var_name in compute_call.arguments - assert 'var2' not in compute_call.arguments - # nested_kernel - nested_kernel = nested_kernel_mod['compute'] - nested_kernel_vars = nested_kernel.variable_map - nested_kernel_args = [arg.name.lower() for arg in nested_kernel.arguments] - # it's always 'b_var' as a rename would clash with the already "used" variable "var" - nested_kernel_var_name = 'b_var' - if recurse_to_kernels: - assert nested_kernel_var_name in nested_kernel_args - assert 'a_var' not in nested_kernel_args - assert nested_kernel_var_name in nested_kernel_vars - assert 'a_var' not in nested_kernel_vars - else: - assert 'b_var' in nested_kernel_args - assert 'a_var' in nested_kernel_args - assert 'b_var' in nested_kernel_vars - assert 'a_var' in nested_kernel_vars diff --git a/loki/transformations/utilities.py b/loki/transformations/utilities.py index cfcbe1721..5eb277324 100644 --- a/loki/transformations/utilities.py +++ b/loki/transformations/utilities.py @@ -9,12 +9,9 @@ Collection of utility routines to deal with general language conversion. """ -import os import platform from collections import defaultdict -import itertools as it from pymbolic.primitives import Expression -from loki.batch import Transformation, ProcedureItem from loki.expression import ( symbols as sym, FindVariables, FindInlineCalls, FindLiterals, SubstituteExpressions, SubstituteExpressionsMapper, ExpressionFinder, @@ -22,11 +19,11 @@ ) from loki.ir import ( nodes as ir, Import, TypeDef, VariableDeclaration, - StatementFunction, Transformer, FindNodes, CallStatement + StatementFunction, Transformer, FindNodes ) from loki.module import Module from loki.subroutine import Subroutine -from loki.tools import CaseInsensitiveDict, as_tuple, flatten +from loki.tools import CaseInsensitiveDict, as_tuple from loki.types import SymbolAttributes, BasicType, DerivedType, ProcedureType @@ -35,187 +32,10 @@ 'sanitise_imports', 'replace_selected_kind', 'single_variable_declaration', 'recursive_expression_map_update', 'get_integer_variable', 'get_loop_bounds', 'find_driver_loops', - 'get_local_arrays', 'check_routine_pragmas', 'remove_duplicate_args_from_calls', - 'modify_variable_declarations', 'RemoveDuplicateArgs', + 'get_local_arrays', 'check_routine_pragmas' ] -class RemoveDuplicateArgs(Transformation): - """ - Transformation to remove duplicate arguments for both caller - and callee. - - .. warning:: - this won't work properly for multiple calls to the same routine - with differing duplicate arguments - - Parameters - ---------- - recurse_to_kernels : bool, optional - Remove duplicate arguments only at the driver level or recurse to - (nested) kernels (Default: `True`). - rename_common : bool, optional - Try to rename dummy arguments in called routines that received the same argument - on the caller side, by finding a common name pattern in those names (Default: `False`). - """ - - # This trafo only operates on procedures - item_filter = (ProcedureItem,) - - def __init__(self, recurse_to_kernels=True, rename_common=False): - self.recurse_to_kernels = recurse_to_kernels - self.rename_common = rename_common - - def transform_subroutine(self, routine, **kwargs): - role = kwargs['role'] - if role == 'driver' or self.recurse_to_kernels: - remove_duplicate_args_from_calls(routine, rename_common=self.rename_common) - -def remove_duplicate_args_from_calls(routine, rename_common=False): - """ - Utility to remove duplicate arguments from calls in :data:`routine` - - This updates the calls as well as the called routines. It requires calls - to be enriched with interprocedural information. - - .. warning:: - this won't work properly for multiple calls to the same routine - with differing duplicate arguments - - Parameters - ---------- - routine : :any:`Subroutine` - The subroutine where calls should be transformed. - rename_common : bool, optional - Try to rename dummy arguments in called routines that received the same argument - on the caller side, by finding a common name pattern in those names (Default: `False`). - """ - - def remove_duplicate_args_call(call): - arg_map = {} - for routine_arg, call_arg in call.arg_iter(): - arg_map.setdefault(call_arg, []).append(routine_arg) - # filter duplicate kwargs (comparing to the other kwarguments) - _new_kwargs = as_tuple(list(kw_vals)[0] for g, kw_vals in it.groupby(call.kwarguments, key=lambda x: x[1])) - # filter duplicate kwargs (comparing to the arguments) - new_kwargs = tuple(kwarg for kwarg in _new_kwargs if kwarg[1] not in call.arguments) - # (filter duplicate arguments and) update call - call._update(arguments=as_tuple(dict.fromkeys(call.arguments)), kwarguments=new_kwargs) - return arg_map - - def modify_callee(callee, callee_arg_map): - - def allowed_rename(routine, rename): - # check whether rename is already "used" in routine - if rename in routine.arguments or rename in routine.variables: - return False - return True - - combine = [routine_args for call_arg, routine_args in callee_arg_map.items() if len(routine_args) > 1] - if rename_common: - matches = [ - os.path.commonprefix([str(elem.name) for elem in args]).rstrip('_') or - os.path.commonprefix([str(elem.name)[::-1] for elem in args]).rstrip('_')[::-1] - for args in combine - ] - rename_common_map = {c[0].name: m for c, m in zip(combine, matches) if m} - # check whether found rename is already "used" in routine - unallowed_renames = () - for name, rename in rename_common_map.items(): - if not allowed_rename(callee, rename): - unallowed_renames += (name,) - # and if already "used", remove and use instead default - for key in unallowed_renames: - del rename_common_map[key] - else: - rename_common_map = {} - redundant = flatten([routine_args[1:] for routine_args in combine]) - combine_map = {routine_args[0]: as_tuple(routine_args[1:]) for routine_args in combine} - arg_map = {arg.name: rename_common_map.get(common_arg.name, common_arg.name) - for common_arg, redundant_args in combine_map.items() for arg in redundant_args} - # remove duplicates from callee.arguments - new_routine_args = tuple(arg for arg in callee.arguments if arg not in redundant) - # rename if common name is possible - new_routine_args = as_tuple(arg.clone(name=rename_common_map[arg.name]) - if arg.name in rename_common_map else arg for arg in new_routine_args) - callee.arguments = new_routine_args - - # rename usage/occurences in callee.body - var_map = {} - variables = FindVariables(unique=False).visit(callee.body) - var_map = {var: var.clone(name=arg_map[var.name]) for var in variables if var.name in arg_map} - var_map.update({var: var.clone(name=rename_common_map[var.name]) for var in variables - if var.name in rename_common_map}) - callee.body = SubstituteExpressions(var_map).visit(callee.body) - # modify the variable declarations, thus remove redundant variable declarations and possibly rename - modify_variable_declarations(callee, remove_symbols=redundant, rename_symbols=rename_common_map) - # store the information for possibly later renaming kwarguments on caller side - return rename_common_map - - def rename_kwarguments(relevant_calls, rename_common_map_routine): - for call in relevant_calls: - kwarguments = call.kwarguments - if kwarguments: - call_name = str(call.routine.name).lower() - new_kwargs = as_tuple((rename_common_map_routine[call_name][kw[0]], kw[1]) - if kw[0] in rename_common_map_routine[call_name] else kw for kw in kwarguments) - call._update(kwarguments=new_kwargs) - - calls = FindNodes(CallStatement).visit(routine.body) - call_arg_map = {} - relevant_calls = [] - # adapt call statements (and remove duplicate args/kwargs) - for call in calls: - if call.routine is BasicType.DEFERRED: - continue - call_arg_map[call.routine] = remove_duplicate_args_call(call) - relevant_calls.append(call) - rename_common_map_routine = {} - # modify/adapt callees - for callee, callee_arg_map in call_arg_map.items(): - rename_common_map_routine[str(callee.name).lower()] = modify_callee(callee, callee_arg_map) - # handle possibly renamed kwarguments on caller side - if rename_common: - rename_kwarguments(relevant_calls, rename_common_map_routine) - - -def modify_variable_declarations(routine, remove_symbols=(), rename_symbols=None): - """ - Utility to modify variable declarations by either removing symbols or renaming - symbols. - - .. note:: - This utility only works on the variable declarations itself and - won't modify variable/symbol usages elsewhere! - - Parameters - ---------- - routine : :any:`Subroutine` - The subroutine to be transformed. - remove_symbols : list, tuple - List of symbols for which their declaration should be removed. - rename_symbols : dict - Dict/Map of symbols for which their declaration should be renamed. - """ - rename_symbols = rename_symbols if rename_symbols is not None else {} - var_decls = FindNodes(VariableDeclaration).visit(routine.spec) - remove_symbol_names = [var.name.lower() for var in remove_symbols] - decl_map = {} - already_declared = () - for decl in var_decls: - symbols = [symbol for symbol in decl.symbols if symbol.name.lower() not in remove_symbol_names] - symbols = [symbol.clone(name=rename_symbols[symbol.name]) - if symbol.name in rename_symbols else symbol for symbol in symbols] - symbols = [symbol for symbol in symbols if not symbol.name.lower() in already_declared] - already_declared += tuple(symbol.name.lower() for symbol in symbols) - if symbols and symbols != decl.symbols: - decl_map[decl] = decl.clone(symbols=as_tuple(symbols)) - else: - if not symbols: - decl_map[decl] = None - routine.spec = Transformer(decl_map).visit(routine.spec) - - def single_variable_declaration(routine, variables=None, group_by_shape=False): """ Modify/extend variable declarations to