Skip to content

Commit

Permalink
Add SessionContext/SessionState::create_physical_expr() to create…
Browse files Browse the repository at this point in the history
… `PhysicalExpressions` from `Expr`s (apache#10330)

* Improve coerce API so it does not need DFSchema

* Add `SessionContext::create_physical_expr()` and `SessionState::create_physical_expr()`

* Apply suggestions from code review

Co-authored-by: Weston Pace <[email protected]>

* Add note on simplification

---------

Co-authored-by: Weston Pace <[email protected]>
  • Loading branch information
alamb and westonpace authored May 7, 2024
1 parent f0e96c6 commit c8b8c74
Show file tree
Hide file tree
Showing 6 changed files with 337 additions and 28 deletions.
34 changes: 11 additions & 23 deletions datafusion-examples/examples/expr_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@ use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit};
use datafusion::common::DFSchema;
use datafusion::error::Result;
use datafusion::optimizer::simplify_expressions::ExprSimplifier;
use datafusion::physical_expr::{
analyze, create_physical_expr, AnalysisContext, ExprBoundaries, PhysicalExpr,
};
use datafusion::physical_expr::{analyze, AnalysisContext, ExprBoundaries};
use datafusion::prelude::*;
use datafusion_common::{ScalarValue, ToDFSchema};
use datafusion_expr::execution_props::ExecutionProps;
Expand Down Expand Up @@ -92,7 +90,8 @@ fn evaluate_demo() -> Result<()> {
let expr = col("a").lt(lit(5)).or(col("a").eq(lit(8)));

// First, you make a "physical expression" from the logical `Expr`
let physical_expr = physical_expr(&batch.schema(), expr)?;
let df_schema = DFSchema::try_from(batch.schema())?;
let physical_expr = SessionContext::new().create_physical_expr(expr, &df_schema)?;

// Now, you can evaluate the expression against the RecordBatch
let result = physical_expr.evaluate(&batch)?;
Expand Down Expand Up @@ -213,7 +212,7 @@ fn range_analysis_demo() -> Result<()> {
// `date < '2020-10-01' AND date > '2020-09-01'`

// As always, we need to tell DataFusion the type of column "date"
let schema = Schema::new(vec![make_field("date", DataType::Date32)]);
let schema = Arc::new(Schema::new(vec![make_field("date", DataType::Date32)]));

// You can provide DataFusion any known boundaries on the values of `date`
// (for example, maybe you know you only have data up to `2020-09-15`), but
Expand All @@ -222,9 +221,13 @@ fn range_analysis_demo() -> Result<()> {
let boundaries = ExprBoundaries::try_new_unbounded(&schema)?;

// Now, we invoke the analysis code to perform the range analysis
let physical_expr = physical_expr(&schema, expr)?;
let analysis_result =
analyze(&physical_expr, AnalysisContext::new(boundaries), &schema)?;
let df_schema = DFSchema::try_from(schema)?;
let physical_expr = SessionContext::new().create_physical_expr(expr, &df_schema)?;
let analysis_result = analyze(
&physical_expr,
AnalysisContext::new(boundaries),
df_schema.as_ref(),
)?;

// The results of the analysis is an range, encoded as an `Interval`, for
// each column in the schema, that must be true in order for the predicate
Expand All @@ -248,21 +251,6 @@ fn make_ts_field(name: &str) -> Field {
make_field(name, DataType::Timestamp(TimeUnit::Nanosecond, tz))
}

/// Build a physical expression from a logical one, after applying simplification and type coercion
pub fn physical_expr(schema: &Schema, expr: Expr) -> Result<Arc<dyn PhysicalExpr>> {
let df_schema = schema.clone().to_dfschema_ref()?;

// Simplify
let props = ExecutionProps::new();
let simplifier =
ExprSimplifier::new(SimplifyContext::new(&props).with_schema(df_schema.clone()));

// apply type coercion here to ensure types match
let expr = simplifier.coerce(expr, &df_schema)?;

create_physical_expr(&expr, df_schema.as_ref(), &props)
}

/// This function shows how to use `Expr::get_type` to retrieve the DataType
/// of an expression
fn expression_type_demo() -> Result<()> {
Expand Down
29 changes: 29 additions & 0 deletions datafusion/common/src/dfschema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,20 @@ impl DFSchema {
}
}

/// Return a reference to the inner Arrow [`Schema`]
///
/// Note this does not have the qualifier information
pub fn as_arrow(&self) -> &Schema {
self.inner.as_ref()
}

/// Return a reference to the inner Arrow [`SchemaRef`]
///
/// Note this does not have the qualifier information
pub fn inner(&self) -> &SchemaRef {
&self.inner
}

/// Create a `DFSchema` from an Arrow schema where all the fields have a given qualifier
pub fn new_with_metadata(
qualified_fields: Vec<(Option<TableReference>, Arc<Field>)>,
Expand Down Expand Up @@ -806,6 +820,21 @@ impl From<&DFSchema> for Schema {
}
}

/// Allow DFSchema to be converted into an Arrow `&Schema`
impl AsRef<Schema> for DFSchema {
fn as_ref(&self) -> &Schema {
self.as_arrow()
}
}

/// Allow DFSchema to be converted into an Arrow `&SchemaRef` (to clone, for
/// example)
impl AsRef<SchemaRef> for DFSchema {
fn as_ref(&self) -> &SchemaRef {
self.inner()
}
}

/// Create a `DFSchema` from an Arrow schema
impl TryFrom<Schema> for DFSchema {
type Error = DataFusionError;
Expand Down
113 changes: 108 additions & 5 deletions datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,13 @@ use datafusion_common::{
config::{ConfigExtension, TableOptions},
exec_err, not_impl_err, plan_datafusion_err, plan_err,
tree_node::{TreeNodeRecursion, TreeNodeVisitor},
SchemaReference, TableReference,
DFSchema, SchemaReference, TableReference,
};
use datafusion_execution::registry::SerializerRegistry;
use datafusion_expr::{
logical_plan::{DdlStatement, Statement},
var_provider::is_system_variables,
Expr, StringifiedPlan, UserDefinedLogicalNode, WindowUDF,
Expr, ExprSchemable, StringifiedPlan, UserDefinedLogicalNode, WindowUDF,
};
use datafusion_sql::{
parser::{CopyToSource, CopyToStatement, DFParser},
Expand All @@ -87,15 +87,20 @@ use datafusion_sql::{

use async_trait::async_trait;
use chrono::{DateTime, Utc};
use datafusion_common::tree_node::TreeNode;
use parking_lot::RwLock;
use sqlparser::dialect::dialect_from_str;
use url::Url;
use uuid::Uuid;

use crate::physical_expr::PhysicalExpr;
pub use datafusion_execution::config::SessionConfig;
pub use datafusion_execution::TaskContext;
pub use datafusion_expr::execution_props::ExecutionProps;
use datafusion_expr::expr_rewriter::FunctionRewrite;
use datafusion_expr::simplify::SimplifyInfo;
use datafusion_optimizer::simplify_expressions::ExprSimplifier;
use datafusion_physical_expr::create_physical_expr;

mod avro;
mod csv;
Expand Down Expand Up @@ -523,6 +528,41 @@ impl SessionContext {
}
}

/// Create a [`PhysicalExpr`] from an [`Expr`] after applying type
/// coercion and function rewrites.
///
/// Note: The expression is not [simplified] or otherwise optimized: `a = 1
/// + 2` will not be simplified to `a = 3` as this is a more involved process.
/// See the [expr_api] example for how to simplify expressions.
///
/// # Example
/// ```
/// # use std::sync::Arc;
/// # use arrow::datatypes::{DataType, Field, Schema};
/// # use datafusion::prelude::*;
/// # use datafusion_common::DFSchema;
/// // a = 1 (i64)
/// let expr = col("a").eq(lit(1i64));
/// // provide type information that `a` is an Int32
/// let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
/// let df_schema = DFSchema::try_from(schema).unwrap();
/// // Create a PhysicalExpr. Note DataFusion automatically coerces (casts) `1i64` to `1i32`
/// let physical_expr = SessionContext::new()
/// .create_physical_expr(expr, &df_schema).unwrap();
/// ```
/// # See Also
/// * [`SessionState::create_physical_expr`] for a lower level API
///
/// [simplified]: datafusion_optimizer::simplify_expressions
/// [expr_api]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/expr_api.rs
pub fn create_physical_expr(
&self,
expr: Expr,
df_schema: &DFSchema,
) -> Result<Arc<dyn PhysicalExpr>> {
self.state.read().create_physical_expr(expr, df_schema)
}

// return an empty dataframe
fn return_empty_dataframe(&self) -> Result<DataFrame> {
let plan = LogicalPlanBuilder::empty(false).build()?;
Expand Down Expand Up @@ -1946,13 +1986,14 @@ impl SessionState {
}
}

/// Creates a physical plan from a logical plan.
/// Creates a physical [`ExecutionPlan`] plan from a [`LogicalPlan`].
///
/// Note: this first calls [`Self::optimize`] on the provided
/// plan.
///
/// This function will error for [`LogicalPlan`]s such as catalog
/// DDL `CREATE TABLE` must be handled by another layer.
/// This function will error for [`LogicalPlan`]s such as catalog DDL like
/// `CREATE TABLE`, which do not have corresponding physical plans and must
/// be handled by another layer, typically [`SessionContext`].
pub async fn create_physical_plan(
&self,
logical_plan: &LogicalPlan,
Expand All @@ -1963,6 +2004,39 @@ impl SessionState {
.await
}

/// Create a [`PhysicalExpr`] from an [`Expr`] after applying type
/// coercion, and function rewrites.
///
/// Note: The expression is not [simplified] or otherwise optimized: `a = 1
/// + 2` will not be simplified to `a = 3` as this is a more involved process.
/// See the [expr_api] example for how to simplify expressions.
///
/// # See Also:
/// * [`SessionContext::create_physical_expr`] for a higher-level API
/// * [`create_physical_expr`] for a lower-level API
///
/// [simplified]: datafusion_optimizer::simplify_expressions
/// [expr_api]: https://github.com/apache/datafusion/blob/main/datafusion-examples/examples/expr_api.rs
pub fn create_physical_expr(
&self,
expr: Expr,
df_schema: &DFSchema,
) -> Result<Arc<dyn PhysicalExpr>> {
let simplifier =
ExprSimplifier::new(SessionSimplifyProvider::new(self, df_schema));
// apply type coercion here to ensure types match
let mut expr = simplifier.coerce(expr, df_schema)?;

// rewrite Exprs to functions if necessary
let config_options = self.config_options();
for rewrite in self.analyzer.function_rewrites() {
expr = expr
.transform_up(|expr| rewrite.rewrite(expr, df_schema, config_options))?
.data;
}
create_physical_expr(&expr, df_schema, self.execution_props())
}

/// Return the session ID
pub fn session_id(&self) -> &str {
&self.session_id
Expand Down Expand Up @@ -2040,6 +2114,35 @@ impl SessionState {
}
}

struct SessionSimplifyProvider<'a> {
state: &'a SessionState,
df_schema: &'a DFSchema,
}

impl<'a> SessionSimplifyProvider<'a> {
fn new(state: &'a SessionState, df_schema: &'a DFSchema) -> Self {
Self { state, df_schema }
}
}

impl<'a> SimplifyInfo for SessionSimplifyProvider<'a> {
fn is_boolean_type(&self, expr: &Expr) -> Result<bool> {
Ok(expr.get_type(self.df_schema)? == DataType::Boolean)
}

fn nullable(&self, expr: &Expr) -> Result<bool> {
expr.nullable(self.df_schema)
}

fn execution_props(&self) -> &ExecutionProps {
self.state.execution_props()
}

fn get_data_type(&self, expr: &Expr) -> Result<DataType> {
expr.get_type(self.df_schema)
}
}

struct SessionContextProvider<'a> {
state: &'a SessionState,
tables: HashMap<String, Arc<dyn TableSource>>,
Expand Down
3 changes: 3 additions & 0 deletions datafusion/core/tests/core_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ mod dataframe;
/// Run all tests that are found in the `macro_hygiene` directory
mod macro_hygiene;

/// Run all tests that are found in the `expr_api` directory
mod expr_api;

#[cfg(test)]
#[ctor::ctor]
fn init() {
Expand Down
Loading

0 comments on commit c8b8c74

Please sign in to comment.