Skip to content

Commit

Permalink
Merge pull request #267 from ecmwf-ifs/nams_cgen_multiconditional
Browse files Browse the repository at this point in the history
`cgen`: multiconditional/switch/select case statement
  • Loading branch information
reuterbal authored Apr 8, 2024
2 parents 2814083 + 9d3b2f5 commit 5772340
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 3 deletions.
41 changes: 39 additions & 2 deletions loki/backend/cgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,13 @@
PREC_UNARY, PREC_LOGICAL_OR, PREC_LOGICAL_AND, PREC_NONE, PREC_CALL
)

from loki.tools import as_tuple
from loki.ir import Import, Stringifier, FindNodes
from loki.expression import LokiStringifyMapper, Array, symbolic_op, Literal
from loki.expression import (
LokiStringifyMapper, Array, symbolic_op, Literal,
symbols as sym
)
from loki.types import BasicType, SymbolAttributes, DerivedType

__all__ = ['cgen', 'CCodegen', 'CCodeMapper']


Expand Down Expand Up @@ -364,6 +367,40 @@ def visit_TypeDef(self, o, **kwargs):
self.depth -= 1
return self.join_lines(header, decls, footer)

def visit_MultiConditional(self, o, **kwargs):
"""
Format as
switch case (<expr>) {
case <value>:
...body...
[case <value>:]
[...body...]
[default:]
[...body...]
}
"""
header = self.format_line('switch (', self.visit(o.expr, **kwargs), ') {')
cases = []
end_cases = []
for value in o.values:
if any(isinstance(val, sym.RangeIndex) for val in value):
# TODO: in Fortran a case can be a range, which is not straight-forward
# to translate/transfer to C
#  https://j3-fortran.org/doc/year/10/10-007.pdf#page=200
raise NotImplementedError
case = self.visit_all(as_tuple(value), **kwargs)
cases.append(self.format_line('case ', self.join_items(case), ':'))
end_cases.append(self.format_line('break;'))
if o.else_body:
cases.append(self.format_line('default: '))
end_cases.append(self.format_line('break;'))
footer = self.format_line('}')
self.depth += 1
bodies = self.visit_all(*o.bodies, o.else_body, **kwargs)
self.depth -= 1
branches = [item for branch in zip(cases, bodies, end_cases) for item in branch]
return self.join_lines(header, *branches, footer)


def cgen(ir):
"""
Expand Down
112 changes: 111 additions & 1 deletion tests/test_transpile.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from conftest import jit_compile, jit_compile_lib, clean_test, available_frontends
from loki import (
Subroutine, Module, FortranCTransformation, OFP
Subroutine, Module, FortranCTransformation, OFP, cgen
)
from loki.build import Builder
from loki.transform import normalize_range_indexing
Expand Down Expand Up @@ -1003,3 +1003,113 @@ def test_transpile_expressions(here, builder, frontend, use_c_ptr):
clean_test(filepath)
f2c.wrapperpath.unlink()
f2c.c_path.unlink()

@pytest.mark.parametrize('frontend', available_frontends())
def test_transpile_multiconditional(here, builder, frontend):
"""
A simple test to verify multiconditionals/select case statements.
"""

fcode = """
subroutine transpile_multi_conditional(in, out)
implicit none
integer, intent(in) :: in
integer, intent(inout) :: out
select case (in)
case (1)
out = 10
case (2)
out = 20
case default
out = 100
end select
end subroutine transpile_multi_conditional
""".strip()

# for testing purposes
in_var = 0
test_vals = [0, 1, 2, 5]
expected_results = [100, 10, 20, 100]
out_var = np.int_([0])

# compile original Fortran version
routine = Subroutine.from_source(fcode, frontend=frontend)
filepath = here/f'{routine.name}_{frontend!s}.f90'
function = jit_compile(routine, filepath=filepath, objname=routine.name)
# test Fortran version
for i, val in enumerate(test_vals):
in_var = val
function(in_var, out_var)
assert out_var == expected_results[i]

# apply F2C trafo
f2c = FortranCTransformation()
f2c.apply(source=routine, path=here)

# check whether 'switch' statement is within C code
assert 'switch' in cgen(routine)

# compile C version
libname = f'fc_{routine.name}_{frontend}'
c_kernel = jit_compile_lib([f2c.wrapperpath, f2c.c_path], path=here, name=libname, builder=builder)
fc_function = c_kernel.transpile_multi_conditional_fc_mod.transpile_multi_conditional_fc
# test C version
for i, val in enumerate(test_vals):
in_var = val
fc_function(in_var, out_var)
assert out_var == expected_results[i]

# cleanup ...
builder.clean()
clean_test(filepath)
f2c.wrapperpath.unlink()
f2c.c_path.unlink()

@pytest.mark.parametrize('frontend', available_frontends())
def test_transpile_multiconditional_range(here, frontend):
"""
A simple test to verify multiconditionals/select case statements.
"""

fcode = """
subroutine transpile_multi_conditional_range(in, out)
implicit none
integer, intent(in) :: in
integer, intent(inout) :: out
select case (in)
case (1:5)
out = 10
case default
out = 100
end select
end subroutine transpile_multi_conditional_range
""".strip()

# for testing purposes
in_var = 0
test_vals = [0, 1, 2, 5, 6]
expected_results = [100, 10, 10, 10, 100]
out_var = np.int_([0])

# compile original Fortran version
routine = Subroutine.from_source(fcode, frontend=frontend)
filepath = here/f'{routine.name}_{frontend!s}.f90'
function = jit_compile(routine, filepath=filepath, objname=routine.name)
# test Fortran version
for i, val in enumerate(test_vals):
in_var = val
function(in_var, out_var)
assert out_var == expected_results[i]

clean_test(filepath)

# apply F2C trafo
# TODO: RangeIndex as case is not yet implemented!
# 'NotImplementedError' is raised
f2c = FortranCTransformation()
with pytest.raises(NotImplementedError):
f2c.apply(source=routine, path=here)

0 comments on commit 5772340

Please sign in to comment.