diff --git a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_radix_sort.h b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_radix_sort.h index b46fb50c831..b625d3ffd18 100644 --- a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_radix_sort.h +++ b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_radix_sort.h @@ -20,6 +20,7 @@ #include #include #include +#include #include "sycl_defs.h" #include "parallel_backend_sycl_utils.h" @@ -778,6 +779,9 @@ __parallel_radix_sort(oneapi::dpl::__internal::__device_backend_tag, _ExecutionP //TODO: 1.to reduce number of the kernels; 2.to define work group size in runtime, depending on number of elements constexpr auto __wg_size = 64; + const auto __subgroup_sizes = __exec.queue().get_device().template get_info(); + const bool __dev_has_sg16 = std::find(__subgroup_sizes.begin(), __subgroup_sizes.end(), + static_cast(16)) != __subgroup_sizes.end(); //TODO: with _RadixSortKernel also the following a couple of compile time constants is used for unique kernel name using _RadixSortKernel = oneapi::dpl::__internal::__policy_kernel_name<_ExecutionPolicy>; @@ -803,10 +807,15 @@ __parallel_radix_sort(oneapi::dpl::__internal::__device_backend_tag, _ExecutionP else if (__n <= 4096 && __wg_size * 4 <= __max_wg_size) __event = __subgroup_radix_sort<_RadixSortKernel, __wg_size * 4, 16, __radix_bits, __is_ascending>{}( __exec.queue(), ::std::forward<_Range>(__in_rng), __proj); - else if (__n <= 8192 && __wg_size * 8 <= __max_wg_size) + // In __subgroup_radix_sort, we request a sub-group size via _ONEDPL_SYCL_REQD_SUB_GROUP_SIZE_IF_SUPPORTED + // based upon the iters per item. For the below cases, register spills that result in runtime exceptions have + // been observed on accelerators that do not support the requested sub-group size of 16. For the above cases + // that request but may not receive a sub-group size of 16, inputs are small enough to avoid register + // spills on assessed hardware. + else if (__n <= 8192 && __wg_size * 8 <= __max_wg_size && __dev_has_sg16) __event = __subgroup_radix_sort<_RadixSortKernel, __wg_size * 8, 16, __radix_bits, __is_ascending>{}( __exec.queue(), ::std::forward<_Range>(__in_rng), __proj); - else if (__n <= 16384 && __wg_size * 8 <= __max_wg_size) + else if (__n <= 16384 && __wg_size * 8 <= __max_wg_size && __dev_has_sg16) __event = __subgroup_radix_sort<_RadixSortKernel, __wg_size * 8, 32, __radix_bits, __is_ascending>{}( __exec.queue(), ::std::forward<_Range>(__in_rng), __proj); else