Skip to content

Commit

Permalink
[GraphBolt][io_uring] Use RAII to ensure queues are returned. (dmlc#7680
Browse files Browse the repository at this point in the history
)
  • Loading branch information
mfbalin authored Aug 10, 2024
1 parent 90c26be commit c86776d
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 33 deletions.
39 changes: 8 additions & 31 deletions graphbolt/src/cnumpy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -171,27 +171,19 @@ torch::Tensor OnDiskNpyArray::IndexSelectIOUringImpl(torch::Tensor index) {
// Indicator for index error.
std::atomic<int> error_flag{};
std::atomic<int64_t> work_queue{};
std::atomic_flag exiting_first = ATOMIC_FLAG_INIT;
// Consume a slot so that parallel_for is called only if there are available
// queues.
semaphore_.acquire();
std::atomic<int> num_semaphore_acquisitions = 1;
graphbolt::parallel_for_each_interop(0, num_thread_, 1, [&](int thread_id) {
// Construct a QueueAndBufferAcquirer object so that the worker threads can
// share the available queues and buffers.
QueueAndBufferAcquirer queue_source(this);
graphbolt::parallel_for_each_interop(0, num_thread_, 1, [&](int) {
// The completion queue might contain 4 * kGroupSize while we may submit
// 4 * kGroupSize more. No harm in overallocation here.
CircularQueue<ReadRequest> read_queue(8 * kGroupSize);
int64_t num_submitted = 0;
int64_t num_completed = 0;
{
// We consume a slot from the semaphore to use a queue.
semaphore_.acquire();
num_semaphore_acquisitions.fetch_add(1, std::memory_order_relaxed);
std::lock_guard lock(available_queues_mtx_);
TORCH_CHECK(!available_queues_.empty());
thread_id = available_queues_.back();
available_queues_.pop_back();
}
auto &io_uring_queue = io_uring_queue_[thread_id];
auto [acquired_queue_handle, my_read_buffer2] = queue_source.get();
auto &io_uring_queue = acquired_queue_handle.get();
// Capturing structured binding is available only in C++20, so we rename.
auto my_read_buffer = my_read_buffer2;
auto submit_fn = [&](int64_t submission_minimum_batch_size) {
if (read_queue.Size() < submission_minimum_batch_size) return;
TORCH_CHECK( // Check for sqe overflow.
Expand All @@ -207,7 +199,6 @@ torch::Tensor OnDiskNpyArray::IndexSelectIOUringImpl(torch::Tensor index) {
read_queue.PopN(submitted);
}
};
auto my_read_buffer = ReadBuffer(thread_id);
for (int64_t read_buffer_slot = 0; true;) {
auto request_read_buffer = [&]() {
return my_read_buffer + (aligned_length_ + block_size_) *
Expand Down Expand Up @@ -307,21 +298,7 @@ torch::Tensor OnDiskNpyArray::IndexSelectIOUringImpl(torch::Tensor index) {
io_uring_cq_advance(&io_uring_queue, num_cqes_seen);
num_completed += num_cqes_seen;
}
{
// We give back the slot we used.
std::lock_guard lock(available_queues_mtx_);
available_queues_.push_back(thread_id);
}
// If this is the first thread exiting, release the master thread's ticket
// as well by releasing 2 slots. Otherwise, release 1 slot.
const auto releasing = exiting_first.test_and_set() ? 1 : 2;
semaphore_.release(releasing);
num_semaphore_acquisitions.fetch_add(-releasing, std::memory_order_relaxed);
});
// If any of the worker threads exit early without being able to release the
// semaphore, we make sure to release it for them in the main thread.
semaphore_.release(
num_semaphore_acquisitions.load(std::memory_order_relaxed));
const auto ret_val = error_flag.load(std::memory_order_relaxed);
switch (ret_val) {
case 0: // Successful.
Expand Down
65 changes: 63 additions & 2 deletions graphbolt/src/cnumpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,10 @@
#include <cstdlib>
#include <cstring>
#include <cuda/std/semaphore>
#include <fstream>
#include <iostream>
#include <memory>
#include <mutex>
#include <string>
#include <utility>
#include <vector>

namespace graphbolt {
Expand Down Expand Up @@ -147,6 +146,68 @@ class OnDiskNpyArray : public torch::CustomClassHolder {
static inline std::mutex available_queues_mtx_; // available_queues_ mutex.
static inline std::vector<int> available_queues_;

struct QueueAndBufferAcquirer {
struct UniqueQueue {
UniqueQueue(QueueAndBufferAcquirer* acquirer, int thread_id)
: acquirer_(acquirer), thread_id_(thread_id) {}
UniqueQueue(const UniqueQueue&) = delete;
UniqueQueue& operator=(const UniqueQueue&) = delete;

~UniqueQueue() {
{
// We give back the slot we used.
std::lock_guard lock(available_queues_mtx_);
available_queues_.push_back(thread_id_);
}
// If this is the first thread exiting, release the master thread's
// ticket as well by releasing 2 slots. Otherwise, release 1 slot.
const auto releasing = acquirer_->exiting_first_.test_and_set() ? 1 : 2;
semaphore_.release(releasing);
acquirer_->num_acquisitions_.fetch_add(
-releasing, std::memory_order_relaxed);
}

::io_uring& get() const { return io_uring_queue_[thread_id_]; }

private:
QueueAndBufferAcquirer* acquirer_;
int thread_id_;
};

QueueAndBufferAcquirer(OnDiskNpyArray* array) : array_(array) {
semaphore_.acquire();
}

~QueueAndBufferAcquirer() {
// If any of the worker threads exit early without being able to release
// the semaphore, we make sure to release it for them in the main thread.
const auto releasing = num_acquisitions_.load(std::memory_order_relaxed);
semaphore_.release(releasing);
TORCH_CHECK(releasing == 0, "An io_uring worker thread didn't not exit.");
}

std::pair<UniqueQueue, char*> get() {
// We consume a slot from the semaphore to use a queue.
semaphore_.acquire();
num_acquisitions_.fetch_add(1, std::memory_order_relaxed);
const auto thread_id = [&] {
std::lock_guard lock(available_queues_mtx_);
TORCH_CHECK(!available_queues_.empty());
const auto thread_id = available_queues_.back();
available_queues_.pop_back();
return thread_id;
}();
return {
std::piecewise_construct, std::make_tuple(this, thread_id),
std::make_tuple(array_->ReadBuffer(thread_id))};
}

private:
const OnDiskNpyArray* array_;
std::atomic_flag exiting_first_ = ATOMIC_FLAG_INIT;
std::atomic<int> num_acquisitions_ = 1;
};

#endif // HAVE_LIBRARY_LIBURING
};

Expand Down

0 comments on commit c86776d

Please sign in to comment.