-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
RemoveDuplicateArgs: move to new file/module 'routine_signatures'
- Loading branch information
1 parent
cfb121d
commit d89f737
Showing
5 changed files
with
363 additions
and
327 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,203 @@ | ||
# (C) Copyright 2018- ECMWF. | ||
# This software is licensed under the terms of the Apache Licence Version 2.0 | ||
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. | ||
# In applying this licence, ECMWF does not waive the privileges and immunities | ||
# granted to it by virtue of its status as an intergovernmental organisation | ||
# nor does it submit to any jurisdiction. | ||
|
||
""" | ||
Collection of utilities and transformations altering routine signatures. | ||
""" | ||
|
||
import os | ||
import itertools as it | ||
from loki.batch import Transformation, ProcedureItem | ||
from loki.expression import ( | ||
FindVariables, | ||
SubstituteExpressions | ||
) | ||
from loki.ir import ( | ||
VariableDeclaration, | ||
Transformer, FindNodes, CallStatement | ||
) | ||
from loki.tools import as_tuple, flatten | ||
from loki.types import BasicType | ||
|
||
__all__ = ['RemoveDuplicateArgs', 'remove_duplicate_args_from_calls', | ||
'modify_variable_declarations'] | ||
|
||
|
||
class RemoveDuplicateArgs(Transformation): | ||
""" | ||
Transformation to remove duplicate arguments for both caller | ||
and callee. | ||
.. warning:: | ||
this won't work properly for multiple calls to the same routine | ||
with differing duplicate arguments | ||
Parameters | ||
---------- | ||
recurse_to_kernels : bool, optional | ||
Remove duplicate arguments only at the driver level or recurse to | ||
(nested) kernels (Default: `True`). | ||
rename_common : bool, optional | ||
Try to rename dummy arguments in called routines that received the same argument | ||
on the caller side, by finding a common name pattern in those names (Default: `False`). | ||
""" | ||
|
||
# This trafo only operates on procedures | ||
item_filter = (ProcedureItem,) | ||
|
||
def __init__(self, recurse_to_kernels=True, rename_common=False): | ||
self.recurse_to_kernels = recurse_to_kernels | ||
self.rename_common = rename_common | ||
|
||
def transform_subroutine(self, routine, **kwargs): | ||
role = kwargs['role'] | ||
if role == 'driver' or self.recurse_to_kernels: | ||
remove_duplicate_args_from_calls(routine, rename_common=self.rename_common) | ||
|
||
def remove_duplicate_args_from_calls(routine, rename_common=False): | ||
""" | ||
Utility to remove duplicate arguments from calls in :data:`routine` | ||
This updates the calls as well as the called routines. It requires calls | ||
to be enriched with interprocedural information. | ||
.. warning:: | ||
this won't work properly for multiple calls to the same routine | ||
with differing duplicate arguments | ||
Parameters | ||
---------- | ||
routine : :any:`Subroutine` | ||
The subroutine where calls should be transformed. | ||
rename_common : bool, optional | ||
Try to rename dummy arguments in called routines that received the same argument | ||
on the caller side, by finding a common name pattern in those names (Default: `False`). | ||
""" | ||
|
||
def remove_duplicate_args_call(call): | ||
arg_map = {} | ||
for routine_arg, call_arg in call.arg_iter(): | ||
arg_map.setdefault(call_arg, []).append(routine_arg) | ||
# filter duplicate kwargs (comparing to the other kwarguments) | ||
_new_kwargs = as_tuple(list(kw_vals)[0] for g, kw_vals in it.groupby(call.kwarguments, key=lambda x: x[1])) | ||
# filter duplicate kwargs (comparing to the arguments) | ||
new_kwargs = tuple(kwarg for kwarg in _new_kwargs if kwarg[1] not in call.arguments) | ||
# (filter duplicate arguments and) update call | ||
call._update(arguments=as_tuple(dict.fromkeys(call.arguments)), kwarguments=new_kwargs) | ||
return arg_map | ||
|
||
def modify_callee(callee, callee_arg_map): | ||
|
||
def allowed_rename(routine, rename): | ||
# check whether rename is already "used" in routine | ||
if rename in routine.arguments or rename in routine.variables: | ||
return False | ||
return True | ||
|
||
combine = [routine_args for call_arg, routine_args in callee_arg_map.items() if len(routine_args) > 1] | ||
if rename_common: | ||
matches = [ | ||
os.path.commonprefix([str(elem.name) for elem in args]).rstrip('_') or | ||
os.path.commonprefix([str(elem.name)[::-1] for elem in args]).rstrip('_')[::-1] | ||
for args in combine | ||
] | ||
rename_common_map = {c[0].name: m for c, m in zip(combine, matches) if m} | ||
# check whether found rename is already "used" in routine | ||
unallowed_renames = () | ||
for name, rename in rename_common_map.items(): | ||
if not allowed_rename(callee, rename): | ||
unallowed_renames += (name,) | ||
# and if already "used", remove and use instead default | ||
for key in unallowed_renames: | ||
del rename_common_map[key] | ||
else: | ||
rename_common_map = {} | ||
redundant = flatten([routine_args[1:] for routine_args in combine]) | ||
combine_map = {routine_args[0]: as_tuple(routine_args[1:]) for routine_args in combine} | ||
arg_map = {arg.name: rename_common_map.get(common_arg.name, common_arg.name) | ||
for common_arg, redundant_args in combine_map.items() for arg in redundant_args} | ||
# remove duplicates from callee.arguments | ||
new_routine_args = tuple(arg for arg in callee.arguments if arg not in redundant) | ||
# rename if common name is possible | ||
new_routine_args = as_tuple(arg.clone(name=rename_common_map[arg.name]) | ||
if arg.name in rename_common_map else arg for arg in new_routine_args) | ||
callee.arguments = new_routine_args | ||
|
||
# rename usage/occurences in callee.body | ||
var_map = {} | ||
variables = FindVariables(unique=False).visit(callee.body) | ||
var_map = {var: var.clone(name=arg_map[var.name]) for var in variables if var.name in arg_map} | ||
var_map.update({var: var.clone(name=rename_common_map[var.name]) for var in variables | ||
if var.name in rename_common_map}) | ||
callee.body = SubstituteExpressions(var_map).visit(callee.body) | ||
# modify the variable declarations, thus remove redundant variable declarations and possibly rename | ||
modify_variable_declarations(callee, remove_symbols=redundant, rename_symbols=rename_common_map) | ||
# store the information for possibly later renaming kwarguments on caller side | ||
return rename_common_map | ||
|
||
def rename_kwarguments(relevant_calls, rename_common_map_routine): | ||
for call in relevant_calls: | ||
kwarguments = call.kwarguments | ||
if kwarguments: | ||
call_name = str(call.routine.name).lower() | ||
new_kwargs = as_tuple((rename_common_map_routine[call_name][kw[0]], kw[1]) | ||
if kw[0] in rename_common_map_routine[call_name] else kw for kw in kwarguments) | ||
call._update(kwarguments=new_kwargs) | ||
|
||
calls = FindNodes(CallStatement).visit(routine.body) | ||
call_arg_map = {} | ||
relevant_calls = [] | ||
# adapt call statements (and remove duplicate args/kwargs) | ||
for call in calls: | ||
if call.routine is BasicType.DEFERRED: | ||
continue | ||
call_arg_map[call.routine] = remove_duplicate_args_call(call) | ||
relevant_calls.append(call) | ||
rename_common_map_routine = {} | ||
# modify/adapt callees | ||
for callee, callee_arg_map in call_arg_map.items(): | ||
rename_common_map_routine[str(callee.name).lower()] = modify_callee(callee, callee_arg_map) | ||
# handle possibly renamed kwarguments on caller side | ||
if rename_common: | ||
rename_kwarguments(relevant_calls, rename_common_map_routine) | ||
|
||
|
||
def modify_variable_declarations(routine, remove_symbols=(), rename_symbols=None): | ||
""" | ||
Utility to modify variable declarations by either removing symbols or renaming | ||
symbols. | ||
.. note:: | ||
This utility only works on the variable declarations itself and | ||
won't modify variable/symbol usages elsewhere! | ||
Parameters | ||
---------- | ||
routine : :any:`Subroutine` | ||
The subroutine to be transformed. | ||
remove_symbols : list, tuple | ||
List of symbols for which their declaration should be removed. | ||
rename_symbols : dict | ||
Dict/Map of symbols for which their declaration should be renamed. | ||
""" | ||
rename_symbols = rename_symbols if rename_symbols is not None else {} | ||
var_decls = FindNodes(VariableDeclaration).visit(routine.spec) | ||
remove_symbol_names = [var.name.lower() for var in remove_symbols] | ||
decl_map = {} | ||
already_declared = () | ||
for decl in var_decls: | ||
symbols = [symbol for symbol in decl.symbols if symbol.name.lower() not in remove_symbol_names] | ||
symbols = [symbol.clone(name=rename_symbols[symbol.name]) | ||
if symbol.name in rename_symbols else symbol for symbol in symbols] | ||
symbols = [symbol for symbol in symbols if not symbol.name.lower() in already_declared] | ||
already_declared += tuple(symbol.name.lower() for symbol in symbols) | ||
if symbols and symbols != decl.symbols: | ||
decl_map[decl] = decl.clone(symbols=as_tuple(symbols)) | ||
else: | ||
if not symbols: | ||
decl_map[decl] = None | ||
routine.spec = Transformer(decl_map).visit(routine.spec) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
# (C) Copyright 2018- ECMWF. | ||
# This software is licensed under the terms of the Apache Licence Version 2.0 | ||
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. | ||
# In applying this licence, ECMWF does not waive the privileges and immunities | ||
# granted to it by virtue of its status as an intergovernmental organisation | ||
# nor does it submit to any jurisdiction. | ||
|
||
import pytest | ||
|
||
from loki import Module, Subroutine | ||
from loki.frontend import available_frontends | ||
from loki.ir import FindNodes, CallStatement | ||
from loki.transformations.routine_signatures import RemoveDuplicateArgs | ||
|
||
@pytest.mark.parametrize('frontend', available_frontends()) | ||
@pytest.mark.parametrize('pass_as_kwarg', (True, False)) | ||
@pytest.mark.parametrize('recurse_to_kernels', (True, False)) | ||
@pytest.mark.parametrize('rename_common', (True, False)) | ||
def test_utilities_remove_duplicate_args(tmp_path, frontend, pass_as_kwarg, recurse_to_kernels, rename_common): | ||
""" | ||
Test lowering constant array indices | ||
""" | ||
fcode_driver = f""" | ||
subroutine driver(nlon,nlev,nb,var) | ||
use kernel_mod, only: kernel | ||
implicit none | ||
integer, intent(in) :: nlon,nlev,nb | ||
real, intent(inout) :: var(nlon,nlev,5,nb) | ||
integer :: ibl | ||
integer :: offset | ||
integer :: some_val | ||
integer :: loop_start, loop_end | ||
loop_start = 2 | ||
loop_end = nb | ||
some_val = 0 | ||
offset = 1 | ||
!$omp test | ||
do ibl=loop_start, loop_end | ||
call kernel(nlon,nlev, & | ||
& {'var1=' if pass_as_kwarg else ''}var(:,:,1,ibl),& | ||
& {'var2=' if pass_as_kwarg else ''}var(:,:,1,ibl),& | ||
& {'another_var=' if pass_as_kwarg else ''}var(:,:,2:5,ibl),& | ||
& {'icend=' if pass_as_kwarg else ''}offset,& | ||
& {'lstart=' if pass_as_kwarg else ''}loop_start,& | ||
& {'lend=' if pass_as_kwarg else ''}loop_end,& | ||
& {'kend=' if pass_as_kwarg else ''}nlev) | ||
call kernel(nlon,nlev, & | ||
& {'var1=' if pass_as_kwarg else ''}var(:,:,1,ibl),& | ||
& {'var2=' if pass_as_kwarg else ''}var(:,:,1,ibl),& | ||
& {'another_var=' if pass_as_kwarg else ''}var(:,:,2:5,ibl),& | ||
& {'icend=' if pass_as_kwarg else ''}offset,& | ||
& {'lstart=' if pass_as_kwarg else ''}loop_start,& | ||
& {'lend=' if pass_as_kwarg else ''}loop_end,& | ||
& {'kend=' if pass_as_kwarg else ''}nlev) | ||
enddo | ||
end subroutine driver | ||
""" | ||
|
||
fcode_kernel = """ | ||
module kernel_mod | ||
implicit none | ||
contains | ||
subroutine kernel(nlon,nlev,var1,var2,another_var,icend,lstart,lend,kend) | ||
use compute_mod, only: compute | ||
implicit none | ||
integer, intent(in) :: nlon,nlev,icend,lstart,lend,kend | ||
real, intent(inout) :: var1(nlon,nlev) | ||
real, intent(inout) :: var2(nlon,nlev) | ||
real, intent(inout) :: another_var(nlon,nlev,4) | ||
integer :: jk, jl, jt | ||
var1(:,:) = 0. | ||
do jk = 1,kend | ||
do jl = 1, nlon | ||
var1(jl, jk) = 0. | ||
var2(jl, jk) = 1.0 | ||
do jt= 1,4 | ||
another_var(jl, jk, jt) = 0.0 | ||
end do | ||
end do | ||
end do | ||
call compute(nlon,nlev,var1, var2) | ||
call compute(nlon,nlev,var1, var2) | ||
end subroutine kernel | ||
end module kernel_mod | ||
""" | ||
|
||
fcode_nested_kernel = """ | ||
module compute_mod | ||
implicit none | ||
contains | ||
subroutine compute(nlon,nlev,b_var,a_var) | ||
implicit none | ||
integer, intent(in) :: nlon,nlev | ||
real, intent(inout) :: b_var(nlon,nlev) | ||
real, intent(inout) :: a_var(nlon,nlev) | ||
real :: VAR ! create name clash on purpose (if rename_common) | ||
b_var(:,:) = 0. | ||
a_var(:,:) = 1.0 | ||
end subroutine compute | ||
end module compute_mod | ||
""" | ||
|
||
nested_kernel_mod = Module.from_source(fcode_nested_kernel, frontend=frontend, xmods=[tmp_path]) | ||
kernel_mod = Module.from_source(fcode_kernel, frontend=frontend, definitions=nested_kernel_mod, xmods=[tmp_path]) | ||
driver = Subroutine.from_source(fcode_driver, frontend=frontend, definitions=kernel_mod, xmods=[tmp_path]) | ||
|
||
transformation = RemoveDuplicateArgs(recurse_to_kernels=recurse_to_kernels, rename_common=rename_common) | ||
transformation.apply(driver, role='driver', targets=('kernel',)) | ||
transformation.apply(kernel_mod['kernel'], role='kernel', targets=('compute',)) | ||
transformation.apply(nested_kernel_mod['compute'], role='kernel') | ||
|
||
# driver | ||
kernel_var_name = 'var' if rename_common else 'var1' | ||
kernel_calls = FindNodes(CallStatement).visit(driver.body) | ||
for kernel_call in kernel_calls: | ||
if pass_as_kwarg: | ||
assert (kernel_var_name, 'var(:, :, 1, ibl)') in kernel_call.kwarguments | ||
assert ('var2', 'var(:, :, 1, ibl)') not in kernel_call.kwarguments | ||
arg1 = kernel_call.kwarguments[0][1] | ||
arg2 = kernel_call.kwarguments[1][1] | ||
else: | ||
assert 'var(:, :, 1, ibl)' in kernel_call.arguments | ||
assert 'var2(:, :, 1, ibl)' not in kernel_call.arguments | ||
arg1 = kernel_call.arguments[2] | ||
arg2 = kernel_call.arguments[3] | ||
assert arg1.dimensions == (':', ':', '1', 'ibl') | ||
assert arg2.dimensions == (':', ':', '2:5', 'ibl') | ||
# kernel | ||
kernel_vars = kernel_mod['kernel'].variable_map | ||
kernel_args = kernel_mod['kernel']._dummies | ||
assert kernel_var_name in kernel_args | ||
assert 'var2' not in kernel_args | ||
assert 'var2' not in kernel_vars | ||
assert kernel_vars[kernel_var_name].shape == ('nlon', 'nlev') | ||
assert kernel_vars['another_var'].dimensions == ('nlon', 'nlev', 4) | ||
compute_calls = FindNodes(CallStatement).visit(kernel_mod['kernel'].body) | ||
for compute_call in compute_calls: | ||
assert kernel_var_name in compute_call.arguments | ||
assert 'var2' not in compute_call.arguments | ||
# nested_kernel | ||
nested_kernel = nested_kernel_mod['compute'] | ||
nested_kernel_vars = nested_kernel.variable_map | ||
nested_kernel_args = [arg.name.lower() for arg in nested_kernel.arguments] | ||
# it's always 'b_var' as a rename would clash with the already "used" variable "var" | ||
nested_kernel_var_name = 'b_var' | ||
if recurse_to_kernels: | ||
assert nested_kernel_var_name in nested_kernel_args | ||
assert 'a_var' not in nested_kernel_args | ||
assert nested_kernel_var_name in nested_kernel_vars | ||
assert 'a_var' not in nested_kernel_vars | ||
else: | ||
assert 'b_var' in nested_kernel_args | ||
assert 'a_var' in nested_kernel_args | ||
assert 'b_var' in nested_kernel_vars | ||
assert 'a_var' in nested_kernel_vars |
Oops, something went wrong.