Skip to content

Commit

Permalink
[Disco] Reduce Process/ThreadSession message queue reads and writes (a…
Browse files Browse the repository at this point in the history
…pache#16817)

This PR reduces the number of reads and writes for the message queue
of ProcessSession and ThreadSession in Disco by caching all the data
to read/write.

The message queue in ThreadSession prior to this PR grabs the mutex
for multiple times for a batch of data to read/write. This PR enables
to read/write data from/to a local buffer first, and then read/write
from/to the critical region together. This reduces the number of
grabbing mutex to once.

The message queue in ProcessSession prior to this PR reads/writes
the inter-process pipe for multiple times for a batch of data.
This PR uses a local buffer to cache all the data first, and then
issues a single read/write from/to the pipe, and effectively reduces
the number of reads/writes to the pipe, which may causes extra
system overhead.
  • Loading branch information
MasterJH5574 authored and thaisacs committed Apr 3, 2024
1 parent 25b76c0 commit 7d928a8
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 27 deletions.
53 changes: 43 additions & 10 deletions src/runtime/disco/process_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,38 +36,71 @@
namespace tvm {
namespace runtime {

class DiscoPipeMessageQueue : private ::tvm::support::Pipe,
private DiscoProtocol<DiscoPipeMessageQueue> {
class DiscoPipeMessageQueue : private dmlc::Stream, private DiscoProtocol<DiscoPipeMessageQueue> {
public:
explicit DiscoPipeMessageQueue(int64_t handle) : ::tvm::support::Pipe(handle) {}
explicit DiscoPipeMessageQueue(int64_t handle) : pipe_(handle) {}

~DiscoPipeMessageQueue() = default;

void Send(const TVMArgs& args) {
RPCReference::ReturnPackedSeq(args.values, args.type_codes, args.num_args, this);
CommitSendAndNotifyEnqueue();
}

TVMArgs Recv() {
{
this->RecycleAll();
uint64_t packet_nbytes = 0;
RPCCode code = RPCCode::kReturn;
this->Read(&packet_nbytes);
this->Read(&code);
}
DequeueNextPacket();
TVMValue* values = nullptr;
int* type_codes = nullptr;
int num_args = 0;
RPCReference::RecvPackedSeq(&values, &type_codes, &num_args, this);
return TVMArgs(values, type_codes, num_args);
}

protected:
void CommitSendAndNotifyEnqueue() {
pipe_.Write(write_buffer_.data(), write_buffer_.size());
write_buffer_.clear();
}

void DequeueNextPacket() {
uint64_t packet_nbytes = 0;
int read_size = pipe_.Read(&packet_nbytes, sizeof(packet_nbytes));
ICHECK_EQ(read_size, sizeof(packet_nbytes))
<< "Pipe closed without proper shutdown. Please make sure to explicitly call "
"`Session::Shutdown`";
read_buffer_.resize(packet_nbytes);
pipe_.Read(read_buffer_.data(), packet_nbytes);
read_offset_ = 0;
this->RecycleAll();
RPCCode code = RPCCode::kReturn;
this->Read(&code);
}

size_t Read(void* data, size_t size) final {
std::memcpy(data, read_buffer_.data() + read_offset_, size);
read_offset_ += size;
ICHECK_LE(read_offset_, read_buffer_.size());
return size;
}

void Write(const void* data, size_t size) final {
size_t cur_size = write_buffer_.size();
write_buffer_.resize(cur_size + size);
std::memcpy(write_buffer_.data() + cur_size, data, size);
}

using dmlc::Stream::Read;
using dmlc::Stream::ReadArray;
using dmlc::Stream::Write;
using dmlc::Stream::WriteArray;
friend struct RPCReference;
friend struct DiscoProtocol<DiscoPipeMessageQueue>;

// The read/write buffer will only be accessed by the producer thread.
std::string write_buffer_;
std::string read_buffer_;
size_t read_offset_ = 0;
support::Pipe pipe_;
};

class DiscoProcessChannel final : public DiscoChannel {
Expand Down
48 changes: 31 additions & 17 deletions src/runtime/disco/threaded_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ class DiscoThreadedMessageQueue : private dmlc::Stream,
public:
void Send(const TVMArgs& args) {
RPCReference::ReturnPackedSeq(args.values, args.type_codes, args.num_args, this);
NotifyEnqueue();
CommitSendAndNotifyEnqueue();
}

TVMArgs Recv() {
WaitDequeue();
DequeueNextPacket();
TVMValue* values = nullptr;
int* type_codes = nullptr;
int num_args = 0;
Expand All @@ -55,43 +55,51 @@ class DiscoThreadedMessageQueue : private dmlc::Stream,
}

protected:
void NotifyEnqueue() {
void CommitSendAndNotifyEnqueue() {
bool need_notify = false;
{
std::lock_guard<std::mutex> lock{mutex_};
++msg_cnt_;
ring_buffer_.Write(write_buffer_.data(), write_buffer_.size());
need_notify = dequeue_waiting_;
}
condition_.notify_one();
if (need_notify) {
condition_.notify_one();
}
write_buffer_.clear();
}

void WaitDequeue() {
void DequeueNextPacket() {
{
std::unique_lock<std::mutex> lock(mutex_);
dequeue_waiting_ = true;
condition_.wait(lock, [this] { return msg_cnt_.load() > 0; });
dequeue_waiting_ = false;
--msg_cnt_;
uint64_t packet_nbytes = 0;
ring_buffer_.Read(&packet_nbytes, sizeof(packet_nbytes));
read_buffer_.resize(packet_nbytes);
ring_buffer_.Read(read_buffer_.data(), packet_nbytes);
read_offset_ = 0;
}
this->RecycleAll();
uint64_t packet_nbytes = 0;
RPCCode code = RPCCode::kReturn;
this->Read(&packet_nbytes);
this->Read(&code);
}

void MessageStart(uint64_t packet_nbytes) {
std::lock_guard<std::mutex> lock(mutex_);
size_t n = ring_buffer_.bytes_available();
n += packet_nbytes + sizeof(uint64_t);
this->ring_buffer_.Reserve(n);
}
void MessageStart(uint64_t packet_nbytes) {}

size_t Read(void* data, size_t size) final {
std::lock_guard<std::mutex> lock(mutex_);
ring_buffer_.Read(data, size);
std::memcpy(data, read_buffer_.data() + read_offset_, size);
read_offset_ += size;
ICHECK_LE(read_offset_, read_buffer_.size());
return size;
}

void Write(const void* data, size_t size) final {
std::lock_guard<std::mutex> lock(mutex_);
ring_buffer_.Write(data, size);
size_t cur_size = write_buffer_.size();
write_buffer_.resize(cur_size + size);
std::memcpy(write_buffer_.data() + cur_size, data, size);
}

using dmlc::Stream::Read;
Expand All @@ -101,6 +109,12 @@ class DiscoThreadedMessageQueue : private dmlc::Stream,
friend struct RPCReference;
friend struct DiscoProtocol<DiscoThreadedMessageQueue>;

// The read/write buffer will only be accessed by the producer thread.
std::string write_buffer_;
std::string read_buffer_;
size_t read_offset_ = 0;
bool dequeue_waiting_ = false;

std::mutex mutex_;
std::atomic<int> msg_cnt_{0};
std::condition_variable condition_;
Expand Down

0 comments on commit 7d928a8

Please sign in to comment.