diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 95b114ca4a00..932f44d98486 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -2880,9 +2880,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.10.4" +version = "1.10.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c117dbdfde9c8308975b6a18d71f3f385c89461f7b3fb054288ecf2a2058ba4c" +checksum = "b91213439dad192326a0d7c6ee3955910425f441d7038e0d6933b0aec5c4517f" dependencies = [ "aho-corasick", "memchr", @@ -2892,9 +2892,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.6" +version = "0.4.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "86b83b8b9847f9bf95ef68afb0b8e6cdb80f498442f5179a29fad448fcc1eaea" +checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df" dependencies = [ "aho-corasick", "memchr", @@ -2903,15 +2903,15 @@ dependencies = [ [[package]] name = "regex-lite" -version = "0.1.5" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30b661b2f27137bdbc16f00eda72866a92bb28af1753ffbd56744fb6e2e9cd8e" +checksum = "53a49587ad06b26609c52e423de037e7f57f20d53535d66e08c695f347df952a" [[package]] name = "regex-syntax" -version = "0.8.3" +version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adad44e29e4c806119491a7f06f03de4d1af22c3a680dd47f1e6e179439d1f56" +checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b" [[package]] name = "reqwest" @@ -3846,9 +3846,9 @@ checksum = "daf8dba3b7eb870caf1ddeed7bc9d2a049f3cfdfae7cb521b087cc33ae4c49da" [[package]] name = "utf8parse" -version = "0.2.1" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" +checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uuid" diff --git a/datafusion/core/tests/dataframe/dataframe_functions.rs b/datafusion/core/tests/dataframe/dataframe_functions.rs index 7d155bb16c72..b05769a6ce9d 100644 --- a/datafusion/core/tests/dataframe/dataframe_functions.rs +++ b/datafusion/core/tests/dataframe/dataframe_functions.rs @@ -33,6 +33,7 @@ use datafusion::assert_batches_eq; use datafusion_common::{DFSchema, ScalarValue}; use datafusion_expr::expr::Alias; use datafusion_expr::ExprSchemable; +use datafusion_functions_aggregate::expr_fn::approx_median; fn test_schema() -> SchemaRef { Arc::new(Schema::new(vec![ @@ -342,7 +343,7 @@ async fn test_fn_approx_median() -> Result<()> { let expected = [ "+-----------------------+", - "| APPROX_MEDIAN(test.b) |", + "| approx_median(test.b) |", "+-----------------------+", "| 10 |", "+-----------------------+", diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index 9e4f7a50ac24..6227df814f91 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -71,8 +71,6 @@ pub enum AggregateFunction { ApproxPercentileCont, /// Approximate continuous percentile function with weight ApproxPercentileContWithWeight, - /// ApproxMedian - ApproxMedian, /// Grouping Grouping, /// Bit And @@ -112,7 +110,6 @@ impl AggregateFunction { RegrSXY => "REGR_SXY", ApproxPercentileCont => "APPROX_PERCENTILE_CONT", ApproxPercentileContWithWeight => "APPROX_PERCENTILE_CONT_WITH_WEIGHT", - ApproxMedian => "APPROX_MEDIAN", Grouping => "GROUPING", BitAnd => "BIT_AND", BitOr => "BIT_OR", @@ -161,7 +158,6 @@ impl FromStr for AggregateFunction { "regr_sxy" => AggregateFunction::RegrSXY, // approximate "approx_distinct" => AggregateFunction::ApproxDistinct, - "approx_median" => AggregateFunction::ApproxMedian, "approx_percentile_cont" => AggregateFunction::ApproxPercentileCont, "approx_percentile_cont_with_weight" => { AggregateFunction::ApproxPercentileContWithWeight @@ -234,7 +230,6 @@ impl AggregateFunction { AggregateFunction::ApproxPercentileContWithWeight => { Ok(coerced_data_types[0].clone()) } - AggregateFunction::ApproxMedian => Ok(coerced_data_types[0].clone()), AggregateFunction::Grouping => Ok(DataType::Int32), AggregateFunction::NthValue => Ok(coerced_data_types[0].clone()), AggregateFunction::StringAgg => Ok(DataType::LargeUtf8), @@ -284,7 +279,8 @@ impl AggregateFunction { AggregateFunction::BoolAnd | AggregateFunction::BoolOr => { Signature::uniform(1, vec![DataType::Boolean], Volatility::Immutable) } - AggregateFunction::Avg | AggregateFunction::ApproxMedian => { + + AggregateFunction::Avg => { Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable) } AggregateFunction::NthValue => Signature::any(2, Volatility::Immutable), diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 0360478eac54..5626d343a6cb 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -284,18 +284,6 @@ pub fn approx_distinct(expr: Expr) -> Expr { )) } -/// Calculate an approximation of the median for `expr`. -pub fn approx_median(expr: Expr) -> Expr { - Expr::AggregateFunction(AggregateFunction::new( - aggregate_function::AggregateFunction::ApproxMedian, - vec![expr], - false, - None, - None, - None, - )) -} - /// Calculate an approximation of the specified `percentile` for `expr`. pub fn approx_percentile_cont(expr: Expr, percentile: Expr) -> Expr { Expr::AggregateFunction(AggregateFunction::new( diff --git a/datafusion/expr/src/type_coercion/aggregates.rs b/datafusion/expr/src/type_coercion/aggregates.rs index 4b4d5265324e..efd3c9f371ef 100644 --- a/datafusion/expr/src/type_coercion/aggregates.rs +++ b/datafusion/expr/src/type_coercion/aggregates.rs @@ -231,16 +231,6 @@ pub fn coerce_types( } Ok(input_types.to_vec()) } - AggregateFunction::ApproxMedian => { - if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) { - return plan_err!( - "The function {:?} does not support inputs of type {:?}.", - agg_fun, - input_types[0] - ); - } - Ok(input_types.to_vec()) - } AggregateFunction::NthValue => Ok(input_types.to_vec()), AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]), AggregateFunction::StringAgg => { diff --git a/datafusion/functions-aggregate/src/approx_median.rs b/datafusion/functions-aggregate/src/approx_median.rs new file mode 100644 index 000000000000..b8b86d30557a --- /dev/null +++ b/datafusion/functions-aggregate/src/approx_median.rs @@ -0,0 +1,129 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines physical expressions for APPROX_MEDIAN that can be evaluated MEDIAN at runtime during query execution + +use std::any::Any; +use std::fmt::Debug; + +use arrow::{datatypes::DataType, datatypes::Field}; +use arrow_schema::DataType::{Float64, UInt64}; + +use datafusion_common::{not_impl_err, plan_err, Result}; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::type_coercion::aggregates::NUMERICS; +use datafusion_expr::utils::format_state_name; +use datafusion_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility}; +use datafusion_physical_expr_common::aggregate::utils::down_cast_any_ref; + +use crate::approx_percentile_cont::ApproxPercentileAccumulator; + +make_udaf_expr_and_func!( + ApproxMedian, + approx_median, + expression, + "Computes the approximate median of a set of numbers", + approx_median_udaf +); + +/// APPROX_MEDIAN aggregate expression +pub struct ApproxMedian { + signature: Signature, +} + +impl Debug for ApproxMedian { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + f.debug_struct("ApproxMedian") + .field("name", &self.name()) + .field("signature", &self.signature) + .finish() + } +} + +impl Default for ApproxMedian { + fn default() -> Self { + Self::new() + } +} + +impl ApproxMedian { + /// Create a new APPROX_MEDIAN aggregate function + pub fn new() -> Self { + Self { + signature: Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for ApproxMedian { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + Ok(vec![ + Field::new(format_state_name(args.name, "max_size"), UInt64, false), + Field::new(format_state_name(args.name, "sum"), Float64, false), + Field::new(format_state_name(args.name, "count"), Float64, false), + Field::new(format_state_name(args.name, "max"), Float64, false), + Field::new(format_state_name(args.name, "min"), Float64, false), + Field::new_list( + format_state_name(args.name, "centroids"), + Field::new("item", Float64, true), + false, + ), + ]) + } + + fn name(&self) -> &str { + "approx_median" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + if !arg_types[0].is_numeric() { + return plan_err!("ApproxMedian requires numeric input types"); + } + Ok(arg_types[0].clone()) + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + if acc_args.is_distinct { + return not_impl_err!( + "APPROX_MEDIAN(DISTINCT) aggregations are not available" + ); + } + + Ok(Box::new(ApproxPercentileAccumulator::new( + 0.5_f64, + acc_args.input_type.clone(), + ))) + } +} + +impl PartialEq for ApproxMedian { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| self.signature == x.signature) + .unwrap_or(false) + } +} diff --git a/datafusion/functions-aggregate/src/approx_percentile_cont.rs b/datafusion/functions-aggregate/src/approx_percentile_cont.rs new file mode 100644 index 000000000000..e75417efc684 --- /dev/null +++ b/datafusion/functions-aggregate/src/approx_percentile_cont.rs @@ -0,0 +1,255 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::{ + array::{ + ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, + Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array, + }, + datatypes::DataType, +}; + +use datafusion_common::{downcast_value, internal_err, DataFusionError, ScalarValue}; +use datafusion_expr::Accumulator; +use datafusion_physical_expr_common::aggregate::tdigest::{ + TDigest, TryIntoF64, DEFAULT_MAX_SIZE, +}; + +#[derive(Debug)] +pub struct ApproxPercentileAccumulator { + digest: TDigest, + percentile: f64, + return_type: DataType, +} + +impl ApproxPercentileAccumulator { + pub fn new(percentile: f64, return_type: DataType) -> Self { + Self { + digest: TDigest::new(DEFAULT_MAX_SIZE), + percentile, + return_type, + } + } + + pub fn new_with_max_size( + percentile: f64, + return_type: DataType, + max_size: usize, + ) -> Self { + Self { + digest: TDigest::new(max_size), + percentile, + return_type, + } + } + + // public for approx_percentile_cont_with_weight + pub fn merge_digests(&mut self, digests: &[TDigest]) { + let digests = digests.iter().chain(std::iter::once(&self.digest)); + self.digest = TDigest::merge_digests(digests) + } + + // public for approx_percentile_cont_with_weight + pub fn convert_to_float(values: &ArrayRef) -> datafusion_common::Result> { + match values.data_type() { + DataType::Float64 => { + let array = downcast_value!(values, Float64Array); + Ok(array + .values() + .iter() + .filter_map(|v| v.try_as_f64().transpose()) + .collect::>>()?) + } + DataType::Float32 => { + let array = downcast_value!(values, Float32Array); + Ok(array + .values() + .iter() + .filter_map(|v| v.try_as_f64().transpose()) + .collect::>>()?) + } + DataType::Int64 => { + let array = downcast_value!(values, Int64Array); + Ok(array + .values() + .iter() + .filter_map(|v| v.try_as_f64().transpose()) + .collect::>>()?) + } + DataType::Int32 => { + let array = downcast_value!(values, Int32Array); + Ok(array + .values() + .iter() + .filter_map(|v| v.try_as_f64().transpose()) + .collect::>>()?) + } + DataType::Int16 => { + let array = downcast_value!(values, Int16Array); + Ok(array + .values() + .iter() + .filter_map(|v| v.try_as_f64().transpose()) + .collect::>>()?) + } + DataType::Int8 => { + let array = downcast_value!(values, Int8Array); + Ok(array + .values() + .iter() + .filter_map(|v| v.try_as_f64().transpose()) + .collect::>>()?) + } + DataType::UInt64 => { + let array = downcast_value!(values, UInt64Array); + Ok(array + .values() + .iter() + .filter_map(|v| v.try_as_f64().transpose()) + .collect::>>()?) + } + DataType::UInt32 => { + let array = downcast_value!(values, UInt32Array); + Ok(array + .values() + .iter() + .filter_map(|v| v.try_as_f64().transpose()) + .collect::>>()?) + } + DataType::UInt16 => { + let array = downcast_value!(values, UInt16Array); + Ok(array + .values() + .iter() + .filter_map(|v| v.try_as_f64().transpose()) + .collect::>>()?) + } + DataType::UInt8 => { + let array = downcast_value!(values, UInt8Array); + Ok(array + .values() + .iter() + .filter_map(|v| v.try_as_f64().transpose()) + .collect::>>()?) + } + e => internal_err!( + "APPROX_PERCENTILE_CONT is not expected to receive the type {e:?}" + ), + } + } +} + +impl Accumulator for ApproxPercentileAccumulator { + fn state(&mut self) -> datafusion_common::Result> { + Ok(self.digest.to_scalar_state().into_iter().collect()) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> datafusion_common::Result<()> { + let values = &values[0]; + let sorted_values = &arrow::compute::sort(values, None)?; + let sorted_values = ApproxPercentileAccumulator::convert_to_float(sorted_values)?; + self.digest = self.digest.merge_sorted_f64(&sorted_values); + Ok(()) + } + + fn evaluate(&mut self) -> datafusion_common::Result { + if self.digest.count() == 0.0 { + return ScalarValue::try_from(self.return_type.clone()); + } + let q = self.digest.estimate_quantile(self.percentile); + + // These acceptable return types MUST match the validation in + // ApproxPercentile::create_accumulator. + Ok(match &self.return_type { + DataType::Int8 => ScalarValue::Int8(Some(q as i8)), + DataType::Int16 => ScalarValue::Int16(Some(q as i16)), + DataType::Int32 => ScalarValue::Int32(Some(q as i32)), + DataType::Int64 => ScalarValue::Int64(Some(q as i64)), + DataType::UInt8 => ScalarValue::UInt8(Some(q as u8)), + DataType::UInt16 => ScalarValue::UInt16(Some(q as u16)), + DataType::UInt32 => ScalarValue::UInt32(Some(q as u32)), + DataType::UInt64 => ScalarValue::UInt64(Some(q as u64)), + DataType::Float32 => ScalarValue::Float32(Some(q as f32)), + DataType::Float64 => ScalarValue::Float64(Some(q)), + v => unreachable!("unexpected return type {:?}", v), + }) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> datafusion_common::Result<()> { + if states.is_empty() { + return Ok(()); + } + + let states = (0..states[0].len()) + .map(|index| { + states + .iter() + .map(|array| ScalarValue::try_from_array(array, index)) + .collect::>>() + .map(|state| TDigest::from_scalar_state(&state)) + }) + .collect::>>()?; + + self.merge_digests(&states); + + Ok(()) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + self.digest.size() + - std::mem::size_of_val(&self.digest) + + self.return_type.size() + - std::mem::size_of_val(&self.return_type) + } + + fn supports_retract_batch(&self) -> bool { + true + } +} + +#[cfg(test)] +mod tests { + use arrow_schema::DataType; + + use datafusion_physical_expr_common::aggregate::tdigest::TDigest; + + use crate::approx_percentile_cont::ApproxPercentileAccumulator; + + #[test] + fn test_combine_approx_percentile_accumulator() { + let mut digests: Vec = Vec::new(); + + // one TDigest with 50_000 values from 1 to 1_000 + for _ in 1..=50 { + let t = TDigest::new(100); + let values: Vec<_> = (1..=1_000).map(f64::from).collect(); + let t = t.merge_unsorted_f64(values); + digests.push(t) + } + + let t1 = TDigest::merge_digests(&digests); + let t2 = TDigest::merge_digests(&digests); + + let mut accumulator = + ApproxPercentileAccumulator::new_with_max_size(0.5, DataType::Float64, 100); + + accumulator.merge_digests(&[t1]); + assert_eq!(accumulator.digest.count(), 50_000.0); + accumulator.merge_digests(&[t2]); + assert_eq!(accumulator.digest.count(), 100_000.0); + } +} diff --git a/datafusion/functions-aggregate/src/lib.rs b/datafusion/functions-aggregate/src/lib.rs index b8a2e7032acd..274ab8302e2a 100644 --- a/datafusion/functions-aggregate/src/lib.rs +++ b/datafusion/functions-aggregate/src/lib.rs @@ -62,6 +62,9 @@ pub mod stddev; pub mod sum; pub mod variance; +pub mod approx_median; +pub mod approx_percentile_cont; + use datafusion_common::Result; use datafusion_execution::FunctionRegistry; use datafusion_expr::AggregateUDF; @@ -70,6 +73,7 @@ use std::sync::Arc; /// Fluent-style API for creating `Expr`s pub mod expr_fn { + pub use super::approx_median::approx_median; pub use super::covariance::covar_pop; pub use super::covariance::covar_samp; pub use super::first_last::first_value; @@ -95,6 +99,7 @@ pub fn all_default_aggregate_functions() -> Vec> { variance::var_pop_udaf(), stddev::stddev_udaf(), stddev::stddev_pop_udaf(), + approx_median::approx_median_udaf(), ] } diff --git a/datafusion/functions-aggregate/src/sum.rs b/datafusion/functions-aggregate/src/sum.rs index 9d3fa2522265..b9293bc2ca28 100644 --- a/datafusion/functions-aggregate/src/sum.rs +++ b/datafusion/functions-aggregate/src/sum.rs @@ -46,7 +46,7 @@ make_udaf_expr_and_func!( Sum, sum, expression, - "Returns the first value in a group of values.", + "Returns the sum of a group of values.", sum_udaf ); diff --git a/datafusion/physical-expr-common/src/aggregate/mod.rs b/datafusion/physical-expr-common/src/aggregate/mod.rs index 2273418c6096..ec02df57b82d 100644 --- a/datafusion/physical-expr-common/src/aggregate/mod.rs +++ b/datafusion/physical-expr-common/src/aggregate/mod.rs @@ -17,6 +17,7 @@ pub mod groups_accumulator; pub mod stats; +pub mod tdigest; pub mod utils; use arrow::datatypes::{DataType, Field, Schema}; diff --git a/datafusion/physical-expr/src/aggregate/tdigest.rs b/datafusion/physical-expr-common/src/aggregate/tdigest.rs similarity index 95% rename from datafusion/physical-expr/src/aggregate/tdigest.rs rename to datafusion/physical-expr-common/src/aggregate/tdigest.rs index e3b23b91d0ff..5107d0ab8e52 100644 --- a/datafusion/physical-expr/src/aggregate/tdigest.rs +++ b/datafusion/physical-expr-common/src/aggregate/tdigest.rs @@ -28,7 +28,7 @@ //! [Facebook's Folly TDigest]: https://github.com/facebook/folly/blob/main/folly/stats/TDigest.h use arrow::datatypes::DataType; -use arrow_array::types::Float64Type; +use arrow::datatypes::Float64Type; use datafusion_common::cast::as_primitive_array; use datafusion_common::Result; use datafusion_common::ScalarValue; @@ -50,7 +50,7 @@ macro_rules! cast_scalar_f64 { /// This trait is implemented for each type a [`TDigest`] can operate on, /// allowing it to support both numerical rust types (obtained from /// `PrimitiveArray` instances), and [`ScalarValue`] instances. -pub(crate) trait TryIntoF64 { +pub trait TryIntoF64 { /// A fallible conversion of a possibly null `self` into a [`f64`]. /// /// If `self` is null, this method must return `Ok(None)`. @@ -84,7 +84,7 @@ impl_try_ordered_f64!(u8); /// Centroid implementation to the cluster mentioned in the paper. #[derive(Debug, PartialEq, Clone)] -pub(crate) struct Centroid { +pub struct Centroid { mean: f64, weight: f64, } @@ -104,21 +104,21 @@ impl Ord for Centroid { } impl Centroid { - pub(crate) fn new(mean: f64, weight: f64) -> Self { + pub fn new(mean: f64, weight: f64) -> Self { Centroid { mean, weight } } #[inline] - pub(crate) fn mean(&self) -> f64 { + pub fn mean(&self) -> f64 { self.mean } #[inline] - pub(crate) fn weight(&self) -> f64 { + pub fn weight(&self) -> f64 { self.weight } - pub(crate) fn add(&mut self, sum: f64, weight: f64) -> f64 { + pub fn add(&mut self, sum: f64, weight: f64) -> f64 { let new_sum = sum + self.weight * self.mean; let new_weight = self.weight + weight; self.weight = new_weight; @@ -138,7 +138,7 @@ impl Default for Centroid { /// T-Digest to be operated on. #[derive(Debug, PartialEq, Clone)] -pub(crate) struct TDigest { +pub struct TDigest { centroids: Vec, max_size: usize, sum: f64, @@ -148,7 +148,7 @@ pub(crate) struct TDigest { } impl TDigest { - pub(crate) fn new(max_size: usize) -> Self { + pub fn new(max_size: usize) -> Self { TDigest { centroids: Vec::new(), max_size, @@ -159,7 +159,7 @@ impl TDigest { } } - pub(crate) fn new_with_centroid(max_size: usize, centroid: Centroid) -> Self { + pub fn new_with_centroid(max_size: usize, centroid: Centroid) -> Self { TDigest { centroids: vec![centroid.clone()], max_size, @@ -171,27 +171,27 @@ impl TDigest { } #[inline] - pub(crate) fn count(&self) -> f64 { + pub fn count(&self) -> f64 { self.count } #[inline] - pub(crate) fn max(&self) -> f64 { + pub fn max(&self) -> f64 { self.max } #[inline] - pub(crate) fn min(&self) -> f64 { + pub fn min(&self) -> f64 { self.min } #[inline] - pub(crate) fn max_size(&self) -> usize { + pub fn max_size(&self) -> usize { self.max_size } /// Size in bytes including `Self`. - pub(crate) fn size(&self) -> usize { + pub fn size(&self) -> usize { std::mem::size_of_val(self) + (std::mem::size_of::() * self.centroids.capacity()) } @@ -228,14 +228,14 @@ impl TDigest { v.clamp(lo, hi) } - #[cfg(test)] - pub(crate) fn merge_unsorted_f64(&self, unsorted_values: Vec) -> TDigest { + // public for testing in other modules + pub fn merge_unsorted_f64(&self, unsorted_values: Vec) -> TDigest { let mut values = unsorted_values; values.sort_by(|a, b| a.total_cmp(b)); self.merge_sorted_f64(&values) } - pub(crate) fn merge_sorted_f64(&self, sorted_values: &[f64]) -> TDigest { + pub fn merge_sorted_f64(&self, sorted_values: &[f64]) -> TDigest { #[cfg(debug_assertions)] debug_assert!(is_sorted(sorted_values), "unsorted input to TDigest"); @@ -370,9 +370,7 @@ impl TDigest { } // Merge multiple T-Digests - pub(crate) fn merge_digests<'a>( - digests: impl IntoIterator, - ) -> TDigest { + pub fn merge_digests<'a>(digests: impl IntoIterator) -> TDigest { let digests = digests.into_iter().collect::>(); let n_centroids: usize = digests.iter().map(|d| d.centroids.len()).sum(); if n_centroids == 0 { @@ -465,7 +463,7 @@ impl TDigest { } /// To estimate the value located at `q` quantile - pub(crate) fn estimate_quantile(&self, q: f64) -> f64 { + pub fn estimate_quantile(&self, q: f64) -> f64 { if self.centroids.is_empty() { return 0.0; } @@ -569,7 +567,7 @@ impl TDigest { /// The [`TDigest::from_scalar_state()`] method reverses this processes, /// consuming the output of this method and returning an unpacked /// [`TDigest`]. - pub(crate) fn to_scalar_state(&self) -> Vec { + pub fn to_scalar_state(&self) -> Vec { // Gather up all the centroids let centroids: Vec = self .centroids @@ -598,7 +596,7 @@ impl TDigest { /// Providing input to this method that was not obtained from /// [`Self::to_scalar_state()`] results in undefined behaviour and may /// panic. - pub(crate) fn from_scalar_state(state: &[ScalarValue]) -> Self { + pub fn from_scalar_state(state: &[ScalarValue]) -> Self { assert_eq!(state.len(), 6, "invalid TDigest state"); let max_size = match &state[0] { diff --git a/datafusion/physical-expr/src/aggregate/approx_median.rs b/datafusion/physical-expr/src/aggregate/approx_median.rs deleted file mode 100644 index cbbfef5a8919..000000000000 --- a/datafusion/physical-expr/src/aggregate/approx_median.rs +++ /dev/null @@ -1,99 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Defines physical expressions for APPROX_MEDIAN that can be evaluated MEDIAN at runtime during query execution - -use crate::aggregate::utils::down_cast_any_ref; -use crate::expressions::{lit, ApproxPercentileCont}; -use crate::{AggregateExpr, PhysicalExpr}; -use arrow::{datatypes::DataType, datatypes::Field}; -use datafusion_common::Result; -use datafusion_expr::Accumulator; -use std::any::Any; -use std::sync::Arc; - -/// MEDIAN aggregate expression -#[derive(Debug)] -pub struct ApproxMedian { - name: String, - expr: Arc, - data_type: DataType, - approx_percentile: ApproxPercentileCont, -} - -impl ApproxMedian { - /// Create a new APPROX_MEDIAN aggregate function - pub fn try_new( - expr: Arc, - name: impl Into, - data_type: DataType, - ) -> Result { - let name: String = name.into(); - let approx_percentile = ApproxPercentileCont::new( - vec![expr.clone(), lit(0.5_f64)], - name.clone(), - data_type.clone(), - )?; - Ok(Self { - name, - expr, - data_type, - approx_percentile, - }) - } -} - -impl AggregateExpr for ApproxMedian { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - - fn field(&self) -> Result { - Ok(Field::new(&self.name, self.data_type.clone(), true)) - } - - fn create_accumulator(&self) -> Result> { - self.approx_percentile.create_accumulator() - } - - fn state_fields(&self) -> Result> { - self.approx_percentile.state_fields() - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn name(&self) -> &str { - &self.name - } -} - -impl PartialEq for ApproxMedian { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.data_type == x.data_type - && self.expr.eq(&x.expr) - && self.approx_percentile == x.approx_percentile - }) - .unwrap_or(false) - } -} diff --git a/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs b/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs index 63a4c85f9e80..f2068bbc92cc 100644 --- a/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs +++ b/datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs @@ -15,26 +15,19 @@ // specific language governing permissions and limitations // under the License. -use crate::aggregate::tdigest::TryIntoF64; -use crate::aggregate::tdigest::{TDigest, DEFAULT_MAX_SIZE}; -use crate::aggregate::utils::down_cast_any_ref; -use crate::expressions::format_state_name; -use crate::{AggregateExpr, PhysicalExpr}; -use arrow::{ - array::{ - ArrayRef, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, - Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array, - }, - datatypes::{DataType, Field}, -}; +use std::{any::Any, sync::Arc}; + +use arrow::datatypes::{DataType, Field}; use arrow_array::RecordBatch; use arrow_schema::Schema; -use datafusion_common::{ - downcast_value, internal_err, not_impl_err, plan_err, DataFusionError, Result, - ScalarValue, -}; + +use datafusion_common::{not_impl_err, plan_err, DataFusionError, Result, ScalarValue}; use datafusion_expr::{Accumulator, ColumnarValue}; -use std::{any::Any, sync::Arc}; +use datafusion_functions_aggregate::approx_percentile_cont::ApproxPercentileAccumulator; + +use crate::aggregate::utils::down_cast_any_ref; +use crate::expressions::format_state_name; +use crate::{AggregateExpr, PhysicalExpr}; /// APPROX_PERCENTILE_CONT aggregate expression #[derive(Debug)] @@ -195,7 +188,7 @@ impl AggregateExpr for ApproxPercentileCont { } #[allow(rustdoc::private_intra_doc_links)] - /// See [`TDigest::to_scalar_state()`] for a description of the serialised + /// See [`datafusion_physical_expr_common::aggregate::tdigest::TDigest::to_scalar_state()`] for a description of the serialised /// state. fn state_fields(&self) -> Result> { Ok(vec![ @@ -254,220 +247,3 @@ impl PartialEq for ApproxPercentileCont { .unwrap_or(false) } } - -#[derive(Debug)] -pub struct ApproxPercentileAccumulator { - digest: TDigest, - percentile: f64, - return_type: DataType, -} - -impl ApproxPercentileAccumulator { - pub fn new(percentile: f64, return_type: DataType) -> Self { - Self { - digest: TDigest::new(DEFAULT_MAX_SIZE), - percentile, - return_type, - } - } - - pub fn new_with_max_size( - percentile: f64, - return_type: DataType, - max_size: usize, - ) -> Self { - Self { - digest: TDigest::new(max_size), - percentile, - return_type, - } - } - - pub(crate) fn merge_digests(&mut self, digests: &[TDigest]) { - let digests = digests.iter().chain(std::iter::once(&self.digest)); - self.digest = TDigest::merge_digests(digests) - } - - pub(crate) fn convert_to_float(values: &ArrayRef) -> Result> { - match values.data_type() { - DataType::Float64 => { - let array = downcast_value!(values, Float64Array); - Ok(array - .values() - .iter() - .filter_map(|v| v.try_as_f64().transpose()) - .collect::>>()?) - } - DataType::Float32 => { - let array = downcast_value!(values, Float32Array); - Ok(array - .values() - .iter() - .filter_map(|v| v.try_as_f64().transpose()) - .collect::>>()?) - } - DataType::Int64 => { - let array = downcast_value!(values, Int64Array); - Ok(array - .values() - .iter() - .filter_map(|v| v.try_as_f64().transpose()) - .collect::>>()?) - } - DataType::Int32 => { - let array = downcast_value!(values, Int32Array); - Ok(array - .values() - .iter() - .filter_map(|v| v.try_as_f64().transpose()) - .collect::>>()?) - } - DataType::Int16 => { - let array = downcast_value!(values, Int16Array); - Ok(array - .values() - .iter() - .filter_map(|v| v.try_as_f64().transpose()) - .collect::>>()?) - } - DataType::Int8 => { - let array = downcast_value!(values, Int8Array); - Ok(array - .values() - .iter() - .filter_map(|v| v.try_as_f64().transpose()) - .collect::>>()?) - } - DataType::UInt64 => { - let array = downcast_value!(values, UInt64Array); - Ok(array - .values() - .iter() - .filter_map(|v| v.try_as_f64().transpose()) - .collect::>>()?) - } - DataType::UInt32 => { - let array = downcast_value!(values, UInt32Array); - Ok(array - .values() - .iter() - .filter_map(|v| v.try_as_f64().transpose()) - .collect::>>()?) - } - DataType::UInt16 => { - let array = downcast_value!(values, UInt16Array); - Ok(array - .values() - .iter() - .filter_map(|v| v.try_as_f64().transpose()) - .collect::>>()?) - } - DataType::UInt8 => { - let array = downcast_value!(values, UInt8Array); - Ok(array - .values() - .iter() - .filter_map(|v| v.try_as_f64().transpose()) - .collect::>>()?) - } - e => internal_err!( - "APPROX_PERCENTILE_CONT is not expected to receive the type {e:?}" - ), - } - } -} - -impl Accumulator for ApproxPercentileAccumulator { - fn state(&mut self) -> Result> { - Ok(self.digest.to_scalar_state().into_iter().collect()) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - let values = &values[0]; - let sorted_values = &arrow::compute::sort(values, None)?; - let sorted_values = ApproxPercentileAccumulator::convert_to_float(sorted_values)?; - self.digest = self.digest.merge_sorted_f64(&sorted_values); - Ok(()) - } - - fn evaluate(&mut self) -> Result { - if self.digest.count() == 0.0 { - return ScalarValue::try_from(self.return_type.clone()); - } - let q = self.digest.estimate_quantile(self.percentile); - - // These acceptable return types MUST match the validation in - // ApproxPercentile::create_accumulator. - Ok(match &self.return_type { - DataType::Int8 => ScalarValue::Int8(Some(q as i8)), - DataType::Int16 => ScalarValue::Int16(Some(q as i16)), - DataType::Int32 => ScalarValue::Int32(Some(q as i32)), - DataType::Int64 => ScalarValue::Int64(Some(q as i64)), - DataType::UInt8 => ScalarValue::UInt8(Some(q as u8)), - DataType::UInt16 => ScalarValue::UInt16(Some(q as u16)), - DataType::UInt32 => ScalarValue::UInt32(Some(q as u32)), - DataType::UInt64 => ScalarValue::UInt64(Some(q as u64)), - DataType::Float32 => ScalarValue::Float32(Some(q as f32)), - DataType::Float64 => ScalarValue::Float64(Some(q)), - v => unreachable!("unexpected return type {:?}", v), - }) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - if states.is_empty() { - return Ok(()); - } - - let states = (0..states[0].len()) - .map(|index| { - states - .iter() - .map(|array| ScalarValue::try_from_array(array, index)) - .collect::>>() - .map(|state| TDigest::from_scalar_state(&state)) - }) - .collect::>>()?; - - self.merge_digests(&states); - - Ok(()) - } - - fn size(&self) -> usize { - std::mem::size_of_val(self) + self.digest.size() - - std::mem::size_of_val(&self.digest) - + self.return_type.size() - - std::mem::size_of_val(&self.return_type) - } -} - -#[cfg(test)] -mod tests { - use crate::aggregate::approx_percentile_cont::ApproxPercentileAccumulator; - use crate::aggregate::tdigest::TDigest; - use arrow_schema::DataType; - - #[test] - fn test_combine_approx_percentile_accumulator() { - let mut digests: Vec = Vec::new(); - - // one TDigest with 50_000 values from 1 to 1_000 - for _ in 1..=50 { - let t = TDigest::new(100); - let values: Vec<_> = (1..=1_000).map(f64::from).collect(); - let t = t.merge_unsorted_f64(values); - digests.push(t) - } - - let t1 = TDigest::merge_digests(&digests); - let t2 = TDigest::merge_digests(&digests); - - let mut accumulator = - ApproxPercentileAccumulator::new_with_max_size(0.5, DataType::Float64, 100); - - accumulator.merge_digests(&[t1]); - assert_eq!(accumulator.digest.count(), 50_000.0); - accumulator.merge_digests(&[t2]); - assert_eq!(accumulator.digest.count(), 100_000.0); - } -} diff --git a/datafusion/physical-expr/src/aggregate/approx_percentile_cont_with_weight.rs b/datafusion/physical-expr/src/aggregate/approx_percentile_cont_with_weight.rs index 3fa715a59238..07c2aff3437f 100644 --- a/datafusion/physical-expr/src/aggregate/approx_percentile_cont_with_weight.rs +++ b/datafusion/physical-expr/src/aggregate/approx_percentile_cont_with_weight.rs @@ -15,14 +15,16 @@ // specific language governing permissions and limitations // under the License. -use crate::aggregate::approx_percentile_cont::ApproxPercentileAccumulator; -use crate::aggregate::tdigest::{Centroid, TDigest, DEFAULT_MAX_SIZE}; use crate::expressions::ApproxPercentileCont; use crate::{AggregateExpr, PhysicalExpr}; use arrow::{ array::ArrayRef, datatypes::{DataType, Field}, }; +use datafusion_functions_aggregate::approx_percentile_cont::ApproxPercentileAccumulator; +use datafusion_physical_expr_common::aggregate::tdigest::{ + Centroid, TDigest, DEFAULT_MAX_SIZE, +}; use datafusion_common::Result; use datafusion_common::ScalarValue; diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index f0cff53fb3c5..89de6ad49c39 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -280,18 +280,6 @@ pub fn create_aggregate_expr( "approx_percentile_cont_with_weight(DISTINCT) aggregations are not available" ); } - (AggregateFunction::ApproxMedian, false) => { - Arc::new(expressions::ApproxMedian::try_new( - input_phy_exprs[0].clone(), - name, - data_type, - )?) - } - (AggregateFunction::ApproxMedian, true) => { - return not_impl_err!( - "APPROX_MEDIAN(DISTINCT) aggregations are not available" - ); - } (AggregateFunction::NthValue, _) => { let expr = &input_phy_exprs[0]; let Some(n) = input_phy_exprs[1] @@ -337,9 +325,8 @@ mod tests { use super::*; use crate::expressions::{ - try_cast, ApproxDistinct, ApproxMedian, ApproxPercentileCont, ArrayAgg, Avg, - BitAnd, BitOr, BitXor, BoolAnd, BoolOr, Count, DistinctArrayAgg, DistinctCount, - Max, Min, + try_cast, ApproxDistinct, ApproxPercentileCont, ArrayAgg, Avg, BitAnd, BitOr, + BitXor, BoolAnd, BoolOr, Count, DistinctArrayAgg, DistinctCount, Max, Min, }; use datafusion_common::{plan_err, DataFusionError, ScalarValue}; @@ -686,60 +673,6 @@ mod tests { Ok(()) } - #[test] - fn test_median_expr() -> Result<()> { - let funcs = vec![AggregateFunction::ApproxMedian]; - let data_types = vec![ - DataType::UInt32, - DataType::UInt64, - DataType::Int32, - DataType::Int64, - DataType::Float32, - DataType::Float64, - ]; - for fun in funcs { - for data_type in &data_types { - let input_schema = - Schema::new(vec![Field::new("c1", data_type.clone(), true)]); - let input_phy_exprs: Vec> = vec![Arc::new( - expressions::Column::new_with_schema("c1", &input_schema).unwrap(), - )]; - let result_agg_phy_exprs = create_physical_agg_expr_for_test( - &fun, - false, - &input_phy_exprs[0..1], - &input_schema, - "c1", - )?; - - if fun == AggregateFunction::ApproxMedian { - assert!(result_agg_phy_exprs.as_any().is::()); - assert_eq!("c1", result_agg_phy_exprs.name()); - assert_eq!( - Field::new("c1", data_type.clone(), true), - result_agg_phy_exprs.field().unwrap() - ); - } - } - } - Ok(()) - } - - #[test] - fn test_median() -> Result<()> { - let observed = AggregateFunction::ApproxMedian.return_type(&[DataType::Utf8]); - assert!(observed.is_err()); - - let observed = AggregateFunction::ApproxMedian.return_type(&[DataType::Int32])?; - assert_eq!(DataType::Int32, observed); - - let observed = - AggregateFunction::ApproxMedian.return_type(&[DataType::Decimal128(10, 6)]); - assert!(observed.is_err()); - - Ok(()) - } - #[test] fn test_min_max() -> Result<()> { let observed = AggregateFunction::Min.return_type(&[DataType::Utf8])?; diff --git a/datafusion/physical-expr/src/aggregate/mod.rs b/datafusion/physical-expr/src/aggregate/mod.rs index 2c14c1550eb8..9db80f155ab3 100644 --- a/datafusion/physical-expr/src/aggregate/mod.rs +++ b/datafusion/physical-expr/src/aggregate/mod.rs @@ -18,10 +18,8 @@ pub use datafusion_physical_expr_common::aggregate::AggregateExpr; mod hyperloglog; -mod tdigest; pub(crate) mod approx_distinct; -pub(crate) mod approx_median; pub(crate) mod approx_percentile_cont; pub(crate) mod approx_percentile_cont_with_weight; pub(crate) mod array_agg; diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 476cbe390733..656cc570ca60 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -38,7 +38,6 @@ pub mod helpers { } pub use crate::aggregate::approx_distinct::ApproxDistinct; -pub use crate::aggregate::approx_median::ApproxMedian; pub use crate::aggregate::approx_percentile_cont::ApproxPercentileCont; pub use crate::aggregate::approx_percentile_cont_with_weight::ApproxPercentileContWithWeight; pub use crate::aggregate::array_agg::ArrayAgg; diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 0071a43bbe3a..9f23824b3af7 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -487,7 +487,7 @@ enum AggregateFunction { // STDDEV_POP = 12; CORRELATION = 13; APPROX_PERCENTILE_CONT = 14; - APPROX_MEDIAN = 15; + // APPROX_MEDIAN = 15; APPROX_PERCENTILE_CONT_WITH_WEIGHT = 16; GROUPING = 17; // MEDIAN = 18; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index e6aded8901ee..28f80c5ee1d5 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -540,7 +540,6 @@ impl serde::Serialize for AggregateFunction { Self::ArrayAgg => "ARRAY_AGG", Self::Correlation => "CORRELATION", Self::ApproxPercentileCont => "APPROX_PERCENTILE_CONT", - Self::ApproxMedian => "APPROX_MEDIAN", Self::ApproxPercentileContWithWeight => "APPROX_PERCENTILE_CONT_WITH_WEIGHT", Self::Grouping => "GROUPING", Self::BitAnd => "BIT_AND", @@ -578,7 +577,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "ARRAY_AGG", "CORRELATION", "APPROX_PERCENTILE_CONT", - "APPROX_MEDIAN", "APPROX_PERCENTILE_CONT_WITH_WEIGHT", "GROUPING", "BIT_AND", @@ -645,7 +643,6 @@ impl<'de> serde::Deserialize<'de> for AggregateFunction { "ARRAY_AGG" => Ok(AggregateFunction::ArrayAgg), "CORRELATION" => Ok(AggregateFunction::Correlation), "APPROX_PERCENTILE_CONT" => Ok(AggregateFunction::ApproxPercentileCont), - "APPROX_MEDIAN" => Ok(AggregateFunction::ApproxMedian), "APPROX_PERCENTILE_CONT_WITH_WEIGHT" => Ok(AggregateFunction::ApproxPercentileContWithWeight), "GROUPING" => Ok(AggregateFunction::Grouping), "BIT_AND" => Ok(AggregateFunction::BitAnd), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 7ec91874912e..9741b2bc4209 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1931,7 +1931,7 @@ pub enum AggregateFunction { /// STDDEV_POP = 12; Correlation = 13, ApproxPercentileCont = 14, - ApproxMedian = 15, + /// APPROX_MEDIAN = 15; ApproxPercentileContWithWeight = 16, Grouping = 17, /// MEDIAN = 18; @@ -1967,7 +1967,6 @@ impl AggregateFunction { AggregateFunction::ArrayAgg => "ARRAY_AGG", AggregateFunction::Correlation => "CORRELATION", AggregateFunction::ApproxPercentileCont => "APPROX_PERCENTILE_CONT", - AggregateFunction::ApproxMedian => "APPROX_MEDIAN", AggregateFunction::ApproxPercentileContWithWeight => { "APPROX_PERCENTILE_CONT_WITH_WEIGHT" } @@ -2001,7 +2000,6 @@ impl AggregateFunction { "ARRAY_AGG" => Some(Self::ArrayAgg), "CORRELATION" => Some(Self::Correlation), "APPROX_PERCENTILE_CONT" => Some(Self::ApproxPercentileCont), - "APPROX_MEDIAN" => Some(Self::ApproxMedian), "APPROX_PERCENTILE_CONT_WITH_WEIGHT" => { Some(Self::ApproxPercentileContWithWeight) } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index a77d3619831d..5c083fa27a9b 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -164,7 +164,6 @@ impl From for AggregateFunction { protobuf::AggregateFunction::ApproxPercentileContWithWeight => { Self::ApproxPercentileContWithWeight } - protobuf::AggregateFunction::ApproxMedian => Self::ApproxMedian, protobuf::AggregateFunction::Grouping => Self::Grouping, protobuf::AggregateFunction::NthValueAgg => Self::NthValue, protobuf::AggregateFunction::StringAgg => Self::StringAgg, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 9c4c7685b34e..e2259896b26e 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -133,7 +133,6 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction { AggregateFunction::ApproxPercentileContWithWeight => { Self::ApproxPercentileContWithWeight } - AggregateFunction::ApproxMedian => Self::ApproxMedian, AggregateFunction::Grouping => Self::Grouping, AggregateFunction::NthValue => Self::NthValueAgg, AggregateFunction::StringAgg => Self::StringAgg, @@ -430,9 +429,6 @@ pub fn serialize_expr( AggregateFunction::RegrSXX => protobuf::AggregateFunction::RegrSxx, AggregateFunction::RegrSYY => protobuf::AggregateFunction::RegrSyy, AggregateFunction::RegrSXY => protobuf::AggregateFunction::RegrSxy, - AggregateFunction::ApproxMedian => { - protobuf::AggregateFunction::ApproxMedian - } AggregateFunction::Grouping => protobuf::AggregateFunction::Grouping, AggregateFunction::NthValue => { protobuf::AggregateFunction::NthValueAgg diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 5d07d5c0fa8a..19ba4a40d52b 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -23,12 +23,12 @@ use datafusion::datasource::file_format::parquet::ParquetSink; use datafusion::physical_expr::window::{NthValueKind, SlidingAggregateWindowExpr}; use datafusion::physical_expr::{PhysicalSortExpr, ScalarFunctionExpr}; use datafusion::physical_plan::expressions::{ - ApproxDistinct, ApproxMedian, ApproxPercentileCont, ApproxPercentileContWithWeight, - ArrayAgg, Avg, BinaryExpr, BitAnd, BitOr, BitXor, BoolAnd, BoolOr, CaseExpr, - CastExpr, Column, Correlation, Count, CumeDist, DistinctArrayAgg, DistinctBitXor, - DistinctCount, Grouping, InListExpr, IsNotNullExpr, IsNullExpr, Literal, Max, Min, - NegativeExpr, NotExpr, NthValue, NthValueAgg, Ntile, OrderSensitiveArrayAgg, Rank, - RankType, Regr, RegrType, RowNumber, StringAgg, TryCastExpr, WindowShift, + ApproxDistinct, ApproxPercentileCont, ApproxPercentileContWithWeight, ArrayAgg, Avg, + BinaryExpr, BitAnd, BitOr, BitXor, BoolAnd, BoolOr, CaseExpr, CastExpr, Column, + Correlation, Count, CumeDist, DistinctArrayAgg, DistinctBitXor, DistinctCount, + Grouping, InListExpr, IsNotNullExpr, IsNullExpr, Literal, Max, Min, NegativeExpr, + NotExpr, NthValue, NthValueAgg, Ntile, OrderSensitiveArrayAgg, Rank, RankType, Regr, + RegrType, RowNumber, StringAgg, TryCastExpr, WindowShift, }; use datafusion::physical_plan::udaf::AggregateFunctionExpr; use datafusion::physical_plan::windows::{BuiltInWindowExpr, PlainAggregateWindowExpr}; @@ -296,8 +296,6 @@ fn aggr_expr_to_aggr_fn(expr: &dyn AggregateExpr) -> Result { .is_some() { protobuf::AggregateFunction::ApproxPercentileContWithWeight - } else if aggr_expr.downcast_ref::().is_some() { - protobuf::AggregateFunction::ApproxMedian } else if aggr_expr.downcast_ref::().is_some() { protobuf::AggregateFunction::StringAgg } else if aggr_expr.downcast_ref::().is_some() { diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index b1cad69b14fb..699697dd2f2c 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -33,6 +33,7 @@ use datafusion::datasource::TableProvider; use datafusion::execution::context::SessionState; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion::execution::FunctionRegistry; +use datafusion::functions_aggregate::approx_median::approx_median; use datafusion::functions_aggregate::expr_fn::{ covar_pop, covar_samp, first_value, median, stddev, stddev_pop, sum, var_pop, var_sample, @@ -658,6 +659,7 @@ async fn roundtrip_expr_api() -> Result<()> { var_pop(lit(2.2)), stddev(lit(2.2)), stddev_pop(lit(2.2)), + approx_median(lit(2)), ]; // ensure expressions created with the expr api can be round tripped diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 6a99f9719de9..7b9d39a2b51e 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -37,6 +37,7 @@ use datafusion_sql::{ planner::{ParserOptions, SqlToRel}, }; +use datafusion_functions_aggregate::approx_median::approx_median_udaf; use rstest::rstest; use sqlparser::dialect::{Dialect, GenericDialect, HiveDialect, MySqlDialect}; @@ -1649,8 +1650,8 @@ fn select_count_column() { #[test] fn select_approx_median() { let sql = "SELECT approx_median(age) FROM person"; - let expected = "Projection: APPROX_MEDIAN(person.age)\ - \n Aggregate: groupBy=[[]], aggr=[[APPROX_MEDIAN(person.age)]]\ + let expected = "Projection: approx_median(person.age)\ + \n Aggregate: groupBy=[[]], aggr=[[approx_median(person.age)]]\ \n TableScan: person"; quick_test(sql, expected); } @@ -2581,8 +2582,8 @@ fn approx_median_window() { let sql = "SELECT order_id, APPROX_MEDIAN(qty) OVER(PARTITION BY order_id) from orders"; let expected = "\ - Projection: orders.order_id, APPROX_MEDIAN(orders.qty) PARTITION BY [orders.order_id] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING\ - \n WindowAggr: windowExpr=[[APPROX_MEDIAN(orders.qty) PARTITION BY [orders.order_id] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ + Projection: orders.order_id, approx_median(orders.qty) PARTITION BY [orders.order_id] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING\ + \n WindowAggr: windowExpr=[[approx_median(orders.qty) PARTITION BY [orders.order_id] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]\ \n TableScan: orders"; quick_test(sql, expected); } @@ -2700,7 +2701,8 @@ fn logical_plan_with_dialect_and_options( DataType::Int32, )) .with_udf(make_udf("sqrt", vec![DataType::Int64], DataType::Int64)) - .with_udaf(sum_udaf()); + .with_udaf(sum_udaf()) + .with_udaf(approx_median_udaf()); let planner = SqlToRel::new_with_options(&context, options); let result = DFParser::parse_sql_with_dialect(sql, dialect); diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 9958f8ac38ea..a245793ebd09 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -518,6 +518,11 @@ SELECT approx_median(c12) FROM aggregate_test_100 ---- 0.555006541052 +# csv_query_approx_median_4 +# test with string, approx median only supports numeric +statement error +SELECT approx_median(c1) FROM aggregate_test_100 + # csv_query_median_1 query I SELECT median(c2) FROM aggregate_test_100 @@ -637,6 +642,11 @@ select median(c), arrow_typeof(median(c)) from t; ---- 0.0003 Decimal128(10, 4) +query RT +select approx_median(c), arrow_typeof(approx_median(c)) from t; +---- +0.00035 Float64 + statement ok drop table t;