Skip to content

Commit

Permalink
Support scratch space without result space (#1773)
Browse files Browse the repository at this point in the history
* Support scratch space without result space

* Fix type error of n_scratch

* Update include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_utils.h

---------

Co-authored-by: Sergey Kopienko <[email protected]>
  • Loading branch information
julianmi and SergeyKopienko authored Aug 12, 2024
1 parent 52efa5b commit fb6357c
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 56 deletions.
8 changes: 4 additions & 4 deletions include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ struct __parallel_scan_submitter<_CustomName, __internal::__optional_kernel_name
// Storage for the results of scan for each workgroup

using __result_and_scratch_storage_t = __result_and_scratch_storage<_ExecutionPolicy, _Type>;
__result_and_scratch_storage_t __result_and_scratch{__exec, __n_groups + 1};
__result_and_scratch_storage_t __result_and_scratch{__exec, 1, __n_groups + 1};

_PRINT_INFO_IN_DEBUG_MODE(__exec, __wgroup_size, __max_cu);

Expand Down Expand Up @@ -568,7 +568,7 @@ struct __parallel_copy_if_static_single_group_submitter<_Size, _ElemsPerItem, _W

constexpr ::std::uint32_t __elems_per_wg = _ElemsPerItem * _WGSize;
using __result_and_scratch_storage_t = __result_and_scratch_storage<_Policy, _Size>;
__result_and_scratch_storage_t __result{__policy, 0};
__result_and_scratch_storage_t __result{__policy, 1, 0};

auto __event = __policy.queue().submit([&](sycl::handler& __hdl) {
oneapi::dpl::__ranges::__require_access(__hdl, __in_rng, __out_rng);
Expand Down Expand Up @@ -664,7 +664,7 @@ __parallel_transform_scan_single_group(oneapi::dpl::__internal::__device_backend
// Although we do not actually need result storage in this case, we need to construct
// a placeholder here to match the return type of the non-single-work-group implementation
using __result_and_scratch_storage_t = __result_and_scratch_storage<_ExecutionPolicy, _ValueType>;
__result_and_scratch_storage_t __dummy_result_and_scratch{__exec, 0};
__result_and_scratch_storage_t __dummy_result_and_scratch{__exec, 0, 0};

if (__max_wg_size >= __targeted_wg_size)
{
Expand Down Expand Up @@ -1212,7 +1212,7 @@ __parallel_find_or_impl_one_wg(oneapi::dpl::__internal::__device_backend_tag, _E
const __FoundStateType __init_value, _Predicate __pred, _Ranges&&... __rngs)
{
using __result_and_scratch_storage_t = __result_and_scratch_storage<_ExecutionPolicy, __FoundStateType>;
__result_and_scratch_storage_t __result_storage(__exec, 0);
__result_and_scratch_storage_t __result_storage{__exec, 1, 0};

// Calculate the number of elements to be processed by each work-item.
const auto __iters_per_work_item = oneapi::dpl::__internal::__dpl_ceiling_div(__rng_n, __wgroup_size);
Expand Down
35 changes: 16 additions & 19 deletions include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ struct __parallel_transform_reduce_small_submitter<_Tp, _Commutative, _VecSize,
auto __reduce_pattern = unseq_backend::reduce_over_group<_ExecutionPolicy, _ReduceOp, _Tp>{__reduce_op};
const bool __is_full = __n == __work_group_size * __iters_per_work_item;

__result_and_scratch_storage<_ExecutionPolicy, _Tp> __scratch_container(__exec, 0);
using __result_and_scratch_storage_t = __result_and_scratch_storage<_ExecutionPolicy, _Tp>;
__result_and_scratch_storage_t __scratch_container{__exec, 1, 0};

sycl::event __reduce_event = __exec.queue().submit([&, __n](sycl::handler& __cgh) {
oneapi::dpl::__ranges::__require_access(__cgh, __rngs...); // get an access to data under SYCL buffer
Expand All @@ -146,9 +147,7 @@ struct __parallel_transform_reduce_small_submitter<_Tp, _Commutative, _VecSize,
__cgh.parallel_for<_Name...>(
sycl::nd_range<1>(sycl::range<1>(__work_group_size), sycl::range<1>(__work_group_size)),
[=](sycl::nd_item<1> __item_id) {
auto __res_ptr =
__result_and_scratch_storage<_ExecutionPolicy, _Tp>::__get_usm_or_buffer_accessor_ptr(
__res_acc);
auto __res_ptr = __result_and_scratch_storage_t::__get_usm_or_buffer_accessor_ptr(__res_acc);
__work_group_reduce_kernel<_Tp>(__item_id, __n, __iters_per_work_item, __is_full,
__transform_pattern, __reduce_pattern, __init, __temp_local,
__res_ptr, __rngs...);
Expand Down Expand Up @@ -191,7 +190,8 @@ struct __parallel_transform_reduce_device_kernel_submitter<_Tp, _Commutative, _V
auto
operator()(oneapi::dpl::__internal::__device_backend_tag, _ExecutionPolicy&& __exec, const _Size __n,
const _Size __work_group_size, const _Size __iters_per_work_item, _ReduceOp __reduce_op,
_TransformOp __transform_op, __result_and_scratch_storage<_ExecutionPolicy2, _Tp> __scratch_container,
_TransformOp __transform_op,
const __result_and_scratch_storage<_ExecutionPolicy2, _Tp>& __scratch_container,
_Ranges&&... __rngs) const
{
auto __transform_pattern =
Expand Down Expand Up @@ -238,7 +238,7 @@ struct __parallel_transform_reduce_work_group_kernel_submitter<_Tp, _Commutative
auto
operator()(oneapi::dpl::__internal::__device_backend_tag, _ExecutionPolicy&& __exec, sycl::event& __reduce_event,
const _Size __n, const _Size __work_group_size, const _Size __iters_per_work_item, _ReduceOp __reduce_op,
_InitType __init, __result_and_scratch_storage<_ExecutionPolicy2, _Tp> __scratch_container) const
_InitType __init, const __result_and_scratch_storage<_ExecutionPolicy2, _Tp>& __scratch_container) const
{
using _NoOpFunctor = unseq_backend::walk_n<_ExecutionPolicy, oneapi::dpl::__internal::__no_op>;
auto __transform_pattern =
Expand All @@ -248,6 +248,8 @@ struct __parallel_transform_reduce_work_group_kernel_submitter<_Tp, _Commutative

const bool __is_full = __n == __work_group_size * __iters_per_work_item;

using __result_and_scratch_storage_t = __result_and_scratch_storage<_ExecutionPolicy2, _Tp>;

__reduce_event = __exec.queue().submit([&, __n](sycl::handler& __cgh) {
__cgh.depends_on(__reduce_event);

Expand All @@ -258,12 +260,8 @@ struct __parallel_transform_reduce_work_group_kernel_submitter<_Tp, _Commutative
__cgh.parallel_for<_KernelName...>(
sycl::nd_range<1>(sycl::range<1>(__work_group_size), sycl::range<1>(__work_group_size)),
[=](sycl::nd_item<1> __item_id) {
auto __temp_ptr =
__result_and_scratch_storage<_ExecutionPolicy2, _Tp>::__get_usm_or_buffer_accessor_ptr(
__temp_acc);
auto __res_ptr =
__result_and_scratch_storage<_ExecutionPolicy2, _Tp>::__get_usm_or_buffer_accessor_ptr(
__res_acc, __n);
auto __temp_ptr = __result_and_scratch_storage_t::__get_usm_or_buffer_accessor_ptr(__temp_acc);
auto __res_ptr = __result_and_scratch_storage_t::__get_usm_or_buffer_accessor_ptr(__res_acc, __n);
__work_group_reduce_kernel<_Tp>(__item_id, __n, __iters_per_work_item, __is_full,
__transform_pattern, __reduce_pattern, __init, __temp_local,
__res_ptr, __temp_ptr);
Expand Down Expand Up @@ -292,7 +290,7 @@ __parallel_transform_reduce_mid_impl(oneapi::dpl::__internal::__device_backend_t
// number of buffer elements processed within workgroup
const _Size __size_per_work_group = __iters_per_work_item_device_kernel * __work_group_size;
const _Size __n_groups = oneapi::dpl::__internal::__dpl_ceiling_div(__n, __size_per_work_group);
__result_and_scratch_storage<_ExecutionPolicy, _Tp> __scratch_container(__exec, __n_groups);
__result_and_scratch_storage<_ExecutionPolicy, _Tp> __scratch_container{__exec, 1, __n_groups};

sycl::event __reduce_event =
__parallel_transform_reduce_device_kernel_submitter<_Tp, _Commutative, _VecSize, _ReduceDeviceKernel>()(
Expand Down Expand Up @@ -341,7 +339,9 @@ struct __parallel_transform_reduce_impl
_Size __n_groups = oneapi::dpl::__internal::__dpl_ceiling_div(__n, __size_per_work_group);

// Create temporary global buffers to store temporary values
__result_and_scratch_storage<_ExecutionPolicy, _Tp> __scratch_container(__exec, 2 * __n_groups);
const std::size_t __n_scratch = 2 * __n_groups;
using __result_and_scratch_storage_t = __result_and_scratch_storage<_ExecutionPolicy, _Tp>;
__result_and_scratch_storage_t __scratch_container{__exec, 1, __n_scratch};

// __is_first == true. Reduce over each work_group
// __is_first == false. Reduce between work groups
Expand Down Expand Up @@ -375,12 +375,9 @@ struct __parallel_transform_reduce_impl
sycl::nd_range<1>(sycl::range<1>(__n_groups * __work_group_size),
sycl::range<1>(__work_group_size)),
[=](sycl::nd_item<1> __item_id) {
auto __temp_ptr =
__result_and_scratch_storage<_ExecutionPolicy, _Tp>::__get_usm_or_buffer_accessor_ptr(
__temp_acc);
auto __temp_ptr = __result_and_scratch_storage_t::__get_usm_or_buffer_accessor_ptr(__temp_acc);
auto __res_ptr =
__result_and_scratch_storage<_ExecutionPolicy, _Tp>::__get_usm_or_buffer_accessor_ptr(
__res_acc, 2 * __n_groups);
__result_and_scratch_storage_t::__get_usm_or_buffer_accessor_ptr(__res_acc, 2 * __n_groups);
auto __local_idx = __item_id.get_local_id(0);
auto __group_idx = __item_id.get_group(0);
// 1. Initialization (transform part). Fill local memory
Expand Down
78 changes: 45 additions & 33 deletions include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -511,11 +511,12 @@ struct __result_and_scratch_storage
using __sycl_buffer_t = sycl::buffer<_T, 1>;

_ExecutionPolicy __exec;
::std::shared_ptr<_T> __scratch_buf;
::std::shared_ptr<_T> __result_buf;
::std::shared_ptr<__sycl_buffer_t> __sycl_buf;
std::shared_ptr<_T> __scratch_buf;
std::shared_ptr<_T> __result_buf;
std::shared_ptr<__sycl_buffer_t> __sycl_buf;

::std::size_t __scratch_n;
std::size_t __result_n;
std::size_t __scratch_n;
bool __use_USM_host;
bool __supports_USM_device;

Expand Down Expand Up @@ -548,40 +549,50 @@ struct __result_and_scratch_storage
}

public:
__result_and_scratch_storage(const _ExecutionPolicy& __exec_, ::std::size_t __scratch_n)
: __exec{__exec_}, __scratch_n{__scratch_n}, __use_USM_host{__use_USM_host_allocations(__exec.queue())},
__supports_USM_device{__use_USM_allocations(__exec.queue())}
__result_and_scratch_storage(const _ExecutionPolicy& __exec_, std::size_t __result_n, std::size_t __scratch_n)
: __exec{__exec_}, __result_n{__result_n}, __scratch_n{__scratch_n},
__use_USM_host{__use_USM_host_allocations(__exec.queue())}, __supports_USM_device{
__use_USM_allocations(__exec.queue())}
{
if (__use_USM_host && __supports_USM_device)
const std::size_t __total_n = __scratch_n + __result_n;
// Skip in case this is a dummy container
if (__total_n > 0)
{
// Separate scratch (device) and result (host) allocations on performant backends (i.e. L0)
if (__scratch_n > 0)
if (__use_USM_host && __supports_USM_device)
{
// Separate scratch (device) and result (host) allocations on performant backends (i.e. L0)
if (__scratch_n > 0)
{
__scratch_buf = std::shared_ptr<_T>(
__internal::__sycl_usm_alloc<_ExecutionPolicy, _T, sycl::usm::alloc::device>{__exec}(
__scratch_n),
__internal::__sycl_usm_free<_ExecutionPolicy, _T>{__exec});
}
if (__result_n > 0)
{
__result_buf = std::shared_ptr<_T>(
__internal::__sycl_usm_alloc<_ExecutionPolicy, _T, sycl::usm::alloc::host>{__exec}(__result_n),
__internal::__sycl_usm_free<_ExecutionPolicy, _T>{__exec});
}
}
else if (__supports_USM_device)
{
__scratch_buf = ::std::shared_ptr<_T>(
__internal::__sycl_usm_alloc<_ExecutionPolicy, _T, sycl::usm::alloc::device>{__exec}(__scratch_n),
// If we don't use host memory, malloc only a single unified device allocation
__scratch_buf = std::shared_ptr<_T>(
__internal::__sycl_usm_alloc<_ExecutionPolicy, _T, sycl::usm::alloc::device>{__exec}(__total_n),
__internal::__sycl_usm_free<_ExecutionPolicy, _T>{__exec});
}
__result_buf = ::std::shared_ptr<_T>(
__internal::__sycl_usm_alloc<_ExecutionPolicy, _T, sycl::usm::alloc::host>{__exec}(1),
__internal::__sycl_usm_free<_ExecutionPolicy, _T>{__exec});
}
else if (__supports_USM_device)
{
// If we don't use host memory, malloc only a single unified device allocation
__scratch_buf = ::std::shared_ptr<_T>(
__internal::__sycl_usm_alloc<_ExecutionPolicy, _T, sycl::usm::alloc::device>{__exec}(__scratch_n + 1),
__internal::__sycl_usm_free<_ExecutionPolicy, _T>{__exec});
}
else
{
// If we don't have USM support allocate memory here
__sycl_buf = ::std::make_shared<__sycl_buffer_t>(__sycl_buffer_t(__scratch_n + 1));
else
{
// If we don't have USM support allocate memory here
__sycl_buf = std::make_shared<__sycl_buffer_t>(__sycl_buffer_t(__total_n));
}
}
}

template <typename _Acc>
static auto
__get_usm_or_buffer_accessor_ptr(const _Acc& __acc, ::std::size_t __scratch_n = 0)
__get_usm_or_buffer_accessor_ptr(const _Acc& __acc, std::size_t __scratch_n = 0)
{
#if _ONEDPL_SYCL_UNIFIED_USM_BUFFER_PRESENT
return __acc.__get_pointer();
Expand All @@ -591,7 +602,7 @@ struct __result_and_scratch_storage
}

auto
__get_result_acc(sycl::handler& __cgh)
__get_result_acc(sycl::handler& __cgh) const
{
#if _ONEDPL_SYCL_UNIFIED_USM_BUFFER_PRESENT
if (__use_USM_host && __supports_USM_device)
Expand All @@ -605,7 +616,7 @@ struct __result_and_scratch_storage
}

auto
__get_scratch_acc(sycl::handler& __cgh)
__get_scratch_acc(sycl::handler& __cgh) const
{
#if _ONEDPL_SYCL_UNIFIED_USM_BUFFER_PRESENT
if (__use_USM_host || __supports_USM_device)
Expand All @@ -627,6 +638,7 @@ struct __result_and_scratch_storage
_T
__get_value(size_t idx = 0) const
{
assert(idx < __result_n);
if (__use_USM_host && __supports_USM_device)
{
return *(__result_buf.get() + idx);
Expand Down Expand Up @@ -663,22 +675,22 @@ class __future : private std::tuple<_Args...>

template <typename _T>
constexpr auto
__wait_and_get_value(sycl::buffer<_T>& __buf)
__wait_and_get_value(const sycl::buffer<_T>& __buf)
{
//according to a contract, returned value is one-element sycl::buffer
return __buf.get_host_access(sycl::read_only)[0];
}

template <typename _ExecutionPolicy, typename _T>
constexpr auto
__wait_and_get_value(__result_and_scratch_storage<_ExecutionPolicy, _T>& __storage)
__wait_and_get_value(const __result_and_scratch_storage<_ExecutionPolicy, _T>& __storage)
{
return __storage.__wait_and_get_value(__my_event);
}

template <typename _T>
constexpr auto
__wait_and_get_value(_T& __val)
__wait_and_get_value(const _T& __val)
{
wait();
return __val;
Expand Down

0 comments on commit fb6357c

Please sign in to comment.