Skip to content

Commit

Permalink
feat: Support PERCENTILE_CONT planning
Browse files Browse the repository at this point in the history
  • Loading branch information
MazterQyou committed Aug 14, 2024
1 parent bba28d6 commit 9022ac3
Show file tree
Hide file tree
Showing 13 changed files with 298 additions and 12 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion datafusion/common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,4 @@ cranelift-module = { version = "0.82.0", optional = true }
ordered-float = "2.10"
parquet = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "8fd2aa80114d5c0d4e6a0c370729507a4424e7b3", features = ["arrow"], optional = true }
pyo3 = { version = "0.16", optional = true }
sqlparser = { git = 'https://github.com/cube-js/sqlparser-rs.git', rev = "6a54d27d3b75a04b9f9cbe309a83078aa54b32fd" }
sqlparser = { git = 'https://github.com/cube-js/sqlparser-rs.git', rev = "f1a97af2f22d9bfe057c77d5c9673038d8cf295b" }
2 changes: 1 addition & 1 deletion datafusion/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ pin-project-lite= "^0.2.7"
pyo3 = { version = "0.16", optional = true }
rand = "0.8"
smallvec = { version = "1.6", features = ["union"] }
sqlparser = { git = 'https://github.com/cube-js/sqlparser-rs.git', rev = "6a54d27d3b75a04b9f9cbe309a83078aa54b32fd" }
sqlparser = { git = 'https://github.com/cube-js/sqlparser-rs.git', rev = "f1a97af2f22d9bfe057c77d5c9673038d8cf295b" }
tempfile = "3"
tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync", "fs", "parking_lot"] }
tokio-stream = "0.1"
Expand Down
8 changes: 8 additions & 0 deletions datafusion/core/src/physical_plan/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,14 @@ pub fn create_aggregate_expr(
.to_string(),
));
}
(AggregateFunction::PercentileCont, _) => {
Arc::new(expressions::PercentileCont::new(
// Pass in the desired percentile expr
name,
coerced_phy_exprs,
return_type,
)?)
}
(AggregateFunction::ApproxMedian, false) => {
Arc::new(expressions::ApproxMedian::new(
coerced_phy_exprs[0].clone(),
Expand Down
51 changes: 46 additions & 5 deletions datafusion/core/src/sql/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ use datafusion_expr::expr::GroupingSet;
use sqlparser::ast::{
ArrayAgg, BinaryOperator, DataType as SQLDataType, DateTimeField, Expr as SQLExpr,
Fetch, FunctionArg, FunctionArgExpr, Ident, Join, JoinConstraint, JoinOperator,
ObjectName, Offset as SQLOffset, Query, Select, SelectItem, SetExpr, SetOperator,
ObjectName, Offset as SQLOffset, PercentileCont, Query, Select, SelectItem, SetExpr, SetOperator,
ShowStatementFilter, TableFactor, TableWithJoins, TrimWhereField, UnaryOperator,
Value, Values as SQLValues,
};
Expand Down Expand Up @@ -1440,22 +1440,22 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {

let order_by_rex = order_by
.into_iter()
.map(|e| self.order_by_to_sort_expr(e, plan.schema()))
.map(|e| self.order_by_to_sort_expr(e, plan.schema(), true))
.collect::<Result<Vec<_>>>()?;

LogicalPlanBuilder::from(plan).sort(order_by_rex)?.build()
}

/// convert sql OrderByExpr to Expr::Sort
fn order_by_to_sort_expr(&self, e: OrderByExpr, schema: &DFSchema) -> Result<Expr> {
fn order_by_to_sort_expr(&self, e: OrderByExpr, schema: &DFSchema, parse_indexes: bool) -> Result<Expr> {
let OrderByExpr {
asc,
expr,
nulls_first,
} = e;

let expr = match expr {
SQLExpr::Value(Value::Number(v, _)) => {
SQLExpr::Value(Value::Number(v, _)) if parse_indexes => {
let field_index = v
.parse::<usize>()
.map_err(|err| DataFusionError::Plan(err.to_string()))?;
Expand Down Expand Up @@ -2313,7 +2313,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
let order_by = window
.order_by
.into_iter()
.map(|e| self.order_by_to_sort_expr(e, schema))
.map(|e| self.order_by_to_sort_expr(e, schema, true))
.collect::<Result<Vec<_>>>()?;
let window_frame = window
.window_frame
Expand Down Expand Up @@ -2441,6 +2441,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {

SQLExpr::ArrayAgg(array_agg) => self.parse_array_agg(array_agg, schema),

SQLExpr::PercentileCont(percentile_cont) => self.parse_percentile_cont(percentile_cont, schema),

_ => Err(DataFusionError::NotImplemented(format!(
"Unsupported ast node {:?} in sqltorel",
sql
Expand Down Expand Up @@ -2494,6 +2496,36 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
})
}

fn parse_percentile_cont(
&self,
percentile_cont: PercentileCont,
input_schema: &DFSchema,
) -> Result<Expr> {
let PercentileCont {
expr,
within_group,
} = percentile_cont;

// Some dialects have special syntax for percentile_cont. DataFusion only supports it like a function.
let expr = self.sql_expr_to_logical_expr(*expr, input_schema)?;
let (order_by_expr, asc, nulls_first) = match self.order_by_to_sort_expr(*within_group, input_schema, false)? {
Expr::Sort { expr, asc, nulls_first } => (expr, asc, nulls_first),
_ => return Err(DataFusionError::Internal("PercentileCont expected Sort expression in ORDER BY".to_string())),
};
let asc_expr = Expr::Literal(ScalarValue::Boolean(Some(asc)));
let nulls_first_expr = Expr::Literal(ScalarValue::Boolean(Some(nulls_first)));

let args = vec![expr, *order_by_expr, asc_expr, nulls_first_expr];
// next, aggregate built-ins
let fun = aggregates::AggregateFunction::PercentileCont;

Ok(Expr::AggregateFunction {
fun,
distinct: false,
args,
})
}

fn function_args_to_expr(
&self,
args: Vec<FunctionArg>,
Expand Down Expand Up @@ -4133,6 +4165,15 @@ mod tests {
quick_test(sql, expected);
}

#[test]
fn select_percentile_cont() {
let sql = "SELECT percentile_cont(0.5) WITHIN GROUP (ORDER BY age) FROM person";
let expected = "Projection: #PERCENTILECONT(Float64(0.5),person.age,Boolean(true),Boolean(false))\
\n Aggregate: groupBy=[[]], aggr=[[PERCENTILECONT(Float64(0.5), #person.age, Boolean(true), Boolean(false))]]\
\n TableScan: person projection=None";
quick_test(sql, expected);
}

#[test]
fn select_scalar_func() {
let sql = "SELECT sqrt(age) FROM person";
Expand Down
2 changes: 1 addition & 1 deletion datafusion/expr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,4 @@ path = "src/lib.rs"
ahash = { version = "0.7", default-features = false }
arrow = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "8fd2aa80114d5c0d4e6a0c370729507a4424e7b3", features = ["prettyprint"] }
datafusion-common = { path = "../common", version = "7.0.0" }
sqlparser = { git = 'https://github.com/cube-js/sqlparser-rs.git', rev = "6a54d27d3b75a04b9f9cbe309a83078aa54b32fd" }
sqlparser = { git = 'https://github.com/cube-js/sqlparser-rs.git', rev = "f1a97af2f22d9bfe057c77d5c9673038d8cf295b" }
39 changes: 39 additions & 0 deletions datafusion/expr/src/aggregate_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ pub enum AggregateFunction {
ApproxPercentileCont,
/// Approximate continuous percentile function with weight
ApproxPercentileContWithWeight,
/// Continuous percentile function
PercentileCont,
/// ApproxMedian
ApproxMedian,
/// BoolAnd
Expand Down Expand Up @@ -124,6 +126,7 @@ impl FromStr for AggregateFunction {
"approx_percentile_cont_with_weight" => {
AggregateFunction::ApproxPercentileContWithWeight
}
"percentile_cont" => AggregateFunction::PercentileCont,
"approx_median" => AggregateFunction::ApproxMedian,
"bool_and" => AggregateFunction::BoolAnd,
"bool_or" => AggregateFunction::BoolOr,
Expand Down Expand Up @@ -178,6 +181,7 @@ pub fn return_type(
AggregateFunction::ApproxPercentileContWithWeight => {
Ok(coerced_data_types[0].clone())
}
AggregateFunction::PercentileCont => Ok(coerced_data_types[1].clone()),
AggregateFunction::ApproxMedian => Ok(coerced_data_types[0].clone()),
AggregateFunction::BoolAnd | AggregateFunction::BoolOr => Ok(DataType::Boolean),
}
Expand Down Expand Up @@ -324,6 +328,33 @@ pub fn coerce_types(
}
Ok(input_types.to_vec())
}
AggregateFunction::PercentileCont => {
if !matches!(input_types[0], DataType::Float64) {
return Err(DataFusionError::Plan(format!(
"The percentile argument for {:?} must be Float64, not {:?}.",
agg_fun, input_types[0]
)));
}
if !is_approx_percentile_cont_supported_arg_type(&input_types[1]) {
return Err(DataFusionError::Plan(format!(
"The function {:?} does not support inputs of type {:?}.",
agg_fun, input_types[1]
)));
}
if !matches!(input_types[2], DataType::Boolean) {
return Err(DataFusionError::Plan(format!(
"The asc argument for {:?} must be Boolean, not {:?}.",
agg_fun, input_types[2]
)));
}
if !matches!(input_types[3], DataType::Boolean) {
return Err(DataFusionError::Plan(format!(
"The nulls_first argument for {:?} must be Boolean, not {:?}.",
agg_fun, input_types[3]
)));
}
Ok(input_types.to_vec())
}
AggregateFunction::ApproxMedian => {
if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) {
return Err(DataFusionError::Plan(format!(
Expand Down Expand Up @@ -395,6 +426,14 @@ pub fn signature(fun: &AggregateFunction) -> Signature {
.collect(),
Volatility::Immutable,
),
AggregateFunction::PercentileCont => Signature::one_of(
// Accept a float64 percentile paired with any numeric value, plus bool values
NUMERICS
.iter()
.map(|t| TypeSignature::Exact(vec![DataType::Float64, t.clone(), DataType::Boolean, DataType::Boolean]))
.collect(),
Volatility::Immutable,
),
AggregateFunction::BoolAnd | AggregateFunction::BoolOr => {
Signature::exact(vec![DataType::Boolean], Volatility::Immutable)
}
Expand Down
2 changes: 2 additions & 0 deletions datafusion/physical-expr/src/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ mod not;
mod nth_value;
mod nullif;
mod outer_column;
mod percentile_cont;
mod rank;
mod row_number;
mod stats;
Expand Down Expand Up @@ -95,6 +96,7 @@ pub use not::{not, NotExpr};
pub use nth_value::NthValue;
pub use nullif::nullif_func;
pub use outer_column::OuterColumn;
pub use percentile_cont::PercentileCont;
pub use rank::{dense_rank, percent_rank, rank};
pub use row_number::RowNumber;
pub use stats::StatsType;
Expand Down
Loading

0 comments on commit 9022ac3

Please sign in to comment.