Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow inlining of Statement Functions #345

Merged
merged 7 commits into from
Jul 24, 2024
67 changes: 60 additions & 7 deletions loki/transformations/inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from loki.batch import Transformation
from loki.ir import (
Import, Comment, Assignment, VariableDeclaration, CallStatement,
Transformer, FindNodes, pragmas_attached, is_loki_pragma, Interface
Transformer, FindNodes, pragmas_attached, is_loki_pragma, Interface,
StatementFunction
)
from loki.expression import (
symbols as sym, FindVariables, FindInlineCalls, FindLiterals,
Expand All @@ -36,7 +37,8 @@
__all__ = [
'inline_constant_parameters', 'inline_elemental_functions',
'inline_internal_procedures', 'inline_member_procedures',
'inline_marked_subroutines', 'InlineTransformation'
'inline_marked_subroutines', 'InlineTransformation',
'inline_statement_functions'
]


Expand All @@ -54,6 +56,10 @@ class InlineTransformation(Transformation):
Replaces :any:`InlineCall` expression to elemental functions
with the called function's body (see :any:`inline_elemental_functions`);
default: True.
inline_stmt_funcs: bool
Replaces :any:`InlineCall` expression to statement functions
with the corresponding rhs of the statement function if
the statement function declaration is available; default: False.
inline_internals : bool
Inline internal procedure (see :any:`inline_internal_procedures`);
default: False.
Expand Down Expand Up @@ -84,13 +90,14 @@ class InlineTransformation(Transformation):

def __init__(
self, inline_constants=False, inline_elementals=True,
inline_internals=False, inline_marked=True,
remove_dead_code=True, allowed_aliases=None,
adjust_imports=True, external_only=True,
resolve_sequence_association=False
inline_stmt_funcs=False, inline_internals=False,
inline_marked=True, remove_dead_code=True,
allowed_aliases=None, adjust_imports=True,
external_only=True, resolve_sequence_association=False
):
self.inline_constants = inline_constants
self.inline_elementals = inline_elementals
self.inline_stmt_funcs = inline_stmt_funcs
self.inline_internals = inline_internals
self.inline_marked = inline_marked
self.remove_dead_code = remove_dead_code
Expand Down Expand Up @@ -121,6 +128,10 @@ def transform_subroutine(self, routine, **kwargs):
if self.inline_elementals:
inline_elemental_functions(routine)

# Inline Statement Functions
if self.inline_stmt_funcs:
inline_statement_functions(routine)

Comment on lines +131 to +134
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code path is currently untested. Can we add a test where InlineTransformation is called to apply the statement function inlining?

# Inline internal (contained) procedures
if self.inline_internals:
inline_internal_procedures(routine, allowed_aliases=self.allowed_aliases)
Expand Down Expand Up @@ -184,11 +195,19 @@ def map_inline_call(self, expr, *args, **kwargs):
# We still need to recurse and ensure re-scoping
return super().map_inline_call(expr, *args, **kwargs)

# if it is an inline call to a Statement Function
if isinstance(expr.routine, StatementFunction):
function = expr.routine
# Substitute all arguments through the elemental body
arg_map = dict(expr.arg_iter())
fbody = SubstituteExpressions(arg_map).visit(function.rhs)
return fbody

function = expr.procedure_type.procedure
v_result = [v for v in function.variables if v == function.name][0]

# Substitute all arguments through the elemental body
arg_map = dict(zip(function.arguments, expr.parameters))
arg_map = dict(expr.arg_iter())
fbody = SubstituteExpressions(arg_map).visit(function.body)

# Extract the RHS of the final result variable assignment
Expand Down Expand Up @@ -338,6 +357,40 @@ def inline_elemental_functions(routine):
import_map[im] = None
routine.spec = Transformer(import_map).visit(routine.spec)

def inline_statement_functions(routine):
"""
Replaces :any:`InlineCall` expression to statement functions with the
called statement functions rhs.
"""
# Keep track of removed symbols
removed_functions = set()

stmt_func_decls = FindNodes(StatementFunction).visit(routine.spec)
exprmap = {}
for call in FindInlineCalls().visit(routine.body):
proc_type = call.procedure_type
if proc_type is BasicType.DEFERRED:
continue
if proc_type.is_function and isinstance(call.routine, StatementFunction):
exprmap[call] = InlineSubstitutionMapper()(call, scope=routine)
removed_functions.add(call.routine)
# Apply the map to itself to handle nested statement function calls
exprmap = recursive_expression_map_update(exprmap, max_iterations=10, mapper_cls=InlineSubstitutionMapper)
# Apply expression-level substitution to routine
routine.body = SubstituteExpressions(exprmap).visit(routine.body)

# remove statement function declarations as well as statement function argument(s) declarations
vars_to_remove = {stmt_func.variable.name.lower() for stmt_func in stmt_func_decls}
vars_to_remove |= {arg.name.lower() for stmt_func in stmt_func_decls for arg in stmt_func.arguments}
spec_map = {stmt_func: None for stmt_func in stmt_func_decls}
for decl in routine.declarations:
if any(var in vars_to_remove for var in decl.symbols):
symbols = tuple(var for var in decl.symbols if var not in vars_to_remove)
if symbols:
decl._update(symbols=symbols)
else:
spec_map[decl] = None
routine.spec = Transformer(spec_map).visit(routine.spec)

def map_call_to_procedure_body(call, caller):
"""
Expand Down
71 changes: 64 additions & 7 deletions loki/transformations/tests/test_inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@
from loki.transformations.inline import (
inline_elemental_functions, inline_constant_parameters,
inline_member_procedures, inline_marked_subroutines,
InlineTransformation,
inline_statement_functions, InlineTransformation,
)
from loki.transformations.sanitise import ResolveAssociatesTransformer
from loki.transformations.utilities import replace_selected_kind

# pylint: disable=too-many-lines


@pytest.fixture(name='builder')
def fixture_builder(tmp_path):
Expand Down Expand Up @@ -700,6 +702,53 @@ def test_inline_member_routines_with_associate(frontend):
assert len(assocs) == 2


@pytest.mark.parametrize('frontend', available_frontends(
skip={OFP: "OFP apparently has problems dealing with those Statement Functions",
OMNI: "OMNI automatically inlines Statement Functions"}
))
@pytest.mark.parametrize('stmt_decls', (True, False))
def test_inline_statement_functions(frontend, stmt_decls):
stmt_decls_code = """
real :: PTARE
real :: FOEDELTA
FOEDELTA ( PTARE ) = PTARE + 1.0
real :: FOEEW
FOEEW ( PTARE ) = PTARE + FOEDELTA(PTARE)
""".strip()

fcode = f"""
subroutine stmt_func(arr, ret)
implicit none
real, intent(in) :: arr(:)
real, intent(inout) :: ret(:)
real :: ret2
real, parameter :: rtt = 1.0
{stmt_decls_code if stmt_decls else '#include "fcttre.func.h"'}
ret = foeew(arr)
ret2 = foedelta(3.0)
end subroutine stmt_func
""".strip()

routine = Subroutine.from_source(fcode, frontend=frontend)
if stmt_decls:
assert FindNodes(ir.StatementFunction).visit(routine.spec)
else:
assert not FindNodes(ir.StatementFunction).visit(routine.spec)
assert FindInlineCalls().visit(routine.body)
inline_statement_functions(routine)

assert not FindNodes(ir.StatementFunction).visit(routine.spec)
if stmt_decls:
assert not FindInlineCalls().visit(routine.body)
assignments = FindNodes(ir.Assignment).visit(routine.body)
assert assignments[0].lhs == 'ret'
assert assignments[0].rhs == "arr + arr + 1.0"
assert assignments[1].lhs == 'ret2'
assert assignments[1].rhs == "3.0 + 1.0"
else:
assert FindInlineCalls().visit(routine.body)

@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('adjust_imports', [True, False])
def test_inline_marked_subroutines(frontend, adjust_imports, tmp_path):
Expand Down Expand Up @@ -1095,7 +1144,8 @@ def test_inline_marked_subroutines_declarations(frontend, tmp_path):
@pytest.mark.parametrize('frontend', available_frontends(
(OFP, 'Prefix/elemental support not implemented'))
)
def test_inline_transformation(frontend, tmp_path):
@pytest.mark.parametrize('pass_as_kwarg', (False, True))
def test_inline_transformation(tmp_path, frontend, pass_as_kwarg):
"""Test combining recursive inlining via :any:`InliningTransformation`."""

fcode_module = """
Expand Down Expand Up @@ -1125,25 +1175,30 @@ def test_inline_transformation(frontend, tmp_path):
end subroutine add_one_and_two
"""

fcode = """
fcode = f"""
subroutine test_inline_pragma(a, b)
implicit none
real(kind=8), intent(inout) :: a(3), b(3)
integer, parameter :: n = 3
integer :: i
real :: stmt_arg
real :: some_stmt_func
some_stmt_func ( stmt_arg ) = stmt_arg + 3.1415
#include "add_one_and_two.intfb.h"
do i=1, n
!$loki inline
call add_one_and_two(a(i))
call add_one_and_two({'a=' if pass_as_kwarg else ''}a(i))
end do
do i=1, n
!$loki inline
call add_one_and_two(b(i))
call add_one_and_two({'a=' if pass_as_kwarg else ''}b(i))
end do
a(1) = some_stmt_func({'stmt_arg=' if pass_as_kwarg else ''}a(2))
end subroutine test_inline_pragma
"""
module = Module.from_source(fcode_module, frontend=frontend, xmods=[tmp_path])
Expand All @@ -1152,7 +1207,8 @@ def test_inline_transformation(frontend, tmp_path):
routine.enrich(inner)

trafo = InlineTransformation(
inline_constants=True, external_only=True, inline_elementals=True
inline_constants=True, external_only=True, inline_elementals=True,
inline_stmt_funcs=True
)

calls = FindNodes(ir.CallStatement).visit(routine.body)
Expand All @@ -1173,11 +1229,12 @@ def test_inline_transformation(frontend, tmp_path):
calls = FindNodes(ir.CallStatement).visit(routine.body)
assert len(calls) == 0
assigns = FindNodes(ir.Assignment).visit(routine.body)
assert len(assigns) == 4
assert len(assigns) == 5
assert assigns[0].lhs == 'a(i)' and assigns[0].rhs == 'a(i) + 1.0'
assert assigns[1].lhs == 'a(i)' and assigns[1].rhs == 'a(i) + 2.0'
assert assigns[2].lhs == 'b(i)' and assigns[2].rhs == 'b(i) + 1.0'
assert assigns[3].lhs == 'b(i)' and assigns[3].rhs == 'b(i) + 2.0'
assert assigns[4].lhs == 'a(1)' and assigns[4].rhs == 'a(2) + 3.1415'


@pytest.mark.parametrize('frontend', available_frontends())
Expand Down
6 changes: 4 additions & 2 deletions loki/transformations/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ def replace_selected_kind(routine):
routine.spec.prepend(imprt)


def recursive_expression_map_update(expr_map, max_iterations=10):
def recursive_expression_map_update(expr_map, max_iterations=10, mapper_cls=SubstituteExpressionsMapper):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No action: Very neat! 👏

"""
Utility function to apply a substitution map for expressions to itself
Expand All @@ -490,6 +490,8 @@ def recursive_expression_map_update(expr_map, max_iterations=10):
max_iterations : int
Maximum number of iterations, corresponds to the maximum level of
nesting that can be replaced.
mapper_cls: :any:`SubstituteExpressionsMapper`
The underlying mapper to be used (default: :any:`SubstituteExpressionsMapper`).
"""
def apply_to_init_arg(name, arg, expr, mapper):
# Helper utility to apply the mapper only to expression arguments and
Expand All @@ -504,7 +506,7 @@ def apply_to_init_arg(name, arg, expr, mapper):
# We update the expression map by applying it to the children of each replacement
# node, thus making sure node replacements are also applied to nested attributes,
# e.g. call arguments or array subscripts etc.
mapper = SubstituteExpressionsMapper(expr_map)
mapper = mapper_cls(expr_map)
prev_map, expr_map = expr_map, {
expr: type(replacement)(**{
name: apply_to_init_arg(name, arg, expr, mapper)
Expand Down
Loading