Skip to content

Commit

Permalink
DerivedTypeArgumentsTransformation using 'reverse_traversal=True' wit…
Browse files Browse the repository at this point in the history
…h test using the scheduler
  • Loading branch information
MichaelSt98 committed Mar 5, 2024
1 parent 2520c10 commit 6b8c210
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 0 deletions.
24 changes: 24 additions & 0 deletions transformations/tests/sources/projDerivedTypes/driver_mod.F90
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
module driver_mod

use some_derived_type_mod, only: some_derived_type
use kernel_mod, only: kernel
implicit none

contains
subroutine driver(z)
integer, intent(in) :: z
type(some_derived_type) :: t_io
type(some_derived_type) :: t_in, t_out
integer :: m, n
integer :: i, j

m = 100
n = 10

t_in%a = real(m-1)
t_in%b = real(n-1)

call kernel(m, n, t_io%a, t_io%b, t_in, t_out)

end subroutine driver
end module driver_mod
19 changes: 19 additions & 0 deletions transformations/tests/sources/projDerivedTypes/kernel_mod.F90
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
module kernel_mod

use some_derived_type_mod, only: some_derived_type
implicit none

contains

subroutine kernel(m, n, P_a, P_b, Q, R)
integer , intent(in) :: m, n
real, intent(inout) :: P_a, P_b
type(some_derived_type), intent(in) :: Q
type(some_derived_type), intent(out) :: R
integer :: j, k

R%a = P_a + Q%a
R%b = P_b - Q%b
end subroutine kernel

end module kernel_mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
module some_derived_type_mod

implicit none

type some_derived_type
real :: a
real :: b
end type some_derived_type

end module some_derived_type_mod
44 changes: 44 additions & 0 deletions transformations/tests/test_transform_derived_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from pathlib import Path
from itertools import zip_longest
from shutil import rmtree
import pytest
Expand All @@ -20,6 +21,11 @@
DerivedTypeArgumentsTransformation,
TypeboundProcedureCallTransformation
)
#pylint: disable=too-many-lines

@pytest.fixture(scope='module', name='here')
def fixture_here():
return Path(__file__).parent


@pytest.fixture(name='config')
Expand Down Expand Up @@ -179,6 +185,44 @@ def test_transform_derived_type_arguments_expansion_trivial_derived_type(fronten
assert source['caller'].variable_map['n'].type.intent is None


@pytest.mark.parametrize('all_derived_types', (False, True))
@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_derived_type_arguments_expansion_trivial_derived_type_scheduler(frontend, all_derived_types,
config, here):

proj = here / 'sources/projDerivedTypes'

scheduler = Scheduler(paths=[proj], config=config, seed_routines=['driver'], frontend=frontend)

# Apply transformation
transformation = DerivedTypeArgumentsTransformation(all_derived_types=all_derived_types)
scheduler.process(transformation=transformation)

# all derived types, disregarding whether the derived type has pointer/allocatable/derived type members or not
if all_derived_types:
call_args = ('m', 'n', 't_io%a', 't_io%b', 't_in%a', 't_in%b', 't_out%a', 't_out%b')
kernel_args = ('m', 'n', 'P_a', 'P_b', 'Q_a', 'Q_b', 'R_a', 'R_b')
# only the derived type(s) with pointer/allocatable/derived type members, thus no changes expected!
else:
call_args = ('m', 'n', 't_io%a', 't_io%b', 't_in', 't_out')
kernel_args = ('m', 'n', 'P_a', 'P_b', 'Q', 'R')

driver = scheduler["driver_mod#driver"].ir
kernel = scheduler["kernel_mod#kernel"].ir
calls = FindNodes(CallStatement).visit(driver.body)
call = calls[0]
assert call.name == 'kernel'
assert call.arguments == call_args
assert kernel.arguments == kernel_args
assert all(v.type.intent for v in kernel.arguments)

# Make sure rescoping hasn't accidentally overwritten the
# type information for local variables that have the same name
# as the shape of another variable
assert driver.variable_map['m'].type.intent is None
assert driver.variable_map['n'].type.intent is None


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_derived_type_arguments_expansion(frontend):
fcode = f"""
Expand Down
4 changes: 4 additions & 0 deletions transformations/transformations/derived_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ class DerivedTypeArgumentsTransformation(Transformation):
"""

_key = 'DerivedTypeArgumentsTransformation'
"""Default identifier for trafo_data entry"""

reverse_traversal = True
"""Traversal from the leaves upwards"""

def __init__(self, all_derived_types=False, key=None, **kwargs):
self.all_derived_types = all_derived_types
Expand Down

0 comments on commit 6b8c210

Please sign in to comment.