From 1aba78a7df5495b797860afc9e4e2696bc649042 Mon Sep 17 00:00:00 2001 From: Etienne Bacher <52219252+etiennebacher@users.noreply.github.com> Date: Tue, 21 May 2024 23:28:58 +0100 Subject: [PATCH] refactor: some cleanup and more tests for `$rolling()` (#1103) --- R/dataframe__frame.R | 7 +++---- R/expr__expr.R | 11 +++++----- R/lazyframe__lazy.R | 7 +++---- R/parse_as_duration.R | 14 +++++++++---- tests/testthat/test-dataframe.R | 32 +++++++++++++++++++++++++++++ tests/testthat/test-expr_expr.R | 36 +++++++++++++++++++++++++++++++++ tests/testthat/test-groupby.R | 12 +++++++++++ tests/testthat/test-lazy.R | 32 +++++++++++++++++++++++++++++ 8 files changed, 134 insertions(+), 17 deletions(-) diff --git a/R/dataframe__frame.R b/R/dataframe__frame.R index 724b690e9..a2268608b 100644 --- a/R/dataframe__frame.R +++ b/R/dataframe__frame.R @@ -2131,9 +2131,8 @@ DataFrame_rolling = function( closed = "right", group_by = NULL, check_sorted = TRUE) { - if (is.null(offset)) { - offset = paste0("-", period) # TODO: `paste0` should be executed after `period` is parsed as string - } + period = parse_as_polars_duration_string(period) + offset = parse_as_polars_duration_string(offset) %||% negate_duration_string(period) construct_rolling_group_by(self, index_column, period, offset, closed, group_by, check_sorted) } @@ -2217,7 +2216,7 @@ DataFrame_group_by_dynamic = function( start_by = "window", check_sorted = TRUE) { every = parse_as_polars_duration_string(every) - offset = parse_as_polars_duration_string(offset) %||% paste0("-", every) + offset = parse_as_polars_duration_string(offset) %||% negate_duration_string(every) period = parse_as_polars_duration_string(period) %||% every construct_group_by_dynamic( self, index_column, every, period, offset, include_boundaries, closed, label, diff --git a/R/expr__expr.R b/R/expr__expr.R index bc0130430..ae345edb4 100644 --- a/R/expr__expr.R +++ b/R/expr__expr.R @@ -3316,11 +3316,12 @@ Expr_peak_max = function() { Expr_rolling = function( index_column, ..., - period, offset = NULL, - closed = "right", check_sorted = TRUE) { - if (is.null(offset)) { - offset = paste0("-", period) # TODO: `paste0` should be executed after `period` is parsed as string - } + period, + offset = NULL, + closed = "right", + check_sorted = TRUE) { + period = parse_as_polars_duration_string(period) + offset = parse_as_polars_duration_string(offset) %||% negate_duration_string(period) .pr$Expr$rolling(self, index_column, period, offset, closed, check_sorted) |> unwrap("in $rolling():") } diff --git a/R/lazyframe__lazy.R b/R/lazyframe__lazy.R index b2921823f..b637e72d1 100644 --- a/R/lazyframe__lazy.R +++ b/R/lazyframe__lazy.R @@ -1914,9 +1914,8 @@ LazyFrame_rolling = function( closed = "right", group_by = NULL, check_sorted = TRUE) { - if (is.null(offset)) { - offset = paste0("-", period) # TODO: `paste0` should be executed after `period` is parsed as string - } + period = parse_as_polars_duration_string(period) + offset = parse_as_polars_duration_string(offset) %||% negate_duration_string(period) .pr$LazyFrame$rolling( self, index_column, period, offset, closed, wrap_elist_result(group_by, str_to_lit = FALSE), check_sorted @@ -2026,7 +2025,7 @@ LazyFrame_group_by_dynamic = function( start_by = "window", check_sorted = TRUE) { every = parse_as_polars_duration_string(every) - offset = parse_as_polars_duration_string(offset) %||% paste0("-", every) + offset = parse_as_polars_duration_string(offset) %||% negate_duration_string(every) period = parse_as_polars_duration_string(period) %||% every .pr$LazyFrame$group_by_dynamic( diff --git a/R/parse_as_duration.R b/R/parse_as_duration.R index f06aec60d..a5705afd9 100644 --- a/R/parse_as_duration.R +++ b/R/parse_as_duration.R @@ -56,10 +56,9 @@ parse_as_polars_duration_string.default = function(x, default = NULL, ...) { unwrap() } - #' @exportS3Method parse_as_polars_duration_string.character = function(x, default = NULL, ...) { - if (length(x) != 1L || is.na(x)) { + if (length(x) != 1L) { Err_plain(paste0("`", deparse(substitute(x)), "` must be a single non-NA character or difftime.")) |> unwrap() } @@ -67,10 +66,9 @@ parse_as_polars_duration_string.character = function(x, default = NULL, ...) { x } - #' @exportS3Method parse_as_polars_duration_string.difftime = function(x, default = NULL, ...) { - if (length(x) != 1L || is.na(x)) { + if (length(x) != 1L) { Err_plain(paste0("`", deparse(substitute(x)), "` must be a single non-NA character or difftime.")) |> unwrap() } @@ -95,3 +93,11 @@ difftime_to_duration_string = function(dft) { ) paste0(value, unit) } + +negate_duration_string = function(x) { + if (startsWith(x, "-")) { + gsub("^-", "", x) + } else { + paste0("-", x) + } +} diff --git a/tests/testthat/test-dataframe.R b/tests/testthat/test-dataframe.R index f90c0d9de..fbe2b0914 100644 --- a/tests/testthat/test-dataframe.R +++ b/tests/testthat/test-dataframe.R @@ -1323,6 +1323,38 @@ test_that("rolling for DataFrame: basic example", { ) }) +test_that("rolling for DataFrame: using difftime as period", { + df = pl$DataFrame( + dt = c( + "2020-01-01", "2020-01-01", "2020-01-01", + "2020-01-02", "2020-01-03", "2020-01-08" + ), + a = c(3, 7, 5, 9, 2, 1) + )$with_columns( + pl$col("dt")$str$strptime(pl$Date, format = NULL)$set_sorted() + ) + + expect_equal( + df$rolling(index_column = "dt", period = "2d")$agg( + pl$sum("a")$alias("sum_a") + )$to_data_frame(), + df$rolling(index_column = "dt", period = as.difftime(2, units = "days"))$agg( + pl$sum("a")$alias("sum_a") + )$to_data_frame() + ) +}) + +test_that("rolling for LazyFrame: error if period is negative", { + df = pl$LazyFrame( + index = c(1L, 2L, 3L, 4L, 8L, 9L), + a = c(3, 7, 5, 9, 2, 1) + ) + expect_grepl_error( + df$rolling(index_column = "index", period = "-2i")$agg(pl$col("a"))$collect(), + "rolling window period should be strictly positive" + ) +}) + test_that("rolling for DataFrame: can be ungrouped", { df = pl$DataFrame( index = c(1:5, 6.0), diff --git a/tests/testthat/test-expr_expr.R b/tests/testthat/test-expr_expr.R index ff0638276..688e167d6 100644 --- a/tests/testthat/test-expr_expr.R +++ b/tests/testthat/test-expr_expr.R @@ -2577,6 +2577,42 @@ test_that("rolling, arg offset", { ) }) +test_that("rolling: error if period is negative", { + dates = c( + "2020-01-01 13:45:48", "2020-01-01 16:42:13", "2020-01-01 16:45:09", + "2020-01-02 18:12:48", "2020-01-03 19:45:32", "2020-01-08 23:16:43" + ) + + df = pl$DataFrame(dt = dates, a = c(3, 7, 5, 9, 2, 1))$ + with_columns( + pl$col("dt")$str$strptime(pl$Datetime("us"), format = "%Y-%m-%d %H:%M:%S")$set_sorted() + ) + expect_grepl_error( + df$select(pl$col("a")$rolling(index_column = "dt", period = "-2d")), + "rolling window period should be strictly positive" + ) +}) + +test_that("rolling: passing a difftime as period works", { + dates = c( + "2020-01-01 13:45:48", "2020-01-01 16:42:13", "2020-01-01 16:45:09", + "2020-01-02 18:12:48", "2020-01-03 19:45:32", "2020-01-08 23:16:43" + ) + + df = pl$DataFrame(dt = dates, a = c(3, 7, 5, 9, 2, 1))$ + with_columns( + pl$col("dt")$str$strptime(pl$Datetime("us"), format = "%Y-%m-%d %H:%M:%S")$set_sorted() + ) + expect_identical( + df$select( + sum_a_offset1 = pl$sum("a")$rolling(index_column = "dt", period = "2d", offset = "1d") + )$to_data_frame(), + df$select( + sum_a_offset1 = pl$sum("a")$rolling(index_column = "dt", period = as.difftime(2, units = "days"), offset = "1d") + )$to_data_frame() + ) +}) + test_that("rolling, arg check_sorted", { dates = c( "2020-01-02 18:12:48", "2020-01-03 19:45:32", "2020-01-08 23:16:43", diff --git a/tests/testthat/test-groupby.R b/tests/testthat/test-groupby.R index fcb81a4e3..cbf9ee1d3 100644 --- a/tests/testthat/test-groupby.R +++ b/tests/testthat/test-groupby.R @@ -274,6 +274,18 @@ test_that("group_by_dynamic for LazyFrame: error if not explicitly sorted", { ) }) +test_that("group_by_dynamic for LazyFrame: error if every is negative", { + df = pl$LazyFrame( + idx = 0:5, + n = 0:5 + )$with_columns(pl$col("idx")$set_sorted()) + + expect_grepl_error( + df$group_by_dynamic("idx", every = "-2i")$agg(pl$col("n")$mean())$collect(), + "'every' argument must be positive" + ) +}) + test_that("group_by_dynamic for LazyFrame: arg 'closed' works", { df = pl$LazyFrame( dt = c( diff --git a/tests/testthat/test-lazy.R b/tests/testthat/test-lazy.R index b67c17e78..28501f875 100644 --- a/tests/testthat/test-lazy.R +++ b/tests/testthat/test-lazy.R @@ -951,6 +951,27 @@ test_that("rolling for LazyFrame: integer variable", { ) }) +test_that("rolling for LazyFrame: using difftime as period", { + df = pl$LazyFrame( + dt = c( + "2020-01-01", "2020-01-01", "2020-01-01", + "2020-01-02", "2020-01-03", "2020-01-08" + ), + a = c(3, 7, 5, 9, 2, 1) + )$with_columns( + pl$col("dt")$str$strptime(pl$Date, format = NULL)$set_sorted() + ) + + expect_equal( + df$rolling(index_column = "dt", period = "2d")$agg( + pl$sum("a")$alias("sum_a") + )$collect()$to_data_frame(), + df$rolling(index_column = "dt", period = as.difftime(2, units = "days"))$agg( + pl$sum("a")$alias("sum_a") + )$collect()$to_data_frame() + ) +}) + test_that("rolling for LazyFrame: error if not explicitly sorted", { df = pl$LazyFrame( index = c(1L, 2L, 3L, 4L, 8L, 9L), @@ -962,6 +983,17 @@ test_that("rolling for LazyFrame: error if not explicitly sorted", { ) }) +test_that("rolling for LazyFrame: error if period is negative", { + df = pl$LazyFrame( + index = c(1L, 2L, 3L, 4L, 8L, 9L), + a = c(3, 7, 5, 9, 2, 1) + ) + expect_grepl_error( + df$rolling(index_column = "index", period = "-2i")$agg(pl$col("a"))$collect(), + "rolling window period should be strictly positive" + ) +}) + test_that("rolling for LazyFrame: argument 'group_by' works", { df = pl$LazyFrame( index = c(1L, 2L, 3L, 4L, 8L, 9L),