-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
handle modulo operator/function for c-like-backends #383
Changes from 2 commits
b25f4cd
1bc4706
2ce21a2
a428a3b
c613469
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,10 +11,13 @@ | |
) | ||
|
||
from loki.tools import as_tuple | ||
from loki.ir import Import, Stringifier, FindNodes | ||
from loki.ir import ( | ||
Import, Stringifier, FindNodes, | ||
FindVariables, ExpressionFinder | ||
) | ||
from loki.expression import ( | ||
LokiStringifyMapper, Array, symbolic_op, Literal, | ||
symbols as sym | ||
symbols as sym, ExpressionRetriever | ||
) | ||
from loki.types import BasicType, SymbolAttributes, DerivedType | ||
|
||
|
@@ -140,6 +143,25 @@ def map_c_reference(self, expr, enclosing_prec, *args, **kwargs): | |
def map_c_dereference(self, expr, enclosing_prec, *args, **kwargs): | ||
return self.format(' (*%s)', self.rec(expr.expression, PREC_NONE, *args, **kwargs)) | ||
|
||
def map_inline_call(self, expr, enclosing_prec, *args, **kwargs): | ||
|
||
class FindFloatLiterals(ExpressionFinder): | ||
retriever = ExpressionRetriever(lambda e: isinstance(e, sym.FloatLiteral)) | ||
|
||
if expr.function.name.lower() == 'mod': | ||
parameters = [self.rec(param, PREC_NONE, *args, **kwargs) for param in expr.parameters] | ||
# TODO: this check is not quite correct, as it should evaluate the | ||
# expression(s) of both arguments/parameters and choose the integer version of modulo ('%') | ||
# instead of the floating-point version ('fmod') | ||
# whenever the mentioned evaluations result in being of kind 'integer' ... | ||
# as an example: 'celing(3.1415)' got an floating point value in it, however it evaluates/returns | ||
# an integer, in that case the wrong modulo function/operation is chosen | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [no action] This is indeed a tricky corner case. Solution looks good for now, but might need some more thought when this gets problematic. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Likely the correct approach would be an Definitely way beyond the scope of this PR, though. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes! We could also think about implementing a C++ (templated) function |
||
if any(var.type.dtype != BasicType.INTEGER for var in FindVariables().visit(expr.parameters)) or\ | ||
FindFloatLiterals().visit(expr.parameters): | ||
return f'fmod({parameters[0]}, {parameters[1]})' | ||
return f'({parameters[0]})%({parameters[1]})' | ||
return super().map_inline_call(expr, enclosing_prec, *args, **kwargs) | ||
|
||
|
||
class CCodegen(Stringifier): | ||
""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,6 @@ | |
# nor does it submit to any jurisdiction. | ||
|
||
from pathlib import Path | ||
# from shutil import rmtree | ||
import pytest | ||
import numpy as np | ||
|
||
|
@@ -19,6 +18,8 @@ | |
from loki.transformations.transpile import FortranCTransformation | ||
from loki.transformations.single_column import SCCLowLevelHoist, SCCLowLevelParametrise | ||
|
||
# pylint: disable=too-many-lines | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should start splitting the tests into logical units (but that can be a separate PR) |
||
@pytest.fixture(scope='function', name='builder') | ||
def fixture_builder(tmp_path): | ||
yield Builder(source_dirs=tmp_path, build_dir=tmp_path) | ||
|
@@ -1249,6 +1250,80 @@ def test_transpile_multiconditional_range(tmp_path, frontend): | |
with pytest.raises(NotImplementedError): | ||
f2c.apply(source=routine, path=tmp_path) | ||
|
||
|
||
@pytest.mark.parametrize('frontend', available_frontends()) | ||
@pytest.mark.parametrize('dtype', ('integer', 'real',)) | ||
@pytest.mark.parametrize('add_float', (False, True)) | ||
def test_transpile_special_functions(tmp_path, builder, frontend, dtype, add_float): | ||
""" | ||
A simple test to verify multiconditionals/select case statements. | ||
""" | ||
if dtype == 'real': | ||
decl_type = f'{dtype}(kind=real64)' | ||
kind = '._real64' | ||
else: | ||
decl_type = dtype | ||
kind = '' | ||
|
||
fcode = f""" | ||
subroutine transpile_special_functions(in, out) | ||
use iso_fortran_env, only: real64 | ||
implicit none | ||
{decl_type}, intent(in) :: in | ||
{decl_type}, intent(inout) :: out | ||
if (mod(in{'+ 2._real64' if add_float else ''}, 2{kind}{'+ 0._real64' if add_float else ''}) .eq. 0) then | ||
out = 42{kind} | ||
else | ||
out = 11{kind} | ||
endif | ||
end subroutine transpile_special_functions | ||
""".strip() | ||
|
||
def init_var(dtype, val=0): | ||
if dtype == 'real': | ||
return np.float64([val]) | ||
return np.int_([val]) | ||
|
||
# for testing purposes | ||
in_var = init_var(dtype) | ||
test_vals = [2, 10, 5, 3] | ||
expected_results = [42, 42, 11, 11] | ||
out_var = init_var(dtype) | ||
|
||
# compile original Fortran version | ||
routine = Subroutine.from_source(fcode, frontend=frontend) | ||
filepath = tmp_path/f'{routine.name}_{frontend!s}.f90' | ||
function = jit_compile(routine, filepath=filepath, objname=routine.name) | ||
# test Fortran version | ||
for i, val in enumerate(test_vals): | ||
in_var = val | ||
function(in_var, out_var) | ||
assert out_var == expected_results[i] | ||
|
||
clean_test(filepath) | ||
|
||
# apply F2C trafo | ||
f2c = FortranCTransformation() | ||
f2c.apply(source=routine, path=tmp_path) | ||
|
||
# check whether correct modulo was inserted | ||
ccode = Path(f2c.c_path).read_text() | ||
if dtype == 'integer' and not add_float: | ||
assert '%' in ccode | ||
if dtype == 'real' or add_float: | ||
assert 'fmod' in ccode | ||
|
||
# compile C version | ||
libname = f'fc_{routine.name}_{frontend}' | ||
c_kernel = jit_compile_lib([f2c.wrapperpath, f2c.c_path], path=tmp_path, name=libname, builder=builder) | ||
fc_function = c_kernel.transpile_special_functions_fc_mod.transpile_special_functions_fc | ||
# test C version | ||
for i, val in enumerate(test_vals): | ||
in_var = val | ||
fc_function(in_var, out_var) | ||
assert int(out_var) == expected_results[i] | ||
|
||
|
||
@pytest.fixture(scope='module', name='horizontal') | ||
def fixture_horizontal(): | ||
return Dimension(name='horizontal', size='nlon', index='jl', bounds=('start', 'iend')) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a minor nitpick, but I think this is general enough to live in
loki.ir.expr_visitor.py