Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Global var hoisting #226

Merged
merged 10 commits into from
Feb 9, 2024
170 changes: 160 additions & 10 deletions transformations/tests/test_data_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
gettempdir, Scheduler, OMNI, Import
)
from conftest import available_frontends
from transformations import DataOffloadTransformation, GlobalVariableAnalysis, GlobalVarOffloadTransformation
from transformations import (
DataOffloadTransformation, GlobalVariableAnalysis,
GlobalVarOffloadTransformation, GlobalVarHoistTransformation
)


@pytest.fixture(scope='module', name='here')
Expand Down Expand Up @@ -378,16 +381,31 @@ def test_global_variable_analysis(frontend, key, config, global_variable_analysi
key = GlobalVariableAnalysis._key

# Validate the analysis trafo_data
nfld_dim = '1:3' if frontend == OMNI else 'nfld'
nval_dim = '1:5' if frontend == OMNI else 'nval'

# OMNI handles array indices and parameters differently
if frontend == OMNI:
nfld_dim = '1:3'
nval_dim = '1:5'
nfld_data = set()
nval_data = set()
nval_offload = set()
nfld_offload = set()
else:
nfld_dim = 'nfld'
nval_dim = 'nval'
nfld_data = {('nfld', 'global_var_analysis_header_mod')}
nval_data = {('nval', 'global_var_analysis_header_mod')}
nval_offload = {'nval'}
nfld_offload = {'nfld'}

expected_trafo_data = {
'global_var_analysis_header_mod#nval': {
'declares': {f'iarr({nfld_dim})', f'rarr({nval_dim}, {nfld_dim})'},
'offload': set()
'offload': nval_offload
},
'global_var_analysis_header_mod#nfld': {
'declares': {f'iarr({nfld_dim})', f'rarr({nval_dim}, {nfld_dim})'},
'offload': set()
'offload': nfld_offload
},
'global_var_analysis_header_mod#iarr': {
'declares': {f'iarr({nfld_dim})', f'rarr({nval_dim}, {nfld_dim})'},
Expand All @@ -408,23 +426,24 @@ def test_global_variable_analysis(frontend, key, config, global_variable_analysi
'global_var_analysis_data_mod#some_routine': {'defines_symbols': set(), 'uses_symbols': set()},
'global_var_analysis_kernel_mod#kernel_a': {
'defines_symbols': set(),
'uses_symbols': {
'uses_symbols': nval_data | nfld_data | {
(f'iarr({nfld_dim})', 'global_var_analysis_header_mod'),
(f'rarr({nval_dim}, {nfld_dim})', 'global_var_analysis_header_mod')
}
},
'global_var_analysis_kernel_mod#kernel_b': {
'defines_symbols': {('rdata(:, :, :)', 'global_var_analysis_data_mod')},
'uses_symbols': {
'uses_symbols': nfld_data | {
('rdata(:, :, :)', 'global_var_analysis_data_mod'), ('tt', 'global_var_analysis_data_mod'),
('tt%vals', 'global_var_analysis_data_mod'), (f'iarr({nfld_dim})', 'global_var_analysis_header_mod')
}
},
'#driver': {
'defines_symbols': {('rdata(:, :, :)', 'global_var_analysis_data_mod')},
'uses_symbols': {
('rdata(:, :, :)', 'global_var_analysis_data_mod'), ('tt', 'global_var_analysis_data_mod'),
('tt%vals', 'global_var_analysis_data_mod'), (f'iarr({nfld_dim})', 'global_var_analysis_header_mod'),
'uses_symbols': nval_data | nfld_data | {
('rdata(:, :, :)', 'global_var_analysis_data_mod'),
('tt', 'global_var_analysis_data_mod'), ('tt%vals', 'global_var_analysis_data_mod'),
(f'iarr({nfld_dim})', 'global_var_analysis_header_mod'),
(f'rarr({nval_dim}, {nfld_dim})', 'global_var_analysis_header_mod')
}
}
Expand Down Expand Up @@ -634,3 +653,134 @@ def test_transformation_global_var_import_derived_type(here, config, frontend):
# Note: g is not offloaded because it is not used by the kernel (albeit imported)
assert 'p0' in pragmas[1].content
assert 'p_array' in pragmas[2].content


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('hoist_parameters', (False, True))
@pytest.mark.parametrize('ignore_modules', (None, ('moduleb',)))
def test_transformation_global_var_hoist(here, config, frontend, hoist_parameters, ignore_modules):
"""
Test hoisting of global variable imports.
"""
config['default']['enable_imports'] = True
config['routines'] = {
'driver': {'role': 'driver'}
}

scheduler = Scheduler(paths=here/'sources/projGlobalVarImports', config=config, frontend=frontend)
scheduler.process(transformation=GlobalVariableAnalysis())
scheduler.process(transformation=GlobalVarHoistTransformation(hoist_parameters=hoist_parameters,
ignore_modules=ignore_modules))

driver = scheduler['#driver'].routine
kernel0 = scheduler['#kernel0'].routine
kernel_map = {key: scheduler[f'#{key}'].routine for key in ['kernel1', 'kernel2', 'kernel3']}

# symbols within each module
expected_symbols = {'modulea': ['var0', 'var1'], 'moduleb': ['var2', 'var3'],
'modulec': ['var4', 'var5']}
# expected intent of those variables (if hoisted)
var_intent_map = {'var0': 'in', 'var1': 'in', 'var2': 'in',
'var3': 'in', 'var4': 'inout', 'var5': 'inout', 'tmp': None}
# DRIVER
imports = FindNodes(Import).visit(driver.spec)
import_names = [_import.module.lower() for _import in imports]
# check driver imports
expected_driver_modules = ['modulec']
expected_driver_modules += ['moduleb'] if ignore_modules is None else []
# OMNI handles parameters differently, ModuleA only contains parameters
if frontend != OMNI:
expected_driver_modules += ['modulea'] if hoist_parameters else []
reuterbal marked this conversation as resolved.
Show resolved Hide resolved
assert len(imports) == len(expected_driver_modules)
assert sorted(expected_driver_modules) == sorted(import_names)
for _import in imports:
assert sorted([sym.name for sym in _import.symbols]) == expected_symbols[_import.module.lower()]
# check driver call
driver_calls = FindNodes(CallStatement).visit(driver.body)
expected_args = []
for module in expected_driver_modules:
expected_args.extend(expected_symbols[module])
assert [arg.name for arg in driver_calls[0].arguments] == sorted(expected_args)

originally = {'kernel1': ['modulea'], 'kernel2': ['moduleb'],
'kernel3': ['moduleb', 'modulec']}
# KERNEL0
assert [arg.name for arg in kernel0.arguments] == sorted(expected_args)
assert [arg.name for arg in kernel0.variables] == sorted(expected_args)
for var in kernel0.variables:
assert kernel0.variable_map[var.name.lower()].type.intent == var_intent_map[var.name.lower()]
kernel0_calls = FindNodes(CallStatement).visit(kernel0.body)
# KERNEL1 & KERNEL2 & KERNEL3
for call in kernel0_calls:
expected_args = []
expected_imports = []
kernel_expected_symbols = []
for module in originally[call.routine.name]:
# always, since at least 'some_func' is imported
if call.routine.name == 'kernel1' and module == 'modulea':
expected_imports.append(module)
kernel_expected_symbols.append('some_func')
if module in expected_driver_modules:
expected_args.extend(expected_symbols[module])
else:
# already added
if module != 'modulea':
expected_imports.append(module)
kernel_expected_symbols.extend(expected_symbols[module])
assert len(expected_args) == len(call.arguments)
assert [arg.name for arg in call.arguments] == expected_args
assert [arg.name for arg in kernel_map[call.routine.name].arguments] == expected_args
for var in kernel_map[call.routine.name].variables:
var_intent = kernel_map[call.routine.name].variable_map[var.name.lower()].type.intent
assert var_intent == var_intent_map[var.name.lower()]
if call.routine.name in ['kernel1', 'kernel2']:
expected_args = ['tmp'] + expected_args
assert [arg.name for arg in kernel_map[call.routine.name].variables] == expected_args
kernel_imports = FindNodes(Import).visit(call.routine.spec)
assert sorted([_import.module.lower() for _import in kernel_imports]) == sorted(expected_imports)
imported_symbols = [] # _import.symbols for _import in kernel_imports]
for _import in kernel_imports:
imported_symbols.extend([sym.name.lower() for sym in _import.symbols])
assert sorted(imported_symbols) == sorted(kernel_expected_symbols)


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('hoist_parameters', (False, True))
def test_transformation_global_var_derived_type_hoist(here, config, frontend, hoist_parameters):
"""
Test hoisting of derived-type global variable imports.
"""

config['default']['enable_imports'] = True
config['routines'] = {
'driver_derived_type': {'role': 'driver'}
}

scheduler = Scheduler(paths=here/'sources/projGlobalVarImports', config=config, frontend=frontend)
scheduler.process(transformation=GlobalVariableAnalysis())
scheduler.process(transformation=GlobalVarHoistTransformation(hoist_parameters))

driver = scheduler['#driver_derived_type'].routine
kernel = scheduler['#kernel_derived_type'].routine

# DRIVER
imports = FindNodes(Import).visit(driver.spec)
assert len(imports) == 1
assert imports[0].module.lower() == 'module_derived_type'
assert sorted([sym.name.lower() for sym in imports[0].symbols]) == sorted(['p', 'p_array', 'p0'])
calls = FindNodes(CallStatement).visit(driver.body)
assert len(calls) == 1
# KERNEL
assert [arg.name for arg in calls[0].arguments] == ['p', 'p0', 'p_array']
assert [arg.name for arg in kernel.arguments] == ['p', 'p0', 'p_array']
kernel_imports = FindNodes(Import).visit(kernel.spec)
assert len(kernel_imports) == 1
assert [sym.name.lower() for sym in kernel_imports[0].symbols] == ['g']
assert sorted([var.name for var in kernel.variables]) == ['i', 'j', 'p', 'p0', 'p_array']
assert kernel.variable_map['p_array'].type.allocatable
assert kernel.variable_map['p_array'].type.intent == 'inout'
assert kernel.variable_map['p_array'].type.dtype.name == 'point'
assert kernel.variable_map['p'].type.intent == 'inout'
assert kernel.variable_map['p'].type.dtype.name == 'point'
assert kernel.variable_map['p0'].type.intent == 'in'
assert kernel.variable_map['p0'].type.dtype.name == 'point'
Loading
Loading