Skip to content

Commit

Permalink
[spark tests]
Browse files Browse the repository at this point in the history
  • Loading branch information
shehabgamin committed Oct 7, 2024
1 parent 619f09f commit d8c4522
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 18 deletions.
38 changes: 26 additions & 12 deletions crates/sail-plan/src/function/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,19 @@ use datafusion_expr::sqlparser::ast::NullTreatment;
use lazy_static::lazy_static;

use crate::error::{PlanError, PlanResult};
use crate::function::common::AggFunction;
use crate::function::common::{AggFunction, AggFunctionContext};
use crate::utils::ItemTaker;

lazy_static! {
static ref BUILT_IN_AGGREGATE_FUNCTIONS: HashMap<&'static str, AggFunction> =
HashMap::from_iter(list_built_in_aggregate_functions());
}

fn min_max_by(args: Vec<expr::Expr>, distinct: bool, asc: bool) -> PlanResult<expr::Expr> {
fn min_max_by(
args: Vec<expr::Expr>,
agg_function_context: AggFunctionContext,
asc: bool,
) -> PlanResult<expr::Expr> {
let (args, order_by, filter) = if args.len() == 2 {
let (first, second) = args.two()?;
(vec![first], second, None)
Expand All @@ -35,7 +39,7 @@ fn min_max_by(args: Vec<expr::Expr>, distinct: bool, asc: bool) -> PlanResult<ex
Ok(expr::Expr::AggregateFunction(AggregateFunction {
func: first_last::first_value_udaf(),
args,
distinct,
distinct: agg_function_context.distinct(),
filter,
order_by,
null_treatment: None,
Expand All @@ -44,7 +48,7 @@ fn min_max_by(args: Vec<expr::Expr>, distinct: bool, asc: bool) -> PlanResult<ex

fn first_last_value(
args: Vec<expr::Expr>,
distinct: bool,
agg_function_context: AggFunctionContext,
first_value: bool,
) -> PlanResult<expr::Expr> {
let (args, ignore_nulls) = if args.len() == 1 {
Expand Down Expand Up @@ -78,7 +82,7 @@ fn first_last_value(
Ok(expr::Expr::AggregateFunction(AggregateFunction {
func,
args,
distinct,
distinct: agg_function_context.distinct(),
filter: None,
order_by: None,
null_treatment: Some(ignore_nulls),
Expand All @@ -92,7 +96,9 @@ fn list_built_in_aggregate_functions() -> Vec<(&'static str, AggFunction)> {
("any", F::default(bool_and_or::bool_or_udaf)),
(
"any_value",
F::custom(|args, distinct| first_last_value(args, distinct, true)),
F::custom(|args, agg_function_context| {
first_last_value(args, agg_function_context, true)
}),
),
(
"approx_count_distinct",
Expand Down Expand Up @@ -122,11 +128,15 @@ fn list_built_in_aggregate_functions() -> Vec<(&'static str, AggFunction)> {
("every", F::default(bool_and_or::bool_and_udaf)),
(
"first",
F::custom(|args, distinct| first_last_value(args, distinct, true)),
F::custom(|args, agg_function_context| {
first_last_value(args, agg_function_context, true)
}),
),
(
"first_value",
F::custom(|args, distinct| first_last_value(args, distinct, true)),
F::custom(|args, agg_function_context| {
first_last_value(args, agg_function_context, true)
}),
),
("grouping", F::default(grouping::grouping_udaf)),
("grouping_id", F::unknown("grouping_id")),
Expand All @@ -136,23 +146,27 @@ fn list_built_in_aggregate_functions() -> Vec<(&'static str, AggFunction)> {
("kurtosis", F::default(kurtosis_pop::kurtosis_pop_udaf)),
(
"last",
F::custom(|args, distinct| first_last_value(args, distinct, false)),
F::custom(|args, agg_function_context| {
first_last_value(args, agg_function_context, false)
}),
),
(
"last_value",
F::custom(|args, distinct| first_last_value(args, distinct, false)),
F::custom(|args, agg_function_context| {
first_last_value(args, agg_function_context, false)
}),
),
("max", F::default(min_max::max_udaf)),
(
"max_by",
F::custom(|args, distinct| min_max_by(args, distinct, false)),
F::custom(|args, agg_function_context| min_max_by(args, agg_function_context, false)),
),
("mean", F::default(average::avg_udaf)),
("median", F::default(median::median_udaf)),
("min", F::default(min_max::min_udaf)),
(
"min_by",
F::custom(|args, distinct| min_max_by(args, distinct, true)),
F::custom(|args, agg_function_context| min_max_by(args, agg_function_context, true)),
),
("mode", F::unknown("mode")),
("percentile", F::unknown("percentile")),
Expand Down
25 changes: 21 additions & 4 deletions crates/sail-plan/src/function/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,22 @@ impl FunctionBuilder {
}
}

pub struct AggFunctionContext {
distinct: bool,
}

impl AggFunctionContext {
pub fn new(distinct: bool) -> Self {
Self { distinct }
}

pub fn distinct(&self) -> bool {
self.distinct
}
}

pub(crate) type AggFunction =
Arc<dyn Fn(Vec<expr::Expr>, bool) -> PlanResult<expr::Expr> + Send + Sync>;
Arc<dyn Fn(Vec<expr::Expr>, AggFunctionContext) -> PlanResult<expr::Expr> + Send + Sync>;

pub(crate) struct AggFunctionBuilder;

Expand All @@ -153,11 +167,11 @@ impl AggFunctionBuilder {
where
F: Fn() -> Arc<AggregateUDF> + Send + Sync + 'static,
{
Arc::new(move |args, distinct| {
Arc::new(move |args, agg_function_context| {
Ok(expr::Expr::AggregateFunction(AggregateFunction {
func: f(),
args,
distinct,
distinct: agg_function_context.distinct(),
filter: None,
order_by: None,
null_treatment: None,
Expand All @@ -167,7 +181,10 @@ impl AggFunctionBuilder {

pub fn custom<F>(f: F) -> AggFunction
where
F: Fn(Vec<expr::Expr>, bool) -> PlanResult<expr::Expr> + Send + Sync + 'static,
F: Fn(Vec<expr::Expr>, AggFunctionContext) -> PlanResult<expr::Expr>
+ Send
+ Sync
+ 'static,
{
Arc::new(f)
}
Expand Down
5 changes: 3 additions & 2 deletions crates/sail-plan/src/resolver/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use sail_python_udf::udf::unresolved_pyspark_udf::UnresolvedPySparkUDF;
use crate::error::{PlanError, PlanResult};
use crate::extension::function::drop_struct_field::DropStructField;
use crate::extension::function::update_struct_field::UpdateStructField;
use crate::function::common::FunctionContext;
use crate::function::common::{AggFunctionContext, FunctionContext};
use crate::function::{
get_built_in_aggregate_function, get_built_in_function, get_built_in_window_function,
};
Expand Down Expand Up @@ -783,7 +783,8 @@ impl PlanResolver<'_> {
let function_context = FunctionContext::new(self.config.clone());
func(arguments.clone(), &function_context)?
} else if let Ok(func) = get_built_in_aggregate_function(function_name.as_str()) {
func(arguments.clone(), is_distinct)?
let agg_function_context = AggFunctionContext::new(is_distinct);
func(arguments.clone(), agg_function_context)?
} else {
return Err(PlanError::unsupported(format!(
"unknown function: {function_name}",
Expand Down

0 comments on commit d8c4522

Please sign in to comment.