Skip to content

Commit

Permalink
perf: Don't split par if cast to categorical
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Oct 26, 2024
1 parent ce001f1 commit 06f054d
Show file tree
Hide file tree
Showing 9 changed files with 95 additions and 24 deletions.
18 changes: 15 additions & 3 deletions crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,13 +163,19 @@ pub(crate) fn insert_streaming_nodes(
execution_id += 1;
match lp_arena.get(root) {
Filter { input, predicate }
if is_streamable(predicate.node(), expr_arena, Context::Default) =>
if is_streamable(
predicate.node(),
expr_arena,
IsStreamableContext::new(Default::default()),
) =>
{
state.streamable = true;
state.operators_sinks.push(PipelineNode::Operator(root));
stack.push(StackFrame::new(*input, state, current_idx))
},
HStack { input, exprs, .. } if all_streamable(exprs, expr_arena, Context::Default) => {
HStack { input, exprs, .. }
if all_streamable(exprs, expr_arena, Default::default()) =>
{
state.streamable = true;
state.operators_sinks.push(PipelineNode::Operator(root));
stack.push(StackFrame::new(*input, state, current_idx))
Expand All @@ -194,7 +200,13 @@ pub(crate) fn insert_streaming_nodes(
state.operators_sinks.push(PipelineNode::Sink(root));
stack.push(StackFrame::new(*input, state, current_idx))
},
Select { input, expr, .. } if all_streamable(expr, expr_arena, Context::Default) => {
Select { input, expr, .. }
if all_streamable(
expr,
expr_arena,
IsStreamableContext::new(Default::default()),
) =>
{
state.streamable = true;
state.operators_sinks.push(PipelineNode::Operator(root));
stack.push(StackFrame::new(*input, state, current_idx))
Expand Down
16 changes: 12 additions & 4 deletions crates/polars-mem-engine/src/planner/lp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,11 @@ fn create_physical_plan_impl(
Ok(Box::new(executors::SliceExec { input, offset, len }))
},
Filter { input, predicate } => {
let mut streamable = is_streamable(predicate.node(), expr_arena, Context::Default);
let mut streamable = is_streamable(
predicate.node(),
expr_arena,
IsStreamableContext::new(Context::Default).with_allow_cast_categorical(false),
);
let input_schema = lp_arena.get(input).schema(lp_arena).into_owned();
if streamable {
// This can cause problems with string caches
Expand Down Expand Up @@ -382,7 +386,7 @@ fn create_physical_plan_impl(
&mut state,
)?;

let streamable = options.should_broadcast && all_streamable(&expr, expr_arena, Context::Default)
let streamable = options.should_broadcast && all_streamable(&expr, expr_arena, IsStreamableContext::new(Context::Default).with_allow_cast_categorical(false))
// If all columns are literal we would get a 1 row per thread.
&& !phys_expr.iter().all(|p| {
p.is_literal()
Expand Down Expand Up @@ -631,8 +635,12 @@ fn create_physical_plan_impl(
let input_schema = lp_arena.get(input).schema(lp_arena).into_owned();
let input = create_physical_plan_impl(input, lp_arena, expr_arena, state)?;

let streamable =
options.should_broadcast && all_streamable(&exprs, expr_arena, Context::Default);
let streamable = options.should_broadcast
&& all_streamable(
&exprs,
expr_arena,
IsStreamableContext::new(Context::Default).with_allow_cast_categorical(false),
);

let mut state = ExpressionConversionState::new(
POOL.current_num_threads() > exprs.len(),
Expand Down
70 changes: 60 additions & 10 deletions crates/polars-plan/src/plans/aexpr/utils.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use bitflags::bitflags;

use super::*;

fn has_series_or_range(ae: &AExpr) -> bool {
Expand All @@ -7,7 +9,46 @@ fn has_series_or_range(ae: &AExpr) -> bool {
)
}

pub fn is_streamable(node: Node, expr_arena: &Arena<AExpr>, context: Context) -> bool {
bitflags! {
#[derive(Default, Copy, Clone)]
struct StreamableFlags: u8 {
const ALLOW_CAST_CATEGORICAL = 1;
}
}

#[derive(Copy, Clone)]
pub struct IsStreamableContext {
flags: StreamableFlags,
context: Context,
}

impl Default for IsStreamableContext {
fn default() -> Self {
Self {
flags: StreamableFlags::all(),
context: Default::default(),
}
}
}

impl IsStreamableContext {
pub fn new(ctx: Context) -> Self {
Self {
flags: StreamableFlags::all(),
context: ctx,
}
}

pub fn with_allow_cast_categorical(mut self, allow_cast_categorical: bool) -> Self {
self.flags.set(
StreamableFlags::ALLOW_CAST_CATEGORICAL,
allow_cast_categorical,
);
self
}
}

pub fn is_streamable(node: Node, expr_arena: &Arena<AExpr>, ctx: IsStreamableContext) -> bool {
// check whether leaf column is Col or Lit
let mut seen_column = false;
let mut seen_lit_range = false;
Expand All @@ -16,13 +57,14 @@ pub fn is_streamable(node: Node, expr_arena: &Arena<AExpr>, context: Context) ->
function: FunctionExpr::SetSortedFlag(_),
..
} => true,
AExpr::Function { options, .. } | AExpr::AnonymousFunction { options, .. } => match context
{
Context::Default => matches!(
options.collect_groups,
ApplyOptions::ElementWise | ApplyOptions::ApplyList
),
Context::Aggregation => matches!(options.collect_groups, ApplyOptions::ElementWise),
AExpr::Function { options, .. } | AExpr::AnonymousFunction { options, .. } => {
match ctx.context {
Context::Default => matches!(
options.collect_groups,
ApplyOptions::ElementWise | ApplyOptions::ApplyList
),
Context::Aggregation => matches!(options.collect_groups, ApplyOptions::ElementWise),
}
},
AExpr::Column(_) => {
seen_column = true;
Expand All @@ -41,6 +83,10 @@ pub fn is_streamable(node: Node, expr_arena: &Arena<AExpr>, context: Context) ->
&& !has_aexpr(*falsy, expr_arena, has_series_or_range)
&& !has_aexpr(*predicate, expr_arena, has_series_or_range)
},
#[cfg(feature = "dtype-categorical")]
AExpr::Cast { dtype, .. } if matches!(dtype, DataType::Categorical(_, _)) => {
ctx.flags.contains(StreamableFlags::ALLOW_CAST_CATEGORICAL)
},
AExpr::Alias(_, _) | AExpr::Cast { .. } => true,
AExpr::Literal(lv) => match lv {
LiteralValue::Series(_) | LiteralValue::Range { .. } => {
Expand All @@ -64,8 +110,12 @@ pub fn is_streamable(node: Node, expr_arena: &Arena<AExpr>, context: Context) ->
false
}

pub fn all_streamable(exprs: &[ExprIR], expr_arena: &Arena<AExpr>, context: Context) -> bool {
pub fn all_streamable(
exprs: &[ExprIR],
expr_arena: &Arena<AExpr>,
ctx: IsStreamableContext,
) -> bool {
exprs
.iter()
.all(|e| is_streamable(e.node(), expr_arena, context))
.all(|e| is_streamable(e.node(), expr_arena, ctx))
}
2 changes: 1 addition & 1 deletion crates/polars-plan/src/plans/conversion/dsl_to_ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ pub fn to_alp_impl(lp: DslPlan, ctxt: &mut DslConversionContext) -> PolarsResult

let predicate_ae = to_expr_ir(predicate.clone(), ctxt.expr_arena)?;

return if is_streamable(predicate_ae.node(), ctxt.expr_arena, Context::Default) {
return if is_streamable(predicate_ae.node(), ctxt.expr_arena, Default::default()) {
// Split expression that are ANDed into multiple Filter nodes as the optimizer can then
// push them down independently. Especially if they refer columns from different tables
// this will be more performant.
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-plan/src/plans/conversion/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ pub fn resolve_join(
// Every expression must be elementwise so that we are
// guaranteed the keys for a join are all the same length.
let all_elementwise =
|aexprs: &[ExprIR]| all_streamable(aexprs, &*ctxt.expr_arena, Context::Default);
|aexprs: &[ExprIR]| all_streamable(aexprs, &*ctxt.expr_arena, Default::default());
polars_ensure!(
all_elementwise(&left_on) && all_elementwise(&right_on),
InvalidOperation: "All join key expressions must be elementwise."
Expand Down
3 changes: 2 additions & 1 deletion crates/polars-plan/src/plans/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,12 @@ pub use schema::*;
use serde::{Deserialize, Serialize};
use strum_macros::IntoStaticStr;

#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, Default)]
pub enum Context {
/// Any operation that is done on groups
Aggregation,
/// Any operation that is done while projection/ selection of data
#[default]
Default,
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,7 @@ impl<'a> PredicatePushDown<'a> {
if let Some(predicate) = predicate {
// For IO plugins we only accept streamable expressions as
// we want to apply the predicates to the batches.
if !is_streamable(predicate.node(), expr_arena, Context::Default)
if !is_streamable(predicate.node(), expr_arena, Default::default())
&& matches!(options.python_source, PythonScanSource::IOPlugin)
{
let lp = PythonScan { options };
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ fn can_pushdown_slice_past_projections(exprs: &[ExprIR], arena: &Arena<AExpr>) -
// `select(c = Literal([1, 2, 3]).is_in(col(a)))`, for functions like `is_in`,
// `str.contains`, `str.contains_many` etc. - observe a column node is present
// but the output height is not dependent on it.
let is_elementwise = is_streamable(expr_ir.node(), arena, Context::Default);
let is_elementwise = is_streamable(expr_ir.node(), arena, Default::default());
let (has_column, literals_all_scalar) = arena.iter(expr_ir.node()).fold(
(false, true),
|(has_column, lit_scalar), (_node, ae)| {
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-stream/src/skeleton.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@ use std::cmp::Reverse;
use polars_core::prelude::*;
use polars_core::POOL;
use polars_expr::planner::{create_physical_expr, get_expr_depth_limit, ExpressionConversionState};
use polars_plan::plans::{Context, IRPlan, IR};
use polars_plan::plans::{Context, IRPlan, IsStreamableContext, IR};
use polars_plan::prelude::expr_ir::ExprIR;
use polars_plan::prelude::AExpr;
use polars_utils::arena::{Arena, Node};
use slotmap::{SecondaryMap, SlotMap};

fn is_streamable(node: Node, arena: &Arena<AExpr>) -> bool {
polars_plan::plans::is_streamable(node, arena, Context::Default)
polars_plan::plans::is_streamable(node, arena, IsStreamableContext::new(Context::Default))
}

pub fn run_query(
Expand Down

0 comments on commit 06f054d

Please sign in to comment.