Skip to content

Commit

Permalink
fix: incorrect simplification of case expr (#7006)
Browse files Browse the repository at this point in the history
* fix: incorrect simplification of case expr

* Use pattern to match against None
  • Loading branch information
jonahgao authored Jul 18, 2023
1 parent 9338880 commit 796f5f5
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 15 deletions.
12 changes: 12 additions & 0 deletions datafusion/core/tests/sqllogictests/test_files/scalar.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1090,6 +1090,18 @@ FROM t1
999
999

# issue: https://github.com/apache/arrow-datafusion/issues/7004
query B
select case c1
when 'foo' then TRUE
when 'bar' then FALSE
end from t1
----
NULL
NULL
NULL
NULL

statement ok
drop table t1

Expand Down
33 changes: 18 additions & 15 deletions datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ use datafusion_common::tree_node::{RewriteRecursion, TreeNode, TreeNodeRewriter}
use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue};
use datafusion_expr::expr::{InList, InSubquery, ScalarFunction};
use datafusion_expr::{
and, expr, lit, or, BinaryExpr, BuiltinScalarFunction, ColumnarValue, Expr, Like,
Volatility,
and, expr, lit, or, BinaryExpr, BuiltinScalarFunction, Case, ColumnarValue, Expr,
Like, Volatility,
};
use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionProps};

Expand Down Expand Up @@ -1069,17 +1069,20 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> {
//
// Note: the rationale for this rewrite is that the expr can then be further
// simplified using the existing rules for AND/OR
Expr::Case(case)
if !case.when_then_expr.is_empty()
&& case.when_then_expr.len() < 3 // The rewrite is O(n!) so limit to small number
&& info.is_boolean_type(&case.when_then_expr[0].1)? =>
Expr::Case(Case {
expr: None,
when_then_expr,
else_expr,
}) if !when_then_expr.is_empty()
&& when_then_expr.len() < 3 // The rewrite is O(n!) so limit to small number
&& info.is_boolean_type(&when_then_expr[0].1)? =>
{
// The disjunction of all the when predicates encountered so far
let mut filter_expr = lit(false);
// The disjunction of all the cases
let mut out_expr = lit(false);

for (when, then) in case.when_then_expr {
for (when, then) in when_then_expr {
let case_expr = when
.as_ref()
.clone()
Expand All @@ -1090,7 +1093,7 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> {
filter_expr = filter_expr.or(*when);
}

if let Some(else_expr) = case.else_expr {
if let Some(else_expr) = else_expr {
let case_expr = filter_expr.not().and(*else_expr);
out_expr = out_expr.or(case_expr);
}
Expand Down Expand Up @@ -2819,9 +2822,9 @@ mod tests {

#[test]
fn simplify_expr_case_when_then_else() {
// CASE WHERE c2 != false THEN "ok" == "not_ok" ELSE c2 == true
// CASE WHEN c2 != false THEN "ok" == "not_ok" ELSE c2 == true
// -->
// CASE WHERE c2 THEN false ELSE c2
// CASE WHEN c2 THEN false ELSE c2
// -->
// false
assert_eq!(
Expand All @@ -2836,9 +2839,9 @@ mod tests {
col("c2").not().and(col("c2")) // #1716
);

// CASE WHERE c2 != false THEN "ok" == "ok" ELSE c2
// CASE WHEN c2 != false THEN "ok" == "ok" ELSE c2
// -->
// CASE WHERE c2 THEN true ELSE c2
// CASE WHEN c2 THEN true ELSE c2
// -->
// c2
//
Expand All @@ -2856,7 +2859,7 @@ mod tests {
col("c2").or(col("c2").not().and(col("c2"))) // #1716
);

// CASE WHERE ISNULL(c2) THEN true ELSE c2
// CASE WHEN ISNULL(c2) THEN true ELSE c2
// -->
// ISNULL(c2) OR c2
//
Expand All @@ -2873,7 +2876,7 @@ mod tests {
.or(col("c2").is_not_null().and(col("c2")))
);

// CASE WHERE c1 then true WHERE c2 then false ELSE true
// CASE WHEN c1 then true WHEN c2 then false ELSE true
// --> c1 OR (NOT(c1) AND c2 AND FALSE) OR (NOT(c1 OR c2) AND TRUE)
// --> c1 OR (NOT(c1) AND NOT(c2))
// --> c1 OR NOT(c2)
Expand All @@ -2892,7 +2895,7 @@ mod tests {
col("c1").or(col("c1").not().and(col("c2").not()))
);

// CASE WHERE c1 then true WHERE c2 then true ELSE false
// CASE WHEN c1 then true WHEN c2 then true ELSE false
// --> c1 OR (NOT(c1) AND c2 AND TRUE) OR (NOT(c1 OR c2) AND FALSE)
// --> c1 OR (NOT(c1) AND c2)
// --> c1 OR c2
Expand Down

0 comments on commit 796f5f5

Please sign in to comment.