Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unhandled hook to PruningPredicate #12606

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 164 additions & 22 deletions datafusion/core/src/physical_optimizer/pruning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,31 @@ pub struct PruningPredicate {
literal_guarantees: Vec<LiteralGuarantee>,
}

/// Hook to handle predicates that DataFusion can not handle, e.g. certain complex expressions
/// or predicates that reference columns that are not in the schema.
pub trait UnhandledPredicateHook {
/// Called when a predicate can not be handled by DataFusion's transformation rules
/// or is referencing a column that is not in the schema.
fn handle(&self, expr: &Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalExpr>;
}
Comment on lines +483 to +487
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be a closure but I had issues with lifetimes, etc. Having the trait also gives it a useful name 😄

The other API questions are:

  • Should this be mutable? I think implementers can just use interior mutability if needed.
  • Should this make it easier to say "use the existing expression"? I don't think that's a common case, and the current APIs use &Arc<dyn PhysicalExpr> -> Arc<dyn PhysicalExpr> as well. Plus it's as easy as a Clone on an Arc.


#[derive(Debug, Clone)]
struct ConstantUnhandledPredicateHook {
default: Arc<dyn PhysicalExpr>,
}

impl ConstantUnhandledPredicateHook {
fn new(default: Arc<dyn PhysicalExpr>) -> Self {
Self { default }
}
}

impl UnhandledPredicateHook for ConstantUnhandledPredicateHook {
fn handle(&self, _expr: &Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalExpr> {
self.default.clone()
}
}

impl PruningPredicate {
/// Try to create a new instance of [`PruningPredicate`]
///
Expand All @@ -502,10 +527,33 @@ impl PruningPredicate {
/// See the struct level documentation on [`PruningPredicate`] for more
/// details.
pub fn try_new(expr: Arc<dyn PhysicalExpr>, schema: SchemaRef) -> Result<Self> {
let unhandled_hook = Arc::new(ConstantUnhandledPredicateHook::new(Arc::new(
phys_expr::Literal::new(ScalarValue::Boolean(Some(true))),
)));
Self::try_new_with_unhandled_hook(expr, schema, unhandled_hook)
}

/// Try to create a new instance of [`PruningPredicate`] with a custom
/// unhandled hook.
///
/// This is the same as [`PruningPredicate::try_new`] but allows for a custom
/// hook to be used when a predicate can not be handled by DataFusion's
/// transformation rules or is referencing a column that is not in the schema.
///
/// By default, a constant `true` is returned for unhandled predicates.
pub fn try_new_with_unhandled_hook(
expr: Arc<dyn PhysicalExpr>,
schema: SchemaRef,
unhandled_hook: Arc<dyn UnhandledPredicateHook>,
) -> Result<Self> {
// build predicate expression once
let mut required_columns = RequiredColumns::new();
let predicate_expr =
build_predicate_expression(&expr, schema.as_ref(), &mut required_columns);
let predicate_expr = build_predicate_expression(
&expr,
schema.as_ref(),
&mut required_columns,
&unhandled_hook,
);

let literal_guarantees = LiteralGuarantee::analyze(&expr);

Expand Down Expand Up @@ -1323,16 +1371,13 @@ fn build_predicate_expression(
expr: &Arc<dyn PhysicalExpr>,
schema: &Schema,
required_columns: &mut RequiredColumns,
unhandled_hook: &Arc<dyn UnhandledPredicateHook>,
) -> Arc<dyn PhysicalExpr> {
// Returned for unsupported expressions. Such expressions are
// converted to TRUE.
let unhandled = Arc::new(phys_expr::Literal::new(ScalarValue::Boolean(Some(true))));

// predicate expression can only be a binary expression
let expr_any = expr.as_any();
if let Some(is_null) = expr_any.downcast_ref::<phys_expr::IsNullExpr>() {
return build_is_null_column_expr(is_null.arg(), schema, required_columns, false)
.unwrap_or(unhandled);
.unwrap_or_else(|| unhandled_hook.handle(expr));
}
if let Some(is_not_null) = expr_any.downcast_ref::<phys_expr::IsNotNullExpr>() {
return build_is_null_column_expr(
Expand All @@ -1341,19 +1386,19 @@ fn build_predicate_expression(
required_columns,
true,
)
.unwrap_or(unhandled);
.unwrap_or_else(|| unhandled_hook.handle(expr));
}
if let Some(col) = expr_any.downcast_ref::<phys_expr::Column>() {
return build_single_column_expr(col, schema, required_columns, false)
.unwrap_or(unhandled);
.unwrap_or_else(|| unhandled_hook.handle(expr));
}
if let Some(not) = expr_any.downcast_ref::<phys_expr::NotExpr>() {
// match !col (don't do so recursively)
if let Some(col) = not.arg().as_any().downcast_ref::<phys_expr::Column>() {
return build_single_column_expr(col, schema, required_columns, true)
.unwrap_or(unhandled);
.unwrap_or_else(|| unhandled_hook.handle(expr));
} else {
return unhandled;
return unhandled_hook.handle(expr);
}
}
if let Some(in_list) = expr_any.downcast_ref::<phys_expr::InListExpr>() {
Expand Down Expand Up @@ -1382,9 +1427,14 @@ fn build_predicate_expression(
})
.reduce(|a, b| Arc::new(phys_expr::BinaryExpr::new(a, re_op, b)) as _)
.unwrap();
return build_predicate_expression(&change_expr, schema, required_columns);
return build_predicate_expression(
&change_expr,
schema,
required_columns,
unhandled_hook,
);
} else {
return unhandled;
return unhandled_hook.handle(expr);
}
}

Expand All @@ -1396,21 +1446,23 @@ fn build_predicate_expression(
bin_expr.right().clone(),
)
} else {
return unhandled;
return unhandled_hook.handle(expr);
}
};

if op == Operator::And || op == Operator::Or {
let left_expr = build_predicate_expression(&left, schema, required_columns);
let right_expr = build_predicate_expression(&right, schema, required_columns);
let left_expr =
build_predicate_expression(&left, schema, required_columns, unhandled_hook);
let right_expr =
build_predicate_expression(&right, schema, required_columns, unhandled_hook);
// simplify boolean expression if applicable
let expr = match (&left_expr, op, &right_expr) {
(left, Operator::And, _) if is_always_true(left) => right_expr,
(_, Operator::And, right) if is_always_true(right) => left_expr,
(left, Operator::Or, right)
if is_always_true(left) || is_always_true(right) =>
{
unhandled
Arc::new(phys_expr::Literal::new(ScalarValue::Boolean(Some(true))))
}
_ => Arc::new(phys_expr::BinaryExpr::new(left_expr, op, right_expr)),
};
Expand All @@ -1423,12 +1475,11 @@ fn build_predicate_expression(
Ok(builder) => builder,
// allow partial failure in predicate expression generation
// this can still produce a useful predicate when multiple conditions are joined using AND
Err(_) => {
return unhandled;
}
Err(_) => return unhandled_hook.handle(expr),
};

build_statistics_expr(&mut expr_builder).unwrap_or(unhandled)
build_statistics_expr(&mut expr_builder)
.unwrap_or_else(|_| unhandled_hook.handle(expr))
}

fn build_statistics_expr(
Expand Down Expand Up @@ -1582,6 +1633,7 @@ mod tests {
use arrow_array::UInt64Array;
use datafusion_expr::expr::InList;
use datafusion_expr::{cast, is_null, try_cast, Expr};
use datafusion_physical_expr::expressions as phys_expr;
use datafusion_physical_expr::planner::logical2physical;

#[derive(Debug, Default)]
Expand Down Expand Up @@ -3397,6 +3449,92 @@ mod tests {
// TODO: add test for other case and op
}

#[test]
fn test_rewrite_expr_to_prunable_custom_unhandled_hook() {
struct CustomUnhandledHook;

impl UnhandledPredicateHook for CustomUnhandledHook {
/// This handles an arbitrary case of a column that doesn't exist in the schema
/// by renaming it to yet another column that doesn't exist in the schema
/// (the transformation is arbitrary, the point is that it can do whatever it wants)
fn handle(&self, expr: &Arc<dyn PhysicalExpr>) -> Arc<dyn PhysicalExpr> {
if let Some(expr) = expr.as_any().downcast_ref::<phys_expr::BinaryExpr>()
{
let left = expr.left();
let right = expr.right();
if let Some(column) =
left.as_any().downcast_ref::<phys_expr::Column>()
{
if column.name() == "b"
&& right
.as_any()
.downcast_ref::<phys_expr::Literal>()
.is_some()
{
let new_column =
Arc::new(phys_expr::Column::new("c", column.index()))
as _;
return Arc::new(phys_expr::BinaryExpr::new(
new_column,
*expr.op(),
right.clone(),
));
}
}
}

Arc::new(phys_expr::Literal::new(ScalarValue::Boolean(Some(true))))
as Arc<dyn PhysicalExpr>
}
}

let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);

let expr = Arc::new(phys_expr::BinaryExpr::new(
Arc::new(phys_expr::Column::new("b", 1)),
Operator::Eq,
Arc::new(phys_expr::Literal::new(ScalarValue::Int32(Some(12)))),
)) as Arc<dyn PhysicalExpr>;

let expected_expr = Arc::new(phys_expr::BinaryExpr::new(
Arc::new(phys_expr::Column::new("c", 1)),
Operator::Eq,
Arc::new(phys_expr::Literal::new(ScalarValue::Int32(Some(12)))),
)) as Arc<dyn PhysicalExpr>;

let handler = Arc::new(CustomUnhandledHook {}) as _;
let actual_expr = build_predicate_expression(
&expr,
&schema,
&mut RequiredColumns::new(),
&handler,
);

assert_eq!(actual_expr.to_string(), expected_expr.to_string());

// but other cases do end up as `true`

let expr = Arc::new(phys_expr::BinaryExpr::new(
Arc::new(phys_expr::Column::new("d", 1)),
Operator::Eq,
Arc::new(phys_expr::Literal::new(ScalarValue::Int32(Some(12)))),
)) as Arc<dyn PhysicalExpr>;

let expected_expr =
Arc::new(phys_expr::Literal::new(ScalarValue::Boolean(Some(true))))
as Arc<dyn PhysicalExpr>;

let handler = Arc::new(CustomUnhandledHook {}) as _;
let actual_expr = build_predicate_expression(
&expr,
&schema,
&mut RequiredColumns::new(),
&handler,
);

assert_eq!(actual_expr.to_string(), expected_expr.to_string());
}

#[test]
fn test_rewrite_expr_to_prunable_error() {
// cast string value to numeric value
Expand Down Expand Up @@ -3886,6 +4024,10 @@ mod tests {
required_columns: &mut RequiredColumns,
) -> Arc<dyn PhysicalExpr> {
let expr = logical2physical(expr, schema);
build_predicate_expression(&expr, schema, required_columns)
// return literal true
let unhandled_hook = Arc::new(ConstantUnhandledPredicateHook::new(Arc::new(
phys_expr::Literal::new(ScalarValue::Boolean(Some(true))),
))) as _;
build_predicate_expression(&expr, schema, required_columns, &unhandled_hook)
}
}