Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix overflow issues in scan tests. #838

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions tests/group_functions/group_scan.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
template <int D, typename T, typename U, typename I, typename OpT>
class joint_scan_group_kernel;

// This should never be higher than std::numeric_limits<T>::max() for the
// smallest type tested. Currently, the smallest type tested is
// char/int8_t, so it shouldn't be higher than 127.
constexpr int init = 42;
constexpr size_t test_size = 12;

Expand Down Expand Up @@ -54,14 +57,25 @@ auto joint_exclusive_scan_helper(Group group, T* v_begin, T* v_end,
op);
}

template <typename T, typename U>
template <typename T, typename U, typename I, typename OpT>
struct JointScanDataStruct {
JointScanDataStruct(size_t range_size)
JointScanDataStruct(size_t range_size, OpT op, bool with_init)
: ref_input(range_size), res(range_size * 4, U(-1)) {
std::iota(ref_input.begin(), ref_input.end(), T(1));
if constexpr (std::is_same_v<OpT, sycl::multiplies<I>> ||
std::is_same_v<OpT, sycl::plus<I>>) {
auto identity = sycl::known_identity_v<OpT, I>;
auto acc = with_init ? I{init} : identity;
for (size_t i = 0; i < range_size; ++i) {
I tmp = op(I(acc), I(ref_input[i]));
if (tmp > std::numeric_limits<U>::max()) {
ref_input[i] = identity;
}
acc = op(acc, ref_input[i]);
}
}
}

template <typename I, typename OpT>
void check_results(size_t range_size, OpT op, const std::string& op_name,
bool with_init) {
CHECK(end[0]);
Expand Down Expand Up @@ -128,7 +142,7 @@ template <int D, typename T, typename U, typename I = U, typename OpT>
void check_scan(sycl::queue& queue, size_t size,
sycl::nd_range<D> executionRange, OpT op,
const std::string& op_name, bool with_init) {
JointScanDataStruct<T, U> host_data{size};
JointScanDataStruct<T, U, I, OpT> host_data{size, op, with_init};
{
sycl::buffer<T, 1> ref_input_sycl = host_data.create_ref_input_buffer();
sycl::buffer<U, 1> res_sycl = host_data.create_res_buffer();
Expand Down Expand Up @@ -180,7 +194,7 @@ void check_scan(sycl::queue& queue, size_t size,
.wait_and_throw();
}

host_data.template check_results<I>(size, op, op_name, with_init);
host_data.check_results(size, op, op_name, with_init);
}

/**
Expand Down Expand Up @@ -393,8 +407,6 @@ template <int D, typename T, typename U = T, typename OpT>
void check_scan_over_group(sycl::queue& queue, sycl::range<D> range, OpT op,
const std::string& op_name, bool with_init) {
auto range_size = range.size();
REQUIRE(((range_size * (range_size + 1) / 2) + T(init)) <=
std::numeric_limits<T>::max());

ScanOverGroupDataStruct<T, U> host_data{range_size};
{
Expand Down