diff --git a/include/oneapi/dpl/pstl/omp/parallel_scan.h b/include/oneapi/dpl/pstl/omp/parallel_scan.h index c3bc022cb2e..d4abe26aff0 100644 --- a/include/oneapi/dpl/pstl/omp/parallel_scan.h +++ b/include/oneapi/dpl/pstl/omp/parallel_scan.h @@ -79,6 +79,37 @@ __downsweep(_Index __i, _Index __m, _Index __tilesize, _Tp* __r, _Index __lastsi } } +template +std::pair<_Index, _Index> +__downsweep(_Index __i, _Index __m, _Index __tilesize, _Tp* __r, _Index __lastsize, _Tp __initial, _Cp __combine, + _Sp __scan, _Index __n_out) +{ + std::pair<_Index, _Index> __res{}; + if (__m == 1) + { + if(__initial < __n_out) + __scan(__i * __tilesize, __lastsize, __initial, __n_out - __initial); + } + else + { + const _Index __k = __split(__m); + std::pair<_Index, _Index> __res_1{}, __res_2{}; + oneapi::dpl::__omp_backend::__parallel_invoke_body( + [=, &__res_1] { + __res_1 = oneapi::dpl::__omp_backend::__downsweep(__i, __k, __tilesize, __r, __tilesize, __initial, __combine, + __scan, __n_out); + }, + // Assumes that __combine never throws. + // TODO: Consider adding a requirement for user functors to be constant. + [=, &__combine, &__res_2] { + __res_2 = oneapi::dpl::__omp_backend::__downsweep(__i + __k, __m - __k, __tilesize, __r + __k, __lastsize, + __combine(__initial, __r[__k - 1]), __combine, __scan, __n_out); + }); + __res = std::make_pair(__res_1.first + __res_2.first, __res_1.second + __res_2.second); + } + return __res; +} + template void @@ -107,6 +138,35 @@ __parallel_strict_scan_body(_ExecutionPolicy&& __exec, _Index __n, _Tp __initial __initial, __combine, __scan); } + +template +void +__parallel_strict_scan_body(_ExecutionPolicy&& __exec, _Index __n, _Tp __initial, _Rp __reduce, _Cp __combine, + _Sp __scan, _Ap __apex, _Index __n_out) +{ + _Index __p = omp_get_num_threads(); + const _Index __slack = 4; + _Index __tilesize = (__n - 1) / (__slack * __p) + 1; + _Index __m = (__n - 1) / __tilesize; + __buffer<_ExecutionPolicy, _Tp> __buf(::std::forward<_ExecutionPolicy>(__exec), __m + 1); + _Tp* __r = __buf.get(); + + oneapi::dpl::__omp_backend::__upsweep(_Index(0), _Index(__m + 1), __tilesize, __r, __n - __m * __tilesize, __reduce, + __combine); + + std::size_t __k = __m + 1; + _Tp __t = __r[__k - 1]; + while ((__k &= __k - 1)) + { + __t = __combine(__r[__k - 1], __t); + } + + auto __res = oneapi::dpl::__omp_backend::__downsweep(_Index(0), _Index(__m + 1), __tilesize, __r, __n - __m * __tilesize, + __initial, __combine, __scan, __n_out); + __apex(__res.first, __res.second); +} + template void __parallel_strict_scan(oneapi::dpl::__internal::__omp_backend_tag, _ExecutionPolicy&& __exec, _Index __n, _Tp __initial, @@ -143,6 +203,42 @@ __parallel_strict_scan(oneapi::dpl::__internal::__omp_backend_tag, _ExecutionPol } } +template +void +__parallel_strict_scan(oneapi::dpl::__internal::__omp_backend_tag, _ExecutionPolicy&& __exec, _Index __n, _Tp __initial, + _Rp __reduce, _Cp __combine, _Sp __scan, _Ap __apex, _Index __n_out) +{ + if(__n_out == 0) + return; + else if(__n_out < 0) + __n_out = __n; + + if (__n <= __default_chunk_size) + { + if (__n) + { + auto __res = __scan(_Index(0), __n, __initial, __n_out); + __apex(__res.first, __res.second); + } + return; + } + + if (omp_in_parallel()) + { + oneapi::dpl::__omp_backend::__parallel_strict_scan_body(::std::forward<_ExecutionPolicy>(__exec), __n, + __initial, __reduce, __combine, __scan, __apex, __n_out); + } + else + { + _PSTL_PRAGMA(omp parallel) + _PSTL_PRAGMA(omp single nowait) + { + oneapi::dpl::__omp_backend::__parallel_strict_scan_body(::std::forward<_ExecutionPolicy>(__exec), __n, + __initial, __reduce, __combine, __scan, __apex, __n_out); + } + } +} + } // namespace __omp_backend } // namespace dpl } // namespace oneapi diff --git a/include/oneapi/dpl/pstl/parallel_backend_serial.h b/include/oneapi/dpl/pstl/parallel_backend_serial.h index 6acd4b617f9..50ab0a2cd01 100644 --- a/include/oneapi/dpl/pstl/parallel_backend_serial.h +++ b/include/oneapi/dpl/pstl/parallel_backend_serial.h @@ -86,6 +86,23 @@ __parallel_strict_scan(oneapi::dpl::__internal::__serial_backend_tag, _Execution __scan(_Index(0), __n, __initial); } +template +void +__parallel_strict_scan(oneapi::dpl::__internal::__serial_backend_tag, _ExecutionPolicy&&, _Index __n, _Tp __initial, + _Rp __reduce, _Cp __combine, _Sp __scan, _Ap __apex, _Index __n_out) +{ + if(__n_out == 0) + return; + else if(__n_out < 0) + __n_out = __n; + + if (__n) + { + auto __res = __scan(_Index(0), __n, __initial, __n_out); + __apex(__res.first, __res.second); + } +} + template _Tp __parallel_transform_scan(oneapi::dpl::__internal::__serial_backend_tag, _ExecutionPolicy&&, _Index __n, _UnaryOp, diff --git a/include/oneapi/dpl/pstl/parallel_backend_tbb.h b/include/oneapi/dpl/pstl/parallel_backend_tbb.h index fd3f7f41a5d..4a441711fa7 100644 --- a/include/oneapi/dpl/pstl/parallel_backend_tbb.h +++ b/include/oneapi/dpl/pstl/parallel_backend_tbb.h @@ -316,6 +316,27 @@ __upsweep(_Index __i, _Index __m, _Index __tilesize, _Tp* __r, _Index __lastsize } } +template +void +__downsweep(_Index __i, _Index __m, _Index __tilesize, _Tp* __r, _Index __lastsize, _Tp __initial, _Cp __combine, + _Sp __scan) +{ + if (__m == 1) + __scan(__i * __tilesize, __lastsize, __initial); + else + { + const _Index __k = __split(__m); + tbb::parallel_invoke( + [=] { __tbb_backend::__downsweep(__i, __k, __tilesize, __r, __tilesize, __initial, __combine, __scan); }, + // Assumes that __combine never throws. + //TODO: Consider adding a requirement for user functors to be constant. + [=, &__combine] { + __tbb_backend::__downsweep(__i + __k, __m - __k, __tilesize, __r + __k, __lastsize, + __combine(__initial, __r[__k - 1]), __combine, __scan); + }); + } +} + template std::pair<_Index, _Index> __downsweep(_Index __i, _Index __m, _Index __tilesize, _Tp* __r, _Index __lastsize, _Tp __initial, _Cp __combine, @@ -353,6 +374,7 @@ __downsweep(_Index __i, _Index __m, _Index __tilesize, _Tp* __r, _Index __lastsi // combine(s1,s2) -> s -- return merged sum // apex(s) -- do any processing necessary between reduce and scan. // scan(i,len,initial) -- perform scan over i:len starting with initial. +// [n_out -- limit for output range] // The initial range 0:n is partitioned into consecutive subranges. // reduce and scan are each called exactly once per subrange. // Thus callers can rely upon side effects in reduce. @@ -363,9 +385,51 @@ __downsweep(_Index __i, _Index __m, _Index __tilesize, _Tp* __r, _Index __lastsi template void __parallel_strict_scan(oneapi::dpl::__internal::__tbb_backend_tag, _ExecutionPolicy&& __exec, _Index __n, _Tp __initial, - _Rp __reduce, _Cp __combine, _Sp __scan, _Ap __apex, int __n_out = -1) + _Rp __reduce, _Cp __combine, _Sp __scan, _Ap __apex) +{ + tbb::this_task_arena::isolate([=, &__combine]() { + if (__n > 1) + { + _Index __p = tbb::this_task_arena::max_concurrency(); + const _Index __slack = 4; + _Index __tilesize = (__n - 1) / (__slack * __p) + 1; + _Index __m = (__n - 1) / __tilesize; + __tbb_backend::__buffer<_ExecutionPolicy, _Tp> __buf(__exec, __m + 1); + _Tp* __r = __buf.get(); + __tbb_backend::__upsweep(_Index(0), _Index(__m + 1), __tilesize, __r, __n - __m * __tilesize, __reduce, + __combine); + + // When __apex is a no-op and __combine has no side effects, a good optimizer + // should be able to eliminate all code between here and __apex. + // Alternatively, provide a default value for __apex that can be + // recognized by metaprogramming that conditionlly executes the following. + size_t __k = __m + 1; + _Tp __t = __r[__k - 1]; + while ((__k &= __k - 1)) + __t = __combine(__r[__k - 1], __t); + __apex(__combine(__initial, __t)); + __tbb_backend::__downsweep(_Index(0), _Index(__m + 1), __tilesize, __r, __n - __m * __tilesize, __initial, + __combine, __scan); + return; + } + // Fewer than 2 elements in sequence, or out of memory. Handle has single block. + _Tp __sum = __initial; + if (__n) + __sum = __combine(__sum, __reduce(_Index(0), __n)); + __apex(__sum); + if (__n) + __scan(_Index(0), __n, __initial); + }); +} + +template +void +__parallel_strict_scan(oneapi::dpl::__internal::__tbb_backend_tag, _ExecutionPolicy&& __exec, _Index __n, _Tp __initial, + _Rp __reduce, _Cp __combine, _Sp __scan, _Ap __apex, _Index __n_out) { - if(__n_out < 0) + if(__n_out == 0) + return; + else if(__n_out < 0) __n_out = __n; tbb::this_task_arena::isolate([=, &__combine]() { if (__n > 1) @@ -394,11 +458,7 @@ __parallel_strict_scan(oneapi::dpl::__internal::__tbb_backend_tag, _ExecutionPol return; } // Fewer than 2 elements in sequence, or out of memory. Handle has single block. - _Tp __sum = __initial; - if (__n && __n_out > 0) - __sum = __combine(__sum, __reduce(_Index(0), __n)); - //__apex(__sum); - if (__n && __n_out > 0) + if (__n) { auto __res = __scan(_Index(0), __n, __initial, __n_out); __apex(__res.first, __res.second);