Skip to content

Commit

Permalink
feat(cubesql): Support MEASURE SQL push down
Browse files Browse the repository at this point in the history
  • Loading branch information
MazterQyou committed Sep 23, 2024
1 parent 73d9314 commit 4968959
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 5 deletions.
4 changes: 3 additions & 1 deletion rust/cubesql/cubesql/src/compile/engine/udf/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2268,7 +2268,9 @@ pub fn create_measure_udaf() -> AggregateUDF {
DataType::Float64,
Arc::new(DataType::Float64),
Volatility::Immutable,
Arc::new(|| todo!("Not implemented")),
Arc::new(|| {
Err(DataFusionError::NotImplemented("MEASURE function was used in context where it's not supported. Try replacing MEASURE with the measure type-matching function (SUM/AVG/etc).".to_string()))

Check warning on line 2272 in rust/cubesql/cubesql/src/compile/engine/udf/common.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/engine/udf/common.rs#L2272

Added line #L2272 was not covered by tests
}),
Arc::new(vec![DataType::Float64]),
)
}
Expand Down
33 changes: 33 additions & 0 deletions rust/cubesql/cubesql/src/compile/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18588,4 +18588,37 @@ LIMIT {{ limit }}{% endif %}"#.to_string(),

Ok(())
}

#[tokio::test]
async fn test_measure_func_push_down() {
if !Rewriter::sql_push_down_enabled() {
return;
}
init_testing_logger();

let query_plan = convert_select_to_query_plan(
r#"
SELECT MEASURE("sumPrice") AS "total_price"
FROM "public"."KibanaSampleDataEcommerce"
WHERE lower("customer_gender") = '123'
"#
.to_string(),
DatabaseProtocol::PostgreSQL,
)
.await;

let logical_plan = query_plan.as_logical_plan();
let sql = logical_plan
.find_cube_scan_wrapper()
.wrapped_sql
.unwrap()
.sql;
assert!(sql.contains("SUM("));

let physical_plan = query_plan.as_physical_plan().await.unwrap();
println!(
"Physical plan: {}",
displayable(physical_plan.as_ref()).indent()
);
}
}
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use crate::{
compile::rewrite::{
agg_fun_expr, analysis::LogicalPlanAnalysis, rewrite, rules::wrapper::WrapperRules,
transforming_rewrite, wrapper_pullup_replacer, wrapper_pushdown_replacer,
AggregateFunctionExprDistinct, AggregateFunctionExprFun, LogicalPlanLanguage,
WrapperPullupReplacerAliasToCube,
agg_fun_expr, alias_expr, analysis::LogicalPlanAnalysis, column_expr, original_expr_name,
rewrite, rules::wrapper::WrapperRules, transforming_chain_rewrite, transforming_rewrite,
udaf_expr, wrapper_pullup_replacer, wrapper_pushdown_replacer,
AggregateFunctionExprDistinct, AggregateFunctionExprFun, AggregateUDFExprFun,
AliasExprAlias, ColumnExprColumn, LogicalPlanLanguage, WrapperPullupReplacerAliasToCube,
},
transport::V1CubeMetaExt,
var, var_iter,
};
use datafusion::physical_plan::aggregates::AggregateFunction;
Expand Down Expand Up @@ -59,6 +61,35 @@ impl WrapperRules {
),
self.transform_agg_fun_expr("?fun", "?distinct", "?alias_to_cube"),
),
transforming_chain_rewrite(
"wrapper-push-down-measure-aggregate-function",
wrapper_pushdown_replacer(
"?udaf",
"?alias_to_cube",
"?ungrouped",
"?in_projection",
"?cube_members",
),
vec![("?udaf", udaf_expr("?fun", vec![column_expr("?column")]))],
alias_expr(
wrapper_pushdown_replacer(
"?output",
"?alias_to_cube",
"?ungrouped",
"?in_projection",
"?cube_members",
),
"?alias",
),
self.transform_measure_udaf_expr(
"?udaf",
"?fun",
"?column",
"?alias_to_cube",
"?output",
"?alias",
),
),
]);
}

Expand Down Expand Up @@ -105,4 +136,106 @@ impl WrapperRules {
false
}
}

fn transform_measure_udaf_expr(
&self,
udaf_var: &'static str,
fun_var: &'static str,
column_var: &'static str,
alias_to_cube_var: &'static str,
output_var: &'static str,
alias_var: &'static str,
) -> impl Fn(&mut EGraph<LogicalPlanLanguage, LogicalPlanAnalysis>, &mut Subst) -> bool {
let udaf_var = var!(udaf_var);
let fun_var = var!(fun_var);
let column_var = var!(column_var);
let alias_to_cube_var = var!(alias_to_cube_var);
let output_var = var!(output_var);
let alias_var = var!(alias_var);
let meta = self.meta_context.clone();
move |egraph, subst| {
let Some(original_alias) = original_expr_name(egraph, subst[udaf_var]) else {
return false;

Check warning on line 158 in rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/aggregate_function.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/aggregate_function.rs#L158

Added line #L158 was not covered by tests
};

for fun in var_iter!(egraph[subst[fun_var]], AggregateUDFExprFun) {
if fun.to_lowercase() != "measure" {
continue;

Check warning on line 163 in rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/aggregate_function.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/aggregate_function.rs#L163

Added line #L163 was not covered by tests
}

for column in var_iter!(egraph[subst[column_var]], ColumnExprColumn) {
for alias_to_cube in var_iter!(
egraph[subst[alias_to_cube_var]],
WrapperPullupReplacerAliasToCube
) {
let Some((_, cube)) = meta.find_cube_by_column(alias_to_cube, column)
else {
continue;

Check warning on line 173 in rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/aggregate_function.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/aggregate_function.rs#L173

Added line #L173 was not covered by tests
};

let Some(measure) = cube.lookup_measure(&column.name) else {
continue;
};

let Some(agg_type) = &measure.agg_type else {
continue;

Check warning on line 181 in rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/aggregate_function.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/aggregate_function.rs#L181

Added line #L181 was not covered by tests
};

let out_fun_distinct = match agg_type.as_str() {
"string" | "time" | "boolean" | "number" => None,
"count" => Some((AggregateFunction::Count, false)),
"countDistinct" => Some((AggregateFunction::Count, true)),
"countDistinctApprox" => {
Some((AggregateFunction::ApproxDistinct, false))

Check warning on line 189 in rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/aggregate_function.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/aggregate_function.rs#L189

Added line #L189 was not covered by tests
}
"sum" => Some((AggregateFunction::Sum, false)),
"avg" => Some((AggregateFunction::Avg, false)),
"min" => Some((AggregateFunction::Min, false)),
"max" => Some((AggregateFunction::Max, false)),
_ => continue,

Check warning on line 195 in rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/aggregate_function.rs

View check run for this annotation

Codecov / codecov/patch

rust/cubesql/cubesql/src/compile/rewrite/rules/wrapper/aggregate_function.rs#L195

Added line #L195 was not covered by tests
};

let column_expr_id =
egraph.add(LogicalPlanLanguage::ColumnExpr([subst[column_var]]));

let output_id = out_fun_distinct
.map(|(out_fun, distinct)| {
let fun_id =
egraph.add(LogicalPlanLanguage::AggregateFunctionExprFun(
AggregateFunctionExprFun(out_fun),
));
let args_tail_id = egraph
.add(LogicalPlanLanguage::AggregateFunctionExprArgs(vec![]));
let args_id =
egraph.add(LogicalPlanLanguage::AggregateFunctionExprArgs(
vec![column_expr_id, args_tail_id],
));
let distinct_id =
egraph.add(LogicalPlanLanguage::AggregateFunctionExprDistinct(
AggregateFunctionExprDistinct(distinct),
));

egraph.add(LogicalPlanLanguage::AggregateFunctionExpr([
fun_id,
args_id,
distinct_id,
]))
})
.unwrap_or(column_expr_id);

subst.insert(output_var, output_id);

subst.insert(
alias_var,
egraph.add(LogicalPlanLanguage::AliasExprAlias(AliasExprAlias(
original_alias,
))),
);
return true;
}
}
}
false
}
}
}

0 comments on commit 4968959

Please sign in to comment.