Skip to content

Commit

Permalink
include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h - perfor…
Browse files Browse the repository at this point in the history
…mance optimization of __parallel_find_or + __device_backend_tag for the usage with __parallel_or_tag

Signed-off-by: Sergey Kopienko <[email protected]>
  • Loading branch information
SergeyKopienko committed Jun 7, 2024
1 parent 758205a commit 5c64105
Showing 1 changed file with 177 additions and 54 deletions.
231 changes: 177 additions & 54 deletions include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h
Original file line number Diff line number Diff line change
Expand Up @@ -997,25 +997,27 @@ struct __parallel_find_backward_tag
// Tag for __parallel_find_or for or-semantic
struct __parallel_or_tag
{
class __atomic_compare
using _AtomicType = unsigned int;

static constexpr _AtomicType __found_state = 1;
static constexpr _AtomicType __not_found_state = 0;

struct __compare_state
{
public:
template <typename _LocalAtomic, typename _GlobalAtomic>
bool
operator()(const _LocalAtomic& __found_local, const _GlobalAtomic& __found) const
operator()(const _AtomicType __found_local, const _AtomicType __found) const
{
return __found_local == 1 && __found == 0;
return __found_local == __found_state && __found == __not_found_state;
}
};

using _AtomicType = int32_t;
using _Compare = __atomic_compare;
using _Compare = __compare_state;

// The template parameter is intended to unify __init_value in tags.
template <typename _DiffType>
constexpr static _AtomicType __init_value(_DiffType)
{
return 0;
return __not_found_state;
}
};

Expand All @@ -1028,52 +1030,89 @@ struct __early_exit_find_or
{
_Pred __pred;

// operator() overload for __parallel_or_tag
template <typename _NDItemId, typename _IterSize, typename _WgSize, typename _FoundLocalState, typename _Compare,
typename... _Ranges>
void
operator()(const _NDItemId __nd_item, const _IterSize __n_iter, const _WgSize __wg_size, _Compare __comp,
_FoundLocalState& __found_local, __parallel_or_tag, _Ranges&&... __rngs) const
{
const auto __n = oneapi::dpl::__ranges::__get_first_range_size(__rngs...);

std::size_t __shift = 16;
const std::size_t __local_idx = __nd_item.get_local_id(0);
const std::size_t __group_idx = __nd_item.get_group(0);

// each work_item processes N_ELEMENTS with step SHIFT
const std::size_t __leader = (__local_idx / __shift) * __shift;
const std::size_t __init_index =
__group_idx * __wg_size * __n_iter + __leader * __n_iter + __local_idx % __shift;

// if our "line" is out of work group size, reduce the line to the number of the rest elements
if (__wg_size - __leader < __shift)
__shift = __wg_size - __leader;
for (_IterSize __i = 0; __i < __n_iter; ++__i)
{
// Point #B1 - not required to have _ShiftedIdxType

_IterSize __current_iter = __i;
// Point #B2 - not required

// Point #B3 - rewritten
const auto __shifted_idx = __init_index + __current_iter * __shift;
// Point #B4 - rewritten
if (__shifted_idx < __n && __pred(__shifted_idx, __rngs...))
{
__found_local = __parallel_or_tag::__found_state;

// Doesn't make sense to continue if we found the element
break;
}
}
}

// operator() overload for __parallel_find_forward_tag and for __parallel_find_backward_tag
template <typename _NDItemId, typename _IterSize, typename _WgSize, typename _LocalAtomic, typename _Compare,
typename _BrickTag, typename... _Ranges>
void
operator()(const _NDItemId __item_id, const _IterSize __n_iter, const _WgSize __wg_size, _Compare __comp,
operator()(const _NDItemId __nd_item, const _IterSize __n_iter, const _WgSize __wg_size, _Compare __comp,
_LocalAtomic& __found_local, _BrickTag, _Ranges&&... __rngs) const
{
using __par_backend_hetero::__parallel_or_tag;
using _OrTagType = ::std::is_same<_BrickTag, __par_backend_hetero::__parallel_or_tag>;
using _BackwardTagType = ::std::is_same<typename _BrickTag::_Compare, oneapi::dpl::__internal::__pstl_greater>;
using _BackwardTagType = std::is_same<typename _BrickTag::_Compare, oneapi::dpl::__internal::__pstl_greater>;

auto __n = oneapi::dpl::__ranges::__get_first_range_size(__rngs...);
const auto __n = oneapi::dpl::__ranges::__get_first_range_size(__rngs...);

::std::size_t __shift = 16;
::std::size_t __local_idx = __item_id.get_local_id(0);
::std::size_t __group_idx = __item_id.get_group(0);
std::size_t __shift = 16;
const std::size_t __local_idx = __nd_item.get_local_id(0);
const std::size_t __group_idx = __nd_item.get_group(0);

// each work_item processes N_ELEMENTS with step SHIFT
::std::size_t __leader = (__local_idx / __shift) * __shift;
::std::size_t __init_index = __group_idx * __wg_size * __n_iter + __leader * __n_iter + __local_idx % __shift;
const std::size_t __leader = (__local_idx / __shift) * __shift;
const std::size_t __init_index =
__group_idx * __wg_size * __n_iter + __leader * __n_iter + __local_idx % __shift;

// if our "line" is out of work group size, reduce the line to the number of the rest elements
if (__wg_size - __leader < __shift)
__shift = __wg_size - __leader;
for (_IterSize __i = 0; __i < __n_iter; ++__i)
{
// Point #B1
//in case of find-semantic __shifted_idx must be the same type as the atomic for a correct comparison
using _ShiftedIdxType = ::std::conditional_t<_OrTagType::value, decltype(__init_index + __i * __shift),
decltype(__found_local.load())>;
using _ShiftedIdxType = decltype(__init_index + __i * __shift);

_IterSize __current_iter = __i;
// Point #B2
if constexpr (_BackwardTagType::value)
__current_iter = __n_iter - 1 - __i;

_ShiftedIdxType __shifted_idx = __init_index + __current_iter * __shift;
// TODO:[Performance] the issue with atomic load (in comparison with __shifted_idx for early exit)
// should be investigated later, with other HW
// Point #B3
const _ShiftedIdxType __shifted_idx = __init_index + __current_iter * __shift;
// Point #B4
if (__shifted_idx < __n && __pred(__shifted_idx, __rngs...))
{
if constexpr (_OrTagType::value)
__found_local.store(1);
else
for (auto __old = __found_local.load(); __comp(__shifted_idx, __old); __old = __found_local.load())
{
for (auto __old = __found_local.load(); __comp(__shifted_idx, __old); __old = __found_local.load())
{
__found_local.compare_exchange_strong(__old, __shifted_idx);
}
__found_local.compare_exchange_strong(__old, __shifted_idx);
}
}
}
Expand All @@ -1084,13 +1123,98 @@ struct __early_exit_find_or
// parallel_find_or - sync pattern
//------------------------------------------------------------------------

// Base pattern for __parallel_or and __parallel_find. The execution depends on tag type _BrickTag.
template <typename _ExecutionPolicy, typename _Brick, typename _BrickTag, typename... _Ranges>
::std::conditional_t<
::std::is_same_v<_BrickTag, __parallel_or_tag>, bool,
oneapi::dpl::__internal::__difference_t<typename oneapi::dpl::__ranges::__get_first_range_type<_Ranges...>::type>>
// Specialization for __parallel_or_tag
template <typename _ExecutionPolicy, typename _Brick, typename... _Ranges>
bool
__parallel_find_or(oneapi::dpl::__internal::__device_backend_tag, _ExecutionPolicy&& __exec, _Brick __f,
_BrickTag __brick_tag, _Ranges&&... __rngs)
__parallel_or_tag /*__brick_tag*/, _Ranges&&... __rngs)
{
using _BrickTag = __parallel_or_tag;

using _CustomName = oneapi::dpl::__internal::__policy_kernel_name<_ExecutionPolicy>;
using _AtomicType = typename _BrickTag::_AtomicType;
using _FindOrKernel =
oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_generator<__find_or_kernel, _CustomName, _Brick,
_BrickTag, _Ranges...>;

auto __rng_n = oneapi::dpl::__ranges::__get_first_range_size(__rngs...);
assert(__rng_n > 0);

// TODO: find a way to generalize getting of reliable work-group size
auto __wgroup_size = oneapi::dpl::__internal::__max_work_group_size(__exec);
#if _ONEDPL_COMPILE_KERNEL
auto __kernel = __internal::__kernel_compiler<_FindOrKernel>::__compile(__exec);
__wgroup_size = ::std::min(__wgroup_size, oneapi::dpl::__internal::__kernel_work_group_size(__exec, __kernel));
#endif
auto __max_cu = oneapi::dpl::__internal::__max_compute_units(__exec);

auto __n_groups = (__rng_n - 1) / __wgroup_size + 1;
// TODO: try to change __n_groups with another formula for more perfect load balancing
__n_groups = ::std::min(__n_groups, decltype(__n_groups)(__max_cu));

auto __n_iter = (__rng_n - 1) / (__n_groups * __wgroup_size) + 1;

_PRINT_INFO_IN_DEBUG_MODE(__exec, __wgroup_size, __max_cu);

const _AtomicType __init_value = _BrickTag::__init_value(__rng_n);
_AtomicType __result = __init_value;

const auto __pred = oneapi::dpl::__par_backend_hetero::__early_exit_find_or<_ExecutionPolicy, _Brick>{__f};

// scope is to copy data back to __result after destruction of temporary sycl:buffer
{
auto __temp = sycl::buffer<_AtomicType, 1>(&__result, 1); // temporary storage for global atomic

// main parallel_for
__exec.queue().submit([&](sycl::handler& __cgh) {
oneapi::dpl::__ranges::__require_access(__cgh, __rngs...);
auto __temp_acc = __temp.template get_access<access_mode::read_write>(__cgh);

#if _ONEDPL_COMPILE_KERNEL && _ONEDPL_KERNEL_BUNDLE_PRESENT
__cgh.use_kernel_bundle(__kernel.get_kernel_bundle());
#endif
__cgh.parallel_for<_FindOrKernel>(
#if _ONEDPL_COMPILE_KERNEL && !_ONEDPL_KERNEL_BUNDLE_PRESENT
__kernel,
#endif
sycl::nd_range</*dim=*/1>(sycl::range</*dim=*/1>(__n_groups * __wgroup_size),
sycl::range</*dim=*/1>(__wgroup_size)),
[=](sycl::nd_item</*dim=*/1> __nd_item) {

// using __atomic_ref = sycl::atomic_ref<_AtomicType, sycl::memory_order::relaxed, sycl::memory_scope::work_group, _Space>;
__dpl_sycl::__atomic_ref<_AtomicType, sycl::access::address_space::global_space> __found(
*__dpl_sycl::__get_accessor_ptr(__temp_acc));
// Point #A1 - not required

// Point #A2 - rewritten
_AtomicType __found_local = __init_value;
// Point #A2.1 - not required

// Point #A3 - rewritten
constexpr auto __comp = typename _BrickTag::_Compare{};
__pred(__nd_item, __n_iter, __wgroup_size, __comp, __found_local, __parallel_or_tag{}, __rngs...);
// Point #A3.1 - not required

// Point #A4 - rewritten
// Set found state result to global atomic
if (__found_local != __init_value)
{
__found.fetch_or(__found_local);
}
});
});
//The end of the scope - a point of synchronization (on temporary sycl buffer destruction)
}

return __result != __init_value;
}

// Specialization for __parallel_find_forward_tag, __parallel_find_backward_tag
template <typename _ExecutionPolicy, typename _Brick, typename _BrickTag,
typename... _Ranges>
oneapi::dpl::__internal::__difference_t<typename oneapi::dpl::__ranges::__get_first_range_type<_Ranges...>::type>
__parallel_find_or(oneapi::dpl::__internal::__device_backend_tag, _ExecutionPolicy&& __exec, _Brick __f,
_BrickTag, _Ranges&&... __rngs)
{
using _CustomName = oneapi::dpl::__internal::__policy_kernel_name<_ExecutionPolicy>;
using _AtomicType = typename _BrickTag::_AtomicType;
Expand All @@ -1099,6 +1223,7 @@ __parallel_find_or(oneapi::dpl::__internal::__device_backend_tag, _ExecutionPoli
_BrickTag, _Ranges...>;

constexpr bool __or_tag_check = ::std::is_same_v<_BrickTag, __parallel_or_tag>;
static_assert(!std::is_same_v<_BrickTag, __parallel_or_tag>);
auto __rng_n = oneapi::dpl::__ranges::__get_first_range_size(__rngs...);
assert(__rng_n > 0);

Expand All @@ -1121,7 +1246,7 @@ __parallel_find_or(oneapi::dpl::__internal::__device_backend_tag, _ExecutionPoli
_AtomicType __init_value = _BrickTag::__init_value(__rng_n);
auto __result = __init_value;

auto __pred = oneapi::dpl::__par_backend_hetero::__early_exit_find_or<_ExecutionPolicy, _Brick>{__f};
const auto __pred = oneapi::dpl::__par_backend_hetero::__early_exit_find_or<_ExecutionPolicy, _Brick>{__f};

// scope is to copy data back to __result after destruction of temporary sycl:buffer
{
Expand All @@ -1143,47 +1268,45 @@ __parallel_find_or(oneapi::dpl::__internal::__device_backend_tag, _ExecutionPoli
#endif
sycl::nd_range</*dim=*/1>(sycl::range</*dim=*/1>(__n_groups * __wgroup_size),
sycl::range</*dim=*/1>(__wgroup_size)),
[=](sycl::nd_item</*dim=*/1> __item_id) {
auto __local_idx = __item_id.get_local_id(0);
[=](sycl::nd_item</*dim=*/1> __nd_item) {
auto __local_idx = __nd_item.get_local_id(0);

__dpl_sycl::__atomic_ref<_AtomicType, sycl::access::address_space::global_space> __found(
*__dpl_sycl::__get_accessor_ptr(__temp_acc));
// Point #A1
__dpl_sycl::__atomic_ref<_AtomicType, sycl::access::address_space::local_space> __found_local(
*__dpl_sycl::__get_accessor_ptr(__temp_local));

// Point #A2
// 1. Set initial value to local atomic
if (__local_idx == 0)
__found_local.store(__init_value);
__dpl_sycl::__group_barrier(__item_id);
// Point #A2.1
__dpl_sycl::__group_barrier(__nd_item);

// Point #A3
// 2. Find any element that satisfies pred and set local atomic value to global atomic
constexpr auto __comp = typename _BrickTag::_Compare{};
__pred(__item_id, __n_iter, __wgroup_size, __comp, __found_local, __brick_tag, __rngs...);
__dpl_sycl::__group_barrier(__item_id);
__pred(__nd_item, __n_iter, __wgroup_size, __comp, __found_local, _BrickTag{}, __rngs...);
// Point #A3.1
__dpl_sycl::__group_barrier(__nd_item);

// Point #A4
// Set local atomic value to global atomic
if (__local_idx == 0 && __comp(__found_local.load(), __found.load()))
{
if constexpr (__or_tag_check)
__found.store(1);
else
for (auto __old = __found.load(); __comp(__found_local.load(), __old);
__old = __found.load())
{
for (auto __old = __found.load(); __comp(__found_local.load(), __old);
__old = __found.load())
{
__found.compare_exchange_strong(__old, __found_local.load());
}
__found.compare_exchange_strong(__old, __found_local.load());
}
}
});
});
//The end of the scope - a point of synchronization (on temporary sycl buffer destruction)
}

if constexpr (__or_tag_check)
return __result;
else
return __result != __init_value ? __result : __rng_n;
return __result != __init_value ? __result : __rng_n;
}

//------------------------------------------------------------------------
Expand Down

0 comments on commit 5c64105

Please sign in to comment.