Skip to content

Commit

Permalink
include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_merge.h - …
Browse files Browse the repository at this point in the history
…using __parallel_merge_submitter_large for merge data equal or greater then 4M items

Signed-off-by: Sergey Kopienko <[email protected]>
  • Loading branch information
SergeyKopienko committed Nov 7, 2024
1 parent b33656a commit d4721ca
Showing 1 changed file with 54 additions and 23 deletions.
77 changes: 54 additions & 23 deletions include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_merge.h
Original file line number Diff line number Diff line change
Expand Up @@ -232,14 +232,11 @@ __serial_merge(const _Rng1& __rng1, const _Rng2& __rng2, _Rng3& __rng3, _Index _
}
}

template <typename... _Name>
class _find_split_points_kernel_on_mid_diagonal;

// Please see the comment for __parallel_for_submitter for optional kernel name explanation
template <typename _IdType, typename _Name>
struct __parallel_merge_submitter;

template <typename _IdType, typename _Name>
template <typename _IdType, typename _CustomName, typename _Name>
struct __parallel_merge_submitter_large;

template <typename _IdType, typename... _Name>
Expand Down Expand Up @@ -275,8 +272,14 @@ struct __parallel_merge_submitter<_IdType, __internal::__optional_kernel_name<_N
}
};

template <typename _IdType, typename... _Name>
struct __parallel_merge_submitter_large<_IdType, __internal::__optional_kernel_name<_Name...>>
template <typename... _Name>
class _find_split_points_kernel_on_mid_diagonal_uint32_t;

template <typename... _Name>
class _find_split_points_kernel_on_mid_diagonal_uint64_t;

template <typename _IdType, typename _CustomName, typename... _Name>
struct __parallel_merge_submitter_large<_IdType, _CustomName, __internal::__optional_kernel_name<_Name...>>
{
template <typename _ExecutionPolicy, typename _Range1, typename _Range2, typename _Range3, typename _Compare>
auto
Expand All @@ -290,11 +293,12 @@ struct __parallel_merge_submitter_large<_IdType, __internal::__optional_kernel_n

_PRINT_INFO_IN_DEBUG_MODE(__exec);

using _CustomName = oneapi::dpl::__internal::__policy_kernel_name<_ExecutionPolicy>;

using _FindSplitPointsKernelOnMidDiagonal =
using _FindSplitPointsKernelOnMidDiagonal = std::conditional_t<
std::is_same_v<_IdType, std::uint32_t>,
oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_generator<
_find_split_points_kernel_on_mid_diagonal_uint32_t, _CustomName, _Range1, _Range2, _IdType, _Compare>,
oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_generator<
_find_split_points_kernel_on_mid_diagonal, _CustomName, _Range1, _Range2, _IdType, _Compare>;
_find_split_points_kernel_on_mid_diagonal_uint64_t, _CustomName, _Range1, _Range2, _IdType, _Compare>>;

// Empirical number of values to process per work-item
const std::uint8_t __chunk = __exec.queue().get_device().is_cpu() ? 128 : 4;
Expand Down Expand Up @@ -379,6 +383,9 @@ struct __parallel_merge_submitter_large<_IdType, __internal::__optional_kernel_n
template <typename... _Name>
class __merge_kernel_name;

template <typename... _Name>
class __merge_kernel_name_large;

template <typename _ExecutionPolicy, typename _Range1, typename _Range2, typename _Range3, typename _Compare>
auto
__parallel_merge(oneapi::dpl::__internal::__device_backend_tag, _ExecutionPolicy&& __exec, _Range1&& __rng1,
Expand All @@ -387,23 +394,47 @@ __parallel_merge(oneapi::dpl::__internal::__device_backend_tag, _ExecutionPolicy
using _CustomName = oneapi::dpl::__internal::__policy_kernel_name<_ExecutionPolicy>;

const auto __n = __rng1.size() + __rng2.size();
if (__n <= std::numeric_limits<std::uint32_t>::max())
if (__n < 4 * 1'048'576)
{
using _WiIndex = std::uint32_t;
using _MergeKernel = oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider<
__merge_kernel_name<_CustomName, _WiIndex>>;
return __parallel_merge_submitter<_WiIndex, _MergeKernel>()(
std::forward<_ExecutionPolicy>(__exec), std::forward<_Range1>(__rng1), std::forward<_Range2>(__rng2),
std::forward<_Range3>(__rng3), __comp);
if (__n <= std::numeric_limits<std::uint32_t>::max())
{
using _WiIndex = std::uint32_t;
using _MergeKernel = oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider<
__merge_kernel_name<_CustomName, _WiIndex>>;
return __parallel_merge_submitter<_WiIndex, _MergeKernel>()(
std::forward<_ExecutionPolicy>(__exec), std::forward<_Range1>(__rng1), std::forward<_Range2>(__rng2),
std::forward<_Range3>(__rng3), __comp);
}
else
{
using _WiIndex = std::uint64_t;
using _MergeKernel = oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider<
__merge_kernel_name<_CustomName, _WiIndex>>;
return __parallel_merge_submitter<_WiIndex, _MergeKernel>()(
std::forward<_ExecutionPolicy>(__exec), std::forward<_Range1>(__rng1), std::forward<_Range2>(__rng2),
std::forward<_Range3>(__rng3), __comp);
}
}
else
{
using _WiIndex = std::uint64_t;
using _MergeKernel = oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider<
__merge_kernel_name<_CustomName, _WiIndex>>;
return __parallel_merge_submitter<_WiIndex, _MergeKernel>()(
std::forward<_ExecutionPolicy>(__exec), std::forward<_Range1>(__rng1), std::forward<_Range2>(__rng2),
std::forward<_Range3>(__rng3), __comp);
if (__n <= std::numeric_limits<std::uint32_t>::max())
{
using _WiIndex = std::uint32_t;
using _MergeKernelLarge = oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider<
__merge_kernel_name_large<_CustomName, _WiIndex>>;
return __parallel_merge_submitter_large<_WiIndex, _CustomName, _MergeKernelLarge>()(
std::forward<_ExecutionPolicy>(__exec), std::forward<_Range1>(__rng1), std::forward<_Range2>(__rng2),
std::forward<_Range3>(__rng3), __comp);
}
else
{
using _WiIndex = std::uint64_t;
using _MergeKernelLarge = oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider<
__merge_kernel_name_large<_CustomName, _WiIndex>>;
return __parallel_merge_submitter_large<_WiIndex, _CustomName, _MergeKernelLarge>()(
std::forward<_ExecutionPolicy>(__exec), std::forward<_Range1>(__rng1), std::forward<_Range2>(__rng2),
std::forward<_Range3>(__rng3), __comp);
}
}
}

Expand Down

0 comments on commit d4721ca

Please sign in to comment.