Skip to content

Commit

Permalink
split other agg function into seperate aggregate
Browse files Browse the repository at this point in the history
  • Loading branch information
feiniaofeiafei committed Dec 18, 2024
1 parent 23c563e commit 3951e5c
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -152,25 +152,46 @@ public Plan visitLogicalAggregate(LogicalAggregate<? extends Plan> agg, Distinct
Alias alias = new Alias(newDistinctAggFunc);
outputExpressions.add(alias);
if (i == 0) {
List<Expression> otherAggFuncAliases = otherAggFuncs.stream()
.map(e -> ExpressionUtils.replace(e, producerToConsumerSlotMap)).collect(Collectors.toList());
for (Expression otherAggFuncAlias : otherAggFuncAliases) {
// otherAggFunc is instance of Alias
Alias outputOtherFunc = new Alias(otherAggFuncAlias.child(0));
outputExpressions.add(outputOtherFunc);
newToOriginDistinctFuncAlias.put(outputOtherFunc, (Alias) otherAggFuncAlias);
}
// save replacedGroupBy
outputJoinGroupBys.addAll(replacedGroupBy);
}
LogicalAggregate<Plan> newAgg = new LogicalAggregate<>(replacedGroupBy, outputExpressions, consumer);
newAggs.add(newAgg);
newToOriginDistinctFuncAlias.put(alias, distinctFuncWithAlias.get(i));
}
buildOtherAggFuncAggregate(otherAggFuncs, producer, ctx, cloneAgg, newToOriginDistinctFuncAlias, newAggs);
List<Expression> groupBy = agg.getGroupByExpressions();
LogicalJoin<Plan, Plan> join = constructJoin(newAggs, groupBy);
return constructProject(groupBy, newToOriginDistinctFuncAlias,
outputJoinGroupBys, join);
return constructProject(groupBy, newToOriginDistinctFuncAlias, outputJoinGroupBys, join);
}

private static void buildOtherAggFuncAggregate(List<Alias> otherAggFuncs, LogicalCTEProducer<Plan> producer,
DistinctSplitContext ctx, LogicalAggregate<Plan> cloneAgg, Map<Alias, Alias> newToOriginDistinctFuncAlias,
List<LogicalAggregate<Plan>> newAggs) {
if (otherAggFuncs.isEmpty()) {
return;
}
LogicalCTEConsumer consumer = new LogicalCTEConsumer(ctx.statementContext.getNextRelationId(),
producer.getCteId(), "", producer);
ctx.cascadesContext.putCTEIdToConsumer(consumer);
Map<Slot, Slot> producerToConsumerSlotMap = new HashMap<>();
for (Map.Entry<Slot, Slot> entry : consumer.getConsumerToProducerOutputMap().entrySet()) {
producerToConsumerSlotMap.put(entry.getValue(), entry.getKey());
}
List<Expression> replacedGroupBy = ExpressionUtils.replace(cloneAgg.getGroupByExpressions(),
producerToConsumerSlotMap);
List<NamedExpression> outputExpressions = replacedGroupBy.stream()
.map(Slot.class::cast).collect(Collectors.toList());
List<Expression> otherAggFuncAliases = otherAggFuncs.stream()
.map(e -> ExpressionUtils.replace(e, producerToConsumerSlotMap)).collect(Collectors.toList());
for (Expression otherAggFuncAlias : otherAggFuncAliases) {
// otherAggFunc is instance of Alias
Alias outputOtherFunc = new Alias(otherAggFuncAlias.child(0));
outputExpressions.add(outputOtherFunc);
newToOriginDistinctFuncAlias.put(outputOtherFunc, (Alias) otherAggFuncAlias);
}
LogicalAggregate<Plan> newAgg = new LogicalAggregate<>(replacedGroupBy, outputExpressions, consumer);
newAggs.add(newAgg);
}

private static boolean isDistinctMultiColumns(AggregateFunction func) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -411,10 +411,6 @@ PhysicalCteAnchor ( cteId=CTEId#0 )
--PhysicalResultSink
----hashJoin[INNER_JOIN] hashCondition=((.d = .d)) otherCondition=()
------hashJoin[INNER_JOIN] hashCondition=((.d = .d)) otherCondition=()
--------hashAgg[DISTINCT_LOCAL]
----------hashAgg[GLOBAL]
------------hashAgg[LOCAL]
--------------PhysicalCteConsumer ( cteId=CTEId#0 )
--------hashJoin[INNER_JOIN] hashCondition=((.d = .d)) otherCondition=()
----------hashAgg[DISTINCT_LOCAL]
------------hashAgg[GLOBAL]
Expand All @@ -424,11 +420,36 @@ PhysicalCteAnchor ( cteId=CTEId#0 )
------------hashAgg[GLOBAL]
--------------hashAgg[LOCAL]
----------------PhysicalCteConsumer ( cteId=CTEId#0 )
--------hashAgg[DISTINCT_LOCAL]
----------hashAgg[GLOBAL]
------------hashAgg[LOCAL]
--------------PhysicalCteConsumer ( cteId=CTEId#0 )
------hashAgg[DISTINCT_LOCAL]
--------hashAgg[GLOBAL]
----------hashAgg[LOCAL]
------------PhysicalCteConsumer ( cteId=CTEId#0 )

-- !has_other_func --
PhysicalCteAnchor ( cteId=CTEId#0 )
--PhysicalCteProducer ( cteId=CTEId#0 )
----PhysicalOlapScan[test_distinct_multi]
--PhysicalResultSink
----NestedLoopJoin[CROSS_JOIN]
------NestedLoopJoin[CROSS_JOIN]
--------hashAgg[DISTINCT_GLOBAL]
----------hashAgg[DISTINCT_LOCAL]
------------hashAgg[GLOBAL]
--------------hashAgg[LOCAL]
----------------PhysicalCteConsumer ( cteId=CTEId#0 )
--------hashAgg[DISTINCT_GLOBAL]
----------hashAgg[DISTINCT_LOCAL]
------------hashAgg[GLOBAL]
--------------hashAgg[LOCAL]
----------------PhysicalCteConsumer ( cteId=CTEId#0 )
------hashAgg[GLOBAL]
--------hashAgg[LOCAL]
----------PhysicalCteConsumer ( cteId=CTEId#0 )

-- !multi_count_with_gby --
PhysicalResultSink
--hashAgg[GLOBAL]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ suite("distinct_split") {
qt_count_multi_sum_avg_no_gby "select sum(distinct b), count(distinct a,d), avg(distinct c) from test_distinct_multi"
qt_count_sum_avg_with_gby "select sum(distinct b), count(distinct a), avg(distinct c) from test_distinct_multi group by b,a order by 1,2,3"
qt_count_multi_sum_avg_with_gby "select sum(distinct b), count(distinct a,d), avg(distinct c) from test_distinct_multi group by a,b order by 1,2,3"
//这里以下还需要验证结果

// There is a reference query in the upper layer
qt_multi_sum_has_upper """select c1+ c2 from (select sum(distinct b) c1, sum(distinct a) c2 from test_distinct_multi) t"""
qt_000_count_has_upper """select abs(c1) from (select count(distinct a) c1 from test_distinct_multi) t"""
Expand Down Expand Up @@ -189,6 +189,7 @@ suite("distinct_split") {
qt_multi_count_mulitcols_with_gby """explain shape plan select count(distinct b,c), count(distinct a,b) from test_distinct_multi group by d"""
qt_three_count_mulitcols_without_gby """explain shape plan select count(distinct b,c), count(distinct a,b), count(distinct a,b,c) from test_distinct_multi"""
qt_four_count_mulitcols_with_gby """explain shape plan select count(distinct b,c), count(distinct a,b),count(distinct b,c,d), count(distinct a,b,c) from test_distinct_multi group by d"""
qt_has_other_func "explain shape plan select count(distinct b), count(distinct a), max(b),sum(c),min(a) from test_distinct_multi"

// should not rewrite
qt_multi_count_with_gby """explain shape plan select count(distinct b), count(distinct a) from test_distinct_multi group by c"""
Expand Down

0 comments on commit 3951e5c

Please sign in to comment.