From a12007b72c778dbb70739b9e7384a61e93762360 Mon Sep 17 00:00:00 2001 From: Costi Ciudatu Date: Sat, 14 Dec 2024 00:05:19 +0200 Subject: [PATCH] [substrait] Add support for ExtensionTable --- datafusion/core/src/execution/context/mod.rs | 24 +-- datafusion/expr/src/registry.rs | 40 ++++- .../substrait/src/logical_plan/consumer.rs | 32 +++- .../substrait/src/logical_plan/producer.rs | 140 +++++++++++++++++- 4 files changed, 199 insertions(+), 37 deletions(-) diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 67236c9a6bd2..e55e0805111e 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -63,7 +63,7 @@ use datafusion_expr::{ expr_rewriter::FunctionRewrite, logical_plan::{DdlStatement, Statement}, planner::ExprPlanner, - Expr, UserDefinedLogicalNode, WindowUDF, + Expr, WindowUDF, }; // backwards compatibility @@ -1679,27 +1679,7 @@ pub enum RegisterFunction { #[derive(Debug)] pub struct EmptySerializerRegistry; -impl SerializerRegistry for EmptySerializerRegistry { - fn serialize_logical_plan( - &self, - node: &dyn UserDefinedLogicalNode, - ) -> Result> { - not_impl_err!( - "Serializing user defined logical plan node `{}` is not supported", - node.name() - ) - } - - fn deserialize_logical_plan( - &self, - name: &str, - _bytes: &[u8], - ) -> Result> { - not_impl_err!( - "Deserializing user defined logical plan node `{name}` is not supported" - ) - } -} +impl SerializerRegistry for EmptySerializerRegistry {} /// Describes which SQL statements can be run. /// diff --git a/datafusion/expr/src/registry.rs b/datafusion/expr/src/registry.rs index 4eb49710bcf8..588181b14421 100644 --- a/datafusion/expr/src/registry.rs +++ b/datafusion/expr/src/registry.rs @@ -19,7 +19,7 @@ use crate::expr_rewriter::FunctionRewrite; use crate::planner::ExprPlanner; -use crate::{AggregateUDF, ScalarUDF, UserDefinedLogicalNode, WindowUDF}; +use crate::{AggregateUDF, ScalarUDF, TableSource, UserDefinedLogicalNode, WindowUDF}; use datafusion_common::{not_impl_err, plan_datafusion_err, HashMap, Result}; use std::collections::HashSet; use std::fmt::Debug; @@ -123,22 +123,52 @@ pub trait FunctionRegistry { } } -/// Serializer and deserializer registry for extensions like [UserDefinedLogicalNode]. +/// Serializer and deserializer registry for extensions like [UserDefinedLogicalNode] +/// and custom table providers for which the name alone is meaningless in the target +/// execution context, e.g. UDTFs, manually registered tables etc. pub trait SerializerRegistry: Debug + Send + Sync { /// Serialize this node to a byte array. This serialization should not include /// input plans. fn serialize_logical_plan( &self, node: &dyn UserDefinedLogicalNode, - ) -> Result>; + ) -> Result> { + not_impl_err!( + "Serializing user defined logical plan node `{}` is not supported", + node.name() + ) + } /// Deserialize user defined logical plan node ([UserDefinedLogicalNode]) from /// bytes. fn deserialize_logical_plan( &self, name: &str, - bytes: &[u8], - ) -> Result>; + _bytes: &[u8], + ) -> Result> { + not_impl_err!( + "Deserializing user defined logical plan node `{name}` is not supported" + ) + } + + /// Serialized table definition for UDTFs or manually registered table providers that can't be + /// marshaled by reference. Should return some benign error for regular tables that can be + /// found/restored by name in the destination execution context. + fn serialize_custom_table(&self, _table: &dyn TableSource) -> Result> { + not_impl_err!("No custom table support") + } + + /// Deserialize the custom table with the given name. + /// Note: more often than not, the name can't be used as a discriminator if multiple different + /// `TableSource` and/or `TableProvider` implementations are expected (this is particularly true + /// for UDTFs in DataFusion, which are always registered under the same name: `tmp_table`). + fn deserialize_custom_table( + &self, + name: &str, + _bytes: &[u8], + ) -> Result> { + not_impl_err!("Deserializing custom table `{name}` is not supported") + } } /// A [`FunctionRegistry`] that uses in memory [`HashMap`]s diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index a9e411e35ae8..d1f12db33d78 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -31,7 +31,7 @@ use datafusion::logical_expr::expr::{Exists, InSubquery, Sort}; use datafusion::logical_expr::{ Aggregate, BinaryExpr, Case, EmptyRelation, Expr, ExprSchemable, LogicalPlan, - Operator, Projection, SortExpr, TryCast, Values, + Operator, Projection, SortExpr, TableScan, TryCast, Values, }; use substrait::proto::aggregate_rel::Grouping; use substrait::proto::expression::subquery::set_predicate::PredicateOp; @@ -994,8 +994,34 @@ pub async fn from_substrait_rel( ) .await } - _ => { - not_impl_err!("Unsupported ReadType: {:?}", &read.as_ref().read_type) + Some(ReadType::ExtensionTable(ext)) => { + if let Some(ext_detail) = &ext.detail { + let source = + state.serializer_registry().deserialize_custom_table( + &ext_detail.type_url, + &ext_detail.value, + )?; + let table_name = ext_detail + .type_url + .rsplit_once('/') + .map(|(_, name)| name) + .unwrap_or(&ext_detail.type_url); + let plan = LogicalPlan::TableScan(TableScan::try_new( + table_name, + source, + None, + vec![], + None, + )?); + let schema = apply_masking(substrait_schema, &read.projection)?; + ensure_schema_compatability(plan.schema(), schema.clone())?; + apply_projection(plan, schema) + } else { + substrait_err!("Unexpected empty detail in ExtensionTable") + } + } + None => { + substrait_err!("Unexpected empty read_type") } } } diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index a128b90e6889..479f72fd0375 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -65,7 +65,7 @@ use substrait::proto::expression::literal::{ }; use substrait::proto::expression::subquery::InPredicate; use substrait::proto::expression::window_function::BoundsType; -use substrait::proto::read_rel::VirtualTable; +use substrait::proto::read_rel::{ExtensionTable, VirtualTable}; use substrait::proto::rel_common::EmitKind; use substrait::proto::rel_common::EmitKind::Emit; use substrait::proto::{ @@ -212,6 +212,23 @@ pub fn to_substrait_rel( let table_schema = scan.source.schema().to_dfschema_ref()?; let base_schema = to_substrait_named_struct(&table_schema)?; + let table = if let Ok(bytes) = state + .serializer_registry() + .serialize_custom_table(scan.source.as_ref()) + { + ReadType::ExtensionTable(ExtensionTable { + detail: Some(ProtoAny { + type_url: scan.table_name.to_string(), + value: bytes.into(), + }), + }) + } else { + ReadType::NamedTable(NamedTable { + names: scan.table_name.to_vec(), + advanced_extension: None, + }) + }; + Ok(Box::new(Rel { rel_type: Some(RelType::Read(Box::new(ReadRel { common: None, @@ -220,10 +237,7 @@ pub fn to_substrait_rel( best_effort_filter: None, projection, advanced_extension: None, - read_type: Some(ReadType::NamedTable(NamedTable { - names: scan.table_name.to_vec(), - advanced_extension: None, - })), + read_type: Some(table), }))), })) } @@ -2204,7 +2218,8 @@ mod test { use super::*; use crate::logical_plan::consumer::{ from_substrait_extended_expr, from_substrait_literal_without_names, - from_substrait_named_struct, from_substrait_type_without_names, + from_substrait_named_struct, from_substrait_plan, + from_substrait_type_without_names, }; use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; use datafusion::arrow::array::{ @@ -2212,8 +2227,13 @@ mod test { }; use datafusion::arrow::datatypes::{Field, Fields, Schema}; use datafusion::common::scalar::ScalarStructBuilder; - use datafusion::common::DFSchema; + use datafusion::common::{assert_contains, DFSchema}; + use datafusion::datasource::empty::EmptyTable; + use datafusion::datasource::{DefaultTableSource, TableProvider}; + use datafusion::execution::registry::SerializerRegistry; use datafusion::execution::SessionStateBuilder; + use datafusion::logical_expr::TableSource; + use datafusion::prelude::SessionContext; #[test] fn round_trip_literals() -> Result<()> { @@ -2540,4 +2560,110 @@ mod test { assert!(matches!(err, Err(DataFusionError::SchemaError(_, _)))); } + + #[tokio::test] + async fn round_trip_extension_table() { + const TABLE_NAME: &str = "custom_table"; + const SERIALIZED: &[u8] = "table definition".as_bytes(); + + fn custom_table() -> Arc { + Arc::new(EmptyTable::new(Arc::new(Schema::new([ + Arc::new(Field::new("id", DataType::Int32, false)), + Arc::new(Field::new("name", DataType::Utf8, false)), + ])))) + } + + #[derive(Debug)] + struct Registry; + impl SerializerRegistry for Registry { + fn serialize_custom_table(&self, table: &dyn TableSource) -> Result> { + if table.schema() == custom_table().schema() { + Ok(SERIALIZED.to_vec()) + } else { + Err(DataFusionError::Internal("Not our table".into())) + } + } + fn deserialize_custom_table( + &self, + name: &str, + bytes: &[u8], + ) -> Result> { + if name == TABLE_NAME && bytes == SERIALIZED { + Ok(Arc::new(DefaultTableSource::new(custom_table()))) + } else { + panic!("Unexpected extension table: {name}"); + } + } + } + + async fn round_trip_logical_plans( + local: &SessionContext, + remote: &SessionContext, + ) -> Result<()> { + local.register_table(TABLE_NAME, custom_table())?; + remote.table_provider(TABLE_NAME).await.expect_err( + "The remote context is not supposed to know about custom_table", + ); + let initial_plan = local + .sql(&format!("select id from {TABLE_NAME}")) + .await? + .logical_plan() + .clone(); + + // write substrait locally + let substrait = to_substrait_plan(&initial_plan, &local.state())?; + + // read substrait remotely + // since we know there's no `custom_table` registered in the remote context, this will only succeed + // if our table got encoded as an ExtensionTable and is now decoded back to a table source. + let restored = from_substrait_plan(&remote.state(), &substrait).await?; + assert_contains!( + // confirm that the Substrait plan contains our custom_table as an ExtensionTable + serde_json::to_string(substrait.as_ref()).unwrap(), + format!(r#""extensionTable":{{"detail":{{"typeUrl":"{TABLE_NAME}","#) + ); + remote // make sure the restored plan is fully working in the remote context + .execute_logical_plan(restored.clone()) + .await? + .collect() + .await + .expect("Restored plan cannot be executed remotely"); + assert_eq!( + // check that the restored plan is functionally equivalent (and almost identical) to the initial one + initial_plan.to_string(), + restored.to_string().replace( + // substrait will add an explicit full-schema projection if the original table had none + &format!("TableScan: {TABLE_NAME} projection=[id, name]"), + &format!("TableScan: {TABLE_NAME}"), + ) + ); + Ok(()) + } + + // take 1 + let failed_attempt = + round_trip_logical_plans(&SessionContext::new(), &SessionContext::new()) + .await + .expect_err( + "The round trip should fail in the absence of a SerializerRegistry", + ); + assert_contains!( + failed_attempt.message(), + format!("No table named '{TABLE_NAME}'") + ); + + // take 2 + fn proper_context() -> SessionContext { + SessionContext::new_with_state( + SessionStateBuilder::new() + // This will transport our custom_table as a Substrait ExtensionTable + .with_serializer_registry(Arc::new(Registry)) + .build(), + ) + } + + round_trip_logical_plans(&proper_context(), &proper_context()) + .await + .expect("Local plan could not be restored remotely"); + } }