Skip to content

Commit

Permalink
Merge pull request #345 from ecmwf-ifs/nams-inline-stmt-funcs
Browse files Browse the repository at this point in the history
Allow inlining of Statement Functions
  • Loading branch information
reuterbal authored Jul 24, 2024
2 parents 78e7605 + a29f729 commit 7ca9dd8
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 16 deletions.
67 changes: 60 additions & 7 deletions loki/transformations/inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from loki.batch import Transformation
from loki.ir import (
Import, Comment, Assignment, VariableDeclaration, CallStatement,
Transformer, FindNodes, pragmas_attached, is_loki_pragma, Interface
Transformer, FindNodes, pragmas_attached, is_loki_pragma, Interface,
StatementFunction
)
from loki.expression import (
symbols as sym, FindVariables, FindInlineCalls, FindLiterals,
Expand All @@ -36,7 +37,8 @@
__all__ = [
'inline_constant_parameters', 'inline_elemental_functions',
'inline_internal_procedures', 'inline_member_procedures',
'inline_marked_subroutines', 'InlineTransformation'
'inline_marked_subroutines', 'InlineTransformation',
'inline_statement_functions'
]


Expand All @@ -54,6 +56,10 @@ class InlineTransformation(Transformation):
Replaces :any:`InlineCall` expression to elemental functions
with the called function's body (see :any:`inline_elemental_functions`);
default: True.
inline_stmt_funcs: bool
Replaces :any:`InlineCall` expression to statement functions
with the corresponding rhs of the statement function if
the statement function declaration is available; default: False.
inline_internals : bool
Inline internal procedure (see :any:`inline_internal_procedures`);
default: False.
Expand Down Expand Up @@ -84,13 +90,14 @@ class InlineTransformation(Transformation):

def __init__(
self, inline_constants=False, inline_elementals=True,
inline_internals=False, inline_marked=True,
remove_dead_code=True, allowed_aliases=None,
adjust_imports=True, external_only=True,
resolve_sequence_association=False
inline_stmt_funcs=False, inline_internals=False,
inline_marked=True, remove_dead_code=True,
allowed_aliases=None, adjust_imports=True,
external_only=True, resolve_sequence_association=False
):
self.inline_constants = inline_constants
self.inline_elementals = inline_elementals
self.inline_stmt_funcs = inline_stmt_funcs
self.inline_internals = inline_internals
self.inline_marked = inline_marked
self.remove_dead_code = remove_dead_code
Expand Down Expand Up @@ -121,6 +128,10 @@ def transform_subroutine(self, routine, **kwargs):
if self.inline_elementals:
inline_elemental_functions(routine)

# Inline Statement Functions
if self.inline_stmt_funcs:
inline_statement_functions(routine)

# Inline internal (contained) procedures
if self.inline_internals:
inline_internal_procedures(routine, allowed_aliases=self.allowed_aliases)
Expand Down Expand Up @@ -184,11 +195,19 @@ def map_inline_call(self, expr, *args, **kwargs):
# We still need to recurse and ensure re-scoping
return super().map_inline_call(expr, *args, **kwargs)

# if it is an inline call to a Statement Function
if isinstance(expr.routine, StatementFunction):
function = expr.routine
# Substitute all arguments through the elemental body
arg_map = dict(expr.arg_iter())
fbody = SubstituteExpressions(arg_map).visit(function.rhs)
return fbody

function = expr.procedure_type.procedure
v_result = [v for v in function.variables if v == function.name][0]

# Substitute all arguments through the elemental body
arg_map = dict(zip(function.arguments, expr.parameters))
arg_map = dict(expr.arg_iter())
fbody = SubstituteExpressions(arg_map).visit(function.body)

# Extract the RHS of the final result variable assignment
Expand Down Expand Up @@ -338,6 +357,40 @@ def inline_elemental_functions(routine):
import_map[im] = None
routine.spec = Transformer(import_map).visit(routine.spec)

def inline_statement_functions(routine):
"""
Replaces :any:`InlineCall` expression to statement functions with the
called statement functions rhs.
"""
# Keep track of removed symbols
removed_functions = set()

stmt_func_decls = FindNodes(StatementFunction).visit(routine.spec)
exprmap = {}
for call in FindInlineCalls().visit(routine.body):
proc_type = call.procedure_type
if proc_type is BasicType.DEFERRED:
continue
if proc_type.is_function and isinstance(call.routine, StatementFunction):
exprmap[call] = InlineSubstitutionMapper()(call, scope=routine)
removed_functions.add(call.routine)
# Apply the map to itself to handle nested statement function calls
exprmap = recursive_expression_map_update(exprmap, max_iterations=10, mapper_cls=InlineSubstitutionMapper)
# Apply expression-level substitution to routine
routine.body = SubstituteExpressions(exprmap).visit(routine.body)

# remove statement function declarations as well as statement function argument(s) declarations
vars_to_remove = {stmt_func.variable.name.lower() for stmt_func in stmt_func_decls}
vars_to_remove |= {arg.name.lower() for stmt_func in stmt_func_decls for arg in stmt_func.arguments}
spec_map = {stmt_func: None for stmt_func in stmt_func_decls}
for decl in routine.declarations:
if any(var in vars_to_remove for var in decl.symbols):
symbols = tuple(var for var in decl.symbols if var not in vars_to_remove)
if symbols:
decl._update(symbols=symbols)
else:
spec_map[decl] = None
routine.spec = Transformer(spec_map).visit(routine.spec)

def map_call_to_procedure_body(call, caller):
"""
Expand Down
71 changes: 64 additions & 7 deletions loki/transformations/tests/test_inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@
from loki.transformations.inline import (
inline_elemental_functions, inline_constant_parameters,
inline_member_procedures, inline_marked_subroutines,
InlineTransformation,
inline_statement_functions, InlineTransformation,
)
from loki.transformations.sanitise import ResolveAssociatesTransformer
from loki.transformations.utilities import replace_selected_kind

# pylint: disable=too-many-lines


@pytest.fixture(name='builder')
def fixture_builder(tmp_path):
Expand Down Expand Up @@ -700,6 +702,53 @@ def test_inline_member_routines_with_associate(frontend):
assert len(assocs) == 2


@pytest.mark.parametrize('frontend', available_frontends(
skip={OFP: "OFP apparently has problems dealing with those Statement Functions",
OMNI: "OMNI automatically inlines Statement Functions"}
))
@pytest.mark.parametrize('stmt_decls', (True, False))
def test_inline_statement_functions(frontend, stmt_decls):
stmt_decls_code = """
real :: PTARE
real :: FOEDELTA
FOEDELTA ( PTARE ) = PTARE + 1.0
real :: FOEEW
FOEEW ( PTARE ) = PTARE + FOEDELTA(PTARE)
""".strip()

fcode = f"""
subroutine stmt_func(arr, ret)
implicit none
real, intent(in) :: arr(:)
real, intent(inout) :: ret(:)
real :: ret2
real, parameter :: rtt = 1.0
{stmt_decls_code if stmt_decls else '#include "fcttre.func.h"'}
ret = foeew(arr)
ret2 = foedelta(3.0)
end subroutine stmt_func
""".strip()

routine = Subroutine.from_source(fcode, frontend=frontend)
if stmt_decls:
assert FindNodes(ir.StatementFunction).visit(routine.spec)
else:
assert not FindNodes(ir.StatementFunction).visit(routine.spec)
assert FindInlineCalls().visit(routine.body)
inline_statement_functions(routine)

assert not FindNodes(ir.StatementFunction).visit(routine.spec)
if stmt_decls:
assert not FindInlineCalls().visit(routine.body)
assignments = FindNodes(ir.Assignment).visit(routine.body)
assert assignments[0].lhs == 'ret'
assert assignments[0].rhs == "arr + arr + 1.0"
assert assignments[1].lhs == 'ret2'
assert assignments[1].rhs == "3.0 + 1.0"
else:
assert FindInlineCalls().visit(routine.body)

@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('adjust_imports', [True, False])
def test_inline_marked_subroutines(frontend, adjust_imports, tmp_path):
Expand Down Expand Up @@ -1095,7 +1144,8 @@ def test_inline_marked_subroutines_declarations(frontend, tmp_path):
@pytest.mark.parametrize('frontend', available_frontends(
(OFP, 'Prefix/elemental support not implemented'))
)
def test_inline_transformation(frontend, tmp_path):
@pytest.mark.parametrize('pass_as_kwarg', (False, True))
def test_inline_transformation(tmp_path, frontend, pass_as_kwarg):
"""Test combining recursive inlining via :any:`InliningTransformation`."""

fcode_module = """
Expand Down Expand Up @@ -1125,25 +1175,30 @@ def test_inline_transformation(frontend, tmp_path):
end subroutine add_one_and_two
"""

fcode = """
fcode = f"""
subroutine test_inline_pragma(a, b)
implicit none
real(kind=8), intent(inout) :: a(3), b(3)
integer, parameter :: n = 3
integer :: i
real :: stmt_arg
real :: some_stmt_func
some_stmt_func ( stmt_arg ) = stmt_arg + 3.1415
#include "add_one_and_two.intfb.h"
do i=1, n
!$loki inline
call add_one_and_two(a(i))
call add_one_and_two({'a=' if pass_as_kwarg else ''}a(i))
end do
do i=1, n
!$loki inline
call add_one_and_two(b(i))
call add_one_and_two({'a=' if pass_as_kwarg else ''}b(i))
end do
a(1) = some_stmt_func({'stmt_arg=' if pass_as_kwarg else ''}a(2))
end subroutine test_inline_pragma
"""
module = Module.from_source(fcode_module, frontend=frontend, xmods=[tmp_path])
Expand All @@ -1152,7 +1207,8 @@ def test_inline_transformation(frontend, tmp_path):
routine.enrich(inner)

trafo = InlineTransformation(
inline_constants=True, external_only=True, inline_elementals=True
inline_constants=True, external_only=True, inline_elementals=True,
inline_stmt_funcs=True
)

calls = FindNodes(ir.CallStatement).visit(routine.body)
Expand All @@ -1173,11 +1229,12 @@ def test_inline_transformation(frontend, tmp_path):
calls = FindNodes(ir.CallStatement).visit(routine.body)
assert len(calls) == 0
assigns = FindNodes(ir.Assignment).visit(routine.body)
assert len(assigns) == 4
assert len(assigns) == 5
assert assigns[0].lhs == 'a(i)' and assigns[0].rhs == 'a(i) + 1.0'
assert assigns[1].lhs == 'a(i)' and assigns[1].rhs == 'a(i) + 2.0'
assert assigns[2].lhs == 'b(i)' and assigns[2].rhs == 'b(i) + 1.0'
assert assigns[3].lhs == 'b(i)' and assigns[3].rhs == 'b(i) + 2.0'
assert assigns[4].lhs == 'a(1)' and assigns[4].rhs == 'a(2) + 3.1415'


@pytest.mark.parametrize('frontend', available_frontends())
Expand Down
6 changes: 4 additions & 2 deletions loki/transformations/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ def replace_selected_kind(routine):
routine.spec.prepend(imprt)


def recursive_expression_map_update(expr_map, max_iterations=10):
def recursive_expression_map_update(expr_map, max_iterations=10, mapper_cls=SubstituteExpressionsMapper):
"""
Utility function to apply a substitution map for expressions to itself
Expand All @@ -490,6 +490,8 @@ def recursive_expression_map_update(expr_map, max_iterations=10):
max_iterations : int
Maximum number of iterations, corresponds to the maximum level of
nesting that can be replaced.
mapper_cls: :any:`SubstituteExpressionsMapper`
The underlying mapper to be used (default: :any:`SubstituteExpressionsMapper`).
"""
def apply_to_init_arg(name, arg, expr, mapper):
# Helper utility to apply the mapper only to expression arguments and
Expand All @@ -504,7 +506,7 @@ def apply_to_init_arg(name, arg, expr, mapper):
# We update the expression map by applying it to the children of each replacement
# node, thus making sure node replacements are also applied to nested attributes,
# e.g. call arguments or array subscripts etc.
mapper = SubstituteExpressionsMapper(expr_map)
mapper = mapper_cls(expr_map)
prev_map, expr_map = expr_map, {
expr: type(replacement)(**{
name: apply_to_init_arg(name, arg, expr, mapper)
Expand Down

0 comments on commit 7ca9dd8

Please sign in to comment.