Skip to content

Commit

Permalink
chore: add db_slice lock to protect segments from preemptions (#3406)
Browse files Browse the repository at this point in the history
DastTable::Traverse is error prone when the callback passed preempts because the segment might change. This is problematic and we need atomicity while traversing segments with preemption. The fix is to add Traverse in DbSlice and protect the traversal via ThreadLocalMutex.

* add ConditionFlag to DbSlice
* add Traverse in DbSlice and protect it with the ConditionFlag
* remove condition flag from snapshot
* remove condition flag from streamer

---------

Signed-off-by: kostas <[email protected]>
  • Loading branch information
kostasrim authored Jul 30, 2024
1 parent f536f8a commit aa02070
Show file tree
Hide file tree
Showing 13 changed files with 65 additions and 71 deletions.
21 changes: 21 additions & 0 deletions src/server/common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -452,4 +452,25 @@ RandomPick UniquePicksGenerator::Generate() {
return max_index;
}

ThreadLocalMutex::ThreadLocalMutex() {
shard_ = EngineShard::tlocal();
}

ThreadLocalMutex::~ThreadLocalMutex() {
DCHECK_EQ(EngineShard::tlocal(), shard_);
}

void ThreadLocalMutex::lock() {
DCHECK_EQ(EngineShard::tlocal(), shard_);
util::fb2::NoOpLock noop_lk_;
cond_var_.wait(noop_lk_, [this]() { return !flag_; });
flag_ = true;
}

void ThreadLocalMutex::unlock() {
DCHECK_EQ(EngineShard::tlocal(), shard_);
flag_ = false;
cond_var_.notify_one();
}

} // namespace dfly
41 changes: 7 additions & 34 deletions src/server/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -365,45 +365,18 @@ struct ConditionFlag {
};

// Helper class used to guarantee atomicity between serialization of buckets
class ConditionGuard {
class ThreadLocalMutex {
public:
explicit ConditionGuard(ConditionFlag* enclosing) : enclosing_(enclosing) {
util::fb2::NoOpLock noop_lk_;
enclosing_->cond_var.wait(noop_lk_, [this]() { return !enclosing_->flag; });
enclosing_->flag = true;
}

~ConditionGuard() {
enclosing_->flag = false;
enclosing_->cond_var.notify_one();
}
ThreadLocalMutex();
~ThreadLocalMutex();

private:
ConditionFlag* enclosing_;
};

class LocalBlockingCounter {
public:
void lock() {
++mutating_;
}

void unlock() {
DCHECK(mutating_ > 0);
--mutating_;
if (mutating_ == 0) {
cond_var_.notify_one();
}
}

void Wait() {
util::fb2::NoOpLock noop_lk_;
cond_var_.wait(noop_lk_, [this]() { return mutating_ == 0; });
}
void lock();
void unlock();

private:
EngineShard* shard_;
util::fb2::CondVarAny cond_var_;
size_t mutating_ = 0;
bool flag_ = false;
};

} // namespace dfly
15 changes: 7 additions & 8 deletions src/server/db_slice.cc
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,7 @@ void DbSlice::FlushSlotsFb(const cluster::SlotSet& slot_ids) {
PrimeTable::Cursor cursor;
uint64_t i = 0;
do {
PrimeTable::Cursor next = pt->Traverse(cursor, del_entry_cb);
PrimeTable::Cursor next = Traverse(pt, cursor, del_entry_cb);
++i;
cursor = next;
if (i % 100 == 0) {
Expand Down Expand Up @@ -1149,7 +1149,7 @@ void DbSlice::ExpireAllIfNeeded() {

ExpireTable::Cursor cursor;
do {
cursor = db.expire.Traverse(cursor, cb);
cursor = Traverse(&db.expire, cursor, cb);
} while (cursor);
}
}
Expand All @@ -1160,7 +1160,6 @@ uint64_t DbSlice::RegisterOnChange(ChangeCallback cb) {

void DbSlice::FlushChangeToEarlierCallbacks(DbIndex db_ind, Iterator it, uint64_t upper_bound) {
FetchedItemsRestorer fetched_restorer(&fetched_items_);
std::unique_lock<LocalBlockingCounter> lk(block_counter_);

uint64_t bucket_version = it.GetVersion();
// change_cb_ is ordered by version.
Expand All @@ -1184,7 +1183,7 @@ void DbSlice::FlushChangeToEarlierCallbacks(DbIndex db_ind, Iterator it, uint64_

//! Unregisters the callback.
void DbSlice::UnregisterOnChange(uint64_t id) {
block_counter_.Wait();
std::unique_lock lk(local_mu_);
auto it = find_if(change_cb_.begin(), change_cb_.end(),
[id](const auto& cb) { return cb.first == id; });
CHECK(it != change_cb_.end());
Expand Down Expand Up @@ -1216,13 +1215,13 @@ auto DbSlice::DeleteExpiredStep(const Context& cntx, unsigned count) -> DeleteEx

unsigned i = 0;
for (; i < count / 3; ++i) {
db.expire_cursor = db.expire.Traverse(db.expire_cursor, cb);
db.expire_cursor = Traverse(&db.expire, db.expire_cursor, cb);
}

// continue traversing only if we had strong deletion rate based on the first sample.
if (result.deleted * 4 > result.traversed) {
for (; i < count; ++i) {
db.expire_cursor = db.expire.Traverse(db.expire_cursor, cb);
db.expire_cursor = Traverse(&db.expire, db.expire_cursor, cb);
}
}

Expand Down Expand Up @@ -1388,7 +1387,7 @@ void DbSlice::ClearOffloadedEntries(absl::Span<const DbIndex> indices, const DbT
// Delete all tiered entries
PrimeTable::Cursor cursor;
do {
cursor = db_ptr->prime.Traverse(cursor, [&](PrimeIterator it) {
cursor = Traverse(&db_ptr->prime, cursor, [&](PrimeIterator it) {
if (it->second.IsExternal()) {
tiered_storage->Delete(index, &it->second);
} else if (it->second.HasStashPending()) {
Expand Down Expand Up @@ -1515,7 +1514,7 @@ void DbSlice::CallChangeCallbacks(DbIndex id, std::string_view key, const Change

DVLOG(2) << "Running callbacks for key " << key << " in dbid " << id;
FetchedItemsRestorer fetched_restorer(&fetched_items_);
std::unique_lock<LocalBlockingCounter> lk(block_counter_);
std::unique_lock lk(local_mu_);

const size_t limit = change_cb_.size();
auto ccb = change_cb_.begin();
Expand Down
23 changes: 16 additions & 7 deletions src/server/db_slice.h
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,20 @@ class DbSlice {
void PerformDeletion(Iterator del_it, DbTable* table);
void PerformDeletion(PrimeIterator del_it, DbTable* table);

// Provides access to the internal lock of db_slice for flows that serialize
// entries with preemption and need to synchronize with Traverse below which
// acquires the same lock.
ThreadLocalMutex* GetSerializationMutex() {
return &local_mu_;
}

// Wrapper around DashTable::Traverse that allows preemptions
template <typename Cb, typename DashTable>
PrimeTable::Cursor Traverse(DashTable* pt, PrimeTable::Cursor cursor, Cb&& cb) {
std::unique_lock lk(local_mu_);
return pt->Traverse(cursor, std::forward<Cb>(cb));
}

private:
void PreUpdate(DbIndex db_ind, Iterator it, std::string_view key);
void PostUpdate(DbIndex db_ind, Iterator it, std::string_view key, size_t orig_size);
Expand Down Expand Up @@ -550,13 +564,8 @@ class DbSlice {

void CallChangeCallbacks(DbIndex id, std::string_view key, const ChangeReq& cr) const;

// We need this because registered callbacks might yield. If RegisterOnChange
// gets called after we preempt while iterating over the registered callbacks
// (let's say in FlushChangeToEarlierCallbacks) we will get UB, because we pushed
// into a vector which might get resized, invalidating the iterators that are being
// used by the preempted FlushChangeToEarlierCallbacks. LocalBlockingCounter
// protects us against this case.
mutable LocalBlockingCounter block_counter_;
// Used to provide exclusive access while Traversing segments
mutable ThreadLocalMutex local_mu_;
ShardId shard_id_;
uint8_t caching_mode_ : 1;

Expand Down
2 changes: 1 addition & 1 deletion src/server/debugcmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ void DoBuildObjHist(EngineShard* shard, ConnectionContext* cntx, ObjHistMap* obj
continue;
PrimeTable::Cursor cursor;
do {
cursor = dbt->prime.Traverse(cursor, [&](PrimeIterator it) {
cursor = db_slice.Traverse(&dbt->prime, cursor, [&](PrimeIterator it) {
unsigned obj_type = it->second.ObjType();
auto& hist_ptr = (*obj_hist_map)[obj_type];
if (!hist_ptr) {
Expand Down
2 changes: 1 addition & 1 deletion src/server/engine_shard_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ bool EngineShard::DoDefrag() {
uint64_t attempts = 0;

do {
cur = prime_table->Traverse(cur, [&](PrimeIterator it) {
cur = slice.Traverse(prime_table, cur, [&](PrimeIterator it) {
// for each value check whether we should move it because it
// seats on underutilized page of memory, and if so, do it.
bool did = it->second.DefragIfNeeded(threshold);
Expand Down
5 changes: 3 additions & 2 deletions src/server/generic_family.cc
Original file line number Diff line number Diff line change
Expand Up @@ -582,8 +582,9 @@ void OpScan(const OpArgs& op_args, const ScanOpts& scan_opts, uint64_t* cursor,
auto [prime_table, expire_table] = db_slice.GetTables(op_args.db_cntx.db_index);
string scratch;
do {
cur = prime_table->Traverse(
cur, [&](PrimeIterator it) { cnt += ScanCb(op_args, it, scan_opts, &scratch, vec); });
cur = db_slice.Traverse(prime_table, cur, [&](PrimeIterator it) {
cnt += ScanCb(op_args, it, scan_opts, &scratch, vec);
});
} while (cur && cnt < scan_opts.limit);

VLOG(1) << "OpScan " << db_slice.shard_id() << " cursor: " << cur.value();
Expand Down
6 changes: 1 addition & 5 deletions src/server/journal/streamer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,7 @@ void RestoreStreamer::Run() {
return;

bool written = false;
cursor = pt->Traverse(cursor, [&](PrimeTable::bucket_iterator it) {
ConditionGuard guard(&bucket_ser_);

cursor = db_slice_->Traverse(pt, cursor, [&](PrimeTable::bucket_iterator it) {
db_slice_->FlushChangeToEarlierCallbacks(0 /*db_id always 0 for cluster*/,
DbSlice::Iterator::FromPrime(it), snapshot_version_);
if (WriteBucket(it)) {
Expand Down Expand Up @@ -313,8 +311,6 @@ bool RestoreStreamer::WriteBucket(PrimeTable::bucket_iterator it) {
void RestoreStreamer::OnDbChange(DbIndex db_index, const DbSlice::ChangeReq& req) {
DCHECK_EQ(db_index, 0) << "Restore migration only allowed in cluster mode in db0";

ConditionGuard guard(&bucket_ser_);

PrimeTable* table = db_slice_->GetTables(0).first;

if (const PrimeTable::bucket_iterator* bit = req.update()) {
Expand Down
2 changes: 0 additions & 2 deletions src/server/journal/streamer.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,6 @@ class RestoreStreamer : public JournalStreamer {
cluster::SlotSet my_slots_;
bool fiber_cancelled_ = false;
bool snapshot_finished_ = false;

ConditionFlag bucket_ser_;
};

} // namespace dfly
10 changes: 3 additions & 7 deletions src/server/snapshot.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ void SliceSnapshot::IterateBucketsFb(const Cancellation* cll, bool send_full_syn
return;

PrimeTable::Cursor next =
pt->Traverse(cursor, absl::bind_front(&SliceSnapshot::BucketSaveCb, this));
db_slice_->Traverse(pt, cursor, absl::bind_front(&SliceSnapshot::BucketSaveCb, this));
cursor = next;
PushSerializedToChannel(false);

Expand Down Expand Up @@ -253,8 +253,6 @@ void SliceSnapshot::IterateBucketsFb(const Cancellation* cll, bool send_full_syn
}

bool SliceSnapshot::BucketSaveCb(PrimeIterator it) {
ConditionGuard guard(&bucket_ser_);

++stats_.savecb_calls;

auto check = [&](auto v) {
Expand Down Expand Up @@ -364,8 +362,6 @@ bool SliceSnapshot::PushSerializedToChannel(bool force) {
}

void SliceSnapshot::OnDbChange(DbIndex db_index, const DbSlice::ChangeReq& req) {
ConditionGuard guard(&bucket_ser_);

PrimeTable* table = db_slice_->GetTables(db_index).first;
const PrimeTable::bucket_iterator* bit = req.update();

Expand All @@ -390,7 +386,7 @@ void SliceSnapshot::OnJournalEntry(const journal::JournalItem& item, bool await)
// To enable journal flushing to sync after non auto journal command is executed we call
// TriggerJournalWriteToSink. This call uses the NOOP opcode with await=true. Since there is no
// additional journal change to serialize, it simply invokes PushSerializedToChannel.
ConditionGuard guard(&bucket_ser_);
std::unique_lock lk(*db_slice_->GetSerializationMutex());
if (item.opcode != journal::Op::NOOP) {
serializer_->WriteJournalEntry(item.data);
}
Expand All @@ -403,7 +399,7 @@ void SliceSnapshot::OnJournalEntry(const journal::JournalItem& item, bool await)
}

void SliceSnapshot::CloseRecordChannel() {
ConditionGuard guard(&bucket_ser_);
std::unique_lock lk(*db_slice_->GetSerializationMutex());

CHECK(!serialize_bucket_running_);
// Make sure we close the channel only once with a CAS check.
Expand Down
2 changes: 0 additions & 2 deletions src/server/snapshot.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,6 @@ class SliceSnapshot {
size_t savecb_calls = 0;
size_t keys_total = 0;
} stats_;

ConditionFlag bucket_ser_;
};

} // namespace dfly
5 changes: 4 additions & 1 deletion tests/dragonfly/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ def __init__(self, params: DflyParams, args):
if threads > 1:
self.args["num_shards"] = threads - 1

# Add 1 byte limit for big values
self.args["serialization_max_chunk_size"] = 1

def __del__(self):
assert self.proc == None

Expand Down Expand Up @@ -163,7 +166,7 @@ def stop(self, kill=False):
proc.kill()
else:
proc.terminate()
proc.communicate(timeout=15)
proc.communicate(timeout=120)
# if the return code is 0 it means normal termination
# if the return code is negative it means termination by signal
# if the return code is positive it means abnormal exit
Expand Down
2 changes: 1 addition & 1 deletion tests/dragonfly/replication_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ async def test_replication_all(
)

# Wait for all replicas to transition into stable sync
async with async_timeout.timeout(20):
async with async_timeout.timeout(240):
await wait_for_replicas_state(*c_replicas)

# Stop streaming data once every replica is in stable sync
Expand Down

0 comments on commit aa02070

Please sign in to comment.