From cd2b661d0301a399c9737dc5dbc126d7fc223ea4 Mon Sep 17 00:00:00 2001 From: zhuliquan Date: Mon, 2 Sep 2024 21:00:39 +0800 Subject: [PATCH] feat: add extension for logical_plan_builder --- datafusion/expr/src/logical_plan/builder.rs | 106 ++++++++++++++++++++ 1 file changed, 106 insertions(+) diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index fc961b83f7b5..72d7a0d078dd 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -20,6 +20,7 @@ use std::any::Any; use std::cmp::Ordering; use std::collections::{HashMap, HashSet}; +use std::ops::Deref; use std::sync::Arc; use crate::dml::CopyTo; @@ -54,6 +55,8 @@ use datafusion_common::{ TableReference, ToDFSchema, UnnestOptions, }; +use super::{Extension, UserDefinedLogicalNode}; + /// Default table name for unnamed table pub const UNNAMED_TABLE: &str = "?table?"; @@ -1175,6 +1178,26 @@ impl LogicalPlanBuilder { unnest_with_options(Arc::unwrap_or_clone(self.plan), columns, options) .map(Self::new) } + + /// Apply a extension logical plan + /// arguemnts: + /// - other_exprs: other expressions exclude the extension node's exprs + /// - other_inputs: other inputs exclude self.plan + /// - extension: the extension node + pub fn extension( + self, + other_exprs: Vec, + other_inputs: Vec, + extension: Arc, + ) -> Result { + extension + .with_exprs_and_inputs( + [extension.expressions(), other_exprs].concat(), + [vec![self.plan.deref().clone()], other_inputs].concat(), + ) + .map(|extension| LogicalPlan::Extension(Extension { node: extension })) + .map(Self::new) + } } impl From for LogicalPlanBuilder { @@ -1659,6 +1682,7 @@ pub fn unnest_with_options( mod tests { use super::*; use crate::logical_plan::StringifiedPlan; + use crate::UserDefinedLogicalNodeCore; use crate::{col, expr, expr_fn::exists, in_subquery, lit, scalar_subquery}; use datafusion_common::SchemaError; @@ -2158,4 +2182,86 @@ mod tests { Ok(()) } + + #[derive(PartialEq, Eq, Hash)] + struct TestExtensionNode { + input: LogicalPlan, + } + + impl std::fmt::Debug for TestExtensionNode { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + UserDefinedLogicalNodeCore::fmt_for_explain(self, f) + } + } + + impl UserDefinedLogicalNodeCore for TestExtensionNode { + fn name(&self) -> &str { + "TestExtensionNode" + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + vec![&self.input] + } + + /// Schema for TopK is the same as the input + fn schema(&self) -> &DFSchemaRef { + self.input.schema() + } + + fn expressions(&self) -> Vec { + vec![] + } + + fn fmt_for_explain(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!( + f, + "TestExtensionNode: {:?}", + self.input + .schema() + .fields() + .iter() + .map(|f| f.name()) + .collect::>() + ) + } + + fn with_exprs_and_inputs( + &self, + _exprs: Vec, + mut inputs: Vec, + ) -> Result { + Ok(Self { + input: inputs.swap_remove(0), + }) + } + } + + #[test] + fn plan_builder_from_extension() { + let node = TestExtensionNode { + input: LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: Arc::new(DFSchema::empty()), + }), + }; + + let plan = table_scan( + Some("test_table"), + &Schema::new(vec![ + Field::new("c1", DataType::UInt32, false), + Field::new("c2", DataType::UInt32, false), + ]), + None, + ) + .unwrap() + .extension(vec![], vec![], Arc::new(node)) + .unwrap() + .build() + .unwrap(); + + let expected = "\ + TestExtensionNode: [\"c1\", \"c2\"]\ + \n TableScan: test_table"; + assert_eq!(expected, format!("{plan}")); + } }