Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Block-index injection transformations #303

Merged
merged 37 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
9770a4b
Dimension: add index alias attribute
awnawab Apr 4, 2024
925a71b
DeprivatiseStructsTrafo: first implementation
awnawab Apr 4, 2024
460b81b
BlockIndexInjectTransformation: first implementation
awnawab Apr 5, 2024
1abe60b
Change trafo name to UnprivatiseStructsTransformation
awnawab Apr 5, 2024
78e070c
SCRIPTS: add unprivatise_structs option to scripts
awnawab Apr 5, 2024
acc16d5
BlockIndexInject: cleanup and fixes
awnawab Apr 5, 2024
7d1c5f1
BlockIndexInject: fix call arg rank logic
awnawab Apr 15, 2024
95b7121
Various improvements to UnprivatiseStructsTrafo and BlockIndexInjectT…
awnawab Apr 17, 2024
4114ae6
Add configurable exclude list to unprivatise/blockinject trafos
awnawab Apr 18, 2024
f67806d
SCRIPTS: read unprivatise pipeline from config
awnawab Apr 18, 2024
e60eca6
Adapt unprivatise/blockinject trafos so they can also be run on drive…
awnawab Apr 19, 2024
9eacc8e
Rename file to block_index_transformations.py
awnawab Apr 22, 2024
7ffb876
BlockViewToFieldViewTrafo: add documentation
awnawab Apr 22, 2024
2712538
BlockIndexInjectTrafo: add documentation
awnawab Apr 22, 2024
2fcdc08
Change name of loki-transform script arg
awnawab Apr 22, 2024
dd8f0d4
Make exclude_arrays configurable via the scheduler
awnawab Apr 24, 2024
5ce3b84
BlockViewtoFieldViewTrafo: make switch to global gfl_ptr optional
awnawab Apr 24, 2024
a750450
BlockViewtoFieldViewTrafo: add bailout for routines marked as seq
awnawab Apr 24, 2024
1d474ea
Dimension: add index_expressions property
awnawab Apr 24, 2024
964002b
BlockIndexInjectTrafos: add tests
awnawab Apr 24, 2024
b58be6d
Appease pylint
awnawab Apr 29, 2024
facaca9
BlockIndexInjectTrafo: add support for call statement kwargs
awnawab May 10, 2024
4fc6af1
BlockIndexTrafos: fix typos and cover untested lines
awnawab May 10, 2024
729120c
Rebase cleanup
awnawab May 10, 2024
9a8df04
InjectBlockIndexTransformation: rename trafo
awnawab May 10, 2024
d31043a
block_index_trafos: remove trafo_data key from constructor args
awnawab May 10, 2024
7e67321
Enrich derived-type variable declarations
awnawab May 11, 2024
07366c0
BlockViewToFieldViewTrafo: update documentation
awnawab May 11, 2024
f30b376
block_index_trafos: misc cleanup
awnawab May 11, 2024
8787f0d
SCRIPTS: remove BLOCKVIEW_TO_FIELDVIEW CLI arg
awnawab May 11, 2024
212677b
Angry linting gods
awnawab May 11, 2024
e71adc4
SCRIPTS: sanitize mode string for custom pipelines
awnawab May 12, 2024
1e05c1b
Change FindVariables searches to non-unique to restore SL functionality
awnawab May 12, 2024
2c01d69
Add a test for derived type enrichment
reuterbal May 13, 2024
998d1f4
Fix enrichment of derived type dtypes for local variables
reuterbal May 13, 2024
9ae7b50
block_index_trafos: change scope of test config fixture to function
awnawab May 13, 2024
6707094
Appease extra fussy upgraded linter
awnawab May 27, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion loki/backend/maxgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ class <name> extends Kernel {

# Class signature
if is_manager:
if is_interface:
if is_interface: # pylint: disable=possibly-used-before-assignment
header += [self.format_line(
'public interface ', o.name, ' extends ManagerPCIe, ManagerKernel {')]
else:
Expand Down
18 changes: 17 additions & 1 deletion loki/dimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,19 @@ class Dimension:
bounds_aliases : list or tuple of strings
String representations of alternative bounds variables that are
used to define loop ranges.
index_aliases : list or tuple of strings
String representations of alternative loop index variables associated
with this dimension.
"""

def __init__(self, name=None, index=None, bounds=None, size=None, aliases=None,
bounds_aliases=None):
bounds_aliases=None, index_aliases=None):
self.name = name
self._index = index
self._bounds = as_tuple(bounds)
self._size = size
self._aliases = as_tuple(aliases)
self._index_aliases = as_tuple(index_aliases)

if bounds_aliases:
if len(bounds_aliases) != 2:
Expand Down Expand Up @@ -118,3 +122,15 @@ def bounds_expressions(self):
exprs = [expr + (b,) for expr, b in zip(exprs, self._bounds_aliases)]

return as_tuple(exprs)

@property
def index_expressions(self):
"""
A list of all expression strings representing the index expression of an iteration space (loop).
"""

exprs = [self.index,]
if self._index_aliases:
exprs += list(self._index_aliases)

return as_tuple(exprs)
1 change: 1 addition & 0 deletions loki/frontend/fparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,7 @@ def visit_Char_Selector(self, o, **kwargs):
* some scalar expression for the kind
"""
length = None
kind = None
if o.children[0] is not None:
length = self.visit(o.children[0], **kwargs)
if o.children[1] is not None:
Expand Down
9 changes: 8 additions & 1 deletion loki/program_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def from_source(cls, source, definitions=None, preprocess=False,
if frontend == Frontend.OFP:
ast = parse_ofp_source(source)
return cls.from_ofp(ast=ast, raw_source=source, definitions=definitions,
pp_info=pp_info, parent=parent)
pp_info=pp_info, parent=parent) # pylint: disable=possibly-used-before-assignment

if frontend == Frontend.FP:
ast = parse_fparser_source(source)
Expand Down Expand Up @@ -361,6 +361,13 @@ def enrich(self, definitions, recurse=False):
updated_symbol_attrs[local_name] = symbol.type.clone(
dtype=remote_node.dtype, imported=True, module=module
)
# Update dtype for local variables using this type
variables_with_this_type = {
name: type_.clone(dtype=remote_node.dtype)
for name, type_ in self.symbol_attrs.items()
if getattr(type_.dtype, 'name') == remote_node.dtype.name
}
updated_symbol_attrs.update(variables_with_this_type)
elif hasattr(remote_node, 'type'):
# This is a global variable or interface import
updated_symbol_attrs[local_name] = remote_node.type.clone(
Expand Down
99 changes: 79 additions & 20 deletions loki/tests/test_subroutine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,16 @@

from loki import (
Sourcefile, Module, Subroutine, FindVariables, FindNodes, Section,
CallStatement, BasicType, Array, Scalar, Variable,
Array, Scalar, Variable,
SymbolAttributes, StringLiteral, fgen, fexprgen,
VariableDeclaration, Transformer, FindTypedSymbols,
ProcedureSymbol, ProcedureType, StatementFunction,
normalize_range_indexing, DeferredTypeSymbol, Assignment,
Interface
ProcedureSymbol, StatementFunction,
normalize_range_indexing, DeferredTypeSymbol
)
from loki.build import jit_compile, jit_compile_lib, clean_test
from loki.frontend import available_frontends, OFP, OMNI, REGEX
from loki.types import BasicType, DerivedType, ProcedureType
from loki.ir import nodes as ir


@pytest.fixture(scope='module', name='here')
Expand Down Expand Up @@ -767,7 +768,7 @@ def test_routine_call_arrays(header_path, frontend):
"""
header = Sourcefile.from_file(header_path, frontend=frontend)['header']
routine = Subroutine.from_source(fcode, frontend=frontend, definitions=header)
call = FindNodes(CallStatement).visit(routine.body)[0]
call = FindNodes(ir.CallStatement).visit(routine.body)[0]

assert str(call.arguments[0]) == 'x'
assert str(call.arguments[1]) == 'y'
Expand Down Expand Up @@ -797,7 +798,7 @@ def test_call_no_arg(frontend):
call abort
end subroutine routine_call_no_arg
""")
calls = FindNodes(CallStatement).visit(routine.body)
calls = FindNodes(ir.CallStatement).visit(routine.body)
assert len(calls) == 1
assert calls[0].arguments == ()
assert calls[0].kwarguments == ()
Expand All @@ -813,7 +814,7 @@ def test_call_kwargs(frontend):
call mpl_init(kprocs=kprocs, cdstring='routine_call_kwargs')
end subroutine routine_call_kwargs
""")
calls = FindNodes(CallStatement).visit(routine.body)
calls = FindNodes(ir.CallStatement).visit(routine.body)
assert len(calls) == 1
assert calls[0].name == 'mpl_init'

Expand All @@ -838,7 +839,7 @@ def test_call_args_kwargs(frontend):
call mpl_send(pbuf, ktag, kdest, cdstring='routine_call_args_kwargs')
end subroutine routine_call_args_kwargs
""")
calls = FindNodes(CallStatement).visit(routine.body)
calls = FindNodes(ir.CallStatement).visit(routine.body)
assert len(calls) == 1
assert calls[0].name == 'mpl_send'
assert len(calls[0].arguments) == 3
Expand Down Expand Up @@ -1520,7 +1521,7 @@ def test_subroutine_stmt_func(here, frontend):
routine.name += f'_{frontend!s}'

# Make sure the statement function injection doesn't invalidate source
for assignment in FindNodes(Assignment).visit(routine.body):
for assignment in FindNodes(ir.Assignment).visit(routine.body):
assert assignment.source is not None

# OMNI inlines statement functions, so we can only check correct representation
Expand Down Expand Up @@ -1958,7 +1959,7 @@ def test_subroutine_clone_contained(frontend):
kernels = driver.subroutines

def _verify_call_enrichment(driver_, kernels_):
calls = FindNodes(CallStatement).visit(driver_.body)
calls = FindNodes(ir.CallStatement).visit(driver_.body)
assert len(calls) == 2

for call in calls:
Expand Down Expand Up @@ -2048,12 +2049,12 @@ def test_enrich_explicit_interface(frontend):
driver.enrich(kernel)

# check if call is enriched correctly
calls = FindNodes(CallStatement).visit(driver.body)
calls = FindNodes(ir.CallStatement).visit(driver.body)
assert calls[0].routine is kernel

# check if the procedure symbol in the interface block has been removed from
# driver's symbol table
intfs = FindNodes(Interface).visit(driver.spec)
intfs = FindNodes(ir.Interface).visit(driver.spec)
assert not intfs[0].body[0].parent

# check that call still points to correct subroutine
Expand All @@ -2065,6 +2066,64 @@ def test_enrich_explicit_interface(frontend):
assert calls[0].routine is kernel


@pytest.mark.parametrize('frontend', available_frontends())
def test_enrich_derived_types(tmp_path, frontend):
fcode = """
subroutine enrich_derived_types_routine(yda_array)
use field_array_module, only : field_3rb_array
implicit none
type(field_3rb_array), intent(inout) :: yda_array
yda_array%p = 0.
end subroutine enrich_derived_types_routine
""".strip()

fcode_module = """
module field_array_module
implicit none
type field_3rb_array
real, pointer :: p(:,:,:)
end type field_3rb_array
end module field_array_module
""".strip()

module = Module.from_source(fcode_module, frontend=frontend, xmods=[tmp_path])
routine = Subroutine.from_source(fcode, frontend=frontend, xmods=[tmp_path])

# The derived type is a dangling import
field_3rb_symbol = routine.symbol_map['field_3rb_array']
assert field_3rb_symbol.type.imported
assert field_3rb_symbol.type.module is None
assert field_3rb_symbol.type.dtype is BasicType.DEFERRED

# The variable type is recognized as a derived type but without enrichment
yda_array = routine.variable_map['yda_array']
assert isinstance(yda_array.type.dtype, DerivedType)
assert routine.variable_map['yda_array'].type.dtype.typedef is BasicType.DEFERRED

# The pointer member has no type information
yda_array_p = routine.resolve_typebound_var('yda_array%p')
assert yda_array_p.type.dtype is BasicType.DEFERRED
assert yda_array_p.type.shape is None

# Pick out the typedef (before enrichment to validate object consistency)
field_3rb_tdef = module['field_3rb_array']
assert isinstance(field_3rb_tdef, ir.TypeDef)

# Enrich the routine with module definitions
routine.enrich(module)

# Ensure the imported type symbol is correctly enriched
assert field_3rb_symbol.type.imported
assert field_3rb_symbol.type.module is module
assert isinstance(field_3rb_symbol.type.dtype, DerivedType)

# Ensure the information has been propagated to other variables
assert isinstance(yda_array.type.dtype, DerivedType)
assert yda_array.type.dtype.typedef is field_3rb_tdef
assert yda_array_p.type.dtype is BasicType.REAL
assert yda_array_p.type.shape == (':', ':', ':')


@pytest.mark.parametrize('frontend', available_frontends(
xfail=[(OMNI, 'OMNI cannot handle external type defs without source')]
))
Expand Down Expand Up @@ -2099,15 +2158,15 @@ def test_subroutine_deep_clone(frontend):

# Replace all assignments with dummy calls
map_nodes={}
for assign in FindNodes(Assignment).visit(new_routine.body):
map_nodes[assign] = CallStatement(
for assign in FindNodes(ir.Assignment).visit(new_routine.body):
map_nodes[assign] = ir.CallStatement(
name=DeferredTypeSymbol(name='testcall'), arguments=(assign.lhs,), scope=new_routine
)
new_routine.body = Transformer(map_nodes).visit(new_routine.body)

# Ensure that the original copy of the routine remains unaffected
assert len(FindNodes(Assignment).visit(routine.body)) == 3
assert len(FindNodes(Assignment).visit(new_routine.body)) == 0
assert len(FindNodes(ir.Assignment).visit(routine.body)) == 3
assert len(FindNodes(ir.Assignment).visit(new_routine.body)) == 0

@pytest.mark.parametrize('frontend', available_frontends())
def test_call_args_kwargs_conversion(frontend):
Expand Down Expand Up @@ -2162,20 +2221,20 @@ def test_call_args_kwargs_conversion(frontend):
len_kwargs = (0, 7, 7, 2)

# sort kwargs
for i_call, call in enumerate(FindNodes(CallStatement).visit(driver.body)):
for i_call, call in enumerate(FindNodes(ir.CallStatement).visit(driver.body)):
assert call.check_kwarguments_order() == kwargs_in_order[i_call]
call.sort_kwarguments()

# check calls with sorted kwargs
for i_call, call in enumerate(FindNodes(CallStatement).visit(driver.body)):
for i_call, call in enumerate(FindNodes(ir.CallStatement).visit(driver.body)):
assert tuple(arg[1].name for arg in call.arg_iter()) == call_args
assert len(call.kwarguments) == len_kwargs[i_call]

# kwarg to arg conversion
for call in FindNodes(CallStatement).visit(driver.body):
for call in FindNodes(ir.CallStatement).visit(driver.body):
call.convert_kwargs_to_args()

# check calls with kwargs converted to args
for call in FindNodes(CallStatement).visit(driver.body):
for call in FindNodes(ir.CallStatement).visit(driver.body):
assert tuple(arg.name for arg in call.arguments) == call_args
assert call.kwarguments == ()
1 change: 1 addition & 0 deletions loki/transformations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,4 @@
from loki.transformations.transform_region import * # noqa
from loki.transformations.pool_allocator import * # noqa
from loki.transformations.utilities import * # noqa
from loki.transformations.block_index_transformations import * # noqa
Loading
Loading