diff --git a/rust/cubesql/cubesql/src/compile/engine/udf/common.rs b/rust/cubesql/cubesql/src/compile/engine/udf/common.rs index 494641bbcac35..3e103578cf4fc 100644 --- a/rust/cubesql/cubesql/src/compile/engine/udf/common.rs +++ b/rust/cubesql/cubesql/src/compile/engine/udf/common.rs @@ -2268,7 +2268,9 @@ pub fn create_measure_udaf() -> AggregateUDF { DataType::Float64, Arc::new(DataType::Float64), Volatility::Immutable, - Arc::new(|| todo!("Not implemented")), + Arc::new(|| { + Err(DataFusionError::NotImplemented("MEASURE function was used in context where it's not supported. Try replacing MEASURE with the measure type-matching function (SUM/AVG/etc).".to_string())) + }), Arc::new(vec![DataType::Float64]), ) } diff --git a/rust/cubesql/cubesql/src/compile/mod.rs b/rust/cubesql/cubesql/src/compile/mod.rs index 10f7a4587ba28..1569aeea1d259 100644 --- a/rust/cubesql/cubesql/src/compile/mod.rs +++ b/rust/cubesql/cubesql/src/compile/mod.rs @@ -18588,4 +18588,37 @@ LIMIT {{ limit }}{% endif %}"#.to_string(), Ok(()) } + + #[tokio::test] + async fn test_measure_func_push_down() { + if !Rewriter::sql_push_down_enabled() { + return; + } + init_testing_logger(); + + let query_plan = convert_select_to_query_plan( + r#" + SELECT MEASURE("sumPrice") AS "total_price" + FROM "public"."KibanaSampleDataEcommerce" + WHERE lower("customer_gender") = '123' + "# + .to_string(), + DatabaseProtocol::PostgreSQL, + ) + .await; + + let logical_plan = query_plan.as_logical_plan(); + let sql = logical_plan + .find_cube_scan_wrapper() + .wrapped_sql + .unwrap() + .sql; + assert!(sql.contains("SUM(")); + + let physical_plan = query_plan.as_physical_plan().await.unwrap(); + println!( + "Physical plan: {}", + displayable(physical_plan.as_ref()).indent() + ); + } } diff --git a/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/aggregate_function.rs b/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/aggregate_function.rs index 3d5280f017912..6c8ce0d9d9436 100644 --- a/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/aggregate_function.rs +++ b/rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/aggregate_function.rs @@ -1,10 +1,12 @@ use crate::{ compile::rewrite::{ - agg_fun_expr, analysis::LogicalPlanAnalysis, rewrite, rules::wrapper::WrapperRules, - transforming_rewrite, wrapper_pullup_replacer, wrapper_pushdown_replacer, - AggregateFunctionExprDistinct, AggregateFunctionExprFun, LogicalPlanLanguage, - WrapperPullupReplacerAliasToCube, + agg_fun_expr, alias_expr, analysis::LogicalPlanAnalysis, column_expr, original_expr_name, + rewrite, rules::wrapper::WrapperRules, transforming_chain_rewrite, transforming_rewrite, + udaf_expr, wrapper_pullup_replacer, wrapper_pushdown_replacer, + AggregateFunctionExprDistinct, AggregateFunctionExprFun, AggregateUDFExprFun, + AliasExprAlias, ColumnExprColumn, LogicalPlanLanguage, WrapperPullupReplacerAliasToCube, }, + transport::V1CubeMetaExt, var, var_iter, }; use datafusion::physical_plan::aggregates::AggregateFunction; @@ -59,6 +61,35 @@ impl WrapperRules { ), self.transform_agg_fun_expr("?fun", "?distinct", "?alias_to_cube"), ), + transforming_chain_rewrite( + "wrapper-push-down-measure-aggregate-function", + wrapper_pushdown_replacer( + "?udaf", + "?alias_to_cube", + "?ungrouped", + "?in_projection", + "?cube_members", + ), + vec![("?udaf", udaf_expr("?fun", vec![column_expr("?column")]))], + alias_expr( + wrapper_pushdown_replacer( + "?output", + "?alias_to_cube", + "?ungrouped", + "?in_projection", + "?cube_members", + ), + "?alias", + ), + self.transform_measure_udaf_expr( + "?udaf", + "?fun", + "?column", + "?alias_to_cube", + "?output", + "?alias", + ), + ), ]); } @@ -105,4 +136,106 @@ impl WrapperRules { false } } + + fn transform_measure_udaf_expr( + &self, + udaf_var: &'static str, + fun_var: &'static str, + column_var: &'static str, + alias_to_cube_var: &'static str, + output_var: &'static str, + alias_var: &'static str, + ) -> impl Fn(&mut EGraph, &mut Subst) -> bool { + let udaf_var = var!(udaf_var); + let fun_var = var!(fun_var); + let column_var = var!(column_var); + let alias_to_cube_var = var!(alias_to_cube_var); + let output_var = var!(output_var); + let alias_var = var!(alias_var); + let meta = self.meta_context.clone(); + move |egraph, subst| { + let Some(original_alias) = original_expr_name(egraph, subst[udaf_var]) else { + return false; + }; + + for fun in var_iter!(egraph[subst[fun_var]], AggregateUDFExprFun) { + if fun.to_lowercase() != "measure" { + continue; + } + + for column in var_iter!(egraph[subst[column_var]], ColumnExprColumn) { + for alias_to_cube in var_iter!( + egraph[subst[alias_to_cube_var]], + WrapperPullupReplacerAliasToCube + ) { + let Some((_, cube)) = meta.find_cube_by_column(alias_to_cube, column) + else { + continue; + }; + + let Some(measure) = cube.lookup_measure(&column.name) else { + continue; + }; + + let Some(agg_type) = &measure.agg_type else { + continue; + }; + + let out_fun_distinct = match agg_type.as_str() { + "string" | "time" | "boolean" | "number" => None, + "count" => Some((AggregateFunction::Count, false)), + "countDistinct" => Some((AggregateFunction::Count, true)), + "countDistinctApprox" => { + Some((AggregateFunction::ApproxDistinct, false)) + } + "sum" => Some((AggregateFunction::Sum, false)), + "avg" => Some((AggregateFunction::Avg, false)), + "min" => Some((AggregateFunction::Min, false)), + "max" => Some((AggregateFunction::Max, false)), + _ => continue, + }; + + let column_expr_id = + egraph.add(LogicalPlanLanguage::ColumnExpr([subst[column_var]])); + + let output_id = out_fun_distinct + .map(|(out_fun, distinct)| { + let fun_id = + egraph.add(LogicalPlanLanguage::AggregateFunctionExprFun( + AggregateFunctionExprFun(out_fun), + )); + let args_tail_id = egraph + .add(LogicalPlanLanguage::AggregateFunctionExprArgs(vec![])); + let args_id = + egraph.add(LogicalPlanLanguage::AggregateFunctionExprArgs( + vec![column_expr_id, args_tail_id], + )); + let distinct_id = + egraph.add(LogicalPlanLanguage::AggregateFunctionExprDistinct( + AggregateFunctionExprDistinct(distinct), + )); + + egraph.add(LogicalPlanLanguage::AggregateFunctionExpr([ + fun_id, + args_id, + distinct_id, + ])) + }) + .unwrap_or(column_expr_id); + + subst.insert(output_var, output_id); + + subst.insert( + alias_var, + egraph.add(LogicalPlanLanguage::AliasExprAlias(AliasExprAlias( + original_alias, + ))), + ); + return true; + } + } + } + false + } + } }