Skip to content

Commit

Permalink
wip [skip ci]
Browse files Browse the repository at this point in the history
  • Loading branch information
eitsupi committed May 22, 2024
1 parent 960b401 commit 092bb4c
Show file tree
Hide file tree
Showing 15 changed files with 155 additions and 54 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -119,4 +119,4 @@ Collate:
Config/rextendr/version: 0.3.1
VignetteBuilder: knitr
Config/polars/LibVersion: 0.39.4
Config/polars/RustToolchainVersion: nightly-2024-04-15
Config/polars/RustToolchainVersion: nightly-2024-05-14
7 changes: 3 additions & 4 deletions R/dataframe__frame.R
Original file line number Diff line number Diff line change
Expand Up @@ -2131,9 +2131,8 @@ DataFrame_rolling = function(
closed = "right",
group_by = NULL,
check_sorted = TRUE) {
if (is.null(offset)) {
offset = paste0("-", period) # TODO: `paste0` should be executed after `period` is parsed as string
}
period = parse_as_polars_duration_string(period)
offset = parse_as_polars_duration_string(offset) %||% negate_duration_string(period)
construct_rolling_group_by(self, index_column, period, offset, closed, group_by, check_sorted)
}

Expand Down Expand Up @@ -2217,7 +2216,7 @@ DataFrame_group_by_dynamic = function(
start_by = "window",
check_sorted = TRUE) {
every = parse_as_polars_duration_string(every)
offset = parse_as_polars_duration_string(offset) %||% paste0("-", every)
offset = parse_as_polars_duration_string(offset) %||% negate_duration_string(every)
period = parse_as_polars_duration_string(period) %||% every
construct_group_by_dynamic(
self, index_column, every, period, offset, include_boundaries, closed, label,
Expand Down
11 changes: 6 additions & 5 deletions R/expr__expr.R
Original file line number Diff line number Diff line change
Expand Up @@ -3316,11 +3316,12 @@ Expr_peak_max = function() {
Expr_rolling = function(
index_column,
...,
period, offset = NULL,
closed = "right", check_sorted = TRUE) {
if (is.null(offset)) {
offset = paste0("-", period) # TODO: `paste0` should be executed after `period` is parsed as string
}
period,
offset = NULL,
closed = "right",
check_sorted = TRUE) {
period = parse_as_polars_duration_string(period)
offset = parse_as_polars_duration_string(offset) %||% negate_duration_string(period)
.pr$Expr$rolling(self, index_column, period, offset, closed, check_sorted) |>
unwrap("in $rolling():")
}
Expand Down
7 changes: 3 additions & 4 deletions R/lazyframe__lazy.R
Original file line number Diff line number Diff line change
Expand Up @@ -1914,9 +1914,8 @@ LazyFrame_rolling = function(
closed = "right",
group_by = NULL,
check_sorted = TRUE) {
if (is.null(offset)) {
offset = paste0("-", period) # TODO: `paste0` should be executed after `period` is parsed as string
}
period = parse_as_polars_duration_string(period)
offset = parse_as_polars_duration_string(offset) %||% negate_duration_string(period)
.pr$LazyFrame$rolling(
self, index_column, period, offset, closed,
wrap_elist_result(group_by, str_to_lit = FALSE), check_sorted
Expand Down Expand Up @@ -2026,7 +2025,7 @@ LazyFrame_group_by_dynamic = function(
start_by = "window",
check_sorted = TRUE) {
every = parse_as_polars_duration_string(every)
offset = parse_as_polars_duration_string(offset) %||% paste0("-", every)
offset = parse_as_polars_duration_string(offset) %||% negate_duration_string(every)
period = parse_as_polars_duration_string(period) %||% every

.pr$LazyFrame$group_by_dynamic(
Expand Down
14 changes: 10 additions & 4 deletions R/parse_as_duration.R
Original file line number Diff line number Diff line change
Expand Up @@ -56,21 +56,19 @@ parse_as_polars_duration_string.default = function(x, default = NULL, ...) {
unwrap()
}


#' @exportS3Method
parse_as_polars_duration_string.character = function(x, default = NULL, ...) {
if (length(x) != 1L || is.na(x)) {
if (length(x) != 1L) {
Err_plain(paste0("`", deparse(substitute(x)), "` must be a single non-NA character or difftime.")) |>
unwrap()
}

x
}


#' @exportS3Method
parse_as_polars_duration_string.difftime = function(x, default = NULL, ...) {
if (length(x) != 1L || is.na(x)) {
if (length(x) != 1L) {
Err_plain(paste0("`", deparse(substitute(x)), "` must be a single non-NA character or difftime.")) |>
unwrap()
}
Expand All @@ -95,3 +93,11 @@ difftime_to_duration_string = function(dft) {
)
paste0(value, unit)
}

negate_duration_string = function(x) {
if (startsWith(x, "-")) {
gsub("^-", "", x)
} else {
paste0("-", x)
}
}
16 changes: 8 additions & 8 deletions src/rust/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions src/rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ rpolars_debug_print = []
[workspace]
# prevents package from thinking it's in the workspace
[target.'cfg(any(not(target_os = "linux"), use_mimalloc))'.dependencies]
mimalloc = { version = "0.1.41", default-features = false }
mimalloc = { version = "0.1.42", default-features = false }

[target.'cfg(all(target_os = "linux", not(use_mimalloc)))'.dependencies]
jemallocator = { version = "0.5.0", features = ["disable_initial_exec_tls"] }
Expand Down Expand Up @@ -51,7 +51,7 @@ serde = { version = "1.0.202", features = ["derive"] }
serde_json = "*"
smartstring = "1.0.1"
state = "0.6.0"
thiserror = "1.0.60"
thiserror = "1.0.61"
polars-core = { git = "https://github.com/pola-rs/polars.git", rev = "7bc70141f4dad7863a2026849522551abb274f00", default-features = false }
polars-lazy = { git = "https://github.com/pola-rs/polars.git", rev = "7bc70141f4dad7863a2026849522551abb274f00", default-features = false }
either = "1"
Expand Down Expand Up @@ -139,6 +139,7 @@ features = [
"replace",
"rle",
"rolling_window",
"rolling_window_by",
"round_series",
"row_hash",
"rows",
Expand Down
2 changes: 1 addition & 1 deletion src/rust/src/lazy/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ impl RPolarsLazyFrame {

fn deserialize(json: Robj) -> RResult<Self> {
let json = robj_to!(str, json)?;
let lp = serde_json::from_str::<pl::LogicalPlan>(json)
let lp = serde_json::from_str::<pl::DslPlan>(json)
.map_err(|err| RPolarsErr::new().plain(format!("{err:?}")))?;
Ok(RPolarsLazyFrame(pl::LazyFrame::from(lp)))
}
Expand Down
19 changes: 5 additions & 14 deletions src/rust/src/lazy/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -684,18 +684,12 @@ impl RPolarsExpr {
weights: Robj,
min_periods: Robj,
center: Robj,
by: Robj,
closed: Robj,
warn_if_unsorted: Robj,
) -> RResult<Self> {
let options = pl::RollingOptions {
window_size: pl::Duration::parse(robj_to!(str, window_size)?),
let options = pl::RollingOptionsFixedWindow {
window_size: robj_to!(usize, window_size)?,
weights: robj_to!(Option, Vec, f64, weights)?,
min_periods: robj_to!(usize, min_periods)?,
center: robj_to!(bool, center)?,
by: robj_to!(Option, String, by)?,
closed_window: robj_to!(Option, ClosedWindow, closed)?,
warn_if_unsorted: robj_to!(bool, warn_if_unsorted)?,
fn_params: None,
};
let quantile = robj_to!(f64, quantile)?;
Expand Down Expand Up @@ -2624,15 +2618,12 @@ pub fn make_rolling_options(
by_null: Robj,
closed_null: Robj,
warn_if_unsorted: Robj,
) -> RResult<pl::RollingOptions> {
Ok(pl::RollingOptions {
window_size: pl::Duration::parse(robj_to!(str, window_size)?),
) -> RResult<pl::RollingOptionsFixedWindow> {
Ok(pl::RollingOptionsFixedWindow {
window_size: robj_to!(usize, window_size)?,
weights: robj_to!(Option, Vec, f64, weights)?,
min_periods: robj_to!(usize, min_periods)?,
center: robj_to!(bool, center)?,
by: robj_to!(Option, String, by_null)?,
closed_window: robj_to!(Option, ClosedWindow, closed_null)?,
warn_if_unsorted: robj_to!(bool, warn_if_unsorted)?,
..Default::default()
})
}
Expand Down
8 changes: 0 additions & 8 deletions src/rust/src/lazy/utils.rs

This file was deleted.

2 changes: 1 addition & 1 deletion src/rust/src/rdataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ impl Iterator for OwnedDataFrameIterator {
.collect();
self.idx += 1;

let chunk = polars::frame::ArrowChunk::new(batch_cols);
let chunk = arrow::record_batch::RecordBatch::new(batch_cols);
let array = arrow::array::StructArray::new(
self.data_type.clone(),
chunk.into_arrays(),
Expand Down
32 changes: 32 additions & 0 deletions tests/testthat/test-dataframe.R
Original file line number Diff line number Diff line change
Expand Up @@ -1323,6 +1323,38 @@ test_that("rolling for DataFrame: basic example", {
)
})

test_that("rolling for DataFrame: using difftime as period", {
df = pl$DataFrame(
dt = c(
"2020-01-01", "2020-01-01", "2020-01-01",
"2020-01-02", "2020-01-03", "2020-01-08"
),
a = c(3, 7, 5, 9, 2, 1)
)$with_columns(
pl$col("dt")$str$strptime(pl$Date, format = NULL)$set_sorted()
)

expect_equal(
df$rolling(index_column = "dt", period = "2d")$agg(
pl$sum("a")$alias("sum_a")
)$to_data_frame(),
df$rolling(index_column = "dt", period = as.difftime(2, units = "days"))$agg(
pl$sum("a")$alias("sum_a")
)$to_data_frame()
)
})

test_that("rolling for LazyFrame: error if period is negative", {
df = pl$LazyFrame(
index = c(1L, 2L, 3L, 4L, 8L, 9L),
a = c(3, 7, 5, 9, 2, 1)
)
expect_grepl_error(
df$rolling(index_column = "index", period = "-2i")$agg(pl$col("a"))$collect(),
"rolling window period should be strictly positive"
)
})

test_that("rolling for DataFrame: can be ungrouped", {
df = pl$DataFrame(
index = c(1:5, 6.0),
Expand Down
36 changes: 36 additions & 0 deletions tests/testthat/test-expr_expr.R
Original file line number Diff line number Diff line change
Expand Up @@ -2577,6 +2577,42 @@ test_that("rolling, arg offset", {
)
})

test_that("rolling: error if period is negative", {
dates = c(
"2020-01-01 13:45:48", "2020-01-01 16:42:13", "2020-01-01 16:45:09",
"2020-01-02 18:12:48", "2020-01-03 19:45:32", "2020-01-08 23:16:43"
)

df = pl$DataFrame(dt = dates, a = c(3, 7, 5, 9, 2, 1))$
with_columns(
pl$col("dt")$str$strptime(pl$Datetime("us"), format = "%Y-%m-%d %H:%M:%S")$set_sorted()
)
expect_grepl_error(
df$select(pl$col("a")$rolling(index_column = "dt", period = "-2d")),
"rolling window period should be strictly positive"
)
})

test_that("rolling: passing a difftime as period works", {
dates = c(
"2020-01-01 13:45:48", "2020-01-01 16:42:13", "2020-01-01 16:45:09",
"2020-01-02 18:12:48", "2020-01-03 19:45:32", "2020-01-08 23:16:43"
)

df = pl$DataFrame(dt = dates, a = c(3, 7, 5, 9, 2, 1))$
with_columns(
pl$col("dt")$str$strptime(pl$Datetime("us"), format = "%Y-%m-%d %H:%M:%S")$set_sorted()
)
expect_identical(
df$select(
sum_a_offset1 = pl$sum("a")$rolling(index_column = "dt", period = "2d", offset = "1d")
)$to_data_frame(),
df$select(
sum_a_offset1 = pl$sum("a")$rolling(index_column = "dt", period = as.difftime(2, units = "days"), offset = "1d")
)$to_data_frame()
)
})

test_that("rolling, arg check_sorted", {
dates = c(
"2020-01-02 18:12:48", "2020-01-03 19:45:32", "2020-01-08 23:16:43",
Expand Down
12 changes: 12 additions & 0 deletions tests/testthat/test-groupby.R
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,18 @@ test_that("group_by_dynamic for LazyFrame: error if not explicitly sorted", {
)
})

test_that("group_by_dynamic for LazyFrame: error if every is negative", {
df = pl$LazyFrame(
idx = 0:5,
n = 0:5
)$with_columns(pl$col("idx")$set_sorted())

expect_grepl_error(
df$group_by_dynamic("idx", every = "-2i")$agg(pl$col("n")$mean())$collect(),
"'every' argument must be positive"
)
})

test_that("group_by_dynamic for LazyFrame: arg 'closed' works", {
df = pl$LazyFrame(
dt = c(
Expand Down
Loading

0 comments on commit 092bb4c

Please sign in to comment.