Skip to content

Commit

Permalink
fix: Only flush if operator can flush in streaming outer join (#16723)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Jun 4, 2024
1 parent 79afe75 commit ddf8126
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pub struct GenericFullOuterJoinProbe<K: ExtraPayload> {
/// the dataframe is not rechunked.
df_a: Arc<DataFrame>,
// Dummy needed for the flush phase.
df_b_dummy: Option<DataFrame>,
df_b_flush_dummy: Option<DataFrame>,
/// The join columns are all tightly packed
/// the values of a join column(s) can be found
/// by:
Expand Down Expand Up @@ -77,7 +77,7 @@ impl<K: ExtraPayload> GenericFullOuterJoinProbe<K> {
) -> Self {
GenericFullOuterJoinProbe {
df_a: Arc::new(df_a),
df_b_dummy: None,
df_b_flush_dummy: None,
materialized_join_cols,
suffix,
hb,
Expand Down Expand Up @@ -207,8 +207,8 @@ impl<K: ExtraPayload> GenericFullOuterJoinProbe<K> {
self.join_tuples_a.clear();
self.join_tuples_b.clear();

if self.df_b_dummy.is_none() {
self.df_b_dummy = Some(chunk.data.clear())
if self.df_b_flush_dummy.is_none() {
self.df_b_flush_dummy = Some(chunk.data.clear())
}

let mut hashes = std::mem::take(&mut self.hashes);
Expand Down Expand Up @@ -270,7 +270,7 @@ impl<K: ExtraPayload> GenericFullOuterJoinProbe<K> {
};

let size = left_df.height();
let right_df = self.df_b_dummy.as_ref().unwrap();
let right_df = self.df_b_flush_dummy.as_ref().unwrap();

let right_df = unsafe {
DataFrame::new_no_checks(
Expand Down Expand Up @@ -301,7 +301,7 @@ impl<K: ExtraPayload> Operator for GenericFullOuterJoinProbe<K> {
}

fn must_flush(&self) -> bool {
true
self.df_b_flush_dummy.is_some()
}

fn split(&self, thread_no: usize) -> Box<dyn Operator> {
Expand Down
35 changes: 35 additions & 0 deletions py-polars/tests/unit/streaming/test_streaming_join.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from datetime import datetime
from typing import TYPE_CHECKING, Literal

import numpy as np
Expand All @@ -10,6 +11,8 @@
from polars.testing import assert_frame_equal

if TYPE_CHECKING:
from pathlib import Path

from polars.type_aliases import JoinStrategy

pytestmark = pytest.mark.xdist_group("streaming")
Expand Down Expand Up @@ -270,3 +273,35 @@ def test_non_coalescing_streaming_left_join() -> None:
"a_right": [1, 2, None],
"c": ["j", "i", None],
}


@pytest.mark.write_disk()
def test_streaming_outer_join_partial_flush(tmp_path: Path) -> None:
data = {
"value_at": [datetime(2024, i + 1, 1) for i in range(6)],
"value": list(range(6)),
}

parquet_path = tmp_path / "data.parquet"
pl.DataFrame(data=data).write_parquet(parquet_path)

other_parquet_path = tmp_path / "data2.parquet"
pl.DataFrame(data=data).write_parquet(other_parquet_path)

lf1 = pl.scan_parquet(other_parquet_path)
lf2 = pl.scan_parquet(parquet_path)

join_cols = set(lf1.columns).intersection(set(lf2.columns))
final_lf = lf1.join(lf2, on=list(join_cols), how="full", coalesce=True)

assert final_lf.collect(streaming=True).to_dict(as_series=False) == {
"value_at": [
datetime(2024, 1, 1, 0, 0),
datetime(2024, 2, 1, 0, 0),
datetime(2024, 3, 1, 0, 0),
datetime(2024, 4, 1, 0, 0),
datetime(2024, 5, 1, 0, 0),
datetime(2024, 6, 1, 0, 0),
],
"value": [0, 1, 2, 3, 4, 5],
}

0 comments on commit ddf8126

Please sign in to comment.