Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add value_from_statisics to AggregateUDFImpl, remove special case for min/max/count aggregate statistics #12296

Merged
merged 5 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datafusion/expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ pub use logical_plan::*;
pub use partition_evaluator::PartitionEvaluator;
pub use sqlparser;
pub use table_source::{TableProviderFilterPushDown, TableSource, TableType};
pub use udaf::{AggregateUDF, AggregateUDFImpl, ReversedUDAF};
pub use udaf::{AggregateUDF, AggregateUDFImpl, ReversedUDAF, StatisticsArgs};
pub use udf::{ScalarUDF, ScalarUDFImpl};
pub use udwf::{WindowUDF, WindowUDFImpl};
pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits};
Expand Down
27 changes: 26 additions & 1 deletion datafusion/expr/src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ use std::vec;

use arrow::datatypes::{DataType, Field};

use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue};
use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue, Statistics};
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;

use crate::expr::AggregateFunction;
use crate::function::{
Expand Down Expand Up @@ -93,6 +94,19 @@ impl fmt::Display for AggregateUDF {
}
}

pub struct StatisticsArgs<'a> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️

I think it would also be great to add some documentation. Perhaps like this:

Suggested change
pub struct StatisticsArgs<'a> {
/// Arguments passed to [`AggregateUDFImpl::value_from_stats`]
pub struct StatisticsArgs<'a> {

pub statistics: &'a Statistics,
pub return_type: &'a DataType,
/// Whether the aggregate function is distinct.
///
/// ```sql
/// SELECT COUNT(DISTINCT column1) FROM t;
/// ```
pub is_distinct: bool,
/// The physical expression of arguments the aggregate function takes.
pub exprs: &'a [Arc<dyn PhysicalExpr>],
}

impl AggregateUDF {
/// Create a new AggregateUDF
///
Expand Down Expand Up @@ -262,6 +276,13 @@ impl AggregateUDF {
self.inner.is_descending()
}

pub fn value_from_stats(
&self,
statistics_args: &StatisticsArgs,
) -> Option<ScalarValue> {
self.inner.value_from_stats(statistics_args)
}

/// See [`AggregateUDFImpl::default_value`] for more details.
pub fn default_value(&self, data_type: &DataType) -> Result<ScalarValue> {
self.inner.default_value(data_type)
Expand Down Expand Up @@ -574,6 +595,10 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
fn is_descending(&self) -> Option<bool> {
None
}
// Return the value of the current UDF from the statistics
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Return the value of the current UDF from the statistics
/// Return the value of this UDF for the query if it can be determined entirely from
/// statistics and arguments.
///
/// For example, if the minimum value of column `x` is known exactly in the statistics,
/// then `MIN(x)` can be replaced by that value, significantly improving query performance.

fn value_from_stats(&self, _statistics_args: &StatisticsArgs) -> Option<ScalarValue> {
None
}

/// Returns default value of the function given the input is all `null`.
///
Expand Down
35 changes: 34 additions & 1 deletion datafusion/functions-aggregate/src/count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
// under the License.

use ahash::RandomState;
use datafusion_common::stats::Precision;
use datafusion_functions_aggregate_common::aggregate::count_distinct::BytesViewDistinctCountAccumulator;
use datafusion_physical_expr::expressions;
use std::collections::HashSet;
use std::ops::BitAnd;
use std::{fmt::Debug, sync::Arc};
Expand Down Expand Up @@ -46,14 +48,15 @@ use datafusion_expr::{
function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl,
EmitTo, GroupsAccumulator, Signature, Volatility,
};
use datafusion_expr::{Expr, ReversedUDAF, TypeSignature};
use datafusion_expr::{Expr, ReversedUDAF, StatisticsArgs, TypeSignature};
use datafusion_functions_aggregate_common::aggregate::count_distinct::{
BytesDistinctCountAccumulator, FloatDistinctCountAccumulator,
PrimitiveDistinctCountAccumulator,
};
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::accumulate_indices;
use datafusion_physical_expr_common::binary_map::OutputType;

use datafusion_common::utils::expr::COUNT_STAR_EXPANSION;
make_udaf_expr_and_func!(
Count,
count,
Expand Down Expand Up @@ -291,6 +294,36 @@ impl AggregateUDFImpl for Count {
fn default_value(&self, _data_type: &DataType) -> Result<ScalarValue> {
Ok(ScalarValue::Int64(Some(0)))
}

fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option<ScalarValue> {
if statistics_args.is_distinct {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

return None;
}
if let Precision::Exact(num_rows) = statistics_args.statistics.num_rows {
if statistics_args.exprs.len() == 1 {
// TODO optimize with exprs other than Column
if let Some(col_expr) = statistics_args.exprs[0]
.as_any()
.downcast_ref::<expressions::Column>()
{
let current_val = &statistics_args.statistics.column_statistics
[col_expr.index()]
.null_count;
if let &Precision::Exact(val) = current_val {
return Some(ScalarValue::Int64(Some((num_rows - val) as i64)));
}
} else if let Some(lit_expr) = statistics_args.exprs[0]
.as_any()
.downcast_ref::<expressions::Literal>()
{
if lit_expr.value() == &COUNT_STAR_EXPANSION {
return Some(ScalarValue::Int64(Some(num_rows as i64)));
}
}
}
}
None
}
}

#[derive(Debug)]
Expand Down
77 changes: 74 additions & 3 deletions datafusion/functions-aggregate/src/min_max.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
// under the License.

//! [`Max`] and [`MaxAccumulator`] accumulator for the `max` function
//! [`Min`] and [`MinAccumulator`] accumulator for the `max` function
//! [`Min`] and [`MinAccumulator`] accumulator for the `min` function

// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
Expand Down Expand Up @@ -49,10 +49,12 @@ use arrow::datatypes::{
UInt8Type,
};
use arrow_schema::IntervalUnit;
use datafusion_common::stats::Precision;
use datafusion_common::{
downcast_value, exec_err, internal_err, DataFusionError, Result,
downcast_value, exec_err, internal_err, ColumnStatistics, DataFusionError, Result,
};
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator;
use datafusion_physical_expr::expressions;
use std::fmt::Debug;

use arrow::datatypes::i256;
Expand All @@ -63,10 +65,10 @@ use arrow::datatypes::{
};

use datafusion_common::ScalarValue;
use datafusion_expr::GroupsAccumulator;
use datafusion_expr::{
function::AccumulatorArgs, Accumulator, AggregateUDFImpl, Signature, Volatility,
};
use datafusion_expr::{GroupsAccumulator, StatisticsArgs};
use half::f16;
use std::ops::Deref;

Expand Down Expand Up @@ -147,6 +149,54 @@ macro_rules! instantiate_min_accumulator {
}};
}

trait FromColumnStatistics {
fn value_from_column_statistics(
&self,
stats: &ColumnStatistics,
) -> Option<ScalarValue>;

fn value_from_statistics(
&self,
statistics_args: &StatisticsArgs,
) -> Option<ScalarValue> {
if let Precision::Exact(num_rows) = &statistics_args.statistics.num_rows {
match *num_rows {
0 => return ScalarValue::try_from(statistics_args.return_type).ok(),
value if value > 0 => {
let col_stats = &statistics_args.statistics.column_statistics;
if statistics_args.exprs.len() == 1 {
// TODO optimize with exprs other than Column
if let Some(col_expr) = statistics_args.exprs[0]
.as_any()
.downcast_ref::<expressions::Column>()
{
return self.value_from_column_statistics(
&col_stats[col_expr.index()],
);
}
}
}
_ => {}
}
}
None
}
}

impl FromColumnStatistics for Max {
fn value_from_column_statistics(
&self,
col_stats: &ColumnStatistics,
) -> Option<ScalarValue> {
if let Precision::Exact(ref val) = col_stats.max_value {
if !val.is_null() {
return Some(val.clone());
}
}
None
}
}

impl AggregateUDFImpl for Max {
fn as_any(&self) -> &dyn std::any::Any {
self
Expand Down Expand Up @@ -272,6 +322,7 @@ impl AggregateUDFImpl for Max {
fn is_descending(&self) -> Option<bool> {
Some(true)
}

fn order_sensitivity(&self) -> datafusion_expr::utils::AggregateOrderSensitivity {
datafusion_expr::utils::AggregateOrderSensitivity::Insensitive
}
Expand All @@ -282,6 +333,9 @@ impl AggregateUDFImpl for Max {
fn reverse_expr(&self) -> datafusion_expr::ReversedUDAF {
datafusion_expr::ReversedUDAF::Identical
}
fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option<ScalarValue> {
self.value_from_statistics(statistics_args)
}
}

// Statically-typed version of min/max(array) -> ScalarValue for string types
Expand Down Expand Up @@ -926,6 +980,20 @@ impl Default for Min {
}
}

impl FromColumnStatistics for Min {
fn value_from_column_statistics(
&self,
col_stats: &ColumnStatistics,
) -> Option<ScalarValue> {
if let Precision::Exact(ref val) = col_stats.min_value {
if !val.is_null() {
return Some(val.clone());
}
}
None
}
}

impl AggregateUDFImpl for Min {
fn as_any(&self) -> &dyn std::any::Any {
self
Expand Down Expand Up @@ -1052,6 +1120,9 @@ impl AggregateUDFImpl for Min {
Some(false)
}

fn value_from_stats(&self, statistics_args: &StatisticsArgs) -> Option<ScalarValue> {
self.value_from_statistics(statistics_args)
}
fn order_sensitivity(&self) -> datafusion_expr::utils::AggregateOrderSensitivity {
datafusion_expr::utils::AggregateOrderSensitivity::Insensitive
}
Expand Down
Loading