From ec02b9f26310425506ce8bc7afaf40a974991bf7 Mon Sep 17 00:00:00 2001 From: RobinTF <83676088+RobinTF@users.noreply.github.com> Date: Wed, 20 Nov 2024 16:35:45 +0100 Subject: [PATCH] Use robuster approach for thread safety --- src/engine/Join.cpp | 60 +++++++++++++++++++++++++++++---------------- 1 file changed, 39 insertions(+), 21 deletions(-) diff --git a/src/engine/Join.cpp b/src/engine/Join.cpp index 8a1b4e6632..8db4ce55ef 100644 --- a/src/engine/Join.cpp +++ b/src/engine/Join.cpp @@ -21,7 +21,6 @@ #include "util/Exception.h" #include "util/HashMap.h" #include "util/JoinAlgorithms/JoinAlgorithms.h" -#include "util/ThreadSafeQueue.h" using std::endl; using std::string; @@ -418,14 +417,21 @@ ProtoResult Join::createResult( std::function> auto action, std::vector permutation) const { if (requestedLaziness) { - return { - [](auto innerAction, auto innerPermutation) -> Result::Generator { - ad_utility::data_structures::ThreadSafeQueue - queue{1}; - ad_utility::JThread thread{ - [&queue, &innerAction, &innerPermutation]() { - auto addValue = [&queue, &innerPermutation]( + return {[](auto innerAction, auto innerPermutation) -> Result::Generator { + std::atomic_flag write = true; + std::variant + storage; + ad_utility::JThread thread{[&write, &storage, &innerAction, + &innerPermutation]() { + auto writeValue = [&write, &storage](auto value) noexcept { + storage = std::move(value); + write.clear(); + write.notify_one(); + }; + auto addValue = [&write, &writeValue, &innerPermutation]( IdTable& idTable, LocalVocab& localVocab) { + AD_CORRECTNESS_CHECK(write.test()); if (idTable.size() < CHUNK_SIZE) { return; } @@ -434,30 +440,42 @@ ProtoResult Join::createResult( if (!innerPermutation.empty()) { pair.idTable_.setColumnSubset(innerPermutation); } - queue.push(std::move(pair)); + writeValue(std::move(pair)); + // Wait until we are allowed to write again. + write.wait(false); }; try { auto finalValue = innerAction(addValue); + AD_CORRECTNESS_CHECK(write.test()); if (!finalValue.idTable_.empty()) { if (!innerPermutation.empty()) { finalValue.idTable_.setColumnSubset(innerPermutation); } - queue.push(std::move(finalValue)); + writeValue(std::move(finalValue)); + // Wait until we are allowed to write again. + write.wait(false); } - queue.finish(); + writeValue(std::monostate{}); } catch (...) { - queue.pushException(std::current_exception()); + writeValue(std::current_exception()); } }}; - while (true) { - auto val = queue.pop(); - if (!val.has_value()) { - break; - } - co_yield val.value(); - } - }(std::move(action), std::move(permutation)), - resultSortedOn()}; + while (true) { + // Wait for read phase. + write.wait(true); + if (std::holds_alternative(storage)) { + break; + } + if (std::holds_alternative(storage)) { + std::rethrow_exception(std::get(storage)); + } + co_yield std::get(storage); + // Initiate write phase. + write.test_and_set(); + write.notify_one(); + } + }(std::move(action), std::move(permutation)), + resultSortedOn()}; } else { auto [idTable, localVocab] = action(ad_utility::noop); if (!permutation.empty()) {