Skip to content

Commit

Permalink
Add support for viewer-based credentials for Databricks & Cortex.
Browse files Browse the repository at this point in the history
This commit brings support for Posit Connect's "viewer-based
credentials" feature [0] to the Databricks and Cortex chatbots. Similar
to the recent work in `odbc` [1], the way this is exposed to R users is
to require them to pass a Shiny session argument to `chat_databricks()`
or `chat_cortex()`.

Checks for viewer-based credentials are designed to fall back gracefully
to existing authentication methods. This is intended to allow users to
-- for example -- develop and test a Shiny app that uses Databricks or
Snowflake credentials in desktop Positron/RStudio or Posit Workbench and
deploy it with no code changes to Connect.

Most of the actual work is outsourced to a new shared package,
`connectcreds` [2].

Note that this commit also brings the internal auth handling for
Snowflake much closer to Databricks, notably by making
`cortex_credentials()` internal and analogous to `databricks_token()`.

Unit tests are included.

[0]: https://docs.posit.co/connect/user/oauth-integrations/
[1]: r-dbi/odbc#853
[2]: https://github.com/posit-dev/connectcreds/

Signed-off-by: Aaron Jacobs <[email protected]>
  • Loading branch information
atheriel committed Nov 27, 2024
1 parent ad380e6 commit c0ea6ad
Show file tree
Hide file tree
Showing 10 changed files with 279 additions and 41 deletions.
4 changes: 3 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ Imports:
Suggests:
base64enc,
bslib,
connectcreds,
curl (>= 6.0.1),
gitcreds,
knitr,
Expand All @@ -39,7 +40,8 @@ VignetteBuilder:
knitr
Remotes:
r-lib/httr2,
jcheng5/shinychat
jcheng5/shinychat,
posit-dev/connectcreds
Config/Needs/website: tidyverse/tidytemplate, rmarkdown
Config/testthat/edition: 3
Config/testthat/parallel: true
Expand Down
1 change: 0 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ export(chat_perplexity)
export(content_image_file)
export(content_image_plot)
export(content_image_url)
export(cortex_credentials)
export(create_tool_def)
export(interpolate)
export(interpolate_file)
Expand Down
76 changes: 51 additions & 25 deletions R/provider-cortex.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,16 @@ NULL
#' previous messages. Nor does it support registering tools, and attempting to
#' do so will result in an error.
#'
#' `chat_cortex()` picks up the following ambient Snowflake credentials:
#'
#' - A static OAuth token defined via the `SNOWFLAKE_TOKEN` environment
#' variable.
#' - Key-pair authentication credentials defined via the `SNOWFLAKE_USER` and
#' `SNOWFLAKE_PRIVATE_KEY` (which can be a PEM-encoded private key or a path
#' to one) environment variables.
#' - Posit Workbench-managed Snowflake credentials for the corresponding
#' `account`.
#'
#' @param account A Snowflake [account identifier](https://docs.snowflake.com/en/user-guide/admin-account-identifier),
#' e.g. `"testorg-test_account"`.
#' @param credentials A list of authentication headers to pass into
Expand All @@ -32,6 +42,7 @@ NULL
#' @param model_file Path to a semantic model file stored in a Snowflake Stage,
#' or `NULL` when using `model_spec` instead.
#' @inheritParams chat_openai
#' @inheritParams chat_databricks
#' @inherit chat_openai return
#' @family chatbots
#' @examplesIf elmer:::cortex_credentials_exist()
Expand All @@ -41,40 +52,53 @@ NULL
#' chat$chat("What questions can I ask?")
#' @export
chat_cortex <- function(account = Sys.getenv("SNOWFLAKE_ACCOUNT"),
credentials = cortex_credentials,
credentials = NULL,
model_spec = NULL,
model_file = NULL,
api_args = list(),
echo = c("none", "text", "all")) {
echo = c("none", "text", "all"),
session = NULL) {
check_string(account, allow_empty = FALSE)
check_string(model_spec, allow_empty = FALSE, allow_null = TRUE)
check_string(model_file, allow_empty = FALSE, allow_null = TRUE)
check_exclusive(model_spec, model_file)
echo <- check_echo(echo)
if (!is.null(session)) {
check_installed("connectcreds", "for viewer-based authentication")
if (!connectcreds::has_viewer_token(session, snowflake_url(account))) {
session <- NULL
}
}

if (is_list(credentials)) {
static_credentials <- force(credentials)
credentials <- function(account) static_credentials
}
check_function(credentials)
check_function(credentials, allow_null = TRUE)

provider <- ProviderCortex(
account = account,
credentials = credentials,
model_spec = model_spec,
model_file = model_file,
extra_args = api_args
extra_args = api_args,
session = session
)

Chat$new(provider = provider, turns = NULL, echo = echo)
}

snowflake_url <- function(account) {
paste0("https://", account, ".snowflakecomputing.com")
}

ProviderCortex <- new_class(
"ProviderCortex",
parent = Provider,
constructor = function(account, credentials, model_spec = NULL,
model_file = NULL, extra_args = list()) {
base_url <- paste0("https://", account, ".snowflakecomputing.com")
model_file = NULL, extra_args = list(),
session = NULL) {
base_url <- snowflake_url(account)
extra_args <- compact(list2(
semantic_model = model_spec,
semantic_model_file = model_file,
Expand All @@ -88,8 +112,9 @@ ProviderCortex <- new_class(
},
properties = list(
account = prop_string(),
credentials = class_function,
extra_args = class_list
credentials = class_function | NULL,
extra_args = class_list,
session = class_list | NULL
)
)

Expand All @@ -110,9 +135,12 @@ method(chat_request, ProviderCortex) <- function(provider,

req <- request(provider@base_url)
req <- req_url_path_append(req, "/api/v2/cortex/analyst/message")
req <- httr2::req_headers(req,
!!!provider@credentials(provider@account), .redact = "Authorization"
creds <- cortex_credentials(
provider@account,
provider@credentials,
provider@session
)
req <- httr2::req_headers(req, !!!creds, .redact = "Authorization")
req <- req_retry(req, max_tries = 2)
req <- req_timeout(req, 60)

Expand Down Expand Up @@ -348,21 +376,19 @@ cortex_credentials_exist <- function(...) {
tryCatch(is_list(cortex_credentials(...)), error = function(e) FALSE)
}

#' @details
#' `cortex_credentials()` picks up the following ambient Snowflake credentials:
#'
#' - A static OAuth token defined via the `SNOWFLAKE_TOKEN` environment
#' variable.
#' - Key-pair authentication credentials defined via the `SNOWFLAKE_USER` and
#' `SNOWFLAKE_PRIVATE_KEY` (which can be a PEM-encoded private key or a path
#' to one) environment variables.
#' - Posit Workbench-managed Snowflake credentials for the corresponding
#' `account`.
#'
#' @inheritParams chat_cortex
#' @export
#' @rdname chat_cortex
cortex_credentials <- function(account = Sys.getenv("SNOWFLAKE_ACCOUNT")) {
cortex_credentials <- function(account = Sys.getenv("SNOWFLAKE_ACCOUNT"),
credentials = NULL,
session = NULL) {
# Session credentials take precedence over static credentials.
if (!is.null(session)) {
return(connectcreds::connect_viewer_token(session, snowflake_url(account)))
}

# User-supplied credentials.
if (!is.null(credentials)) {
return(credentials(account))
}

token <- Sys.getenv("SNOWFLAKE_TOKEN")
if (nchar(token) != 0) {
return(
Expand Down
30 changes: 25 additions & 5 deletions R/provider-databricks.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
#' - `databricks-meta-llama-3-1-405b-instruct`
#' @param token An authentication token for the Databricks workspace, or
#' `NULL` to use ambient credentials.
#' @param session A Shiny session object, when using viewer-based credentials on
#' Posit Connect.
#' @inheritParams chat_openai
#' @inherit chat_openai return
#' @export
Expand All @@ -44,28 +46,39 @@ chat_databricks <- function(workspace = databricks_workspace(),
model = NULL,
token = NULL,
api_args = list(),
echo = c("none", "text", "all")) {
echo = c("none", "text", "all"),
session = NULL) {
check_string(workspace, allow_empty = FALSE)
check_string(token, allow_empty = FALSE, allow_null = TRUE)
model <- set_default(model, "databricks-dbrx-instruct")
turns <- normalize_turns(turns, system_prompt)
echo <- check_echo(echo)
if (!is.null(session)) {
check_installed("connectcreds", "for viewer-based authentication")
if (!connectcreds::has_viewer_token(session, workspace)) {
session <- NULL
}
}
provider <- ProviderDatabricks(
base_url = workspace,
model = model,
extra_args = api_args,
token = token,
# Databricks APIs use bearer tokens, not API keys, but we need to pass an
# empty string here anyway to make S7::validate() happy.
api_key = ""
api_key = "",
session = session
)
Chat$new(provider = provider, turns = turns, echo = echo)
}

ProviderDatabricks <- new_class(
"ProviderDatabricks",
parent = ProviderOpenAI,
properties = list(token = prop_string(allow_null = TRUE))
properties = list(
token = prop_string(allow_null = TRUE),
session = class_list | NULL
)
)

method(chat_request, ProviderDatabricks) <- function(provider,
Expand All @@ -80,7 +93,7 @@ method(chat_request, ProviderDatabricks) <- function(provider,
# `/serving-endpoints/<model>/invocations`.
req <- req_url_path_append(req, "/serving-endpoints/chat/completions")
req <- req_auth_bearer_token(req,
databricks_token(provider@base_url, provider@token)
databricks_token(provider@base_url, provider@token, provider@session)
)
req <- req_retry(req, max_tries = 2)
req <- req_error(req, body = function(resp) {
Expand Down Expand Up @@ -165,9 +178,16 @@ databricks_workspace <- function() {

# Try various ways to get Databricks credentials. This implements a subset of
# the "Databricks client unified authentication" model.
databricks_token <- function(workspace = databricks_workspace(), token = NULL) {
databricks_token <- function(workspace = databricks_workspace(),
token = NULL,
session = NULL) {
host <- gsub("https://|/$", "", workspace)

# Session credentials take precedence over static credentials.
if (!is.null(session)) {
return(connectcreds::connect_viewer_token(session, workspace))
}

# An explicit bearer token takes precedence over everything else.
token <- token %||% Sys.getenv("DATABRICKS_TOKEN")
if (nchar(token)) {
Expand Down
16 changes: 8 additions & 8 deletions man/chat_cortex.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 5 additions & 1 deletion man/chat_databricks.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

51 changes: 51 additions & 0 deletions tests/testthat/_snaps/provider-cortex.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,54 @@
[1] "@my_db.my_schema.my_stage/model.yaml"

# the session parameter is ignored when not on Connect

Code
. <- chat_cortex("testorg-test_account", model_file = "model.yaml", session = session)
Message
! Ignoring the `session` parameter.
i Viewer-based credentials are only available when running on Connect.

# missing viewer credentials generate errors on Connect

Code
. <- chat_cortex("testorg-test_account", model_file = "model.yaml", session = session)
Condition
Error in `connectcreds::has_viewer_token()`:
! Cannot fetch viewer-based credentials for the current Shiny session.
Caused by error in `connect_viewer_token()`:
! Viewer-based credentials are not supported by this version of Connect.

# token exchange requests to Connect look correct

Code
list(url = req$url, headers = req$headers, body = req$body$data)
Output
$url
[1] "localhost:3030/__api__/v1/oauth/integrations/credentials"
$headers
$headers$Authorization
[1] "Key key"
$headers$Accept
[1] "application/json"
attr(,"redact")
[1] "Authorization"
$body
$body$grant_type
[1] "urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Atoken-exchange"
$body$subject_token
[1] "user-token"
$body$subject_token_type
[1] "urn%3Aposit%3Aconnect%3Auser-session-token"
$body$resource
[1] "https%3A%2F%2Ftestorg-test_account.snowflakecomputing.com"

Loading

0 comments on commit c0ea6ad

Please sign in to comment.