From a63ba1d1d5878319b826e91e20ebe338e5e9d416 Mon Sep 17 00:00:00 2001 From: Michael Staneker Date: Thu, 26 Sep 2024 10:54:34 +0200 Subject: [PATCH] [TRANSPILATION] improve on multiconditionals/switch/select case (allow RangeIndex as case value) --- loki/backend/cgen.py | 77 ++++++++++++++++--- .../transpile/tests/test_transpile.py | 60 ++++++++++----- 2 files changed, 105 insertions(+), 32 deletions(-) diff --git a/loki/backend/cgen.py b/loki/backend/cgen.py index 57ecb99f1..77190e1a1 100644 --- a/loki/backend/cgen.py +++ b/loki/backend/cgen.py @@ -458,28 +458,83 @@ def visit_MultiConditional(self, o, **kwargs): Format as switch case () { case : + { ...body... + } [case :] + { [...body...] - [default:] + } + [default:] { [...body...] } + } + + E.g., the following + + select case (in) + case (:2) + out = 1 + case (4, 5, 7:9) + out = 2 + case (6) + out = 3 + case default + out = 4 + end select + + becomes + + switch (in) { + case 0: + case 1: + case 2: + { + out = 1; + break; + } + case 4: + case 5: + case 7: + case 8: + case 9: + { + out = 2; + break; + } + case 6: + { + out = 3; + break; + } + default: + { + out = 4; + breal; + } + } """ header = self.format_line('switch (', self.visit(o.expr, **kwargs), ') {') cases = [] end_cases = [] for value in o.values: - if any(isinstance(val, sym.RangeIndex) for val in value): - # TODO: in Fortran a case can be a range, which is not straight-forward - # to translate/transfer to C - # https://j3-fortran.org/doc/year/10/10-007.pdf#page=200 - raise NotImplementedError - case = self.visit_all(as_tuple(value), **kwargs) - cases.append(self.format_line('case ', self.join_items(case), ':')) - end_cases.append(self.format_line('break;')) + sub_cases = [] + for val in value: + if not isinstance(val, sym.RangeIndex): + sub_cases.append(self.visit(val, **kwargs)) + else: + assert (val.lower is None or isinstance(val.lower, sym.IntLiteral))\ + and isinstance(val.upper, sym.IntLiteral) + lower = val.lower.value if val.lower is not None else 0 + sub_cases.extend([str(v) for v in list(range(lower, val.upper.value + 1))]) + case = () + for sub_case in sub_cases: + case += (self.format_line('case ', self.join_items(as_tuple(sub_case)), ':'),) + cases.append(self.join_lines(*case, self.format_line('{'))) + end_cases.append(self.join_lines(self.format_line('break;'), self.format_line('}'))) if o.else_body: - cases.append(self.format_line('default: ')) - end_cases.append(self.format_line('break;')) + cases.append(self.join_lines(self.format_line('default: '), self.format_line('{'))) + end_cases.append(self.join_lines(self.format_line('break;'), self.format_line('}'))) footer = self.format_line('}') self.depth += 1 bodies = self.visit_all(*o.bodies, o.else_body, **kwargs) diff --git a/loki/transformations/transpile/tests/test_transpile.py b/loki/transformations/transpile/tests/test_transpile.py index d0c1bdc8f..958adafcb 100644 --- a/loki/transformations/transpile/tests/test_transpile.py +++ b/loki/transformations/transpile/tests/test_transpile.py @@ -1145,13 +1145,13 @@ def test_transpile_inline_functions_return(tmp_path, frontend, f_type, codegen): @pytest.mark.parametrize('frontend', available_frontends()) @pytest.mark.parametrize('codegen', (cgen, cppgen, cudagen)) -def test_transpile_multiconditional(tmp_path, builder, frontend, codegen): +def test_transpile_multiconditional_simple(tmp_path, builder, frontend, codegen): """ A simple test to verify multiconditionals/select case statements. """ fcode = """ -subroutine multi_cond(in, out) +subroutine multi_cond_simple(in, out) implicit none integer, intent(in) :: in integer, intent(inout) :: out @@ -1165,7 +1165,7 @@ def test_transpile_multiconditional(tmp_path, builder, frontend, codegen): out = 100 end select -end subroutine multi_cond +end subroutine multi_cond_simple """.strip() # for testing purposes @@ -1194,7 +1194,7 @@ def test_transpile_multiconditional(tmp_path, builder, frontend, codegen): # compile C version libname = f'fc_{routine.name}_{frontend}' c_kernel = jit_compile_lib([f2c.wrapperpath, f2c.c_path], path=tmp_path, name=libname, builder=builder) - fc_function = c_kernel.multi_cond_fc_mod.multi_cond_fc + fc_function = c_kernel.multi_cond_simple_fc_mod.multi_cond_simple_fc # test C version for i, val in enumerate(test_vals): in_var = val @@ -1202,32 +1202,41 @@ def test_transpile_multiconditional(tmp_path, builder, frontend, codegen): assert out_var == expected_results[i] -@pytest.mark.parametrize('frontend', available_frontends()) -def test_transpile_multiconditional_range(tmp_path, frontend): +@pytest.mark.parametrize('frontend', available_frontends( + skip=[(OFP, 'OFP got problems with RangeIndex as case value!')] +)) +def test_transpile_multiconditional(tmp_path, builder, frontend): """ - A simple test to verify multiconditionals/select case statements. + A test to verify multiconditionals/select case statements. """ fcode = """ -subroutine transpile_multi_conditional_range(in, out) +subroutine multi_cond(in, out) implicit none integer, intent(in) :: in integer, intent(inout) :: out select case (in) - case (1:5) + case (:5) out = 10 + case (6, 7, 10:15) + out = 15 + case (8) + out = 12 + case (20:30) + out = 20 case default out = 100 end select -end subroutine transpile_multi_conditional_range +end subroutine multi_cond """.strip() # for testing purposes in_var = 0 - test_vals = [0, 1, 2, 5, 6] - expected_results = [100, 10, 10, 10, 100] + # [(, ), (, ), ...] + test_results = [(0, 10), (1, 10), (5, 10), (6, 15), (10, 15), (11, 15), + (15, 15), (8, 12), (20, 20), (21, 20), (29, 20), (50, 100)] out_var = np.int_([0]) # compile original Fortran version @@ -1235,19 +1244,28 @@ def test_transpile_multiconditional_range(tmp_path, frontend): filepath = tmp_path/f'{routine.name}_{frontend!s}.f90' function = jit_compile(routine, filepath=filepath, objname=routine.name) # test Fortran version - for i, val in enumerate(test_vals): - in_var = val + for val in test_results: + in_var = val[0] function(in_var, out_var) - assert out_var == expected_results[i] - - clean_test(filepath) + assert out_var == val[1] # apply F2C trafo - # TODO: RangeIndex as case is not yet implemented! - # 'NotImplementedError' is raised f2c = FortranCTransformation() - with pytest.raises(NotImplementedError): - f2c.apply(source=routine, path=tmp_path) + f2c.apply(source=routine, path=tmp_path) + + # check whether 'switch' statement is within C code + assert 'switch' in cgen(routine) + + # compile C version + libname = f'fc_{routine.name}_{frontend}' + c_kernel = jit_compile_lib([f2c.wrapperpath, f2c.c_path], path=tmp_path, name=libname, builder=builder) + fc_function = c_kernel.multi_cond_fc_mod.multi_cond_fc + # test C version + for val in test_results: + in_var = val[0] + fc_function(in_var, out_var) + assert out_var == val[1] + @pytest.fixture(scope='module', name='horizontal') def fixture_horizontal():