Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve coerce API so it does not need DFSchema #10331

Merged
merged 1 commit into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datafusion-examples/examples/expr_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ pub fn physical_expr(schema: &Schema, expr: Expr) -> Result<Arc<dyn PhysicalExpr
ExprSimplifier::new(SimplifyContext::new(&props).with_schema(df_schema.clone()));

// apply type coercion here to ensure types match
let expr = simplifier.coerce(expr, df_schema.clone())?;
let expr = simplifier.coerce(expr, &df_schema)?;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shows the effect of this API change on users (add a &)


create_physical_expr(&expr, df_schema.as_ref(), &props)
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/test_util/parquet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ impl TestParquetFile {
let parquet_options = ctx.copied_table_options().parquet;
if let Some(filter) = maybe_filter {
let simplifier = ExprSimplifier::new(context);
let filter = simplifier.coerce(filter, df_schema.clone()).unwrap();
let filter = simplifier.coerce(filter, &df_schema).unwrap();
let physical_filter_expr =
create_physical_expr(&filter, &df_schema, &ExecutionProps::default())?;
let parquet_exec = Arc::new(ParquetExec::new(
Expand Down
99 changes: 48 additions & 51 deletions datafusion/optimizer/src/analyzer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TreeNodeRewriter};
use datafusion_common::{
exec_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, DFSchema,
DFSchemaRef, DataFusionError, Result, ScalarValue,
DataFusionError, Result, ScalarValue,
};
use datafusion_expr::expr::{
self, AggregateFunctionDefinition, Between, BinaryExpr, Case, Exists, InList,
Expand Down Expand Up @@ -99,9 +99,7 @@ fn analyze_internal(
// select t2.c2 from t1 where t1.c1 in (select t2.c1 from t2 where t2.c2=t1.c3)
schema.merge(external_schema);

let mut expr_rewrite = TypeCoercionRewriter {
schema: Arc::new(schema),
};
let mut expr_rewrite = TypeCoercionRewriter { schema: &schema };

let new_expr = plan
.expressions()
Expand All @@ -116,11 +114,11 @@ fn analyze_internal(
plan.with_new_exprs(new_expr, new_inputs)
}

pub(crate) struct TypeCoercionRewriter {
pub(crate) schema: DFSchemaRef,
pub(crate) struct TypeCoercionRewriter<'a> {
pub(crate) schema: &'a DFSchema,
}

impl TreeNodeRewriter for TypeCoercionRewriter {
impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> {
type Node = Expr;

fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
Expand All @@ -132,14 +130,14 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
subquery,
outer_ref_columns,
}) => {
let new_plan = analyze_internal(&self.schema, &subquery)?;
let new_plan = analyze_internal(self.schema, &subquery)?;
Ok(Transformed::yes(Expr::ScalarSubquery(Subquery {
subquery: Arc::new(new_plan),
outer_ref_columns,
})))
}
Expr::Exists(Exists { subquery, negated }) => {
let new_plan = analyze_internal(&self.schema, &subquery.subquery)?;
let new_plan = analyze_internal(self.schema, &subquery.subquery)?;
Ok(Transformed::yes(Expr::Exists(Exists {
subquery: Subquery {
subquery: Arc::new(new_plan),
Expand All @@ -153,8 +151,8 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
subquery,
negated,
}) => {
let new_plan = analyze_internal(&self.schema, &subquery.subquery)?;
let expr_type = expr.get_type(&self.schema)?;
let new_plan = analyze_internal(self.schema, &subquery.subquery)?;
let expr_type = expr.get_type(self.schema)?;
let subquery_type = new_plan.schema().field(0).data_type();
let common_type = comparison_coercion(&expr_type, subquery_type).ok_or(plan_datafusion_err!(
"expr type {expr_type:?} can't cast to {subquery_type:?} in InSubquery"
Expand All @@ -165,32 +163,32 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
outer_ref_columns: subquery.outer_ref_columns,
};
Ok(Transformed::yes(Expr::InSubquery(InSubquery::new(
Box::new(expr.cast_to(&common_type, &self.schema)?),
Box::new(expr.cast_to(&common_type, self.schema)?),
cast_subquery(new_subquery, &common_type)?,
negated,
))))
}
Expr::Not(expr) => Ok(Transformed::yes(not(get_casted_expr_for_bool_op(
*expr,
&self.schema,
self.schema,
)?))),
Expr::IsTrue(expr) => Ok(Transformed::yes(is_true(
get_casted_expr_for_bool_op(*expr, &self.schema)?,
get_casted_expr_for_bool_op(*expr, self.schema)?,
))),
Expr::IsNotTrue(expr) => Ok(Transformed::yes(is_not_true(
get_casted_expr_for_bool_op(*expr, &self.schema)?,
get_casted_expr_for_bool_op(*expr, self.schema)?,
))),
Expr::IsFalse(expr) => Ok(Transformed::yes(is_false(
get_casted_expr_for_bool_op(*expr, &self.schema)?,
get_casted_expr_for_bool_op(*expr, self.schema)?,
))),
Expr::IsNotFalse(expr) => Ok(Transformed::yes(is_not_false(
get_casted_expr_for_bool_op(*expr, &self.schema)?,
get_casted_expr_for_bool_op(*expr, self.schema)?,
))),
Expr::IsUnknown(expr) => Ok(Transformed::yes(is_unknown(
get_casted_expr_for_bool_op(*expr, &self.schema)?,
get_casted_expr_for_bool_op(*expr, self.schema)?,
))),
Expr::IsNotUnknown(expr) => Ok(Transformed::yes(is_not_unknown(
get_casted_expr_for_bool_op(*expr, &self.schema)?,
get_casted_expr_for_bool_op(*expr, self.schema)?,
))),
Expr::Like(Like {
negated,
Expand All @@ -199,8 +197,8 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
escape_char,
case_insensitive,
}) => {
let left_type = expr.get_type(&self.schema)?;
let right_type = pattern.get_type(&self.schema)?;
let left_type = expr.get_type(self.schema)?;
let right_type = pattern.get_type(self.schema)?;
let coerced_type = like_coercion(&left_type, &right_type).ok_or_else(|| {
let op_name = if case_insensitive {
"ILIKE"
Expand All @@ -211,8 +209,8 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
"There isn't a common type to coerce {left_type} and {right_type} in {op_name} expression"
)
})?;
let expr = Box::new(expr.cast_to(&coerced_type, &self.schema)?);
let pattern = Box::new(pattern.cast_to(&coerced_type, &self.schema)?);
let expr = Box::new(expr.cast_to(&coerced_type, self.schema)?);
let pattern = Box::new(pattern.cast_to(&coerced_type, self.schema)?);
Ok(Transformed::yes(Expr::Like(Like::new(
negated,
expr,
Expand All @@ -223,14 +221,14 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
}
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
let (left_type, right_type) = get_input_types(
&left.get_type(&self.schema)?,
&left.get_type(self.schema)?,
&op,
&right.get_type(&self.schema)?,
&right.get_type(self.schema)?,
)?;
Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr::new(
Box::new(left.cast_to(&left_type, &self.schema)?),
Box::new(left.cast_to(&left_type, self.schema)?),
op,
Box::new(right.cast_to(&right_type, &self.schema)?),
Box::new(right.cast_to(&right_type, self.schema)?),
))))
}
Expr::Between(Between {
Expand All @@ -239,15 +237,15 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
low,
high,
}) => {
let expr_type = expr.get_type(&self.schema)?;
let low_type = low.get_type(&self.schema)?;
let expr_type = expr.get_type(self.schema)?;
let low_type = low.get_type(self.schema)?;
let low_coerced_type = comparison_coercion(&expr_type, &low_type)
.ok_or_else(|| {
DataFusionError::Internal(format!(
"Failed to coerce types {expr_type} and {low_type} in BETWEEN expression"
))
})?;
let high_type = high.get_type(&self.schema)?;
let high_type = high.get_type(self.schema)?;
let high_coerced_type = comparison_coercion(&expr_type, &low_type)
.ok_or_else(|| {
DataFusionError::Internal(format!(
Expand All @@ -262,21 +260,21 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
))
})?;
Ok(Transformed::yes(Expr::Between(Between::new(
Box::new(expr.cast_to(&coercion_type, &self.schema)?),
Box::new(expr.cast_to(&coercion_type, self.schema)?),
negated,
Box::new(low.cast_to(&coercion_type, &self.schema)?),
Box::new(high.cast_to(&coercion_type, &self.schema)?),
Box::new(low.cast_to(&coercion_type, self.schema)?),
Box::new(high.cast_to(&coercion_type, self.schema)?),
))))
}
Expr::InList(InList {
expr,
list,
negated,
}) => {
let expr_data_type = expr.get_type(&self.schema)?;
let expr_data_type = expr.get_type(self.schema)?;
let list_data_types = list
.iter()
.map(|list_expr| list_expr.get_type(&self.schema))
.map(|list_expr| list_expr.get_type(self.schema))
.collect::<Result<Vec<_>>>()?;
let result_type =
get_coerce_type_for_list(&expr_data_type, &list_data_types);
Expand All @@ -286,11 +284,11 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
),
Some(coerced_type) => {
// find the coerced type
let cast_expr = expr.cast_to(&coerced_type, &self.schema)?;
let cast_expr = expr.cast_to(&coerced_type, self.schema)?;
let cast_list_expr = list
.into_iter()
.map(|list_expr| {
list_expr.cast_to(&coerced_type, &self.schema)
list_expr.cast_to(&coerced_type, self.schema)
})
.collect::<Result<Vec<_>>>()?;
Ok(Transformed::yes(Expr::InList(InList ::new(
Expand All @@ -302,18 +300,17 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
}
}
Expr::Case(case) => {
let case = coerce_case_expression(case, &self.schema)?;
let case = coerce_case_expression(case, self.schema)?;
Ok(Transformed::yes(Expr::Case(case)))
}
Expr::ScalarFunction(ScalarFunction { func_def, args }) => match func_def {
ScalarFunctionDefinition::UDF(fun) => {
let new_expr = coerce_arguments_for_signature(
args,
&self.schema,
self.schema,
fun.signature(),
)?;
let new_expr =
coerce_arguments_for_fun(new_expr, &self.schema, &fun)?;
let new_expr = coerce_arguments_for_fun(new_expr, self.schema, &fun)?;
Ok(Transformed::yes(Expr::ScalarFunction(
ScalarFunction::new_udf(fun, new_expr),
)))
Expand All @@ -331,7 +328,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
let new_expr = coerce_agg_exprs_for_signature(
&fun,
args,
&self.schema,
self.schema,
&fun.signature(),
)?;
Ok(Transformed::yes(Expr::AggregateFunction(
Expand All @@ -348,7 +345,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
AggregateFunctionDefinition::UDF(fun) => {
let new_expr = coerce_arguments_for_signature(
args,
&self.schema,
self.schema,
fun.signature(),
)?;
Ok(Transformed::yes(Expr::AggregateFunction(
Expand All @@ -375,14 +372,14 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
null_treatment,
}) => {
let window_frame =
coerce_window_frame(window_frame, &self.schema, &order_by)?;
coerce_window_frame(window_frame, self.schema, &order_by)?;

let args = match &fun {
expr::WindowFunctionDefinition::AggregateFunction(fun) => {
coerce_agg_exprs_for_signature(
fun,
args,
&self.schema,
self.schema,
&fun.signature(),
)?
}
Expand Down Expand Up @@ -495,7 +492,7 @@ fn coerce_frame_bound(
// For example, ROWS and GROUPS frames use `UInt64` during calculations.
fn coerce_window_frame(
window_frame: WindowFrame,
schema: &DFSchemaRef,
schema: &DFSchema,
expressions: &[Expr],
) -> Result<WindowFrame> {
let mut window_frame = window_frame;
Expand Down Expand Up @@ -531,7 +528,7 @@ fn coerce_window_frame(

// Support the `IsTrue` `IsNotTrue` `IsFalse` `IsNotFalse` type coercion.
// The above op will be rewrite to the binary op when creating the physical op.
fn get_casted_expr_for_bool_op(expr: Expr, schema: &DFSchemaRef) -> Result<Expr> {
fn get_casted_expr_for_bool_op(expr: Expr, schema: &DFSchema) -> Result<Expr> {
let left_type = expr.get_type(schema)?;
get_input_types(&left_type, &Operator::IsDistinctFrom, &DataType::Boolean)?;
expr.cast_to(&DataType::Boolean, schema)
Expand Down Expand Up @@ -615,7 +612,7 @@ fn coerce_agg_exprs_for_signature(
.collect()
}

fn coerce_case_expression(case: Case, schema: &DFSchemaRef) -> Result<Case> {
fn coerce_case_expression(case: Case, schema: &DFSchema) -> Result<Case> {
// Given expressions like:
//
// CASE a1
Expand Down Expand Up @@ -1238,7 +1235,7 @@ mod test {
vec![Field::new("a", DataType::Int64, true)].into(),
std::collections::HashMap::new(),
)?);
let mut rewriter = TypeCoercionRewriter { schema };
let mut rewriter = TypeCoercionRewriter { schema: &schema };
let expr = is_true(lit(12i32).gt(lit(13i64)));
let expected = is_true(cast(lit(12i32), DataType::Int64).gt(lit(13i64)));
let result = expr.rewrite(&mut rewriter).data()?;
Expand All @@ -1249,7 +1246,7 @@ mod test {
vec![Field::new("a", DataType::Int64, true)].into(),
std::collections::HashMap::new(),
)?);
let mut rewriter = TypeCoercionRewriter { schema };
let mut rewriter = TypeCoercionRewriter { schema: &schema };
let expr = is_true(lit(12i32).eq(lit(13i64)));
let expected = is_true(cast(lit(12i32), DataType::Int64).eq(lit(13i64)));
let result = expr.rewrite(&mut rewriter).data()?;
Expand All @@ -1260,7 +1257,7 @@ mod test {
vec![Field::new("a", DataType::Int64, true)].into(),
std::collections::HashMap::new(),
)?);
let mut rewriter = TypeCoercionRewriter { schema };
let mut rewriter = TypeCoercionRewriter { schema: &schema };
let expr = is_true(lit(12i32).lt(lit(13i64)));
let expected = is_true(cast(lit(12i32), DataType::Int64).lt(lit(13i64)));
let result = expr.rewrite(&mut rewriter).data()?;
Expand Down
20 changes: 4 additions & 16 deletions datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,7 @@ use datafusion_common::{
cast::{as_large_list_array, as_list_array},
tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter},
};
use datafusion_common::{
internal_err, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue,
};
use datafusion_common::{internal_err, DFSchema, DataFusionError, Result, ScalarValue};
use datafusion_expr::expr::{InList, InSubquery};
use datafusion_expr::simplify::ExprSimplifyResult;
use datafusion_expr::{
Expand Down Expand Up @@ -208,14 +206,8 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
///
/// See the [type coercion module](datafusion_expr::type_coercion)
/// documentation for more details on type coercion
///
// Would be nice if this API could use the SimplifyInfo
// rather than creating an DFSchemaRef coerces rather than doing
// it manually.
// https://github.com/apache/datafusion/issues/3793
pub fn coerce(&self, expr: Expr, schema: DFSchemaRef) -> Result<Expr> {
pub fn coerce(&self, expr: Expr, schema: &DFSchema) -> Result<Expr> {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the actual change (I don't think making it take SimplifyInfo is practical given how much coerce does)

let mut expr_rewrite = TypeCoercionRewriter { schema };

expr.rewrite(&mut expr_rewrite).data()
}

Expand Down Expand Up @@ -1686,7 +1678,7 @@ mod tests {
sync::Arc,
};

use datafusion_common::{assert_contains, ToDFSchema};
use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema};
use datafusion_expr::{interval_arithmetic::Interval, *};

use crate::simplify_expressions::SimplifyContext;
Expand Down Expand Up @@ -1721,11 +1713,7 @@ mod tests {
// should fully simplify to 3 < i (though i has been coerced to i64)
let expected = lit(3i64).lt(col("i"));

// Would be nice if this API could use the SimplifyInfo
// rather than creating an DFSchemaRef coerces rather than doing
// it manually.
// https://github.com/apache/datafusion/issues/3793
let expr = simplifier.coerce(expr, schema).unwrap();
let expr = simplifier.coerce(expr, &schema).unwrap();

assert_eq!(expected, simplifier.simplify(expr).unwrap());
}
Expand Down