diff --git a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h index b006eae051b..aeb844b6b62 100644 --- a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h +++ b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h @@ -997,25 +997,27 @@ struct __parallel_find_backward_tag // Tag for __parallel_find_or for or-semantic struct __parallel_or_tag { - class __atomic_compare + using _AtomicType = unsigned int; + + static constexpr _AtomicType __found_state = 1; + static constexpr _AtomicType __not_found_state = 0; + + struct __compare_state { - public: - template bool - operator()(const _LocalAtomic& __found_local, const _GlobalAtomic& __found) const + operator()(const _AtomicType __found_local, const _AtomicType __found) const { - return __found_local == 1 && __found == 0; + return __found_local == __found_state && __found == __not_found_state; } }; - using _AtomicType = int32_t; - using _Compare = __atomic_compare; + using _Compare = __compare_state; // The template parameter is intended to unify __init_value in tags. template constexpr static _AtomicType __init_value(_DiffType) { - return 0; + return __not_found_state; } }; @@ -1028,52 +1030,89 @@ struct __early_exit_find_or { _Pred __pred; + // operator() overload for __parallel_or_tag + template + void + operator()(const _NDItemId __nd_item, const _IterSize __n_iter, const _WgSize __wg_size, _Compare __comp, + _FoundLocalState& __found_local, __parallel_or_tag, _Ranges&&... __rngs) const + { + const auto __n = oneapi::dpl::__ranges::__get_first_range_size(__rngs...); + + std::size_t __shift = 16; + const std::size_t __local_idx = __nd_item.get_local_id(0); + const std::size_t __group_idx = __nd_item.get_group(0); + + // each work_item processes N_ELEMENTS with step SHIFT + const std::size_t __leader = (__local_idx / __shift) * __shift; + const std::size_t __init_index = + __group_idx * __wg_size * __n_iter + __leader * __n_iter + __local_idx % __shift; + + // if our "line" is out of work group size, reduce the line to the number of the rest elements + if (__wg_size - __leader < __shift) + __shift = __wg_size - __leader; + for (_IterSize __i = 0; __i < __n_iter; ++__i) + { + // Point #B1 - not required to have _ShiftedIdxType + + _IterSize __current_iter = __i; + // Point #B2 - not required + + // Point #B3 - rewritten + const auto __shifted_idx = __init_index + __current_iter * __shift; + // Point #B4 - rewritten + if (__shifted_idx < __n && __pred(__shifted_idx, __rngs...)) + { + __found_local = __parallel_or_tag::__found_state; + + // Doesn't make sense to continue if we found the element + break; + } + } + } + + // operator() overload for __parallel_find_forward_tag and for __parallel_find_backward_tag template void - operator()(const _NDItemId __item_id, const _IterSize __n_iter, const _WgSize __wg_size, _Compare __comp, + operator()(const _NDItemId __nd_item, const _IterSize __n_iter, const _WgSize __wg_size, _Compare __comp, _LocalAtomic& __found_local, _BrickTag, _Ranges&&... __rngs) const { - using __par_backend_hetero::__parallel_or_tag; - using _OrTagType = ::std::is_same<_BrickTag, __par_backend_hetero::__parallel_or_tag>; - using _BackwardTagType = ::std::is_same; + using _BackwardTagType = std::is_same; - auto __n = oneapi::dpl::__ranges::__get_first_range_size(__rngs...); + const auto __n = oneapi::dpl::__ranges::__get_first_range_size(__rngs...); - ::std::size_t __shift = 16; - ::std::size_t __local_idx = __item_id.get_local_id(0); - ::std::size_t __group_idx = __item_id.get_group(0); + std::size_t __shift = 16; + const std::size_t __local_idx = __nd_item.get_local_id(0); + const std::size_t __group_idx = __nd_item.get_group(0); // each work_item processes N_ELEMENTS with step SHIFT - ::std::size_t __leader = (__local_idx / __shift) * __shift; - ::std::size_t __init_index = __group_idx * __wg_size * __n_iter + __leader * __n_iter + __local_idx % __shift; + const std::size_t __leader = (__local_idx / __shift) * __shift; + const std::size_t __init_index = + __group_idx * __wg_size * __n_iter + __leader * __n_iter + __local_idx % __shift; // if our "line" is out of work group size, reduce the line to the number of the rest elements if (__wg_size - __leader < __shift) __shift = __wg_size - __leader; for (_IterSize __i = 0; __i < __n_iter; ++__i) { + // Point #B1 //in case of find-semantic __shifted_idx must be the same type as the atomic for a correct comparison - using _ShiftedIdxType = ::std::conditional_t<_OrTagType::value, decltype(__init_index + __i * __shift), - decltype(__found_local.load())>; + using _ShiftedIdxType = decltype(__init_index + __i * __shift); _IterSize __current_iter = __i; + // Point #B2 if constexpr (_BackwardTagType::value) __current_iter = __n_iter - 1 - __i; - _ShiftedIdxType __shifted_idx = __init_index + __current_iter * __shift; - // TODO:[Performance] the issue with atomic load (in comparison with __shifted_idx for early exit) - // should be investigated later, with other HW + // Point #B3 + const _ShiftedIdxType __shifted_idx = __init_index + __current_iter * __shift; + // Point #B4 if (__shifted_idx < __n && __pred(__shifted_idx, __rngs...)) { - if constexpr (_OrTagType::value) - __found_local.store(1); - else + for (auto __old = __found_local.load(); __comp(__shifted_idx, __old); __old = __found_local.load()) { - for (auto __old = __found_local.load(); __comp(__shifted_idx, __old); __old = __found_local.load()) - { - __found_local.compare_exchange_strong(__old, __shifted_idx); - } + __found_local.compare_exchange_strong(__old, __shifted_idx); } } } @@ -1084,13 +1123,98 @@ struct __early_exit_find_or // parallel_find_or - sync pattern //------------------------------------------------------------------------ -// Base pattern for __parallel_or and __parallel_find. The execution depends on tag type _BrickTag. -template -::std::conditional_t< - ::std::is_same_v<_BrickTag, __parallel_or_tag>, bool, - oneapi::dpl::__internal::__difference_t::type>> +// Specialization for __parallel_or_tag +template +bool __parallel_find_or(oneapi::dpl::__internal::__device_backend_tag, _ExecutionPolicy&& __exec, _Brick __f, - _BrickTag __brick_tag, _Ranges&&... __rngs) + __parallel_or_tag /*__brick_tag*/, _Ranges&&... __rngs) +{ + using _BrickTag = __parallel_or_tag; + + using _CustomName = oneapi::dpl::__internal::__policy_kernel_name<_ExecutionPolicy>; + using _AtomicType = typename _BrickTag::_AtomicType; + using _FindOrKernel = + oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_generator<__find_or_kernel, _CustomName, _Brick, + _BrickTag, _Ranges...>; + + auto __rng_n = oneapi::dpl::__ranges::__get_first_range_size(__rngs...); + assert(__rng_n > 0); + + // TODO: find a way to generalize getting of reliable work-group size + auto __wgroup_size = oneapi::dpl::__internal::__max_work_group_size(__exec); +#if _ONEDPL_COMPILE_KERNEL + auto __kernel = __internal::__kernel_compiler<_FindOrKernel>::__compile(__exec); + __wgroup_size = ::std::min(__wgroup_size, oneapi::dpl::__internal::__kernel_work_group_size(__exec, __kernel)); +#endif + auto __max_cu = oneapi::dpl::__internal::__max_compute_units(__exec); + + auto __n_groups = (__rng_n - 1) / __wgroup_size + 1; + // TODO: try to change __n_groups with another formula for more perfect load balancing + __n_groups = ::std::min(__n_groups, decltype(__n_groups)(__max_cu)); + + auto __n_iter = (__rng_n - 1) / (__n_groups * __wgroup_size) + 1; + + _PRINT_INFO_IN_DEBUG_MODE(__exec, __wgroup_size, __max_cu); + + const _AtomicType __init_value = _BrickTag::__init_value(__rng_n); + _AtomicType __result = __init_value; + + const auto __pred = oneapi::dpl::__par_backend_hetero::__early_exit_find_or<_ExecutionPolicy, _Brick>{__f}; + + // scope is to copy data back to __result after destruction of temporary sycl:buffer + { + auto __temp = sycl::buffer<_AtomicType, 1>(&__result, 1); // temporary storage for global atomic + + // main parallel_for + __exec.queue().submit([&](sycl::handler& __cgh) { + oneapi::dpl::__ranges::__require_access(__cgh, __rngs...); + auto __temp_acc = __temp.template get_access(__cgh); + +#if _ONEDPL_COMPILE_KERNEL && _ONEDPL_KERNEL_BUNDLE_PRESENT + __cgh.use_kernel_bundle(__kernel.get_kernel_bundle()); +#endif + __cgh.parallel_for<_FindOrKernel>( +#if _ONEDPL_COMPILE_KERNEL && !_ONEDPL_KERNEL_BUNDLE_PRESENT + __kernel, +#endif + sycl::nd_range(sycl::range(__n_groups * __wgroup_size), + sycl::range(__wgroup_size)), + [=](sycl::nd_item __nd_item) { + + // using __atomic_ref = sycl::atomic_ref<_AtomicType, sycl::memory_order::relaxed, sycl::memory_scope::work_group, _Space>; + __dpl_sycl::__atomic_ref<_AtomicType, sycl::access::address_space::global_space> __found( + *__dpl_sycl::__get_accessor_ptr(__temp_acc)); + // Point #A1 - not required + + // Point #A2 - rewritten + _AtomicType __found_local = __init_value; + // Point #A2.1 - not required + + // Point #A3 - rewritten + constexpr auto __comp = typename _BrickTag::_Compare{}; + __pred(__nd_item, __n_iter, __wgroup_size, __comp, __found_local, __parallel_or_tag{}, __rngs...); + // Point #A3.1 - not required + + // Point #A4 - rewritten + // Set found state result to global atomic + if (__found_local != __init_value) + { + __found.fetch_or(__found_local); + } + }); + }); + //The end of the scope - a point of synchronization (on temporary sycl buffer destruction) + } + + return __result != __init_value; +} + +// Specialization for __parallel_find_forward_tag, __parallel_find_backward_tag +template +oneapi::dpl::__internal::__difference_t::type> +__parallel_find_or(oneapi::dpl::__internal::__device_backend_tag, _ExecutionPolicy&& __exec, _Brick __f, + _BrickTag, _Ranges&&... __rngs) { using _CustomName = oneapi::dpl::__internal::__policy_kernel_name<_ExecutionPolicy>; using _AtomicType = typename _BrickTag::_AtomicType; @@ -1099,6 +1223,7 @@ __parallel_find_or(oneapi::dpl::__internal::__device_backend_tag, _ExecutionPoli _BrickTag, _Ranges...>; constexpr bool __or_tag_check = ::std::is_same_v<_BrickTag, __parallel_or_tag>; + static_assert(!std::is_same_v<_BrickTag, __parallel_or_tag>); auto __rng_n = oneapi::dpl::__ranges::__get_first_range_size(__rngs...); assert(__rng_n > 0); @@ -1121,7 +1246,7 @@ __parallel_find_or(oneapi::dpl::__internal::__device_backend_tag, _ExecutionPoli _AtomicType __init_value = _BrickTag::__init_value(__rng_n); auto __result = __init_value; - auto __pred = oneapi::dpl::__par_backend_hetero::__early_exit_find_or<_ExecutionPolicy, _Brick>{__f}; + const auto __pred = oneapi::dpl::__par_backend_hetero::__early_exit_find_or<_ExecutionPolicy, _Brick>{__f}; // scope is to copy data back to __result after destruction of temporary sycl:buffer { @@ -1143,36 +1268,37 @@ __parallel_find_or(oneapi::dpl::__internal::__device_backend_tag, _ExecutionPoli #endif sycl::nd_range(sycl::range(__n_groups * __wgroup_size), sycl::range(__wgroup_size)), - [=](sycl::nd_item __item_id) { - auto __local_idx = __item_id.get_local_id(0); + [=](sycl::nd_item __nd_item) { + auto __local_idx = __nd_item.get_local_id(0); __dpl_sycl::__atomic_ref<_AtomicType, sycl::access::address_space::global_space> __found( *__dpl_sycl::__get_accessor_ptr(__temp_acc)); + // Point #A1 __dpl_sycl::__atomic_ref<_AtomicType, sycl::access::address_space::local_space> __found_local( *__dpl_sycl::__get_accessor_ptr(__temp_local)); + // Point #A2 // 1. Set initial value to local atomic if (__local_idx == 0) __found_local.store(__init_value); - __dpl_sycl::__group_barrier(__item_id); + // Point #A2.1 + __dpl_sycl::__group_barrier(__nd_item); + // Point #A3 // 2. Find any element that satisfies pred and set local atomic value to global atomic constexpr auto __comp = typename _BrickTag::_Compare{}; - __pred(__item_id, __n_iter, __wgroup_size, __comp, __found_local, __brick_tag, __rngs...); - __dpl_sycl::__group_barrier(__item_id); + __pred(__nd_item, __n_iter, __wgroup_size, __comp, __found_local, _BrickTag{}, __rngs...); + // Point #A3.1 + __dpl_sycl::__group_barrier(__nd_item); + // Point #A4 // Set local atomic value to global atomic if (__local_idx == 0 && __comp(__found_local.load(), __found.load())) { - if constexpr (__or_tag_check) - __found.store(1); - else + for (auto __old = __found.load(); __comp(__found_local.load(), __old); + __old = __found.load()) { - for (auto __old = __found.load(); __comp(__found_local.load(), __old); - __old = __found.load()) - { - __found.compare_exchange_strong(__old, __found_local.load()); - } + __found.compare_exchange_strong(__old, __found_local.load()); } } }); @@ -1180,10 +1306,7 @@ __parallel_find_or(oneapi::dpl::__internal::__device_backend_tag, _ExecutionPoli //The end of the scope - a point of synchronization (on temporary sycl buffer destruction) } - if constexpr (__or_tag_check) - return __result; - else - return __result != __init_value ? __result : __rng_n; + return __result != __init_value ? __result : __rng_n; } //------------------------------------------------------------------------