Skip to content

Commit

Permalink
Merge pull request #269 from ecmwf-ifs/nams_transpile_function_return
Browse files Browse the repository at this point in the history
`cgen`: return type and var for function(s)
  • Loading branch information
reuterbal authored Apr 10, 2024
2 parents 6216cca + 1bdb019 commit 2b9ab70
Show file tree
Hide file tree
Showing 8 changed files with 80 additions and 32 deletions.
14 changes: 12 additions & 2 deletions loki/backend/cgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,14 @@ def visit_Subroutine(self, o, **kwargs):
aptr += ['']
arguments = [f'{self.visit(a.type, **kwargs)} {p}{a.name.lower()}'
for a, p in zip(o.arguments, aptr)]
header += [self.format_line('int ', o.name, '(', self.join_items(arguments), ') {')]

# check whether to return something and define function return type accordingly
if o.is_function:
return_type = c_intrinsic_type(o.return_type)
else:
return_type = 'void'

header += [self.format_line(f'{return_type} ', o.name, '(', self.join_items(arguments), ') {')]

self.depth += 1

Expand All @@ -183,7 +190,10 @@ def visit_Subroutine(self, o, **kwargs):

# Fill the body
body += [self.visit(o.body, **kwargs)]
body += [self.format_line('return 0;')]

# if something to be returned, add 'return <var>' statement
if o.result_name is not None:
body += [self.format_line(f'return {o.result_name.lower()};')]

# Close everything off
self.depth -= 1
Expand Down
3 changes: 2 additions & 1 deletion loki/backend/fgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,8 @@ def visit_Subroutine(self, o, **kwargs):
if o.prefix:
prefix += ' '
arguments = self.join_items(o.argnames)
result = f' RESULT({o.result_name})' if o.result_name else ''
result = f' RESULT({o.result_name})' if o.result_name\
and o.result_name.lower() != o.name.lower() else ''
if isinstance(o.bind, str):
bind_c = f' BIND(c, name="{o.bind}")'
elif isinstance(o.bind, StringLiteral):
Expand Down
8 changes: 5 additions & 3 deletions loki/subroutine.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ def __initialize__(
self.result_name = result_name
self.is_function = is_function

# Make sure 'result_name' is defined if it's a function
if self.result_name is None and self.is_function:
self.result_name = name

# Additional IR components
if body is not None and not isinstance(body, ir.Section):
body = ir.Section(body=body)
Expand Down Expand Up @@ -327,9 +331,7 @@ def return_type(self):
"""
if not self.is_function:
return None
if self.result_name is not None:
return self.symbol_attrs.get(self.result_name)
return self.symbol_attrs.get(self.name)
return self.symbol_attrs.get(self.result_name)

variables = ProgramUnit.variables

Expand Down
22 changes: 1 addition & 21 deletions loki/transform/dependency_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pathlib import Path

from loki.backend import fgen
from loki.expression import Variable, FindInlineCalls, SubstituteExpressions
from loki.expression import Variable, FindInlineCalls
from loki.ir import (
CallStatement, Import, Section, Interface, FindNodes, Transformer
)
Expand Down Expand Up @@ -155,8 +155,6 @@ def transform_subroutine(self, routine, **kwargs):
return

# Change the name of kernel routines
if routine.is_function and not routine.result_name:
self.update_result_var(routine)
routine.name += self.suffix
if item:
item.name += self.suffix.lower()
Expand Down Expand Up @@ -220,24 +218,6 @@ def derive_module_name(self, modname):
return f'{modname}{self.suffix}{self.module_suffix}'
return f'{modname}{self.suffix}'

def update_result_var(self, routine):
"""
Update name of result variable for function calls.
Parameters
----------
routine : :any:`Subroutine`
The function object for which the result variable is to be renamed
"""
assert routine.name in routine.variables

vmap = {
v: v.clone(name=v.name + self.suffix)
for v in routine.variables if v == routine.name
}
routine.spec = SubstituteExpressions(vmap).visit(routine.spec)
routine.body = SubstituteExpressions(vmap).visit(routine.body)

def rename_calls(self, routine, targets=None, item=None):
"""
Update :any:`CallStatement` and :any:`InlineCall` to actively
Expand Down
4 changes: 2 additions & 2 deletions loki/transform/fortran_c_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,8 +407,8 @@ def generate_c_kernel(self, routine):
such as the explicit getter calls for imported module-level variables.
"""

# Work with a copy of the original routine to not break the
# dependency graph of the Scheduler through the rename
# CAUTION! Work with a copy of the original routine to not break the
# dependency graph of the Scheduler through the rename
kernel = routine.clone()
kernel.name = f'{kernel.name.lower()}_c'

Expand Down
2 changes: 1 addition & 1 deletion tests/test_subroutine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1646,7 +1646,7 @@ def test_subroutine_suffix(frontend):

check_value = module.interface_map['check_value'].body[0]
assert check_value.is_function
assert check_value.result_name is None
assert check_value.result_name == 'check_value'
assert check_value.return_type.dtype is BasicType.INTEGER
assert check_value.return_type.kind == 'c_int'
if frontend != OMNI:
Expand Down
5 changes: 3 additions & 2 deletions tests/test_transform_dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,8 +486,9 @@ def test_dependency_transformation_inline_call(frontend):
assert kernel.modules[0].name == 'kernel_test_mod'
assert kernel['kernel_test_mod'] == kernel.modules[0]

# Check that the return name has been added as a variable
assert 'kernel_test' in kernel['kernel_test'].variables
# Check that the return name hasn't changed
assert 'kernel' in kernel['kernel_test'].variables
assert kernel['kernel_test'].result_name == 'kernel'

# Check that the driver name has not changed
assert len(driver.modules) == 0
Expand Down
54 changes: 54 additions & 0 deletions tests/test_transpile.py
Original file line number Diff line number Diff line change
Expand Up @@ -1004,6 +1004,59 @@ def test_transpile_expressions(here, builder, frontend, use_c_ptr):
f2c.wrapperpath.unlink()
f2c.c_path.unlink()


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('f_type', ['integer', 'real'])
def test_transpile_inline_functions(here, frontend, f_type):
"""
Test correct transpilation of functions in C transpilation.
"""

fcode = f"""
function add(a, b)
{f_type} :: add
{f_type}, intent(in) :: a, b
add = a + b
end function add
""".format(f_type)

routine = Subroutine.from_source(fcode, frontend=frontend)
f2c = FortranCTransformation()
f2c.apply(source=routine, path=here)

f_type_map = {'integer': 'int', 'real': 'double'}
c_routine = cgen(routine)
assert 'return add;' in c_routine
assert f'{f_type_map[f_type]} add(' in c_routine


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('f_type', ['integer', 'real'])
def test_transpile_inline_functions_return(here, frontend, f_type):
"""
Test correct transpilation of functions in C transpilation.
"""

fcode = f"""
function add(a, b) result(res)
{f_type} :: res
{f_type}, intent(in) :: a, b
res = a + b
end function add
""".format(f_type)

routine = Subroutine.from_source(fcode, frontend=frontend)
f2c = FortranCTransformation()
f2c.apply(source=routine, path=here)

f_type_map = {'integer': 'int', 'real': 'double'}
c_routine = cgen(routine)
assert 'return res;' in c_routine
assert f'{f_type_map[f_type]} add(' in c_routine


@pytest.mark.parametrize('frontend', available_frontends())
def test_transpile_multiconditional(here, builder, frontend):
"""
Expand Down Expand Up @@ -1067,6 +1120,7 @@ def test_transpile_multiconditional(here, builder, frontend):
f2c.wrapperpath.unlink()
f2c.c_path.unlink()


@pytest.mark.parametrize('frontend', available_frontends())
def test_transpile_multiconditional_range(here, frontend):
"""
Expand Down

0 comments on commit 2b9ab70

Please sign in to comment.