From ae9d15bff9ae800e5dc13e0d232f8df8fbddc428 Mon Sep 17 00:00:00 2001 From: Michael Staneker Date: Tue, 5 Mar 2024 10:14:34 +0000 Subject: [PATCH] DerivedTypeArgumentsTransformation using 'reverse_traversal=True' with test using the scheduler --- .../sources/projDerivedTypes/driver_mod.F90 | 24 ++++++++++ .../sources/projDerivedTypes/kernel_mod.F90 | 19 ++++++++ .../projDerivedTypes/some_derived_type.F90 | 10 +++++ .../tests/test_transform_derived_types.py | 44 +++++++++++++++++++ .../transformations/derived_types.py | 4 ++ 5 files changed, 101 insertions(+) create mode 100644 transformations/tests/sources/projDerivedTypes/driver_mod.F90 create mode 100644 transformations/tests/sources/projDerivedTypes/kernel_mod.F90 create mode 100644 transformations/tests/sources/projDerivedTypes/some_derived_type.F90 diff --git a/transformations/tests/sources/projDerivedTypes/driver_mod.F90 b/transformations/tests/sources/projDerivedTypes/driver_mod.F90 new file mode 100644 index 000000000..cc95a3ba9 --- /dev/null +++ b/transformations/tests/sources/projDerivedTypes/driver_mod.F90 @@ -0,0 +1,24 @@ +module driver_mod + +use some_derived_type_mod, only: some_derived_type +use kernel_mod, only: kernel +implicit none + +contains + subroutine driver(z) + integer, intent(in) :: z + type(some_derived_type) :: t_io + type(some_derived_type) :: t_in, t_out + integer :: m, n + integer :: i, j + + m = 100 + n = 10 + + t_in%a = real(m-1) + t_in%b = real(n-1) + + call kernel(m, n, t_io%a, t_io%b, t_in, t_out) + + end subroutine driver +end module driver_mod diff --git a/transformations/tests/sources/projDerivedTypes/kernel_mod.F90 b/transformations/tests/sources/projDerivedTypes/kernel_mod.F90 new file mode 100644 index 000000000..05f60eb2b --- /dev/null +++ b/transformations/tests/sources/projDerivedTypes/kernel_mod.F90 @@ -0,0 +1,19 @@ +module kernel_mod + +use some_derived_type_mod, only: some_derived_type +implicit none + +contains + + subroutine kernel(m, n, P_a, P_b, Q, R) + integer , intent(in) :: m, n + real, intent(inout) :: P_a, P_b + type(some_derived_type), intent(in) :: Q + type(some_derived_type), intent(out) :: R + integer :: j, k + + R%a = P_a + Q%a + R%b = P_b - Q%b + end subroutine kernel + +end module kernel_mod diff --git a/transformations/tests/sources/projDerivedTypes/some_derived_type.F90 b/transformations/tests/sources/projDerivedTypes/some_derived_type.F90 new file mode 100644 index 000000000..0cf82f4c4 --- /dev/null +++ b/transformations/tests/sources/projDerivedTypes/some_derived_type.F90 @@ -0,0 +1,10 @@ +module some_derived_type_mod + +implicit none + + type some_derived_type + real :: a + real :: b + end type some_derived_type + +end module some_derived_type_mod diff --git a/transformations/tests/test_transform_derived_types.py b/transformations/tests/test_transform_derived_types.py index 2519c6b56..8cd21899e 100644 --- a/transformations/tests/test_transform_derived_types.py +++ b/transformations/tests/test_transform_derived_types.py @@ -5,6 +5,7 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +from pathlib import Path from itertools import zip_longest from shutil import rmtree import pytest @@ -20,6 +21,11 @@ DerivedTypeArgumentsTransformation, TypeboundProcedureCallTransformation ) +#pylint: disable=too-many-lines + +@pytest.fixture(scope='module', name='here') +def fixture_here(): + return Path(__file__).parent @pytest.fixture(name='config') @@ -180,6 +186,44 @@ def test_transform_derived_type_arguments_expansion_trivial_derived_type(fronten assert source['caller'].variable_map['n'].type.intent is None +@pytest.mark.parametrize('all_derived_types', (False, True)) +@pytest.mark.parametrize('frontend', available_frontends()) +def test_transform_derived_type_arguments_expansion_trivial_derived_type_scheduler(frontend, all_derived_types, + config, here): + + proj = here / 'sources/projDerivedTypes' + + scheduler = Scheduler(paths=[proj], config=config, seed_routines=['driver'], frontend=frontend) + + # Apply transformation + transformation = DerivedTypeArgumentsTransformation(all_derived_types=all_derived_types) + scheduler.process(transformation=transformation) + + # all derived types, disregarding whether the derived type has pointer/allocatable/derived type members or not + if all_derived_types: + call_args = ('m', 'n', 't_io%a', 't_io%b', 't_in%a', 't_in%b', 't_out%a', 't_out%b') + kernel_args = ('m', 'n', 'P_a', 'P_b', 'Q_a', 'Q_b', 'R_a', 'R_b') + # only the derived type(s) with pointer/allocatable/derived type members, thus no changes expected! + else: + call_args = ('m', 'n', 't_io%a', 't_io%b', 't_in', 't_out') + kernel_args = ('m', 'n', 'P_a', 'P_b', 'Q', 'R') + + driver = scheduler["driver_mod#driver"].ir + kernel = scheduler["kernel_mod#kernel"].ir + calls = FindNodes(CallStatement).visit(driver.body) + call = calls[0] + assert call.name == 'kernel' + assert call.arguments == call_args + assert kernel.arguments == kernel_args + assert all(v.type.intent for v in kernel.arguments) + + # Make sure rescoping hasn't accidentally overwritten the + # type information for local variables that have the same name + # as the shape of another variable + assert driver.variable_map['m'].type.intent is None + assert driver.variable_map['n'].type.intent is None + + @pytest.mark.parametrize('frontend', available_frontends()) def test_transform_derived_type_arguments_expansion(frontend): fcode = f""" diff --git a/transformations/transformations/derived_types.py b/transformations/transformations/derived_types.py index 769a6e962..9f11cfc7c 100644 --- a/transformations/transformations/derived_types.py +++ b/transformations/transformations/derived_types.py @@ -60,6 +60,10 @@ class DerivedTypeArgumentsTransformation(Transformation): """ _key = 'DerivedTypeArgumentsTransformation' + """Default identifier for trafo_data entry""" + + reverse_traversal = True + """Traversal from the leaves upwards""" def __init__(self, all_derived_types=False, key=None, **kwargs): self.all_derived_types = all_derived_types