Skip to content

Commit

Permalink
add multiple workgroups
Browse files Browse the repository at this point in the history
  • Loading branch information
abhilash1910 authored Jun 19, 2024
1 parent 703f9ed commit 099fe54
Showing 1 changed file with 136 additions and 0 deletions.
136 changes: 136 additions & 0 deletions help_function/src/util_group_load_store_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,46 @@ template <dpct::group::load_algorithm T, dpct::group::store_algorithm S> bool te
return helper_validation_function(ptr, "test_group_load_store");
}

template <dpct::group::load_algorithm T, dpct::group::store_algorithm S> bool test_group_load_store_multiple_wgs() {
// Tests dpct::group::load_algorithm::BLOCK_LOAD_DIRECT
// dpct::group::load_algorithm::BLOCK_LOAD_STRIPED in its entirety as API
// functions
// Tests dpct::group::store_algorithm::BLOCK_STORE_DIRECT
// dpct::group::store_algorithm::BLOCK_STORE_STRIPED in its entirety as API
// functions
sycl::queue q(dpct::get_default_queue());
oneapi::dpl::counting_iterator<int> count_it(0);
sycl::buffer<int, 1> buffer(count_it, count_it + 512);

q.submit([&](sycl::handler &h) {
using group_load =
dpct::group::workgroup_load<4, T, int, const int *, sycl::nd_item<3>>;
using group_store =
dpct::group::workgroup_store<4, S, int, int *, sycl::nd_item<3>>;
size_t temp_storage_size = group_load::get_local_memory_size(128);
sycl::local_accessor<uint8_t, 1> tacc(sycl::range<1>(temp_storage_size), h);
sycl::accessor data_accessor_read_write(buffer, h, sycl::read_write);
h.parallel_for(
sycl::nd_range<3>(sycl::range<3>(2, 2, 64), sycl::range<3>(1, 1, 64)),
[=](sycl::nd_item<3> item) {
int thread_data[4];
auto *d_r_w =
data_accessor_read_write.get_multi_ptr<sycl::access::decorated::yes>()
.get();
auto *tmp = tacc.get_multi_ptr<sycl::access::decorated::yes>().get();
// Load thread_data of each work item to blocked arrangement
group_load(tmp).load(item, d_r_w, thread_data);
// Store thread_data of each work item from blocked arrangement
group_store(tmp).store(item, d_r_w, thread_data);
});
});
q.wait_and_throw();

sycl::host_accessor data_accessor(buffer, sycl::read_write);
const int *ptr = data_accessor.get_multi_ptr<sycl::access::decorated::yes>();
return helper_validation_function(ptr, "test_group_load_store");
}

bool test_load_store_subgroup_striped_standalone() {
// Tests dpct::group::load_subgroup_striped as standalone method
sycl::queue q(dpct::get_default_queue());
Expand Down Expand Up @@ -113,6 +153,50 @@ bool test_load_store_subgroup_striped_standalone() {
ptr, "test_subgroup_striped_standalone");
}

bool test_load_store_subgroup_striped_standalone_multiple_wgs() {
// Tests dpct::group::load_subgroup_striped as standalone method
sycl::queue q(dpct::get_default_queue());
int data[512];
for (int i = 0; i < 512; i++)
data[i] = i;
sycl::buffer<int, 1> buffer(data, 512);
sycl::buffer<uint32_t, 1> sg_sz_buf{sycl::range<1>(1)};

q.submit([&](sycl::handler &h) {
sycl::accessor dacc_read_write(buffer, h, sycl::read_write);
sycl::accessor sg_sz_dacc(sg_sz_buf, h, sycl::read_write);
h.parallel_for(
sycl::nd_range<3>(sycl::range<3>(2, 2, 64), sycl::range<3>(1, 1, 64)),
[=](sycl::nd_item<3> item) {
int thread_data[4];
auto *d_r_w =
dacc_read_write.get_multi_ptr<sycl::access::decorated::yes>().get();
auto *sg_sz_acc =
sg_sz_dacc.get_multi_ptr<sycl::access::decorated::yes>().get();
size_t gid = item.get_global_linear_id();
if (gid == 0) {
sg_sz_acc[0] = item.get_sub_group().get_local_linear_range();
}
dpct::group::uninitialized_load_subgroup_striped<4, int>(item, d_r_w,
thread_data);
dpct::group::store_subgroup_striped<4, int>(item, d_r_w, thread_data);
//call destructor of thread type
for (size_t i = 0; i < 4; ++i) {
thread_data[i].~int();
}
});
});
q.wait_and_throw();

sycl::host_accessor data_accessor(buffer, sycl::read_only);
const int *ptr = data_accessor.get_multi_ptr<sycl::access::decorated::yes>();
sycl::host_accessor data_accessor_sg(sg_sz_buf, sycl::read_only);
const uint32_t *ptr_sg =
data_accessor_sg.get_multi_ptr<sycl::access::decorated::yes>();
return helper_validation_function(
ptr, "test_subgroup_striped_standalone");
}

template <dpct::group::load_algorithm T, dpct::group::store_algorithm S> bool test_group_load_store_standalone() {
// Tests dpct::group::load_algorithm::BLOCK_LOAD_DIRECT &
// dpct::group::load_algorithm::BLOCK_LOAD_STRIPED as standalone methods
Expand Down Expand Up @@ -153,6 +237,47 @@ template <dpct::group::load_algorithm T, dpct::group::store_algorithm S> bool te
return helper_validation_function(ptr, "test_group_load_store");
}

template <dpct::group::load_algorithm T, dpct::group::store_algorithm S> bool test_group_load_store_standalone_multi_wgs() {
// Tests dpct::group::load_algorithm::BLOCK_LOAD_DIRECT &
// dpct::group::load_algorithm::BLOCK_LOAD_STRIPED as standalone methods
// Tests dpct::group::store_algorithm::BLOCK_STORE_DIRECT &
// dpct::group::store_algorithm::BLOCK_STORE_STRIPED as standalone methods
sycl::queue q(dpct::get_default_queue());
int data[512];
for (int i = 0; i < 512; i++)
data[i] = i;
sycl::buffer<int, 1> buffer(data, 512);

q.submit([&](sycl::handler &h) {
sycl::accessor dacc_read_write(buffer, h, sycl::read_write);
h.parallel_for(
sycl::nd_range<3>(sycl::range<3>(2, 2, 64), sycl::range<3>(1, 1, 64)),
[=](sycl::nd_item<3> item) {
int thread_data[4];
auto *d_r_w =
dacc_read_write.get_multi_ptr<sycl::access::decorated::yes>().get();
// Load thread_data of each work item to blocked arrangement
if (T == dpct::group::load_algorithm::BLOCK_LOAD_DIRECT) {
dpct::group::load_blocked<4, int>(item, d_r_w, thread_data);
} else {
dpct::group::load_striped<4, int>(item, d_r_w, thread_data);
}
// Store thread_data of each work item from blocked arrangement
if (S == dpct::group::store_algorithm::BLOCK_STORE_DIRECT) {
dpct::group::store_blocked<4, int>(item, d_r_w, thread_data);
} else {
dpct::group::store_striped<4, int>(item, d_r_w, thread_data);
}
});
});
q.wait_and_throw();

sycl::host_accessor data_accessor(buffer, sycl::read_only);
const int *ptr = data_accessor.get_multi_ptr<sycl::access::decorated::yes>();
return helper_validation_function(ptr, "test_group_load_store");
}


int main() {

return !(
Expand All @@ -167,5 +292,16 @@ int main() {
test_group_load_store_standalone<
dpct::group::load_algorithm::BLOCK_LOAD_STRIPED, dpct::group::store_algorithm::BLOCK_STORE_STRIPED>() &&
test_group_load_store_standalone<
dpct::group::load_algorithm::BLOCK_LOAD_DIRECT, dpct::group::store_algorithm::BLOCK_STORE_DIRECT>()) &&

test_group_load_store_multiple_wgs<dpct::group::load_algorithm::BLOCK_LOAD_DIRECT, dpct::group::store_algorithm::BLOCK_STORE_DIRECT>() &&
test_group_load_store_multiple_wgs<dpct::group::load_algorithm::BLOCK_LOAD_STRIPED, dpct::group::store_algorithm::BLOCK_STORE_STRIPED>() &&
// Calls test_load_subgroup_striped_standalone and should pass
test_load_store_subgroup_striped_standalone_multiple_wgs() &&
// Calls test_group_load_standalone with blocked and striped strategies as
// free functions, should pass both results.
test_group_load_store_standalone_multiple_wgs<
dpct::group::load_algorithm::BLOCK_LOAD_STRIPED, dpct::group::store_algorithm::BLOCK_STORE_STRIPED>() &&
test_group_load_store_standalone_multiple_wgs<
dpct::group::load_algorithm::BLOCK_LOAD_DIRECT, dpct::group::store_algorithm::BLOCK_STORE_DIRECT>());
}

0 comments on commit 099fe54

Please sign in to comment.