Skip to content

Commit

Permalink
[oneDPL][ranges] + fixes in ranges::search implementation; + test for…
Browse files Browse the repository at this point in the history
… ranges::search
  • Loading branch information
MikeDvorskiy committed Mar 21, 2024
1 parent 131003e commit 0d1c1d5
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 30 deletions.
4 changes: 2 additions & 2 deletions include/oneapi/dpl/pstl/algorithm_ranges_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -234,11 +234,11 @@ __pattern_search_impl(_Tag __tag, _ExecutionPolicy&& __exec, _R1&& __r1, _R2&& _
auto __pred_2 =
[__pred, __proj1, __proj2](auto&& __val1, auto&& __val2) { return __pred(__proj1(__val1), __proj2(__val2));};

auto __res = oneapi::dpl::__internal::__pattern_search(std::forward<_ExecutionPolicy>(__exec),
auto __res = oneapi::dpl::__internal::__pattern_search(__tag, std::forward<_ExecutionPolicy>(__exec),
std::ranges::begin(__r1), std::ranges::begin(__r1) + __r1.size(), std::ranges::begin(__r2),
std::ranges::begin(__r2) + __r2.size(), __pred_2);

return std::ranges::borrowed_subrange_t<_R1>(__res, __res + __r2.size());
return std::ranges::borrowed_subrange_t<_R1>(__res, __res == std::ranges::end(__r1) ? __res : __res + __r2.size());
}

template<typename _IsVector, typename _ExecutionPolicy, typename _R1, typename _R2, typename _Pred, typename _Proj1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,8 @@ __pattern_search(__hetero_tag<_BackendTag> __tag, _ExecutionPolicy&& __exec, _R1
oneapi::dpl::views::all_read(::std::forward<_R1>(__r1)), oneapi::dpl::views::all_read(::std::forward<_R2>(__r2)),
__pred_2);

return std::ranges::borrowed_subrange_t<_R1>(__r1.begin() + __idx, __r1.begin() + __idx + __r2.size());
auto __end = (__idx == __r1.size() ? __r1.begin() + __idx : __r1.begin() + __idx + __r2.size());
return std::ranges::borrowed_subrange_t<_R1>(__r1.begin() + __idx, __end);
}

//------------------------------------------------------------------------
Expand Down
4 changes: 3 additions & 1 deletion test/parallel_api/ranges/std_ranges.pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ main()
auto f = [](auto&& val) -> decltype(auto) { return val * val; };
auto proj = [](auto&& val) -> decltype(auto) { return val * 2; };
auto pred = [](auto&& val) -> decltype(auto) { return val == 5; };
auto pred_2 = [](auto&& val1, auto&& val2) -> decltype(auto) { return val1 == val2; };

using namespace test_std_ranges;

Expand All @@ -46,9 +47,10 @@ main()
test_range_algo{}(oneapi::dpl::ranges::any_of, std::ranges::any_of, pred2, std::identity{});
test_range_algo{}(oneapi::dpl::ranges::none_of, std::ranges::none_of, pred3, std::identity{});

auto pred_2 = [](auto&& val1, auto&& val2) -> decltype(auto) { return val1 == val2; };
test_range_algo{}(oneapi::dpl::ranges::adjacent_find, std::ranges::adjacent_find, pred_2, proj);

test_range_algo<data_in_in>{}(oneapi::dpl::ranges::search, std::ranges::search, pred_2, proj);

#endif //_ENABLE_STD_RANGES_TESTING

return TestUtils::done(_ENABLE_STD_RANGES_TESTING);
Expand Down
84 changes: 58 additions & 26 deletions test/parallel_api/ranges/std_ranges_test.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,14 +121,44 @@ struct test
auto bres_in = ret_in_val(expected_res, src_view.begin()) == ret_in_val(res, tr(A).begin());
EXPECT_TRUE(bres_in, (std::string("wrong return value from algo with input range: ") + typeid(Algo).name()).c_str());

auto bres_out = ret_out_val(expected_res, src_view.begin()) == ret_out_val(res, tr(A).begin());
auto bres_out = ret_out_val(expected_res, expected) == ret_out_val(res, B.begin());
EXPECT_TRUE(bres_out, (std::string("wrong return value from algo with output range: ") + typeid(Algo).name()).c_str());
}

//check result
EXPECT_EQ_N(expected, data_out, max_n, (std::string("wrong effect algo with ranges: ") + typeid(Algo).name()).c_str());
}

template<typename Policy, typename Algo, typename Checker, typename Functor, typename Proj = std::identity,
typename Transform = std::identity>
std::enable_if_t<!std::is_same_v<Policy, std::true_type> && Ranges == data_in_in>
operator()(Policy&& exec, Algo algo, Checker checker, Functor f, Proj proj = {}, Transform tr = {})
{
constexpr int max_n = 10;
int data_in1[max_n] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
int data_in2[max_n] = {0, 0, 2, 3, 4, 5, 0, 0, 0, 0};

auto src_view1 = tr(std::ranges::subrange(data_in1, data_in1 + max_n));
auto src_view2 = tr(std::ranges::subrange(data_in2, data_in2 + max_n));
auto expected_res = checker(src_view1, src_view2, f, proj, proj);
{
Container cont_in1(exec, data_in1, max_n);
Container cont_in2(exec, data_in2, max_n);

typename Container::type& A = cont_in1();
typename Container::type& B = cont_in2();

auto res = algo(exec, tr(A), tr(B), f, proj, proj);

if constexpr(RetTypeCheck)
static_assert(std::is_same_v<decltype(res), decltype(checker(tr(A), tr(B), f, proj, proj))>, "Wrong return type");

auto bres_in = ret_in_val(expected_res, src_view1.begin()) == ret_in_val(res, tr(A).begin());
EXPECT_TRUE(bres_in, (std::string("wrong return value from algo: ") + typeid(Algo).name() +
typeid(decltype(tr(std::declval<Container&>()()))).name()).c_str());
}
}

private:

template<typename, typename = void>
Expand All @@ -138,48 +168,50 @@ struct test
static constexpr
bool is_iterator<T, std::void_t<decltype(++std::declval<T&>()), decltype(*std::declval<T&>())>> = true;

template <typename Ret>
static constexpr auto check_in(int) -> decltype(std::declval<Ret>().in, std::true_type{})
{
return {};
}
template<typename, typename = void>
static constexpr bool check_in{};

template <typename Ret>
static constexpr auto check_in(...) -> std::false_type
{
return {};
}
template<typename T>
static constexpr
bool check_in<T, std::void_t<decltype(std::declval<T>().in)>> = true;

template<typename, typename = void>
static constexpr bool check_out{};

template<typename T>
static constexpr
bool check_out<T, std::void_t<decltype(std::declval<T>().out)>> = true;


template<typename, typename = void>
static constexpr bool is_range{};

template<typename T>
static constexpr
bool is_range<T, std::void_t<decltype(std::declval<T&>().begin())>> = true;

template<typename Ret, typename Begin>
auto ret_in_val(Ret&& ret, Begin&& begin)
{
if constexpr (check_in<Ret>(0))
if constexpr (check_in<Ret>)
return std::distance(begin, ret.in);
else if constexpr (is_iterator<Ret>)
return std::distance(begin, ret);
else if constexpr(is_range<Ret>)
return std::pair{std::distance(begin, ret.begin()), std::ranges::distance(ret.begin(), ret.end())};
else
return ret;
}

template <typename Ret>
static constexpr auto check_out(int) -> decltype(std::declval<Ret>().out, std::true_type{})
{
return {};
}

template <typename Ret>
static constexpr auto check_out(...) -> std::false_type
{
return {};
}

template<typename Ret, typename Begin>
auto ret_out_val(Ret&& ret, Begin&& begin)
{
if constexpr (check_out<Ret>(0))
return std::distance(begin, ret.in);
if constexpr (check_out<Ret>)
return std::distance(begin, ret.out);
else if constexpr (is_iterator<Ret>)
return std::distance(begin, ret);
else if constexpr(is_range<Ret>)
return std::pair{std::distance(begin, ret.begin()), std::ranges::distance(ret.begin(), ret.end())};
else
return ret;
}
Expand Down

0 comments on commit 0d1c1d5

Please sign in to comment.