Skip to content

Commit

Permalink
refactor: some cleanup and more tests for $rolling() (#1103)
Browse files Browse the repository at this point in the history
  • Loading branch information
etiennebacher authored May 21, 2024
1 parent 432a98c commit 1aba78a
Show file tree
Hide file tree
Showing 8 changed files with 134 additions and 17 deletions.
7 changes: 3 additions & 4 deletions R/dataframe__frame.R
Original file line number Diff line number Diff line change
Expand Up @@ -2131,9 +2131,8 @@ DataFrame_rolling = function(
closed = "right",
group_by = NULL,
check_sorted = TRUE) {
if (is.null(offset)) {
offset = paste0("-", period) # TODO: `paste0` should be executed after `period` is parsed as string
}
period = parse_as_polars_duration_string(period)
offset = parse_as_polars_duration_string(offset) %||% negate_duration_string(period)
construct_rolling_group_by(self, index_column, period, offset, closed, group_by, check_sorted)
}

Expand Down Expand Up @@ -2217,7 +2216,7 @@ DataFrame_group_by_dynamic = function(
start_by = "window",
check_sorted = TRUE) {
every = parse_as_polars_duration_string(every)
offset = parse_as_polars_duration_string(offset) %||% paste0("-", every)
offset = parse_as_polars_duration_string(offset) %||% negate_duration_string(every)
period = parse_as_polars_duration_string(period) %||% every
construct_group_by_dynamic(
self, index_column, every, period, offset, include_boundaries, closed, label,
Expand Down
11 changes: 6 additions & 5 deletions R/expr__expr.R
Original file line number Diff line number Diff line change
Expand Up @@ -3316,11 +3316,12 @@ Expr_peak_max = function() {
Expr_rolling = function(
index_column,
...,
period, offset = NULL,
closed = "right", check_sorted = TRUE) {
if (is.null(offset)) {
offset = paste0("-", period) # TODO: `paste0` should be executed after `period` is parsed as string
}
period,
offset = NULL,
closed = "right",
check_sorted = TRUE) {
period = parse_as_polars_duration_string(period)
offset = parse_as_polars_duration_string(offset) %||% negate_duration_string(period)
.pr$Expr$rolling(self, index_column, period, offset, closed, check_sorted) |>
unwrap("in $rolling():")
}
Expand Down
7 changes: 3 additions & 4 deletions R/lazyframe__lazy.R
Original file line number Diff line number Diff line change
Expand Up @@ -1914,9 +1914,8 @@ LazyFrame_rolling = function(
closed = "right",
group_by = NULL,
check_sorted = TRUE) {
if (is.null(offset)) {
offset = paste0("-", period) # TODO: `paste0` should be executed after `period` is parsed as string
}
period = parse_as_polars_duration_string(period)
offset = parse_as_polars_duration_string(offset) %||% negate_duration_string(period)
.pr$LazyFrame$rolling(
self, index_column, period, offset, closed,
wrap_elist_result(group_by, str_to_lit = FALSE), check_sorted
Expand Down Expand Up @@ -2026,7 +2025,7 @@ LazyFrame_group_by_dynamic = function(
start_by = "window",
check_sorted = TRUE) {
every = parse_as_polars_duration_string(every)
offset = parse_as_polars_duration_string(offset) %||% paste0("-", every)
offset = parse_as_polars_duration_string(offset) %||% negate_duration_string(every)
period = parse_as_polars_duration_string(period) %||% every

.pr$LazyFrame$group_by_dynamic(
Expand Down
14 changes: 10 additions & 4 deletions R/parse_as_duration.R
Original file line number Diff line number Diff line change
Expand Up @@ -56,21 +56,19 @@ parse_as_polars_duration_string.default = function(x, default = NULL, ...) {
unwrap()
}


#' @exportS3Method
parse_as_polars_duration_string.character = function(x, default = NULL, ...) {
if (length(x) != 1L || is.na(x)) {
if (length(x) != 1L) {
Err_plain(paste0("`", deparse(substitute(x)), "` must be a single non-NA character or difftime.")) |>
unwrap()
}

x
}


#' @exportS3Method
parse_as_polars_duration_string.difftime = function(x, default = NULL, ...) {
if (length(x) != 1L || is.na(x)) {
if (length(x) != 1L) {
Err_plain(paste0("`", deparse(substitute(x)), "` must be a single non-NA character or difftime.")) |>
unwrap()
}
Expand All @@ -95,3 +93,11 @@ difftime_to_duration_string = function(dft) {
)
paste0(value, unit)
}

negate_duration_string = function(x) {
if (startsWith(x, "-")) {
gsub("^-", "", x)
} else {
paste0("-", x)
}
}
32 changes: 32 additions & 0 deletions tests/testthat/test-dataframe.R
Original file line number Diff line number Diff line change
Expand Up @@ -1323,6 +1323,38 @@ test_that("rolling for DataFrame: basic example", {
)
})

test_that("rolling for DataFrame: using difftime as period", {
df = pl$DataFrame(
dt = c(
"2020-01-01", "2020-01-01", "2020-01-01",
"2020-01-02", "2020-01-03", "2020-01-08"
),
a = c(3, 7, 5, 9, 2, 1)
)$with_columns(
pl$col("dt")$str$strptime(pl$Date, format = NULL)$set_sorted()
)

expect_equal(
df$rolling(index_column = "dt", period = "2d")$agg(
pl$sum("a")$alias("sum_a")
)$to_data_frame(),
df$rolling(index_column = "dt", period = as.difftime(2, units = "days"))$agg(
pl$sum("a")$alias("sum_a")
)$to_data_frame()
)
})

test_that("rolling for LazyFrame: error if period is negative", {
df = pl$LazyFrame(
index = c(1L, 2L, 3L, 4L, 8L, 9L),
a = c(3, 7, 5, 9, 2, 1)
)
expect_grepl_error(
df$rolling(index_column = "index", period = "-2i")$agg(pl$col("a"))$collect(),
"rolling window period should be strictly positive"
)
})

test_that("rolling for DataFrame: can be ungrouped", {
df = pl$DataFrame(
index = c(1:5, 6.0),
Expand Down
36 changes: 36 additions & 0 deletions tests/testthat/test-expr_expr.R
Original file line number Diff line number Diff line change
Expand Up @@ -2577,6 +2577,42 @@ test_that("rolling, arg offset", {
)
})

test_that("rolling: error if period is negative", {
dates = c(
"2020-01-01 13:45:48", "2020-01-01 16:42:13", "2020-01-01 16:45:09",
"2020-01-02 18:12:48", "2020-01-03 19:45:32", "2020-01-08 23:16:43"
)

df = pl$DataFrame(dt = dates, a = c(3, 7, 5, 9, 2, 1))$
with_columns(
pl$col("dt")$str$strptime(pl$Datetime("us"), format = "%Y-%m-%d %H:%M:%S")$set_sorted()
)
expect_grepl_error(
df$select(pl$col("a")$rolling(index_column = "dt", period = "-2d")),
"rolling window period should be strictly positive"
)
})

test_that("rolling: passing a difftime as period works", {
dates = c(
"2020-01-01 13:45:48", "2020-01-01 16:42:13", "2020-01-01 16:45:09",
"2020-01-02 18:12:48", "2020-01-03 19:45:32", "2020-01-08 23:16:43"
)

df = pl$DataFrame(dt = dates, a = c(3, 7, 5, 9, 2, 1))$
with_columns(
pl$col("dt")$str$strptime(pl$Datetime("us"), format = "%Y-%m-%d %H:%M:%S")$set_sorted()
)
expect_identical(
df$select(
sum_a_offset1 = pl$sum("a")$rolling(index_column = "dt", period = "2d", offset = "1d")
)$to_data_frame(),
df$select(
sum_a_offset1 = pl$sum("a")$rolling(index_column = "dt", period = as.difftime(2, units = "days"), offset = "1d")
)$to_data_frame()
)
})

test_that("rolling, arg check_sorted", {
dates = c(
"2020-01-02 18:12:48", "2020-01-03 19:45:32", "2020-01-08 23:16:43",
Expand Down
12 changes: 12 additions & 0 deletions tests/testthat/test-groupby.R
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,18 @@ test_that("group_by_dynamic for LazyFrame: error if not explicitly sorted", {
)
})

test_that("group_by_dynamic for LazyFrame: error if every is negative", {
df = pl$LazyFrame(
idx = 0:5,
n = 0:5
)$with_columns(pl$col("idx")$set_sorted())

expect_grepl_error(
df$group_by_dynamic("idx", every = "-2i")$agg(pl$col("n")$mean())$collect(),
"'every' argument must be positive"
)
})

test_that("group_by_dynamic for LazyFrame: arg 'closed' works", {
df = pl$LazyFrame(
dt = c(
Expand Down
32 changes: 32 additions & 0 deletions tests/testthat/test-lazy.R
Original file line number Diff line number Diff line change
Expand Up @@ -951,6 +951,27 @@ test_that("rolling for LazyFrame: integer variable", {
)
})

test_that("rolling for LazyFrame: using difftime as period", {
df = pl$LazyFrame(
dt = c(
"2020-01-01", "2020-01-01", "2020-01-01",
"2020-01-02", "2020-01-03", "2020-01-08"
),
a = c(3, 7, 5, 9, 2, 1)
)$with_columns(
pl$col("dt")$str$strptime(pl$Date, format = NULL)$set_sorted()
)

expect_equal(
df$rolling(index_column = "dt", period = "2d")$agg(
pl$sum("a")$alias("sum_a")
)$collect()$to_data_frame(),
df$rolling(index_column = "dt", period = as.difftime(2, units = "days"))$agg(
pl$sum("a")$alias("sum_a")
)$collect()$to_data_frame()
)
})

test_that("rolling for LazyFrame: error if not explicitly sorted", {
df = pl$LazyFrame(
index = c(1L, 2L, 3L, 4L, 8L, 9L),
Expand All @@ -962,6 +983,17 @@ test_that("rolling for LazyFrame: error if not explicitly sorted", {
)
})

test_that("rolling for LazyFrame: error if period is negative", {
df = pl$LazyFrame(
index = c(1L, 2L, 3L, 4L, 8L, 9L),
a = c(3, 7, 5, 9, 2, 1)
)
expect_grepl_error(
df$rolling(index_column = "index", period = "-2i")$agg(pl$col("a"))$collect(),
"rolling window period should be strictly positive"
)
})

test_that("rolling for LazyFrame: argument 'group_by' works", {
df = pl$LazyFrame(
index = c(1L, 2L, 3L, 4L, 8L, 9L),
Expand Down

0 comments on commit 1aba78a

Please sign in to comment.