Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transformations: Move common SCC utility routines to utilities #354

Merged
merged 9 commits into from
Aug 1, 2024
11 changes: 7 additions & 4 deletions loki/transformations/block_index_transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
from loki.types import SymbolAttributes, BasicType
from loki.expression import Variable, Array, RangeIndex, FindVariables, SubstituteExpressions
from loki.transformations.sanitise import resolve_associates
from loki.transformations.utilities import recursive_expression_map_update
from loki.transformations.utilities import (
recursive_expression_map_update, get_integer_variable,
get_loop_bounds, check_routine_pragmas
)
from loki.transformations.single_column.base import SCCBaseTransformation

__all__ = ['BlockViewToFieldViewTransformation', 'InjectBlockIndexTransformation']
Expand Down Expand Up @@ -232,14 +235,14 @@ def process_kernel(self, routine, item, successors, targets, exclude_arrays):

# Sanitize the subroutine
resolve_associates(routine)
v_index = SCCBaseTransformation.get_integer_variable(routine, name=self.horizontal.index)
v_index = get_integer_variable(routine, name=self.horizontal.index)
SCCBaseTransformation.resolve_masked_stmts(routine, loop_variable=v_index)

# Bail if routine is marked as sequential or routine has already been processed
if SCCBaseTransformation.check_routine_pragmas(routine, directive=None):
if check_routine_pragmas(routine, directive=None):
return

bounds = SCCBaseTransformation.get_horizontal_loop_bounds(routine, self.horizontal)
bounds = get_loop_bounds(routine, self.horizontal)
SCCBaseTransformation.resolve_vector_dimension(routine, loop_variable=v_index, bounds=bounds)

# for kernels we process the entire body
Expand Down
20 changes: 11 additions & 9 deletions loki/transformations/single_column/annotate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
pragma_regions_attached, is_loki_pragma
)
from loki.logging import info
from loki.tools import as_tuple, flatten, CaseInsensitiveDict
from loki.tools import as_tuple, flatten
from loki.types import DerivedType

from loki.transformations.single_column.base import SCCBaseTransformation
from loki.transformations.utilities import (
find_driver_loops, get_local_arrays, check_routine_pragmas
)


__all__ = ['SCCAnnotateTransformation']
Expand Down Expand Up @@ -63,11 +65,11 @@ def kernel_annotate_vector_loops_openacc(cls, routine, horizontal):
"""

# Find any local arrays that need explicitly privatization
argument_map = CaseInsensitiveDict({a.name: a for a in routine.arguments})
private_arrays = [v for v in routine.variables if not v.name in argument_map]
private_arrays = [v for v in private_arrays if isinstance(v, sym.Array)]
private_arrays = [v for v in private_arrays
if all(is_dimension_constant(d) for d in v.shape)]
private_arrays = get_local_arrays(routine, section=routine.spec)
private_arrays = [
v for v in private_arrays
if all(is_dimension_constant(d) for d in v.shape)
]

if private_arrays:
# Log private arrays in vector regions, as these can impact performance
Expand Down Expand Up @@ -204,7 +206,7 @@ def process_kernel(self, routine):
"""

# Bail if routine is marked as sequential
if SCCBaseTransformation.check_routine_pragmas(routine, self.directive):
if check_routine_pragmas(routine, self.directive):
return

if self.directive == 'openacc':
Expand Down Expand Up @@ -239,7 +241,7 @@ def process_driver(self, routine, targets=None):
break

with pragmas_attached(routine, ir.Loop, attach_pragma_post=True):
driver_loops = SCCBaseTransformation.find_driver_loops(routine=routine, targets=targets)
driver_loops = find_driver_loops(routine=routine, targets=targets)
for loop in driver_loops:
loops = FindNodes(ir.Loop).visit(loop.body)
kernel_loops = [l for l in loops if l.variable == self.horizontal.index]
Expand Down
160 changes: 7 additions & 153 deletions loki/transformations/single_column/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
symbols as sym, FindExpressions, SubstituteExpressions
)
from loki.ir import nodes as ir, FindNodes, Transformer
from loki.logging import debug
from loki.tools import as_tuple
from loki.types import SymbolAttributes, BasicType


from loki.transformations.sanitise import resolve_associates
from loki.transformations.utilities import (
get_integer_variable, get_loop_bounds, check_routine_pragmas
)


__all__ = ['SCCBaseTransformation']
Expand Down Expand Up @@ -43,89 +43,6 @@ def __init__(self, horizontal, directive=None):
assert directive in [None, 'openacc']
self.directive = directive

@classmethod
def check_routine_pragmas(cls, routine, directive):
"""
Check if routine is marked as sequential or has already been processed.

Parameters
----------
routine : :any:`Subroutine`
Subroutine to perform checks on.
directive: string or None
Directives flavour to use for parallelism annotations; either
``'openacc'`` or ``None``.
"""

pragmas = FindNodes(ir.Pragma).visit(routine.ir)
routine_pragmas = [p for p in pragmas if p.keyword.lower() in ['loki', 'acc']]
routine_pragmas = [p for p in routine_pragmas if 'routine' in p.content.lower()]

seq_pragmas = [r for r in routine_pragmas if 'seq' in r.content.lower()]
if seq_pragmas:
loki_seq_pragmas = [r for r in routine_pragmas if 'loki' == r.keyword.lower()]
if loki_seq_pragmas:
if directive == 'openacc':
# Mark routine as acc seq
mapper = {seq_pragmas[0]: None}
routine.spec = Transformer(mapper).visit(routine.spec)
routine.body = Transformer(mapper).visit(routine.body)

# Append the acc pragma to routine.spec, regardless of where the corresponding
# loki pragma is found
routine.spec.append(ir.Pragma(keyword='acc', content='routine seq'))
return True

vec_pragmas = [r for r in routine_pragmas if 'vector' in r.content.lower()]
if vec_pragmas:
if directive == 'openacc':
return True

return False

@classmethod
def get_horizontal_loop_bounds(cls, routine, horizontal):
"""
Check for horizontal loop bounds in a :any:`Subroutine`.

Parameters
----------
routine : :any:`Subroutine`
Subroutine to perform checks on.
horizontal : :any:`Dimension`
:any:`Dimension` object describing the variable conventions used in code
to define the horizontal data dimension and iteration space.
"""

bounds = ()
variables = routine.variables
for name, _bounds in zip(['start', 'end'], horizontal.bounds_expressions):
for bound in _bounds:
if bound.split('%', maxsplit=1)[0] in variables:
bounds += (bound,)
break
else:
raise RuntimeError(f'No horizontol {name} variable matching {_bounds[0]} found in {routine.name}')

return bounds

@classmethod
def get_integer_variable(cls, routine, name):
"""
Find a local variable in the routine, or create an integer-typed one.

Parameters
----------
routine : :any:`Subroutine`
The subroutine in which to find the variable
name : string
Name of the variable to find the in the routine.
"""
if not (v_index := routine.symbol_map.get(name, None)):
dtype = SymbolAttributes(BasicType.INTEGER)
v_index = sym.Variable(name=name, type=dtype, scope=routine)
return v_index

@classmethod
def resolve_masked_stmts(cls, routine, loop_variable):
"""
Expand Down Expand Up @@ -180,25 +97,14 @@ def resolve_vector_dimension(cls, routine, loop_variable, bounds):

bounds_str = f'{bounds[0]}:{bounds[1]}'

variable_map = routine.variable_map
try:
bounds_v = (routine.resolve_typebound_var(bounds[0], variable_map),
routine.resolve_typebound_var(bounds[1], variable_map))
except KeyError:
debug(
'SCCBaseTransformation.resolve_vector_dimension: '
f'Dimension bound {bounds[0]} or {bounds[1]} not found in {routine.name}.'
)
return

mapper = {}
for stmt in FindNodes(ir.Assignment).visit(routine.body):
ranges = [e for e in FindExpressions().visit(stmt)
if isinstance(e, sym.RangeIndex) and e == bounds_str]
if ranges:
exprmap = {r: loop_variable for r in ranges}
loop = ir.Loop(
variable=loop_variable, bounds=sym.LoopRange(bounds_v),
variable=loop_variable, bounds=sym.LoopRange(bounds),
body=as_tuple(SubstituteExpressions(exprmap).visit(stmt))
)
mapper[stmt] = loop
Expand All @@ -209,58 +115,6 @@ def resolve_vector_dimension(cls, routine, loop_variable, bounds):
if mapper and loop_variable not in routine.variables:
routine.variables += as_tuple(loop_variable)

@staticmethod
def is_driver_loop(loop, targets):
"""
Test/check whether a given loop is a *driver loop*.

Parameters
----------
loop : :any: `Loop`
The loop to test if it is a *driver loop*.
targets : list or string
List of subroutines that are to be considered as part of
the transformation call tree.
"""
if loop.pragma:
for pragma in loop.pragma:
if pragma.keyword.lower() == "loki" and pragma.content.lower() == "driver-loop":
return True
for call in FindNodes(ir.CallStatement).visit(loop.body):
if call.name in targets:
return True
return False

@classmethod
def find_driver_loops(cls, routine, targets):
"""
Find and return all driver loops of a given `routine`.

A *driver loop* is specified either by a call to a routine within
`targets` or by the pragma `!$loki driver-loop`.

Parameters
----------
routine : :any:`Subroutine`
The subroutine in which to find the driver loops.
targets : list or string
List of subroutines that are to be considered as part of
the transformation call tree.
"""

driver_loops = []
nested_driver_loops = []
for loop in FindNodes(ir.Loop).visit(routine.body):
if loop in nested_driver_loops:
continue

if not cls.is_driver_loop(loop, targets):
continue

driver_loops.append(loop)
loops = FindNodes(ir.Loop).visit(loop.body)
nested_driver_loops.extend(loops)
return driver_loops

def transform_subroutine(self, routine, **kwargs):
"""
Expand Down Expand Up @@ -292,14 +146,14 @@ def process_kernel(self, routine):
"""

# Bail if routine is marked as sequential or routine has already been processed
if self.check_routine_pragmas(routine, self.directive):
if check_routine_pragmas(routine, self.directive):
return

# check for horizontal loop bounds in subroutine symbol table
bounds = self.get_horizontal_loop_bounds(routine, self.horizontal)
bounds = get_loop_bounds(routine, dimension=self.horizontal)

# Find the iteration index variable for the specified horizontal
v_index = self.get_integer_variable(routine, name=self.horizontal.index)
v_index = get_integer_variable(routine, name=self.horizontal.index)

# Associates at the highest level, so they don't interfere
# with the sections we need to do for detecting subroutine calls
Expand Down
6 changes: 3 additions & 3 deletions loki/transformations/single_column/hoist.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from loki.ir import nodes as ir

from loki.transformations.hoist_variables import HoistVariablesTransformation
from loki.transformations.single_column.base import SCCBaseTransformation
from loki.transformations.utilities import get_integer_variable


__all__ = ['SCCHoistTemporaryArraysTransformation']
Expand Down Expand Up @@ -55,7 +55,7 @@ def driver_variable_declaration(self, routine, variables):
'for array argument hoisting.'
)

block_var = SCCBaseTransformation.get_integer_variable(routine, self.block_dim.size)
block_var = get_integer_variable(routine, self.block_dim.size)
routine.variables += tuple(
v.clone(
dimensions=v.dimensions + (block_var,),
Expand Down Expand Up @@ -95,7 +95,7 @@ def driver_call_argument_remapping(self, routine, call, variables):
'[Loki] SingleColumnCoalescedTransform: No blocking dimension found '
'for array argument hoisting.'
)
idx_var = SCCBaseTransformation.get_integer_variable(routine, self.block_dim.index)
idx_var = get_integer_variable(routine, self.block_dim.index)
if self.as_kwarguments:
new_kwargs = tuple(
(a.name, v.clone(dimensions=tuple(sym.RangeIndex((None, None))
Expand Down
6 changes: 4 additions & 2 deletions loki/transformations/single_column/scc_cuf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
from loki.transformations.sanitise import resolve_associates
from loki.transformations.single_column.base import SCCBaseTransformation
from loki.transformations.single_column.vector import SCCDevectorTransformation
from loki.transformations.utilities import single_variable_declaration
from loki.transformations.utilities import (
single_variable_declaration, get_integer_variable
)


__all__ = [
Expand Down Expand Up @@ -745,7 +747,7 @@ def process_routine_kernel(self, routine, depth=1, targets=None):
The subroutines depth
"""

v_index = SCCBaseTransformation.get_integer_variable(routine, name=self.horizontal.index)
v_index = get_integer_variable(routine, name=self.horizontal.index)
resolve_associates(routine)
SCCBaseTransformation.resolve_masked_stmts(routine, loop_variable=v_index)
SCCBaseTransformation.resolve_vector_dimension(routine, loop_variable=v_index, bounds=self.horizontal.bounds)
Expand Down
7 changes: 4 additions & 3 deletions loki/transformations/single_column/tests/test_scc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
)

from loki.transformations import (
DataOffloadTransformation, SanitiseTransformation, InlineTransformation
DataOffloadTransformation, SanitiseTransformation,
InlineTransformation, get_loop_bounds
)
from loki.transformations.single_column import (
SCCBaseTransformation, SCCDevectorTransformation,
Expand Down Expand Up @@ -948,11 +949,11 @@ def test_scc_base_horizontal_bounds_checks(frontend, horizontal, horizontal_boun
transform = SCCBaseTransformation(horizontal=horizontal_bounds_aliases)
transform.apply(alias, role='kernel')

bounds = SCCBaseTransformation.get_horizontal_loop_bounds(routine, horizontal_bounds_aliases)
bounds = get_loop_bounds(routine, dimension=horizontal_bounds_aliases)
assert bounds[0] == 'start'
assert bounds[1] == 'end'

bounds = SCCBaseTransformation.get_horizontal_loop_bounds(alias, horizontal_bounds_aliases)
bounds = get_loop_bounds(alias, dimension=horizontal_bounds_aliases)
assert bounds[0] == 'bnds%start'
assert bounds[1] == 'bnds%end'

Expand Down
Loading
Loading