Skip to content

Commit

Permalink
[oneDPL][ranges] + dispatch fixes for ranges::sort, ranges::sort_stab…
Browse files Browse the repository at this point in the history
…le implementation
  • Loading branch information
MikeDvorskiy committed Sep 24, 2024
1 parent 284df8c commit de0349b
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 11 deletions.
31 changes: 28 additions & 3 deletions include/oneapi/dpl/pstl/algorithm_ranges_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -342,11 +342,9 @@ __pattern_sort_ranges(_Tag __tag, _ExecutionPolicy&& __exec, _R&& __r, _Comp __c

auto __comp_2 = [__comp, __proj](auto&& __val1, auto&& __val2) { return std::invoke(__comp, std::invoke(__proj,
std::forward<decltype(__val1)>(__val1)), std::invoke(__proj, std::forward<decltype(__val2)>(__val2)));};
// Use stable sort as a leaf since __pattern_sort_ranges is shared between sort and stable_sort
// TODO: add a separate pattern for ranges::sort for better performance
oneapi::dpl::__internal::__pattern_sort(__tag, std::forward<_ExecutionPolicy>(__exec), std::ranges::begin(__r),
std::ranges::begin(__r) + std::ranges::size(__r), __comp_2,
std::ranges::stable_sort);
std::ranges::sort);

return std::ranges::borrowed_iterator_t<_R>(std::ranges::begin(__r) + std::ranges::size(__r));
}
Expand All @@ -355,6 +353,33 @@ template <typename _ExecutionPolicy, typename _R, typename _Proj, typename _Comp
auto
__pattern_sort_ranges(__serial_tag</*IsVector*/ std::false_type>, _ExecutionPolicy&& __exec, _R&& __r, _Comp __comp,
_Proj __proj)
{
return std::ranges::sort(std::forward<_R>(__r), __comp, __proj);
}

//---------------------------------------------------------------------------------------------------------------------
// pattern_stable_sort_ranges
//---------------------------------------------------------------------------------------------------------------------

template <typename _Tag, typename _ExecutionPolicy, typename _R, typename _Proj, typename _Comp>
auto
__pattern_stable_sort_ranges(_Tag __tag, _ExecutionPolicy&& __exec, _R&& __r, _Comp __comp, _Proj __proj)
{
static_assert(__is_parallel_tag_v<_Tag> || typename _Tag::__is_vector{});

auto __comp_2 = [__comp, __proj](auto&& __val1, auto&& __val2) { return std::invoke(__comp, std::invoke(__proj,
std::forward<decltype(__val1)>(__val1)), std::invoke(__proj, std::forward<decltype(__val2)>(__val2)));};
oneapi::dpl::__internal::__pattern_sort(__tag, std::forward<_ExecutionPolicy>(__exec), std::ranges::begin(__r),
std::ranges::begin(__r) + std::ranges::size(__r), __comp_2,
std::ranges::stable_sort);

return std::ranges::borrowed_iterator_t<_R>(std::ranges::begin(__r) + std::ranges::size(__r));
}

template <typename _ExecutionPolicy, typename _R, typename _Proj, typename _Comp>
auto
__pattern_stable_sort_ranges(__serial_tag</*IsVector*/ std::false_type>, _ExecutionPolicy&& __exec, _R&& __r, _Comp __comp,
_Proj __proj)
{
return std::ranges::stable_sort(std::forward<_R>(__r), __comp, __proj);
}
Expand Down
13 changes: 8 additions & 5 deletions include/oneapi/dpl/pstl/glue_algorithm_ranges_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ struct __stable_sort_fn
operator()(_ExecutionPolicy&& __exec, _R&& __r, _Comp __comp = {}, _Proj __proj = {}) const
{
const auto __dispatch_tag = oneapi::dpl::__ranges::__select_backend(__exec);
return oneapi::dpl::__internal::__ranges::__pattern_sort_ranges(__dispatch_tag,
return oneapi::dpl::__internal::__ranges::__pattern_stable_sort_ranges(__dispatch_tag,
std::forward<_ExecutionPolicy>(__exec), std::forward<_R>(__r), __comp, __proj);
}
}; //__stable_sort_fn
Expand All @@ -440,8 +440,9 @@ struct __sort_fn
auto
operator()(_ExecutionPolicy&& __exec, _R&& __r, _Comp __comp = {}, _Proj __proj = {}) const
{
return oneapi::dpl::ranges::stable_sort(std::forward<_ExecutionPolicy>(__exec), std::forward<_R>(__r),
__comp, __proj);
const auto __dispatch_tag = oneapi::dpl::__ranges::__select_backend(__exec);
return oneapi::dpl::__internal::__ranges::__pattern_sort_ranges(__dispatch_tag,
std::forward<_ExecutionPolicy>(__exec), std::forward<_R>(__r), __comp, __proj);
}
}; //__sort_fn
} //__internal
Expand Down Expand Up @@ -1062,14 +1063,16 @@ template <typename _ExecutionPolicy, typename _Range, typename _Compare>
oneapi::dpl::__internal::__enable_if_execution_policy<_ExecutionPolicy>
stable_sort(_ExecutionPolicy&& __exec, _Range&& __rng, _Compare __comp)
{
sort(::std::forward<_ExecutionPolicy>(__exec), ::std::forward<_Range>(__rng), __comp);
const auto __dispatch_tag = oneapi::dpl::__ranges::__select_backend(__exec, __rng);
oneapi::dpl::__internal::__ranges::__pattern_stable_sort(__dispatch_tag, std::forward<_ExecutionPolicy>(__exec),
views::all(std::forward<_Range>(__rng)), __comp, __proj);
}

template <typename _ExecutionPolicy, typename _Range>
oneapi::dpl::__internal::__enable_if_execution_policy<_ExecutionPolicy>
stable_sort(_ExecutionPolicy&& __exec, _Range&& __rng)
{
sort(::std::forward<_ExecutionPolicy>(__exec), ::std::forward<_Range>(__rng),
stable_sort(std::forward<_ExecutionPolicy>(__exec), std::forward<_Range>(__rng),
oneapi::dpl::__internal::__pstl_less());
}

Expand Down
15 changes: 12 additions & 3 deletions include/oneapi/dpl/pstl/hetero/algorithm_ranges_impl_hetero.h
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,7 @@ __pattern_merge(__hetero_tag<_BackendTag> __tag, _ExecutionPolicy&& __exec, _R1&

template <typename _BackendTag, typename _ExecutionPolicy, typename _Range, typename _Compare, typename _Proj>
void
__pattern_sort(__hetero_tag<_BackendTag>, _ExecutionPolicy&& __exec, _Range&& __rng, _Compare __comp, _Proj __proj)
__pattern_stable_sort(__hetero_tag<_BackendTag>, _ExecutionPolicy&& __exec, _Range&& __rng, _Compare __comp, _Proj __proj)
{
if (__rng.size() >= 2)
__par_backend_hetero::__parallel_stable_sort(_BackendTag{}, ::std::forward<_ExecutionPolicy>(__exec),
Expand All @@ -758,16 +758,25 @@ __pattern_sort(__hetero_tag<_BackendTag>, _ExecutionPolicy&& __exec, _Range&& __
}

#if _ONEDPL_CPP20_RANGES_PRESENT

template <typename _BackendTag, typename _ExecutionPolicy, typename _R, typename _Comp, typename _Proj>
auto
__pattern_sort_ranges(__hetero_tag<_BackendTag> __tag, _ExecutionPolicy&& __exec, _R&& __r, _Comp __comp, _Proj __proj)
__pattern_stable_sort_ranges(__hetero_tag<_BackendTag> __tag, _ExecutionPolicy&& __exec, _R&& __r, _Comp __comp, _Proj __proj)
{
oneapi::dpl::__internal::__ranges::__pattern_sort(__tag, std::forward<_ExecutionPolicy>(__exec),
oneapi::dpl::__internal::__ranges::__pattern_stable_sort(__tag, std::forward<_ExecutionPolicy>(__exec),
oneapi::dpl::__ranges::views::all(__r), __comp, __proj);

return std::ranges::borrowed_iterator_t<_R>(std::ranges::begin(__r) + std::ranges::size(__r));
}

template <typename _BackendTag, typename _ExecutionPolicy, typename _R, typename _Comp, typename _Proj>
auto
__pattern_sort_ranges(__hetero_tag<_BackendTag> __tag, _ExecutionPolicy&& __exec, _R&& __r, _Comp __comp, _Proj __proj)
{
//TODO: to use potentially more effective unstable sorting pattern (not implemented yet)
return __pattern_stable_sort_ranges(__tag, std::forward<_ExecutionPolicy>(__exec), std::forward<_R>(__r), __comp, __proj);
}

#endif //_ONEDPL_CPP20_RANGES_PRESENT

//------------------------------------------------------------------------
Expand Down

0 comments on commit de0349b

Please sign in to comment.