Skip to content

Commit

Permalink
fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
MazterQyou committed Sep 11, 2024
1 parent 98359f9 commit 11a4ed1
Show file tree
Hide file tree
Showing 22 changed files with 360 additions and 193 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion datafusion/common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,4 @@ cranelift-module = { version = "0.82.0", optional = true }
ordered-float = "2.10"
parquet = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "a03d4eef5640e05dddf99fc2357ad6d58b5337cb", features = ["arrow"], optional = true }
pyo3 = { version = "0.16", optional = true }
sqlparser = { git = 'https://github.com/cube-js/sqlparser-rs.git', rev = "3a3a7e582f51576c4d2ac2350512564633fe02dd" }
sqlparser = { git = 'https://github.com/cube-js/sqlparser-rs.git', rev = "5fe1b77d1a91b80529a0b7af0b89411d3cba5137" }
2 changes: 1 addition & 1 deletion datafusion/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ pin-project-lite= "^0.2.7"
pyo3 = { version = "0.16", optional = true }
rand = "0.8"
smallvec = { version = "1.6", features = ["union"] }
sqlparser = { git = 'https://github.com/cube-js/sqlparser-rs.git', rev = "3a3a7e582f51576c4d2ac2350512564633fe02dd" }
sqlparser = { git = 'https://github.com/cube-js/sqlparser-rs.git', rev = "5fe1b77d1a91b80529a0b7af0b89411d3cba5137" }
tempfile = "3"
tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync", "fs", "parking_lot"] }
tokio-stream = "0.1"
Expand Down
18 changes: 13 additions & 5 deletions datafusion/core/src/logical_plan/expr_rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -252,11 +252,19 @@ impl ExprRewritable for Expr {
args,
fun,
distinct,
} => Expr::AggregateFunction {
args: rewrite_vec(args, rewriter)?,
fun,
distinct,
},
within_group,
} => {
let within_group = match within_group {
Some(within_group) => Some(rewrite_vec(within_group, rewriter)?),
None => None,
};
Expr::AggregateFunction {
args: rewrite_vec(args, rewriter)?,
fun,
distinct,
within_group,
}
}
Expr::GroupingSet(grouping_set) => match grouping_set {
GroupingSet::Rollup(exprs) => {
Expr::GroupingSet(GroupingSet::Rollup(rewrite_vec(exprs, rewriter)?))
Expand Down
15 changes: 13 additions & 2 deletions datafusion/core/src/logical_plan/expr_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,23 @@ impl ExprSchemable for Expr {
.collect::<Result<Vec<_>>>()?;
window_function::return_type(fun, &data_types)
}
Expr::AggregateFunction { fun, args, .. } => {
Expr::AggregateFunction {
fun,
args,
within_group,
..
} => {
let data_types = args
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
aggregate_function::return_type(fun, &data_types)
let within_group = within_group
.as_ref()
.unwrap_or(&vec![])
.iter()
.map(|e| e.get_type(schema))
.collect::<Result<Vec<_>>>()?;
aggregate_function::return_type(fun, &data_types, &within_group)
}
Expr::AggregateUDF { fun, args, .. } => {
let data_types = args
Expand Down
16 changes: 15 additions & 1 deletion datafusion/core/src/logical_plan/expr_visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,24 @@ impl ExprVisitable for Expr {
Expr::ScalarFunction { args, .. }
| Expr::ScalarUDF { args, .. }
| Expr::TableUDF { args, .. }
| Expr::AggregateFunction { args, .. }
| Expr::AggregateUDF { args, .. } => args
.iter()
.try_fold(visitor, |visitor, arg| arg.accept(visitor)),
Expr::AggregateFunction {
args, within_group, ..
} => {
let visitor = args
.iter()
.try_fold(visitor, |visitor, arg| arg.accept(visitor))?;
let visitor = if let Some(within_group) = within_group.as_ref() {
within_group
.iter()
.try_fold(visitor, |visitor, arg| arg.accept(visitor))?
} else {
visitor
};
Ok(visitor)
}
Expr::WindowFunction {
args,
partition_by,
Expand Down
14 changes: 12 additions & 2 deletions datafusion/core/src/optimizer/single_distinct_to_groupby.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ fn optimize(plan: &LogicalPlan) -> Result<LogicalPlan> {
fun: fun.clone(),
args: vec![col(SINGLE_DISTINCT_ALIAS)],
distinct: false,
within_group: None,
}
}
_ => agg_expr.clone(),
Expand Down Expand Up @@ -168,13 +169,21 @@ fn is_single_distinct_agg(plan: &LogicalPlan) -> bool {
.iter()
.filter(|expr| {
let mut is_distinct = false;
if let Expr::AggregateFunction { distinct, args, .. } = expr {
let mut is_within_group = false;
if let Expr::AggregateFunction {
distinct,
args,
within_group,
..
} = expr
{
is_distinct = *distinct;
is_within_group = within_group.is_some();
args.iter().for_each(|expr| {
fields_set.insert(expr.name(input.schema()).unwrap());
})
}
is_distinct
is_distinct && !is_within_group
})
.count()
== aggr_expr.len()
Expand Down Expand Up @@ -314,6 +323,7 @@ mod tests {
fun: aggregates::AggregateFunction::Max,
distinct: true,
args: vec![col("b")],
within_group: None,
},
],
)?
Expand Down
32 changes: 26 additions & 6 deletions datafusion/core/src/optimizer/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -339,8 +339,14 @@ pub fn expr_sub_expressions(expr: &Expr) -> Result<Vec<Expr>> {
Expr::ScalarFunction { args, .. }
| Expr::ScalarUDF { args, .. }
| Expr::TableUDF { args, .. }
| Expr::AggregateFunction { args, .. }
| Expr::AggregateUDF { args, .. } => Ok(args.clone()),
Expr::AggregateFunction {
args, within_group, ..
} => Ok(args
.iter()
.chain(within_group.as_ref().unwrap_or(&vec![]))
.cloned()
.collect()),
Expr::GroupingSet(grouping_set) => match grouping_set {
GroupingSet::Rollup(exprs) => Ok(exprs.clone()),
GroupingSet::Cube(exprs) => Ok(exprs.clone()),
Expand Down Expand Up @@ -517,11 +523,25 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result<Expr> {
})
}
}
Expr::AggregateFunction { fun, distinct, .. } => Ok(Expr::AggregateFunction {
fun: fun.clone(),
args: expressions.to_vec(),
distinct: *distinct,
}),
Expr::AggregateFunction {
fun,
distinct,
args,
..
} => {
let args_limit = args.len();
let within_group = if expressions.len() > args_limit {
Some(expressions[args_limit..].to_vec())
} else {
None
};
Ok(Expr::AggregateFunction {
fun: fun.clone(),
args: expressions[..args_limit].to_vec(),
distinct: *distinct,
within_group,
})
}
Expr::AggregateUDF { fun, .. } => Ok(Expr::AggregateUDF {
fun: fun.clone(),
args: expressions.to_vec(),
Expand Down
Loading

0 comments on commit 11a4ed1

Please sign in to comment.