From 06f054d3977c53f187b7387b43f99b9ba75b67b7 Mon Sep 17 00:00:00 2001 From: ritchie Date: Sat, 26 Oct 2024 10:42:03 +0200 Subject: [PATCH] perf: Don't split par if cast to categorical --- .../physical_plan/streaming/convert_alp.rs | 18 ++++- crates/polars-mem-engine/src/planner/lp.rs | 16 +++-- crates/polars-plan/src/plans/aexpr/utils.rs | 70 ++++++++++++++++--- .../src/plans/conversion/dsl_to_ir.rs | 2 +- .../polars-plan/src/plans/conversion/join.rs | 2 +- crates/polars-plan/src/plans/mod.rs | 3 +- .../plans/optimizer/predicate_pushdown/mod.rs | 2 +- .../src/plans/optimizer/slice_pushdown_lp.rs | 2 +- crates/polars-stream/src/skeleton.rs | 4 +- 9 files changed, 95 insertions(+), 24 deletions(-) diff --git a/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs b/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs index 7100c083bd47..6c84af4510b5 100644 --- a/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs +++ b/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs @@ -163,13 +163,19 @@ pub(crate) fn insert_streaming_nodes( execution_id += 1; match lp_arena.get(root) { Filter { input, predicate } - if is_streamable(predicate.node(), expr_arena, Context::Default) => + if is_streamable( + predicate.node(), + expr_arena, + IsStreamableContext::new(Default::default()), + ) => { state.streamable = true; state.operators_sinks.push(PipelineNode::Operator(root)); stack.push(StackFrame::new(*input, state, current_idx)) }, - HStack { input, exprs, .. } if all_streamable(exprs, expr_arena, Context::Default) => { + HStack { input, exprs, .. } + if all_streamable(exprs, expr_arena, Default::default()) => + { state.streamable = true; state.operators_sinks.push(PipelineNode::Operator(root)); stack.push(StackFrame::new(*input, state, current_idx)) @@ -194,7 +200,13 @@ pub(crate) fn insert_streaming_nodes( state.operators_sinks.push(PipelineNode::Sink(root)); stack.push(StackFrame::new(*input, state, current_idx)) }, - Select { input, expr, .. } if all_streamable(expr, expr_arena, Context::Default) => { + Select { input, expr, .. } + if all_streamable( + expr, + expr_arena, + IsStreamableContext::new(Default::default()), + ) => + { state.streamable = true; state.operators_sinks.push(PipelineNode::Operator(root)); stack.push(StackFrame::new(*input, state, current_idx)) diff --git a/crates/polars-mem-engine/src/planner/lp.rs b/crates/polars-mem-engine/src/planner/lp.rs index 3eef72e86bef..3a5e525867fb 100644 --- a/crates/polars-mem-engine/src/planner/lp.rs +++ b/crates/polars-mem-engine/src/planner/lp.rs @@ -239,7 +239,11 @@ fn create_physical_plan_impl( Ok(Box::new(executors::SliceExec { input, offset, len })) }, Filter { input, predicate } => { - let mut streamable = is_streamable(predicate.node(), expr_arena, Context::Default); + let mut streamable = is_streamable( + predicate.node(), + expr_arena, + IsStreamableContext::new(Context::Default).with_allow_cast_categorical(false), + ); let input_schema = lp_arena.get(input).schema(lp_arena).into_owned(); if streamable { // This can cause problems with string caches @@ -382,7 +386,7 @@ fn create_physical_plan_impl( &mut state, )?; - let streamable = options.should_broadcast && all_streamable(&expr, expr_arena, Context::Default) + let streamable = options.should_broadcast && all_streamable(&expr, expr_arena, IsStreamableContext::new(Context::Default).with_allow_cast_categorical(false)) // If all columns are literal we would get a 1 row per thread. && !phys_expr.iter().all(|p| { p.is_literal() @@ -631,8 +635,12 @@ fn create_physical_plan_impl( let input_schema = lp_arena.get(input).schema(lp_arena).into_owned(); let input = create_physical_plan_impl(input, lp_arena, expr_arena, state)?; - let streamable = - options.should_broadcast && all_streamable(&exprs, expr_arena, Context::Default); + let streamable = options.should_broadcast + && all_streamable( + &exprs, + expr_arena, + IsStreamableContext::new(Context::Default).with_allow_cast_categorical(false), + ); let mut state = ExpressionConversionState::new( POOL.current_num_threads() > exprs.len(), diff --git a/crates/polars-plan/src/plans/aexpr/utils.rs b/crates/polars-plan/src/plans/aexpr/utils.rs index aef7cd157334..6520cc476178 100644 --- a/crates/polars-plan/src/plans/aexpr/utils.rs +++ b/crates/polars-plan/src/plans/aexpr/utils.rs @@ -1,3 +1,5 @@ +use bitflags::bitflags; + use super::*; fn has_series_or_range(ae: &AExpr) -> bool { @@ -7,7 +9,46 @@ fn has_series_or_range(ae: &AExpr) -> bool { ) } -pub fn is_streamable(node: Node, expr_arena: &Arena, context: Context) -> bool { +bitflags! { + #[derive(Default, Copy, Clone)] + struct StreamableFlags: u8 { + const ALLOW_CAST_CATEGORICAL = 1; + } +} + +#[derive(Copy, Clone)] +pub struct IsStreamableContext { + flags: StreamableFlags, + context: Context, +} + +impl Default for IsStreamableContext { + fn default() -> Self { + Self { + flags: StreamableFlags::all(), + context: Default::default(), + } + } +} + +impl IsStreamableContext { + pub fn new(ctx: Context) -> Self { + Self { + flags: StreamableFlags::all(), + context: ctx, + } + } + + pub fn with_allow_cast_categorical(mut self, allow_cast_categorical: bool) -> Self { + self.flags.set( + StreamableFlags::ALLOW_CAST_CATEGORICAL, + allow_cast_categorical, + ); + self + } +} + +pub fn is_streamable(node: Node, expr_arena: &Arena, ctx: IsStreamableContext) -> bool { // check whether leaf column is Col or Lit let mut seen_column = false; let mut seen_lit_range = false; @@ -16,13 +57,14 @@ pub fn is_streamable(node: Node, expr_arena: &Arena, context: Context) -> function: FunctionExpr::SetSortedFlag(_), .. } => true, - AExpr::Function { options, .. } | AExpr::AnonymousFunction { options, .. } => match context - { - Context::Default => matches!( - options.collect_groups, - ApplyOptions::ElementWise | ApplyOptions::ApplyList - ), - Context::Aggregation => matches!(options.collect_groups, ApplyOptions::ElementWise), + AExpr::Function { options, .. } | AExpr::AnonymousFunction { options, .. } => { + match ctx.context { + Context::Default => matches!( + options.collect_groups, + ApplyOptions::ElementWise | ApplyOptions::ApplyList + ), + Context::Aggregation => matches!(options.collect_groups, ApplyOptions::ElementWise), + } }, AExpr::Column(_) => { seen_column = true; @@ -41,6 +83,10 @@ pub fn is_streamable(node: Node, expr_arena: &Arena, context: Context) -> && !has_aexpr(*falsy, expr_arena, has_series_or_range) && !has_aexpr(*predicate, expr_arena, has_series_or_range) }, + #[cfg(feature = "dtype-categorical")] + AExpr::Cast { dtype, .. } if matches!(dtype, DataType::Categorical(_, _)) => { + ctx.flags.contains(StreamableFlags::ALLOW_CAST_CATEGORICAL) + }, AExpr::Alias(_, _) | AExpr::Cast { .. } => true, AExpr::Literal(lv) => match lv { LiteralValue::Series(_) | LiteralValue::Range { .. } => { @@ -64,8 +110,12 @@ pub fn is_streamable(node: Node, expr_arena: &Arena, context: Context) -> false } -pub fn all_streamable(exprs: &[ExprIR], expr_arena: &Arena, context: Context) -> bool { +pub fn all_streamable( + exprs: &[ExprIR], + expr_arena: &Arena, + ctx: IsStreamableContext, +) -> bool { exprs .iter() - .all(|e| is_streamable(e.node(), expr_arena, context)) + .all(|e| is_streamable(e.node(), expr_arena, ctx)) } diff --git a/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs b/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs index 7ee9c7f069d7..793ab63194d7 100644 --- a/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs +++ b/crates/polars-plan/src/plans/conversion/dsl_to_ir.rs @@ -369,7 +369,7 @@ pub fn to_alp_impl(lp: DslPlan, ctxt: &mut DslConversionContext) -> PolarsResult let predicate_ae = to_expr_ir(predicate.clone(), ctxt.expr_arena)?; - return if is_streamable(predicate_ae.node(), ctxt.expr_arena, Context::Default) { + return if is_streamable(predicate_ae.node(), ctxt.expr_arena, Default::default()) { // Split expression that are ANDed into multiple Filter nodes as the optimizer can then // push them down independently. Especially if they refer columns from different tables // this will be more performant. diff --git a/crates/polars-plan/src/plans/conversion/join.rs b/crates/polars-plan/src/plans/conversion/join.rs index 4fa47e7695c8..9d63d18b0a46 100644 --- a/crates/polars-plan/src/plans/conversion/join.rs +++ b/crates/polars-plan/src/plans/conversion/join.rs @@ -115,7 +115,7 @@ pub fn resolve_join( // Every expression must be elementwise so that we are // guaranteed the keys for a join are all the same length. let all_elementwise = - |aexprs: &[ExprIR]| all_streamable(aexprs, &*ctxt.expr_arena, Context::Default); + |aexprs: &[ExprIR]| all_streamable(aexprs, &*ctxt.expr_arena, Default::default()); polars_ensure!( all_elementwise(&left_on) && all_elementwise(&right_on), InvalidOperation: "All join key expressions must be elementwise." diff --git a/crates/polars-plan/src/plans/mod.rs b/crates/polars-plan/src/plans/mod.rs index 03eb06387cc6..314ca8bb0cb2 100644 --- a/crates/polars-plan/src/plans/mod.rs +++ b/crates/polars-plan/src/plans/mod.rs @@ -49,11 +49,12 @@ pub use schema::*; use serde::{Deserialize, Serialize}; use strum_macros::IntoStaticStr; -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, Default)] pub enum Context { /// Any operation that is done on groups Aggregation, /// Any operation that is done while projection/ selection of data + #[default] Default, } diff --git a/crates/polars-plan/src/plans/optimizer/predicate_pushdown/mod.rs b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/mod.rs index ed3f3e0376bd..ff5f2f89ff0d 100644 --- a/crates/polars-plan/src/plans/optimizer/predicate_pushdown/mod.rs +++ b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/mod.rs @@ -672,7 +672,7 @@ impl<'a> PredicatePushDown<'a> { if let Some(predicate) = predicate { // For IO plugins we only accept streamable expressions as // we want to apply the predicates to the batches. - if !is_streamable(predicate.node(), expr_arena, Context::Default) + if !is_streamable(predicate.node(), expr_arena, Default::default()) && matches!(options.python_source, PythonScanSource::IOPlugin) { let lp = PythonScan { options }; diff --git a/crates/polars-plan/src/plans/optimizer/slice_pushdown_lp.rs b/crates/polars-plan/src/plans/optimizer/slice_pushdown_lp.rs index b656795f53d2..9c2f8497fac8 100644 --- a/crates/polars-plan/src/plans/optimizer/slice_pushdown_lp.rs +++ b/crates/polars-plan/src/plans/optimizer/slice_pushdown_lp.rs @@ -31,7 +31,7 @@ fn can_pushdown_slice_past_projections(exprs: &[ExprIR], arena: &Arena) - // `select(c = Literal([1, 2, 3]).is_in(col(a)))`, for functions like `is_in`, // `str.contains`, `str.contains_many` etc. - observe a column node is present // but the output height is not dependent on it. - let is_elementwise = is_streamable(expr_ir.node(), arena, Context::Default); + let is_elementwise = is_streamable(expr_ir.node(), arena, Default::default()); let (has_column, literals_all_scalar) = arena.iter(expr_ir.node()).fold( (false, true), |(has_column, lit_scalar), (_node, ae)| { diff --git a/crates/polars-stream/src/skeleton.rs b/crates/polars-stream/src/skeleton.rs index c7c996824d11..9516be3b902a 100644 --- a/crates/polars-stream/src/skeleton.rs +++ b/crates/polars-stream/src/skeleton.rs @@ -4,14 +4,14 @@ use std::cmp::Reverse; use polars_core::prelude::*; use polars_core::POOL; use polars_expr::planner::{create_physical_expr, get_expr_depth_limit, ExpressionConversionState}; -use polars_plan::plans::{Context, IRPlan, IR}; +use polars_plan::plans::{Context, IRPlan, IsStreamableContext, IR}; use polars_plan::prelude::expr_ir::ExprIR; use polars_plan::prelude::AExpr; use polars_utils::arena::{Arena, Node}; use slotmap::{SecondaryMap, SlotMap}; fn is_streamable(node: Node, arena: &Arena) -> bool { - polars_plan::plans::is_streamable(node, arena, Context::Default) + polars_plan::plans::is_streamable(node, arena, IsStreamableContext::new(Context::Default)) } pub fn run_query(