Skip to content

Commit

Permalink
feat: add extension for logical_plan_builder
Browse files Browse the repository at this point in the history
  • Loading branch information
zhuliquan committed Sep 2, 2024
1 parent 88dd305 commit e825bc1
Showing 1 changed file with 105 additions and 0 deletions.
105 changes: 105 additions & 0 deletions datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
use std::any::Any;
use std::cmp::Ordering;
use std::collections::{HashMap, HashSet};
use std::ops::Deref;
use std::sync::Arc;

use crate::dml::CopyTo;
Expand Down Expand Up @@ -54,6 +55,8 @@ use datafusion_common::{
TableReference, ToDFSchema, UnnestOptions,
};

use super::{Extension, UserDefinedLogicalNode};

/// Default table name for unnamed table
pub const UNNAMED_TABLE: &str = "?table?";

Expand Down Expand Up @@ -1175,6 +1178,28 @@ impl LogicalPlanBuilder {
unnest_with_options(Arc::unwrap_or_clone(self.plan), columns, options)
.map(Self::new)
}

/// Apply a extension logical plan
/// arguemnts:
/// - other_exprs: other expressions exclude the extension node's exprs
/// - other_inputs: other inputs exclude self.plan
/// - extension: the extension node
pub fn extension(
self,
other_exprs: Vec<Expr>,
other_inputs: Vec<LogicalPlan>,
extension: Arc<dyn UserDefinedLogicalNode>,
) -> Result<Self> {
extension
.with_exprs_and_inputs(
vec![extension.expressions(), other_exprs].concat(),
vec![vec![self.plan.deref().clone()], other_inputs].concat(),
)
.map(|extension| {
LogicalPlan::Extension(Extension{node: extension})
})
.map(Self::new)
}
}

impl From<LogicalPlan> for LogicalPlanBuilder {
Expand Down Expand Up @@ -1659,6 +1684,7 @@ pub fn unnest_with_options(
mod tests {
use super::*;
use crate::logical_plan::StringifiedPlan;
use crate::UserDefinedLogicalNodeCore;
use crate::{col, expr, expr_fn::exists, in_subquery, lit, scalar_subquery};

use datafusion_common::SchemaError;
Expand Down Expand Up @@ -2158,4 +2184,83 @@ mod tests {

Ok(())
}

#[derive(PartialEq, Eq, Hash)]
struct TestExtensionNode {
input: LogicalPlan,
}

impl std::fmt::Debug for TestExtensionNode {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
UserDefinedLogicalNodeCore::fmt_for_explain(self, f)
}
}

impl UserDefinedLogicalNodeCore for TestExtensionNode {
fn name(&self) -> &str {
"TestExtensionNode"
}

fn inputs(&self) -> Vec<&LogicalPlan> {
vec![&self.input]
}

/// Schema for TopK is the same as the input
fn schema(&self) -> &DFSchemaRef {
self.input.schema()
}

fn expressions(&self) -> Vec<Expr> {
vec![]
}

fn fmt_for_explain(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "TestExtensionNode: {:?}", self.input
.schema()
.fields()
.iter()
.map(|f| f.name())
.collect::<Vec<&String>>()
)
}

fn with_exprs_and_inputs(
&self,
_exprs: Vec<Expr>,
mut inputs: Vec<LogicalPlan>,
) -> Result<Self> {
Ok(Self {
input: inputs.swap_remove(0),
})
}
}

#[test]
fn plan_builder_from_extension() {
let node = TestExtensionNode {
input: LogicalPlan::EmptyRelation(EmptyRelation{
produce_one_row: false,
schema: Arc::new(DFSchema::empty()),
}),
};

let plan = table_scan(
Some("test_table"),
&Schema::new(vec![
Field::new("c1", DataType::UInt32, false),
Field::new("c2", DataType::UInt32, false),
]),
None,
)
.unwrap()
.extension(vec![], vec![], Arc::new(node))
.unwrap()
.build()
.unwrap();

let expected = "\
TestExtensionNode: [\"c1\", \"c2\"]\
\n TableScan: test_table";
assert_eq!(expected, format!("{plan}"));
}
}

0 comments on commit e825bc1

Please sign in to comment.