Skip to content

Commit

Permalink
fix(cubesql): Transform IN filter with one value to = with all ex…
Browse files Browse the repository at this point in the history
…pressions
  • Loading branch information
MazterQyou authored Jul 19, 2024
1 parent f135933 commit 671e067
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 33 deletions.
42 changes: 42 additions & 0 deletions rust/cubesql/cubesql/src/compile/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24325,4 +24325,46 @@ LIMIT {{ limit }}{% endif %}"#.to_string(),

Ok(())
}

#[tokio::test]
async fn test_filter_time_dimension_equals_as_date_range() {
init_logger();

let logical_plan = convert_select_to_query_plan(
r#"
SELECT
measure(count) AS cnt,
date_trunc('month', order_date) AS dt
FROM KibanaSampleDataEcommerce
WHERE order_date IN (to_timestamp('2019-01-01 00:00:00.000000', 'YYYY-MM-DD HH24:MI:SS.US'))
GROUP BY 2
;"#
.to_string(),
DatabaseProtocol::PostgreSQL,
)
.await
.as_logical_plan();

assert_eq!(
logical_plan.find_cube_scan().request,
V1LoadRequestQuery {
measures: Some(vec!["KibanaSampleDataEcommerce.count".to_string()]),
dimensions: Some(vec![]),
segments: Some(vec![]),
time_dimensions: Some(vec![V1LoadRequestQueryTimeDimension {
dimension: "KibanaSampleDataEcommerce.order_date".to_string(),
granularity: Some("month".to_string()),
date_range: Some(json!(vec![
"2019-01-01T00:00:00.000Z".to_string(),
"2019-01-01T00:00:00.000Z".to_string()
]))
}]),
order: None,
limit: None,
offset: None,
filters: None,
ungrouped: None,
}
)
}
}
4 changes: 4 additions & 0 deletions rust/cubesql/cubesql/src/compile/rewrite/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1536,6 +1536,10 @@ fn inlist_expr(expr: impl Display, list: impl Display, negated: impl Display) ->
format!("(InListExpr {} {} {})", expr, list, negated)
}

fn inlist_expr_list(exprs: Vec<impl Display>) -> String {
flat_list_expr("InListExprList", exprs, true)
}

fn insubquery_expr(expr: impl Display, subquery: impl Display, negated: impl Display) -> String {
format!("(InSubqueryExpr {} {} {})", expr, subquery, negated)
}
Expand Down
46 changes: 13 additions & 33 deletions rust/cubesql/cubesql/src/compile/rewrite/rules/filters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@ use crate::{
column_expr, cube_scan, cube_scan_filters, cube_scan_filters_empty_tail, cube_scan_members,
dimension_expr, expr_column_name, filter, filter_member, filter_op, filter_op_filters,
filter_op_filters_empty_tail, filter_replacer, filter_simplify_replacer, fun_expr,
fun_expr_args_legacy, fun_expr_var_arg, inlist_expr, is_not_null_expr, is_null_expr,
like_expr, limit, list_rewrite, literal_bool, literal_expr, literal_int, literal_string,
measure_expr, member_name_to_expr_by_alias, negative_expr, not_expr, projection, rewrite,
fun_expr_args_legacy, fun_expr_var_arg, inlist_expr, inlist_expr_list, is_not_null_expr,
is_null_expr, like_expr, limit, list_rewrite, literal_bool, literal_expr, literal_int,
literal_string, measure_expr, member_name_to_expr_by_alias, negative_expr, not_expr,
projection, rewrite,
rewriter::RewriteRules,
scalar_fun_expr_args_empty_tail, segment_member, time_dimension_date_range_replacer,
time_dimension_expr, transform_original_expr_to_alias, transforming_chain_rewrite,
Expand Down Expand Up @@ -394,18 +395,18 @@ impl RewriteRules for FilterRules {
transforming_rewrite(
"in-filter-equal",
filter_replacer(
inlist_expr("?expr", "?list", "?negated"),
inlist_expr("?expr", inlist_expr_list(vec!["?elem"]), "?negated"),
"?alias_to_cube",
"?members",
"?filter_aliases",
),
filter_replacer(
"?binary_expr",
binary_expr("?expr", "?op", "?elem"),
"?alias_to_cube",
"?members",
"?filter_aliases",
),
self.transform_filter_in_to_equal("?expr", "?list", "?negated", "?binary_expr"),
self.transform_filter_in_to_equal("?negated", "?op"),
),
transforming_rewrite(
"filter-in-list-datetrunc",
Expand Down Expand Up @@ -3288,45 +3289,24 @@ impl FilterRules {
// Transform ?expr IN (?literal) to ?expr = ?literal
fn transform_filter_in_to_equal(
&self,
expr_val: &'static str,
list_var: &'static str,
negated_var: &'static str,
return_binary_expr_var: &'static str,
op_var: &'static str,
) -> impl Fn(&mut EGraph<LogicalPlanLanguage, LogicalPlanAnalysis>, &mut Subst) -> bool {
let expr_val = var!(expr_val);
let list_var = var!(list_var);
let negated_var = var!(negated_var);
let return_binary_expr_var = var!(return_binary_expr_var);
let op_var = var!(op_var);

move |egraph, subst| {
let expr_id = subst[expr_val];
let scalar = match &egraph[subst[list_var]].data.constant_in_list {
Some(list) if list.len() == 1 => list[0].clone(),
_ => return false,
};

for negated in var_iter!(egraph[subst[negated_var]], InListExprNegated) {
let operator = if *negated {
Operator::NotEq
} else {
Operator::Eq
};
let operator =
egraph.add(LogicalPlanLanguage::BinaryExprOp(BinaryExprOp(operator)));

let literal_expr = egraph.add(LogicalPlanLanguage::LiteralExprValue(
LiteralExprValue(scalar),
));
let literal_expr = egraph.add(LogicalPlanLanguage::LiteralExpr([literal_expr]));

let return_binary_expr = egraph.add(LogicalPlanLanguage::BinaryExpr([
expr_id,
operator,
literal_expr,
]));

subst.insert(return_binary_expr_var, return_binary_expr);

subst.insert(
op_var,
egraph.add(LogicalPlanLanguage::BinaryExprOp(BinaryExprOp(operator))),
);
return true;
}

Expand Down

0 comments on commit 671e067

Please sign in to comment.