Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skip driver routine in GlobalVariableAnalysis #265

Merged
merged 3 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 34 additions & 6 deletions transformations/tests/test_data_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,24 +388,20 @@ def test_global_variable_analysis(frontend, key, config, global_variable_analysi
nval_dim = '1:5'
nfld_data = set()
nval_data = set()
nval_offload = set()
nfld_offload = set()
else:
nfld_dim = 'nfld'
nval_dim = 'nval'
nfld_data = {('nfld', 'global_var_analysis_header_mod')}
nval_data = {('nval', 'global_var_analysis_header_mod')}
nval_offload = {'nval'}
nfld_offload = {'nfld'}

expected_trafo_data = {
'global_var_analysis_header_mod': {
'declares': {f'iarr({nfld_dim})', f'rarr({nval_dim}, {nfld_dim})'},
'offload': {f'iarr({nfld_dim})', f'rarr({nval_dim}, {nfld_dim})'} | nval_offload | nfld_offload,
'offload': {}
},
'global_var_analysis_data_mod': {
'declares': {'rdata(:, :, :)', 'tt'},
'offload': {'rdata(:, :, :)', 'tt', 'tt%vals'}
'offload': {}
},
'global_var_analysis_data_mod#some_routine': {'defines_symbols': set(), 'uses_symbols': set()},
'global_var_analysis_kernel_mod#kernel_a': {
Expand Down Expand Up @@ -454,6 +450,14 @@ def test_global_variable_offload(frontend, key, config, global_variable_analysis
'driver': {'role': 'driver'}
}

# OMNI handles array indices and parameters differently
if frontend == OMNI:
nfld_dim = '1:3'
nval_dim = '1:5'
else:
nfld_dim = 'nfld'
nval_dim = 'nval'

scheduler = Scheduler(
paths=(global_variable_analysis_code,), config=config, seed_routines='driver',
frontend=frontend, xmods=(global_variable_analysis_code,)
Expand All @@ -462,6 +466,30 @@ def test_global_variable_offload(frontend, key, config, global_variable_analysis
scheduler.process(GlobalVarOffloadTransformation(key=key))
driver = scheduler['#driver'].ir

if key is None:
key = GlobalVariableAnalysis._key

expected_trafo_data = {
'global_var_analysis_header_mod': {
'declares': {f'iarr({nfld_dim})', f'rarr({nval_dim}, {nfld_dim})'},
'offload': {f'iarr({nfld_dim})', f'rarr({nval_dim}, {nfld_dim})'}
},
'global_var_analysis_data_mod': {
'declares': {'rdata(:, :, :)', 'tt'},
'offload': {'rdata(:, :, :)', 'tt', 'tt%vals'}
},
}

# Verify module offload sets
for item in [scheduler['global_var_analysis_header_mod'], scheduler['global_var_analysis_data_mod']]:
for trafo_data_key, trafo_data_value in item.trafo_data[key].items():
assert (
sorted(
tuple(str(vv) for vv in v) if isinstance(v, tuple) else str(v)
for v in trafo_data_value
) == sorted(expected_trafo_data[item.name][trafo_data_key])
)

# Verify imports have been added to the driver
expected_imports = {
'global_var_analysis_header_mod': {'iarr', 'rarr'},
Expand Down
30 changes: 19 additions & 11 deletions transformations/transformations/data_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,17 +323,6 @@ def _map_var_to_module(var):
_map_var_to_module(var) for var in defines_imported_symbols
}

# Propagate offload requirement to the items of the global variables
successors_map = CaseInsensitiveDict(
(item.name, item) for item in successors if isinstance(item, ModuleItem)
)
for var, module in chain(
item.trafo_data[self._key]['uses_symbols'],
item.trafo_data[self._key]['defines_symbols']
):
if successor := successors_map.get(module):
successor.trafo_data[self._key]['offload'].add(var)

# Amend analysis data with data from successors
# Note: This is a temporary workaround for the incomplete list of successor items
# provided by the current scheduler implementation
Expand Down Expand Up @@ -476,9 +465,28 @@ def transform_subroutine(self, routine, **kwargs):
"""
role = kwargs.get('role')
successors = kwargs.get('successors', ())
item = kwargs['item']

if role == 'driver':
self.process_driver(routine, successors)
elif role == 'kernel':
self.process_kernel(item, successors)

def process_kernel(self, item, successors):
"""
Propagate offload requirement to the items of the global variables
"""
successors_map = CaseInsensitiveDict(
(item.name, item) for item in successors if isinstance(item, ModuleItem)
)
for var, module in chain(
item.trafo_data[self._key]['uses_symbols'],
item.trafo_data[self._key]['defines_symbols']
):
if var.type.parameter:
continue
if successor := successors_map.get(module):
successor.trafo_data[self._key]['offload'].add(var)

def process_driver(self, routine, successors):
"""
Expand Down
Loading