Skip to content

Commit

Permalink
SCCBase: get_integer_variable now also checks module imports
Browse files Browse the repository at this point in the history
  • Loading branch information
awnawab committed Apr 11, 2024
1 parent a594b44 commit 47314ba
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 6 deletions.
17 changes: 14 additions & 3 deletions transformations/tests/test_single_column_coalesced.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,9 +706,16 @@ def test_single_column_coalesced_hoist_openacc(frontend, horizontal, vertical, b
when hoisting array temporaries to driver.
"""

fcode_mod = """
MODULE BLOCK_DIM_MOD
INTEGER :: nb
END MODULE BLOCK_DIM_MOD
"""

fcode_driver = """
SUBROUTINE column_driver(nlon, nz, q, nb)
INTEGER, INTENT(IN) :: nlon, nz, nb ! Size of the horizontal and vertical
SUBROUTINE column_driver(nlon, nz, q)
USE BLOCK_DIM_MOD, ONLY : nb
INTEGER, INTENT(IN) :: nlon, nz ! Size of the horizontal and vertical
REAL, INTENT(INOUT) :: q(nlon,nz,nb)
INTEGER :: b, start, end
Expand Down Expand Up @@ -760,8 +767,9 @@ def test_single_column_coalesced_hoist_openacc(frontend, horizontal, vertical, b
""".strip()

# Mimic the scheduler internal mechanis to apply the transformation cascade
mod_source = Sourcefile.from_source(fcode_mod, frontend=frontend)
kernel_source = Sourcefile.from_source(fcode_kernel, frontend=frontend)
driver_source = Sourcefile.from_source(fcode_driver, frontend=frontend)
driver_source = Sourcefile.from_source(fcode_driver, frontend=frontend, definitions=mod_source.modules)
module_source = Sourcefile.from_source(fcode_module, frontend=frontend)
driver = driver_source['column_driver']
kernel = kernel_source['compute_column']
Expand All @@ -784,6 +792,9 @@ def test_single_column_coalesced_hoist_openacc(frontend, horizontal, vertical, b
driver, role='driver', item=driver_item, successors=(kernel_item,), targets=['compute_column']
)

# Check that blocking size has not been redefined
assert driver.symbol_map[blocking.size].type.module.name.lower() == 'block_dim_mod'

with pragmas_attached(kernel, Loop):
# Ensure kernel routine is anntoated at vector level
kernel_pragmas = FindNodes(Pragma).visit(kernel.ir)
Expand Down
4 changes: 1 addition & 3 deletions transformations/transformations/single_column_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,7 @@ def get_integer_variable(cls, routine, name):
name : string
Name of the variable to find the in the routine.
"""
if name in routine.variable_map:
v_index = routine.variable_map[name]
else:
if not (v_index := routine.symbol_map.get(name, None)):
dtype = SymbolAttributes(BasicType.INTEGER)
v_index = sym.Variable(name=name, type=dtype, scope=routine)
return v_index
Expand Down

0 comments on commit 47314ba

Please sign in to comment.