diff --git a/transformations/tests/test_pool_allocator.py b/transformations/tests/test_pool_allocator.py index 77f486684..ec86132ad 100644 --- a/transformations/tests/test_pool_allocator.py +++ b/transformations/tests/test_pool_allocator.py @@ -71,7 +71,7 @@ def check_stack_created_in_driver( assert len(loops) == num_block_loops assignments = FindNodes(Assignment).visit(loops[0].body) assert assignments[0].lhs == 'ylstack_l' - if cray_ptr_loc_rhs: # generate_driver_stack: + if cray_ptr_loc_rhs: assert assignments[0].rhs == '1' else: assert isinstance(assignments[0].rhs, InlineCall) and assignments[0].rhs.function == 'loc' @@ -91,12 +91,11 @@ def check_stack_created_in_driver( else: assert assignments[1].lhs == 'ylstack_u' and ( assignments[1].rhs == f'ylstack_l + max(c_sizeof(real(1, kind={kind_real})), 8)*istsz') - # expected_rhs = f'ylstack_l + max(c_sizeof(real(1, kind={kind_real})), 8)*istsz' + if cray_ptr_loc_rhs: expected_rhs = 'ylstack_l + istsz' else: expected_rhs = f'ylstack_l + max(c_sizeof(real(1, kind={kind_real})), 8)*istsz' - # expected_rhs = remove_redundant_substrings(expected_rhs, kind_real=kind_real) assert assignments[1].lhs == 'ylstack_u' and assignments[1].rhs == expected_rhs # Check that stack assignment happens before kernel call @@ -335,10 +334,10 @@ def test_pool_allocator_temporaries(frontend, generate_driver_stack, block_dim, if cray_ptr_loc_rhs: kind_real = kind_real.replace(' ', '') trafo_data_compare = trafo_data_compare.replace(f'max(c_sizeof(real(1,kind={kind_real})),8)*', '') - # if generate_driver_stack: # not generate_driver_stack: stack_size = remove_redundant_substrings(stack_size, kind_real) - # TODO: ... nice if stack_size[-2:] == "+2": + # This is a little hacky but unless we start to properly assemble the size expression + # symbolically, this is the easiest to fix the expression ordering stack_size = f"2+{stack_size[:-2]}" assert kernel_item.trafo_data[transformation._key]['stack_size'] == trafo_data_compare assert all(v.scope is None for v in @@ -347,7 +346,6 @@ def test_pool_allocator_temporaries(frontend, generate_driver_stack, block_dim, # # A few checks on the driver # - # normalize_range_indexing(scheduler['#driver'].ir) driver = scheduler['#driver'].ir # Has c_sizeof procedure been imported? check_c_sizeof_import(driver) @@ -364,12 +362,15 @@ def test_pool_allocator_temporaries(frontend, generate_driver_stack, block_dim, else: expected_kwargs = (('YDSTACK_L', 'ylstack_l'),) if cray_ptr_loc_rhs: - expected_kwargs += (('ZSTACK', 'zstack(:,b)'),) + if frontend == OMNI and not generate_driver_stack: + # If the stack exists already in the driver, that variable is used. And because + # OMNI lower-cases everything, this will result in a lower-case name for the + # argument for that particular case... + expected_kwargs += (('zstack', 'zstack(:,b)'),) + else: + expected_kwargs += (('ZSTACK', 'zstack(:,b)'),) assert calls[0].arguments == expected_args - if frontend == OMNI and cray_ptr_loc_rhs: - pass # TODO: ... WTF - else: - assert calls[0].kwarguments == expected_kwargs + assert calls[0].kwarguments == expected_kwargs if generate_driver_stack: check_stack_created_in_driver(driver, stack_size, calls[0], 1, generate_driver_stack, check_bounds=check_bounds, @@ -671,7 +672,7 @@ def test_pool_allocator_temporaries_kernel_sequence(frontend, block_dim, directi f'max(c_sizeof(real(1, kind=jprb)), 8)' if cray_ptr_loc_rhs: stack_size = 'max(3*nlon + nlon*nz + nz, 3*nlon*nz + nlon)' - # TODO: continue + check_stack_created_in_driver(driver, stack_size, calls[0], 2, cray_ptr_loc_rhs=cray_ptr_loc_rhs) # Has the data sharing been updated?