Skip to content

Commit

Permalink
[oneDPL][ranges][merge] + fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
MikeDvorskiy committed Nov 22, 2024
1 parent 0ac0ae3 commit 1ee9aff
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 22 deletions.
32 changes: 18 additions & 14 deletions include/oneapi/dpl/pstl/algorithm_ranges_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ __brick_merge(It1 __it_1, It1 __it_1_e, It2 __it_2, It2 __it_2_e, ItOut __it_out
}
else
{
assert(__it_2 == __it_2_e);
//assert(__it_2 == __it_2_e);
for(; __it_1 != __it_1_e && __it_out != __it_out_e; ++__it_1, ++__it_out)
*__it_out = *__it_1;
}
Expand All @@ -511,9 +511,14 @@ __pattern_merge(__parallel_tag<_IsVector>, _ExecutionPolicy&& __exec, _R1&& __r1
std::invoke(__proj1, std::forward<decltype(__val1)>(__val1)), std::invoke(__proj2,
std::forward<decltype(__val2)>(__val2)));};

auto __n_1 = std::ranges::size(__r1);
auto __n_2 = std::ranges::size(__r2);
auto __n_out = std::min(__n_1 + __n_2, std::ranges::size(__out_r));
using _Index1 = std::ranges::range_difference_t<_R1>;
using _Index2 = std::ranges::range_difference_t<_R2>;
using _Index3 = std::ranges::range_difference_t<_OutRange>;

_Index1 __n_1 = std::ranges::size(__r1);
_Index2 __n_2 = std::ranges::size(__r2);

_Index3 __n_out = std::min<_Index3>(__n_1 + __n_2, std::ranges::size(__out_r));

auto __it_1 = std::ranges::begin(__r1);
auto __it_2 = std::ranges::begin(__r2);
Expand All @@ -523,8 +528,8 @@ __pattern_merge(__parallel_tag<_IsVector>, _ExecutionPolicy&& __exec, _R1&& __r1
std::ranges::borrowed_iterator_t<_R1> __it_res_2;

__internal::__except_handler([&]() {
__par_backend::__parallel_for(__backend_tag{}, ::std::forward<_ExecutionPolicy>(__exec), 0, __n_out,
[=, &__r1, &__r2, &__out_r, &__it_res_1, &__it_res_2](auto __i, auto __j)
__par_backend::__parallel_for(__backend_tag{}, ::std::forward<_ExecutionPolicy>(__exec), _Index3(0), __n_out,
[=, &__r1, &__r2, &__out_r, &__it_res_1, &__it_res_2](_Index3 __i, _Index3 __j)
{/*...*/

//a start merging point on the merge path; for each thread
Expand All @@ -534,19 +539,18 @@ __pattern_merge(__parallel_tag<_IsVector>, _ExecutionPolicy&& __exec, _R1&& __r1
if(__i > 0)
{
//calc merge path intersection:
using _Index3 = std::ranges::range_difference_t<_OutRange>;
_Index3 __d_size = std::abs(std::max(0, __i - __n_2) - (std::min(__i, __n_1) - 1)) + 1;
const _Index3 __d_size = std::abs(std::max<_Index2>(0, __i - __n_2) - (std::min<_Index1>(__i, __n_1) - 1)) + 1;

auto __get_x = [__i, __n_1](auto __d) { return std::min(__i, __n_1) - __d; };
auto __get_y = [__i, __n_1](auto __d) { return std::max(0, __i - __n_1) + __d; };
auto __get_x = [__i, __n_1](auto __d) { return std::min<_Index1>(__i, __n_1) - __d; };
auto __get_y = [__i, __n_1](auto __d) { return std::max<_Index1>(0, __i - __n_1) + __d; };

oneapi::dpl::counting_iterator<_Index3> __it_d(0);
auto __res_d = *std::lower_bound(__it_d, __it_d + __d_size, 1,
[](auto __d, auto __val) {
[&](auto __d, auto __val) {
auto __x = __get_x(__d);
auto __y = __get_y(__d);

const auto __res = __comp_2(__r1[__x], __r2[__y]) ? 0 : 1
const auto __res = __comp_2(__r1[__x], __r2[__y]) ? 0 : 1;
return __res < __val;
}
);
Expand All @@ -556,7 +560,7 @@ __pattern_merge(__parallel_tag<_IsVector>, _ExecutionPolicy&& __exec, _R1&& __r1

}

const auto __n = __j - __i;
const _Index3 __n = __j - __i;

//serial merge n elements, starting from input x and y, to [i, j) output range
auto __res = __brick_merge(__it_1 + __x, __it_1 + __n_1,
Expand Down Expand Up @@ -647,7 +651,7 @@ __pattern_merge(__serial_tag</*IsVector*/std::false_type>, _ExecutionPolicy&& __
}
else
{
assert(__it_2 == std::ranges::end(__r2);
//assert(__it_2 == std::ranges::end(__r2);
for(; __it_1 != std::ranges::end(__r1) && __it_out != std::ranges::end(__out_r); ++__it_1, ++__it_out)
*__it_out = *__it_1;
}
Expand Down
16 changes: 8 additions & 8 deletions test/parallel_api/ranges/std_ranges_merge.pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ main()
auto merge_checker = [](std::ranges::random_access_range auto&& r_1,
std::ranges::random_access_range auto&& r_2,
std::ranges::random_access_range auto&& r_out, auto comp, auto proj1,
auto proj2, auto&&... args)
auto proj2)
{
using ret_type = std::ranges::merge_result<std::ranges::borrowed_iterator_t<decltype(r_1)>,
std::ranges::borrowed_iterator_t<decltype(r_2)>, std::ranges::borrowed_iterator_t<decltype(r_out)>>;
Expand All @@ -37,19 +37,19 @@ main()
for(;;)
{
if(it_out == std::ranges::end(r_out))
return ret_type{res.in1, res.in2, res.out};
return ret_type{it_1, it_2, it_out};

if(it_1 == std::ranges::end(r_1))
{
for(auto it_2 = std::ranges::begin(r_2) && it_out != std::ranges::end(r_out); ++it_2, ++it_out)
for(auto it_2 = std::ranges::begin(r_2); it_out != std::ranges::end(r_out); ++it_2, ++it_out)
*it_out = *it_2;
return ret_type{res.in1, res.in2, res.out};
return ret_type{it_1, it_2, it_out};
}
else if(it_2 == std::ranges::end(r_2))
{
for(auto it_1 = std::ranges::begin(r_1) && it_out != std::ranges::end(r_out); ++it_1, ++it_out)
for(auto it_1 = std::ranges::begin(r_1); it_out != std::ranges::end(r_out); ++it_1, ++it_out)
*it_out = *it_1;
return ret_type{res.in1, res.in2, res.out};
return ret_type{it_1, it_2, it_out};
}

if (std::invoke(comp, std::invoke(proj1, *it_1), std::invoke(proj2, *it_2)))
Expand All @@ -64,10 +64,10 @@ main()
}
}

return ret_type{res.in1, res.in2, res.out};
return ret_type{it_1, it_2, it_out};
};

test_range_algo<0, int, data_in_in_out>{big_sz}(dpl_ranges::merge, merge_checker, std::ranges::less{});
test_range_algo<0, int, data_in_in_out>{big_sz}(dpl_ranges::merge, merge_checker, std::ranges::less{}, std::identity{}, std::identity{});

test_range_algo<1, int, data_in_in_out>{}(dpl_ranges::merge, merge_checker, std::ranges::less{}, proj, proj);
test_range_algo<2, P2, data_in_in_out>{}(dpl_ranges::merge, merge_checker, std::ranges::less{}, &P2::x, &P2::x);
Expand Down

0 comments on commit 1ee9aff

Please sign in to comment.