Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements to onedpl_test_sort_by_key #719

Merged
merged 3 commits into from
Aug 26, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 18 additions & 35 deletions help_function/src/onedpl_test_sort_by_key.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ int main() {
// #14 SORT BY KEY TEST //

{
sycl::queue myQueue;
const int N = 6;
sycl::buffer<int, 1> keys_buf{ sycl::range<1>(N) };
sycl::buffer<int, 1> values_buf{ sycl::range<1>(N) };
Expand All @@ -116,21 +117,21 @@ int main() {
auto values_it = oneapi::dpl::begin(values_buf);

{
auto keys = keys_it.get_buffer().template get_access<sycl::access::mode::write>();
auto values = values_it.get_buffer().template get_access<sycl::access::mode::write>();
auto keys = keys_it.get_buffer().get_host_access();
auto values = values_it.get_buffer().get_host_access();

keys[0] = 1; keys[1] = 4; keys[2] = 2; keys[3] = 8; keys[4] = 5; keys[5] = 7;
values[0] = 'a'; values[1] = 'b'; values[2] = 'c'; values[3] = 'd'; values[4] = 'e';values[5] = 'f';
}

// call algorithm:
dpct::sort(oneapi::dpl::execution::dpcpp_default, keys_it, keys_it + N, values_it);
dpct::sort(oneapi::dpl::execution::make_device_policy<class kernel1>(myQueue), keys_it, keys_it + N, values_it);

// keys is now { 1, 2, 4, 5, 7, 8}
// values is now {'a', 'c', 'b', 'e', 'f', 'd'}
{
test_name = "Regular call to sort";
auto values = values_it.get_buffer().template get_access<sycl::access::mode::read>();
auto values = values_it.get_buffer().get_host_access();
num_failing += ASSERT_EQUAL(test_name, values[0], 'a');
num_failing += ASSERT_EQUAL(test_name, values[1], 'c');
num_failing += ASSERT_EQUAL(test_name, values[2], 'b');
Expand Down Expand Up @@ -170,7 +171,7 @@ int main() {
auto keys_end = dpct::device_pointer<int>(keysArray + 10);
auto values_begin = dpct::device_pointer<int>(valuesArray);
// call algorithm
dpct::sort(oneapi::dpl::execution::make_device_policy<>(myQueue), keys_begin, keys_end, values_begin);
dpct::sort(oneapi::dpl::execution::make_device_policy<class kernel2>(myQueue), keys_begin, keys_end, values_begin);
}

// copy back
Expand Down Expand Up @@ -201,40 +202,22 @@ int main() {

{
// Test Two, test calls to dpct::sort using device vectors
dpct::device_vector<int> keys_vec(10);
dpct::device_vector<int> values_vec(10);

std::vector<int> keys_data{4, 8, 5, 3, 0, 9, 7, 2, 1, 6};
std::vector<int> values_data{13, 16, 17, 11, 19, 14, 12, 18, 10, 15};

dpct::get_default_queue().submit([&](sycl::handler& h) {
h.memcpy(keys_vec.data(), keys_data.data(), 10 * sizeof(int));
});

dpct::get_default_queue().submit([&](sycl::handler& h) {
h.memcpy(values_vec.data(), values_data.data(), 10 * sizeof(int));
});
dpct::get_default_queue().wait();
dpct::device_vector<int> keys_vec(keys_data);
dpct::device_vector<int> values_vec(values_data);

auto keys_it = keys_vec.begin();
auto keys_it_end = keys_vec.end();
auto values_it = values_vec.begin();
{
// call algorithm
dpct::sort(oneapi::dpl::execution::make_device_policy<>(dpct::get_default_queue()), keys_it, keys_it_end, values_it);
dpct::sort(oneapi::dpl::execution::dpcpp_default, keys_it, keys_it_end, values_it);
// keys is now = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
// values is now = {19, 10, 18, 11, 13, 17, 15, 12, 16, 14}
}

dpct::get_default_queue().submit([&](sycl::handler& h) {
h.memcpy(keys_data.data(), keys_vec.data(), 10 * sizeof(int));
});

dpct::get_default_queue().submit([&](sycl::handler& h) {
h.memcpy(values_data.data(), values_vec.data(), 10 * sizeof(int));
});
dpct::get_default_queue().wait();

{
int check_keys[10] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
int check_values[10] = {19, 10, 18, 11, 13, 17, 15, 12, 16, 14};
Expand All @@ -243,8 +226,8 @@ int main() {
// check that values and keys are correct

for (int i = 0; i != 10; ++i) {
num_failing += ASSERT_EQUAL(test_name, values_data[i], check_values[i]);
num_failing += ASSERT_EQUAL(test_name, keys_data[i], check_keys[i]);
num_failing += ASSERT_EQUAL(test_name, values_vec[i], check_values[i]);
num_failing += ASSERT_EQUAL(test_name, keys_vec[i], check_keys[i]);
}

failed_tests += test_passed(num_failing, test_name);
Expand All @@ -264,8 +247,8 @@ int main() {
auto values_it = oneapi::dpl::begin(values_buf);

{
auto keys = keys_it.get_buffer().template get_access<sycl::access::mode::write>();
auto values = values_it.get_buffer().template get_access<sycl::access::mode::write>();
auto keys = keys_it.get_buffer().get_host_access();
auto values = values_it.get_buffer().get_host_access();
keys[0] = 1; keys[1] = 4; keys[2] = 2; keys[3] = 8; keys[4] = 5; keys[5] = 7;
values[0] = 'a'; values[1] = 'b'; values[2] = 'c'; values[3] = 'd'; values[4] = 'e';values[5] = 'f';
}
Expand All @@ -277,7 +260,7 @@ int main() {
// values is now {'a', 'c', 'b', 'e', 'f', 'd'}
{
test_name = "Regular call to stable_sort";
auto values = values_it.get_buffer().template get_access<sycl::access::mode::read>();
auto values = values_it.get_buffer().get_host_access();

num_failing += ASSERT_EQUAL(test_name, values[0], 'a');
num_failing += ASSERT_EQUAL(test_name, values[1], 'c');
Expand Down Expand Up @@ -357,8 +340,8 @@ int main() {
auto values_it = oneapi::dpl::begin(values_buf);

{
auto keys = keys_it.get_buffer().template get_access<sycl::access::mode::write>();
auto values = values_it.get_buffer().template get_access<sycl::access::mode::write>();
auto keys = keys_it.get_buffer().get_host_access();
auto values = values_it.get_buffer().get_host_access();
// keys = {8, 3, 0, 2, 6, 5, 1, 8, 9, 10, 7, 4, 5, 2, 2, 10}
keys[0] = 8; keys[1] = 3; keys[2] = 0; keys[3] = 2; keys[4] = 6; keys[5] = 5;
keys[6] = 1; keys[7] = 8; keys[8] = 9; keys[9] = 10; keys[10] = 7; keys[11] = 4;
Expand All @@ -375,8 +358,8 @@ int main() {
// keys is now = {0, 1, 2, 2, 2, 3, 4, 5, 5, 6, 7, 8, 8, 9, 10, 10}
// values is now = {'k', 'n', 'g', 'j', 'l', 'm', 'p', 'c', 'o', 'd', 'h', 'b', 'f', 'i', 'e', 'a'}
{
auto keys = keys_it.get_buffer().template get_access<sycl::access::mode::read>();
auto values = values_it.get_buffer().template get_access<sycl::access::mode::read>();
auto keys = keys_it.get_buffer().get_host_access();
auto values = values_it.get_buffer().get_host_access();
int check_values[16] = {'k', 'n', 'g', 'j', 'l', 'm', 'p', 'c', 'o', 'd', 'h', 'b', 'f', 'i', 'e', 'a'};
int check_keys[16] = {0, 1, 2, 2, 2, 3, 4, 5, 5, 6, 7, 8, 8, 9, 10, 10};
// check that values and keys are correct
Expand Down