Skip to content

Commit

Permalink
Fixes in __detail namespace (#1729)
Browse files Browse the repository at this point in the history
* Relative paths in includes

* format fix

* fixes in __detail namespace

* removed unnamed namespaces

* fixes in namespaces

* ranges::local_or_identity in place of __detail::local

* minor fix

* ranges.hpp update

* ranges.hpp update

* local_or_identity made __detail

* removed is_localizable

---------

Co-authored-by: Łukasz Ślusarczyk <[email protected]>
Co-authored-by: Łukasz Ślusarczyk <[email protected]>
  • Loading branch information
3 people authored Aug 5, 2024
1 parent d6f21d3 commit 917eeb7
Show file tree
Hide file tree
Showing 18 changed files with 66 additions and 223 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,12 @@
#include <iterator>
#include <type_traits>

#include "ranges.hpp"
#include "std_ranges_shim.hpp"

namespace oneapi::dpl::experimental::dr
{

namespace
{

template <typename R>
concept has_segments_method = requires(R r)
{
{r.segments()};
};

} // namespace

template <typename Accessor>
class iterator_adaptor
{
Expand Down Expand Up @@ -208,7 +198,7 @@ class iterator_adaptor
}

auto
segments() const noexcept requires(has_segments_method<accessor_type>)
segments() const noexcept requires(ranges::__detail::has_segments_method<accessor_type>)
{
return accessor_.segments();
}
Expand Down
133 changes: 14 additions & 119 deletions include/oneapi/dpl/internal/distributed_ranges_impl/detail/ranges.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ namespace ranges
template <typename>
inline constexpr bool disable_rank = false;

namespace
namespace __detail
{

template <typename T>
Expand All @@ -54,14 +54,8 @@ template <typename Iter>
concept is_remote_iterator_shadow_impl_ =
std::forward_iterator<Iter> && has_rank_method<Iter> && !disable_rank<std::remove_cv_t<Iter>>;

} // namespace

namespace
{

struct rank_fn_
{

// Return the rank associated with a remote range.
// This is either:
// 1) r.rank(), if the remote range has a `rank()` method
Expand Down Expand Up @@ -101,11 +95,11 @@ struct rank_fn_
}
};

} // namespace
} // namespace __detail

inline constexpr auto rank = rank_fn_{};
inline constexpr auto rank = __detail::rank_fn_{};

namespace
namespace __detail
{

template <typename R>
Expand Down Expand Up @@ -164,13 +158,6 @@ struct segments_fn_
}
};

} // namespace

inline constexpr auto segments = segments_fn_{};

namespace
{

template <typename Iter>
concept has_local_adl = requires(Iter& iter)
{
Expand All @@ -187,35 +174,6 @@ concept iter_has_local_method = std::forward_iterator<Iter> && requires(Iter ite
} -> std::forward_iterator;
};

template <typename T>
struct is_localizable_helper : std::false_type
{
};

template <has_local_adl T>
struct is_localizable_helper<T> : std::true_type
{
};

template <iter_has_local_method T>
struct is_localizable_helper<T> : std::true_type
{
};

template <std::forward_iterator Iter>
requires(not iter_has_local_method<Iter> && not has_local_adl<Iter>) && requires() { std::iter_value_t<Iter>(); }
struct is_localizable_helper<Iter> : is_localizable_helper<std::iter_value_t<Iter>>
{
};

template <stdrng::forward_range R>
struct is_localizable_helper<R> : is_localizable_helper<stdrng::iterator_t<R>>
{
};

template <typename T>
concept is_localizable = is_localizable_helper<T>::value;

template <typename Segment>
concept segment_has_local_method = stdrng::forward_range<Segment> && requires(Segment segment)
{
Expand All @@ -227,62 +185,8 @@ concept segment_has_local_method = stdrng::forward_range<Segment> && requires(Se
struct local_fn_
{

// based on https://ericniebler.github.io/range-v3/#autotoc_md30 "Create
// custom iterators"
// TODO: rewrite using iterator_interface from
// https://github.com/boostorg/stl_interfaces
template <typename Iter>
requires stdrng::forward_range<typename Iter::value_type>
struct cursor_over_local_ranges
{
Iter iter;
auto
make_begin_for_counted() const
{
if constexpr (iter_has_local_method<stdrng::iterator_t<typename Iter::value_type>>)
return stdrng::begin(*iter).local();
else
return std::iterator<std::bidirectional_iterator_tag,
cursor_over_local_ranges<stdrng::iterator_t<typename Iter::value_type>>>(
stdrng::begin(*iter));
}
auto
read() const
{
return stdrng::views::counted(make_begin_for_counted(), stdrng::size(*iter));
}
bool
equal(const cursor_over_local_ranges& other) const
{
return iter == other.iter;
}
void
next()
{
++iter;
}
void
prev()
{
--iter;
}
void
advance(std::ptrdiff_t n)
{
this->iter += n;
}
std::ptrdiff_t
distance_to(const cursor_over_local_ranges& other) const
{
return other.iter - this->iter;
}
cursor_over_local_ranges() = default;
cursor_over_local_ranges(Iter iter) : iter(iter) {}
};

template <std::forward_iterator Iter>
requires(has_local_adl<Iter> || iter_has_local_method<Iter> || std::contiguous_iterator<Iter> ||
is_localizable<Iter>) auto
requires(has_local_adl<Iter> || iter_has_local_method<Iter> || std::contiguous_iterator<Iter>) auto
operator()(Iter iter) const
{
if constexpr (iter_has_local_method<Iter>)
Expand All @@ -293,10 +197,6 @@ struct local_fn_
{
return local_(iter);
}
else if constexpr (is_localizable<Iter>)
{
return std::iterator<std::bidirectional_iterator_tag, cursor_over_local_ranges<Iter>>(iter);
}
else if constexpr (std::contiguous_iterator<Iter>)
{
return iter;
Expand All @@ -305,7 +205,7 @@ struct local_fn_

template <stdrng::forward_range R>
requires(has_local_adl<R> || iter_has_local_method<stdrng::iterator_t<R>> || segment_has_local_method<R> ||
std::contiguous_iterator<stdrng::iterator_t<R>> || is_localizable<R> || stdrng::contiguous_range<R>) auto
std::contiguous_iterator<stdrng::iterator_t<R>> || stdrng::contiguous_range<R>) auto
operator()(R&& r) const
{
if constexpr (segment_has_local_method<R>)
Expand All @@ -320,23 +220,16 @@ struct local_fn_
{
return local_(std::forward<R>(r));
}
else if constexpr (is_localizable<R>)
{
return stdrng::views::counted(
std::iterator<std::input_iterator_tag, cursor_over_local_ranges<stdrng::iterator_t<R>>>(
stdrng::begin(r)),
stdrng::size(r));
}
else if constexpr (std::contiguous_iterator<stdrng::iterator_t<R>>)
{
return std::span(stdrng::begin(r), stdrng::size(r));
}
}
};

} // namespace
} // namespace __detail

inline constexpr auto local = local_fn_{};
inline constexpr auto local = __detail::local_fn_{};

namespace __detail
{
Expand All @@ -349,7 +242,7 @@ concept has_local = requires(T& t)
} -> std::convertible_to<std::any>;
};

struct local_fn_
struct local_or_identity_fn_
{
template <typename T>
requires(has_local<T>) auto
Expand All @@ -359,17 +252,19 @@ struct local_fn_
}

template <typename T>
decltype(auto)
auto
operator()(T&& t) const
{
return t;
return std::forward<T>(t);
}
};

inline constexpr auto local = local_fn_{};
inline constexpr auto local_or_identity = local_or_identity_fn_{};

} // namespace __detail

inline constexpr auto segments = __detail::segments_fn_{};

} // namespace ranges

} // namespace oneapi::dpl::experimental::dr
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,4 @@ inline constexpr bool stdrng::enable_borrowed_range<oneapi::dpl::experimental::d

#endif

#endif /* _ONEDPL_DR_DETAIL_REMOTE_SUBRANGE_HPP */
#endif /* _ONEDPL_DR_DETAIL_REMOTE_SUBRANGE_HPP */
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,4 @@ namespace stdrng = ::ranges;

#endif /* DR_USE_RANGES_V3 */

#endif /* _ONEDPL_DR_DETAIL_RANGES_SHIM_HPP */
#endif /* _ONEDPL_DR_DETAIL_RANGES_SHIM_HPP */
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@
#include "sort.hpp"
#include "transform.hpp"

#endif
#endif
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
namespace oneapi::dpl::experimental::dr::sp
{

namespace __detail
{
template <typename ExecutionPolicy, distributed_contiguous_range R, distributed_contiguous_range O, typename U,
typename BinaryOp>
void
Expand All @@ -54,7 +56,7 @@ exclusive_scan_impl_(ExecutionPolicy&& policy, R&& r, O&& o, U init, BinaryOp bi
{
auto&& [in_segment, out_segment] = segs;

auto last_element = stdrng::prev(stdrng::end(__detail::local(in_segment)));
auto last_element = stdrng::prev(stdrng::end(ranges::__detail::local_or_identity(in_segment)));
auto dest = d_inits + segment_id;

auto&& q = __detail::queue(ranges::rank(in_segment));
Expand Down Expand Up @@ -151,7 +153,7 @@ exclusive_scan_impl_(ExecutionPolicy&& policy, R&& r, O&& o, U init, BinaryOp bi
auto first = stdrng::begin(out_segment);
dr::__detail::direct_iterator d_first(first);

auto d_sum = ranges::__detail::local(partial_sums).begin() + idx - 1;
auto d_sum = ranges::__detail::local_or_identity(partial_sums).begin() + idx - 1;

sycl::event e = dr::__detail::parallel_for(q, sycl::range<>(stdrng::distance(out_segment)), [=](auto idx) {
d_first[idx] = binary_op(d_first[idx], *d_sum);
Expand All @@ -165,37 +167,38 @@ exclusive_scan_impl_(ExecutionPolicy&& policy, R&& r, O&& o, U init, BinaryOp bi
__detail::wait(events);
}

} // namespace __detail
// Ranges versions

template <typename ExecutionPolicy, distributed_contiguous_range R, distributed_contiguous_range O, typename T,
typename BinaryOp>
void
exclusive_scan(ExecutionPolicy&& policy, R&& r, O&& o, T init, BinaryOp binary_op)
{
exclusive_scan_impl_(std::forward<ExecutionPolicy>(policy), std::forward<R>(r), std::forward<O>(o), init,
binary_op);
__detail::exclusive_scan_impl_(std::forward<ExecutionPolicy>(policy), std::forward<R>(r), std::forward<O>(o), init,
binary_op);
}

template <typename ExecutionPolicy, distributed_contiguous_range R, distributed_contiguous_range O, typename T>
void
exclusive_scan(ExecutionPolicy&& policy, R&& r, O&& o, T init)
{
exclusive_scan_impl_(std::forward<ExecutionPolicy>(policy), std::forward<R>(r), std::forward<O>(o), init,
std::plus<>{});
__detail::exclusive_scan_impl_(std::forward<ExecutionPolicy>(policy), std::forward<R>(r), std::forward<O>(o), init,
std::plus<>{});
}

template <distributed_contiguous_range R, distributed_contiguous_range O, typename T, typename BinaryOp>
void
exclusive_scan(R&& r, O&& o, T init, BinaryOp binary_op)
{
exclusive_scan_impl_(par_unseq, std::forward<R>(r), std::forward<O>(o), init, binary_op);
__detail::exclusive_scan_impl_(par_unseq, std::forward<R>(r), std::forward<O>(o), init, binary_op);
}

template <distributed_contiguous_range R, distributed_contiguous_range O, typename T>
void
exclusive_scan(R&& r, O&& o, T init)
{
exclusive_scan_impl_(par_unseq, std::forward<R>(r), std::forward<O>(o), init, std::plus<>{});
__detail::exclusive_scan_impl_(par_unseq, std::forward<R>(r), std::forward<O>(o), init, std::plus<>{});
}

// Iterator versions
Expand All @@ -208,8 +211,8 @@ exclusive_scan(ExecutionPolicy&& policy, Iter first, Iter last, OutputIter d_fir
auto dist = stdrng::distance(first, last);
auto d_last = d_first;
stdrng::advance(d_last, dist);
exclusive_scan_impl_(std::forward<ExecutionPolicy>(policy), stdrng::subrange(first, last),
stdrng::subrange(d_first, d_last), init, binary_op);
__detail::exclusive_scan_impl_(std::forward<ExecutionPolicy>(policy), stdrng::subrange(first, last),
stdrng::subrange(d_first, d_last), init, binary_op);
}

template <typename ExecutionPolicy, distributed_iterator Iter, distributed_iterator OutputIter, typename T>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,4 @@ struct sycl_device_collection

} // namespace oneapi::dpl::experimental::dr::sp

#endif /* _ONEDPL_DR_SP_EXECUTION_POLICY_HPP */
#endif /* _ONEDPL_DR_SP_EXECUTION_POLICY_HPP */
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ for_each(ExecutionPolicy&& policy, R&& r, Fn fn)

assert(stdrng::distance(segment) > 0);

auto local_segment = __detail::local(segment);
auto local_segment = ranges::__detail::local_or_identity(segment);

auto first = stdrng::begin(local_segment);

Expand Down
Loading

0 comments on commit 917eeb7

Please sign in to comment.