From 1ee9affe03690255f091b193788d228d5b1c1003 Mon Sep 17 00:00:00 2001 From: MikeDvorskiy Date: Fri, 22 Nov 2024 18:41:48 +0100 Subject: [PATCH] [oneDPL][ranges][merge] + fixes --- .../oneapi/dpl/pstl/algorithm_ranges_impl.h | 32 +++++++++++-------- .../ranges/std_ranges_merge.pass.cpp | 16 +++++----- 2 files changed, 26 insertions(+), 22 deletions(-) diff --git a/include/oneapi/dpl/pstl/algorithm_ranges_impl.h b/include/oneapi/dpl/pstl/algorithm_ranges_impl.h index 4acadd6b726..ff3f1e65e78 100644 --- a/include/oneapi/dpl/pstl/algorithm_ranges_impl.h +++ b/include/oneapi/dpl/pstl/algorithm_ranges_impl.h @@ -492,7 +492,7 @@ __brick_merge(It1 __it_1, It1 __it_1_e, It2 __it_2, It2 __it_2_e, ItOut __it_out } else { - assert(__it_2 == __it_2_e); + //assert(__it_2 == __it_2_e); for(; __it_1 != __it_1_e && __it_out != __it_out_e; ++__it_1, ++__it_out) *__it_out = *__it_1; } @@ -511,9 +511,14 @@ __pattern_merge(__parallel_tag<_IsVector>, _ExecutionPolicy&& __exec, _R1&& __r1 std::invoke(__proj1, std::forward(__val1)), std::invoke(__proj2, std::forward(__val2)));}; - auto __n_1 = std::ranges::size(__r1); - auto __n_2 = std::ranges::size(__r2); - auto __n_out = std::min(__n_1 + __n_2, std::ranges::size(__out_r)); + using _Index1 = std::ranges::range_difference_t<_R1>; + using _Index2 = std::ranges::range_difference_t<_R2>; + using _Index3 = std::ranges::range_difference_t<_OutRange>; + + _Index1 __n_1 = std::ranges::size(__r1); + _Index2 __n_2 = std::ranges::size(__r2); + + _Index3 __n_out = std::min<_Index3>(__n_1 + __n_2, std::ranges::size(__out_r)); auto __it_1 = std::ranges::begin(__r1); auto __it_2 = std::ranges::begin(__r2); @@ -523,8 +528,8 @@ __pattern_merge(__parallel_tag<_IsVector>, _ExecutionPolicy&& __exec, _R1&& __r1 std::ranges::borrowed_iterator_t<_R1> __it_res_2; __internal::__except_handler([&]() { - __par_backend::__parallel_for(__backend_tag{}, ::std::forward<_ExecutionPolicy>(__exec), 0, __n_out, - [=, &__r1, &__r2, &__out_r, &__it_res_1, &__it_res_2](auto __i, auto __j) + __par_backend::__parallel_for(__backend_tag{}, ::std::forward<_ExecutionPolicy>(__exec), _Index3(0), __n_out, + [=, &__r1, &__r2, &__out_r, &__it_res_1, &__it_res_2](_Index3 __i, _Index3 __j) {/*...*/ //a start merging point on the merge path; for each thread @@ -534,19 +539,18 @@ __pattern_merge(__parallel_tag<_IsVector>, _ExecutionPolicy&& __exec, _R1&& __r1 if(__i > 0) { //calc merge path intersection: - using _Index3 = std::ranges::range_difference_t<_OutRange>; - _Index3 __d_size = std::abs(std::max(0, __i - __n_2) - (std::min(__i, __n_1) - 1)) + 1; + const _Index3 __d_size = std::abs(std::max<_Index2>(0, __i - __n_2) - (std::min<_Index1>(__i, __n_1) - 1)) + 1; - auto __get_x = [__i, __n_1](auto __d) { return std::min(__i, __n_1) - __d; }; - auto __get_y = [__i, __n_1](auto __d) { return std::max(0, __i - __n_1) + __d; }; + auto __get_x = [__i, __n_1](auto __d) { return std::min<_Index1>(__i, __n_1) - __d; }; + auto __get_y = [__i, __n_1](auto __d) { return std::max<_Index1>(0, __i - __n_1) + __d; }; oneapi::dpl::counting_iterator<_Index3> __it_d(0); auto __res_d = *std::lower_bound(__it_d, __it_d + __d_size, 1, - [](auto __d, auto __val) { + [&](auto __d, auto __val) { auto __x = __get_x(__d); auto __y = __get_y(__d); - const auto __res = __comp_2(__r1[__x], __r2[__y]) ? 0 : 1 + const auto __res = __comp_2(__r1[__x], __r2[__y]) ? 0 : 1; return __res < __val; } ); @@ -556,7 +560,7 @@ __pattern_merge(__parallel_tag<_IsVector>, _ExecutionPolicy&& __exec, _R1&& __r1 } - const auto __n = __j - __i; + const _Index3 __n = __j - __i; //serial merge n elements, starting from input x and y, to [i, j) output range auto __res = __brick_merge(__it_1 + __x, __it_1 + __n_1, @@ -647,7 +651,7 @@ __pattern_merge(__serial_tag, _ExecutionPolicy&& __ } else { - assert(__it_2 == std::ranges::end(__r2); + //assert(__it_2 == std::ranges::end(__r2); for(; __it_1 != std::ranges::end(__r1) && __it_out != std::ranges::end(__out_r); ++__it_1, ++__it_out) *__it_out = *__it_1; } diff --git a/test/parallel_api/ranges/std_ranges_merge.pass.cpp b/test/parallel_api/ranges/std_ranges_merge.pass.cpp index e381de3f73f..373145a20da 100644 --- a/test/parallel_api/ranges/std_ranges_merge.pass.cpp +++ b/test/parallel_api/ranges/std_ranges_merge.pass.cpp @@ -26,7 +26,7 @@ main() auto merge_checker = [](std::ranges::random_access_range auto&& r_1, std::ranges::random_access_range auto&& r_2, std::ranges::random_access_range auto&& r_out, auto comp, auto proj1, - auto proj2, auto&&... args) + auto proj2) { using ret_type = std::ranges::merge_result, std::ranges::borrowed_iterator_t, std::ranges::borrowed_iterator_t>; @@ -37,19 +37,19 @@ main() for(;;) { if(it_out == std::ranges::end(r_out)) - return ret_type{res.in1, res.in2, res.out}; + return ret_type{it_1, it_2, it_out}; if(it_1 == std::ranges::end(r_1)) { - for(auto it_2 = std::ranges::begin(r_2) && it_out != std::ranges::end(r_out); ++it_2, ++it_out) + for(auto it_2 = std::ranges::begin(r_2); it_out != std::ranges::end(r_out); ++it_2, ++it_out) *it_out = *it_2; - return ret_type{res.in1, res.in2, res.out}; + return ret_type{it_1, it_2, it_out}; } else if(it_2 == std::ranges::end(r_2)) { - for(auto it_1 = std::ranges::begin(r_1) && it_out != std::ranges::end(r_out); ++it_1, ++it_out) + for(auto it_1 = std::ranges::begin(r_1); it_out != std::ranges::end(r_out); ++it_1, ++it_out) *it_out = *it_1; - return ret_type{res.in1, res.in2, res.out}; + return ret_type{it_1, it_2, it_out}; } if (std::invoke(comp, std::invoke(proj1, *it_1), std::invoke(proj2, *it_2))) @@ -64,10 +64,10 @@ main() } } - return ret_type{res.in1, res.in2, res.out}; + return ret_type{it_1, it_2, it_out}; }; - test_range_algo<0, int, data_in_in_out>{big_sz}(dpl_ranges::merge, merge_checker, std::ranges::less{}); + test_range_algo<0, int, data_in_in_out>{big_sz}(dpl_ranges::merge, merge_checker, std::ranges::less{}, std::identity{}, std::identity{}); test_range_algo<1, int, data_in_in_out>{}(dpl_ranges::merge, merge_checker, std::ranges::less{}, proj, proj); test_range_algo<2, P2, data_in_in_out>{}(dpl_ranges::merge, merge_checker, std::ranges::less{}, &P2::x, &P2::x);