Skip to content

Commit

Permalink
[mv](nereids) mv cost related PRs (apache#35652 apache#35701 apache#3…
Browse files Browse the repository at this point in the history
…5864 apache#36368 apache#36789 apache#34970) (apache#37097)

## Proposed changes
pick from apache#35652 apache#35701 apache#35864 apache#36368 apache#36789 apache#34970

Issue Number: close #xxx

<!--Describe your changes.-->
  • Loading branch information
englefly authored Jul 4, 2024
1 parent 077fda4 commit 26be313
Show file tree
Hide file tree
Showing 25 changed files with 494 additions and 299 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,18 @@ public void setSchema(List<Column> newSchema) throws IOException {
initColumnNameMap();
}

public List<Column> getPrefixKeyColumns() {
List<Column> keys = Lists.newArrayList();
for (Column col : schema) {
if (col.isKey()) {
keys.add(col);
} else {
break;
}
}
return keys;
}

public void setSchemaHash(int newSchemaHash) {
this.schemaHash = newSchemaHash;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,27 @@

package org.apache.doris.nereids.cost;

import org.apache.doris.catalog.Column;
import org.apache.doris.catalog.KeysType;
import org.apache.doris.catalog.OlapTable;
import org.apache.doris.nereids.PlanContext;
import org.apache.doris.nereids.properties.DistributionSpec;
import org.apache.doris.nereids.properties.DistributionSpecGather;
import org.apache.doris.nereids.properties.DistributionSpecHash;
import org.apache.doris.nereids.properties.DistributionSpecReplicated;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.OlapScan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalAssertNumRows;
import org.apache.doris.nereids.trees.plans.physical.PhysicalDeferMaterializeOlapScan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalDeferMaterializeTopN;
import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute;
import org.apache.doris.nereids.trees.plans.physical.PhysicalEsScan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalFileScan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalFilter;
import org.apache.doris.nereids.trees.plans.physical.PhysicalGenerate;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin;
Expand All @@ -52,8 +58,11 @@
import org.apache.doris.statistics.Statistics;

import com.google.common.base.Preconditions;
import com.google.common.collect.Sets;

import java.util.Collections;
import java.util.List;
import java.util.Set;

class CostModelV1 extends PlanVisitor<Cost, PlanContext> {

Expand Down Expand Up @@ -113,6 +122,57 @@ public Cost visitPhysicalOlapScan(PhysicalOlapScan physicalOlapScan, PlanContext
return CostV1.ofCpu(context.getSessionVariable(), rows - aggMvBonus);
}

private Set<Column> getColumnForRangePredicate(Set<Expression> expressions) {
Set<Column> columns = Sets.newHashSet();
for (Expression expr : expressions) {
if (expr instanceof ComparisonPredicate) {
ComparisonPredicate compare = (ComparisonPredicate) expr;
boolean hasLiteral = compare.left() instanceof Literal || compare.right() instanceof Literal;
boolean hasSlot = compare.left() instanceof SlotReference || compare.right() instanceof SlotReference;
if (hasSlot && hasLiteral) {
if (compare.left() instanceof SlotReference) {
if (((SlotReference) compare.left()).getColumn().isPresent()) {
columns.add(((SlotReference) compare.left()).getColumn().get());
}
} else {
if (((SlotReference) compare.right()).getColumn().isPresent()) {
columns.add(((SlotReference) compare.right()).getColumn().get());
}
}
}
}
}
return columns;
}

@Override
public Cost visitPhysicalFilter(PhysicalFilter<? extends Plan> filter, PlanContext context) {
double exprCost = expressionTreeCost(filter.getExpressions());
double filterCostFactor = 0.0001;
if (ConnectContext.get() != null) {
filterCostFactor = ConnectContext.get().getSessionVariable().filterCostFactor;
}
int prefixIndexMatched = 0;
if (filter.getGroupExpression().isPresent()) {
OlapScan olapScan = (OlapScan) filter.getGroupExpression().get().getFirstChildPlan(OlapScan.class);
if (olapScan != null) {
// check prefix index
long idxId = olapScan.getSelectedIndexId();
List<Column> keyColumns = olapScan.getTable().getIndexMetaByIndexId(idxId).getPrefixKeyColumns();
Set<Column> predicateColumns = getColumnForRangePredicate(filter.getConjuncts());
for (Column col : keyColumns) {
if (predicateColumns.contains(col)) {
prefixIndexMatched++;
} else {
break;
}
}
}
}
return CostV1.ofCpu(context.getSessionVariable(),
(filter.getConjuncts().size() - prefixIndexMatched + exprCost) * filterCostFactor);
}

@Override
public Cost visitPhysicalDeferMaterializeOlapScan(PhysicalDeferMaterializeOlapScan deferMaterializeOlapScan,
PlanContext context) {
Expand Down Expand Up @@ -141,7 +201,8 @@ public Cost visitPhysicalFileScan(PhysicalFileScan physicalFileScan, PlanContext

@Override
public Cost visitPhysicalProject(PhysicalProject<? extends Plan> physicalProject, PlanContext context) {
return CostV1.ofCpu(context.getSessionVariable(), 1);
double exprCost = expressionTreeCost(physicalProject.getProjects());
return CostV1.ofCpu(context.getSessionVariable(), exprCost + 1);
}

@Override
Expand Down Expand Up @@ -252,16 +313,29 @@ public Cost visitPhysicalDistribute(
intputRowCount * childStatistics.dataSizeFactor() * RANDOM_SHUFFLE_TO_HASH_SHUFFLE_FACTOR / beNumber);
}

private double expressionTreeCost(List<? extends Expression> expressions) {
double exprCost = 0.0;
ExpressionCostEvaluator expressionCostEvaluator = new ExpressionCostEvaluator();
for (Expression expr : expressions) {
if (!(expr instanceof SlotReference)) {
exprCost += expr.accept(expressionCostEvaluator, null);
}
}
return exprCost;
}

@Override
public Cost visitPhysicalHashAggregate(
PhysicalHashAggregate<? extends Plan> aggregate, PlanContext context) {
Statistics inputStatistics = context.getChildStatistics(0);
double exprCost = expressionTreeCost(aggregate.getExpressions());
if (aggregate.getAggPhase().isLocal()) {
return CostV1.of(context.getSessionVariable(), inputStatistics.getRowCount() / beNumber,
return CostV1.of(context.getSessionVariable(),
exprCost / 100 + inputStatistics.getRowCount() / beNumber,
inputStatistics.getRowCount() / beNumber, 0);
} else {
// global
return CostV1.of(context.getSessionVariable(), inputStatistics.getRowCount(),
return CostV1.of(context.getSessionVariable(), exprCost / 100 + inputStatistics.getRowCount(),
inputStatistics.getRowCount(), 0);
}
}
Expand Down Expand Up @@ -289,7 +363,7 @@ public Cost visitPhysicalHashJoin(

double leftRowCount = probeStats.getRowCount();
double rightRowCount = buildStats.getRowCount();
if (leftRowCount == rightRowCount
if ((long) leftRowCount == (long) rightRowCount
&& physicalHashJoin.getGroupExpression().isPresent()
&& physicalHashJoin.getGroupExpression().get().getOwnerGroup() != null
&& !physicalHashJoin.getGroupExpression().get().getOwnerGroup().isStatsReliable()) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.cost;

import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.CharType;
import org.apache.doris.nereids.types.DecimalV2Type;
import org.apache.doris.nereids.types.DecimalV3Type;
import org.apache.doris.nereids.types.MapType;
import org.apache.doris.nereids.types.StringType;
import org.apache.doris.nereids.types.StructType;
import org.apache.doris.nereids.types.VarcharType;

import com.google.common.collect.Maps;

import java.util.Map;

/**
* expression cost is calculated by
* 1. non-leaf tree node count: N
* 2. expression which contains input of stringType or complexType(array/json/struct...), add cost
*/
public class ExpressionCostEvaluator extends ExpressionVisitor<Double, Void> {
private static Map<Class, Double> dataTypeCost = Maps.newHashMap();

static {
dataTypeCost.put(DecimalV2Type.class, 1.5);
dataTypeCost.put(DecimalV3Type.class, 1.5);
dataTypeCost.put(StringType.class, 2.0);
dataTypeCost.put(CharType.class, 2.0);
dataTypeCost.put(VarcharType.class, 2.0);
dataTypeCost.put(ArrayType.class, 3.0);
dataTypeCost.put(MapType.class, 3.0);
dataTypeCost.put(StructType.class, 3.0);
}

@Override
public Double visit(Expression expr, Void context) {
double cost = 0.0;
for (Expression child : expr.children()) {
cost += child.accept(this, context);
// the more children, the more computing cost
cost += dataTypeCost.getOrDefault(child.getDataType().getClass(), 0.1);
}
return cost;
}

@Override
public Double visitSlotReference(SlotReference slot, Void context) {
return 0.0;
}

@Override
public Double visitLiteral(Literal literal, Void context) {
return 0.0;
}

@Override
public Double visitAlias(Alias alias, Void context) {
Expression child = alias.child();
if (child instanceof SlotReference) {
return 0.0;
}
return alias.child().accept(this, context);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -349,4 +349,28 @@ public String toString() {
public ObjectId getId() {
return id;
}

/**
* the first child plan of clazz
* @param clazz the operator type, like join/aggregate
* @return child operator of type clazz, if not found, return null
*/
public Plan getFirstChildPlan(Class clazz) {
for (Group childGroup : children) {
for (GroupExpression logical : childGroup.getLogicalExpressions()) {
if (clazz.isInstance(logical.getPlan())) {
return logical.getPlan();
}
}
}
// for dphyp
for (Group childGroup : children) {
for (GroupExpression physical : childGroup.getPhysicalExpressions()) {
if (clazz.isInstance(physical.getPlan())) {
return physical.getPlan();
}
}
}
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -603,24 +603,24 @@ private Statistics computeAssertNumRows(AssertNumRowsElement assertNumRowsElemen

private Statistics computeFilter(Filter filter) {
Statistics stats = groupExpression.childStatistics(0);
Plan plan = tryToFindChild(groupExpression);
boolean isOnBaseTable = false;
if (plan != null) {
if (plan instanceof OlapScan) {
isOnBaseTable = true;
} else if (plan instanceof Aggregate) {
Aggregate agg = ((Aggregate<?>) plan);
List<NamedExpression> expressions = agg.getOutputExpressions();
Set<Slot> slots = expressions
.stream()
.filter(Alias.class::isInstance)
.filter(s -> ((Alias) s).child().anyMatch(AggregateFunction.class::isInstance))
.map(NamedExpression::toSlot).collect(Collectors.toSet());
Expression predicate = filter.getPredicate();
if (predicate.anyMatch(s -> slots.contains(s))) {
return new FilterEstimation(slots).estimate(filter.getPredicate(), stats);
}
} else if (plan instanceof LogicalJoin && filter instanceof LogicalFilter
if (groupExpression.getFirstChildPlan(OlapScan.class) != null) {
return new FilterEstimation(true).estimate(filter.getPredicate(), stats);
}
if (groupExpression.getFirstChildPlan(Aggregate.class) != null) {
Aggregate agg = (Aggregate<?>) groupExpression.getFirstChildPlan(Aggregate.class);
List<NamedExpression> expressions = agg.getOutputExpressions();
Set<Slot> slots = expressions
.stream()
.filter(Alias.class::isInstance)
.filter(s -> ((Alias) s).child().anyMatch(AggregateFunction.class::isInstance))
.map(NamedExpression::toSlot).collect(Collectors.toSet());
Expression predicate = filter.getPredicate();
if (predicate.anyMatch(s -> slots.contains(s))) {
return new FilterEstimation(slots).estimate(filter.getPredicate(), stats);
}
} else if (groupExpression.getFirstChildPlan(LogicalJoin.class) != null) {
LogicalJoin plan = (LogicalJoin) groupExpression.getFirstChildPlan(LogicalJoin.class);
if (filter instanceof LogicalFilter
&& filter.getConjuncts().stream().anyMatch(e -> e instanceof IsNull)) {
Statistics isNullStats = computeGeneratedIsNullStats((LogicalJoin) plan, filter);
if (isNullStats != null) {
Expand All @@ -640,8 +640,7 @@ private Statistics computeFilter(Filter filter) {
}
}
}

return new FilterEstimation(isOnBaseTable).estimate(filter.getPredicate(), stats);
return new FilterEstimation(false).estimate(filter.getPredicate(), stats);
}

private Statistics computeGeneratedIsNullStats(LogicalJoin join, Filter filter) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1205,6 +1205,8 @@ public void setEnableLeftZigZag(boolean enableLeftZigZag) {
@VariableMgr.VarAttr(name = ENABLE_NEW_COST_MODEL, needForward = true)
private boolean enableNewCostModel = false;

@VariableMgr.VarAttr(name = "filter_cost_factor", needForward = true)
public double filterCostFactor = 0.0001;
@VariableMgr.VarAttr(name = NEREIDS_STAR_SCHEMA_SUPPORT)
private boolean nereidsStarSchemaSupport = true;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,22 @@ PhysicalResultSink
----------------------------PhysicalOlapScan[customer_demographics] apply RFs: RF4
----------------------------PhysicalDistribute[DistributionSpecReplicated]
------------------------------PhysicalProject
--------------------------------hashJoin[RIGHT_SEMI_JOIN] hashCondition=((c.c_customer_sk = store_sales.ss_customer_sk)) otherCondition=() build RFs:RF3 c_customer_sk->[ss_customer_sk]
--------------------------------hashJoin[RIGHT_SEMI_JOIN] hashCondition=((c.c_customer_sk = web_sales.ws_bill_customer_sk)) otherCondition=()
----------------------------------PhysicalDistribute[DistributionSpecHash]
------------------------------------PhysicalProject
--------------------------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF2 d_date_sk->[ss_sold_date_sk]
--------------------------------------hashJoin[INNER_JOIN] hashCondition=((web_sales.ws_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF3 d_date_sk->[ws_sold_date_sk]
----------------------------------------PhysicalProject
------------------------------------------PhysicalOlapScan[store_sales] apply RFs: RF2 RF3
------------------------------------------PhysicalOlapScan[web_sales] apply RFs: RF3
----------------------------------------PhysicalDistribute[DistributionSpecReplicated]
------------------------------------------PhysicalProject
--------------------------------------------filter((date_dim.d_moy <= 6) and (date_dim.d_moy >= 3) and (date_dim.d_year = 2001))
----------------------------------------------PhysicalOlapScan[date_dim]
----------------------------------hashJoin[RIGHT_SEMI_JOIN] hashCondition=((c.c_customer_sk = web_sales.ws_bill_customer_sk)) otherCondition=()
----------------------------------hashJoin[RIGHT_SEMI_JOIN] hashCondition=((c.c_customer_sk = store_sales.ss_customer_sk)) otherCondition=() build RFs:RF2 c_customer_sk->[ss_customer_sk]
------------------------------------PhysicalDistribute[DistributionSpecHash]
--------------------------------------PhysicalProject
----------------------------------------hashJoin[INNER_JOIN] hashCondition=((web_sales.ws_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF1 d_date_sk->[ws_sold_date_sk]
----------------------------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF1 d_date_sk->[ss_sold_date_sk]
------------------------------------------PhysicalProject
--------------------------------------------PhysicalOlapScan[web_sales] apply RFs: RF1
--------------------------------------------PhysicalOlapScan[store_sales] apply RFs: RF1 RF2
------------------------------------------PhysicalDistribute[DistributionSpecReplicated]
--------------------------------------------PhysicalProject
----------------------------------------------filter((date_dim.d_moy <= 6) and (date_dim.d_moy >= 3) and (date_dim.d_year = 2001))
Expand Down
Loading

0 comments on commit 26be313

Please sign in to comment.