Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce the number of kernels to compile in merge-sort #1872

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -202,67 +202,66 @@ struct __leaf_sorter
_GroupSorter __group_sorter;
};

// Please see the comment for __parallel_for_submitter for optional kernel name explanation
template <typename _IdType, typename _LeafSortName, typename _GlobalSortName, typename _CopyBackName>
struct __parallel_sort_submitter;

template <typename _IdType, typename... _LeafSortName, typename... _GlobalSortName, typename... _CopyBackName>
struct __parallel_sort_submitter<_IdType, __internal::__optional_kernel_name<_LeafSortName...>,
__internal::__optional_kernel_name<_GlobalSortName...>,
__internal::__optional_kernel_name<_CopyBackName...>>
template <typename _LeafSortName>
struct __merge_sort_leaf_submitter;

template <typename... _LeafSortName>
struct __merge_sort_leaf_submitter<__internal::__optional_kernel_name<_LeafSortName...>>
{
template <typename _ExecutionPolicy, typename _Range, typename _Compare, typename _LeafSorter>
auto
operator()(_ExecutionPolicy&& __exec, _Range&& __rng, _Compare __comp, _LeafSorter& __leaf_sorter) const
template <typename _Range, typename _Compare, typename _LeafSorter>
sycl::event
operator()(sycl::queue& __q, _Range& __rng, _Compare __comp, _LeafSorter& __leaf_sorter) const
{
using _Tp = oneapi::dpl::__internal::__value_t<_Range>;

const std::size_t __n = __rng.size();
assert(__n > 1);

const std::uint32_t __leaf = __leaf_sorter.__process_size;
assert((__leaf & (__leaf - 1)) == 0 && "Leaf size must be a power of 2");

// 1. Perform sorting of the leaves of the merge sort tree
sycl::event __event1 = __exec.queue().submit([&](sycl::handler& __cgh) {
return __q.submit([&](sycl::handler& __cgh) {
oneapi::dpl::__ranges::__require_access(__cgh, __rng);
auto __storage_acc = __leaf_sorter.create_storage_accessor(__cgh);
const std::uint32_t __wg_count = oneapi::dpl::__internal::__dpl_ceiling_div(__n, __leaf);
const std::uint32_t __wg_count =
oneapi::dpl::__internal::__dpl_ceiling_div(__rng.size(), __leaf_sorter.__process_size);
const sycl::nd_range<1> __nd_range(sycl::range<1>(__wg_count * __leaf_sorter.__workgroup_size),
sycl::range<1>(__leaf_sorter.__workgroup_size));
__cgh.parallel_for<_LeafSortName...>(
__nd_range, [=](sycl::nd_item<1> __item) { __leaf_sorter.sort(__item, __storage_acc); });
});
}
};

// 2. Merge sorting
oneapi::dpl::__par_backend_hetero::__buffer<_ExecutionPolicy, _Tp> __temp_buf(__exec, __n);
auto __temp = __temp_buf.get_buffer();
bool __data_in_temp = false;
_IdType __n_sorted = __leaf;
const bool __is_cpu = __exec.queue().get_device().is_cpu();
template <typename _IndexT, typename _GlobalSortName>
struct __merge_sort_global_submitter;

template <typename _IndexT, typename... _GlobalSortName>
struct __merge_sort_global_submitter<_IndexT, __internal::__optional_kernel_name<_GlobalSortName...>>
{
template <typename _Range, typename _Compare, typename _TempBuf, typename _LeafSizeT>
sycl::event
operator()(sycl::queue& __q, _Range& __rng, _Compare __comp, _LeafSizeT __leaf_size, _TempBuf& __temp_buf,
bool& __data_in_temp, sycl::event __event_chain) const
{
const _IndexT __n = __rng.size();
_IndexT __n_sorted = __leaf_size;
const bool __is_cpu = __q.get_device().is_cpu();
const std::uint32_t __chunk = __is_cpu ? 32 : 4;
const std::size_t __steps = oneapi::dpl::__internal::__dpl_ceiling_div(__n, __chunk);

const std::size_t __n_power2 = oneapi::dpl::__internal::__dpl_bit_ceil(__n);
// ctz precisely calculates log2 of an integral value which is a power of 2, while
// std::log2 may be prone to rounding errors on some architectures
const std::int64_t __n_iter = sycl::ctz(__n_power2) - sycl::ctz(__leaf);
const std::int64_t __n_iter = sycl::ctz(__n_power2) - sycl::ctz(__leaf_size);
for (std::int64_t __i = 0; __i < __n_iter; ++__i)
{
__event1 = __exec.queue().submit([&, __event1, __n_sorted, __data_in_temp](sycl::handler& __cgh) {
__cgh.depends_on(__event1);
__event_chain = __q.submit([&, __event_chain, __n_sorted, __data_in_temp](sycl::handler& __cgh) {
__cgh.depends_on(__event_chain);

oneapi::dpl::__ranges::__require_access(__cgh, __rng);
sycl::accessor __dst(__temp, __cgh, sycl::read_write, sycl::no_init);
sycl::accessor __dst(__temp_buf, __cgh, sycl::read_write, sycl::no_init);

__cgh.parallel_for<_GlobalSortName...>(
sycl::range</*dim=*/1>(__steps), [=](sycl::item</*dim=*/1> __item_id) {
const _IdType __i_elem = __item_id.get_linear_id() * __chunk;
const _IndexT __i_elem = __item_id.get_linear_id() * __chunk;
const auto __i_elem_local = __i_elem % (__n_sorted * 2);

const auto __offset = std::min<_IdType>(__i_elem - __i_elem_local, __n);
const auto __n1 = std::min<_IdType>(__offset + __n_sorted, __n) - __offset;
const auto __n2 = std::min<_IdType>(__offset + __n1 + __n_sorted, __n) - (__offset + __n1);
const auto __offset = std::min<_IndexT>(__i_elem - __i_elem_local, __n);
const auto __n1 = std::min<_IndexT>(__offset + __n_sorted, __n) - __offset;
const auto __n2 = std::min<_IndexT>(__offset + __n1 + __n_sorted, __n) - (__offset + __n1);

if (__data_in_temp)
{
Expand All @@ -287,23 +286,32 @@ struct __parallel_sort_submitter<_IdType, __internal::__optional_kernel_name<_Le
__n_sorted *= 2;
__data_in_temp = !__data_in_temp;
}
return __event_chain;
}
};

// 3. If the data remained in the temporary buffer then copy it back
if (__data_in_temp)
{
__event1 = __exec.queue().submit([&, __event1](sycl::handler& __cgh) {
__cgh.depends_on(__event1);
oneapi::dpl::__ranges::__require_access(__cgh, __rng);
auto __temp_acc = __temp.template get_access<access_mode::read>(__cgh);
// We cannot use __cgh.copy here because of zip_iterator usage
__cgh.parallel_for<_CopyBackName...>(sycl::range</*dim=*/1>(__n), [=](sycl::item</*dim=*/1> __item_id) {
const _IdType __idx = __item_id.get_linear_id();
__rng[__idx] = __temp_acc[__idx];
});
});
}
template <typename _CopyBackName>
struct __merge_sort_copy_back_submitter;

return __future(__event1);
template <typename... _CopyBackName>
struct __merge_sort_copy_back_submitter<__internal::__optional_kernel_name<_CopyBackName...>>
{
template <typename _Range, typename _TempBuf>
sycl::event
operator()(sycl::queue& __q, _Range& __rng, _TempBuf& __temp_buf, sycl::event __event_chain) const
{
__event_chain = __q.submit([&, __event_chain](sycl::handler& __cgh) {
__cgh.depends_on(__event_chain);
oneapi::dpl::__ranges::__require_access(__cgh, __rng);
auto __temp_acc = __temp_buf.template get_access<access_mode::read>(__cgh);
// We cannot use __cgh.copy here because of zip_iterator usage
__cgh.parallel_for<_CopyBackName...>(sycl::range</*dim=*/1>(__rng.size()),
[=](sycl::item</*dim=*/1> __item_id) {
const std::uint64_t __idx = __item_id.get_linear_id();
__rng[__idx] = __temp_acc[__idx];
});
});
return __event_chain;
}
};

Expand All @@ -316,22 +324,50 @@ class __sort_global_kernel;
template <typename... _Name>
class __sort_copy_back_kernel;

template <typename _IndexT, typename _ExecutionPolicy, typename _Range, typename _Compare>
template <typename _IndexT, typename _ExecutionPolicy, typename _Range, typename _Compare, typename _LeafSorter>
auto
__submit_selecting_leaf(_ExecutionPolicy&& __exec, _Range&& __rng, _Compare __comp)
__merge_sort(_ExecutionPolicy&& __exec, _Range&& __rng, _Compare __comp, _LeafSorter& __leaf_sorter)
{
using _Leaf = __leaf_sorter<std::decay_t<_Range>, _Compare>;
using _Tp = oneapi::dpl::__internal::__value_t<_Range>;

using _CustomName = oneapi::dpl::__internal::__policy_kernel_name<_ExecutionPolicy>;
// TODO: split the submitter into multiple ones to avoid extra compilation of kernels:
// - _LeafSortKernel and _CopyBackKernel do not need _IndexT
using _LeafSortKernel = oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider<
__sort_leaf_kernel<_CustomName, _IndexT>>;
using _LeafSortKernel =
oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider<__sort_leaf_kernel<_CustomName>>;
using _GlobalSortKernel = oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider<
__sort_global_kernel<_CustomName, _IndexT>>;
using _CopyBackKernel = oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider<
__sort_copy_back_kernel<_CustomName, _IndexT>>;
using _CopyBackKernel =
oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider<__sort_copy_back_kernel<_CustomName>>;

assert(__rng.size() > 1);
assert((__leaf_sorter.__process_size & (__leaf_sorter.__process_size - 1)) == 0 &&
"Leaf size must be a power of 2");

auto __q = __exec.queue();

// 1. Perform sorting of the leaves of the merge sort tree
auto __event_chain = __merge_sort_leaf_submitter<_LeafSortKernel>()(__q, __rng, __comp, __leaf_sorter);

// 2. Merge sorting
oneapi::dpl::__par_backend_hetero::__buffer<_ExecutionPolicy, _Tp> __temp(__exec, __rng.size());
auto __temp_buf = __temp.get_buffer();
bool __data_in_temp = false;
__event_chain = __merge_sort_global_submitter<_IndexT, _GlobalSortKernel>()(
__q, __rng, __comp, __leaf_sorter.__process_size, __temp_buf, __data_in_temp, __event_chain);

// 3. If the data remained in the temporary buffer then copy it back
if (__data_in_temp)
{
__event_chain = __merge_sort_copy_back_submitter<_CopyBackKernel>()(__q, __rng, __temp_buf, __event_chain);
}
return __future(__event_chain);
}

template <typename _IndexT, typename _ExecutionPolicy, typename _Range, typename _Compare>
auto
__submit_selecting_leaf(_ExecutionPolicy&& __exec, _Range&& __rng, _Compare __comp)
{
using _Leaf = __leaf_sorter<std::decay_t<_Range>, _Compare>;
using _Tp = oneapi::dpl::__internal::__value_t<_Range>;

const std::size_t __n = __rng.size();
auto __device = __exec.queue().get_device();
Expand Down Expand Up @@ -376,8 +412,7 @@ __submit_selecting_leaf(_ExecutionPolicy&& __exec, _Range&& __rng, _Compare __co
__wg_size = oneapi::dpl::__internal::__dpl_bit_floor(__wg_size);

_Leaf __leaf(__rng, __comp, __data_per_workitem, __wg_size);
return __parallel_sort_submitter<_IndexT, _LeafSortKernel, _GlobalSortKernel, _CopyBackKernel>()(
std::forward<_ExecutionPolicy>(__exec), std::forward<_Range>(__rng), __comp, __leaf);
return __merge_sort<_IndexT>(std::forward<_ExecutionPolicy>(__exec), std::forward<_Range>(__rng), __comp, __leaf);
};

template <typename _ExecutionPolicy, typename _Range, typename _Compare>
Expand Down
Loading