Skip to content

Commit

Permalink
fix(db_join): Account for extra args such as copy
Browse files Browse the repository at this point in the history
  • Loading branch information
RasmusSkytte committed Oct 15, 2024
1 parent f7c25d7 commit ec4c03b
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 20 deletions.
51 changes: 32 additions & 19 deletions R/db_joins.R
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ inner_join.tbl_sql <- function(x, y, by = NULL, ...) {
}

# Prepare the combined join
out <- do.call(dplyr::inner_join, args = join_args(.dots))
out <- do.call(dplyr::inner_join, args = join_args(x, y, by, .dots))
out$lazy_query$vars <- join_na_select_fix(out$lazy_query$vars, .dots$na_by)

return(out)
Expand All @@ -79,7 +79,7 @@ left_join.tbl_sql <- function(x, y, by = NULL, ...) {
return(NextMethod("left_join"))
}

out <- do.call(dplyr::left_join, args = join_args(.dots))
out <- do.call(dplyr::left_join, args = join_args(x, y, by, .dots))
out$lazy_query$vars <- join_na_select_fix(out$lazy_query$vars, .dots$na_by)

return(out)
Expand All @@ -95,7 +95,7 @@ right_join.tbl_sql <- function(x, y, by = NULL, ...) {
return(NextMethod("right_join"))
}

out <- do.call(dplyr::right_join, args = join_args(.dots))
out <- do.call(dplyr::right_join, args = join_args(x, y, by, .dots))
out$lazy_query$vars <- join_na_select_fix(out$lazy_query$vars, .dots$na_by, right = TRUE)

return(out)
Expand All @@ -112,7 +112,7 @@ full_join.tbl_sql <- function(x, y, by = NULL, ...) {
return(NextMethod("full_join"))
}

out <- do.call(dplyr::full_join, args = join_args(.dots))
out <- do.call(dplyr::full_join, args = join_args(x, y, by, .dots))
out$lazy_query$vars <- join_na_select_fix(out$lazy_query$vars, .dots$na_by)

return(out)
Expand All @@ -129,7 +129,7 @@ semi_join.tbl_sql <- function(x, y, by = NULL, ...) {
return(NextMethod("semi_join"))
}

out <- do.call(dplyr::semi_join, args = join_args(.dots))
out <- do.call(dplyr::semi_join, args = join_args(x, y, by, .dots))

return(out)
}
Expand All @@ -145,7 +145,7 @@ anti_join.tbl_sql <- function(x, y, by = NULL, ...) {
return(NextMethod("anti_join"))
}

out <- do.call(dplyr::anti_join, args = join_args(.dots))
out <- do.call(dplyr::anti_join, args = join_args(x, y, by, .dots))

return(out)
}
Expand Down Expand Up @@ -197,14 +197,11 @@ join_warn_experimental <- function() {
#' @param .dots (`list`) \cr
#' Arguments passed to the `*_join` function.
#' @noRd
join_args <- function(.dots) {
# Grab the environment of the caller and add the dot args
args <- append(as.list(rlang::caller_env()), .dots)
join_args <- function(x, y, by, .dots) {

# Remove the na matching args, and let join_na_sql combine the `by` and `na_by` statements
args$na_by <- NULL
args$na_matches <- NULL
args$by <- join_na_sql(args$x, args$y, by = args$by, na_by = .dots$na_by)
by <- join_na_sql(x, y, by, .dots)
args <- append(list(x = x, y = y, by = by), purrr::discard_at(.dots, c("na_by", "na_matches")))

return(args)
}
Expand Down Expand Up @@ -254,7 +251,14 @@ join_merger <- function(by, na_by) {
#' A `dplyr_join_by` object to join by such that "NA" are matched with "NA" given the columns listed in `by` and
#' `na_by`.
#' @noRd
join_na_sql <- function(x, y, by = NULL, na_by = NULL) {
join_na_sql <- function(x, y, by = NULL, .dots = NULL) {

# Early return if no na_by statement is given
if (is.null(.dots$na_by)) {
return(by)
} else {
na_by <- .dots$na_by
}

# Check arguments
checkmate::assert(
Expand All @@ -270,17 +274,24 @@ join_na_sql <- function(x, y, by = NULL, na_by = NULL) {

# Convert to dplyr_join_by if not already
if (!is.null(by) && !inherits(by, "dplyr_join_by")) {
by <- dplyr::join_by(!!by)
by <- dplyr::join_by(!!!by)
}

if (!is.null(na_by) && !inherits(na_by, "dplyr_join_by")) {
na_by <- dplyr::join_by(!!na_by)
na_by <- dplyr::join_by(!!!na_by)
}

combined_join <- join_merger(by, na_by)

# Get the translation for matching the na_by component of the join
na_subquery <- dbplyr::remote_query(dplyr::inner_join(x, y, by = combined_join, na_matches = "na"))
subquery_args <- purrr::discard_at(.dots, "na_by") |>
modifyList(
list(
x = x,
y = y,
by = join_merger(by, na_by),
na_matches = "na"
)
)
na_subquery <- dbplyr::remote_query(do.call(dplyr::inner_join, args = subquery_args))

# Determine the NA matching statement by extracting from the translated query.
# E.g. on RSQlite, the keyword "IS" checks if arguments are identical
Expand Down Expand Up @@ -317,7 +328,9 @@ join_na_sql <- function(x, y, by = NULL, na_by = NULL) {
#' A `tibble` with the `vars` component of the `lazy_query` fixed to remove doubly selected columns.
#' @noRd
join_na_select_fix <- function(vars, na_by, right = FALSE) {
if (!inherits(na_by, "dplyr_join_by")) na_by <- dplyr::join_by(!!na_by)
if (is.null(na_by)) return(vars)
if (!inherits(na_by, "dplyr_join_by")) na_by <- dplyr::join_by(!!!na_by)
# All equality joins in `na_by` are incorrectly translated
doubly_selected_columns <- na_by |>
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-filter_keys.R
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ test_that("filter_keys() works with copy = TRUE", {
dplyr::collect())

# The above filter_keys with `copy = TRUE` generates a dbplyr_### table.
# We manually remove this since we expect it. If more arrise, we will get an error.
# We manually remove this since we expect it. If more arise, we will get an error.
DBI::dbRemoveTable(conn, id(utils::head(get_tables(conn, "dbplyr_"), 1)))

connection_clean_up(conn)
Expand Down

0 comments on commit ec4c03b

Please sign in to comment.