Skip to content

Commit

Permalink
cherry-pick qtc optimization updates by jackz (#723)
Browse files Browse the repository at this point in the history
* update to poplar 2.3.1 & remove useless code

* update log message & add open log option

* Change wait to use condition variable, and change the contextqueues to use mutex

make the format of codes to meet requirements

Correct the == to = for set the buffer_ element to nullptr

Co-authored-by: yanwei <[email protected]>
Co-authored-by: gcuser <[email protected]>
  • Loading branch information
3 people authored Dec 7, 2021
1 parent 97f5a65 commit 6e06f70
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 99 deletions.
157 changes: 85 additions & 72 deletions ODLA/platforms/odla_popart/odla_pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,87 +100,101 @@ void QManager::deleteQ() {
}
}

void ContextQueues::init(std::size_t capacity) {
buffer_ = new odla_context[capacity];
if (nullptr == buffer_)
throw std::invalid_argument(
"ContextQueues::init failed to create buffer for queue with capacity "
": " +
std::to_string(capacity));
for (int i = 0; i < capacity; i++) buffer_[i] = nullptr;
capacity_ = capacity;
}

void ContextQueues::put(odla_context ctx) {
popart::logging::info("ContextQueues::put -> ctx: {}.", ctx);
{
std::lock_guard<std::mutex> guard(write_mutex);
write_queue->push(ctx);
write_wait_queue->push(
ctx); // put the ctx to input & wait_output queue in same order.
}
std::lock_guard<std::mutex> guard(queue_mutex_);
auto new_tail = (tail_ + 1) % capacity_;
if (new_tail == wait_) // last item as the boundary
throw std::out_of_range("ContextQueues::put the queue is full");
buffer_[tail_] = ctx;
tail_ = new_tail;
} // Make sure the queue mutex released here.
// Notify the batch wait we got a batch data
std::unique_lock<std::mutex> lock(batch_wait_mutex_);
batch_wait_cv_.notify_one();
}

odla_context ContextQueues::get_input_context() {
if (nullptr != input_ctx) {
return input_ctx;
}
if (!read_queue->empty())
input_ctx = read_queue->front();
else // read queue is empty, switch it
{
std::lock_guard<std::mutex> guard(write_mutex);
std::queue<odla_context>* tmp = read_queue;
read_queue = write_queue;
write_queue = tmp;
popart::logging::info(
"switched the read write queue, now read queu size is: {}.",
read_queue->size());
if (!read_queue->empty())
input_ctx = read_queue->front();
else { // create a zero data if there's not data in the 2 queues
input_ctx = create_empty_odla_context();
write_wait_queue->push(
input_ctx); // Make it wait for the return for the empty data
}
}

return input_ctx;
throw std::runtime_error(
"ContextQueues::get_input_context we should never call this.");
}

odla_context ContextQueues::get_output_context() {
if (output_ctx != nullptr) return output_ctx;
if (!read_wait_queue->empty())
output_ctx = read_wait_queue->front();
else {
// switch the wait queue
std::lock_guard<std::mutex> guard(
write_mutex); // Use the same mutex to save 1 mutex lock for every put
std::queue<odla_context>* tmp = read_wait_queue;
read_wait_queue = write_wait_queue;
write_wait_queue = tmp;
popart::logging::info(
"switched the read write wait queue, now read queu size is: {}.",
read_wait_queue->size());
}
if (!read_wait_queue->empty()) output_ctx = read_wait_queue->front();
if (nullptr == output_ctx)
if (wait_ == tail_)
throw std::out_of_range(
"*** FATAL ERROR *** No context in the queue when an output gotten");
return output_ctx;
"ContextQueues: queue is empty when get_output_context()");
return buffer_[wait_];
}

void ContextQueues::pop_input(odla_context ctx) {
popart::logging::info("ContextQueues::pop_input with ctx: {}", input_ctx);
if (!input_ctx->deletable()) // Only pop the non zero ctx, the zero one not in
// the queue
read_queue->pop();
input_ctx = nullptr;
popart::logging::info("ContextQueues::pop_input with ctx: {}", ctx);
assert(ctx == buffer_[head_]);
head_ = (head_ + 1) % capacity_;
}

void ContextQueues::pop_output(
odla_context
ctx) { // Never delete a context here, only operate on the queue
// wait_output_queue.pop();
if (!read_wait_queue
->empty()) // There must be an element when all tensor written
read_wait_queue->pop(); // pop the first one from the read wait queue
else {
throw std::out_of_range(
"*** FATAL ERROR *** no ctx in read_wait_queue when pop_output called");
}
output_ctx = nullptr;
void ContextQueues::pop_output(odla_context ctx) {
if (wait_ == head_)
throw std::runtime_error("Got out before input all read on index " +
std::to_string(wait_));
assert(ctx == buffer_[wait_]);
buffer_[wait_] = nullptr; // clear the buffer to nullptr;
wait_ = (wait_ + 1) % capacity_;
}

odla_context ContextQueues::get_ctx_by_tensor(const popart::TensorId& id) {
std::uint32_t idx = -1;
odla_context ctx = nullptr;
// Get current index
auto iter = tensor_to_idx_.find(id);
if (tensor_to_idx_.end() == iter)
idx = 0;
else
idx = iter->second;
// Check whether is empty, tail alwasy points to the first element not written
std::uint32_t cnt = 0;
popart::logging::info("ContextQueues::get_ctx_by_tensor queue has size: {}",
size());
while (idx == tail_) {
auto locked_tail = tail_;
{
std::lock_guard<std::mutex> guard(queue_mutex_);
locked_tail = tail_;
}
if (idx == locked_tail) {
std::unique_lock<std::mutex> lock(batch_wait_mutex_);
batch_wait_cv_.wait_for(lock, std::chrono::milliseconds(5));
}
if (idx != tail_) break;
popart::logging::info(
"[get_ctx_by_tensor] the queue is empty when read, add zero contexts");
if (cnt++ > 1)
throw std::runtime_error(
"[get_ctx_by_tensor] Must get one ctx in 2 fetch, as empty one "
"created.");
odla_context zero_ctx = create_empty_odla_context();
put(zero_ctx);
}
// The lock ensured the ctx has been written
ctx = buffer_[idx];
popart::logging::info(
"ContextQueues::get_ctx_by_tensor tensorid:{} got ctx:{} with idx: {}",
id, ctx, idx);
// Update the index of the tensor to next
tensor_to_idx_[id] = (idx + 1) % capacity_;
return ctx;
}
/*------------------------------------------------------------------------*/
LockFreeQueue::LockFreeQueue() : head_(0), tail_(0), wait_(0) {}

Expand Down Expand Up @@ -222,6 +236,9 @@ void LockFreeQueue::put(odla_context ctx) {
if (cnt++ > 5)
throw std::runtime_error("LockFreeQueue::put No one should stop me");
}
// Notify the batch wait we got a batch data
std::unique_lock<std::mutex> lock(batch_wait_mutex_);
batch_wait_cv_.notify_one();
popart::logging::info(
"[LockFreeQueue::put] Set the idx: {} for ctx: {} in {} times.", idx, ctx,
cnt);
Expand Down Expand Up @@ -290,15 +307,11 @@ odla_context LockFreeQueue::get_ctx_by_tensor(const popart::TensorId& id) {
popart::logging::info("LockFreeQueue::get_ctx_by_tensor queue has size: {}",
size());
while (idx == tail_.load()) {
bool got_data = false;
for (int i = 0; i < 5; i++) {
std::this_thread::sleep_for(std::chrono::milliseconds(1));
if (idx != tail_.load()) {
got_data = true;
break;
}
if (idx == tail_.load()) {
std::unique_lock<std::mutex> lock(batch_wait_mutex_);
batch_wait_cv_.wait_for(lock, std::chrono::milliseconds(5));
}
if (got_data) break;
if (idx != tail_.load()) break;
popart::logging::info(
"[get_ctx_by_tensor] the queue is empty when read, add zero contexts");
if (cnt++ > 1)
Expand Down
46 changes: 19 additions & 27 deletions ODLA/platforms/odla_popart/odla_pipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

#include <atomic>
#include <chrono>
#include <condition_variable>
#include <mutex>
#include <popart/stepio.hpp>
#include <queue>
Expand All @@ -48,40 +49,29 @@ class Queue {

class ContextQueues : public Queue {
private:
std::queue<odla_context> input_queue_1;
std::queue<odla_context> input_queue_2;
std::queue<odla_context> wait_output_queue_1;
std::queue<odla_context> wait_output_queue_2;
std::mutex write_mutex;
std::queue<odla_context>* read_queue;
std::queue<odla_context>* write_queue;
std::queue<odla_context>* read_wait_queue;
std::queue<odla_context>* write_wait_queue;
odla_context input_ctx; // the context which is under reading
odla_context output_ctx; // the context which is under writing
odla_context* buffer_;
std::size_t capacity_;
std::uint32_t head_;
std::uint32_t tail_;
std::uint32_t wait_;
std::map<popart::TensorId, std::uint32_t> tensor_to_idx_;
std::mutex batch_wait_mutex_;
std::condition_variable batch_wait_cv_;
std::mutex queue_mutex_; // lock the read & write

public:
ContextQueues()
: read_queue(&input_queue_1),
write_queue(&input_queue_2),
read_wait_queue(&wait_output_queue_1),
write_wait_queue(&wait_output_queue_2),
input_ctx(nullptr),
output_ctx(nullptr) {}

~ContextQueues() {}
void init(std::size_t capacity) final {}
ContextQueues() : head_(0), tail_(0), wait_(0){};
~ContextQueues() {
if (buffer_) delete[] buffer_;
}
void init(std::size_t capacity);
void put(odla_context ctx) final;
odla_context get_input_context() final;
odla_context get_ctx_by_tensor(const popart::TensorId& id) final {
return nullptr;
}
odla_context get_ctx_by_tensor(const popart::TensorId& id) final;
odla_context get_output_context() final;
void pop_input(odla_context ctx) final;
void pop_output(odla_context ctx) final;
std::size_t size() final {
return input_queue_1.size() + input_queue_2.size();
}
std::size_t size() final { return (tail_ - wait_ + capacity_) % capacity_; }
};

class LockFreeQueue : public Queue {
Expand All @@ -92,6 +82,8 @@ class LockFreeQueue : public Queue {
std::atomic<uint32_t> tail_;
std::uint32_t wait_;
std::map<popart::TensorId, std::uint32_t> tensor_to_idx_;
std::mutex batch_wait_mutex_;
std::condition_variable batch_wait_cv_;

public:
LockFreeQueue();
Expand Down

0 comments on commit 6e06f70

Please sign in to comment.