Skip to content

Commit

Permalink
Merge pull request #226 from ecmwf-ifs/nams_global_var_hoisting
Browse files Browse the repository at this point in the history
Global var hoisting
  • Loading branch information
reuterbal authored Feb 9, 2024
2 parents 73c6635 + a840d1d commit ff12ecb
Show file tree
Hide file tree
Showing 2 changed files with 435 additions and 17 deletions.
170 changes: 160 additions & 10 deletions transformations/tests/test_data_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
gettempdir, Scheduler, OMNI, Import
)
from conftest import available_frontends
from transformations import DataOffloadTransformation, GlobalVariableAnalysis, GlobalVarOffloadTransformation
from transformations import (
DataOffloadTransformation, GlobalVariableAnalysis,
GlobalVarOffloadTransformation, GlobalVarHoistTransformation
)


@pytest.fixture(scope='module', name='here')
Expand Down Expand Up @@ -378,16 +381,31 @@ def test_global_variable_analysis(frontend, key, config, global_variable_analysi
key = GlobalVariableAnalysis._key

# Validate the analysis trafo_data
nfld_dim = '1:3' if frontend == OMNI else 'nfld'
nval_dim = '1:5' if frontend == OMNI else 'nval'

# OMNI handles array indices and parameters differently
if frontend == OMNI:
nfld_dim = '1:3'
nval_dim = '1:5'
nfld_data = set()
nval_data = set()
nval_offload = set()
nfld_offload = set()
else:
nfld_dim = 'nfld'
nval_dim = 'nval'
nfld_data = {('nfld', 'global_var_analysis_header_mod')}
nval_data = {('nval', 'global_var_analysis_header_mod')}
nval_offload = {'nval'}
nfld_offload = {'nfld'}

expected_trafo_data = {
'global_var_analysis_header_mod#nval': {
'declares': {f'iarr({nfld_dim})', f'rarr({nval_dim}, {nfld_dim})'},
'offload': set()
'offload': nval_offload
},
'global_var_analysis_header_mod#nfld': {
'declares': {f'iarr({nfld_dim})', f'rarr({nval_dim}, {nfld_dim})'},
'offload': set()
'offload': nfld_offload
},
'global_var_analysis_header_mod#iarr': {
'declares': {f'iarr({nfld_dim})', f'rarr({nval_dim}, {nfld_dim})'},
Expand All @@ -408,23 +426,24 @@ def test_global_variable_analysis(frontend, key, config, global_variable_analysi
'global_var_analysis_data_mod#some_routine': {'defines_symbols': set(), 'uses_symbols': set()},
'global_var_analysis_kernel_mod#kernel_a': {
'defines_symbols': set(),
'uses_symbols': {
'uses_symbols': nval_data | nfld_data | {
(f'iarr({nfld_dim})', 'global_var_analysis_header_mod'),
(f'rarr({nval_dim}, {nfld_dim})', 'global_var_analysis_header_mod')
}
},
'global_var_analysis_kernel_mod#kernel_b': {
'defines_symbols': {('rdata(:, :, :)', 'global_var_analysis_data_mod')},
'uses_symbols': {
'uses_symbols': nfld_data | {
('rdata(:, :, :)', 'global_var_analysis_data_mod'), ('tt', 'global_var_analysis_data_mod'),
('tt%vals', 'global_var_analysis_data_mod'), (f'iarr({nfld_dim})', 'global_var_analysis_header_mod')
}
},
'#driver': {
'defines_symbols': {('rdata(:, :, :)', 'global_var_analysis_data_mod')},
'uses_symbols': {
('rdata(:, :, :)', 'global_var_analysis_data_mod'), ('tt', 'global_var_analysis_data_mod'),
('tt%vals', 'global_var_analysis_data_mod'), (f'iarr({nfld_dim})', 'global_var_analysis_header_mod'),
'uses_symbols': nval_data | nfld_data | {
('rdata(:, :, :)', 'global_var_analysis_data_mod'),
('tt', 'global_var_analysis_data_mod'), ('tt%vals', 'global_var_analysis_data_mod'),
(f'iarr({nfld_dim})', 'global_var_analysis_header_mod'),
(f'rarr({nval_dim}, {nfld_dim})', 'global_var_analysis_header_mod')
}
}
Expand Down Expand Up @@ -634,3 +653,134 @@ def test_transformation_global_var_import_derived_type(here, config, frontend):
# Note: g is not offloaded because it is not used by the kernel (albeit imported)
assert 'p0' in pragmas[1].content
assert 'p_array' in pragmas[2].content


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('hoist_parameters', (False, True))
@pytest.mark.parametrize('ignore_modules', (None, ('moduleb',)))
def test_transformation_global_var_hoist(here, config, frontend, hoist_parameters, ignore_modules):
"""
Test hoisting of global variable imports.
"""
config['default']['enable_imports'] = True
config['routines'] = {
'driver': {'role': 'driver'}
}

scheduler = Scheduler(paths=here/'sources/projGlobalVarImports', config=config, frontend=frontend)
scheduler.process(transformation=GlobalVariableAnalysis())
scheduler.process(transformation=GlobalVarHoistTransformation(hoist_parameters=hoist_parameters,
ignore_modules=ignore_modules))

driver = scheduler['#driver'].routine
kernel0 = scheduler['#kernel0'].routine
kernel_map = {key: scheduler[f'#{key}'].routine for key in ['kernel1', 'kernel2', 'kernel3']}

# symbols within each module
expected_symbols = {'modulea': ['var0', 'var1'], 'moduleb': ['var2', 'var3'],
'modulec': ['var4', 'var5']}
# expected intent of those variables (if hoisted)
var_intent_map = {'var0': 'in', 'var1': 'in', 'var2': 'in',
'var3': 'in', 'var4': 'inout', 'var5': 'inout', 'tmp': None}
# DRIVER
imports = FindNodes(Import).visit(driver.spec)
import_names = [_import.module.lower() for _import in imports]
# check driver imports
expected_driver_modules = ['modulec']
expected_driver_modules += ['moduleb'] if ignore_modules is None else []
# OMNI handles parameters differently, ModuleA only contains parameters
if frontend != OMNI:
expected_driver_modules += ['modulea'] if hoist_parameters else []
assert len(imports) == len(expected_driver_modules)
assert sorted(expected_driver_modules) == sorted(import_names)
for _import in imports:
assert sorted([sym.name for sym in _import.symbols]) == expected_symbols[_import.module.lower()]
# check driver call
driver_calls = FindNodes(CallStatement).visit(driver.body)
expected_args = []
for module in expected_driver_modules:
expected_args.extend(expected_symbols[module])
assert [arg.name for arg in driver_calls[0].arguments] == sorted(expected_args)

originally = {'kernel1': ['modulea'], 'kernel2': ['moduleb'],
'kernel3': ['moduleb', 'modulec']}
# KERNEL0
assert [arg.name for arg in kernel0.arguments] == sorted(expected_args)
assert [arg.name for arg in kernel0.variables] == sorted(expected_args)
for var in kernel0.variables:
assert kernel0.variable_map[var.name.lower()].type.intent == var_intent_map[var.name.lower()]
kernel0_calls = FindNodes(CallStatement).visit(kernel0.body)
# KERNEL1 & KERNEL2 & KERNEL3
for call in kernel0_calls:
expected_args = []
expected_imports = []
kernel_expected_symbols = []
for module in originally[call.routine.name]:
# always, since at least 'some_func' is imported
if call.routine.name == 'kernel1' and module == 'modulea':
expected_imports.append(module)
kernel_expected_symbols.append('some_func')
if module in expected_driver_modules:
expected_args.extend(expected_symbols[module])
else:
# already added
if module != 'modulea':
expected_imports.append(module)
kernel_expected_symbols.extend(expected_symbols[module])
assert len(expected_args) == len(call.arguments)
assert [arg.name for arg in call.arguments] == expected_args
assert [arg.name for arg in kernel_map[call.routine.name].arguments] == expected_args
for var in kernel_map[call.routine.name].variables:
var_intent = kernel_map[call.routine.name].variable_map[var.name.lower()].type.intent
assert var_intent == var_intent_map[var.name.lower()]
if call.routine.name in ['kernel1', 'kernel2']:
expected_args = ['tmp'] + expected_args
assert [arg.name for arg in kernel_map[call.routine.name].variables] == expected_args
kernel_imports = FindNodes(Import).visit(call.routine.spec)
assert sorted([_import.module.lower() for _import in kernel_imports]) == sorted(expected_imports)
imported_symbols = [] # _import.symbols for _import in kernel_imports]
for _import in kernel_imports:
imported_symbols.extend([sym.name.lower() for sym in _import.symbols])
assert sorted(imported_symbols) == sorted(kernel_expected_symbols)


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('hoist_parameters', (False, True))
def test_transformation_global_var_derived_type_hoist(here, config, frontend, hoist_parameters):
"""
Test hoisting of derived-type global variable imports.
"""

config['default']['enable_imports'] = True
config['routines'] = {
'driver_derived_type': {'role': 'driver'}
}

scheduler = Scheduler(paths=here/'sources/projGlobalVarImports', config=config, frontend=frontend)
scheduler.process(transformation=GlobalVariableAnalysis())
scheduler.process(transformation=GlobalVarHoistTransformation(hoist_parameters))

driver = scheduler['#driver_derived_type'].routine
kernel = scheduler['#kernel_derived_type'].routine

# DRIVER
imports = FindNodes(Import).visit(driver.spec)
assert len(imports) == 1
assert imports[0].module.lower() == 'module_derived_type'
assert sorted([sym.name.lower() for sym in imports[0].symbols]) == sorted(['p', 'p_array', 'p0'])
calls = FindNodes(CallStatement).visit(driver.body)
assert len(calls) == 1
# KERNEL
assert [arg.name for arg in calls[0].arguments] == ['p', 'p0', 'p_array']
assert [arg.name for arg in kernel.arguments] == ['p', 'p0', 'p_array']
kernel_imports = FindNodes(Import).visit(kernel.spec)
assert len(kernel_imports) == 1
assert [sym.name.lower() for sym in kernel_imports[0].symbols] == ['g']
assert sorted([var.name for var in kernel.variables]) == ['i', 'j', 'p', 'p0', 'p_array']
assert kernel.variable_map['p_array'].type.allocatable
assert kernel.variable_map['p_array'].type.intent == 'inout'
assert kernel.variable_map['p_array'].type.dtype.name == 'point'
assert kernel.variable_map['p'].type.intent == 'inout'
assert kernel.variable_map['p'].type.dtype.name == 'point'
assert kernel.variable_map['p0'].type.intent == 'in'
assert kernel.variable_map['p0'].type.dtype.name == 'point'
Loading

0 comments on commit ff12ecb

Please sign in to comment.