Skip to content

Commit

Permalink
[oneDPL][ranges][merge] support size limit for output; parallel patte…
Browse files Browse the repository at this point in the history
…rn (based on merge path) for host back-ends
  • Loading branch information
MikeDvorskiy committed Nov 22, 2024
1 parent 568036c commit be10c8a
Showing 1 changed file with 147 additions and 0 deletions.
147 changes: 147 additions & 0 deletions include/oneapi/dpl/pstl/algorithm_ranges_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,153 @@ __pattern_merge(_Tag __tag, _ExecutionPolicy&& __exec, _R1&& __r1, _R2&& __r2, _
return __return_type{std::ranges::begin(__r1) + std::ranges::size(__r1), std::ranges::begin(__r2) + std::ranges::size(__r2), __res};
}

template<std::random_access_iterator It1, std::random_access_iterator It2, std::random_access_iterator ItOut, typename _Comp>
std::pair<It1, It2>
__brick_merge(It1 __it_1, It1 __it_1_e, It2 __it_2, It2 __it_2_e, ItOut __it_out, ItOut __it_out_e, _Comp __comp)
{
while(__it_1 != __it_1_e && __it_2 != __it_2_e)
{
if (__comp(*__it_1, *__it_2))
{
*__it_out = *__it_1;
++__it_out, ++__it_1;
}
else
{
*__it_out = *__it_2;
++__it_out, ++__it_2;
}
if(__it_out == __it_out_e)
return {__it_1, __it_2};
}

if(__it_1 == __it_1_e)
{
for(; __it_2 != __it_2_e && __it_out != __it_out_e; ++__it_2, ++__it_out)
*__it_out = *__it_2;
}
else
{
assert(__it_2 == __it_2_e);
for(; __it_1 != __it_1_e && __it_out != __it_out_e; ++__it_1, ++__it_out)
*__it_out = *__it_1;
}
return {__it_1, __it_2};
}

template<typename _IsVector, typename _ExecutionPolicy, typename _R1, typename _R2, typename _OutRange, typename _Comp,
typename _Proj1, typename _Proj2>
auto
__pattern_merge(__parallel_tag<_IsVector>, _ExecutionPolicy&& __exec, _R1&& __r1, _R2&& __r2, _OutRange&& __out_r, _Comp __comp,
_Proj1 __proj1, _Proj2 __proj2)
{
using __backend_tag = typename __parallel_tag<_IsVector>::__backend_tag;

auto __comp_2 = [__comp, __proj1, __proj2](auto&& __val1, auto&& __val2) { return std::invoke(__comp,
std::invoke(__proj1, std::forward<decltype(__val1)>(__val1)), std::invoke(__proj2,
std::forward<decltype(__val2)>(__val2)));};

auto __n_1 = std::ranges::size(__r1);
auto __n_2 = std::ranges::size(__r2);
auto __n_out = std::min(__n_1 + __n_2, std::ranges::size(__out_r));

auto __it_1 = std::ranges::begin(__r1);
auto __it_2 = std::ranges::begin(__r2);
auto __it_out = std::ranges::begin(__out_r);

std::ranges::borrowed_iterator_t<_R1> __it_res_1;
std::ranges::borrowed_iterator_t<_R1> __it_res_2;

__internal::__except_handler([&]() {
__par_backend::__parallel_for(__backend_tag{}, ::std::forward<_ExecutionPolicy>(__exec), 0, __n_out,
[&__r1, &__r2, &__out_r, &__it_res_1, &__it_res_2, =](auto __i, auto __j)
{/*...*/

//a start merging point on the merge path; for each thread
std::ranges::range_difference_t<_R1> __x = 0;
std::ranges::range_difference_t<_R2> __y = 0;

if(__i > 0)
{
//calc merge path intersection:
using _Index3 = std::ranges::range_difference_t<_OutRange>;
_Index3 __d_size = std::abs(std::max(0, __i - __n2) - (std::min(__i, __n1) - 1)) + 1;

auto __get_x = [__i, __n1](auto __d) { return std::min(__i, __n1) - __d};
auto __get_y = [__i, __n1](auto __d) { return std::max(0, __i - __n1) + __d};

oneapi::dpl::counting_iterator<_Index3> __it_d(0);
auto __res_d = *std::lower_bound(__it_d, __it_d + __d_size, 1,
[](auto __d, auto __val) {
auto __x = __get_x(__d);
auto __y = __get_y(__d);

const auto __res = __comp_2(__r1[__x], __r2[__y]) ? 0 : 1
return __res < __val;
}
);
//intersection point
__x = __get_x(__res_d);
__y = __get_y(__res_d);

}

const auto __n = __j - __i;

//serial merge n elements, starting from input x and y, to [i, j) output range
auto __res = __brick_merge(__it_1 + __x, __it_1 + __n_1,
__it_2 + __y, __it_2 + __n_2,
__it_out + __i, __it_out + __j, __comp_2);

if(__j == __n_out)
{
__it_res_1 = __res.first;
__it_res_2 = __res.second;
}
});
});

using __return_type = std::ranges::merge_result<std::ranges::borrowed_iterator_t<_R1>, std::ranges::borrowed_iterator_t<_R2>,
std::ranges::borrowed_iterator_t<_OutRange>>;

return __return_type{__it_res_1, __it_res_2, std::ranges::begin(__out_r) + __n_out};
}

//TODO:
#if 0
template<typename Index1, typename Index2, typename Index3>
auto
__brick_merge(Index1 x, Index2 y, Index3 i, Index3 j)
{
for(Index3 k = i; k < j; ++k)
{
if(x >= x_e)
{
assert(y < y_e);
c[k] = b[y];
++y;
}
else if(y >= y_e)
{
assert(x < x_e);
c[k] = a[x];
++x;
}
else if(a[x] < b[y])
{
c[k] = a[x];
++x;
}
else
{
c[k] = b[y];
++y;
}
}
}
#endif


template<typename _ExecutionPolicy, typename _R1, typename _R2, typename _OutRange, typename _Comp,
typename _Proj1, typename _Proj2>
auto
Expand Down

0 comments on commit be10c8a

Please sign in to comment.