Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix function inlining when only interface is available (fixes #397) #402

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions loki/frontend/fparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,13 @@ def visit_Part_Ref(self, o, **kwargs):
# This should go away once fparser has a basic symbol table, see
# https://github.com/stfc/fparser/issues/201 for some details
_type = kwargs['scope'].symbol_attrs.lookup(name.name)
if _type is None and (definition := self.definitions.get(name.name)):
# We don't have any type information for this, which means it has
# not been declared locally. Check the definitions for enriched
# type information:
if isinstance(dtype := definition.procedure_type, ProcedureType):
_type = name.type.clone(dtype=dtype)
name = name.clone(type=_type)
if _type and isinstance(_type.dtype, ProcedureType):
name = name.clone(dimensions=None)
call = sym.InlineCall(name, parameters=dimensions, kw_parameters=())
Expand Down
2 changes: 1 addition & 1 deletion loki/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


__all__ = ['logger', 'log_levels', 'set_log_level', 'FileLogger',
'debug', 'info', 'warning', 'error', 'log']
'debug', 'detail', 'perf', 'info', 'warning', 'error', 'log']


def FileLogger(name, filename, level=None, file_level=None, fmt=None,
Expand Down
10 changes: 9 additions & 1 deletion loki/transformations/inline/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from loki.expression import LokiIdentityMapper
from loki.ir import (
FindNodes, Assignment, StatementFunction, SubstituteExpressions
)
from loki.expression import LokiIdentityMapper
from loki.logging import detail
from loki.types import BasicType


Expand Down Expand Up @@ -73,6 +74,13 @@ def map_inline_call(self, expr, *args, **kwargs):
function = expr.procedure_type.procedure
v_result = [v for v in function.variables if v == function.name][0]

scope = kwargs.get('scope') or expr.function.scope
if scope and function.name in scope.interface_map:
# Inline call to a function that is provided via an interface
# We don't have the function body available for inlining
detail(f'Cannot inline {expr.function.name} into {scope.name}. Only interface available.')
return super().map_inline_call(expr, *args, **kwargs)

# Substitute all arguments through the elemental body
arg_map = dict(expr.arg_iter())
fbody = SubstituteExpressions(arg_map).visit(function.body)
Expand Down
122 changes: 121 additions & 1 deletion loki/transformations/inline/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from loki.ir import (
nodes as ir, FindNodes, FindVariables, FindInlineCalls
)
from loki.types import ProcedureType

from loki.transformations.inline import (
inline_elemental_functions, inline_statement_functions
Expand Down Expand Up @@ -216,7 +217,7 @@ def test_inline_statement_functions(frontend, stmt_decls):
real, parameter :: rtt = 1.0
{stmt_decls_code if stmt_decls else '#include "fcttre.func.h"'}

ret = foeew(arr)
ret = foeew(arr)
ret2 = foedelta(3.0)
end subroutine stmt_func
""".strip()
Expand All @@ -239,3 +240,122 @@ def test_inline_statement_functions(frontend, stmt_decls):
assert assignments[1].rhs == "3.0 + 1.0"
else:
assert FindInlineCalls().visit(routine.body)

@pytest.mark.parametrize('frontend', available_frontends(
skip={OFP: "OFP apparently has problems dealing with those Statement Functions",
OMNI: "OMNI automatically inlines Statement Functions"}
))
@pytest.mark.parametrize('provide_myfunc', ('import', 'module', 'interface', 'intfb', 'routine'))
def test_inline_statement_functions_inline_call(frontend, provide_myfunc, tmp_path):
fcode_myfunc = """
elemental function myfunc(a)
real, intent(in) :: a
real :: myfunc
myfunc = a * 2.0
end function myfunc
""".strip()

if provide_myfunc == 'module':
fcode_myfunc = f"""
module my_mod
implicit none
contains
{fcode_myfunc}
end module my_mod
""".strip()

if provide_myfunc in ('import', 'module'):
module_import = 'use my_mod, only: myfunc'
else:
module_import = ''

if provide_myfunc == 'interface':
intf = """
interface
elemental function myfunc(a)
implicit none
real a
real myfunc
end function myfunc
end interface
"""
elif provide_myfunc in ('intfb', 'routine'):
intf = '#include "myfunc.intfb.h"'
else:
intf = ''

fcode = f"""
subroutine stmt_func(arr, val, ret)
{module_import}
implicit none
real, intent(in) :: arr(:)
real, intent(in) :: val
real, intent(inout) :: ret(:)
real :: ret2
real, parameter :: rtt = 1.0
real :: PTARE
real :: FOEDELTA
FOEDELTA ( PTARE ) = PTARE + 1.0 + MYFUNC(PTARE)
real :: FOEEW
FOEEW ( PTARE ) = PTARE + FOEDELTA(PTARE) + MYFUNC(PTARE)
{intf}

ret = foeew(arr)
ret2 = foedelta(3.0) + foedelta(val)
end subroutine stmt_func
""".strip()

if provide_myfunc == 'module':
definitions = (Module.from_source(fcode_myfunc, xmods=[tmp_path]),)
elif provide_myfunc == 'routine':
definitions = (Subroutine.from_source(fcode_myfunc, xmods=[tmp_path]),)
else:
definitions = None
routine = Subroutine.from_source(fcode, frontend=frontend, definitions=definitions, xmods=[tmp_path])

# Check the spec
statement_funcs = FindNodes(ir.StatementFunction).visit(routine.spec)
assert len(statement_funcs) == 2

inline_calls = FindInlineCalls(unique=False).visit(routine.spec)
if provide_myfunc in ('module', 'interface', 'routine'):
# Enough information available that MYFUNC is recognized as a procedure call
assert len(inline_calls) == 3
assert all(isinstance(call.function.type.dtype, ProcedureType) for call in inline_calls)
else:
# No information available about MYFUNC, so fparser treats it as an ArraySubscript
assert len(inline_calls) == 1
assert inline_calls[0].function == 'foedelta'
assert isinstance(inline_calls[0].function.type.dtype, ProcedureType)

# Check the body
inline_calls = FindInlineCalls().visit(routine.body)
assert len(inline_calls) == 3

# Apply the transformation
inline_statement_functions(routine)

# Check the outcome
assert not FindNodes(ir.StatementFunction).visit(routine.spec)
inline_calls = FindInlineCalls(unique=False).visit(routine.body)
assignments = FindNodes(ir.Assignment).visit(routine.body)

if provide_myfunc in ('import', 'intfb'):
# MYFUNC(arr) is misclassified as array subscript
assert len(inline_calls) == 0
elif provide_myfunc in ('module', 'routine'):
# MYFUNC(arr) is eliminated due to inlining
assert len(inline_calls) == 0
else:
assert len(inline_calls) == 4

assert assignments[0].lhs == 'ret'
assert assignments[1].lhs == 'ret2'
if provide_myfunc in ('module', 'routine'):
# Fully inlined due to definition of myfunc available
assert assignments[0].rhs == "arr + arr + 1.0 + arr*2.0 + arr*2.0"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[no action] This highlights that the entire separation of inline_statement_function and inline_elemental is basically now broken. All four definition types should yield consistent behaviour.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, the behaviour difference comes again from the lack of information. Inlining always works the same if you have the definition of the function body available, and fails when you don't have it

In fact, let me add a fifth case provide_myfunc == 'subroutine', where we have the intfb and provide the definition separately as a free function (not embedded in a module). Providing this as a definition should also allow the inlining to be effective

assert assignments[1].rhs == "3.0 + 1.0 + 3.0*2.0 + val + 1.0 + val*2.0"
else:
# myfunc not inlined
assert assignments[0].rhs == "arr + arr + 1.0 + myfunc(arr) + myfunc(arr)"
assert assignments[1].rhs == "3.0 + 1.0 + myfunc(3.0) + val + 1.0 + myfunc(val)"
Loading