Skip to content

Commit

Permalink
address feedback
Browse files Browse the repository at this point in the history
Signed-off-by: Dan Hoeflinger <[email protected]>
  • Loading branch information
danhoeflinger committed Jan 21, 2025
1 parent 0a2847f commit ee65630
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 23 deletions.
30 changes: 15 additions & 15 deletions include/oneapi/dpl/pstl/omp/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,39 +157,39 @@ __process_chunk(const __chunk_metrics& __metrics, _Iterator __base, _Index __chu

// abstract class to allow inclusion in __enumerable_thread_local_storage as member without requiring explicit template
// instantiation of param types
template <typename _StorageType>
template <typename _ValueType>
class __construct_by_args_base
{
public:
virtual ~__construct_by_args_base() = default;
virtual std::unique_ptr<_StorageType> construct() = 0;
virtual std::unique_ptr<_ValueType> construct() = 0;
};

// Helper class to allow construction of _StorageType from a stored argument pack
template <typename _StorageType, typename... _P>
class __construct_by_args : public __construct_by_args_base<_StorageType>
// Helper class to allow construction of _ValueType from a stored argument pack
template <typename _ValueType, typename... _P>
class __construct_by_args : public __construct_by_args_base<_ValueType>
{
public:
std::unique_ptr<_StorageType>
std::unique_ptr<_ValueType>
construct() override
{
return std::apply([](_P... __arg_pack) { return std::make_unique<_StorageType>(__arg_pack...); }, __pack);
return std::apply([](_P... __arg_pack) { return std::make_unique<_ValueType>(__arg_pack...); }, __pack);
}
__construct_by_args(_P&&... __args) : __pack(std::forward<_P>(__args)...) {}

private:
const std::tuple<_P...> __pack;
};

template <typename _StorageType>
template <typename _ValueType>
struct __enumerable_thread_local_storage
{
template <typename... Args>
__enumerable_thread_local_storage(Args&&... __args) : __num_elements(0)
{
__storage_factory = std::make_unique<__construct_by_args<_StorageType, Args...>>(std::forward<Args>(__args)...);
_PSTL_PRAGMA(omp parallel)
_PSTL_PRAGMA(omp single) { __thread_specific_storage.resize(omp_get_num_threads()); }
__storage_factory = std::make_unique<__construct_by_args<_ValueType, Args...>>(std::forward<Args>(__args)...);
std::size_t __num_threads = omp_in_parallel() ? omp_get_num_threads() : omp_get_max_threads();
__thread_specific_storage.resize(__num_threads);
}

// Note: Size should not be used concurrently with parallel loops which may instantiate storage objects, as it may
Expand All @@ -202,7 +202,7 @@ struct __enumerable_thread_local_storage
return __num_elements.load();
}

_StorageType&
_ValueType&
get_with_id(std::size_t __i)
{
assert(__i < size());
Expand All @@ -226,7 +226,7 @@ struct __enumerable_thread_local_storage
return *__thread_specific_storage[__j - 1];
}

_StorageType&
_ValueType&
get_for_current_thread()
{
std::size_t __i = omp_get_thread_num();
Expand All @@ -239,9 +239,9 @@ struct __enumerable_thread_local_storage
return *__thread_specific_storage[__i];
}

std::vector<std::unique_ptr<_StorageType>> __thread_specific_storage;
std::vector<std::unique_ptr<_ValueType>> __thread_specific_storage;
std::atomic_size_t __num_elements;
std::unique_ptr<__construct_by_args_base<_StorageType>> __storage_factory;
std::unique_ptr<__construct_by_args_base<_ValueType>> __storage_factory;
};

} // namespace __omp_backend
Expand Down
8 changes: 4 additions & 4 deletions include/oneapi/dpl/pstl/parallel_backend_serial.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ __cancel_execution(oneapi::dpl::__internal::__serial_backend_tag)
{
}

template <typename _StorageType>
template <typename _ValueType>
struct __enumerable_thread_local_storage
{
template <typename... Args>
Expand All @@ -56,19 +56,19 @@ struct __enumerable_thread_local_storage
return std::size_t{1};
}

_StorageType&
_ValueType&
get_for_current_thread()
{
return __storage;
}

_StorageType&
_ValueType&
get_with_id(std::size_t /*__i*/)
{
return get_for_current_thread();
}

_StorageType __storage;
_ValueType __storage;
};

template <class _ExecutionPolicy, class _Index, class _Fp>
Expand Down
8 changes: 4 additions & 4 deletions include/oneapi/dpl/pstl/parallel_backend_tbb.h
Original file line number Diff line number Diff line change
Expand Up @@ -1307,7 +1307,7 @@ __parallel_for_each(oneapi::dpl::__internal::__tbb_backend_tag, _ExecutionPolicy
tbb::this_task_arena::isolate([&]() { tbb::parallel_for_each(__begin, __end, __f); });
}

template <typename _StorageType>
template <typename _ValueType>
struct __enumerable_thread_local_storage
{
template <typename... Args>
Expand All @@ -1321,19 +1321,19 @@ struct __enumerable_thread_local_storage
return __thread_specific_storage.size();
}

_StorageType&
_ValueType&
get_for_current_thread()
{
return __thread_specific_storage.local();
}

_StorageType&
_ValueType&
get_with_id(std::size_t __i)
{
return __thread_specific_storage.begin()[__i];
}

tbb::enumerable_thread_specific<_StorageType> __thread_specific_storage;
tbb::enumerable_thread_specific<_ValueType> __thread_specific_storage;
};

} // namespace __tbb_backend
Expand Down

0 comments on commit ee65630

Please sign in to comment.