Skip to content

Commit

Permalink
fix slice, test arg "order_by" in $over()
Browse files Browse the repository at this point in the history
  • Loading branch information
etiennebacher committed Jun 23, 2024
1 parent bfb59c7 commit 6f619a7
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 21 deletions.
24 changes: 18 additions & 6 deletions R/expr__expr.R
Original file line number Diff line number Diff line change
Expand Up @@ -1163,11 +1163,7 @@ Expr_is_not_nan = use_extendr_wrapper
#' full data.
#'
#' @return Expr
#' @aliases slice
#' @name Expr_slice
#' @format NULL
#' @examples
#'
#' # as head
#' pl$DataFrame(list(a = 0:100))$select(
#' pl$all()$slice(0, 6)
Expand All @@ -1185,7 +1181,8 @@ Expr_is_not_nan = use_extendr_wrapper
#' # recycling
#' pl$DataFrame(mtcars)$with_columns(pl$col("mpg")$slice(0, 1))
Expr_slice = function(offset, length = NULL) {
.pr$Expr$slice(self, wrap_e(offset), wrap_e(length))
.pr$Expr$slice(self, offset, wrap_e(length)) |>
unwrap("in $slice():")
}


Expand Down Expand Up @@ -1840,7 +1837,8 @@ Expr_last = use_extendr_wrapper
#' @param ... Column(s) to group by. Accepts expression input.
#' Characters are parsed as column names.
#' @param order_by Order the window functions/aggregations with the partitioned
#' groups by the result of the expression passed to `order_by`.
#' groups by the result of the expression passed to `order_by`. Can be an Expr.
#' Strings are parsed as column names.
#' @param mapping_strategy One of the following:
#' * `"group_to_rows"` (default): if the aggregation results in multiple values,
#' assign them back to their position in the DataFrame. This can only be done
Expand Down Expand Up @@ -1886,6 +1884,20 @@ Expr_last = use_extendr_wrapper
#' df$with_columns(
#' top_2 = pl$col("c")$top_k(2)$over("a", mapping_strategy = "join")
#' )
#'
#' # order_by specifies how values are sorted within a group, which is
#' # essential when the operation depends on the order of values
#' df = pl$DataFrame(
#' g = c(1, 1, 1, 1, 2, 2, 2, 2),
#' t = c(1, 2, 3, 4, 4, 1, 2, 3),
#' x = c(10, 20, 30, 40, 10, 20, 30, 40)
#' )
#'
#' # without order_by, the first and second values in the second group would
#' # be inverted, which would be wrong
#' df$with_columns(
#' x_lag = pl$col("x")$shift(1)$over("g", order_by = "t")
#' )
Expr_over = function(..., order_by = NULL, mapping_strategy = "group_to_rows") {
list_of_exprs = list2(...) |>
lapply(\(x) {
Expand Down
17 changes: 16 additions & 1 deletion man/Expr_over.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 0 additions & 2 deletions man/Expr_slice.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 9 additions & 4 deletions src/rust/src/lazy/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1790,12 +1790,17 @@ impl RPolarsExpr {
self.0.clone().len().into()
}

pub fn slice(&self, offset: &RPolarsExpr, length: Nullable<&RPolarsExpr>) -> Self {
pub fn slice(&self, offset: Robj, length: Nullable<&RPolarsExpr>) -> RResult<Self> {
let offset = robj_to!(PLExpr, offset)?;
let length = match null_to_opt(length) {
Some(i) => i.0.clone(),
Some(i) => dsl::cast(i.0.clone(), pl::DataType::Int64),
None => dsl::lit(i64::MAX),
};
self.0.clone().slice(offset.0.clone(), length).into()
Ok(self
.0
.clone()
.slice(dsl::cast(offset, pl::DataType::Int64), length)
.into())
}

pub fn append(&self, other: &RPolarsExpr, upcast: bool) -> Self {
Expand Down Expand Up @@ -1913,7 +1918,7 @@ impl RPolarsExpr {
) -> RResult<Self> {
let partition_by = robj_to!(Vec, PLExpr, partition_by)?;

let order_by = robj_to!(Option, Vec, PLExpr, order_by)?.map(|order_by| {
let order_by = robj_to!(Option, Vec, PLExprCol, order_by)?.map(|order_by| {
(
order_by,
SortOptions {
Expand Down
8 changes: 1 addition & 7 deletions tests/testthat/test-expr_datetime.R
Original file line number Diff line number Diff line change
Expand Up @@ -108,15 +108,9 @@ test_that("dt$truncate", {
lapply(l_actual, \(x) diff(x) |> as.numeric()),
list(
datetime = rep(2, 12),
truncated_4s = rep(c(0, 4), 6),
truncated_4s_offset_2s = rep(c(0, 4), 6)
truncated_4s = rep(c(0, 4), 6)
)
)

expect_identical(
as.numeric(l_actual$truncated_4s_offset_2s - l_actual$truncated_4s),
rep(3, 13)
)
})


Expand Down
17 changes: 16 additions & 1 deletion tests/testthat/test-expr_expr.R
Original file line number Diff line number Diff line change
Expand Up @@ -324,11 +324,26 @@ test_that("$over() with mapping_strategy", {
expect_identical(
df$select(pl$col("val")$top_k(2)$over("a", mapping_strategy = "join"))$to_list(),
list(
val = list(c(5L, 2L), c(5L, 2L), c(4L, 3L), c(4L, 3L), c(5L, 2L))
val = list(c(5L, 2L), c(5L, 2L), c(3L, 4L), c(3L, 4L), c(5L, 2L))
)
)
})

test_that("arg 'order_by' in $over() works", {
df = pl$DataFrame(
g = c(1, 1, 1, 1, 2, 2, 2, 2),
t = c(1, 2, 3, 4, 4, 1, 2, 3),
x = c(10, 20, 30, 40, 10, 20, 30, 40)
)

expect_equal(
df$select(
x_lag = pl$col("x")$shift(1)$over("g", order_by = "t")
)$to_list(),
list(x_lag = c(NA, 10, 20, 30, 40, NA, 20, 30))
)
})

test_that("col DataType + col(s) + col regex", {
# one Datatype
expect_equal(
Expand Down

0 comments on commit 6f619a7

Please sign in to comment.