Skip to content

Commit

Permalink
[Enhancement] Prune mv that not contains all columns used in sub quer…
Browse files Browse the repository at this point in the history
…y when mv is SPG (backport #51044) (#51106)

Co-authored-by: kaijianding <[email protected]>
  • Loading branch information
mergify[bot] and kaijianding authored Sep 20, 2024
1 parent 80ef3ea commit c2b00bc
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,8 @@ public AggRewriteInfo visitLogicalTableScan(OptExpression optExpression, AggRewr
// rewrite by mv.
OptExpression rewritten = doRewritePushDownAgg(ctx, optAggOp);
if (rewritten == null) {
logMVRewrite(mvRewriteContext, "Rewrite table scan node by mv failed");
logMVRewrite(mvRewriteContext,
"Rewrite table " + scanOp.getTable().getTableIdentifier() + " scan node by mv failed");
return AggRewriteInfo.NOT_REWRITE;
}
// Generate the push down aggregate function for the given call operator.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2674,7 +2674,7 @@ private static void getPermutationsOfTableIds(List<Integer> tableIds, int target
private Map<Table, Set<Integer>> getTableToRelationid(
OptExpression optExpression, ColumnRefFactory refFactory, List<Table> tableList) {
Map<Table, Set<Integer>> tableToRelationId = Maps.newHashMap();
List<ColumnRefOperator> validColumnRefs = MvUtils.collectScanColumn(optExpression);
Set<ColumnRefOperator> validColumnRefs = MvUtils.collectScanColumn(optExpression);
for (Map.Entry<ColumnRefOperator, Table> entry : refFactory.getColumnRefToTable().entrySet()) {
if (!tableList.contains(entry.getValue())) {
continue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@

import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
Expand Down Expand Up @@ -367,6 +369,7 @@ public static ScalarOperator compensateQueryPartitionPredicate(MaterializationCo
Map<Pair<LogicalScanOperator, Boolean>, List<ScalarOperator>> scanOperatorScalarOperatorMap =
mvContext.getScanOpToPartitionCompensatePredicates();
MaterializedView mv = mvContext.getMv();
final Set<Table> baseTables = new HashSet<>(mvContext.getBaseTables());
for (LogicalScanOperator scanOperator : scanOperators) {
if (!SUPPORTED_PARTITION_COMPENSATE_SCAN_TYPES.contains(scanOperator.getOpType())) {
// If the scan operator is not supported, then return null when compensate type is not NO_COMPENSATE
Expand All @@ -379,6 +382,9 @@ public static ScalarOperator compensateQueryPartitionPredicate(MaterializationCo
}
List<ScalarOperator> partitionPredicate = scanOperatorScalarOperatorMap
.computeIfAbsent(Pair.create(scanOperator, isCompensatePartition), x -> {
if (!baseTables.contains(scanOperator.getTable())) {
return Collections.emptyList();
}
return isCompensatePartition ? getCompensatePartitionPredicates(mvContext, columnRefFactory,
scanOperator) : getScanOpPrunedPartitionPredicates(mv, scanOperator);
});
Expand Down Expand Up @@ -503,26 +509,26 @@ private static List<ScalarOperator> compensatePartitionPredicateForExternalTable
*/
private static List<ScalarOperator> compensatePartitionPredicateForOlapScan(LogicalOlapScanOperator olapScanOperator,
ColumnRefFactory columnRefFactory) {
List<ScalarOperator> partitionPredicates = Lists.newArrayList();
Preconditions.checkState(olapScanOperator.getTable().isNativeTableOrMaterializedView());
OlapTable olapTable = (OlapTable) olapScanOperator.getTable();

// compensate nothing for single partition table
if (olapTable.getPartitionInfo() instanceof SinglePartitionInfo) {
return partitionPredicates;
return Collections.emptyList();
}

// compensate nothing if selected partitions are the same with the total partitions.
if (olapScanOperator.getSelectedPartitionId() != null
&& olapScanOperator.getSelectedPartitionId().size() == olapTable.getPartitions().size()) {
return partitionPredicates;
return Collections.emptyList();
}

// if no partitions are selected, return pruned partition predicates directly.
if (olapScanOperator.getSelectedPartitionId().isEmpty()) {
return olapScanOperator.getPrunedPartitionPredicates();
}

List<ScalarOperator> partitionPredicates = Lists.newArrayList();
if (olapTable.getPartitionInfo() instanceof ExpressionRangePartitionInfo) {
ExpressionRangePartitionInfo partitionInfo = (ExpressionRangePartitionInfo) olapTable.getPartitionInfo();
Expr partitionExpr = partitionInfo.getPartitionExprs(olapTable.getIdToColumn()).get(0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -871,14 +871,14 @@ public static Map<ColumnRefOperator, ScalarOperator> getColumnRefMap(
return columnRefMap;
}

public static List<ColumnRefOperator> collectScanColumn(OptExpression optExpression) {
public static Set<ColumnRefOperator> collectScanColumn(OptExpression optExpression) {
return collectScanColumn(optExpression, Predicates.alwaysTrue());
}

public static List<ColumnRefOperator> collectScanColumn(OptExpression optExpression,
public static Set<ColumnRefOperator> collectScanColumn(OptExpression optExpression,
Predicate<LogicalScanOperator> predicate) {

List<ColumnRefOperator> columnRefOperators = Lists.newArrayList();
Set<ColumnRefOperator> columnRefOperators = Sets.newHashSet();
OptExpressionVisitor visitor = new OptExpressionVisitor<Void, Void>() {
@Override
public Void visit(OptExpression optExpression, Void context) {
Expand Down Expand Up @@ -1163,7 +1163,7 @@ public static void inactiveRelatedMaterializedViews(Database db,
for (MvPlanContext mvPlanContext : mvPlanContexts) {
if (mvPlanContext != null) {
OptExpression mvPlan = mvPlanContext.getLogicalPlan();
List<ColumnRefOperator> usedColRefs = MvUtils.collectScanColumn(mvPlan, scan -> {
Set<ColumnRefOperator> usedColRefs = MvUtils.collectScanColumn(mvPlan, scan -> {
if (scan == null) {
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
package com.starrocks.sql.optimizer.rule.transformation.materialization.rule;

import com.google.api.client.util.Lists;
import com.google.common.base.Predicate;
import com.starrocks.catalog.Column;
import com.starrocks.sql.optimizer.MaterializationContext;
import com.starrocks.sql.optimizer.MvRewriteContext;
import com.starrocks.sql.optimizer.OptExpression;
Expand All @@ -28,11 +30,17 @@
import com.starrocks.sql.optimizer.operator.logical.LogicalProjectOperator;
import com.starrocks.sql.optimizer.operator.logical.LogicalScanOperator;
import com.starrocks.sql.optimizer.operator.pattern.Pattern;
import com.starrocks.sql.optimizer.operator.scalar.ColumnRefOperator;
import com.starrocks.sql.optimizer.rule.RuleType;
import com.starrocks.sql.optimizer.rule.transformation.materialization.AggregatedMaterializedViewPushDownRewriter;
import com.starrocks.sql.optimizer.rule.transformation.materialization.IMaterializedViewRewriter;
import com.starrocks.sql.optimizer.rule.transformation.materialization.MvUtils;

import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;

import static com.starrocks.sql.optimizer.OptimizerTraceUtil.logMVRewrite;

/**
* Support to push down aggregate functions below join operator and rewrite the query by mv transparently.
Expand Down Expand Up @@ -121,15 +129,42 @@ public static boolean isLogicalSPG(OptExpression root) {
public List<MaterializationContext> doPrune(OptExpression queryExpression,
OptimizerContext context,
List<MaterializationContext> mvCandidateContexts) {
List<LogicalScanOperator> scanOperators = MvUtils.getScanOperator(queryExpression);
List<MaterializationContext> validCandidateContexts = Lists.newArrayList();
for (MaterializationContext mvContext : mvCandidateContexts) {
if (isLogicalSPG(mvContext.getMvExpression())) {
if (isLogicalSPG(mvContext.getMvExpression()) && validMv(mvContext, scanOperators)) {
validCandidateContexts.add(mvContext);
} else {
logMVRewrite(mvContext, "mv pruned");
}
}
return validCandidateContexts;
}

private boolean validMv(MaterializationContext mvContext, List<LogicalScanOperator> scanOperators) {
// mv is SPG, so there is only one baseTable
long baseTableId = mvContext.getBaseTables().get(0).getId();
Set<ColumnRefOperator> mvUsedColRefs = MvUtils.collectScanColumn(mvContext.getMvExpression());
Set<String> mvUsedColNames = mvUsedColRefs.stream()
.map(ColumnRefOperator::getName)
.collect(Collectors.toSet());
for (LogicalScanOperator scanOperator : scanOperators) {
if (scanOperator.getTable().getId() != baseTableId) {
continue;
}
// mv should contain all columns that used in at least one query
if (mvContainsAllColumnsUsedInScan(mvUsedColNames, scanOperator)) {
return true;
}
}
return false;
}

private boolean mvContainsAllColumnsUsedInScan(Set<String> mvUsedColNames, LogicalScanOperator scanOperator) {
return scanOperator.getColRefToColumnMetaMap().values().stream().allMatch(
(Predicate<Column>) c -> mvUsedColNames.contains(c.getName()));
}

@Override
public IMaterializedViewRewriter createRewriter(OptimizerContext optimizerContext,
MvRewriteContext mvContext) {
Expand Down

0 comments on commit c2b00bc

Please sign in to comment.