diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index 228bce1979a38..0d0029c5a541e 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -19,8 +19,9 @@ use arrow::array::{ self, Array, ArrayRef, ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType, - AsArray, PrimitiveArray, PrimitiveBuilder, UInt64Array, + AsArray, BooleanArray, PrimitiveArray, PrimitiveBuilder, UInt64Array, }; + use arrow::compute::sum; use arrow::datatypes::{ i256, ArrowNativeType, DataType, Decimal128Type, Decimal256Type, DecimalType, Field, @@ -35,6 +36,9 @@ use datafusion_expr::{ Accumulator, AggregateUDFImpl, EmitTo, GroupsAccumulator, ReversedUDAF, Signature, }; use datafusion_physical_expr_common::aggregate::groups_accumulator::accumulate::NullState; +use datafusion_physical_expr_common::aggregate::groups_accumulator::nulls::{ + filtered_null_mask, set_nulls, +}; use datafusion_physical_expr_common::aggregate::utils::DecimalAverager; use log::debug; use std::any::Any; @@ -554,6 +558,30 @@ where Ok(()) } + fn convert_to_state( + &self, + values: &[ArrayRef], + opt_filter: Option<&BooleanArray>, + ) -> Result> { + let sums = values[0] + .as_primitive::() + .clone() + .with_data_type(self.sum_data_type.clone()); + let counts = UInt64Array::from_value(1, sums.len()); + + let nulls = filtered_null_mask(opt_filter, &sums); + + // set nulls on the arrays + let counts = set_nulls(counts, nulls.clone()); + let sums = set_nulls(sums, nulls); + + Ok(vec![Arc::new(counts) as ArrayRef, Arc::new(sums)]) + } + + fn supports_convert_to_state(&self) -> bool { + true + } + fn size(&self) -> usize { self.counts.capacity() * std::mem::size_of::() + self.sums.capacity() * std::mem::size_of::() diff --git a/datafusion/physical-expr-common/src/aggregate/groups_accumulator/mod.rs b/datafusion/physical-expr-common/src/aggregate/groups_accumulator/mod.rs index 5b0182c5db8a7..220f142927b32 100644 --- a/datafusion/physical-expr-common/src/aggregate/groups_accumulator/mod.rs +++ b/datafusion/physical-expr-common/src/aggregate/groups_accumulator/mod.rs @@ -19,4 +19,5 @@ pub mod accumulate; pub mod bool_op; +pub mod nulls; pub mod prim_op; diff --git a/datafusion/physical-expr-common/src/aggregate/groups_accumulator/nulls.rs b/datafusion/physical-expr-common/src/aggregate/groups_accumulator/nulls.rs new file mode 100644 index 0000000000000..cf24c575ac695 --- /dev/null +++ b/datafusion/physical-expr-common/src/aggregate/groups_accumulator/nulls.rs @@ -0,0 +1,87 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! XX utlities for working with nulls + +use arrow::array::{Array, ArrowNumericType, BooleanArray, PrimitiveArray}; +use arrow::buffer::NullBuffer; + +/// Sets the validity mask for a `PrimitiveArray` to `nulls` +/// replacing any existing null mask +pub fn set_nulls( + array: PrimitiveArray, + nulls: Option, +) -> PrimitiveArray { + let (dt, values, _old_nulls) = array.into_parts(); + PrimitiveArray::::new(values, nulls).with_data_type(dt) +} + +/// Converts a `BooleanBuffer` representing a filter to a `NullBuffer. +/// +/// The `NullBuffer` is +/// * `true` (representing valid) for values that were `true` in filter +/// * `false` (representing null) for values that were `false` or `null` in filter +fn filter_to_nulls(filter: &BooleanArray) -> Option { + let (filter_bools, filter_nulls) = filter.clone().into_parts(); + let filter_bools = NullBuffer::from(filter_bools); + NullBuffer::union(Some(&filter_bools), filter_nulls.as_ref()) +} + +/// Compute an output validity mask for an array that has been filtered +/// +/// This can be used to compute nulls for the output of +/// [`GroupsAccumulator::emit_to_state`], which quickly applies an optional +/// filter to the input rows by setting any filtered rows to NULL in the output. +/// Subsequent applications of aggregate functions that ignore NULLs (most of +/// them) will thus ignore the filtered rows as well. +/// +/// # Output element is `true` +/// * A `true` in the output represents non null output for all values that were both: +/// +/// * `true` in any `opt_filter` (aka values that passed the filter) +/// +/// * `non null` in `input` +/// +/// # Output element is `false` +/// * is false (null) for all values that were false in the filter or null in the input +/// +/// # Example +/// +/// ```text +/// ┌─────┐ ┌─────┐ ┌─────┐ +/// │true │ │NULL │ │NULL │ +/// │true │ │ │true │ │true │ +/// │true │ ───┼─── │false│ ────────▶ │false│ filtered_nulls +/// │false│ │ │NULL │ │NULL │ +/// │false│ │true │ │true │ +/// └─────┘ └─────┘ └─────┘ +/// array opt_filter output nulls +/// .nulls() +/// +/// false = NULL true = pass false = NULL Meanings +/// true = valid false = filter true = valid +/// NULL = filter +/// ``` +/// +/// [`GroupsAccumulator::emit_to_state`]: datafusion_expr::groups_accumulator::GroupsAccumulator::emit_to_state +pub fn filtered_null_mask( + opt_filter: Option<&BooleanArray>, + input: &dyn Array, +) -> Option { + let opt_filter = opt_filter.and_then(filter_to_nulls); + NullBuffer::union(opt_filter.as_ref(), input.nulls()) +} diff --git a/datafusion/sqllogictest/test_files/aggregate_skip_partial.slt b/datafusion/sqllogictest/test_files/aggregate_skip_partial.slt index 65efc24ec037c..56ae7edd3cdc1 100644 --- a/datafusion/sqllogictest/test_files/aggregate_skip_partial.slt +++ b/datafusion/sqllogictest/test_files/aggregate_skip_partial.slt @@ -209,6 +209,20 @@ SELECT c2, sum(c3), sum(c11) FROM aggregate_test_100 GROUP BY c2 ORDER BY c2; 4 29 9.531112968922 5 -194 7.074412226677 +# Test avg for bigint / float +query IRR +SELECT + c2, + avg(c10), + avg(c11) +FROM aggregate_test_100 GROUP BY c2 ORDER BY c2; +---- +1 9803675241365398000 0.552420626987 +2 6843194947657418000 0.435355657881 +3 10700987547561746000 0.504783755855 +4 7199224282513318000 0.41439621604 +5 9295051061697067000 0.505315159048 + # Enabling PG dialect for filtered aggregates tests statement ok set datafusion.sql_parser.dialect = 'Postgres'; @@ -267,6 +281,20 @@ FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2; 4 11 14 5 8 7 +# Test avg for bigint / float with filter +query IRR +SELECT + c2, + avg(c10) FILTER (WHERE c2 != 'e'), + avg(c11) FILTER (WHERE c2 != 'e') +FROM aggregate_test_100 GROUP BY c2 ORDER BY c2; +---- +1 9803675241365398000 0.552420626987 +2 6843194947657418000 0.435355657881 +3 10700987547561746000 0.504783755855 +4 7199224282513318000 0.41439621604 +5 9295051061697067000 0.505315159048 + # Test count with nullable fields and nullable filter query III SELECT c2,