diff --git a/velox/exec/MergeJoin.cpp b/velox/exec/MergeJoin.cpp index b4281ea970d6..87e3dbf8648d 100644 --- a/velox/exec/MergeJoin.cpp +++ b/velox/exec/MergeJoin.cpp @@ -20,6 +20,13 @@ namespace facebook::velox::exec { +namespace { +bool supportsMergeJoin(std::shared_ptr joinNode) { + return joinNode->isInnerJoin() || joinNode->isLeftJoin() || + joinNode->isLeftSemiFilterJoin() || joinNode->isRightSemiFilterJoin() || + joinNode->isAntiJoin() || joinNode->isRightJoin(); +} +} // namespace MergeJoin::MergeJoin( int32_t operatorId, DriverCtx* driverCtx, @@ -35,10 +42,9 @@ MergeJoin::MergeJoin( numKeys_{joinNode->leftKeys().size()}, joinNode_(joinNode) { VELOX_USER_CHECK( - joinNode_->isInnerJoin() || joinNode_->isLeftJoin() || - joinNode_->isLeftSemiFilterJoin() || - joinNode_->isRightSemiFilterJoin() || joinNode_->isAntiJoin(), - "Merge join supports only inner, left and left semi joins. Other join types are not supported yet."); + supportsMergeJoin(joinNode_), + "The join type is not supported by merge join: ", + joinTypeName(joinNode_->joinType())); } void MergeJoin::initialize() { @@ -89,13 +95,14 @@ void MergeJoin::initialize() { if (joinNode_->filter()) { initializeFilter(joinNode_->filter(), leftType, rightType); - if (joinNode_->isLeftJoin() || joinNode_->isAntiJoin()) { - leftJoinTracker_ = LeftJoinTracker(outputBatchSize_, pool()); + if (joinNode_->isLeftJoin() || joinNode_->isAntiJoin() || + joinNode_->isRightJoin()) { + joinTracker_ = JoinTracker(outputBatchSize_, pool()); } } else if (joinNode_->isAntiJoin()) { // Anti join needs to track the left side rows that have no match on the // right. - leftJoinTracker_ = LeftJoinTracker(outputBatchSize_, pool()); + joinTracker_ = JoinTracker(outputBatchSize_, pool()); } joinNode_.reset(); @@ -183,6 +190,9 @@ BlockingReason MergeJoin::isBlocked(ContinueFuture* future) { } bool MergeJoin::needsInput() const { + if (isRightJoin(joinType_)) { + return (input_ == nullptr || rightInput_ == nullptr); + } return input_ == nullptr; } @@ -190,8 +200,8 @@ void MergeJoin::addInput(RowVectorPtr input) { input_ = std::move(input); index_ = 0; - if (leftJoinTracker_) { - leftJoinTracker_->resetLastVector(); + if (joinTracker_) { + joinTracker_->resetLastVector(); } } @@ -269,6 +279,7 @@ void copyRow( void MergeJoin::addOutputRowForLeftJoin( const RowVectorPtr& left, vector_size_t leftIndex) { + VELOX_USER_CHECK(isLeftJoin(joinType_) || isAntiJoin(joinType_)); rawLeftIndices_[outputSize_] = leftIndex; for (const auto& projection : rightProjections_) { @@ -276,9 +287,28 @@ void MergeJoin::addOutputRowForLeftJoin( target->setNull(outputSize_, true); } - if (leftJoinTracker_) { + if (joinTracker_) { // Record left-side row with no match on the right side. - leftJoinTracker_->addMiss(outputSize_); + joinTracker_->addMiss(outputSize_); + } + + ++outputSize_; +} + +void MergeJoin::addOutputRowForRightJoin( + const RowVectorPtr& right, + vector_size_t rightIndex) { + VELOX_USER_CHECK(isRightJoin(joinType_)); + rawRightIndices_[outputSize_] = rightIndex; + + for (const auto& projection : leftProjections_) { + const auto& target = output_->childAt(projection.outputChannel); + target->setNull(outputSize_, true); + } + + if (joinTracker_) { + // Record right-side row with no match on the left side. + joinTracker_->addMiss(outputSize_); } ++outputSize_; @@ -320,18 +350,23 @@ void MergeJoin::addOutputRow( copyRow(left, leftIndex, filterInput_, outputSize_, filterLeftInputs_); copyRow(right, rightIndex, filterInput_, outputSize_, filterRightInputs_); - if (leftJoinTracker_) { - // Record left-side row with a match on the right-side. - leftJoinTracker_->addMatch(left, leftIndex, outputSize_); + if (joinTracker_) { + if (isRightJoin(joinType_)) { + // Record right-side row with a match on the left-side. + joinTracker_->addMatch(right, rightIndex, outputSize_); + } else { + // Record left-side row with a match on the right-side. + joinTracker_->addMatch(left, leftIndex, outputSize_); + } } } // Anti join needs to track the left side rows that have no match on the // right. if (isAntiJoin(joinType_)) { - VELOX_CHECK(leftJoinTracker_); + VELOX_CHECK(joinTracker_); // Record left-side row with a match on the right-side. - leftJoinTracker_->addMatch(left, leftIndex, outputSize_); + joinTracker_->addMatch(left, leftIndex, outputSize_); } ++outputSize_; @@ -348,6 +383,10 @@ bool MergeJoin::prepareOutput( return true; } + if (isRightJoin(joinType_) && right != currentRight_) { + return true; + } + // If there is a new right, we need to flatten the dictionary. if (!isRightFlattened_ && right && currentRight_ != right) { flattenRightProjections(); @@ -363,14 +402,23 @@ bool MergeJoin::prepareOutput( rightIndices_ = allocateIndices(outputBatchSize_, pool()); rawRightIndices_ = rightIndices_->asMutable(); - // Create output dictionary vectors for left projections. + // Create left side projection outputs. std::vector localColumns(outputType_->size()); - for (const auto& projection : leftProjections_) { - localColumns[projection.outputChannel] = BaseVector::wrapInDictionary( - {}, - leftIndices_, - outputBatchSize_, - newLeft->childAt(projection.inputChannel)); + if (newLeft == nullptr) { + for (const auto& projection : leftProjections_) { + localColumns[projection.outputChannel] = BaseVector::create( + outputType_->childAt(projection.outputChannel), + outputBatchSize_, + operatorCtx_->pool()); + } + } else { + for (const auto& projection : leftProjections_) { + localColumns[projection.outputChannel] = BaseVector::wrapInDictionary( + {}, + leftIndices_, + outputBatchSize_, + newLeft->childAt(projection.inputChannel)); + } } currentLeft_ = newLeft; @@ -556,7 +604,7 @@ vector_size_t firstNonNull( RowVectorPtr MergeJoin::filterOutputForAntiJoin(const RowVectorPtr& output) { auto numRows = output->size(); - const auto& filterRows = leftJoinTracker_->matchingRows(numRows); + const auto& filterRows = joinTracker_->matchingRows(numRows); auto numPassed = 0; BufferPtr indices = allocateIndices(numRows, pool()); @@ -738,6 +786,35 @@ RowVectorPtr MergeJoin::doGetOutput() { output_->resize(outputSize_); return std::move(output_); } + } else if (isRightJoin(joinType_)) { + if (rightInput_ && noMoreInput_) { + // If output_ is currently wrapping a different buffer, return it + // first. + if (prepareOutput(nullptr, rightInput_)) { + output_->resize(outputSize_); + return std::move(output_); + } + + while (true) { + if (outputSize_ == outputBatchSize_) { + return std::move(output_); + } + + addOutputRowForRightJoin(rightInput_, rightIndex_); + + ++rightIndex_; + if (rightIndex_ == rightInput_->size()) { + // Ran out of rows on the right side. + rightInput_ = nullptr; + return nullptr; + } + } + } + + if (noMoreRightInput_ && output_) { + output_->resize(outputSize_); + return std::move(output_); + } } else { if (noMoreInput_ || noMoreRightInput_) { if (output_) { @@ -770,9 +847,11 @@ RowVectorPtr MergeJoin::doGetOutput() { return std::move(output_); } addOutputRowForLeftJoin(input_, index_); + ++index_; + } else { + index_ = firstNonNull(input_, leftKeys_, index_ + 1); } - ++index_; if (index_ == input_->size()) { // Ran out of rows on the left side. input_ = nullptr; @@ -783,7 +862,24 @@ RowVectorPtr MergeJoin::doGetOutput() { // Catch up rightInput_ with input_. while (compareResult > 0) { - rightIndex_ = firstNonNull(rightInput_, rightKeys_, rightIndex_ + 1); + if (isRightJoin(joinType_)) { + // If output_ is currently wrapping a different buffer, return it + // first. + if (prepareOutput(nullptr, rightInput_)) { + output_->resize(outputSize_); + return std::move(output_); + } + + if (outputSize_ == outputBatchSize_) { + return std::move(output_); + } + + addOutputRowForRightJoin(rightInput_, rightIndex_); + ++rightIndex_; + } else { + rightIndex_ = firstNonNull(rightInput_, rightKeys_, rightIndex_ + 1); + } + if (rightIndex_ == rightInput_->size()) { // Ran out of rows on the right side. rightInput_ = nullptr; @@ -862,8 +958,8 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) { auto rawIndices = indices->asMutable(); vector_size_t numPassed = 0; - if (leftJoinTracker_) { - const auto& filterRows = leftJoinTracker_->matchingRows(numRows); + if (joinTracker_) { + const auto& filterRows = joinTracker_->matchingRows(numRows); if (!filterRows.hasSelections()) { // No matches in the output, no need to evaluate the filter. @@ -878,9 +974,16 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) { if (!isAntiJoin(joinType_)) { rawIndices[numPassed++] = row; - for (auto& projection : rightProjections_) { - auto target = output->childAt(projection.outputChannel); - target->setNull(row, true); + if (!isRightJoin(joinType_)) { + for (auto& projection : rightProjections_) { + auto target = output->childAt(projection.outputChannel); + target->setNull(row, true); + } + } else { + for (auto& projection : leftProjections_) { + auto target = output->childAt(projection.outputChannel); + target->setNull(row, true); + } } } }; @@ -890,7 +993,7 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) { const bool passed = !decodedFilterResult_.isNullAt(i) && decodedFilterResult_.valueAt(i); - leftJoinTracker_->processFilterResult(i, passed, onMiss); + joinTracker_->processFilterResult(i, passed, onMiss); if (isAntiJoin(joinType_)) { if (!passed) { @@ -927,8 +1030,8 @@ RowVectorPtr MergeJoin::applyFilter(const RowVectorPtr& output) { // 2. leftMatch_ may not be nullopt, but may be related to a different // (subsequent) left key. So we check if the last row in the batch has the // same left row number as the last key match. - if (!leftMatch_ || !leftJoinTracker_->isCurrentLeftMatch(numRows - 1)) { - leftJoinTracker_->noMoreFilterResults(onMiss); + if (!leftMatch_ || !joinTracker_->isCurrentLeftMatch(numRows - 1)) { + joinTracker_->noMoreFilterResults(onMiss); } } else { filterRows_.resize(numRows); @@ -966,6 +1069,12 @@ void MergeJoin::evaluateFilter(const SelectivityVector& rows) { } bool MergeJoin::isFinished() { + if (isRightJoin(joinType_)) { + // If all rows on both the left and right sides match, we must also verify + // the 'noMoreInput_' on the left side to ensure that all results are + // complete. + return noMoreInput_ && noMoreRightInput_ && rightInput_ == nullptr; + } return noMoreInput_ && input_ == nullptr; } diff --git a/velox/exec/MergeJoin.h b/velox/exec/MergeJoin.h index e5f5b71999f1..42222f83ae2e 100644 --- a/velox/exec/MergeJoin.h +++ b/velox/exec/MergeJoin.h @@ -215,6 +215,13 @@ class MergeJoin : public Operator { const RowVectorPtr& left, vector_size_t leftIndex); + /// Adds one row of output for a right-side row with no left-side match. + /// Copies values from the 'rightIndex' row of 'right' and fills in nulls + /// for columns that correspond to the right side. + void addOutputRowForRightJoin( + const RowVectorPtr& right, + vector_size_t rightIndex); + /// Evaluates join filter on 'filterInput_' and returns 'output' that contains /// a subset of rows on which the filter passed. Returns nullptr if no rows /// passed the filter. @@ -231,9 +238,9 @@ class MergeJoin : public Operator { /// rows from the left side that have a match on the right. RowVectorPtr filterOutputForAntiJoin(const RowVectorPtr& output); - /// As we populate the results of the left join, we track whether a given + /// As we populate the results of the join, we track whether a given /// output row is a result of a match between left and right sides or a miss. - /// We use LeftJoinTracker::addMatch and addMiss methods for that. + /// We use JoinTracker::addMatch and addMiss methods for that. /// /// The semantic of the filter is to include at least one left side row in the /// output after filters are applied. Therefore: @@ -256,8 +263,8 @@ class MergeJoin : public Operator { /// block, we keep the subset of passing rows. However, if the filter failed /// on all rows in such a block, we add one of these rows back and update /// build-side columns to null. - struct LeftJoinTracker { - LeftJoinTracker(vector_size_t numRows, memory::MemoryPool* pool) + struct JoinTracker { + JoinTracker(vector_size_t numRows, memory::MemoryPool* pool) : matchingRows_{numRows, false} { leftRowNumbers_ = AlignedBuffer::allocate(numRows, pool); rawLeftRowNumbers_ = leftRowNumbers_->asMutable(); @@ -391,7 +398,8 @@ class MergeJoin : public Operator { bool currentRowPassed_{false}; }; - std::optional leftJoinTracker_{std::nullopt}; + /// Used to record both left and right join. + std::optional joinTracker_{std::nullopt}; // Indices buffer used by the output dictionaries. All projection from the // left share `leftIndices_`, and projections in the right share diff --git a/velox/exec/fuzzer/JoinFuzzer.cpp b/velox/exec/fuzzer/JoinFuzzer.cpp index 87bc26b15376..e3b969f46115 100644 --- a/velox/exec/fuzzer/JoinFuzzer.cpp +++ b/velox/exec/fuzzer/JoinFuzzer.cpp @@ -861,7 +861,7 @@ void JoinFuzzer::makeAlternativePlans( // Use OrderBy + MergeJoin if (joinNode->isInnerJoin() || joinNode->isLeftJoin() || joinNode->isLeftSemiFilterJoin() || joinNode->isRightSemiFilterJoin() || - joinNode->isAntiJoin()) { + joinNode->isAntiJoin() || joinNode->isRightJoin()) { auto planWithSplits = makeMergeJoinPlan( joinType, probeKeys, buildKeys, probeInput, buildInput, outputColumns); plans.push_back(planWithSplits); diff --git a/velox/exec/tests/MergeJoinTest.cpp b/velox/exec/tests/MergeJoinTest.cpp index 27b99f7373ef..a91e62ca7b17 100644 --- a/velox/exec/tests/MergeJoinTest.cpp +++ b/velox/exec/tests/MergeJoinTest.cpp @@ -155,34 +155,69 @@ class MergeJoinTest : public HiveConnectorTestBase { // Test LEFT join. planNodeIdGenerator = std::make_shared(); - plan = PlanBuilder(planNodeIdGenerator) - .values(left) - .mergeJoin( - {"c0"}, - {"u_c0"}, - PlanBuilder(planNodeIdGenerator) - .values(right) - .project({"c1 as u_c1", "c0 as u_c0"}) - .planNode(), - "", - {"c0", "c1", "u_c1"}, - core::JoinType::kLeft) - .planNode(); + auto leftPlan = PlanBuilder(planNodeIdGenerator) + .values(left) + .mergeJoin( + {"c0"}, + {"u_c0"}, + PlanBuilder(planNodeIdGenerator) + .values(right) + .project({"c1 as u_c1", "c0 as u_c0"}) + .planNode(), + "", + {"c0", "c1", "u_c1"}, + core::JoinType::kLeft) + .planNode(); // Use very small output batch size. assertQuery( - makeCursorParameters(plan, 16), + makeCursorParameters(leftPlan, 16), "SELECT t.c0, t.c1, u.c1 FROM t LEFT JOIN u ON t.c0 = u.c0"); // Use regular output batch size. assertQuery( - makeCursorParameters(plan, 1024), + makeCursorParameters(leftPlan, 1024), "SELECT t.c0, t.c1, u.c1 FROM t LEFT JOIN u ON t.c0 = u.c0"); // Use very large output batch size. assertQuery( - makeCursorParameters(plan, 10'000), + makeCursorParameters(leftPlan, 10'000), "SELECT t.c0, t.c1, u.c1 FROM t LEFT JOIN u ON t.c0 = u.c0"); + + // Test RIGHT join. + planNodeIdGenerator = std::make_shared(); + auto rightPlan = PlanBuilder(planNodeIdGenerator) + .values(right) + .mergeJoin( + {"c0"}, + {"u_c0"}, + PlanBuilder(planNodeIdGenerator) + .values(left) + .project({"c1 as u_c1", "c0 as u_c0"}) + .planNode(), + "", + {"u_c0", "u_c1", "c1"}, + core::JoinType::kRight) + .planNode(); + + // Use very small output batch size. + assertQuery( + makeCursorParameters(rightPlan, 16), + "SELECT t.c0, t.c1, u.c1 FROM u RIGHT JOIN t ON t.c0 = u.c0"); + + // Use regular output batch size. + assertQuery( + makeCursorParameters(rightPlan, 1024), + "SELECT t.c0, t.c1, u.c1 FROM u RIGHT JOIN t ON t.c0 = u.c0"); + + // Use very large output batch size. + assertQuery( + makeCursorParameters(rightPlan, 10'000), + "SELECT t.c0, t.c1, u.c1 FROM u RIGHT JOIN t ON t.c0 = u.c0"); + + // Test right join and left join with same result. + auto expectedResult = AssertQueryBuilder(leftPlan).copyResults(pool_.get()); + AssertQueryBuilder(rightPlan).assertResults(expectedResult); } }; @@ -346,7 +381,7 @@ TEST_F(MergeJoinTest, innerJoinFilter) { "SELECT t_c0, u_c0, u_c1 FROM t, u WHERE t_c0 = u_c0 AND (t_c1 + u_c1) % 2 = 0"); } -TEST_F(MergeJoinTest, leftJoinFilter) { +TEST_F(MergeJoinTest, leftAndRightJoinFilter) { // Each row on the left side has at most one match on the right side. auto left = makeRowVector( {"t_c0", "t_c1"}, @@ -366,7 +401,7 @@ TEST_F(MergeJoinTest, leftJoinFilter) { createDuckDbTable("u", {right}); auto planNodeIdGenerator = std::make_shared(); - auto plan = [&](const std::string& filter) { + auto leftPlan = [&](const std::string& filter) { return PlanBuilder(planNodeIdGenerator) .values({left}) .mergeJoin( @@ -379,11 +414,28 @@ TEST_F(MergeJoinTest, leftJoinFilter) { .planNode(); }; + auto rightPlan = [&](const std::string& filter) { + return PlanBuilder(planNodeIdGenerator) + .values({right}) + .mergeJoin( + {"u_c0"}, + {"t_c0"}, + PlanBuilder(planNodeIdGenerator).values({left}).planNode(), + filter, + {"t_c0", "t_c1", "u_c1"}, + core::JoinType::kRight) + .planNode(); + }; + // Test with different output batch sizes. for (auto batchSize : {1, 3, 16}) { assertQuery( - makeCursorParameters(plan("(t_c1 + u_c1) % 2 = 0"), batchSize), + makeCursorParameters(leftPlan("(t_c1 + u_c1) % 2 = 0"), batchSize), "SELECT t_c0, t_c1, u_c1 FROM t LEFT JOIN u ON t_c0 = u_c0 AND (t_c1 + u_c1) % 2 = 0"); + + assertQuery( + makeCursorParameters(rightPlan("(t_c1 + u_c1) % 2 = 0"), batchSize), + "SELECT t_c0, t_c1, u_c1 FROM u RIGHT JOIN t ON t_c0 = u_c0 AND (t_c1 + u_c1) % 2 = 0"); } // A left-side row with multiple matches on the right side. @@ -412,10 +464,15 @@ TEST_F(MergeJoinTest, leftJoinFilter) { "t_c1 + u_c1 > 100", "t_c1 + u_c1 < 100"}) { assertQuery( - makeCursorParameters(plan(filter), batchSize), + makeCursorParameters(leftPlan(filter), batchSize), fmt::format( "SELECT t_c0, t_c1, u_c1 FROM t LEFT JOIN u ON t_c0 = u_c0 AND {}", filter)); + assertQuery( + makeCursorParameters(rightPlan(filter), batchSize), + fmt::format( + "SELECT t_c0, t_c1, u_c1 FROM u RIGHT JOIN t ON t_c0 = u_c0 AND {}", + filter)); } } } @@ -592,6 +649,52 @@ TEST_F(MergeJoinTest, semiJoin) { core::JoinType::kRightSemiFilter); } +TEST_F(MergeJoinTest, rightJoin) { + auto left = makeRowVector( + {"t0"}, + {makeNullableFlatVector( + {1, 2, std::nullopt, 5, 6, std::nullopt})}); + + auto right = makeRowVector( + {"u0"}, + {makeNullableFlatVector( + {1, 5, 6, 8, std::nullopt, std::nullopt})}); + + createDuckDbTable("t", {left}); + createDuckDbTable("u", {right}); + + // Right join. + auto planNodeIdGenerator = std::make_shared(); + auto rightPlan = + PlanBuilder(planNodeIdGenerator) + .values({left}) + .mergeJoin( + {"t0"}, + {"u0"}, + PlanBuilder(planNodeIdGenerator).values({right}).planNode(), + "t0 > 2", + {"t0", "u0"}, + core::JoinType::kRight) + .planNode(); + AssertQueryBuilder(rightPlan, duckDbQueryRunner_) + .assertResults( + "SELECT * FROM t RIGHT JOIN u ON t.t0 = u.u0 AND t.t0 > 2"); + + auto leftPlan = + PlanBuilder(planNodeIdGenerator) + .values({right}) + .mergeJoin( + {"u0"}, + {"t0"}, + PlanBuilder(planNodeIdGenerator).values({left}).planNode(), + "t0 > 2", + {"t0", "u0"}, + core::JoinType::kLeft) + .planNode(); + auto expectedResult = AssertQueryBuilder(leftPlan).copyResults(pool_.get()); + AssertQueryBuilder(rightPlan).assertResults(expectedResult); +} + TEST_F(MergeJoinTest, nullKeys) { auto left = makeRowVector( {"t0"}, {makeNullableFlatVector({1, 2, 5, std::nullopt})});