Skip to content

Commit

Permalink
Reduce code duplication with lambda
Browse files Browse the repository at this point in the history
Signed-off-by: Matthew Michel <[email protected]>
  • Loading branch information
mmichel11 committed Mar 22, 2024
1 parent 918ec1c commit be9345f
Showing 1 changed file with 31 additions and 34 deletions.
65 changes: 31 additions & 34 deletions include/oneapi/dpl/internal/binary_search_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,26 +138,26 @@ lower_bound_impl(__internal::__hetero_tag<_BackendTag>, Policy&& policy, InputIt
auto keep_result = oneapi::dpl::__ranges::__get_sycl_range<__bknd::access_mode::read_write, OutputIterator>();
auto result_buf = keep_result(result, result + value_size);
auto zip_vw = make_zip_view(input_buf.all_view(), value_buf.all_view(), result_buf.all_view());
auto run_lower_bound = [comp, value_size, zip_vw](auto&& policy, auto size) {
__bknd::__parallel_for(
_BackendTag{}, ::std::forward<decltype(policy)>(policy),
custom_brick<StrictWeakOrdering, decltype(size), search_algorithm::lower_bound>{comp, size}, value_size,
zip_vw)
.wait();
};
// Enable index calculation to proceed with uint32_t if input range is small enough.
if (size <= ::std::numeric_limits<::std::uint32_t>::max())
{
__bknd::__parallel_for(_BackendTag{}, ::std::forward<Policy>(policy),
custom_brick<StrictWeakOrdering, ::std::uint32_t, search_algorithm::lower_bound>{
comp, static_cast<::std::uint32_t>(size)},
value_size, zip_vw)
.wait();
run_lower_bound(::std::forward<Policy>(policy), static_cast<::std::uint32_t>(size));
}
else
{
auto fallback_policy = __bknd::make_wrapped_policy<lower_bound_fallback>(::std::forward<Policy>(policy));
__bknd::__parallel_for(_BackendTag{}, ::std::move(fallback_policy),
custom_brick<StrictWeakOrdering, ::std::size_t, search_algorithm::lower_bound>{
comp, static_cast<::std::size_t>(size)},
value_size, zip_vw)
.wait();
run_lower_bound(__bknd::make_wrapped_policy<lower_bound_fallback>(::std::forward<Policy>(policy)),
static_cast<::std::size_t>(size));
}
return result + value_size;
}

template <typename T>
class upper_bound_fallback;

Expand All @@ -184,23 +184,22 @@ upper_bound_impl(__internal::__hetero_tag<_BackendTag>, Policy&& policy, InputIt
auto keep_result = oneapi::dpl::__ranges::__get_sycl_range<__bknd::access_mode::read_write, OutputIterator>();
auto result_buf = keep_result(result, result + value_size);
auto zip_vw = make_zip_view(input_buf.all_view(), value_buf.all_view(), result_buf.all_view());
auto run_upper_bound = [comp, value_size, zip_vw](auto&& policy, auto size) {
__bknd::__parallel_for(
_BackendTag{}, ::std::forward<decltype(policy)>(policy),
custom_brick<StrictWeakOrdering, decltype(size), search_algorithm::upper_bound>{comp, size}, value_size,
zip_vw)
.wait();
};
// Enable index calculation to proceed with uint32_t if input range is small enough.
if (size <= ::std::numeric_limits<::std::uint32_t>::max())
{
__bknd::__parallel_for(_BackendTag{}, ::std::forward<Policy>(policy),
custom_brick<StrictWeakOrdering, ::std::uint32_t, search_algorithm::upper_bound>{
comp, static_cast<::std::uint32_t>(size)},
value_size, zip_vw)
.wait();
run_upper_bound(::std::forward<Policy>(policy), static_cast<::std::uint32_t>(size));
}
else
{
auto fallback_policy = __bknd::make_wrapped_policy<upper_bound_fallback>(::std::forward<Policy>(policy));
__bknd::__parallel_for(_BackendTag{}, ::std::move(fallback_policy),
custom_brick<StrictWeakOrdering, ::std::size_t, search_algorithm::upper_bound>{
comp, static_cast<::std::size_t>(size)},
value_size, zip_vw)
.wait();
run_upper_bound(__bknd::make_wrapped_policy<upper_bound_fallback>(::std::forward<Policy>(policy)),
static_cast<::std::size_t>(size));
}
return result + value_size;
}
Expand Down Expand Up @@ -231,24 +230,22 @@ binary_search_impl(__internal::__hetero_tag<_BackendTag>, Policy&& policy, Input
auto keep_result = oneapi::dpl::__ranges::__get_sycl_range<__bknd::access_mode::read_write, OutputIterator>();
auto result_buf = keep_result(result, result + value_size);
auto zip_vw = make_zip_view(input_buf.all_view(), value_buf.all_view(), result_buf.all_view());

auto run_binary_search = [&comp, &value_size, &zip_vw](auto&& policy, auto size) {
__bknd::__parallel_for(
_BackendTag{}, ::std::forward<decltype(policy)>(policy),
custom_brick<StrictWeakOrdering, decltype(size), search_algorithm::binary_search>{comp, size}, value_size,
zip_vw)
.wait();
};
// Enable index calculation to proceed with uint32_t if input range is small enough.
if (size <= ::std::numeric_limits<::std::uint32_t>::max())
{
__bknd::__parallel_for(_BackendTag{}, ::std::forward<Policy>(policy),
custom_brick<StrictWeakOrdering, ::std::uint32_t, search_algorithm::binary_search>{
comp, static_cast<::std::uint32_t>(size)},
value_size, zip_vw)
.wait();
run_binary_search(::std::forward<Policy>(policy), static_cast<::std::uint32_t>(size));
}
else
{
auto fallback_policy = __bknd::make_wrapped_policy<binary_search_fallback>(::std::forward<Policy>(policy));
__bknd::__parallel_for(_BackendTag{}, ::std::move(fallback_policy),
custom_brick<StrictWeakOrdering, ::std::size_t, search_algorithm::binary_search>{
comp, static_cast<::std::size_t>(size)},
value_size, zip_vw)
.wait();
run_binary_search(__bknd::make_wrapped_policy<binary_search_fallback>(::std::forward<Policy>(policy)),
static_cast<::std::size_t>(size));
}
return result + value_size;
}
Expand Down

0 comments on commit be9345f

Please sign in to comment.