Skip to content

Commit

Permalink
[TRANSPILATION] handle modulo operator/function for c-like-backends
Browse files Browse the repository at this point in the history
  • Loading branch information
MichaelSt98 committed Sep 25, 2024
1 parent 477c56d commit 14e7b37
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 1 deletion.
22 changes: 21 additions & 1 deletion loki/backend/cgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from loki.ir import Import, Stringifier, FindNodes
from loki.expression import (
LokiStringifyMapper, Array, symbolic_op, Literal,
symbols as sym
symbols as sym, FindVariables, ExpressionFinder,
ExpressionRetriever
)
from loki.types import BasicType, SymbolAttributes, DerivedType

Expand Down Expand Up @@ -140,6 +141,25 @@ def map_c_reference(self, expr, enclosing_prec, *args, **kwargs):
def map_c_dereference(self, expr, enclosing_prec, *args, **kwargs):
return self.format(' (*%s)', self.rec(expr.expression, PREC_NONE, *args, **kwargs))

def map_inline_call(self, expr, enclosing_prec, *args, **kwargs):

class FindFloatLiterals(ExpressionFinder):
retriever = ExpressionRetriever(lambda e: isinstance(e, sym.FloatLiteral))

if expr.function.name.lower() == 'mod':
parameters = [self.rec(param, PREC_NONE, *args, **kwargs) for param in expr.parameters]
# TODO: this check is not quite correct, as it should evaluate the
# expression(s) of both arguments/parameters and choose the integer version of modulo ('%')
# instead of the floating-point version ('fmod')
# whenever the mentioned evaluations result in being of kind 'integer' ...
# as an example: 'celing(3.1415)' got an floating point value in it, however it evaluates/returns
# an integer, in that case the wrong modulo function/operation is chosen
if any(var.type.dtype != BasicType.INTEGER for var in FindVariables().visit(expr.parameters)) or\
FindFloatLiterals().visit(expr.parameters):
return f'fmod({parameters[0]}, {parameters[1]})'
return f'({parameters[0]})%({parameters[1]})'
return super().map_inline_call(expr, enclosing_prec, *args, **kwargs)


class CCodegen(Stringifier):
"""
Expand Down
73 changes: 73 additions & 0 deletions loki/transformations/transpile/tests/test_transpile.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
from loki.transformations.transpile import FortranCTransformation
from loki.transformations.single_column import SCCLowLevelHoist, SCCLowLevelParametrise

# pylint: disable=too-many-lines

@pytest.fixture(scope='function', name='builder')
def fixture_builder(tmp_path):
yield Builder(source_dirs=tmp_path, build_dir=tmp_path)
Expand Down Expand Up @@ -1249,6 +1251,77 @@ def test_transpile_multiconditional_range(tmp_path, frontend):
with pytest.raises(NotImplementedError):
f2c.apply(source=routine, path=tmp_path)


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('dtype', ('integer', 'real',))
@pytest.mark.parametrize('add_float', (False, True))
def test_transpile_special_functions(tmp_path, builder, frontend, dtype, add_float):
"""
A simple test to verify multiconditionals/select case statements.
"""

fcode = f"""
subroutine transpile_special_functions(in, out)
use iso_fortran_env, only: real64
implicit none
{dtype}{'(kind=real64)' if dtype == 'real' else ''}, intent(in) :: in
{dtype}{'(kind=real64)' if dtype == 'real' else ''}, intent(inout) :: out
if (mod(in{'+ 2._real64' if add_float else ''}, 2{'._real64' if dtype == 'real' else ''}{'+ 0._real64' if add_float else ''}) .eq. 0) then
out = 42{'._real64' if dtype == 'real' else ''}
else
out = 11{'._real64' if dtype == 'real' else ''}
endif
end subroutine transpile_special_functions
""".strip()

def init_var(dtype, val=0):
if dtype == 'real':
return np.float64([val])
return np.int_([val])

# for testing purposes
in_var = init_var(dtype) # np.float64([0]) # 0
test_vals = [2, 10, 5, 3]
expected_results = [42, 42, 11, 11]
# out_var = np.int_([0])
out_var = init_var(dtype) # np.float64([0])

# compile original Fortran version
routine = Subroutine.from_source(fcode, frontend=frontend)
filepath = tmp_path/f'{routine.name}_{frontend!s}.f90'
function = jit_compile(routine, filepath=filepath, objname=routine.name)
# test Fortran version
for i, val in enumerate(test_vals):
in_var = val
function(in_var, out_var)
assert out_var == expected_results[i]

clean_test(filepath)

# apply F2C trafo
f2c = FortranCTransformation()
f2c.apply(source=routine, path=tmp_path)

# check whether correct modulo was inserted
with open(f2c.c_path, 'r') as f:
ccode = f.read()
if dtype == 'integer' and not add_float:
assert '%' in ccode
if dtype == 'real' or add_float:
assert 'fmod' in ccode

# compile C version
libname = f'fc_{routine.name}_{frontend}'
c_kernel = jit_compile_lib([f2c.wrapperpath, f2c.c_path], path=tmp_path, name=libname, builder=builder)
fc_function = c_kernel.transpile_special_functions_fc_mod.transpile_special_functions_fc
# test C version
for i, val in enumerate(test_vals):
in_var = val
fc_function(in_var, out_var)
assert int(out_var) == expected_results[i]


@pytest.fixture(scope='module', name='horizontal')
def fixture_horizontal():
return Dimension(name='horizontal', size='nlon', index='jl', bounds=('start', 'iend'))
Expand Down

0 comments on commit 14e7b37

Please sign in to comment.