From 6a244bfc79d948278a05b362f75b8684444293f2 Mon Sep 17 00:00:00 2001 From: Marshall Crumiller Date: Fri, 19 Jan 2024 14:46:20 -0500 Subject: [PATCH 1/3] Include nulls in window validity --- .../src/chunked_array/ops/rolling_window.rs | 13 +++++++++++++ py-polars/tests/unit/operations/rolling/test_map.py | 9 ++++++++- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/crates/polars-core/src/chunked_array/ops/rolling_window.rs b/crates/polars-core/src/chunked_array/ops/rolling_window.rs index d1134d2d96ad..8a2de10b6442 100644 --- a/crates/polars-core/src/chunked_array/ops/rolling_window.rs +++ b/crates/polars-core/src/chunked_array/ops/rolling_window.rs @@ -111,6 +111,13 @@ mod inner_mod { // we are in bounds let arr_window = unsafe { arr.slice_typed_unchecked(start, size) }; + + // ensure we still meet window size criteria after removing null values + if size - arr_window.null_count() < options.min_periods { + builder.append_null(); + continue; + } + // Safety. // ptr is not dropped as we are in scope // We are also the only owner of the contents of the Arc @@ -159,6 +166,12 @@ mod inner_mod { // we are in bounds let arr_window = unsafe { arr.slice_typed_unchecked(start, size) }; + // ensure we still meet window size criteria after removing null values + if size - arr_window.null_count() < options.min_periods { + builder.append_null(); + continue; + } + // Safety. // ptr is not dropped as we are in scope // We are also the only owner of the contents of the Arc diff --git a/py-polars/tests/unit/operations/rolling/test_map.py b/py-polars/tests/unit/operations/rolling/test_map.py index 3ffc9d736d83..f1b8eb2e7dea 100644 --- a/py-polars/tests/unit/operations/rolling/test_map.py +++ b/py-polars/tests/unit/operations/rolling/test_map.py @@ -21,6 +21,13 @@ def test_rolling_map_window_size_9160(input: list[int], output: list[int]) -> No assert_series_equal(result, expected) +def testing_rolling_map_window_size_with_nulls() -> None: + s = pl.Series([0, 1, None, 3, 4, 5]) + result = s.rolling_map(lambda x: sum(x), window_size=3, min_periods=3) + expected = pl.Series([None, None, None, None, None, 12]) + assert_series_equal(result, expected) + + def test_rolling_map_clear_reuse_series_state_10681() -> None: df = pl.DataFrame( { @@ -86,7 +93,7 @@ def test_rolling_map_sum_int_cast_to_float() -> None: function=lambda s: s.sum(), window_size=3, weights=[1.0, 2.0, 3.0] ) - expected = pl.Series("A", [None, None, 32.0, 20.0, 48.0], dtype=pl.Float64) + expected = pl.Series("A", [None, None, 32.0, None, None], dtype=pl.Float64) assert_series_equal(result, expected) From 0acb11150cffd3b3506709d62d6af2971097cda5 Mon Sep 17 00:00:00 2001 From: Marshall Crumiller Date: Fri, 19 Jan 2024 14:51:00 -0500 Subject: [PATCH 2/3] lint --- crates/polars-core/src/chunked_array/ops/rolling_window.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/crates/polars-core/src/chunked_array/ops/rolling_window.rs b/crates/polars-core/src/chunked_array/ops/rolling_window.rs index 8a2de10b6442..dff2e76e2616 100644 --- a/crates/polars-core/src/chunked_array/ops/rolling_window.rs +++ b/crates/polars-core/src/chunked_array/ops/rolling_window.rs @@ -111,7 +111,6 @@ mod inner_mod { // we are in bounds let arr_window = unsafe { arr.slice_typed_unchecked(start, size) }; - // ensure we still meet window size criteria after removing null values if size - arr_window.null_count() < options.min_periods { builder.append_null(); From 31fcb232541a71ab5d5cb170af43d085eaf4751e Mon Sep 17 00:00:00 2001 From: Marshall Crumiller Date: Fri, 19 Jan 2024 15:05:11 -0500 Subject: [PATCH 3/3] Fix invalid test --- crates/polars/tests/it/core/rolling_window.rs | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/crates/polars/tests/it/core/rolling_window.rs b/crates/polars/tests/it/core/rolling_window.rs index a79691e79336..17270932bca2 100644 --- a/crates/polars/tests/it/core/rolling_window.rs +++ b/crates/polars/tests/it/core/rolling_window.rs @@ -190,15 +190,7 @@ fn test_rolling_map() { assert_eq!( Vec::from(out), - &[ - None, - None, - Some(3.0), - Some(3.0), - Some(2.0), - Some(5.0), - Some(11.0) - ] + &[None, None, Some(3.0), None, None, None, None,] ); }