Skip to content

Commit

Permalink
Add rewrite rule for push-down-limit for Extension and tests
Browse files Browse the repository at this point in the history
Signed-off-by: Austin Liu <[email protected]>
  • Loading branch information
austin362667 committed Sep 30, 2024
1 parent e9d8574 commit af8babf
Showing 1 changed file with 255 additions and 1 deletion.
256 changes: 255 additions & 1 deletion datafusion/optimizer/src/push_down_limit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,36 @@ impl OptimizerRule for PushDownLimit {
subquery_alias.input = Arc::new(new_limit);
Ok(Transformed::yes(LogicalPlan::SubqueryAlias(subquery_alias)))
}
LogicalPlan::Extension(extension_plan) => {
if !extension_plan.node.allows_limit_to_inputs() {
// If push down is not allowed, keep the original limit
return original_limit(
skip,
fetch,
LogicalPlan::Extension(extension_plan),
);
}

let new_children = extension_plan
.node
.inputs()
.into_iter()
.map(|child| {
LogicalPlan::Limit(Limit {
skip: 0,
fetch: Some(fetch + skip),
input: Arc::new(child.clone()),
})
})
.collect::<Vec<_>>();

// Create a new extension node with updated inputs
let child_plan = LogicalPlan::Extension(extension_plan);
let new_extension =
child_plan.with_new_exprs(child_plan.expressions(), new_children)?;

transformed_limit(skip, fetch, new_extension)
}
input => original_limit(skip, fetch, input),
}
}
Expand Down Expand Up @@ -258,17 +288,241 @@ fn push_down_join(mut join: Join, limit: usize) -> Transformed<Join> {

#[cfg(test)]
mod test {
use std::cmp::Ordering;
use std::fmt::{Debug, Formatter};
use std::vec;

use super::*;
use crate::test::*;
use datafusion_expr::{col, exists, logical_plan::builder::LogicalPlanBuilder};

use datafusion_common::DFSchemaRef;
use datafusion_expr::{
col, exists, logical_plan::builder::LogicalPlanBuilder, Expr, Extension,
UserDefinedLogicalNodeCore,
};
use datafusion_functions_aggregate::expr_fn::max;

fn assert_optimized_plan_equal(plan: LogicalPlan, expected: &str) -> Result<()> {
assert_optimized_plan_eq(Arc::new(PushDownLimit::new()), plan, expected)
}

#[derive(Debug, PartialEq, Eq, Hash)]
pub struct NoopPlan {
input: Vec<LogicalPlan>,
schema: DFSchemaRef,
}

// Manual implementation needed because of `schema` field. Comparison excludes this field.
impl PartialOrd for NoopPlan {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
self.input.partial_cmp(&other.input)
}
}

impl UserDefinedLogicalNodeCore for NoopPlan {
fn name(&self) -> &str {
"NoopPlan"
}

fn inputs(&self) -> Vec<&LogicalPlan> {
self.input.iter().collect()
}

fn schema(&self) -> &DFSchemaRef {
&self.schema
}

fn expressions(&self) -> Vec<Expr> {
self.input
.iter()
.flat_map(|child| child.expressions())
.collect()
}

fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
write!(f, "NoopPlan")
}

fn with_exprs_and_inputs(
&self,
_exprs: Vec<Expr>,
inputs: Vec<LogicalPlan>,
) -> Result<Self> {
Ok(Self {
input: inputs,
schema: Arc::clone(&self.schema),
})
}

fn allows_limit_to_inputs(&self) -> bool {
true // Allow limit push-down
}
}

#[derive(Debug, PartialEq, Eq, Hash)]
struct NoLimitNoopPlan {
input: Vec<LogicalPlan>,
schema: DFSchemaRef,
}

// Manual implementation needed because of `schema` field. Comparison excludes this field.
impl PartialOrd for NoLimitNoopPlan {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
self.input.partial_cmp(&other.input)
}
}

impl UserDefinedLogicalNodeCore for NoLimitNoopPlan {
fn name(&self) -> &str {
"NoLimitNoopPlan"
}

fn inputs(&self) -> Vec<&LogicalPlan> {
self.input.iter().collect()
}

fn schema(&self) -> &DFSchemaRef {
&self.schema
}

fn expressions(&self) -> Vec<Expr> {
self.input
.iter()
.flat_map(|child| child.expressions())
.collect()
}

fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
write!(f, "NoLimitNoopPlan")
}

fn with_exprs_and_inputs(
&self,
_exprs: Vec<Expr>,
inputs: Vec<LogicalPlan>,
) -> Result<Self> {
Ok(Self {
input: inputs,
schema: Arc::clone(&self.schema),
})
}

fn allows_limit_to_inputs(&self) -> bool {
false // Disallow limit push-down by default
}
}
#[test]
fn limit_pushdown_basic() -> Result<()> {
let table_scan = test_table_scan()?;
let noop_plan = LogicalPlan::Extension(Extension {
node: Arc::new(NoopPlan {
input: vec![table_scan.clone()],
schema: Arc::clone(table_scan.schema()),
}),
});

let plan = LogicalPlanBuilder::from(noop_plan)
.limit(0, Some(1000))?
.build()?;

let expected = "Limit: skip=0, fetch=1000\
\n NoopPlan\
\n Limit: skip=0, fetch=1000\
\n TableScan: test, fetch=1000";

assert_optimized_plan_equal(plan, expected)
}

#[test]
fn limit_pushdown_with_skip() -> Result<()> {
let table_scan = test_table_scan()?;
let noop_plan = LogicalPlan::Extension(Extension {
node: Arc::new(NoopPlan {
input: vec![table_scan.clone()],
schema: Arc::clone(table_scan.schema()),
}),
});

let plan = LogicalPlanBuilder::from(noop_plan)
.limit(10, Some(1000))?
.build()?;

let expected = "Limit: skip=10, fetch=1000\
\n NoopPlan\
\n Limit: skip=0, fetch=1010\
\n TableScan: test, fetch=1010";

assert_optimized_plan_equal(plan, expected)
}

#[test]
fn limit_pushdown_multiple_limits() -> Result<()> {
let table_scan = test_table_scan()?;
let noop_plan = LogicalPlan::Extension(Extension {
node: Arc::new(NoopPlan {
input: vec![table_scan.clone()],
schema: Arc::clone(table_scan.schema()),
}),
});

let plan = LogicalPlanBuilder::from(noop_plan)
.limit(10, Some(1000))?
.limit(20, Some(500))?
.build()?;

let expected = "Limit: skip=30, fetch=500\
\n NoopPlan\
\n Limit: skip=0, fetch=530\
\n TableScan: test, fetch=530";

assert_optimized_plan_equal(plan, expected)
}

#[test]
fn limit_pushdown_multiple_inputs() -> Result<()> {
let table_scan = test_table_scan()?;
let noop_plan = LogicalPlan::Extension(Extension {
node: Arc::new(NoopPlan {
input: vec![table_scan.clone(), table_scan.clone()],
schema: Arc::clone(table_scan.schema()),
}),
});

let plan = LogicalPlanBuilder::from(noop_plan)
.limit(0, Some(1000))?
.build()?;

let expected = "Limit: skip=0, fetch=1000\
\n NoopPlan\
\n Limit: skip=0, fetch=1000\
\n TableScan: test, fetch=1000\
\n Limit: skip=0, fetch=1000\
\n TableScan: test, fetch=1000";

assert_optimized_plan_equal(plan, expected)
}

#[test]
fn limit_pushdown_disallowed_noop_plan() -> Result<()> {
let table_scan = test_table_scan()?;
let no_limit_noop_plan = LogicalPlan::Extension(Extension {
node: Arc::new(NoLimitNoopPlan {
input: vec![table_scan.clone()],
schema: Arc::clone(table_scan.schema()),
}),
});

let plan = LogicalPlanBuilder::from(no_limit_noop_plan)
.limit(0, Some(1000))?
.build()?;

let expected = "Limit: skip=0, fetch=1000\
\n NoLimitNoopPlan\
\n TableScan: test";

assert_optimized_plan_equal(plan, expected)
}

#[test]
fn limit_pushdown_projection_table_provider() -> Result<()> {
let table_scan = test_table_scan()?;
Expand Down

0 comments on commit af8babf

Please sign in to comment.