Skip to content

Commit

Permalink
[oneDPL][ranges] + a test for ranges::adjacent_find; + range tests im…
Browse files Browse the repository at this point in the history
…provements
  • Loading branch information
MikeDvorskiy committed Mar 20, 2024
1 parent 5ca1e3f commit 4aa794e
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 38 deletions.
23 changes: 15 additions & 8 deletions test/parallel_api/ranges/std_ranges.pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,27 @@ main()
auto proj = [](auto&& val) -> decltype(auto) { return val * 2; };
auto pred = [](auto&& val) -> decltype(auto) { return val == 5; };

test_range_algo<1>{}(oneapi::dpl::ranges::for_each, std::ranges::for_each, f_mutuable, proj_mutuable);
test_range_algo<2>{}(oneapi::dpl::ranges::transform, std::ranges::transform, f, proj);
using namespace test_std_ranges;

test_range_algo<1>{}(oneapi::dpl::ranges::find_if, std::ranges::find_if, pred, proj);
test_range_algo<1>{}(oneapi::dpl::ranges::find_if_not, std::ranges::find_if_not, pred, proj);
test_range_algo<1>{}(oneapi::dpl::ranges::find, std::ranges::find, 4, proj);
test_range_algo{}(oneapi::dpl::ranges::for_each, std::ranges::for_each, f_mutuable, proj_mutuable);

//TODO: oneapi::dpl::ranges::transform has output range as return type, std::ranges::trasnform - output iterator.
test_range_algo<data_in_out, false/*return type check*/>{}(oneapi::dpl::ranges::transform, std::ranges::transform, f, proj);

test_range_algo{}(oneapi::dpl::ranges::find_if, std::ranges::find_if, pred, proj);
test_range_algo{}(oneapi::dpl::ranges::find_if_not, std::ranges::find_if_not, pred, proj);
test_range_algo{}(oneapi::dpl::ranges::find, std::ranges::find, 4, proj);

auto pred1 = [](auto&& val) -> decltype(auto) { return val > 0; };
auto pred2 = [](auto&& val) -> decltype(auto) { return val == 4; };
auto pred3 = [](auto&& val) -> decltype(auto) { return val < 0; };

test_range_algo<1>{}(oneapi::dpl::ranges::all_of, std::ranges::all_of, pred1, proj);
test_range_algo<1>{}(oneapi::dpl::ranges::any_of, std::ranges::any_of, pred2, std::identity{});
test_range_algo<1>{}(oneapi::dpl::ranges::none_of, std::ranges::none_of, pred3, std::identity{});
test_range_algo{}(oneapi::dpl::ranges::all_of, std::ranges::all_of, pred1, proj);
test_range_algo{}(oneapi::dpl::ranges::any_of, std::ranges::any_of, pred2, std::identity{});
test_range_algo{}(oneapi::dpl::ranges::none_of, std::ranges::none_of, pred3, std::identity{});

auto pred_2 = [](auto&& val1, auto&& val2) -> decltype(auto) { return val1 == val2; };
test_range_algo{}(oneapi::dpl::ranges::adjacent_find, std::ranges::adjacent_find, pred_2, proj);

#endif //_ENABLE_STD_RANGES_TESTING

Expand Down
74 changes: 44 additions & 30 deletions test/parallel_api/ranges/std_ranges_test.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,11 @@
#include <span>
#include <iostream>
#include <vector>
#include <ranges>
#include <typeinfo>
#include <string>

namespace test_std_ranges
{;

template<int call_id = 0>
auto dpcpp_policy()
Expand All @@ -39,7 +42,15 @@ auto dpcpp_policy()

auto host_policies() { return std::true_type{};}

template<typename Container, int Ranges = 1>
enum TestDataMode
{
data_in,
data_in_out,
data_in_in,
data_in_in_out
};

template<typename Container, TestDataMode Ranges = data_in, bool RetTypeCheck = true>
struct test
{
template<typename Policy, typename Algo, typename... Args>
Expand All @@ -54,7 +65,7 @@ struct test

template<typename Policy, typename Algo, typename Checker, typename FunctorOrVal, typename Proj = std::identity,
typename Transform = std::identity>
std::enable_if_t<!std::is_same_v<Policy, std::true_type> && Ranges == 1>
std::enable_if_t<!std::is_same_v<Policy, std::true_type> && Ranges == data_in>
operator()(Policy&& exec, Algo algo, Checker checker, FunctorOrVal f, Proj proj = {}, Transform tr = {})
{
constexpr int max_n = 10;
Expand All @@ -69,20 +80,22 @@ struct test

auto res = algo(exec, tr(A), f, proj);

static_assert(std::is_same_v<decltype(res), decltype(checker(tr(A), f, proj))>, "Wrong return type.");
//check result
if constexpr(RetTypeCheck)
static_assert(std::is_same_v<decltype(res), decltype(checker(tr(A), f, proj))>, "Wrong return type");

auto bres = ret_in_val(expected_res, expected_view.begin()) == ret_in_val(res, tr(A).begin());
if(!bres) std::cout << typeid(Algo).name() <<": ";
EXPECT_TRUE(bres, "wrong return value from algo with ranges");
EXPECT_TRUE(bres, (std::string("wrong return value from algo with ranges: ") + typeid(Algo).name()).c_str());
}

//check result
EXPECT_EQ_N(expected, data, max_n, "wrong effect algo with ranges");
EXPECT_EQ_N(expected, data, max_n, (std::string("wrong effect algo with ranges: ")
+ typeid(Algo).name() + typeid(decltype(tr(std::declval<Container&>()()))).name()).c_str());
}

template<typename Policy, typename Algo, typename Checker, typename Functor, typename Proj = std::identity,
typename Transform = std::identity>
std::enable_if_t<!std::is_same_v<Policy, std::true_type> && Ranges == 2>
std::enable_if_t<!std::is_same_v<Policy, std::true_type> && Ranges == data_in_out>
operator()(Policy&& exec, Algo algo, Checker checker, Functor f, Proj proj = {}, Transform tr = {})
{
constexpr int max_n = 10;
Expand All @@ -99,22 +112,21 @@ struct test
typename Container::type& A = cont_in();
typename Container::type& B = cont_out();

auto res = algo(exec, tr(A), tr(B), f, proj);
auto res = algo(exec, tr(A), B, f, proj);

static_assert(std::is_same_v<decltype(res),
decltype(checker(tr(A), tr(B), f, proj))>, "Wrong return type.");
//check result
if constexpr(RetTypeCheck)
static_assert(std::is_same_v<decltype(res), decltype(checker(tr(A), B, f, proj))>, "Wrong return type");

auto bres_in = ret_in_val(expected_res, src_view.begin()) == ret_in_val(res, tr(A).begin());
if(!bres_in) std::cout << typeid(Algo).name() <<": ";
EXPECT_TRUE(bres_in, "wrong return value from algo with input range");
EXPECT_TRUE(bres_in, (std::string("wrong return value from algo with input range: ") + typeid(Algo).name()).c_str());

auto bres_out = ret_out_val(expected_res, src_view.begin()) == ret_out_val(res, tr(A).begin());
if(!bres_out) std::cout << typeid(Algo).name() <<": ";
EXPECT_TRUE(bres_out, "wrong return value from algo with output range");
EXPECT_TRUE(bres_out, (std::string("wrong return value from algo with output range: ") + typeid(Algo).name()).c_str());
}

//check result
EXPECT_EQ_N(expected, data_out, max_n, "wrong effect algo with ranges");
EXPECT_EQ_N(expected, data_out, max_n, (std::string("wrong effect algo with ranges: ") + typeid(Algo).name()).c_str());
}

private:
Expand Down Expand Up @@ -277,7 +289,7 @@ struct usm_subrange_impl
using usm_subrange = usm_subrange_impl<std::ranges::subrange<int*>>;
using usm_span = usm_subrange_impl<std::span<int>>;

template<int NRanges = 1>
template<TestDataMode TestDataMode = data_in, bool RetTypeCheck = true>
struct test_range_algo
{
void operator()(auto algo, auto checker, auto f, auto proj)
Expand All @@ -291,26 +303,28 @@ struct test_range_algo
return std::ranges::subrange(forward_it(v.begin()), forward_it(v.end()));
};

test<host_vector, NRanges>{}(host_policies(), algo, checker, f, std::identity{}, forward_view);
test<host_vector, TestDataMode, RetTypeCheck>{}(host_policies(), algo, checker, f, std::identity{}, forward_view);

test<host_vector, NRanges>{}(host_policies(), algo, checker, f, std::identity{}, subrange_view);
test<host_vector, NRanges>{}(host_policies(), algo, checker, f, std::identity{}, span_view);
test<host_vector, NRanges>{}(host_policies(), algo, checker, f, proj, std::views::all);
test<host_subrange, NRanges>{}(host_policies(), algo, checker, f, proj, std::views::all);
test<host_span, NRanges>{}(host_policies(), algo, checker, f, proj, std::views::all);
test<host_vector, TestDataMode, RetTypeCheck>{}(host_policies(), algo, checker, f, std::identity{}, subrange_view);
test<host_vector, TestDataMode, RetTypeCheck>{}(host_policies(), algo, checker, f, std::identity{}, span_view);
test<host_vector, TestDataMode, RetTypeCheck>{}(host_policies(), algo, checker, f, proj, std::views::all);
test<host_subrange, TestDataMode, RetTypeCheck>{}(host_policies(), algo, checker, f, proj, std::views::all);
test<host_span, TestDataMode, RetTypeCheck>{}(host_policies(), algo, checker, f, proj, std::views::all);

test<usm_vector, NRanges>{}(dpcpp_policy(), algo, checker, f);
test<usm_vector, NRanges>{}(dpcpp_policy(), algo, checker, f, std::identity{}, oneapi::dpl::views::all);
test<usm_vector, NRanges>{}(dpcpp_policy(), algo, checker, f, std::identity{}, subrange_view);
test<usm_vector, NRanges>{}(dpcpp_policy(), algo, checker, f, std::identity{}, span_view);
test<usm_subrange, NRanges>{}(dpcpp_policy(), algo, checker, f);
test<usm_span, NRanges>{}(dpcpp_policy(), algo, checker, f);
test<usm_vector, TestDataMode, RetTypeCheck>{}(dpcpp_policy(), algo, checker, f, proj);
test<usm_vector, TestDataMode, RetTypeCheck>{}(dpcpp_policy(), algo, checker, f, proj, oneapi::dpl::views::all);
test<usm_vector, TestDataMode, RetTypeCheck>{}(dpcpp_policy(), algo, checker, f, proj, subrange_view);
test<usm_vector, TestDataMode, RetTypeCheck>{}(dpcpp_policy(), algo, checker, f, std::identity{}, span_view);
test<usm_subrange, TestDataMode, RetTypeCheck>{}(dpcpp_policy(), algo, checker, f, proj);
test<usm_span, TestDataMode, RetTypeCheck>{}(dpcpp_policy(), algo, checker, f, proj);

#if 0 //sycl buffer
test<sycl_buffer, NRanges>{}(dpcpp_policy(), algo, checker, f, std::identity{}, oneapi::dpl::views::all);
test<sycl_buffer, TestDataMode, RetTypeCheck>{}(dpcpp_policy(), algo, checker, f, std::identity{}, oneapi::dpl::views::all);
#endif
}
};

}; //namespace test_std_ranges

#endif //_ENABLE_STD_RANGES_TESTING

0 comments on commit 4aa794e

Please sign in to comment.