From 5c318fe2a0fd17bd9f8c983f100cf48e7a2572e2 Mon Sep 17 00:00:00 2001 From: Abhilash Majumder <30946547+abhilash1910@users.noreply.github.com> Date: Mon, 1 Jul 2024 09:27:03 +0530 Subject: [PATCH] use nd range arg, fix func name --- .../src/util_group_load_store_test.cpp | 165 +++--------------- 1 file changed, 24 insertions(+), 141 deletions(-) diff --git a/help_function/src/util_group_load_store_test.cpp b/help_function/src/util_group_load_store_test.cpp index 73bd27418..da995d771 100644 --- a/help_function/src/util_group_load_store_test.cpp +++ b/help_function/src/util_group_load_store_test.cpp @@ -29,7 +29,7 @@ bool helper_validation_function(const int *ptr, const char *func_name) { return true; } -template bool test_group_load_store() { +template bool test_group_load_store(sycl::nd_range<3> &range, char *func_name) { // Tests dpct::group::load_algorithm::BLOCK_LOAD_DIRECT // dpct::group::load_algorithm::BLOCK_LOAD_STRIPED in its entirety as API // functions @@ -49,7 +49,8 @@ template bool te sycl::local_accessor tacc(sycl::range<1>(temp_storage_size), h); sycl::accessor data_accessor_read_write(buffer, h, sycl::read_write); h.parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, 128), sycl::range<3>(1, 1, 128)), + range, + //sycl::nd_range<3>(sycl::range<3>(1, 1, 128), sycl::range<3>(1, 1, 128)), [=](sycl::nd_item<3> item) { int thread_data[4]; auto *d_r_w = @@ -66,47 +67,7 @@ template bool te sycl::host_accessor data_accessor(buffer, sycl::read_write); const int *ptr = data_accessor.get_multi_ptr(); - return helper_validation_function(ptr, "test_group_load_store"); -} - -template bool test_group_load_store_multiple_wgs() { - // Tests dpct::group::load_algorithm::BLOCK_LOAD_DIRECT - // dpct::group::load_algorithm::BLOCK_LOAD_STRIPED in its entirety as API - // functions - // Tests dpct::group::store_algorithm::BLOCK_STORE_DIRECT - // dpct::group::store_algorithm::BLOCK_STORE_STRIPED in its entirety as API - // functions - sycl::queue q(dpct::get_default_queue()); - oneapi::dpl::counting_iterator count_it(0); - sycl::buffer buffer(count_it, count_it + 512); - - q.submit([&](sycl::handler &h) { - using group_load = - dpct::group::workgroup_load<4, T, int, const int *, sycl::nd_item<3>>; - using group_store = - dpct::group::workgroup_store<4, S, int, int *, sycl::nd_item<3>>; - size_t temp_storage_size = group_load::get_local_memory_size(128); - sycl::local_accessor tacc(sycl::range<1>(temp_storage_size), h); - sycl::accessor data_accessor_read_write(buffer, h, sycl::read_write); - h.parallel_for( - sycl::nd_range<3>(sycl::range<3>(2, 2, 64), sycl::range<3>(1, 1, 64)), - [=](sycl::nd_item<3> item) { - int thread_data[4]; - auto *d_r_w = - data_accessor_read_write.get_multi_ptr() - .get(); - auto *tmp = tacc.get_multi_ptr().get(); - // Load thread_data of each work item to blocked arrangement - group_load(tmp).load(item, d_r_w, thread_data); - // Store thread_data of each work item from blocked arrangement - group_store(tmp).store(item, d_r_w, thread_data); - }); - }); - q.wait_and_throw(); - - sycl::host_accessor data_accessor(buffer, sycl::read_write); - const int *ptr = data_accessor.get_multi_ptr(); - return helper_validation_function(ptr, "test_group_load_store"); + return helper_validation_function(ptr, func_name); } bool test_load_store_subgroup_striped_standalone() { @@ -153,91 +114,7 @@ bool test_load_store_subgroup_striped_standalone() { ptr, "test_subgroup_striped_standalone"); } -bool test_load_store_subgroup_striped_standalone_multiple_wgs() { - // Tests dpct::group::load_subgroup_striped as standalone method - sycl::queue q(dpct::get_default_queue()); - int data[512]; - for (int i = 0; i < 512; i++) - data[i] = i; - sycl::buffer buffer(data, 512); - sycl::buffer sg_sz_buf{sycl::range<1>(1)}; - - q.submit([&](sycl::handler &h) { - sycl::accessor dacc_read_write(buffer, h, sycl::read_write); - sycl::accessor sg_sz_dacc(sg_sz_buf, h, sycl::read_write); - h.parallel_for( - sycl::nd_range<3>(sycl::range<3>(2, 2, 64), sycl::range<3>(1, 1, 64)), - [=](sycl::nd_item<3> item) { - int thread_data[4]; - auto *d_r_w = - dacc_read_write.get_multi_ptr().get(); - auto *sg_sz_acc = - sg_sz_dacc.get_multi_ptr().get(); - size_t gid = item.get_global_linear_id(); - if (gid == 0) { - sg_sz_acc[0] = item.get_sub_group().get_local_linear_range(); - } - dpct::group::uninitialized_load_subgroup_striped<4, int>(item, d_r_w, - thread_data); - dpct::group::store_subgroup_striped<4, int>(item, d_r_w, thread_data); - //call destructor of thread type - for (size_t i = 0; i < 4; ++i) { - thread_data[i].~int(); - } - }); - }); - q.wait_and_throw(); - - sycl::host_accessor data_accessor(buffer, sycl::read_only); - const int *ptr = data_accessor.get_multi_ptr(); - sycl::host_accessor data_accessor_sg(sg_sz_buf, sycl::read_only); - const uint32_t *ptr_sg = - data_accessor_sg.get_multi_ptr(); - return helper_validation_function( - ptr, "test_subgroup_striped_standalone"); -} - -template bool test_group_load_store_standalone() { - // Tests dpct::group::load_algorithm::BLOCK_LOAD_DIRECT & - // dpct::group::load_algorithm::BLOCK_LOAD_STRIPED as standalone methods - // Tests dpct::group::store_algorithm::BLOCK_STORE_DIRECT & - // dpct::group::store_algorithm::BLOCK_STORE_STRIPED as standalone methods - sycl::queue q(dpct::get_default_queue()); - int data[512]; - for (int i = 0; i < 512; i++) - data[i] = i; - sycl::buffer buffer(data, 512); - - q.submit([&](sycl::handler &h) { - sycl::accessor dacc_read_write(buffer, h, sycl::read_write); - h.parallel_for( - sycl::nd_range<3>(sycl::range<3>(1, 1, 128), sycl::range<3>(1, 1, 128)), - [=](sycl::nd_item<3> item) { - int thread_data[4]; - auto *d_r_w = - dacc_read_write.get_multi_ptr().get(); - // Load thread_data of each work item to blocked arrangement - if (T == dpct::group::load_algorithm::BLOCK_LOAD_DIRECT) { - dpct::group::load_blocked<4, int>(item, d_r_w, thread_data); - } else { - dpct::group::load_striped<4, int>(item, d_r_w, thread_data); - } - // Store thread_data of each work item from blocked arrangement - if (S == dpct::group::store_algorithm::BLOCK_STORE_DIRECT) { - dpct::group::store_blocked<4, int>(item, d_r_w, thread_data); - } else { - dpct::group::store_striped<4, int>(item, d_r_w, thread_data); - } - }); - }); - q.wait_and_throw(); - - sycl::host_accessor data_accessor(buffer, sycl::read_only); - const int *ptr = data_accessor.get_multi_ptr(); - return helper_validation_function(ptr, "test_group_load_store"); -} - -template bool test_group_load_store_standalone_multi_wgs() { +template bool test_group_load_store_standalone(sycl::nd_range<3> & range, char *func_name) { // Tests dpct::group::load_algorithm::BLOCK_LOAD_DIRECT & // dpct::group::load_algorithm::BLOCK_LOAD_STRIPED as standalone methods // Tests dpct::group::store_algorithm::BLOCK_STORE_DIRECT & @@ -251,7 +128,8 @@ template bool te q.submit([&](sycl::handler &h) { sycl::accessor dacc_read_write(buffer, h, sycl::read_write); h.parallel_for( - sycl::nd_range<3>(sycl::range<3>(2, 2, 64), sycl::range<3>(1, 1, 64)), + range, + //sycl::nd_range<3>(sycl::range<3>(1, 1, 128), sycl::range<3>(1, 1, 128)), [=](sycl::nd_item<3> item) { int thread_data[4]; auto *d_r_w = @@ -274,34 +152,39 @@ template bool te sycl::host_accessor data_accessor(buffer, sycl::read_only); const int *ptr = data_accessor.get_multi_ptr(); - return helper_validation_function(ptr, "test_group_load_store"); + return helper_validation_function(ptr, func_name); } - int main() { + sycl::range<3> global_range{1, 1, 128}; + sycl::range<3> local_range{1, 1, 128}; + sycl::nd_range<3> range{global_range, local_range}; + sycl::range<3> global_range_multi{2, 2, 64}; + sycl::range<3> local_range_multi{1, 1, 64}; + sycl::nd_range<3> range_multi{global_range_multi, local_range_multi}; return !( // Calls test_group_load with blocked and striped strategies , should pass // both results. - test_group_load_store() && - test_group_load_store() && + test_group_load_store(range, "test_group_load_store") && + test_group_load_store(range, "test_group_load_store") && // Calls test_load_subgroup_striped_standalone and should pass test_load_store_subgroup_striped_standalone() && // Calls test_group_load_standalone with blocked and striped strategies as // free functions, should pass both results. test_group_load_store_standalone< - dpct::group::load_algorithm::BLOCK_LOAD_STRIPED, dpct::group::store_algorithm::BLOCK_STORE_STRIPED>() && + dpct::group::load_algorithm::BLOCK_LOAD_STRIPED, dpct::group::store_algorithm::BLOCK_STORE_STRIPED>(range, "test_group_load_store_standalone") && test_group_load_store_standalone< - dpct::group::load_algorithm::BLOCK_LOAD_DIRECT, dpct::group::store_algorithm::BLOCK_STORE_DIRECT>()) && + dpct::group::load_algorithm::BLOCK_LOAD_DIRECT, dpct::group::store_algorithm::BLOCK_STORE_DIRECT>(range, "test_group_load_store_standalone") && - test_group_load_store_multiple_wgs() && - test_group_load_store_multiple_wgs() && + test_group_load_store(range_multi, "test_group_load_store_multiple_wgs") && + test_group_load_store_multiple_wgs(range_multi, "test_group_load_store_multiple_wgs") && // Calls test_load_subgroup_striped_standalone and should pass test_load_store_subgroup_striped_standalone_multiple_wgs() && // Calls test_group_load_standalone with blocked and striped strategies as // free functions, should pass both results. - test_group_load_store_standalone_multiple_wgs< - dpct::group::load_algorithm::BLOCK_LOAD_STRIPED, dpct::group::store_algorithm::BLOCK_STORE_STRIPED>() && - test_group_load_store_standalone_multiple_wgs< - dpct::group::load_algorithm::BLOCK_LOAD_DIRECT, dpct::group::store_algorithm::BLOCK_STORE_DIRECT>()); + test_group_load_store_standalone< + dpct::group::load_algorithm::BLOCK_LOAD_STRIPED, dpct::group::store_algorithm::BLOCK_STORE_STRIPED>(range_multi, "test_group_load_store_standalone_multiple_wgs") && + test_group_load_store_standalone< + dpct::group::load_algorithm::BLOCK_LOAD_DIRECT, dpct::group::store_algorithm::BLOCK_STORE_DIRECT>(range_multi, "test_group_load_store_standalone_multiple_wgs")); }