Skip to content

Commit

Permalink
fix: issue apache#9213 substitute ArrayAgg to NthValue to optimize qu…
Browse files Browse the repository at this point in the history
…ery plan
  • Loading branch information
Lordworms committed Feb 21, 2024
1 parent b2a0451 commit 029e970
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 4 deletions.
36 changes: 32 additions & 4 deletions datafusion/sql/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,34 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
}

SQLExpr::ArrayIndex { obj, indexes } => {
fn simplify_array_index_expr(expr: Expr, index: Expr) -> (Expr, bool) {
match &expr {
Expr::AggregateFunction(agg_func) if agg_func.func_def == datafusion_expr::expr::AggregateFunctionDefinition::BuiltIn(AggregateFunction::ArrayAgg) => {
let mut new_args = agg_func.args.clone();
new_args.push(index.clone());
(Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new(
datafusion_expr::AggregateFunction::NthValue,
new_args,
agg_func.distinct,
agg_func.filter.clone(),
agg_func.order_by.clone(),
)), true)
},
_ => (expr, false),
}
}
let expr =
self.sql_expr_to_logical_expr(*obj, schema, planner_context)?;
self.plan_indexed(expr, indexes, schema, planner_context)
if indexes.len() > 1 {
return self.plan_indexed(expr, indexes, schema, planner_context);
}
let (new_expr, changed) =
simplify_array_index_expr(expr, self.sql_expr_to_logical_expr(indexes[0].clone(), schema, planner_context)?);
if changed {
Ok(new_expr)
} else {
self.plan_indexed(new_expr, indexes, schema, planner_context)
}
}

SQLExpr::CompoundIdentifier(ids) => {
Expand Down Expand Up @@ -557,7 +582,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
limit,
within_group,
} = array_agg;

let order_by = if let Some(order_by) = order_by {
Some(self.order_by_to_sort_expr(
&order_by,
Expand All @@ -581,10 +605,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
vec![self.sql_expr_to_logical_expr(*expr, input_schema, planner_context)?];

// next, aggregate built-ins
let fun = AggregateFunction::ArrayAgg;
Ok(Expr::AggregateFunction(expr::AggregateFunction::new(
fun, args, distinct, None, order_by,
AggregateFunction::ArrayAgg,
args,
distinct,
None,
order_by,
)))
// see if we can rewrite it into NTH-VALUE
}

fn sql_in_list_to_expr(
Expand Down
92 changes: 92 additions & 0 deletions datafusion/sqllogictest/test_files/agg_func_substitute.slt
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

#######
# Setup test data table
#######
statement ok
CREATE EXTERNAL TABLE multiple_ordered_table (
a0 INTEGER,
a INTEGER,
b INTEGER,
c INTEGER,
d INTEGER
)
STORED AS CSV
WITH HEADER ROW
WITH ORDER (a ASC, b ASC)
WITH ORDER (c ASC)
LOCATION '../../datafusion/core/tests/data/window_2.csv';


query TT
EXPLAIN SELECT a, ARRAY_AGG(c ORDER BY c)[1] as result
FROM multiple_ordered_table
GROUP BY a;
----
logical_plan
Projection: multiple_ordered_table.a, NTH_VALUE(multiple_ordered_table.c,Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST] AS result
--Aggregate: groupBy=[[multiple_ordered_table.a]], aggr=[[NTH_VALUE(multiple_ordered_table.c, Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]]]
----TableScan: multiple_ordered_table projection=[a, c]
physical_plan
ProjectionExec: expr=[a@0 as a, NTH_VALUE(multiple_ordered_table.c,Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]@1 as result]
--AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[NTH_VALUE(multiple_ordered_table.c,Int64(1))], ordering_mode=Sorted
----SortExec: expr=[a@0 ASC NULLS LAST]
------CoalesceBatchesExec: target_batch_size=8192
--------RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4
----------AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[NTH_VALUE(multiple_ordered_table.c,Int64(1))], ordering_mode=Sorted
------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
--------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], has_header=true


query TT
EXPLAIN SELECT a, NTH_VALUE(c, 1 ORDER BY c) as result
FROM multiple_ordered_table
GROUP BY a;
----
logical_plan
Projection: multiple_ordered_table.a, NTH_VALUE(multiple_ordered_table.c,Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST] AS result
--Aggregate: groupBy=[[multiple_ordered_table.a]], aggr=[[NTH_VALUE(multiple_ordered_table.c, Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]]]
----TableScan: multiple_ordered_table projection=[a, c]
physical_plan
ProjectionExec: expr=[a@0 as a, NTH_VALUE(multiple_ordered_table.c,Int64(1)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]@1 as result]
--AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[NTH_VALUE(multiple_ordered_table.c,Int64(1))], ordering_mode=Sorted
----SortExec: expr=[a@0 ASC NULLS LAST]
------CoalesceBatchesExec: target_batch_size=8192
--------RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4
----------AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[NTH_VALUE(multiple_ordered_table.c,Int64(1))], ordering_mode=Sorted
------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
--------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], has_header=true

query TT
EXPLAIN SELECT a, ARRAY_AGG(c ORDER BY c)[1 + 100] as result
FROM multiple_ordered_table
GROUP BY a;
----
logical_plan
Projection: multiple_ordered_table.a, NTH_VALUE(multiple_ordered_table.c,Int64(1) + Int64(100)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST] AS result
--Aggregate: groupBy=[[multiple_ordered_table.a]], aggr=[[NTH_VALUE(multiple_ordered_table.c, Int64(101)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST] AS NTH_VALUE(multiple_ordered_table.c,Int64(1) + Int64(100)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]]]
----TableScan: multiple_ordered_table projection=[a, c]
physical_plan
ProjectionExec: expr=[a@0 as a, NTH_VALUE(multiple_ordered_table.c,Int64(1) + Int64(100)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]@1 as result]
--AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[NTH_VALUE(multiple_ordered_table.c,Int64(1) + Int64(100)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]], ordering_mode=Sorted
----SortExec: expr=[a@0 ASC NULLS LAST]
------CoalesceBatchesExec: target_batch_size=8192
--------RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4
----------AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[NTH_VALUE(multiple_ordered_table.c,Int64(1) + Int64(100)) ORDER BY [multiple_ordered_table.c ASC NULLS LAST]], ordering_mode=Sorted
------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
--------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/core/tests/data/window_2.csv]]}, projection=[a, c], output_orderings=[[a@0 ASC NULLS LAST], [c@1 ASC NULLS LAST]], has_header=true

0 comments on commit 029e970

Please sign in to comment.