Skip to content

Commit

Permalink
[improvement](nereids) support extract from disjunction in join on co…
Browse files Browse the repository at this point in the history
…ndition (#38479) (#43670)

cherry-pick #38479 to branch-2.1
  • Loading branch information
feiniaofeiafei authored Nov 16, 2024
1 parent b120e89 commit 8fcce4f
Show file tree
Hide file tree
Showing 6 changed files with 251 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.util.ExpressionUtils;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;

Expand All @@ -50,24 +52,44 @@
* 3. In old optimizer, there is `InferFilterRule` generates redundancy expressions. Its Nereid counterpart also need
* `RemoveRedundantExpression`.
* <p>
* TODO: This rule just match filter, but it could be applied to inner/cross join condition.
*/
public class ExtractSingleTableExpressionFromDisjunction extends OneRewriteRuleFactory {
public class ExtractSingleTableExpressionFromDisjunction implements RewriteRuleFactory {
private static final ImmutableSet<JoinType> ALLOW_JOIN_TYPE = ImmutableSet.of(JoinType.INNER_JOIN,
JoinType.LEFT_OUTER_JOIN, JoinType.RIGHT_OUTER_JOIN, JoinType.LEFT_SEMI_JOIN, JoinType.RIGHT_SEMI_JOIN,
JoinType.LEFT_ANTI_JOIN, JoinType.RIGHT_ANTI_JOIN, JoinType.CROSS_JOIN, JoinType.FULL_OUTER_JOIN);

@Override
public Rule build() {
return logicalFilter().then(filter -> {
List<Expression> dependentPredicates = extractDependentConjuncts(filter.getConjuncts());
if (dependentPredicates.isEmpty()) {
return null;
}
Set<Expression> newPredicates = ImmutableSet.<Expression>builder()
.addAll(filter.getConjuncts())
.addAll(dependentPredicates).build();
if (newPredicates.size() == filter.getConjuncts().size()) {
return null;
}
return new LogicalFilter<>(newPredicates, filter.child());
}).toRule(RuleType.EXTRACT_SINGLE_TABLE_EXPRESSION_FROM_DISJUNCTION);
public List<Rule> buildRules() {
return ImmutableList.of(
logicalFilter().then(filter -> {
List<Expression> dependentPredicates = extractDependentConjuncts(filter.getConjuncts());
if (dependentPredicates.isEmpty()) {
return null;
}
Set<Expression> newPredicates = ImmutableSet.<Expression>builder()
.addAll(filter.getConjuncts())
.addAll(dependentPredicates).build();
if (newPredicates.size() == filter.getConjuncts().size()) {
return null;
}
return new LogicalFilter<>(newPredicates, filter.child());
}).toRule(RuleType.EXTRACT_SINGLE_TABLE_EXPRESSION_FROM_DISJUNCTION),
logicalJoin().when(join -> ALLOW_JOIN_TYPE.contains(join.getJoinType())).then(join -> {
List<Expression> dependentOtherPredicates = extractDependentConjuncts(
ImmutableSet.copyOf(join.getOtherJoinConjuncts()));
if (dependentOtherPredicates.isEmpty()) {
return null;
}
Set<Expression> newOtherPredicates = ImmutableSet.<Expression>builder()
.addAll(join.getOtherJoinConjuncts())
.addAll(dependentOtherPredicates).build();
if (newOtherPredicates.size() == join.getOtherJoinConjuncts().size()) {
return null;
}
return join.withJoinConjuncts(join.getHashJoinConjuncts(),
ImmutableList.copyOf(newOtherPredicates),
join.getMarkJoinConjuncts(), join.getJoinReorderContext());
}).toRule(RuleType.EXTRACT_SINGLE_TABLE_EXPRESSION_FROM_DISJUNCTION));
}

private List<Expression> extractDependentConjuncts(Set<Expression> conjuncts) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanChecker;
Expand All @@ -41,6 +42,7 @@
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;

import java.util.List;
import java.util.Set;

@TestInstance(TestInstance.Lifecycle.PER_CLASS)
Expand Down Expand Up @@ -179,4 +181,38 @@ private boolean verifySingleTableExpression3(Set<Expression> conjuncts) {

return conjuncts.size() == 2 && conjuncts.contains(or);
}

/**
* test join otherJoinReorderContext
*(cid=1 and sage=10) or sgender=1
* =>
* (sage=10 or sgender=1)
*/
@Test
public void testExtract4() {
Expression expr = new Or(
new And(
new EqualTo(courseCid, new IntegerLiteral(1)),
new EqualTo(studentAge, new IntegerLiteral(10))
),
new EqualTo(studentGender, new IntegerLiteral(1))
);
Plan join = new LogicalJoin<>(JoinType.CROSS_JOIN, ExpressionUtils.EMPTY_CONDITION, ImmutableList.of(expr),
student, course, null);
PlanChecker.from(MemoTestUtils.createConnectContext(), join)
.applyTopDown(new ExtractSingleTableExpressionFromDisjunction())
.matchesFromRoot(
logicalJoin()
.when(j -> verifySingleTableExpression4(j.getOtherJoinConjuncts()))
);
Assertions.assertNotNull(studentGender);
}

private boolean verifySingleTableExpression4(List<Expression> conjuncts) {
Expression or = new Or(
new EqualTo(studentAge, new IntegerLiteral(10)),
new EqualTo(studentGender, new IntegerLiteral(1))
);
return conjuncts.size() == 2 && conjuncts.contains(or);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !left_semi --
PhysicalResultSink
--hashJoin[LEFT_SEMI_JOIN] hashCondition=((t1.b = t2.b)) otherCondition=((((t2.a = 9) AND (t1.a = 1)) OR ((t1.a = 2) AND (t2.a = 8))))
----filter(a IN (1, 2))
------PhysicalOlapScan[extract_from_disjunction_in_join_t1]
----filter(a IN (8, 9))
------PhysicalOlapScan[extract_from_disjunction_in_join_t2]

-- !right_semi --
PhysicalResultSink
--hashJoin[LEFT_SEMI_JOIN] hashCondition=((t1.b = t2.b)) otherCondition=((((t2.a = 9) AND (t1.a = 1)) OR ((t1.a = 2) AND (t2.a = 8))))
----filter(a IN (8, 9))
------PhysicalOlapScan[extract_from_disjunction_in_join_t2]
----filter(a IN (1, 2))
------PhysicalOlapScan[extract_from_disjunction_in_join_t1]

-- !left --
PhysicalResultSink
--hashJoin[LEFT_OUTER_JOIN] hashCondition=((t1.b = t2.b)) otherCondition=((((t2.a = 9) AND (t1.a = 1)) OR ((t1.a = 2) AND (t2.a = 8))) and a IN (1, 2))
----PhysicalOlapScan[extract_from_disjunction_in_join_t1]
----filter(a IN (8, 9))
------PhysicalOlapScan[extract_from_disjunction_in_join_t2]

-- !right --
PhysicalResultSink
--hashJoin[LEFT_OUTER_JOIN] hashCondition=((t1.b = t2.b)) otherCondition=((((t2.a = 9) AND (t1.a = 1)) OR ((t1.a = 2) AND (t2.a = 8))) and a IN (8, 9))
----PhysicalOlapScan[extract_from_disjunction_in_join_t2]
----filter(a IN (1, 2))
------PhysicalOlapScan[extract_from_disjunction_in_join_t1]

-- !left_anti --
PhysicalResultSink
--hashJoin[LEFT_ANTI_JOIN] hashCondition=((t1.b = t2.b)) otherCondition=((((t2.a = 9) AND (t1.a = 1)) OR ((t1.a = 2) AND (t2.a = 8))) and a IN (1, 2))
----PhysicalOlapScan[extract_from_disjunction_in_join_t1]
----filter(a IN (8, 9))
------PhysicalOlapScan[extract_from_disjunction_in_join_t2]

-- !right_anti --
PhysicalResultSink
--hashJoin[LEFT_ANTI_JOIN] hashCondition=((t1.b = t2.b)) otherCondition=((((t2.a = 9) AND (t1.a = 1)) OR ((t1.a = 2) AND (t2.a = 8))) and a IN (8, 9))
----PhysicalOlapScan[extract_from_disjunction_in_join_t2]
----filter(a IN (1, 2))
------PhysicalOlapScan[extract_from_disjunction_in_join_t1]

-- !inner --
PhysicalResultSink
--hashJoin[INNER_JOIN] hashCondition=((t1.b = t2.b)) otherCondition=((((t2.a = 9) AND (t1.a = 1)) OR ((t1.a = 2) AND (t2.a = 8))))
----filter(a IN (1, 2))
------PhysicalOlapScan[extract_from_disjunction_in_join_t1]
----filter(a IN (8, 9))
------PhysicalOlapScan[extract_from_disjunction_in_join_t2]

-- !outer --
PhysicalResultSink
--hashJoin[LEFT_OUTER_JOIN] hashCondition=((t1.b = t2.b)) otherCondition=((((t2.a = 9) AND (t1.a = 1)) OR ((t1.a = 2) AND (t2.a = 8))) and a IN (1, 2))
----filter((t1.c = 3))
------PhysicalOlapScan[extract_from_disjunction_in_join_t1]
----filter(a IN (8, 9))
------PhysicalOlapScan[extract_from_disjunction_in_join_t2]

-- !left_semi_res --
1
2

-- !right_semi_res --
8
9

-- !left_res --
1
2
3

-- !right_res --
\N
1
2

-- !left_anti_res --
3

-- !right_anti_res --
7

-- !inner_res --
1
2

-- !outer_res --
1
2
3

Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

suite("extract_from_disjunction_in_join") {
sql "SET enable_nereids_planner=true"
sql "SET enable_fallback_to_original_planner=false"
sql "set ignore_shape_nodes='PhysicalDistribute,PhysicalProject'"
sql "set disable_nereids_rules=PRUNE_EMPTY_PARTITION"
sql "set runtime_filter_mode=OFF"


sql "drop table if exists extract_from_disjunction_in_join_t1"
sql "drop table if exists extract_from_disjunction_in_join_t2"
sql """
CREATE TABLE `extract_from_disjunction_in_join_t1` (
`a` INT NULL,
`b` VARCHAR(10) NULL,
`c` INT NULL,
`d` INT NULL
) ENGINE=OLAP
DUPLICATE KEY(`a`, `b`)
DISTRIBUTED BY RANDOM BUCKETS AUTO
PROPERTIES (
"replication_allocation" = "tag.location.default: 1"
);
"""
sql """
CREATE TABLE `extract_from_disjunction_in_join_t2` (
`a` INT NULL,
`b` VARCHAR(10) NULL,
`c` INT NULL,
`d` INT NULL
) ENGINE=OLAP
DUPLICATE KEY(`a`, `b`)
DISTRIBUTED BY RANDOM BUCKETS AUTO
PROPERTIES (
"replication_allocation" = "tag.location.default: 1"
);"""

sql "insert into extract_from_disjunction_in_join_t1 values(1,'d2',3,5),(2,'d2',3,5),(3,'d2',3,5);"
sql "insert into extract_from_disjunction_in_join_t2 values(7,'d2',2,2),(8,'d2',2,2),(9,'d2',2,2);"
qt_left_semi """explain shape plan
select * from extract_from_disjunction_in_join_t1 t1 left semi join extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && t1.a=1 || t1.a=2 && t2.a=8);"""
qt_right_semi """explain shape plan
select * from extract_from_disjunction_in_join_t1 t1 right semi join extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && t1.a=1 || t1.a=2 && t2.a=8);"""
qt_left """explain shape plan
select * from extract_from_disjunction_in_join_t1 t1 left join extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && t1.a=1 || t1.a=2 && t2.a=8);"""
qt_right """explain shape plan
select * from extract_from_disjunction_in_join_t1 t1 right join extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && t1.a=1 || t1.a=2 && t2.a=8);"""
qt_left_anti """explain shape plan
select * from extract_from_disjunction_in_join_t1 t1 left anti join extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && t1.a=1 || t1.a=2 && t2.a=8);"""
qt_right_anti """explain shape plan
select * from extract_from_disjunction_in_join_t1 t1 right anti join extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && t1.a=1 || t1.a=2 && t2.a=8);"""
qt_inner """explain shape plan
select * from extract_from_disjunction_in_join_t1 t1 inner join extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && t1.a=1 || t1.a=2 && t2.a=8);"""
qt_outer """explain shape plan
select * from extract_from_disjunction_in_join_t1 t1 full join extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && t1.a=1 || t1.a=2 && t2.a=8)
where t1.c=3;"""

qt_left_semi_res "select t1.a from extract_from_disjunction_in_join_t1 t1 left semi join extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && t1.a=1 || t1.a=2 && t2.a=8) order by 1;"
qt_right_semi_res "select t2.a from extract_from_disjunction_in_join_t1 t1 right semi join extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && t1.a=1 || t1.a=2 && t2.a=8) order by 1;"
qt_left_res "select t1.a from extract_from_disjunction_in_join_t1 t1 left join extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && t1.a=1 || t1.a=2 && t2.a=8) order by 1;"
qt_right_res "select t1.a from extract_from_disjunction_in_join_t1 t1 right join extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && t1.a=1 || t1.a=2 && t2.a=8) order by 1;"
qt_left_anti_res "select t1.a from extract_from_disjunction_in_join_t1 t1 left anti join extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && t1.a=1 || t1.a=2 && t2.a=8) order by 1;"
qt_right_anti_res "select t2.a from extract_from_disjunction_in_join_t1 t1 right anti join extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && t1.a=1 || t1.a=2 && t2.a=8) order by 1;"
qt_inner_res "select t1.a from extract_from_disjunction_in_join_t1 t1 inner join extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && t1.a=1 || t1.a=2 && t2.a=8) order by 1;"
qt_outer_res """select t1.a from extract_from_disjunction_in_join_t1 t1 full join extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && t1.a=1 || t1.a=2 && t2.a=8)
where t1.c=3 order by 1;"""
}

0 comments on commit 8fcce4f

Please sign in to comment.