diff --git a/datafusion/catalog/src/table.rs b/datafusion/catalog/src/table.rs index 6c36d907acc3..ca3a2bef882e 100644 --- a/datafusion/catalog/src/table.rs +++ b/datafusion/catalog/src/table.rs @@ -25,6 +25,7 @@ use arrow_schema::SchemaRef; use async_trait::async_trait; use datafusion_common::Result; use datafusion_common::{not_impl_err, Constraints, Statistics}; +use datafusion_expr::dml::InsertOp; use datafusion_expr::{ CreateExternalTable, Expr, LogicalPlan, TableProviderFilterPushDown, TableType, }; @@ -274,7 +275,7 @@ pub trait TableProvider: Debug + Sync + Send { &self, _state: &dyn Session, _input: Arc, - _overwrite: bool, + _insert_op: InsertOp, ) -> Result> { not_impl_err!("Insert into not implemented for this table") } diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index 116dab316bf5..5bf0f08b092a 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -291,6 +291,9 @@ pub(crate) fn parse_identifiers(s: &str) -> Result> { } /// Construct a new [`Vec`] of [`ArrayRef`] from the rows of the `arrays` at the `indices`. +/// +/// TODO: use implementation in arrow-rs when available: +/// pub fn take_arrays(arrays: &[ArrayRef], indices: &dyn Array) -> Result> { arrays .iter() diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 72b763ce0f2b..70c507511453 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -52,6 +52,7 @@ use datafusion_common::config::{CsvOptions, JsonOptions}; use datafusion_common::{ plan_err, Column, DFSchema, DataFusionError, ParamValues, SchemaError, UnnestOptions, }; +use datafusion_expr::dml::InsertOp; use datafusion_expr::{case, is_null, lit, SortExpr}; use datafusion_expr::{ utils::COUNT_STAR_EXPANSION, TableProviderFilterPushDown, UNNAMED_TABLE, @@ -66,8 +67,9 @@ use datafusion_catalog::Session; /// Contains options that control how data is /// written out from a DataFrame pub struct DataFrameWriteOptions { - /// Controls if existing data should be overwritten - overwrite: bool, + /// Controls how new data should be written to the table, determining whether + /// to append, overwrite, or replace existing data. + insert_op: InsertOp, /// Controls if all partitions should be coalesced into a single output file /// Generally will have slower performance when set to true. single_file_output: bool, @@ -80,14 +82,15 @@ impl DataFrameWriteOptions { /// Create a new DataFrameWriteOptions with default values pub fn new() -> Self { DataFrameWriteOptions { - overwrite: false, + insert_op: InsertOp::Append, single_file_output: false, partition_by: vec![], } } - /// Set the overwrite option to true or false - pub fn with_overwrite(mut self, overwrite: bool) -> Self { - self.overwrite = overwrite; + + /// Set the insert operation + pub fn with_insert_operation(mut self, insert_op: InsertOp) -> Self { + self.insert_op = insert_op; self } @@ -1525,7 +1528,7 @@ impl DataFrame { self.plan, table_name.to_owned(), &arrow_schema, - write_options.overwrite, + write_options.insert_op, )? .build()?; @@ -1566,10 +1569,11 @@ impl DataFrame { options: DataFrameWriteOptions, writer_options: Option, ) -> Result, DataFusionError> { - if options.overwrite { - return Err(DataFusionError::NotImplemented( - "Overwrites are not implemented for DataFrame::write_csv.".to_owned(), - )); + if options.insert_op != InsertOp::Append { + return Err(DataFusionError::NotImplemented(format!( + "{} is not implemented for DataFrame::write_csv.", + options.insert_op + ))); } let format = if let Some(csv_opts) = writer_options { @@ -1626,10 +1630,11 @@ impl DataFrame { options: DataFrameWriteOptions, writer_options: Option, ) -> Result, DataFusionError> { - if options.overwrite { - return Err(DataFusionError::NotImplemented( - "Overwrites are not implemented for DataFrame::write_json.".to_owned(), - )); + if options.insert_op != InsertOp::Append { + return Err(DataFusionError::NotImplemented(format!( + "{} is not implemented for DataFrame::write_json.", + options.insert_op + ))); } let format = if let Some(json_opts) = writer_options { diff --git a/datafusion/core/src/dataframe/parquet.rs b/datafusion/core/src/dataframe/parquet.rs index 66974e37f453..f90b35fde6ba 100644 --- a/datafusion/core/src/dataframe/parquet.rs +++ b/datafusion/core/src/dataframe/parquet.rs @@ -26,6 +26,7 @@ use super::{ }; use datafusion_common::config::TableParquetOptions; +use datafusion_expr::dml::InsertOp; impl DataFrame { /// Execute the `DataFrame` and write the results to Parquet file(s). @@ -57,10 +58,11 @@ impl DataFrame { options: DataFrameWriteOptions, writer_options: Option, ) -> Result, DataFusionError> { - if options.overwrite { - return Err(DataFusionError::NotImplemented( - "Overwrites are not implemented for DataFrame::write_parquet.".to_owned(), - )); + if options.insert_op != InsertOp::Append { + return Err(DataFusionError::NotImplemented(format!( + "{} is not implemented for DataFrame::write_parquet.", + options.insert_op + ))); } let format = if let Some(parquet_opts) = writer_options { diff --git a/datafusion/core/src/datasource/file_format/arrow.rs b/datafusion/core/src/datasource/file_format/arrow.rs index 6ee4280956e8..c10ebbd6c9ea 100644 --- a/datafusion/core/src/datasource/file_format/arrow.rs +++ b/datafusion/core/src/datasource/file_format/arrow.rs @@ -47,6 +47,7 @@ use datafusion_common::{ not_impl_err, DataFusionError, GetExt, Statistics, DEFAULT_ARROW_EXTENSION, }; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_expr::dml::InsertOp; use datafusion_physical_expr::PhysicalExpr; use datafusion_physical_plan::insert::{DataSink, DataSinkExec}; use datafusion_physical_plan::metrics::MetricsSet; @@ -181,7 +182,7 @@ impl FileFormat for ArrowFormat { conf: FileSinkConfig, order_requirements: Option, ) -> Result> { - if conf.overwrite { + if conf.insert_op != InsertOp::Append { return not_impl_err!("Overwrites are not implemented yet for Arrow format"); } diff --git a/datafusion/core/src/datasource/file_format/csv.rs b/datafusion/core/src/datasource/file_format/csv.rs index 99e8f13776fc..e821fa806fce 100644 --- a/datafusion/core/src/datasource/file_format/csv.rs +++ b/datafusion/core/src/datasource/file_format/csv.rs @@ -46,6 +46,7 @@ use datafusion_common::{ exec_err, not_impl_err, DataFusionError, GetExt, DEFAULT_CSV_EXTENSION, }; use datafusion_execution::TaskContext; +use datafusion_expr::dml::InsertOp; use datafusion_physical_expr::PhysicalExpr; use datafusion_physical_plan::metrics::MetricsSet; @@ -382,7 +383,7 @@ impl FileFormat for CsvFormat { conf: FileSinkConfig, order_requirements: Option, ) -> Result> { - if conf.overwrite { + if conf.insert_op != InsertOp::Append { return not_impl_err!("Overwrites are not implemented yet for CSV"); } diff --git a/datafusion/core/src/datasource/file_format/json.rs b/datafusion/core/src/datasource/file_format/json.rs index 4471d7d6cb31..c9ed0c0d2805 100644 --- a/datafusion/core/src/datasource/file_format/json.rs +++ b/datafusion/core/src/datasource/file_format/json.rs @@ -46,6 +46,7 @@ use datafusion_common::config::{ConfigField, ConfigFileType, JsonOptions}; use datafusion_common::file_options::json_writer::JsonWriterOptions; use datafusion_common::{not_impl_err, GetExt, DEFAULT_JSON_EXTENSION}; use datafusion_execution::TaskContext; +use datafusion_expr::dml::InsertOp; use datafusion_physical_expr::PhysicalExpr; use datafusion_physical_plan::metrics::MetricsSet; use datafusion_physical_plan::ExecutionPlan; @@ -252,7 +253,7 @@ impl FileFormat for JsonFormat { conf: FileSinkConfig, order_requirements: Option, ) -> Result> { - if conf.overwrite { + if conf.insert_op != InsertOp::Append { return not_impl_err!("Overwrites are not implemented yet for Json"); } diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index 35296b0d7907..98ae0ce14bd7 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -53,6 +53,7 @@ use datafusion_common::{ use datafusion_common_runtime::SpawnedTask; use datafusion_execution::memory_pool::{MemoryConsumer, MemoryPool, MemoryReservation}; use datafusion_execution::TaskContext; +use datafusion_expr::dml::InsertOp; use datafusion_expr::Expr; use datafusion_functions_aggregate::min_max::{MaxAccumulator, MinAccumulator}; use datafusion_physical_expr::PhysicalExpr; @@ -403,7 +404,7 @@ impl FileFormat for ParquetFormat { conf: FileSinkConfig, order_requirements: Option, ) -> Result> { - if conf.overwrite { + if conf.insert_op != InsertOp::Append { return not_impl_err!("Overwrites are not implemented yet for Parquet"); } @@ -2269,7 +2270,7 @@ mod tests { table_paths: vec![ListingTableUrl::parse("file:///")?], output_schema: schema.clone(), table_partition_cols: vec![], - overwrite: true, + insert_op: InsertOp::Overwrite, keep_partition_by_columns: false, }; let parquet_sink = Arc::new(ParquetSink::new( @@ -2364,7 +2365,7 @@ mod tests { table_paths: vec![ListingTableUrl::parse("file:///")?], output_schema: schema.clone(), table_partition_cols: vec![("a".to_string(), DataType::Utf8)], // add partitioning - overwrite: true, + insert_op: InsertOp::Overwrite, keep_partition_by_columns: false, }; let parquet_sink = Arc::new(ParquetSink::new( @@ -2447,7 +2448,7 @@ mod tests { table_paths: vec![ListingTableUrl::parse("file:///")?], output_schema: schema.clone(), table_partition_cols: vec![], - overwrite: true, + insert_op: InsertOp::Overwrite, keep_partition_by_columns: false, }; let parquet_sink = Arc::new(ParquetSink::new( diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index 2a35fddeb033..3eb8eed9de36 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -34,6 +34,7 @@ use crate::datasource::{ use crate::execution::context::SessionState; use datafusion_catalog::TableProvider; use datafusion_common::{DataFusionError, Result}; +use datafusion_expr::dml::InsertOp; use datafusion_expr::{utils::conjunction, Expr, TableProviderFilterPushDown}; use datafusion_expr::{SortExpr, TableType}; use datafusion_physical_plan::{empty::EmptyExec, ExecutionPlan, Statistics}; @@ -916,7 +917,7 @@ impl TableProvider for ListingTable { &self, state: &dyn Session, input: Arc, - overwrite: bool, + insert_op: InsertOp, ) -> Result> { // Check that the schema of the plan matches the schema of this table. if !self @@ -975,7 +976,7 @@ impl TableProvider for ListingTable { file_groups, output_schema: self.schema(), table_partition_cols: self.options.table_partition_cols.clone(), - overwrite, + insert_op, keep_partition_by_columns, }; @@ -1990,7 +1991,8 @@ mod tests { // Therefore, we will have 8 partitions in the final plan. // Create an insert plan to insert the source data into the initial table let insert_into_table = - LogicalPlanBuilder::insert_into(scan_plan, "t", &schema, false)?.build()?; + LogicalPlanBuilder::insert_into(scan_plan, "t", &schema, InsertOp::Append)? + .build()?; // Create a physical plan from the insert plan let plan = session_ctx .state() diff --git a/datafusion/core/src/datasource/memory.rs b/datafusion/core/src/datasource/memory.rs index 70f3c36b81e1..24a4938e7b2b 100644 --- a/datafusion/core/src/datasource/memory.rs +++ b/datafusion/core/src/datasource/memory.rs @@ -39,6 +39,7 @@ use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; use datafusion_common::{not_impl_err, plan_err, Constraints, DFSchema, SchemaExt}; use datafusion_execution::TaskContext; +use datafusion_expr::dml::InsertOp; use datafusion_physical_plan::metrics::MetricsSet; use async_trait::async_trait; @@ -262,7 +263,7 @@ impl TableProvider for MemTable { &self, _state: &dyn Session, input: Arc, - overwrite: bool, + insert_op: InsertOp, ) -> Result> { // If we are inserting into the table, any sort order may be messed up so reset it here *self.sort_order.lock() = vec![]; @@ -289,8 +290,8 @@ impl TableProvider for MemTable { .collect::>() ); } - if overwrite { - return not_impl_err!("Overwrite not implemented for MemoryTable yet"); + if insert_op != InsertOp::Append { + return not_impl_err!("{insert_op} not implemented for MemoryTable yet"); } let sink = Arc::new(MemSink::new(self.batches.clone())); Ok(Arc::new(DataSinkExec::new( @@ -638,7 +639,8 @@ mod tests { let scan_plan = LogicalPlanBuilder::scan("source", source, None)?.build()?; // Create an insert plan to insert the source data into the initial table let insert_into_table = - LogicalPlanBuilder::insert_into(scan_plan, "t", &schema, false)?.build()?; + LogicalPlanBuilder::insert_into(scan_plan, "t", &schema, InsertOp::Append)? + .build()?; // Create a physical plan from the insert plan let plan = session_ctx .state() diff --git a/datafusion/core/src/datasource/physical_plan/mod.rs b/datafusion/core/src/datasource/physical_plan/mod.rs index 4018b3bb2920..6e8752ccfbf4 100644 --- a/datafusion/core/src/datasource/physical_plan/mod.rs +++ b/datafusion/core/src/datasource/physical_plan/mod.rs @@ -36,6 +36,7 @@ pub use self::parquet::{ParquetExec, ParquetFileMetrics, ParquetFileReaderFactor pub use arrow_file::ArrowExec; pub use avro::AvroExec; pub use csv::{CsvConfig, CsvExec, CsvExecBuilder, CsvOpener}; +use datafusion_expr::dml::InsertOp; pub use file_groups::FileGroupPartitioner; pub use file_scan_config::{ wrap_partition_type_in_dict, wrap_partition_value_in_dict, FileScanConfig, @@ -83,8 +84,9 @@ pub struct FileSinkConfig { /// A vector of column names and their corresponding data types, /// representing the partitioning columns for the file pub table_partition_cols: Vec<(String, DataType)>, - /// Controls whether existing data should be overwritten by this sink - pub overwrite: bool, + /// Controls how new data should be written to the file, determining whether + /// to append to, overwrite, or replace records in existing files. + pub insert_op: InsertOp, /// Controls whether partition columns are kept for the file pub keep_partition_by_columns: bool, } diff --git a/datafusion/core/src/datasource/stream.rs b/datafusion/core/src/datasource/stream.rs index d30247e2c67a..34023fbbb620 100644 --- a/datafusion/core/src/datasource/stream.rs +++ b/datafusion/core/src/datasource/stream.rs @@ -33,6 +33,7 @@ use arrow_schema::SchemaRef; use datafusion_common::{config_err, plan_err, Constraints, DataFusionError, Result}; use datafusion_common_runtime::SpawnedTask; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_expr::dml::InsertOp; use datafusion_expr::{CreateExternalTable, Expr, SortExpr, TableType}; use datafusion_physical_plan::insert::{DataSink, DataSinkExec}; use datafusion_physical_plan::metrics::MetricsSet; @@ -350,7 +351,7 @@ impl TableProvider for StreamTable { &self, _state: &dyn Session, input: Arc, - _overwrite: bool, + _insert_op: InsertOp, ) -> Result> { let ordering = match self.0.order.first() { Some(x) => { diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index cffb63f52047..4953eecd66e3 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -174,27 +174,30 @@ pub struct SessionState { } impl Debug for SessionState { + /// Prefer having short fields at the top and long vector fields near the end + /// Group fields by fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("SessionState") .field("session_id", &self.session_id) .field("config", &self.config) .field("runtime_env", &self.runtime_env) - .field("catalog_list", &"...") - .field("serializer_registry", &"...") + .field("catalog_list", &self.catalog_list) + .field("serializer_registry", &self.serializer_registry) + .field("file_formats", &self.file_formats) .field("execution_props", &self.execution_props) .field("table_options", &self.table_options) - .field("table_factories", &"...") - .field("function_factory", &"...") - .field("expr_planners", &"...") - .field("query_planner", &"...") - .field("analyzer", &"...") - .field("optimizer", &"...") - .field("physical_optimizers", &"...") - .field("table_functions", &"...") + .field("table_factories", &self.table_factories) + .field("function_factory", &self.function_factory) + .field("expr_planners", &self.expr_planners) + .field("query_planners", &self.query_planner) + .field("analyzer", &self.analyzer) + .field("optimizer", &self.optimizer) + .field("physical_optimizers", &self.physical_optimizers) + .field("table_functions", &self.table_functions) .field("scalar_functions", &self.scalar_functions) .field("aggregate_functions", &self.aggregate_functions) .field("window_functions", &self.window_functions) - .finish_non_exhaustive() + .finish() } } diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index b2b912d8add2..520392c9f075 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -71,7 +71,7 @@ use datafusion_common::{ exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err, DFSchema, ScalarValue, }; -use datafusion_expr::dml::CopyTo; +use datafusion_expr::dml::{CopyTo, InsertOp}; use datafusion_expr::expr::{ physical_name, AggregateFunction, Alias, GroupingSet, WindowFunction, }; @@ -529,7 +529,7 @@ impl DefaultPhysicalPlanner { file_groups: vec![], output_schema: Arc::new(schema), table_partition_cols, - overwrite: false, + insert_op: InsertOp::Append, keep_partition_by_columns, }; @@ -542,7 +542,7 @@ impl DefaultPhysicalPlanner { } LogicalPlan::Dml(DmlStatement { table_name, - op: WriteOp::InsertInto, + op: WriteOp::Insert(insert_op), .. }) => { let name = table_name.table(); @@ -550,23 +550,7 @@ impl DefaultPhysicalPlanner { if let Some(provider) = schema.table(name).await? { let input_exec = children.one()?; provider - .insert_into(session_state, input_exec, false) - .await? - } else { - return exec_err!("Table '{table_name}' does not exist"); - } - } - LogicalPlan::Dml(DmlStatement { - table_name, - op: WriteOp::InsertOverwrite, - .. - }) => { - let name = table_name.table(); - let schema = session_state.schema_for_ref(table_name.clone())?; - if let Some(provider) = schema.table(name).await? { - let input_exec = children.one()?; - provider - .insert_into(session_state, input_exec, true) + .insert_into(session_state, input_exec, *insert_op) .await? } else { return exec_err!("Table '{table_name}' does not exist"); diff --git a/datafusion/core/tests/user_defined/insert_operation.rs b/datafusion/core/tests/user_defined/insert_operation.rs new file mode 100644 index 000000000000..ff14fa0be3fb --- /dev/null +++ b/datafusion/core/tests/user_defined/insert_operation.rs @@ -0,0 +1,188 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{any::Any, sync::Arc}; + +use arrow_schema::{DataType, Field, Schema, SchemaRef}; +use async_trait::async_trait; +use datafusion::{ + error::Result, + prelude::{SessionConfig, SessionContext}, +}; +use datafusion_catalog::{Session, TableProvider}; +use datafusion_expr::{dml::InsertOp, Expr, TableType}; +use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; +use datafusion_physical_plan::{DisplayAs, ExecutionMode, ExecutionPlan, PlanProperties}; + +#[tokio::test] +async fn insert_operation_is_passed_correctly_to_table_provider() { + // Use the SQLite syntax so we can test the "INSERT OR REPLACE INTO" syntax + let ctx = session_ctx_with_dialect("SQLite"); + let table_provider = Arc::new(TestInsertTableProvider::new()); + ctx.register_table("testing", table_provider.clone()) + .unwrap(); + + let sql = "INSERT INTO testing (column) VALUES (1)"; + assert_insert_op(&ctx, sql, InsertOp::Append).await; + + let sql = "INSERT OVERWRITE testing (column) VALUES (1)"; + assert_insert_op(&ctx, sql, InsertOp::Overwrite).await; + + let sql = "REPLACE INTO testing (column) VALUES (1)"; + assert_insert_op(&ctx, sql, InsertOp::Replace).await; + + let sql = "INSERT OR REPLACE INTO testing (column) VALUES (1)"; + assert_insert_op(&ctx, sql, InsertOp::Replace).await; +} + +async fn assert_insert_op(ctx: &SessionContext, sql: &str, insert_op: InsertOp) { + let df = ctx.sql(sql).await.unwrap(); + let plan = df.create_physical_plan().await.unwrap(); + let exec = plan.as_any().downcast_ref::().unwrap(); + assert_eq!(exec.op, insert_op); +} + +fn session_ctx_with_dialect(dialect: impl Into) -> SessionContext { + let mut config = SessionConfig::new(); + let options = config.options_mut(); + options.sql_parser.dialect = dialect.into(); + SessionContext::new_with_config(config) +} + +#[derive(Debug)] +struct TestInsertTableProvider { + schema: SchemaRef, +} + +impl TestInsertTableProvider { + fn new() -> Self { + Self { + schema: SchemaRef::new(Schema::new(vec![Field::new( + "column", + DataType::Int64, + false, + )])), + } + } +} + +#[async_trait] +impl TableProvider for TestInsertTableProvider { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } + + fn table_type(&self) -> TableType { + TableType::Base + } + + async fn scan( + &self, + _state: &dyn Session, + _projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> Result> { + unimplemented!("TestInsertTableProvider is a stub for testing.") + } + + async fn insert_into( + &self, + _state: &dyn Session, + _input: Arc, + insert_op: InsertOp, + ) -> Result> { + Ok(Arc::new(TestInsertExec::new(insert_op))) + } +} + +#[derive(Debug)] +struct TestInsertExec { + op: InsertOp, + plan_properties: PlanProperties, +} + +impl TestInsertExec { + fn new(op: InsertOp) -> Self { + let eq_properties = EquivalenceProperties::new(make_count_schema()); + let plan_properties = PlanProperties::new( + eq_properties, + Partitioning::UnknownPartitioning(1), + ExecutionMode::Bounded, + ); + Self { + op, + plan_properties, + } + } +} + +impl DisplayAs for TestInsertExec { + fn fmt_as( + &self, + _t: datafusion_physical_plan::DisplayFormatType, + f: &mut std::fmt::Formatter, + ) -> std::fmt::Result { + write!(f, "TestInsertExec") + } +} + +impl ExecutionPlan for TestInsertExec { + fn name(&self) -> &str { + "TestInsertExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.plan_properties + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + assert!(children.is_empty()); + Ok(self) + } + + fn execute( + &self, + _partition: usize, + _context: Arc, + ) -> Result { + unimplemented!("TestInsertExec is a stub for testing.") + } +} + +fn make_count_schema() -> SchemaRef { + Arc::new(Schema::new(vec![Field::new( + "count", + DataType::UInt64, + false, + )])) +} diff --git a/datafusion/core/tests/user_defined/mod.rs b/datafusion/core/tests/user_defined/mod.rs index 56cec8df468b..5d84cdb69283 100644 --- a/datafusion/core/tests/user_defined/mod.rs +++ b/datafusion/core/tests/user_defined/mod.rs @@ -32,3 +32,6 @@ mod user_defined_table_functions; /// Tests for Expression Planner mod expr_planner; + +/// Tests for insert operations +mod insert_operation; diff --git a/datafusion/execution/src/stream.rs b/datafusion/execution/src/stream.rs index 7fc5e458b86b..f3eb7b77e03c 100644 --- a/datafusion/execution/src/stream.rs +++ b/datafusion/execution/src/stream.rs @@ -20,7 +20,9 @@ use datafusion_common::Result; use futures::Stream; use std::pin::Pin; -/// Trait for types that stream [arrow::record_batch::RecordBatch] +/// Trait for types that stream [RecordBatch] +/// +/// See [`SendableRecordBatchStream`] for more details. pub trait RecordBatchStream: Stream> { /// Returns the schema of this `RecordBatchStream`. /// @@ -29,5 +31,23 @@ pub trait RecordBatchStream: Stream> { fn schema(&self) -> SchemaRef; } -/// Trait for a [`Stream`] of [`RecordBatch`]es +/// Trait for a [`Stream`] of [`RecordBatch`]es that can be passed between threads +/// +/// This trait is used to retrieve the results of DataFusion execution plan nodes. +/// +/// The trait is a specialized Rust Async [`Stream`] that also knows the schema +/// of the data it will return (even if the stream has no data). Every +/// `RecordBatch` returned by the stream should have the same schema as returned +/// by [`schema`](`RecordBatchStream::schema`). +/// +/// # Error Handling +/// +/// Once a stream returns an error, it should not be polled again (the caller +/// should stop calling `next`) and handle the error. +/// +/// However, returning `Ready(None)` (end of stream) is likely the safest +/// behavior after an error. Like [`Stream`]s, `RecordBatchStream`s should not +/// be polled after end of stream or returning an error. However, also like +/// [`Stream`]s there is no mechanism to prevent callers polling so returning +/// `Ready(None)` is recommended. pub type SendableRecordBatchStream = Pin>; diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 260065f69af9..32eac90c3eec 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -92,7 +92,7 @@ pub use sqlparser; pub use table_source::{TableProviderFilterPushDown, TableSource, TableType}; pub use udaf::{AggregateUDF, AggregateUDFImpl, ReversedUDAF}; pub use udf::{ScalarUDF, ScalarUDFImpl}; -pub use udwf::{WindowUDF, WindowUDFImpl}; +pub use udwf::{ReversedUDWF, WindowUDF, WindowUDFImpl}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; #[cfg(test)] diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index ad96f6a85d0e..cc8ddf8ec8e8 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -54,6 +54,7 @@ use datafusion_common::{ TableReference, ToDFSchema, UnnestOptions, }; +use super::dml::InsertOp; use super::plan::{ColumnUnnestList, ColumnUnnestType}; /// Default table name for unnamed table @@ -307,20 +308,14 @@ impl LogicalPlanBuilder { input: LogicalPlan, table_name: impl Into, table_schema: &Schema, - overwrite: bool, + insert_op: InsertOp, ) -> Result { let table_schema = table_schema.clone().to_dfschema_ref()?; - let op = if overwrite { - WriteOp::InsertOverwrite - } else { - WriteOp::InsertInto - }; - Ok(Self::new(LogicalPlan::Dml(DmlStatement::new( table_name.into(), table_schema, - op, + WriteOp::Insert(insert_op), Arc::new(input), )))) } diff --git a/datafusion/expr/src/logical_plan/dml.rs b/datafusion/expr/src/logical_plan/dml.rs index c2ed9dc0781c..68b3ac41fa08 100644 --- a/datafusion/expr/src/logical_plan/dml.rs +++ b/datafusion/expr/src/logical_plan/dml.rs @@ -146,8 +146,7 @@ impl PartialOrd for DmlStatement { #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)] pub enum WriteOp { - InsertOverwrite, - InsertInto, + Insert(InsertOp), Delete, Update, Ctas, @@ -157,8 +156,7 @@ impl WriteOp { /// Return a descriptive name of this [`WriteOp`] pub fn name(&self) -> &str { match self { - WriteOp::InsertOverwrite => "Insert Overwrite", - WriteOp::InsertInto => "Insert Into", + WriteOp::Insert(insert) => insert.name(), WriteOp::Delete => "Delete", WriteOp::Update => "Update", WriteOp::Ctas => "Ctas", @@ -172,6 +170,37 @@ impl Display for WriteOp { } } +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Hash)] +pub enum InsertOp { + /// Appends new rows to the existing table without modifying any + /// existing rows. This corresponds to the SQL `INSERT INTO` query. + Append, + /// Overwrites all existing rows in the table with the new rows. + /// This corresponds to the SQL `INSERT OVERWRITE` query. + Overwrite, + /// If any existing rows collides with the inserted rows (typically based + /// on a unique key or primary key), those existing rows are replaced. + /// This corresponds to the SQL `REPLACE INTO` query and its equivalents. + Replace, +} + +impl InsertOp { + /// Return a descriptive name of this [`InsertOp`] + pub fn name(&self) -> &str { + match self { + InsertOp::Append => "Insert Into", + InsertOp::Overwrite => "Insert Overwrite", + InsertOp::Replace => "Replace Into", + } + } +} + +impl Display for InsertOp { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.name()) + } +} + fn make_count_schema() -> DFSchemaRef { Arc::new( Schema::new(vec![Field::new("count", DataType::UInt64, false)]) diff --git a/datafusion/expr/src/udwf.rs b/datafusion/expr/src/udwf.rs index 7cc57523a14d..678a0b62cd9a 100644 --- a/datafusion/expr/src/udwf.rs +++ b/datafusion/expr/src/udwf.rs @@ -172,6 +172,14 @@ impl WindowUDF { pub fn coerce_types(&self, arg_types: &[DataType]) -> Result> { self.inner.coerce_types(arg_types) } + + /// Returns the reversed user-defined window function when the + /// order of evaluation is reversed. + /// + /// See [`WindowUDFImpl::reverse_expr`] for more details. + pub fn reverse_expr(&self) -> ReversedUDWF { + self.inner.reverse_expr() + } } impl From for WindowUDF @@ -351,6 +359,24 @@ pub trait WindowUDFImpl: Debug + Send + Sync { fn coerce_types(&self, _arg_types: &[DataType]) -> Result> { not_impl_err!("Function {} does not implement coerce_types", self.name()) } + + /// Allows customizing the behavior of the user-defined window + /// function when it is evaluated in reverse order. + fn reverse_expr(&self) -> ReversedUDWF { + ReversedUDWF::NotSupported + } +} + +pub enum ReversedUDWF { + /// The result of evaluating the user-defined window function + /// remains identical when reversed. + Identical, + /// A window function which does not support evaluating the result + /// in reverse order. + NotSupported, + /// Customize the user-defined window function for evaluating the + /// result in reverse order. + Reversed(Arc), } impl PartialEq for dyn WindowUDFImpl { diff --git a/datafusion/physical-expr/src/expressions/binary/kernels.rs b/datafusion/physical-expr/src/expressions/binary/kernels.rs index 1f9cfed1a44f..c0685c6decde 100644 --- a/datafusion/physical-expr/src/expressions/binary/kernels.rs +++ b/datafusion/physical-expr/src/expressions/binary/kernels.rs @@ -24,7 +24,7 @@ use arrow::compute::kernels::bitwise::{ bitwise_xor, bitwise_xor_scalar, }; use arrow::datatypes::DataType; -use datafusion_common::internal_err; +use datafusion_common::plan_err; use datafusion_common::{Result, ScalarValue}; use arrow_schema::ArrowError; @@ -70,7 +70,7 @@ macro_rules! create_dyn_kernel { DataType::UInt64 => { call_bitwise_kernel!(left, right, $KERNEL, UInt64Array) } - other => internal_err!( + other => plan_err!( "Data type {:?} not supported for binary operation '{}' on dyn arrays", other, stringify!($KERNEL) @@ -116,7 +116,7 @@ macro_rules! create_dyn_scalar_kernel { DataType::UInt16 => call_bitwise_scalar_kernel!(array, scalar, $KERNEL, UInt16Array, u16), DataType::UInt32 => call_bitwise_scalar_kernel!(array, scalar, $KERNEL, UInt32Array, u32), DataType::UInt64 => call_bitwise_scalar_kernel!(array, scalar, $KERNEL, UInt64Array, u64), - other => internal_err!( + other => plan_err!( "Data type {:?} not supported for binary operation '{}' on dyn arrays", other, stringify!($KERNEL) diff --git a/datafusion/physical-plan/src/aggregates/group_values/column.rs b/datafusion/physical-plan/src/aggregates/group_values/column.rs index 1565c483c24c..91d87302ce99 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/column.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/column.rs @@ -37,7 +37,9 @@ use datafusion_physical_expr::binary_map::OutputType; use hashbrown::raw::RawTable; -/// Compare GroupValue Rows column by column +/// A [`GroupValues`] that stores multiple columns of group values. +/// +/// pub struct GroupValuesColumn { /// The output schema schema: SchemaRef, @@ -56,8 +58,13 @@ pub struct GroupValuesColumn { map_size: usize, /// The actual group by values, stored column-wise. Compare from - /// the left to right, each column is stored as `ArrayRowEq`. - /// This is shown faster than the row format + /// the left to right, each column is stored as [`GroupColumn`]. + /// + /// Performance tests showed that this design is faster than using the + /// more general purpose [`GroupValuesRows`]. See the ticket for details: + /// + /// + /// [`GroupValuesRows`]: crate::aggregates::group_values::row::GroupValuesRows group_values: Vec>, /// reused buffer to store hashes diff --git a/datafusion/physical-plan/src/aggregates/group_values/group_column.rs b/datafusion/physical-plan/src/aggregates/group_values/group_column.rs index bde871836258..7409f5c214b9 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/group_column.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/group_column.rs @@ -37,12 +37,13 @@ use datafusion_physical_expr_common::binary_map::{OutputType, INITIAL_BUFFER_CAP use std::sync::Arc; use std::vec; -/// Trait for group values column-wise row comparison +/// Trait for storing a single column of group values in [`GroupValuesColumn`] /// -/// Implementations of this trait store a in-progress collection of group values +/// Implementations of this trait store an in-progress collection of group values /// (similar to various builders in Arrow-rs) that allow for quick comparison to /// incoming rows. /// +/// [`GroupValuesColumn`]: crate::aggregates::group_values::GroupValuesColumn pub trait GroupColumn: Send + Sync { /// Returns equal if the row stored in this builder at `lhs_row` is equal to /// the row in `array` at `rhs_row` @@ -62,7 +63,7 @@ pub trait GroupColumn: Send + Sync { fn take_n(&mut self, n: usize) -> ArrayRef; } -/// Stores a collection of primitive group values which are known to have no nulls +/// An implementation of [`GroupColumn`] for primitive values which are known to have no nulls #[derive(Debug)] pub struct NonNullPrimitiveGroupValueBuilder { group_values: Vec, @@ -120,7 +121,7 @@ impl GroupColumn for NonNullPrimitiveGroupValueBuilder } } -/// Stores a collection of primitive group values which may have nulls +/// An implementation of [`GroupColumn`] for primitive values which may have nulls #[derive(Debug)] pub struct PrimitiveGroupValueBuilder { group_values: Vec, @@ -188,13 +189,14 @@ impl GroupColumn for PrimitiveGroupValueBuilder { } } +/// An implementation of [`GroupColumn`] for binary and utf8 types. pub struct ByteGroupValueBuilder where O: OffsetSizeTrait, { output_type: OutputType, buffer: BufferBuilder, - /// Offsets into `buffer` for each distinct value. These offsets as used + /// Offsets into `buffer` for each distinct value. These offsets as used /// directly to create the final `GenericBinaryArray`. The `i`th string is /// stored in the range `offsets[i]..offsets[i+1]` in `buffer`. Null values /// are stored as a zero length string. diff --git a/datafusion/physical-plan/src/aggregates/group_values/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/mod.rs index bc05e8a40516..fb7b66775092 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/mod.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! [`GroupValues`] trait for storing and interning group keys + use arrow::record_batch::RecordBatch; use arrow_array::{downcast_primitive, ArrayRef}; use arrow_schema::{DataType, SchemaRef}; @@ -38,18 +40,61 @@ use datafusion_physical_expr::binary_map::OutputType; mod group_column; mod null_builder; -/// An interning store for group keys +/// Stores the group values during hash aggregation. +/// +/// # Background +/// +/// In a query such as `SELECT a, b, count(*) FROM t GROUP BY a, b`, the group values +/// identify each group, and correspond to all the distinct values of `(a,b)`. +/// +/// ```sql +/// -- Input has 4 rows with 3 distinct combinations of (a,b) ("groups") +/// create table t(a int, b varchar) +/// as values (1, 'a'), (2, 'b'), (1, 'a'), (3, 'c'); +/// +/// select a, b, count(*) from t group by a, b; +/// ---- +/// 1 a 2 +/// 2 b 1 +/// 3 c 1 +/// ``` +/// +/// # Design +/// +/// Managing group values is a performance critical operation in hash +/// aggregation. The major operations are: +/// +/// 1. Intern: Quickly finding existing and adding new group values +/// 2. Emit: Returning the group values as an array +/// +/// There are multiple specialized implementations of this trait optimized for +/// different data types and number of columns, optimized for these operations. +/// See [`new_group_values`] for details. +/// +/// # Group Ids +/// +/// Each distinct group in a hash aggregation is identified by a unique group id +/// (usize) which is assigned by instances of this trait. Group ids are +/// continuous without gaps, starting from 0. pub trait GroupValues: Send { - /// Calculates the `groups` for each input row of `cols` + /// Calculates the group id for each input row of `cols`, assigning new + /// group ids as necessary. + /// + /// When the function returns, `groups` must contain the group id for each + /// row in `cols`. + /// + /// If a row has the same value as a previous row, the same group id is + /// assigned. If a row has a new value, the next available group id is + /// assigned. fn intern(&mut self, cols: &[ArrayRef], groups: &mut Vec) -> Result<()>; - /// Returns the number of bytes used by this [`GroupValues`] + /// Returns the number of bytes of memory used by this [`GroupValues`] fn size(&self) -> usize; /// Returns true if this [`GroupValues`] is empty fn is_empty(&self) -> bool; - /// The number of values stored in this [`GroupValues`] + /// The number of values (distinct group values) stored in this [`GroupValues`] fn len(&self) -> usize; /// Emits the group values @@ -59,6 +104,7 @@ pub trait GroupValues: Send { fn clear_shrink(&mut self, batch: &RecordBatch); } +/// Return a specialized implementation of [`GroupValues`] for the given schema. pub fn new_group_values(schema: SchemaRef) -> Result> { if schema.fields.len() == 1 { let d = schema.fields[0].data_type(); diff --git a/datafusion/physical-plan/src/aggregates/group_values/row.rs b/datafusion/physical-plan/src/aggregates/group_values/row.rs index b252d0008784..8ca88257bf1a 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/row.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/row.rs @@ -30,6 +30,13 @@ use hashbrown::raw::RawTable; use std::sync::Arc; /// A [`GroupValues`] making use of [`Rows`] +/// +/// This is a general implementation of [`GroupValues`] that works for any +/// combination of data types and number of columns, including nested types such as +/// structs and lists. +/// +/// It uses the arrow-rs [`Rows`] to store the group values, which is a row-wise +/// representation. pub struct GroupValuesRows { /// The output schema schema: SchemaRef, @@ -220,7 +227,8 @@ impl GroupValues for GroupValuesRows { } }; - // TODO: Materialize dictionaries in group keys (#7647) + // TODO: Materialize dictionaries in group keys + // https://github.com/apache/datafusion/issues/7647 for (field, array) in self.schema.fields.iter().zip(&mut output) { let expected = field.data_type(); *array = dictionary_encode_if_necessary( diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index d4dbdf0f029d..a043905765ec 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -609,14 +609,11 @@ impl Stream for GroupedHashAggregateStream { match &self.exec_state { ExecutionState::ReadingInput => 'reading_input: { match ready!(self.input.poll_next_unpin(cx)) { - // new batch to aggregate - Some(Ok(batch)) => { + // New batch to aggregate in partial aggregation operator + Some(Ok(batch)) if self.mode == AggregateMode::Partial => { let timer = elapsed_compute.timer(); let input_rows = batch.num_rows(); - // Make sure we have enough capacity for `batch`, otherwise spill - extract_ok!(self.spill_previous_if_necessary(&batch)); - // Do the grouping extract_ok!(self.group_aggregate_batch(batch)); @@ -649,10 +646,49 @@ impl Stream for GroupedHashAggregateStream { timer.done(); } + + // New batch to aggregate in terminal aggregation operator + // (Final/FinalPartitioned/Single/SinglePartitioned) + Some(Ok(batch)) => { + let timer = elapsed_compute.timer(); + + // Make sure we have enough capacity for `batch`, otherwise spill + extract_ok!(self.spill_previous_if_necessary(&batch)); + + // Do the grouping + extract_ok!(self.group_aggregate_batch(batch)); + + // If we can begin emitting rows, do so, + // otherwise keep consuming input + assert!(!self.input_done); + + // If the number of group values equals or exceeds the soft limit, + // emit all groups and switch to producing output + if self.hit_soft_group_limit() { + timer.done(); + extract_ok!(self.set_input_done_and_produce_output()); + // make sure the exec_state just set is not overwritten below + break 'reading_input; + } + + if let Some(to_emit) = self.group_ordering.emit_to() { + let batch = extract_ok!(self.emit(to_emit, false)); + self.exec_state = ExecutionState::ProducingOutput(batch); + timer.done(); + // make sure the exec_state just set is not overwritten below + break 'reading_input; + } + + timer.done(); + } + + // Found error from input stream Some(Err(e)) => { // inner had error, return to caller return Poll::Ready(Some(Err(e))); } + + // Found end from input stream None => { // inner is done, emit all rows and switch to producing output extract_ok!(self.set_input_done_and_produce_output()); @@ -691,7 +727,12 @@ impl Stream for GroupedHashAggregateStream { ( if self.input_done { ExecutionState::Done - } else if self.should_skip_aggregation() { + } + // In Partial aggregation, we also need to check + // if we should trigger partial skipping + else if self.mode == AggregateMode::Partial + && self.should_skip_aggregation() + { ExecutionState::SkippingAggregation } else { ExecutionState::ReadingInput @@ -879,10 +920,10 @@ impl GroupedHashAggregateStream { if self.group_values.len() > 0 && batch.num_rows() > 0 && matches!(self.group_ordering, GroupOrdering::None) - && !matches!(self.mode, AggregateMode::Partial) && !self.spill_state.is_stream_merging && self.update_memory_reservation().is_err() { + assert_ne!(self.mode, AggregateMode::Partial); // Use input batch (Partial mode) schema for spilling because // the spilled data will be merged and re-evaluated later. self.spill_state.spill_schema = batch.schema(); @@ -927,9 +968,9 @@ impl GroupedHashAggregateStream { fn emit_early_if_necessary(&mut self) -> Result<()> { if self.group_values.len() >= self.batch_size && matches!(self.group_ordering, GroupOrdering::None) - && matches!(self.mode, AggregateMode::Partial) && self.update_memory_reservation().is_err() { + assert_eq!(self.mode, AggregateMode::Partial); let n = self.group_values.len() / self.batch_size * self.batch_size; let batch = self.emit(EmitTo::First(n), false)?; self.exec_state = ExecutionState::ProducingOutput(batch); @@ -1002,6 +1043,8 @@ impl GroupedHashAggregateStream { } /// Updates skip aggregation probe state. + /// + /// Notice: It should only be called in Partial aggregation fn update_skip_aggregation_probe(&mut self, input_rows: usize) { if let Some(probe) = self.skip_aggregation_probe.as_mut() { // Skip aggregation probe is not supported if stream has any spills, @@ -1013,6 +1056,8 @@ impl GroupedHashAggregateStream { /// In case the probe indicates that aggregation may be /// skipped, forces stream to produce currently accumulated output. + /// + /// Notice: It should only be called in Partial aggregation fn switch_to_skip_aggregation(&mut self) -> Result<()> { if let Some(probe) = self.skip_aggregation_probe.as_mut() { if probe.should_skip() { @@ -1026,6 +1071,8 @@ impl GroupedHashAggregateStream { /// Returns true if the aggregation probe indicates that aggregation /// should be skipped. + /// + /// Notice: It should only be called in Partial aggregation fn should_skip_aggregation(&self) -> bool { self.skip_aggregation_probe .as_ref() diff --git a/datafusion/physical-plan/src/execution_plan.rs b/datafusion/physical-plan/src/execution_plan.rs index 542861688dfe..b14021f4a99b 100644 --- a/datafusion/physical-plan/src/execution_plan.rs +++ b/datafusion/physical-plan/src/execution_plan.rs @@ -228,6 +228,16 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { /// [`TryStreamExt`]: futures::stream::TryStreamExt /// [`RecordBatchStreamAdapter`]: crate::stream::RecordBatchStreamAdapter /// + /// # Error handling + /// + /// Any error that occurs during execution is sent as an `Err` in the output + /// stream. + /// + /// `ExecutionPlan` implementations in DataFusion cancel additional work + /// immediately once an error occurs. The rationale is that if the overall + /// query will return an error, any additional work such as continued + /// polling of inputs will be wasted as it will be thrown away. + /// /// # Cancellation / Aborting Execution /// /// The [`Stream`] that is returned must ensure that any allocated resources diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index 5b25d582d20c..4fd364cca4d0 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -377,6 +377,11 @@ impl BatchPartitioner { /// `───────' `───────' ///``` /// +/// # Error Handling +/// +/// If any of the input partitions return an error, the error is propagated to +/// all output partitions and inputs are not polled again. +/// /// # Output Ordering /// /// If more than one stream is being repartitioned, the output will be some diff --git a/datafusion/physical-plan/src/sorts/merge.rs b/datafusion/physical-plan/src/sorts/merge.rs index 875922ac34b5..e0644e3d99e5 100644 --- a/datafusion/physical-plan/src/sorts/merge.rs +++ b/datafusion/physical-plan/src/sorts/merge.rs @@ -39,6 +39,7 @@ use futures::Stream; /// A fallible [`PartitionedStream`] of [`Cursor`] and [`RecordBatch`] type CursorStream = Box>>; +/// Merges a stream of sorted cursors and record batches into a single sorted stream #[derive(Debug)] pub(crate) struct SortPreservingMergeStream { in_progress: BatchBuilder, diff --git a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs index f83bb58d08dd..b00a11a5355f 100644 --- a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs +++ b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs @@ -65,6 +65,11 @@ use log::{debug, trace}; /// Input Streams Output stream /// (sorted) (sorted) /// ``` +/// +/// # Error Handling +/// +/// If any of the input partitions return an error, the error is propagated to +/// the output and inputs are not polled again. #[derive(Debug)] pub struct SortPreservingMergeExec { /// Input plan diff --git a/datafusion/physical-plan/src/windows/mod.rs b/datafusion/physical-plan/src/windows/mod.rs index 6e1cb8db5f09..b6f34ec69f68 100644 --- a/datafusion/physical-plan/src/windows/mod.rs +++ b/datafusion/physical-plan/src/windows/mod.rs @@ -34,8 +34,8 @@ use datafusion_common::{ exec_datafusion_err, exec_err, DataFusionError, Result, ScalarValue, }; use datafusion_expr::{ - BuiltInWindowFunction, PartitionEvaluator, WindowFrame, WindowFunctionDefinition, - WindowUDF, + BuiltInWindowFunction, PartitionEvaluator, ReversedUDWF, WindowFrame, + WindowFunctionDefinition, WindowUDF, }; use datafusion_physical_expr::aggregate::{AggregateExprBuilder, AggregateFunctionExpr}; use datafusion_physical_expr::equivalence::collapse_lex_req; @@ -130,7 +130,7 @@ pub fn create_window_expr( } // TODO: Ordering not supported for Window UDFs yet WindowFunctionDefinition::WindowUDF(fun) => Arc::new(BuiltInWindowExpr::new( - create_udwf_window_expr(fun, args, input_schema, name)?, + create_udwf_window_expr(fun, args, input_schema, name, ignore_nulls)?, partition_by, order_by, window_frame, @@ -329,6 +329,7 @@ fn create_udwf_window_expr( args: &[Arc], input_schema: &Schema, name: String, + ignore_nulls: bool, ) -> Result> { // need to get the types into an owned vec for some reason let input_types: Vec<_> = args @@ -341,6 +342,8 @@ fn create_udwf_window_expr( args: args.to_vec(), input_types, name, + is_reversed: false, + ignore_nulls, })) } @@ -353,6 +356,12 @@ struct WindowUDFExpr { name: String, /// Types of input expressions input_types: Vec, + /// This is set to `true` only if the user-defined window function + /// expression supports evaluation in reverse order, and the + /// evaluation order is reversed. + is_reversed: bool, + /// Set to `true` if `IGNORE NULLS` is defined, `false` otherwise. + ignore_nulls: bool, } impl BuiltInWindowFunctionExpr for WindowUDFExpr { @@ -378,7 +387,18 @@ impl BuiltInWindowFunctionExpr for WindowUDFExpr { } fn reverse_expr(&self) -> Option> { - None + match self.fun.reverse_expr() { + ReversedUDWF::Identical => Some(Arc::new(self.clone())), + ReversedUDWF::NotSupported => None, + ReversedUDWF::Reversed(fun) => Some(Arc::new(WindowUDFExpr { + fun, + args: self.args.clone(), + name: self.name.clone(), + input_types: self.input_types.clone(), + is_reversed: !self.is_reversed, + ignore_nulls: self.ignore_nulls, + })), + } } fn get_result_ordering(&self, schema: &SchemaRef) -> Option { diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 1204c843fdb1..e36c91e7d004 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -731,14 +731,21 @@ message PartitionColumn { message FileSinkConfig { reserved 6; // writer_mode + reserved 8; // was `overwrite` which has been superseded by `insert_op` string object_store_url = 1; repeated PartitionedFile file_groups = 2; repeated string table_paths = 3; datafusion_common.Schema output_schema = 4; repeated PartitionColumn table_partition_cols = 5; - bool overwrite = 8; bool keep_partition_by_columns = 9; + InsertOp insert_op = 10; +} + +enum InsertOp { + Append = 0; + Overwrite = 1; + Replace = 2; } message JsonSink { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 0614e33b7a4b..004798b3ba93 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -5832,10 +5832,10 @@ impl serde::Serialize for FileSinkConfig { if !self.table_partition_cols.is_empty() { len += 1; } - if self.overwrite { + if self.keep_partition_by_columns { len += 1; } - if self.keep_partition_by_columns { + if self.insert_op != 0 { len += 1; } let mut struct_ser = serializer.serialize_struct("datafusion.FileSinkConfig", len)?; @@ -5854,12 +5854,14 @@ impl serde::Serialize for FileSinkConfig { if !self.table_partition_cols.is_empty() { struct_ser.serialize_field("tablePartitionCols", &self.table_partition_cols)?; } - if self.overwrite { - struct_ser.serialize_field("overwrite", &self.overwrite)?; - } if self.keep_partition_by_columns { struct_ser.serialize_field("keepPartitionByColumns", &self.keep_partition_by_columns)?; } + if self.insert_op != 0 { + let v = InsertOp::try_from(self.insert_op) + .map_err(|_| serde::ser::Error::custom(format!("Invalid variant {}", self.insert_op)))?; + struct_ser.serialize_field("insertOp", &v)?; + } struct_ser.end() } } @@ -5880,9 +5882,10 @@ impl<'de> serde::Deserialize<'de> for FileSinkConfig { "outputSchema", "table_partition_cols", "tablePartitionCols", - "overwrite", "keep_partition_by_columns", "keepPartitionByColumns", + "insert_op", + "insertOp", ]; #[allow(clippy::enum_variant_names)] @@ -5892,8 +5895,8 @@ impl<'de> serde::Deserialize<'de> for FileSinkConfig { TablePaths, OutputSchema, TablePartitionCols, - Overwrite, KeepPartitionByColumns, + InsertOp, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -5920,8 +5923,8 @@ impl<'de> serde::Deserialize<'de> for FileSinkConfig { "tablePaths" | "table_paths" => Ok(GeneratedField::TablePaths), "outputSchema" | "output_schema" => Ok(GeneratedField::OutputSchema), "tablePartitionCols" | "table_partition_cols" => Ok(GeneratedField::TablePartitionCols), - "overwrite" => Ok(GeneratedField::Overwrite), "keepPartitionByColumns" | "keep_partition_by_columns" => Ok(GeneratedField::KeepPartitionByColumns), + "insertOp" | "insert_op" => Ok(GeneratedField::InsertOp), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -5946,8 +5949,8 @@ impl<'de> serde::Deserialize<'de> for FileSinkConfig { let mut table_paths__ = None; let mut output_schema__ = None; let mut table_partition_cols__ = None; - let mut overwrite__ = None; let mut keep_partition_by_columns__ = None; + let mut insert_op__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::ObjectStoreUrl => { @@ -5980,18 +5983,18 @@ impl<'de> serde::Deserialize<'de> for FileSinkConfig { } table_partition_cols__ = Some(map_.next_value()?); } - GeneratedField::Overwrite => { - if overwrite__.is_some() { - return Err(serde::de::Error::duplicate_field("overwrite")); - } - overwrite__ = Some(map_.next_value()?); - } GeneratedField::KeepPartitionByColumns => { if keep_partition_by_columns__.is_some() { return Err(serde::de::Error::duplicate_field("keepPartitionByColumns")); } keep_partition_by_columns__ = Some(map_.next_value()?); } + GeneratedField::InsertOp => { + if insert_op__.is_some() { + return Err(serde::de::Error::duplicate_field("insertOp")); + } + insert_op__ = Some(map_.next_value::()? as i32); + } } } Ok(FileSinkConfig { @@ -6000,8 +6003,8 @@ impl<'de> serde::Deserialize<'de> for FileSinkConfig { table_paths: table_paths__.unwrap_or_default(), output_schema: output_schema__, table_partition_cols: table_partition_cols__.unwrap_or_default(), - overwrite: overwrite__.unwrap_or_default(), keep_partition_by_columns: keep_partition_by_columns__.unwrap_or_default(), + insert_op: insert_op__.unwrap_or_default(), }) } } @@ -7198,6 +7201,80 @@ impl<'de> serde::Deserialize<'de> for InListNode { deserializer.deserialize_struct("datafusion.InListNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for InsertOp { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + let variant = match self { + Self::Append => "Append", + Self::Overwrite => "Overwrite", + Self::Replace => "Replace", + }; + serializer.serialize_str(variant) + } +} +impl<'de> serde::Deserialize<'de> for InsertOp { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "Append", + "Overwrite", + "Replace", + ]; + + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = InsertOp; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + fn visit_i64(self, v: i64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Signed(v), &self) + }) + } + + fn visit_u64(self, v: u64) -> std::result::Result + where + E: serde::de::Error, + { + i32::try_from(v) + .ok() + .and_then(|x| x.try_into().ok()) + .ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Unsigned(v), &self) + }) + } + + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "Append" => Ok(InsertOp::Append), + "Overwrite" => Ok(InsertOp::Overwrite), + "Replace" => Ok(InsertOp::Replace), + _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), + } + } + } + deserializer.deserialize_any(GeneratedVisitor) + } +} impl serde::Serialize for InterleaveExecNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 21d88e565e80..436347330d92 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1067,10 +1067,10 @@ pub struct FileSinkConfig { pub output_schema: ::core::option::Option, #[prost(message, repeated, tag = "5")] pub table_partition_cols: ::prost::alloc::vec::Vec, - #[prost(bool, tag = "8")] - pub overwrite: bool, #[prost(bool, tag = "9")] pub keep_partition_by_columns: bool, + #[prost(enumeration = "InsertOp", tag = "10")] + pub insert_op: i32, } #[derive(Clone, PartialEq, ::prost::Message)] pub struct JsonSink { @@ -1954,6 +1954,35 @@ impl DateUnit { } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] +pub enum InsertOp { + Append = 0, + Overwrite = 1, + Replace = 2, +} +impl InsertOp { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + Self::Append => "Append", + Self::Overwrite => "Overwrite", + Self::Replace => "Replace", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "Append" => Some(Self::Append), + "Overwrite" => Some(Self::Overwrite), + "Replace" => Some(Self::Replace), + _ => None, + } + } +} +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] pub enum PartitionMode { CollectLeft = 0, Partitioned = 1, diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index b2f92f4b2ee4..20ec5eeaeaf8 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -21,6 +21,7 @@ use std::sync::Arc; use arrow::compute::SortOptions; use chrono::{TimeZone, Utc}; +use datafusion_expr::dml::InsertOp; use object_store::path::Path; use object_store::ObjectMeta; @@ -640,13 +641,18 @@ impl TryFrom<&protobuf::FileSinkConfig> for FileSinkConfig { Ok((name.clone(), data_type)) }) .collect::>>()?; + let insert_op = match conf.insert_op() { + protobuf::InsertOp::Append => InsertOp::Append, + protobuf::InsertOp::Overwrite => InsertOp::Overwrite, + protobuf::InsertOp::Replace => InsertOp::Replace, + }; Ok(Self { object_store_url: ObjectStoreUrl::parse(&conf.object_store_url)?, file_groups, table_paths, output_schema: Arc::new(convert_required!(conf.output_schema)?), table_partition_cols, - overwrite: conf.overwrite, + insert_op, keep_partition_by_columns: conf.keep_partition_by_columns, }) } diff --git a/datafusion/proto/src/physical_plan/to_proto.rs b/datafusion/proto/src/physical_plan/to_proto.rs index 6981c77228a8..6f6065a1c284 100644 --- a/datafusion/proto/src/physical_plan/to_proto.rs +++ b/datafusion/proto/src/physical_plan/to_proto.rs @@ -642,8 +642,8 @@ impl TryFrom<&FileSinkConfig> for protobuf::FileSinkConfig { table_paths, output_schema: Some(conf.output_schema.as_ref().try_into()?), table_partition_cols, - overwrite: conf.overwrite, keep_partition_by_columns: conf.keep_partition_by_columns, + insert_op: conf.insert_op as i32, }) } } diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index db84a08e5b40..025676f790a8 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -27,6 +27,7 @@ use arrow::csv::WriterBuilder; use arrow::datatypes::{Fields, TimeUnit}; use datafusion::physical_expr::aggregate::AggregateExprBuilder; use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; +use datafusion_expr::dml::InsertOp; use datafusion_functions_aggregate::approx_percentile_cont::approx_percentile_cont_udaf; use datafusion_functions_aggregate::array_agg::array_agg_udaf; use datafusion_functions_aggregate::min_max::max_udaf; @@ -1143,7 +1144,7 @@ fn roundtrip_json_sink() -> Result<()> { table_paths: vec![ListingTableUrl::parse("file:///")?], output_schema: schema.clone(), table_partition_cols: vec![("plan_type".to_string(), DataType::Utf8)], - overwrite: true, + insert_op: InsertOp::Overwrite, keep_partition_by_columns: true, }; let data_sink = Arc::new(JsonSink::new( @@ -1179,7 +1180,7 @@ fn roundtrip_csv_sink() -> Result<()> { table_paths: vec![ListingTableUrl::parse("file:///")?], output_schema: schema.clone(), table_partition_cols: vec![("plan_type".to_string(), DataType::Utf8)], - overwrite: true, + insert_op: InsertOp::Overwrite, keep_partition_by_columns: true, }; let data_sink = Arc::new(CsvSink::new( @@ -1238,7 +1239,7 @@ fn roundtrip_parquet_sink() -> Result<()> { table_paths: vec![ListingTableUrl::parse("file:///")?], output_schema: schema.clone(), table_partition_cols: vec![("plan_type".to_string(), DataType::Utf8)], - overwrite: true, + insert_op: InsertOp::Overwrite, keep_partition_by_columns: true, }; let data_sink = Arc::new(ParquetSink::new( diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index 29dfe25993f1..895285c59737 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -37,7 +37,7 @@ use datafusion_common::{ DataFusionError, Result, ScalarValue, SchemaError, SchemaReference, TableReference, ToDFSchema, }; -use datafusion_expr::dml::CopyTo; +use datafusion_expr::dml::{CopyTo, InsertOp}; use datafusion_expr::expr_rewriter::normalize_col_with_schemas_and_ambiguity_check; use datafusion_expr::logical_plan::builder::project; use datafusion_expr::logical_plan::DdlStatement; @@ -53,7 +53,7 @@ use datafusion_expr::{ TransactionConclusion, TransactionEnd, TransactionIsolationLevel, TransactionStart, Volatility, WriteOp, }; -use sqlparser::ast; +use sqlparser::ast::{self, SqliteOnConflict}; use sqlparser::ast::{ Assignment, AssignmentTarget, ColumnDef, CreateIndex, CreateTable, CreateTableOptions, Delete, DescribeAlias, Expr as SQLExpr, FromTable, Ident, Insert, @@ -665,12 +665,15 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { returning, ignore, table_alias, - replace_into, + mut replace_into, priority, insert_alias, }) => { - if or.is_some() { - plan_err!("Inserts with or clauses not supported")?; + if let Some(or) = or { + match or { + SqliteOnConflict::Replace => replace_into = true, + _ => plan_err!("Inserts with {or} clause is not supported")?, + } } if partitioned.is_some() { plan_err!("Partitioned inserts not yet supported")?; @@ -698,9 +701,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { "Inserts with a table alias not supported: {table_alias:?}" )? }; - if replace_into { - plan_err!("Inserts with a `REPLACE INTO` clause not supported")? - }; if let Some(priority) = priority { plan_err!( "Inserts with a `PRIORITY` clause not supported: {priority:?}" @@ -710,7 +710,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { plan_err!("Inserts with an alias not supported")?; } let _ = into; // optional keyword doesn't change behavior - self.insert_to_plan(table_name, columns, source, overwrite) + self.insert_to_plan(table_name, columns, source, overwrite, replace_into) } Statement::Update { table, @@ -1605,6 +1605,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { columns: Vec, source: Box, overwrite: bool, + replace_into: bool, ) -> Result { // Do a table lookup to verify the table exists let table_name = self.object_name_to_table_reference(table_name)?; @@ -1707,16 +1708,17 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .collect::>>()?; let source = project(source, exprs)?; - let op = if overwrite { - WriteOp::InsertOverwrite - } else { - WriteOp::InsertInto + let insert_op = match (overwrite, replace_into) { + (false, false) => InsertOp::Append, + (true, false) => InsertOp::Overwrite, + (false, true) => InsertOp::Replace, + (true, true) => plan_err!("Conflicting insert operations: `overwrite` and `replace_into` cannot both be true")?, }; let plan = LogicalPlan::Dml(DmlStatement::new( table_name, Arc::new(table_schema), - op, + WriteOp::Insert(insert_op), Arc::new(source), )); Ok(plan) diff --git a/docs/source/user-guide/introduction.md b/docs/source/user-guide/introduction.md index 8f8983061eb6..7c975055d152 100644 --- a/docs/source/user-guide/introduction.md +++ b/docs/source/user-guide/introduction.md @@ -96,6 +96,7 @@ Here are some active projects using DataFusion: - [Arroyo](https://github.com/ArroyoSystems/arroyo) Distributed stream processing engine in Rust - [Ballista](https://github.com/apache/datafusion-ballista) Distributed SQL Query Engine +- [Blaze](https://github.com/kwai/blaze) The Blaze accelerator for Apache Spark leverages native vectorized execution to accelerate query processing - [CnosDB](https://github.com/cnosdb/cnosdb) Open Source Distributed Time Series Database - [Comet](https://github.com/apache/datafusion-comet) Apache Spark native query execution plugin - [Cube Store](https://github.com/cube-js/cube.js/tree/master/rust) @@ -124,7 +125,6 @@ Here are some active projects using DataFusion: Here are some less active projects that used DataFusion: - [bdt](https://github.com/datafusion-contrib/bdt) Boring Data Tool -- [Blaze](https://github.com/blaze-init/blaze) Spark accelerator with DataFusion at its core - [Cloudfuse Buzz](https://github.com/cloudfuse-io/buzz-rust) - [datafusion-tui](https://github.com/datafusion-contrib/datafusion-tui) Text UI for DataFusion - [Flock](https://github.com/flock-lab/flock)