diff --git a/loki/transformations/array_indexing.py b/loki/transformations/array_indexing.py index cca64b128..cdbfb4267 100644 --- a/loki/transformations/array_indexing.py +++ b/loki/transformations/array_indexing.py @@ -157,6 +157,10 @@ def resolve_vector_notation(routine): loop_map = {} index_vars = set() vmap = {} + # find available loops and create map {(lower, upper, step): loop_variable} + loops = FindNodes(Loop).visit(routine.body) + loop_map = {(loop.bounds.lower, loop.bounds.upper, loop.bounds.step or 1): + loop.variable for loop in loops} for stmt in FindNodes(Assignment).visit(routine.body): # Loop over all variables and replace them with loop indices vdims = [] @@ -173,9 +177,17 @@ def resolve_vector_notation(routine): ivar_basename = f'i_{stmt.lhs.basename}' for i, dim, s in zip(count(), v.dimensions, as_tuple(v.shape)): if isinstance(dim, sym.RangeIndex): - # Create new index variable - vtype = SymbolAttributes(BasicType.INTEGER) - ivar = sym.Variable(name=f'{ivar_basename}_{i}', type=vtype, scope=routine) + # create tuple to test whether an appropriate loop is already available + test_range = (sym.IntLiteral(1), s, 1) if not isinstance(s, sym.RangeIndex)\ + else (s.lower, s.upper, 1) + # actually test for it + if test_range in loop_map: + # Use index variable of available matching loop + ivar = loop_map[test_range] + else: + # Create new index variable + vtype = SymbolAttributes(BasicType.INTEGER) + ivar = sym.Variable(name=f'{ivar_basename}_{i}', type=vtype, scope=routine) shape_index_map[(i, s)] = ivar index_range_map[ivar] = s diff --git a/loki/transformations/tests/test_array_indexing.py b/loki/transformations/tests/test_array_indexing.py index 224ce64c9..4cf073b85 100644 --- a/loki/transformations/tests/test_array_indexing.py +++ b/loki/transformations/tests/test_array_indexing.py @@ -1062,6 +1062,123 @@ def test_transform_promote_resolve_vector_notation(tmp_path, frontend): assert np.all(ret1 == 11) assert np.all(ret2 == 42) + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_transform_resolve_vector_notation_common_loops(tmp_path, frontend): + """ + Apply and test resolve vector notation utility with already + available/appropriate loops. + """ + fcode = """ +subroutine transform_resolve_vector_notation_common_loops(scalar, vector, matrix, n, m, l) + implicit none + integer, intent(in) :: n, m, l + integer, intent(inout) :: scalar, vector(n), matrix(l, n) + integer :: tmp_scalar, tmp_vector(n, m), tmp_matrix(l, m, n), tmp_dummy(n, 0:4) + integer :: jl, jk, jm + + tmp_dummy(:,:) = 0 + tmp_vector(:, 1) = tmp_dummy(:, 1) + tmp_vector(:, :) = 0 + tmp_matrix(:, :, :) = 0 + matrix(:, :) = 0 + + do jl=1,n + do jm=1,m + tmp_vector(jl, jm) = scalar + jl + end do + end do + + do jm=1,m + do jl=1,n + scalar = jl + vector(jl) = tmp_vector(jl, jm) + tmp_vector(jl, jm) + + do jk=1,l + tmp_matrix(jk, jm, jl) = vector(jl) + jk + end do + end do + end do + + + do jk=1,l + matrix(jk, :) = 0 + do jm=1,m + do jl=1,n + matrix(jk, jl) = tmp_matrix(jk, jm, jl) + end do + end do + end do + +end subroutine transform_resolve_vector_notation_common_loops + """.strip() + routine = Subroutine.from_source(fcode, frontend=frontend) + # Test the original implementation + filepath = tmp_path/(f'{routine.name}_{frontend}.f90') + function = jit_compile(routine, filepath=filepath, objname=routine.name) + + n = 3 + m = 2 + l = 3 + scalar = np.zeros(shape=(1,), order='F', dtype=np.int32) + vector = np.zeros(shape=(n,), order='F', dtype=np.int32) + matrix = np.zeros(shape=(n, n), order='F', dtype=np.int32) + function(scalar, vector, matrix, n, m, l) + + assert all(scalar == 3) + assert np.all(vector == np.arange(1, n + 1)*2) + assert np.all(matrix == np.sum(np.mgrid[1:4,2:8:2], axis=0)) + + resolve_vector_notation(routine) + + loops = FindNodes(Loop).visit(routine.body) + arrays = [var for var in FindVariables(unique=False).visit(routine.body) if isinstance(var, sym.Array)] + + assert len(loops) == 19 + assert loops[0].variable == 'i_tmp_dummy_1' and loops[0].bounds.children == (0, 4, None) + assert loops[1].variable == 'jl' and loops[1].bounds.children == (1, 'n', 1) + assert loops[2].variable == 'jl' and loops[2].bounds.children == (1, 'n', 1) + assert loops[3].variable == 'jm' and loops[3].bounds.children == (1, 'm', 1) + assert loops[4].variable == 'jl' and loops[4].bounds.children == (1, 'n', 1) + assert loops[5].variable == 'jl' and loops[5].bounds.children == (1, 'n', 1) + assert loops[6].variable == 'jm' and loops[6].bounds.children == (1, 'm', 1) + assert loops[7].variable == 'jk' and loops[7].bounds.children == (1, 'l', 1) + assert loops[8].variable == 'jl' and loops[8].bounds.children == (1, 'n', 1) + assert loops[9].variable == 'jk' and loops[9].bounds.children == (1, 'l', 1) + assert loops[10].variable == 'jl' and loops[10].bounds.children == (1, 'n', None) + assert loops[11].variable == 'jm' and loops[11].bounds.children == (1, 'm', None) + assert loops[12].variable == 'jm' and loops[12].bounds.children == (1, 'm', None) + assert loops[13].variable == 'jl' and loops[13].bounds.children == (1, 'n', None) + assert loops[14].variable == 'jk' and loops[14].bounds.children == (1, 'l', None) + assert loops[15].variable == 'jk' and loops[15].bounds.children == (1, 'l', None) + assert loops[16].variable == 'jl' and loops[16].bounds.children == (1, 'n', 1) + assert loops[17].variable == 'jm' and loops[17].bounds.children == (1, 'm', None) + assert loops[18].variable == 'jl' and loops[18].bounds.children == (1, 'n', None) + + assert len(arrays) == 15 + assert arrays[0].name.lower() == 'tmp_dummy' and arrays[0].dimensions == ('jl', 'i_tmp_dummy_1') + assert arrays[1].name.lower() == 'tmp_vector' and arrays[1].dimensions == ('jl', 1) + assert arrays[2].name.lower() == 'tmp_dummy' and arrays[2].dimensions == ('jl', 1) + assert arrays[3].name.lower() == 'tmp_vector' and arrays[3].dimensions == ('jl', 'jm') + assert arrays[4].name.lower() == 'tmp_matrix' and arrays[4].dimensions == ('jk', 'jm', 'jl') + + # Test promoted routine + resolved_filepath = tmp_path/(f'{routine.name}_resolved_{frontend}.f90') + resolved_function = jit_compile(routine, filepath=resolved_filepath, objname=routine.name) + + n = 3 + m = 2 + l = 3 + scalar = np.zeros(shape=(1,), order='F', dtype=np.int32) + vector = np.zeros(shape=(n,), order='F', dtype=np.int32) + matrix = np.zeros(shape=(n, n), order='F', dtype=np.int32) + resolved_function(scalar, vector, matrix, n, m, l) + + assert all(scalar == 3) + assert np.all(vector == np.arange(1, n + 1)*2) + assert np.all(matrix == np.sum(np.mgrid[1:4,2:8:2], axis=0)) + + @pytest.mark.parametrize('frontend', available_frontends()) @pytest.mark.parametrize('calls_only', (False, True)) def test_transform_explicit_dimensions(tmp_path, frontend, builder, calls_only):