diff --git a/include/oneapi/dpl/pstl/algorithm_ranges_impl.h b/include/oneapi/dpl/pstl/algorithm_ranges_impl.h index 3aeb9ef4bd8..e2753a51a18 100644 --- a/include/oneapi/dpl/pstl/algorithm_ranges_impl.h +++ b/include/oneapi/dpl/pstl/algorithm_ranges_impl.h @@ -234,11 +234,11 @@ __pattern_search_impl(_Tag __tag, _ExecutionPolicy&& __exec, _R1&& __r1, _R2&& _ auto __pred_2 = [__pred, __proj1, __proj2](auto&& __val1, auto&& __val2) { return __pred(__proj1(__val1), __proj2(__val2));}; - auto __res = oneapi::dpl::__internal::__pattern_search(std::forward<_ExecutionPolicy>(__exec), + auto __res = oneapi::dpl::__internal::__pattern_search(__tag, std::forward<_ExecutionPolicy>(__exec), std::ranges::begin(__r1), std::ranges::begin(__r1) + __r1.size(), std::ranges::begin(__r2), std::ranges::begin(__r2) + __r2.size(), __pred_2); - return std::ranges::borrowed_subrange_t<_R1>(__res, __res + __r2.size()); + return std::ranges::borrowed_subrange_t<_R1>(__res, __res == std::ranges::end(__r1) ? __res : __res + __r2.size()); } template __tag, _ExecutionPolicy&& __exec, _R1 oneapi::dpl::views::all_read(::std::forward<_R1>(__r1)), oneapi::dpl::views::all_read(::std::forward<_R2>(__r2)), __pred_2); - return std::ranges::borrowed_subrange_t<_R1>(__r1.begin() + __idx, __r1.begin() + __idx + __r2.size()); + auto __end = (__idx == __r1.size() ? __r1.begin() + __idx : __r1.begin() + __idx + __r2.size()); + return std::ranges::borrowed_subrange_t<_R1>(__r1.begin() + __idx, __end); } //------------------------------------------------------------------------ diff --git a/test/parallel_api/ranges/std_ranges.pass.cpp b/test/parallel_api/ranges/std_ranges.pass.cpp index 3f6f4c1bf3a..f8b9093a553 100644 --- a/test/parallel_api/ranges/std_ranges.pass.cpp +++ b/test/parallel_api/ranges/std_ranges.pass.cpp @@ -26,6 +26,7 @@ main() auto f = [](auto&& val) -> decltype(auto) { return val * val; }; auto proj = [](auto&& val) -> decltype(auto) { return val * 2; }; auto pred = [](auto&& val) -> decltype(auto) { return val == 5; }; + auto pred_2 = [](auto&& val1, auto&& val2) -> decltype(auto) { return val1 == val2; }; using namespace test_std_ranges; @@ -46,9 +47,10 @@ main() test_range_algo{}(oneapi::dpl::ranges::any_of, std::ranges::any_of, pred2, std::identity{}); test_range_algo{}(oneapi::dpl::ranges::none_of, std::ranges::none_of, pred3, std::identity{}); - auto pred_2 = [](auto&& val1, auto&& val2) -> decltype(auto) { return val1 == val2; }; test_range_algo{}(oneapi::dpl::ranges::adjacent_find, std::ranges::adjacent_find, pred_2, proj); + test_range_algo{}(oneapi::dpl::ranges::search, std::ranges::search, pred_2, proj); + #endif //_ENABLE_STD_RANGES_TESTING return TestUtils::done(_ENABLE_STD_RANGES_TESTING); diff --git a/test/parallel_api/ranges/std_ranges_test.h b/test/parallel_api/ranges/std_ranges_test.h index e856e1d57ae..7c7524f9607 100644 --- a/test/parallel_api/ranges/std_ranges_test.h +++ b/test/parallel_api/ranges/std_ranges_test.h @@ -121,7 +121,7 @@ struct test auto bres_in = ret_in_val(expected_res, src_view.begin()) == ret_in_val(res, tr(A).begin()); EXPECT_TRUE(bres_in, (std::string("wrong return value from algo with input range: ") + typeid(Algo).name()).c_str()); - auto bres_out = ret_out_val(expected_res, src_view.begin()) == ret_out_val(res, tr(A).begin()); + auto bres_out = ret_out_val(expected_res, expected) == ret_out_val(res, B.begin()); EXPECT_TRUE(bres_out, (std::string("wrong return value from algo with output range: ") + typeid(Algo).name()).c_str()); } @@ -129,6 +129,36 @@ struct test EXPECT_EQ_N(expected, data_out, max_n, (std::string("wrong effect algo with ranges: ") + typeid(Algo).name()).c_str()); } + template + std::enable_if_t && Ranges == data_in_in> + operator()(Policy&& exec, Algo algo, Checker checker, Functor f, Proj proj = {}, Transform tr = {}) + { + constexpr int max_n = 10; + int data_in1[max_n] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + int data_in2[max_n] = {0, 0, 2, 3, 4, 5, 0, 0, 0, 0}; + + auto src_view1 = tr(std::ranges::subrange(data_in1, data_in1 + max_n)); + auto src_view2 = tr(std::ranges::subrange(data_in2, data_in2 + max_n)); + auto expected_res = checker(src_view1, src_view2, f, proj, proj); + { + Container cont_in1(exec, data_in1, max_n); + Container cont_in2(exec, data_in2, max_n); + + typename Container::type& A = cont_in1(); + typename Container::type& B = cont_in2(); + + auto res = algo(exec, tr(A), tr(B), f, proj, proj); + + if constexpr(RetTypeCheck) + static_assert(std::is_same_v, "Wrong return type"); + + auto bres_in = ret_in_val(expected_res, src_view1.begin()) == ret_in_val(res, tr(A).begin()); + EXPECT_TRUE(bres_in, (std::string("wrong return value from algo: ") + typeid(Algo).name() + + typeid(decltype(tr(std::declval()()))).name()).c_str()); + } + } + private: template @@ -138,48 +168,50 @@ struct test static constexpr bool is_iterator()), decltype(*std::declval())>> = true; - template - static constexpr auto check_in(int) -> decltype(std::declval().in, std::true_type{}) - { - return {}; - } + template + static constexpr bool check_in{}; - template - static constexpr auto check_in(...) -> std::false_type - { - return {}; - } + template + static constexpr + bool check_in().in)>> = true; + + template + static constexpr bool check_out{}; + + template + static constexpr + bool check_out().out)>> = true; + + + template + static constexpr bool is_range{}; + + template + static constexpr + bool is_range().begin())>> = true; template auto ret_in_val(Ret&& ret, Begin&& begin) { - if constexpr (check_in(0)) + if constexpr (check_in) return std::distance(begin, ret.in); else if constexpr (is_iterator) return std::distance(begin, ret); + else if constexpr(is_range) + return std::pair{std::distance(begin, ret.begin()), std::ranges::distance(ret.begin(), ret.end())}; else return ret; } - - template - static constexpr auto check_out(int) -> decltype(std::declval().out, std::true_type{}) - { - return {}; - } - - template - static constexpr auto check_out(...) -> std::false_type - { - return {}; - } template auto ret_out_val(Ret&& ret, Begin&& begin) { - if constexpr (check_out(0)) - return std::distance(begin, ret.in); + if constexpr (check_out) + return std::distance(begin, ret.out); else if constexpr (is_iterator) return std::distance(begin, ret); + else if constexpr(is_range) + return std::pair{std::distance(begin, ret.begin()), std::ranges::distance(ret.begin(), ret.end())}; else return ret; }