Skip to content

Commit

Permalink
Fix tests for OMNI
Browse files Browse the repository at this point in the history
  • Loading branch information
reuterbal committed Apr 9, 2024
1 parent f6e381c commit f939a8c
Showing 1 changed file with 13 additions and 12 deletions.
25 changes: 13 additions & 12 deletions transformations/tests/test_pool_allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def check_stack_created_in_driver(
assert len(loops) == num_block_loops
assignments = FindNodes(Assignment).visit(loops[0].body)
assert assignments[0].lhs == 'ylstack_l'
if cray_ptr_loc_rhs: # generate_driver_stack:
if cray_ptr_loc_rhs:
assert assignments[0].rhs == '1'
else:
assert isinstance(assignments[0].rhs, InlineCall) and assignments[0].rhs.function == 'loc'
Expand All @@ -91,12 +91,11 @@ def check_stack_created_in_driver(
else:
assert assignments[1].lhs == 'ylstack_u' and (
assignments[1].rhs == f'ylstack_l + max(c_sizeof(real(1, kind={kind_real})), 8)*istsz')
# expected_rhs = f'ylstack_l + max(c_sizeof(real(1, kind={kind_real})), 8)*istsz'

if cray_ptr_loc_rhs:
expected_rhs = 'ylstack_l + istsz'
else:
expected_rhs = f'ylstack_l + max(c_sizeof(real(1, kind={kind_real})), 8)*istsz'
# expected_rhs = remove_redundant_substrings(expected_rhs, kind_real=kind_real)
assert assignments[1].lhs == 'ylstack_u' and assignments[1].rhs == expected_rhs

# Check that stack assignment happens before kernel call
Expand Down Expand Up @@ -335,10 +334,10 @@ def test_pool_allocator_temporaries(frontend, generate_driver_stack, block_dim,
if cray_ptr_loc_rhs:
kind_real = kind_real.replace(' ', '')
trafo_data_compare = trafo_data_compare.replace(f'max(c_sizeof(real(1,kind={kind_real})),8)*', '')
# if generate_driver_stack: # not generate_driver_stack:
stack_size = remove_redundant_substrings(stack_size, kind_real)
# TODO: ... nice
if stack_size[-2:] == "+2":
# This is a little hacky but unless we start to properly assemble the size expression
# symbolically, this is the easiest to fix the expression ordering
stack_size = f"2+{stack_size[:-2]}"
assert kernel_item.trafo_data[transformation._key]['stack_size'] == trafo_data_compare
assert all(v.scope is None for v in
Expand All @@ -347,7 +346,6 @@ def test_pool_allocator_temporaries(frontend, generate_driver_stack, block_dim,
#
# A few checks on the driver
#
# normalize_range_indexing(scheduler['#driver'].ir)
driver = scheduler['#driver'].ir
# Has c_sizeof procedure been imported?
check_c_sizeof_import(driver)
Expand All @@ -364,12 +362,15 @@ def test_pool_allocator_temporaries(frontend, generate_driver_stack, block_dim,
else:
expected_kwargs = (('YDSTACK_L', 'ylstack_l'),)
if cray_ptr_loc_rhs:
expected_kwargs += (('ZSTACK', 'zstack(:,b)'),)
if frontend == OMNI and not generate_driver_stack:
# If the stack exists already in the driver, that variable is used. And because
# OMNI lower-cases everything, this will result in a lower-case name for the
# argument for that particular case...
expected_kwargs += (('zstack', 'zstack(:,b)'),)
else:
expected_kwargs += (('ZSTACK', 'zstack(:,b)'),)
assert calls[0].arguments == expected_args
if frontend == OMNI and cray_ptr_loc_rhs:
pass # TODO: ... WTF
else:
assert calls[0].kwarguments == expected_kwargs
assert calls[0].kwarguments == expected_kwargs

if generate_driver_stack:
check_stack_created_in_driver(driver, stack_size, calls[0], 1, generate_driver_stack, check_bounds=check_bounds,
Expand Down Expand Up @@ -671,7 +672,7 @@ def test_pool_allocator_temporaries_kernel_sequence(frontend, block_dim, directi
f'max(c_sizeof(real(1, kind=jprb)), 8)'
if cray_ptr_loc_rhs:
stack_size = 'max(3*nlon + nlon*nz + nz, 3*nlon*nz + nlon)'
# TODO: continue

check_stack_created_in_driver(driver, stack_size, calls[0], 2, cray_ptr_loc_rhs=cray_ptr_loc_rhs)

# Has the data sharing been updated?
Expand Down

0 comments on commit f939a8c

Please sign in to comment.