diff --git a/R/db_joins.R b/R/db_joins.R index 9a1282f1..cb706848 100644 --- a/R/db_joins.R +++ b/R/db_joins.R @@ -63,7 +63,7 @@ inner_join.tbl_sql <- function(x, y, by = NULL, ...) { } # Prepare the combined join - out <- do.call(dplyr::inner_join, args = join_args(.dots)) + out <- do.call(dplyr::inner_join, args = join_args(x, y, by, .dots)) out$lazy_query$vars <- join_na_select_fix(out$lazy_query$vars, .dots$na_by) return(out) @@ -79,7 +79,7 @@ left_join.tbl_sql <- function(x, y, by = NULL, ...) { return(NextMethod("left_join")) } - out <- do.call(dplyr::left_join, args = join_args(.dots)) + out <- do.call(dplyr::left_join, args = join_args(x, y, by, .dots)) out$lazy_query$vars <- join_na_select_fix(out$lazy_query$vars, .dots$na_by) return(out) @@ -95,7 +95,7 @@ right_join.tbl_sql <- function(x, y, by = NULL, ...) { return(NextMethod("right_join")) } - out <- do.call(dplyr::right_join, args = join_args(.dots)) + out <- do.call(dplyr::right_join, args = join_args(x, y, by, .dots)) out$lazy_query$vars <- join_na_select_fix(out$lazy_query$vars, .dots$na_by, right = TRUE) return(out) @@ -112,7 +112,7 @@ full_join.tbl_sql <- function(x, y, by = NULL, ...) { return(NextMethod("full_join")) } - out <- do.call(dplyr::full_join, args = join_args(.dots)) + out <- do.call(dplyr::full_join, args = join_args(x, y, by, .dots)) out$lazy_query$vars <- join_na_select_fix(out$lazy_query$vars, .dots$na_by) return(out) @@ -129,7 +129,7 @@ semi_join.tbl_sql <- function(x, y, by = NULL, ...) { return(NextMethod("semi_join")) } - out <- do.call(dplyr::semi_join, args = join_args(.dots)) + out <- do.call(dplyr::semi_join, args = join_args(x, y, by, .dots)) return(out) } @@ -145,7 +145,7 @@ anti_join.tbl_sql <- function(x, y, by = NULL, ...) { return(NextMethod("anti_join")) } - out <- do.call(dplyr::anti_join, args = join_args(.dots)) + out <- do.call(dplyr::anti_join, args = join_args(x, y, by, .dots)) return(out) } @@ -197,14 +197,11 @@ join_warn_experimental <- function() { #' @param .dots (`list`) \cr #' Arguments passed to the `*_join` function. #' @noRd -join_args <- function(.dots) { - # Grab the environment of the caller and add the dot args - args <- append(as.list(rlang::caller_env()), .dots) +join_args <- function(x, y, by, .dots) { # Remove the na matching args, and let join_na_sql combine the `by` and `na_by` statements - args$na_by <- NULL - args$na_matches <- NULL - args$by <- join_na_sql(args$x, args$y, by = args$by, na_by = .dots$na_by) + by <- join_na_sql(x, y, by, .dots) + args <- append(list(x = x, y = y, by = by), purrr::discard_at(.dots, c("na_by", "na_matches"))) return(args) } @@ -254,7 +251,14 @@ join_merger <- function(by, na_by) { #' A `dplyr_join_by` object to join by such that "NA" are matched with "NA" given the columns listed in `by` and #' `na_by`. #' @noRd -join_na_sql <- function(x, y, by = NULL, na_by = NULL) { +join_na_sql <- function(x, y, by = NULL, .dots = NULL) { + + # Early return if no na_by statement is given + if (is.null(.dots$na_by)) { + return(by) + } else { + na_by <- .dots$na_by + } # Check arguments checkmate::assert( @@ -270,17 +274,24 @@ join_na_sql <- function(x, y, by = NULL, na_by = NULL) { # Convert to dplyr_join_by if not already if (!is.null(by) && !inherits(by, "dplyr_join_by")) { - by <- dplyr::join_by(!!by) + by <- dplyr::join_by(!!!by) } if (!is.null(na_by) && !inherits(na_by, "dplyr_join_by")) { - na_by <- dplyr::join_by(!!na_by) + na_by <- dplyr::join_by(!!!na_by) } - combined_join <- join_merger(by, na_by) - # Get the translation for matching the na_by component of the join - na_subquery <- dbplyr::remote_query(dplyr::inner_join(x, y, by = combined_join, na_matches = "na")) + subquery_args <- purrr::discard_at(.dots, "na_by") |> + modifyList( + list( + x = x, + y = y, + by = join_merger(by, na_by), + na_matches = "na" + ) + ) + na_subquery <- dbplyr::remote_query(do.call(dplyr::inner_join, args = subquery_args)) # Determine the NA matching statement by extracting from the translated query. # E.g. on RSQlite, the keyword "IS" checks if arguments are identical @@ -317,7 +328,9 @@ join_na_sql <- function(x, y, by = NULL, na_by = NULL) { #' A `tibble` with the `vars` component of the `lazy_query` fixed to remove doubly selected columns. #' @noRd join_na_select_fix <- function(vars, na_by, right = FALSE) { - if (!inherits(na_by, "dplyr_join_by")) na_by <- dplyr::join_by(!!na_by) + if (is.null(na_by)) return(vars) + + if (!inherits(na_by, "dplyr_join_by")) na_by <- dplyr::join_by(!!!na_by) # All equality joins in `na_by` are incorrectly translated doubly_selected_columns <- na_by |> diff --git a/tests/testthat/test-filter_keys.R b/tests/testthat/test-filter_keys.R index ad6952cc..76b80b4a 100644 --- a/tests/testthat/test-filter_keys.R +++ b/tests/testthat/test-filter_keys.R @@ -55,7 +55,7 @@ test_that("filter_keys() works with copy = TRUE", { dplyr::collect()) # The above filter_keys with `copy = TRUE` generates a dbplyr_### table. - # We manually remove this since we expect it. If more arrise, we will get an error. + # We manually remove this since we expect it. If more arise, we will get an error. DBI::dbRemoveTable(conn, id(utils::head(get_tables(conn, "dbplyr_"), 1))) connection_clean_up(conn)