Skip to content

Commit

Permalink
feat: Support scalar subquery in WHERE
Browse files Browse the repository at this point in the history
  • Loading branch information
MazterQyou committed Apr 6, 2023
1 parent 6b006e5 commit b39fc37
Show file tree
Hide file tree
Showing 8 changed files with 269 additions and 113 deletions.
5 changes: 3 additions & 2 deletions datafusion/core/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ use super::{dfschema::ToDFSchema, expr_rewriter::coerce_plan_expr_for_schema, Di
use super::{exprlist_to_fields, Expr, JoinConstraint, JoinType, LogicalPlan, PlanType};
use crate::logical_plan::{
columnize_expr, normalize_col, normalize_cols, rewrite_sort_cols_by_aggs, Column,
CrossJoin, DFField, DFSchema, DFSchemaRef, Limit, Partitioning, Repartition, Values,
CrossJoin, DFField, DFSchema, DFSchemaRef, Limit, Partitioning, Repartition,
SubqueryType, Values,
};
use crate::sql::utils::group_window_expr_by_sort_keys;

Expand Down Expand Up @@ -527,7 +528,7 @@ impl LogicalPlanBuilder {
/// Apply correlated sub query
pub fn subquery(
&self,
subqueries: impl IntoIterator<Item = impl Into<LogicalPlan>>,
subqueries: impl IntoIterator<Item = impl Into<(LogicalPlan, SubqueryType)>>,
) -> Result<Self> {
let subqueries = subqueries.into_iter().map(|l| l.into()).collect::<Vec<_>>();
let schema = Arc::new(Subquery::merged_schema(&self.plan, &subqueries));
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,6 @@ pub use plan::{
CreateCatalogSchema, CreateExternalTable, CreateMemoryTable, CrossJoin, Distinct,
DropTable, EmptyRelation, Filter, JoinConstraint, JoinType, Limit, LogicalPlan,
Partitioning, PlanType, PlanVisitor, Repartition, StringifiedPlan, Subquery,
TableScan, ToStringifiedPlan, Union, Values,
SubqueryType, TableScan, ToStringifiedPlan, Union, Values,
};
pub use registry::FunctionRegistry;
48 changes: 39 additions & 9 deletions datafusion/core/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,21 +268,51 @@ pub struct Limit {
#[derive(Clone)]
pub struct Subquery {
/// The list of sub queries
pub subqueries: Vec<LogicalPlan>,
pub subqueries: Vec<(LogicalPlan, SubqueryType)>,
/// The incoming logical plan
pub input: Arc<LogicalPlan>,
/// The schema description of the output
pub schema: DFSchemaRef,
}

/// Subquery type
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum SubqueryType {
/// Scalar (SELECT, WHERE) evaluating to one value
Scalar,
// This will be extended with `Exists` and `AnyAll` types.
}

impl Subquery {
/// Merge schema of main input and correlated subquery columns
pub fn merged_schema(input: &LogicalPlan, subqueries: &[LogicalPlan]) -> DFSchema {
subqueries.iter().fold((**input.schema()).clone(), |a, b| {
let mut res = a;
res.merge(b.schema());
res
})
pub fn merged_schema(
input: &LogicalPlan,
subqueries: &[(LogicalPlan, SubqueryType)],
) -> DFSchema {
subqueries
.iter()
.fold((**input.schema()).clone(), |input_schema, (plan, typ)| {
let mut res = input_schema;
let subquery_schema = Self::transform_dfschema(plan.schema(), *typ);
res.merge(&subquery_schema);
res
})
}

/// Transform DataFusion schema according to subquery type
pub fn transform_dfschema(schema: &DFSchema, typ: SubqueryType) -> DFSchema {
match typ {
SubqueryType::Scalar => schema.clone(),
// Schema will be transformed for `Exists` and `AnyAll`
}
}

/// Transform Arrow field according to subquery type
pub fn transform_field(field: &Field, typ: SubqueryType) -> Field {
match typ {
SubqueryType::Scalar => field.clone(),
// Field will be transformed for `Exists` and `AnyAll`
}
}
}

Expand Down Expand Up @@ -585,7 +615,7 @@ impl LogicalPlan {
input, subqueries, ..
}) => vec![input.as_ref()]
.into_iter()
.chain(subqueries.iter())
.chain(subqueries.iter().map(|(q, _)| q))
.collect(),
LogicalPlan::Filter(Filter { input, .. }) => vec![input],
LogicalPlan::Repartition(Repartition { input, .. }) => vec![input],
Expand Down Expand Up @@ -728,7 +758,7 @@ impl LogicalPlan {
input, subqueries, ..
}) => {
input.accept(visitor)?;
for input in subqueries {
for (input, _) in subqueries {
if !input.accept(visitor)? {
return Ok(false);
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/optimizer/projection_push_down.rs
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ fn optimize_plan(
input, subqueries, ..
}) => {
let mut subquery_required_columns = HashSet::new();
for subquery in subqueries.iter() {
for subquery in subqueries.iter().map(|(q, _)| q) {
let mut inputs = vec![subquery];
while !inputs.is_empty() {
let mut next_inputs = Vec::new();
Expand Down
18 changes: 11 additions & 7 deletions datafusion/core/src/optimizer/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,13 +161,17 @@ pub fn from_plan(
alias: alias.clone(),
}))
}
LogicalPlan::Subquery(Subquery { schema, .. }) => {
Ok(LogicalPlan::Subquery(Subquery {
subqueries: inputs[1..inputs.len()].to_vec(),
input: Arc::new(inputs[0].clone()),
schema: schema.clone(),
}))
}
LogicalPlan::Subquery(Subquery {
schema, subqueries, ..
}) => Ok(LogicalPlan::Subquery(Subquery {
subqueries: inputs[1..inputs.len()]
.iter()
.zip(subqueries.iter())
.map(|(input, (_, t))| (input.clone(), *t))
.collect(),
input: Arc::new(inputs[0].clone()),
schema: schema.clone(),
})),
LogicalPlan::TableUDFs(TableUDFs { .. }) => {
Ok(LogicalPlan::TableUDFs(TableUDFs {
expr: expr.to_vec(),
Expand Down
9 changes: 5 additions & 4 deletions datafusion/core/src/physical_plan/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use crate::logical_plan::plan::{
};
use crate::logical_plan::{
unalias, unnormalize_cols, CrossJoin, DFSchema, Distinct, Expr, Like, LogicalPlan,
Operator, Partitioning as LogicalPartitioning, PlanType, Repartition,
Operator, Partitioning as LogicalPartitioning, PlanType, Repartition, SubqueryType,
ToStringifiedPlan, Union, UserDefinedLogicalNode,
};
use crate::logical_plan::{Limit, Values};
Expand Down Expand Up @@ -923,11 +923,12 @@ impl DefaultPhysicalPlanner {
new_session_state.execution_props = new_session_state.execution_props.with_outer_query_cursor(cursor.clone());
new_session_state.config.target_partitions = 1;
let subqueries = futures::stream::iter(subqueries)
.then(|lp| self.create_initial_plan(lp, &new_session_state))
.then(|(lp, _)| self.create_initial_plan(lp, &new_session_state))
.try_collect::<Vec<_>>()
.await?.into_iter()
.map(|p| -> Arc<dyn ExecutionPlan> {
Arc::new(CoalescePartitionsExec::new(p))
.zip(subqueries.iter())
.map(|(p, (_, t))| -> (Arc<dyn ExecutionPlan>, SubqueryType) {
(Arc::new(CoalescePartitionsExec::new(p)), *t)
})
.collect::<Vec<_>>();
let input = self.create_initial_plan(input, &new_session_state).await?;
Expand Down
143 changes: 82 additions & 61 deletions datafusion/core/src/physical_plan/subquery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use std::sync::Arc;
use std::task::{Context, Poll};

use crate::error::{DataFusionError, Result};
use crate::logical_plan::{Subquery, SubqueryType};
use crate::physical_plan::{DisplayFormatType, ExecutionPlan, Partitioning};
use arrow::array::new_null_array;
use arrow::datatypes::{Schema, SchemaRef};
Expand All @@ -46,7 +47,7 @@ use futures::stream::StreamExt;
#[derive(Debug)]
pub struct SubqueryExec {
/// Sub queries
subqueries: Vec<Arc<dyn ExecutionPlan>>,
subqueries: Vec<(Arc<dyn ExecutionPlan>, SubqueryType)>,
/// Merged schema
schema: SchemaRef,
/// The input plan
Expand All @@ -58,15 +59,22 @@ pub struct SubqueryExec {
impl SubqueryExec {
/// Create a projection on an input
pub fn try_new(
subqueries: Vec<Arc<dyn ExecutionPlan>>,
subqueries: Vec<(Arc<dyn ExecutionPlan>, SubqueryType)>,
input: Arc<dyn ExecutionPlan>,
cursor: Arc<OuterQueryCursor>,
) -> Result<Self> {
let input_schema = input.schema();

let mut total_fields = input_schema.fields().clone();
for q in subqueries.iter() {
total_fields.append(&mut q.schema().fields().clone());
for (q, t) in subqueries.iter() {
total_fields.append(
&mut q
.schema()
.fields()
.iter()
.map(|f| Subquery::transform_field(f, *t))
.collect(),
);
}

let merged_schema = Schema::new_with_metadata(total_fields, HashMap::new());
Expand Down Expand Up @@ -100,7 +108,7 @@ impl ExecutionPlan for SubqueryExec {

fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
let mut res = vec![self.input.clone()];
res.extend(self.subqueries.iter().cloned());
res.extend(self.subqueries.iter().map(|(i, _)| i).cloned());
res
}

Expand Down Expand Up @@ -134,7 +142,13 @@ impl ExecutionPlan for SubqueryExec {
}

Ok(Arc::new(SubqueryExec::try_new(
children.iter().skip(1).cloned().collect(),
children
.iter()
.skip(1)
.cloned()
.zip(self.subqueries.iter())
.map(|(p, (_, t))| (p, *t))
.collect(),
children[0].clone(),
self.cursor.clone(),
)?))
Expand All @@ -151,71 +165,78 @@ impl ExecutionPlan for SubqueryExec {
let context = context.clone();
let size_hint = stream.size_hint();
let schema = self.schema.clone();
let res_stream =
stream.then(move |batch| {
let cursor = cursor.clone();
let context = context.clone();
let subqueries = subqueries.clone();
let schema = schema.clone();
async move {
let batch = batch?;
let b = Arc::new(batch.clone());
cursor.set_batch(b)?;
let mut subquery_arrays = vec![Vec::new(); subqueries.len()];
for i in 0..batch.num_rows() {
cursor.set_position(i)?;
for (subquery_i, subquery) in subqueries.iter().enumerate() {
let null_array = || {
let schema = subquery.schema();
let fields = schema.fields();
if fields.len() != 1 {
return Err(ArrowError::ComputeError(format!(
"Sub query should have only one column but got {}",
fields.len()
)));
}

let data_type = fields.get(0).unwrap().data_type();
Ok(new_null_array(data_type, 1))
};
let res_stream = stream.then(move |batch| {
let cursor = cursor.clone();
let context = context.clone();
let subqueries = subqueries.clone();
let schema = schema.clone();
async move {
let batch = batch?;
let b = Arc::new(batch.clone());
cursor.set_batch(b)?;
let mut subquery_arrays = vec![Vec::new(); subqueries.len()];
for i in 0..batch.num_rows() {
cursor.set_position(i)?;
for (subquery_i, (subquery, subquery_type)) in
subqueries.iter().enumerate()
{
let schema = subquery.schema();
let fields = schema.fields();
if fields.len() != 1 {
return Err(ArrowError::ComputeError(format!(
"Sub query should have only one column but got {}",
fields.len()
)));
}
let data_type = fields.get(0).unwrap().data_type();
let null_array = || new_null_array(data_type, 1);

if subquery.output_partitioning().partition_count() != 1 {
return Err(ArrowError::ComputeError(format!(
"Sub query should have only one partition but got {}",
subquery.output_partitioning().partition_count()
)));
}
let mut stream = subquery.execute(0, context.clone()).await?;
let res = stream.next().await;
if let Some(subquery_batch) = res {
let subquery_batch = subquery_batch?;
match subquery_batch.column(0).len() {
0 => subquery_arrays[subquery_i].push(null_array()?),
if subquery.output_partitioning().partition_count() != 1 {
return Err(ArrowError::ComputeError(format!(
"Sub query should have only one partition but got {}",
subquery.output_partitioning().partition_count()
)));
}
let mut stream = subquery.execute(0, context.clone()).await?;
let res = stream.next().await;
if let Some(subquery_batch) = res {
let subquery_batch = subquery_batch?;
match subquery_type {
SubqueryType::Scalar => match subquery_batch
.column(0)
.len()
{
0 => subquery_arrays[subquery_i].push(null_array()),
1 => subquery_arrays[subquery_i]
.push(subquery_batch.column(0).clone()),
_ => return Err(ArrowError::ComputeError(
"Sub query should return no more than one row"
.to_string(),
)),
};
} else {
subquery_arrays[subquery_i].push(null_array()?);
}
},
};
} else {
match subquery_type {
SubqueryType::Scalar => {
subquery_arrays[subquery_i].push(null_array())
}
};
}
}
let mut new_columns = batch.columns().to_vec();
for subquery_array in subquery_arrays {
new_columns.push(concat(
subquery_array
.iter()
.map(|a| a.as_ref())
.collect::<Vec<_>>()
.as_slice(),
)?);
}
RecordBatch::try_new(schema.clone(), new_columns)
}
});
let mut new_columns = batch.columns().to_vec();
for subquery_array in subquery_arrays {
new_columns.push(concat(
subquery_array
.iter()
.map(|a| a.as_ref())
.collect::<Vec<_>>()
.as_slice(),
)?);
}
RecordBatch::try_new(schema.clone(), new_columns)
}
});
Ok(Box::pin(SubQueryStream {
schema: self.schema.clone(),
stream: Box::pin(res_stream),
Expand Down
Loading

0 comments on commit b39fc37

Please sign in to comment.