Skip to content

Commit

Permalink
Merge pull request #839 from Nuullll/fix-group-scan
Browse files Browse the repository at this point in the history
[group] Refine scan_over_group for sub-group
  • Loading branch information
bader authored Dec 1, 2023
2 parents 6052dc6 + 0a7be0b commit 4993dc6
Showing 1 changed file with 81 additions and 42 deletions.
123 changes: 81 additions & 42 deletions tests/group_functions/group_scan.h
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,8 @@ struct ScanOverGroupDataStruct {
ScanOverGroupDataStruct(size_t range_size)
: ref_input(range_size),
res(range_size * 4, T(-1)),
local_id(range_size * 2, 0) {
local_id(range_size, 0),
sub_group_id(range_size, 0) {
std::iota(ref_input.begin(), ref_input.end(), U(1));
}

Expand All @@ -337,35 +338,70 @@ struct ScanOverGroupDataStruct {
T init_value = with_init ? T(init) : sycl::known_identity<OpT, T>::value;
// res consists of 4 series of results: two pairs of exclusive and inclusive
// scan results made over 'group' and 'sub_group' accordingly.
for (int group_i = 0; group_i < 2; group_i++) {
std::string group_name = group_i == 0 ? "group" : "sub_group";
size_t group_offset = range_size * group_i;
{
std::vector<T> reference(range_size, T(-1));
// There is only one work-group so we can scan over all the input data.
std::exclusive_scan(ref_input.begin(), ref_input.end(), reference.begin(),
init_value, op);
for (int i = 0; i < range_size; i++) {
int shift = i - local_id[i + group_offset];
auto startIter = ref_input.begin() + shift;
// Each group contains two sets of results.
size_t res_i = i + 2 * group_offset;
int res_i = i;
INFO("Check exclusive_scan_over_group on group for element " +
std::to_string(i) + " (Operator: " + op_name + ")");
INFO("Result: " + std::to_string(res[res_i]));
INFO("Expected: " + std::to_string(reference[i]));
CHECK(res[res_i] == reference[i]);
}
std::inclusive_scan(ref_input.begin(), ref_input.end(), reference.begin(),
op, init_value);
for (int i = 0; i < range_size; i++) {
int res_i = range_size + i;
INFO("Check inclusive_scan_over_group on group for element " +
std::to_string(i) + " (Operator: " + op_name + ")");
INFO("Result: " + std::to_string(res[res_i]));
INFO("Expected: " + std::to_string(reference[i]));
CHECK(res[res_i] == reference[i]);
}
}
{
// Mapping from "sub-group id" to "vector of input data (ordered by item
// linear id within the sub-group)"
std::unordered_map<size_t, std::vector<T>> ref_input_per_sub_group;
for (int i = 0; i < range_size; i++) {
size_t sgid = sub_group_id[i];
size_t lid = local_id[i];
std::vector<T>& input_vec = ref_input_per_sub_group[sgid];
// Extend input vector dynamically.
if (input_vec.size() <= lid) input_vec.resize(lid + 1);
// Place the data identified by (sgid, lid).
input_vec[lid] = ref_input[i];
}
// Compute the reference results and verify.
for (int i = 0; i < range_size; i++) {
size_t sgid = sub_group_id[i];
size_t lid = local_id[i];
const std::vector<T>& input_vec = ref_input_per_sub_group[sgid];
// Scan over the first (lid + 1) elements of input_vec to obtain the
// result identified by i.
std::vector<T> reference(lid + 1, T(-1));
std::exclusive_scan(input_vec.begin(), input_vec.begin() + lid + 1,
reference.begin(), init_value, op);
{
INFO("Check exclusive_scan_over_group on " + group_name +
" for element " + std::to_string(i) + " (Operator: " + op_name +
")");
std::vector<T> reference(i + 1, T(-1));
std::exclusive_scan(startIter, ref_input.begin() + i + 1,
reference.begin(), init_value, op);
int res_i = range_size * 2 + i;
INFO("Check exclusive_scan_over_group on sub_group for element " +
std::to_string(i) + " (Operator: " + op_name + ")");
INFO("Result: " + std::to_string(res[res_i]));
INFO("Expected: " + std::to_string(reference[i - shift]));
CHECK(res[res_i] == reference[i - shift]);
INFO("Expected: " + std::to_string(reference[lid]));
CHECK(res[res_i] == reference[lid]);
}
std::inclusive_scan(input_vec.begin(), input_vec.begin() + lid + 1,
reference.begin(), op, init_value);
{
INFO("Check inclusive_scan_over_group on " + group_name +
" for element " + std::to_string(i) + " (Operator: " + op_name +
")");
std::vector<T> reference(i + 1, T(-1));
std::inclusive_scan(startIter, ref_input.begin() + i + 1,
reference.begin(), op, init_value);
INFO("Result: " + std::to_string(res[res_i + range_size]));
INFO("Expected: " + std::to_string(reference[i - shift]));
CHECK(res[res_i + range_size] == reference[i - shift]);
int res_i = range_size * 3 + i;
INFO("Check inclusive_scan_over_group on sub_group for element " +
std::to_string(i) + " (Operator: " + op_name + ")");
INFO("Result: " + std::to_string(res[res_i]));
INFO("Expected: " + std::to_string(reference[lid]));
CHECK(res[res_i] == reference[lid]);
}
}
}
Expand All @@ -383,10 +419,15 @@ struct ScanOverGroupDataStruct {
return {local_id.data(), local_id.size()};
}

sycl::buffer<size_t, 1> create_sub_group_id_buffer() {
return {sub_group_id.data(), sub_group_id.size()};
}

std::vector<U> ref_input;
std::vector<T> res;
bool ret_type[4] = {false, false, false, false};
std::vector<size_t> local_id;
std::vector<size_t> sub_group_id;
};

template <int D, typename T, typename U = T, typename OpT>
Expand All @@ -402,6 +443,7 @@ void check_scan_over_group(sycl::queue& queue, sycl::range<D> range, OpT op,
auto res_sycl = host_data.create_res_buffer();
auto ret_type_sycl = host_data.create_ret_type_buffer();
auto local_id_sycl = host_data.create_local_id_buffer();
auto sub_group_id_sycl = host_data.create_sub_group_id_buffer();

queue
.submit([&](sycl::handler& cgh) {
Expand All @@ -410,18 +452,14 @@ void check_scan_over_group(sycl::queue& queue, sycl::range<D> range, OpT op,
sycl::accessor<T, 1> res_acc(res_sycl, cgh);
sycl::accessor<bool, 1> ret_type_acc(ret_type_sycl, cgh);
sycl::accessor<size_t, 1> local_id_acc(local_id_sycl, cgh);
sycl::accessor<size_t, 1> sub_group_id_acc(sub_group_id_sycl, cgh);

cgh.parallel_for<scan_over_group_kernel<D, T, U, OpT>>(
sycl::nd_range<D>(range, range), [=](sycl::nd_item<D> item) {
sycl::group<D> group = item.get_group();
sycl::sub_group sub_group = item.get_sub_group();

// Use the local id of the item in the group to place results of
// the scan operation in the order of the items.
auto g_index = group.get_group_linear_id() *
group.get_local_linear_range() +
group.get_local_linear_id();
local_id_acc[g_index] = group.get_local_linear_id();
auto g_index = item.get_global_linear_id();

auto res_g_e = exclusive_scan_over_group_helper<T>(
group, ref_input_acc[g_index], op, with_init);
Expand All @@ -433,22 +471,23 @@ void check_scan_over_group(sycl::queue& queue, sycl::range<D> range, OpT op,
res_acc[range_size + g_index] = res_g_i;
ret_type_acc[1] = std::is_same_v<T, decltype(res_g_i)>;

// Use the local id of the item in the sub-group to place
// results of the scan operation in the order of the items.
auto sg_index = sub_group.get_group_linear_id() *
sub_group.get_local_linear_range() +
sub_group.get_local_linear_id();
local_id_acc[range_size + sg_index] =
sub_group.get_local_linear_id();
// Input data is indexed by global linear id of item (g_index),
// however, sub-group partitioning and ordering are
// implementation-defined.
// Here we store both the sub-group id and item linear id within
// the sub-group so that we could recover the sub-group
// construction when verifying.
sub_group_id_acc[g_index] = sub_group.get_group_linear_id();
local_id_acc[g_index] = sub_group.get_local_linear_id();

auto res_sg_e = exclusive_scan_over_group_helper<T>(
sub_group, ref_input_acc[sg_index], op, with_init);
res_acc[range_size * 2 + sg_index] = res_sg_e;
sub_group, ref_input_acc[g_index], op, with_init);
res_acc[range_size * 2 + g_index] = res_sg_e;
ret_type_acc[2] = std::is_same_v<T, decltype(res_sg_e)>;

auto res_sg_i = inclusive_scan_over_group_helper<T>(
sub_group, ref_input_acc[sg_index], op, with_init);
res_acc[range_size * 3 + sg_index] = res_sg_i;
sub_group, ref_input_acc[g_index], op, with_init);
res_acc[range_size * 3 + g_index] = res_sg_i;
ret_type_acc[3] = std::is_same_v<T, decltype(res_sg_i)>;
});
})
Expand Down

0 comments on commit 4993dc6

Please sign in to comment.