diff --git a/velox/common/memory/MemoryArbitrator.cpp b/velox/common/memory/MemoryArbitrator.cpp index 9220a8e30c3b..0de3e40bc9b0 100644 --- a/velox/common/memory/MemoryArbitrator.cpp +++ b/velox/common/memory/MemoryArbitrator.cpp @@ -180,6 +180,7 @@ std::unique_ptr MemoryReclaimer::create() { uint64_t MemoryReclaimer::run( const std::function& func, Stats& stats) { + VELOX_CHECK(underMemoryArbitration()); uint64_t execTimeUs{0}; int64_t reclaimedBytes{0}; { diff --git a/velox/common/memory/MemoryArbitrator.h b/velox/common/memory/MemoryArbitrator.h index b587fe16cc37..29d8ce6403ae 100644 --- a/velox/common/memory/MemoryArbitrator.h +++ b/velox/common/memory/MemoryArbitrator.h @@ -453,11 +453,14 @@ template std::shared_ptr> createAsyncMemoryReclaimTask( std::function()> task) { auto* arbitrationCtx = memory::memoryArbitrationContext(); - VELOX_CHECK_NOT_NULL(arbitrationCtx); return std::make_shared>( [asyncTask = std::move(task), arbitrationCtx]() -> std::unique_ptr { - VELOX_CHECK_NOT_NULL(arbitrationCtx); - memory::ScopedMemoryArbitrationContext ctx(arbitrationCtx->requestor); + std::unique_ptr restoreArbitrationCtx; + if (arbitrationCtx != nullptr) { + restoreArbitrationCtx = + std::make_unique( + arbitrationCtx->requestor); + } return asyncTask(); }); } diff --git a/velox/common/memory/SharedArbitrator.cpp b/velox/common/memory/SharedArbitrator.cpp index e2e2c590b86e..938dc9b41e8e 100644 --- a/velox/common/memory/SharedArbitrator.cpp +++ b/velox/common/memory/SharedArbitrator.cpp @@ -556,6 +556,10 @@ uint64_t SharedArbitrator::getCapacityGrowthTarget( } bool SharedArbitrator::growCapacity(MemoryPool* pool, uint64_t requestBytes) { + // NOTE: we shouldn't trigger the recursive memory capacity growth under + // memory arbiration context. + VELOX_CHECK(!underMemoryArbitration()); + ArbitrationOperation op( pool, requestBytes, getCapacityGrowthTarget(*pool, requestBytes)); ScopedArbitration scopedArbitration(this, &op); diff --git a/velox/exec/MemoryReclaimer.cpp b/velox/exec/MemoryReclaimer.cpp index 16bb2d81dbe5..8b64b9498907 100644 --- a/velox/exec/MemoryReclaimer.cpp +++ b/velox/exec/MemoryReclaimer.cpp @@ -122,7 +122,7 @@ uint64_t ParallelMemoryReclaimer::reclaim( if (candidate.reclaimableBytes == 0) { continue; } - reclaimTasks.push_back(std::make_shared>( + reclaimTasks.push_back(memory::createAsyncMemoryReclaimTask( [&, reclaimPool = candidate.pool]() { try { Stats reclaimStats; diff --git a/velox/exec/Spiller.cpp b/velox/exec/Spiller.cpp index 96c38f67e2ed..eee401a4e70b 100644 --- a/velox/exec/Spiller.cpp +++ b/velox/exec/Spiller.cpp @@ -17,6 +17,7 @@ #include "velox/exec/Spiller.h" #include #include "velox/common/base/AsyncSource.h" +#include "velox/common/memory/MemoryArbitrator.h" #include "velox/common/testutil/TestValue.h" #include "velox/exec/Aggregate.h" #include "velox/exec/HashJoinBridge.h" @@ -476,7 +477,7 @@ void Spiller::runSpill(bool lastRun) { if (spillRuns_[partition].rows.empty()) { continue; } - writes.push_back(std::make_shared>( + writes.push_back(memory::createAsyncMemoryReclaimTask( [partition, this]() { return writeSpill(partition); })); if ((writes.size() > 1) && executor_ != nullptr) { executor_->add([source = writes.back()]() { source->prepare(); }); diff --git a/velox/exec/tests/AggregationTest.cpp b/velox/exec/tests/AggregationTest.cpp index de32aad132ee..002635c1a43b 100644 --- a/velox/exec/tests/AggregationTest.cpp +++ b/velox/exec/tests/AggregationTest.cpp @@ -2125,10 +2125,13 @@ DEBUG_ONLY_TEST_F(AggregationTest, reclaimDuringInputProcessing) { if (testData.expectedReclaimable) { const auto usedMemory = op->pool()->usedBytes(); - op->pool()->reclaim( - folly::Random::oneIn(2) ? 0 : folly::Random::rand32(rng_), - 0, - reclaimerStats_); + { + memory::ScopedMemoryArbitrationContext ctx(op->pool()); + op->pool()->reclaim( + folly::Random::oneIn(2) ? 0 : folly::Random::rand32(rng_), + 0, + reclaimerStats_); + } ASSERT_GT(reclaimerStats_.reclaimExecTimeUs, 0); ASSERT_GT(reclaimerStats_.reclaimedBytes, 0); reclaimerStats_.reset(); @@ -2136,11 +2139,14 @@ DEBUG_ONLY_TEST_F(AggregationTest, reclaimDuringInputProcessing) { // uses some memory. ASSERT_LT(op->pool()->usedBytes(), usedMemory); } else { - VELOX_ASSERT_THROW( - op->reclaim( - folly::Random::oneIn(2) ? 0 : folly::Random::rand32(rng_), - reclaimerStats_), - ""); + { + memory::ScopedMemoryArbitrationContext ctx(op->pool()); + VELOX_ASSERT_THROW( + op->reclaim( + folly::Random::oneIn(2) ? 0 : folly::Random::rand32(rng_), + reclaimerStats_), + ""); + } } Task::resume(task); @@ -2249,10 +2255,13 @@ DEBUG_ONLY_TEST_F(AggregationTest, reclaimDuringReserve) { ASSERT_GT(reclaimableBytes, 0); const auto usedMemory = op->pool()->usedBytes(); - op->pool()->reclaim( - folly::Random::oneIn(2) ? 0 : folly::Random::rand32(rng_), - 0, - reclaimerStats_); + { + memory::ScopedMemoryArbitrationContext ctx(op->pool()); + op->pool()->reclaim( + folly::Random::oneIn(2) ? 0 : folly::Random::rand32(rng_), + 0, + reclaimerStats_); + } ASSERT_GT(reclaimerStats_.reclaimExecTimeUs, 0); ASSERT_GE(reclaimerStats_.reclaimedBytes, 0); reclaimerStats_.reset(); @@ -2492,6 +2501,7 @@ DEBUG_ONLY_TEST_F(AggregationTest, reclaimDuringOutputProcessing) { if (enableSpilling) { ASSERT_GT(reclaimableBytes, 0); const auto usedMemory = op->pool()->usedBytes(); + memory::ScopedMemoryArbitrationContext ctx(op->pool()); op->pool()->reclaim( folly::Random::oneIn(2) ? 0 : folly::Random::rand32(rng_), 0, @@ -2503,6 +2513,7 @@ DEBUG_ONLY_TEST_F(AggregationTest, reclaimDuringOutputProcessing) { reclaimerStats_.reset(); } else { ASSERT_EQ(reclaimableBytes, 0); + memory::ScopedMemoryArbitrationContext ctx(op->pool()); VELOX_ASSERT_THROW( op->reclaim( folly::Random::oneIn(2) ? 0 : folly::Random::rand32(rng_), diff --git a/velox/exec/tests/MemoryReclaimerTest.cpp b/velox/exec/tests/MemoryReclaimerTest.cpp index 4f28a08f1ccd..697caf9d18cb 100644 --- a/velox/exec/tests/MemoryReclaimerTest.cpp +++ b/velox/exec/tests/MemoryReclaimerTest.cpp @@ -14,6 +14,7 @@ * limitations under the License. */ +#include "velox/common/base/tests/GTestUtils.h" #include "velox/common/memory/MemoryArbitrator.h" #include "velox/common/memory/MemoryPool.h" #include "velox/exec/tests/utils/OperatorTestBase.h" @@ -167,50 +168,64 @@ TEST(ReclaimableSectionGuard, basic) { ASSERT_TRUE(nonReclaimableSection); } -TEST_F(MemoryReclaimerTest, parallelMemoryReclaimer) { - class MockMemoryReclaimer : public memory::MemoryReclaimer { - public: - static std::unique_ptr create( - bool reclaimable, - uint64_t memoryBytes) { - return std::unique_ptr( - new MockMemoryReclaimer(reclaimable, memoryBytes)); - } +namespace { +class MockMemoryReclaimer : public memory::MemoryReclaimer { + public: + static std::unique_ptr create( + bool reclaimable, + uint64_t memoryBytes, + const std::function& reclaimCallback = + nullptr) { + return std::unique_ptr( + new MockMemoryReclaimer(reclaimable, memoryBytes, reclaimCallback)); + } - bool reclaimableBytes(const MemoryPool& pool, uint64_t& reclaimableBytes) - const override { - reclaimableBytes = 0; - if (!reclaimable_) { - return false; - } - reclaimableBytes = memoryBytes_; - return true; + bool reclaimableBytes(const MemoryPool& pool, uint64_t& reclaimableBytes) + const override { + reclaimableBytes = 0; + if (!reclaimable_) { + return false; } + reclaimableBytes = memoryBytes_; + return true; + } - uint64_t reclaim( - MemoryPool* pool, - uint64_t targetBytes, - uint64_t maxWaitMs, - Stats& stats) override { - VELOX_CHECK(reclaimable_); - const uint64_t reclaimedBytes = memoryBytes_; - memoryBytes_ = 0; - return reclaimedBytes; + uint64_t reclaim( + MemoryPool* pool, + uint64_t targetBytes, + uint64_t maxWaitMs, + Stats& stats) override { + VELOX_CHECK(underMemoryArbitration()); + VELOX_CHECK(reclaimable_); + if (reclaimCallback_) { + reclaimCallback_(pool); } + const uint64_t reclaimedBytes = memoryBytes_; + memoryBytes_ = 0; + return reclaimedBytes; + } - uint64_t memoryBytes() const { - return memoryBytes_; - } + uint64_t memoryBytes() const { + return memoryBytes_; + } - private: - MockMemoryReclaimer(bool reclaimable, uint64_t memoryBytes) - : reclaimable_(reclaimable), memoryBytes_(memoryBytes) {} + private: + MockMemoryReclaimer( + bool reclaimable, + uint64_t memoryBytes, + const std::function& reclaimCallback) + : reclaimCallback_(reclaimCallback), + reclaimable_(reclaimable), + memoryBytes_(memoryBytes) {} - bool reclaimable_{false}; - int reclaimCount_{0}; - uint64_t memoryBytes_{0}; - }; + const std::function reclaimCallback_; + bool reclaimable_{false}; + int reclaimCount_{0}; + uint64_t memoryBytes_{0}; +}; +} // namespace +TEST_F(MemoryReclaimerTest, parallelMemoryReclaimer) { struct TestReclaimer { bool reclaimable; uint64_t memoryBytes; @@ -261,3 +276,39 @@ TEST_F(MemoryReclaimerTest, parallelMemoryReclaimer) { } } } + +// This test is to verify if the parallel memory reclaimer can prevent recursive +// arbitration. +TEST_F(MemoryReclaimerTest, recursiveArbitrationWithParallelReclaim) { + std::atomic_bool reclaimExecuted{false}; + auto rootPool = memory::memoryManager()->addRootPool( + "recursiveArbitrationWithParallelReclaim", + 32 << 20, + exec::ParallelMemoryReclaimer::create(executor_.get())); + const auto reclaimCallback = [&](memory::MemoryPool* pool) { + void* buffer = pool->allocate(64 << 20); + pool->free(buffer, 64 << 20); + reclaimExecuted = true; + }; + const int numLeafPools = 10; + const int bufferSize = 1 << 20; + std::vector memoryReclaimers; + std::vector> leafPools; + std::vector buffers; + for (int i = 0; i < numLeafPools; ++i) { + auto reclaimer = + MockMemoryReclaimer::create(true, bufferSize, reclaimCallback); + leafPools.push_back( + rootPool->addLeafChild(std::to_string(i), true, std::move(reclaimer))); + buffers.push_back(leafPools.back()->allocate(bufferSize)); + memoryReclaimers.push_back( + static_cast(leafPools.back()->reclaimer())); + } + + memory::testingRunArbitration(); + + for (int i = 0; i < numLeafPools; ++i) { + leafPools[i]->free(buffers[i], bufferSize); + } + ASSERT_TRUE(reclaimExecuted); +} diff --git a/velox/exec/tests/OrderByTest.cpp b/velox/exec/tests/OrderByTest.cpp index 3b5d45d86129..405e7aa13f0c 100644 --- a/velox/exec/tests/OrderByTest.cpp +++ b/velox/exec/tests/OrderByTest.cpp @@ -649,10 +649,13 @@ DEBUG_ONLY_TEST_F(OrderByTest, reclaimDuringInputProcessing) { } if (testData.expectedReclaimable) { - op->pool()->reclaim( - folly::Random::oneIn(2) ? 0 : folly::Random::rand32(rng_), - 0, - reclaimerStats_); + { + memory::ScopedMemoryArbitrationContext ctx(op->pool()); + op->pool()->reclaim( + folly::Random::oneIn(2) ? 0 : folly::Random::rand32(rng_), + 0, + reclaimerStats_); + } ASSERT_GT(reclaimerStats_.reclaimedBytes, 0); ASSERT_GT(reclaimerStats_.reclaimExecTimeUs, 0); reclaimerStats_.reset(); @@ -772,10 +775,13 @@ DEBUG_ONLY_TEST_F(OrderByTest, reclaimDuringReserve) { ASSERT_TRUE(reclaimable); ASSERT_GT(reclaimableBytes, 0); - op->pool()->reclaim( - folly::Random::oneIn(2) ? 0 : folly::Random::rand32(rng_), - 0, - reclaimerStats_); + { + memory::ScopedMemoryArbitrationContext ctx(op->pool()); + op->pool()->reclaim( + folly::Random::oneIn(2) ? 0 : folly::Random::rand32(rng_), + 0, + reclaimerStats_); + } ASSERT_GT(reclaimerStats_.reclaimedBytes, 0); ASSERT_GT(reclaimerStats_.reclaimExecTimeUs, 0); reclaimerStats_.reset();