From d8c452259add93d528be8c04989575ce198a9f80 Mon Sep 17 00:00:00 2001 From: Shehab <11789402+shehabgamin@users.noreply.github.com> Date: Mon, 7 Oct 2024 13:57:01 -0700 Subject: [PATCH] [spark tests] --- crates/sail-plan/src/function/aggregate.rs | 38 ++++++++++++++------- crates/sail-plan/src/function/common.rs | 25 +++++++++++--- crates/sail-plan/src/resolver/expression.rs | 5 +-- 3 files changed, 50 insertions(+), 18 deletions(-) diff --git a/crates/sail-plan/src/function/aggregate.rs b/crates/sail-plan/src/function/aggregate.rs index b0e28919..80be852e 100644 --- a/crates/sail-plan/src/function/aggregate.rs +++ b/crates/sail-plan/src/function/aggregate.rs @@ -12,7 +12,7 @@ use datafusion_expr::sqlparser::ast::NullTreatment; use lazy_static::lazy_static; use crate::error::{PlanError, PlanResult}; -use crate::function::common::AggFunction; +use crate::function::common::{AggFunction, AggFunctionContext}; use crate::utils::ItemTaker; lazy_static! { @@ -20,7 +20,11 @@ lazy_static! { HashMap::from_iter(list_built_in_aggregate_functions()); } -fn min_max_by(args: Vec, distinct: bool, asc: bool) -> PlanResult { +fn min_max_by( + args: Vec, + agg_function_context: AggFunctionContext, + asc: bool, +) -> PlanResult { let (args, order_by, filter) = if args.len() == 2 { let (first, second) = args.two()?; (vec![first], second, None) @@ -35,7 +39,7 @@ fn min_max_by(args: Vec, distinct: bool, asc: bool) -> PlanResult, distinct: bool, asc: bool) -> PlanResult, - distinct: bool, + agg_function_context: AggFunctionContext, first_value: bool, ) -> PlanResult { let (args, ignore_nulls) = if args.len() == 1 { @@ -78,7 +82,7 @@ fn first_last_value( Ok(expr::Expr::AggregateFunction(AggregateFunction { func, args, - distinct, + distinct: agg_function_context.distinct(), filter: None, order_by: None, null_treatment: Some(ignore_nulls), @@ -92,7 +96,9 @@ fn list_built_in_aggregate_functions() -> Vec<(&'static str, AggFunction)> { ("any", F::default(bool_and_or::bool_or_udaf)), ( "any_value", - F::custom(|args, distinct| first_last_value(args, distinct, true)), + F::custom(|args, agg_function_context| { + first_last_value(args, agg_function_context, true) + }), ), ( "approx_count_distinct", @@ -122,11 +128,15 @@ fn list_built_in_aggregate_functions() -> Vec<(&'static str, AggFunction)> { ("every", F::default(bool_and_or::bool_and_udaf)), ( "first", - F::custom(|args, distinct| first_last_value(args, distinct, true)), + F::custom(|args, agg_function_context| { + first_last_value(args, agg_function_context, true) + }), ), ( "first_value", - F::custom(|args, distinct| first_last_value(args, distinct, true)), + F::custom(|args, agg_function_context| { + first_last_value(args, agg_function_context, true) + }), ), ("grouping", F::default(grouping::grouping_udaf)), ("grouping_id", F::unknown("grouping_id")), @@ -136,23 +146,27 @@ fn list_built_in_aggregate_functions() -> Vec<(&'static str, AggFunction)> { ("kurtosis", F::default(kurtosis_pop::kurtosis_pop_udaf)), ( "last", - F::custom(|args, distinct| first_last_value(args, distinct, false)), + F::custom(|args, agg_function_context| { + first_last_value(args, agg_function_context, false) + }), ), ( "last_value", - F::custom(|args, distinct| first_last_value(args, distinct, false)), + F::custom(|args, agg_function_context| { + first_last_value(args, agg_function_context, false) + }), ), ("max", F::default(min_max::max_udaf)), ( "max_by", - F::custom(|args, distinct| min_max_by(args, distinct, false)), + F::custom(|args, agg_function_context| min_max_by(args, agg_function_context, false)), ), ("mean", F::default(average::avg_udaf)), ("median", F::default(median::median_udaf)), ("min", F::default(min_max::min_udaf)), ( "min_by", - F::custom(|args, distinct| min_max_by(args, distinct, true)), + F::custom(|args, agg_function_context| min_max_by(args, agg_function_context, true)), ), ("mode", F::unknown("mode")), ("percentile", F::unknown("percentile")), diff --git a/crates/sail-plan/src/function/common.rs b/crates/sail-plan/src/function/common.rs index c8160e8b..663906e4 100644 --- a/crates/sail-plan/src/function/common.rs +++ b/crates/sail-plan/src/function/common.rs @@ -143,8 +143,22 @@ impl FunctionBuilder { } } +pub struct AggFunctionContext { + distinct: bool, +} + +impl AggFunctionContext { + pub fn new(distinct: bool) -> Self { + Self { distinct } + } + + pub fn distinct(&self) -> bool { + self.distinct + } +} + pub(crate) type AggFunction = - Arc, bool) -> PlanResult + Send + Sync>; + Arc, AggFunctionContext) -> PlanResult + Send + Sync>; pub(crate) struct AggFunctionBuilder; @@ -153,11 +167,11 @@ impl AggFunctionBuilder { where F: Fn() -> Arc + Send + Sync + 'static, { - Arc::new(move |args, distinct| { + Arc::new(move |args, agg_function_context| { Ok(expr::Expr::AggregateFunction(AggregateFunction { func: f(), args, - distinct, + distinct: agg_function_context.distinct(), filter: None, order_by: None, null_treatment: None, @@ -167,7 +181,10 @@ impl AggFunctionBuilder { pub fn custom(f: F) -> AggFunction where - F: Fn(Vec, bool) -> PlanResult + Send + Sync + 'static, + F: Fn(Vec, AggFunctionContext) -> PlanResult + + Send + + Sync + + 'static, { Arc::new(f) } diff --git a/crates/sail-plan/src/resolver/expression.rs b/crates/sail-plan/src/resolver/expression.rs index 19b332a6..e348ec1b 100644 --- a/crates/sail-plan/src/resolver/expression.rs +++ b/crates/sail-plan/src/resolver/expression.rs @@ -24,7 +24,7 @@ use sail_python_udf::udf::unresolved_pyspark_udf::UnresolvedPySparkUDF; use crate::error::{PlanError, PlanResult}; use crate::extension::function::drop_struct_field::DropStructField; use crate::extension::function::update_struct_field::UpdateStructField; -use crate::function::common::FunctionContext; +use crate::function::common::{AggFunctionContext, FunctionContext}; use crate::function::{ get_built_in_aggregate_function, get_built_in_function, get_built_in_window_function, }; @@ -783,7 +783,8 @@ impl PlanResolver<'_> { let function_context = FunctionContext::new(self.config.clone()); func(arguments.clone(), &function_context)? } else if let Ok(func) = get_built_in_aggregate_function(function_name.as_str()) { - func(arguments.clone(), is_distinct)? + let agg_function_context = AggFunctionContext::new(is_distinct); + func(arguments.clone(), agg_function_context)? } else { return Err(PlanError::unsupported(format!( "unknown function: {function_name}",