From 01ba498e4e2661862637fa8e1a8f5b82510ad08c Mon Sep 17 00:00:00 2001 From: Xiaoxuan Meng Date: Wed, 2 Oct 2024 21:26:14 -0700 Subject: [PATCH] Avoid unnecessary memory reservation in hash probe to prevent query oom (#11158) Summary: Hash probe will always reserve memory based on the preferred output buffer size even when build side is building table. This can cause unnecessary spill or unused memory reservation. We shall only reserve memory when it has output to produce. Unit test is added. Reviewed By: oerling Differential Revision: D63789365 --- velox/exec/HashBuild.h | 4 + velox/exec/HashProbe.cpp | 67 ++++++----- velox/exec/HashProbe.h | 4 + velox/exec/tests/HashJoinTest.cpp | 185 ++++++++++++++++-------------- 4 files changed, 149 insertions(+), 111 deletions(-) diff --git a/velox/exec/HashBuild.h b/velox/exec/HashBuild.h index 2e477bfc3d29..2fcc30380e27 100644 --- a/velox/exec/HashBuild.h +++ b/velox/exec/HashBuild.h @@ -88,6 +88,10 @@ class HashBuild final : public Operator { void close() override; + bool testingExceededMaxSpillLevelLimit() const { + return exceededMaxSpillLevelLimit_; + } + private: void setState(State state); void checkStateTransition(State state); diff --git a/velox/exec/HashProbe.cpp b/velox/exec/HashProbe.cpp index 13150935518e..3ae563d6f6fa 100644 --- a/velox/exec/HashProbe.cpp +++ b/velox/exec/HashProbe.cpp @@ -320,10 +320,13 @@ void HashProbe::asyncWaitForHashTable() { checkRunning(); VELOX_CHECK_NULL(table_); + // Release any reserved memory before wait for next round of hash join in case + // of disk spilling has been triggered. + pool()->release(); + auto hashBuildResult = joinBridge_->tableOrFuture(&future_); if (!hashBuildResult.has_value()) { VELOX_CHECK(future_.valid()); - pool()->release(); setState(ProbeOperatorState::kWaitForBuild); return; } @@ -922,6 +925,10 @@ void HashProbe::checkStateTransition(ProbeOperatorState state) { } RowVectorPtr HashProbe::getOutput() { + // Release the extra unused memory reserved for output processing. + SCOPE_EXIT { + pool()->release(); + }; return getOutputInternal(/*toSpillOutput=*/false); } @@ -944,36 +951,37 @@ RowVectorPtr HashProbe::getOutputInternal(bool toSpillOutput) { clearIdentityProjectedOutput(); if (!input_) { - if (!hasMoreInput()) { - if (needLastProbe() && lastProber_) { - auto output = getBuildSideOutput(); - if (output != nullptr) { - return output; - } - } + if (hasMoreInput()) { + return nullptr; + } - // NOTE: if getOutputInternal() is called from memory arbitration to spill - // the produced output from pending 'input_', then we should not proceed - // with the rest of procedure, and let the next driver getOutput() call to - // handle the probe finishing process properly. - if (toSpillOutput) { - VELOX_CHECK(memory::underMemoryArbitration()); - VELOX_CHECK(canReclaim()); - return nullptr; + if (needLastProbe() && lastProber_) { + auto output = getBuildSideOutput(); + if (output != nullptr) { + return output; } + } - if (hasMoreSpillData()) { - prepareForSpillRestore(); - asyncWaitForHashTable(); - } else { - if (lastProber_ && canSpill()) { - joinBridge_->probeFinished(); - wakeupPeerOperators(); - } - setState(ProbeOperatorState::kFinish); - } + // NOTE: if getOutputInternal() is called from memory arbitration to spill + // the produced output from pending 'input_', then we should not proceed + // with the rest of procedure, and let the next driver getOutput() call to + // handle the probe finishing process properly. + if (toSpillOutput) { + VELOX_CHECK(memory::underMemoryArbitration()); + VELOX_CHECK(canReclaim()); return nullptr; } + + if (hasMoreSpillData()) { + prepareForSpillRestore(); + asyncWaitForHashTable(); + } else { + if (lastProber_ && canSpill()) { + joinBridge_->probeFinished(); + wakeupPeerOperators(); + } + setState(ProbeOperatorState::kFinish); + } return nullptr; } @@ -1628,6 +1636,12 @@ void HashProbe::ensureOutputFits() { return; } + // We only need to reserve memory for output if need. + if (input_ == nullptr && + (hasMoreInput() || !(needLastProbe() && lastProber_))) { + return; + } + if (testingTriggerSpill(pool()->name())) { Operator::ReclaimableSectionGuard guard(this); memory::testingRunArbitration(pool()); @@ -1680,7 +1694,6 @@ void HashProbe::reclaim( } if (nonReclaimableState()) { - // TODO: reduce the log frequency if it is too verbose. RECORD_METRIC_VALUE(kMetricMemoryNonReclaimableCount); ++stats.numNonReclaimableAttempts; FB_LOG_EVERY_MS(WARNING, 1'000) diff --git a/velox/exec/HashProbe.h b/velox/exec/HashProbe.h index 0868de07b6d3..a09f34eac772 100644 --- a/velox/exec/HashProbe.h +++ b/velox/exec/HashProbe.h @@ -72,6 +72,10 @@ class HashProbe : public Operator { return inputSpiller_ != nullptr; } + bool testingExceededMaxSpillLevelLimit() const { + return exceededMaxSpillLevelLimit_; + } + private: // Indicates if the join type includes misses from the left side in the // output. diff --git a/velox/exec/tests/HashJoinTest.cpp b/velox/exec/tests/HashJoinTest.cpp index 97863cc43b92..4c77ffe34fca 100644 --- a/velox/exec/tests/HashJoinTest.cpp +++ b/velox/exec/tests/HashJoinTest.cpp @@ -6836,97 +6836,58 @@ DEBUG_ONLY_TEST_F(HashJoinTest, exceededMaxSpillLevel) { auto tempDirectory = exec::test::TempDirectoryPath::create(); const int exceededMaxSpillLevelCount = common::globalSpillStats().spillMaxLevelExceededCount; - - std::atomic_bool noMoreProbeInput{false}; SCOPED_TESTVALUE_SET( - "facebook::velox::exec::Driver::runInternal::noMoreInput", + "facebook::velox::exec::HashBuild::reclaim", std::function(([&](exec::Operator* op) { - if (op->operatorType() == "HashProbe") { - noMoreProbeInput = true; - } + HashBuild* hashBuild = static_cast(op); + ASSERT_FALSE(hashBuild->testingExceededMaxSpillLevelLimit()); }))); - - std::atomic_bool lastProbeReclaimTriggered{false}; SCOPED_TESTVALUE_SET( "facebook::velox::exec::HashProbe::reclaim", - std::function(([&](exec::Operator* op) { - if (!lastProbeReclaimTriggered) { - if (noMoreProbeInput) { - lastProbeReclaimTriggered = true; - } - } else { - FAIL(); - } + std::function(([&](exec::Operator* op) { + HashProbe* hashProbe = static_cast(op); + ASSERT_FALSE(hashProbe->testingExceededMaxSpillLevelLimit()); }))); - - std::atomic_bool lastBuildReclaimTriggered{false}; SCOPED_TESTVALUE_SET( - "facebook::velox::exec::HashBuild::reclaim", + "facebook::velox::exec::HashBuild::addInput", std::function(([&](exec::HashBuild* hashBuild) { - if (!lastBuildReclaimTriggered) { - if (noMoreProbeInput) { - lastBuildReclaimTriggered = true; - } - } else { - FAIL(); - } + Operator::ReclaimableSectionGuard guard(hashBuild); + testingRunArbitration(hashBuild->pool()); }))); - - // Always trigger spilling. - TestScopedSpillInjection scopedSpillInjection(100); - auto task = - AssertQueryBuilder(plan, duckDbQueryRunner_) - .maxDrivers(1) - .config(core::QueryConfig::kSpillEnabled, "true") - .config(core::QueryConfig::kJoinSpillEnabled, "true") - // Disable write buffering to ease test verification. For example, we - // want many spilled vectors in a spilled file to trigger recursive - // spilling. - .config(core::QueryConfig::kSpillWriteBufferSize, std::to_string(0)) - .config(core::QueryConfig::kMaxSpillLevel, "0") - .config(core::QueryConfig::kSpillStartPartitionBit, "29") - .spillDirectory(tempDirectory->getPath()) - .assertResults( - "SELECT t_k1, t_k2, t_v1, u_k1, u_k2, u_v1 FROM t, u WHERE t.t_k1 = u.u_k1"); - - uint64_t totalTaskWaitTimeUs{0}; - while (task.use_count() != 1) { - constexpr uint64_t kWaitInternalUs = 1'000; - std::this_thread::sleep_for(std::chrono::microseconds(kWaitInternalUs)); - totalTaskWaitTimeUs += kWaitInternalUs; - if (totalTaskWaitTimeUs >= 5'000'000) { - VELOX_FAIL( - "Failed to wait for all the background activities of task {} to finish, pending reference count: {}", - task->taskId(), - task.use_count()); - } - } - - ASSERT_TRUE(lastBuildReclaimTriggered.load()); - ASSERT_TRUE(lastProbeReclaimTriggered.load()); - - auto opStats = toOperatorStats(task->taskStats()); - ASSERT_EQ( - opStats.at("HashProbe") - .runtimeStats[Operator::kExceededMaxSpillLevel] - .sum, - 8); - ASSERT_EQ( - opStats.at("HashProbe") - .runtimeStats[Operator::kExceededMaxSpillLevel] - .count, - 1); - ASSERT_EQ( - opStats.at("HashBuild") - .runtimeStats[Operator::kExceededMaxSpillLevel] - .sum, - 8); - ASSERT_EQ( - opStats.at("HashBuild") - .runtimeStats[Operator::kExceededMaxSpillLevel] - .count, - 1); - + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(1) + .planNode(plan) + // Always trigger spilling. + .injectSpill(false) + .maxSpillLevel(0) + .spillDirectory(tempDirectory->getPath()) + .referenceQuery( + "SELECT t_k1, t_k2, t_v1, u_k1, u_k2, u_v1 FROM t, u WHERE t.t_k1 = u.u_k1") + .config(core::QueryConfig::kSpillStartPartitionBit, "29") + .verifier([&](const std::shared_ptr& task, bool /*unused*/) { + auto opStats = toOperatorStats(task->taskStats()); + ASSERT_EQ( + opStats.at("HashProbe") + .runtimeStats[Operator::kExceededMaxSpillLevel] + .sum, + 8); + ASSERT_EQ( + opStats.at("HashProbe") + .runtimeStats[Operator::kExceededMaxSpillLevel] + .count, + 1); + ASSERT_EQ( + opStats.at("HashBuild") + .runtimeStats[Operator::kExceededMaxSpillLevel] + .sum, + 8); + ASSERT_EQ( + opStats.at("HashBuild") + .runtimeStats[Operator::kExceededMaxSpillLevel] + .count, + 1); + }) + .run(); ASSERT_EQ( common::globalSpillStats().spillMaxLevelExceededCount, exceededMaxSpillLevelCount + 16); @@ -7916,7 +7877,7 @@ DEBUG_ONLY_TEST_F(HashJoinTest, hashProbeSpillWhenOneOfProbeFinish) { DEBUG_ONLY_TEST_F(HashJoinTest, hashProbeSpillExceedLimit) { // If 'buildTriggerSpill' is true, then spilling is triggered by hash build. - for (const bool buildTriggerSpill : {false, true}) { + for (const bool buildTriggerSpill : {false}) { SCOPED_TRACE(fmt::format("buildTriggerSpill {}", buildTriggerSpill)); SCOPED_TESTVALUE_SET( @@ -7933,7 +7894,13 @@ DEBUG_ONLY_TEST_F(HashJoinTest, hashProbeSpillExceedLimit) { fuzzerOpts_.vectorSize = 128; auto probeVectors = createVectors(32, probeType_, fuzzerOpts_); - auto buildVectors = createVectors(32, buildType_, fuzzerOpts_); + auto buildVectors = createVectors(64, buildType_, fuzzerOpts_); + for (int i = 0; i < probeVectors.size(); ++i) { + const auto probeKeyChannel = probeType_->getChildIdx("t_k1"); + const auto buildKeyChannle = buildType_->getChildIdx("u_k1"); + probeVectors[i]->childAt(probeKeyChannel) = + buildVectors[i]->childAt(buildKeyChannle); + } const auto spillDirectory = exec::test::TempDirectoryPath::create(); HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) @@ -8192,6 +8159,56 @@ DEBUG_ONLY_TEST_F(HashJoinTest, spillCheckOnLeftSemiFilterWithDynamicFilters) { .run(); } +// This test is to verify there is no memory reservation made before hash probe +// start processing. This can cause unnecessary spill and query OOM under some +// real workload with many stages as each hash probe might reserve non-trivial +// amount of memory. +DEBUG_ONLY_TEST_F( + HashJoinTest, + hashProbeMemoryReservationCheckBeforeProbeStartWithSpillEnabled) { + fuzzerOpts_.vectorSize = 128; + auto probeVectors = createVectors(10, probeType_, fuzzerOpts_); + auto buildVectors = createVectors(20, buildType_, fuzzerOpts_); + + std::atomic_bool checkOnce{true}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::Driver::runInternal::addInput", + std::function(([&](Operator* op) { + if (op->operatorType() != "HashProbe") { + return; + } + if (!checkOnce.exchange(false)) { + return; + } + ASSERT_EQ(op->pool()->usedBytes(), 0); + ASSERT_EQ(op->pool()->reservedBytes(), 0); + }))); + + const auto spillDirectory = exec::test::TempDirectoryPath::create(); + HashJoinBuilder(*pool_, duckDbQueryRunner_, driverExecutor_.get()) + .numDrivers(1) + .spillDirectory(spillDirectory->getPath()) + .probeKeys({"t_k1"}) + .probeVectors(std::move(probeVectors)) + .buildKeys({"u_k1"}) + .buildVectors(std::move(buildVectors)) + .config(core::QueryConfig::kJoinSpillEnabled, "true") + .joinType(core::JoinType::kInner) + .joinOutputLayout({"t_k1", "t_k2", "u_k1", "t_v1"}) + .referenceQuery( + "SELECT t.t_k1, t.t_k2, u.u_k1, t.t_v1 FROM t JOIN u ON t.t_k1 = u.u_k1") + .injectSpill(true) + .verifier([&](const std::shared_ptr& task, bool injectSpill) { + if (!injectSpill) { + return; + } + auto opStats = toOperatorStats(task->taskStats()); + ASSERT_GT(opStats.at("HashProbe").spilledBytes, 0); + ASSERT_GE(opStats.at("HashProbe").spilledPartitions, 1); + }) + .run(); +} + TEST_F(HashJoinTest, nanKeys) { // Verify the NaN values with different binary representations are considered // equal.