diff --git a/transformations/tests/test_data_offload.py b/transformations/tests/test_data_offload.py index 4b60148fb..8745a2bf5 100644 --- a/transformations/tests/test_data_offload.py +++ b/transformations/tests/test_data_offload.py @@ -388,24 +388,20 @@ def test_global_variable_analysis(frontend, key, config, global_variable_analysi nval_dim = '1:5' nfld_data = set() nval_data = set() - nval_offload = set() - nfld_offload = set() else: nfld_dim = 'nfld' nval_dim = 'nval' nfld_data = {('nfld', 'global_var_analysis_header_mod')} nval_data = {('nval', 'global_var_analysis_header_mod')} - nval_offload = {'nval'} - nfld_offload = {'nfld'} expected_trafo_data = { 'global_var_analysis_header_mod': { 'declares': {f'iarr({nfld_dim})', f'rarr({nval_dim}, {nfld_dim})'}, - 'offload': {f'iarr({nfld_dim})', f'rarr({nval_dim}, {nfld_dim})'} | nval_offload | nfld_offload, + 'offload': {} }, 'global_var_analysis_data_mod': { 'declares': {'rdata(:, :, :)', 'tt'}, - 'offload': {'rdata(:, :, :)', 'tt', 'tt%vals'} + 'offload': {} }, 'global_var_analysis_data_mod#some_routine': {'defines_symbols': set(), 'uses_symbols': set()}, 'global_var_analysis_kernel_mod#kernel_a': { @@ -454,6 +450,14 @@ def test_global_variable_offload(frontend, key, config, global_variable_analysis 'driver': {'role': 'driver'} } + # OMNI handles array indices and parameters differently + if frontend == OMNI: + nfld_dim = '1:3' + nval_dim = '1:5' + else: + nfld_dim = 'nfld' + nval_dim = 'nval' + scheduler = Scheduler( paths=(global_variable_analysis_code,), config=config, seed_routines='driver', frontend=frontend, xmods=(global_variable_analysis_code,) @@ -462,6 +466,30 @@ def test_global_variable_offload(frontend, key, config, global_variable_analysis scheduler.process(GlobalVarOffloadTransformation(key=key)) driver = scheduler['#driver'].ir + if key is None: + key = GlobalVariableAnalysis._key + + expected_trafo_data = { + 'global_var_analysis_header_mod': { + 'declares': {f'iarr({nfld_dim})', f'rarr({nval_dim}, {nfld_dim})'}, + 'offload': {f'iarr({nfld_dim})', f'rarr({nval_dim}, {nfld_dim})'} + }, + 'global_var_analysis_data_mod': { + 'declares': {'rdata(:, :, :)', 'tt'}, + 'offload': {'rdata(:, :, :)', 'tt', 'tt%vals'} + }, + } + + # Verify module offload sets + for item in [scheduler['global_var_analysis_header_mod'], scheduler['global_var_analysis_data_mod']]: + for trafo_data_key, trafo_data_value in item.trafo_data[key].items(): + assert ( + sorted( + tuple(str(vv) for vv in v) if isinstance(v, tuple) else str(v) + for v in trafo_data_value + ) == sorted(expected_trafo_data[item.name][trafo_data_key]) + ) + # Verify imports have been added to the driver expected_imports = { 'global_var_analysis_header_mod': {'iarr', 'rarr'}, diff --git a/transformations/transformations/data_offload.py b/transformations/transformations/data_offload.py index cfc28e54f..dae7d3326 100644 --- a/transformations/transformations/data_offload.py +++ b/transformations/transformations/data_offload.py @@ -323,17 +323,6 @@ def _map_var_to_module(var): _map_var_to_module(var) for var in defines_imported_symbols } - # Propagate offload requirement to the items of the global variables - successors_map = CaseInsensitiveDict( - (item.name, item) for item in successors if isinstance(item, ModuleItem) - ) - for var, module in chain( - item.trafo_data[self._key]['uses_symbols'], - item.trafo_data[self._key]['defines_symbols'] - ): - if successor := successors_map.get(module): - successor.trafo_data[self._key]['offload'].add(var) - # Amend analysis data with data from successors # Note: This is a temporary workaround for the incomplete list of successor items # provided by the current scheduler implementation @@ -476,9 +465,28 @@ def transform_subroutine(self, routine, **kwargs): """ role = kwargs.get('role') successors = kwargs.get('successors', ()) + item = kwargs['item'] if role == 'driver': self.process_driver(routine, successors) + elif role == 'kernel': + self.process_kernel(item, successors) + + def process_kernel(self, item, successors): + """ + Propagate offload requirement to the items of the global variables + """ + successors_map = CaseInsensitiveDict( + (item.name, item) for item in successors if isinstance(item, ModuleItem) + ) + for var, module in chain( + item.trafo_data[self._key]['uses_symbols'], + item.trafo_data[self._key]['defines_symbols'] + ): + if var.type.parameter: + continue + if successor := successors_map.get(module): + successor.trafo_data[self._key]['offload'].add(var) def process_driver(self, routine, successors): """