From 24b67c862d68fd8676d2edbf1219ecd227de9df8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rasmus=20Skytte=20Randl=C3=B8v?= Date: Fri, 11 Oct 2024 15:25:30 +0200 Subject: [PATCH 01/17] fix(db_joins): Allow dplyr::join_by as by argument --- R/db_joins.R | 31 ++++++++++++++++++++++++------- tests/testthat/test-db_joins.R | 31 ++++++++++++++++++++++++++++++- 2 files changed, 54 insertions(+), 8 deletions(-) diff --git a/R/db_joins.R b/R/db_joins.R index 0d207903..03a214ed 100644 --- a/R/db_joins.R +++ b/R/db_joins.R @@ -175,8 +175,10 @@ inner_join.tbl_sql <- function(x, y, by = NULL, ...) { # Check arguments assert_data_like(x) assert_data_like(y) - checkmate::assert_character(by, null.ok = TRUE) - + checkmate::assert( + checkmate::check_character(by, null.ok = TRUE), + checkmate::check_class(by, "dplyr_join_by", null.ok = TRUE) + ) .dots <- list(...) if (!"na_by" %in% names(.dots)) { @@ -209,7 +211,10 @@ left_join.tbl_sql <- function(x, y, by = NULL, ...) { # Check arguments assert_data_like(x) assert_data_like(y) - checkmate::assert_character(by, null.ok = TRUE) + checkmate::assert( + checkmate::check_character(by, null.ok = TRUE), + checkmate::check_class(by, "dplyr_join_by", null.ok = TRUE) + ) .dots <- list(...) @@ -244,7 +249,10 @@ right_join.tbl_sql <- function(x, y, by = NULL, ...) { # Check arguments assert_data_like(x) assert_data_like(y) - checkmate::assert_character(by, null.ok = TRUE) + checkmate::assert( + checkmate::check_character(by, null.ok = TRUE), + checkmate::check_class(by, "dplyr_join_by", null.ok = TRUE) + ) .dots <- list(...) @@ -280,7 +288,10 @@ full_join.tbl_sql <- function(x, y, by = NULL, ...) { # Check arguments assert_data_like(x) assert_data_like(y) - checkmate::assert_character(by, null.ok = TRUE) + checkmate::assert( + checkmate::check_character(by, null.ok = TRUE), + checkmate::check_class(by, "dplyr_join_by", null.ok = TRUE) + ) .dots <- list(...) @@ -304,7 +315,10 @@ semi_join.tbl_sql <- function(x, y, by = NULL, ...) { # Check arguments assert_data_like(x) assert_data_like(y) - checkmate::assert_character(by, null.ok = TRUE) + checkmate::assert( + checkmate::check_character(by, null.ok = TRUE), + checkmate::check_class(by, "dplyr_join_by", null.ok = TRUE) + ) .dots <- list(...) @@ -324,7 +338,10 @@ anti_join.tbl_sql <- function(x, y, by = NULL, ...) { # Check arguments assert_data_like(x) assert_data_like(y) - checkmate::assert_character(by, null.ok = TRUE) + checkmate::assert( + checkmate::check_character(by, null.ok = TRUE), + checkmate::check_class(by, "dplyr_join_by", null.ok = TRUE) + ) .dots <- list(...) diff --git a/tests/testthat/test-db_joins.R b/tests/testthat/test-db_joins.R index 320c7c25..a7c95885 100644 --- a/tests/testthat/test-db_joins.R +++ b/tests/testthat/test-db_joins.R @@ -1,4 +1,4 @@ -test_that("*_join() works", { +test_that("*_join() works with character `by` and `na_by`", { for (conn in get_test_conns()) { # Define two test datasets @@ -115,3 +115,32 @@ test_that("*_join() works", { connection_clean_up(conn) } }) + + +test_that("*_join() works with `dplyr::join_by()`", { + for (conn in get_test_conns()) { + + # Define two test datasets + x <- get_table(conn, "__mtcars") |> + dplyr::select(name, mpg, cyl, hp, vs, am, gear, carb) + + y <- get_table(conn, "__mtcars") |> + dplyr::select(name, drat, wt, qsec) + + + # Test the implemented joins + q <- dplyr::left_join(x, y, by = dplyr::join_by(x$name == y$name)) |> dplyr::collect() + qr <- dplyr::left_join(dplyr::collect(x), dplyr::collect(y), by = dplyr::join_by(x$name == y$name)) + expect_equal(q, qr) + + q <- dplyr::right_join(x, y, by = dplyr::join_by(x$name == y$name)) |> dplyr::collect() + qr <- dplyr::right_join(dplyr::collect(x), dplyr::collect(y), by = dplyr::join_by(x$name == y$name)) + expect_equal(q, qr) + + q <- dplyr::inner_join(x, y, by = dplyr::join_by(x$name == y$name)) |> dplyr::collect() + qr <- dplyr::inner_join(dplyr::collect(x), dplyr::collect(y), by = dplyr::join_by(x$name == y$name)) + expect_equal(q, qr) + + connection_clean_up(conn) + } +}) From b3941da3b2b9c6fbf21f14b42d8509b0d67aa1cf Mon Sep 17 00:00:00 2001 From: RasmusSkytte Date: Fri, 11 Oct 2024 13:34:39 +0000 Subject: [PATCH 02/17] chore: Update pak.lock --- pak.lock | 64 ++++++++++++++++++++++++++++---------------------------- 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/pak.lock b/pak.lock index 96a02e9c..6b153271 100644 --- a/pak.lock +++ b/pak.lock @@ -7,7 +7,7 @@ { "ref": "askpass", "package": "askpass", - "version": "1.2.0", + "version": "1.2.1", "type": "standard", "direct": false, "binary": true, @@ -19,10 +19,10 @@ "RemoteRef": "askpass", "RemoteRepos": "https://packagemanager.posit.co/cran/__linux__/jammy/latest", "RemotePkgPlatform": "x86_64-pc-linux-gnu-ubuntu-22.04", - "RemoteSha": "1.2.0" + "RemoteSha": "1.2.1" }, - "sources": "https://packagemanager.posit.co/cran/__linux__/jammy/latest/src/contrib/askpass_1.2.0.tar.gz", - "target": "src/contrib/x86_64-pc-linux-gnu-ubuntu-22.04/4.4/askpass_1.2.0.tar.gz", + "sources": "https://packagemanager.posit.co/cran/__linux__/jammy/latest/src/contrib/askpass_1.2.1.tar.gz", + "target": "src/contrib/x86_64-pc-linux-gnu-ubuntu-22.04/4.4/askpass_1.2.1.tar.gz", "platform": "x86_64-pc-linux-gnu-ubuntu-22.04", "rversion": "4.4", "directpkg": false, @@ -534,7 +534,7 @@ { "ref": "commonmark", "package": "commonmark", - "version": "1.9.1", + "version": "1.9.2", "type": "standard", "direct": false, "binary": true, @@ -546,10 +546,10 @@ "RemoteRef": "commonmark", "RemoteRepos": "https://packagemanager.posit.co/cran/__linux__/jammy/latest", "RemotePkgPlatform": "x86_64-pc-linux-gnu-ubuntu-22.04", - "RemoteSha": "1.9.1" + "RemoteSha": "1.9.2" }, - "sources": "https://packagemanager.posit.co/cran/__linux__/jammy/latest/src/contrib/commonmark_1.9.1.tar.gz", - "target": "src/contrib/x86_64-pc-linux-gnu-ubuntu-22.04/4.4/commonmark_1.9.1.tar.gz", + "sources": "https://packagemanager.posit.co/cran/__linux__/jammy/latest/src/contrib/commonmark_1.9.2.tar.gz", + "target": "src/contrib/x86_64-pc-linux-gnu-ubuntu-22.04/4.4/commonmark_1.9.2.tar.gz", "platform": "x86_64-pc-linux-gnu-ubuntu-22.04", "rversion": "4.4", "directpkg": false, @@ -702,7 +702,7 @@ { "ref": "data.table", "package": "data.table", - "version": "1.16.0", + "version": "1.16.2", "type": "standard", "direct": false, "binary": true, @@ -714,10 +714,10 @@ "RemoteRef": "data.table", "RemoteRepos": "https://packagemanager.posit.co/cran/__linux__/jammy/latest", "RemotePkgPlatform": "x86_64-pc-linux-gnu-ubuntu-22.04", - "RemoteSha": "1.16.0" + "RemoteSha": "1.16.2" }, - "sources": "https://packagemanager.posit.co/cran/__linux__/jammy/latest/src/contrib/data.table_1.16.0.tar.gz", - "target": "src/contrib/x86_64-pc-linux-gnu-ubuntu-22.04/4.4/data.table_1.16.0.tar.gz", + "sources": "https://packagemanager.posit.co/cran/__linux__/jammy/latest/src/contrib/data.table_1.16.2.tar.gz", + "target": "src/contrib/x86_64-pc-linux-gnu-ubuntu-22.04/4.4/data.table_1.16.2.tar.gz", "platform": "x86_64-pc-linux-gnu-ubuntu-22.04", "rversion": "4.4", "directpkg": false, @@ -1012,7 +1012,7 @@ { "ref": "evaluate", "package": "evaluate", - "version": "1.0.0", + "version": "1.0.1", "type": "standard", "direct": false, "binary": true, @@ -1024,10 +1024,10 @@ "RemoteRef": "evaluate", "RemoteRepos": "https://packagemanager.posit.co/cran/__linux__/jammy/latest", "RemotePkgPlatform": "x86_64-pc-linux-gnu-ubuntu-22.04", - "RemoteSha": "1.0.0" + "RemoteSha": "1.0.1" }, - "sources": "https://packagemanager.posit.co/cran/__linux__/jammy/latest/src/contrib/evaluate_1.0.0.tar.gz", - "target": "src/contrib/x86_64-pc-linux-gnu-ubuntu-22.04/4.4/evaluate_1.0.0.tar.gz", + "sources": "https://packagemanager.posit.co/cran/__linux__/jammy/latest/src/contrib/evaluate_1.0.1.tar.gz", + "target": "src/contrib/x86_64-pc-linux-gnu-ubuntu-22.04/4.4/evaluate_1.0.1.tar.gz", "platform": "x86_64-pc-linux-gnu-ubuntu-22.04", "rversion": "4.4", "directpkg": false, @@ -1683,7 +1683,7 @@ { "ref": "hunspell", "package": "hunspell", - "version": "3.0.4", + "version": "3.0.5", "type": "standard", "direct": false, "binary": true, @@ -1695,10 +1695,10 @@ "RemoteRef": "hunspell", "RemoteRepos": "https://packagemanager.posit.co/cran/__linux__/jammy/latest", "RemotePkgPlatform": "x86_64-pc-linux-gnu-ubuntu-22.04", - "RemoteSha": "3.0.4" + "RemoteSha": "3.0.5" }, - "sources": "https://packagemanager.posit.co/cran/__linux__/jammy/latest/src/contrib/hunspell_3.0.4.tar.gz", - "target": "src/contrib/x86_64-pc-linux-gnu-ubuntu-22.04/4.4/hunspell_3.0.4.tar.gz", + "sources": "https://packagemanager.posit.co/cran/__linux__/jammy/latest/src/contrib/hunspell_3.0.5.tar.gz", + "target": "src/contrib/x86_64-pc-linux-gnu-ubuntu-22.04/4.4/hunspell_3.0.5.tar.gz", "platform": "x86_64-pc-linux-gnu-ubuntu-22.04", "rversion": "4.4", "directpkg": false, @@ -3619,7 +3619,7 @@ { "ref": "spelling", "package": "spelling", - "version": "2.3.0", + "version": "2.3.1", "type": "standard", "direct": false, "binary": true, @@ -3631,10 +3631,10 @@ "RemoteRef": "spelling", "RemoteRepos": "https://packagemanager.posit.co/cran/__linux__/jammy/latest", "RemotePkgPlatform": "x86_64-pc-linux-gnu-ubuntu-22.04", - "RemoteSha": "2.3.0" + "RemoteSha": "2.3.1" }, - "sources": "https://packagemanager.posit.co/cran/__linux__/jammy/latest/src/contrib/spelling_2.3.0.tar.gz", - "target": "src/contrib/x86_64-pc-linux-gnu-ubuntu-22.04/4.4/spelling_2.3.0.tar.gz", + "sources": "https://packagemanager.posit.co/cran/__linux__/jammy/latest/src/contrib/spelling_2.3.1.tar.gz", + "target": "src/contrib/x86_64-pc-linux-gnu-ubuntu-22.04/4.4/spelling_2.3.1.tar.gz", "platform": "x86_64-pc-linux-gnu-ubuntu-22.04", "rversion": "4.4", "directpkg": false, @@ -3719,7 +3719,7 @@ { "ref": "sys", "package": "sys", - "version": "3.4.2", + "version": "3.4.3", "type": "standard", "direct": false, "binary": true, @@ -3731,10 +3731,10 @@ "RemoteRef": "sys", "RemoteRepos": "https://packagemanager.posit.co/cran/__linux__/jammy/latest", "RemotePkgPlatform": "x86_64-pc-linux-gnu-ubuntu-22.04", - "RemoteSha": "3.4.2" + "RemoteSha": "3.4.3" }, - "sources": "https://packagemanager.posit.co/cran/__linux__/jammy/latest/src/contrib/sys_3.4.2.tar.gz", - "target": "src/contrib/x86_64-pc-linux-gnu-ubuntu-22.04/4.4/sys_3.4.2.tar.gz", + "sources": "https://packagemanager.posit.co/cran/__linux__/jammy/latest/src/contrib/sys_3.4.3.tar.gz", + "target": "src/contrib/x86_64-pc-linux-gnu-ubuntu-22.04/4.4/sys_3.4.3.tar.gz", "platform": "x86_64-pc-linux-gnu-ubuntu-22.04", "rversion": "4.4", "directpkg": false, @@ -4340,7 +4340,7 @@ { "ref": "xfun", "package": "xfun", - "version": "0.47", + "version": "0.48", "type": "standard", "direct": false, "binary": true, @@ -4352,10 +4352,10 @@ "RemoteRef": "xfun", "RemoteRepos": "https://packagemanager.posit.co/cran/__linux__/jammy/latest", "RemotePkgPlatform": "x86_64-pc-linux-gnu-ubuntu-22.04", - "RemoteSha": "0.47" + "RemoteSha": "0.48" }, - "sources": "https://packagemanager.posit.co/cran/__linux__/jammy/latest/src/contrib/xfun_0.47.tar.gz", - "target": "src/contrib/x86_64-pc-linux-gnu-ubuntu-22.04/4.4/xfun_0.47.tar.gz", + "sources": "https://packagemanager.posit.co/cran/__linux__/jammy/latest/src/contrib/xfun_0.48.tar.gz", + "target": "src/contrib/x86_64-pc-linux-gnu-ubuntu-22.04/4.4/xfun_0.48.tar.gz", "platform": "x86_64-pc-linux-gnu-ubuntu-22.04", "rversion": "4.4", "directpkg": false, From 01e53e7eef814349f1e9556c7aada325d281aeda Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rasmus=20Skytte=20Randl=C3=B8v?= Date: Fri, 11 Oct 2024 15:34:41 +0200 Subject: [PATCH 03/17] docs(NEWS): Add entry on `*_joins` fix for `dplyr::join_by()` --- NEWS.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/NEWS.md b/NEWS.md index ad8693cc..81d607b1 100644 --- a/NEWS.md +++ b/NEWS.md @@ -8,6 +8,8 @@ * `update_snapshot()` has been optimized and now runs faster on all the supported backends (#137). +* `*_joins()` can now take `dplyr::join_by()` as `by` argument (#156). + ## Documentation * A vignette including benchmarks of `update_snapshot()` across various backends is added (#138). From ddb1f543e3ede7a076db5fa215bf4fa6c596f9ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rasmus=20Skytte=20Randl=C3=B8v?= Date: Sun, 13 Oct 2024 21:46:16 +0200 Subject: [PATCH 04/17] fix(db_joins): Remove checkmate checks for dplyr dispatch --- R/db_joins.R | 111 +++++++++++++++++++-------------------------------- 1 file changed, 42 insertions(+), 69 deletions(-) diff --git a/R/db_joins.R b/R/db_joins.R index 03a214ed..84760229 100644 --- a/R/db_joins.R +++ b/R/db_joins.R @@ -171,20 +171,18 @@ join_warn_experimental <- function() { #' @seealso [dplyr::show_query] #' @exportS3Method dplyr::inner_join inner_join.tbl_sql <- function(x, y, by = NULL, ...) { + .dots <- list(...) + + if (!"na_by" %in% names(.dots)) { + join_warn() + return(NextMethod("inner_join")) + } # Check arguments - assert_data_like(x) - assert_data_like(y) checkmate::assert( checkmate::check_character(by, null.ok = TRUE), checkmate::check_class(by, "dplyr_join_by", null.ok = TRUE) ) - .dots <- list(...) - - if (!"na_by" %in% names(.dots)) { - if (inherits(x, "tbl_dbi") || inherits(y, "tbl_dbi")) join_warn() - return(NextMethod("inner_join")) - } join_warn_experimental() @@ -207,23 +205,19 @@ inner_join.tbl_sql <- function(x, y, by = NULL, ...) { #' @rdname joins #' @exportS3Method dplyr::left_join left_join.tbl_sql <- function(x, y, by = NULL, ...) { + .dots <- list(...) + + if (!"na_by" %in% names(.dots)) { + join_warn() + return(NextMethod("left_join")) + } # Check arguments - assert_data_like(x) - assert_data_like(y) checkmate::assert( checkmate::check_character(by, null.ok = TRUE), checkmate::check_class(by, "dplyr_join_by", null.ok = TRUE) ) - .dots <- list(...) - - if (!"na_by" %in% names(.dots)) { - if (inherits(x, "tbl_dbi") || inherits(y, "tbl_dbi")) join_warn() - - return(NextMethod("left_join")) - } - join_warn_experimental() args <- as.list(rlang::current_env()) |> @@ -245,23 +239,19 @@ left_join.tbl_sql <- function(x, y, by = NULL, ...) { #' @rdname joins #' @exportS3Method dplyr::right_join right_join.tbl_sql <- function(x, y, by = NULL, ...) { + .dots <- list(...) + + if (!"na_by" %in% names(.dots)) { + join_warn() + return(NextMethod("right_join")) + } # Check arguments - assert_data_like(x) - assert_data_like(y) checkmate::assert( checkmate::check_character(by, null.ok = TRUE), checkmate::check_class(by, "dplyr_join_by", null.ok = TRUE) ) - .dots <- list(...) - - if (!"na_by" %in% names(.dots)) { - if (inherits(x, "tbl_dbi") || inherits(y, "tbl_dbi")) join_warn() - - return(NextMethod("right_join")) - } - join_warn_experimental() args <- as.list(rlang::current_env()) |> @@ -284,71 +274,54 @@ right_join.tbl_sql <- function(x, y, by = NULL, ...) { #' @rdname joins #' @exportS3Method dplyr::full_join full_join.tbl_sql <- function(x, y, by = NULL, ...) { + .dots <- list(...) + + if (!"na_by" %in% names(.dots)) { + join_warn() + return(NextMethod("full_join")) + } # Check arguments - assert_data_like(x) - assert_data_like(y) checkmate::assert( checkmate::check_character(by, null.ok = TRUE), checkmate::check_class(by, "dplyr_join_by", null.ok = TRUE) ) - .dots <- list(...) + join_warn_experimental() - if ("na_by" %in% names(.dots)) { - join_warn_experimental() - # Full joins are hard... - out <- dplyr::union(dplyr::left_join(x, y, by = by, na_by = .dots$na_by), - dplyr::right_join(x, y, by = by, na_by = .dots$na_by)) - return(out) - } else { - if (inherits(x, "tbl_dbi") || inherits(y, "tbl_dbi")) join_warn() - return(dplyr::full_join(x, y, by = by, ...)) - } + # Full joins are hard... + out <- dplyr::union( + dplyr::left_join(x, y, by = by, na_by = .dots$na_by), + dplyr::right_join(x, y, by = by, na_by = .dots$na_by) + ) + + return(out) } #' @rdname joins #' @exportS3Method dplyr::semi_join semi_join.tbl_sql <- function(x, y, by = NULL, ...) { - - # Check arguments - assert_data_like(x) - assert_data_like(y) - checkmate::assert( - checkmate::check_character(by, null.ok = TRUE), - checkmate::check_class(by, "dplyr_join_by", null.ok = TRUE) - ) - .dots <- list(...) - if ("na_by" %in% names(.dots)) { - stop("Not implemented") - } else { - if (inherits(x, "tbl_dbi") || inherits(y, "tbl_dbi")) join_warn() - return(dplyr::semi_join(x, y, by = by, ...)) + if (!"na_by" %in% names(.dots)) { + join_warn() + return(NextMethod("semi_join")) } + + stop("Not implemented") } #' @rdname joins #' @exportS3Method dplyr::anti_join anti_join.tbl_sql <- function(x, y, by = NULL, ...) { - - # Check arguments - assert_data_like(x) - assert_data_like(y) - checkmate::assert( - checkmate::check_character(by, null.ok = TRUE), - checkmate::check_class(by, "dplyr_join_by", null.ok = TRUE) - ) - .dots <- list(...) - if ("na_by" %in% names(.dots)) { - stop("Not implemented") - } else { - if (inherits(x, "tbl_dbi") || inherits(y, "tbl_dbi")) join_warn() - return(dplyr::anti_join(x, y, by = by, ...)) + if (!"na_by" %in% names(.dots)) { + join_warn() + return(NextMethod("anti_join")) } + + stop("Not implemented") } From d2c6f7d7adfb959c75504bd53bcc94e9d59722c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rasmus=20Skytte=20Randl=C3=B8v?= Date: Sun, 13 Oct 2024 21:51:09 +0200 Subject: [PATCH 05/17] test(db_joins): Test all dplyr joins work as expected --- tests/testthat/test-db_joins.R | 86 +++++++++++++++++++++++++--------- 1 file changed, 64 insertions(+), 22 deletions(-) diff --git a/tests/testthat/test-db_joins.R b/tests/testthat/test-db_joins.R index a7c95885..10415c2c 100644 --- a/tests/testthat/test-db_joins.R +++ b/tests/testthat/test-db_joins.R @@ -1,28 +1,6 @@ test_that("*_join() works with character `by` and `na_by`", { for (conn in get_test_conns()) { - # Define two test datasets - x <- get_table(conn, "__mtcars") |> - dplyr::select(name, mpg, cyl, hp, vs, am, gear, carb) - - y <- get_table(conn, "__mtcars") |> - dplyr::select(name, drat, wt, qsec) - - - # Test the implemented joins - q <- dplyr::left_join(x, y, by = "name") |> dplyr::collect() - qr <- dplyr::left_join(dplyr::collect(x), dplyr::collect(y), by = "name") - expect_equal(q, qr) - - q <- dplyr::right_join(x, y, by = "name") |> dplyr::collect() - qr <- dplyr::right_join(dplyr::collect(x), dplyr::collect(y), by = "name") - expect_equal(q, qr) - - q <- dplyr::inner_join(x, y, by = "name") |> dplyr::collect() - qr <- dplyr::inner_join(dplyr::collect(x), dplyr::collect(y), by = "name") - expect_equal(q, qr) - - # Create two more synthetic test data set with NA data # First test case @@ -144,3 +122,67 @@ test_that("*_join() works with `dplyr::join_by()`", { connection_clean_up(conn) } }) + + +test_that("*_join() does not break any dplyr joins", { + for (conn in get_test_conns()) { + + # Define two test datasets + x <- get_table(conn, "__mtcars") |> + dplyr::select(name, mpg, cyl, hp, vs, am, gear, carb) + + y <- get_table(conn, "__mtcars") |> + dplyr::select(name, drat, wt, qsec) + + # Test the standard joins + # left_join + qr <- dplyr::left_join(dplyr::collect(x), dplyr::collect(y), by = "name") + q <- dplyr::left_join(x, y, by = "name") |> dplyr::collect() + expect_equal(q, qr) + + q <- dplyr::left_join(x, y, by = dplyr::join_by(x$name == y$name)) |> dplyr::collect() + expect_equal(q, qr) + + # right_join + qr <- dplyr::right_join(dplyr::collect(x), dplyr::collect(y), by = "name") + q <- dplyr::right_join(x, y, by = "name") |> dplyr::collect() + expect_equal(q, qr) + + q <- dplyr::right_join(x, y, by = dplyr::join_by(x$name == y$name)) |> dplyr::collect() + expect_equal(q, qr) + + # inner_join + qr <- dplyr::inner_join(dplyr::collect(x), dplyr::collect(y), by = "name") + q <- dplyr::inner_join(x, y, by = "name") |> dplyr::collect() + expect_equal(q, qr) + + q <- dplyr::inner_join(x, y, by = dplyr::join_by(x$name == y$name)) |> dplyr::collect() + expect_equal(q, qr) + + # full_join + qr <- dplyr::full_join(dplyr::collect(x), dplyr::collect(y), by = "name") + q <- dplyr::full_join(x, y, by = "name") |> dplyr::collect() + expect_equal(q, qr) + + q <- dplyr::full_join(x, y, by = dplyr::join_by(x$name == y$name)) |> dplyr::collect() + expect_equal(q, qr) + + # semi_join + qr <- dplyr::semi_join(dplyr::collect(x), dplyr::collect(y), by = "name") + q <- dplyr::semi_join(x, y, by = "name") |> dplyr::collect() + expect_equal(q, qr) + + q <- dplyr::semi_join(x, y, by = dplyr::join_by(x$name == y$name)) |> dplyr::collect() + expect_equal(q, qr) + + # anti_join + qr <- dplyr::anti_join(dplyr::collect(x), dplyr::collect(y), by = "name") + q <- dplyr::anti_join(x, y, by = "name") |> dplyr::collect() + expect_equal(q, qr) + + q <- dplyr::anti_join(x, y, by = dplyr::join_by(x$name == y$name)) |> dplyr::collect() + expect_equal(q, qr) + + connection_clean_up(conn) + } +}) From 12e9d0601ed843a3e40962cf63697f3cbe4eed7a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rasmus=20Skytte=20Randl=C3=B8v?= Date: Mon, 14 Oct 2024 10:41:53 +0200 Subject: [PATCH 06/17] docs(NEWS): Be more precise on when `dplyr::join_by()` works --- NEWS.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/NEWS.md b/NEWS.md index 81d607b1..0e266d02 100644 --- a/NEWS.md +++ b/NEWS.md @@ -8,7 +8,7 @@ * `update_snapshot()` has been optimized and now runs faster on all the supported backends (#137). -* `*_joins()` can now take `dplyr::join_by()` as `by` argument (#156). +* `*_joins()` can now take `dplyr::join_by()` as `by` argument when no `na_by` argument is given (#156). ## Documentation From ec8aabed6366143c2450d04586ae9319420272f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rasmus=20Skytte=20Randl=C3=B8v?= Date: Tue, 15 Oct 2024 10:12:47 +0200 Subject: [PATCH 07/17] feat(db_joins): Use dbplyr to translate joins --- R/db_joins.R | 430 +++++++++++++++++++++++++++------------------------ man/joins.Rd | 16 +- 2 files changed, 236 insertions(+), 210 deletions(-) diff --git a/R/db_joins.R b/R/db_joins.R index 84760229..006520bb 100644 --- a/R/db_joins.R +++ b/R/db_joins.R @@ -1,138 +1,22 @@ -#' Generate sql_on statement for na joins -#' -#' @description -#' This function generates a much faster SQL statement for NA join compared to dbplyr's _join with na_matches = "na". -#' @inheritParams left_join -#' @return -#' A sql_on statement to join by such that "NA" are matched with "NA" given the columns listed in "by" and "na_by". -#' @noRd -join_na_sql <- function(x, by, na_by) { - UseMethod("join_na_sql") -} - -join_na_not_distinct <- function(by, na_by = NULL) { - sql_on <- "" - if (!missing(by)) { - for (i in seq_along(by)) { - sql_on <- paste0(sql_on, '"LHS"."', by[i], '" = "RHS"."', by[i], '"') - if (i < length(by) || !is.null(na_by)) { - sql_on <- paste(sql_on, "\nAND ") - } - } - } - - if (!missing(na_by)) { - for (i in seq_along(na_by)) { - sql_on <- paste0(sql_on, '"LHS"."', na_by[i], '" IS NOT DISTINCT FROM "RHS"."', na_by[i], '"') - if (i < length(na_by)) { - sql_on <- paste(sql_on, "\nAND ") - } - } - } - - return(sql_on) -} - -join_na_not_null <- function(by, na_by = NULL) { - sql_on <- "" - if (!missing(by)) { - for (i in seq_along(by)) { - sql_on <- paste0(sql_on, '"LHS"."', by[i], '" = "RHS"."', by[i], '"') - if (i < length(by) || !is.null(na_by)) { - sql_on <- paste(sql_on, "\nAND ") - } - } - } - - if (!missing(na_by)) { - for (i in seq_along(na_by)) { - sql_on <- paste0(sql_on, - '("LHS"."', na_by[i], '" IS NULL AND "RHS"."', na_by[i], '" IS NULL ', - 'OR "LHS"."', na_by[i], '" = "RHS"."', na_by[i], '")') - if (i < length(na_by)) { - sql_on <- paste(sql_on, "\nAND ") - } - } - } - - return(sql_on) -} - -#' @noRd -join_na_sql.tbl_dbi <- function(x, by, na_by) { - return(join_na_not_distinct(by = by, na_by = na_by)) -} - -#' @noRd -`join_na_sql.tbl_Microsoft SQL Server` <- function(x, by, na_by) { - return(join_na_not_null(by = by, na_by = na_by)) -} - -#' Get colnames to select -#' -#' @inheritParams left_join -#' @param left (`logical(1)`)\cr -#' Is the join a left (alternatively right) join? -#' @return -#' A named character vector indicating which columns to select from x and y. -#' @noRd -select_na_sql <- function(x, y, by, na_by, left = TRUE) { - - all_by <- c(by, na_by) # Variables to be common after join - cx <- dplyr::setdiff(colnames(x), colnames(y)) # Variables only in x - cy <- dplyr::setdiff(colnames(y), colnames(x)) # Variables only in y - - sql_select <- - c(paste0(colnames(x), ifelse(colnames(x) %in% cx, "", ".x")), - paste0(colnames(y), ifelse(colnames(y) %in% cy, "", ".y"))[!colnames(y) %in% all_by]) |> - stats::setNames(c(colnames(x), - paste0(colnames(y), ifelse(colnames(y) %in% colnames(x), ".y", ""))[!colnames(y) %in% all_by])) - - return(sql_select) -} - - -#' Warn users that SQL does not match on NA by default -#' -#' @return -#' A warning that *_joins on SQL backends does not match NA by default. -#' @noRd -join_warn <- function() { - if (interactive() && identical(parent.frame(n = 2), globalenv())) { - rlang::warn(paste("*_joins in database-backend does not match NA by default.\n", - "If your data contains NA, the columns with NA values must be supplied to \"na_by\",", - "or you must specifiy na_matches = \"na\""), - .frequency = "once", .frequency_id = "*_join NA warning") - } -} - - -#' Warn users that SQL joins by NA is experimental -#' -#' @return -#' A warning that *_joins are still experimental. -#' @noRd -join_warn_experimental <- function() { - if (interactive() && identical(parent.frame(n = 2), globalenv())) { - rlang::warn("*_joins with na_by is still experimental. Please report issues.", - .frequency = "once", .frequency_id = "*_join NA warning") - } -} - - #' SQL Joins #' #' @name joins #' #' @description +#' `r lifecycle::badge("experimental")` +#' #' Overloads the dplyr `*_join` to accept an `na_by` argument. #' By default, joining using SQL does not match on `NA` / `NULL`. -#' dbplyr `*_join`s has the option "na_matches = na" to match on `NA` / `NULL` but this is very inefficient in some -#' cases. -#' This function does the matching more efficiently: +#' dbplyr `*_join`s has the option "na_matches = na" to match on `NA` / `NULL` but this operation is substantially +#' slower since it turns all equality comparisons to identical comparisons. +#' +#' This function does the matching more efficiently by allowing the user to specify which column contains +#' `NA` / `NULL` values and which does not: #' If a column contains `NA` / `NULL`, the names of these columns can be passed via the `na_by` argument and -#' efficiently match as if "na_matches = na". -#' If no `na_by` argument is given is given, the function defaults to using `dplyr::*_join`. +#' efficiently match as if `na_matches = "na"`. +#' Columns without `NA` / `NULL` values is passed via the `by` argument and will be matched `na_matches = "never"`. +#' +#' If no `na_by` argument is given, the function defaults to using `dplyr::*_join` without modification. #' #' @inheritParams dbplyr::join.tbl_sql #' @return Another \code{tbl_lazy}. Use \code{\link[dplyr:show_query]{show_query()}} to see the generated @@ -178,28 +62,11 @@ inner_join.tbl_sql <- function(x, y, by = NULL, ...) { return(NextMethod("inner_join")) } - # Check arguments - checkmate::assert( - checkmate::check_character(by, null.ok = TRUE), - checkmate::check_class(by, "dplyr_join_by", null.ok = TRUE) - ) - - join_warn_experimental() + # Prepare the combined join + out <- do.call(dplyr::inner_join, args = join_args(.dots)) + out$lazy_query$vars <- join_na_select_fix(out$lazy_query$vars, .dots$na_by) - args <- as.list(rlang::current_env()) |> - append(.dots) - - .renamer <- select_na_sql(x, y, by, .dots$na_by) - - # Remove na_by from args to avoid infinite loops - args$na_by <- NULL - args$sql_on <- join_na_sql(x, by, .dots$na_by) - - join_result <- do.call(dplyr::inner_join, args = args) |> - dplyr::rename(!!.renamer) |> - dplyr::select(tidyselect::all_of(names(.renamer))) - - return(join_result) + return(out) } #' @rdname joins @@ -212,28 +79,10 @@ left_join.tbl_sql <- function(x, y, by = NULL, ...) { return(NextMethod("left_join")) } - # Check arguments - checkmate::assert( - checkmate::check_character(by, null.ok = TRUE), - checkmate::check_class(by, "dplyr_join_by", null.ok = TRUE) - ) + out <- do.call(dplyr::left_join, args = join_args(.dots)) + out$lazy_query$vars <- join_na_select_fix(out$lazy_query$vars, .dots$na_by) - join_warn_experimental() - - args <- as.list(rlang::current_env()) |> - append(.dots) - - .renamer <- select_na_sql(x, y, by, .dots$na_by) - - # Remove na_by from args to avoid infinite loops - args$na_by <- NULL - args$sql_on <- join_na_sql(x, by, .dots$na_by) - - join_result <- do.call(dplyr::left_join, args = args) |> - dplyr::rename(!!.renamer) |> - dplyr::select(tidyselect::all_of(names(.renamer))) - - return(join_result) + return(out) } #' @rdname joins @@ -246,28 +95,10 @@ right_join.tbl_sql <- function(x, y, by = NULL, ...) { return(NextMethod("right_join")) } - # Check arguments - checkmate::assert( - checkmate::check_character(by, null.ok = TRUE), - checkmate::check_class(by, "dplyr_join_by", null.ok = TRUE) - ) - - join_warn_experimental() - - args <- as.list(rlang::current_env()) |> - append(.dots) - - .renamer <- select_na_sql(x, y, by, .dots$na_by) - - # Remove na_by from args to avoid infinite loops - args$na_by <- NULL - args$sql_on <- join_na_sql(x, by, .dots$na_by) - - join_result <- do.call(dplyr::right_join, args = args) |> - dplyr::rename(!!.renamer) |> - dplyr::select(tidyselect::all_of(names(.renamer))) + out <- do.call(dplyr::right_join, args = join_args(.dots)) + out$lazy_query$vars <- join_na_select_fix(out$lazy_query$vars, .dots$na_by, right = TRUE) - return(join_result) + return(out) } @@ -281,19 +112,8 @@ full_join.tbl_sql <- function(x, y, by = NULL, ...) { return(NextMethod("full_join")) } - # Check arguments - checkmate::assert( - checkmate::check_character(by, null.ok = TRUE), - checkmate::check_class(by, "dplyr_join_by", null.ok = TRUE) - ) - - join_warn_experimental() - - # Full joins are hard... - out <- dplyr::union( - dplyr::left_join(x, y, by = by, na_by = .dots$na_by), - dplyr::right_join(x, y, by = by, na_by = .dots$na_by) - ) + out <- do.call(dplyr::full_join, args = join_args(.dots)) + out$lazy_query$vars <- join_na_select_fix(out$lazy_query$vars, .dots$na_by) return(out) } @@ -309,7 +129,9 @@ semi_join.tbl_sql <- function(x, y, by = NULL, ...) { return(NextMethod("semi_join")) } - stop("Not implemented") + out <- do.call(dplyr::semi_join, args = join_args(.dots)) + + return(out) } @@ -323,5 +145,203 @@ anti_join.tbl_sql <- function(x, y, by = NULL, ...) { return(NextMethod("anti_join")) } - stop("Not implemented") + out <- do.call(dplyr::anti_join, args = join_args(.dots)) + + return(out) +} + + +#' Warn users that SQL does not match on NA by default +#' +#' @return +#' A warning that *_joins on SQL backends does not match NA by default. +#' @noRd +join_warn <- function() { + if (interactive() && identical(parent.frame(n = 2), globalenv())) { + rlang::warn( + paste( + "*_joins in database-backend does not match NA by default.\n", + "If your data contains NA, the columns with NA values must be supplied to \"na_by\",", + "or you must specify na_matches = \"na\"" + ), + .frequency = "once", + .frequency_id = "*_join NA warning" + ) + } +} + + +#' Warn users that SQL joins by NA is experimental +#' +#' @return +#' A warning that *_joins are still experimental. +#' @noRd +join_warn_experimental <- function() { + if (interactive() && identical(parent.frame(n = 2), globalenv())) { + rlang::warn( + "*_joins with na_by is still experimental. Please report issues.", + .frequency = "once", + .frequency_id = "*_join NA warning" + ) + } +} + + +#' Construct the arguments to `*_join` that accounts for the na matching +#' @param x (`tbl_sql`) The left table to join. +#' @param y (`tbl_sql`) The right table to join. +#' @param by (`dbplyr_join_by` or `character`) The columns to match on without NA values. +#' @param .dots (`list`) Arguments passed to the `*_join` function. +#' @noRd +join_args <- function(.dots) { + # Grab the environment of the caller and add the dot args + args <- append(as.list(rlang::caller_env()), .dots) + + # Remove the na matching args, and let join_na_sql combine the `by` and `na_by` statements + args$na_by <- NULL + args$na_matches <- NULL + args$by <- join_na_sql(args$x, args$y, by = args$by, na_by = .dots$na_by) + + return(args) +} + + +#' Merge two `dplyr_join_by` objects +#' @param by (`dplyr_join_by` or `character`) The columns to match on without NA values. +#' @param na_by (`dplyr_join_by` or `character`) The columns to match on NA. +#' @noRd +join_merger <- function(by, na_by) { + + # Early return if only one by statement is given + if (is.null(by) && is.null(na_by)) { + stop("Both by and na_by cannot be NULL") + } else if (is.null(by)) { + return(na_by) + } else if (is.null(na_by)) { + return(by) + } + + # Combine the by and na_by statements by unclassing, merging and reclassing + combined_join <- list( + "exprs" = c(purrr::pluck(by, "exprs"), purrr::pluck(na_by, "exprs")) + ) |> + modifyList( + purrr::map2(purrr::discard_at(by, "exprs"), purrr::discard_at(na_by, "exprs"), ~ c(.x, .y)) + ) + class(combined_join) <- "dplyr_join_by" + + return(combined_join) +} + + +#' Generate `dplyr_join_by` statement for na joins +#' +#' @description +#' This function creates a `dplyr_join_by` object to join by where the statements supplied in `by` are treated as not +#' having NA values while the columns listed in `na_by` are treated as having NA values. +#' This latter translation corresponds to using `dplyr::*_join` with `na_matches = "na"`. +#' @inheritParams left_join +#' @param na_by (`character`)\cr +#' The columns to match on NA. If a column contains NA, the names of these columns can be passed via the `na_by` +#' argument. These will then be matched as if with the `na_matches = "na"` argument. +#' @return +#' A `dplyr_join_by` object to join by such that "NA" are matched with "NA" given the columns listed in `by` and +#' `na_by`. +#' @noRd +join_na_sql <- function(x, y, by = NULL, na_by = NULL) { + + # Check arguments + checkmate::assert( + checkmate::check_character(by, null.ok = TRUE), + checkmate::check_class(by, "dplyr_join_by", null.ok = TRUE) + ) + + join_warn_experimental() + + # Convert to dplyr_join_by if not already + if (!is.null(by) && !inherits(by, "dplyr_join_by")) { + by <- dplyr::join_by(!!by) + } + + if (!is.null(na_by) && !inherits(na_by, "dplyr_join_by")) { + na_by <- dplyr::join_by(!!na_by) + } + + combined_join <- join_merger(by, na_by) + + # Get the translation for matching the na_by component of the join + na_subquery <- dbplyr::remote_query(dplyr::inner_join(x, y, by = combined_join, na_matches = "na")) + + # Determine the NA matching statement by extracting from the translated query. + # E.g. on RSQlite, the keyword "IS" checks if arguments are identical + # and on PostgreSQL, the keyword "IS NOT DISTINCT FROM" checks if arguments are identical. + na_matching <- na_subquery |> + stringr::str_remove_all(stringr::fixed("\n")) |> # Remove newlines from the formatted query + stringr::str_replace_all(r"{\s{2,}}", " ") |> # Remove multiple spaces from the formatted query + stringr::str_extract(r"{(?<=ON \().*(?=\))}") |> # Extract the contents of the ON statement + stringr::str_extract(pattern = r"{(?:["'`´]\s)([\w\s]+)(?:\s["'`´])}", group = 1) # First non quoted word(s) + + # Replace NA equals with NA matching statement + na_by$condition[na_by$condition == "=="] <- na_matching + + return(join_merger(by, na_by)) +} + + +#' Manually fixes the select component of the `lazy_query` after overwriting the `by` statement. +#' +#' @description +#' After overwriting the `by` statement in the `lazy_query`, the `vars` component of the `lazy_query` is not +#' consistent with the new non-overwritten `by` statement. +#' As a result, columns which are matched in the join are included as both `.x` and ``, instead of just +#' as ``. +#' This function fixes the `vars` component of the `lazy_query` to remove the doubly selected columns and rename +#' to the expected name. +#' @param vars (`tibble`) The `vars` component of the `lazy_query`. +#' @param na_by (`dplyr_join_by`) The `na_by` statement used in the join. +#' @param right (`logical`) If the join is a right join. +#' @return +#' A `tibble` with the `vars` component of the `lazy_query` fixed to remove doubly selected columns. +#' @noRd +join_na_select_fix <- function(vars, na_by, right = FALSE) { + if (!inherits(na_by, "dplyr_join_by")) na_by <- dplyr::join_by(!!na_by) + + # All equality joins in `na_by` are incorrectly translated + doubly_selected_columns <- na_by |> + purrr::discard_at("exprs") |> + tibble::as_tibble() |> + dplyr::filter(.data$condition == "==", .data$x == .data$y) |> + dplyr::pull("x") + + if (length(doubly_selected_columns) == 0) { + updated_vars <- vars # no doubly selected columns + } else { + + # The vars table structure is not consistent between dplyr join types + # There are two formats which we needs to manage independently. + if (checkmate::test_names(names(vars), identical.to = c("name", "x", "y"))) { + updated_vars <- rbind( + tibble::tibble( + "name" = doubly_selected_columns, + "x" = ifelse(right, NA, doubly_selected_columns), + "y" = doubly_selected_columns + ), + dplyr::filter(vars, .data$x %in% !!doubly_selected_columns | .data$y %in% !!doubly_selected_columns) + ) |> + dplyr::symdiff(vars) + + } else if (checkmate::test_names(names(vars), identical.to = c("name", "table", "var"))) { + updated_vars <- rbind( + tibble::tibble( + "name" = doubly_selected_columns, + "table" = 1, + "var" = doubly_selected_columns + ), + dplyr::filter(vars, .data$var %in% !!doubly_selected_columns) + ) |> + dplyr::symdiff(vars) + } + } + + return(updated_vars) } diff --git a/man/joins.Rd b/man/joins.Rd index dd54252b..7d70ffcd 100644 --- a/man/joins.Rd +++ b/man/joins.Rd @@ -62,14 +62,20 @@ query, and use \code{\link[dbplyr:collect.tbl_sql]{collect()}} to execute the qu and return data to R. } \description{ +\ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#experimental}{\figure{lifecycle-experimental.svg}{options: alt='[Experimental]'}}}{\strong{[Experimental]}} + Overloads the dplyr \verb{*_join} to accept an \code{na_by} argument. By default, joining using SQL does not match on \code{NA} / \code{NULL}. -dbplyr \verb{*_join}s has the option "na_matches = na" to match on \code{NA} / \code{NULL} but this is very inefficient in some -cases. -This function does the matching more efficiently: +dbplyr \verb{*_join}s has the option "na_matches = na" to match on \code{NA} / \code{NULL} but this operation is substantially +slower since it turns all equality comparisons to identical comparisons. + +This function does the matching more efficiently by allowing the user to specify which column contains +\code{NA} / \code{NULL} values and which does not: If a column contains \code{NA} / \code{NULL}, the names of these columns can be passed via the \code{na_by} argument and -efficiently match as if "na_matches = na". -If no \code{na_by} argument is given is given, the function defaults to using \verb{dplyr::*_join}. +efficiently match as if \code{na_matches = "na"}. +Columns without \code{NA} / \code{NULL} values is passed via the \code{by} argument and will be matched \code{na_matches = "never"}. + +If no \code{na_by} argument is given, the function defaults to using \verb{dplyr::*_join} without modification. } \examples{ \dontshow{if (requireNamespace("RSQLite", quietly = TRUE)) (if (getRversion() >= "3.4") withAutoprint else force)(\{ # examplesIf} From bd1817de36ac74d20b727878a9c66ddf34adf66e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rasmus=20Skytte=20Randl=C3=B8v?= Date: Tue, 15 Oct 2024 10:14:02 +0200 Subject: [PATCH 08/17] test(db_joins): Test that na_by and by arguments are interchangeable without NA data --- tests/testthat/test-db_joins.R | 64 ++++++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/tests/testthat/test-db_joins.R b/tests/testthat/test-db_joins.R index 10415c2c..03e1457e 100644 --- a/tests/testthat/test-db_joins.R +++ b/tests/testthat/test-db_joins.R @@ -186,3 +186,67 @@ test_that("*_join() does not break any dplyr joins", { connection_clean_up(conn) } }) + + +test_that("*_join() with only na_by works as dplyr joins", { + for (conn in get_test_conns()) { + + # Define two test datasets + x <- get_table(conn, "__mtcars") |> + dplyr::select(name, mpg, cyl, hp, vs, am, gear, carb) + + y <- get_table(conn, "__mtcars") |> + dplyr::select(name, drat, wt, qsec) + + # Test the standard joins + # left_join + qr <- dplyr::left_join(dplyr::collect(x), dplyr::collect(y), by = "name") + q <- dplyr::left_join(x, y, na_by = "name") |> dplyr::collect() + expect_equal(q, qr) + + q <- dplyr::left_join(x, y, na_by = dplyr::join_by(x$name == y$name)) |> dplyr::collect() + expect_equal(q, qr) + + # right_join + qr <- dplyr::right_join(dplyr::collect(x), dplyr::collect(y), by = "name") + q <- dplyr::right_join(x, y, na_by = "name") |> dplyr::collect() + expect_equal(q, qr) + + q <- dplyr::right_join(x, y, na_by = dplyr::join_by(x$name == y$name)) |> dplyr::collect() + expect_equal(q, qr) + + # inner_join + qr <- dplyr::inner_join(dplyr::collect(x), dplyr::collect(y), by = "name") + q <- dplyr::inner_join(x, y, na_by = "name") |> dplyr::collect() + expect_equal(q, qr) + + q <- dplyr::inner_join(x, y, na_by = dplyr::join_by(x$name == y$name)) |> dplyr::collect() + expect_equal(q, qr) + + # full_join + qr <- dplyr::full_join(dplyr::collect(x), dplyr::collect(y), by = "name") + q <- dplyr::full_join(x, y, na_by = "name") |> dplyr::collect() + expect_equal(q, qr) + + q <- dplyr::full_join(x, y, na_by = dplyr::join_by(x$name == y$name)) |> dplyr::collect() + expect_equal(q, qr) + + # semi_join + qr <- dplyr::semi_join(dplyr::collect(x), dplyr::collect(y), by = "name") + q <- dplyr::semi_join(x, y, na_by = "name") |> dplyr::collect() + expect_equal(q, qr) + + q <- dplyr::semi_join(x, y, na_by = dplyr::join_by(x$name == y$name)) |> dplyr::collect() + expect_equal(q, qr) + + # anti_join + qr <- dplyr::anti_join(dplyr::collect(x), dplyr::collect(y), by = "name") + q <- dplyr::anti_join(x, y, na_by = "name") |> dplyr::collect() + expect_equal(q, qr) + + q <- dplyr::anti_join(x, y, na_by = dplyr::join_by(x$name == y$name)) |> dplyr::collect() + expect_equal(q, qr) + + connection_clean_up(conn) + } +}) From 01a6c668eeffe150e8049bb0a9641845f8ce839c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rasmus=20Skytte=20Randl=C3=B8v?= Date: Tue, 15 Oct 2024 10:37:43 +0200 Subject: [PATCH 09/17] fix(db_joins): Match column order with non-overwritten joins --- R/db_joins.R | 9 +++++++++ tests/testthat/test-db_joins.R | 2 +- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/R/db_joins.R b/R/db_joins.R index 006520bb..ace49612 100644 --- a/R/db_joins.R +++ b/R/db_joins.R @@ -330,6 +330,11 @@ join_na_select_fix <- function(vars, na_by, right = FALSE) { ) |> dplyr::symdiff(vars) + # Reorder our updated columns to match the original order + updated_vars <- updated_vars[ + order(match(updated_vars$name, unique(purrr::pmap_chr(vars, ~ dplyr::coalesce(..2, ..3))))), + ] + } else if (checkmate::test_names(names(vars), identical.to = c("name", "table", "var"))) { updated_vars <- rbind( tibble::tibble( @@ -340,6 +345,10 @@ join_na_select_fix <- function(vars, na_by, right = FALSE) { dplyr::filter(vars, .data$var %in% !!doubly_selected_columns) ) |> dplyr::symdiff(vars) + + # Reorder our updated columns to match the original order + updated_vars <- updated_vars[order(match(updated_vars$name, unique(vars$var))), ] + } } diff --git a/tests/testthat/test-db_joins.R b/tests/testthat/test-db_joins.R index 03e1457e..9670d44a 100644 --- a/tests/testthat/test-db_joins.R +++ b/tests/testthat/test-db_joins.R @@ -21,7 +21,7 @@ test_that("*_join() works with character `by` and `na_by`", { dplyr::arrange(number, t, letter) qr <- dplyr::left_join(dplyr::collect(x), dplyr::collect(y), by = "number", multiple = "all") |> dplyr::arrange(number, t, letter) - expect_mapequal(q, qr) + expect_equal(q, qr) q <- dplyr::right_join(x, y, na_by = "number") |> dplyr::collect() |> From 13747234b07061d6a80f81105aea9553b2335e35 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rasmus=20Skytte=20Randl=C3=B8v?= Date: Tue, 15 Oct 2024 10:53:24 +0200 Subject: [PATCH 10/17] chore(db_joins): Fix lints --- R/db_joins.R | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/R/db_joins.R b/R/db_joins.R index ace49612..73c376a8 100644 --- a/R/db_joins.R +++ b/R/db_joins.R @@ -188,10 +188,14 @@ join_warn_experimental <- function() { #' Construct the arguments to `*_join` that accounts for the na matching -#' @param x (`tbl_sql`) The left table to join. -#' @param y (`tbl_sql`) The right table to join. -#' @param by (`dbplyr_join_by` or `character`) The columns to match on without NA values. -#' @param .dots (`list`) Arguments passed to the `*_join` function. +#' @param x (`tbl_sql`) \cr +#' The left table to join. +#' @param y (`tbl_sql`) \cr +#' The right table to join. +#' @param by (`dbplyr_join_by` or `character`) \cr +#' The columns to match on without NA values. +#' @param .dots (`list`) \cr +#' Arguments passed to the `*_join` function. #' @noRd join_args <- function(.dots) { # Grab the environment of the caller and add the dot args @@ -207,8 +211,10 @@ join_args <- function(.dots) { #' Merge two `dplyr_join_by` objects -#' @param by (`dplyr_join_by` or `character`) The columns to match on without NA values. -#' @param na_by (`dplyr_join_by` or `character`) The columns to match on NA. +#' @param by (`dplyr_join_by` or `character`) \cr +#' The columns to match on without NA values. +#' @param na_by (`dplyr_join_by` or `character`) \cr +#' The columns to match on NA. #' @noRd join_merger <- function(by, na_by) { @@ -279,7 +285,7 @@ join_na_sql <- function(x, y, by = NULL, na_by = NULL) { stringr::str_remove_all(stringr::fixed("\n")) |> # Remove newlines from the formatted query stringr::str_replace_all(r"{\s{2,}}", " ") |> # Remove multiple spaces from the formatted query stringr::str_extract(r"{(?<=ON \().*(?=\))}") |> # Extract the contents of the ON statement - stringr::str_extract(pattern = r"{(?:["'`´]\s)([\w\s]+)(?:\s["'`´])}", group = 1) # First non quoted word(s) + stringr::str_extract(pattern = r"{(?:["'`]\s)([\w\s]+)(?:\s["'`])}", group = 1) # First non quoted word(s) # Replace NA equals with NA matching statement na_by$condition[na_by$condition == "=="] <- na_matching @@ -297,9 +303,12 @@ join_na_sql <- function(x, y, by = NULL, na_by = NULL) { #' as ``. #' This function fixes the `vars` component of the `lazy_query` to remove the doubly selected columns and rename #' to the expected name. -#' @param vars (`tibble`) The `vars` component of the `lazy_query`. -#' @param na_by (`dplyr_join_by`) The `na_by` statement used in the join. -#' @param right (`logical`) If the join is a right join. +#' @param vars (`tibble`)\cr +#' The `vars` component of the `lazy_query`. +#' @param na_by (`dplyr_join_by`)\cr +#' The `na_by` statement used in the join. +#' @param right (`logical`)\cr +#' If the join is a right join. #' @return #' A `tibble` with the `vars` component of the `lazy_query` fixed to remove doubly selected columns. #' @noRd From 2c178ed45d54c642e11f079ba6410e2a1e38ec6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rasmus=20Skytte=20Randl=C3=B8v?= Date: Tue, 15 Oct 2024 10:53:53 +0200 Subject: [PATCH 11/17] docs(db_joins): Update example with `dplyr::join_by` --- R/db_joins.R | 8 ++++---- man/joins.Rd | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/R/db_joins.R b/R/db_joins.R index 73c376a8..64078fd8 100644 --- a/R/db_joins.R +++ b/R/db_joins.R @@ -43,13 +43,13 @@ #' # But you can activate R's usual behaviour with the na_matches argument #' left_join(db, label, by = "x", na_matches = "na") #' -#' # By default, joins are equijoins, but you can use `sql_on` to +#' # By default, joins are equijoins, but you can use `dplyr::join_by()` to #' # express richer relationships -#' db1 <- memdb_frame(x = 1:5) -#' db2 <- memdb_frame(x = 1:3, y = letters[1:3]) +#' db1 <- memdb_frame(id = 1:5) +#' db2 <- memdb_frame(id = 1:3, y = letters[1:3]) #' #' left_join(db1, db2) |> show_query() -#' left_join(db1, db2, sql_on = "LHS.x < RHS.x") |> show_query() +#' left_join(db1, db2, by = join_by(x$id < y$id)) |> show_query() #' @seealso [dplyr::mutate-joins] which this function wraps. #' @seealso [dbplyr::join.tbl_sql] which this function wraps. #' @seealso [dplyr::show_query] diff --git a/man/joins.Rd b/man/joins.Rd index 7d70ffcd..5af76132 100644 --- a/man/joins.Rd +++ b/man/joins.Rd @@ -99,13 +99,13 @@ If no \code{na_by} argument is given, the function defaults to using \verb{dplyr # But you can activate R's usual behaviour with the na_matches argument left_join(db, label, by = "x", na_matches = "na") - # By default, joins are equijoins, but you can use `sql_on` to + # By default, joins are equijoins, but you can use `dplyr::join_by()` to # express richer relationships - db1 <- memdb_frame(x = 1:5) - db2 <- memdb_frame(x = 1:3, y = letters[1:3]) + db1 <- memdb_frame(id = 1:5) + db2 <- memdb_frame(id = 1:3, y = letters[1:3]) left_join(db1, db2) |> show_query() - left_join(db1, db2, sql_on = "LHS.x < RHS.x") |> show_query() + left_join(db1, db2, by = join_by(x$id < y$id)) |> show_query() \dontshow{\}) # examplesIf} } \seealso{ From 94cc1df80cf54cca132cc730df6e85af80e2b7ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rasmus=20Skytte=20Randl=C3=B8v?= Date: Tue, 15 Oct 2024 10:57:21 +0200 Subject: [PATCH 12/17] docs(NEWS) Add entry on improved db_joins --- NEWS.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/NEWS.md b/NEWS.md index 0e266d02..7cccf667 100644 --- a/NEWS.md +++ b/NEWS.md @@ -8,7 +8,9 @@ * `update_snapshot()` has been optimized and now runs faster on all the supported backends (#137). -* `*_joins()` can now take `dplyr::join_by()` as `by` argument when no `na_by` argument is given (#156). +* `*_joins()` have been more robust: + * `dbplyr` is now used internally which improves `full_join` and adds `anti_join` and `semi_join` (#157). + * When not supplying a `na_by` argument no input validation is made and unmodified `dplyr::*_join()` is called (#156). ## Documentation From f7c25d762c016c4f503ed50970ab9b90d2fd0d4a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rasmus=20Skytte=20Randl=C3=B8v?= Date: Tue, 15 Oct 2024 12:44:28 +0200 Subject: [PATCH 13/17] feat(db_joins): Add extra checkmate checks --- R/db_joins.R | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/R/db_joins.R b/R/db_joins.R index 64078fd8..9a1282f1 100644 --- a/R/db_joins.R +++ b/R/db_joins.R @@ -261,6 +261,10 @@ join_na_sql <- function(x, y, by = NULL, na_by = NULL) { checkmate::check_character(by, null.ok = TRUE), checkmate::check_class(by, "dplyr_join_by", null.ok = TRUE) ) + checkmate::assert( + checkmate::check_character(na_by, null.ok = TRUE), + checkmate::check_class(na_by, "dplyr_join_by", null.ok = TRUE) + ) join_warn_experimental() From ec4c03b9a6156d92a7e355ecdd63bede812e01f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rasmus=20Skytte=20Randl=C3=B8v?= Date: Tue, 15 Oct 2024 12:45:16 +0200 Subject: [PATCH 14/17] fix(db_join): Account for extra args such as `copy` --- R/db_joins.R | 51 +++++++++++++++++++------------ tests/testthat/test-filter_keys.R | 2 +- 2 files changed, 33 insertions(+), 20 deletions(-) diff --git a/R/db_joins.R b/R/db_joins.R index 9a1282f1..cb706848 100644 --- a/R/db_joins.R +++ b/R/db_joins.R @@ -63,7 +63,7 @@ inner_join.tbl_sql <- function(x, y, by = NULL, ...) { } # Prepare the combined join - out <- do.call(dplyr::inner_join, args = join_args(.dots)) + out <- do.call(dplyr::inner_join, args = join_args(x, y, by, .dots)) out$lazy_query$vars <- join_na_select_fix(out$lazy_query$vars, .dots$na_by) return(out) @@ -79,7 +79,7 @@ left_join.tbl_sql <- function(x, y, by = NULL, ...) { return(NextMethod("left_join")) } - out <- do.call(dplyr::left_join, args = join_args(.dots)) + out <- do.call(dplyr::left_join, args = join_args(x, y, by, .dots)) out$lazy_query$vars <- join_na_select_fix(out$lazy_query$vars, .dots$na_by) return(out) @@ -95,7 +95,7 @@ right_join.tbl_sql <- function(x, y, by = NULL, ...) { return(NextMethod("right_join")) } - out <- do.call(dplyr::right_join, args = join_args(.dots)) + out <- do.call(dplyr::right_join, args = join_args(x, y, by, .dots)) out$lazy_query$vars <- join_na_select_fix(out$lazy_query$vars, .dots$na_by, right = TRUE) return(out) @@ -112,7 +112,7 @@ full_join.tbl_sql <- function(x, y, by = NULL, ...) { return(NextMethod("full_join")) } - out <- do.call(dplyr::full_join, args = join_args(.dots)) + out <- do.call(dplyr::full_join, args = join_args(x, y, by, .dots)) out$lazy_query$vars <- join_na_select_fix(out$lazy_query$vars, .dots$na_by) return(out) @@ -129,7 +129,7 @@ semi_join.tbl_sql <- function(x, y, by = NULL, ...) { return(NextMethod("semi_join")) } - out <- do.call(dplyr::semi_join, args = join_args(.dots)) + out <- do.call(dplyr::semi_join, args = join_args(x, y, by, .dots)) return(out) } @@ -145,7 +145,7 @@ anti_join.tbl_sql <- function(x, y, by = NULL, ...) { return(NextMethod("anti_join")) } - out <- do.call(dplyr::anti_join, args = join_args(.dots)) + out <- do.call(dplyr::anti_join, args = join_args(x, y, by, .dots)) return(out) } @@ -197,14 +197,11 @@ join_warn_experimental <- function() { #' @param .dots (`list`) \cr #' Arguments passed to the `*_join` function. #' @noRd -join_args <- function(.dots) { - # Grab the environment of the caller and add the dot args - args <- append(as.list(rlang::caller_env()), .dots) +join_args <- function(x, y, by, .dots) { # Remove the na matching args, and let join_na_sql combine the `by` and `na_by` statements - args$na_by <- NULL - args$na_matches <- NULL - args$by <- join_na_sql(args$x, args$y, by = args$by, na_by = .dots$na_by) + by <- join_na_sql(x, y, by, .dots) + args <- append(list(x = x, y = y, by = by), purrr::discard_at(.dots, c("na_by", "na_matches"))) return(args) } @@ -254,7 +251,14 @@ join_merger <- function(by, na_by) { #' A `dplyr_join_by` object to join by such that "NA" are matched with "NA" given the columns listed in `by` and #' `na_by`. #' @noRd -join_na_sql <- function(x, y, by = NULL, na_by = NULL) { +join_na_sql <- function(x, y, by = NULL, .dots = NULL) { + + # Early return if no na_by statement is given + if (is.null(.dots$na_by)) { + return(by) + } else { + na_by <- .dots$na_by + } # Check arguments checkmate::assert( @@ -270,17 +274,24 @@ join_na_sql <- function(x, y, by = NULL, na_by = NULL) { # Convert to dplyr_join_by if not already if (!is.null(by) && !inherits(by, "dplyr_join_by")) { - by <- dplyr::join_by(!!by) + by <- dplyr::join_by(!!!by) } if (!is.null(na_by) && !inherits(na_by, "dplyr_join_by")) { - na_by <- dplyr::join_by(!!na_by) + na_by <- dplyr::join_by(!!!na_by) } - combined_join <- join_merger(by, na_by) - # Get the translation for matching the na_by component of the join - na_subquery <- dbplyr::remote_query(dplyr::inner_join(x, y, by = combined_join, na_matches = "na")) + subquery_args <- purrr::discard_at(.dots, "na_by") |> + modifyList( + list( + x = x, + y = y, + by = join_merger(by, na_by), + na_matches = "na" + ) + ) + na_subquery <- dbplyr::remote_query(do.call(dplyr::inner_join, args = subquery_args)) # Determine the NA matching statement by extracting from the translated query. # E.g. on RSQlite, the keyword "IS" checks if arguments are identical @@ -317,7 +328,9 @@ join_na_sql <- function(x, y, by = NULL, na_by = NULL) { #' A `tibble` with the `vars` component of the `lazy_query` fixed to remove doubly selected columns. #' @noRd join_na_select_fix <- function(vars, na_by, right = FALSE) { - if (!inherits(na_by, "dplyr_join_by")) na_by <- dplyr::join_by(!!na_by) + if (is.null(na_by)) return(vars) + + if (!inherits(na_by, "dplyr_join_by")) na_by <- dplyr::join_by(!!!na_by) # All equality joins in `na_by` are incorrectly translated doubly_selected_columns <- na_by |> diff --git a/tests/testthat/test-filter_keys.R b/tests/testthat/test-filter_keys.R index ad6952cc..76b80b4a 100644 --- a/tests/testthat/test-filter_keys.R +++ b/tests/testthat/test-filter_keys.R @@ -55,7 +55,7 @@ test_that("filter_keys() works with copy = TRUE", { dplyr::collect()) # The above filter_keys with `copy = TRUE` generates a dbplyr_### table. - # We manually remove this since we expect it. If more arrise, we will get an error. + # We manually remove this since we expect it. If more arise, we will get an error. DBI::dbRemoveTable(conn, id(utils::head(get_tables(conn, "dbplyr_"), 1))) connection_clean_up(conn) From 6908d7a005ccd11a090d28d8b29ee4a8f1256d0c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rasmus=20Skytte=20Randl=C3=B8v?= Date: Tue, 15 Oct 2024 12:50:06 +0200 Subject: [PATCH 15/17] test(db_join): Remove duplicated test --- tests/testthat/test-db_joins.R | 33 ++------------------------------- 1 file changed, 2 insertions(+), 31 deletions(-) diff --git a/tests/testthat/test-db_joins.R b/tests/testthat/test-db_joins.R index 9670d44a..ed4abe77 100644 --- a/tests/testthat/test-db_joins.R +++ b/tests/testthat/test-db_joins.R @@ -95,36 +95,7 @@ test_that("*_join() works with character `by` and `na_by`", { }) -test_that("*_join() works with `dplyr::join_by()`", { - for (conn in get_test_conns()) { - - # Define two test datasets - x <- get_table(conn, "__mtcars") |> - dplyr::select(name, mpg, cyl, hp, vs, am, gear, carb) - - y <- get_table(conn, "__mtcars") |> - dplyr::select(name, drat, wt, qsec) - - - # Test the implemented joins - q <- dplyr::left_join(x, y, by = dplyr::join_by(x$name == y$name)) |> dplyr::collect() - qr <- dplyr::left_join(dplyr::collect(x), dplyr::collect(y), by = dplyr::join_by(x$name == y$name)) - expect_equal(q, qr) - - q <- dplyr::right_join(x, y, by = dplyr::join_by(x$name == y$name)) |> dplyr::collect() - qr <- dplyr::right_join(dplyr::collect(x), dplyr::collect(y), by = dplyr::join_by(x$name == y$name)) - expect_equal(q, qr) - - q <- dplyr::inner_join(x, y, by = dplyr::join_by(x$name == y$name)) |> dplyr::collect() - qr <- dplyr::inner_join(dplyr::collect(x), dplyr::collect(y), by = dplyr::join_by(x$name == y$name)) - expect_equal(q, qr) - - connection_clean_up(conn) - } -}) - - -test_that("*_join() does not break any dplyr joins", { +test_that("*_join() does not break any dplyr joins when no `na_by` argument is given", { for (conn in get_test_conns()) { # Define two test datasets @@ -188,7 +159,7 @@ test_that("*_join() does not break any dplyr joins", { }) -test_that("*_join() with only na_by works as dplyr joins", { +test_that("*_join() with only `na_by` works identically as if `by` was given instead when no data is NA", { for (conn in get_test_conns()) { # Define two test datasets From e2daf77d23dd60b9f693517c0ae609335e3a1217 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rasmus=20Skytte=20Randl=C3=B8v?= Date: Tue, 15 Oct 2024 12:55:28 +0200 Subject: [PATCH 16/17] chore(db_joins): Add visible binding to `modifyList` --- R/db_joins.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/R/db_joins.R b/R/db_joins.R index cb706848..6bf5e3f2 100644 --- a/R/db_joins.R +++ b/R/db_joins.R @@ -228,7 +228,7 @@ join_merger <- function(by, na_by) { combined_join <- list( "exprs" = c(purrr::pluck(by, "exprs"), purrr::pluck(na_by, "exprs")) ) |> - modifyList( + utils::modifyList( purrr::map2(purrr::discard_at(by, "exprs"), purrr::discard_at(na_by, "exprs"), ~ c(.x, .y)) ) class(combined_join) <- "dplyr_join_by" @@ -283,7 +283,7 @@ join_na_sql <- function(x, y, by = NULL, .dots = NULL) { # Get the translation for matching the na_by component of the join subquery_args <- purrr::discard_at(.dots, "na_by") |> - modifyList( + utils::modifyList( list( x = x, y = y, From c84dffaea741255067a8ddb941e4cba91b03facb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rasmus=20Skytte=20Randl=C3=B8v?= Date: Tue, 15 Oct 2024 13:45:55 +0200 Subject: [PATCH 17/17] debug(db_joins): Inspect ON query --- R/db_joins.R | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/R/db_joins.R b/R/db_joins.R index 6bf5e3f2..700a256a 100644 --- a/R/db_joins.R +++ b/R/db_joins.R @@ -299,7 +299,12 @@ join_na_sql <- function(x, y, by = NULL, .dots = NULL) { na_matching <- na_subquery |> stringr::str_remove_all(stringr::fixed("\n")) |> # Remove newlines from the formatted query stringr::str_replace_all(r"{\s{2,}}", " ") |> # Remove multiple spaces from the formatted query - stringr::str_extract(r"{(?<=ON \().*(?=\))}") |> # Extract the contents of the ON statement + stringr::str_extract(r"{(?<=ON \().*(?=\))}") # Extract the contents of the ON statement + + print("ON subquery") + print(na_matching) + + na_matching <- na_matching |> stringr::str_extract(pattern = r"{(?:["'`]\s)([\w\s]+)(?:\s["'`])}", group = 1) # First non quoted word(s) # Replace NA equals with NA matching statement