diff --git a/include/oneapi/dpl/pstl/hetero/dpcpp/sycl_iterator.h b/include/oneapi/dpl/pstl/hetero/dpcpp/sycl_iterator.h index 05047f61d77..27bfb0d4a88 100644 --- a/include/oneapi/dpl/pstl/hetero/dpcpp/sycl_iterator.h +++ b/include/oneapi/dpl/pstl/hetero/dpcpp/sycl_iterator.h @@ -149,6 +149,34 @@ struct _ModeConverter static constexpr access_mode __value = access_mode::discard_write; }; +template ::value_type> +using __default_alloc_vec_iter = typename std::vector::iterator; + +template ::value_type> +using __usm_shared_alloc_vec_iter = + typename std::vector>::iterator; + +template ::value_type> +using __usm_host_alloc_vec_iter = + typename std::vector>::iterator; + +// Evaluates to true if the provided type is an iterator with a value_type and if the implementation of a +// std::vector::iterator can be distinguished between three different allocators, the +// default, usm_shared, and usm_host. If all are distinct, it is very unlikely any non-usm based allocator +// could be confused with a usm allocator. +template +struct __vector_iter_distinguishes_by_allocator : std::false_type +{ +}; +template +struct __vector_iter_distinguishes_by_allocator< + Iter, std::enable_if_t, __usm_shared_alloc_vec_iter> && + !std::is_same_v<__default_alloc_vec_iter, __usm_host_alloc_vec_iter> && + !std::is_same_v<__usm_host_alloc_vec_iter, __usm_shared_alloc_vec_iter>>> + : std::true_type +{ +}; + } // namespace __internal template diff --git a/include/oneapi/dpl/pstl/hetero/dpcpp/utils_ranges_sycl.h b/include/oneapi/dpl/pstl/hetero/dpcpp/utils_ranges_sycl.h index 1821301e911..9351b20dc88 100644 --- a/include/oneapi/dpl/pstl/hetero/dpcpp/utils_ranges_sycl.h +++ b/include/oneapi/dpl/pstl/hetero/dpcpp/utils_ranges_sycl.h @@ -22,6 +22,7 @@ #include "../../utils_ranges.h" #include "../../iterator_impl.h" #include "../../glue_numeric_defs.h" +#include "sycl_iterator.h" #include "sycl_defs.h" namespace oneapi @@ -206,6 +207,16 @@ struct is_passed_directly +struct is_passed_directly< + Iter, std::enable_if_t<(std::is_same_v> || + std::is_same_v>) && + oneapi::dpl::__internal::__vector_iter_distinguishes_by_allocator::value>> : + std::true_type +{ +}; + template struct is_passed_directly> : ::std::true_type {