Skip to content

Commit

Permalink
Resolve register spills in dispatch of __subgroup_radix_sort (#1626)
Browse files Browse the repository at this point in the history
This PR resolves register spills leading to runtime crashes on certain accelerators in single work-group radix sort by first checking for sub-group size 16 support before dispatching __subgroup_radix_sort.

---------

Signed-off-by: Matthew Michel <[email protected]>
  • Loading branch information
mmichel11 authored Jun 20, 2024
1 parent 224acf7 commit c36337d
Showing 1 changed file with 11 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <type_traits>
#include <utility>
#include <cstdint>
#include <algorithm>

#include "sycl_defs.h"
#include "parallel_backend_sycl_utils.h"
Expand Down Expand Up @@ -778,6 +779,9 @@ __parallel_radix_sort(oneapi::dpl::__internal::__device_backend_tag, _ExecutionP

//TODO: 1.to reduce number of the kernels; 2.to define work group size in runtime, depending on number of elements
constexpr auto __wg_size = 64;
const auto __subgroup_sizes = __exec.queue().get_device().template get_info<sycl::info::device::sub_group_sizes>();
const bool __dev_has_sg16 = std::find(__subgroup_sizes.begin(), __subgroup_sizes.end(),
static_cast<std::size_t>(16)) != __subgroup_sizes.end();

//TODO: with _RadixSortKernel also the following a couple of compile time constants is used for unique kernel name
using _RadixSortKernel = oneapi::dpl::__internal::__policy_kernel_name<_ExecutionPolicy>;
Expand All @@ -803,10 +807,15 @@ __parallel_radix_sort(oneapi::dpl::__internal::__device_backend_tag, _ExecutionP
else if (__n <= 4096 && __wg_size * 4 <= __max_wg_size)
__event = __subgroup_radix_sort<_RadixSortKernel, __wg_size * 4, 16, __radix_bits, __is_ascending>{}(
__exec.queue(), ::std::forward<_Range>(__in_rng), __proj);
else if (__n <= 8192 && __wg_size * 8 <= __max_wg_size)
// In __subgroup_radix_sort, we request a sub-group size via _ONEDPL_SYCL_REQD_SUB_GROUP_SIZE_IF_SUPPORTED
// based upon the iters per item. For the below cases, register spills that result in runtime exceptions have
// been observed on accelerators that do not support the requested sub-group size of 16. For the above cases
// that request but may not receive a sub-group size of 16, inputs are small enough to avoid register
// spills on assessed hardware.
else if (__n <= 8192 && __wg_size * 8 <= __max_wg_size && __dev_has_sg16)
__event = __subgroup_radix_sort<_RadixSortKernel, __wg_size * 8, 16, __radix_bits, __is_ascending>{}(
__exec.queue(), ::std::forward<_Range>(__in_rng), __proj);
else if (__n <= 16384 && __wg_size * 8 <= __max_wg_size)
else if (__n <= 16384 && __wg_size * 8 <= __max_wg_size && __dev_has_sg16)
__event = __subgroup_radix_sort<_RadixSortKernel, __wg_size * 8, 32, __radix_bits, __is_ascending>{}(
__exec.queue(), ::std::forward<_Range>(__in_rng), __proj);
else
Expand Down

0 comments on commit c36337d

Please sign in to comment.