Skip to content

Commit

Permalink
include/oneapi/dpl/pstl/hetero/algorithm_impl_hetero.h - implementati…
Browse files Browse the repository at this point in the history
…on of __pattern_any_of on __parallel_transform_reduce - fix review comment: counting iterator not required.

Signed-off-by: Sergey Kopienko <[email protected]>
  • Loading branch information
SergeyKopienko committed Jun 13, 2024
1 parent 7395451 commit f3bd4e9
Showing 1 changed file with 45 additions and 49 deletions.
94 changes: 45 additions & 49 deletions include/oneapi/dpl/pstl/hetero/algorithm_impl_hetero.h
Original file line number Diff line number Diff line change
Expand Up @@ -646,34 +646,12 @@ __pattern_count(__hetero_tag<_BackendTag>, _ExecutionPolicy&& __exec, _Iterator
// any_of
//------------------------------------------------------------------------

template <typename _Tuple, typename _UnaryTransformOp>
struct __find_if_unary_transform_op
struct __any_of_binary_reduce_op
{
_UnaryTransformOp __transform_op;

template <typename Arg>
_Tuple
operator()(const Arg& arg) const
bool
operator()(bool op1, bool op2) const
{
return {__transform_op(std::get<0>(arg)), std::get<1>(arg)};
}
};

template <typename _Tuple, typename _IsFirst>
struct __find_if_binary_reduce_op
{
_Tuple
operator()(const _Tuple& op1, const _Tuple& op2) const
{
if (std::get<0>(op1) && std::get<0>(op2))
{
if constexpr (_IsFirst{})
return {true, std::min(std::get<1>(op1), std::get<1>(op2))};
else
return {true, std::max(std::get<1>(op1), std::get<1>(op2))};
}

return std::get<0>(op1) ? op1 : op2;
return op1 || op2;
}
};

Expand All @@ -688,38 +666,25 @@ __pattern_any_of(__hetero_tag<_BackendTag>, _ExecutionPolicy&& __exec, _Iterator
if (__n == 0)
return false;

using _result_type = oneapi::dpl::__internal::tuple<bool, _difference_type>;
const auto __init = _result_type{false, __n};

// __counting_iterator_t - iterate position (index) in source data
using __counting_iterator_t = oneapi::dpl::counting_iterator<_difference_type>;
using _data_type = typename ::std::iterator_traits<_Iterator>::value_type;

using _zipped_data_type = typename std::iterator_traits<decltype(oneapi::dpl::make_zip_iterator(
__first, __counting_iterator_t{0}))>::value_type;
using _result_type = oneapi::dpl::__internal::tuple<bool, _difference_type>;
const bool __init = false;

__find_if_binary_reduce_op<_zipped_data_type, /*_IsFirst*/ std::true_type> __reduce_op;
__find_if_unary_transform_op<_zipped_data_type, _Pred> __transform_op{__pred};
__any_of_binary_reduce_op __reduce_op;

using _Functor = unseq_backend::walk_n<_ExecutionPolicy, decltype(__transform_op)>;
using _Functor = unseq_backend::walk_n<_ExecutionPolicy, decltype(__pred)>;
using _RepackedTp = __par_backend_hetero::__repacked_tuple_t<_result_type>;

auto __keep_src_data =
oneapi::dpl::__ranges::__get_sycl_range<__par_backend_hetero::access_mode::read, _Iterator>();
auto __buf_src_data = __keep_src_data(__first, __last);

const __counting_iterator_t __counting_it_first{0}, __counting_it_last{__n};
auto __keep_counting_it =
oneapi::dpl::__ranges::__get_sycl_range<__par_backend_hetero::access_mode::read, __counting_iterator_t>();
auto __buf_counting_it = __keep_counting_it(__counting_it_first, __counting_it_last);

auto res =
oneapi::dpl::__par_backend_hetero::__parallel_transform_reduce<_RepackedTp, std::true_type /*is_commutative*/>(
_BackendTag{}, std::forward<_ExecutionPolicy>(__exec), __reduce_op, _Functor{__transform_op},
unseq_backend::__init_value<_RepackedTp>{__init}, // initial value
oneapi::dpl::__ranges::make_zip_view(__buf_src_data.all_view(), __buf_counting_it.all_view()))
.get();

return std::get<0>(res);
return oneapi::dpl::__par_backend_hetero::__parallel_transform_reduce<bool, std::true_type /*is_commutative*/>(
_BackendTag{}, std::forward<_ExecutionPolicy>(__exec), __reduce_op, _Functor{__pred},
unseq_backend::__init_value<bool>{__init}, // initial value
__buf_src_data.all_view())
.get();
}

//------------------------------------------------------------------------
Expand Down Expand Up @@ -766,6 +731,37 @@ __pattern_equal(__hetero_tag<_BackendTag> __tag, _ExecutionPolicy&& __exec, _Ite
// find_if
//------------------------------------------------------------------------

template <typename _Tuple, typename _UnaryTransformOp>
struct __find_if_unary_transform_op
{
_UnaryTransformOp __transform_op;

template <typename Arg>
_Tuple
operator()(const Arg& arg) const
{
return {__transform_op(std::get<0>(arg)), std::get<1>(arg)};
}
};

template <typename _Tuple, typename _IsFirst>
struct __find_if_binary_reduce_op
{
_Tuple
operator()(const _Tuple& op1, const _Tuple& op2) const
{
if (std::get<0>(op1) && std::get<0>(op2))
{
if constexpr (_IsFirst{})
return {true, std::min(std::get<1>(op1), std::get<1>(op2))};
else
return {true, std::max(std::get<1>(op1), std::get<1>(op2))};
}

return std::get<0>(op1) ? op1 : op2;
}
};

template <typename _BackendTag, typename _ExecutionPolicy, typename _Iterator, typename _Pred>
_Iterator
__pattern_find_if(__hetero_tag<_BackendTag>, _ExecutionPolicy&& __exec, _Iterator __first, _Iterator __last,
Expand Down

0 comments on commit f3bd4e9

Please sign in to comment.