Skip to content

Commit

Permalink
Merge pull request #392 from ecmwf-ifs/naml_fix_multicond_scc_vector
Browse files Browse the repository at this point in the history
SingleColumn: Fix vectorisation of nested else-if bodies
  • Loading branch information
reuterbal authored Oct 10, 2024
2 parents 3834903 + 53b2080 commit 0189d24
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 13 deletions.
10 changes: 10 additions & 0 deletions loki/ir/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,16 @@ def __repr__(self):
return f'Conditional:: {self.name}'
return 'Conditional::'

@property
def else_bodies(self):
"""
Return all nested node tuples in the ``ELSEIF``/``ELSE`` part
of the conditional chain.
"""
if self.has_elseif:
return (self.else_body[0].body,) + self.else_body[0].else_bodies
return (self.else_body,) if self.else_body else ()


@dataclass_strict(frozen=True)
class _PragmaRegionBase():
Expand Down
48 changes: 48 additions & 0 deletions loki/ir/tests/test_ir_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,54 @@ def test_conditional(scope, one, i, n, a_i):
# TODO: Test inline, name, has_elseif


def test_multi_conditional(scope, one, i, n, a_i):
"""
Test nested chains of constructors of :any:`Conditional` to form
multi-conditional.
"""
multicond = ir.Conditional(
condition=sym.Comparison(i, '==', sym.IntLiteral(1)),
body=ir.Assignment(lhs=a_i, rhs=sym.Literal(1.0)),
else_body=ir.Assignment(lhs=a_i, rhs=sym.Literal(42.0))
)
for idx in range(2, 4):
multicond = ir.Conditional(
condition=sym.Comparison(i, '==', sym.IntLiteral(idx)),
body=ir.Assignment(lhs=a_i, rhs=sym.Literal(float(idx))),
else_body=multicond, has_elseif=True
)

# Check that we can recover all bodies from a nested else-if construct
else_bodies = multicond.else_bodies
assert len(else_bodies) == 3
assert all(isinstance(b, tuple) for b in else_bodies)
assert isinstance(else_bodies[0][0], ir.Assignment)
assert else_bodies[0][0].lhs == 'a(i)' and else_bodies[0][0].rhs == '2.0'
assert isinstance(else_bodies[1][0], ir.Assignment)
assert else_bodies[1][0].lhs == 'a(i)' and else_bodies[1][0].rhs == '1.0'
assert isinstance(else_bodies[2][0], ir.Assignment)
assert else_bodies[2][0].lhs == 'a(i)' and else_bodies[2][0].rhs == '42.0'

# Not try without the final else
multicond = ir.Conditional(
condition=sym.Comparison(i, '==', sym.IntLiteral(1)),
body=ir.Assignment(lhs=a_i, rhs=sym.Literal(1.0)),
)
for idx in range(2, 4):
multicond = ir.Conditional(
condition=sym.Comparison(i, '==', sym.IntLiteral(idx)),
body=ir.Assignment(lhs=a_i, rhs=sym.Literal(float(idx))),
else_body=multicond, has_elseif=True
)
else_bodies = multicond.else_bodies
assert len(else_bodies) == 2
assert all(isinstance(b, tuple) for b in else_bodies)
assert isinstance(else_bodies[0][0], ir.Assignment)
assert else_bodies[0][0].lhs == 'a(i)' and else_bodies[0][0].rhs == '2.0'
assert isinstance(else_bodies[1][0], ir.Assignment)
assert else_bodies[1][0].lhs == 'a(i)' and else_bodies[1][0].rhs == '1.0'


def test_section(scope, one, i, n, a_n, a_i):
"""
Test constructors and behaviour of :any:`Section` nodes.
Expand Down
28 changes: 23 additions & 5 deletions loki/transformations/single_column/tests/test_scc_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,20 +573,30 @@ def test_scc_devector_section_special_case(frontend, horizontal, vertical, block
"""

fcode_kernel = """
subroutine some_kernel(start, end, nlon, flag0, flag1)
subroutine some_kernel(start, end, nlon, flag0, flag1, flag2)
implicit none
integer, intent(in) :: nlon, start, end
logical, intent(in) :: flag0, flag1
logical, intent(in) :: flag0, flag1, flag2
real, dimension(nlon) :: work
integer :: jl
if(flag0)then
if (flag0) then
call some_other_kernel()
elseif(flag1)then
elseif (flag1) then
do jl=start,end
work(jl) = 1.
enddo
elseif (flag2) then
do jl=start,end
work(jl) = 1.
work(jl) = 2.
enddo
else
do jl=start,end
work(jl) = 41.
work(jl) = 42.
enddo
endif
Expand All @@ -595,7 +605,7 @@ def test_scc_devector_section_special_case(frontend, horizontal, vertical, block

routine = Subroutine.from_source(fcode_kernel, frontend=frontend)

# check whether pipeline can be applied and works as expected
# check whether pipeline can be applied and works as expected
scc_pipeline = SCCVectorPipeline(
horizontal=horizontal, vertical=vertical, block_dim=blocking,
directive='openacc', trim_vector_sections=trim_vector_sections
Expand All @@ -611,3 +621,11 @@ def test_scc_devector_section_special_case(frontend, horizontal, vertical, block
assert isinstance(conditional.else_body[0].body[0], ir.Comment)
assert isinstance(conditional.else_body[0].body[1], ir.Loop)
assert conditional.else_body[0].body[1].pragma[0].content.lower() == 'loop vector'

# Check that all else-bodies have been wrapped
else_bodies = conditional.else_bodies
assert(len(else_bodies) == 3)
for body in else_bodies:
assert isinstance(body[0], ir.Comment)
assert isinstance(body[1], ir.Loop)
assert body[1].pragma[0].content.lower() == 'loop vector'
12 changes: 4 additions & 8 deletions loki/transformations/single_column/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,15 +137,11 @@ def extract_vector_sections(cls, section, horizontal):
subsec_body = cls.extract_vector_sections(separator.body, horizontal)
if subsec_body:
subsections += subsec_body
# we need to prevent that the whole 'else_body' is wrapped in a section,
# as 'Conditional's rely on the fact that the first element of the 'else_body'
# we need to prevent that all (possibly nested) 'else_bodies' are completely wrapped as a section,
# as 'Conditional's rely on the fact that the first element of each 'else_body'
# (if 'has_elseif') is a Conditional itself
if separator.has_elseif and separator.else_body:
subsec_else = cls.extract_vector_sections(separator.else_body[0].body, horizontal)
else:
subsec_else = cls.extract_vector_sections(separator.else_body, horizontal)
if subsec_else:
subsections += subsec_else
for ebody in separator.else_bodies:
subsections += cls.extract_vector_sections(ebody, horizontal)

if isinstance(separator, ir.MultiConditional):
for body in separator.bodies:
Expand Down

0 comments on commit 0189d24

Please sign in to comment.