Skip to content

Commit

Permalink
Improve explanations of reduce-by-segment approach
Browse files Browse the repository at this point in the history
Signed-off-by: Matthew Michel <[email protected]>
  • Loading branch information
mmichel11 committed Nov 22, 2024
1 parent 3c8154e commit db63d45
Showing 1 changed file with 50 additions and 0 deletions.
50 changes: 50 additions & 0 deletions include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h
Original file line number Diff line number Diff line change
Expand Up @@ -864,6 +864,51 @@ struct __gen_red_by_seg_scan_input
template <typename _BinaryOp>
struct __red_by_seg_op
{
// Consider the following segment / value pairs that would be processed in reduce-then-scan by a sub-group of size 8:
// ----------------------------------------------------------
// Keys: 0 0 1 1 2 2 2 2
// Values: 1 1 1 1 1 1 1 1
// ----------------------------------------------------------
// The reduce and scan input generation phase flags new segments (excluding index 0) for use in the sub-group scan
// operation. The above key, value pairs correspond to the following flag, value pairs:
// ----------------------------------------------------------
// Flags: 0 0 1 0 1 0 0 0
// Values: 1 1 1 1 1 1 1 1
// ----------------------------------------------------------
// The sub-group scan operation looks back by powers-of-2 applying encountered prefixes. The __red_by_seg_op
// operation performs a standard inclusive scan over the flags to compute output indices while performing a masked
// scan over values to avoid applying a previous segment's partial reduction. Previous value elements are reduced
// so long as the current index's flag is 0, indicating that input within its segment is still being processed
// ----------------------------------------------------------
// Start:
// ----------------------------------------------------------
// Flags: 0 0 1 0 1 0 0 0
// Values: 1 1 1 1 1 1 1 1
// ----------------------------------------------------------
// After step 1 (apply the i-1th value if the ith flag is 0):
// ----------------------------------------------------------
// Flags: 0 0 1 1 1 1 0 0
// Values: 1 2 1 2 1 2 2 2
// ----------------------------------------------------------
// After step 2 (apply the i-2th value if the ith flag is 0):
// ----------------------------------------------------------
// Flags: 0 0 1 1 2 2 1 1
// Values: 1 2 1 2 1 2 3 4
// ----------------------------------------------------------
// After step 3 (apply the i-4th value if the ith flag is 0):
// ----------------------------------------------------------
// Flags: 0 0 1 1 2 2 2 2
// Values: 1 2 1 2 1 2 3 4
// ^ ^ ^
// ----------------------------------------------------------
// Note that the scan of segment flags results in the desired output index of the reduce_by_segment operation in
// each segment and the item corresponding to the final key in a segment contains its output reduction value. This
// operation is first applied within a sub-group and then across sub-groups, work-groups, and blocks to
// reduce-by-segment across the full input. The result of these operations combined with cached key data in
// __gen_red_by_seg_scan_input enables the write phase to output keys and reduction values.
// =>
// Segments : 0 1 2
// Values : 2 2 4
template <typename _Tup1, typename _Tup2>
auto
operator()(const _Tup1& __lhs_tup, const _Tup2& __rhs_tup) const
Expand Down Expand Up @@ -1326,10 +1371,15 @@ __parallel_reduce_by_segment_reduce_then_scan(oneapi::dpl::__internal::__device_
_Range3&& __out_keys, _Range4&& __out_values,
_BinaryPredicate __binary_pred, _BinaryOperator __binary_op)
{
// Flags new segments and passes input value through a 2-tuple
using _GenReduceInput = __gen_red_by_seg_reduce_input<_BinaryPredicate>;
// Operation that computes output indices and output reduction values per segment
using _ReduceOp = __red_by_seg_op<_BinaryOperator>;
// Returns 4-component tuple which contains flags, keys, value, and a flag to write output
using _GenScanInput = __gen_red_by_seg_scan_input<_BinaryPredicate>;
// Returns the first component from scan input which is scanned over
using _ScanInputTransform = __get_zeroth_element;
// Writes current segment's output reduction and the next segment's output key
using _WriteOp = __write_red_by_seg<_BinaryPredicate>;
using _ValueType = oneapi::dpl::__internal::__value_t<_Range2>;
std::size_t __n = __keys.size();
Expand Down

0 comments on commit db63d45

Please sign in to comment.