Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(rust, python): include null count in rolling window validity with min_periods #13863

Merged
merged 3 commits into from
Jan 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions crates/polars-core/src/chunked_array/ops/rolling_window.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,12 @@ mod inner_mod {
// we are in bounds
let arr_window = unsafe { arr.slice_typed_unchecked(start, size) };

// ensure we still meet window size criteria after removing null values
if size - arr_window.null_count() < options.min_periods {
Copy link
Contributor Author

@mcrumiller mcrumiller Jan 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realize that the null count is more efficiently calculated by only tracking the head and tail of each new window (i.e. decrementing if a null drops off and incrementing if a new null pops up), but because this is the slow rolling_map this is almost guaranteed not to be a bottleneck.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nah.. Indeed. This might be the worst code in the repo. :(

builder.append_null();
continue;
}

// Safety.
// ptr is not dropped as we are in scope
// We are also the only owner of the contents of the Arc
Expand Down Expand Up @@ -159,6 +165,12 @@ mod inner_mod {
// we are in bounds
let arr_window = unsafe { arr.slice_typed_unchecked(start, size) };

// ensure we still meet window size criteria after removing null values
if size - arr_window.null_count() < options.min_periods {
builder.append_null();
continue;
}

// Safety.
// ptr is not dropped as we are in scope
// We are also the only owner of the contents of the Arc
Expand Down
10 changes: 1 addition & 9 deletions crates/polars/tests/it/core/rolling_window.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,15 +190,7 @@ fn test_rolling_map() {

assert_eq!(
Vec::from(out),
&[
None,
None,
Some(3.0),
Some(3.0),
Some(2.0),
Some(5.0),
Some(11.0)
]
&[None, None, Some(3.0), None, None, None, None,]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test was also previously incorrect.

);
}

Expand Down
9 changes: 8 additions & 1 deletion py-polars/tests/unit/operations/rolling/test_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@ def test_rolling_map_window_size_9160(input: list[int], output: list[int]) -> No
assert_series_equal(result, expected)


def testing_rolling_map_window_size_with_nulls() -> None:
s = pl.Series([0, 1, None, 3, 4, 5])
result = s.rolling_map(lambda x: sum(x), window_size=3, min_periods=3)
expected = pl.Series([None, None, None, None, None, 12])
assert_series_equal(result, expected)


def test_rolling_map_clear_reuse_series_state_10681() -> None:
df = pl.DataFrame(
{
Expand Down Expand Up @@ -86,7 +93,7 @@ def test_rolling_map_sum_int_cast_to_float() -> None:
function=lambda s: s.sum(), window_size=3, weights=[1.0, 2.0, 3.0]
)

expected = pl.Series("A", [None, None, 32.0, 20.0, 48.0], dtype=pl.Float64)
expected = pl.Series("A", [None, None, 32.0, None, None], dtype=pl.Float64)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test was previously incorrect, as it did not take into account the null values in the window.

assert_series_equal(result, expected)


Expand Down