From 9022ac324f4d68816f9cfd3478f953fd1436488a Mon Sep 17 00:00:00 2001 From: Alex Qyoun-ae <4062971+MazterQyou@users.noreply.github.com> Date: Wed, 14 Aug 2024 19:49:11 +0400 Subject: [PATCH] feat: Support `PERCENTILE_CONT` planning --- Cargo.lock | 2 +- datafusion-cli/Cargo.lock | 6 +- datafusion/common/Cargo.toml | 2 +- datafusion/core/Cargo.toml | 2 +- .../core/src/physical_plan/aggregates.rs | 8 + datafusion/core/src/sql/planner.rs | 51 ++++- datafusion/expr/Cargo.toml | 2 +- datafusion/expr/src/aggregate_function.rs | 39 ++++ .../physical-expr/src/expressions/mod.rs | 2 + .../src/expressions/percentile_cont.rs | 192 ++++++++++++++++++ datafusion/proto/proto/datafusion.proto | 1 + datafusion/proto/src/from_proto.rs | 1 + datafusion/proto/src/to_proto.rs | 2 + 13 files changed, 298 insertions(+), 12 deletions(-) create mode 100644 datafusion/physical-expr/src/expressions/percentile_cont.rs diff --git a/Cargo.lock b/Cargo.lock index 261fe9001ed40..d0155ebdb749d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2275,7 +2275,7 @@ dependencies = [ [[package]] name = "sqlparser" version = "0.16.0" -source = "git+https://github.com/cube-js/sqlparser-rs.git?rev=6a54d27d3b75a04b9f9cbe309a83078aa54b32fd#6a54d27d3b75a04b9f9cbe309a83078aa54b32fd" +source = "git+https://github.com/cube-js/sqlparser-rs.git?rev=f1a97af2f22d9bfe057c77d5c9673038d8cf295b#f1a97af2f22d9bfe057c77d5c9673038d8cf295b" dependencies = [ "log", ] diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 46cd39ff1620e..352aa06a26f40 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -67,7 +67,7 @@ checksum = "8da52d66c7071e2e3fa2a1e5c6d088fec47b593032b254f5e980de8ea54454d6" [[package]] name = "arrow" version = "11.1.0" -source = "git+https://github.com/cube-js/arrow-rs.git?rev=d9c12d71b655d356c5a287226a763638417972e9#d9c12d71b655d356c5a287226a763638417972e9" +source = "git+https://github.com/cube-js/arrow-rs.git?rev=8fd2aa80114d5c0d4e6a0c370729507a4424e7b3#8fd2aa80114d5c0d4e6a0c370729507a4424e7b3" dependencies = [ "bitflags", "chrono", @@ -1229,7 +1229,7 @@ dependencies = [ [[package]] name = "parquet" version = "11.1.0" -source = "git+https://github.com/cube-js/arrow-rs.git?rev=d9c12d71b655d356c5a287226a763638417972e9#d9c12d71b655d356c5a287226a763638417972e9" +source = "git+https://github.com/cube-js/arrow-rs.git?rev=8fd2aa80114d5c0d4e6a0c370729507a4424e7b3#8fd2aa80114d5c0d4e6a0c370729507a4424e7b3" dependencies = [ "arrow", "base64", @@ -1531,7 +1531,7 @@ checksum = "45456094d1983e2ee2a18fdfebce3189fa451699d0502cb8e3b49dba5ba41451" [[package]] name = "sqlparser" version = "0.16.0" -source = "git+https://github.com/cube-js/sqlparser-rs.git?rev=6a54d27d3b75a04b9f9cbe309a83078aa54b32fd#6a54d27d3b75a04b9f9cbe309a83078aa54b32fd" +source = "git+https://github.com/cube-js/sqlparser-rs.git?rev=f1a97af2f22d9bfe057c77d5c9673038d8cf295b#f1a97af2f22d9bfe057c77d5c9673038d8cf295b" dependencies = [ "log", ] diff --git a/datafusion/common/Cargo.toml b/datafusion/common/Cargo.toml index f317cbb3f485c..1774290cb6085 100644 --- a/datafusion/common/Cargo.toml +++ b/datafusion/common/Cargo.toml @@ -44,4 +44,4 @@ cranelift-module = { version = "0.82.0", optional = true } ordered-float = "2.10" parquet = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "8fd2aa80114d5c0d4e6a0c370729507a4424e7b3", features = ["arrow"], optional = true } pyo3 = { version = "0.16", optional = true } -sqlparser = { git = 'https://github.com/cube-js/sqlparser-rs.git', rev = "6a54d27d3b75a04b9f9cbe309a83078aa54b32fd" } +sqlparser = { git = 'https://github.com/cube-js/sqlparser-rs.git', rev = "f1a97af2f22d9bfe057c77d5c9673038d8cf295b" } diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 5cb3cc1337a31..366f7a91b31e9 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -79,7 +79,7 @@ pin-project-lite= "^0.2.7" pyo3 = { version = "0.16", optional = true } rand = "0.8" smallvec = { version = "1.6", features = ["union"] } -sqlparser = { git = 'https://github.com/cube-js/sqlparser-rs.git', rev = "6a54d27d3b75a04b9f9cbe309a83078aa54b32fd" } +sqlparser = { git = 'https://github.com/cube-js/sqlparser-rs.git', rev = "f1a97af2f22d9bfe057c77d5c9673038d8cf295b" } tempfile = "3" tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync", "fs", "parking_lot"] } tokio-stream = "0.1" diff --git a/datafusion/core/src/physical_plan/aggregates.rs b/datafusion/core/src/physical_plan/aggregates.rs index 07d85d31fa34c..5ab3d46b51e7a 100644 --- a/datafusion/core/src/physical_plan/aggregates.rs +++ b/datafusion/core/src/physical_plan/aggregates.rs @@ -239,6 +239,14 @@ pub fn create_aggregate_expr( .to_string(), )); } + (AggregateFunction::PercentileCont, _) => { + Arc::new(expressions::PercentileCont::new( + // Pass in the desired percentile expr + name, + coerced_phy_exprs, + return_type, + )?) + } (AggregateFunction::ApproxMedian, false) => { Arc::new(expressions::ApproxMedian::new( coerced_phy_exprs[0].clone(), diff --git a/datafusion/core/src/sql/planner.rs b/datafusion/core/src/sql/planner.rs index 9abc77d3c3e9d..559f107d3ac60 100644 --- a/datafusion/core/src/sql/planner.rs +++ b/datafusion/core/src/sql/planner.rs @@ -56,7 +56,7 @@ use datafusion_expr::expr::GroupingSet; use sqlparser::ast::{ ArrayAgg, BinaryOperator, DataType as SQLDataType, DateTimeField, Expr as SQLExpr, Fetch, FunctionArg, FunctionArgExpr, Ident, Join, JoinConstraint, JoinOperator, - ObjectName, Offset as SQLOffset, Query, Select, SelectItem, SetExpr, SetOperator, + ObjectName, Offset as SQLOffset, PercentileCont, Query, Select, SelectItem, SetExpr, SetOperator, ShowStatementFilter, TableFactor, TableWithJoins, TrimWhereField, UnaryOperator, Value, Values as SQLValues, }; @@ -1440,14 +1440,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let order_by_rex = order_by .into_iter() - .map(|e| self.order_by_to_sort_expr(e, plan.schema())) + .map(|e| self.order_by_to_sort_expr(e, plan.schema(), true)) .collect::>>()?; LogicalPlanBuilder::from(plan).sort(order_by_rex)?.build() } /// convert sql OrderByExpr to Expr::Sort - fn order_by_to_sort_expr(&self, e: OrderByExpr, schema: &DFSchema) -> Result { + fn order_by_to_sort_expr(&self, e: OrderByExpr, schema: &DFSchema, parse_indexes: bool) -> Result { let OrderByExpr { asc, expr, @@ -1455,7 +1455,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } = e; let expr = match expr { - SQLExpr::Value(Value::Number(v, _)) => { + SQLExpr::Value(Value::Number(v, _)) if parse_indexes => { let field_index = v .parse::() .map_err(|err| DataFusionError::Plan(err.to_string()))?; @@ -2313,7 +2313,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let order_by = window .order_by .into_iter() - .map(|e| self.order_by_to_sort_expr(e, schema)) + .map(|e| self.order_by_to_sort_expr(e, schema, true)) .collect::>>()?; let window_frame = window .window_frame @@ -2441,6 +2441,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLExpr::ArrayAgg(array_agg) => self.parse_array_agg(array_agg, schema), + SQLExpr::PercentileCont(percentile_cont) => self.parse_percentile_cont(percentile_cont, schema), + _ => Err(DataFusionError::NotImplemented(format!( "Unsupported ast node {:?} in sqltorel", sql @@ -2494,6 +2496,36 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }) } + fn parse_percentile_cont( + &self, + percentile_cont: PercentileCont, + input_schema: &DFSchema, + ) -> Result { + let PercentileCont { + expr, + within_group, + } = percentile_cont; + + // Some dialects have special syntax for percentile_cont. DataFusion only supports it like a function. + let expr = self.sql_expr_to_logical_expr(*expr, input_schema)?; + let (order_by_expr, asc, nulls_first) = match self.order_by_to_sort_expr(*within_group, input_schema, false)? { + Expr::Sort { expr, asc, nulls_first } => (expr, asc, nulls_first), + _ => return Err(DataFusionError::Internal("PercentileCont expected Sort expression in ORDER BY".to_string())), + }; + let asc_expr = Expr::Literal(ScalarValue::Boolean(Some(asc))); + let nulls_first_expr = Expr::Literal(ScalarValue::Boolean(Some(nulls_first))); + + let args = vec![expr, *order_by_expr, asc_expr, nulls_first_expr]; + // next, aggregate built-ins + let fun = aggregates::AggregateFunction::PercentileCont; + + Ok(Expr::AggregateFunction { + fun, + distinct: false, + args, + }) + } + fn function_args_to_expr( &self, args: Vec, @@ -4133,6 +4165,15 @@ mod tests { quick_test(sql, expected); } + #[test] + fn select_percentile_cont() { + let sql = "SELECT percentile_cont(0.5) WITHIN GROUP (ORDER BY age) FROM person"; + let expected = "Projection: #PERCENTILECONT(Float64(0.5),person.age,Boolean(true),Boolean(false))\ + \n Aggregate: groupBy=[[]], aggr=[[PERCENTILECONT(Float64(0.5), #person.age, Boolean(true), Boolean(false))]]\ + \n TableScan: person projection=None"; + quick_test(sql, expected); + } + #[test] fn select_scalar_func() { let sql = "SELECT sqrt(age) FROM person"; diff --git a/datafusion/expr/Cargo.toml b/datafusion/expr/Cargo.toml index f85e0f247eb8b..c80f8e83ed39a 100644 --- a/datafusion/expr/Cargo.toml +++ b/datafusion/expr/Cargo.toml @@ -38,4 +38,4 @@ path = "src/lib.rs" ahash = { version = "0.7", default-features = false } arrow = { git = 'https://github.com/cube-js/arrow-rs.git', rev = "8fd2aa80114d5c0d4e6a0c370729507a4424e7b3", features = ["prettyprint"] } datafusion-common = { path = "../common", version = "7.0.0" } -sqlparser = { git = 'https://github.com/cube-js/sqlparser-rs.git', rev = "6a54d27d3b75a04b9f9cbe309a83078aa54b32fd" } +sqlparser = { git = 'https://github.com/cube-js/sqlparser-rs.git', rev = "f1a97af2f22d9bfe057c77d5c9673038d8cf295b" } diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index f81efe5c35e3e..8d502e8f28b08 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -84,6 +84,8 @@ pub enum AggregateFunction { ApproxPercentileCont, /// Approximate continuous percentile function with weight ApproxPercentileContWithWeight, + /// Continuous percentile function + PercentileCont, /// ApproxMedian ApproxMedian, /// BoolAnd @@ -124,6 +126,7 @@ impl FromStr for AggregateFunction { "approx_percentile_cont_with_weight" => { AggregateFunction::ApproxPercentileContWithWeight } + "percentile_cont" => AggregateFunction::PercentileCont, "approx_median" => AggregateFunction::ApproxMedian, "bool_and" => AggregateFunction::BoolAnd, "bool_or" => AggregateFunction::BoolOr, @@ -178,6 +181,7 @@ pub fn return_type( AggregateFunction::ApproxPercentileContWithWeight => { Ok(coerced_data_types[0].clone()) } + AggregateFunction::PercentileCont => Ok(coerced_data_types[1].clone()), AggregateFunction::ApproxMedian => Ok(coerced_data_types[0].clone()), AggregateFunction::BoolAnd | AggregateFunction::BoolOr => Ok(DataType::Boolean), } @@ -324,6 +328,33 @@ pub fn coerce_types( } Ok(input_types.to_vec()) } + AggregateFunction::PercentileCont => { + if !matches!(input_types[0], DataType::Float64) { + return Err(DataFusionError::Plan(format!( + "The percentile argument for {:?} must be Float64, not {:?}.", + agg_fun, input_types[0] + ))); + } + if !is_approx_percentile_cont_supported_arg_type(&input_types[1]) { + return Err(DataFusionError::Plan(format!( + "The function {:?} does not support inputs of type {:?}.", + agg_fun, input_types[1] + ))); + } + if !matches!(input_types[2], DataType::Boolean) { + return Err(DataFusionError::Plan(format!( + "The asc argument for {:?} must be Boolean, not {:?}.", + agg_fun, input_types[2] + ))); + } + if !matches!(input_types[3], DataType::Boolean) { + return Err(DataFusionError::Plan(format!( + "The nulls_first argument for {:?} must be Boolean, not {:?}.", + agg_fun, input_types[3] + ))); + } + Ok(input_types.to_vec()) + } AggregateFunction::ApproxMedian => { if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) { return Err(DataFusionError::Plan(format!( @@ -395,6 +426,14 @@ pub fn signature(fun: &AggregateFunction) -> Signature { .collect(), Volatility::Immutable, ), + AggregateFunction::PercentileCont => Signature::one_of( + // Accept a float64 percentile paired with any numeric value, plus bool values + NUMERICS + .iter() + .map(|t| TypeSignature::Exact(vec![DataType::Float64, t.clone(), DataType::Boolean, DataType::Boolean])) + .collect(), + Volatility::Immutable, + ), AggregateFunction::BoolAnd | AggregateFunction::BoolOr => { Signature::exact(vec![DataType::Boolean], Volatility::Immutable) } diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index d55814deae867..c92eb5d88fe9f 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -49,6 +49,7 @@ mod not; mod nth_value; mod nullif; mod outer_column; +mod percentile_cont; mod rank; mod row_number; mod stats; @@ -95,6 +96,7 @@ pub use not::{not, NotExpr}; pub use nth_value::NthValue; pub use nullif::nullif_func; pub use outer_column::OuterColumn; +pub use percentile_cont::PercentileCont; pub use rank::{dense_rank, percent_rank, rank}; pub use row_number::RowNumber; pub use stats::StatsType; diff --git a/datafusion/physical-expr/src/expressions/percentile_cont.rs b/datafusion/physical-expr/src/expressions/percentile_cont.rs new file mode 100644 index 0000000000000..58d640ffb72bc --- /dev/null +++ b/datafusion/physical-expr/src/expressions/percentile_cont.rs @@ -0,0 +1,192 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use super::{format_state_name, Literal}; +use crate::tdigest::TryIntoOrderedF64; +use crate::{ + tdigest::{TDigest, DEFAULT_MAX_SIZE}, + AggregateExpr, PhysicalExpr, +}; +use arrow::{ + array::{ + ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, + Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array, + }, + datatypes::{DataType, Field}, +}; +use datafusion_common::DataFusionError; +use datafusion_common::Result; +use datafusion_common::ScalarValue; +use datafusion_expr::Accumulator; +use ordered_float::OrderedFloat; +use std::{any::Any, iter, sync::Arc}; + +/// PERCENTILE_CONT aggregate expression +#[derive(Debug)] +pub struct PercentileCont { + name: String, + input_data_type: DataType, + percentile: f64, + expr: Arc, + asc: bool, + nulls_first: bool, +} + +impl PercentileCont { + /// Create a new [`PercentileCont`] aggregate function. + pub fn new( + name: impl Into, + expr: Vec>, + input_data_type: DataType, + ) -> Result { + // Arguments should be [DesiredPercentileLiteral, ColumnExpr, AscLiteral, NullsFirstLiteral] + debug_assert_eq!(expr.len(), 4); + + // Extract the desired percentile literal + let lit = expr[0] + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal( + "desired percentile argument must be float literal".to_string(), + ) + })? + .value(); + let percentile = match lit { + ScalarValue::Float32(Some(q)) => *q as f64, + ScalarValue::Float64(Some(q)) => *q as f64, + got => return Err(DataFusionError::NotImplemented(format!( + "Percentile value for 'PERCENTILE_CONT' must be Float32 or Float64 literal (got data type {})", + got + ))) + }; + + // Ensure the percentile is between 0 and 1. + if !(0.0..=1.0).contains(&percentile) { + return Err(DataFusionError::Plan(format!( + "Percentile value must be between 0.0 and 1.0 inclusive, {} is invalid", + percentile + ))); + } + + // Extract the desired asc literal + let lit = expr[2] + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal( + "desired asc argument must be boolean literal".to_string(), + ) + })? + .value(); + let asc = match lit { + ScalarValue::Boolean(Some(q)) => *q, + got => return Err(DataFusionError::NotImplemented(format!( + "ASC value for 'PERCENTILE_CONT' must be Boolean literal (got data type {})", + got + ))) + }; + + // Extract the desired nulls_first literal + let lit = expr[3] + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Internal( + "desired nulls_first argument must be boolean literal".to_string(), + ) + })? + .value(); + let nulls_first = match lit { + ScalarValue::Boolean(Some(q)) => *q, + got => return Err(DataFusionError::NotImplemented(format!( + "NULLS_FIRST value for 'PERCENTILE_CONT' must be Boolean literal (got data type {})", + got + ))) + }; + + Ok(Self { + name: name.into(), + input_data_type, + percentile, + // The physical expr to evaluate during accumulation + expr: expr[1].clone(), + asc, + nulls_first, + }) + } +} + +impl AggregateExpr for PercentileCont { + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + Ok(Field::new(&self.name, self.input_data_type.clone(), false)) + } + + #[allow(rustdoc::private_intra_doc_links)] + /// See [`TDigest::to_scalar_state()`] for a description of the serialised + /// state. + fn state_fields(&self) -> Result> { + Ok(vec![ + Field::new( + &format_state_name(&self.name, "max_size"), + DataType::UInt64, + false, + ), + Field::new( + &format_state_name(&self.name, "sum"), + DataType::Float64, + false, + ), + Field::new( + &format_state_name(&self.name, "count"), + DataType::Float64, + false, + ), + Field::new( + &format_state_name(&self.name, "max"), + DataType::Float64, + false, + ), + Field::new( + &format_state_name(&self.name, "min"), + DataType::Float64, + false, + ), + Field::new( + &format_state_name(&self.name, "centroids"), + DataType::List(Box::new(Field::new("item", DataType::Float64, true))), + false, + ), + ]) + } + + fn expressions(&self) -> Vec> { + vec![self.expr.clone()] + } + + fn create_accumulator(&self) -> Result> { + Err(DataFusionError::NotImplemented("percentile_cont(...) execution is not implemented".to_string())) + } + + fn name(&self) -> &str { + &self.name + } +} diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index d159e1511f875..a862f76fce5ec 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -224,6 +224,7 @@ enum AggregateFunction { // Cubesql BOOL_AND = 17; BOOL_OR = 18; + PERCENTILE_CONT = 19; } message AggregateExprNode { diff --git a/datafusion/proto/src/from_proto.rs b/datafusion/proto/src/from_proto.rs index 893e011f9c3be..4a3c5d3a32331 100644 --- a/datafusion/proto/src/from_proto.rs +++ b/datafusion/proto/src/from_proto.rs @@ -462,6 +462,7 @@ impl From for AggregateFunction { protobuf::AggregateFunction::ApproxPercentileContWithWeight => { Self::ApproxPercentileContWithWeight } + protobuf::AggregateFunction::PercentileCont => Self::PercentileCont, protobuf::AggregateFunction::ApproxMedian => Self::ApproxMedian, protobuf::AggregateFunction::BoolAnd => Self::BoolAnd, protobuf::AggregateFunction::BoolOr => Self::BoolOr, diff --git a/datafusion/proto/src/to_proto.rs b/datafusion/proto/src/to_proto.rs index 34c4254e46696..66d28306c7802 100644 --- a/datafusion/proto/src/to_proto.rs +++ b/datafusion/proto/src/to_proto.rs @@ -312,6 +312,7 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { AggregateFunction::ApproxPercentileContWithWeight => { Self::ApproxPercentileContWithWeight } + AggregateFunction::PercentileCont => Self::PercentileCont, AggregateFunction::ApproxMedian => Self::ApproxMedian, AggregateFunction::BoolAnd => Self::BoolAnd, AggregateFunction::BoolOr => Self::BoolOr, @@ -528,6 +529,7 @@ impl TryFrom<&Expr> for protobuf::LogicalExprNode { AggregateFunction::ApproxPercentileContWithWeight => { protobuf::AggregateFunction::ApproxPercentileContWithWeight } + AggregateFunction::PercentileCont => protobuf::AggregateFunction::PercentileCont, AggregateFunction::ArrayAgg => protobuf::AggregateFunction::ArrayAgg, AggregateFunction::Min => protobuf::AggregateFunction::Min, AggregateFunction::Max => protobuf::AggregateFunction::Max,