diff --git a/loki/backend/cgen.py b/loki/backend/cgen.py index 57ecb99f1..2648f15cc 100644 --- a/loki/backend/cgen.py +++ b/loki/backend/cgen.py @@ -14,7 +14,8 @@ from loki.ir import Import, Stringifier, FindNodes from loki.expression import ( LokiStringifyMapper, Array, symbolic_op, Literal, - symbols as sym + symbols as sym, FindVariables, ExpressionFinder, + ExpressionRetriever ) from loki.types import BasicType, SymbolAttributes, DerivedType @@ -140,6 +141,25 @@ def map_c_reference(self, expr, enclosing_prec, *args, **kwargs): def map_c_dereference(self, expr, enclosing_prec, *args, **kwargs): return self.format(' (*%s)', self.rec(expr.expression, PREC_NONE, *args, **kwargs)) + def map_inline_call(self, expr, enclosing_prec, *args, **kwargs): + + class FindFloatLiterals(ExpressionFinder): + retriever = ExpressionRetriever(lambda e: isinstance(e, sym.FloatLiteral)) + + if expr.function.name.lower() == 'mod': + parameters = [self.rec(param, PREC_NONE, *args, **kwargs) for param in expr.parameters] + # TODO: this check is not quite correct, as it should evaluate the + # expression(s) of both arguments/parameters and choose the integer version of modulo ('%') + # instead of the floating-point version ('fmod') + # whenever the mentioned evaluations result in being of kind 'integer' ... + # as an example: 'celing(3.1415)' got an floating point value in it, however it evaluates/returns + # an integer, in that case the wrong modulo function/operation is chosen + if any(var.type.dtype != BasicType.INTEGER for var in FindVariables().visit(expr.parameters)) or\ + FindFloatLiterals().visit(expr.parameters): + return f'fmod({parameters[0]}, {parameters[1]})' + return f'({parameters[0]})%({parameters[1]})' + return super().map_inline_call(expr, enclosing_prec, *args, **kwargs) + class CCodegen(Stringifier): """ diff --git a/loki/transformations/transpile/tests/test_transpile.py b/loki/transformations/transpile/tests/test_transpile.py index d0c1bdc8f..9a8962e28 100644 --- a/loki/transformations/transpile/tests/test_transpile.py +++ b/loki/transformations/transpile/tests/test_transpile.py @@ -19,6 +19,8 @@ from loki.transformations.transpile import FortranCTransformation from loki.transformations.single_column import SCCLowLevelHoist, SCCLowLevelParametrise +# pylint: disable=too-many-lines + @pytest.fixture(scope='function', name='builder') def fixture_builder(tmp_path): yield Builder(source_dirs=tmp_path, build_dir=tmp_path) @@ -1249,6 +1251,77 @@ def test_transpile_multiconditional_range(tmp_path, frontend): with pytest.raises(NotImplementedError): f2c.apply(source=routine, path=tmp_path) + +@pytest.mark.parametrize('frontend', available_frontends()) +@pytest.mark.parametrize('dtype', ('integer', 'real',)) +@pytest.mark.parametrize('add_float', (False, True)) +def test_transpile_special_functions(tmp_path, builder, frontend, dtype, add_float): + """ + A simple test to verify multiconditionals/select case statements. + """ + + fcode = f""" +subroutine transpile_special_functions(in, out) + use iso_fortran_env, only: real64 + implicit none + {dtype}{'(kind=real64)' if dtype == 'real' else ''}, intent(in) :: in + {dtype}{'(kind=real64)' if dtype == 'real' else ''}, intent(inout) :: out + + if (mod(in{'+ 2._real64' if add_float else ''}, 2{'._real64' if dtype == 'real' else ''}{'+ 0._real64' if add_float else ''}) .eq. 0) then + out = 42{'._real64' if dtype == 'real' else ''} + else + out = 11{'._real64' if dtype == 'real' else ''} + endif +end subroutine transpile_special_functions +""".strip() + + def init_var(dtype, val=0): + if dtype == 'real': + return np.float64([val]) + return np.int_([val]) + + # for testing purposes + in_var = init_var(dtype) # np.float64([0]) # 0 + test_vals = [2, 10, 5, 3] + expected_results = [42, 42, 11, 11] + # out_var = np.int_([0]) + out_var = init_var(dtype) # np.float64([0]) + + # compile original Fortran version + routine = Subroutine.from_source(fcode, frontend=frontend) + filepath = tmp_path/f'{routine.name}_{frontend!s}.f90' + function = jit_compile(routine, filepath=filepath, objname=routine.name) + # test Fortran version + for i, val in enumerate(test_vals): + in_var = val + function(in_var, out_var) + assert out_var == expected_results[i] + + clean_test(filepath) + + # apply F2C trafo + f2c = FortranCTransformation() + f2c.apply(source=routine, path=tmp_path) + + # check whether correct modulo was inserted + with open(f2c.c_path, 'r') as f: + ccode = f.read() + if dtype == 'integer' and not add_float: + assert '%' in ccode + if dtype == 'real' or add_float: + assert 'fmod' in ccode + + # compile C version + libname = f'fc_{routine.name}_{frontend}' + c_kernel = jit_compile_lib([f2c.wrapperpath, f2c.c_path], path=tmp_path, name=libname, builder=builder) + fc_function = c_kernel.transpile_special_functions_fc_mod.transpile_special_functions_fc + # test C version + for i, val in enumerate(test_vals): + in_var = val + fc_function(in_var, out_var) + assert int(out_var) == expected_results[i] + + @pytest.fixture(scope='module', name='horizontal') def fixture_horizontal(): return Dimension(name='horizontal', size='nlon', index='jl', bounds=('start', 'iend'))