From 14ef3491b0af1bbb72d4433913f4bb0ae58429f6 Mon Sep 17 00:00:00 2001 From: Michael Staneker Date: Wed, 31 Jan 2024 15:19:40 +0000 Subject: [PATCH 01/10] First implementation of a 'GlobalVarHoistTransformation' trafo (utilizing the slightly modified 'GlobalVariableAnalysis' for analysis) --- transformations/tests/test_data_offload.py | 52 +++++- .../transformations/data_offload.py | 176 +++++++++++++++++- 2 files changed, 223 insertions(+), 5 deletions(-) diff --git a/transformations/tests/test_data_offload.py b/transformations/tests/test_data_offload.py index 89baa0bd7..a4e80e40e 100644 --- a/transformations/tests/test_data_offload.py +++ b/transformations/tests/test_data_offload.py @@ -12,10 +12,13 @@ from loki import ( Sourcefile, FindNodes, Pragma, PragmaRegion, Loop, CallStatement, pragma_regions_attached, get_pragma_parameters, - gettempdir, Scheduler, OMNI, Import + gettempdir, Scheduler, OMNI, Import, fgen ) from conftest import available_frontends -from transformations import DataOffloadTransformation, GlobalVariableAnalysis, GlobalVarOffloadTransformation +from transformations import ( + DataOffloadTransformation, GlobalVariableAnalysis, + GlobalVarOffloadTransformation, GlobalVarHoistTransformation +) @pytest.fixture(scope='module', name='here') @@ -435,6 +438,8 @@ def test_global_variable_analysis(frontend, key, config, global_variable_analysi if item == 'global_var_analysis_data_mod#some_type': continue for trafo_data_key, trafo_data_value in item.trafo_data[key].items(): + if trafo_data_key not in ('defines_symbols', 'uses_symbols'): + continue assert ( sorted( tuple(str(vv) for vv in v) if isinstance(v, tuple) else str(v) @@ -634,3 +639,46 @@ def test_transformation_global_var_import_derived_type(here, config, frontend): # Note: g is not offloaded because it is not used by the kernel (albeit imported) assert 'p0' in pragmas[1].content assert 'p_array' in pragmas[2].content + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_transformation_global_var_hoist(here, config, frontend): + """ + Test the generation of offload instructions of global variable imports. + """ + config['default']['enable_imports'] = True + config['routines'] = { + 'driver': {'role': 'driver'} + } + + scheduler = Scheduler(paths=here/'sources/projGlobalVarImports', config=config, frontend=frontend) + scheduler.process(transformation=GlobalVariableAnalysis()) + scheduler.process(transformation=GlobalVarHoistTransformation(hoist_parameters=False)) + + print("") + print("") + + driver = scheduler['#driver'].routine + kernel0 = scheduler['#kernel0'].routine + kernel1 = scheduler['#kernel1'].routine + kernel2 = scheduler['#kernel2'].routine + kernel3 = scheduler['#kernel3'].routine + moduleA = scheduler['modulea#var0'].scope + moduleB = scheduler['moduleb#var2'].scope + moduleC = scheduler['modulec#var4'].scope + + print(fgen(driver)) + print("----------") + print(fgen(kernel0)) + print("----------") + print(fgen(kernel1)) + print("----------") + print(fgen(kernel2)) + print("----------") + print(fgen(kernel3)) + print("----------") + # print(fgen(moduleA)) + print("----------") + # print(fgen(moduleB)) + print("----------") + # print(fgen(moduleC)) diff --git a/transformations/transformations/data_offload.py b/transformations/transformations/data_offload.py index 1d113abdc..152652ae0 100644 --- a/transformations/transformations/data_offload.py +++ b/transformations/transformations/data_offload.py @@ -17,7 +17,7 @@ __all__ = [ 'DataOffloadTransformation', 'GlobalVariableAnalysis', - 'GlobalVarOffloadTransformation' + 'GlobalVarOffloadTransformation', 'GlobalVarHoistTransformation' ] @@ -210,7 +210,9 @@ class GlobalVariableAnalysis(Transformation): For procedures, use the the Loki dataflow analysis functionality to compile a list of used and/or defined variables (i.e., read and/or written). - Store these under the keys ``'uses_symbols'`` and ``'defines_symbols'``, respectively. + Store these under the keys ``'uses_symbols'`` and ``'defines_symbols'``, + respectively (and additionally ``'uses_parameters'`` for used symbols being + compile time constants). For modules/:any:`GlobalVarImportItem`, store the list of variables declared in the module under the key ``'declares'`` and out of this the subset of variables that @@ -228,6 +230,7 @@ class GlobalVariableAnalysis(Transformation): SubroutineItem: { 'uses_symbols': set( (Variable, ''), (Variable, ''), ...), + 'uses_parameters': set( (Variable, ''), (Variable, ''), ...), 'defines_symbols': set((Variable, ''), (Variable, ''), ...) } @@ -301,7 +304,9 @@ def transform_subroutine(self, routine, **kwargs): defines_imported_symbols = {var for var in defines_imported_symbols if isinstance(var, (Scalar, Array))} # Discard parameters (which are read-only by definition) - uses_imported_symbols = {var for var in uses_imported_symbols if not var.type.parameter} + # uses_imported_symbols = {var for var in uses_imported_symbols if not var.type.parameter} + uses_imported_parameters = {var for var in uses_imported_symbols if var.type.parameter} + uses_imported_symbols ^= uses_imported_parameters # TODO: does it matter, wether ^= or -= ??? def _map_var_to_module(var): if var.parent: @@ -323,6 +328,9 @@ def _map_var_to_module(var): item.trafo_data[self._key]['uses_symbols'] = { _map_var_to_module(var) for var in uses_imported_symbols } + item.trafo_data[self._key]['uses_parameters'] = { + _map_var_to_module(var) for var in uses_imported_parameters + } item.trafo_data[self._key]['defines_symbols'] = { _map_var_to_module(var) for var in defines_imported_symbols } @@ -348,6 +356,7 @@ def _map_var_to_module(var): for successor in successors: if isinstance(successor, SubroutineItem): item.trafo_data[self._key]['uses_symbols'] |= successor.trafo_data[self._key]['uses_symbols'] + item.trafo_data[self._key]['uses_parameters'] |= successor.trafo_data[self._key]['uses_parameters'] item.trafo_data[self._key]['defines_symbols'] |= successor.trafo_data[self._key]['defines_symbols'] @@ -613,3 +622,164 @@ def process_driver(self, routine, successors): '![Loki::GlobalVarOffloadTransformation] ' '-------- Added global variable imports for offload directives -----------' ))) + + +class GlobalVarHoistTransformation(Transformation): + """ + TODO: DOCSTRING ... + """ + # Include module variable imports in the underlying graph + # connectivity for traversal with the Scheduler + item_filter = (SubroutineItem,) # GlobalVarImportItem) + + def __init__(self, hoist_parameters=True, key=None): + self._key = key or GlobalVariableAnalysis._key + self.hoist_parameters = hoist_parameters + + def transform_subroutine(self, routine, **kwargs): + """ + Add data offload and pull-back directives to the driver + """ + role = kwargs.get('role') + successors = kwargs.get('successors', ()) + item = kwargs.get('item', None) + + if role == 'driver': + self.process_driver(routine, successors) + elif role == 'kernel': + self.process_kernel(routine, successors, item) + + def process_driver(self, routine, successors): + defines_symbols, uses_symbols = self._get_symbols(routine, successors) + ## collect disregarding module + all_defines_symbols = set() + all_uses_symbols = set() + for _, value in defines_symbols.items(): + all_defines_symbols |= value + for _, value in uses_symbols.items(): + all_uses_symbols |= value + ## + # DONE: add to calls + offload_map = defaultdict(set) + # all_symbols = uses_symbols[key]|defines_symbols[key] + for key, _ in uses_symbols.items(): + all_symbols = uses_symbols[key]|defines_symbols[key] + for var, module in all_symbols: #sorted(all_symbols, key=lambda symbol: symbol[0].name): # uses_symbols[key]|defines_symbols[key]: + offload_map[key].add(var.parents[0] if var.parent else var) + call_map = {} + calls = FindNodes(CallStatement).visit(routine.body) + for call in calls: + if call.routine.name in uses_symbols: # in chain(uses_symbols[call.routine.name], defines_symbols[call.routine.name]): + arguments = call.arguments + call_map[call] = call.clone(arguments=arguments + tuple(sorted([var.clone() for var in offload_map[call.routine.name]], key=lambda symbol: symbol.name))) # uses_symbols[call.routine.name])) + routine.body = Transformer(call_map).visit(routine.body) + # DONE: append/add/adapt imports + # Add imports for offload variables + offload_map = defaultdict(set) + for var, module in chain(all_uses_symbols, all_defines_symbols): + offload_map[module].add(var.parents[0] if var.parent else var) + + import_map = CaseInsensitiveDict() + scope = routine + while scope: + import_map.update(scope.import_map) + scope = scope.parent + + missing_imports_map = defaultdict(set) + for module, variables in offload_map.items(): + missing_imports_map[module] |= {var for var in variables if var.name not in import_map} + + if missing_imports_map: + routine.spec.prepend(Comment(text=( + '![Loki::GlobalVarHoistTransformation] ---------------------------------------' + ))) + for module, variables in missing_imports_map.items(): + symbols = tuple(var.clone(dimensions=None, scope=routine) for var in variables) + routine.spec.prepend(Import(module=module, symbols=symbols)) + + routine.spec.prepend(Comment(text=( + '![Loki::GlobalVarHoistTransformation] ' + '-------- Added global variable imports for offload directives -----------' + ))) + ## + + def process_kernel(self, routine, successors, item): + defines_symbols, uses_symbols = self._get_symbols(routine, successors) + # DONE: add to routine arguments + # TODO: intent('out') is evil, just use 'intent('inout') like currently implemented?! + all_defines_symbols = item.trafo_data.get(self._key, {}).get('defines_symbols', set()) # set() + all_defines_vars = [var.parents[0] if var.parent else var for var, _ in all_defines_symbols] + all_uses_symbols = item.trafo_data.get(self._key, {}).get('uses_symbols', set()) # set() + all_uses_vars = [var.parent[0] if var.parent else var for var, _ in all_uses_symbols] + all_symbols = all_uses_symbols|all_defines_symbols + new_arguments = [] + for var, _ in all_symbols: + new_arguments.append(var.parents[0] if var.parent else var) + + new_arguments = [arg.clone(type=arg.type.clone(intent='inout' if arg in all_defines_vars else 'in', parameter=False, initial=None)) for arg in new_arguments] + routine.arguments += tuple(sorted(new_arguments, key=lambda symbol: symbol.name)) + ### + # DONE: add to calls + offload_map = defaultdict(set) + # all_symbols = uses_symbols[key]|defines_symbols[key] + for key, _ in uses_symbols.items(): + all_symbols = uses_symbols[key]|defines_symbols[key] + for var, module in all_symbols: #sorted(all_symbols, key=lambda symbol: symbol[0].name): # uses_symbols[key]|defines_symbols[key]: + offload_map[key].add(var.parents[0] if var.parent else var) + call_map = {} + calls = FindNodes(CallStatement).visit(routine.body) + for call in calls: + if call.routine.name in uses_symbols: # in chain(uses_symbols[call.routine.name], defines_symbols[call.routine.name]): + arguments = call.arguments + call_map[call] = call.clone(arguments=arguments + tuple(sorted([var.clone() for var in offload_map[call.routine.name]], key=lambda symbol: symbol.name))) # uses_symbols[call.routine.name])) + routine.body = Transformer(call_map).visit(routine.body) + ### + # DONE: remove/adapt imports + # Add imports for offload variables + all_defines_symbols = item.trafo_data.get(self._key, {}).get('defines_symbols', set()) # set() + all_uses_symbols = item.trafo_data.get(self._key, {}).get('uses_symbols', set()) # set() + + offload_map = defaultdict(set) + for var, module in chain(all_uses_symbols, all_defines_symbols): + offload_map[module].add(var.parents[0] if var.parent else var) + + import_map = CaseInsensitiveDict() + scope = routine + while scope: + import_map.update(scope.import_map) + scope = scope.parent + + redundant_imports_map = defaultdict(set) + for module, variables in offload_map.items(): + redundant_imports_map[module] |= {var for var in variables if var.name in import_map} + + import_map = {} + imports = FindNodes(Import).visit(routine.spec) + for _import in imports: + new_symbols = tuple(var.clone(dimensions=None, scope=routine) + for var in set(_import.symbols)^redundant_imports_map[_import.module.lower()]) + if new_symbols: + import_map[_import] = _import.clone(symbols=new_symbols) + else: + import_map[_import] = None + routine.spec = Transformer(import_map).visit(routine.spec) + ## + + def _get_symbols(self, routine, successors): + # Combine analysis data across successor items + defines_symbols = {} # set() + uses_symbols = {} # set() + for item in successors: + if isinstance(item, GlobalVarImportItem): + continue + defines_symbols[item.routine.name] = set() + uses_symbols[item.routine.name] = set() + # defines_symbols |= item.trafo_data.get(self._key, {}).get('defines_symbols', set()) + defines_symbols[item.routine.name] = item.trafo_data.get(self._key, {}).get('defines_symbols', set()) + # uses_symbols |= item.trafo_data.get(self._key, {}).get('uses_symbols', set()) + uses_symbols[item.routine.name] = item.trafo_data.get(self._key, {}).get('uses_symbols', set()) + if self.hoist_parameters: + # uses_symbols |= item.trafo_data.get(self._key, {}).get('uses_parameters', set()) + uses_symbols[item.routine.name] |= item.trafo_data.get(self._key, {}).get('uses_parameters', set()) + return defines_symbols, uses_symbols + From a56fc3d5b9e1f28662c6e4b67576f094d2712702 Mon Sep 17 00:00:00 2001 From: Michael Staneker Date: Thu, 1 Feb 2024 08:31:41 +0000 Subject: [PATCH 02/10] Continued implementation of a 'GlobalVarHoistTransformation' trafo --- transformations/tests/test_data_offload.py | 30 +++++++- .../transformations/data_offload.py | 69 +++++++++---------- 2 files changed, 59 insertions(+), 40 deletions(-) diff --git a/transformations/tests/test_data_offload.py b/transformations/tests/test_data_offload.py index a4e80e40e..b2b37574c 100644 --- a/transformations/tests/test_data_offload.py +++ b/transformations/tests/test_data_offload.py @@ -642,7 +642,8 @@ def test_transformation_global_var_import_derived_type(here, config, frontend): @pytest.mark.parametrize('frontend', available_frontends()) -def test_transformation_global_var_hoist(here, config, frontend): +@pytest.mark.parametrize('hoist_parameters', (False, True)) +def test_transformation_global_var_hoist(here, config, frontend, hoist_parameters): """ Test the generation of offload instructions of global variable imports. """ @@ -653,7 +654,7 @@ def test_transformation_global_var_hoist(here, config, frontend): scheduler = Scheduler(paths=here/'sources/projGlobalVarImports', config=config, frontend=frontend) scheduler.process(transformation=GlobalVariableAnalysis()) - scheduler.process(transformation=GlobalVarHoistTransformation(hoist_parameters=False)) + scheduler.process(transformation=GlobalVarHoistTransformation(hoist_parameters=hoist_parameters)) print("") print("") @@ -682,3 +683,28 @@ def test_transformation_global_var_hoist(here, config, frontend): # print(fgen(moduleB)) print("----------") # print(fgen(moduleC)) + + +@pytest.mark.parametrize('frontend', available_frontends()) +@pytest.mark.parametrize('hoist_parameters', (False, True)) +def test_transformation_global_var_derived_type_hoist(here, config, frontend, hoist_parameters): + """ + Test the generation of offload instructions of derived-type global variable imports. + """ + + config['default']['enable_imports'] = True + config['routines'] = { + 'driver_derived_type': {'role': 'driver'} + } + + scheduler = Scheduler(paths=here/'sources/projGlobalVarImports', config=config, frontend=frontend) + scheduler.process(transformation=GlobalVariableAnalysis()) + scheduler.process(transformation=GlobalVarHoistTransformation(hoist_parameters)) + + driver = scheduler['#driver_derived_type'].routine + kernel = scheduler['#kernel_derived_type'].routine + module = scheduler['module_derived_type#p'].scope + + print(fgen(driver)) + print("--------------------") + print(fgen(kernel)) diff --git a/transformations/transformations/data_offload.py b/transformations/transformations/data_offload.py index 152652ae0..a3a3db843 100644 --- a/transformations/transformations/data_offload.py +++ b/transformations/transformations/data_offload.py @@ -659,20 +659,7 @@ def process_driver(self, routine, successors): for _, value in uses_symbols.items(): all_uses_symbols |= value ## - # DONE: add to calls - offload_map = defaultdict(set) - # all_symbols = uses_symbols[key]|defines_symbols[key] - for key, _ in uses_symbols.items(): - all_symbols = uses_symbols[key]|defines_symbols[key] - for var, module in all_symbols: #sorted(all_symbols, key=lambda symbol: symbol[0].name): # uses_symbols[key]|defines_symbols[key]: - offload_map[key].add(var.parents[0] if var.parent else var) - call_map = {} - calls = FindNodes(CallStatement).visit(routine.body) - for call in calls: - if call.routine.name in uses_symbols: # in chain(uses_symbols[call.routine.name], defines_symbols[call.routine.name]): - arguments = call.arguments - call_map[call] = call.clone(arguments=arguments + tuple(sorted([var.clone() for var in offload_map[call.routine.name]], key=lambda symbol: symbol.name))) # uses_symbols[call.routine.name])) - routine.body = Transformer(call_map).visit(routine.body) + self._append_call_arguments(routine, uses_symbols, defines_symbols) # DONE: append/add/adapt imports # Add imports for offload variables offload_map = defaultdict(set) @@ -707,32 +694,10 @@ def process_kernel(self, routine, successors, item): defines_symbols, uses_symbols = self._get_symbols(routine, successors) # DONE: add to routine arguments # TODO: intent('out') is evil, just use 'intent('inout') like currently implemented?! - all_defines_symbols = item.trafo_data.get(self._key, {}).get('defines_symbols', set()) # set() - all_defines_vars = [var.parents[0] if var.parent else var for var, _ in all_defines_symbols] - all_uses_symbols = item.trafo_data.get(self._key, {}).get('uses_symbols', set()) # set() - all_uses_vars = [var.parent[0] if var.parent else var for var, _ in all_uses_symbols] - all_symbols = all_uses_symbols|all_defines_symbols - new_arguments = [] - for var, _ in all_symbols: - new_arguments.append(var.parents[0] if var.parent else var) - - new_arguments = [arg.clone(type=arg.type.clone(intent='inout' if arg in all_defines_vars else 'in', parameter=False, initial=None)) for arg in new_arguments] - routine.arguments += tuple(sorted(new_arguments, key=lambda symbol: symbol.name)) + self._append_routine_arguments(routine, item) ### # DONE: add to calls - offload_map = defaultdict(set) - # all_symbols = uses_symbols[key]|defines_symbols[key] - for key, _ in uses_symbols.items(): - all_symbols = uses_symbols[key]|defines_symbols[key] - for var, module in all_symbols: #sorted(all_symbols, key=lambda symbol: symbol[0].name): # uses_symbols[key]|defines_symbols[key]: - offload_map[key].add(var.parents[0] if var.parent else var) - call_map = {} - calls = FindNodes(CallStatement).visit(routine.body) - for call in calls: - if call.routine.name in uses_symbols: # in chain(uses_symbols[call.routine.name], defines_symbols[call.routine.name]): - arguments = call.arguments - call_map[call] = call.clone(arguments=arguments + tuple(sorted([var.clone() for var in offload_map[call.routine.name]], key=lambda symbol: symbol.name))) # uses_symbols[call.routine.name])) - routine.body = Transformer(call_map).visit(routine.body) + self._append_call_arguments(routine, uses_symbols, defines_symbols) ### # DONE: remove/adapt imports # Add imports for offload variables @@ -783,3 +748,31 @@ def _get_symbols(self, routine, successors): uses_symbols[item.routine.name] |= item.trafo_data.get(self._key, {}).get('uses_parameters', set()) return defines_symbols, uses_symbols + def _append_call_arguments(self, routine, uses_symbols, defines_symbols): + # DONE: add to calls + offload_map = defaultdict(set) + # all_symbols = uses_symbols[key]|defines_symbols[key] + for key, _ in uses_symbols.items(): + all_symbols = uses_symbols[key]|defines_symbols[key] + for var, module in all_symbols: #sorted(all_symbols, key=lambda symbol: symbol[0].name): # uses_symbols[key]|defines_symbols[key]: + offload_map[key].add(var.parents[0] if var.parent else var) + call_map = {} + calls = FindNodes(CallStatement).visit(routine.body) + for call in calls: + if call.routine.name in uses_symbols: # in chain(uses_symbols[call.routine.name], defines_symbols[call.routine.name]): + arguments = call.arguments + call_map[call] = call.clone(arguments=arguments + tuple(sorted([var.clone(dimensions=None) for var in offload_map[call.routine.name]], key=lambda symbol: symbol.name))) # uses_symbols[call.routine.name])) + routine.body = Transformer(call_map).visit(routine.body) + + def _append_routine_arguments(self, routine, item): + all_defines_symbols = item.trafo_data.get(self._key, {}).get('defines_symbols', set()) # set() + all_defines_vars = [var.parents[0] if var.parent else var for var, _ in all_defines_symbols] + all_uses_symbols = item.trafo_data.get(self._key, {}).get('uses_symbols', set()) # set() + all_uses_vars = [var.parent[0] if var.parent else var for var, _ in all_uses_symbols] + all_symbols = all_uses_symbols|all_defines_symbols + new_arguments = [] + for var, _ in all_symbols: + new_arguments.append(var.parents[0] if var.parent else var) + new_arguments = set(new_arguments) # remove duplicates + new_arguments = [arg.clone(type=arg.type.clone(intent='inout' if arg in all_defines_vars else 'in', parameter=False, initial=None)) for arg in new_arguments] + routine.arguments += tuple(sorted(new_arguments, key=lambda symbol: symbol.name)) From 553f6f18890238496fe852f4860de79dab787f74 Mon Sep 17 00:00:00 2001 From: Michael Staneker Date: Thu, 1 Feb 2024 10:22:19 +0000 Subject: [PATCH 03/10] Allow ignoring modules in 'GlobalVarHoistTransformation' --- transformations/tests/test_data_offload.py | 2 +- .../transformations/data_offload.py | 31 +++++++++++++++++-- 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/transformations/tests/test_data_offload.py b/transformations/tests/test_data_offload.py index b2b37574c..61d6ccd54 100644 --- a/transformations/tests/test_data_offload.py +++ b/transformations/tests/test_data_offload.py @@ -654,7 +654,7 @@ def test_transformation_global_var_hoist(here, config, frontend, hoist_parameter scheduler = Scheduler(paths=here/'sources/projGlobalVarImports', config=config, frontend=frontend) scheduler.process(transformation=GlobalVariableAnalysis()) - scheduler.process(transformation=GlobalVarHoistTransformation(hoist_parameters=hoist_parameters)) + scheduler.process(transformation=GlobalVarHoistTransformation(hoist_parameters=hoist_parameters, ignore_modules=('modulec',))) print("") print("") diff --git a/transformations/transformations/data_offload.py b/transformations/transformations/data_offload.py index a3a3db843..7a8b383ba 100644 --- a/transformations/transformations/data_offload.py +++ b/transformations/transformations/data_offload.py @@ -632,9 +632,13 @@ class GlobalVarHoistTransformation(Transformation): # connectivity for traversal with the Scheduler item_filter = (SubroutineItem,) # GlobalVarImportItem) - def __init__(self, hoist_parameters=True, key=None): + def __init__(self, hoist_parameters=True, ignore_modules=None, key=None): self._key = key or GlobalVariableAnalysis._key self.hoist_parameters = hoist_parameters + if ignore_modules is None: + self.ignore_modules = () + else: + self.ignore_modules = [module.lower() for module in ignore_modules] def transform_subroutine(self, routine, **kwargs): """ @@ -664,6 +668,10 @@ def process_driver(self, routine, successors): # Add imports for offload variables offload_map = defaultdict(set) for var, module in chain(all_uses_symbols, all_defines_symbols): + print(f"{module.lower()} vs. {self.ignore_modules}") + if module.lower() in self.ignore_modules: + print(f"ignoring module: {module.lower()}") + continue offload_map[module].add(var.parents[0] if var.parent else var) import_map = CaseInsensitiveDict() @@ -674,6 +682,9 @@ def process_driver(self, routine, successors): missing_imports_map = defaultdict(set) for module, variables in offload_map.items(): + if module.lower() in self.ignore_modules: + print(f"ignoring module: {module.lower()}") + continue missing_imports_map[module] |= {var for var in variables if var.name not in import_map} if missing_imports_map: @@ -681,6 +692,9 @@ def process_driver(self, routine, successors): '![Loki::GlobalVarHoistTransformation] ---------------------------------------' ))) for module, variables in missing_imports_map.items(): + if module.lower() in self.ignore_modules: + print(f"ignoring module: {module.lower()}") + continue symbols = tuple(var.clone(dimensions=None, scope=routine) for var in variables) routine.spec.prepend(Import(module=module, symbols=symbols)) @@ -706,6 +720,9 @@ def process_kernel(self, routine, successors, item): offload_map = defaultdict(set) for var, module in chain(all_uses_symbols, all_defines_symbols): + if module.lower() in self.ignore_modules: + print(f"ignoring module: {module.lower()}") + continue offload_map[module].add(var.parents[0] if var.parent else var) import_map = CaseInsensitiveDict() @@ -716,6 +733,10 @@ def process_kernel(self, routine, successors, item): redundant_imports_map = defaultdict(set) for module, variables in offload_map.items(): + # probably not necessary ... + #if module.lower() in self.ignore_modules: + # print(f"ignoring module: {module.lower()}") + # continue redundant_imports_map[module] |= {var for var in variables if var.name in import_map} import_map = {} @@ -755,6 +776,9 @@ def _append_call_arguments(self, routine, uses_symbols, defines_symbols): for key, _ in uses_symbols.items(): all_symbols = uses_symbols[key]|defines_symbols[key] for var, module in all_symbols: #sorted(all_symbols, key=lambda symbol: symbol[0].name): # uses_symbols[key]|defines_symbols[key]: + if module.lower() in self.ignore_modules: + print(f"ignoring module: {module.lower()}") + continue offload_map[key].add(var.parents[0] if var.parent else var) call_map = {} calls = FindNodes(CallStatement).visit(routine.body) @@ -771,7 +795,10 @@ def _append_routine_arguments(self, routine, item): all_uses_vars = [var.parent[0] if var.parent else var for var, _ in all_uses_symbols] all_symbols = all_uses_symbols|all_defines_symbols new_arguments = [] - for var, _ in all_symbols: + for var, module in all_symbols: + if module.lower() in self.ignore_modules: + print(f"ignoring module: {module.lower()}") + continue new_arguments.append(var.parents[0] if var.parent else var) new_arguments = set(new_arguments) # remove duplicates new_arguments = [arg.clone(type=arg.type.clone(intent='inout' if arg in all_defines_vars else 'in', parameter=False, initial=None)) for arg in new_arguments] From 747228ffa561c0df64547a83b73d0101fddc48d5 Mon Sep 17 00:00:00 2001 From: Michael Staneker Date: Thu, 1 Feb 2024 12:48:12 +0000 Subject: [PATCH 04/10] Don't discard parameters in 'GlobalVariableAnalysis' anymore, instead discard them in the transformation pass instead (if wanted/necessary) --- transformations/tests/test_data_offload.py | 18 ++++++----- .../transformations/data_offload.py | 30 +++++++++---------- 2 files changed, 25 insertions(+), 23 deletions(-) diff --git a/transformations/tests/test_data_offload.py b/transformations/tests/test_data_offload.py index 61d6ccd54..6fa7c41c0 100644 --- a/transformations/tests/test_data_offload.py +++ b/transformations/tests/test_data_offload.py @@ -386,11 +386,11 @@ def test_global_variable_analysis(frontend, key, config, global_variable_analysi expected_trafo_data = { 'global_var_analysis_header_mod#nval': { 'declares': {f'iarr({nfld_dim})', f'rarr({nval_dim}, {nfld_dim})'}, - 'offload': set() + 'offload': {'nval'} }, 'global_var_analysis_header_mod#nfld': { 'declares': {f'iarr({nfld_dim})', f'rarr({nval_dim}, {nfld_dim})'}, - 'offload': set() + 'offload': ['nfld'] }, 'global_var_analysis_header_mod#iarr': { 'declares': {f'iarr({nfld_dim})', f'rarr({nval_dim}, {nfld_dim})'}, @@ -412,6 +412,8 @@ def test_global_variable_analysis(frontend, key, config, global_variable_analysi 'global_var_analysis_kernel_mod#kernel_a': { 'defines_symbols': set(), 'uses_symbols': { + ('iarr(nfld)', 'global_var_analysis_header_mod'), ('nfld', 'global_var_analysis_header_mod'), + ('nval', 'global_var_analysis_header_mod'), (f'iarr({nfld_dim})', 'global_var_analysis_header_mod'), (f'rarr({nval_dim}, {nfld_dim})', 'global_var_analysis_header_mod') } @@ -419,6 +421,7 @@ def test_global_variable_analysis(frontend, key, config, global_variable_analysi 'global_var_analysis_kernel_mod#kernel_b': { 'defines_symbols': {('rdata(:, :, :)', 'global_var_analysis_data_mod')}, 'uses_symbols': { + ('iarr(nfld)', 'global_var_analysis_header_mod'), ('nfld', 'global_var_analysis_header_mod'), ('rdata(:, :, :)', 'global_var_analysis_data_mod'), ('tt', 'global_var_analysis_data_mod'), ('tt%vals', 'global_var_analysis_data_mod'), (f'iarr({nfld_dim})', 'global_var_analysis_header_mod') } @@ -426,8 +429,10 @@ def test_global_variable_analysis(frontend, key, config, global_variable_analysi '#driver': { 'defines_symbols': {('rdata(:, :, :)', 'global_var_analysis_data_mod')}, 'uses_symbols': { - ('rdata(:, :, :)', 'global_var_analysis_data_mod'), ('tt', 'global_var_analysis_data_mod'), - ('tt%vals', 'global_var_analysis_data_mod'), (f'iarr({nfld_dim})', 'global_var_analysis_header_mod'), + ('iarr(nfld)', 'global_var_analysis_header_mod'), ('nfld', 'global_var_analysis_header_mod'), + ('nval', 'global_var_analysis_header_mod'), ('rdata(:, :, :)', 'global_var_analysis_data_mod'), + ('tt', 'global_var_analysis_data_mod'), ('tt%vals', 'global_var_analysis_data_mod'), + (f'iarr({nfld_dim})', 'global_var_analysis_header_mod'), (f'rarr({nval_dim}, {nfld_dim})', 'global_var_analysis_header_mod') } } @@ -438,8 +443,7 @@ def test_global_variable_analysis(frontend, key, config, global_variable_analysi if item == 'global_var_analysis_data_mod#some_type': continue for trafo_data_key, trafo_data_value in item.trafo_data[key].items(): - if trafo_data_key not in ('defines_symbols', 'uses_symbols'): - continue + print(f"item: {item} | trafo_data_key: {trafo_data_key}") assert ( sorted( tuple(str(vv) for vv in v) if isinstance(v, tuple) else str(v) @@ -654,7 +658,7 @@ def test_transformation_global_var_hoist(here, config, frontend, hoist_parameter scheduler = Scheduler(paths=here/'sources/projGlobalVarImports', config=config, frontend=frontend) scheduler.process(transformation=GlobalVariableAnalysis()) - scheduler.process(transformation=GlobalVarHoistTransformation(hoist_parameters=hoist_parameters, ignore_modules=('modulec',))) + scheduler.process(transformation=GlobalVarHoistTransformation(hoist_parameters=hoist_parameters)) # , ignore_modules=('modulec',))) print("") print("") diff --git a/transformations/transformations/data_offload.py b/transformations/transformations/data_offload.py index 7a8b383ba..83de1a43b 100644 --- a/transformations/transformations/data_offload.py +++ b/transformations/transformations/data_offload.py @@ -211,8 +211,7 @@ class GlobalVariableAnalysis(Transformation): For procedures, use the the Loki dataflow analysis functionality to compile a list of used and/or defined variables (i.e., read and/or written). Store these under the keys ``'uses_symbols'`` and ``'defines_symbols'``, - respectively (and additionally ``'uses_parameters'`` for used symbols being - compile time constants). + respectively. For modules/:any:`GlobalVarImportItem`, store the list of variables declared in the module under the key ``'declares'`` and out of this the subset of variables that @@ -230,7 +229,6 @@ class GlobalVariableAnalysis(Transformation): SubroutineItem: { 'uses_symbols': set( (Variable, ''), (Variable, ''), ...), - 'uses_parameters': set( (Variable, ''), (Variable, ''), ...), 'defines_symbols': set((Variable, ''), (Variable, ''), ...) } @@ -303,11 +301,6 @@ def transform_subroutine(self, routine, **kwargs): uses_imported_symbols = {var for var in uses_imported_symbols if isinstance(var, (Scalar, Array))} defines_imported_symbols = {var for var in defines_imported_symbols if isinstance(var, (Scalar, Array))} - # Discard parameters (which are read-only by definition) - # uses_imported_symbols = {var for var in uses_imported_symbols if not var.type.parameter} - uses_imported_parameters = {var for var in uses_imported_symbols if var.type.parameter} - uses_imported_symbols ^= uses_imported_parameters # TODO: does it matter, wether ^= or -= ??? - def _map_var_to_module(var): if var.parent: module = var.parents[0].type.module @@ -328,9 +321,6 @@ def _map_var_to_module(var): item.trafo_data[self._key]['uses_symbols'] = { _map_var_to_module(var) for var in uses_imported_symbols } - item.trafo_data[self._key]['uses_parameters'] = { - _map_var_to_module(var) for var in uses_imported_parameters - } item.trafo_data[self._key]['defines_symbols'] = { _map_var_to_module(var) for var in defines_imported_symbols } @@ -356,7 +346,6 @@ def _map_var_to_module(var): for successor in successors: if isinstance(successor, SubroutineItem): item.trafo_data[self._key]['uses_symbols'] |= successor.trafo_data[self._key]['uses_symbols'] - item.trafo_data[self._key]['uses_parameters'] |= successor.trafo_data[self._key]['uses_parameters'] item.trafo_data[self._key]['defines_symbols'] |= successor.trafo_data[self._key]['defines_symbols'] @@ -474,10 +463,10 @@ def transform_module(self, module, **kwargs): for v in as_tuple(acc_pragma_parameters.get('create')) ])) - # Build list of symbols to be offloaded + # Build list of symbols to be offloaded (discard variables being parameter) offload_variables = { var.parents[0] if var.parent else var - for var in item.trafo_data[self._key].get('offload', ()) + for var in item.trafo_data[self._key].get('offload', ()) if not var.type.parameter } if (invalid_vars := offload_variables - set(module.variables)): @@ -517,6 +506,9 @@ def process_driver(self, routine, successors): for item in successors: defines_symbols |= item.trafo_data.get(self._key, {}).get('defines_symbols', set()) uses_symbols |= item.trafo_data.get(self._key, {}).get('uses_symbols', set()) + # discard variables being parameter + parameters = {(var, module) for var, module in uses_symbols if var.type.parameter} + uses_symbols ^= parameters # Filter out arrays of derived types and nested derived types # For these, automatic offloading is currently not supported @@ -764,9 +756,11 @@ def _get_symbols(self, routine, successors): defines_symbols[item.routine.name] = item.trafo_data.get(self._key, {}).get('defines_symbols', set()) # uses_symbols |= item.trafo_data.get(self._key, {}).get('uses_symbols', set()) uses_symbols[item.routine.name] = item.trafo_data.get(self._key, {}).get('uses_symbols', set()) - if self.hoist_parameters: + if not self.hoist_parameters: + parameters = {(var, module) for var, module in uses_symbols[item.routine.name] if var.type.parameter} + uses_symbols[item.routine.name] ^= parameters # uses_symbols |= item.trafo_data.get(self._key, {}).get('uses_parameters', set()) - uses_symbols[item.routine.name] |= item.trafo_data.get(self._key, {}).get('uses_parameters', set()) + # uses_symbols[item.routine.name] |= item.trafo_data.get(self._key, {}).get('uses_parameters') # , set()) return defines_symbols, uses_symbols def _append_call_arguments(self, routine, uses_symbols, defines_symbols): @@ -792,7 +786,11 @@ def _append_routine_arguments(self, routine, item): all_defines_symbols = item.trafo_data.get(self._key, {}).get('defines_symbols', set()) # set() all_defines_vars = [var.parents[0] if var.parent else var for var, _ in all_defines_symbols] all_uses_symbols = item.trafo_data.get(self._key, {}).get('uses_symbols', set()) # set() + parameters = {(var, module) for var, module in all_uses_symbols if var.type.parameter} + if not self.hoist_parameters: + all_uses_symbols ^= parameters all_uses_vars = [var.parent[0] if var.parent else var for var, _ in all_uses_symbols] + all_symbols = all_uses_symbols|all_defines_symbols new_arguments = [] for var, module in all_symbols: From 804245ab793fd414ef5c05111a832122109b65b8 Mon Sep 17 00:00:00 2001 From: Michael Staneker Date: Fri, 2 Feb 2024 11:26:47 +0000 Subject: [PATCH 05/10] Finalized 'GlobalVarHoistTransformation' including tests --- transformations/tests/test_data_offload.py | 133 +++++++--- .../transformations/data_offload.py | 241 ++++++++++++------ 2 files changed, 262 insertions(+), 112 deletions(-) diff --git a/transformations/tests/test_data_offload.py b/transformations/tests/test_data_offload.py index 6fa7c41c0..08a891555 100644 --- a/transformations/tests/test_data_offload.py +++ b/transformations/tests/test_data_offload.py @@ -12,11 +12,11 @@ from loki import ( Sourcefile, FindNodes, Pragma, PragmaRegion, Loop, CallStatement, pragma_regions_attached, get_pragma_parameters, - gettempdir, Scheduler, OMNI, Import, fgen + gettempdir, Scheduler, OMNI, Import ) from conftest import available_frontends from transformations import ( - DataOffloadTransformation, GlobalVariableAnalysis, + DataOffloadTransformation, GlobalVariableAnalysis, GlobalVarOffloadTransformation, GlobalVarHoistTransformation ) @@ -647,9 +647,10 @@ def test_transformation_global_var_import_derived_type(here, config, frontend): @pytest.mark.parametrize('frontend', available_frontends()) @pytest.mark.parametrize('hoist_parameters', (False, True)) -def test_transformation_global_var_hoist(here, config, frontend, hoist_parameters): +@pytest.mark.parametrize('ignore_modules', (None, ('moduleb',))) +def test_transformation_global_var_hoist(here, config, frontend, hoist_parameters, ignore_modules): """ - Test the generation of offload instructions of global variable imports. + Test hoisting of global variable imports. """ config['default']['enable_imports'] = True config['routines'] = { @@ -658,42 +659,85 @@ def test_transformation_global_var_hoist(here, config, frontend, hoist_parameter scheduler = Scheduler(paths=here/'sources/projGlobalVarImports', config=config, frontend=frontend) scheduler.process(transformation=GlobalVariableAnalysis()) - scheduler.process(transformation=GlobalVarHoistTransformation(hoist_parameters=hoist_parameters)) # , ignore_modules=('modulec',))) - - print("") - print("") + scheduler.process(transformation=GlobalVarHoistTransformation(hoist_parameters=hoist_parameters, + ignore_modules=ignore_modules)) driver = scheduler['#driver'].routine kernel0 = scheduler['#kernel0'].routine - kernel1 = scheduler['#kernel1'].routine - kernel2 = scheduler['#kernel2'].routine - kernel3 = scheduler['#kernel3'].routine - moduleA = scheduler['modulea#var0'].scope - moduleB = scheduler['moduleb#var2'].scope - moduleC = scheduler['modulec#var4'].scope - - print(fgen(driver)) - print("----------") - print(fgen(kernel0)) - print("----------") - print(fgen(kernel1)) - print("----------") - print(fgen(kernel2)) - print("----------") - print(fgen(kernel3)) - print("----------") - # print(fgen(moduleA)) - print("----------") - # print(fgen(moduleB)) - print("----------") - # print(fgen(moduleC)) + kernel_map = {key: scheduler[f'#{key}'].routine for key in ['kernel1', 'kernel2', 'kernel3']} + + # symbols within each module + expected_symbols = {'modulea': ['var0', 'var1'], 'moduleb': ['var2', 'var3'], + 'modulec': ['var4', 'var5']} + # expected intent of those variables (if hoisted) + var_intent_map = {'var0': 'in', 'var1': 'in', 'var2': 'in', + 'var3': 'in', 'var4': 'inout', 'var5': 'inout', 'tmp': None} + # DRIVER + imports = FindNodes(Import).visit(driver.spec) + import_names = [_import.module.lower() for _import in imports] + # check driver imports + expected_driver_modules = ['modulec'] + expected_driver_modules += ['moduleb'] if ignore_modules is None else [] + if frontend != OMNI: + expected_driver_modules += ['modulea'] if hoist_parameters else [] + assert len(imports) == len(expected_driver_modules) + assert sorted(expected_driver_modules) == sorted(import_names) + for _import in imports: + assert sorted([sym.name for sym in _import.symbols]) == expected_symbols[_import.module.lower()] + # check driver call + driver_calls = FindNodes(CallStatement).visit(driver.body) + expected_args = [] + for module in expected_driver_modules: + expected_args.extend(expected_symbols[module]) + assert [arg.name for arg in driver_calls[0].arguments] == sorted(expected_args) + + originally = {'kernel1': ['modulea'], 'kernel2': ['moduleb'], + 'kernel3': ['moduleb', 'modulec']} + # KERNEL0 + assert [arg.name for arg in kernel0.arguments] == sorted(expected_args) + assert [arg.name for arg in kernel0.variables] == sorted(expected_args) + for var in kernel0.variables: + assert kernel0.variable_map[var.name.lower()].type.intent == var_intent_map[var.name.lower()] + kernel0_calls = FindNodes(CallStatement).visit(kernel0.body) + # KERNEL1 & KERNEL2 & KERNEL3 + for call in kernel0_calls: + expected_args = [] + expected_imports = [] + kernel_expected_symbols = [] + for module in originally[call.routine.name]: + # always, since at least 'some_func' is imported + if call.routine.name == 'kernel1' and module == 'modulea': + expected_imports.append(module) + kernel_expected_symbols.append('some_func') + if module in expected_driver_modules: + expected_args.extend(expected_symbols[module]) + else: + # already added + if module != 'modulea': + expected_imports.append(module) + kernel_expected_symbols.extend(expected_symbols[module]) + assert len(expected_args) == len(call.arguments) + assert [arg.name for arg in call.arguments] == expected_args + assert [arg.name for arg in kernel_map[call.routine.name].arguments] == expected_args + for var in kernel_map[call.routine.name].variables: + var_intent = kernel_map[call.routine.name].variable_map[var.name.lower()].type.intent + assert var_intent == var_intent_map[var.name.lower()] + if call.routine.name in ['kernel1', 'kernel2']: + expected_args = ['tmp'] + expected_args + assert [arg.name for arg in kernel_map[call.routine.name].variables] == expected_args + kernel_imports = FindNodes(Import).visit(call.routine.spec) + assert sorted([_import.module.lower() for _import in kernel_imports]) == sorted(expected_imports) + imported_symbols = [] # _import.symbols for _import in kernel_imports] + for _import in kernel_imports: + imported_symbols.extend([sym.name.lower() for sym in _import.symbols]) + assert sorted(imported_symbols) == sorted(kernel_expected_symbols) @pytest.mark.parametrize('frontend', available_frontends()) @pytest.mark.parametrize('hoist_parameters', (False, True)) def test_transformation_global_var_derived_type_hoist(here, config, frontend, hoist_parameters): """ - Test the generation of offload instructions of derived-type global variable imports. + Test hoisting of derived-type global variable imports. """ config['default']['enable_imports'] = True @@ -707,8 +751,25 @@ def test_transformation_global_var_derived_type_hoist(here, config, frontend, ho driver = scheduler['#driver_derived_type'].routine kernel = scheduler['#kernel_derived_type'].routine - module = scheduler['module_derived_type#p'].scope - - print(fgen(driver)) - print("--------------------") - print(fgen(kernel)) + + # DRIVER + imports = FindNodes(Import).visit(driver.spec) + assert len(imports) == 1 + assert imports[0].module.lower() == 'module_derived_type' + assert sorted([sym.name.lower() for sym in imports[0].symbols]) == sorted(['p', 'p_array', 'p0']) + calls = FindNodes(CallStatement).visit(driver.body) + assert len(calls) == 1 + # KERNEL + assert [arg.name for arg in calls[0].arguments] == ['p', 'p0', 'p_array'] + assert [arg.name for arg in kernel.arguments] == ['p', 'p0', 'p_array'] + kernel_imports = FindNodes(Import).visit(kernel.spec) + assert len(kernel_imports) == 1 + assert [sym.name.lower() for sym in kernel_imports[0].symbols] == ['g'] + assert sorted([var.name for var in kernel.variables]) == ['i', 'j', 'p', 'p0', 'p_array'] + assert kernel.variable_map['p_array'].type.allocatable + assert kernel.variable_map['p_array'].type.intent == 'inout' + assert kernel.variable_map['p_array'].type.dtype.name == 'point' + assert kernel.variable_map['p'].type.intent == 'inout' + assert kernel.variable_map['p'].type.dtype.name == 'point' + assert kernel.variable_map['p0'].type.intent == 'in' + assert kernel.variable_map['p0'].type.dtype.name == 'point' diff --git a/transformations/transformations/data_offload.py b/transformations/transformations/data_offload.py index 83de1a43b..42d27c641 100644 --- a/transformations/transformations/data_offload.py +++ b/transformations/transformations/data_offload.py @@ -618,13 +618,98 @@ def process_driver(self, routine, successors): class GlobalVarHoistTransformation(Transformation): """ - TODO: DOCSTRING ... + Transformation to hoist module variables used in device routines + + This requires a prior analysis pass with :any:`GlobalVariableAnalysis` to collect + the relevant global variable use information. + + Modules to be ignored can be specified. Further, it is possible to + configure whether parameters/compile time constants are hoisted as well + or not. + + .. note:: + Hoisted variables that could theoretically be ``intent(out)`` + are despite specified as ``intent(inout)``. + + For example, the following code: + + .. code-block:: fortran + + module moduleB + real :: var2 + real :: var3 + end module moduleB + + module moduleC + real :: var4 + real :: var5 + end module moduleC + + subroutine driver() + implicit none + + call kernel() + + end subroutine driver + + subroutine kernel() + use moduleB, only: var2,var3 + use moduleC, only: var4,var5 + implicit none + + var4 = var2 + var5 = var3 + + end subroutine kernel + + is transformed to: + + .. code-block:: fortran + + module moduleB + real :: var2 + real :: var3 + end module moduleB + + module moduleC + real :: var4 + real :: var5 + end module moduleC + + subroutine driver() + use moduleB, only: var2,var3 + use moduleC, only: var4,var5 + implicit none + + call kernel(var2, var3, var4, var5) + + end subroutine driver + + subroutine kernel(var2, var3, var4, var5) + implicit none + real, intent(in) :: var2 + real, intent(in) :: var3 + real, intent(inout) :: var4 + real, intent(inout) :: var5 + + var4 = var2 + var5 = var3 + + end subroutine kernel + + Parameters + ---------- + hoist_parameters : bool, optional + Whether or not to hoist module variables being parameter/compile + time constants (default: `False`). + ignore_modules : (list, tuple) of str + Modules to be ignored (default: `None`, thus no module to be ignored). + key : str, optional + Overwrite the key that is used to store analysis results in ``trafo_data``. """ - # Include module variable imports in the underlying graph - # connectivity for traversal with the Scheduler - item_filter = (SubroutineItem,) # GlobalVarImportItem) + item_filter = (SubroutineItem,) - def __init__(self, hoist_parameters=True, ignore_modules=None, key=None): + def __init__(self, hoist_parameters=False, ignore_modules=None, key=None): self._key = key or GlobalVariableAnalysis._key self.hoist_parameters = hoist_parameters if ignore_modules is None: @@ -634,7 +719,7 @@ def __init__(self, hoist_parameters=True, ignore_modules=None, key=None): def transform_subroutine(self, routine, **kwargs): """ - Add data offload and pull-back directives to the driver + Hoist module variables. """ role = kwargs.get('role') successors = kwargs.get('successors', ()) @@ -644,49 +729,48 @@ def transform_subroutine(self, routine, **kwargs): self.process_driver(routine, successors) elif role == 'kernel': self.process_kernel(routine, successors, item) - + def process_driver(self, routine, successors): - defines_symbols, uses_symbols = self._get_symbols(routine, successors) - ## collect disregarding module + """ + Hoist module variables for driver routines. + + This includes: appending the corresponding variables + to calls within the driver and adding the relevant + imports. + """ + # get symbols per routine (successors) + defines_symbols, uses_symbols = self._get_symbols(successors) + + # append symbols to calls (arguments) + self._append_call_arguments(routine, uses_symbols, defines_symbols) + + # combine/collect symbols disregarding routine all_defines_symbols = set() all_uses_symbols = set() for _, value in defines_symbols.items(): all_defines_symbols |= value for _, value in uses_symbols.items(): all_uses_symbols |= value - ## - self._append_call_arguments(routine, uses_symbols, defines_symbols) - # DONE: append/add/adapt imports - # Add imports for offload variables - offload_map = defaultdict(set) + # add imports for symbols hoisted + symbol_map = defaultdict(set) for var, module in chain(all_uses_symbols, all_defines_symbols): - print(f"{module.lower()} vs. {self.ignore_modules}") + # filter modules that are supposed to be ignored if module.lower() in self.ignore_modules: - print(f"ignoring module: {module.lower()}") continue - offload_map[module].add(var.parents[0] if var.parent else var) - + symbol_map[module].add(var.parents[0] if var.parent else var) import_map = CaseInsensitiveDict() scope = routine while scope: import_map.update(scope.import_map) scope = scope.parent - missing_imports_map = defaultdict(set) - for module, variables in offload_map.items(): - if module.lower() in self.ignore_modules: - print(f"ignoring module: {module.lower()}") - continue + for module, variables in symbol_map.items(): missing_imports_map[module] |= {var for var in variables if var.name not in import_map} - if missing_imports_map: routine.spec.prepend(Comment(text=( '![Loki::GlobalVarHoistTransformation] ---------------------------------------' ))) for module, variables in missing_imports_map.items(): - if module.lower() in self.ignore_modules: - print(f"ignoring module: {module.lower()}") - continue symbols = tuple(var.clone(dimensions=None, scope=routine) for var in variables) routine.spec.prepend(Import(module=module, symbols=symbols)) @@ -694,110 +778,115 @@ def process_driver(self, routine, successors): '![Loki::GlobalVarHoistTransformation] ' '-------- Added global variable imports for offload directives -----------' ))) - ## def process_kernel(self, routine, successors, item): - defines_symbols, uses_symbols = self._get_symbols(routine, successors) - # DONE: add to routine arguments - # TODO: intent('out') is evil, just use 'intent('inout') like currently implemented?! + """ + Hoist mdule variables for kernel routines. + + This includes: appending the corresponding variables + to the routine arguments as well as to calls within the kernel + and removing the imports that became unused. + """ + # get symbols per routine (successors) + defines_symbols, uses_symbols = self._get_symbols(successors) + + # append symbols to routine (arguments) self._append_routine_arguments(routine, item) - ### - # DONE: add to calls + + # append symbols to calls (arguments) self._append_call_arguments(routine, uses_symbols, defines_symbols) - ### - # DONE: remove/adapt imports - # Add imports for offload variables - all_defines_symbols = item.trafo_data.get(self._key, {}).get('defines_symbols', set()) # set() - all_uses_symbols = item.trafo_data.get(self._key, {}).get('uses_symbols', set()) # set() - offload_map = defaultdict(set) - for var, module in chain(all_uses_symbols, all_defines_symbols): + # get symbols for this routine/kernel + kernel_defines_symbols = item.trafo_data.get(self._key, {}).get('defines_symbols', set()) + kernel_uses_symbols = item.trafo_data.get(self._key, {}).get('uses_symbols', set()) + # remove imports for symbols hoisted + symbol_map = defaultdict(set) + for var, module in chain(kernel_uses_symbols, kernel_defines_symbols): + # filter modules that are supposed to be ignored if module.lower() in self.ignore_modules: - print(f"ignoring module: {module.lower()}") continue - offload_map[module].add(var.parents[0] if var.parent else var) - + symbol_map[module].add(var.parents[0] if var.parent else var) import_map = CaseInsensitiveDict() scope = routine while scope: import_map.update(scope.import_map) scope = scope.parent - redundant_imports_map = defaultdict(set) - for module, variables in offload_map.items(): - # probably not necessary ... - #if module.lower() in self.ignore_modules: - # print(f"ignoring module: {module.lower()}") - # continue - redundant_imports_map[module] |= {var for var in variables if var.name in import_map} - + for module, variables in symbol_map.items(): + redundant = [var.parent[0] if var.parent else var for var in variables] + redundant = {var.clone(dimensions=None) for var in redundant if var.name in import_map} + redundant_imports_map[module] |= redundant import_map = {} imports = FindNodes(Import).visit(routine.spec) for _import in imports: new_symbols = tuple(var.clone(dimensions=None, scope=routine) - for var in set(_import.symbols)^redundant_imports_map[_import.module.lower()]) + for var in set(_import.symbols)-redundant_imports_map[_import.module.lower()]) if new_symbols: import_map[_import] = _import.clone(symbols=new_symbols) else: import_map[_import] = None routine.spec = Transformer(import_map).visit(routine.spec) - ## - def _get_symbols(self, routine, successors): - # Combine analysis data across successor items - defines_symbols = {} # set() - uses_symbols = {} # set() + def _get_symbols(self, successors): + """ + Get module variables/symbols (grouped by routine/successor). + """ + defines_symbols = {} + uses_symbols = {} for item in successors: if isinstance(item, GlobalVarImportItem): continue defines_symbols[item.routine.name] = set() uses_symbols[item.routine.name] = set() - # defines_symbols |= item.trafo_data.get(self._key, {}).get('defines_symbols', set()) defines_symbols[item.routine.name] = item.trafo_data.get(self._key, {}).get('defines_symbols', set()) - # uses_symbols |= item.trafo_data.get(self._key, {}).get('uses_symbols', set()) uses_symbols[item.routine.name] = item.trafo_data.get(self._key, {}).get('uses_symbols', set()) + # remove parameters if hoist_parameters is False if not self.hoist_parameters: parameters = {(var, module) for var, module in uses_symbols[item.routine.name] if var.type.parameter} uses_symbols[item.routine.name] ^= parameters - # uses_symbols |= item.trafo_data.get(self._key, {}).get('uses_parameters', set()) - # uses_symbols[item.routine.name] |= item.trafo_data.get(self._key, {}).get('uses_parameters') # , set()) return defines_symbols, uses_symbols def _append_call_arguments(self, routine, uses_symbols, defines_symbols): - # DONE: add to calls - offload_map = defaultdict(set) - # all_symbols = uses_symbols[key]|defines_symbols[key] + """ + Helper to append variables to the call(s) (arguments). + """ + symbol_map = defaultdict(set) for key, _ in uses_symbols.items(): all_symbols = uses_symbols[key]|defines_symbols[key] - for var, module in all_symbols: #sorted(all_symbols, key=lambda symbol: symbol[0].name): # uses_symbols[key]|defines_symbols[key]: + for var, module in all_symbols: + # filter modules that are supposed to be ignored if module.lower() in self.ignore_modules: - print(f"ignoring module: {module.lower()}") continue - offload_map[key].add(var.parents[0] if var.parent else var) + symbol_map[key].add(var.parents[0] if var.parent else var) call_map = {} calls = FindNodes(CallStatement).visit(routine.body) for call in calls: - if call.routine.name in uses_symbols: # in chain(uses_symbols[call.routine.name], defines_symbols[call.routine.name]): + if call.routine.name in uses_symbols: arguments = call.arguments - call_map[call] = call.clone(arguments=arguments + tuple(sorted([var.clone(dimensions=None) for var in offload_map[call.routine.name]], key=lambda symbol: symbol.name))) # uses_symbols[call.routine.name])) + new_args = sorted([var.clone(dimensions=None) for var in symbol_map[call.routine.name]], + key=lambda symbol: symbol.name) + call_map[call] = call.clone(arguments=arguments + tuple(new_args)) routine.body = Transformer(call_map).visit(routine.body) def _append_routine_arguments(self, routine, item): - all_defines_symbols = item.trafo_data.get(self._key, {}).get('defines_symbols', set()) # set() + """ + Helper to append variables to the routine (arguments). + """ + all_defines_symbols = item.trafo_data.get(self._key, {}).get('defines_symbols', set()) all_defines_vars = [var.parents[0] if var.parent else var for var, _ in all_defines_symbols] - all_uses_symbols = item.trafo_data.get(self._key, {}).get('uses_symbols', set()) # set() - parameters = {(var, module) for var, module in all_uses_symbols if var.type.parameter} + all_uses_symbols = item.trafo_data.get(self._key, {}).get('uses_symbols', set()) + # remove parameters if hoist_parameters is False if not self.hoist_parameters: + parameters = {(var, module) for var, module in all_uses_symbols if var.type.parameter} all_uses_symbols ^= parameters - all_uses_vars = [var.parent[0] if var.parent else var for var, _ in all_uses_symbols] - all_symbols = all_uses_symbols|all_defines_symbols new_arguments = [] for var, module in all_symbols: + # filter modules that are supposed to be ignored if module.lower() in self.ignore_modules: - print(f"ignoring module: {module.lower()}") continue new_arguments.append(var.parents[0] if var.parent else var) new_arguments = set(new_arguments) # remove duplicates - new_arguments = [arg.clone(type=arg.type.clone(intent='inout' if arg in all_defines_vars else 'in', parameter=False, initial=None)) for arg in new_arguments] + new_arguments = [arg.clone(type=arg.type.clone(intent='inout' if arg in all_defines_vars + else 'in', parameter=False, initial=None)) for arg in new_arguments] routine.arguments += tuple(sorted(new_arguments, key=lambda symbol: symbol.name)) From 6192eb1a11db068bcbef355eac70de76fb8dc914 Mon Sep 17 00:00:00 2001 From: Michael Staneker Date: Fri, 2 Feb 2024 13:21:21 +0000 Subject: [PATCH 06/10] fix test for 'GlobalVariableAnalysis' for frontend OMNI --- transformations/tests/test_data_offload.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/transformations/tests/test_data_offload.py b/transformations/tests/test_data_offload.py index 08a891555..fe022102f 100644 --- a/transformations/tests/test_data_offload.py +++ b/transformations/tests/test_data_offload.py @@ -386,11 +386,11 @@ def test_global_variable_analysis(frontend, key, config, global_variable_analysi expected_trafo_data = { 'global_var_analysis_header_mod#nval': { 'declares': {f'iarr({nfld_dim})', f'rarr({nval_dim}, {nfld_dim})'}, - 'offload': {'nval'} + 'offload': set() if frontend == OMNI else {'nval'} }, 'global_var_analysis_header_mod#nfld': { 'declares': {f'iarr({nfld_dim})', f'rarr({nval_dim}, {nfld_dim})'}, - 'offload': ['nfld'] + 'offload': set() if frontend == OMNI else ['nfld'] }, 'global_var_analysis_header_mod#iarr': { 'declares': {f'iarr({nfld_dim})', f'rarr({nval_dim}, {nfld_dim})'}, @@ -412,6 +412,8 @@ def test_global_variable_analysis(frontend, key, config, global_variable_analysi 'global_var_analysis_kernel_mod#kernel_a': { 'defines_symbols': set(), 'uses_symbols': { + ('iarr(1:3)', 'global_var_analysis_header_mod'), ('rarr(1:5, 1:3)', 'global_var_analysis_header_mod') + } if frontend == OMNI else { ('iarr(nfld)', 'global_var_analysis_header_mod'), ('nfld', 'global_var_analysis_header_mod'), ('nval', 'global_var_analysis_header_mod'), (f'iarr({nfld_dim})', 'global_var_analysis_header_mod'), @@ -421,6 +423,9 @@ def test_global_variable_analysis(frontend, key, config, global_variable_analysi 'global_var_analysis_kernel_mod#kernel_b': { 'defines_symbols': {('rdata(:, :, :)', 'global_var_analysis_data_mod')}, 'uses_symbols': { + ('iarr(1:3)', 'global_var_analysis_header_mod'), ('rdata(:, :, :)', 'global_var_analysis_data_mod'), + ('tt', 'global_var_analysis_data_mod'), ('tt%vals', 'global_var_analysis_data_mod') + } if frontend == OMNI else { ('iarr(nfld)', 'global_var_analysis_header_mod'), ('nfld', 'global_var_analysis_header_mod'), ('rdata(:, :, :)', 'global_var_analysis_data_mod'), ('tt', 'global_var_analysis_data_mod'), ('tt%vals', 'global_var_analysis_data_mod'), (f'iarr({nfld_dim})', 'global_var_analysis_header_mod') @@ -429,6 +434,10 @@ def test_global_variable_analysis(frontend, key, config, global_variable_analysi '#driver': { 'defines_symbols': {('rdata(:, :, :)', 'global_var_analysis_data_mod')}, 'uses_symbols': { + ('iarr(1:3)', 'global_var_analysis_header_mod'), ('rarr(1:5, 1:3)', 'global_var_analysis_header_mod'), + ('rdata(:, :, :)', 'global_var_analysis_data_mod'), ('tt', 'global_var_analysis_data_mod'), + ('tt%vals', 'global_var_analysis_data_mod') + } if frontend == OMNI else { ('iarr(nfld)', 'global_var_analysis_header_mod'), ('nfld', 'global_var_analysis_header_mod'), ('nval', 'global_var_analysis_header_mod'), ('rdata(:, :, :)', 'global_var_analysis_data_mod'), ('tt', 'global_var_analysis_data_mod'), ('tt%vals', 'global_var_analysis_data_mod'), @@ -443,7 +452,6 @@ def test_global_variable_analysis(frontend, key, config, global_variable_analysi if item == 'global_var_analysis_data_mod#some_type': continue for trafo_data_key, trafo_data_value in item.trafo_data[key].items(): - print(f"item: {item} | trafo_data_key: {trafo_data_key}") assert ( sorted( tuple(str(vv) for vv in v) if isinstance(v, tuple) else str(v) From ee561641dc6de29bf05f48e4c41206bedaebc0a6 Mon Sep 17 00:00:00 2001 From: Michael Staneker <50531288+MichaelSt98@users.noreply.github.com> Date: Mon, 5 Feb 2024 13:13:56 +0100 Subject: [PATCH 07/10] improve init of `ignore_modules` in 'GlobalVarHoistTransformation' Co-authored-by: Balthasar Reuter <6384870+reuterbal@users.noreply.github.com> --- transformations/transformations/data_offload.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/transformations/transformations/data_offload.py b/transformations/transformations/data_offload.py index 42d27c641..9aec6b5ab 100644 --- a/transformations/transformations/data_offload.py +++ b/transformations/transformations/data_offload.py @@ -712,10 +712,7 @@ class GlobalVarHoistTransformation(Transformation): def __init__(self, hoist_parameters=False, ignore_modules=None, key=None): self._key = key or GlobalVariableAnalysis._key self.hoist_parameters = hoist_parameters - if ignore_modules is None: - self.ignore_modules = () - else: - self.ignore_modules = [module.lower() for module in ignore_modules] + self.ignore_modules = [module.lower() for module in as_tuple(ignore_modules)] def transform_subroutine(self, routine, **kwargs): """ From 6e56895b47e031a35aed28129bf2109f15070d52 Mon Sep 17 00:00:00 2001 From: Michael Staneker Date: Mon, 5 Feb 2024 15:40:14 +0000 Subject: [PATCH 08/10] simplified expression and handling of edge case for GlobalVarHoisting trafo --- transformations/transformations/data_offload.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/transformations/transformations/data_offload.py b/transformations/transformations/data_offload.py index 9aec6b5ab..4f7df8b3a 100644 --- a/transformations/transformations/data_offload.py +++ b/transformations/transformations/data_offload.py @@ -742,12 +742,8 @@ def process_driver(self, routine, successors): self._append_call_arguments(routine, uses_symbols, defines_symbols) # combine/collect symbols disregarding routine - all_defines_symbols = set() - all_uses_symbols = set() - for _, value in defines_symbols.items(): - all_defines_symbols |= value - for _, value in uses_symbols.items(): - all_uses_symbols |= value + all_defines_symbols = set.union(*defines_symbols.values(), set()) + all_uses_symbols = set.union(*uses_symbols.values(), set()) # add imports for symbols hoisted symbol_map = defaultdict(set) for var, module in chain(all_uses_symbols, all_defines_symbols): @@ -803,7 +799,9 @@ def process_kernel(self, routine, successors, item): if module.lower() in self.ignore_modules: continue symbol_map[module].add(var.parents[0] if var.parent else var) - import_map = CaseInsensitiveDict() + import_map = CaseInsensitiveDict( + (s.name, imprt) for imprt in routine.all_imports[::-1] for s in imprt.symbols + ) scope = routine while scope: import_map.update(scope.import_map) From 45db01a92cf38881eaa59a09984f314a60f61581 Mon Sep 17 00:00:00 2001 From: Michael Staneker Date: Mon, 5 Feb 2024 15:41:03 +0000 Subject: [PATCH 09/10] improved readability for tests GlobalVariableAnalyis and GlobalVarHoisting trafo --- transformations/tests/test_data_offload.py | 45 ++++++++++++---------- 1 file changed, 24 insertions(+), 21 deletions(-) diff --git a/transformations/tests/test_data_offload.py b/transformations/tests/test_data_offload.py index fe022102f..425c1ba16 100644 --- a/transformations/tests/test_data_offload.py +++ b/transformations/tests/test_data_offload.py @@ -381,16 +381,31 @@ def test_global_variable_analysis(frontend, key, config, global_variable_analysi key = GlobalVariableAnalysis._key # Validate the analysis trafo_data - nfld_dim = '1:3' if frontend == OMNI else 'nfld' - nval_dim = '1:5' if frontend == OMNI else 'nval' + + # OMNI handles array indices and parameters differently + if frontend == OMNI: + nfld_dim = '1:3' + nval_dim = '1:5' + nfld_data = set() + nval_data = set() + nval_offload = set() + nfld_offload = set() + else: + nfld_dim = 'nfld' + nval_dim = 'nval' + nfld_data = {('nfld', 'global_var_analysis_header_mod')} + nval_data = {('nval', 'global_var_analysis_header_mod')} + nval_offload = {'nval'} + nfld_offload = {'nfld'} + expected_trafo_data = { 'global_var_analysis_header_mod#nval': { 'declares': {f'iarr({nfld_dim})', f'rarr({nval_dim}, {nfld_dim})'}, - 'offload': set() if frontend == OMNI else {'nval'} + 'offload': nval_offload }, 'global_var_analysis_header_mod#nfld': { 'declares': {f'iarr({nfld_dim})', f'rarr({nval_dim}, {nfld_dim})'}, - 'offload': set() if frontend == OMNI else ['nfld'] + 'offload': nfld_offload }, 'global_var_analysis_header_mod#iarr': { 'declares': {f'iarr({nfld_dim})', f'rarr({nval_dim}, {nfld_dim})'}, @@ -411,35 +426,22 @@ def test_global_variable_analysis(frontend, key, config, global_variable_analysi 'global_var_analysis_data_mod#some_routine': {'defines_symbols': set(), 'uses_symbols': set()}, 'global_var_analysis_kernel_mod#kernel_a': { 'defines_symbols': set(), - 'uses_symbols': { - ('iarr(1:3)', 'global_var_analysis_header_mod'), ('rarr(1:5, 1:3)', 'global_var_analysis_header_mod') - } if frontend == OMNI else { - ('iarr(nfld)', 'global_var_analysis_header_mod'), ('nfld', 'global_var_analysis_header_mod'), - ('nval', 'global_var_analysis_header_mod'), + 'uses_symbols': nval_data | nfld_data | { (f'iarr({nfld_dim})', 'global_var_analysis_header_mod'), (f'rarr({nval_dim}, {nfld_dim})', 'global_var_analysis_header_mod') } }, 'global_var_analysis_kernel_mod#kernel_b': { 'defines_symbols': {('rdata(:, :, :)', 'global_var_analysis_data_mod')}, - 'uses_symbols': { - ('iarr(1:3)', 'global_var_analysis_header_mod'), ('rdata(:, :, :)', 'global_var_analysis_data_mod'), - ('tt', 'global_var_analysis_data_mod'), ('tt%vals', 'global_var_analysis_data_mod') - } if frontend == OMNI else { - ('iarr(nfld)', 'global_var_analysis_header_mod'), ('nfld', 'global_var_analysis_header_mod'), + 'uses_symbols': nfld_data | { ('rdata(:, :, :)', 'global_var_analysis_data_mod'), ('tt', 'global_var_analysis_data_mod'), ('tt%vals', 'global_var_analysis_data_mod'), (f'iarr({nfld_dim})', 'global_var_analysis_header_mod') } }, '#driver': { 'defines_symbols': {('rdata(:, :, :)', 'global_var_analysis_data_mod')}, - 'uses_symbols': { - ('iarr(1:3)', 'global_var_analysis_header_mod'), ('rarr(1:5, 1:3)', 'global_var_analysis_header_mod'), - ('rdata(:, :, :)', 'global_var_analysis_data_mod'), ('tt', 'global_var_analysis_data_mod'), - ('tt%vals', 'global_var_analysis_data_mod') - } if frontend == OMNI else { - ('iarr(nfld)', 'global_var_analysis_header_mod'), ('nfld', 'global_var_analysis_header_mod'), - ('nval', 'global_var_analysis_header_mod'), ('rdata(:, :, :)', 'global_var_analysis_data_mod'), + 'uses_symbols': nval_data | nfld_data | { + ('rdata(:, :, :)', 'global_var_analysis_data_mod'), ('tt', 'global_var_analysis_data_mod'), ('tt%vals', 'global_var_analysis_data_mod'), (f'iarr({nfld_dim})', 'global_var_analysis_header_mod'), (f'rarr({nval_dim}, {nfld_dim})', 'global_var_analysis_header_mod') @@ -686,6 +688,7 @@ def test_transformation_global_var_hoist(here, config, frontend, hoist_parameter # check driver imports expected_driver_modules = ['modulec'] expected_driver_modules += ['moduleb'] if ignore_modules is None else [] + # OMNI handles parameters differently, ModuleA only contains parameters if frontend != OMNI: expected_driver_modules += ['modulea'] if hoist_parameters else [] assert len(imports) == len(expected_driver_modules) From a840d1dea52db1b5ded98c5a831444b8f78ed3a4 Mon Sep 17 00:00:00 2001 From: Michael Staneker <50531288+MichaelSt98@users.noreply.github.com> Date: Mon, 5 Feb 2024 21:46:22 +0100 Subject: [PATCH 10/10] proper handling of edge case for 'GlobalVarHoisting' Co-authored-by: Balthasar Reuter <6384870+reuterbal@users.noreply.github.com> --- transformations/transformations/data_offload.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/transformations/transformations/data_offload.py b/transformations/transformations/data_offload.py index 4f7df8b3a..058d8ed7b 100644 --- a/transformations/transformations/data_offload.py +++ b/transformations/transformations/data_offload.py @@ -802,10 +802,6 @@ def process_kernel(self, routine, successors, item): import_map = CaseInsensitiveDict( (s.name, imprt) for imprt in routine.all_imports[::-1] for s in imprt.symbols ) - scope = routine - while scope: - import_map.update(scope.import_map) - scope = scope.parent redundant_imports_map = defaultdict(set) for module, variables in symbol_map.items(): redundant = [var.parent[0] if var.parent else var for var in variables]