Skip to content

Commit

Permalink
use nd range arg, fix func name
Browse files Browse the repository at this point in the history
  • Loading branch information
abhilash1910 authored Jul 1, 2024
1 parent 099fe54 commit 5c318fe
Showing 1 changed file with 24 additions and 141 deletions.
165 changes: 24 additions & 141 deletions help_function/src/util_group_load_store_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ bool helper_validation_function(const int *ptr, const char *func_name) {
return true;
}

template <dpct::group::load_algorithm T, dpct::group::store_algorithm S> bool test_group_load_store() {
template <dpct::group::load_algorithm T, dpct::group::store_algorithm S> bool test_group_load_store(sycl::nd_range<3> &range, char *func_name) {
// Tests dpct::group::load_algorithm::BLOCK_LOAD_DIRECT
// dpct::group::load_algorithm::BLOCK_LOAD_STRIPED in its entirety as API
// functions
Expand All @@ -49,7 +49,8 @@ template <dpct::group::load_algorithm T, dpct::group::store_algorithm S> bool te
sycl::local_accessor<uint8_t, 1> tacc(sycl::range<1>(temp_storage_size), h);
sycl::accessor data_accessor_read_write(buffer, h, sycl::read_write);
h.parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, 128), sycl::range<3>(1, 1, 128)),
range,
//sycl::nd_range<3>(sycl::range<3>(1, 1, 128), sycl::range<3>(1, 1, 128)),
[=](sycl::nd_item<3> item) {
int thread_data[4];
auto *d_r_w =
Expand All @@ -66,47 +67,7 @@ template <dpct::group::load_algorithm T, dpct::group::store_algorithm S> bool te

sycl::host_accessor data_accessor(buffer, sycl::read_write);
const int *ptr = data_accessor.get_multi_ptr<sycl::access::decorated::yes>();
return helper_validation_function(ptr, "test_group_load_store");
}

template <dpct::group::load_algorithm T, dpct::group::store_algorithm S> bool test_group_load_store_multiple_wgs() {
// Tests dpct::group::load_algorithm::BLOCK_LOAD_DIRECT
// dpct::group::load_algorithm::BLOCK_LOAD_STRIPED in its entirety as API
// functions
// Tests dpct::group::store_algorithm::BLOCK_STORE_DIRECT
// dpct::group::store_algorithm::BLOCK_STORE_STRIPED in its entirety as API
// functions
sycl::queue q(dpct::get_default_queue());
oneapi::dpl::counting_iterator<int> count_it(0);
sycl::buffer<int, 1> buffer(count_it, count_it + 512);

q.submit([&](sycl::handler &h) {
using group_load =
dpct::group::workgroup_load<4, T, int, const int *, sycl::nd_item<3>>;
using group_store =
dpct::group::workgroup_store<4, S, int, int *, sycl::nd_item<3>>;
size_t temp_storage_size = group_load::get_local_memory_size(128);
sycl::local_accessor<uint8_t, 1> tacc(sycl::range<1>(temp_storage_size), h);
sycl::accessor data_accessor_read_write(buffer, h, sycl::read_write);
h.parallel_for(
sycl::nd_range<3>(sycl::range<3>(2, 2, 64), sycl::range<3>(1, 1, 64)),
[=](sycl::nd_item<3> item) {
int thread_data[4];
auto *d_r_w =
data_accessor_read_write.get_multi_ptr<sycl::access::decorated::yes>()
.get();
auto *tmp = tacc.get_multi_ptr<sycl::access::decorated::yes>().get();
// Load thread_data of each work item to blocked arrangement
group_load(tmp).load(item, d_r_w, thread_data);
// Store thread_data of each work item from blocked arrangement
group_store(tmp).store(item, d_r_w, thread_data);
});
});
q.wait_and_throw();

sycl::host_accessor data_accessor(buffer, sycl::read_write);
const int *ptr = data_accessor.get_multi_ptr<sycl::access::decorated::yes>();
return helper_validation_function(ptr, "test_group_load_store");
return helper_validation_function(ptr, func_name);
}

bool test_load_store_subgroup_striped_standalone() {
Expand Down Expand Up @@ -153,91 +114,7 @@ bool test_load_store_subgroup_striped_standalone() {
ptr, "test_subgroup_striped_standalone");
}

bool test_load_store_subgroup_striped_standalone_multiple_wgs() {
// Tests dpct::group::load_subgroup_striped as standalone method
sycl::queue q(dpct::get_default_queue());
int data[512];
for (int i = 0; i < 512; i++)
data[i] = i;
sycl::buffer<int, 1> buffer(data, 512);
sycl::buffer<uint32_t, 1> sg_sz_buf{sycl::range<1>(1)};

q.submit([&](sycl::handler &h) {
sycl::accessor dacc_read_write(buffer, h, sycl::read_write);
sycl::accessor sg_sz_dacc(sg_sz_buf, h, sycl::read_write);
h.parallel_for(
sycl::nd_range<3>(sycl::range<3>(2, 2, 64), sycl::range<3>(1, 1, 64)),
[=](sycl::nd_item<3> item) {
int thread_data[4];
auto *d_r_w =
dacc_read_write.get_multi_ptr<sycl::access::decorated::yes>().get();
auto *sg_sz_acc =
sg_sz_dacc.get_multi_ptr<sycl::access::decorated::yes>().get();
size_t gid = item.get_global_linear_id();
if (gid == 0) {
sg_sz_acc[0] = item.get_sub_group().get_local_linear_range();
}
dpct::group::uninitialized_load_subgroup_striped<4, int>(item, d_r_w,
thread_data);
dpct::group::store_subgroup_striped<4, int>(item, d_r_w, thread_data);
//call destructor of thread type
for (size_t i = 0; i < 4; ++i) {
thread_data[i].~int();
}
});
});
q.wait_and_throw();

sycl::host_accessor data_accessor(buffer, sycl::read_only);
const int *ptr = data_accessor.get_multi_ptr<sycl::access::decorated::yes>();
sycl::host_accessor data_accessor_sg(sg_sz_buf, sycl::read_only);
const uint32_t *ptr_sg =
data_accessor_sg.get_multi_ptr<sycl::access::decorated::yes>();
return helper_validation_function(
ptr, "test_subgroup_striped_standalone");
}

template <dpct::group::load_algorithm T, dpct::group::store_algorithm S> bool test_group_load_store_standalone() {
// Tests dpct::group::load_algorithm::BLOCK_LOAD_DIRECT &
// dpct::group::load_algorithm::BLOCK_LOAD_STRIPED as standalone methods
// Tests dpct::group::store_algorithm::BLOCK_STORE_DIRECT &
// dpct::group::store_algorithm::BLOCK_STORE_STRIPED as standalone methods
sycl::queue q(dpct::get_default_queue());
int data[512];
for (int i = 0; i < 512; i++)
data[i] = i;
sycl::buffer<int, 1> buffer(data, 512);

q.submit([&](sycl::handler &h) {
sycl::accessor dacc_read_write(buffer, h, sycl::read_write);
h.parallel_for(
sycl::nd_range<3>(sycl::range<3>(1, 1, 128), sycl::range<3>(1, 1, 128)),
[=](sycl::nd_item<3> item) {
int thread_data[4];
auto *d_r_w =
dacc_read_write.get_multi_ptr<sycl::access::decorated::yes>().get();
// Load thread_data of each work item to blocked arrangement
if (T == dpct::group::load_algorithm::BLOCK_LOAD_DIRECT) {
dpct::group::load_blocked<4, int>(item, d_r_w, thread_data);
} else {
dpct::group::load_striped<4, int>(item, d_r_w, thread_data);
}
// Store thread_data of each work item from blocked arrangement
if (S == dpct::group::store_algorithm::BLOCK_STORE_DIRECT) {
dpct::group::store_blocked<4, int>(item, d_r_w, thread_data);
} else {
dpct::group::store_striped<4, int>(item, d_r_w, thread_data);
}
});
});
q.wait_and_throw();

sycl::host_accessor data_accessor(buffer, sycl::read_only);
const int *ptr = data_accessor.get_multi_ptr<sycl::access::decorated::yes>();
return helper_validation_function(ptr, "test_group_load_store");
}

template <dpct::group::load_algorithm T, dpct::group::store_algorithm S> bool test_group_load_store_standalone_multi_wgs() {
template <dpct::group::load_algorithm T, dpct::group::store_algorithm S> bool test_group_load_store_standalone(sycl::nd_range<3> & range, char *func_name) {
// Tests dpct::group::load_algorithm::BLOCK_LOAD_DIRECT &
// dpct::group::load_algorithm::BLOCK_LOAD_STRIPED as standalone methods
// Tests dpct::group::store_algorithm::BLOCK_STORE_DIRECT &
Expand All @@ -251,7 +128,8 @@ template <dpct::group::load_algorithm T, dpct::group::store_algorithm S> bool te
q.submit([&](sycl::handler &h) {
sycl::accessor dacc_read_write(buffer, h, sycl::read_write);
h.parallel_for(
sycl::nd_range<3>(sycl::range<3>(2, 2, 64), sycl::range<3>(1, 1, 64)),
range,
//sycl::nd_range<3>(sycl::range<3>(1, 1, 128), sycl::range<3>(1, 1, 128)),
[=](sycl::nd_item<3> item) {
int thread_data[4];
auto *d_r_w =
Expand All @@ -274,34 +152,39 @@ template <dpct::group::load_algorithm T, dpct::group::store_algorithm S> bool te

sycl::host_accessor data_accessor(buffer, sycl::read_only);
const int *ptr = data_accessor.get_multi_ptr<sycl::access::decorated::yes>();
return helper_validation_function(ptr, "test_group_load_store");
return helper_validation_function(ptr, func_name);
}


int main() {
sycl::range<3> global_range{1, 1, 128};
sycl::range<3> local_range{1, 1, 128};
sycl::nd_range<3> range{global_range, local_range};
sycl::range<3> global_range_multi{2, 2, 64};
sycl::range<3> local_range_multi{1, 1, 64};
sycl::nd_range<3> range_multi{global_range_multi, local_range_multi};

return !(
// Calls test_group_load with blocked and striped strategies , should pass
// both results.
test_group_load_store<dpct::group::load_algorithm::BLOCK_LOAD_DIRECT, dpct::group::store_algorithm::BLOCK_STORE_DIRECT>() &&
test_group_load_store<dpct::group::load_algorithm::BLOCK_LOAD_STRIPED, dpct::group::store_algorithm::BLOCK_STORE_STRIPED>() &&
test_group_load_store<dpct::group::load_algorithm::BLOCK_LOAD_DIRECT, dpct::group::store_algorithm::BLOCK_STORE_DIRECT>(range, "test_group_load_store") &&
test_group_load_store<dpct::group::load_algorithm::BLOCK_LOAD_STRIPED, dpct::group::store_algorithm::BLOCK_STORE_STRIPED>(range, "test_group_load_store") &&
// Calls test_load_subgroup_striped_standalone and should pass
test_load_store_subgroup_striped_standalone() &&
// Calls test_group_load_standalone with blocked and striped strategies as
// free functions, should pass both results.
test_group_load_store_standalone<
dpct::group::load_algorithm::BLOCK_LOAD_STRIPED, dpct::group::store_algorithm::BLOCK_STORE_STRIPED>() &&
dpct::group::load_algorithm::BLOCK_LOAD_STRIPED, dpct::group::store_algorithm::BLOCK_STORE_STRIPED>(range, "test_group_load_store_standalone") &&
test_group_load_store_standalone<
dpct::group::load_algorithm::BLOCK_LOAD_DIRECT, dpct::group::store_algorithm::BLOCK_STORE_DIRECT>()) &&
dpct::group::load_algorithm::BLOCK_LOAD_DIRECT, dpct::group::store_algorithm::BLOCK_STORE_DIRECT>(range, "test_group_load_store_standalone") &&

test_group_load_store_multiple_wgs<dpct::group::load_algorithm::BLOCK_LOAD_DIRECT, dpct::group::store_algorithm::BLOCK_STORE_DIRECT>() &&
test_group_load_store_multiple_wgs<dpct::group::load_algorithm::BLOCK_LOAD_STRIPED, dpct::group::store_algorithm::BLOCK_STORE_STRIPED>() &&
test_group_load_store<dpct::group::load_algorithm::BLOCK_LOAD_DIRECT, dpct::group::store_algorithm::BLOCK_STORE_DIRECT>(range_multi, "test_group_load_store_multiple_wgs") &&
test_group_load_store_multiple_wgs<dpct::group::load_algorithm::BLOCK_LOAD_STRIPED, dpct::group::store_algorithm::BLOCK_STORE_STRIPED>(range_multi, "test_group_load_store_multiple_wgs") &&
// Calls test_load_subgroup_striped_standalone and should pass
test_load_store_subgroup_striped_standalone_multiple_wgs() &&
// Calls test_group_load_standalone with blocked and striped strategies as
// free functions, should pass both results.
test_group_load_store_standalone_multiple_wgs<
dpct::group::load_algorithm::BLOCK_LOAD_STRIPED, dpct::group::store_algorithm::BLOCK_STORE_STRIPED>() &&
test_group_load_store_standalone_multiple_wgs<
dpct::group::load_algorithm::BLOCK_LOAD_DIRECT, dpct::group::store_algorithm::BLOCK_STORE_DIRECT>());
test_group_load_store_standalone<
dpct::group::load_algorithm::BLOCK_LOAD_STRIPED, dpct::group::store_algorithm::BLOCK_STORE_STRIPED>(range_multi, "test_group_load_store_standalone_multiple_wgs") &&
test_group_load_store_standalone<
dpct::group::load_algorithm::BLOCK_LOAD_DIRECT, dpct::group::store_algorithm::BLOCK_STORE_DIRECT>(range_multi, "test_group_load_store_standalone_multiple_wgs"));
}

0 comments on commit 5c318fe

Please sign in to comment.