From 8fcce4f591c2b3ffbf4d67b762c75281a56f6edd Mon Sep 17 00:00:00 2001 From: feiniaofeiafei Date: Sat, 16 Nov 2024 17:15:39 +0800 Subject: [PATCH] [improvement](nereids) support extract from disjunction in join on condition (#38479) (#43670) cherry-pick #38479 to branch-2.1 --- ...tSingleTableExpressionFromDisjunction.java | 54 +++++++---- ...gleTableExpressionFromDisjunctionTest.java | 36 +++++++ .../extract_from_disjunction_in_join.out | 94 +++++++++++++++++++ .../push_down_filter_through_window.out | 0 .../extract_from_disjunction_in_join.groovy | 83 ++++++++++++++++ .../push_down_filter_through_window.groovy | 0 6 files changed, 251 insertions(+), 16 deletions(-) create mode 100644 regression-test/data/nereids_rules_p0/push_down_filter/extract_from_disjunction_in_join.out rename regression-test/data/nereids_rules_p0/{push_down_filter_through_window => push_down_filter}/push_down_filter_through_window.out (100%) create mode 100644 regression-test/suites/nereids_rules_p0/push_down_filter/extract_from_disjunction_in_join.groovy rename regression-test/suites/nereids_rules_p0/{push_down_filter_through_window => push_down_filter}/push_down_filter_through_window.groovy (100%) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExtractSingleTableExpressionFromDisjunction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExtractSingleTableExpressionFromDisjunction.java index 2f8e1404b7199e..fe2d7072ef5d98 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExtractSingleTableExpressionFromDisjunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ExtractSingleTableExpressionFromDisjunction.java @@ -21,9 +21,11 @@ import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; import org.apache.doris.nereids.util.ExpressionUtils; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Lists; @@ -50,24 +52,44 @@ * 3. In old optimizer, there is `InferFilterRule` generates redundancy expressions. Its Nereid counterpart also need * `RemoveRedundantExpression`. *

- * TODO: This rule just match filter, but it could be applied to inner/cross join condition. */ -public class ExtractSingleTableExpressionFromDisjunction extends OneRewriteRuleFactory { +public class ExtractSingleTableExpressionFromDisjunction implements RewriteRuleFactory { + private static final ImmutableSet ALLOW_JOIN_TYPE = ImmutableSet.of(JoinType.INNER_JOIN, + JoinType.LEFT_OUTER_JOIN, JoinType.RIGHT_OUTER_JOIN, JoinType.LEFT_SEMI_JOIN, JoinType.RIGHT_SEMI_JOIN, + JoinType.LEFT_ANTI_JOIN, JoinType.RIGHT_ANTI_JOIN, JoinType.CROSS_JOIN, JoinType.FULL_OUTER_JOIN); + @Override - public Rule build() { - return logicalFilter().then(filter -> { - List dependentPredicates = extractDependentConjuncts(filter.getConjuncts()); - if (dependentPredicates.isEmpty()) { - return null; - } - Set newPredicates = ImmutableSet.builder() - .addAll(filter.getConjuncts()) - .addAll(dependentPredicates).build(); - if (newPredicates.size() == filter.getConjuncts().size()) { - return null; - } - return new LogicalFilter<>(newPredicates, filter.child()); - }).toRule(RuleType.EXTRACT_SINGLE_TABLE_EXPRESSION_FROM_DISJUNCTION); + public List buildRules() { + return ImmutableList.of( + logicalFilter().then(filter -> { + List dependentPredicates = extractDependentConjuncts(filter.getConjuncts()); + if (dependentPredicates.isEmpty()) { + return null; + } + Set newPredicates = ImmutableSet.builder() + .addAll(filter.getConjuncts()) + .addAll(dependentPredicates).build(); + if (newPredicates.size() == filter.getConjuncts().size()) { + return null; + } + return new LogicalFilter<>(newPredicates, filter.child()); + }).toRule(RuleType.EXTRACT_SINGLE_TABLE_EXPRESSION_FROM_DISJUNCTION), + logicalJoin().when(join -> ALLOW_JOIN_TYPE.contains(join.getJoinType())).then(join -> { + List dependentOtherPredicates = extractDependentConjuncts( + ImmutableSet.copyOf(join.getOtherJoinConjuncts())); + if (dependentOtherPredicates.isEmpty()) { + return null; + } + Set newOtherPredicates = ImmutableSet.builder() + .addAll(join.getOtherJoinConjuncts()) + .addAll(dependentOtherPredicates).build(); + if (newOtherPredicates.size() == join.getOtherJoinConjuncts().size()) { + return null; + } + return join.withJoinConjuncts(join.getHashJoinConjuncts(), + ImmutableList.copyOf(newOtherPredicates), + join.getMarkJoinConjuncts(), join.getJoinReorderContext()); + }).toRule(RuleType.EXTRACT_SINGLE_TABLE_EXPRESSION_FROM_DISJUNCTION)); } private List extractDependentConjuncts(Set conjuncts) { diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ExtractSingleTableExpressionFromDisjunctionTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ExtractSingleTableExpressionFromDisjunctionTest.java index fc55f473ee6417..39706d39f2cb0d 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ExtractSingleTableExpressionFromDisjunctionTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ExtractSingleTableExpressionFromDisjunctionTest.java @@ -29,6 +29,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.MemoPatternMatchSupported; import org.apache.doris.nereids.util.MemoTestUtils; import org.apache.doris.nereids.util.PlanChecker; @@ -41,6 +42,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; +import java.util.List; import java.util.Set; @TestInstance(TestInstance.Lifecycle.PER_CLASS) @@ -179,4 +181,38 @@ private boolean verifySingleTableExpression3(Set conjuncts) { return conjuncts.size() == 2 && conjuncts.contains(or); } + + /** + * test join otherJoinReorderContext + *(cid=1 and sage=10) or sgender=1 + * => + * (sage=10 or sgender=1) + */ + @Test + public void testExtract4() { + Expression expr = new Or( + new And( + new EqualTo(courseCid, new IntegerLiteral(1)), + new EqualTo(studentAge, new IntegerLiteral(10)) + ), + new EqualTo(studentGender, new IntegerLiteral(1)) + ); + Plan join = new LogicalJoin<>(JoinType.CROSS_JOIN, ExpressionUtils.EMPTY_CONDITION, ImmutableList.of(expr), + student, course, null); + PlanChecker.from(MemoTestUtils.createConnectContext(), join) + .applyTopDown(new ExtractSingleTableExpressionFromDisjunction()) + .matchesFromRoot( + logicalJoin() + .when(j -> verifySingleTableExpression4(j.getOtherJoinConjuncts())) + ); + Assertions.assertNotNull(studentGender); + } + + private boolean verifySingleTableExpression4(List conjuncts) { + Expression or = new Or( + new EqualTo(studentAge, new IntegerLiteral(10)), + new EqualTo(studentGender, new IntegerLiteral(1)) + ); + return conjuncts.size() == 2 && conjuncts.contains(or); + } } diff --git a/regression-test/data/nereids_rules_p0/push_down_filter/extract_from_disjunction_in_join.out b/regression-test/data/nereids_rules_p0/push_down_filter/extract_from_disjunction_in_join.out new file mode 100644 index 00000000000000..9077ecb24b9b56 --- /dev/null +++ b/regression-test/data/nereids_rules_p0/push_down_filter/extract_from_disjunction_in_join.out @@ -0,0 +1,94 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !left_semi -- +PhysicalResultSink +--hashJoin[LEFT_SEMI_JOIN] hashCondition=((t1.b = t2.b)) otherCondition=((((t2.a = 9) AND (t1.a = 1)) OR ((t1.a = 2) AND (t2.a = 8)))) +----filter(a IN (1, 2)) +------PhysicalOlapScan[extract_from_disjunction_in_join_t1] +----filter(a IN (8, 9)) +------PhysicalOlapScan[extract_from_disjunction_in_join_t2] + +-- !right_semi -- +PhysicalResultSink +--hashJoin[LEFT_SEMI_JOIN] hashCondition=((t1.b = t2.b)) otherCondition=((((t2.a = 9) AND (t1.a = 1)) OR ((t1.a = 2) AND (t2.a = 8)))) +----filter(a IN (8, 9)) +------PhysicalOlapScan[extract_from_disjunction_in_join_t2] +----filter(a IN (1, 2)) +------PhysicalOlapScan[extract_from_disjunction_in_join_t1] + +-- !left -- +PhysicalResultSink +--hashJoin[LEFT_OUTER_JOIN] hashCondition=((t1.b = t2.b)) otherCondition=((((t2.a = 9) AND (t1.a = 1)) OR ((t1.a = 2) AND (t2.a = 8))) and a IN (1, 2)) +----PhysicalOlapScan[extract_from_disjunction_in_join_t1] +----filter(a IN (8, 9)) +------PhysicalOlapScan[extract_from_disjunction_in_join_t2] + +-- !right -- +PhysicalResultSink +--hashJoin[LEFT_OUTER_JOIN] hashCondition=((t1.b = t2.b)) otherCondition=((((t2.a = 9) AND (t1.a = 1)) OR ((t1.a = 2) AND (t2.a = 8))) and a IN (8, 9)) +----PhysicalOlapScan[extract_from_disjunction_in_join_t2] +----filter(a IN (1, 2)) +------PhysicalOlapScan[extract_from_disjunction_in_join_t1] + +-- !left_anti -- +PhysicalResultSink +--hashJoin[LEFT_ANTI_JOIN] hashCondition=((t1.b = t2.b)) otherCondition=((((t2.a = 9) AND (t1.a = 1)) OR ((t1.a = 2) AND (t2.a = 8))) and a IN (1, 2)) +----PhysicalOlapScan[extract_from_disjunction_in_join_t1] +----filter(a IN (8, 9)) +------PhysicalOlapScan[extract_from_disjunction_in_join_t2] + +-- !right_anti -- +PhysicalResultSink +--hashJoin[LEFT_ANTI_JOIN] hashCondition=((t1.b = t2.b)) otherCondition=((((t2.a = 9) AND (t1.a = 1)) OR ((t1.a = 2) AND (t2.a = 8))) and a IN (8, 9)) +----PhysicalOlapScan[extract_from_disjunction_in_join_t2] +----filter(a IN (1, 2)) +------PhysicalOlapScan[extract_from_disjunction_in_join_t1] + +-- !inner -- +PhysicalResultSink +--hashJoin[INNER_JOIN] hashCondition=((t1.b = t2.b)) otherCondition=((((t2.a = 9) AND (t1.a = 1)) OR ((t1.a = 2) AND (t2.a = 8)))) +----filter(a IN (1, 2)) +------PhysicalOlapScan[extract_from_disjunction_in_join_t1] +----filter(a IN (8, 9)) +------PhysicalOlapScan[extract_from_disjunction_in_join_t2] + +-- !outer -- +PhysicalResultSink +--hashJoin[LEFT_OUTER_JOIN] hashCondition=((t1.b = t2.b)) otherCondition=((((t2.a = 9) AND (t1.a = 1)) OR ((t1.a = 2) AND (t2.a = 8))) and a IN (1, 2)) +----filter((t1.c = 3)) +------PhysicalOlapScan[extract_from_disjunction_in_join_t1] +----filter(a IN (8, 9)) +------PhysicalOlapScan[extract_from_disjunction_in_join_t2] + +-- !left_semi_res -- +1 +2 + +-- !right_semi_res -- +8 +9 + +-- !left_res -- +1 +2 +3 + +-- !right_res -- +\N +1 +2 + +-- !left_anti_res -- +3 + +-- !right_anti_res -- +7 + +-- !inner_res -- +1 +2 + +-- !outer_res -- +1 +2 +3 + diff --git a/regression-test/data/nereids_rules_p0/push_down_filter_through_window/push_down_filter_through_window.out b/regression-test/data/nereids_rules_p0/push_down_filter/push_down_filter_through_window.out similarity index 100% rename from regression-test/data/nereids_rules_p0/push_down_filter_through_window/push_down_filter_through_window.out rename to regression-test/data/nereids_rules_p0/push_down_filter/push_down_filter_through_window.out diff --git a/regression-test/suites/nereids_rules_p0/push_down_filter/extract_from_disjunction_in_join.groovy b/regression-test/suites/nereids_rules_p0/push_down_filter/extract_from_disjunction_in_join.groovy new file mode 100644 index 00000000000000..858f39e5e65cf2 --- /dev/null +++ b/regression-test/suites/nereids_rules_p0/push_down_filter/extract_from_disjunction_in_join.groovy @@ -0,0 +1,83 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("extract_from_disjunction_in_join") { + sql "SET enable_nereids_planner=true" + sql "SET enable_fallback_to_original_planner=false" + sql "set ignore_shape_nodes='PhysicalDistribute,PhysicalProject'" + sql "set disable_nereids_rules=PRUNE_EMPTY_PARTITION" + sql "set runtime_filter_mode=OFF" + + + sql "drop table if exists extract_from_disjunction_in_join_t1" + sql "drop table if exists extract_from_disjunction_in_join_t2" + sql """ + CREATE TABLE `extract_from_disjunction_in_join_t1` ( + `a` INT NULL, + `b` VARCHAR(10) NULL, + `c` INT NULL, + `d` INT NULL + ) ENGINE=OLAP + DUPLICATE KEY(`a`, `b`) + DISTRIBUTED BY RANDOM BUCKETS AUTO + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1" + ); + """ + sql """ + CREATE TABLE `extract_from_disjunction_in_join_t2` ( + `a` INT NULL, + `b` VARCHAR(10) NULL, + `c` INT NULL, + `d` INT NULL + ) ENGINE=OLAP + DUPLICATE KEY(`a`, `b`) + DISTRIBUTED BY RANDOM BUCKETS AUTO + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1" + );""" + + sql "insert into extract_from_disjunction_in_join_t1 values(1,'d2',3,5),(2,'d2',3,5),(3,'d2',3,5);" + sql "insert into extract_from_disjunction_in_join_t2 values(7,'d2',2,2),(8,'d2',2,2),(9,'d2',2,2);" + qt_left_semi """explain shape plan + select * from extract_from_disjunction_in_join_t1 t1 left semi join extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && t1.a=1 || t1.a=2 && t2.a=8);""" + qt_right_semi """explain shape plan + select * from extract_from_disjunction_in_join_t1 t1 right semi join extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && t1.a=1 || t1.a=2 && t2.a=8);""" + qt_left """explain shape plan + select * from extract_from_disjunction_in_join_t1 t1 left join extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && t1.a=1 || t1.a=2 && t2.a=8);""" + qt_right """explain shape plan + select * from extract_from_disjunction_in_join_t1 t1 right join extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && t1.a=1 || t1.a=2 && t2.a=8);""" + qt_left_anti """explain shape plan + select * from extract_from_disjunction_in_join_t1 t1 left anti join extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && t1.a=1 || t1.a=2 && t2.a=8);""" + qt_right_anti """explain shape plan + select * from extract_from_disjunction_in_join_t1 t1 right anti join extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && t1.a=1 || t1.a=2 && t2.a=8);""" + qt_inner """explain shape plan + select * from extract_from_disjunction_in_join_t1 t1 inner join extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && t1.a=1 || t1.a=2 && t2.a=8);""" + qt_outer """explain shape plan + select * from extract_from_disjunction_in_join_t1 t1 full join extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && t1.a=1 || t1.a=2 && t2.a=8) + where t1.c=3;""" + + qt_left_semi_res "select t1.a from extract_from_disjunction_in_join_t1 t1 left semi join extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && t1.a=1 || t1.a=2 && t2.a=8) order by 1;" + qt_right_semi_res "select t2.a from extract_from_disjunction_in_join_t1 t1 right semi join extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && t1.a=1 || t1.a=2 && t2.a=8) order by 1;" + qt_left_res "select t1.a from extract_from_disjunction_in_join_t1 t1 left join extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && t1.a=1 || t1.a=2 && t2.a=8) order by 1;" + qt_right_res "select t1.a from extract_from_disjunction_in_join_t1 t1 right join extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && t1.a=1 || t1.a=2 && t2.a=8) order by 1;" + qt_left_anti_res "select t1.a from extract_from_disjunction_in_join_t1 t1 left anti join extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && t1.a=1 || t1.a=2 && t2.a=8) order by 1;" + qt_right_anti_res "select t2.a from extract_from_disjunction_in_join_t1 t1 right anti join extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && t1.a=1 || t1.a=2 && t2.a=8) order by 1;" + qt_inner_res "select t1.a from extract_from_disjunction_in_join_t1 t1 inner join extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && t1.a=1 || t1.a=2 && t2.a=8) order by 1;" + qt_outer_res """select t1.a from extract_from_disjunction_in_join_t1 t1 full join extract_from_disjunction_in_join_t2 t2 on t1.b=t2.b and (t2.a=9 && t1.a=1 || t1.a=2 && t2.a=8) + where t1.c=3 order by 1;""" +} \ No newline at end of file diff --git a/regression-test/suites/nereids_rules_p0/push_down_filter_through_window/push_down_filter_through_window.groovy b/regression-test/suites/nereids_rules_p0/push_down_filter/push_down_filter_through_window.groovy similarity index 100% rename from regression-test/suites/nereids_rules_p0/push_down_filter_through_window/push_down_filter_through_window.groovy rename to regression-test/suites/nereids_rules_p0/push_down_filter/push_down_filter_through_window.groovy