From 3d0c281f47dfe7ad957bb459db39250d6e571956 Mon Sep 17 00:00:00 2001 From: Alex Qyoun-ae <4062971+MazterQyou@users.noreply.github.com> Date: Thu, 6 Apr 2023 13:44:03 +0400 Subject: [PATCH 1/2] feat: Support scalar subquery in `WHERE` --- datafusion/core/src/logical_plan/builder.rs | 8 +- datafusion/core/src/logical_plan/mod.rs | 2 +- datafusion/core/src/logical_plan/plan.rs | 74 ++++++-- .../core/src/optimizer/projection_drop_out.rs | 2 + .../src/optimizer/projection_push_down.rs | 8 +- datafusion/core/src/optimizer/utils.rs | 5 +- datafusion/core/src/physical_plan/planner.rs | 4 +- datafusion/core/src/physical_plan/subquery.rs | 146 +++++++++------- datafusion/core/src/sql/planner.rs | 161 ++++++++++++++---- datafusion/core/tests/sql/subquery.rs | 94 +++++++++- 10 files changed, 390 insertions(+), 114 deletions(-) diff --git a/datafusion/core/src/logical_plan/builder.rs b/datafusion/core/src/logical_plan/builder.rs index 1f8237d244a7..60e2c072489c 100644 --- a/datafusion/core/src/logical_plan/builder.rs +++ b/datafusion/core/src/logical_plan/builder.rs @@ -47,7 +47,8 @@ use super::{dfschema::ToDFSchema, expr_rewriter::coerce_plan_expr_for_schema, Di use super::{exprlist_to_fields, Expr, JoinConstraint, JoinType, LogicalPlan, PlanType}; use crate::logical_plan::{ columnize_expr, normalize_col, normalize_cols, rewrite_sort_cols_by_aggs, Column, - CrossJoin, DFField, DFSchema, DFSchemaRef, Limit, Partitioning, Repartition, Values, + CrossJoin, DFField, DFSchema, DFSchemaRef, Limit, Partitioning, Repartition, + SubqueryType, Values, }; use crate::sql::utils::group_window_expr_by_sort_keys; @@ -528,12 +529,15 @@ impl LogicalPlanBuilder { pub fn subquery( &self, subqueries: impl IntoIterator>, + types: impl IntoIterator, ) -> Result { let subqueries = subqueries.into_iter().map(|l| l.into()).collect::>(); - let schema = Arc::new(Subquery::merged_schema(&self.plan, &subqueries)); + let types = types.into_iter().collect::>(); + let schema = Arc::new(Subquery::merged_schema(&self.plan, &subqueries, &types)); Ok(Self::from(LogicalPlan::Subquery(Subquery { input: Arc::new(self.plan.clone()), subqueries, + types, schema, }))) } diff --git a/datafusion/core/src/logical_plan/mod.rs b/datafusion/core/src/logical_plan/mod.rs index ce7c342325ac..5c34ee7932d4 100644 --- a/datafusion/core/src/logical_plan/mod.rs +++ b/datafusion/core/src/logical_plan/mod.rs @@ -68,6 +68,6 @@ pub use plan::{ CreateCatalogSchema, CreateExternalTable, CreateMemoryTable, CrossJoin, Distinct, DropTable, EmptyRelation, Filter, JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, PlanVisitor, Repartition, StringifiedPlan, Subquery, - TableScan, ToStringifiedPlan, Union, Values, + SubqueryType, TableScan, ToStringifiedPlan, Union, Values, }; pub use registry::FunctionRegistry; diff --git a/datafusion/core/src/logical_plan/plan.rs b/datafusion/core/src/logical_plan/plan.rs index 3c8d5a72fa0e..41226b486bd9 100644 --- a/datafusion/core/src/logical_plan/plan.rs +++ b/datafusion/core/src/logical_plan/plan.rs @@ -267,22 +267,64 @@ pub struct Limit { /// Evaluates correlated sub queries #[derive(Clone)] pub struct Subquery { - /// The list of sub queries - pub subqueries: Vec, /// The incoming logical plan pub input: Arc, + /// The list of sub queries + pub subqueries: Vec, + /// The list of subquery types + pub types: Vec, /// The schema description of the output pub schema: DFSchemaRef, } +/// Subquery type +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub enum SubqueryType { + /// Scalar (SELECT, WHERE) evaluating to one value + Scalar, + // This will be extended with `Exists` and `AnyAll` types. +} + +impl Display for SubqueryType { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + let subquery_type = match self { + SubqueryType::Scalar => "Scalar", + }; + write!(f, "{}", subquery_type) + } +} + impl Subquery { /// Merge schema of main input and correlated subquery columns - pub fn merged_schema(input: &LogicalPlan, subqueries: &[LogicalPlan]) -> DFSchema { - subqueries.iter().fold((**input.schema()).clone(), |a, b| { - let mut res = a; - res.merge(b.schema()); - res - }) + pub fn merged_schema( + input: &LogicalPlan, + subqueries: &[LogicalPlan], + types: &[SubqueryType], + ) -> DFSchema { + subqueries.iter().zip(types.iter()).fold( + (**input.schema()).clone(), + |schema, (plan, typ)| { + let mut schema = schema; + schema.merge(&Self::transform_dfschema(plan.schema(), *typ)); + schema + }, + ) + } + + /// Transform DataFusion schema according to subquery type + pub fn transform_dfschema(schema: &DFSchema, typ: SubqueryType) -> DFSchema { + match typ { + SubqueryType::Scalar => schema.clone(), + // Schema will be transformed for `Exists` and `AnyAll` + } + } + + /// Transform Arrow field according to subquery type + pub fn transform_field(field: &Field, typ: SubqueryType) -> Field { + match typ { + SubqueryType::Scalar => field.clone(), + // Field will be transformed for `Exists` and `AnyAll` + } } } @@ -475,13 +517,23 @@ impl LogicalPlan { LogicalPlan::Values(Values { schema, .. }) => vec![schema], LogicalPlan::Window(Window { input, schema, .. }) | LogicalPlan::Projection(Projection { input, schema, .. }) - | LogicalPlan::Subquery(Subquery { input, schema, .. }) | LogicalPlan::Aggregate(Aggregate { input, schema, .. }) | LogicalPlan::TableUDFs(TableUDFs { input, schema, .. }) => { let mut schemas = input.all_schemas(); schemas.insert(0, schema); schemas } + LogicalPlan::Subquery(Subquery { + input, + subqueries, + schema, + .. + }) => { + let mut schemas = input.all_schemas(); + schemas.extend(subqueries.iter().map(|s| s.schema())); + schemas.insert(0, schema); + schemas + } LogicalPlan::Join(Join { left, right, @@ -1063,7 +1115,9 @@ impl LogicalPlan { } Ok(()) } - LogicalPlan::Subquery(Subquery { .. }) => write!(f, "Subquery"), + LogicalPlan::Subquery(Subquery { types, .. }) => { + write!(f, "Subquery: types={:?}", types) + } LogicalPlan::Filter(Filter { predicate: ref expr, .. diff --git a/datafusion/core/src/optimizer/projection_drop_out.rs b/datafusion/core/src/optimizer/projection_drop_out.rs index e96d365d1e9d..479c9ca917f2 100644 --- a/datafusion/core/src/optimizer/projection_drop_out.rs +++ b/datafusion/core/src/optimizer/projection_drop_out.rs @@ -254,6 +254,7 @@ fn optimize_plan( LogicalPlan::Subquery(Subquery { input, subqueries, + types, schema, }) => { // TODO: subqueries are not optimized @@ -269,6 +270,7 @@ fn optimize_plan( .map(|(p, _)| p)?, ), subqueries: subqueries.clone(), + types: types.clone(), schema: schema.clone(), }), None, diff --git a/datafusion/core/src/optimizer/projection_push_down.rs b/datafusion/core/src/optimizer/projection_push_down.rs index f4c76f8b6882..9c1cdb11bc9f 100644 --- a/datafusion/core/src/optimizer/projection_push_down.rs +++ b/datafusion/core/src/optimizer/projection_push_down.rs @@ -453,7 +453,10 @@ fn optimize_plan( })) } LogicalPlan::Subquery(Subquery { - input, subqueries, .. + input, + subqueries, + types, + .. }) => { let mut subquery_required_columns = HashSet::new(); for subquery in subqueries.iter() { @@ -484,11 +487,12 @@ fn optimize_plan( has_projection, _optimizer_config, )?; - let new_schema = Subquery::merged_schema(&input, subqueries); + let new_schema = Subquery::merged_schema(&input, subqueries, types); Ok(LogicalPlan::Subquery(Subquery { input: Arc::new(input), schema: Arc::new(new_schema), subqueries: subqueries.clone(), + types: types.clone(), })) } // all other nodes: Add any additional columns used by diff --git a/datafusion/core/src/optimizer/utils.rs b/datafusion/core/src/optimizer/utils.rs index 2e741eb892b3..0b8c460b9d01 100644 --- a/datafusion/core/src/optimizer/utils.rs +++ b/datafusion/core/src/optimizer/utils.rs @@ -161,10 +161,11 @@ pub fn from_plan( alias: alias.clone(), })) } - LogicalPlan::Subquery(Subquery { schema, .. }) => { + LogicalPlan::Subquery(Subquery { schema, types, .. }) => { Ok(LogicalPlan::Subquery(Subquery { - subqueries: inputs[1..inputs.len()].to_vec(), input: Arc::new(inputs[0].clone()), + subqueries: inputs[1..inputs.len()].to_vec(), + types: types.clone(), schema: schema.clone(), })) } diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs index 8f3b9c261a65..1347856d70f5 100644 --- a/datafusion/core/src/physical_plan/planner.rs +++ b/datafusion/core/src/physical_plan/planner.rs @@ -917,7 +917,7 @@ impl DefaultPhysicalPlanner { Ok(Arc::new(GlobalLimitExec::new(input, *skip, *fetch))) } - LogicalPlan::Subquery(Subquery { subqueries, input, schema }) => { + LogicalPlan::Subquery(Subquery { input, subqueries, types, schema }) => { let cursor = Arc::new(OuterQueryCursor::new(schema.as_ref().to_owned().into())); let mut new_session_state = session_state.clone(); new_session_state.execution_props = new_session_state.execution_props.with_outer_query_cursor(cursor.clone()); @@ -931,7 +931,7 @@ impl DefaultPhysicalPlanner { }) .collect::>(); let input = self.create_initial_plan(input, &new_session_state).await?; - Ok(Arc::new(SubqueryExec::try_new(subqueries, input, cursor)?)) + Ok(Arc::new(SubqueryExec::try_new(input, subqueries, types.clone(), cursor)?)) } LogicalPlan::CreateExternalTable(_) => { // There is no default plan for "CREATE EXTERNAL diff --git a/datafusion/core/src/physical_plan/subquery.rs b/datafusion/core/src/physical_plan/subquery.rs index c7ffad4cbd55..853fd92701b4 100644 --- a/datafusion/core/src/physical_plan/subquery.rs +++ b/datafusion/core/src/physical_plan/subquery.rs @@ -28,6 +28,7 @@ use std::sync::Arc; use std::task::{Context, Poll}; use crate::error::{DataFusionError, Result}; +use crate::logical_plan::{Subquery, SubqueryType}; use crate::physical_plan::{DisplayFormatType, ExecutionPlan, Partitioning}; use arrow::array::new_null_array; use arrow::datatypes::{Schema, SchemaRef}; @@ -45,12 +46,14 @@ use futures::stream::StreamExt; /// Execution plan for a sub query #[derive(Debug)] pub struct SubqueryExec { + /// The input plan + input: Arc, /// Sub queries subqueries: Vec>, + /// Subquery types + types: Vec, /// Merged schema schema: SchemaRef, - /// The input plan - input: Arc, /// Cursor used to send outer query column values to sub queries cursor: Arc, } @@ -58,15 +61,23 @@ pub struct SubqueryExec { impl SubqueryExec { /// Create a projection on an input pub fn try_new( - subqueries: Vec>, input: Arc, + subqueries: Vec>, + types: Vec, cursor: Arc, ) -> Result { let input_schema = input.schema(); let mut total_fields = input_schema.fields().clone(); - for q in subqueries.iter() { - total_fields.append(&mut q.schema().fields().clone()); + for (q, t) in subqueries.iter().zip(types.iter()) { + total_fields.append( + &mut q + .schema() + .fields() + .iter() + .map(|f| Subquery::transform_field(f, *t)) + .collect(), + ); } let merged_schema = Schema::new_with_metadata(total_fields, HashMap::new()); @@ -78,9 +89,10 @@ impl SubqueryExec { } Ok(Self { + input, subqueries, + types, schema: Arc::new(merged_schema), - input, cursor, }) } @@ -134,8 +146,9 @@ impl ExecutionPlan for SubqueryExec { } Ok(Arc::new(SubqueryExec::try_new( - children.iter().skip(1).cloned().collect(), children[0].clone(), + children.iter().skip(1).cloned().collect(), + self.types.clone(), self.cursor.clone(), )?)) } @@ -148,74 +161,83 @@ impl ExecutionPlan for SubqueryExec { let stream = self.input.execute(partition, context.clone()).await?; let cursor = self.cursor.clone(); let subqueries = self.subqueries.clone(); + let types = self.types.clone(); let context = context.clone(); let size_hint = stream.size_hint(); let schema = self.schema.clone(); - let res_stream = - stream.then(move |batch| { - let cursor = cursor.clone(); - let context = context.clone(); - let subqueries = subqueries.clone(); - let schema = schema.clone(); - async move { - let batch = batch?; - let b = Arc::new(batch.clone()); - cursor.set_batch(b)?; - let mut subquery_arrays = vec![Vec::new(); subqueries.len()]; - for i in 0..batch.num_rows() { - cursor.set_position(i)?; - for (subquery_i, subquery) in subqueries.iter().enumerate() { - let null_array = || { - let schema = subquery.schema(); - let fields = schema.fields(); - if fields.len() != 1 { - return Err(ArrowError::ComputeError(format!( - "Sub query should have only one column but got {}", - fields.len() - ))); - } - - let data_type = fields.get(0).unwrap().data_type(); - Ok(new_null_array(data_type, 1)) - }; + let res_stream = stream.then(move |batch| { + let cursor = cursor.clone(); + let context = context.clone(); + let subqueries = subqueries.clone(); + let types = types.clone(); + let schema = schema.clone(); + async move { + let batch = batch?; + let b = Arc::new(batch.clone()); + cursor.set_batch(b)?; + let mut subquery_arrays = vec![Vec::new(); subqueries.len()]; + for i in 0..batch.num_rows() { + cursor.set_position(i)?; + for (subquery_i, (subquery, subquery_type)) in + subqueries.iter().zip(types.iter()).enumerate() + { + let schema = subquery.schema(); + let fields = schema.fields(); + if fields.len() != 1 { + return Err(ArrowError::ComputeError(format!( + "Sub query should have only one column but got {}", + fields.len() + ))); + } + let data_type = fields.get(0).unwrap().data_type(); + let null_array = || new_null_array(data_type, 1); - if subquery.output_partitioning().partition_count() != 1 { - return Err(ArrowError::ComputeError(format!( - "Sub query should have only one partition but got {}", - subquery.output_partitioning().partition_count() - ))); - } - let mut stream = subquery.execute(0, context.clone()).await?; - let res = stream.next().await; - if let Some(subquery_batch) = res { - let subquery_batch = subquery_batch?; - match subquery_batch.column(0).len() { - 0 => subquery_arrays[subquery_i].push(null_array()?), + if subquery.output_partitioning().partition_count() != 1 { + return Err(ArrowError::ComputeError(format!( + "Sub query should have only one partition but got {}", + subquery.output_partitioning().partition_count() + ))); + } + let mut stream = subquery.execute(0, context.clone()).await?; + let res = stream.next().await; + if let Some(subquery_batch) = res { + let subquery_batch = subquery_batch?; + match subquery_type { + SubqueryType::Scalar => match subquery_batch + .column(0) + .len() + { + 0 => subquery_arrays[subquery_i].push(null_array()), 1 => subquery_arrays[subquery_i] .push(subquery_batch.column(0).clone()), _ => return Err(ArrowError::ComputeError( "Sub query should return no more than one row" .to_string(), )), - }; - } else { - subquery_arrays[subquery_i].push(null_array()?); - } + }, + }; + } else { + match subquery_type { + SubqueryType::Scalar => { + subquery_arrays[subquery_i].push(null_array()) + } + }; } } - let mut new_columns = batch.columns().to_vec(); - for subquery_array in subquery_arrays { - new_columns.push(concat( - subquery_array - .iter() - .map(|a| a.as_ref()) - .collect::>() - .as_slice(), - )?); - } - RecordBatch::try_new(schema.clone(), new_columns) } - }); + let mut new_columns = batch.columns().to_vec(); + for subquery_array in subquery_arrays { + new_columns.push(concat( + subquery_array + .iter() + .map(|a| a.as_ref()) + .collect::>() + .as_slice(), + )?); + } + RecordBatch::try_new(schema.clone(), new_columns) + } + }); Ok(Box::pin(SubQueryStream { schema: self.schema.clone(), stream: Box::pin(res_stream), diff --git a/datafusion/core/src/sql/planner.rs b/datafusion/core/src/sql/planner.rs index 59c7b068910a..bd790b5ccdb1 100644 --- a/datafusion/core/src/sql/planner.rs +++ b/datafusion/core/src/sql/planner.rs @@ -19,8 +19,9 @@ use std::collections::HashSet; use std::iter; +use std::ops::RangeFrom; use std::str::FromStr; -use std::sync::{Arc, RwLock}; +use std::sync::{Arc, Mutex, RwLock}; use std::{convert::TryInto, vec}; use crate::catalog::TableReference; @@ -32,7 +33,7 @@ use crate::logical_plan::{ and, builder::expand_qualified_wildcard, builder::expand_wildcard, col, lit, normalize_col, rewrite_udtfs_to_columns, Column, CreateMemoryTable, DFSchema, DFSchemaRef, DropTable, Expr, ExprSchemable, Like, LogicalPlan, LogicalPlanBuilder, - Operator, PlanType, ToDFSchema, ToStringifiedPlan, + Operator, PlanType, SubqueryType, ToDFSchema, ToStringifiedPlan, }; use crate::optimizer::utils::exprlist_to_columns; use crate::prelude::JoinType; @@ -97,13 +98,14 @@ pub struct SqlToRel<'a, S: ContextProvider> { schema_provider: &'a S, table_columns_precedence_over_projection: bool, context: SqlToRelContext, + subquery_alias_iter: Arc>>, } /// Planning context #[derive(Default)] pub struct SqlToRelContext { outer_query_context_schema: Vec, - subqueries_plans: Option>>, + subqueries_plans: Option>>, } impl SqlToRelContext { @@ -115,12 +117,16 @@ impl SqlToRelContext { } } - fn add_subquery_plan(&self, plan: LogicalPlan) -> Result<()> { - self.subqueries_plans.as_ref().ok_or_else(|| DataFusionError::Plan(format!("Sub query {:?} planned outside of sub query context. This type of sub query isn't supported", plan)))?.write().unwrap().push(plan); + fn add_subquery_plan( + &self, + plan: LogicalPlan, + subquery_type: SubqueryType, + ) -> Result<()> { + self.subqueries_plans.as_ref().ok_or_else(|| DataFusionError::Plan(format!("Sub query {:?} planned outside of sub query context. This type of sub query isn't supported", plan)))?.write().unwrap().push((plan, subquery_type)); Ok(()) } - fn subqueries_plans(&self) -> Result>> { + fn subqueries_plans(&self) -> Result>> { Ok(if let Some(subqueries) = self.subqueries_plans.as_ref() { Some( subqueries @@ -170,6 +176,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema_provider, table_columns_precedence_over_projection, context: SqlToRelContext::default(), + subquery_alias_iter: Arc::new(Mutex::new(0..)), } } @@ -182,6 +189,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { table_columns_precedence_over_projection: self .table_columns_precedence_over_projection, context, + subquery_alias_iter: Arc::clone(&self.subquery_alias_iter), } } @@ -877,6 +885,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { selection: Option, plans: Vec, ) -> Result { + // TODO: enable subqueries for joins let plan = match selection { Some(predicate_expr) => { // build join schema @@ -978,6 +987,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // remove join expressions from filter match remove_join_expressions(&filter_expr, &all_join_keys)? { Some(filter_expr) => { + let left = self.wrap_with_subquery_plan_if_necessary(left)?; LogicalPlanBuilder::from(left).filter(filter_expr)?.build() } _ => Ok(left), @@ -1011,7 +1021,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let empty_from = matches!(plans.first(), Some(LogicalPlan::EmptyRelation(_))); // process `where` clause - let plan = self.plan_selection(select.selection, plans)?; + let with_where_outer_query_context = + self.with_context(|c| c.subqueries_plans = Some(RwLock::new(Vec::new()))); + let plan = + with_where_outer_query_context.plan_selection(select.selection, plans)?; // process the SELECT expressions, with wildcards expanded. let with_outer_query_context = @@ -1200,8 +1213,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .read() .map_err(|e| DataFusionError::Plan(e.to_string()))?; if !subqueries.is_empty() { + let (subqueries, types): (Vec<_>, Vec<_>) = + subqueries.clone().into_iter().unzip(); LogicalPlanBuilder::from(plan) - .subquery(subqueries.clone())? + .subquery(subqueries, types)? .build()? } else { plan @@ -1439,7 +1454,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Expr::Column(col) => match &col.relation { Some(r) => { if let Some(plans) = self.context.subqueries_plans()? { - if plans.into_iter().any(|p| { + if plans.into_iter().any(|(p, _)| { p.schema().field_with_qualified_name(r, &col.name).is_ok() }) { return Ok(()); @@ -1450,7 +1465,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } None => { if let Some(plans) = self.context.subqueries_plans()? { - if plans.into_iter().any(|p| { + if plans.into_iter().any(|(p, _)| { !p.schema() .fields_with_unqualified_name(&col.name) .is_empty() @@ -2277,19 +2292,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLExpr::Nested(e) => self.sql_expr_to_logical_expr(*e, schema), - SQLExpr::Subquery(q) => { - let with_outer_query_context = self.with_context(|c| c.outer_query_context_schema.push(Arc::new(schema.clone()))); - let alias_name = format!("subquery-{}", self.context.subqueries_plans().unwrap_or_default().unwrap_or_default().len()); - let plan = with_outer_query_context.query_to_plan_with_alias(*q, Some(alias_name), &mut HashMap::new())?; - - let fields = plan.schema().fields(); - if fields.len() != 1 { - return Err(DataFusionError::Plan(format!("Correlated sub query requires only one column in result set but found: {:?}", fields))); - } - let column = fields.iter().next().unwrap().qualified_column(); - self.context.add_subquery_plan(plan)?; - Ok(Expr::Column(column)) - } + SQLExpr::Subquery(q) => self.subquery_to_plan(q, SubqueryType::Scalar, schema), SQLExpr::DotExpr { expr, field } => { Ok(Expr::GetIndexedField { @@ -2714,6 +2717,46 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ))) } } + + fn subquery_to_plan( + &self, + query: Box, + subquery_type: SubqueryType, + schema: &DFSchema, + ) -> Result { + let with_outer_query_context = self.with_context(|c| { + c.outer_query_context_schema.push(Arc::new(schema.clone())) + }); + let alias_name = { + let mut subquery_alias_iter = with_outer_query_context + .subquery_alias_iter + .lock() + .map_err(|_| { + DataFusionError::Plan( + "Unable to lock subquery alias iterator".to_string(), + ) + })?; + let alias_index = subquery_alias_iter.next().ok_or_else(|| { + DataFusionError::Plan( + "Unable to assign an alias to a subquery".to_string(), + ) + })?; + format!("__subquery-{}", alias_index) + }; + let plan = with_outer_query_context.query_to_plan_with_alias( + *query, + Some(alias_name), + &mut HashMap::new(), + )?; + + let fields = plan.schema().fields(); + if fields.len() != 1 { + return Err(DataFusionError::Plan(format!("Correlated sub query requires only one column in result set but found: {:?}", fields))); + } + let column = fields.iter().next().unwrap().qualified_column(); + self.context.add_subquery_plan(plan, subquery_type)?; + Ok(Expr::Column(column)) + } } /// Normalize a SQL object name @@ -4877,25 +4920,81 @@ mod tests { } #[test] - fn subquery() { + fn subquery_select() { let sql = "select person.id, (select lineitem.l_item_id from lineitem where person.id = lineitem.l_item_id limit 1) from person"; - let expected = "Projection: #person.id, #subquery-0.l_item_id\ - \n Subquery\ + let expected = "Projection: #person.id, #__subquery-0.l_item_id\ + \n Subquery: types=[Scalar]\ \n TableScan: person projection=None\ \n Limit: skip=None, fetch=1\ - \n Projection: #lineitem.l_item_id, alias=subquery-0\ + \n Projection: #lineitem.l_item_id, alias=__subquery-0\ \n Filter: ^#person.id = #lineitem.l_item_id\ \n TableScan: lineitem projection=None"; quick_test(sql, expected); } #[test] - fn subquery_no_from() { + fn subquery_select_without_from() { let sql = "select person.id, (select person.age + 1) from person"; - let expected = "Projection: #person.id, #subquery-0.person.age + Int64(1)\ - \n Subquery\ + let expected = "Projection: #person.id, #__subquery-0.person.age + Int64(1)\ + \n Subquery: types=[Scalar]\ \n TableScan: person projection=None\ - \n Projection: ^#person.age + Int64(1), alias=subquery-0\ + \n Projection: ^#person.age + Int64(1), alias=__subquery-0\ + \n EmptyRelation"; + quick_test(sql, expected); + } + + #[test] + fn subquery_where() { + let sql = "select person.id from person where person.id > (select lineitem.l_item_id from lineitem limit 1)"; + let expected = "Projection: #person.id\ + \n Filter: #person.id > #__subquery-0.l_item_id\ + \n Subquery: types=[Scalar]\ + \n TableScan: person projection=None\ + \n Limit: skip=None, fetch=1\ + \n Projection: #lineitem.l_item_id, alias=__subquery-0\ + \n TableScan: lineitem projection=None"; + quick_test(sql, expected); + } + + #[test] + fn subquery_where_without_from() { + let sql = "select person.id from person where person.id = (select person.id)"; + let expected = "Projection: #person.id\ + \n Filter: #person.id = #__subquery-0.person.id\ + \n Subquery: types=[Scalar]\ + \n TableScan: person projection=None\ + \n Projection: ^#person.id, alias=__subquery-0\ + \n EmptyRelation"; + quick_test(sql, expected); + } + + #[test] + fn subquery_select_and_where() { + let sql = "select person.id, (select person.id) from person where person.id > (select lineitem.l_item_id from lineitem limit 1)"; + let expected = "Projection: #person.id, #__subquery-1.person.id\ + \n Subquery: types=[Scalar]\ + \n Filter: #person.id > #__subquery-0.l_item_id\ + \n Subquery: types=[Scalar]\ + \n TableScan: person projection=None\ + \n Limit: skip=None, fetch=1\ + \n Projection: #lineitem.l_item_id, alias=__subquery-0\ + \n TableScan: lineitem projection=None\ + \n Projection: ^#person.id, alias=__subquery-1\ + \n EmptyRelation"; + quick_test(sql, expected); + } + + #[test] + fn subquery_select_and_where_without_from() { + let sql = "select person.id, (select person.id) from person where person.id = (select person.id)"; + let expected = "Projection: #person.id, #__subquery-1.person.id\ + \n Subquery: types=[Scalar]\ + \n Filter: #person.id = #__subquery-0.person.id\ + \n Subquery: types=[Scalar]\ + \n TableScan: person projection=None\ + \n Projection: ^#person.id, alias=__subquery-0\ + \n EmptyRelation\ + \n Projection: ^#person.id, alias=__subquery-1\ \n EmptyRelation"; quick_test(sql, expected); } diff --git a/datafusion/core/tests/sql/subquery.rs b/datafusion/core/tests/sql/subquery.rs index fea43d7d0b41..a60e3801aebd 100644 --- a/datafusion/core/tests/sql/subquery.rs +++ b/datafusion/core/tests/sql/subquery.rs @@ -18,7 +18,7 @@ use super::*; #[tokio::test] -async fn subquery_no_from() -> Result<()> { +async fn subquery_select_no_from() -> Result<()> { let ctx = SessionContext::new(); register_aggregate_simple_csv(&ctx).await?; @@ -39,7 +39,7 @@ async fn subquery_no_from() -> Result<()> { } #[tokio::test] -async fn subquery_with_from() -> Result<()> { +async fn subquery_select_with_from() -> Result<()> { let ctx = SessionContext::new(); register_aggregate_simple_csv(&ctx).await?; @@ -59,6 +59,96 @@ async fn subquery_with_from() -> Result<()> { Ok(()) } +#[tokio::test] +async fn subquery_where_no_from() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_simple_csv(&ctx).await?; + + let sql = + "SELECT DISTINCT c1 FROM aggregate_simple o WHERE (SELECT NOT c3) ORDER BY c1"; + let actual = execute_to_batches(&ctx, sql).await; + + let expected = vec![ + "+---------+", + "| c1 |", + "+---------+", + "| 0.00002 |", + "| 0.00004 |", + "+---------+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} + +#[tokio::test] +async fn subquery_where_with_from() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_simple_csv(&ctx).await?; + + let sql = "SELECT DISTINCT c1 FROM aggregate_simple o WHERE (SELECT c3 FROM aggregate_simple p WHERE o.c1 = p.c1 LIMIT 1) ORDER BY c1"; + let actual = execute_to_batches(&ctx, sql).await; + + let expected = vec![ + "+---------+", + "| c1 |", + "+---------+", + "| 0.00001 |", + "| 0.00003 |", + "| 0.00005 |", + "+---------+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} + +// TODO: plans but does not execute +#[ignore] +#[tokio::test] +async fn subquery_select_and_where_no_from() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_simple_csv(&ctx).await?; + + let sql = "SELECT c1, (SELECT c1 + 1) FROM aggregate_simple o WHERE (SELECT NOT c3) ORDER BY c1 LIMIT 2"; + let actual = execute_to_batches(&ctx, sql).await; + + let expected = vec![ + "+---------+------------------+", + "| c1 | c1 Plus Int64(1) |", + "+---------+------------------+", + "| 0.00002 | 1.00002 |", + "| 0.00004 | 1.00004 |", + "+---------+------------------+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} + +// TODO: plans but does not execute +#[ignore] +#[tokio::test] +async fn subquery_select_and_where_with_from() -> Result<()> { + let ctx = SessionContext::new(); + register_aggregate_simple_csv(&ctx).await?; + + let sql = "SELECT c1, (SELECT c1 + 1) FROM aggregate_simple o WHERE (SELECT c3 FROM aggregate_simple p WHERE o.c1 = p.c1 LIMIT 1) ORDER BY c1 LIMIT 2"; + let actual = execute_to_batches(&ctx, sql).await; + + let expected = vec![ + "+---------+------------------+", + "| c1 | c1 Plus Int64(1) |", + "+---------+------------------+", + "| 0.00001 | 1.00001 |", + "| 0.00003 | 1.00003 |", + "+---------+------------------+", + ]; + assert_batches_eq!(expected, &actual); + + Ok(()) +} + #[tokio::test] async fn subquery_projection_pushdown() -> Result<()> { let ctx = SessionContext::new(); From b57961b31b8c1285d0aee63a3319725fad48e3d5 Mon Sep 17 00:00:00 2001 From: Alexandr Romanenko Date: Mon, 5 Feb 2024 23:52:20 +0300 Subject: [PATCH 2/2] Fix an issue where the query with subqueries in SELECT and WHERE won't execute --- datafusion/core/src/physical_plan/planner.rs | 1 + datafusion/core/tests/sql/subquery.rs | 5 ++--- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs index 1347856d70f5..f93a53485269 100644 --- a/datafusion/core/src/physical_plan/planner.rs +++ b/datafusion/core/src/physical_plan/planner.rs @@ -1033,6 +1033,7 @@ pub fn create_physical_expr( let cursors = execution_props.outer_query_cursors.clone(); let cursor = cursors .iter() + .rev() .find(|cur| cur.schema().field_with_name(c.name.as_str()).is_ok()) .ok_or_else(|| { DataFusionError::Execution(format!( diff --git a/datafusion/core/tests/sql/subquery.rs b/datafusion/core/tests/sql/subquery.rs index a60e3801aebd..572835296e9e 100644 --- a/datafusion/core/tests/sql/subquery.rs +++ b/datafusion/core/tests/sql/subquery.rs @@ -104,13 +104,12 @@ async fn subquery_where_with_from() -> Result<()> { } // TODO: plans but does not execute -#[ignore] #[tokio::test] async fn subquery_select_and_where_no_from() -> Result<()> { let ctx = SessionContext::new(); register_aggregate_simple_csv(&ctx).await?; - let sql = "SELECT c1, (SELECT c1 + 1) FROM aggregate_simple o WHERE (SELECT NOT c3) ORDER BY c1 LIMIT 2"; + let sql = "SELECT c1, (SELECT c1 + 1) FROM aggregate_simple o WHERE (SELECT NOT c3) ORDER BY c1 LIMIT 3"; let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ @@ -118,6 +117,7 @@ async fn subquery_select_and_where_no_from() -> Result<()> { "| c1 | c1 Plus Int64(1) |", "+---------+------------------+", "| 0.00002 | 1.00002 |", + "| 0.00002 | 1.00002 |", "| 0.00004 | 1.00004 |", "+---------+------------------+", ]; @@ -127,7 +127,6 @@ async fn subquery_select_and_where_no_from() -> Result<()> { } // TODO: plans but does not execute -#[ignore] #[tokio::test] async fn subquery_select_and_where_with_from() -> Result<()> { let ctx = SessionContext::new();