Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
RasmusSkytte committed Oct 14, 2024
1 parent 2171250 commit ba302cf
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 15 deletions.
96 changes: 82 additions & 14 deletions R/db_joins.R
Original file line number Diff line number Diff line change
Expand Up @@ -76,17 +76,63 @@ join_na_sql <- function(x, y, by = NULL, na_by = NULL) {
}
join_na_select_fix <- function(vars, na_by, right = FALSE) {
if (!inherits(na_by, "dplyr_join_by")) na_by <- dplyr::join_by(!!na_by)
doubly_selected_columns <- na_by |>
purrr::discard_at("exprs") |>
tibble::as_tibble() |>
dplyr::filter(.data$condition == "==", .data$x == .data$y) |>
dplyr::pull("x")
if (length(doubly_selected_columns) == 0) {
updated_vars <- vars # no doubly selected columns
} else {
# The vars table structure is not consistent between dplyr join types
if (checkmate::test_names(names(vars), identical.to = c("name", "x", "y"))) {
updated_vars <- vars # no doubly selected columns
updated_vars <- rbind(
tibble::tibble(
name = doubly_selected_columns,
x = ifelse(right, NA, doubly_selected_columns),
y = doubly_selected_columns
),
dplyr::filter(vars, .data$x %in% !!doubly_selected_columns | .data$y %in% !!doubly_selected_columns)
) |>
dplyr::symdiff(vars)
} else if (checkmate::test_names(names(vars), identical.to = c("name", "table", "var"))) {
updated_vars <- rbind(
tibble::tibble(name = doubly_selected_columns, table = 1, var = doubly_selected_columns),
dplyr::filter(vars, .data$var %in% !!doubly_selected_columns)
) |>
dplyr::symdiff(vars)
}
}
return(updated_vars)
}
#' Warn users that SQL does not match on NA by default
#'
#' @return
#' A warning that *_joins on SQL backends does not match NA by default.
#' @noRd
join_warn <- function() {
if (interactive() && identical(parent.frame(n = 2), globalenv())) {
rlang::warn(paste("*_joins in database-backend does not match NA by default.\n",
"If your data contains NA, the columns with NA values must be supplied to \"na_by\",",
"or you must specify na_matches = \"na\""),
.frequency = "once", .frequency_id = "*_join NA warning")
rlang::warn(
paste(
"*_joins in database-backend does not match NA by default.\n",
"If your data contains NA, the columns with NA values must be supplied to \"na_by\",",
"or you must specify na_matches = \"na\""
),
.frequency = "once",
.frequency_id = "*_join NA warning"
)
}
}

Expand All @@ -98,8 +144,11 @@ join_warn <- function() {
#' @noRd
join_warn_experimental <- function() {
if (interactive() && identical(parent.frame(n = 2), globalenv())) {
rlang::warn("*_joins with na_by is still experimental. Please report issues.",
.frequency = "once", .frequency_id = "*_join NA warning")
rlang::warn(
"*_joins with na_by is still experimental. Please report issues.",
.frequency = "once",
.frequency_id = "*_join NA warning"
)
}
}

Expand Down Expand Up @@ -163,11 +212,15 @@ inner_join.tbl_sql <- function(x, y, by = NULL, ...) {
}

# Prepare the combined join
query <- dbplyr::remote_query(dplyr::inner_join(x, y, by = join_merger(by, .dots$na_by), ...))
args <- append(as.list(rlang::current_env()), .dots)
args$na_by <- NULL
args$na_matches <- NULL
args$by <- join_na_sql(x, y, by = by, na_by = .dots$na_by)

updated_by <- join_na_sql(x, y, by = by, na_by = .dots$na_by)
out <- do.call(dplyr::inner_join, args = args)
out$lazy_query$vars <- join_na_select_fix(out$lazy_query$vars, .dots$na_by)

return(do.call(dplyr::inner_join, args = args))
return(out)
}

#' @rdname joins
Expand All @@ -186,7 +239,10 @@ left_join.tbl_sql <- function(x, y, by = NULL, ...) {
args$na_matches <- NULL
args$by <- join_na_sql(x, y, by = by, na_by = .dots$na_by)

return(do.call(dplyr::left_join, args = args))
out <- do.call(dplyr::left_join, args = args)
out$lazy_query$vars <- join_na_select_fix(out$lazy_query$vars, .dots$na_by)

return(out)
}

#' @rdname joins
Expand All @@ -205,7 +261,10 @@ right_join.tbl_sql <- function(x, y, by = NULL, ...) {
args$na_matches <- NULL
args$by <- join_na_sql(x, y, by = by, na_by = .dots$na_by)

return(do.call(dplyr::right_join, args = args))
out <- do.call(dplyr::right_join, args = args)
out$lazy_query$vars <- join_na_select_fix(out$lazy_query$vars, .dots$na_by, right = TRUE)

return(out)
}


Expand All @@ -225,7 +284,10 @@ full_join.tbl_sql <- function(x, y, by = NULL, ...) {
args$na_matches <- NULL
args$by <- join_na_sql(x, y, by = by, na_by = .dots$na_by)

return(do.call(dplyr::full_join, args = args))
out <- do.call(dplyr::full_join, args = args)
out$lazy_query$vars <- join_na_select_fix(out$lazy_query$vars, .dots$na_by)

return(out)
}


Expand All @@ -245,7 +307,10 @@ semi_join.tbl_sql <- function(x, y, by = NULL, ...) {
args$na_matches <- NULL
args$by <- join_na_sql(x, y, by = by, na_by = .dots$na_by)

return(do.call(dplyr::semi_join, args = args))
out <- do.call(dplyr::semi_join, args = args)
#out$lazy_query$vars <- join_na_select_fix(out$lazy_query$vars, .dots$na_by)

Check warning on line 311 in R/db_joins.R

View workflow job for this annotation

GitHub Actions / ⚙️ Dispatch / lint / 🖋️ Lint

file=R/db_joins.R,line=311,col=4,[commented_code_linter] Commented code should be removed.

return(out)
}


Expand All @@ -265,5 +330,8 @@ anti_join.tbl_sql <- function(x, y, by = NULL, ...) {
args$na_matches <- NULL
args$by <- join_na_sql(x, y, by = by, na_by = .dots$na_by)

return(do.call(dplyr::anti_join, args = args))
out <- do.call(dplyr::anti_join, args = args)
#out$lazy_query$vars <- join_na_select_fix(out$lazy_query$vars, .dots$na_by)

Check warning on line 334 in R/db_joins.R

View workflow job for this annotation

GitHub Actions / ⚙️ Dispatch / lint / 🖋️ Lint

file=R/db_joins.R,line=334,col=4,[commented_code_linter] Commented code should be removed.

return(out)
}
67 changes: 66 additions & 1 deletion tests/testthat/test-db_joins.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ test_that("*_join() works with character `by` and `na_by`", {
dplyr::arrange(date, region_id)
qr <- dplyr::full_join(dplyr::collect(x), dplyr::collect(y), by = c("date", "region_id")) |>
dplyr::arrange(date, region_id)
expect_equal(q, qr)
expect_mapequal(q, qr)



Expand Down Expand Up @@ -186,3 +186,68 @@ test_that("*_join() does not break any dplyr joins", {
connection_clean_up(conn)
}
})



test_that("*_join() with only na_by works as dplyr joins", {
for (conn in get_test_conns()) {

# Define two test datasets
x <- get_table(conn, "__mtcars") |>
dplyr::select(name, mpg, cyl, hp, vs, am, gear, carb)

y <- get_table(conn, "__mtcars") |>
dplyr::select(name, drat, wt, qsec)

# Test the standard joins
# left_join
qr <- dplyr::left_join(dplyr::collect(x), dplyr::collect(y), by = "name")
q <- dplyr::left_join(x, y, na_by = "name") |> dplyr::collect()
expect_equal(q, qr)

q <- dplyr::left_join(x, y, na_by = dplyr::join_by(x$name == y$name)) |> dplyr::collect()
expect_equal(q, qr)

# right_join
qr <- dplyr::right_join(dplyr::collect(x), dplyr::collect(y), by = "name")
q <- dplyr::right_join(x, y, na_by = "name") |> dplyr::collect()
expect_equal(q, qr)

q <- dplyr::right_join(x, y, na_by = dplyr::join_by(x$name == y$name)) |> dplyr::collect()
expect_equal(q, qr)

# inner_join
qr <- dplyr::inner_join(dplyr::collect(x), dplyr::collect(y), by = "name")
q <- dplyr::inner_join(x, y, na_by = "name") |> dplyr::collect()
expect_equal(q, qr)

q <- dplyr::inner_join(x, y, na_by = dplyr::join_by(x$name == y$name)) |> dplyr::collect()
expect_equal(q, qr)

# full_join
qr <- dplyr::full_join(dplyr::collect(x), dplyr::collect(y), by = "name")
q <- dplyr::full_join(x, y, na_by = "name") |> dplyr::collect()
expect_equal(q, qr)

q <- dplyr::full_join(x, y, na_by = dplyr::join_by(x$name == y$name)) |> dplyr::collect()
expect_equal(q, qr)

# semi_join
qr <- dplyr::semi_join(dplyr::collect(x), dplyr::collect(y), by = "name")
q <- dplyr::semi_join(x, y, na_by = "name") |> dplyr::collect()
expect_equal(q, qr)

q <- dplyr::semi_join(x, y, na_by = dplyr::join_by(x$name == y$name)) |> dplyr::collect()
expect_equal(q, qr)

# anti_join
qr <- dplyr::anti_join(dplyr::collect(x), dplyr::collect(y), by = "name")
q <- dplyr::anti_join(x, y, na_by = "name") |> dplyr::collect()
expect_equal(q, qr)

q <- dplyr::anti_join(x, y, na_by = dplyr::join_by(x$name == y$name)) |> dplyr::collect()
expect_equal(q, qr)

connection_clean_up(conn)
}
})

0 comments on commit ba302cf

Please sign in to comment.