diff --git a/transformations/tests/test_single_column_coalesced.py b/transformations/tests/test_single_column_coalesced.py index 7a2def8ca..c431a3805 100644 --- a/transformations/tests/test_single_column_coalesced.py +++ b/transformations/tests/test_single_column_coalesced.py @@ -706,9 +706,16 @@ def test_single_column_coalesced_hoist_openacc(frontend, horizontal, vertical, b when hoisting array temporaries to driver. """ + fcode_mod = """ + MODULE BLOCK_DIM_MOD + INTEGER :: nb + END MODULE BLOCK_DIM_MOD + """ + fcode_driver = """ - SUBROUTINE column_driver(nlon, nz, q, nb) - INTEGER, INTENT(IN) :: nlon, nz, nb ! Size of the horizontal and vertical + SUBROUTINE column_driver(nlon, nz, q) + USE BLOCK_DIM_MOD, ONLY : nb + INTEGER, INTENT(IN) :: nlon, nz ! Size of the horizontal and vertical REAL, INTENT(INOUT) :: q(nlon,nz,nb) INTEGER :: b, start, end @@ -760,8 +767,9 @@ def test_single_column_coalesced_hoist_openacc(frontend, horizontal, vertical, b """.strip() # Mimic the scheduler internal mechanis to apply the transformation cascade + mod_source = Sourcefile.from_source(fcode_mod, frontend=frontend) kernel_source = Sourcefile.from_source(fcode_kernel, frontend=frontend) - driver_source = Sourcefile.from_source(fcode_driver, frontend=frontend) + driver_source = Sourcefile.from_source(fcode_driver, frontend=frontend, definitions=mod_source.modules) module_source = Sourcefile.from_source(fcode_module, frontend=frontend) driver = driver_source['column_driver'] kernel = kernel_source['compute_column'] @@ -784,6 +792,9 @@ def test_single_column_coalesced_hoist_openacc(frontend, horizontal, vertical, b driver, role='driver', item=driver_item, successors=(kernel_item,), targets=['compute_column'] ) + # Check that blocking size has not been redefined + assert driver.symbol_map[blocking.size].type.module.name.lower() == 'block_dim_mod' + with pragmas_attached(kernel, Loop): # Ensure kernel routine is anntoated at vector level kernel_pragmas = FindNodes(Pragma).visit(kernel.ir) diff --git a/transformations/transformations/single_column_base.py b/transformations/transformations/single_column_base.py index 3dbe50a90..795d7d1f4 100644 --- a/transformations/transformations/single_column_base.py +++ b/transformations/transformations/single_column_base.py @@ -110,9 +110,7 @@ def get_integer_variable(cls, routine, name): name : string Name of the variable to find the in the routine. """ - if name in routine.variable_map: - v_index = routine.variable_map[name] - else: + if not (v_index := routine.symbol_map.get(name, None)): dtype = SymbolAttributes(BasicType.INTEGER) v_index = sym.Variable(name=name, type=dtype, scope=routine) return v_index