Skip to content

Commit

Permalink
Convert VariancePopulation to UDAF (#10836)
Browse files Browse the repository at this point in the history
  • Loading branch information
mknaw authored Jun 10, 2024
1 parent 9503456 commit e8fdc09
Show file tree
Hide file tree
Showing 15 changed files with 105 additions and 265 deletions.
11 changes: 1 addition & 10 deletions datafusion/expr/src/aggregate_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ pub enum AggregateFunction {
ArrayAgg,
/// N'th value in a group according to some ordering
NthValue,
/// Variance (Population)
VariancePop,
/// Correlation
Correlation,
/// Slope from linear regression
Expand Down Expand Up @@ -102,7 +100,6 @@ impl AggregateFunction {
ApproxDistinct => "APPROX_DISTINCT",
ArrayAgg => "ARRAY_AGG",
NthValue => "NTH_VALUE",
VariancePop => "VAR_POP",
Correlation => "CORR",
RegrSlope => "REGR_SLOPE",
RegrIntercept => "REGR_INTERCEPT",
Expand Down Expand Up @@ -153,7 +150,6 @@ impl FromStr for AggregateFunction {
"string_agg" => AggregateFunction::StringAgg,
// statistical
"corr" => AggregateFunction::Correlation,
"var_pop" => AggregateFunction::VariancePop,
"regr_slope" => AggregateFunction::RegrSlope,
"regr_intercept" => AggregateFunction::RegrIntercept,
"regr_count" => AggregateFunction::RegrCount,
Expand Down Expand Up @@ -216,9 +212,6 @@ impl AggregateFunction {
AggregateFunction::BoolAnd | AggregateFunction::BoolOr => {
Ok(DataType::Boolean)
}
AggregateFunction::VariancePop => {
variance_return_type(&coerced_data_types[0])
}
AggregateFunction::Correlation => {
correlation_return_type(&coerced_data_types[0])
}
Expand Down Expand Up @@ -291,9 +284,7 @@ impl AggregateFunction {
AggregateFunction::BoolAnd | AggregateFunction::BoolOr => {
Signature::uniform(1, vec![DataType::Boolean], Volatility::Immutable)
}
AggregateFunction::Avg
| AggregateFunction::VariancePop
| AggregateFunction::ApproxMedian => {
AggregateFunction::Avg | AggregateFunction::ApproxMedian => {
Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable)
}
AggregateFunction::NthValue => Signature::any(2, Volatility::Immutable),
Expand Down
10 changes: 0 additions & 10 deletions datafusion/expr/src/type_coercion/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -151,16 +151,6 @@ pub fn coerce_types(
}
Ok(input_types.to_vec())
}
AggregateFunction::VariancePop => {
if !is_variance_support_arg_type(&input_types[0]) {
return plan_err!(
"The function {:?} does not support inputs of type {:?}.",
agg_fun,
input_types[0]
);
}
Ok(vec![Float64, Float64])
}
AggregateFunction::Correlation => {
if !is_correlation_support_arg_type(&input_types[0]) {
return plan_err!(
Expand Down
2 changes: 2 additions & 0 deletions datafusion/functions-aggregate/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ pub mod expr_fn {
pub use super::stddev::stddev;
pub use super::stddev::stddev_pop;
pub use super::sum::sum;
pub use super::variance::var_pop;
pub use super::variance::var_sample;
}

Expand All @@ -91,6 +92,7 @@ pub fn all_default_aggregate_functions() -> Vec<Arc<AggregateUDF>> {
covariance::covar_pop_udaf(),
median::median_udaf(),
variance::var_samp_udaf(),
variance::var_pop_udaf(),
stddev::stddev_udaf(),
stddev::stddev_pop_udaf(),
]
Expand Down
85 changes: 84 additions & 1 deletion datafusion/functions-aggregate/src/variance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
// specific language governing permissions and limitations
// under the License.

//! [`VarianceSample`]: covariance sample aggregations.
//! [`VarianceSample`]: variance sample aggregations.
//! [`VariancePopulation`]: variance population aggregations.

use std::fmt::Debug;

Expand Down Expand Up @@ -43,6 +44,14 @@ make_udaf_expr_and_func!(
var_samp_udaf
);

make_udaf_expr_and_func!(
VariancePopulation,
var_pop,
expression,
"Computes the population variance.",
var_pop_udaf
);

pub struct VarianceSample {
signature: Signature,
aliases: Vec<String>,
Expand Down Expand Up @@ -115,6 +124,80 @@ impl AggregateUDFImpl for VarianceSample {
}
}

pub struct VariancePopulation {
signature: Signature,
aliases: Vec<String>,
}

impl Debug for VariancePopulation {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("VariancePopulation")
.field("name", &self.name())
.field("signature", &self.signature)
.finish()
}
}

impl Default for VariancePopulation {
fn default() -> Self {
Self::new()
}
}

impl VariancePopulation {
pub fn new() -> Self {
Self {
aliases: vec![String::from("var_population")],
signature: Signature::numeric(1, Volatility::Immutable),
}
}
}

impl AggregateUDFImpl for VariancePopulation {
fn as_any(&self) -> &dyn std::any::Any {
self
}

fn name(&self) -> &str {
"var_pop"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
if !arg_types[0].is_numeric() {
return plan_err!("Variance requires numeric input types");
}

Ok(DataType::Float64)
}

fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
let name = args.name;
Ok(vec![
Field::new(format_state_name(name, "count"), DataType::UInt64, true),
Field::new(format_state_name(name, "mean"), DataType::Float64, true),
Field::new(format_state_name(name, "m2"), DataType::Float64, true),
])
}

fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
if acc_args.is_distinct {
return not_impl_err!("VAR_POP(DISTINCT) aggregations are not available");
}

Ok(Box::new(VarianceAccumulator::try_new(
StatsType::Population,
)?))
}

fn aliases(&self) -> &[String] {
&self.aliases
}
}

/// An accumulator to compute variance
/// The algrithm used is an online implementation and numerically stable. It is based on this paper:
/// Welford, B. P. (1962). "Note on a method for calculating corrected sums of squares and products".
Expand Down
45 changes: 0 additions & 45 deletions datafusion/physical-expr/src/aggregate/build_in.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,12 +157,6 @@ pub fn create_aggregate_expr(
(AggregateFunction::Avg, true) => {
return not_impl_err!("AVG(DISTINCT) aggregations are not available");
}
(AggregateFunction::VariancePop, false) => Arc::new(
expressions::VariancePop::new(input_phy_exprs[0].clone(), name, data_type),
),
(AggregateFunction::VariancePop, true) => {
return not_impl_err!("VAR_POP(DISTINCT) aggregations are not available");
}
(AggregateFunction::Correlation, false) => {
Arc::new(expressions::Correlation::new(
input_phy_exprs[0].clone(),
Expand Down Expand Up @@ -340,7 +334,6 @@ pub fn create_aggregate_expr(
#[cfg(test)]
mod tests {
use arrow::datatypes::{DataType, Field};
use expressions::VariancePop;

use super::*;
use crate::expressions::{
Expand Down Expand Up @@ -693,44 +686,6 @@ mod tests {
Ok(())
}

#[test]
fn test_var_pop_expr() -> Result<()> {
let funcs = vec![AggregateFunction::VariancePop];
let data_types = vec![
DataType::UInt32,
DataType::UInt64,
DataType::Int32,
DataType::Int64,
DataType::Float32,
DataType::Float64,
];
for fun in funcs {
for data_type in &data_types {
let input_schema =
Schema::new(vec![Field::new("c1", data_type.clone(), true)]);
let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> = vec![Arc::new(
expressions::Column::new_with_schema("c1", &input_schema).unwrap(),
)];
let result_agg_phy_exprs = create_physical_agg_expr_for_test(
&fun,
false,
&input_phy_exprs[0..1],
&input_schema,
"c1",
)?;
if fun == AggregateFunction::VariancePop {
assert!(result_agg_phy_exprs.as_any().is::<VariancePop>());
assert_eq!("c1", result_agg_phy_exprs.name());
assert_eq!(
Field::new("c1", DataType::Float64, true),
result_agg_phy_exprs.field().unwrap()
)
}
}
}
Ok(())
}

#[test]
fn test_median_expr() -> Result<()> {
let funcs = vec![AggregateFunction::ApproxMedian];
Expand Down
Loading

0 comments on commit e8fdc09

Please sign in to comment.