From 3d300e9b783e46d7063c13c7c0f085d2a79144a1 Mon Sep 17 00:00:00 2001 From: xiaoxmeng Date: Thu, 29 Feb 2024 14:34:53 -0800 Subject: [PATCH] Hash join changes for probe side spilling support The hash join related changes made to support probe side spilling: hash build operator is always waiting for the hash probe to finish no matter there is pending spilled partition to restore or not. The reason is that the hash probe might trigger spilling as well. This requires the change in hash build operator and the corresponding API change in hash join bridge: setHashTable extend hash join bridge to add setSpilledHashTable API which used by hash probe operator to save the spilled table partitions in the hash join bridge. The function also clears the the spilled hash table cached in the build result in hash join bridge to release the held memory resource. extend hash join node reclaimer to reclaim from both build and probe side once. The actual probe side spilling is in followup. --- velox/exec/HashBuild.cpp | 15 +- velox/exec/HashBuild.h | 8 +- velox/exec/HashJoinBridge.cpp | 79 ++++++++--- velox/exec/HashJoinBridge.h | 29 ++-- velox/exec/HashTable.cpp | 9 +- velox/exec/HashTable.h | 8 +- velox/exec/tests/HashJoinBridgeTest.cpp | 176 ++++++++++++++++-------- velox/exec/tests/HashJoinTest.cpp | 176 ++---------------------- velox/exec/tests/HashTableTest.cpp | 11 +- 9 files changed, 236 insertions(+), 275 deletions(-) diff --git a/velox/exec/HashBuild.cpp b/velox/exec/HashBuild.cpp index a1805e69cd2e..7ec2a62c505a 100644 --- a/velox/exec/HashBuild.cpp +++ b/velox/exec/HashBuild.cpp @@ -845,8 +845,9 @@ bool HashBuild::finishHashBuild() { isInputFromSpill() ? spillConfig()->startPartitionBit : BaseHashTable::kNoSpillInputStartPartitionBit); addRuntimeStats(); - if (joinBridge_->setHashTable( - std::move(table_), std::move(spillPartitions), joinHasNullKeys_)) { + joinBridge_->setHashTable( + std::move(table_), std::move(spillPartitions), joinHasNullKeys_); + if (spillEnabled()) { intermediateStateCleared_ = true; spillGroup_->restart(); } @@ -1222,11 +1223,11 @@ void HashBuild::reclaim( } bool HashBuild::nonReclaimableState() const { - // Apart from being in the nonReclaimable section, - // its also not reclaimable if: - // 1) the hash table has been built by the last build thread (inidicated - // by state_) - // 2) the last build operator has transferred ownership of 'this' operator's + // Apart from being in the nonReclaimable section, it's also not reclaimable + // if: + // 1) the hash table has been built by the last build thread (indicated by + // state_) + // 2) the last build operator has transferred ownership of 'this operator's // intermediate state (table_ and spiller_) to itself // 3) it has completed spilling before reaching either of the previous // two states. diff --git a/velox/exec/HashBuild.h b/velox/exec/HashBuild.h index 84366f5cff2b..ae736e8ffffb 100644 --- a/velox/exec/HashBuild.h +++ b/velox/exec/HashBuild.h @@ -264,8 +264,8 @@ class HashBuild final : public Operator { // The row type used for hash table build and disk spilling. RowTypePtr tableType_; - // Used to serialize access to intermediate state variables (like 'table_' and - // 'spiller_'). This is only required when variables are accessed + // Used to serialize access to internal state including 'table_' and + // 'spiller_'. This is only required when variables are accessed // concurrently, that is, when a thread tries to close the operator while // another thread is building the hash table. Refer to 'close()' and // finishHashBuild()' for more details. @@ -316,8 +316,8 @@ class HashBuild final : public Operator { uint64_t numSpillBytes_{0}; // This can be nullptr if either spilling is not allowed or it has been - // trsnaferred to the last hash build operator while in kWaitForBuild state or - // it has been cleared to setup a new one for recursive spilling. + // transferred to the last hash build operator while in kWaitForBuild state or + // it has been cleared to set up a new one for recursive spilling. std::unique_ptr spiller_; // Used to read input from previously spilled data for restoring. diff --git a/velox/exec/HashJoinBridge.cpp b/velox/exec/HashJoinBridge.cpp index 464f7dbff100..e9a717cbcbd9 100644 --- a/velox/exec/HashJoinBridge.cpp +++ b/velox/exec/HashJoinBridge.cpp @@ -29,7 +29,7 @@ void HashJoinBridge::addBuilder() { ++numBuilders_; } -bool HashJoinBridge::setHashTable( +void HashJoinBridge::setHashTable( std::unique_ptr table, SpillPartitionSet spillPartitionSet, bool hasNullKeys) { @@ -37,7 +37,6 @@ bool HashJoinBridge::setHashTable( auto spillPartitionIdSet = toSpillPartitionIdSet(spillPartitionSet); - bool hasSpillData; std::vector promises; { std::lock_guard l(mutex_); @@ -64,12 +63,25 @@ bool HashJoinBridge::setHashTable( std::move(spillPartitionIdSet), hasNullKeys); restoringSpillPartitionId_.reset(); - - hasSpillData = !spillPartitionSets_.empty(); promises = std::move(promises_); } notify(std::move(promises)); - return hasSpillData; +} + +void HashJoinBridge::setSpilledHashTable(SpillPartitionSet spillPartitionSet) { + VELOX_CHECK( + !spillPartitionSet.empty(), "Spilled table partitions can't be empty"); + std::lock_guard l(mutex_); + VELOX_CHECK(started_); + VELOX_CHECK(buildResult_.has_value()); + VELOX_CHECK(restoringSpillShards_.empty()); + VELOX_CHECK(!restoringSpillPartitionId_.has_value()); + + for (auto& partitionEntry : spillPartitionSet) { + const auto id = partitionEntry.first; + VELOX_CHECK_EQ(spillPartitionSets_.count(id), 0); + spillPartitionSets_.emplace(id, std::move(partitionEntry.second)); + } } void HashJoinBridge::setAntiJoinHasNullKeys() { @@ -131,10 +143,8 @@ bool HashJoinBridge::probeFinished() { spillPartitionSets_.begin()->second->split(numBuilders_); VELOX_CHECK_EQ(restoringSpillShards_.size(), numBuilders_); spillPartitionSets_.erase(spillPartitionSets_.begin()); - promises = std::move(promises_); - } else { - VELOX_CHECK(promises_.empty()); } + promises = std::move(promises_); } notify(std::move(promises)); return hasSpillInput; @@ -148,15 +158,23 @@ std::optional HashJoinBridge::spillInputOrFuture( VELOX_DCHECK( !restoringSpillPartitionId_.has_value() || !buildResult_.has_value()); + // If 'buildResult_' is set, then the probe side is under processing. The + // build shall just wait. + if (buildResult_.has_value()) { + VELOX_CHECK(!restoringSpillPartitionId_.has_value()); + promises_.emplace_back("HashJoinBridge::spillInputOrFuture"); + *future = promises_.back().getSemiFuture(); + return std::nullopt; + } + + // If 'restoringSpillPartitionId_' is not set after probe side is done, then + // the join processing is all done. if (!restoringSpillPartitionId_.has_value()) { - if (spillPartitionSets_.empty()) { - return HashJoinBridge::SpillInput{}; - } else { - promises_.emplace_back("HashJoinBridge::spillInputOrFuture"); - *future = promises_.back().getSemiFuture(); - return std::nullopt; - } + VELOX_CHECK(spillPartitionSets_.empty()); + VELOX_CHECK(restoringSpillShards_.empty()); + return HashJoinBridge::SpillInput{}; } + VELOX_CHECK(!restoringSpillShards_.empty()); auto spillShard = std::move(restoringSpillShards_.back()); restoringSpillShards_.pop_back(); @@ -175,17 +193,30 @@ uint64_t HashJoinMemoryReclaimer::reclaim( uint64_t targetBytes, uint64_t maxWaitMs, memory::MemoryReclaimer::Stats& stats) { + // The flags to track if we have reclaimed from both build and probe operators + // under a hash join node. + bool hasReclaimedFromBuild{false}; + bool hasReclaimedFromProbe{false}; uint64_t reclaimedBytes{0}; pool->visitChildren([&](memory::MemoryPool* child) { VELOX_CHECK_EQ(child->kind(), memory::MemoryPool::Kind::kLeaf); - // The hash probe operator do not support memory reclaim. - if (!isHashBuildMemoryPool(*child)) { - return true; + const bool isBuild = isHashBuildMemoryPool(*child); + if (isBuild) { + if (!hasReclaimedFromBuild) { + // We just need to reclaim from any one of the hash build operator. + hasReclaimedFromBuild = true; + reclaimedBytes = child->reclaim(targetBytes, maxWaitMs, stats); + } + return !hasReclaimedFromProbe; } - // We only need to reclaim from any one of the hash build operators - // which will reclaim from all the peer hash build operators. - reclaimedBytes = child->reclaim(targetBytes, maxWaitMs, stats); - return false; + + if (!hasReclaimedFromProbe) { + // The same as build operator, we only need to reclaim from any one of the + // hash probe operator. + hasReclaimedFromProbe = true; + reclaimedBytes = child->reclaim(targetBytes, maxWaitMs, stats); + } + return !hasReclaimedFromBuild; }); return reclaimedBytes; } @@ -193,4 +224,8 @@ uint64_t HashJoinMemoryReclaimer::reclaim( bool isHashBuildMemoryPool(const memory::MemoryPool& pool) { return folly::StringPiece(pool.name()).endsWith("HashBuild"); } + +bool isHashProbeMemoryPool(const memory::MemoryPool& pool) { + return folly::StringPiece(pool.name()).endsWith("HashProbe"); +} } // namespace facebook::velox::exec diff --git a/velox/exec/HashJoinBridge.h b/velox/exec/HashJoinBridge.h index 899f8fa5c63f..1af8de0814f6 100644 --- a/velox/exec/HashJoinBridge.h +++ b/velox/exec/HashJoinBridge.h @@ -22,6 +22,10 @@ namespace facebook::velox::exec { +namespace test { +class HashJoinBridgeTestHelper; +} + /// Hands over a hash table from a multi-threaded build pipeline to a /// multi-threaded probe pipeline. This is owned by shared_ptr by all the build /// and probe Operator instances concerned. Corresponds to the Presto concept of @@ -35,15 +39,20 @@ class HashJoinBridge : public JoinBridge { /// HashBuild operators to parallelize the restoring operation. void addBuilder(); + /// Invoked by the build operator to set the built hash table. /// 'spillPartitionSet' contains the spilled partitions while building - /// 'table'. The function returns true if there is spill data to restore - /// after HashProbe operators process 'table', otherwise false. This only - /// applies if the disk spilling is enabled. - bool setHashTable( + /// 'table' which only applies if the disk spilling is enabled. + void setHashTable( std::unique_ptr table, SpillPartitionSet spillPartitionSet, bool hasNullKeys); + /// Invoked by the probe operator to set the spilled hash table while the + /// probing. The function puts the spilled table partitions into + /// 'spillPartitionSets_' stack. This only applies if the disk spilling is + /// enabled. + void setSpilledHashTable(SpillPartitionSet spillPartitionSet); + void setAntiJoinHasNullKeys(); /// Represents the result of HashBuild operators: a hash table, an optional @@ -75,8 +84,7 @@ class HashJoinBridge : public JoinBridge { /// HashBuild operators. If HashProbe operator calls this early, 'future' will /// be set to wait asynchronously, otherwise the built table along with /// optional spilling related information will be returned in HashBuildResult. - std::optional tableOrFuture( - ContinueFuture* FOLLY_NONNULL future); + std::optional tableOrFuture(ContinueFuture* future); /// Invoked by HashProbe operator after finishes probing the built table to /// set one of the previously spilled partition to restore. The HashBuild @@ -102,8 +110,7 @@ class HashJoinBridge : public JoinBridge { /// If HashBuild operator calls this early, 'future' will be set to wait /// asynchronously. If there is no more spill data to restore, then /// 'spillPartition' will be set to null in the returned SpillInput. - std::optional spillInputOrFuture( - ContinueFuture* FOLLY_NONNULL future); + std::optional spillInputOrFuture(ContinueFuture* future); private: uint32_t numBuilders_{0}; @@ -129,6 +136,8 @@ class HashJoinBridge : public JoinBridge { // This set can grow if HashBuild operator cannot load full partition in // memory and engages in recursive spilling. SpillPartitionSet spillPartitionSets_; + + friend test::HashJoinBridgeTestHelper; }; // Indicates if 'joinNode' is null-aware anti or left semi project join type and @@ -156,4 +165,8 @@ class HashJoinMemoryReclaimer final : public MemoryReclaimer { /// Returns true if 'pool' is a hash build operator's memory pool. The check is /// currently based on the pool name. bool isHashBuildMemoryPool(const memory::MemoryPool& pool); + +/// Returns true if 'pool' is a hash probe operator's memory pool. The check is +/// currently based on the pool name. +bool isHashProbeMemoryPool(const memory::MemoryPool& pool); } // namespace facebook::velox::exec diff --git a/velox/exec/HashTable.cpp b/velox/exec/HashTable.cpp index 44cd1e82e273..56f3d68376b7 100644 --- a/velox/exec/HashTable.cpp +++ b/velox/exec/HashTable.cpp @@ -726,11 +726,16 @@ void HashTable::allocateTables(uint64_t size) { } template -void HashTable::clear() { +void HashTable::clear(bool freeTable) { rows_->clear(); if (table_) { // All modes have 8 bytes per slot. - memset(table_, 0, capacity_ * sizeof(char*)); + if (freeTable) { + ::memset(table_, 0, capacity_ * sizeof(char*)); + } else { + rows_->pool()->freeContiguous(tableAllocation_); + table_ = nullptr; + } } numDistinct_ = 0; numTombstones_ = 0; diff --git a/velox/exec/HashTable.h b/velox/exec/HashTable.h index eec394caf599..4a828ce1bfde 100644 --- a/velox/exec/HashTable.h +++ b/velox/exec/HashTable.h @@ -258,9 +258,9 @@ class BaseHashTable { /// owned by 'this'. virtual int64_t allocatedBytes() const = 0; - /// Deletes any content of 'this' but does not free the memory. Can - /// be used for flushing a partial group by, for example. - virtual void clear() = 0; + /// Deletes any content of 'this'. If 'freeTable' is false, then hash table is + /// not freed which can be used for flushing a partial group by, for example. + virtual void clear(bool freeTable = false) = 0; /// Returns the capacity of the internal hash table which is number of rows /// it can stores in a group by or hash join build. @@ -498,7 +498,7 @@ class HashTable : public BaseHashTable { int32_t maxRows, char** rows) override; - void clear() override; + void clear(bool freeTable = false) override; int64_t allocatedBytes() const override { // For each row: sizeof(char*) per table entry + memory diff --git a/velox/exec/tests/HashJoinBridgeTest.cpp b/velox/exec/tests/HashJoinBridgeTest.cpp index 9ad04c331a5f..e3874da9ab79 100644 --- a/velox/exec/tests/HashJoinBridgeTest.cpp +++ b/velox/exec/tests/HashJoinBridgeTest.cpp @@ -24,6 +24,26 @@ using namespace facebook::velox; using namespace facebook::velox::exec; using facebook::velox::exec::test::TempDirectoryPath; +namespace facebook::velox::exec::test { + +class HashJoinBridgeTestHelper { + public: + static HashJoinBridgeTestHelper create(HashJoinBridge* bridge) { + return HashJoinBridgeTestHelper(bridge); + } + + std::optional& buildResult() const { + return bridge_->buildResult_; + } + + private: + explicit HashJoinBridgeTestHelper(HashJoinBridge* bridge) : bridge_(bridge) { + VELOX_CHECK_NOT_NULL(bridge_); + } + + HashJoinBridge* const bridge_; +}; + struct TestParam { int32_t numProbers{1}; int32_t numBuilders{1}; @@ -150,17 +170,20 @@ TEST_P(HashJoinBridgeTest, withoutSpill) { auto joinBridge = createJoinBridge(); // Can't call any other APIs except addBuilder() before start a join bridge // first. - ASSERT_ANY_THROW( - joinBridge->setHashTable(createFakeHashTable(), {}, false)); - ASSERT_ANY_THROW(joinBridge->setAntiJoinHasNullKeys()); - ASSERT_ANY_THROW(joinBridge->probeFinished()); - ASSERT_ANY_THROW(joinBridge->tableOrFuture(&futures[0])); - ASSERT_ANY_THROW(joinBridge->spillInputOrFuture(&futures[0])); + VELOX_ASSERT_THROW( + joinBridge->setHashTable(createFakeHashTable(), {}, false), ""); + VELOX_ASSERT_THROW(joinBridge->setAntiJoinHasNullKeys(), ""); + VELOX_ASSERT_THROW(joinBridge->probeFinished(), ""); + VELOX_ASSERT_THROW(joinBridge->tableOrFuture(&futures[0]), ""); + VELOX_ASSERT_THROW(joinBridge->spillInputOrFuture(&futures[0]), ""); + VELOX_ASSERT_THROW( + joinBridge->setSpilledHashTable(makeFakeSpillPartitionSet(0)), ""); // Can't start a bridge without any builders. - ASSERT_ANY_THROW(joinBridge->start()); + VELOX_ASSERT_THROW(joinBridge->start(), ""); joinBridge = createJoinBridge(); + auto helper = HashJoinBridgeTestHelper::create(joinBridge.get()); for (int32_t i = 0; i < numBuilders_; ++i) { joinBridge->addBuilder(); @@ -172,18 +195,20 @@ TEST_P(HashJoinBridgeTest, withoutSpill) { ASSERT_FALSE(tableOr.has_value()); ASSERT_TRUE(futures[i].valid()); } + ASSERT_FALSE(helper.buildResult().has_value()); BaseHashTable* rawTable = nullptr; if (hasNullKeys) { joinBridge->setAntiJoinHasNullKeys(); - ASSERT_ANY_THROW(joinBridge->setAntiJoinHasNullKeys()); + VELOX_ASSERT_THROW(joinBridge->setAntiJoinHasNullKeys(), ""); } else { auto table = createFakeHashTable(); rawTable = table.get(); joinBridge->setHashTable(std::move(table), {}, false); - ASSERT_ANY_THROW( - joinBridge->setHashTable(createFakeHashTable(), {}, false)); + VELOX_ASSERT_THROW( + joinBridge->setHashTable(createFakeHashTable(), {}, false), ""); } + ASSERT_TRUE(helper.buildResult().has_value()); for (int32_t i = 0; i < numProbers_; ++i) { futures[i].wait(); @@ -208,16 +233,28 @@ TEST_P(HashJoinBridgeTest, withoutSpill) { ASSERT_TRUE(tableOr.value().spillPartitionIds.empty()); } } + ASSERT_TRUE(helper.buildResult().has_value()); - // Verify builder will see no spill input. + // Verify builder will wait for probe side finish signal even if there is no + // spill input. auto inputOr = joinBridge->spillInputOrFuture(&futures[0]); + ASSERT_FALSE(inputOr.has_value()); + ASSERT_TRUE(futures[0].valid()); + + // Probe side completion. + ASSERT_FALSE(joinBridge->probeFinished()); + ASSERT_FALSE(helper.buildResult().has_value()); + + futures[0].wait(); + + futures = createEmptyFutures(1); + inputOr = joinBridge->spillInputOrFuture(&futures[0]); ASSERT_TRUE(inputOr.has_value()); ASSERT_FALSE(futures[0].valid()); ASSERT_TRUE(inputOr.value().spillPartition == nullptr); - // Probe side completion. - ASSERT_FALSE(joinBridge->probeFinished()); - ASSERT_ANY_THROW(joinBridge->probeFinished()); + VELOX_ASSERT_THROW(joinBridge->probeFinished(), ""); + ASSERT_FALSE(helper.buildResult().has_value()); } } @@ -269,15 +306,23 @@ TEST_P(HashJoinBridgeTest, withSpill) { } else { spillPartitionSet = makeFakeSpillPartitionSet(startPartitionBitOffset_); } - bool hasMoreSpill; + + bool spillByProber{false}; + bool hasMoreSpill{false}; if (spillLevel >= testData.spillLevel && testData.endWithNull) { joinBridge->setAntiJoinHasNullKeys(); hasMoreSpill = false; } else { numSpilledPartitions += spillPartitionSet.size(); - spillPartitionIdSet = toSpillPartitionIdSet(spillPartitionSet); - hasMoreSpill = joinBridge->setHashTable( - createFakeHashTable(), std::move(spillPartitionSet), false); + if (oneIn(2)) { + spillPartitionIdSet = toSpillPartitionIdSet(spillPartitionSet); + joinBridge->setHashTable( + createFakeHashTable(), std::move(spillPartitionSet), false); + } else { + spillByProber = !spillPartitionSet.empty(); + joinBridge->setHashTable(createFakeHashTable(), {}, false); + } + hasMoreSpill = numSpilledPartitions > numRestoredPartitions; } // Get built table from probe side. @@ -296,22 +341,33 @@ TEST_P(HashJoinBridgeTest, withSpill) { } // Wait for probe to complete from build side. + for (int32_t i = 0; i < numBuilders_; ++i) { + ASSERT_FALSE( + joinBridge->spillInputOrFuture(&buildFutures[i]).has_value()); + } + + if (spillByProber) { + VELOX_ASSERT_THROW( + joinBridge->setSpilledHashTable({}), + "Spilled table partitions can't be empty"); + joinBridge->setSpilledHashTable(std::move(spillPartitionSet)); + } + + // Probe table. + ASSERT_EQ(hasMoreSpill, joinBridge->probeFinished()); + // Probe can't set spilled table partitions after it finishes probe. + VELOX_ASSERT_THROW( + joinBridge->setSpilledHashTable({}), + "Spilled table partitions can't be empty"); + VELOX_ASSERT_THROW( + joinBridge->setSpilledHashTable(makeFakeSpillPartitionSet(0)), ""); + if (!hasMoreSpill) { for (int32_t i = 0; i < numBuilders_; ++i) { auto inputOr = joinBridge->spillInputOrFuture(&buildFutures[i]); ASSERT_TRUE(inputOr.has_value()); ASSERT_TRUE(inputOr.value().spillPartition == nullptr); } - } else { - for (int32_t i = 0; i < numBuilders_; ++i) { - ASSERT_FALSE( - joinBridge->spillInputOrFuture(&buildFutures[i]).has_value()); - } - } - - // Probe table. - ASSERT_EQ(hasMoreSpill, joinBridge->probeFinished()); - if (!hasMoreSpill) { break; } @@ -428,23 +484,20 @@ TEST_P(HashJoinBridgeTest, multiThreading) { tableOr = joinBridge->tableOrFuture(&tableFuture); ASSERT_TRUE(tableOr.has_value()); } - if (tableOr.value().hasNullKeys) { - break; - } - ASSERT_TRUE(tableOr.value().table != nullptr); - for (const auto& id : tableOr.value().spillPartitionIds) { - ASSERT_FALSE(spillPartitionIdSet.contains(id)); - spillPartitionIdSet.insert(id); - } - if (tableOr.value().restoredPartitionId.has_value()) { - ASSERT_TRUE(spillPartitionIdSet.contains( - tableOr.value().restoredPartitionId.value())); - spillPartitionIdSet.erase( - tableOr.value().restoredPartitionId.value()); - } - - if (spillPartitionIdSet.empty()) { - break; + if (!tableOr.value().hasNullKeys) { + ASSERT_TRUE(tableOr.value().table != nullptr); + for (const auto& id : tableOr.value().spillPartitionIds) { + ASSERT_FALSE(spillPartitionIdSet.contains(id)); + spillPartitionIdSet.insert(id); + } + if (tableOr.value().restoredPartitionId.has_value()) { + ASSERT_TRUE(spillPartitionIdSet.contains( + tableOr.value().restoredPartitionId.value())); + spillPartitionIdSet.erase( + tableOr.value().restoredPartitionId.value()); + } + } else { + spillPartitionIdSet.clear(); } // Wait for probe to finish. @@ -465,11 +518,15 @@ TEST_P(HashJoinBridgeTest, multiThreading) { probeFuture.wait(); } else { proberBarrier.reset(new BarrierState()); - ASSERT_TRUE(joinBridge->probeFinished()); + ASSERT_EQ( + joinBridge->probeFinished(), !spillPartitionIdSet.empty()); for (auto& promise : promises) { promise.setValue(); } } + if (spillPartitionIdSet.empty()) { + break; + } } }); } @@ -483,29 +540,32 @@ TEST_P(HashJoinBridgeTest, multiThreading) { } } -TEST_P(HashJoinBridgeTest, isHashBuildMemoryPool) { +TEST_P(HashJoinBridgeTest, isHashJoinMemoryPools) { auto root = memory::memoryManager()->addRootPool("isHashBuildMemoryPool"); struct { std::string poolName; - bool expectedHashBuildPool; + bool isHashBuildPool; + bool isHashProbePool; std::string debugString() const { return fmt::format( - "poolName: {}, expectedHashBuildPool: {}", + "poolName: {}, isHashBuildPool: {}, isHashProbePool: {}", poolName, - expectedHashBuildPool); + isHashBuildPool, + isHashProbePool); } } testSettings[] = { - {"HashBuild", true}, - {"HashBuildd", false}, - {"hHashBuild", true}, - {"hHashProbe", false}, - {"HashProbe", false}, - {"HashProbeh", false}}; + {"HashBuild", true, false}, + {"HashBuildd", false, false}, + {"hHashBuild", true, false}, + {"hHashProbe", false, true}, + {"HashProbe", false, true}, + {"HashProbeh", false, false}}; for (const auto& testData : testSettings) { SCOPED_TRACE(testData.debugString()); const auto pool = root->addLeafChild(testData.poolName); - ASSERT_EQ(isHashBuildMemoryPool(*pool), testData.expectedHashBuildPool); + ASSERT_EQ(isHashBuildMemoryPool(*pool), testData.isHashBuildPool); + ASSERT_EQ(isHashProbeMemoryPool(*pool), testData.isHashProbePool); } } @@ -513,3 +573,5 @@ VELOX_INSTANTIATE_TEST_SUITE_P( HashJoinBridgeTest, HashJoinBridgeTest, testing::ValuesIn(HashJoinBridgeTest::getTestParams())); + +} // namespace facebook::velox::exec::test diff --git a/velox/exec/tests/HashJoinTest.cpp b/velox/exec/tests/HashJoinTest.cpp index 08dc5b8f7604..0cebc7cd9a3c 100644 --- a/velox/exec/tests/HashJoinTest.cpp +++ b/velox/exec/tests/HashJoinTest.cpp @@ -24,7 +24,6 @@ #include "velox/exec/HashBuild.h" #include "velox/exec/HashJoinBridge.h" #include "velox/exec/PlanNodeStats.h" -#include "velox/exec/TableScan.h" #include "velox/exec/tests/utils/ArbitratorTestUtil.h" #include "velox/exec/tests/utils/AssertQueryBuilder.h" #include "velox/exec/tests/utils/Cursor.h" @@ -5498,7 +5497,6 @@ DEBUG_ONLY_TEST_F(HashJoinTest, reclaimDuringAllocation) { "facebook::velox::exec::Driver::runInternal::addInput", std::function(([&](Operator* testOp) { if (testOp->operatorType() != "HashBuild") { - ASSERT_FALSE(testOp->canReclaim()); return; } op = testOp; @@ -5746,10 +5744,10 @@ DEBUG_ONLY_TEST_F(HashJoinTest, reclaimDuringWaitForProbe) { concat(probeType_->names(), buildType_->names())) .planNode(); + std::atomic_bool driverWaitFlag{true}; folly::EventCount driverWait; - auto driverWaitKey = driverWait.prepareWait(); + std::atomic_bool testWaitFlag{true}; folly::EventCount testWait; - auto testWaitKey = testWait.prepareWait(); Operator* op; std::atomic injectSpillOnce{true}; @@ -5780,7 +5778,6 @@ DEBUG_ONLY_TEST_F(HashJoinTest, reclaimDuringWaitForProbe) { if (testOp->operatorType() != "HashProbe") { return; } - ASSERT_FALSE(testOp->canReclaim()); if (!injectOnce.exchange(false)) { return; } @@ -5790,11 +5787,12 @@ DEBUG_ONLY_TEST_F(HashJoinTest, reclaimDuringWaitForProbe) { const bool reclaimable = op->reclaimableBytes(reclaimableBytes); ASSERT_TRUE(reclaimable); ASSERT_GT(reclaimableBytes, 0); - testWait.notify(); + testWaitFlag = false; + testWait.notifyAll(); auto* driver = testOp->testingOperatorCtx()->driver(); auto task = driver->task(); SuspendedSection suspendedSection(driver); - driverWait.wait(driverWaitKey); + driverWait.await([&]() { return !driverWaitFlag.load(); }); }))); std::thread taskThread([&]() { @@ -5817,7 +5815,7 @@ DEBUG_ONLY_TEST_F(HashJoinTest, reclaimDuringWaitForProbe) { .run(); }); - testWait.wait(testWaitKey); + testWait.await([&]() { return !testWaitFlag.load(); }); ASSERT_TRUE(op != nullptr); auto task = op->testingOperatorCtx()->task(); auto taskPauseWait = task->requestPause(); @@ -5840,7 +5838,8 @@ DEBUG_ONLY_TEST_F(HashJoinTest, reclaimDuringWaitForProbe) { // No reclaim as the build operator is not in building table state. ASSERT_EQ(usedMemoryBytes, op->pool()->currentBytes()); - driverWait.notify(); + driverWaitFlag = false; + driverWait.notifyAll(); Task::resume(task); task.reset(); @@ -7075,165 +7074,6 @@ DEBUG_ONLY_TEST_F(HashJoinTest, reclaimDuringJoinTableBuild) { waitForAllTasksToBeDeleted(); } -// This test is to reproduce a race condition that memory arbitrator tries to -// reclaim from a set of hash build operators in which the last hash build -// operator has finished. -DEBUG_ONLY_TEST_F(HashJoinTest, raceBetweenRaclaimAndJoinFinish) { - std::unique_ptr memoryManager = createMemoryManager(); - const auto& arbitrator = memoryManager->arbitrator(); - auto rowType = ROW({ - {"c0", INTEGER()}, - {"c1", INTEGER()}, - {"c2", VARCHAR()}, - }); - // Build a large vector to trigger memory arbitration. - fuzzerOpts_.vectorSize = 10'000; - std::vector vectors = createVectors(2, rowType, fuzzerOpts_); - createDuckDbTable(vectors); - - std::shared_ptr joinQueryCtx = - newQueryCtx(memoryManager, executor_, kMemoryCapacity); - - auto planNodeIdGenerator = std::make_shared(); - core::PlanNodeId planNodeId; - auto plan = PlanBuilder(planNodeIdGenerator) - .values(vectors, false) - .project({"c0 AS t0", "c1 AS t1", "c2 AS t2"}) - .hashJoin( - {"t0"}, - {"u0"}, - PlanBuilder(planNodeIdGenerator) - .values(vectors, true) - .project({"c0 AS u0", "c1 AS u1", "c2 AS u2"}) - .planNode(), - "", - {"t1"}, - core::JoinType::kAnti) - .capturePlanNodeId(planNodeId) - .planNode(); - - std::atomic waitForBuildFinishFlag{true}; - folly::EventCount waitForBuildFinishEvent; - std::atomic lastBuildDriver{nullptr}; - std::atomic task{nullptr}; - std::atomic isLastBuildFirstChildPool{false}; - SCOPED_TESTVALUE_SET( - "facebook::velox::exec::HashBuild::finishHashBuild", - std::function([&](exec::HashBuild* buildOp) { - lastBuildDriver = buildOp->testingOperatorCtx()->driver(); - // Checks if the last build memory pool is the first build pool in its - // parent node pool. It is used to check the test result. - int buildPoolIndex{0}; - buildOp->pool()->parent()->visitChildren([&](memory::MemoryPool* pool) { - if (pool == buildOp->pool()) { - return false; - } - if (isHashBuildMemoryPool(*pool)) { - ++buildPoolIndex; - } - return true; - }); - isLastBuildFirstChildPool = (buildPoolIndex == 0); - task = lastBuildDriver.load()->task().get(); - waitForBuildFinishFlag = false; - waitForBuildFinishEvent.notifyAll(); - })); - - std::atomic waitForReclaimFlag{true}; - folly::EventCount waitForReclaimEvent; - SCOPED_TESTVALUE_SET( - "facebook::velox::exec::Driver::runInternal", - std::function([&](Driver* driver) { - auto* op = driver->findOperator(planNodeId); - if (op->operatorType() != "HashBuild" && - op->operatorType() != "HashProbe") { - return; - } - - // Suspend hash probe driver to wait for the test triggered reclaim to - // finish. - if (op->operatorType() == "HashProbe") { - op->pool()->reclaimer()->enterArbitration(); - waitForReclaimEvent.await( - [&]() { return !waitForReclaimFlag.load(); }); - op->pool()->reclaimer()->leaveArbitration(); - } - - // Check if we have reached to the last hash build operator or not. The - // testvalue callback will set the last build driver. - if (lastBuildDriver == nullptr) { - return; - } - - // Suspend all the remaining hash build drivers until the test triggered - // reclaim finish. - op->pool()->reclaimer()->enterArbitration(); - waitForReclaimEvent.await([&]() { return !waitForReclaimFlag.load(); }); - op->pool()->reclaimer()->leaveArbitration(); - })); - - const int numDrivers = 4; - std::thread queryThread([&]() { - const auto spillDirectory = exec::test::TempDirectoryPath::create(); - AssertQueryBuilder(plan, duckDbQueryRunner_) - .maxDrivers(numDrivers) - .queryCtx(joinQueryCtx) - .spillDirectory(spillDirectory->path) - .config(core::QueryConfig::kSpillEnabled, true) - .config(core::QueryConfig::kJoinSpillEnabled, true) - .assertResults( - "SELECT c1 FROM tmp WHERE c0 NOT IN (SELECT c0 FROM tmp)"); - }); - - // Wait for the last hash build operator to start building the hash table. - waitForBuildFinishEvent.await([&] { return !waitForBuildFinishFlag.load(); }); - ASSERT_TRUE(lastBuildDriver != nullptr); - ASSERT_TRUE(task != nullptr); - - // Wait until the last build driver gets removed from the task after finishes. - while (task.load()->numFinishedDrivers() != 1) { - bool foundLastBuildDriver{false}; - task.load()->testingVisitDrivers([&](Driver* driver) { - if (driver == lastBuildDriver) { - foundLastBuildDriver = true; - } - }); - if (!foundLastBuildDriver) { - break; - } - } - - // Reclaim from the task, and we can't reclaim anything as we don't support - // spill after hash table built. - memory::MemoryReclaimer::Stats stats; - const uint64_t oldCapacity = joinQueryCtx->pool()->capacity(); - task.load()->pool()->shrink(); - task.load()->pool()->reclaim(1'000, 0, stats); - // If the last build memory pool is first child of its parent memory pool, - // then memory arbitration (or join node memory pool) will reclaim from the - // last build operator first which simply quits as the driver has gone. If - // not, we expect to get numNonReclaimableAttempts from any one of the - // remaining hash build operator. - if (isLastBuildFirstChildPool) { - ASSERT_EQ(stats.numNonReclaimableAttempts, 0); - } else { - ASSERT_EQ(stats.numNonReclaimableAttempts, 1); - } - // Make sure we don't leak memory capacity since we reclaim from task pool - // directly. - static_cast(task.load()->pool()) - ->testingSetCapacity(oldCapacity); - waitForReclaimFlag = false; - waitForReclaimEvent.notifyAll(); - - queryThread.join(); - - waitForAllTasksToBeDeleted(); - ASSERT_EQ(arbitrator->stats().numFailures, 0); - ASSERT_EQ(arbitrator->stats().numReclaimedBytes, 0); - ASSERT_EQ(arbitrator->stats().numReserves, 1); -} - DEBUG_ONLY_TEST_F(HashJoinTest, joinBuildSpillError) { const int kMemoryCapacity = 32 << 20; // Set a small memory capacity to trigger spill. diff --git a/velox/exec/tests/HashTableTest.cpp b/velox/exec/tests/HashTableTest.cpp index fcad6cd1a1de..0190df05b601 100644 --- a/velox/exec/tests/HashTableTest.cpp +++ b/velox/exec/tests/HashTableTest.cpp @@ -166,6 +166,9 @@ class HashTableTest : public testing::TestWithParam, testEraseEveryN(4); testProbe(); testGroupBySpill(size, buildType, numKeys); + const auto memoryUsage = pool()->currentBytes(); + topTable_->clear(true); + ASSERT_LT(pool()->currentBytes(), memoryUsage); } // Inserts and deletes rows in a HashTable, similarly to a group by @@ -622,9 +625,11 @@ TEST_P(HashTableTest, clear) { BIGINT(), config); - auto table = HashTable::createForAggregation( - std::move(keyHashers), {Accumulator{aggregate.get(), nullptr}}, pool()); - ASSERT_NO_THROW(table->clear()); + for (const bool clearTable : {false, true}) { + auto table = HashTable::createForAggregation( + std::move(keyHashers), {Accumulator{aggregate.get(), nullptr}}, pool()); + ASSERT_NO_THROW(table->clear(clearTable)); + } } // Test a specific code path in HashTable::decodeHashMode where