From db59add8e8a1012b536b87ba8d9ecbd43dcf1256 Mon Sep 17 00:00:00 2001 From: "Maronas, Marcos" Date: Wed, 22 Nov 2023 07:36:02 -0800 Subject: [PATCH 1/3] Fix overflow issues in scan tests. Signed-off-by: Maronas, Marcos --- tests/group_functions/group_scan.h | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/tests/group_functions/group_scan.h b/tests/group_functions/group_scan.h index d15bd8bda..8c7ed7980 100644 --- a/tests/group_functions/group_scan.h +++ b/tests/group_functions/group_scan.h @@ -25,6 +25,9 @@ template class joint_scan_group_kernel; +// This should never be higher than std::numeric_limits::max() for the +// smallest type tested. Currently, the smallest type tested is +// char/int8_t, so it shouldn't be higher than 127. constexpr int init = 42; constexpr size_t test_size = 12; @@ -54,14 +57,26 @@ auto joint_exclusive_scan_helper(Group group, T* v_begin, T* v_end, op); } -template +template struct JointScanDataStruct { - JointScanDataStruct(size_t range_size) + JointScanDataStruct(size_t range_size, OpT op, bool with_init) : ref_input(range_size), res(range_size * 4, U(-1)) { + std::iota(ref_input.begin(), ref_input.end(), T(1)); + if constexpr (std::is_same_v> || + std::is_same_v>) { + auto identity = sycl::known_identity_v; + auto acc = with_init ? I(init) : identity; + for (size_t i = 0; i < range_size; ++i) { + I tmp = op(I(acc), I(ref_input[i])); + if (tmp > std::numeric_limits::max()) { + ref_input[i] = identity; + } + acc = op(acc, ref_input[i]); + } + } } - template void check_results(size_t range_size, OpT op, const std::string& op_name, bool with_init) { CHECK(end[0]); @@ -128,7 +143,7 @@ template void check_scan(sycl::queue& queue, size_t size, sycl::nd_range executionRange, OpT op, const std::string& op_name, bool with_init) { - JointScanDataStruct host_data{size}; + JointScanDataStruct host_data{size, op, with_init}; { sycl::buffer ref_input_sycl = host_data.create_ref_input_buffer(); sycl::buffer res_sycl = host_data.create_res_buffer(); @@ -180,7 +195,7 @@ void check_scan(sycl::queue& queue, size_t size, .wait_and_throw(); } - host_data.template check_results(size, op, op_name, with_init); + host_data.check_results(size, op, op_name, with_init); } /** @@ -393,8 +408,6 @@ template void check_scan_over_group(sycl::queue& queue, sycl::range range, OpT op, const std::string& op_name, bool with_init) { auto range_size = range.size(); - REQUIRE(((range_size * (range_size + 1) / 2) + T(init)) <= - std::numeric_limits::max()); ScanOverGroupDataStruct host_data{range_size}; { From eb46016e1ae496dca1ccd4c212902e3b9d3ef3c9 Mon Sep 17 00:00:00 2001 From: "Maronas, Marcos" Date: Wed, 22 Nov 2023 12:35:34 -0800 Subject: [PATCH 2/3] Fix clang-format issue. --- tests/group_functions/group_scan.h | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/group_functions/group_scan.h b/tests/group_functions/group_scan.h index 8c7ed7980..f2721884b 100644 --- a/tests/group_functions/group_scan.h +++ b/tests/group_functions/group_scan.h @@ -61,7 +61,6 @@ template struct JointScanDataStruct { JointScanDataStruct(size_t range_size, OpT op, bool with_init) : ref_input(range_size), res(range_size * 4, U(-1)) { - std::iota(ref_input.begin(), ref_input.end(), T(1)); if constexpr (std::is_same_v> || std::is_same_v>) { From f7ef9925d499dd3e416b8a8affd8a22171f37753 Mon Sep 17 00:00:00 2001 From: "Maronas, Marcos" Date: Fri, 1 Dec 2023 02:30:33 -0800 Subject: [PATCH 3/3] Address code review comments. Signed-off-by: Maronas, Marcos --- tests/group_functions/group_scan.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/group_functions/group_scan.h b/tests/group_functions/group_scan.h index f2721884b..f9b7d1eaf 100644 --- a/tests/group_functions/group_scan.h +++ b/tests/group_functions/group_scan.h @@ -65,7 +65,7 @@ struct JointScanDataStruct { if constexpr (std::is_same_v> || std::is_same_v>) { auto identity = sycl::known_identity_v; - auto acc = with_init ? I(init) : identity; + auto acc = with_init ? I{init} : identity; for (size_t i = 0; i < range_size; ++i) { I tmp = op(I(acc), I(ref_input[i])); if (tmp > std::numeric_limits::max()) {