From 5c4ac1d217867af565abe66b3d6a9497e56d6c87 Mon Sep 17 00:00:00 2001 From: Etienne Bacher <52219252+etiennebacher@users.noreply.github.com> Date: Sat, 4 May 2024 22:40:45 +0100 Subject: [PATCH] Implement `$str$head()` and `$str$tail()` --- NEWS.md | 1 + R/expr__string.R | 70 +++++++++++++++++++++++++ R/extendr-wrappers.R | 4 ++ man/ExprStr_head.Rd | 42 +++++++++++++++ man/ExprStr_tail.Rd | 42 +++++++++++++++ src/rust/src/lazy/dsl.rs | 9 ++++ tests/testthat/_snaps/after-wrappers.md | 47 +++++++++-------- tests/testthat/test-expr_string.R | 37 +++++++++++++ 8 files changed, 229 insertions(+), 23 deletions(-) create mode 100644 man/ExprStr_head.Rd create mode 100644 man/ExprStr_tail.Rd diff --git a/NEWS.md b/NEWS.md index 1b8c38a08..d526b5806 100644 --- a/NEWS.md +++ b/NEWS.md @@ -7,6 +7,7 @@ - `pl$read_ipc()` can read a raw vector of Apache Arrow IPC file (#1072). - New method `$to_raw_ipc()` to serialize a DataFrame to a raw vector of Apache Arrow IPC file format (#1072). +- New methods `$str$head()` and `$str$tail()` (#1074). ## Polars R Package 0.16.3 diff --git a/R/expr__string.R b/R/expr__string.R index 2f709fad6..5e942b60f 100644 --- a/R/expr__string.R +++ b/R/expr__string.R @@ -1028,3 +1028,73 @@ ExprStr_find = function(pattern, ..., literal = FALSE, strict = TRUE) { .pr$Expr$str_find(self, pattern, literal, strict) |> unwrap("in $str$find():") } + +#' Return the first n characters of each string +#' +#' @param n Length of the slice (integer or expression). Strings are parsed as +#' column names. Negative indexing is supported. +#' +#' @details +#' The `n` input is defined in terms of the number of characters in the (UTF-8) +#' string. A character is defined as a Unicode scalar value. A single character +#' is represented by a single byte when working with ASCII text, and a maximum +#' of 4 bytes otherwise. +#' +#' When the `n` input is negative, `head()` returns characters up to the `n`th +#' from the end of the string. For example, if `n = -3`, then all characters +#' except the last three are returned. +#' +#' If the length of the string has fewer than `n` characters, the full string is +#' returned. +#' +#' @return A string Expr +#' +#' @examples +#' df = pl$DataFrame( +#' s = c("pear", NA, "papaya", "dragonfruit"), +#' n = c(3, 4, -2, -5) +#' ) +#' +#' df$with_columns( +#' s_head_5 = pl$col("s")$str$head(5), +#' s_head_n = pl$col("s")$str$head("n") +#' ) +ExprStr_head = function(n) { + .pr$Expr$str_head(self, n) |> + unwrap("in $str$head():") +} + +#' Return the last n characters of each string +#' +#' @param n Length of the slice (integer or expression). Strings are parsed as +#' column names. Negative indexing is supported. +#' +#' @details +#' The `n` input is defined in terms of the number of characters in the (UTF-8) +#' string. A character is defined as a Unicode scalar value. A single character +#' is represented by a single byte when working with ASCII text, and a maximum +#' of 4 bytes otherwise. +#' +#' When the `n` input is negative, `tail()` returns characters starting from the +#' `n`th from the beginning of the string. For example, if `n = -3`, then all +#' characters except the first three are returned. +#' +#' If the length of the string has fewer than `n` characters, the full string is +#' returned. +#' +#' @return A string Expr +#' +#' @examples +#' df = pl$DataFrame( +#' s = c("pear", NA, "papaya", "dragonfruit"), +#' n = c(3, 4, -2, -5) +#' ) +#' +#' df$with_columns( +#' s_tail_5 = pl$col("s")$str$tail(5), +#' s_tail_n = pl$col("s")$str$tail("n") +#' ) +ExprStr_tail = function(n) { + .pr$Expr$str_tail(self, n) |> + unwrap("in $str$tail():") +} diff --git a/R/extendr-wrappers.R b/R/extendr-wrappers.R index 2fe094416..a937949b7 100644 --- a/R/extendr-wrappers.R +++ b/R/extendr-wrappers.R @@ -1074,6 +1074,10 @@ RPolarsExpr$str_replace_many <- function(patterns, replace_with, ascii_case_inse RPolarsExpr$str_find <- function(pat, literal, strict) .Call(wrap__RPolarsExpr__str_find, self, pat, literal, strict) +RPolarsExpr$str_head <- function(n) .Call(wrap__RPolarsExpr__str_head, self, n) + +RPolarsExpr$str_tail <- function(n) .Call(wrap__RPolarsExpr__str_tail, self, n) + RPolarsExpr$bin_contains <- function(lit) .Call(wrap__RPolarsExpr__bin_contains, self, lit) RPolarsExpr$bin_starts_with <- function(sub) .Call(wrap__RPolarsExpr__bin_starts_with, self, sub) diff --git a/man/ExprStr_head.Rd b/man/ExprStr_head.Rd new file mode 100644 index 000000000..6655b7fc5 --- /dev/null +++ b/man/ExprStr_head.Rd @@ -0,0 +1,42 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/expr__string.R +\name{ExprStr_head} +\alias{ExprStr_head} +\title{Return the first n characters of each string} +\usage{ +ExprStr_head(n) +} +\arguments{ +\item{n}{Length of the slice (integer or expression). Strings are parsed as +column names. Negative indexing is supported.} +} +\value{ +A string Expr +} +\description{ +Return the first n characters of each string +} +\details{ +The \code{n} input is defined in terms of the number of characters in the (UTF-8) +string. A character is defined as a Unicode scalar value. A single character +is represented by a single byte when working with ASCII text, and a maximum +of 4 bytes otherwise. + +When the \code{n} input is negative, \code{head()} returns characters up to the \code{n}th +from the end of the string. For example, if \code{n = -3}, then all characters +except the last three are returned. + +If the length of the string has fewer than \code{n} characters, the full string is +returned. +} +\examples{ +df = pl$DataFrame( + s = c("pear", NA, "papaya", "dragonfruit"), + n = c(3, 4, -2, -5) +) + +df$with_columns( + s_head_5 = pl$col("s")$str$head(5), + s_head_n = pl$col("s")$str$head("n") +) +} diff --git a/man/ExprStr_tail.Rd b/man/ExprStr_tail.Rd new file mode 100644 index 000000000..a3f524418 --- /dev/null +++ b/man/ExprStr_tail.Rd @@ -0,0 +1,42 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/expr__string.R +\name{ExprStr_tail} +\alias{ExprStr_tail} +\title{Return the last n characters of each string} +\usage{ +ExprStr_tail(n) +} +\arguments{ +\item{n}{Length of the slice (integer or expression). Strings are parsed as +column names. Negative indexing is supported.} +} +\value{ +A string Expr +} +\description{ +Return the last n characters of each string +} +\details{ +The \code{n} input is defined in terms of the number of characters in the (UTF-8) +string. A character is defined as a Unicode scalar value. A single character +is represented by a single byte when working with ASCII text, and a maximum +of 4 bytes otherwise. + +When the \code{n} input is negative, \code{tail()} returns characters starting from the +\code{n}th from the beginning of the string. For example, if \code{n = -3}, then all +characters except the first three are returned. + +If the length of the string has fewer than \code{n} characters, the full string is +returned. +} +\examples{ +df = pl$DataFrame( + s = c("pear", NA, "papaya", "dragonfruit"), + n = c(3, 4, -2, -5) +) + +df$with_columns( + s_tail_5 = pl$col("s")$str$tail(5), + s_tail_n = pl$col("s")$str$tail("n") +) +} diff --git a/src/rust/src/lazy/dsl.rs b/src/rust/src/lazy/dsl.rs index 0ee8dfcd9..d5c3b19fa 100644 --- a/src/rust/src/lazy/dsl.rs +++ b/src/rust/src/lazy/dsl.rs @@ -2426,6 +2426,15 @@ impl RPolarsExpr { _ => Ok(self.0.clone().str().find(pat, strict).into()), } } + + fn str_head(&self, n: Robj) -> RResult { + Ok(self.0.clone().str().head(robj_to!(PLExprCol, n)?).into()) + } + + fn str_tail(&self, n: Robj) -> RResult { + Ok(self.0.clone().str().tail(robj_to!(PLExprCol, n)?).into()) + } + //binary methods pub fn bin_contains(&self, lit: Robj) -> RResult { Ok(self diff --git a/tests/testthat/_snaps/after-wrappers.md b/tests/testthat/_snaps/after-wrappers.md index 6cb0462cb..7b13a3093 100644 --- a/tests/testthat/_snaps/after-wrappers.md +++ b/tests/testthat/_snaps/after-wrappers.md @@ -414,29 +414,30 @@ [271] "str_count_matches" "str_ends_with" [273] "str_explode" "str_extract" [275] "str_extract_all" "str_extract_groups" - [277] "str_find" "str_hex_decode" - [279] "str_hex_encode" "str_json_decode" - [281] "str_json_path_match" "str_len_bytes" - [283] "str_len_chars" "str_pad_end" - [285] "str_pad_start" "str_replace" - [287] "str_replace_all" "str_replace_many" - [289] "str_reverse" "str_slice" - [291] "str_split" "str_split_exact" - [293] "str_splitn" "str_starts_with" - [295] "str_strip_chars" "str_strip_chars_end" - [297] "str_strip_chars_start" "str_to_date" - [299] "str_to_datetime" "str_to_integer" - [301] "str_to_lowercase" "str_to_time" - [303] "str_to_titlecase" "str_to_uppercase" - [305] "str_zfill" "struct_field_by_name" - [307] "struct_rename_fields" "sub" - [309] "sum" "tail" - [311] "tan" "tanh" - [313] "timestamp" "to_physical" - [315] "top_k" "unique" - [317] "unique_counts" "unique_stable" - [319] "upper_bound" "value_counts" - [321] "var" "xor" + [277] "str_find" "str_head" + [279] "str_hex_decode" "str_hex_encode" + [281] "str_json_decode" "str_json_path_match" + [283] "str_len_bytes" "str_len_chars" + [285] "str_pad_end" "str_pad_start" + [287] "str_replace" "str_replace_all" + [289] "str_replace_many" "str_reverse" + [291] "str_slice" "str_split" + [293] "str_split_exact" "str_splitn" + [295] "str_starts_with" "str_strip_chars" + [297] "str_strip_chars_end" "str_strip_chars_start" + [299] "str_tail" "str_to_date" + [301] "str_to_datetime" "str_to_integer" + [303] "str_to_lowercase" "str_to_time" + [305] "str_to_titlecase" "str_to_uppercase" + [307] "str_zfill" "struct_field_by_name" + [309] "struct_rename_fields" "sub" + [311] "sum" "tail" + [313] "tan" "tanh" + [315] "timestamp" "to_physical" + [317] "top_k" "unique" + [319] "unique_counts" "unique_stable" + [321] "upper_bound" "value_counts" + [323] "var" "xor" # public and private methods of each class When diff --git a/tests/testthat/test-expr_string.R b/tests/testthat/test-expr_string.R index 0552c9252..fe9043908 100644 --- a/tests/testthat/test-expr_string.R +++ b/tests/testthat/test-expr_string.R @@ -884,3 +884,40 @@ test_that("str$find() works", { test$select(lit = pl$col("s")$str$find("(?iAa", strict = TRUE, literal = TRUE)) ) }) + +test_that("$str$head() works", { + df = pl$DataFrame( + s = c("pear", NA, "papaya", "dragonfruit"), + n = c(3, 4, -2, -5) + ) + + expect_equal( + df$select( + s_head_5 = pl$col("s")$str$head(5), + s_head_n = pl$col("s")$str$head("n") + )$to_list(), + list( + s_head_5 = c("pear", NA, "papay", "drago"), + s_head_n = c("pea", NA, "papa", "dragon") + ) + ) +}) + + +test_that("$str$tail() works", { + df = pl$DataFrame( + s = c("pear", NA, "papaya", "dragonfruit"), + n = c(3, 4, -2, -5) + ) + + expect_equal( + df$select( + s_tail_5 = pl$col("s")$str$tail(5), + s_tail_n = pl$col("s")$str$tail("n") + )$to_list(), + list( + s_tail_5 = c("pear", NA, "apaya", "fruit"), + s_tail_n = c("ear", NA, "paya", "nfruit") + ) + ) +})