diff --git a/include/oneapi/dpl/pstl/hetero/algorithm_impl_hetero.h b/include/oneapi/dpl/pstl/hetero/algorithm_impl_hetero.h index 0bea23faf1f..8e99f3701e0 100644 --- a/include/oneapi/dpl/pstl/hetero/algorithm_impl_hetero.h +++ b/include/oneapi/dpl/pstl/hetero/algorithm_impl_hetero.h @@ -759,16 +759,44 @@ _Iterator __pattern_find_if(__hetero_tag<_BackendTag>, _ExecutionPolicy&& __exec, _Iterator __first, _Iterator __last, _Pred __pred) { - if (__first == __last) + using _difference_type = typename ::std::iterator_traits<_Iterator>::difference_type; + + const _difference_type __n = __last - __first; + if (__n == 0) return __last; - using _Predicate = oneapi::dpl::unseq_backend::single_match_pred<_ExecutionPolicy, _Pred>; + using _result_type = oneapi::dpl::__internal::tuple; + const auto __init = _result_type{false, __n}; - return __par_backend_hetero::__parallel_find( - _BackendTag{}, ::std::forward<_ExecutionPolicy>(__exec), - __par_backend_hetero::make_iter_mode<__par_backend_hetero::access_mode::read>(__first), - __par_backend_hetero::make_iter_mode<__par_backend_hetero::access_mode::read>(__last), _Predicate{__pred}, - ::std::true_type{}); + // __counting_iterator_t - iterate position (index) in source data + using __counting_iterator_t = oneapi::dpl::counting_iterator<_difference_type>; + + using _zipped_data_type = typename std::iterator_traits::value_type; + + __find_if_binary_reduce_op<_zipped_data_type, /*_IsFirst*/ std::true_type> __reduce_op; + __find_if_unary_transform_op<_zipped_data_type, _Pred> __transform_op{__pred}; + + using _Functor = unseq_backend::walk_n<_ExecutionPolicy, decltype(__transform_op)>; + using _RepackedTp = __par_backend_hetero::__repacked_tuple_t<_result_type>; + + auto __keep_src_data = + oneapi::dpl::__ranges::__get_sycl_range<__par_backend_hetero::access_mode::read, _Iterator>(); + auto __buf_src_data = __keep_src_data(__first, __last); + + const __counting_iterator_t __counting_it_first{0}, __counting_it_last{__n}; + auto __keep_counting_it = + oneapi::dpl::__ranges::__get_sycl_range<__par_backend_hetero::access_mode::read, __counting_iterator_t>(); + auto __buf_counting_it = __keep_counting_it(__counting_it_first, __counting_it_last); + + auto res = + oneapi::dpl::__par_backend_hetero::__parallel_transform_reduce<_RepackedTp, std::true_type /*is_commutative*/>( + _BackendTag{}, std::forward<_ExecutionPolicy>(__exec), __reduce_op, _Functor{__transform_op}, + unseq_backend::__init_value<_RepackedTp>{__init}, // initial value + oneapi::dpl::__ranges::make_zip_view(__buf_src_data.all_view(), __buf_counting_it.all_view())) + .get(); + + return std::get<0>(res) ? __first + std::get<1>(res) : __last; } //------------------------------------------------------------------------