Skip to content

Commit

Permalink
Merge pull request #386 from ecmwf-ifs/nams-resolve-vector-notation-a…
Browse files Browse the repository at this point in the history
…vailable-loops

Extend 'resolve_vector_notation' to look for available and appropriate loops
  • Loading branch information
reuterbal authored Oct 7, 2024
2 parents a4b1df9 + 924ae71 commit 0c5750b
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 3 deletions.
18 changes: 15 additions & 3 deletions loki/transformations/array_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,10 @@ def resolve_vector_notation(routine):
loop_map = {}
index_vars = set()
vmap = {}
# find available loops and create map {(lower, upper, step): loop_variable}
loops = FindNodes(Loop).visit(routine.body)
loop_map = {(loop.bounds.lower, loop.bounds.upper, loop.bounds.step or 1):
loop.variable for loop in loops}
for stmt in FindNodes(Assignment).visit(routine.body):
# Loop over all variables and replace them with loop indices
vdims = []
Expand All @@ -173,9 +177,17 @@ def resolve_vector_notation(routine):
ivar_basename = f'i_{stmt.lhs.basename}'
for i, dim, s in zip(count(), v.dimensions, as_tuple(v.shape)):
if isinstance(dim, sym.RangeIndex):
# Create new index variable
vtype = SymbolAttributes(BasicType.INTEGER)
ivar = sym.Variable(name=f'{ivar_basename}_{i}', type=vtype, scope=routine)
# create tuple to test whether an appropriate loop is already available
test_range = (sym.IntLiteral(1), s, 1) if not isinstance(s, sym.RangeIndex)\
else (s.lower, s.upper, 1)
# actually test for it
if test_range in loop_map:
# Use index variable of available matching loop
ivar = loop_map[test_range]
else:
# Create new index variable
vtype = SymbolAttributes(BasicType.INTEGER)
ivar = sym.Variable(name=f'{ivar_basename}_{i}', type=vtype, scope=routine)
shape_index_map[(i, s)] = ivar
index_range_map[ivar] = s

Expand Down
117 changes: 117 additions & 0 deletions loki/transformations/tests/test_array_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,6 +1062,123 @@ def test_transform_promote_resolve_vector_notation(tmp_path, frontend):
assert np.all(ret1 == 11)
assert np.all(ret2 == 42)


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_resolve_vector_notation_common_loops(tmp_path, frontend):
"""
Apply and test resolve vector notation utility with already
available/appropriate loops.
"""
fcode = """
subroutine transform_resolve_vector_notation_common_loops(scalar, vector, matrix, n, m, l)
implicit none
integer, intent(in) :: n, m, l
integer, intent(inout) :: scalar, vector(n), matrix(l, n)
integer :: tmp_scalar, tmp_vector(n, m), tmp_matrix(l, m, n), tmp_dummy(n, 0:4)
integer :: jl, jk, jm
tmp_dummy(:,:) = 0
tmp_vector(:, 1) = tmp_dummy(:, 1)
tmp_vector(:, :) = 0
tmp_matrix(:, :, :) = 0
matrix(:, :) = 0
do jl=1,n
do jm=1,m
tmp_vector(jl, jm) = scalar + jl
end do
end do
do jm=1,m
do jl=1,n
scalar = jl
vector(jl) = tmp_vector(jl, jm) + tmp_vector(jl, jm)
do jk=1,l
tmp_matrix(jk, jm, jl) = vector(jl) + jk
end do
end do
end do
do jk=1,l
matrix(jk, :) = 0
do jm=1,m
do jl=1,n
matrix(jk, jl) = tmp_matrix(jk, jm, jl)
end do
end do
end do
end subroutine transform_resolve_vector_notation_common_loops
""".strip()
routine = Subroutine.from_source(fcode, frontend=frontend)
# Test the original implementation
filepath = tmp_path/(f'{routine.name}_{frontend}.f90')
function = jit_compile(routine, filepath=filepath, objname=routine.name)

n = 3
m = 2
l = 3
scalar = np.zeros(shape=(1,), order='F', dtype=np.int32)
vector = np.zeros(shape=(n,), order='F', dtype=np.int32)
matrix = np.zeros(shape=(n, n), order='F', dtype=np.int32)
function(scalar, vector, matrix, n, m, l)

assert all(scalar == 3)
assert np.all(vector == np.arange(1, n + 1)*2)
assert np.all(matrix == np.sum(np.mgrid[1:4,2:8:2], axis=0))

resolve_vector_notation(routine)

loops = FindNodes(Loop).visit(routine.body)
arrays = [var for var in FindVariables(unique=False).visit(routine.body) if isinstance(var, sym.Array)]

assert len(loops) == 19
assert loops[0].variable == 'i_tmp_dummy_1' and loops[0].bounds.children == (0, 4, None)
assert loops[1].variable == 'jl' and loops[1].bounds.children == (1, 'n', 1)
assert loops[2].variable == 'jl' and loops[2].bounds.children == (1, 'n', 1)
assert loops[3].variable == 'jm' and loops[3].bounds.children == (1, 'm', 1)
assert loops[4].variable == 'jl' and loops[4].bounds.children == (1, 'n', 1)
assert loops[5].variable == 'jl' and loops[5].bounds.children == (1, 'n', 1)
assert loops[6].variable == 'jm' and loops[6].bounds.children == (1, 'm', 1)
assert loops[7].variable == 'jk' and loops[7].bounds.children == (1, 'l', 1)
assert loops[8].variable == 'jl' and loops[8].bounds.children == (1, 'n', 1)
assert loops[9].variable == 'jk' and loops[9].bounds.children == (1, 'l', 1)
assert loops[10].variable == 'jl' and loops[10].bounds.children == (1, 'n', None)
assert loops[11].variable == 'jm' and loops[11].bounds.children == (1, 'm', None)
assert loops[12].variable == 'jm' and loops[12].bounds.children == (1, 'm', None)
assert loops[13].variable == 'jl' and loops[13].bounds.children == (1, 'n', None)
assert loops[14].variable == 'jk' and loops[14].bounds.children == (1, 'l', None)
assert loops[15].variable == 'jk' and loops[15].bounds.children == (1, 'l', None)
assert loops[16].variable == 'jl' and loops[16].bounds.children == (1, 'n', 1)
assert loops[17].variable == 'jm' and loops[17].bounds.children == (1, 'm', None)
assert loops[18].variable == 'jl' and loops[18].bounds.children == (1, 'n', None)

assert len(arrays) == 15
assert arrays[0].name.lower() == 'tmp_dummy' and arrays[0].dimensions == ('jl', 'i_tmp_dummy_1')
assert arrays[1].name.lower() == 'tmp_vector' and arrays[1].dimensions == ('jl', 1)
assert arrays[2].name.lower() == 'tmp_dummy' and arrays[2].dimensions == ('jl', 1)
assert arrays[3].name.lower() == 'tmp_vector' and arrays[3].dimensions == ('jl', 'jm')
assert arrays[4].name.lower() == 'tmp_matrix' and arrays[4].dimensions == ('jk', 'jm', 'jl')

# Test promoted routine
resolved_filepath = tmp_path/(f'{routine.name}_resolved_{frontend}.f90')
resolved_function = jit_compile(routine, filepath=resolved_filepath, objname=routine.name)

n = 3
m = 2
l = 3
scalar = np.zeros(shape=(1,), order='F', dtype=np.int32)
vector = np.zeros(shape=(n,), order='F', dtype=np.int32)
matrix = np.zeros(shape=(n, n), order='F', dtype=np.int32)
resolved_function(scalar, vector, matrix, n, m, l)

assert all(scalar == 3)
assert np.all(vector == np.arange(1, n + 1)*2)
assert np.all(matrix == np.sum(np.mgrid[1:4,2:8:2], axis=0))


@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('calls_only', (False, True))
def test_transform_explicit_dimensions(tmp_path, frontend, builder, calls_only):
Expand Down

0 comments on commit 0c5750b

Please sign in to comment.