Skip to content

Commit

Permalink
Merge pull request #5144 from ye-luo/update-ParallelExecutorSTD
Browse files Browse the repository at this point in the history
ParallelExecutor no longer needs std::ref
  • Loading branch information
prckent authored Aug 23, 2024
2 parents ea7ecc5 + e30e813 commit 8e075b4
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 36 deletions.
21 changes: 2 additions & 19 deletions src/Concurrency/ParallelExecutorSTD.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,34 +26,17 @@

namespace qmcplusplus
{
/** TaskWrapper code from
* https://stackoverflow.com/questions/48268322/wrap-stdthread-call-function
*
* Required to forward an arbitrary var args function
*/
template<typename F>
struct TaskWrapper
{
F f;

template<typename... T>
void operator()(T&&... args)
{
f(std::forward<T>(args)...);
}
};

/** implements parallel tasks executed by STD threads. One task one thread mapping.
*/
template<>
template<typename F, typename... Args>
void ParallelExecutor<Executor::STD_THREADS>::operator()(int num_tasks, F&& f, Args&&... args)
{
std::vector<std::thread> threads(num_tasks_);
std::vector<std::thread> threads(num_tasks);

for (int task_id = 0; task_id < num_tasks; ++task_id)
{
threads[task_id] = std::thread(TaskWrapper<F>{std::forward<F>(f)}, task_id, std::forward<Args>(args)...);
threads[task_id] = std::thread(f, task_id, std::ref(std::forward<Args>(args))...);
}

for (int task_id = 0; task_id < num_tasks; ++task_id)
Expand Down
8 changes: 4 additions & 4 deletions src/Concurrency/tests/test_ParallelExecutorOPENMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ TEST_CASE("ParallelExecutor<OPENMP> function case", "[concurrency]")
const int num_threads = omp_get_max_threads();
ParallelExecutor<Executor::OPENMP> test_block;
int count(0);
test_block(num_threads, TestTaskOMP, std::ref(count));
test_block(num_threads, TestTaskOMP, count);
REQUIRE(count == num_threads);
}

Expand All @@ -44,7 +44,7 @@ TEST_CASE("ParallelExecutor<OPENMP> lambda case", "[concurrency]")
#pragma omp atomic update
c++;
},
std::ref(count));
count);
REQUIRE(count == num_threads);
}

Expand All @@ -55,10 +55,10 @@ TEST_CASE("ParallelExecutor<OPENMP> nested case", "[concurrency]")
int count(0);
auto nested_tasks = [num_threads](int task_id, int& my_count) {
ParallelExecutor<Executor::OPENMP> test_block2;
test_block2(num_threads, TestTaskOMP, std::ref(my_count));
test_block2(num_threads, TestTaskOMP, my_count);
};
#ifdef _OPENMP
REQUIRE_THROWS_WITH(test_block(num_threads, nested_tasks, std::ref(count)),
REQUIRE_THROWS_WITH(test_block(num_threads, nested_tasks, count),
Catch::Contains("ParallelExecutor should not be used for nested openmp threading"));
#endif
}
Expand Down
8 changes: 4 additions & 4 deletions src/Concurrency/tests/test_ParallelExecutorSTD.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ TEST_CASE("ParallelExecutor<STD> function case", "[concurrency]")
int num_threads = 8;
ParallelExecutor<Executor::STD_THREADS> test_block;
std::atomic<int> count(0);
test_block(num_threads, TestTask, std::ref(count));
test_block(num_threads, TestTask, count);
REQUIRE(count == 8);
}

Expand All @@ -36,7 +36,7 @@ TEST_CASE("ParallelExecutor<STD> lambda case", "[concurrency]")
ParallelExecutor<Executor::STD_THREADS> test_block;
std::atomic<int> count(0);
test_block(
num_threads, [](int id, std::atomic<int>& my_count) { ++my_count; }, std::ref(count));
num_threads, [](int id, std::atomic<int>& my_count) { ++my_count; }, count);
REQUIRE(count == 8);
}

Expand All @@ -49,9 +49,9 @@ TEST_CASE("ParallelExecutor<STD> nested case", "[concurrency]")
num_threads,
[num_threads](int task_id, std::atomic<int>& my_count) {
ParallelExecutor<Executor::STD_THREADS> test_block2;
test_block2(num_threads, TestTask, std::ref(my_count));
test_block2(num_threads, TestTask, my_count);
},
std::ref(count));
count);
REQUIRE(count == 64);
}

Expand Down
6 changes: 2 additions & 4 deletions src/QMCDrivers/DMC/DMCBatched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -473,8 +473,7 @@ bool DMCBatched::run()
ScopedTimer local_timer(timers_.init_walkers_timer);
ParallelExecutor<> section_start_task;
auto step_contexts_refs = getContextForStepsRefs();
section_start_task(crowds_.size(), initialLogEvaluation, std::ref(crowds_), std::ref(step_contexts_refs),
serializing_crowd_walkers_);
section_start_task(crowds_.size(), initialLogEvaluation, crowds_, step_contexts_refs, serializing_crowd_walkers_);

FullPrecRealType energy, variance;
population_.measureGlobalEnergyVariance(*myComm, energy, variance);
Expand Down Expand Up @@ -516,8 +515,7 @@ bool DMCBatched::run()

dmc_state.step = step;
dmc_state.global_step = global_step;
crowd_task(crowds_.size(), runDMCStep, dmc_state, timers_, dmc_timers_, std::ref(step_contexts_),
std::ref(crowds_));
crowd_task(crowds_.size(), runDMCStep, dmc_state, timers_, dmc_timers_, step_contexts_, crowds_);

{
const int iter = block * steps_per_block_ + step;
Expand Down
8 changes: 3 additions & 5 deletions src/QMCDrivers/VMC/VMCBatched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -340,8 +340,7 @@ bool VMCBatched::run()
ScopedTimer local_timer(timers_.init_walkers_timer);
ParallelExecutor<> section_start_task;
auto step_contexts_refs = getContextForStepsRefs();
section_start_task(crowds_.size(), initialLogEvaluation, std::ref(crowds_), std::ref(step_contexts_refs),
serializing_crowd_walkers_);
section_start_task(crowds_.size(), initialLogEvaluation, crowds_, step_contexts_refs, serializing_crowd_walkers_);
print_mem("VMCBatched after initialLogEvaluation", app_summary());
if (qmcdriver_input_.get_measure_imbalance())
measureImbalance("InitialLogEvaluation");
Expand Down Expand Up @@ -371,8 +370,7 @@ bool VMCBatched::run()
for (int step = 0; step < qmcdriver_input_.get_warmup_steps(); ++step)
{
ScopedTimer local_timer(timers_.run_steps_timer);
crowd_task(crowds_.size(), runWarmupStep, vmc_state, std::ref(timers_), std::ref(step_contexts_),
std::ref(crowds_));
crowd_task(crowds_.size(), runWarmupStep, vmc_state, timers_, step_contexts_, crowds_);
}

app_log() << "VMC Warmup completed in " << std::setprecision(4) << warmup_timer.elapsed() << " secs" << std::endl;
Expand Down Expand Up @@ -405,7 +403,7 @@ bool VMCBatched::run()
ScopedTimer local_timer(timers_.run_steps_timer);
vmc_state.step = step;
vmc_state.global_step = global_step;
crowd_task(crowds_.size(), runVMCStep, vmc_state, timers_, std::ref(step_contexts_), std::ref(crowds_));
crowd_task(crowds_.size(), runVMCStep, vmc_state, timers_, step_contexts_, crowds_);

if (collect_samples_)
{
Expand Down

0 comments on commit 8e075b4

Please sign in to comment.