diff --git a/transformations/tests/test_data_offload.py b/transformations/tests/test_data_offload.py index 89baa0bd7..425c1ba16 100644 --- a/transformations/tests/test_data_offload.py +++ b/transformations/tests/test_data_offload.py @@ -15,7 +15,10 @@ gettempdir, Scheduler, OMNI, Import ) from conftest import available_frontends -from transformations import DataOffloadTransformation, GlobalVariableAnalysis, GlobalVarOffloadTransformation +from transformations import ( + DataOffloadTransformation, GlobalVariableAnalysis, + GlobalVarOffloadTransformation, GlobalVarHoistTransformation +) @pytest.fixture(scope='module', name='here') @@ -378,16 +381,31 @@ def test_global_variable_analysis(frontend, key, config, global_variable_analysi key = GlobalVariableAnalysis._key # Validate the analysis trafo_data - nfld_dim = '1:3' if frontend == OMNI else 'nfld' - nval_dim = '1:5' if frontend == OMNI else 'nval' + + # OMNI handles array indices and parameters differently + if frontend == OMNI: + nfld_dim = '1:3' + nval_dim = '1:5' + nfld_data = set() + nval_data = set() + nval_offload = set() + nfld_offload = set() + else: + nfld_dim = 'nfld' + nval_dim = 'nval' + nfld_data = {('nfld', 'global_var_analysis_header_mod')} + nval_data = {('nval', 'global_var_analysis_header_mod')} + nval_offload = {'nval'} + nfld_offload = {'nfld'} + expected_trafo_data = { 'global_var_analysis_header_mod#nval': { 'declares': {f'iarr({nfld_dim})', f'rarr({nval_dim}, {nfld_dim})'}, - 'offload': set() + 'offload': nval_offload }, 'global_var_analysis_header_mod#nfld': { 'declares': {f'iarr({nfld_dim})', f'rarr({nval_dim}, {nfld_dim})'}, - 'offload': set() + 'offload': nfld_offload }, 'global_var_analysis_header_mod#iarr': { 'declares': {f'iarr({nfld_dim})', f'rarr({nval_dim}, {nfld_dim})'}, @@ -408,23 +426,24 @@ def test_global_variable_analysis(frontend, key, config, global_variable_analysi 'global_var_analysis_data_mod#some_routine': {'defines_symbols': set(), 'uses_symbols': set()}, 'global_var_analysis_kernel_mod#kernel_a': { 'defines_symbols': set(), - 'uses_symbols': { + 'uses_symbols': nval_data | nfld_data | { (f'iarr({nfld_dim})', 'global_var_analysis_header_mod'), (f'rarr({nval_dim}, {nfld_dim})', 'global_var_analysis_header_mod') } }, 'global_var_analysis_kernel_mod#kernel_b': { 'defines_symbols': {('rdata(:, :, :)', 'global_var_analysis_data_mod')}, - 'uses_symbols': { + 'uses_symbols': nfld_data | { ('rdata(:, :, :)', 'global_var_analysis_data_mod'), ('tt', 'global_var_analysis_data_mod'), ('tt%vals', 'global_var_analysis_data_mod'), (f'iarr({nfld_dim})', 'global_var_analysis_header_mod') } }, '#driver': { 'defines_symbols': {('rdata(:, :, :)', 'global_var_analysis_data_mod')}, - 'uses_symbols': { - ('rdata(:, :, :)', 'global_var_analysis_data_mod'), ('tt', 'global_var_analysis_data_mod'), - ('tt%vals', 'global_var_analysis_data_mod'), (f'iarr({nfld_dim})', 'global_var_analysis_header_mod'), + 'uses_symbols': nval_data | nfld_data | { + ('rdata(:, :, :)', 'global_var_analysis_data_mod'), + ('tt', 'global_var_analysis_data_mod'), ('tt%vals', 'global_var_analysis_data_mod'), + (f'iarr({nfld_dim})', 'global_var_analysis_header_mod'), (f'rarr({nval_dim}, {nfld_dim})', 'global_var_analysis_header_mod') } } @@ -634,3 +653,134 @@ def test_transformation_global_var_import_derived_type(here, config, frontend): # Note: g is not offloaded because it is not used by the kernel (albeit imported) assert 'p0' in pragmas[1].content assert 'p_array' in pragmas[2].content + + +@pytest.mark.parametrize('frontend', available_frontends()) +@pytest.mark.parametrize('hoist_parameters', (False, True)) +@pytest.mark.parametrize('ignore_modules', (None, ('moduleb',))) +def test_transformation_global_var_hoist(here, config, frontend, hoist_parameters, ignore_modules): + """ + Test hoisting of global variable imports. + """ + config['default']['enable_imports'] = True + config['routines'] = { + 'driver': {'role': 'driver'} + } + + scheduler = Scheduler(paths=here/'sources/projGlobalVarImports', config=config, frontend=frontend) + scheduler.process(transformation=GlobalVariableAnalysis()) + scheduler.process(transformation=GlobalVarHoistTransformation(hoist_parameters=hoist_parameters, + ignore_modules=ignore_modules)) + + driver = scheduler['#driver'].routine + kernel0 = scheduler['#kernel0'].routine + kernel_map = {key: scheduler[f'#{key}'].routine for key in ['kernel1', 'kernel2', 'kernel3']} + + # symbols within each module + expected_symbols = {'modulea': ['var0', 'var1'], 'moduleb': ['var2', 'var3'], + 'modulec': ['var4', 'var5']} + # expected intent of those variables (if hoisted) + var_intent_map = {'var0': 'in', 'var1': 'in', 'var2': 'in', + 'var3': 'in', 'var4': 'inout', 'var5': 'inout', 'tmp': None} + # DRIVER + imports = FindNodes(Import).visit(driver.spec) + import_names = [_import.module.lower() for _import in imports] + # check driver imports + expected_driver_modules = ['modulec'] + expected_driver_modules += ['moduleb'] if ignore_modules is None else [] + # OMNI handles parameters differently, ModuleA only contains parameters + if frontend != OMNI: + expected_driver_modules += ['modulea'] if hoist_parameters else [] + assert len(imports) == len(expected_driver_modules) + assert sorted(expected_driver_modules) == sorted(import_names) + for _import in imports: + assert sorted([sym.name for sym in _import.symbols]) == expected_symbols[_import.module.lower()] + # check driver call + driver_calls = FindNodes(CallStatement).visit(driver.body) + expected_args = [] + for module in expected_driver_modules: + expected_args.extend(expected_symbols[module]) + assert [arg.name for arg in driver_calls[0].arguments] == sorted(expected_args) + + originally = {'kernel1': ['modulea'], 'kernel2': ['moduleb'], + 'kernel3': ['moduleb', 'modulec']} + # KERNEL0 + assert [arg.name for arg in kernel0.arguments] == sorted(expected_args) + assert [arg.name for arg in kernel0.variables] == sorted(expected_args) + for var in kernel0.variables: + assert kernel0.variable_map[var.name.lower()].type.intent == var_intent_map[var.name.lower()] + kernel0_calls = FindNodes(CallStatement).visit(kernel0.body) + # KERNEL1 & KERNEL2 & KERNEL3 + for call in kernel0_calls: + expected_args = [] + expected_imports = [] + kernel_expected_symbols = [] + for module in originally[call.routine.name]: + # always, since at least 'some_func' is imported + if call.routine.name == 'kernel1' and module == 'modulea': + expected_imports.append(module) + kernel_expected_symbols.append('some_func') + if module in expected_driver_modules: + expected_args.extend(expected_symbols[module]) + else: + # already added + if module != 'modulea': + expected_imports.append(module) + kernel_expected_symbols.extend(expected_symbols[module]) + assert len(expected_args) == len(call.arguments) + assert [arg.name for arg in call.arguments] == expected_args + assert [arg.name for arg in kernel_map[call.routine.name].arguments] == expected_args + for var in kernel_map[call.routine.name].variables: + var_intent = kernel_map[call.routine.name].variable_map[var.name.lower()].type.intent + assert var_intent == var_intent_map[var.name.lower()] + if call.routine.name in ['kernel1', 'kernel2']: + expected_args = ['tmp'] + expected_args + assert [arg.name for arg in kernel_map[call.routine.name].variables] == expected_args + kernel_imports = FindNodes(Import).visit(call.routine.spec) + assert sorted([_import.module.lower() for _import in kernel_imports]) == sorted(expected_imports) + imported_symbols = [] # _import.symbols for _import in kernel_imports] + for _import in kernel_imports: + imported_symbols.extend([sym.name.lower() for sym in _import.symbols]) + assert sorted(imported_symbols) == sorted(kernel_expected_symbols) + + +@pytest.mark.parametrize('frontend', available_frontends()) +@pytest.mark.parametrize('hoist_parameters', (False, True)) +def test_transformation_global_var_derived_type_hoist(here, config, frontend, hoist_parameters): + """ + Test hoisting of derived-type global variable imports. + """ + + config['default']['enable_imports'] = True + config['routines'] = { + 'driver_derived_type': {'role': 'driver'} + } + + scheduler = Scheduler(paths=here/'sources/projGlobalVarImports', config=config, frontend=frontend) + scheduler.process(transformation=GlobalVariableAnalysis()) + scheduler.process(transformation=GlobalVarHoistTransformation(hoist_parameters)) + + driver = scheduler['#driver_derived_type'].routine + kernel = scheduler['#kernel_derived_type'].routine + + # DRIVER + imports = FindNodes(Import).visit(driver.spec) + assert len(imports) == 1 + assert imports[0].module.lower() == 'module_derived_type' + assert sorted([sym.name.lower() for sym in imports[0].symbols]) == sorted(['p', 'p_array', 'p0']) + calls = FindNodes(CallStatement).visit(driver.body) + assert len(calls) == 1 + # KERNEL + assert [arg.name for arg in calls[0].arguments] == ['p', 'p0', 'p_array'] + assert [arg.name for arg in kernel.arguments] == ['p', 'p0', 'p_array'] + kernel_imports = FindNodes(Import).visit(kernel.spec) + assert len(kernel_imports) == 1 + assert [sym.name.lower() for sym in kernel_imports[0].symbols] == ['g'] + assert sorted([var.name for var in kernel.variables]) == ['i', 'j', 'p', 'p0', 'p_array'] + assert kernel.variable_map['p_array'].type.allocatable + assert kernel.variable_map['p_array'].type.intent == 'inout' + assert kernel.variable_map['p_array'].type.dtype.name == 'point' + assert kernel.variable_map['p'].type.intent == 'inout' + assert kernel.variable_map['p'].type.dtype.name == 'point' + assert kernel.variable_map['p0'].type.intent == 'in' + assert kernel.variable_map['p0'].type.dtype.name == 'point' diff --git a/transformations/transformations/data_offload.py b/transformations/transformations/data_offload.py index 1d113abdc..058d8ed7b 100644 --- a/transformations/transformations/data_offload.py +++ b/transformations/transformations/data_offload.py @@ -17,7 +17,7 @@ __all__ = [ 'DataOffloadTransformation', 'GlobalVariableAnalysis', - 'GlobalVarOffloadTransformation' + 'GlobalVarOffloadTransformation', 'GlobalVarHoistTransformation' ] @@ -210,7 +210,8 @@ class GlobalVariableAnalysis(Transformation): For procedures, use the the Loki dataflow analysis functionality to compile a list of used and/or defined variables (i.e., read and/or written). - Store these under the keys ``'uses_symbols'`` and ``'defines_symbols'``, respectively. + Store these under the keys ``'uses_symbols'`` and ``'defines_symbols'``, + respectively. For modules/:any:`GlobalVarImportItem`, store the list of variables declared in the module under the key ``'declares'`` and out of this the subset of variables that @@ -300,9 +301,6 @@ def transform_subroutine(self, routine, **kwargs): uses_imported_symbols = {var for var in uses_imported_symbols if isinstance(var, (Scalar, Array))} defines_imported_symbols = {var for var in defines_imported_symbols if isinstance(var, (Scalar, Array))} - # Discard parameters (which are read-only by definition) - uses_imported_symbols = {var for var in uses_imported_symbols if not var.type.parameter} - def _map_var_to_module(var): if var.parent: module = var.parents[0].type.module @@ -465,10 +463,10 @@ def transform_module(self, module, **kwargs): for v in as_tuple(acc_pragma_parameters.get('create')) ])) - # Build list of symbols to be offloaded + # Build list of symbols to be offloaded (discard variables being parameter) offload_variables = { var.parents[0] if var.parent else var - for var in item.trafo_data[self._key].get('offload', ()) + for var in item.trafo_data[self._key].get('offload', ()) if not var.type.parameter } if (invalid_vars := offload_variables - set(module.variables)): @@ -508,6 +506,9 @@ def process_driver(self, routine, successors): for item in successors: defines_symbols |= item.trafo_data.get(self._key, {}).get('defines_symbols', set()) uses_symbols |= item.trafo_data.get(self._key, {}).get('uses_symbols', set()) + # discard variables being parameter + parameters = {(var, module) for var, module in uses_symbols if var.type.parameter} + uses_symbols ^= parameters # Filter out arrays of derived types and nested derived types # For these, automatic offloading is currently not supported @@ -613,3 +614,270 @@ def process_driver(self, routine, successors): '![Loki::GlobalVarOffloadTransformation] ' '-------- Added global variable imports for offload directives -----------' ))) + + +class GlobalVarHoistTransformation(Transformation): + """ + Transformation to hoist module variables used in device routines + + This requires a prior analysis pass with :any:`GlobalVariableAnalysis` to collect + the relevant global variable use information. + + Modules to be ignored can be specified. Further, it is possible to + configure whether parameters/compile time constants are hoisted as well + or not. + + .. note:: + Hoisted variables that could theoretically be ``intent(out)`` + are despite specified as ``intent(inout)``. + + For example, the following code: + + .. code-block:: fortran + + module moduleB + real :: var2 + real :: var3 + end module moduleB + + module moduleC + real :: var4 + real :: var5 + end module moduleC + + subroutine driver() + implicit none + + call kernel() + + end subroutine driver + + subroutine kernel() + use moduleB, only: var2,var3 + use moduleC, only: var4,var5 + implicit none + + var4 = var2 + var5 = var3 + + end subroutine kernel + + is transformed to: + + .. code-block:: fortran + + module moduleB + real :: var2 + real :: var3 + end module moduleB + + module moduleC + real :: var4 + real :: var5 + end module moduleC + + subroutine driver() + use moduleB, only: var2,var3 + use moduleC, only: var4,var5 + implicit none + + call kernel(var2, var3, var4, var5) + + end subroutine driver + + subroutine kernel(var2, var3, var4, var5) + implicit none + real, intent(in) :: var2 + real, intent(in) :: var3 + real, intent(inout) :: var4 + real, intent(inout) :: var5 + + var4 = var2 + var5 = var3 + + end subroutine kernel + + Parameters + ---------- + hoist_parameters : bool, optional + Whether or not to hoist module variables being parameter/compile + time constants (default: `False`). + ignore_modules : (list, tuple) of str + Modules to be ignored (default: `None`, thus no module to be ignored). + key : str, optional + Overwrite the key that is used to store analysis results in ``trafo_data``. + """ + item_filter = (SubroutineItem,) + + def __init__(self, hoist_parameters=False, ignore_modules=None, key=None): + self._key = key or GlobalVariableAnalysis._key + self.hoist_parameters = hoist_parameters + self.ignore_modules = [module.lower() for module in as_tuple(ignore_modules)] + + def transform_subroutine(self, routine, **kwargs): + """ + Hoist module variables. + """ + role = kwargs.get('role') + successors = kwargs.get('successors', ()) + item = kwargs.get('item', None) + + if role == 'driver': + self.process_driver(routine, successors) + elif role == 'kernel': + self.process_kernel(routine, successors, item) + + def process_driver(self, routine, successors): + """ + Hoist module variables for driver routines. + + This includes: appending the corresponding variables + to calls within the driver and adding the relevant + imports. + """ + # get symbols per routine (successors) + defines_symbols, uses_symbols = self._get_symbols(successors) + + # append symbols to calls (arguments) + self._append_call_arguments(routine, uses_symbols, defines_symbols) + + # combine/collect symbols disregarding routine + all_defines_symbols = set.union(*defines_symbols.values(), set()) + all_uses_symbols = set.union(*uses_symbols.values(), set()) + # add imports for symbols hoisted + symbol_map = defaultdict(set) + for var, module in chain(all_uses_symbols, all_defines_symbols): + # filter modules that are supposed to be ignored + if module.lower() in self.ignore_modules: + continue + symbol_map[module].add(var.parents[0] if var.parent else var) + import_map = CaseInsensitiveDict() + scope = routine + while scope: + import_map.update(scope.import_map) + scope = scope.parent + missing_imports_map = defaultdict(set) + for module, variables in symbol_map.items(): + missing_imports_map[module] |= {var for var in variables if var.name not in import_map} + if missing_imports_map: + routine.spec.prepend(Comment(text=( + '![Loki::GlobalVarHoistTransformation] ---------------------------------------' + ))) + for module, variables in missing_imports_map.items(): + symbols = tuple(var.clone(dimensions=None, scope=routine) for var in variables) + routine.spec.prepend(Import(module=module, symbols=symbols)) + + routine.spec.prepend(Comment(text=( + '![Loki::GlobalVarHoistTransformation] ' + '-------- Added global variable imports for offload directives -----------' + ))) + + def process_kernel(self, routine, successors, item): + """ + Hoist mdule variables for kernel routines. + + This includes: appending the corresponding variables + to the routine arguments as well as to calls within the kernel + and removing the imports that became unused. + """ + # get symbols per routine (successors) + defines_symbols, uses_symbols = self._get_symbols(successors) + + # append symbols to routine (arguments) + self._append_routine_arguments(routine, item) + + # append symbols to calls (arguments) + self._append_call_arguments(routine, uses_symbols, defines_symbols) + + # get symbols for this routine/kernel + kernel_defines_symbols = item.trafo_data.get(self._key, {}).get('defines_symbols', set()) + kernel_uses_symbols = item.trafo_data.get(self._key, {}).get('uses_symbols', set()) + # remove imports for symbols hoisted + symbol_map = defaultdict(set) + for var, module in chain(kernel_uses_symbols, kernel_defines_symbols): + # filter modules that are supposed to be ignored + if module.lower() in self.ignore_modules: + continue + symbol_map[module].add(var.parents[0] if var.parent else var) + import_map = CaseInsensitiveDict( + (s.name, imprt) for imprt in routine.all_imports[::-1] for s in imprt.symbols + ) + redundant_imports_map = defaultdict(set) + for module, variables in symbol_map.items(): + redundant = [var.parent[0] if var.parent else var for var in variables] + redundant = {var.clone(dimensions=None) for var in redundant if var.name in import_map} + redundant_imports_map[module] |= redundant + import_map = {} + imports = FindNodes(Import).visit(routine.spec) + for _import in imports: + new_symbols = tuple(var.clone(dimensions=None, scope=routine) + for var in set(_import.symbols)-redundant_imports_map[_import.module.lower()]) + if new_symbols: + import_map[_import] = _import.clone(symbols=new_symbols) + else: + import_map[_import] = None + routine.spec = Transformer(import_map).visit(routine.spec) + + def _get_symbols(self, successors): + """ + Get module variables/symbols (grouped by routine/successor). + """ + defines_symbols = {} + uses_symbols = {} + for item in successors: + if isinstance(item, GlobalVarImportItem): + continue + defines_symbols[item.routine.name] = set() + uses_symbols[item.routine.name] = set() + defines_symbols[item.routine.name] = item.trafo_data.get(self._key, {}).get('defines_symbols', set()) + uses_symbols[item.routine.name] = item.trafo_data.get(self._key, {}).get('uses_symbols', set()) + # remove parameters if hoist_parameters is False + if not self.hoist_parameters: + parameters = {(var, module) for var, module in uses_symbols[item.routine.name] if var.type.parameter} + uses_symbols[item.routine.name] ^= parameters + return defines_symbols, uses_symbols + + def _append_call_arguments(self, routine, uses_symbols, defines_symbols): + """ + Helper to append variables to the call(s) (arguments). + """ + symbol_map = defaultdict(set) + for key, _ in uses_symbols.items(): + all_symbols = uses_symbols[key]|defines_symbols[key] + for var, module in all_symbols: + # filter modules that are supposed to be ignored + if module.lower() in self.ignore_modules: + continue + symbol_map[key].add(var.parents[0] if var.parent else var) + call_map = {} + calls = FindNodes(CallStatement).visit(routine.body) + for call in calls: + if call.routine.name in uses_symbols: + arguments = call.arguments + new_args = sorted([var.clone(dimensions=None) for var in symbol_map[call.routine.name]], + key=lambda symbol: symbol.name) + call_map[call] = call.clone(arguments=arguments + tuple(new_args)) + routine.body = Transformer(call_map).visit(routine.body) + + def _append_routine_arguments(self, routine, item): + """ + Helper to append variables to the routine (arguments). + """ + all_defines_symbols = item.trafo_data.get(self._key, {}).get('defines_symbols', set()) + all_defines_vars = [var.parents[0] if var.parent else var for var, _ in all_defines_symbols] + all_uses_symbols = item.trafo_data.get(self._key, {}).get('uses_symbols', set()) + # remove parameters if hoist_parameters is False + if not self.hoist_parameters: + parameters = {(var, module) for var, module in all_uses_symbols if var.type.parameter} + all_uses_symbols ^= parameters + all_symbols = all_uses_symbols|all_defines_symbols + new_arguments = [] + for var, module in all_symbols: + # filter modules that are supposed to be ignored + if module.lower() in self.ignore_modules: + continue + new_arguments.append(var.parents[0] if var.parent else var) + new_arguments = set(new_arguments) # remove duplicates + new_arguments = [arg.clone(type=arg.type.clone(intent='inout' if arg in all_defines_vars + else 'in', parameter=False, initial=None)) for arg in new_arguments] + routine.arguments += tuple(sorted(new_arguments, key=lambda symbol: symbol.name))