diff --git a/src/training/graph_group.h b/src/training/graph_group.h index 9e118b1d2..f6b37e7d5 100644 --- a/src/training/graph_group.h +++ b/src/training/graph_group.h @@ -45,16 +45,15 @@ class AsyncGraphGroup : public GraphGroup { std::vector> graphs_; std::mutex sync_; - //std::unique_ptr shardSync_; std::vector shardSync_; std::vector params_; std::vector > paramsAlloc_; std::vector grads_; - std::vector > gradsAlloc_; + std::vector> gradsAlloc_; - std::vector< Ptr > shardOpt_; + std::vector> shardOpt_; int shardSize_; @@ -66,11 +65,20 @@ class AsyncGraphGroup : public GraphGroup { // @TODO read guard on parameters int pos = 0; + + std::vector threads; for (int idx = 0; idx < devices_.size(); idx++) { - std::lock_guard guard( shardSync_[idx] ); - oldParams->subtensor(pos , params_[idx]->size())->copyFrom(params_[idx]); + threads.emplace_back( std::thread( [=](int idx, int pos) { + //individual mutex per-shard + std::lock_guard guard( shardSync_[idx] ); + oldParams->subtensor(pos , params_[idx]->size())->copyFrom(params_[idx]); + }, idx, pos) ); + pos += shardSize_; } + for (auto &&t : threads) { + t.join(); + } } void pushGradients(Tensor newGrads) { @@ -79,18 +87,22 @@ class AsyncGraphGroup : public GraphGroup { } else { // add instead of copy? + std::vector threads; int pos = 0; - for(int idx = 0; idx < devices_.size(); idx++) { - auto task = [=](int idx, int pos) { + for (int idx = 0; idx < devices_.size(); idx++) { + threads.emplace_back( std::thread([=](int idx, int pos) { //individual mutex per-shard std::lock_guard guard( shardSync_[idx] ); grads_[idx]->copyFrom( newGrads->subtensor(pos , grads_[idx]->size() ) ); shardOpt_[idx]->update(params_[idx], grads_[idx]); + cudaDeviceSynchronize(); - }; - std::thread(task, idx, pos).detach(); + } , idx, pos) ); + pos += shardSize_; } + for(auto&& t : threads) + t.join(); } } @@ -193,7 +205,7 @@ class AsyncGraphGroup : public GraphGroup { : GraphGroup(options), builder_{New(options_)}, devices_{options_->get>("device")}, - pool_{devices_.size(), devices_.size() }, + pool_{devices_.size(), devices_.size()}, shardSync_{devices_.size()} { for(auto device : devices_) {