Skip to content

Commit

Permalink
introduce flag to allow removing all derived types from procedure sig…
Browse files Browse the repository at this point in the history
…natures (via replacing by its members)
  • Loading branch information
MichaelSt98 committed Jan 22, 2024
1 parent c86b83a commit bdf3993
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 3 deletions.
79 changes: 79 additions & 0 deletions transformations/tests/test_transform_derived_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,85 @@ def test_transform_derived_type_arguments_analysis(frontend):
assert member.type.dtype != BasicType.DEFERRED


@pytest.mark.parametrize('all_derived_types', (False, True))
@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_derived_type_arguments_expansion_trivial_derived_type(frontend, all_derived_types):
fcode = """
module transform_derived_type_arguments_mod
implicit none
type some_derived_type
real :: a
real :: b
end type some_derived_type
contains
subroutine caller(z)
integer, intent(in) :: z
type(some_derived_type) :: t_io
type(some_derived_type) :: t_in, t_out
integer :: m, n
integer :: i, j
m = 100
n = 10
t_in%a = real(m-1)
t_in%b = real(n-1)
call kernel(m, n, t_io%a, t_io%b, t_in, t_out)
end subroutine caller
subroutine kernel(m, n, P_a, P_b, Q, R)
integer , intent(in) :: m, n
real, intent(inout) :: P_a, P_b
type(some_derived_type), intent(in) :: Q
type(some_derived_type), intent(out) :: R
integer :: j, k
R%a = P_a + Q%a
R%b = P_b - Q%b
end subroutine kernel
end module transform_derived_type_arguments_mod
""".strip()

source = Sourcefile.from_source(fcode, frontend=frontend)

call_tree = [
SubroutineItem(name='transform_derived_type_arguments_mod#caller', source=source, config={'role': 'driver'}),
SubroutineItem(name='transform_derived_type_arguments_mod#kernel', source=source, config={'role': 'kernel'}),
]

# Apply transformation
transformation = DerivedTypeArgumentsTransformation(all_derived_types=all_derived_types)
for item, successor in reversed(list(zip_longest(call_tree, call_tree[1:]))):
transformation.apply(item.routine, role=item.role, item=item, successors=as_tuple(successor))

# all derived types, disregarding whether the derived type has pointer/allocatable/derived type members or not
if all_derived_types:
call_args = ('m', 'n', 't_io%a', 't_io%b', 't_in%a', 't_in%b', 't_out%a', 't_out%b')
kernel_args = ('m', 'n', 'P_a', 'P_b', 'Q_a', 'Q_b', 'R_a', 'R_b')
# only the derived type(s) with pointer/allocatable/derived type members, thus no changes expected!
else:
call_args = ('m', 'n', 't_io%a', 't_io%b', 't_in', 't_out')
kernel_args = ('m', 'n', 'P_a', 'P_b', 'Q', 'R')

call = FindNodes(CallStatement).visit(source['caller'].ir)[0]
assert call.name == 'kernel'
assert call.arguments == call_args
assert source['kernel'].arguments == kernel_args
assert all(v.type.intent for v in source['kernel'].arguments)

# Make sure rescoping hasn't accidentally overwritten the
# type information for local variables that have the same name
# as the shape of another variable
assert source['caller'].variable_map['m'].type.intent is None
assert source['caller'].variable_map['n'].type.intent is None


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_derived_type_arguments_expansion(frontend):
fcode = f"""
Expand Down
18 changes: 15 additions & 3 deletions transformations/transformations/derived_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
class DerivedTypeArgumentsTransformation(Transformation):
"""
Remove derived types from procedure signatures by replacing the
relevant derived type arguments by its member variables
(relevant) derived type arguments by its member variables
.. note::
This transformation requires a Scheduler traversal that
Expand All @@ -47,11 +47,22 @@ class DerivedTypeArgumentsTransformation(Transformation):
the routine. The information about the expansion map is stored
in the :any:`Item`'s ``trafo_data``.
See :meth:`expand_derived_args_kernel` for more information.
Parameters
----------
all_derived_types : bool, optional
Whether to remove all derived types from procedure signatures by
replacing the derived type arguments using its member variables or
only the "relevant" ones, referring to derived types with array
members or nested derived types (default: `False`).
key : str, optional
Overwrite the key that is used to store analysis results in ``trafo_data``.
"""

_key = 'DerivedTypeArgumentsTransformation'

def __init__(self, key=None, **kwargs):
def __init__(self, all_derived_types=False, key=None, **kwargs):
self.all_derived_types = all_derived_types
if key is not None:
self._key = key
super().__init__(**kwargs)
Expand Down Expand Up @@ -270,9 +281,10 @@ def expand_derived_args_kernel(self, routine):
candidates = []
for arg in routine.arguments:
if isinstance(arg.type.dtype, DerivedType):
if any(v.type.pointer or v.type.allocatable or
if self.all_derived_types or any(v.type.pointer or v.type.allocatable or
isinstance(v.type.dtype, DerivedType) for v in as_tuple(arg.variables)):
# Only include derived types with array members or nested derived types
# unless self.all_derived_types is True
candidates += [arg]

# Inspect all derived type member use and determine their expansion
Expand Down

0 comments on commit bdf3993

Please sign in to comment.