diff --git a/datafusion/core/tests/memory_limit.rs b/datafusion/core/tests/memory_limit.rs index 5b18d616b3f9..a7cff6cbd758 100644 --- a/datafusion/core/tests/memory_limit.rs +++ b/datafusion/core/tests/memory_limit.rs @@ -19,12 +19,13 @@ use arrow::datatypes::SchemaRef; use arrow::record_batch::RecordBatch; +use datafusion::physical_optimizer::PhysicalOptimizerRule; use datafusion::physical_plan::streaming::PartitionStream; use futures::StreamExt; use std::sync::Arc; use datafusion::datasource::streaming::StreamingTable; -use datafusion::datasource::MemTable; +use datafusion::datasource::{MemTable, TableProvider}; use datafusion::execution::context::SessionState; use datafusion::execution::disk_manager::DiskManagerConfig; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; @@ -46,19 +47,20 @@ fn init() { #[tokio::test] async fn oom_sort() { - run_limit_test( + TestCase::new( "select * from t order by host DESC", vec![ "Resources exhausted: Memory Exhausted while Sorting (DiskManager is disabled)", ], 200_000, ) - .await + .run() + .await } #[tokio::test] async fn group_by_none() { - run_limit_test( + TestCase::new( "select median(image) from t", vec![ "Resources exhausted: Failed to allocate additional", @@ -66,12 +68,13 @@ async fn group_by_none() { ], 20_000, ) + .run() .await } #[tokio::test] async fn group_by_row_hash() { - run_limit_test( + TestCase::new( "select count(*) from t GROUP BY response_bytes", vec![ "Resources exhausted: Failed to allocate additional", @@ -79,12 +82,13 @@ async fn group_by_row_hash() { ], 2_000, ) + .run() .await } #[tokio::test] async fn group_by_hash() { - run_limit_test( + TestCase::new( // group by dict column "select count(*) from t GROUP BY service, host, pod, container", vec![ @@ -93,42 +97,45 @@ async fn group_by_hash() { ], 1_000, ) + .run() .await } #[tokio::test] async fn join_by_key_multiple_partitions() { let config = SessionConfig::new().with_target_partitions(2); - run_limit_test_with_config( + TestCase::new( "select t1.* from t t1 JOIN t t2 ON t1.service = t2.service", vec![ "Resources exhausted: Failed to allocate additional", "HashJoinInput[0]", ], 1_000, - config, ) + .with_config(config) + .run() .await } #[tokio::test] async fn join_by_key_single_partition() { let config = SessionConfig::new().with_target_partitions(1); - run_limit_test_with_config( + TestCase::new( "select t1.* from t t1 JOIN t t2 ON t1.service = t2.service", vec![ "Resources exhausted: Failed to allocate additional", "HashJoinInput", ], 1_000, - config, ) + .with_config(config) + .run() .await } #[tokio::test] async fn join_by_expression() { - run_limit_test( + TestCase::new( "select t1.* from t t1 JOIN t t2 ON t1.service != t2.service", vec![ "Resources exhausted: Failed to allocate additional", @@ -136,12 +143,13 @@ async fn join_by_expression() { ], 1_000, ) + .run() .await } #[tokio::test] async fn cross_join() { - run_limit_test( + TestCase::new( "select t1.* from t t1 CROSS JOIN t t2", vec![ "Resources exhausted: Failed to allocate additional", @@ -149,6 +157,7 @@ async fn cross_join() { ], 1_000, ) + .run() .await } @@ -159,94 +168,185 @@ async fn merge_join() { .with_target_partitions(2) .set_bool("datafusion.optimizer.prefer_hash_join", false); - run_limit_test_with_config( + TestCase::new( "select t1.* from t t1 JOIN t t2 ON t1.pod = t2.pod AND t1.time = t2.time", vec![ "Resources exhausted: Failed to allocate additional", "SMJStream", ], 1_000, - config, ) + .with_config(config) + .run() .await } #[tokio::test] -async fn test_limit_symmetric_hash_join() { - let config = SessionConfig::new(); - - run_streaming_test_with_config( +async fn symmetric_hash_join() { + TestCase::new( "select t1.* from t t1 JOIN t t2 ON t1.pod = t2.pod AND t1.time = t2.time", vec![ "Resources exhausted: Failed to allocate additional", "SymmetricHashJoinStream", ], 1_000, - config, ) + .with_scenario(Scenario::AccessLogStreaming) + .run() .await } -/// 50 byte memory limit -const MEMORY_FRACTION: f64 = 0.95; - -/// runs the specified query against 1000 rows with specified -/// memory limit and no disk manager enabled with default SessionConfig. -async fn run_limit_test( - query: &str, - expected_error_contains: Vec<&str>, +/// Run the query with the specified memory limit, +/// and verifies the expected errors are returned +#[derive(Clone, Debug)] +struct TestCase { + query: String, + expected_errors: Vec, memory_limit: usize, -) { - let config = SessionConfig::new(); - run_limit_test_with_config(query, expected_error_contains, memory_limit, config).await + config: SessionConfig, + scenario: Scenario, } -/// runs the specified query against 1000 rows with a 50 -/// byte memory limit and no disk manager enabled -/// with specified SessionConfig instance -async fn run_limit_test_with_config( - query: &str, - expected_error_contains: Vec<&str>, - memory_limit: usize, - config: SessionConfig, -) { - let batches: Vec<_> = AccessLogGenerator::new() - .with_row_limit(1000) - .with_max_batch_size(50) - .collect(); +impl TestCase { + fn new<'a>( + query: impl Into, + expected_errors: impl IntoIterator, + memory_limit: usize, + ) -> Self { + let expected_errors: Vec = + expected_errors.into_iter().map(|s| s.to_string()).collect(); + + Self { + query: query.into(), + expected_errors, + memory_limit, + config: SessionConfig::new(), + scenario: Scenario::AccessLog, + } + } - let table = MemTable::try_new(batches[0].schema(), vec![batches]).unwrap(); + /// Specify the configuration to use + pub fn with_config(mut self, config: SessionConfig) -> Self { + self.config = config; + self + } - let rt_config = RuntimeConfig::new() - // do not allow spilling - .with_disk_manager(DiskManagerConfig::Disabled) - .with_memory_limit(memory_limit, MEMORY_FRACTION); + /// Specify the scenario to run + pub fn with_scenario(mut self, scenario: Scenario) -> Self { + self.scenario = scenario; + self + } + + /// Run the test, panic'ing on error + async fn run(self) { + let Self { + query, + expected_errors, + memory_limit, + config, + scenario, + } = self; + + let table = scenario.table(); - let runtime = RuntimeEnv::new(rt_config).unwrap(); + let rt_config = RuntimeConfig::new() + // do not allow spilling + .with_disk_manager(DiskManagerConfig::Disabled) + .with_memory_limit(memory_limit, MEMORY_FRACTION); - // Disabling physical optimizer rules to avoid sorts / repartitions - // (since RepartitionExec / SortExec also has a memory budget which we'll likely hit first) - let state = SessionState::with_config_rt(config, Arc::new(runtime)) - .with_physical_optimizer_rules(vec![]); + let runtime = RuntimeEnv::new(rt_config).unwrap(); - let ctx = SessionContext::with_state(state); - ctx.register_table("t", Arc::new(table)) - .expect("registering table"); + // Configure execution + let state = SessionState::with_config_rt(config, Arc::new(runtime)) + .with_physical_optimizer_rules(scenario.rules()); - let df = ctx.sql(query).await.expect("Planning query"); + let ctx = SessionContext::with_state(state); + ctx.register_table("t", table).expect("registering table"); - match df.collect().await { - Ok(_batches) => { - panic!("Unexpected success when running, expected memory limit failure") + let df = ctx.sql(&query).await.expect("Planning query"); + + match df.collect().await { + Ok(_batches) => { + panic!("Unexpected success when running, expected memory limit failure") + } + Err(e) => { + for error_substring in expected_errors { + assert_contains!(e.to_string(), error_substring); + } + } } - Err(e) => { - for error_substring in expected_error_contains { - assert_contains!(e.to_string(), error_substring); + } +} + +/// 50 byte memory limit +const MEMORY_FRACTION: f64 = 0.95; + +/// Different data scenarios +#[derive(Clone, Debug)] +enum Scenario { + /// 1000 rows of access log data with batches of 50 rows + AccessLog, + + /// 1000 rows of access log data with batches of 50 rows in a + /// [`StreamingTable`] + AccessLogStreaming, +} + +impl Scenario { + /// return a TableProvider with data for the test + fn table(&self) -> Arc { + match self { + Self::AccessLog => { + let batches = access_log_batches(); + let table = + MemTable::try_new(batches[0].schema(), vec![batches]).unwrap(); + Arc::new(table) + } + Self::AccessLogStreaming => { + let batches = access_log_batches(); + + // Create a new streaming table with the generated schema and batches + let table = StreamingTable::try_new( + batches[0].schema(), + vec![Arc::new(DummyStreamPartition { + schema: batches[0].schema(), + batches: batches.clone(), + })], + ) + .unwrap() + .with_infinite_table(true); + Arc::new(table) + } + } + } + + /// return the optimizer rules to use + fn rules(&self) -> Vec> { + match self { + Self::AccessLog => { + // Disabling physical optimizer rules to avoid sorts / + // repartitions (since RepartitionExec / SortExec also + // has a memory budget which we'll likely hit first) + vec![] + } + Self::AccessLogStreaming => { + // Disable all physical optimizer rules except the + // JoinSelection rule to avoid sorts or repartition, + // as they also have memory budgets that may be hit + // first + vec![Arc::new(JoinSelection::new())] } } } } +fn access_log_batches() -> Vec { + AccessLogGenerator::new() + .with_row_limit(1000) + .with_max_batch_size(50) + .collect() +} + struct DummyStreamPartition { schema: SchemaRef, batches: Vec, @@ -266,66 +366,3 @@ impl PartitionStream for DummyStreamPartition { )) } } - -async fn run_streaming_test_with_config( - query: &str, - expected_error_contains: Vec<&str>, - memory_limit: usize, - config: SessionConfig, -) { - // Generate a set of access logs with a row limit of 1000 and a max batch size of 50 - let batches: Vec<_> = AccessLogGenerator::new() - .with_row_limit(1000) - .with_max_batch_size(50) - .collect(); - - // Create a new streaming table with the generated schema and batches - let table = StreamingTable::try_new( - batches[0].schema(), - vec![Arc::new(DummyStreamPartition { - schema: batches[0].schema(), - batches: batches.clone(), - })], - ) - .unwrap() - .with_infinite_table(true); - - // Configure the runtime environment with custom settings - let rt_config = RuntimeConfig::new() - // Disable disk manager to disallow spilling - .with_disk_manager(DiskManagerConfig::Disabled) - // Set memory limit to 50 bytes - .with_memory_limit(memory_limit, MEMORY_FRACTION); - - // Create a new runtime environment with the configured settings - let runtime = RuntimeEnv::new(rt_config).unwrap(); - - // Create a new session state with the given configuration and runtime environment - // Disable all physical optimizer rules except the PipelineFixer rule to avoid sorts or - // repartition, as they also have memory budgets that may be hit first - let state = SessionState::with_config_rt(config, Arc::new(runtime)) - .with_physical_optimizer_rules(vec![Arc::new(JoinSelection::new())]); - - // Create a new session context with the session state - let ctx = SessionContext::with_state(state); - // Register the streaming table with the session context - ctx.register_table("t", Arc::new(table)) - .expect("registering table"); - - // Execute the SQL query and get a DataFrame - let df = ctx.sql(query).await.expect("Planning query"); - - // Collect the results of the DataFrame execution - match df.collect().await { - // If the execution succeeds, panic as we expect memory limit failure - Ok(_batches) => { - panic!("Unexpected success when running, expected memory limit failure") - } - // If the execution fails, verify if the error contains the expected substrings - Err(e) => { - for error_substring in expected_error_contains { - assert_contains!(e.to_string(), error_substring); - } - } - } -} diff --git a/datafusion/execution/src/config.rs b/datafusion/execution/src/config.rs index 97770eb99c58..3e607af705cb 100644 --- a/datafusion/execution/src/config.rs +++ b/datafusion/execution/src/config.rs @@ -25,7 +25,7 @@ use std::{ use datafusion_common::{config::ConfigOptions, Result, ScalarValue}; /// Configuration options for Execution context -#[derive(Clone)] +#[derive(Clone, Debug)] pub struct SessionConfig { /// Configuration options options: ConfigOptions,