Skip to content

Commit

Permalink
__result_and_scratch_storage changes for scan
Browse files Browse the repository at this point in the history
Signed-off-by: Dan Hoeflinger <[email protected]>
  • Loading branch information
danhoeflinger committed Aug 6, 2024
1 parent a697e86 commit c8a274a
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 79 deletions.
99 changes: 45 additions & 54 deletions include/oneapi/dpl/pstl/hetero/dpcpp/parallel_backend_sycl.h
Original file line number Diff line number Diff line change
Expand Up @@ -262,29 +262,6 @@ __parallel_for(oneapi::dpl::__internal::__device_backend_tag, _ExecutionPolicy&&
//------------------------------------------------------------------------
// parallel_transform_scan - async pattern
//------------------------------------------------------------------------
template <typename _GlobalScan, typename _Range2, typename _Range1, typename _Accessor, typename _Size>
struct __global_scan_caller
{
__global_scan_caller(const _GlobalScan& __global_scan, const _Range2& __rng2, const _Range1& __rng1,
const _Accessor& __wg_sums_acc, _Size __n, ::std::size_t __size_per_wg)
: __m_global_scan(__global_scan), __m_rng2(__rng2), __m_rng1(__rng1), __m_wg_sums_acc(__wg_sums_acc),
__m_n(__n), __m_size_per_wg(__size_per_wg)
{
}

void operator()(sycl::item<1> __item) const
{
__m_global_scan(__item, __m_rng2, __m_rng1, __m_wg_sums_acc, __m_n, __m_size_per_wg);
}

private:
_GlobalScan __m_global_scan;
_Range2 __m_rng2;
_Range1 __m_rng1;
_Accessor __m_wg_sums_acc;
_Size __m_n;
::std::size_t __m_size_per_wg;
};

// Please see the comment for __parallel_for_submitter for optional kernel name explanation
template <typename _CustomName, typename _PropagateScanName>
Expand Down Expand Up @@ -332,14 +309,16 @@ struct __parallel_scan_submitter<_CustomName, __internal::__optional_kernel_name
auto __size_per_wg = __iters_per_witem * __wgroup_size;
auto __n_groups = oneapi::dpl::__internal::__dpl_ceiling_div(__n, __size_per_wg);
// Storage for the results of scan for each workgroup
sycl::buffer<_Type> __wg_sums(__n_groups);

using _TempStorage = __result_and_scratch_storage<std::decay_t<_ExecutionPolicy>, _Type>;
_TempStorage __result_and_scratch{__exec, __n_groups + 1};

_PRINT_INFO_IN_DEBUG_MODE(__exec, __wgroup_size, __max_cu);

// 1. Local scan on each workgroup
auto __submit_event = __exec.queue().submit([&](sycl::handler& __cgh) {
oneapi::dpl::__ranges::__require_access(__cgh, __rng1, __rng2); //get an access to data under SYCL buffer
auto __wg_sums_acc = __wg_sums.template get_access<access_mode::discard_write>(__cgh);
auto __temp_acc = __result_and_scratch.__get_scratch_acc(__cgh);
__dpl_sycl::__local_accessor<_Type> __local_acc(__wgroup_size, __cgh);
#if _ONEDPL_COMPILE_KERNEL && _ONEDPL_KERNEL_BUNDLE_PRESENT
__cgh.use_kernel_bundle(__kernel_1.get_kernel_bundle());
Expand All @@ -349,7 +328,8 @@ struct __parallel_scan_submitter<_CustomName, __internal::__optional_kernel_name
__kernel_1,
#endif
sycl::nd_range<1>(__n_groups * __wgroup_size, __wgroup_size), [=](sycl::nd_item<1> __item) {
__local_scan(__item, __n, __local_acc, __rng1, __rng2, __wg_sums_acc, __size_per_wg, __wgroup_size,
auto __temp_ptr = _TempStorage::__get_usm_or_buffer_accessor_ptr(__temp_acc);
__local_scan(__item, __n, __local_acc, __rng1, __rng2, __temp_ptr, __size_per_wg, __wgroup_size,
__iters_per_witem, __init);
});
});
Expand All @@ -359,7 +339,7 @@ struct __parallel_scan_submitter<_CustomName, __internal::__optional_kernel_name
auto __iters_per_single_wg = oneapi::dpl::__internal::__dpl_ceiling_div(__n_groups, __wgroup_size);
__submit_event = __exec.queue().submit([&](sycl::handler& __cgh) {
__cgh.depends_on(__submit_event);
auto __wg_sums_acc = __wg_sums.template get_access<access_mode::read_write>(__cgh);
auto __temp_acc = __result_and_scratch.__get_scratch_acc(__cgh);
__dpl_sycl::__local_accessor<_Type> __local_acc(__wgroup_size, __cgh);
#if _ONEDPL_COMPILE_KERNEL && _ONEDPL_KERNEL_BUNDLE_PRESENT
__cgh.use_kernel_bundle(__kernel_2.get_kernel_bundle());
Expand All @@ -370,8 +350,9 @@ struct __parallel_scan_submitter<_CustomName, __internal::__optional_kernel_name
#endif
// TODO: try to balance work between several workgroups instead of one
sycl::nd_range<1>(__wgroup_size, __wgroup_size), [=](sycl::nd_item<1> __item) {
__group_scan(__item, __n_groups, __local_acc, __wg_sums_acc, __wg_sums_acc,
/*dummy*/ __wg_sums_acc, __n_groups, __wgroup_size, __iters_per_single_wg);
auto __temp_ptr = _TempStorage::__get_usm_or_buffer_accessor_ptr(__temp_acc);
__group_scan(__item, __n_groups, __local_acc, __temp_ptr, __temp_ptr,
/*dummy*/ __temp_ptr, __n_groups, __wgroup_size, __iters_per_single_wg);
});
});
}
Expand All @@ -380,15 +361,16 @@ struct __parallel_scan_submitter<_CustomName, __internal::__optional_kernel_name
auto __final_event = __exec.queue().submit([&](sycl::handler& __cgh) {
__cgh.depends_on(__submit_event);
oneapi::dpl::__ranges::__require_access(__cgh, __rng1, __rng2); //get an access to data under SYCL buffer
auto __wg_sums_acc = __wg_sums.template get_access<access_mode::read>(__cgh);
__cgh.parallel_for<_PropagateScanName...>(
sycl::range<1>(__n_groups * __size_per_wg),
__global_scan_caller<_GlobalScan, ::std::decay_t<_Range2>, ::std::decay_t<_Range1>,
decltype(__wg_sums_acc), decltype(__n)>(__global_scan, __rng2, __rng1,
__wg_sums_acc, __n, __size_per_wg));
auto __temp_acc = __result_and_scratch.__get_scratch_acc(__cgh);
auto __res_acc = __result_and_scratch.__get_result_acc(__cgh);
__cgh.parallel_for<_PropagateScanName...>(sycl::range<1>(__n_groups * __size_per_wg), [=](auto __item) {
auto __temp_ptr = _TempStorage::__get_usm_or_buffer_accessor_ptr(__temp_acc);
auto __res_ptr = _TempStorage::__get_usm_or_buffer_accessor_ptr(__res_acc, __n_groups + 1);
__global_scan(__item, __rng2, __rng1, __temp_ptr, __res_ptr, __n, __size_per_wg);
});
});

return __future(__final_event, sycl::buffer(__wg_sums, sycl::id<1>(__n_groups - 1), sycl::range<1>(1)));
return __future(__final_event, __result_and_scratch);
}
};

Expand Down Expand Up @@ -434,7 +416,7 @@ struct __parallel_transform_scan_dynamic_single_group_submitter<_Inclusive,
const ::std::uint16_t __elems_per_item = oneapi::dpl::__internal::__dpl_ceiling_div(__n, __wg_size);
const ::std::uint16_t __elems_per_wg = __elems_per_item * __wg_size;

auto __event = __policy.queue().submit([&](sycl::handler& __hdl) {
return __policy.queue().submit([&](sycl::handler& __hdl) {
oneapi::dpl::__ranges::__require_access(__hdl, __in_rng, __out_rng);

auto __lacc = __dpl_sycl::__local_accessor<_ValueType>(sycl::range<1>{__elems_per_wg}, __hdl);
Expand Down Expand Up @@ -466,7 +448,6 @@ struct __parallel_transform_scan_dynamic_single_group_submitter<_Inclusive,
}
});
});
return __future(__event);
}
};

Expand All @@ -489,7 +470,7 @@ struct __parallel_transform_scan_static_single_group_submitter<_Inclusive, _Elem

constexpr ::uint32_t __elems_per_wg = _ElemsPerItem * _WGSize;

auto __event = __policy.queue().submit([&](sycl::handler& __hdl) {
return __policy.queue().submit([&](sycl::handler& __hdl) {
oneapi::dpl::__ranges::__require_access(__hdl, __in_rng, __out_rng);

auto __lacc = __dpl_sycl::__local_accessor<_ValueType>(sycl::range<1>{__elems_per_wg}, __hdl);
Expand Down Expand Up @@ -559,7 +540,6 @@ struct __parallel_transform_scan_static_single_group_submitter<_Inclusive, _Elem
}
});
});
return __future(__event);
}
};

Expand All @@ -575,7 +555,7 @@ struct __parallel_copy_if_static_single_group_submitter<_Size, _ElemsPerItem, _W
template <typename _Policy, typename _InRng, typename _OutRng, typename _InitType, typename _BinaryOperation,
typename _UnaryOp>
auto
operator()(const _Policy& __policy, _InRng&& __in_rng, _OutRng&& __out_rng, ::std::size_t __n, _InitType __init,
operator()(_Policy&& __policy, _InRng&& __in_rng, _OutRng&& __out_rng, ::std::size_t __n, _InitType __init,
_BinaryOperation __bin_op, _UnaryOp __unary_op)
{
using _ValueType = ::std::uint16_t;
Expand All @@ -587,7 +567,7 @@ struct __parallel_copy_if_static_single_group_submitter<_Size, _ElemsPerItem, _W

constexpr ::std::uint32_t __elems_per_wg = _ElemsPerItem * _WGSize;

sycl::buffer<_Size> __res(sycl::range<1>(1));
__result_and_scratch_storage<std::decay_t<_Policy>, _Size> __result{__policy, 0};

auto __event = __policy.queue().submit([&](sycl::handler& __hdl) {
oneapi::dpl::__ranges::__require_access(__hdl, __in_rng, __out_rng);
Expand All @@ -596,10 +576,12 @@ struct __parallel_copy_if_static_single_group_submitter<_Size, _ElemsPerItem, _W
// predicate on each element of the input range. The second half stores the index of the output
// range to copy elements of the input range.
auto __lacc = __dpl_sycl::__local_accessor<_ValueType>(sycl::range<1>{__elems_per_wg * 2}, __hdl);
auto __res_acc = __res.template get_access<access_mode::write>(__hdl);
auto __res_acc = __result.__get_result_acc(__hdl);

__hdl.parallel_for<_ScanKernelName...>(
sycl::nd_range<1>(_WGSize, _WGSize), [=](sycl::nd_item<1> __self_item) {
auto __res_ptr =
__result_and_scratch_storage<_Policy, _Size>::__get_usm_or_buffer_accessor_ptr(__res_acc);
const auto& __group = __self_item.get_group();
const auto& __subgroup = __self_item.get_sub_group();
// This kernel is only launched for sizes less than 2^16
Expand Down Expand Up @@ -654,11 +636,11 @@ struct __parallel_copy_if_static_single_group_submitter<_Size, _ElemsPerItem, _W
if (__item_id == 0)
{
// Add predicate of last element to account for the scan's exclusivity
__res_acc[0] = __lacc[__elems_per_wg + __n - 1] + __lacc[__n - 1];
__res_ptr[0] = __lacc[__elems_per_wg + __n - 1] + __lacc[__n - 1];
}
});
});
return __future(__event, __res);
return __future(__event, __result);
}
};

Expand All @@ -677,6 +659,13 @@ __parallel_transform_scan_single_group(oneapi::dpl::__internal::__device_backend
// Specialization for devices that have a max work-group size of 1024
constexpr ::std::uint16_t __targeted_wg_size = 1024;

using _ValueType = typename _InitType::__value_type;

// Although we do not actually need result storage in this case, we need to construct
// a placeholder here to match the return type of the non-single-work-group implementation
using _TempStorage = __result_and_scratch_storage<std::decay_t<_ExecutionPolicy>, _ValueType>;
_TempStorage __dummy_result_and_scratch{__exec, 0};

if (__max_wg_size >= __targeted_wg_size)
{
auto __single_group_scan_f = [&](auto __size_constant) {
Expand All @@ -686,8 +675,9 @@ __parallel_transform_scan_single_group(oneapi::dpl::__internal::__device_backend
oneapi::dpl::__internal::__dpl_ceiling_div(__size, __wg_size);
const bool __is_full_group = __n == __wg_size;

sycl::event __event;
if (__is_full_group)
return __parallel_transform_scan_static_single_group_submitter<
__event = __parallel_transform_scan_static_single_group_submitter<
_Inclusive::value, __num_elems_per_item, __wg_size,
/* _IsFullGroup= */ true,
oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider<__scan_single_wg_kernel<
Expand All @@ -697,7 +687,7 @@ __parallel_transform_scan_single_group(oneapi::dpl::__internal::__device_backend
::std::forward<_ExecutionPolicy>(__exec), std::forward<_InRng>(__in_rng),
std::forward<_OutRng>(__out_rng), __n, __init, __binary_op, __unary_op);
else
return __parallel_transform_scan_static_single_group_submitter<
__event = __parallel_transform_scan_static_single_group_submitter<
_Inclusive::value, __num_elems_per_item, __wg_size,
/* _IsFullGroup= */ false,
oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider<__scan_single_wg_kernel<
Expand All @@ -706,6 +696,7 @@ __parallel_transform_scan_single_group(oneapi::dpl::__internal::__device_backend
/* _IsFullGroup= */ ::std::false_type, _Inclusive, _CustomName>>>()(
::std::forward<_ExecutionPolicy>(__exec), std::forward<_InRng>(__in_rng),
std::forward<_OutRng>(__out_rng), __n, __init, __binary_op, __unary_op);
return __future(__event, __dummy_result_and_scratch);
};
if (__n <= 16)
return __single_group_scan_f(std::integral_constant<::std::uint16_t, 16>{});
Expand Down Expand Up @@ -735,9 +726,11 @@ __parallel_transform_scan_single_group(oneapi::dpl::__internal::__device_backend
using _DynamicGroupScanKernel = oneapi::dpl::__par_backend_hetero::__internal::__kernel_name_provider<
__par_backend_hetero::__scan_single_wg_dynamic_kernel<_BinaryOperation, _CustomName>>;

return __parallel_transform_scan_dynamic_single_group_submitter<_Inclusive::value, _DynamicGroupScanKernel>()(
::std::forward<_ExecutionPolicy>(__exec), std::forward<_InRng>(__in_rng), std::forward<_OutRng>(__out_rng),
__n, __init, __binary_op, __unary_op, __max_wg_size);
auto __event =
__parallel_transform_scan_dynamic_single_group_submitter<_Inclusive::value, _DynamicGroupScanKernel>()(
std::forward<_ExecutionPolicy>(__exec), std::forward<_InRng>(__in_rng),
std::forward<_OutRng>(__out_rng), __n, __init, __binary_op, __unary_op, __max_wg_size);
return __future(__event, __dummy_result_and_scratch);
}
}

Expand Down Expand Up @@ -807,8 +800,7 @@ __parallel_transform_scan(oneapi::dpl::__internal::__device_backend_tag __backen
_NoAssign __no_assign_op;
_NoOpFunctor __get_data_op;

return __future(
__parallel_transform_scan_base(
return __parallel_transform_scan_base(
__backend_tag, ::std::forward<_ExecutionPolicy>(__exec), ::std::forward<_Range1>(__in_rng),
::std::forward<_Range2>(__out_rng), __binary_op, __init,
// local scan
Expand All @@ -820,8 +812,7 @@ __parallel_transform_scan(oneapi::dpl::__internal::__device_backend_tag __backen
_NoAssign, _Assigner, _NoOpFunctor, unseq_backend::__no_init_value<_Type>>{
__binary_op, _NoOpFunctor{}, __no_assign_op, __assign_op, __get_data_op},
// global scan
unseq_backend::__global_scan_functor<_Inclusive, _BinaryOperation, _InitType>{__binary_op, __init})
.event());
unseq_backend::__global_scan_functor<_Inclusive, _BinaryOperation, _InitType>{__binary_op, __init});
}

template <typename _SizeType>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,8 @@ struct __result_and_scratch_storage
}

public:
__result_and_scratch_storage(_ExecutionPolicy& __exec, ::std::size_t __scratch_n)
template <typename _Policy>
__result_and_scratch_storage(_Policy&& __exec, ::std::size_t __scratch_n)
: __exec{__exec}, __scratch_n{__scratch_n}, __use_USM_host{__use_USM_host_allocations(__exec.queue())},
__supports_USM_device{__use_USM_allocations(__exec.queue())}
{
Expand Down
Loading

0 comments on commit c8a274a

Please sign in to comment.