Skip to content

Commit

Permalink
probe spill draft
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaoxmeng committed Mar 4, 2024
1 parent f391c02 commit 96612d5
Show file tree
Hide file tree
Showing 19 changed files with 821 additions and 276 deletions.
3 changes: 1 addition & 2 deletions velox/core/PlanNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -1608,8 +1608,7 @@ class HashJoinNode : public AbstractJoinNode {
// filter set. It requires to cross join the null-key probe rows with all
// the build-side rows for filter evaluation which is not supported under
// spilling.
return !(isAntiJoin() && nullAware_ && filter() != nullptr) &&
queryConfig.joinSpillEnabled();
return !(isAntiJoin() && nullAware_ && filter() != nullptr);
}

bool isNullAware() const {
Expand Down
14 changes: 14 additions & 0 deletions velox/core/QueryConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,10 @@ class QueryConfig {
/// Join spilling flag, only applies if "spill_enabled" flag is set.
static constexpr const char* kJoinSpillEnabled = "join_spill_enabled";

static constexpr const char* kJoinBuildSpillEnabled = "join_build_spill_enabled";

static constexpr const char* kJoinProbeSpillEnabled = "join_probe_spill_enabled";

/// OrderBy spilling flag, only applies if "spill_enabled" flag is set.
static constexpr const char* kOrderBySpillEnabled = "order_by_spill_enabled";

Expand Down Expand Up @@ -533,6 +537,16 @@ class QueryConfig {
return get<bool>(kJoinSpillEnabled, true);
}

/// Returns 'is join spilling enabled' flag. Must also check the
/// spillEnabled()!
bool joinBuildSpillEnabled() const {
return get<bool>(kJoinBuildSpillEnabled, true);
}

bool joinProbeSpillEnabled() const {
return get<bool>(kJoinProbeSpillEnabled, true);
}

/// Returns 'is orderby spilling enabled' flag. Must also check the
/// spillEnabled()!
bool orderBySpillEnabled() const {
Expand Down
36 changes: 27 additions & 9 deletions velox/exec/HashBuild.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ HashBuild::HashBuild(
operatorId,
joinNode->id(),
"HashBuild",
joinNode->canSpill(driverCtx->queryConfig())
canHashJoinSpill(joinNode, driverCtx->queryConfig(), true)
? driverCtx->makeSpillConfig(operatorId)
: std::nullopt),
joinNode_(std::move(joinNode)),
Expand Down Expand Up @@ -203,6 +203,9 @@ void HashBuild::setupSpiller(SpillPartition* spillPartition) {
if (!spillEnabled()) {
return;
}
if (spillType_ == nullptr) {
spillType_ = getTableSpillType(tableType_);
}

const auto& spillConfig = spillConfig_.value();
HashBitRange hashBits(
Expand Down Expand Up @@ -239,15 +242,15 @@ void HashBuild::setupSpiller(SpillPartition* spillPartition) {
spiller_ = std::make_unique<Spiller>(
Spiller::Type::kHashJoinBuild,
table_->rows(),
tableType_,
spillType_,
std::move(hashBits),
&spillConfig);

const int32_t numPartitions = spiller_->hashBits().numPartitions();
spillInputIndicesBuffers_.resize(numPartitions);
rawSpillInputIndicesBuffers_.resize(numPartitions);
numSpillInputs_.resize(numPartitions, 0);
spillChildVectors_.resize(tableType_->size());
spillChildVectors_.resize(spillType_->size());
}

bool HashBuild::isInputFromSpill() const {
Expand Down Expand Up @@ -408,6 +411,12 @@ void HashBuild::addInput(RowVectorPtr input) {
}
auto rows = table_->rows();
auto nextOffset = rows->nextOffset();
FlatVector<bool>* probedFlagVector{nullptr};
if (isInputFromSpill()) {
const auto probedFlagChannel = spillType_->size() - 1;
probedFlagVector = input->childAt(probedFlagChannel)->asFlatVector<bool>();
}

activeRows_.applyToSelected([&](auto rowIndex) {
char* newRow = rows->newRow();
if (nextOffset) {
Expand All @@ -422,6 +431,11 @@ void HashBuild::addInput(RowVectorPtr input) {
for (auto i = 0; i < dependentChannels_.size(); ++i) {
rows->store(*decoders_[i], rowIndex, newRow, i + hashers.size());
}
if (probedFlagVector != nullptr) {
if (probedFlagVector->valueAt(rowIndex)) {
rows->setProbedFlag(&newRow, 1);
}
}
});
}

Expand Down Expand Up @@ -590,6 +604,8 @@ void HashBuild::maybeSetupSpillChildVectors(const RowVectorPtr& input) {
for (const auto& channel : dependentChannels_) {
spillChildVectors_[spillChannel++] = input->childAt(channel);
}
spillChildVectors_[spillChannel] = std::make_shared<ConstantVector<bool>>(
pool(), input->size(), /*isNull=*/false, BOOLEAN(), false);
}

void HashBuild::prepareInputIndicesBuffers(
Expand Down Expand Up @@ -640,7 +656,7 @@ void HashBuild::spillPartition(
} else {
spiller_->spill(
partition,
wrap(size, indices, tableType_, spillChildVectors_, input->pool()));
wrap(size, indices, spillType_, spillChildVectors_, input->pool()));
}
}

Expand Down Expand Up @@ -844,8 +860,9 @@ bool HashBuild::finishHashBuild() {
isInputFromSpill() ? spillConfig()->startPartitionBit
: BaseHashTable::kNoSpillInputStartPartitionBit);
addRuntimeStats();
if (joinBridge_->setHashTable(
std::move(table_), std::move(spillPartitions), joinHasNullKeys_)) {
joinBridge_->setHashTable(
std::move(table_), std::move(spillPartitions), joinHasNullKeys_);
if (spillEnabled()) {
intermediateStateCleared_ = true;
spillGroup_->restart();
}
Expand All @@ -857,6 +874,7 @@ bool HashBuild::finishHashBuild() {
}

void HashBuild::recordSpillStats() {
LOG(ERROR) << "record spill stats from hash build";
recordSpillStats(spiller_.get());
}

Expand Down Expand Up @@ -946,8 +964,8 @@ void HashBuild::setupSpillInput(HashJoinBridge::SpillInput spillInput) {
void HashBuild::processSpillInput() {
checkRunning();

while (spillInputReader_->nextBatch(input_)) {
addInput(std::move(input_));
while (spillInputReader_->nextBatch(spillInput_)) {
addInput(std::move(spillInput_));
if (!isRunning()) {
return;
}
Expand Down Expand Up @@ -1214,7 +1232,7 @@ void HashBuild::reclaim(
bool HashBuild::nonReclaimableState() const {
// Apart from being in the nonReclaimable section,
// its also not reclaimable if:
// 1) the hash table has been built by the last build thread (inidicated
// 1) the hash table has been built by the last build thread (indicated
// by state_)
// 2) the last build operator has transferred ownership of 'this' operator's
// intermediate state (table_ and spiller_) to itself
Expand Down
7 changes: 5 additions & 2 deletions velox/exec/HashBuild.h
Original file line number Diff line number Diff line change
Expand Up @@ -261,8 +261,8 @@ class HashBuild final : public Operator {
// The row type used for hash table build and disk spilling.
RowTypePtr tableType_;

// Used to serialize access to intermediate state variables (like 'table_' and
// 'spiller_'). This is only required when variables are accessed
// Used to serialize access to internal state including 'table_' and
// 'spiller_'. This is only required when variables are accessed
// concurrently, that is, when a thread tries to close the operator while
// another thread is building the hash table. Refer to 'close()' and
// finishHashBuild()' for more details.
Expand Down Expand Up @@ -308,6 +308,7 @@ class HashBuild final : public Operator {
uint64_t numSpillRows_{0};
uint64_t numSpillBytes_{0};

RowTypePtr spillType_;
// This can be nullptr if either spilling is not allowed or it has been
// trsnaferred to the last hash build operator while in kWaitForBuild state or
// it has been cleared to setup a new one for recursive spilling.
Expand All @@ -316,6 +317,8 @@ class HashBuild final : public Operator {
// Used to read input from previously spilled data for restoring.
std::unique_ptr<UnorderedStreamReader<BatchStream>> spillInputReader_;

RowVectorPtr spillInput_;

// Reusable memory for spill partition calculation for input data.
std::vector<uint32_t> spillPartitions_;

Expand Down
86 changes: 64 additions & 22 deletions velox/exec/HashJoinBridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,14 @@ void HashJoinBridge::addBuilder() {
++numBuilders_;
}

bool HashJoinBridge::setHashTable(
void HashJoinBridge::setHashTable(
std::unique_ptr<BaseHashTable> table,
SpillPartitionSet spillPartitionSet,
bool hasNullKeys) {
VELOX_CHECK_NOT_NULL(table, "setHashTable called with null table");

auto spillPartitionIdSet = toSpillPartitionIdSet(spillPartitionSet);

bool hasSpillData;
std::vector<ContinuePromise> promises;
{
std::lock_guard<std::mutex> l(mutex_);
Expand All @@ -64,12 +63,28 @@ bool HashJoinBridge::setHashTable(
std::move(spillPartitionIdSet),
hasNullKeys);
restoringSpillPartitionId_.reset();

hasSpillData = !spillPartitionSets_.empty();
promises = std::move(promises_);
}
notify(std::move(promises));
return hasSpillData;
}

void HashJoinBridge::setSpilledHashTable(SpillPartitionSet spillPartitionSet) {
VELOX_CHECK(!spillPartitionSet.empty());
//std::shared_ptr<BaseHashTable> tableToFree;
{
std::lock_guard<std::mutex> l(mutex_);
VELOX_CHECK(started_);
VELOX_CHECK(buildResult_.has_value());
VELOX_CHECK(restoringSpillShards_.empty());
VELOX_CHECK(!restoringSpillPartitionId_.has_value());

for (auto& partitionEntry : spillPartitionSet) {
const auto id = partitionEntry.first;
VELOX_CHECK_EQ(spillPartitionSets_.count(id), 0);
spillPartitionSets_.emplace(id, std::move(partitionEntry.second));
}
//tableToFree = std::move(buildResult_->table);
}
}

void HashJoinBridge::setAntiJoinHasNullKeys() {
Expand Down Expand Up @@ -131,10 +146,8 @@ bool HashJoinBridge::probeFinished() {
spillPartitionSets_.begin()->second->split(numBuilders_);
VELOX_CHECK_EQ(restoringSpillShards_.size(), numBuilders_);
spillPartitionSets_.erase(spillPartitionSets_.begin());
promises = std::move(promises_);
} else {
VELOX_CHECK(promises_.empty());
}
promises = std::move(promises_);
}
notify(std::move(promises));
return hasSpillInput;
Expand All @@ -149,13 +162,9 @@ std::optional<HashJoinBridge::SpillInput> HashJoinBridge::spillInputOrFuture(
!restoringSpillPartitionId_.has_value() || !buildResult_.has_value());

if (!restoringSpillPartitionId_.has_value()) {
if (spillPartitionSets_.empty()) {
return HashJoinBridge::SpillInput{};
} else {
promises_.emplace_back("HashJoinBridge::spillInputOrFuture");
*future = promises_.back().getSemiFuture();
return std::nullopt;
}
promises_.emplace_back("HashJoinBridge::spillInputOrFuture");
*future = promises_.back().getSemiFuture();
return std::nullopt;
}
VELOX_CHECK(!restoringSpillShards_.empty());
auto spillShard = std::move(restoringSpillShards_.back());
Expand All @@ -175,22 +184,55 @@ uint64_t HashJoinMemoryReclaimer::reclaim(
uint64_t targetBytes,
uint64_t maxWaitMs,
memory::MemoryReclaimer::Stats& stats) {
bool hasReclaimedFromBuild{false};
bool hasReclaimedFromProbe{false};
uint64_t reclaimedBytes{0};
pool->visitChildren([&](memory::MemoryPool* child) {
VELOX_CHECK_EQ(child->kind(), memory::MemoryPool::Kind::kLeaf);
// The hash probe operator do not support memory reclaim.
if (!isHashBuildMemoryPool(*child)) {
return true;
const bool isBuild = isHashBuildMemoryPool(*child);
if (isBuild) {
if (!hasReclaimedFromBuild) {
hasReclaimedFromBuild = true;
reclaimedBytes = child->reclaim(targetBytes, maxWaitMs, stats);
}
// We only need to reclaim from any one of the hash build operators
// which will reclaim from all the peer hash build operators.
return !hasReclaimedFromProbe;
}
// We only need to reclaim from any one of the hash build operators
// which will reclaim from all the peer hash build operators.
reclaimedBytes = child->reclaim(targetBytes, maxWaitMs, stats);
return false;

if (!hasReclaimedFromProbe) {
hasReclaimedFromProbe = true;
reclaimedBytes = child->reclaim(targetBytes, maxWaitMs, stats);
}
return !hasReclaimedFromBuild;
});
return reclaimedBytes;
}

bool isHashBuildMemoryPool(const memory::MemoryPool& pool) {
return folly::StringPiece(pool.name()).endsWith("HashBuild");
}

bool isHashProbeMemoryPool(const memory::MemoryPool& pool) {
return folly::StringPiece(pool.name()).endsWith("HashProbe");
}

bool canHashJoinSpill(
const std::shared_ptr<const core::HashJoinNode>& joinNode,
const core::QueryConfig& queryConfig,
bool isBuild) {
if (!joinNode->canSpill(queryConfig)) {
return false;
}
return isBuild ? queryConfig.joinBuildSpillEnabled()
: queryConfig.joinProbeSpillEnabled();
}

RowTypePtr getTableSpillType(const RowTypePtr& tableType) {
auto names = tableType->names();
names.push_back("probedFlags");
auto types = tableType->children();
types.push_back(BOOLEAN());
return ROW(std::move(names), std::move(types));
}
} // namespace facebook::velox::exec
21 changes: 16 additions & 5 deletions velox/exec/HashJoinBridge.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,13 @@ class HashJoinBridge : public JoinBridge {
/// 'table'. The function returns true if there is spill data to restore
/// after HashProbe operators process 'table', otherwise false. This only
/// applies if the disk spilling is enabled.
bool setHashTable(
void setHashTable(
std::unique_ptr<BaseHashTable> table,
SpillPartitionSet spillPartitionSet,
bool hasNullKeys);

void setSpilledHashTable(SpillPartitionSet spillPartitionSet);

void setAntiJoinHasNullKeys();

/// Represents the result of HashBuild operators: a hash table, an optional
Expand Down Expand Up @@ -75,8 +77,7 @@ class HashJoinBridge : public JoinBridge {
/// HashBuild operators. If HashProbe operator calls this early, 'future' will
/// be set to wait asynchronously, otherwise the built table along with
/// optional spilling related information will be returned in HashBuildResult.
std::optional<HashBuildResult> tableOrFuture(
ContinueFuture* FOLLY_NONNULL future);
std::optional<HashBuildResult> tableOrFuture(ContinueFuture* future);

/// Invoked by HashProbe operator after finishes probing the built table to
/// set one of the previously spilled partition to restore. The HashBuild
Expand All @@ -102,8 +103,7 @@ class HashJoinBridge : public JoinBridge {
/// If HashBuild operator calls this early, 'future' will be set to wait
/// asynchronously. If there is no more spill data to restore, then
/// 'spillPartition' will be set to null in the returned SpillInput.
std::optional<SpillInput> spillInputOrFuture(
ContinueFuture* FOLLY_NONNULL future);
std::optional<SpillInput> spillInputOrFuture(ContinueFuture* future);

private:
uint32_t numBuilders_{0};
Expand Down Expand Up @@ -156,4 +156,15 @@ class HashJoinMemoryReclaimer final : public MemoryReclaimer {
/// Returns true if 'pool' is a hash build operator's memory pool. The check is
/// currently based on the pool name.
bool isHashBuildMemoryPool(const memory::MemoryPool& pool);

/// Returns true if 'pool' is a hash probe operator's memory pool. The check is
/// currently based on the pool name.
bool isHashProbeMemoryPool(const memory::MemoryPool& pool);

bool canHashJoinSpill(
const std::shared_ptr<const core::HashJoinNode>& joinNode,
const core::QueryConfig& queryConfig,
bool isBuild);

RowTypePtr getTableSpillType(const RowTypePtr& tableType);
} // namespace facebook::velox::exec
Loading

0 comments on commit 96612d5

Please sign in to comment.