Skip to content

Commit

Permalink
feat: add $to_dummies() for DataFrame (#1225)
Browse files Browse the repository at this point in the history
  • Loading branch information
etiennebacher authored Sep 4, 2024
1 parent 20ddf35 commit f55eade
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 15 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
- New method `$cast()` for `DataFrame` and `LazyFrame` (#1219).
- New argument `strict` in `$drop()` to determine whether unknown column names
should trigger an error (#1220).
- New method `$to_dummies()` for `DataFrame` (#1225).

### Bug fixes

Expand Down
31 changes: 31 additions & 0 deletions R/dataframe__frame.R
Original file line number Diff line number Diff line change
Expand Up @@ -2546,3 +2546,34 @@ DataFrame_gather_every = function(n, offset = 0) {
DataFrame_cast = function(dtypes, ..., strict = TRUE) {
self$lazy()$cast(dtypes, strict = strict)$collect()
}


#' Convert variables into dummy/indicator variables
#'
#' @param columns Column name(s) or selector(s) that should be converted to
#' dummy variables. If `NULL` (default), convert all columns.
#' @param ... Ignored.
#' @param separator Separator/delimiter used when generating column names.
#' @param drop_first Remove the first category from the variables being encoded.
#'
#' @return A DataFrame
#'
#' @examples
#' df = pl$DataFrame(foo = 1:2, bar = 3:4, ham = c("a", "b"))
#'
#' df$to_dummies()
#'
#' df$to_dummies(drop_first = TRUE)
#'
#' df$to_dummies(c("foo", "bar"), separator = "::")
DataFrame_to_dummies = function(
columns = NULL,
...,
separator = "_",
drop_first = FALSE) {
if (is.null(columns)) {
columns = names(self)
}
.pr$DataFrame$to_dummies(self, columns = columns, separator = separator, drop_first = drop_first) |>
unwrap("in $to_dummies():")
}
2 changes: 2 additions & 0 deletions R/extendr-wrappers.R
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,8 @@ RPolarsDataFrame$sample_n <- function(n, with_replacement, shuffle, seed) .Call(

RPolarsDataFrame$sample_frac <- function(frac, with_replacement, shuffle, seed) .Call(wrap__RPolarsDataFrame__sample_frac, self, frac, with_replacement, shuffle, seed)

RPolarsDataFrame$to_dummies <- function(columns, separator, drop_first) .Call(wrap__RPolarsDataFrame__to_dummies, self, columns, separator, drop_first)

RPolarsDataFrame$transpose <- function(keep_names_as, new_col_names) .Call(wrap__RPolarsDataFrame__transpose, self, keep_names_as, new_col_names)

RPolarsDataFrame$clear <- function() .Call(wrap__RPolarsDataFrame__clear, self)
Expand Down
33 changes: 33 additions & 0 deletions man/DataFrame_to_dummies.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 17 additions & 0 deletions src/rust/src/rdataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,23 @@ impl RPolarsDataFrame {
.map(RPolarsDataFrame)
}

fn to_dummies(&self, columns: Robj, separator: Robj, drop_first: Robj) -> RResult<Self> {
use polars::prelude::DataFrameOps;
let columns = robj_to!(Option, Vec, String, columns)?;
let separator = robj_to!(Option, str, separator)?;
let drop_first = robj_to!(bool, drop_first)?;
let df = match columns {
Some(cols) => self.0.columns_to_dummies(
cols.iter().map(|x| x as &str).collect(),
separator,
drop_first,
),
None => self.0.to_dummies(separator, drop_first),
}
.map_err(polars_to_rpolars_err)?;
Ok(df.into())
}

pub fn transpose(&mut self, keep_names_as: Robj, new_col_names: Robj) -> RResult<Self> {
let opt_s = robj_to!(Option, str, keep_names_as)?;
let opt_vec_s = robj_to!(Option, Vec, String, new_col_names)?;
Expand Down
30 changes: 15 additions & 15 deletions tests/testthat/_snaps/after-wrappers.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,12 @@
[45] "reverse" "rolling" "sample" "schema"
[49] "select" "select_seq" "shape" "shift"
[53] "slice" "sort" "sql" "std"
[57] "sum" "tail" "to_data_frame" "to_list"
[61] "to_raw_ipc" "to_series" "to_struct" "transpose"
[65] "unique" "unnest" "unpivot" "var"
[69] "width" "with_columns" "with_columns_seq" "with_row_index"
[73] "write_csv" "write_ipc" "write_json" "write_ndjson"
[77] "write_parquet"
[57] "sum" "tail" "to_data_frame" "to_dummies"
[61] "to_list" "to_raw_ipc" "to_series" "to_struct"
[65] "transpose" "unique" "unnest" "unpivot"
[69] "var" "width" "with_columns" "with_columns_seq"
[73] "with_row_index" "write_csv" "write_ipc" "write_json"
[77] "write_ndjson" "write_parquet"

---

Expand All @@ -119,15 +119,15 @@
[27] "select" "select_at_idx"
[29] "select_seq" "set_column_from_robj"
[31] "set_column_from_series" "set_column_names_mut"
[33] "shape" "to_list"
[35] "to_list_tag_structs" "to_list_unwind"
[37] "to_raw_ipc" "to_struct"
[39] "transpose" "unnest"
[41] "unpivot" "with_columns"
[43] "with_columns_seq" "with_row_index"
[45] "write_csv" "write_ipc"
[47] "write_json" "write_ndjson"
[49] "write_parquet"
[33] "shape" "to_dummies"
[35] "to_list" "to_list_tag_structs"
[37] "to_list_unwind" "to_raw_ipc"
[39] "to_struct" "transpose"
[41] "unnest" "unpivot"
[43] "with_columns" "with_columns_seq"
[45] "with_row_index" "write_csv"
[47] "write_ipc" "write_json"
[49] "write_ndjson" "write_parquet"

# public and private methods of each class GroupBy

Expand Down
26 changes: 26 additions & 0 deletions tests/testthat/test-dataframe.R
Original file line number Diff line number Diff line change
Expand Up @@ -1735,3 +1735,29 @@ test_that("$cast() works", {
list(x = NA_integer_)
)
})


test_that("$to_dummies() works", {
df = pl$DataFrame(foo = 1:2, bar = 3:4, ham = c("a", "b"))

expect_identical(
df$to_dummies()$to_list(),
list(foo_1 = 1:0, foo_2 = 0:1, bar_3 = 1:0, bar_4 = 0:1, ham_a = 1:0, ham_b = 0:1)
)

expect_identical(
df$to_dummies(drop_first = TRUE)$to_list(),
list(foo_2 = 0:1, bar_4 = 0:1, ham_b = 0:1)
)

expect_identical(
df$to_dummies(c("foo", "bar"), separator = "::")$to_list(),
list(
`foo::1` = 1:0,
`foo::2` = 0:1,
`bar::3` = 1:0,
`bar::4` = 0:1,
ham = c("a", "b")
)
)
})

0 comments on commit f55eade

Please sign in to comment.