Skip to content

Commit

Permalink
include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_merge.h - …
Browse files Browse the repository at this point in the history
…fix review comment: remove extra condition check from __find_start_point_in

Signed-off-by: Sergey Kopienko <[email protected]>
  • Loading branch information
SergeyKopienko committed Dec 4, 2024
1 parent c350675 commit 6ad8170
Showing 1 changed file with 44 additions and 54 deletions.
98 changes: 44 additions & 54 deletions include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_merge.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,9 @@ __find_start_point_in(const _Rng1& __rng1, const _Index __rng1_from, _Index __rn
if constexpr (!std::is_pointer_v<_Rng2>)
assert(__rng2_to <= __rng2.size());

assert(__i_elem >= 0);
// We shouldn't call this function with __i_elem == 0 because we a priory know that
// split point for this case is {0, 0}
assert(__i_elem > 0);

// ----------------------- EXAMPLE ------------------------
// Let's consider the following input data:
Expand Down Expand Up @@ -150,76 +152,64 @@ __find_start_point_in(const _Rng1& __rng1, const _Index __rng1_from, _Index __rn
// - where for every comparing pairs idx(rng1) + idx(rng2) == i_diag - 1

////////////////////////////////////////////////////////////////////////////////////
// Process the corner case: for the first diagonal with the index 0 split point
// is equal to (0, 0) regardless of the size and content of the data.
if (__i_elem > 0)
{
////////////////////////////////////////////////////////////////////////////////////
// Taking into account the specified constraints of the range of processed data
const auto __index_sum = __i_elem - 1;
// Taking into account the specified constraints of the range of processed data
const auto __index_sum = __i_elem - 1;

using _IndexSigned = std::make_signed_t<_Index>;
using _IndexSigned = std::make_signed_t<_Index>;

_IndexSigned idx1_from = __rng1_from;
_IndexSigned idx1_to = __rng1_to;
assert(idx1_from <= idx1_to);
_IndexSigned idx1_from = __rng1_from;
_IndexSigned idx1_to = __rng1_to;
assert(idx1_from <= idx1_to);

_IndexSigned idx2_from = __index_sum - (__rng1_to - 1);
_IndexSigned idx2_to = __index_sum - __rng1_from + 1;
assert(idx2_from <= idx2_to);
_IndexSigned idx2_from = __index_sum - (__rng1_to - 1);
_IndexSigned idx2_to = __index_sum - __rng1_from + 1;
assert(idx2_from <= idx2_to);

const _IndexSigned idx2_from_diff =
idx2_from < (_IndexSigned)__rng2_from ? (_IndexSigned)__rng2_from - idx2_from : 0;
const _IndexSigned idx2_to_diff = idx2_to > (_IndexSigned)__rng2_to ? idx2_to - (_IndexSigned)__rng2_to : 0;
const _IndexSigned idx2_from_diff =
idx2_from < (_IndexSigned)__rng2_from ? (_IndexSigned)__rng2_from - idx2_from : 0;
const _IndexSigned idx2_to_diff = idx2_to > (_IndexSigned)__rng2_to ? idx2_to - (_IndexSigned)__rng2_to : 0;

idx1_to -= idx2_from_diff;
idx1_from += idx2_to_diff;
idx1_to -= idx2_from_diff;
idx1_from += idx2_to_diff;

idx2_from = __index_sum - (idx1_to - 1);
idx2_to = __index_sum - idx1_from + 1;
idx2_from = __index_sum - (idx1_to - 1);
idx2_to = __index_sum - idx1_from + 1;

assert(idx1_from <= idx1_to);
assert(__rng1_from <= idx1_from && idx1_to <= __rng1_to);
assert(idx1_from <= idx1_to);
assert(__rng1_from <= idx1_from && idx1_to <= __rng1_to);

assert(idx2_from <= idx2_to);
assert(__rng2_from <= idx2_from && idx2_to <= __rng2_to);
assert(idx2_from <= idx2_to);
assert(__rng2_from <= idx2_from && idx2_to <= __rng2_to);

////////////////////////////////////////////////////////////////////////////////////
// Run search of split point on diagonal
////////////////////////////////////////////////////////////////////////////////////
// Run search of split point on diagonal

using __it_t = oneapi::dpl::counting_iterator<_Index>;
using __it_t = oneapi::dpl::counting_iterator<_Index>;

__it_t __diag_it_begin(idx1_from);
__it_t __diag_it_end(idx1_to);
__it_t __diag_it_begin(idx1_from);
__it_t __diag_it_end(idx1_to);

constexpr int kValue = 1;
const __it_t __res =
std::lower_bound(__diag_it_begin, __diag_it_end, kValue, [&](_Index __idx, const auto& __value) {
const auto __rng1_idx = __idx;
const auto __rng2_idx = __index_sum - __idx;
constexpr int kValue = 1;
const __it_t __res =
std::lower_bound(__diag_it_begin, __diag_it_end, kValue, [&](_Index __idx, const auto& __value) {
const auto __rng1_idx = __idx;
const auto __rng2_idx = __index_sum - __idx;

assert(__rng1_from <= __rng1_idx && __rng1_idx < __rng1_to);
assert(__rng2_from <= __rng2_idx && __rng2_idx < __rng2_to);
assert(__rng1_idx + __rng2_idx == __index_sum);
assert(__rng1_from <= __rng1_idx && __rng1_idx < __rng1_to);
assert(__rng2_from <= __rng2_idx && __rng2_idx < __rng2_to);
assert(__rng1_idx + __rng2_idx == __index_sum);

const auto __zero_or_one = __comp(__rng2[__rng2_idx], __rng1[__rng1_idx]);
return __zero_or_one < kValue;
});
const auto __zero_or_one = __comp(__rng2[__rng2_idx], __rng1[__rng1_idx]);
return __zero_or_one < kValue;
});

const _split_point_t<_Index> __result{ *__res, __index_sum - *__res + 1 };
assert(__result.first + __result.second == __i_elem);
const _split_point_t<_Index> __result{ *__res, __index_sum - *__res + 1 };
assert(__result.first + __result.second == __i_elem);

assert(__rng1_from <= __result.first && __result.first <= __rng1_to);
assert(__rng2_from <= __result.second && __result.second <= __rng2_to);
assert(__rng1_from <= __result.first && __result.first <= __rng1_to);
assert(__rng2_from <= __result.second && __result.second <= __rng2_to);

return __result;
}
else
{
assert(__rng1_from == 0);
assert(__rng2_from == 0);
return { __rng1_from, __rng2_from };
}
return __result;
}

// Do serial merge of the data from rng1 (starting from start1) and rng2 (starting from start2) and writing
Expand Down

0 comments on commit 6ad8170

Please sign in to comment.