Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

assumed size array handling for 'normalize_array_shape_and_access' #218

Merged
merged 3 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions loki/transform/transform_array_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,15 +502,22 @@ def normalize_array_shape_and_access(routine):
"""
Shift all arrays to start counting at "1"
"""
def is_range_index(dim):
return isinstance(dim, sym.RangeIndex) and not dim.lower == 1
def is_explicit_range_index(dim):
# return False if assumed sized array or lower dimension equals to 1
# return (isinstance(dim, sym.RangeIndex) and not dim.lower == 1 and not dim is None
# and not dim.lower is None and not dim.upper is None)
return (isinstance(dim, sym.RangeIndex)
and not (dim.lower == 1 or dim.lower is None or dim.upper is None))

vmap = {}
for v in FindVariables(unique=False).visit(routine.body):
if isinstance(v, sym.Array):
# skip if e.g., `array(len)`, passed as `call routine(array)`
if not v.dimensions:
continue
new_dims = []
for i, d in enumerate(v.shape):
if isinstance(d, sym.RangeIndex):
if is_explicit_range_index(d):
if isinstance(v.dimensions[i], sym.RangeIndex):
start = simplify(v.dimensions[i].start - d.start + 1) if d.start is not None else None
stop = simplify(v.dimensions[i].stop - d.start + 1) if d.stop is not None else None
Expand All @@ -520,16 +527,17 @@ def is_range_index(dim):
new_dims += [start]
else:
new_dims += [v.dimensions[i]]
vmap[v] = v.clone(dimensions=as_tuple(new_dims))
if new_dims:
vmap[v] = v.clone(dimensions=as_tuple(new_dims))
routine.body = SubstituteExpressions(vmap).visit(routine.body)

vmap = {}
for v in routine.variables:
if isinstance(v, sym.Array):
new_dims = [sym.RangeIndex((1, simplify(d.upper - d.lower + 1)))
if is_range_index(d) else d for d in v.dimensions]
if is_explicit_range_index(d) else d for d in v.dimensions]
new_shape = [sym.RangeIndex((1, simplify(d.upper - d.lower + 1)))
if is_range_index(d) else d for d in v.shape]
if is_explicit_range_index(d) else d for d in v.shape]
new_type = v.type.clone(shape=as_tuple(new_shape))
vmap[v] = v.clone(dimensions=as_tuple(new_dims), type=new_type)
routine.variables = [vmap.get(v, v) for v in routine.variables]
Expand Down
79 changes: 57 additions & 22 deletions tests/test_transform_array_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy as np

from conftest import jit_compile, jit_compile_lib, clean_test, available_frontends
from loki import Subroutine, FindVariables, Array
from loki import Module, Subroutine, FindVariables, Array
from loki.expression import symbols as sym
from loki.transform import (
promote_variables, demote_variables, normalize_range_indexing,
Expand Down Expand Up @@ -330,12 +330,19 @@ def test_transform_demote_dimension_arguments(here, frontend):
@pytest.mark.parametrize('start_index', (0, 1, 5))
def test_transform_normalize_array_shape_and_access(here, frontend, start_index):
"""
Test flattening or arrays, meaning converting multi-dimensional
arrays to one-dimensional arrays including corresponding
index arithmetic.
Test normalization of array shape and access, thus changing arrays with start
index different than "1" to have start index "1".

E.g., ``x1(5:len)`` -> ```x1(1:len-4)``
"""
fcode = f"""
subroutine transform_normalize_array_shape_and_access(x1, x2, x3, x4, l1, l2, l3, l4)
module transform_normalize_array_shape_and_access_mod
implicit none

contains

subroutine transform_normalize_array_shape_and_access(x1, x2, x3, x4, assumed_x1, l1, l2, l3, l4)
! use nested_routine_mod, only : nested_routine
implicit none
integer :: i1, i2, i3, i4, c1, c2, c3, c4
integer, intent(in) :: l1, l2, l3, l4
Expand All @@ -347,10 +354,15 @@ def test_transform_normalize_array_shape_and_access(here, frontend, start_index)
integer, intent(inout) :: x4({start_index}:l4+{start_index}-1, &
& {start_index}:l3+{start_index}-1, {start_index}:l2+{start_index}-1, &
& {start_index}:l1+{start_index}-1)
integer, intent(inout) :: assumed_x1(l1)
c1 = 1
c2 = 1
c3 = 1
c4 = 1
do i1=1,l1
assumed_x1(i1) = c1
call nested_routine(assumed_x1, l1, c1)
end do
x1({start_index}:l4+{start_index}-1) = 0
do i1={start_index},l1+{start_index}-1
x1(i1) = c1
Expand All @@ -368,15 +380,28 @@ def test_transform_normalize_array_shape_and_access(here, frontend, start_index)
end do
c1 = c1 + 1
end do

end subroutine transform_normalize_array_shape_and_access

subroutine nested_routine(nested_x1, l1, c1)
implicit none
integer, intent(in) :: l1, c1
integer, intent(inout) :: nested_x1(:)
integer :: i1
do i1=1,l1
nested_x1(i1) = c1
end do
end subroutine nested_routine

end module transform_normalize_array_shape_and_access_mod
"""

def init_arguments(l1, l2, l3, l4):
x1 = np.zeros(shape=(l1,), order='F', dtype=np.int32)
assumed_x1 = np.zeros(shape=(l1,), order='F', dtype=np.int32)
x2 = np.zeros(shape=(l2,l1,), order='F', dtype=np.int32)
x3 = np.zeros(shape=(l3,l2,l1,), order='F', dtype=np.int32)
x4 = np.zeros(shape=(l4,l3,l2,l1,), order='F', dtype=np.int32)
return x1, x2, x3, x4
return x1, x2, x3, x4, assumed_x1

def validate_routine(routine):
arrays = [var for var in FindVariables().visit(routine.body) if isinstance(var, Array)]
Expand All @@ -387,30 +412,40 @@ def validate_routine(routine):
l2 = 3
l3 = 4
l4 = 5
routine = Subroutine.from_source(fcode, frontend=frontend)
normalize_range_indexing(routine) # Fix OMNI nonsense
filepath = here/(f'{routine.name}_{frontend}.f90')
function = jit_compile(routine, filepath=filepath, objname=routine.name)
orig_x1, orig_x2, orig_x3, orig_x4 = init_arguments(l1, l2, l3, l4)
function(orig_x1, orig_x2, orig_x3, orig_x4, l1, l2, l3, l4)
module = Module.from_source(fcode, frontend=frontend)
for routine in module.routines:
normalize_range_indexing(routine) # Fix OMNI nonsense
filepath = here/(f'transform_normalize_array_shape_and_access_{frontend}.f90')
# compile and test "original" module/function
mod = jit_compile(module, filepath=filepath, objname='transform_normalize_array_shape_and_access_mod')
function = getattr(mod, 'transform_normalize_array_shape_and_access')
orig_x1, orig_x2, orig_x3, orig_x4, orig_assumed_x1 = init_arguments(l1, l2, l3, l4)
function(orig_x1, orig_x2, orig_x3, orig_x4, orig_assumed_x1, l1, l2, l3, l4)
clean_test(filepath)

routine = Subroutine.from_source(fcode, frontend=frontend)
normalize_array_shape_and_access(routine)
normalize_range_indexing(routine)
filepath = here/(f'{routine.name}_normalized_{frontend}.f90')
function = jit_compile(routine, filepath=filepath, objname=routine.name)
x1, x2, x3, x4 = init_arguments(l1, l2, l3, l4)
function(x1, x2, x3, x4, l1, l2, l3, l4)
validate_routine(routine)
# apply `normalize_array_shape_and_access`
for routine in module.routines:
normalize_array_shape_and_access(routine)

filepath = here/(f'transform_normalize_array_shape_and_access_normalized_{frontend}.f90')
# compile and test "normalized" module/function
mod = jit_compile(module, filepath=filepath, objname='transform_normalize_array_shape_and_access_mod')
function = getattr(mod, 'transform_normalize_array_shape_and_access')
x1, x2, x3, x4, assumed_x1 = init_arguments(l1, l2, l3, l4)
function(x1, x2, x3, x4, assumed_x1, l1, l2, l3, l4)
clean_test(filepath)
# validate the routine "transform_normalize_array_shape_and_access"
validate_routine(module.subroutines[0])
# validate the nested routine to see whether the assumed size array got correctly handled
assert module.subroutines[1].variable_map['nested_x1'] == 'nested_x1(:)'

# check whether results generated by the "original" and "normalized" version agree
assert (x1 == orig_x1).all()
assert (assumed_x1 == orig_assumed_x1).all()
assert (x2 == orig_x2).all()
assert (x3 == orig_x3).all()
assert (x4 == orig_x4).all()


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('start_index', (0, 1, 5))
def test_transform_flatten_arrays(here, frontend, builder, start_index):
Expand Down
Loading