Skip to content

Commit

Permalink
feat: add $join_where() for inequality joins (#1237)
Browse files Browse the repository at this point in the history
Co-authored-by: eitsupi <[email protected]>
Co-authored-by: etiennebacher <[email protected]>
  • Loading branch information
3 people authored Oct 3, 2024
1 parent 64c9076 commit 341c12f
Show file tree
Hide file tree
Showing 12 changed files with 959 additions and 44 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
should trigger an error (#1220).
- New method `$to_dummies()` for `DataFrame` (#1225).
- New argument `include_file_paths` in `pl_scan_csv()` and `pl_read_csv()` (#1235).
- New method `$join_where()` for `DataFrame` and `LazyFrame` to perform
inequality joins (#1237).

### Bug fixes

Expand Down
43 changes: 43 additions & 0 deletions R/dataframe__frame.R
Original file line number Diff line number Diff line change
Expand Up @@ -2577,3 +2577,46 @@ DataFrame_to_dummies = function(
.pr$DataFrame$to_dummies(self, columns = columns, separator = separator, drop_first = drop_first) |>
unwrap("in $to_dummies():")
}

#' @inherit LazyFrame_join_where title params
#'
#' @description
#' This performs an inner join, so only rows where all predicates are true are
#' included in the result, and a row from either DataFrame may be included
#' multiple times in the result.
#'
#' Note that the row order of the input DataFrames is not preserved.
#'
#' @param other DataFrame to join with.
#'
#' @return A DataFrame
#'
#' @examples
#' east = pl$DataFrame(
#' id = c(100, 101, 102),
#' dur = c(120, 140, 160),
#' rev = c(12, 14, 16),
#' cores = c(2, 8, 4)
#' )
#'
#' west = pl$DataFrame(
#' t_id = c(404, 498, 676, 742),
#' time = c(90, 130, 150, 170),
#' cost = c(9, 13, 15, 16),
#' cores = c(4, 2, 1, 4)
#' )
#'
#' east$join_where(
#' west,
#' pl$col("dur") < pl$col("time"),
#' pl$col("rev") < pl$col("cost")
#' )
DataFrame_join_where = function(
other,
...,
suffix = "_right") {
if (!is_polars_df(other)) {
Err_plain("`other` must be a DataFrame.") |> unwrap()
}
self$lazy()$join_where(other = other$lazy(), ..., suffix = suffix)$collect()
}
2 changes: 2 additions & 0 deletions R/extendr-wrappers.R
Original file line number Diff line number Diff line change
Expand Up @@ -1258,6 +1258,8 @@ RPolarsLazyFrame$join_asof <- function(other, left_on, right_on, left_by, right_

RPolarsLazyFrame$join <- function(other, left_on, right_on, how, validate, join_nulls, suffix, allow_parallel, force_parallel, coalesce) .Call(wrap__RPolarsLazyFrame__join, self, other, left_on, right_on, how, validate, join_nulls, suffix, allow_parallel, force_parallel, coalesce)

RPolarsLazyFrame$join_where <- function(other, predicates, suffix) .Call(wrap__RPolarsLazyFrame__join_where, self, other, predicates, suffix)

RPolarsLazyFrame$sort_by_exprs <- function(by, dotdotdot, descending, nulls_last, maintain_order, multithreaded) .Call(wrap__RPolarsLazyFrame__sort_by_exprs, self, by, dotdotdot, descending, nulls_last, maintain_order, multithreaded)

RPolarsLazyFrame$unpivot <- function(on, index, value_name, variable_name) .Call(wrap__RPolarsLazyFrame__unpivot, self, on, index, value_name, variable_name)
Expand Down
56 changes: 56 additions & 0 deletions R/lazyframe__lazy.R
Original file line number Diff line number Diff line change
Expand Up @@ -1367,6 +1367,62 @@ LazyFrame_join = function(
uw()
}

#' Perform a join based on one or multiple (in)equality predicates
#'
#' @description
#' This performs an inner join, so only rows where all predicates are true are
#' included in the result, and a row from either LazyFrame may be included
#' multiple times in the result.
#'
#' Note that the row order of the input LazyFrames is not preserved.
#'
#' @param other LazyFrame to join with.
#' @param ... (In)Equality condition to join the two tables on. When a column
#' name occurs in both tables, the proper suffix must be applied in the
#' predicate. For example, if both tables have a column `"x"` that you want to
#' use in the conditions, you must refer to the column of the right table as
#' `"x<suffix>"`.
#' @param suffix Suffix to append to columns with a duplicate name.
#'
#' @return A LazyFrame
#'
#' @examples
#' east = pl$LazyFrame(
#' id = c(100, 101, 102),
#' dur = c(120, 140, 160),
#' rev = c(12, 14, 16),
#' cores = c(2, 8, 4)
#' )
#'
#' west = pl$LazyFrame(
#' t_id = c(404, 498, 676, 742),
#' time = c(90, 130, 150, 170),
#' cost = c(9, 13, 15, 16),
#' cores = c(4, 2, 1, 4)
#' )
#'
#' east$join_where(
#' west,
#' pl$col("dur") < pl$col("time"),
#' pl$col("rev") < pl$col("cost")
#' )$collect()
LazyFrame_join_where = function(
other,
...,
suffix = "_right") {
uw = \(res) unwrap(res, "in $join_where():")

if (!is_polars_lf(other)) {
Err_plain("`other` must be a LazyFrame.") |> uw()
}

.pr$LazyFrame$join_where(
self, other, unpack_list(..., .context = "in $join_where():"), suffix
) |>
uw()
}



#' Sort the LazyFrame by the given columns
#'
Expand Down
50 changes: 50 additions & 0 deletions man/DataFrame_join_where.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

50 changes: 50 additions & 0 deletions man/LazyFrame_join_where.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions src/rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ features = [
"fmt",
"gcp",
"http",
"iejoin",
"interpolate",
"ipc",
"is_between",
Expand Down
16 changes: 16 additions & 0 deletions src/rust/src/lazy/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,22 @@ impl RPolarsLazyFrame {
))
}

fn join_where(&self, other: Robj, predicates: Robj, suffix: Robj) -> RResult<Self> {
let ldf = self.0.clone();
let other = robj_to!(PLLazyFrame, other)?;
let predicates = robj_to!(VecPLExprColNamed, predicates)?;
let suffix = robj_to!(str, suffix)?;

let out = ldf
.join_builder()
.with(other)
.suffix(suffix)
.join_where(predicates)
.into();

Ok(out)
}

pub fn sort_by_exprs(
&self,
by: Robj,
Expand Down
89 changes: 45 additions & 44 deletions tests/testthat/_snaps/after-wrappers.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,19 +84,19 @@
[17] "first" "flags" "gather_every" "get_column"
[21] "get_columns" "glimpse" "group_by" "group_by_dynamic"
[25] "head" "height" "item" "join"
[29] "join_asof" "last" "lazy" "limit"
[33] "max" "mean" "median" "min"
[37] "n_chunks" "null_count" "partition_by" "pivot"
[41] "print" "quantile" "rechunk" "rename"
[45] "reverse" "rolling" "sample" "schema"
[49] "select" "select_seq" "shape" "shift"
[53] "slice" "sort" "sql" "std"
[57] "sum" "tail" "to_data_frame" "to_dummies"
[61] "to_list" "to_raw_ipc" "to_series" "to_struct"
[65] "transpose" "unique" "unnest" "unpivot"
[69] "var" "width" "with_columns" "with_columns_seq"
[73] "with_row_index" "write_csv" "write_ipc" "write_json"
[77] "write_ndjson" "write_parquet"
[29] "join_asof" "join_where" "last" "lazy"
[33] "limit" "max" "mean" "median"
[37] "min" "n_chunks" "null_count" "partition_by"
[41] "pivot" "print" "quantile" "rechunk"
[45] "rename" "reverse" "rolling" "sample"
[49] "schema" "select" "select_seq" "shape"
[53] "shift" "slice" "sort" "sql"
[57] "std" "sum" "tail" "to_data_frame"
[61] "to_dummies" "to_list" "to_raw_ipc" "to_series"
[65] "to_struct" "transpose" "unique" "unnest"
[69] "unpivot" "var" "width" "with_columns"
[73] "with_columns_seq" "with_row_index" "write_csv" "write_ipc"
[77] "write_json" "write_ndjson" "write_parquet"

---

Expand Down Expand Up @@ -150,19 +150,19 @@
[13] "fill_nan" "fill_null" "filter"
[16] "first" "gather_every" "group_by"
[19] "group_by_dynamic" "head" "join"
[22] "join_asof" "last" "limit"
[25] "max" "mean" "median"
[28] "min" "print" "profile"
[31] "quantile" "rename" "reverse"
[34] "rolling" "schema" "select"
[37] "select_seq" "serialize" "shift"
[40] "sink_csv" "sink_ipc" "sink_ndjson"
[43] "sink_parquet" "slice" "sort"
[46] "sql" "std" "sum"
[49] "tail" "to_dot" "unique"
[52] "unnest" "unpivot" "var"
[55] "width" "with_columns" "with_columns_seq"
[58] "with_context" "with_row_index"
[22] "join_asof" "join_where" "last"
[25] "limit" "max" "mean"
[28] "median" "min" "print"
[31] "profile" "quantile" "rename"
[34] "reverse" "rolling" "schema"
[37] "select" "select_seq" "serialize"
[40] "shift" "sink_csv" "sink_ipc"
[43] "sink_ndjson" "sink_parquet" "slice"
[46] "sort" "sql" "std"
[49] "sum" "tail" "to_dot"
[52] "unique" "unnest" "unpivot"
[55] "var" "width" "with_columns"
[58] "with_columns_seq" "with_context" "with_row_index"

---

Expand All @@ -180,24 +180,25 @@
[17] "fill_null" "filter"
[19] "first" "group_by"
[21] "group_by_dynamic" "join"
[23] "join_asof" "last"
[25] "max" "mean"
[27] "median" "min"
[29] "optimization_toggle" "print"
[31] "profile" "quantile"
[33] "rename" "reverse"
[35] "rolling" "schema"
[37] "select" "select_seq"
[39] "serialize" "shift"
[41] "sink_csv" "sink_ipc"
[43] "sink_json" "sink_parquet"
[45] "slice" "sort_by_exprs"
[47] "std" "sum"
[49] "tail" "to_dot"
[51] "unique" "unnest"
[53] "unpivot" "var"
[55] "with_columns" "with_columns_seq"
[57] "with_context" "with_row_index"
[23] "join_asof" "join_where"
[25] "last" "max"
[27] "mean" "median"
[29] "min" "optimization_toggle"
[31] "print" "profile"
[33] "quantile" "rename"
[35] "reverse" "rolling"
[37] "schema" "select"
[39] "select_seq" "serialize"
[41] "shift" "sink_csv"
[43] "sink_ipc" "sink_json"
[45] "sink_parquet" "slice"
[47] "sort_by_exprs" "std"
[49] "sum" "tail"
[51] "to_dot" "unique"
[53] "unnest" "unpivot"
[55] "var" "with_columns"
[57] "with_columns_seq" "with_context"
[59] "with_row_index"

# public and private methods of each class Expr

Expand Down
Loading

0 comments on commit 341c12f

Please sign in to comment.