Skip to content

Commit

Permalink
Merge pull request #156 from ecmwf-ifs/naml-scc-recursive-hoist
Browse files Browse the repository at this point in the history
Loki-SCC: Adapt to use recursive hoisting transformation
  • Loading branch information
reuterbal authored Oct 2, 2023
2 parents 9a17903 + 57a59ed commit 19fd1c4
Show file tree
Hide file tree
Showing 8 changed files with 334 additions and 467 deletions.
132 changes: 78 additions & 54 deletions loki/transform/transform_hoist_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,19 +112,16 @@ class HoistVariablesAnalysis(Transformation):

_key = 'HoistVariablesTransformation'

def __init__(self, key=None, disable=None):
def __init__(self, key=None):
if key is not None:
self._key = key
if disable is None:
self.disable = ()
else:
self.disable = [_.lower() for _ in disable]

def transform_subroutine(self, routine, **kwargs):
"""
Analysis applied to :any:`Subroutine` item.
Collects all the variables to be hoisted, including renaming in order to grant for unique variable names.
Collects all the variables to be hoisted, including renaming
in order to grant for unique variable names.
Parameters
----------
Expand All @@ -136,9 +133,7 @@ def transform_subroutine(self, routine, **kwargs):

role = kwargs.get('role', None)
item = kwargs.get('item', None)
_successors = kwargs.get('successors', ())
successors = [_ for _ in _successors if _.local_name.lower()
not in self.disable and _.name.lower() not in self.disable]
successors = kwargs.get('successors', ())

item.trafo_data[self._key] = {}

Expand All @@ -151,8 +146,7 @@ def transform_subroutine(self, routine, **kwargs):
item.trafo_data[self._key]["to_hoist"] = []
item.trafo_data[self._key]["hoist_variables"] = []

calls = [call for call in FindNodes(CallStatement).visit(routine.body) if call.name
not in self.disable]
calls = FindNodes(CallStatement).visit(routine.body)
call_map = CaseInsensitiveDict((str(call.name), call) for call in calls)

for child in successors:
Expand Down Expand Up @@ -204,13 +198,9 @@ class HoistVariablesTransformation(Transformation):

_key = 'HoistVariablesTransformation'

def __init__(self, key=None, disable=None):
def __init__(self, key=None):
if key is not None:
self._key = key
if disable is None:
self.disable = ()
else:
self.disable = [_.lower() for _ in disable]

def transform_subroutine(self, routine, **kwargs):
"""
Expand All @@ -232,25 +222,20 @@ def transform_subroutine(self, routine, **kwargs):
"""
role = kwargs.get('role', None)
item = kwargs.get('item', None)
# targets = kwargs.get('targets', None)
_successors = kwargs.get('successors', ())
successors = [_ for _ in _successors if _.local_name.lower()
not in self.disable and _.name.lower() not in self.disable]
successor_map = {successor.routine.name: successor for successor in successors}

if item.local_name.lower() in self.disable:
return
successors = kwargs.get('successors', ())
successor_map = CaseInsensitiveDict(
(successor.routine.name, successor) for successor in successors
)

if self._key not in item.trafo_data:
raise RuntimeError(f'{self.__class__.__name__} requires key "{self._key}" in item.trafo_data!\n'
f'Make sure to call HoistVariablesAnalysis (or any derived class) before and to provide '
f'the correct key.')

if role == 'driver':
for var in item.trafo_data[self._key]["to_hoist"]:
self.driver_variable_declaration(routine, var)
self.driver_variable_declaration(routine, item.trafo_data[self._key]["to_hoist"])
else:
# We build the list of tempararies that are hoisted to the calling routine
# We build the list of temporaries that are hoisted to the calling routine
# Because this requires adding an intent, we need to make sure they are not
# declared together with non-hoisted variables
hoisted_temporaries = tuple(
Expand All @@ -261,37 +246,69 @@ def transform_subroutine(self, routine, **kwargs):
routine.arguments += hoisted_temporaries

call_map = {}
calls = [_ for _ in FindNodes(CallStatement).visit(routine.body) if _.name not in self.disable]
for call in calls:
new_args = [arg.clone(dimensions=None) for arg
in successor_map[str(call.routine.name)].trafo_data[self._key]["hoist_variables"]]
arguments = list(call.arguments) + new_args
call_map[call] = call.clone(arguments=as_tuple(arguments))
for call in FindNodes(CallStatement).visit(routine.body):
# Only process calls in this call tree
if str(call.name) not in successor_map:
continue

successor_item = successor_map[str(call.routine.name)]
hoisted_variables = successor_item.trafo_data[self._key]["hoist_variables"]
call_map[call] = self.driver_call_argument_remapping(
routine=routine, call=call, variables=hoisted_variables
)

routine.body = Transformer(call_map).visit(routine.body)

def driver_variable_declaration(self, routine, var):
def driver_variable_declaration(self, routine, variables):
"""
**Override**: Define the variable declaration (and possibly allocation, de-allocation, ...)
for each variable to be hoisted.
**Override**: Define the variable declaration (and possibly
allocation, de-allocation, ...) for each variable to be
hoisted.
Declares hoisted variables with a re-scope.
Parameters
----------
routine : :any:`Subroutine`
The subroutine to add the variable declaration to.
var : :any:`Variable`
The variable to be declared.
variables : tuple of :any:`Variable`
The tuple of variables to be declared.
"""
routine.variables += tuple([var.rescope(routine)])
routine.variables += tuple(v.rescope(routine) for v in variables)

def driver_call_argument_remapping(self, routine, call, variables):
"""
Callback method to re-map hoisted arguments for the driver-level routine.
The callback will simply add all the hoisted variable arrays to the call
without dimension range symbols.
This callback is used to adjust the argument variable mapping, so that
the call signature in the driver can be adjusted to the declaration
scheme of subclassed variants of the basic hoisting tnansformation.
Potentially, different variants of the hoist transformation can override
the behaviour here to map to a differnt call invocation scheme.
Parameters
----------
routine : :any:`Subroutine`
The subroutine to add the variable declaration to.
call : :any:`CallStatement`
Call object to which hoisted variables will be added.
variables : tuple of :any:`Variable`
The tuple of variables to be declared.
"""
# pylint: disable=unused-argument
new_args = tuple(v.clone(dimensions=None) for v in variables)
return call.clone(arguments=call.arguments + new_args)


class HoistTemporaryArraysAnalysis(HoistVariablesAnalysis):
"""
**Specialisation** for the *Analysis* part of the hoist variables functionality/transformation, to hoist only
temporary arrays and if provided only temporary arrays with specific variables/variable names within the
array dimensions.
**Specialisation** for the *Analysis* part of the hoist variables
functionality/transformation, to hoist only temporary arrays and
if provided only temporary arrays with specific variables/variable
names within the array dimensions.
.. code-block::python
Expand All @@ -308,8 +325,8 @@ class HoistTemporaryArraysAnalysis(HoistVariablesAnalysis):
for the array dimensions.
"""

def __init__(self, key=None, disable=None, dim_vars=None, **kwargs):
super().__init__(key=key, disable=disable, **kwargs)
def __init__(self, key=None, dim_vars=None, **kwargs):
super().__init__(key=key, **kwargs)
self.dim_vars = dim_vars
if self.dim_vars is not None:
assert is_iterable(self.dim_vars)
Expand All @@ -336,8 +353,10 @@ def find_variables(self, routine):

class HoistTemporaryArraysTransformationAllocatable(HoistVariablesTransformation):
"""
**Specialisation** for the *Synthesis* part of the hoist variables functionality/transformation, to hoist temporary
arrays and make them ``allocatable``, including the actual *allocation* and *de-allocation*.
**Specialisation** for the *Synthesis* part of the hoist variables
functionality/transformation, to hoist temporary arrays and make
them ``allocatable``, including the actual *allocation* and
*de-allocation*.
Parameters
----------
Expand All @@ -346,21 +365,26 @@ class HoistTemporaryArraysTransformationAllocatable(HoistVariablesTransformation
these transformations are carried out in succession.
"""

def __init__(self, key=None, disable=None, **kwargs):
super().__init__(key=key, disable=disable, **kwargs)
def __init__(self, key=None, **kwargs):
super().__init__(key=key, **kwargs)

def driver_variable_declaration(self, routine, var):
def driver_variable_declaration(self, routine, variables):
"""
Declares hoisted arrays as ``allocatable``, including *allocation* and *de-allocation*.
Parameters
----------
routine : :any:`Subroutine`
The subroutine to add the variable declaration to.
var : :any:`Variable`
variables : tuple of :any:`Variable`
The array to be declared, allocated and de-allocated.
"""
routine.variables += tuple([var.clone(scope=routine, dimensions=as_tuple(
[sym.RangeIndex((None, None))] * (len(var.dimensions))), type=var.type.clone(allocatable=True))])
routine.body.prepend(Allocation((var.clone(),)))
routine.body.append(Deallocation((var.clone(dimensions=None),)))
for var in variables:
routine.variables += as_tuple(
var.clone(
dimensions=as_tuple([sym.RangeIndex((None, None))] * len(var.dimensions)),
type=var.type.clone(allocatable=True), scope=routine
)
)
routine.body.prepend(Allocation((var.clone(),)))
routine.body.append(Deallocation((var.clone(dimensions=None),)))
27 changes: 6 additions & 21 deletions loki/transform/transform_parametrise.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,17 +121,15 @@ def error_stop(**kwargs):
dic2p = {'a': 12, 'b': 11}
transformation = ParametriseTransformation(dic2p=dic2p, disable=("ignore_this_func", "ignore_another_func"),
abort_callback=error_stop, entry_points=("driver1", "driver2"))
transformation = ParametriseTransformation(dic2p=dic2p, abort_callback=error_stop,
entry_points=("driver1", "driver2"))
scheduler.process(transformation=transformation)
Parameters
----------
dic2p: dict
Dictionary of variable names and corresponding values to be parametrised.
disable: tuple
Tuple of subroutines not to be processed.
replace_by_value: bool
Replace variables entirely by value (default: `False`)
entry_points: None or tuple
Expand All @@ -152,12 +150,8 @@ def error_stop(**kwargs):

_key = "ParametriseTransformation"

def __init__(self, dic2p, disable=None, replace_by_value=False, entry_points=None, abort_callback=None, key=None):
def __init__(self, dic2p, replace_by_value=False, entry_points=None, abort_callback=None, key=None):
self.dic2p = dic2p
if disable is None:
self.disable = ()
else:
self.disable = [_.upper() for _ in disable]
self.replace_by_value = replace_by_value
if entry_points is not None:
self.entry_points = [_.upper() for _ in entry_points]
Expand Down Expand Up @@ -188,17 +182,8 @@ def transform_subroutine(self, routine, **kwargs):
role = kwargs.get('role', None)

_successors = kwargs.get('successors', None)
successors = []
for successor in _successors:
append = True
for _disable in self.disable:
if _disable.upper() in successor.name.upper():
append = False
break
if append:
successors.append(successor)
successor_map = {successor.routine.name: successor for successor in successors if successor.name.upper()
not in self.disable}
successor_map = {successor.routine.name: successor for successor in _successors}
successors = [successor.local_name.upper() for successor in _successors]

# decide whether subroutine is an entry point or not
process_entry_point = False
Expand Down Expand Up @@ -263,7 +248,7 @@ def transform_subroutine(self, routine, **kwargs):
# remove variables to be parametrised from all call statements
call_map = {}
for call in FindNodes(ir.CallStatement).visit(routine.body):
if str(call.name).upper() not in self.disable:
if str(call.name).upper() in successors:
successor_map[str(call.name)].trafo_data[self._key] = {}
arg_map = dict(call.arg_iter())
arg_map_reversed = {v: k for k, v in arg_map.items()}
Expand Down
Loading

0 comments on commit 19fd1c4

Please sign in to comment.