diff --git a/loki/backend/cgen.py b/loki/backend/cgen.py index 282c05969..df590f619 100644 --- a/loki/backend/cgen.py +++ b/loki/backend/cgen.py @@ -10,10 +10,13 @@ PREC_UNARY, PREC_LOGICAL_OR, PREC_LOGICAL_AND, PREC_NONE, PREC_CALL ) +from loki.tools import as_tuple from loki.ir import Import, Stringifier, FindNodes -from loki.expression import LokiStringifyMapper, Array, symbolic_op, Literal +from loki.expression import ( + LokiStringifyMapper, Array, symbolic_op, Literal, + symbols as sym +) from loki.types import BasicType, SymbolAttributes, DerivedType - __all__ = ['cgen', 'CCodegen', 'CCodeMapper'] @@ -364,6 +367,40 @@ def visit_TypeDef(self, o, **kwargs): self.depth -= 1 return self.join_lines(header, decls, footer) + def visit_MultiConditional(self, o, **kwargs): + """ + Format as + switch case () { + case : + ...body... + [case :] + [...body...] + [default:] + [...body...] + } + """ + header = self.format_line('switch (', self.visit(o.expr, **kwargs), ') {') + cases = [] + end_cases = [] + for value in o.values: + if any(isinstance(val, sym.RangeIndex) for val in value): + # TODO: in Fortran a case can be a range, which is not straight-forward + # to translate/transfer to C + #  https://j3-fortran.org/doc/year/10/10-007.pdf#page=200 + raise NotImplementedError + case = self.visit_all(as_tuple(value), **kwargs) + cases.append(self.format_line('case ', self.join_items(case), ':')) + end_cases.append(self.format_line('break;')) + if o.else_body: + cases.append(self.format_line('default: ')) + end_cases.append(self.format_line('break;')) + footer = self.format_line('}') + self.depth += 1 + bodies = self.visit_all(*o.bodies, o.else_body, **kwargs) + self.depth -= 1 + branches = [item for branch in zip(cases, bodies, end_cases) for item in branch] + return self.join_lines(header, *branches, footer) + def cgen(ir): """ diff --git a/tests/test_transpile.py b/tests/test_transpile.py index 7be234366..6945271da 100644 --- a/tests/test_transpile.py +++ b/tests/test_transpile.py @@ -11,7 +11,7 @@ from conftest import jit_compile, jit_compile_lib, clean_test, available_frontends from loki import ( - Subroutine, Module, FortranCTransformation, OFP + Subroutine, Module, FortranCTransformation, OFP, cgen ) from loki.build import Builder from loki.transform import normalize_range_indexing @@ -1003,3 +1003,113 @@ def test_transpile_expressions(here, builder, frontend, use_c_ptr): clean_test(filepath) f2c.wrapperpath.unlink() f2c.c_path.unlink() + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_transpile_multiconditional(here, builder, frontend): + """ + A simple test to verify multiconditionals/select case statements. + """ + + fcode = """ +subroutine transpile_multi_conditional(in, out) + implicit none + integer, intent(in) :: in + integer, intent(inout) :: out + + select case (in) + case (1) + out = 10 + case (2) + out = 20 + case default + out = 100 + end select + +end subroutine transpile_multi_conditional +""".strip() + + # for testing purposes + in_var = 0 + test_vals = [0, 1, 2, 5] + expected_results = [100, 10, 20, 100] + out_var = np.int_([0]) + + # compile original Fortran version + routine = Subroutine.from_source(fcode, frontend=frontend) + filepath = here/f'{routine.name}_{frontend!s}.f90' + function = jit_compile(routine, filepath=filepath, objname=routine.name) + # test Fortran version + for i, val in enumerate(test_vals): + in_var = val + function(in_var, out_var) + assert out_var == expected_results[i] + + # apply F2C trafo + f2c = FortranCTransformation() + f2c.apply(source=routine, path=here) + + # check whether 'switch' statement is within C code + assert 'switch' in cgen(routine) + + # compile C version + libname = f'fc_{routine.name}_{frontend}' + c_kernel = jit_compile_lib([f2c.wrapperpath, f2c.c_path], path=here, name=libname, builder=builder) + fc_function = c_kernel.transpile_multi_conditional_fc_mod.transpile_multi_conditional_fc + # test C version + for i, val in enumerate(test_vals): + in_var = val + fc_function(in_var, out_var) + assert out_var == expected_results[i] + + # cleanup ... + builder.clean() + clean_test(filepath) + f2c.wrapperpath.unlink() + f2c.c_path.unlink() + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_transpile_multiconditional_range(here, frontend): + """ + A simple test to verify multiconditionals/select case statements. + """ + + fcode = """ +subroutine transpile_multi_conditional_range(in, out) + implicit none + integer, intent(in) :: in + integer, intent(inout) :: out + + select case (in) + case (1:5) + out = 10 + case default + out = 100 + end select + +end subroutine transpile_multi_conditional_range +""".strip() + + # for testing purposes + in_var = 0 + test_vals = [0, 1, 2, 5, 6] + expected_results = [100, 10, 10, 10, 100] + out_var = np.int_([0]) + + # compile original Fortran version + routine = Subroutine.from_source(fcode, frontend=frontend) + filepath = here/f'{routine.name}_{frontend!s}.f90' + function = jit_compile(routine, filepath=filepath, objname=routine.name) + # test Fortran version + for i, val in enumerate(test_vals): + in_var = val + function(in_var, out_var) + assert out_var == expected_results[i] + + clean_test(filepath) + + # apply F2C trafo + # TODO: RangeIndex as case is not yet implemented! + # 'NotImplementedError' is raised + f2c = FortranCTransformation() + with pytest.raises(NotImplementedError): + f2c.apply(source=routine, path=here)