Skip to content

Commit

Permalink
probe side support
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaoxmeng committed Feb 29, 2024
1 parent b557ab6 commit 72eb438
Show file tree
Hide file tree
Showing 24 changed files with 659 additions and 265 deletions.
5 changes: 3 additions & 2 deletions velox/common/memory/Memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,9 @@ class MemoryManager {
bool threadSafe = true);

/// Invoked to shrink alive pools to free 'targetBytes' capacity. The function
/// returns the actual freed memory capacity in bytes.
uint64_t shrinkPools(uint64_t targetBytes);
/// returns the actual freed memory capacity in bytes. If 'targetBytes' is
/// zero, then try to reclaim all the memory from the alive pools.
uint64_t shrinkPools(uint64_t targetBytes = 0);

/// Default unmanaged leaf pool with no threadsafe stats support. Libraries
/// using this method can get a pool that is shared with other threads. The
Expand Down
5 changes: 3 additions & 2 deletions velox/common/memory/MemoryArbitrator.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,9 @@ class MemoryArbitrator {

/// Invoked by the memory manager to shrink memory capacity from a given list
/// of memory pools by reclaiming free and used memory. The freed memory
/// capacity is given back to the arbitrator. The function returns the actual
/// freed memory capacity in bytes.
/// capacity is given back to the arbitrator. If 'targetBytes' is zero, then
/// try to reclaim all the memory from 'pools'. The function returns the
/// actual freed memory capacity in bytes.
virtual uint64_t shrinkCapacity(
const std::vector<std::shared_ptr<MemoryPool>>& pools,
uint64_t targetBytes) = 0;
Expand Down
1 change: 1 addition & 0 deletions velox/common/memory/MemoryPool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,7 @@ std::shared_ptr<MemoryPool> MemoryPoolImpl::genChild(

bool MemoryPoolImpl::maybeReserve(uint64_t increment) {
CHECK_AND_INC_MEM_OP_STATS(Reserves);
//LOG(ERROR) << name_ << " reserve " << succinctBytes(increment) << " " << succinctBytes(currentBytes()) << " " << succinctBytes(reservedBytes());
TestValue::adjust(
"facebook::velox::common::memory::MemoryPoolImpl::maybeReserve", this);
// TODO: make this a configurable memory pool option.
Expand Down
3 changes: 1 addition & 2 deletions velox/core/PlanNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -1608,8 +1608,7 @@ class HashJoinNode : public AbstractJoinNode {
// filter set. It requires to cross join the null-key probe rows with all
// the build-side rows for filter evaluation which is not supported under
// spilling.
return !(isAntiJoin() && nullAware_ && filter() != nullptr) &&
queryConfig.joinSpillEnabled();
return !(isAntiJoin() && nullAware_ && filter() != nullptr);
}

bool isNullAware() const {
Expand Down
14 changes: 14 additions & 0 deletions velox/core/QueryConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,10 @@ class QueryConfig {
/// Join spilling flag, only applies if "spill_enabled" flag is set.
static constexpr const char* kJoinSpillEnabled = "join_spill_enabled";

static constexpr const char* kJoinBuildSpillEnabled = "join_build_spill_enabled";

static constexpr const char* kJoinProbeSpillEnabled = "join_probe_spill_enabled";

/// OrderBy spilling flag, only applies if "spill_enabled" flag is set.
static constexpr const char* kOrderBySpillEnabled = "order_by_spill_enabled";

Expand Down Expand Up @@ -533,6 +537,16 @@ class QueryConfig {
return get<bool>(kJoinSpillEnabled, true);
}

/// Returns 'is join spilling enabled' flag. Must also check the
/// spillEnabled()!
bool joinBuildSpillEnabled() const {
return get<bool>(kJoinBuildSpillEnabled, true);
}

bool joinProbeSpillEnabled() const {
return get<bool>(kJoinProbeSpillEnabled, true);
}

/// Returns 'is orderby spilling enabled' flag. Must also check the
/// spillEnabled()!
bool orderBySpillEnabled() const {
Expand Down
36 changes: 27 additions & 9 deletions velox/exec/HashBuild.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ HashBuild::HashBuild(
operatorId,
joinNode->id(),
"HashBuild",
joinNode->canSpill(driverCtx->queryConfig())
canHashJoinSpill(joinNode, driverCtx->queryConfig(), true)
? driverCtx->makeSpillConfig(operatorId)
: std::nullopt),
joinNode_(std::move(joinNode)),
Expand Down Expand Up @@ -203,6 +203,9 @@ void HashBuild::setupSpiller(SpillPartition* spillPartition) {
if (!spillEnabled()) {
return;
}
if (spillType_ == nullptr) {
spillType_ = getTableSpillType(tableType_);
}

const auto& spillConfig = spillConfig_.value();
HashBitRange hashBits(
Expand Down Expand Up @@ -239,15 +242,15 @@ void HashBuild::setupSpiller(SpillPartition* spillPartition) {
spiller_ = std::make_unique<Spiller>(
Spiller::Type::kHashJoinBuild,
table_->rows(),
tableType_,
spillType_,
std::move(hashBits),
&spillConfig);

const int32_t numPartitions = spiller_->hashBits().numPartitions();
spillInputIndicesBuffers_.resize(numPartitions);
rawSpillInputIndicesBuffers_.resize(numPartitions);
numSpillInputs_.resize(numPartitions, 0);
spillChildVectors_.resize(tableType_->size());
spillChildVectors_.resize(spillType_->size());
}

bool HashBuild::isInputFromSpill() const {
Expand Down Expand Up @@ -408,6 +411,12 @@ void HashBuild::addInput(RowVectorPtr input) {
}
auto rows = table_->rows();
auto nextOffset = rows->nextOffset();
FlatVector<bool>* probedFlagVector{nullptr};
if (isInputFromSpill()) {
const auto probedFlagChannel = spillType_->size() - 1;
probedFlagVector = input->childAt(probedFlagChannel)->asFlatVector<bool>();
}

activeRows_.applyToSelected([&](auto rowIndex) {
char* newRow = rows->newRow();
if (nextOffset) {
Expand All @@ -422,6 +431,11 @@ void HashBuild::addInput(RowVectorPtr input) {
for (auto i = 0; i < dependentChannels_.size(); ++i) {
rows->store(*decoders_[i], rowIndex, newRow, i + hashers.size());
}
if (probedFlagVector != nullptr) {
if (probedFlagVector->valueAt(rowIndex)) {
rows->setProbedFlag(&newRow, 1);
}
}
});
}

Expand Down Expand Up @@ -591,6 +605,8 @@ void HashBuild::maybeSetupSpillChildVectors(const RowVectorPtr& input) {
for (const auto& channel : dependentChannels_) {
spillChildVectors_[spillChannel++] = input->childAt(channel);
}
spillChildVectors_[spillChannel] = std::make_shared<ConstantVector<bool>>(
pool(), input->size(), /*isNull=*/false, BOOLEAN(), false);
}

void HashBuild::prepareInputIndicesBuffers(
Expand Down Expand Up @@ -641,7 +657,7 @@ void HashBuild::spillPartition(
} else {
spiller_->spill(
partition,
wrap(size, indices, tableType_, spillChildVectors_, input->pool()));
wrap(size, indices, spillType_, spillChildVectors_, input->pool()));
}
}

Expand Down Expand Up @@ -845,8 +861,9 @@ bool HashBuild::finishHashBuild() {
isInputFromSpill() ? spillConfig()->startPartitionBit
: BaseHashTable::kNoSpillInputStartPartitionBit);
addRuntimeStats();
if (joinBridge_->setHashTable(
std::move(table_), std::move(spillPartitions), joinHasNullKeys_)) {
joinBridge_->setHashTable(
std::move(table_), std::move(spillPartitions), joinHasNullKeys_);
if (spillEnabled()) {
intermediateStateCleared_ = true;
spillGroup_->restart();
}
Expand All @@ -858,6 +875,7 @@ bool HashBuild::finishHashBuild() {
}

void HashBuild::recordSpillStats() {
LOG(ERROR) << "record spill stats from hash build";
recordSpillStats(spiller_.get());
}

Expand Down Expand Up @@ -947,8 +965,8 @@ void HashBuild::setupSpillInput(HashJoinBridge::SpillInput spillInput) {
void HashBuild::processSpillInput() {
checkRunning();

while (spillInputReader_->nextBatch(input_)) {
addInput(std::move(input_));
while (spillInputReader_->nextBatch(spillInput_)) {
addInput(std::move(spillInput_));
if (!isRunning()) {
return;
}
Expand Down Expand Up @@ -1224,7 +1242,7 @@ void HashBuild::reclaim(
bool HashBuild::nonReclaimableState() const {
// Apart from being in the nonReclaimable section,
// its also not reclaimable if:
// 1) the hash table has been built by the last build thread (inidicated
// 1) the hash table has been built by the last build thread (indicated
// by state_)
// 2) the last build operator has transferred ownership of 'this' operator's
// intermediate state (table_ and spiller_) to itself
Expand Down
7 changes: 5 additions & 2 deletions velox/exec/HashBuild.h
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,8 @@ class HashBuild final : public Operator {
// The row type used for hash table build and disk spilling.
RowTypePtr tableType_;

// Used to serialize access to intermediate state variables (like 'table_' and
// 'spiller_'). This is only required when variables are accessed
// Used to serialize access to internal state including 'table_' and
// 'spiller_'. This is only required when variables are accessed
// concurrently, that is, when a thread tries to close the operator while
// another thread is building the hash table. Refer to 'close()' and
// finishHashBuild()' for more details.
Expand Down Expand Up @@ -315,6 +315,7 @@ class HashBuild final : public Operator {
uint64_t numSpillRows_{0};
uint64_t numSpillBytes_{0};

RowTypePtr spillType_;
// This can be nullptr if either spilling is not allowed or it has been
// trsnaferred to the last hash build operator while in kWaitForBuild state or
// it has been cleared to setup a new one for recursive spilling.
Expand All @@ -323,6 +324,8 @@ class HashBuild final : public Operator {
// Used to read input from previously spilled data for restoring.
std::unique_ptr<UnorderedStreamReader<BatchStream>> spillInputReader_;

RowVectorPtr spillInput_;

// Reusable memory for spill partition calculation for input data.
std::vector<uint32_t> spillPartitions_;

Expand Down
86 changes: 64 additions & 22 deletions velox/exec/HashJoinBridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,14 @@ void HashJoinBridge::addBuilder() {
++numBuilders_;
}

bool HashJoinBridge::setHashTable(
void HashJoinBridge::setHashTable(
std::unique_ptr<BaseHashTable> table,
SpillPartitionSet spillPartitionSet,
bool hasNullKeys) {
VELOX_CHECK_NOT_NULL(table, "setHashTable called with null table");

auto spillPartitionIdSet = toSpillPartitionIdSet(spillPartitionSet);

bool hasSpillData;
std::vector<ContinuePromise> promises;
{
std::lock_guard<std::mutex> l(mutex_);
Expand All @@ -64,12 +63,28 @@ bool HashJoinBridge::setHashTable(
std::move(spillPartitionIdSet),
hasNullKeys);
restoringSpillPartitionId_.reset();

hasSpillData = !spillPartitionSets_.empty();
promises = std::move(promises_);
}
notify(std::move(promises));
return hasSpillData;
}

void HashJoinBridge::setSpilledHashTable(SpillPartitionSet spillPartitionSet) {
VELOX_CHECK(!spillPartitionSet.empty());
std::shared_ptr<BaseHashTable> tableToFree;
{
std::lock_guard<std::mutex> l(mutex_);
VELOX_CHECK(started_);
VELOX_CHECK(buildResult_.has_value());
VELOX_CHECK(restoringSpillShards_.empty());
VELOX_CHECK(!restoringSpillPartitionId_.has_value());

for (auto& partitionEntry : spillPartitionSet) {
const auto id = partitionEntry.first;
VELOX_CHECK_EQ(spillPartitionSets_.count(id), 0);
spillPartitionSets_.emplace(id, std::move(partitionEntry.second));
}
tableToFree = std::move(buildResult_->table);
}
}

void HashJoinBridge::setAntiJoinHasNullKeys() {
Expand Down Expand Up @@ -131,10 +146,8 @@ bool HashJoinBridge::probeFinished() {
spillPartitionSets_.begin()->second->split(numBuilders_);
VELOX_CHECK_EQ(restoringSpillShards_.size(), numBuilders_);
spillPartitionSets_.erase(spillPartitionSets_.begin());
promises = std::move(promises_);
} else {
VELOX_CHECK(promises_.empty());
}
promises = std::move(promises_);
}
notify(std::move(promises));
return hasSpillInput;
Expand All @@ -149,13 +162,9 @@ std::optional<HashJoinBridge::SpillInput> HashJoinBridge::spillInputOrFuture(
!restoringSpillPartitionId_.has_value() || !buildResult_.has_value());

if (!restoringSpillPartitionId_.has_value()) {
if (spillPartitionSets_.empty()) {
return HashJoinBridge::SpillInput{};
} else {
promises_.emplace_back("HashJoinBridge::spillInputOrFuture");
*future = promises_.back().getSemiFuture();
return std::nullopt;
}
promises_.emplace_back("HashJoinBridge::spillInputOrFuture");
*future = promises_.back().getSemiFuture();
return std::nullopt;
}
VELOX_CHECK(!restoringSpillShards_.empty());
auto spillShard = std::move(restoringSpillShards_.back());
Expand All @@ -175,22 +184,55 @@ uint64_t HashJoinMemoryReclaimer::reclaim(
uint64_t targetBytes,
uint64_t maxWaitMs,
memory::MemoryReclaimer::Stats& stats) {
bool hasReclaimedFromBuild{false};
bool hasReclaimedFromProbe{false};
uint64_t reclaimedBytes{0};
pool->visitChildren([&](memory::MemoryPool* child) {
VELOX_CHECK_EQ(child->kind(), memory::MemoryPool::Kind::kLeaf);
// The hash probe operator do not support memory reclaim.
if (!isHashBuildMemoryPool(*child)) {
return true;
const bool isBuild = isHashBuildMemoryPool(*child);
if (isBuild) {
if (!hasReclaimedFromBuild) {
hasReclaimedFromBuild = true;
reclaimedBytes = child->reclaim(targetBytes, maxWaitMs, stats);
}
// We only need to reclaim from any one of the hash build operators
// which will reclaim from all the peer hash build operators.
return !hasReclaimedFromProbe;
}
// We only need to reclaim from any one of the hash build operators
// which will reclaim from all the peer hash build operators.
reclaimedBytes = child->reclaim(targetBytes, maxWaitMs, stats);
return false;

if (!hasReclaimedFromProbe) {
hasReclaimedFromProbe = true;
reclaimedBytes = child->reclaim(targetBytes, maxWaitMs, stats);
}
return !hasReclaimedFromBuild;
});
return reclaimedBytes;
}

bool isHashBuildMemoryPool(const memory::MemoryPool& pool) {
return folly::StringPiece(pool.name()).endsWith("HashBuild");
}

bool isHashProbeMemoryPool(const memory::MemoryPool& pool) {
return folly::StringPiece(pool.name()).endsWith("HashProbe");
}

bool canHashJoinSpill(
const std::shared_ptr<const core::HashJoinNode>& joinNode,
const core::QueryConfig& queryConfig,
bool isBuild) {
if (!joinNode->canSpill(queryConfig)) {
return false;
}
return isBuild ? queryConfig.joinBuildSpillEnabled()
: queryConfig.joinProbeSpillEnabled();
}

RowTypePtr getTableSpillType(const RowTypePtr& tableType) {
auto names = tableType->names();
names.push_back("probedFlags");
auto types = tableType->children();
types.push_back(BOOLEAN());
return ROW(std::move(names), std::move(types));
}
} // namespace facebook::velox::exec
Loading

0 comments on commit 72eb438

Please sign in to comment.