From c1da480968510da79e1cc0f4510bafef8872ffd6 Mon Sep 17 00:00:00 2001 From: Sergey Kopienko Date: Fri, 20 Dec 2024 12:24:37 +0100 Subject: [PATCH] include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h - remove usage of __kernel_name_generator as not required Signed-off-by: Sergey Kopienko --- .../pstl/hetero/dpcpp/parallel_backend_sycl.h | 216 ++++++++++-------- 1 file changed, 116 insertions(+), 100 deletions(-) diff --git a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h index 96d63e33aee..1366417c967 100644 --- a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h +++ b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h @@ -1837,100 +1837,46 @@ struct __parallel_find_or_nd_range_tuner -__FoundStateType -__parallel_find_or_impl_one_wg(oneapi::dpl::__internal::__device_backend_tag, _ExecutionPolicy&& __exec, - _BrickTag __brick_tag, const std::size_t __rng_n, const std::size_t __wgroup_size, - const __FoundStateType __init_value, _Predicate __pred, _Ranges&&... __rngs) -{ - using __result_and_scratch_storage_t = __result_and_scratch_storage<_ExecutionPolicy, __FoundStateType>; - __result_and_scratch_storage_t __result_storage{__exec, 1, 0}; - - // Calculate the number of elements to be processed by each work-item. - const auto __iters_per_work_item = oneapi::dpl::__internal::__dpl_ceiling_div(__rng_n, __wgroup_size); - - // main parallel_for - auto __event = __exec.queue().submit([&](sycl::handler& __cgh) { - oneapi::dpl::__ranges::__require_access(__cgh, __rngs...); - auto __result_acc = - __result_storage.template __get_result_acc(__cgh, __dpl_sycl::__no_init{}); - - __cgh.parallel_for( - sycl::nd_range(sycl::range(__wgroup_size), sycl::range(__wgroup_size)), - [=](sycl::nd_item __item_id) { - auto __local_idx = __item_id.get_local_id(0); - - // 1. Set initial value to local found state - __FoundStateType __found_local = __init_value; - - // 2. Find any element that satisfies pred - // - after this call __found_local may still have initial value: - // 1) if no element satisfies pred; - // 2) early exit from sub-group occurred: in this case the state of __found_local will updated in the next group operation (3) - __pred(__item_id, __rng_n, __iters_per_work_item, __wgroup_size, __found_local, __brick_tag, __rngs...); - - // 3. Reduce over group: find __dpl_sycl::__minimum (for the __parallel_find_forward_tag), - // find __dpl_sycl::__maximum (for the __parallel_find_backward_tag) - // or update state with __dpl_sycl::__any_of_group (for the __parallel_or_tag) - // inside all our group items - if constexpr (__or_tag_check) - __found_local = __dpl_sycl::__any_of_group(__item_id.get_group(), __found_local); - else - __found_local = __dpl_sycl::__reduce_over_group(__item_id.get_group(), __found_local, - typename _BrickTag::_LocalResultsReduceOp{}); - - // Set local found state value value to global state to have correct result - if (__local_idx == 0) - { - __result_and_scratch_storage_t::__get_usm_or_buffer_accessor_ptr(__result_acc)[0] = __found_local; - } - }); - }); - - // Wait and return result - return __result_storage.__wait_and_get_value(__event); -} +template +struct __parallel_find_or_impl_one_wg; // Base pattern for __parallel_or and __parallel_find. The execution depends on tag type _BrickTag. -template -_AtomicType -__parallel_find_or_impl_multiple_wgs(oneapi::dpl::__internal::__device_backend_tag, _ExecutionPolicy&& __exec, - _BrickTag __brick_tag, const std::size_t __rng_n, const std::size_t __n_groups, - const std::size_t __wgroup_size, const _AtomicType __init_value, _Predicate __pred, - _Ranges&&... __rngs) -{ - auto __result = __init_value; - - // Calculate the number of elements to be processed by each work-item. - const auto __iters_per_work_item = oneapi::dpl::__internal::__dpl_ceiling_div(__rng_n, __n_groups * __wgroup_size); - - // scope is to copy data back to __result after destruction of temporary sycl:buffer +template +struct __parallel_find_or_impl_one_wg<__internal::__optional_kernel_name, __or_tag_check> +{ + template + __FoundStateType + operator()(_ExecutionPolicy&& __exec, _BrickTag __brick_tag, const std::size_t __rng_n, + const std::size_t __wgroup_size, const __FoundStateType __init_value, _Predicate __pred, + _Ranges&&... __rngs) { - sycl::buffer<_AtomicType, 1> __result_sycl_buf(&__result, 1); // temporary storage for global atomic + using __result_and_scratch_storage_t = __result_and_scratch_storage<_ExecutionPolicy, __FoundStateType>; + __result_and_scratch_storage_t __result_storage{__exec, 1, 0}; + + // Calculate the number of elements to be processed by each work-item. + const auto __iters_per_work_item = oneapi::dpl::__internal::__dpl_ceiling_div(__rng_n, __wgroup_size); // main parallel_for - __exec.queue().submit([&](sycl::handler& __cgh) { + auto __event = __exec.queue().submit([&](sycl::handler& __cgh) { oneapi::dpl::__ranges::__require_access(__cgh, __rngs...); - auto __result_sycl_buf_acc = __result_sycl_buf.template get_access(__cgh); + auto __result_acc = + __result_storage.template __get_result_acc(__cgh, __dpl_sycl::__no_init{}); - __cgh.parallel_for( - sycl::nd_range(sycl::range(__n_groups * __wgroup_size), - sycl::range(__wgroup_size)), + __cgh.parallel_for( + sycl::nd_range(sycl::range(__wgroup_size), sycl::range(__wgroup_size)), [=](sycl::nd_item __item_id) { auto __local_idx = __item_id.get_local_id(0); // 1. Set initial value to local found state - _AtomicType __found_local = __init_value; + __FoundStateType __found_local = __init_value; // 2. Find any element that satisfies pred // - after this call __found_local may still have initial value: // 1) if no element satisfies pred; // 2) early exit from sub-group occurred: in this case the state of __found_local will updated in the next group operation (3) - __pred(__item_id, __rng_n, __iters_per_work_item, __n_groups * __wgroup_size, __found_local, - __brick_tag, __rngs...); + __pred(__item_id, __rng_n, __iters_per_work_item, __wgroup_size, __found_local, __brick_tag, + __rngs...); // 3. Reduce over group: find __dpl_sycl::__minimum (for the __parallel_find_forward_tag), // find __dpl_sycl::__maximum (for the __parallel_find_backward_tag) @@ -1942,22 +1888,92 @@ __parallel_find_or_impl_multiple_wgs(oneapi::dpl::__internal::__device_backend_t __found_local = __dpl_sycl::__reduce_over_group(__item_id.get_group(), __found_local, typename _BrickTag::_LocalResultsReduceOp{}); - // Set local found state value value to global atomic - if (__local_idx == 0 && __found_local != __init_value) + // Set local found state value value to global state to have correct result + if (__local_idx == 0) { - __dpl_sycl::__atomic_ref<_AtomicType, sycl::access::address_space::global_space> __found( - *__dpl_sycl::__get_accessor_ptr(__result_sycl_buf_acc)); - - // Update global (for all groups) atomic state with the found index - _BrickTag::__save_state_to_atomic(__found, __found_local); + __result_and_scratch_storage_t::__get_usm_or_buffer_accessor_ptr(__result_acc)[0] = + __found_local; } }); }); - //The end of the scope - a point of synchronization (on temporary sycl buffer destruction) + + // Wait and return result + return __result_storage.__wait_and_get_value(__event); } +}; - return __result; -} +template +struct __parallel_find_or_impl_multiple_wgs; + +// Base pattern for __parallel_or and __parallel_find. The execution depends on tag type _BrickTag. +template +struct __parallel_find_or_impl_multiple_wgs<__internal::__optional_kernel_name, __or_tag_check> +{ + template + _AtomicType + operator()(_ExecutionPolicy&& __exec, _BrickTag __brick_tag, + const std::size_t __rng_n, const std::size_t __n_groups, const std::size_t __wgroup_size, + const _AtomicType __init_value, _Predicate __pred, _Ranges&&... __rngs) + { + auto __result = __init_value; + + // Calculate the number of elements to be processed by each work-item. + const auto __iters_per_work_item = + oneapi::dpl::__internal::__dpl_ceiling_div(__rng_n, __n_groups * __wgroup_size); + + // scope is to copy data back to __result after destruction of temporary sycl:buffer + { + sycl::buffer<_AtomicType, 1> __result_sycl_buf(&__result, 1); // temporary storage for global atomic + + // main parallel_for + __exec.queue().submit([&](sycl::handler& __cgh) { + oneapi::dpl::__ranges::__require_access(__cgh, __rngs...); + auto __result_sycl_buf_acc = __result_sycl_buf.template get_access(__cgh); + + __cgh.parallel_for( + sycl::nd_range(sycl::range(__n_groups * __wgroup_size), + sycl::range(__wgroup_size)), + [=](sycl::nd_item __item_id) { + auto __local_idx = __item_id.get_local_id(0); + + // 1. Set initial value to local found state + _AtomicType __found_local = __init_value; + + // 2. Find any element that satisfies pred + // - after this call __found_local may still have initial value: + // 1) if no element satisfies pred; + // 2) early exit from sub-group occurred: in this case the state of __found_local will updated in the next group operation (3) + __pred(__item_id, __rng_n, __iters_per_work_item, __n_groups * __wgroup_size, __found_local, + __brick_tag, __rngs...); + + // 3. Reduce over group: find __dpl_sycl::__minimum (for the __parallel_find_forward_tag), + // find __dpl_sycl::__maximum (for the __parallel_find_backward_tag) + // or update state with __dpl_sycl::__any_of_group (for the __parallel_or_tag) + // inside all our group items + if constexpr (__or_tag_check) + __found_local = __dpl_sycl::__any_of_group(__item_id.get_group(), __found_local); + else + __found_local = __dpl_sycl::__reduce_over_group( + __item_id.get_group(), __found_local, typename _BrickTag::_LocalResultsReduceOp{}); + + // Set local found state value value to global atomic + if (__local_idx == 0 && __found_local != __init_value) + { + __dpl_sycl::__atomic_ref<_AtomicType, sycl::access::address_space::global_space> __found( + *__dpl_sycl::__get_accessor_ptr(__result_sycl_buf_acc)); + + // Update global (for all groups) atomic state with the found index + _BrickTag::__save_state_to_atomic(__found, __found_local); + } + }); + }); + //The end of the scope - a point of synchronization (on temporary sycl buffer destruction) + } + + return __result; + } +}; // Base pattern for __parallel_or and __parallel_find. The execution depends on tag type _BrickTag. template @@ -1968,12 +1984,6 @@ __parallel_find_or(oneapi::dpl::__internal::__device_backend_tag, _ExecutionPoli _BrickTag __brick_tag, _Ranges&&... __rngs) { using _CustomName = oneapi::dpl::__internal::__policy_kernel_name<_ExecutionPolicy>; - using _FindOrKernelOneWG = - oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_generator<__find_or_kernel_one_wg, _CustomName, - _Brick, _BrickTag, _Ranges...>; - using _FindOrKernel = - oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_generator<__find_or_kernel, _CustomName, _Brick, - _BrickTag, _Ranges...>; auto __rng_n = oneapi::dpl::__ranges::__get_first_range_size(__rngs...); assert(__rng_n > 0); @@ -1996,20 +2006,26 @@ __parallel_find_or(oneapi::dpl::__internal::__device_backend_tag, _ExecutionPoli // We shouldn't have any restrictions for _AtomicType type here // because we have a single work-group and we don't need to use atomics for inter-work-group communication. + using _KernelName = + oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider<__find_or_kernel_one_wg<_CustomName>>; + // Single WG implementation - __result = __parallel_find_or_impl_one_wg<_FindOrKernelOneWG, __or_tag_check>( - oneapi::dpl::__internal::__device_backend_tag{}, std::forward<_ExecutionPolicy>(__exec), __brick_tag, - __rng_n, __wgroup_size, __init_value, __pred, std::forward<_Ranges>(__rngs)...); + __result = __parallel_find_or_impl_one_wg<_KernelName, __or_tag_check>()( + std::forward<_ExecutionPolicy>(__exec), __brick_tag, __rng_n, __wgroup_size, __init_value, __pred, + std::forward<_Ranges>(__rngs)...); } else { assert("This device does not support 64-bit atomics" && (sizeof(_AtomicType) < 8 || __exec.queue().get_device().has(sycl::aspect::atomic64))); + using _KernelName = + oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider<__find_or_kernel<_CustomName>>; + // Multiple WG implementation - __result = __parallel_find_or_impl_multiple_wgs<_FindOrKernel, __or_tag_check>( - oneapi::dpl::__internal::__device_backend_tag{}, std::forward<_ExecutionPolicy>(__exec), __brick_tag, - __rng_n, __n_groups, __wgroup_size, __init_value, __pred, std::forward<_Ranges>(__rngs)...); + __result = __parallel_find_or_impl_multiple_wgs<_KernelName, __or_tag_check>()( + std::forward<_ExecutionPolicy>(__exec), __brick_tag, __rng_n, __n_groups, __wgroup_size, __init_value, + __pred, std::forward<_Ranges>(__rngs)...); } if constexpr (__or_tag_check)