Skip to content

Commit

Permalink
merged with working version
Browse files Browse the repository at this point in the history
  • Loading branch information
emjotde committed Feb 12, 2017
2 parents 0d64384 + a0493f4 commit 10bafa3
Showing 1 changed file with 22 additions and 10 deletions.
32 changes: 22 additions & 10 deletions src/training/graph_group.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,15 @@ class AsyncGraphGroup : public GraphGroup {
std::vector<Ptr<ExpressionGraph>> graphs_;

std::mutex sync_;
//std::unique_ptr<std::mutex[]> shardSync_;
std::vector<std::mutex> shardSync_;

std::vector<Tensor> params_;
std::vector<Ptr<TensorAllocator> > paramsAlloc_;

std::vector<Tensor> grads_;
std::vector<Ptr<TensorAllocator> > gradsAlloc_;
std::vector<Ptr<TensorAllocator>> gradsAlloc_;

std::vector< Ptr<OptimizerBase> > shardOpt_;
std::vector<Ptr<OptimizerBase>> shardOpt_;

int shardSize_;

Expand All @@ -66,11 +65,20 @@ class AsyncGraphGroup : public GraphGroup {

// @TODO read guard on parameters
int pos = 0;

std::vector<std::thread> threads;
for (int idx = 0; idx < devices_.size(); idx++) {
std::lock_guard<std::mutex> guard( shardSync_[idx] );
oldParams->subtensor(pos , params_[idx]->size())->copyFrom(params_[idx]);
threads.emplace_back( std::thread( [=](int idx, int pos) {
//individual mutex per-shard
std::lock_guard<std::mutex> guard( shardSync_[idx] );
oldParams->subtensor(pos , params_[idx]->size())->copyFrom(params_[idx]);
}, idx, pos) );

pos += shardSize_;
}
for (auto &&t : threads) {
t.join();
}
}

void pushGradients(Tensor newGrads) {
Expand All @@ -79,18 +87,22 @@ class AsyncGraphGroup : public GraphGroup {
}
else {
// add instead of copy?
std::vector<std::thread> threads;
int pos = 0;
for(int idx = 0; idx < devices_.size(); idx++) {
auto task = [=](int idx, int pos) {
for (int idx = 0; idx < devices_.size(); idx++) {
threads.emplace_back( std::thread([=](int idx, int pos) {
//individual mutex per-shard
std::lock_guard<std::mutex> guard( shardSync_[idx] );
grads_[idx]->copyFrom( newGrads->subtensor(pos , grads_[idx]->size() ) );
shardOpt_[idx]->update(params_[idx], grads_[idx]);

cudaDeviceSynchronize();
};
std::thread(task, idx, pos).detach();
} , idx, pos) );

pos += shardSize_;
}
for(auto&& t : threads)
t.join();
}
}

Expand Down Expand Up @@ -193,7 +205,7 @@ class AsyncGraphGroup : public GraphGroup {
: GraphGroup(options),
builder_{New<Builder>(options_)},
devices_{options_->get<std::vector<size_t>>("device")},
pool_{devices_.size(), devices_.size() },
pool_{devices_.size(), devices_.size()},
shardSync_{devices_.size()} {

for(auto device : devices_) {
Expand Down

0 comments on commit 10bafa3

Please sign in to comment.