Skip to content

Commit

Permalink
RemoveDuplicateArgs: solution for possible name clashes if 'rename_co…
Browse files Browse the repository at this point in the history
…mmon'
  • Loading branch information
MichaelSt98 committed Oct 8, 2024
1 parent d16279d commit cfb121d
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 1 deletion.
4 changes: 3 additions & 1 deletion loki/transformations/tests/test_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,7 @@ def test_utilities_remove_duplicate_args(tmp_path, frontend, pass_as_kwarg, recu
integer, intent(in) :: nlon,nlev
real, intent(inout) :: b_var(nlon,nlev)
real, intent(inout) :: a_var(nlon,nlev)
real :: VAR ! create name clash on purpose (if rename_common)
b_var(:,:) = 0.
a_var(:,:) = 1.0
end subroutine compute
Expand Down Expand Up @@ -676,7 +677,8 @@ def test_utilities_remove_duplicate_args(tmp_path, frontend, pass_as_kwarg, recu
nested_kernel = nested_kernel_mod['compute']
nested_kernel_vars = nested_kernel.variable_map
nested_kernel_args = [arg.name.lower() for arg in nested_kernel.arguments]
nested_kernel_var_name = 'var' if rename_common else 'b_var'
# it's always 'b_var' as a rename would clash with the already "used" variable "var"
nested_kernel_var_name = 'b_var'
if recurse_to_kernels:
assert nested_kernel_var_name in nested_kernel_args
assert 'a_var' not in nested_kernel_args
Expand Down
15 changes: 15 additions & 0 deletions loki/transformations/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,13 @@ def remove_duplicate_args_call(call):
return arg_map

def modify_callee(callee, callee_arg_map):

def allowed_rename(routine, rename):
# check whether rename is already "used" in routine
if rename in routine.arguments or rename in routine.variables:
return False
return True

combine = [routine_args for call_arg, routine_args in callee_arg_map.items() if len(routine_args) > 1]
if rename_common:
matches = [
Expand All @@ -112,6 +119,14 @@ def modify_callee(callee, callee_arg_map):
for args in combine
]
rename_common_map = {c[0].name: m for c, m in zip(combine, matches) if m}
# check whether found rename is already "used" in routine
unallowed_renames = ()
for name, rename in rename_common_map.items():
if not allowed_rename(callee, rename):
unallowed_renames += (name,)
# and if already "used", remove and use instead default
for key in unallowed_renames:
del rename_common_map[key]
else:
rename_common_map = {}
redundant = flatten([routine_args[1:] for routine_args in combine])
Expand Down

0 comments on commit cfb121d

Please sign in to comment.