diff --git a/DESCRIPTION b/DESCRIPTION index feea8e169..aa542f794 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -119,4 +119,4 @@ Collate: Config/rextendr/version: 0.3.1 VignetteBuilder: knitr Config/polars/LibVersion: 0.39.4 -Config/polars/RustToolchainVersion: nightly-2024-04-15 +Config/polars/RustToolchainVersion: nightly-2024-05-14 diff --git a/R/dataframe__frame.R b/R/dataframe__frame.R index 724b690e9..a2268608b 100644 --- a/R/dataframe__frame.R +++ b/R/dataframe__frame.R @@ -2131,9 +2131,8 @@ DataFrame_rolling = function( closed = "right", group_by = NULL, check_sorted = TRUE) { - if (is.null(offset)) { - offset = paste0("-", period) # TODO: `paste0` should be executed after `period` is parsed as string - } + period = parse_as_polars_duration_string(period) + offset = parse_as_polars_duration_string(offset) %||% negate_duration_string(period) construct_rolling_group_by(self, index_column, period, offset, closed, group_by, check_sorted) } @@ -2217,7 +2216,7 @@ DataFrame_group_by_dynamic = function( start_by = "window", check_sorted = TRUE) { every = parse_as_polars_duration_string(every) - offset = parse_as_polars_duration_string(offset) %||% paste0("-", every) + offset = parse_as_polars_duration_string(offset) %||% negate_duration_string(every) period = parse_as_polars_duration_string(period) %||% every construct_group_by_dynamic( self, index_column, every, period, offset, include_boundaries, closed, label, diff --git a/R/expr__expr.R b/R/expr__expr.R index bc0130430..ae345edb4 100644 --- a/R/expr__expr.R +++ b/R/expr__expr.R @@ -3316,11 +3316,12 @@ Expr_peak_max = function() { Expr_rolling = function( index_column, ..., - period, offset = NULL, - closed = "right", check_sorted = TRUE) { - if (is.null(offset)) { - offset = paste0("-", period) # TODO: `paste0` should be executed after `period` is parsed as string - } + period, + offset = NULL, + closed = "right", + check_sorted = TRUE) { + period = parse_as_polars_duration_string(period) + offset = parse_as_polars_duration_string(offset) %||% negate_duration_string(period) .pr$Expr$rolling(self, index_column, period, offset, closed, check_sorted) |> unwrap("in $rolling():") } diff --git a/R/lazyframe__lazy.R b/R/lazyframe__lazy.R index b2921823f..b637e72d1 100644 --- a/R/lazyframe__lazy.R +++ b/R/lazyframe__lazy.R @@ -1914,9 +1914,8 @@ LazyFrame_rolling = function( closed = "right", group_by = NULL, check_sorted = TRUE) { - if (is.null(offset)) { - offset = paste0("-", period) # TODO: `paste0` should be executed after `period` is parsed as string - } + period = parse_as_polars_duration_string(period) + offset = parse_as_polars_duration_string(offset) %||% negate_duration_string(period) .pr$LazyFrame$rolling( self, index_column, period, offset, closed, wrap_elist_result(group_by, str_to_lit = FALSE), check_sorted @@ -2026,7 +2025,7 @@ LazyFrame_group_by_dynamic = function( start_by = "window", check_sorted = TRUE) { every = parse_as_polars_duration_string(every) - offset = parse_as_polars_duration_string(offset) %||% paste0("-", every) + offset = parse_as_polars_duration_string(offset) %||% negate_duration_string(every) period = parse_as_polars_duration_string(period) %||% every .pr$LazyFrame$group_by_dynamic( diff --git a/R/parse_as_duration.R b/R/parse_as_duration.R index f06aec60d..a5705afd9 100644 --- a/R/parse_as_duration.R +++ b/R/parse_as_duration.R @@ -56,10 +56,9 @@ parse_as_polars_duration_string.default = function(x, default = NULL, ...) { unwrap() } - #' @exportS3Method parse_as_polars_duration_string.character = function(x, default = NULL, ...) { - if (length(x) != 1L || is.na(x)) { + if (length(x) != 1L) { Err_plain(paste0("`", deparse(substitute(x)), "` must be a single non-NA character or difftime.")) |> unwrap() } @@ -67,10 +66,9 @@ parse_as_polars_duration_string.character = function(x, default = NULL, ...) { x } - #' @exportS3Method parse_as_polars_duration_string.difftime = function(x, default = NULL, ...) { - if (length(x) != 1L || is.na(x)) { + if (length(x) != 1L) { Err_plain(paste0("`", deparse(substitute(x)), "` must be a single non-NA character or difftime.")) |> unwrap() } @@ -95,3 +93,11 @@ difftime_to_duration_string = function(dft) { ) paste0(value, unit) } + +negate_duration_string = function(x) { + if (startsWith(x, "-")) { + gsub("^-", "", x) + } else { + paste0("-", x) + } +} diff --git a/src/rust/Cargo.lock b/src/rust/Cargo.lock index 7b19c134e..9f69351f2 100644 --- a/src/rust/Cargo.lock +++ b/src/rust/Cargo.lock @@ -1168,9 +1168,9 @@ checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" [[package]] name = "libmimalloc-sys" -version = "0.1.37" +version = "0.1.38" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81eb4061c0582dedea1cbc7aff2240300dd6982e0239d1c99e65c1dbf4a30ba7" +checksum = "0e7bb23d733dfcc8af652a78b7bf232f0e967710d044732185e561e47c0336b6" dependencies = [ "cc", "libc", @@ -1289,9 +1289,9 @@ dependencies = [ [[package]] name = "mimalloc" -version = "0.1.41" +version = "0.1.42" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9f41a2280ded0da56c8cf898babb86e8f10651a34adcfff190ae9a1159c6908d" +checksum = "e9186d86b79b52f4a77af65604b51225e8db1d6ee7e3f41aec1e40829c71a176" dependencies = [ "libmimalloc-sys", ] @@ -2824,18 +2824,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.60" +version = "1.0.61" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "579e9083ca58dd9dcf91a9923bb9054071b9ebbd800b342194c9feb0ee89fc18" +checksum = "c546c80d6be4bc6a00c0f01730c08df82eaa7a7a61f11d656526506112cc1709" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.60" +version = "1.0.61" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2470041c06ec3ac1ab38d0356a6119054dedaea53e12fbefc0de730a1c08524" +checksum = "46c3384250002a6d5af4d114f2845d37b57521033f30d5c3f46c4d70e1197533" dependencies = [ "proc-macro2", "quote", diff --git a/src/rust/Cargo.toml b/src/rust/Cargo.toml index 4df9cce70..6d3adcbf8 100644 --- a/src/rust/Cargo.toml +++ b/src/rust/Cargo.toml @@ -21,7 +21,7 @@ rpolars_debug_print = [] [workspace] # prevents package from thinking it's in the workspace [target.'cfg(any(not(target_os = "linux"), use_mimalloc))'.dependencies] -mimalloc = { version = "0.1.41", default-features = false } +mimalloc = { version = "0.1.42", default-features = false } [target.'cfg(all(target_os = "linux", not(use_mimalloc)))'.dependencies] jemallocator = { version = "0.5.0", features = ["disable_initial_exec_tls"] } @@ -51,7 +51,7 @@ serde = { version = "1.0.202", features = ["derive"] } serde_json = "*" smartstring = "1.0.1" state = "0.6.0" -thiserror = "1.0.60" +thiserror = "1.0.61" polars-core = { git = "https://github.com/pola-rs/polars.git", rev = "7bc70141f4dad7863a2026849522551abb274f00", default-features = false } polars-lazy = { git = "https://github.com/pola-rs/polars.git", rev = "7bc70141f4dad7863a2026849522551abb274f00", default-features = false } either = "1" @@ -139,6 +139,7 @@ features = [ "replace", "rle", "rolling_window", + "rolling_window_by", "round_series", "row_hash", "rows", diff --git a/src/rust/src/lazy/dataframe.rs b/src/rust/src/lazy/dataframe.rs index 8d14f2478..53879986c 100644 --- a/src/rust/src/lazy/dataframe.rs +++ b/src/rust/src/lazy/dataframe.rs @@ -87,7 +87,7 @@ impl RPolarsLazyFrame { fn deserialize(json: Robj) -> RResult { let json = robj_to!(str, json)?; - let lp = serde_json::from_str::(json) + let lp = serde_json::from_str::(json) .map_err(|err| RPolarsErr::new().plain(format!("{err:?}")))?; Ok(RPolarsLazyFrame(pl::LazyFrame::from(lp))) } diff --git a/src/rust/src/lazy/dsl.rs b/src/rust/src/lazy/dsl.rs index 744bb928f..a9ed08159 100644 --- a/src/rust/src/lazy/dsl.rs +++ b/src/rust/src/lazy/dsl.rs @@ -684,18 +684,12 @@ impl RPolarsExpr { weights: Robj, min_periods: Robj, center: Robj, - by: Robj, - closed: Robj, - warn_if_unsorted: Robj, ) -> RResult { - let options = pl::RollingOptions { - window_size: pl::Duration::parse(robj_to!(str, window_size)?), + let options = pl::RollingOptionsFixedWindow { + window_size: robj_to!(usize, window_size)?, weights: robj_to!(Option, Vec, f64, weights)?, min_periods: robj_to!(usize, min_periods)?, center: robj_to!(bool, center)?, - by: robj_to!(Option, String, by)?, - closed_window: robj_to!(Option, ClosedWindow, closed)?, - warn_if_unsorted: robj_to!(bool, warn_if_unsorted)?, fn_params: None, }; let quantile = robj_to!(f64, quantile)?; @@ -2624,15 +2618,12 @@ pub fn make_rolling_options( by_null: Robj, closed_null: Robj, warn_if_unsorted: Robj, -) -> RResult { - Ok(pl::RollingOptions { - window_size: pl::Duration::parse(robj_to!(str, window_size)?), +) -> RResult { + Ok(pl::RollingOptionsFixedWindow { + window_size: robj_to!(usize, window_size)?, weights: robj_to!(Option, Vec, f64, weights)?, min_periods: robj_to!(usize, min_periods)?, center: robj_to!(bool, center)?, - by: robj_to!(Option, String, by_null)?, - closed_window: robj_to!(Option, ClosedWindow, closed_null)?, - warn_if_unsorted: robj_to!(bool, warn_if_unsorted)?, ..Default::default() }) } diff --git a/src/rust/src/lazy/utils.rs b/src/rust/src/lazy/utils.rs deleted file mode 100644 index 3eb2f94dd..000000000 --- a/src/rust/src/lazy/utils.rs +++ /dev/null @@ -1,8 +0,0 @@ -use polars::lazy::dsl::Expr as PLExpr; -use crate::lazy::dsl::Expr as ArghExpr; - -pub fn r_exprs_to_exprs(r_exprs: Vec) -> Vec { - // Safety: - // transparent struct - unsafe { std::mem::transmute(r_exprs) } -} diff --git a/src/rust/src/rdataframe/mod.rs b/src/rust/src/rdataframe/mod.rs index f8fa4def0..33a024e3b 100644 --- a/src/rust/src/rdataframe/mod.rs +++ b/src/rust/src/rdataframe/mod.rs @@ -66,7 +66,7 @@ impl Iterator for OwnedDataFrameIterator { .collect(); self.idx += 1; - let chunk = polars::frame::ArrowChunk::new(batch_cols); + let chunk = arrow::record_batch::RecordBatch::new(batch_cols); let array = arrow::array::StructArray::new( self.data_type.clone(), chunk.into_arrays(), diff --git a/tests/testthat/test-dataframe.R b/tests/testthat/test-dataframe.R index f90c0d9de..fbe2b0914 100644 --- a/tests/testthat/test-dataframe.R +++ b/tests/testthat/test-dataframe.R @@ -1323,6 +1323,38 @@ test_that("rolling for DataFrame: basic example", { ) }) +test_that("rolling for DataFrame: using difftime as period", { + df = pl$DataFrame( + dt = c( + "2020-01-01", "2020-01-01", "2020-01-01", + "2020-01-02", "2020-01-03", "2020-01-08" + ), + a = c(3, 7, 5, 9, 2, 1) + )$with_columns( + pl$col("dt")$str$strptime(pl$Date, format = NULL)$set_sorted() + ) + + expect_equal( + df$rolling(index_column = "dt", period = "2d")$agg( + pl$sum("a")$alias("sum_a") + )$to_data_frame(), + df$rolling(index_column = "dt", period = as.difftime(2, units = "days"))$agg( + pl$sum("a")$alias("sum_a") + )$to_data_frame() + ) +}) + +test_that("rolling for LazyFrame: error if period is negative", { + df = pl$LazyFrame( + index = c(1L, 2L, 3L, 4L, 8L, 9L), + a = c(3, 7, 5, 9, 2, 1) + ) + expect_grepl_error( + df$rolling(index_column = "index", period = "-2i")$agg(pl$col("a"))$collect(), + "rolling window period should be strictly positive" + ) +}) + test_that("rolling for DataFrame: can be ungrouped", { df = pl$DataFrame( index = c(1:5, 6.0), diff --git a/tests/testthat/test-expr_expr.R b/tests/testthat/test-expr_expr.R index ff0638276..688e167d6 100644 --- a/tests/testthat/test-expr_expr.R +++ b/tests/testthat/test-expr_expr.R @@ -2577,6 +2577,42 @@ test_that("rolling, arg offset", { ) }) +test_that("rolling: error if period is negative", { + dates = c( + "2020-01-01 13:45:48", "2020-01-01 16:42:13", "2020-01-01 16:45:09", + "2020-01-02 18:12:48", "2020-01-03 19:45:32", "2020-01-08 23:16:43" + ) + + df = pl$DataFrame(dt = dates, a = c(3, 7, 5, 9, 2, 1))$ + with_columns( + pl$col("dt")$str$strptime(pl$Datetime("us"), format = "%Y-%m-%d %H:%M:%S")$set_sorted() + ) + expect_grepl_error( + df$select(pl$col("a")$rolling(index_column = "dt", period = "-2d")), + "rolling window period should be strictly positive" + ) +}) + +test_that("rolling: passing a difftime as period works", { + dates = c( + "2020-01-01 13:45:48", "2020-01-01 16:42:13", "2020-01-01 16:45:09", + "2020-01-02 18:12:48", "2020-01-03 19:45:32", "2020-01-08 23:16:43" + ) + + df = pl$DataFrame(dt = dates, a = c(3, 7, 5, 9, 2, 1))$ + with_columns( + pl$col("dt")$str$strptime(pl$Datetime("us"), format = "%Y-%m-%d %H:%M:%S")$set_sorted() + ) + expect_identical( + df$select( + sum_a_offset1 = pl$sum("a")$rolling(index_column = "dt", period = "2d", offset = "1d") + )$to_data_frame(), + df$select( + sum_a_offset1 = pl$sum("a")$rolling(index_column = "dt", period = as.difftime(2, units = "days"), offset = "1d") + )$to_data_frame() + ) +}) + test_that("rolling, arg check_sorted", { dates = c( "2020-01-02 18:12:48", "2020-01-03 19:45:32", "2020-01-08 23:16:43", diff --git a/tests/testthat/test-groupby.R b/tests/testthat/test-groupby.R index fcb81a4e3..cbf9ee1d3 100644 --- a/tests/testthat/test-groupby.R +++ b/tests/testthat/test-groupby.R @@ -274,6 +274,18 @@ test_that("group_by_dynamic for LazyFrame: error if not explicitly sorted", { ) }) +test_that("group_by_dynamic for LazyFrame: error if every is negative", { + df = pl$LazyFrame( + idx = 0:5, + n = 0:5 + )$with_columns(pl$col("idx")$set_sorted()) + + expect_grepl_error( + df$group_by_dynamic("idx", every = "-2i")$agg(pl$col("n")$mean())$collect(), + "'every' argument must be positive" + ) +}) + test_that("group_by_dynamic for LazyFrame: arg 'closed' works", { df = pl$LazyFrame( dt = c( diff --git a/tests/testthat/test-lazy.R b/tests/testthat/test-lazy.R index b67c17e78..7e0a1ba2f 100644 --- a/tests/testthat/test-lazy.R +++ b/tests/testthat/test-lazy.R @@ -555,7 +555,7 @@ test_that("join_asof_simple", { # test allow_parallel and force_parallel - # export LogicalPlan as json string + # export DslPlan as json string logical_json_plan_TT = pop$join_asof(gdp, on = "date", allow_parallel = TRUE, force_parallel = TRUE) |> .pr$LazyFrame$debug_plan() |> @@ -570,7 +570,7 @@ test_that("join_asof_simple", { allow_p_pat = r"{*"allow_parallel":\s*([^,]*)}" # find allow_parallel value in json string force_p_pat = r"{*"force_parallel":\s*([^,]*)}" - # test if setting was as expected in LogicalPlan + # test if setting was as expected in DslPlan expect_identical(get_reg(logical_json_plan_TT, allow_p_pat), "\"allow_parallel\": Bool(true)") expect_identical(get_reg(logical_json_plan_TT, force_p_pat), "\"force_parallel\": Bool(true)") expect_identical(get_reg(logical_json_plan_FF, allow_p_pat), "\"allow_parallel\": Bool(false)") @@ -951,6 +951,27 @@ test_that("rolling for LazyFrame: integer variable", { ) }) +test_that("rolling for LazyFrame: using difftime as period", { + df = pl$LazyFrame( + dt = c( + "2020-01-01", "2020-01-01", "2020-01-01", + "2020-01-02", "2020-01-03", "2020-01-08" + ), + a = c(3, 7, 5, 9, 2, 1) + )$with_columns( + pl$col("dt")$str$strptime(pl$Date, format = NULL)$set_sorted() + ) + + expect_equal( + df$rolling(index_column = "dt", period = "2d")$agg( + pl$sum("a")$alias("sum_a") + )$collect()$to_data_frame(), + df$rolling(index_column = "dt", period = as.difftime(2, units = "days"))$agg( + pl$sum("a")$alias("sum_a") + )$collect()$to_data_frame() + ) +}) + test_that("rolling for LazyFrame: error if not explicitly sorted", { df = pl$LazyFrame( index = c(1L, 2L, 3L, 4L, 8L, 9L), @@ -962,6 +983,17 @@ test_that("rolling for LazyFrame: error if not explicitly sorted", { ) }) +test_that("rolling for LazyFrame: error if period is negative", { + df = pl$LazyFrame( + index = c(1L, 2L, 3L, 4L, 8L, 9L), + a = c(3, 7, 5, 9, 2, 1) + ) + expect_grepl_error( + df$rolling(index_column = "index", period = "-2i")$agg(pl$col("a"))$collect(), + "rolling window period should be strictly positive" + ) +}) + test_that("rolling for LazyFrame: argument 'group_by' works", { df = pl$LazyFrame( index = c(1L, 2L, 3L, 4L, 8L, 9L),