diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index 9bc2bb1d1db9..5671345d8a09 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -478,6 +478,31 @@ pub struct PruningPredicate { literal_guarantees: Vec, } +/// Hook to handle predicates that DataFusion can not handle, e.g. certain complex expressions +/// or predicates that reference columns that are not in the schema. +pub trait UnhandledPredicateHook { + /// Called when a predicate can not be handled by DataFusion's transformation rules + /// or is referencing a column that is not in the schema. + fn handle(&self, expr: &Arc) -> Arc; +} + +#[derive(Debug, Clone)] +struct ConstantUnhandledPredicateHook { + default: Arc, +} + +impl ConstantUnhandledPredicateHook { + fn new(default: Arc) -> Self { + Self { default } + } +} + +impl UnhandledPredicateHook for ConstantUnhandledPredicateHook { + fn handle(&self, _expr: &Arc) -> Arc { + self.default.clone() + } +} + impl PruningPredicate { /// Try to create a new instance of [`PruningPredicate`] /// @@ -502,10 +527,33 @@ impl PruningPredicate { /// See the struct level documentation on [`PruningPredicate`] for more /// details. pub fn try_new(expr: Arc, schema: SchemaRef) -> Result { + let unhandled_hook = Arc::new(ConstantUnhandledPredicateHook::new(Arc::new( + phys_expr::Literal::new(ScalarValue::Boolean(Some(true))), + ))); + Self::try_new_with_unhandled_hook(expr, schema, unhandled_hook) + } + + /// Try to create a new instance of [`PruningPredicate`] with a custom + /// unhandled hook. + /// + /// This is the same as [`PruningPredicate::try_new`] but allows for a custom + /// hook to be used when a predicate can not be handled by DataFusion's + /// transformation rules or is referencing a column that is not in the schema. + /// + /// By default, a constant `true` is returned for unhandled predicates. + pub fn try_new_with_unhandled_hook( + expr: Arc, + schema: SchemaRef, + unhandled_hook: Arc, + ) -> Result { // build predicate expression once let mut required_columns = RequiredColumns::new(); - let predicate_expr = - build_predicate_expression(&expr, schema.as_ref(), &mut required_columns); + let predicate_expr = build_predicate_expression( + &expr, + schema.as_ref(), + &mut required_columns, + &unhandled_hook, + ); let literal_guarantees = LiteralGuarantee::analyze(&expr); @@ -1323,16 +1371,13 @@ fn build_predicate_expression( expr: &Arc, schema: &Schema, required_columns: &mut RequiredColumns, + unhandled_hook: &Arc, ) -> Arc { - // Returned for unsupported expressions. Such expressions are - // converted to TRUE. - let unhandled = Arc::new(phys_expr::Literal::new(ScalarValue::Boolean(Some(true)))); - // predicate expression can only be a binary expression let expr_any = expr.as_any(); if let Some(is_null) = expr_any.downcast_ref::() { return build_is_null_column_expr(is_null.arg(), schema, required_columns, false) - .unwrap_or(unhandled); + .unwrap_or_else(|| unhandled_hook.handle(expr)); } if let Some(is_not_null) = expr_any.downcast_ref::() { return build_is_null_column_expr( @@ -1341,19 +1386,19 @@ fn build_predicate_expression( required_columns, true, ) - .unwrap_or(unhandled); + .unwrap_or_else(|| unhandled_hook.handle(expr)); } if let Some(col) = expr_any.downcast_ref::() { return build_single_column_expr(col, schema, required_columns, false) - .unwrap_or(unhandled); + .unwrap_or_else(|| unhandled_hook.handle(expr)); } if let Some(not) = expr_any.downcast_ref::() { // match !col (don't do so recursively) if let Some(col) = not.arg().as_any().downcast_ref::() { return build_single_column_expr(col, schema, required_columns, true) - .unwrap_or(unhandled); + .unwrap_or_else(|| unhandled_hook.handle(expr)); } else { - return unhandled; + return unhandled_hook.handle(expr); } } if let Some(in_list) = expr_any.downcast_ref::() { @@ -1382,9 +1427,14 @@ fn build_predicate_expression( }) .reduce(|a, b| Arc::new(phys_expr::BinaryExpr::new(a, re_op, b)) as _) .unwrap(); - return build_predicate_expression(&change_expr, schema, required_columns); + return build_predicate_expression( + &change_expr, + schema, + required_columns, + unhandled_hook, + ); } else { - return unhandled; + return unhandled_hook.handle(expr); } } @@ -1396,13 +1446,15 @@ fn build_predicate_expression( bin_expr.right().clone(), ) } else { - return unhandled; + return unhandled_hook.handle(expr); } }; if op == Operator::And || op == Operator::Or { - let left_expr = build_predicate_expression(&left, schema, required_columns); - let right_expr = build_predicate_expression(&right, schema, required_columns); + let left_expr = + build_predicate_expression(&left, schema, required_columns, unhandled_hook); + let right_expr = + build_predicate_expression(&right, schema, required_columns, unhandled_hook); // simplify boolean expression if applicable let expr = match (&left_expr, op, &right_expr) { (left, Operator::And, _) if is_always_true(left) => right_expr, @@ -1410,7 +1462,7 @@ fn build_predicate_expression( (left, Operator::Or, right) if is_always_true(left) || is_always_true(right) => { - unhandled + Arc::new(phys_expr::Literal::new(ScalarValue::Boolean(Some(true)))) } _ => Arc::new(phys_expr::BinaryExpr::new(left_expr, op, right_expr)), }; @@ -1423,12 +1475,11 @@ fn build_predicate_expression( Ok(builder) => builder, // allow partial failure in predicate expression generation // this can still produce a useful predicate when multiple conditions are joined using AND - Err(_) => { - return unhandled; - } + Err(_) => return unhandled_hook.handle(expr), }; - build_statistics_expr(&mut expr_builder).unwrap_or(unhandled) + build_statistics_expr(&mut expr_builder) + .unwrap_or_else(|_| unhandled_hook.handle(expr)) } fn build_statistics_expr( @@ -1582,6 +1633,7 @@ mod tests { use arrow_array::UInt64Array; use datafusion_expr::expr::InList; use datafusion_expr::{cast, is_null, try_cast, Expr}; + use datafusion_physical_expr::expressions as phys_expr; use datafusion_physical_expr::planner::logical2physical; #[derive(Debug, Default)] @@ -3397,6 +3449,92 @@ mod tests { // TODO: add test for other case and op } + #[test] + fn test_rewrite_expr_to_prunable_custom_unhandled_hook() { + struct CustomUnhandledHook; + + impl UnhandledPredicateHook for CustomUnhandledHook { + /// This handles an arbitrary case of a column that doesn't exist in the schema + /// by renaming it to yet another column that doesn't exist in the schema + /// (the transformation is arbitrary, the point is that it can do whatever it wants) + fn handle(&self, expr: &Arc) -> Arc { + if let Some(expr) = expr.as_any().downcast_ref::() + { + let left = expr.left(); + let right = expr.right(); + if let Some(column) = + left.as_any().downcast_ref::() + { + if column.name() == "b" + && right + .as_any() + .downcast_ref::() + .is_some() + { + let new_column = + Arc::new(phys_expr::Column::new("c", column.index())) + as _; + return Arc::new(phys_expr::BinaryExpr::new( + new_column, + *expr.op(), + right.clone(), + )); + } + } + } + + Arc::new(phys_expr::Literal::new(ScalarValue::Boolean(Some(true)))) + as Arc + } + } + + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + + let expr = Arc::new(phys_expr::BinaryExpr::new( + Arc::new(phys_expr::Column::new("b", 1)), + Operator::Eq, + Arc::new(phys_expr::Literal::new(ScalarValue::Int32(Some(12)))), + )) as Arc; + + let expected_expr = Arc::new(phys_expr::BinaryExpr::new( + Arc::new(phys_expr::Column::new("c", 1)), + Operator::Eq, + Arc::new(phys_expr::Literal::new(ScalarValue::Int32(Some(12)))), + )) as Arc; + + let handler = Arc::new(CustomUnhandledHook {}) as _; + let actual_expr = build_predicate_expression( + &expr, + &schema, + &mut RequiredColumns::new(), + &handler, + ); + + assert_eq!(actual_expr.to_string(), expected_expr.to_string()); + + // but other cases do end up as `true` + + let expr = Arc::new(phys_expr::BinaryExpr::new( + Arc::new(phys_expr::Column::new("d", 1)), + Operator::Eq, + Arc::new(phys_expr::Literal::new(ScalarValue::Int32(Some(12)))), + )) as Arc; + + let expected_expr = + Arc::new(phys_expr::Literal::new(ScalarValue::Boolean(Some(true)))) + as Arc; + + let handler = Arc::new(CustomUnhandledHook {}) as _; + let actual_expr = build_predicate_expression( + &expr, + &schema, + &mut RequiredColumns::new(), + &handler, + ); + + assert_eq!(actual_expr.to_string(), expected_expr.to_string()); + } + #[test] fn test_rewrite_expr_to_prunable_error() { // cast string value to numeric value @@ -3886,6 +4024,10 @@ mod tests { required_columns: &mut RequiredColumns, ) -> Arc { let expr = logical2physical(expr, schema); - build_predicate_expression(&expr, schema, required_columns) + // return literal true + let unhandled_hook = Arc::new(ConstantUnhandledPredicateHook::new(Arc::new( + phys_expr::Literal::new(ScalarValue::Boolean(Some(true))), + ))) as _; + build_predicate_expression(&expr, schema, required_columns, &unhandled_hook) } }