Skip to content

Commit

Permalink
fix: Fix cse union schema (pola-rs#19305)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Oct 18, 2024
1 parent 1a3d928 commit 997ebb4
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 9 deletions.
4 changes: 2 additions & 2 deletions crates/polars-lazy/src/tests/cse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ fn test_cse_joins_4954() -> PolarsResult<()> {
let (mut expr_arena, mut lp_arena) = get_arenas();
let lp = c.optimize(&mut lp_arena, &mut expr_arena).unwrap();

// Ensure we get only one cache and the it is not above the join
// Ensure we get only one cache and it is not above the join
// and ensure that every cache only has 1 hit.
let cache_ids = (&lp_arena)
.iter(lp)
Expand All @@ -218,7 +218,7 @@ fn test_cse_joins_4954() -> PolarsResult<()> {
..
} => {
assert_eq!(*cache_hits, 1);
assert!(matches!(lp_arena.get(*input), IR::DataFrameScan { .. }));
assert!(matches!(lp_arena.get(*input), IR::SimpleProjection { .. }));

Some(*id)
},
Expand Down
18 changes: 12 additions & 6 deletions crates/polars-plan/src/plans/optimizer/cache_states.rs
Original file line number Diff line number Diff line change
Expand Up @@ -348,17 +348,23 @@ pub(super) fn set_cache_states(

let lp = IRBuilder::new(new_child, expr_arena, lp_arena)
.project_simple(projection)
.unwrap()
.expect("unique names")
.build();

let lp = proj_pd.optimize(lp, lp_arena, expr_arena)?;
// Remove the projection added by the optimization.
let lp =
if let IR::Select { input, .. } | IR::SimpleProjection { input, .. } = lp {
lp_arena.take(input)
// Optimization can lead to a double projection. Only take the last.
let lp = if let IR::SimpleProjection { input, columns } = lp {
let input = if let IR::SimpleProjection { input: input2, .. } =
lp_arena.get(input)
{
*input2
} else {
lp
input
};
IR::SimpleProjection { input, columns }
} else {
lp
};
lp_arena.replace(child, lp);
}
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ impl ProjectionPushDown {
projections_seen,
lp_arena,
expr_arena,
false,
),
SimpleProjection { columns, input, .. } => {
let exprs = names_to_expr_irs(columns.iter_names_cloned(), expr_arena);
Expand All @@ -356,6 +357,7 @@ impl ProjectionPushDown {
projections_seen,
lp_arena,
expr_arena,
true,
)
},
DataFrameScan {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ pub(super) fn process_projection(
projections_seen: usize,
lp_arena: &mut Arena<IR>,
expr_arena: &mut Arena<AExpr>,
// Whether is SimpleProjection.
simple: bool,
) -> PolarsResult<IR> {
let mut local_projection = Vec::with_capacity(exprs.len());

Expand Down Expand Up @@ -130,7 +132,14 @@ pub(super) fn process_projection(
)?;

let builder = IRBuilder::new(input, expr_arena, lp_arena);
let lp = proj_pd.finish_node(local_projection, builder);

let lp = if !local_projection.is_empty() && simple {
builder
.project_simple_nodes(local_projection.into_iter().map(|e| e.node()))?
.build()
} else {
proj_pd.finish_node(local_projection, builder)
};

Ok(lp)
}
15 changes: 15 additions & 0 deletions py-polars/tests/unit/test_cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,3 +804,18 @@ def test_cse_skip_as_struct_19253() -> None:
"q1": [{"x": -3.5}, {"x": -2.5}],
"q2": [{"x": -3.0}, {"x": -3.0}],
}


def test_cse_union_19227() -> None:
lf = pl.LazyFrame({"A": [1], "B": [2]})
lf_1 = lf.select(C="A", B="B")
lf_2 = lf.select(C="A", A="B")

direct = lf_2.join(lf, on=["A"]).select("C", "A", "B")

indirect = lf_1.join(direct, on=["C", "B"]).select("C", "A", "B")

out = pl.concat([direct, indirect])
assert out.collect().schema == pl.Schema(
[("C", pl.Int64), ("A", pl.Int64), ("B", pl.Int64)]
)

0 comments on commit 997ebb4

Please sign in to comment.