From 57afb28d9fcc72e6212008b4cdd2d3b717177b85 Mon Sep 17 00:00:00 2001 From: Yilong Guo Date: Wed, 29 Nov 2023 15:51:27 +0800 Subject: [PATCH] Refine scan_over_group for sub-group Reorder ref_input according to the actual sub-group partitioning and ordering. --- tests/group_functions/group_scan.h | 100 ++++++++++++++++++++++------- 1 file changed, 76 insertions(+), 24 deletions(-) diff --git a/tests/group_functions/group_scan.h b/tests/group_functions/group_scan.h index ef43f1dfb..f8a10e16f 100644 --- a/tests/group_functions/group_scan.h +++ b/tests/group_functions/group_scan.h @@ -322,7 +322,8 @@ struct ScanOverGroupDataStruct { ScanOverGroupDataStruct(size_t range_size) : ref_input(range_size), res(range_size * 4, T(-1)), - local_id(range_size * 2, 0) { + local_id(range_size, 0), + sub_group_id(range_size, 0) { std::iota(ref_input.begin(), ref_input.end(), U(1)); } @@ -337,35 +338,76 @@ struct ScanOverGroupDataStruct { T init_value = with_init ? T(init) : sycl::known_identity::value; // res consists of 4 series of results: two pairs of exclusive and inclusive // scan results made over 'group' and 'sub_group' accordingly. - for (int group_i = 0; group_i < 2; group_i++) { - std::string group_name = group_i == 0 ? "group" : "sub_group"; - size_t group_offset = range_size * group_i; + { + std::string group_name = "group"; + std::vector reference(range_size, T(-1)); + // There is only one work-group so we can scan over all the input data. + std::exclusive_scan(ref_input.begin(), ref_input.end(), reference.begin(), + init_value, op); for (int i = 0; i < range_size; i++) { - int shift = i - local_id[i + group_offset]; - auto startIter = ref_input.begin() + shift; - // Each group contains two sets of results. - size_t res_i = i + 2 * group_offset; + int res_i = i; + INFO("Check exclusive_scan_over_group on " + group_name + + " for element " + std::to_string(i) + " (Operator: " + op_name + + ")"); + INFO("Result: " + std::to_string(res[res_i])); + INFO("Expected: " + std::to_string(reference[i])); + CHECK(res[res_i] == reference[i]); + } + std::inclusive_scan(ref_input.begin(), ref_input.end(), reference.begin(), + op, init_value); + for (int i = 0; i < range_size; i++) { + int res_i = range_size + i; + INFO("Check inclusive_scan_over_group on " + group_name + + " for element " + std::to_string(i) + " (Operator: " + op_name + + ")"); + INFO("Result: " + std::to_string(res[res_i])); + INFO("Expected: " + std::to_string(reference[i])); + CHECK(res[res_i] == reference[i]); + } + } + { + std::string group_name = "sub_group"; + // Mapping from "sub-group id" to "vector of input data (ordered by item + // linear id within the sub-group)" + std::unordered_map> ref_input_per_sub_group; + for (int i = 0; i < range_size; i++) { + size_t sgid = sub_group_id[i]; + size_t lid = local_id[i]; + std::vector& input_vec = ref_input_per_sub_group[sgid]; + // Extend input vector dynamically. + if (input_vec.size() <= lid) input_vec.resize(lid + 1); + // Place the data identified by (sgid, lid). + input_vec[lid] = ref_input[i]; + } + // Compute the reference results and verify. + for (int i = 0; i < range_size; i++) { + size_t sgid = sub_group_id[i]; + size_t lid = local_id[i]; + const std::vector& input_vec = ref_input_per_sub_group[sgid]; + // Scan over the first (lid + 1) elements of input_vec to obtain the + // result identified by i. + std::vector reference(lid + 1, T(-1)); + std::exclusive_scan(input_vec.begin(), input_vec.begin() + lid + 1, + reference.begin(), init_value, op); { + int res_i = range_size * 2 + i; INFO("Check exclusive_scan_over_group on " + group_name + " for element " + std::to_string(i) + " (Operator: " + op_name + ")"); - std::vector reference(i + 1, T(-1)); - std::exclusive_scan(startIter, ref_input.begin() + i + 1, - reference.begin(), init_value, op); INFO("Result: " + std::to_string(res[res_i])); - INFO("Expected: " + std::to_string(reference[i - shift])); - CHECK(res[res_i] == reference[i - shift]); + INFO("Expected: " + std::to_string(reference[lid])); + CHECK(res[res_i] == reference[lid]); } + std::inclusive_scan(input_vec.begin(), input_vec.begin() + lid + 1, + reference.begin(), op, init_value); { + int res_i = range_size * 3 + i; INFO("Check inclusive_scan_over_group on " + group_name + " for element " + std::to_string(i) + " (Operator: " + op_name + ")"); - std::vector reference(i + 1, T(-1)); - std::inclusive_scan(startIter, ref_input.begin() + i + 1, - reference.begin(), op, init_value); - INFO("Result: " + std::to_string(res[res_i + range_size])); - INFO("Expected: " + std::to_string(reference[i - shift])); - CHECK(res[res_i + range_size] == reference[i - shift]); + INFO("Result: " + std::to_string(res[res_i])); + INFO("Expected: " + std::to_string(reference[lid])); + CHECK(res[res_i] == reference[lid]); } } } @@ -383,10 +425,15 @@ struct ScanOverGroupDataStruct { return {local_id.data(), local_id.size()}; } + sycl::buffer create_sub_group_id_buffer() { + return {sub_group_id.data(), sub_group_id.size()}; + } + std::vector ref_input; std::vector res; bool ret_type[4] = {false, false, false, false}; std::vector local_id; + std::vector sub_group_id; }; template @@ -402,6 +449,7 @@ void check_scan_over_group(sycl::queue& queue, sycl::range range, OpT op, auto res_sycl = host_data.create_res_buffer(); auto ret_type_sycl = host_data.create_ret_type_buffer(); auto local_id_sycl = host_data.create_local_id_buffer(); + auto sub_group_id_sycl = host_data.create_sub_group_id_buffer(); queue .submit([&](sycl::handler& cgh) { @@ -410,16 +458,14 @@ void check_scan_over_group(sycl::queue& queue, sycl::range range, OpT op, sycl::accessor res_acc(res_sycl, cgh); sycl::accessor ret_type_acc(ret_type_sycl, cgh); sycl::accessor local_id_acc(local_id_sycl, cgh); + sycl::accessor sub_group_id_acc(sub_group_id_acc, cgh); cgh.parallel_for>( sycl::nd_range(range, range), [=](sycl::nd_item item) { sycl::group group = item.get_group(); sycl::sub_group sub_group = item.get_sub_group(); - // Use the global linear id of the item in the group to place - // results of the scan operation in the order of the items. auto g_index = item.get_global_linear_id(); - local_id_acc[g_index] = group.get_local_linear_id(); auto res_g_e = exclusive_scan_over_group_helper( group, ref_input_acc[g_index], op, with_init); @@ -431,8 +477,14 @@ void check_scan_over_group(sycl::queue& queue, sycl::range range, OpT op, res_acc[range_size + g_index] = res_g_i; ret_type_acc[1] = std::is_same_v; - local_id_acc[range_size + g_index] = - sub_group.get_local_linear_id(); + // Input data is indexed by global linear id of item (g_index), + // however, sub-group partitioning and ordering are + // implementation-defined. + // Here we store both the sub-group id and item linear id within + // the sub-group so that we could recover the sub-group + // construction when verifying. + sub_group_id_acc[g_index] = sub_group.get_group_linear_id(); + local_id_acc[g_index] = sub_group.get_local_linear_id(); auto res_sg_e = exclusive_scan_over_group_helper( sub_group, ref_input_acc[g_index], op, with_init);