From c0ea6ad9f57a7ebb44c5ce79dd235c60cfc5ce0a Mon Sep 17 00:00:00 2001 From: Aaron Jacobs Date: Wed, 27 Nov 2024 17:11:24 -0500 Subject: [PATCH] Add support for viewer-based credentials for Databricks & Cortex. This commit brings support for Posit Connect's "viewer-based credentials" feature [0] to the Databricks and Cortex chatbots. Similar to the recent work in `odbc` [1], the way this is exposed to R users is to require them to pass a Shiny session argument to `chat_databricks()` or `chat_cortex()`. Checks for viewer-based credentials are designed to fall back gracefully to existing authentication methods. This is intended to allow users to -- for example -- develop and test a Shiny app that uses Databricks or Snowflake credentials in desktop Positron/RStudio or Posit Workbench and deploy it with no code changes to Connect. Most of the actual work is outsourced to a new shared package, `connectcreds` [2]. Note that this commit also brings the internal auth handling for Snowflake much closer to Databricks, notably by making `cortex_credentials()` internal and analogous to `databricks_token()`. Unit tests are included. [0]: https://docs.posit.co/connect/user/oauth-integrations/ [1]: https://github.com/r-dbi/odbc/pull/853 [2]: https://github.com/posit-dev/connectcreds/ Signed-off-by: Aaron Jacobs --- DESCRIPTION | 4 +- NAMESPACE | 1 - R/provider-cortex.R | 76 +++++++++++++------- R/provider-databricks.R | 30 ++++++-- man/chat_cortex.Rd | 16 ++--- man/chat_databricks.Rd | 6 +- tests/testthat/_snaps/provider-cortex.md | 51 +++++++++++++ tests/testthat/_snaps/provider-databricks.md | 54 ++++++++++++++ tests/testthat/test-provider-cortex.R | 44 ++++++++++++ tests/testthat/test-provider-databricks.R | 38 ++++++++++ 10 files changed, 279 insertions(+), 41 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 289e06f9..358b8154 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -24,6 +24,7 @@ Imports: Suggests: base64enc, bslib, + connectcreds, curl (>= 6.0.1), gitcreds, knitr, @@ -39,7 +40,8 @@ VignetteBuilder: knitr Remotes: r-lib/httr2, - jcheng5/shinychat + jcheng5/shinychat, + posit-dev/connectcreds Config/Needs/website: tidyverse/tidytemplate, rmarkdown Config/testthat/edition: 3 Config/testthat/parallel: true diff --git a/NAMESPACE b/NAMESPACE index d5fb98a9..d38564ba 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -24,7 +24,6 @@ export(chat_perplexity) export(content_image_file) export(content_image_plot) export(content_image_url) -export(cortex_credentials) export(create_tool_def) export(interpolate) export(interpolate_file) diff --git a/R/provider-cortex.R b/R/provider-cortex.R index ae9c4da5..d36144ac 100644 --- a/R/provider-cortex.R +++ b/R/provider-cortex.R @@ -20,6 +20,16 @@ NULL #' previous messages. Nor does it support registering tools, and attempting to #' do so will result in an error. #' +#' `chat_cortex()` picks up the following ambient Snowflake credentials: +#' +#' - A static OAuth token defined via the `SNOWFLAKE_TOKEN` environment +#' variable. +#' - Key-pair authentication credentials defined via the `SNOWFLAKE_USER` and +#' `SNOWFLAKE_PRIVATE_KEY` (which can be a PEM-encoded private key or a path +#' to one) environment variables. +#' - Posit Workbench-managed Snowflake credentials for the corresponding +#' `account`. +#' #' @param account A Snowflake [account identifier](https://docs.snowflake.com/en/user-guide/admin-account-identifier), #' e.g. `"testorg-test_account"`. #' @param credentials A list of authentication headers to pass into @@ -32,6 +42,7 @@ NULL #' @param model_file Path to a semantic model file stored in a Snowflake Stage, #' or `NULL` when using `model_spec` instead. #' @inheritParams chat_openai +#' @inheritParams chat_databricks #' @inherit chat_openai return #' @family chatbots #' @examplesIf elmer:::cortex_credentials_exist() @@ -41,40 +52,53 @@ NULL #' chat$chat("What questions can I ask?") #' @export chat_cortex <- function(account = Sys.getenv("SNOWFLAKE_ACCOUNT"), - credentials = cortex_credentials, + credentials = NULL, model_spec = NULL, model_file = NULL, api_args = list(), - echo = c("none", "text", "all")) { + echo = c("none", "text", "all"), + session = NULL) { check_string(account, allow_empty = FALSE) check_string(model_spec, allow_empty = FALSE, allow_null = TRUE) check_string(model_file, allow_empty = FALSE, allow_null = TRUE) check_exclusive(model_spec, model_file) echo <- check_echo(echo) + if (!is.null(session)) { + check_installed("connectcreds", "for viewer-based authentication") + if (!connectcreds::has_viewer_token(session, snowflake_url(account))) { + session <- NULL + } + } if (is_list(credentials)) { static_credentials <- force(credentials) credentials <- function(account) static_credentials } - check_function(credentials) + check_function(credentials, allow_null = TRUE) provider <- ProviderCortex( account = account, credentials = credentials, model_spec = model_spec, model_file = model_file, - extra_args = api_args + extra_args = api_args, + session = session ) Chat$new(provider = provider, turns = NULL, echo = echo) } +snowflake_url <- function(account) { + paste0("https://", account, ".snowflakecomputing.com") +} + ProviderCortex <- new_class( "ProviderCortex", parent = Provider, constructor = function(account, credentials, model_spec = NULL, - model_file = NULL, extra_args = list()) { - base_url <- paste0("https://", account, ".snowflakecomputing.com") + model_file = NULL, extra_args = list(), + session = NULL) { + base_url <- snowflake_url(account) extra_args <- compact(list2( semantic_model = model_spec, semantic_model_file = model_file, @@ -88,8 +112,9 @@ ProviderCortex <- new_class( }, properties = list( account = prop_string(), - credentials = class_function, - extra_args = class_list + credentials = class_function | NULL, + extra_args = class_list, + session = class_list | NULL ) ) @@ -110,9 +135,12 @@ method(chat_request, ProviderCortex) <- function(provider, req <- request(provider@base_url) req <- req_url_path_append(req, "/api/v2/cortex/analyst/message") - req <- httr2::req_headers(req, - !!!provider@credentials(provider@account), .redact = "Authorization" + creds <- cortex_credentials( + provider@account, + provider@credentials, + provider@session ) + req <- httr2::req_headers(req, !!!creds, .redact = "Authorization") req <- req_retry(req, max_tries = 2) req <- req_timeout(req, 60) @@ -348,21 +376,19 @@ cortex_credentials_exist <- function(...) { tryCatch(is_list(cortex_credentials(...)), error = function(e) FALSE) } -#' @details -#' `cortex_credentials()` picks up the following ambient Snowflake credentials: -#' -#' - A static OAuth token defined via the `SNOWFLAKE_TOKEN` environment -#' variable. -#' - Key-pair authentication credentials defined via the `SNOWFLAKE_USER` and -#' `SNOWFLAKE_PRIVATE_KEY` (which can be a PEM-encoded private key or a path -#' to one) environment variables. -#' - Posit Workbench-managed Snowflake credentials for the corresponding -#' `account`. -#' -#' @inheritParams chat_cortex -#' @export -#' @rdname chat_cortex -cortex_credentials <- function(account = Sys.getenv("SNOWFLAKE_ACCOUNT")) { +cortex_credentials <- function(account = Sys.getenv("SNOWFLAKE_ACCOUNT"), + credentials = NULL, + session = NULL) { + # Session credentials take precedence over static credentials. + if (!is.null(session)) { + return(connectcreds::connect_viewer_token(session, snowflake_url(account))) + } + + # User-supplied credentials. + if (!is.null(credentials)) { + return(credentials(account)) + } + token <- Sys.getenv("SNOWFLAKE_TOKEN") if (nchar(token) != 0) { return( diff --git a/R/provider-databricks.R b/R/provider-databricks.R index df602b16..1e4e61fc 100644 --- a/R/provider-databricks.R +++ b/R/provider-databricks.R @@ -35,6 +35,8 @@ #' - `databricks-meta-llama-3-1-405b-instruct` #' @param token An authentication token for the Databricks workspace, or #' `NULL` to use ambient credentials. +#' @param session A Shiny session object, when using viewer-based credentials on +#' Posit Connect. #' @inheritParams chat_openai #' @inherit chat_openai return #' @export @@ -44,12 +46,19 @@ chat_databricks <- function(workspace = databricks_workspace(), model = NULL, token = NULL, api_args = list(), - echo = c("none", "text", "all")) { + echo = c("none", "text", "all"), + session = NULL) { check_string(workspace, allow_empty = FALSE) check_string(token, allow_empty = FALSE, allow_null = TRUE) model <- set_default(model, "databricks-dbrx-instruct") turns <- normalize_turns(turns, system_prompt) echo <- check_echo(echo) + if (!is.null(session)) { + check_installed("connectcreds", "for viewer-based authentication") + if (!connectcreds::has_viewer_token(session, workspace)) { + session <- NULL + } + } provider <- ProviderDatabricks( base_url = workspace, model = model, @@ -57,7 +66,8 @@ chat_databricks <- function(workspace = databricks_workspace(), token = token, # Databricks APIs use bearer tokens, not API keys, but we need to pass an # empty string here anyway to make S7::validate() happy. - api_key = "" + api_key = "", + session = session ) Chat$new(provider = provider, turns = turns, echo = echo) } @@ -65,7 +75,10 @@ chat_databricks <- function(workspace = databricks_workspace(), ProviderDatabricks <- new_class( "ProviderDatabricks", parent = ProviderOpenAI, - properties = list(token = prop_string(allow_null = TRUE)) + properties = list( + token = prop_string(allow_null = TRUE), + session = class_list | NULL + ) ) method(chat_request, ProviderDatabricks) <- function(provider, @@ -80,7 +93,7 @@ method(chat_request, ProviderDatabricks) <- function(provider, # `/serving-endpoints//invocations`. req <- req_url_path_append(req, "/serving-endpoints/chat/completions") req <- req_auth_bearer_token(req, - databricks_token(provider@base_url, provider@token) + databricks_token(provider@base_url, provider@token, provider@session) ) req <- req_retry(req, max_tries = 2) req <- req_error(req, body = function(resp) { @@ -165,9 +178,16 @@ databricks_workspace <- function() { # Try various ways to get Databricks credentials. This implements a subset of # the "Databricks client unified authentication" model. -databricks_token <- function(workspace = databricks_workspace(), token = NULL) { +databricks_token <- function(workspace = databricks_workspace(), + token = NULL, + session = NULL) { host <- gsub("https://|/$", "", workspace) + # Session credentials take precedence over static credentials. + if (!is.null(session)) { + return(connectcreds::connect_viewer_token(session, workspace)) + } + # An explicit bearer token takes precedence over everything else. token <- token %||% Sys.getenv("DATABRICKS_TOKEN") if (nchar(token)) { diff --git a/man/chat_cortex.Rd b/man/chat_cortex.Rd index 69e60856..e000646a 100644 --- a/man/chat_cortex.Rd +++ b/man/chat_cortex.Rd @@ -2,19 +2,17 @@ % Please edit documentation in R/provider-cortex.R \name{chat_cortex} \alias{chat_cortex} -\alias{cortex_credentials} \title{Create a chatbot that speaks to the Snowflake Cortex Analyst} \usage{ chat_cortex( account = Sys.getenv("SNOWFLAKE_ACCOUNT"), - credentials = cortex_credentials, + credentials = NULL, model_spec = NULL, model_file = NULL, api_args = list(), - echo = c("none", "text", "all") + echo = c("none", "text", "all"), + session = NULL ) - -cortex_credentials(account = Sys.getenv("SNOWFLAKE_ACCOUNT")) } \arguments{ \item{account}{A Snowflake \href{https://docs.snowflake.com/en/user-guide/admin-account-identifier}{account identifier}, @@ -44,6 +42,9 @@ the console). } Note this only affects the \code{chat()} method.} + +\item{session}{A Shiny session object, when using viewer-based credentials on +Posit Connect.} } \value{ A \link{Chat} object. @@ -60,9 +61,8 @@ reference to an existing file in a Snowflake Stage. Note that Cortex does not support multi-turn, so it will not remember previous messages. Nor does it support registering tools, and attempting to do so will result in an error. -} -\details{ -\code{cortex_credentials()} picks up the following ambient Snowflake credentials: + +\code{chat_cortex()} picks up the following ambient Snowflake credentials: \itemize{ \item A static OAuth token defined via the \code{SNOWFLAKE_TOKEN} environment variable. diff --git a/man/chat_databricks.Rd b/man/chat_databricks.Rd index c9cf7fbc..5879c9a1 100644 --- a/man/chat_databricks.Rd +++ b/man/chat_databricks.Rd @@ -11,7 +11,8 @@ chat_databricks( model = NULL, token = NULL, api_args = list(), - echo = c("none", "text", "all") + echo = c("none", "text", "all"), + session = NULL ) } \arguments{ @@ -51,6 +52,9 @@ the console). } Note this only affects the \code{chat()} method.} + +\item{session}{A Shiny session object, when using viewer-based credentials on +Posit Connect.} } \value{ A \link{Chat} object. diff --git a/tests/testthat/_snaps/provider-cortex.md b/tests/testthat/_snaps/provider-cortex.md index 2860b513..f571f66a 100644 --- a/tests/testthat/_snaps/provider-cortex.md +++ b/tests/testthat/_snaps/provider-cortex.md @@ -77,3 +77,54 @@ [1] "@my_db.my_schema.my_stage/model.yaml" +# the session parameter is ignored when not on Connect + + Code + . <- chat_cortex("testorg-test_account", model_file = "model.yaml", session = session) + Message + ! Ignoring the `session` parameter. + i Viewer-based credentials are only available when running on Connect. + +# missing viewer credentials generate errors on Connect + + Code + . <- chat_cortex("testorg-test_account", model_file = "model.yaml", session = session) + Condition + Error in `connectcreds::has_viewer_token()`: + ! Cannot fetch viewer-based credentials for the current Shiny session. + Caused by error in `connect_viewer_token()`: + ! Viewer-based credentials are not supported by this version of Connect. + +# token exchange requests to Connect look correct + + Code + list(url = req$url, headers = req$headers, body = req$body$data) + Output + $url + [1] "localhost:3030/__api__/v1/oauth/integrations/credentials" + + $headers + $headers$Authorization + [1] "Key key" + + $headers$Accept + [1] "application/json" + + attr(,"redact") + [1] "Authorization" + + $body + $body$grant_type + [1] "urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Atoken-exchange" + + $body$subject_token + [1] "user-token" + + $body$subject_token_type + [1] "urn%3Aposit%3Aconnect%3Auser-session-token" + + $body$resource + [1] "https%3A%2F%2Ftestorg-test_account.snowflakecomputing.com" + + + diff --git a/tests/testthat/_snaps/provider-databricks.md b/tests/testthat/_snaps/provider-databricks.md index 01411386..cec4530d 100644 --- a/tests/testthat/_snaps/provider-databricks.md +++ b/tests/testthat/_snaps/provider-databricks.md @@ -32,3 +32,57 @@ +# the session parameter is ignored when not on Connect + + Code + . <- chat_databricks(session = session) + Message + Using model = "databricks-dbrx-instruct". + ! Ignoring the `session` parameter. + i Viewer-based credentials are only available when running on Connect. + +# missing viewer credentials generate errors on Connect + + Code + . <- chat_databricks(session = session) + Message + Using model = "databricks-dbrx-instruct". + Condition + Error in `connectcreds::has_viewer_token()`: + ! Cannot fetch viewer-based credentials for the current Shiny session. + Caused by error in `connect_viewer_token()`: + ! Viewer-based credentials are not supported by this version of Connect. + +# token exchange requests to Connect look correct + + Code + list(url = req$url, headers = req$headers, body = req$body$data) + Output + $url + [1] "localhost:3030/__api__/v1/oauth/integrations/credentials" + + $headers + $headers$Authorization + [1] "Key key" + + $headers$Accept + [1] "application/json" + + attr(,"redact") + [1] "Authorization" + + $body + $body$grant_type + [1] "urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Atoken-exchange" + + $body$subject_token + [1] "user-token" + + $body$subject_token_type + [1] "urn%3Aposit%3Aconnect%3Auser-session-token" + + $body$resource + [1] "https%3A%2F%2Fexample.cloud.databricks.com" + + + diff --git a/tests/testthat/test-provider-cortex.R b/tests/testthat/test-provider-cortex.R index 853c8c3e..332d8fa1 100644 --- a/tests/testthat/test-provider-cortex.R +++ b/tests/testthat/test-provider-cortex.R @@ -142,3 +142,47 @@ verified_queries: [] # Note: It may not be 100 percent certain this will be in the output. expect_match(resp, "semantic data model") }) + +# Auth -------------------------------------------------------------------- + +test_that("the session parameter is ignored when not on Connect", { + session <- structure(list(request = list()), class = "ShinySession") + expect_snapshot(. <- chat_cortex( + "testorg-test_account", + model_file = "model.yaml", + session = session + )) +}) + +test_that("missing viewer credentials generate errors on Connect", { + # Mock a Connect environment that *does not* support viewer-based credentials. + withr::local_envvar(RSTUDIO_PRODUCT = "CONNECT") + session <- structure(list(request = list()), class = "ShinySession") + expect_snapshot(. <- chat_cortex( + "testorg-test_account", + model_file = "model.yaml", + session = session + ), error = TRUE) +}) + +test_that("token exchange requests to Connect look correct", { + # Mock a Connect environment that supports viewer-based credentials. + withr::local_envvar( + SNOWFLAKE_ACCOUNT = "testorg-test_account", + RSTUDIO_PRODUCT = "CONNECT", + CONNECT_SERVER = "localhost:3030", + CONNECT_API_KEY = "key" + ) + local_mocked_responses(function(req) { + # Snapshot relevant fields of the outgoing request. + expect_snapshot( + list(url = req$url, headers = req$headers, body = req$body$data) + ) + response_json(body = list(access_token = "token")) + }) + session <- structure( + list(request = list(HTTP_POSIT_CONNECT_USER_SESSION_TOKEN = "user-token")), + class = "ShinySession" + ) + expect_equal(cortex_credentials(session = session), "token") +}) diff --git a/tests/testthat/test-provider-databricks.R b/tests/testthat/test-provider-databricks.R index c826eba0..e59c0b77 100644 --- a/tests/testthat/test-provider-databricks.R +++ b/tests/testthat/test-provider-databricks.R @@ -83,3 +83,41 @@ test_that("M2M authentication requests look correct", { }) expect_equal(databricks_token(), "token") }) + +test_that("the session parameter is ignored when not on Connect", { + withr::local_envvar(DATABRICKS_HOST = "https://example.cloud.databricks.com") + session <- structure(list(request = list()), class = "ShinySession") + expect_snapshot(. <- chat_databricks(session = session)) +}) + +test_that("missing viewer credentials generate errors on Connect", { + # Mock a Connect environment that *does not* support viewer-based credentials. + withr::local_envvar( + DATABRICKS_HOST = "https://example.cloud.databricks.com", + RSTUDIO_PRODUCT = "CONNECT" + ) + session <- structure(list(request = list()), class = "ShinySession") + expect_snapshot(. <- chat_databricks(session = session), error = TRUE) +}) + +test_that("token exchange requests to Connect look correct", { + # Mock a Connect environment that supports viewer-based credentials. + withr::local_envvar( + DATABRICKS_HOST = "https://example.cloud.databricks.com", + RSTUDIO_PRODUCT = "CONNECT", + CONNECT_SERVER = "localhost:3030", + CONNECT_API_KEY = "key" + ) + local_mocked_responses(function(req) { + # Snapshot relevant fields of the outgoing request. + expect_snapshot( + list(url = req$url, headers = req$headers, body = req$body$data) + ) + response_json(body = list(access_token = "token")) + }) + session <- structure( + list(request = list(HTTP_POSIT_CONNECT_USER_SESSION_TOKEN = "user-token")), + class = "ShinySession" + ) + expect_equal(databricks_token(session = session), "token") +})