From 02427674eef658b9b0acb7142f18e8c1520bdb17 Mon Sep 17 00:00:00 2001 From: Edmondo Porcu Date: Mon, 30 Sep 2024 15:09:40 -0400 Subject: [PATCH] Add `value_from_statisics` to AggregateUDFImpl, remove special case for min/max/count aggregate statistics (#12296) * Removes min/max/count comparison based on name in aggregate statistics * Abstracting away value from statistics * Removing imports * Introduced StatisticsArgs * Fixed docs --- datafusion/expr/src/lib.rs | 2 +- datafusion/expr/src/udaf.rs | 27 ++- datafusion/functions-aggregate/src/count.rs | 35 +++- datafusion/functions-aggregate/src/min_max.rs | 77 +++++++- .../src/aggregate_statistics.rs | 182 ++---------------- datafusion/physical-plan/src/lib.rs | 1 + 6 files changed, 154 insertions(+), 170 deletions(-) diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 32eac90c3eec..7d94a3b93eab 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -90,7 +90,7 @@ pub use logical_plan::*; pub use partition_evaluator::PartitionEvaluator; pub use sqlparser; pub use table_source::{TableProviderFilterPushDown, TableSource, TableType}; -pub use udaf::{AggregateUDF, AggregateUDFImpl, ReversedUDAF}; +pub use udaf::{AggregateUDF, AggregateUDFImpl, ReversedUDAF, StatisticsArgs}; pub use udf::{ScalarUDF, ScalarUDFImpl}; pub use udwf::{ReversedUDWF, WindowUDF, WindowUDFImpl}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index e3ef672daf5f..d8592bce60cd 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -26,7 +26,8 @@ use std::vec; use arrow::datatypes::{DataType, Field}; -use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue}; +use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue, Statistics}; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use crate::expr::AggregateFunction; use crate::function::{ @@ -94,6 +95,19 @@ impl fmt::Display for AggregateUDF { } } +pub struct StatisticsArgs<'a> { + pub statistics: &'a Statistics, + pub return_type: &'a DataType, + /// Whether the aggregate function is distinct. + /// + /// ```sql + /// SELECT COUNT(DISTINCT column1) FROM t; + /// ``` + pub is_distinct: bool, + /// The physical expression of arguments the aggregate function takes. + pub exprs: &'a [Arc], +} + impl AggregateUDF { /// Create a new `AggregateUDF` from a `[AggregateUDFImpl]` trait object /// @@ -244,6 +258,13 @@ impl AggregateUDF { self.inner.is_descending() } + pub fn value_from_stats( + &self, + statistics_args: &StatisticsArgs, + ) -> Option { + self.inner.value_from_stats(statistics_args) + } + /// See [`AggregateUDFImpl::default_value`] for more details. pub fn default_value(&self, data_type: &DataType) -> Result { self.inner.default_value(data_type) @@ -556,6 +577,10 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { fn is_descending(&self) -> Option { None } + // Return the value of the current UDF from the statistics + fn value_from_stats(&self, _statistics_args: &StatisticsArgs) -> Option { + None + } /// Returns default value of the function given the input is all `null`. /// diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index 417e28e72a71..cc245b3572ec 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -16,7 +16,9 @@ // under the License. use ahash::RandomState; +use datafusion_common::stats::Precision; use datafusion_functions_aggregate_common::aggregate::count_distinct::BytesViewDistinctCountAccumulator; +use datafusion_physical_expr::expressions; use std::collections::HashSet; use std::ops::BitAnd; use std::{fmt::Debug, sync::Arc}; @@ -46,7 +48,7 @@ use datafusion_expr::{ function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, Signature, Volatility, }; -use datafusion_expr::{Expr, ReversedUDAF, TypeSignature}; +use datafusion_expr::{Expr, ReversedUDAF, StatisticsArgs, TypeSignature}; use datafusion_functions_aggregate_common::aggregate::count_distinct::{ BytesDistinctCountAccumulator, FloatDistinctCountAccumulator, PrimitiveDistinctCountAccumulator, @@ -54,6 +56,7 @@ use datafusion_functions_aggregate_common::aggregate::count_distinct::{ use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate_indices; use datafusion_physical_expr_common::binary_map::OutputType; +use datafusion_common::utils::expr::COUNT_STAR_EXPANSION; make_udaf_expr_and_func!( Count, count, @@ -291,6 +294,36 @@ impl AggregateUDFImpl for Count { fn default_value(&self, _data_type: &DataType) -> Result { Ok(ScalarValue::Int64(Some(0))) } + + fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option { + if statistics_args.is_distinct { + return None; + } + if let Precision::Exact(num_rows) = statistics_args.statistics.num_rows { + if statistics_args.exprs.len() == 1 { + // TODO optimize with exprs other than Column + if let Some(col_expr) = statistics_args.exprs[0] + .as_any() + .downcast_ref::() + { + let current_val = &statistics_args.statistics.column_statistics + [col_expr.index()] + .null_count; + if let &Precision::Exact(val) = current_val { + return Some(ScalarValue::Int64(Some((num_rows - val) as i64))); + } + } else if let Some(lit_expr) = statistics_args.exprs[0] + .as_any() + .downcast_ref::() + { + if lit_expr.value() == &COUNT_STAR_EXPANSION { + return Some(ScalarValue::Int64(Some(num_rows as i64))); + } + } + } + } + None + } } #[derive(Debug)] diff --git a/datafusion/functions-aggregate/src/min_max.rs b/datafusion/functions-aggregate/src/min_max.rs index 961e8639604c..1ce1abe09ea8 100644 --- a/datafusion/functions-aggregate/src/min_max.rs +++ b/datafusion/functions-aggregate/src/min_max.rs @@ -15,7 +15,7 @@ // under the License. //! [`Max`] and [`MaxAccumulator`] accumulator for the `max` function -//! [`Min`] and [`MinAccumulator`] accumulator for the `max` function +//! [`Min`] and [`MinAccumulator`] accumulator for the `min` function // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file @@ -49,10 +49,12 @@ use arrow::datatypes::{ UInt8Type, }; use arrow_schema::IntervalUnit; +use datafusion_common::stats::Precision; use datafusion_common::{ - downcast_value, exec_err, internal_err, DataFusionError, Result, + downcast_value, exec_err, internal_err, ColumnStatistics, DataFusionError, Result, }; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator; +use datafusion_physical_expr::expressions; use std::fmt::Debug; use arrow::datatypes::i256; @@ -63,10 +65,10 @@ use arrow::datatypes::{ }; use datafusion_common::ScalarValue; -use datafusion_expr::GroupsAccumulator; use datafusion_expr::{ function::AccumulatorArgs, Accumulator, AggregateUDFImpl, Signature, Volatility, }; +use datafusion_expr::{GroupsAccumulator, StatisticsArgs}; use half::f16; use std::ops::Deref; @@ -147,6 +149,54 @@ macro_rules! instantiate_min_accumulator { }}; } +trait FromColumnStatistics { + fn value_from_column_statistics( + &self, + stats: &ColumnStatistics, + ) -> Option; + + fn value_from_statistics( + &self, + statistics_args: &StatisticsArgs, + ) -> Option { + if let Precision::Exact(num_rows) = &statistics_args.statistics.num_rows { + match *num_rows { + 0 => return ScalarValue::try_from(statistics_args.return_type).ok(), + value if value > 0 => { + let col_stats = &statistics_args.statistics.column_statistics; + if statistics_args.exprs.len() == 1 { + // TODO optimize with exprs other than Column + if let Some(col_expr) = statistics_args.exprs[0] + .as_any() + .downcast_ref::() + { + return self.value_from_column_statistics( + &col_stats[col_expr.index()], + ); + } + } + } + _ => {} + } + } + None + } +} + +impl FromColumnStatistics for Max { + fn value_from_column_statistics( + &self, + col_stats: &ColumnStatistics, + ) -> Option { + if let Precision::Exact(ref val) = col_stats.max_value { + if !val.is_null() { + return Some(val.clone()); + } + } + None + } +} + impl AggregateUDFImpl for Max { fn as_any(&self) -> &dyn std::any::Any { self @@ -272,6 +322,7 @@ impl AggregateUDFImpl for Max { fn is_descending(&self) -> Option { Some(true) } + fn order_sensitivity(&self) -> datafusion_expr::utils::AggregateOrderSensitivity { datafusion_expr::utils::AggregateOrderSensitivity::Insensitive } @@ -282,6 +333,9 @@ impl AggregateUDFImpl for Max { fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF { datafusion_expr::ReversedUDAF::Identical } + fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option { + self.value_from_statistics(statistics_args) + } } // Statically-typed version of min/max(array) -> ScalarValue for string types @@ -926,6 +980,20 @@ impl Default for Min { } } +impl FromColumnStatistics for Min { + fn value_from_column_statistics( + &self, + col_stats: &ColumnStatistics, + ) -> Option { + if let Precision::Exact(ref val) = col_stats.min_value { + if !val.is_null() { + return Some(val.clone()); + } + } + None + } +} + impl AggregateUDFImpl for Min { fn as_any(&self) -> &dyn std::any::Any { self @@ -1052,6 +1120,9 @@ impl AggregateUDFImpl for Min { Some(false) } + fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option { + self.value_from_statistics(statistics_args) + } fn order_sensitivity(&self) -> datafusion_expr::utils::AggregateOrderSensitivity { datafusion_expr::utils::AggregateOrderSensitivity::Insensitive } diff --git a/datafusion/physical-optimizer/src/aggregate_statistics.rs b/datafusion/physical-optimizer/src/aggregate_statistics.rs index 71f129be984d..a11b498b955c 100644 --- a/datafusion/physical-optimizer/src/aggregate_statistics.rs +++ b/datafusion/physical-optimizer/src/aggregate_statistics.rs @@ -23,14 +23,12 @@ use datafusion_common::scalar::ScalarValue; use datafusion_common::Result; use datafusion_physical_plan::aggregates::AggregateExec; use datafusion_physical_plan::projection::ProjectionExec; -use datafusion_physical_plan::{expressions, ExecutionPlan, Statistics}; +use datafusion_physical_plan::{expressions, ExecutionPlan}; use crate::PhysicalOptimizerRule; -use datafusion_common::stats::Precision; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion_common::utils::expr::COUNT_STAR_EXPANSION; use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; -use datafusion_physical_plan::udaf::AggregateFunctionExpr; +use datafusion_physical_plan::udaf::{AggregateFunctionExpr, StatisticsArgs}; /// Optimizer that uses available statistics for aggregate functions #[derive(Default, Debug)] @@ -57,14 +55,19 @@ impl PhysicalOptimizerRule for AggregateStatistics { let stats = partial_agg_exec.input().statistics()?; let mut projections = vec![]; for expr in partial_agg_exec.aggr_expr() { - if let Some((non_null_rows, name)) = - take_optimizable_column_and_table_count(expr, &stats) + let field = expr.field(); + let args = expr.expressions(); + let statistics_args = StatisticsArgs { + statistics: &stats, + return_type: field.data_type(), + is_distinct: expr.is_distinct(), + exprs: args.as_slice(), + }; + if let Some((optimizable_statistic, name)) = + take_optimizable_value_from_statistics(&statistics_args, expr) { - projections.push((expressions::lit(non_null_rows), name.to_owned())); - } else if let Some((min, name)) = take_optimizable_min(expr, &stats) { - projections.push((expressions::lit(min), name.to_owned())); - } else if let Some((max, name)) = take_optimizable_max(expr, &stats) { - projections.push((expressions::lit(max), name.to_owned())); + projections + .push((expressions::lit(optimizable_statistic), name.to_owned())); } else { // TODO: we need all aggr_expr to be resolved (cf TODO fullres) break; @@ -135,160 +138,11 @@ fn take_optimizable(node: &dyn ExecutionPlan) -> Option> None } -/// If this agg_expr is a count that can be exactly derived from the statistics, return it. -fn take_optimizable_column_and_table_count( - agg_expr: &AggregateFunctionExpr, - stats: &Statistics, -) -> Option<(ScalarValue, String)> { - let col_stats = &stats.column_statistics; - if is_non_distinct_count(agg_expr) { - if let Precision::Exact(num_rows) = stats.num_rows { - let exprs = agg_expr.expressions(); - if exprs.len() == 1 { - // TODO optimize with exprs other than Column - if let Some(col_expr) = - exprs[0].as_any().downcast_ref::() - { - let current_val = &col_stats[col_expr.index()].null_count; - if let &Precision::Exact(val) = current_val { - return Some(( - ScalarValue::Int64(Some((num_rows - val) as i64)), - agg_expr.name().to_string(), - )); - } - } else if let Some(lit_expr) = - exprs[0].as_any().downcast_ref::() - { - if lit_expr.value() == &COUNT_STAR_EXPANSION { - return Some(( - ScalarValue::Int64(Some(num_rows as i64)), - agg_expr.name().to_string(), - )); - } - } - } - } - } - None -} - -/// If this agg_expr is a min that is exactly defined in the statistics, return it. -fn take_optimizable_min( - agg_expr: &AggregateFunctionExpr, - stats: &Statistics, -) -> Option<(ScalarValue, String)> { - if let Precision::Exact(num_rows) = &stats.num_rows { - match *num_rows { - 0 => { - // MIN/MAX with 0 rows is always null - if is_min(agg_expr) { - if let Ok(min_data_type) = - ScalarValue::try_from(agg_expr.field().data_type()) - { - return Some((min_data_type, agg_expr.name().to_string())); - } - } - } - value if value > 0 => { - let col_stats = &stats.column_statistics; - if is_min(agg_expr) { - let exprs = agg_expr.expressions(); - if exprs.len() == 1 { - // TODO optimize with exprs other than Column - if let Some(col_expr) = - exprs[0].as_any().downcast_ref::() - { - if let Precision::Exact(val) = - &col_stats[col_expr.index()].min_value - { - if !val.is_null() { - return Some(( - val.clone(), - agg_expr.name().to_string(), - )); - } - } - } - } - } - } - _ => {} - } - } - None -} - /// If this agg_expr is a max that is exactly defined in the statistics, return it. -fn take_optimizable_max( +fn take_optimizable_value_from_statistics( + statistics_args: &StatisticsArgs, agg_expr: &AggregateFunctionExpr, - stats: &Statistics, ) -> Option<(ScalarValue, String)> { - if let Precision::Exact(num_rows) = &stats.num_rows { - match *num_rows { - 0 => { - // MIN/MAX with 0 rows is always null - if is_max(agg_expr) { - if let Ok(max_data_type) = - ScalarValue::try_from(agg_expr.field().data_type()) - { - return Some((max_data_type, agg_expr.name().to_string())); - } - } - } - value if value > 0 => { - let col_stats = &stats.column_statistics; - if is_max(agg_expr) { - let exprs = agg_expr.expressions(); - if exprs.len() == 1 { - // TODO optimize with exprs other than Column - if let Some(col_expr) = - exprs[0].as_any().downcast_ref::() - { - if let Precision::Exact(val) = - &col_stats[col_expr.index()].max_value - { - if !val.is_null() { - return Some(( - val.clone(), - agg_expr.name().to_string(), - )); - } - } - } - } - } - } - _ => {} - } - } - None -} - -// TODO: Move this check into AggregateUDFImpl -// https://github.com/apache/datafusion/issues/11153 -fn is_non_distinct_count(agg_expr: &AggregateFunctionExpr) -> bool { - if agg_expr.fun().name() == "count" && !agg_expr.is_distinct() { - return true; - } - false + let value = agg_expr.fun().value_from_stats(statistics_args); + value.map(|val| (val, agg_expr.name().to_string())) } - -// TODO: Move this check into AggregateUDFImpl -// https://github.com/apache/datafusion/issues/11153 -fn is_min(agg_expr: &AggregateFunctionExpr) -> bool { - if agg_expr.fun().name().to_lowercase() == "min" { - return true; - } - false -} - -// TODO: Move this check into AggregateUDFImpl -// https://github.com/apache/datafusion/issues/11153 -fn is_max(agg_expr: &AggregateFunctionExpr) -> bool { - if agg_expr.fun().name().to_lowercase() == "max" { - return true; - } - false -} - -// See tests in datafusion/core/tests/physical_optimizer diff --git a/datafusion/physical-plan/src/lib.rs b/datafusion/physical-plan/src/lib.rs index 7cbfd49afb86..845a74eaea48 100644 --- a/datafusion/physical-plan/src/lib.rs +++ b/datafusion/physical-plan/src/lib.rs @@ -82,6 +82,7 @@ pub mod windows; pub mod work_table; pub mod udaf { + pub use datafusion_expr::StatisticsArgs; pub use datafusion_physical_expr::aggregate::AggregateFunctionExpr; }