Skip to content

Commit

Permalink
Revert "remove partition pattern family from reduce_then_scan"
Browse files Browse the repository at this point in the history
This reverts commit 3d69c3c.
  • Loading branch information
danhoeflinger committed Aug 6, 2024
1 parent e33efaa commit 06da0e2
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 10 deletions.
25 changes: 15 additions & 10 deletions include/oneapi/dpl/pstl/hetero/algorithm_impl_hetero.h
Original file line number Diff line number Diff line change
Expand Up @@ -951,19 +951,24 @@ __pattern_partition_copy(__hetero_tag<_BackendTag> __tag, _ExecutionPolicy&& __e
return ::std::make_pair(__result1, __result2);

using _It1DifferenceType = typename ::std::iterator_traits<_Iterator1>::difference_type;
using _ReduceOp = ::std::plus<_It1DifferenceType>;

unseq_backend::__create_mask<_UnaryPredicate, _It1DifferenceType> __create_mask_op{__pred};
unseq_backend::__partition_by_mask<_ReduceOp, /*inclusive*/ ::std::true_type> __copy_by_mask_op{_ReduceOp{}};
_It1DifferenceType __n = __last - __first;

auto __result = __pattern_scan_copy(
__tag, ::std::forward<_ExecutionPolicy>(__exec), __first, __last,
__par_backend_hetero::zip(
__par_backend_hetero::make_iter_mode<__par_backend_hetero::access_mode::write>(__result1),
__par_backend_hetero::make_iter_mode<__par_backend_hetero::access_mode::write>(__result2)),
__create_mask_op, __copy_by_mask_op);
auto __keep1 = oneapi::dpl::__ranges::__get_sycl_range<__par_backend_hetero::access_mode::read, _Iterator1>();
auto __buf1 = __keep1(__first, __last);

auto __zipped_res = __par_backend_hetero::zip(
__par_backend_hetero::make_iter_mode<__par_backend_hetero::access_mode::write>(__result1),
__par_backend_hetero::make_iter_mode<__par_backend_hetero::access_mode::write>(__result2));

auto __keep2 =
oneapi::dpl::__ranges::__get_sycl_range<__par_backend_hetero::access_mode::write, decltype(__zipped_res)>();
auto __buf2 = __keep2(__zipped_res, __zipped_res + __n);

auto __result = oneapi::dpl::__par_backend_hetero::__parallel_partition_copy(
_BackendTag{}, std::forward<_ExecutionPolicy>(__exec), __buf1.all_view(), __buf2.all_view(), __pred);

return ::std::make_pair(__result1 + __result.second, __result2 + (__last - __first - __result.second));
return std::make_pair(__result1 + __result.get(), __result2 + (__last - __first - __result.get()));
}

//------------------------------------------------------------------------
Expand Down
46 changes: 46 additions & 0 deletions include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h
Original file line number Diff line number Diff line change
Expand Up @@ -868,6 +868,25 @@ struct __write_to_idx_if
Assign __assign;
};

template <typename Assign = oneapi::dpl::__internal::__pstl_assign>
struct __write_to_idx_if_else
{
template <typename _OutRng, typename _SizeType, typename ValueType>
void
operator()(_OutRng&& __out_rng, _SizeType __idx, const ValueType& __v) const
{
using _ConvertedTupleType =
typename oneapi::dpl::__internal::__get_tuple_type<std::decay_t<decltype(std::get<2>(__v))>,
std::decay_t<decltype(__out_rng[__idx])>>::__type;
if (std::get<1>(__v))
__assign(static_cast<_ConvertedTupleType>(std::get<2>(__v)), std::get<0>(__out_rng[std::get<0>(__v) - 1]));
else
__assign(static_cast<_ConvertedTupleType>(std::get<2>(__v)),
std::get<1>(__out_rng[__idx - std::get<0>(__v)]));
}
Assign __assign;
};

template <typename _ExecutionPolicy, typename _Range1, typename _Range2, typename _UnaryOperation, typename _InitType,
typename _BinaryOperation, typename _Inclusive>
auto
Expand Down Expand Up @@ -1038,6 +1057,33 @@ __parallel_scan_copy(oneapi::dpl::__internal::__device_backend_tag __backend_tag
__copy_by_mask_op);
}

template <typename _ExecutionPolicy, typename _Range1, typename _Range2, typename _UnaryPredicate>
auto
__parallel_partition_copy(oneapi::dpl::__internal::__device_backend_tag __backend_tag, _ExecutionPolicy&& __exec,
_Range1&& __rng, _Range2&& __result, _UnaryPredicate __pred)
{
auto __n = __rng.size();
if (oneapi::dpl::__par_backend_hetero::__prefer_reduce_then_scan(__exec))
{
using _GenMask = oneapi::dpl::__par_backend_hetero::__gen_mask<_UnaryPredicate>;
using _WriteOp =
oneapi::dpl::__par_backend_hetero::__write_to_idx_if_else<oneapi::dpl::__internal::__pstl_assign>;

return __parallel_reduce_then_scan_copy(__backend_tag, std::forward<_ExecutionPolicy>(__exec),
std::forward<_Range1>(__rng), std::forward<_Range2>(__result), __n,
_GenMask{__pred}, _WriteOp{});
}
else
{
using _ReduceOp = std::plus<decltype(__n)>;
using _CreateOp = unseq_backend::__create_mask<_UnaryPredicate, decltype(__n)>;
using _CopyOp = unseq_backend::__partition_by_mask<_ReduceOp, /*inclusive*/ std::true_type>;

return __parallel_scan_copy(__backend_tag, std::forward<_ExecutionPolicy>(__exec), std::forward<_Range1>(__rng),
std::forward<_Range2>(__result), __n, _CreateOp{__pred}, _CopyOp{_ReduceOp{}});
}
}

template <typename _ExecutionPolicy, typename _InRng, typename _OutRng, typename _Size, typename _Pred,
typename _Assign = oneapi::dpl::__internal::__pstl_assign>
auto
Expand Down
10 changes: 10 additions & 0 deletions include/oneapi/dpl/pstl/hetero/dpcpp/sycl_traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,9 @@ struct __gen_expand_count_mask;
template <int32_t __offset, typename Assign>
struct __write_to_idx_if;

template <typename Assign>
struct __write_to_idx_if_else;

template <typename _ExecutionPolicy, typename _Pred>
struct __early_exit_find_or;

Expand Down Expand Up @@ -286,6 +289,13 @@ struct sycl::is_device_copyable<_ONEDPL_SPECIALIZE_FOR(oneapi::dpl::__par_backen
{
};

template <typename Assign>
struct sycl::is_device_copyable<_ONEDPL_SPECIALIZE_FOR(oneapi::dpl::__par_backend_hetero::__write_to_idx_if_else,
Assign)>
: oneapi::dpl::__internal::__are_all_device_copyable<Assign>
{
};

template <typename _ExecutionPolicy, typename _Pred>
struct sycl::is_device_copyable<_ONEDPL_SPECIALIZE_FOR(oneapi::dpl::__par_backend_hetero::__early_exit_find_or,
_ExecutionPolicy, _Pred)>
Expand Down
10 changes: 10 additions & 0 deletions test/general/implementation_details/device_copyable.pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,11 @@ test_device_copyable()
sycl::is_device_copyable_v<oneapi::dpl::__par_backend_hetero::__write_to_idx_if<0, assign_device_copyable>>,
"__write_to_idx_if is not device copyable with device copyable types");

//__write_to_idx_if_else
static_assert(
sycl::is_device_copyable_v<oneapi::dpl::__par_backend_hetero::__write_to_idx_if_else<assign_device_copyable>>,
"__write_to_idx_if_else is not device copyable with device copyable types");

// __early_exit_find_or
static_assert(
sycl::is_device_copyable_v<
Expand Down Expand Up @@ -391,6 +396,11 @@ test_non_device_copyable()
oneapi::dpl::__par_backend_hetero::__write_to_idx_if<0, assign_non_device_copyable>>,
"__write_to_idx_if is device copyable with non device copyable types");

//__write_to_idx_if_else
static_assert(!sycl::is_device_copyable_v<
oneapi::dpl::__par_backend_hetero::__write_to_idx_if_else<assign_non_device_copyable>>,
"__write_to_idx_if_else is device copyable with non device copyable types");

// __early_exit_find_or
static_assert(
!sycl::is_device_copyable_v<oneapi::dpl::__par_backend_hetero::__early_exit_find_or<policy_non_device_copyable,
Expand Down

0 comments on commit 06da0e2

Please sign in to comment.