diff --git a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h index 2c8b5c472ae..1d10b8ac29b 100644 --- a/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h +++ b/include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h @@ -864,6 +864,51 @@ struct __gen_red_by_seg_scan_input template struct __red_by_seg_op { + // Consider the following segment / value pairs that would be processed in reduce-then-scan by a sub-group of size 8: + // ---------------------------------------------------------- + // Keys: 0 0 1 1 2 2 2 2 + // Values: 1 1 1 1 1 1 1 1 + // ---------------------------------------------------------- + // The reduce and scan input generation phase flags new segments (excluding index 0) for use in the sub-group scan + // operation. The above key, value pairs correspond to the following flag, value pairs: + // ---------------------------------------------------------- + // Flags: 0 0 1 0 1 0 0 0 + // Values: 1 1 1 1 1 1 1 1 + // ---------------------------------------------------------- + // The sub-group scan operation looks back by powers-of-2 applying encountered prefixes. The __red_by_seg_op + // operation performs a standard inclusive scan over the flags to compute output indices while performing a masked + // scan over values to avoid applying a previous segment's partial reduction. Previous value elements are reduced + // so long as the current index's flag is 0, indicating that input within its segment is still being processed + // ---------------------------------------------------------- + // Start: + // ---------------------------------------------------------- + // Flags: 0 0 1 0 1 0 0 0 + // Values: 1 1 1 1 1 1 1 1 + // ---------------------------------------------------------- + // After step 1 (apply the i-1th value if the ith flag is 0): + // ---------------------------------------------------------- + // Flags: 0 0 1 1 1 1 0 0 + // Values: 1 2 1 2 1 2 2 2 + // ---------------------------------------------------------- + // After step 2 (apply the i-2th value if the ith flag is 0): + // ---------------------------------------------------------- + // Flags: 0 0 1 1 2 2 1 1 + // Values: 1 2 1 2 1 2 3 4 + // ---------------------------------------------------------- + // After step 3 (apply the i-4th value if the ith flag is 0): + // ---------------------------------------------------------- + // Flags: 0 0 1 1 2 2 2 2 + // Values: 1 2 1 2 1 2 3 4 + // ^ ^ ^ + // ---------------------------------------------------------- + // Note that the scan of segment flags results in the desired output index of the reduce_by_segment operation in + // each segment and the item corresponding to the final key in a segment contains its output reduction value. This + // operation is first applied within a sub-group and then across sub-groups, work-groups, and blocks to + // reduce-by-segment across the full input. The result of these operations combined with cached key data in + // __gen_red_by_seg_scan_input enables the write phase to output keys and reduction values. + // => + // Segments : 0 1 2 + // Values : 2 2 4 template auto operator()(const _Tup1& __lhs_tup, const _Tup2& __rhs_tup) const @@ -1326,10 +1371,15 @@ __parallel_reduce_by_segment_reduce_then_scan(oneapi::dpl::__internal::__device_ _Range3&& __out_keys, _Range4&& __out_values, _BinaryPredicate __binary_pred, _BinaryOperator __binary_op) { + // Flags new segments and passes input value through a 2-tuple using _GenReduceInput = __gen_red_by_seg_reduce_input<_BinaryPredicate>; + // Operation that computes output indices and output reduction values per segment using _ReduceOp = __red_by_seg_op<_BinaryOperator>; + // Returns 4-component tuple which contains flags, keys, value, and a flag to write output using _GenScanInput = __gen_red_by_seg_scan_input<_BinaryPredicate>; + // Returns the first component from scan input which is scanned over using _ScanInputTransform = __get_zeroth_element; + // Writes current segment's output reduction and the next segment's output key using _WriteOp = __write_red_by_seg<_BinaryPredicate>; using _ValueType = oneapi::dpl::__internal::__value_t<_Range2>; std::size_t __n = __keys.size();