From 8ee1dfe3b0aa0d74b98e442fda777819fadc6036 Mon Sep 17 00:00:00 2001 From: jefffffyang Date: Thu, 26 Sep 2024 18:49:58 +0800 Subject: [PATCH 1/2] Repartition/Sort/SortPreservingMerge will progate input error and continue execution --- .../physical-plan/src/repartition/mod.rs | 36 +++-- datafusion/physical-plan/src/sorts/merge.rs | 16 ++- datafusion/physical-plan/src/sorts/sort.rs | 135 +++++++++++++++--- .../src/sorts/sort_preserving_merge.rs | 95 +++++++++++- 4 files changed, 243 insertions(+), 39 deletions(-) diff --git a/datafusion/physical-plan/src/repartition/mod.rs b/datafusion/physical-plan/src/repartition/mod.rs index 5a3fcb5029e1..cb6271a86956 100644 --- a/datafusion/physical-plan/src/repartition/mod.rs +++ b/datafusion/physical-plan/src/repartition/mod.rs @@ -792,7 +792,21 @@ impl RepartitionExec { // Input is done let batch = match result { - Some(result) => result?, + Some(Ok(result)) => result, + Some(Err(e)) => { + // Error from running input task. Propagate error to all output partitions + let e = Arc::new(e); + + for (tx, _) in output_channels.values() { + // wrap it because need to send error to all output partitions + let err = + Err(DataFusionError::External(Box::new(Arc::clone(&e)))); + tx.send(Some(err)).await.ok(); + } + + // Continue pulling inputs + continue; + } None => break, }; @@ -1236,23 +1250,25 @@ mod tests { let err = exec_err!("bad data error"); let schema = batch.schema(); - let input = MockExec::new(vec![Ok(batch), err], schema); + let input = MockExec::new(vec![err, Ok(batch.clone())], schema); let partitioning = Partitioning::RoundRobinBatch(1); let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap(); // Note: this should pass (the stream can be created) but the // error when the input is executed should get passed back - let output_stream = exec.execute(0, task_ctx).unwrap(); + let mut output_stream = exec.execute(0, task_ctx).unwrap(); - // Expect that an error is returned - let result_string = crate::common::collect(output_stream) - .await - .unwrap_err() - .to_string(); + // Ensure the repartition could poll the stream continuously even if error happens + let error_string = output_stream.next().await.unwrap().unwrap_err().to_string(); assert!( - result_string.contains("bad data error"), - "actual: {result_string}" + error_string.contains("bad data error"), + "actual: {error_string}" ); + + let result = output_stream.next().await.unwrap().unwrap(); + assert_eq!(result, batch); + + assert!(output_stream.next().await.is_none()); } #[tokio::test] diff --git a/datafusion/physical-plan/src/sorts/merge.rs b/datafusion/physical-plan/src/sorts/merge.rs index 85418ff36119..9a426f976f27 100644 --- a/datafusion/physical-plan/src/sorts/merge.rs +++ b/datafusion/physical-plan/src/sorts/merge.rs @@ -155,14 +155,15 @@ impl SortPreservingMergeStream { return Poll::Ready(None); } // try to initialize the loser tree - if self.loser_tree.is_empty() { - // Ensure all non-exhausted streams have a cursor from which - // rows can be pulled - for i in 0..self.streams.partitions() { + if self.loser_tree.len() != self.streams.partitions() { + // Loser tree is not constructed, continue to initialize the loser tree. + // Through this way, the ith stream will be polled until the first `OK(batch)` + // is returned + for i in self.loser_tree.len()..self.streams.partitions() { if let Err(e) = ready!(self.maybe_poll_stream(cx, i)) { - self.aborted = true; return Poll::Ready(Some(Err(e))); } + self.loser_tree.push(usize::MAX); } self.init_loser_tree(); } @@ -177,7 +178,8 @@ impl SortPreservingMergeStream { if !self.loser_tree_adjusted { let winner = self.loser_tree[0]; if let Err(e) = ready!(self.maybe_poll_stream(cx, winner)) { - self.aborted = true; + // Propagate the error from input. The next poll call will poll the + // same stream return Poll::Ready(Some(Err(e))); } self.update_loser_tree(); @@ -275,7 +277,7 @@ impl SortPreservingMergeStream { /// non exhausted input, if possible fn init_loser_tree(&mut self) { // Init loser tree - self.loser_tree = vec![usize::MAX; self.cursors.len()]; + assert_eq!(self.loser_tree.len(), self.cursors.len()); for i in 0..self.cursors.len() { let mut winner = i; let mut cmp_node = self.lt_leaf_node_index(i); diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index a81b09948cca..9b2659a63bb5 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -52,7 +52,7 @@ use datafusion_execution::runtime_env::RuntimeEnv; use datafusion_execution::TaskContext; use datafusion_physical_expr::LexOrdering; -use futures::{StreamExt, TryStreamExt}; +use futures::{Future, StreamExt, TryStreamExt}; use log::{debug, trace}; struct ExternalSorterMetrics { @@ -868,14 +868,14 @@ impl ExecutionPlan for SortExec { ) -> Result { trace!("Start SortExec::execute for partition {} of context session_id {} and task_id {:?}", partition, context.session_id(), context.task_id()); - let mut input = self.input.execute(partition, Arc::clone(&context))?; + let input = self.input.execute(partition, Arc::clone(&context))?; let execution_options = &context.session_config().options().execution; trace!("End SortExec's input.execute for partition: {}", partition); if let Some(fetch) = self.fetch.as_ref() { - let mut topk = TopK::try_new( + let topk = TopK::try_new( partition, input.schema(), self.expr.clone(), @@ -888,17 +888,14 @@ impl ExecutionPlan for SortExec { Ok(Box::pin(RecordBatchStreamAdapter::new( self.schema(), - futures::stream::once(async move { - while let Some(batch) = input.next().await { - let batch = batch?; - topk.insert_batch(batch)?; - } - topk.emit() - }) - .try_flatten(), + futures::stream::unfold( + SortStreamState::new(input, topk), + SortStreamState::poll_next, + ) + .fuse(), ))) } else { - let mut sorter = ExternalSorter::new( + let sorter = ExternalSorter::new( partition, input.schema(), self.expr.clone(), @@ -912,14 +909,11 @@ impl ExecutionPlan for SortExec { Ok(Box::pin(RecordBatchStreamAdapter::new( self.schema(), - futures::stream::once(async move { - while let Some(batch) = input.next().await { - let batch = batch?; - sorter.insert_batch(batch).await?; - } - sorter.sort() - }) - .try_flatten(), + futures::stream::unfold( + SortStreamState::new(input, sorter), + SortStreamState::poll_next, + ) + .fuse(), ))) } } @@ -948,6 +942,107 @@ impl ExecutionPlan for SortExec { } } +/// Sorter that sorts a stream of batches +trait Sorter: Send + 'static { + /// Sink the record batch into the sorter + fn sink(&mut self, batch: RecordBatch) -> impl Future> + Send; + + /// Get the source stream + fn source(self) -> Result; +} + +impl Sorter for TopK { + async fn sink(&mut self, batch: RecordBatch) -> Result<()> { + self.insert_batch(batch) + } + + fn source(self) -> Result { + self.emit() + } +} + +impl Sorter for ExternalSorter { + fn sink(&mut self, batch: RecordBatch) -> impl Future> + Send { + self.insert_batch(batch) + } + + fn source(mut self) -> Result { + self.sort() + } +} + +enum SortStreamState { + SinkPhase { + input: SendableRecordBatchStream, + sorter: Option, + }, + SourcePhase { + output: SendableRecordBatchStream, + }, +} + +impl SortStreamState { + fn new(input: SendableRecordBatchStream, sorter: T) -> Self { + Self::SinkPhase { + input, + sorter: Some(sorter), + } + } + + async fn poll_next(self) -> Option<(Result, Self)> { + match self { + Self::SinkPhase { + mut input, + mut sorter, + } => { + while let Some(sorter_) = &mut sorter { + if let Some(item) = input.next().await { + match item { + Ok(batch) => { + if let Err(e) = sorter_.sink(batch).await { + return Some(( + Err(e), + Self::SinkPhase { + input, + sorter: None, + }, + )); + } + } + Err(e) => { + // Propagate the error + return Some((Err(e), Self::SinkPhase { input, sorter })); + } + } + } else { + let sorter = sorter.take().unwrap(); + match sorter.source() { + Ok(mut output) => { + let rb = output.next().await; + return rb.map(|rb| (rb, Self::SourcePhase { output })); + } + Err(e) => { + return Some(( + Err(e), + Self::SinkPhase { + input, + sorter: None, + }, + )); + } + } + } + } + None + } + Self::SourcePhase { mut output } => { + let rb = output.next().await; + rb.map(|rb| (rb, Self::SourcePhase { output })) + } + } + } +} + #[cfg(test)] mod tests { use std::collections::HashMap; diff --git a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs index 131fa71217cc..ec388ab2d12b 100644 --- a/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs +++ b/datafusion/physical-plan/src/sorts/sort_preserving_merge.rs @@ -308,7 +308,9 @@ mod tests { use crate::metrics::{MetricValue, Timestamp}; use crate::sorts::sort::SortExec; use crate::stream::RecordBatchReceiverStream; - use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec}; + use crate::test::exec::{ + assert_strong_count_converges_to_zero, BlockingExec, MockExec, + }; use crate::test::{self, assert_is_pending, make_partition}; use crate::{collect, common}; @@ -316,7 +318,7 @@ mod tests { use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; - use datafusion_common::{assert_batches_eq, assert_contains}; + use datafusion_common::{assert_batches_eq, assert_contains, exec_err}; use datafusion_execution::config::SessionConfig; use futures::{FutureExt, StreamExt}; @@ -1141,4 +1143,93 @@ mod tests { collected.as_slice() ); } + + #[tokio::test] + async fn test_merge_with_error_in_stream() { + let task_ctx = Arc::new(TaskContext::default()); + let schema = Arc::new(Schema::new(vec![ + Field::new("sorted", DataType::Utf8, false), + Field::new("payload", DataType::Int32, false), + ])); + + // Error returned during updating the loser tree + let t0 = MockExec::new( + vec![ + Ok(RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(StringArray::from(vec!["bar", "foo"])) as ArrayRef, + Arc::new(Arc::new(Int32Array::from(vec![10, 2])) as ArrayRef), + ], + ) + .unwrap()), + exec_err!("t0 bad data"), + Ok(RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(StringArray::from(vec!["zip"])) as ArrayRef, + Arc::new(Arc::new(Int32Array::from(vec![111])) as ArrayRef), + ], + ) + .unwrap()), + ], + Arc::clone(&schema), + ); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(StringArray::from(vec!["arrow", "datafusion"])) as ArrayRef, + Arc::new(Arc::new(Int32Array::from(vec![1, 99])) as ArrayRef), + ], + ) + .unwrap(); + + // Error returned when constructing the loser tree + let t1 = MockExec::new( + vec![exec_err!("t1 bad data"), Ok(batch)], + Arc::clone(&schema), + ); + + let sort_exprs = vec![PhysicalSortExpr { + expr: col("sorted", &schema).unwrap(), + options: Default::default(), + }]; + + let merge = SortPreservingMergeExec::new( + sort_exprs, + Arc::new(crate::union::UnionExec::new(vec![ + Arc::new(t0), + Arc::new(t1), + ])), + ); + + let mut stream = merge.execute(0, task_ctx).unwrap(); + + let err = stream.next().await.unwrap().unwrap_err().to_string(); + + assert!(err.contains("t1 bad data"), "actual: {err}"); + + let err = stream.next().await.unwrap().unwrap_err().to_string(); + assert!(err.contains("t0 bad data"), "actual: {err}"); + + let batch = stream.next().await.unwrap().unwrap(); + + assert_batches_eq!( + &[ + "+------------+---------+", + "| sorted | payload |", + "+------------+---------+", + "| arrow | 1 |", + "| bar | 10 |", + "| datafusion | 99 |", + "| foo | 2 |", + "| zip | 111 |", + "+------------+---------+", + ], + &[batch] + ); + + assert!(stream.next().await.is_none()); + } } From 6488a61dd9583828b7b3f82d1cfbc4af33679ffe Mon Sep 17 00:00:00 2001 From: jefffffyang Date: Fri, 27 Sep 2024 10:33:54 +0800 Subject: [PATCH 2/2] test_sort_with_error_in_stream --- datafusion/physical-plan/src/sorts/sort.rs | 70 +++++++++++++++++++++- 1 file changed, 68 insertions(+), 2 deletions(-) diff --git a/datafusion/physical-plan/src/sorts/sort.rs b/datafusion/physical-plan/src/sorts/sort.rs index 9b2659a63bb5..dedb159114f8 100644 --- a/datafusion/physical-plan/src/sorts/sort.rs +++ b/datafusion/physical-plan/src/sorts/sort.rs @@ -1054,7 +1054,9 @@ mod tests { use crate::memory::MemoryExec; use crate::test; use crate::test::assert_is_pending; - use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec}; + use crate::test::exec::{ + assert_strong_count_converges_to_zero, BlockingExec, MockExec, + }; use arrow::array::*; use arrow::compute::SortOptions; @@ -1063,7 +1065,7 @@ mod tests { use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::RuntimeConfig; - use datafusion_common::ScalarValue; + use datafusion_common::{assert_batches_eq, exec_err, ScalarValue}; use datafusion_physical_expr::expressions::Literal; use futures::FutureExt; @@ -1504,4 +1506,68 @@ mod tests { let result = sort_batch(&batch, &expressions, None).unwrap(); assert_eq!(result.num_rows(), 1); } + + #[tokio::test] + async fn test_sort_with_error_in_stream() { + let task_ctx = Arc::new(TaskContext::default()); + let schema = Arc::new(Schema::new(vec![ + Field::new("key", DataType::Utf8, false), + Field::new("payload", DataType::Int32, false), + ])); + + let mock_exec = MockExec::new( + vec![ + exec_err!("bad data 0"), + Ok(RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(StringArray::from(vec!["datafusion", "arrow"])) + as ArrayRef, + Arc::new(Arc::new(Int32Array::from(vec![10, 2])) as ArrayRef), + ], + ) + .unwrap()), + exec_err!("bad data 1"), + Ok(RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(StringArray::from(vec!["parquet", "comet"])) as ArrayRef, + Arc::new(Arc::new(Int32Array::from(vec![7, 99])) as ArrayRef), + ], + ) + .unwrap()), + ], + Arc::clone(&schema), + ); + + let sort_exprs = vec![PhysicalSortExpr { + expr: col("key", &schema).unwrap(), + options: Default::default(), + }]; + + let sort = SortExec::new(sort_exprs, Arc::new(mock_exec)); + let mut stream = sort.execute(0, task_ctx).unwrap(); + + let err = stream.next().await.unwrap().unwrap_err().to_string(); + assert!(err.contains("bad data 0"), "actual: {err}"); + + let err = stream.next().await.unwrap().unwrap_err().to_string(); + assert!(err.contains("bad data 1"), "actual: {err}"); + + let batch = stream.next().await.unwrap().unwrap(); + + assert_batches_eq!( + &[ + "+------------+---------+", + "| key | payload |", + "+------------+---------+", + "| arrow | 2 |", + "| comet | 99 |", + "| datafusion | 10 |", + "| parquet | 7 |", + "+------------+---------+", + ], + &[batch] + ); + } }