diff --git a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_merge.h b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_merge.h index 3a2c893bf15..cc34a8144d8 100644 --- a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_merge.h +++ b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_merge.h @@ -336,17 +336,24 @@ struct __parallel_merge_submitter_large<_IdType, _CustomName, _IdType __chunk = __exec.queue().get_device().is_cpu() ? 128 : __data_items_in_slm_bank; assert(__chunk > 0); - // Pessimistically only use 2/3 of the memory to take into account memory used by compiled kernel - const auto __slm_adjusted_work_group_size = oneapi::dpl::__internal::__slm_adjusted_work_group_size(__exec, sizeof(_RangeValueType)); - const auto __slm_adjusted_work_group_size_x_part = __slm_adjusted_work_group_size * 4 / 5; + // Get the size of local memory arena in bytes. + const std::size_t __slm_mem_size = __exec.queue().get_device().template get_info(); - // The amount of data must be a multiple of the chunk size. - const std::size_t __max_source_data_items_fit_into_slm = __slm_adjusted_work_group_size_x_part - __slm_adjusted_work_group_size_x_part % __chunk; - assert(__max_source_data_items_fit_into_slm > 0); - assert(__max_source_data_items_fit_into_slm % __chunk == 0); + // Pessimistically only use 4/5 of the memory to take into account memory used by compiled kernel + const std::size_t __slm_mem_size_x_part = __slm_mem_size * 4 / 5; + + // Calculate how many items count we may place into SLM memory + const auto __slm_cached_items_count = __slm_mem_size_x_part / sizeof(_RangeValueType); // The amount of items in the each work-group is the amount of diagonals processing between two work-groups + 1 (for the left base diagonal in work-group) - const std::size_t __wi_in_one_wg = __max_source_data_items_fit_into_slm / __chunk; + std::size_t __wi_in_one_wg = __slm_cached_items_count / __chunk; + const std::size_t __max_wi_in_one_wg = __exec.queue().get_device().template get_info>()[0]; + if (__wi_in_one_wg > __max_wi_in_one_wg) + { + __chunk = oneapi::dpl::__internal::__dpl_ceiling_div(__slm_cached_items_count, __max_wi_in_one_wg); + __wi_in_one_wg = __slm_cached_items_count / __chunk; + assert(__wi_in_one_wg <= __max_wi_in_one_wg); + } assert(__wi_in_one_wg > 0); // The amount of the base diagonals is the amount of the work-groups @@ -396,7 +403,7 @@ struct __parallel_merge_submitter_large<_IdType, _CustomName, auto __base_diagonals_sp_global_ptr = __base_diagonals_sp_storage_t::__get_usm_or_buffer_accessor_ptr(__base_diagonals_sp_global_acc); const std::size_t __slm_cached_data_size = __wi_in_one_wg * __chunk; - __dpl_sycl::__local_accessor<_RangeValueType> __loc_acc(2 * __slm_cached_data_size, __cgh); + __dpl_sycl::__local_accessor<_RangeValueType> __loc_acc(__slm_cached_data_size, __cgh); // Run nd_range parallel_for to process all the data // - each work-group caching source data in SLM and processing diagonals between two base diagonals;