Skip to content

Commit

Permalink
Cache OAuth access token per host and user pair
Browse files Browse the repository at this point in the history
  • Loading branch information
hovaesco authored and hashhar committed Feb 16, 2024
1 parent e90ba3a commit d8b3b68
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 8 deletions.
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,10 @@ the [OAuth2 authentication type](https://trino.io/docs/current/security/oauth2.h

A callback to handle the redirect url can be provided via param `redirect_auth_url_handler` of the `trino.auth.OAuth2Authentication` class. By default, it will try to launch a web browser (`trino.auth.WebBrowserRedirectHandler`) to go through the authentication flow and output the redirect url to stdout (`trino.auth.ConsoleRedirectHandler`). Multiple redirect handlers are combined using the `trino.auth.CompositeRedirectHandler` class.

The OAuth2 token will be cached either per `trino.auth.OAuth2Authentication` instance or, when keyring is installed, it will be cached within a secure backend (MacOS keychain, Windows credential locker, etc) under a key including host of the Trino connection. Keyring can be installed using `pip install 'trino[external-authentication-token-cache]'`.
The OAuth2 token will be cached either per `trino.auth.OAuth2Authentication` instance and username or, when keyring is installed, it will be cached within a secure backend (MacOS keychain, Windows credential locker, etc) under a key including host of the Trino connection. Keyring can be installed using `pip install 'trino[external-authentication-token-cache]'`.

> [!WARNING]
> If username is not specified then the OAuth2 token cache is shared and stored per host.

- DBAPI

Expand Down
33 changes: 26 additions & 7 deletions trino/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import re
import threading
import webbrowser
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple
from urllib.parse import urlparse

from requests import PreparedRequest, Request, Response, Session
Expand All @@ -26,6 +26,7 @@

import trino.logging
from trino.client import exceptions
from trino.constants import HEADER_USER

logger = trino.logging.get_logger(__name__)

Expand Down Expand Up @@ -218,7 +219,8 @@ def store_token_to_cache(self, key: Optional[str], token: str) -> None:

class _OAuth2TokenInMemoryCache(_OAuth2TokenCache):
"""
In-memory token cache implementation. The token is stored per host, so multiple clients can share the same cache.
Multiple clients can share the same cache only if each connection explicitly specifies
a user otherwise the first cached token will be used to authenticate all other users.
"""

def __init__(self) -> None:
Expand All @@ -233,7 +235,7 @@ def store_token_to_cache(self, key: Optional[str], token: str) -> None:

class _OAuth2KeyRingTokenCache(_OAuth2TokenCache):
"""
Keyring Token Cache implementation
Keyring token cache implementation
"""

def __init__(self) -> None:
Expand Down Expand Up @@ -268,7 +270,7 @@ def store_token_to_cache(self, key: Optional[str], token: str) -> None:

class _OAuth2TokenBearer(AuthBase):
"""
Custom implementation of Trino Oauth2 based authorization to get the token
Custom implementation of Trino OAuth2 based authentication to get the token
"""
MAX_OAUTH_ATTEMPTS = 5
_BEARER_PREFIX = re.compile(r"bearer", flags=re.IGNORECASE)
Expand All @@ -283,7 +285,9 @@ def __init__(self, redirect_auth_url_handler: Callable[[str], None]):

def __call__(self, r: PreparedRequest) -> PreparedRequest:
host = self._determine_host(r.url)
token = self._get_token_from_cache(host)
user = self._determine_user(r.headers)
key = self._construct_cache_key(host, user)
token = self._get_token_from_cache(key)

if token is not None:
r.headers['Authorization'] = "Bearer " + token
Expand Down Expand Up @@ -341,15 +345,19 @@ def _attempt_oauth(self, response: Response, **kwargs: Any) -> None:

request = response.request
host = self._determine_host(request.url)
self._store_token_to_cache(host, token)
user = self._determine_user(request.headers)
key = self._construct_cache_key(host, user)
self._store_token_to_cache(key, token)

def _retry_request(self, response: Response, **kwargs: Any) -> Optional[Response]:
request = response.request.copy()
extract_cookies_to_jar(request._cookies, response.request, response.raw) # type: ignore
request.prepare_cookies(request._cookies) # type: ignore

host = self._determine_host(response.request.url)
token = self._get_token_from_cache(host)
user = self._determine_user(request.headers)
key = self._construct_cache_key(host, user)
token = self._get_token_from_cache(key)
if token is not None:
request.headers['Authorization'] = "Bearer " + token
retry_response = response.connection.send(request, **kwargs) # type: ignore
Expand Down Expand Up @@ -394,6 +402,17 @@ def _store_token_to_cache(self, key: Optional[str], token: str) -> None:
def _determine_host(url: Optional[str]) -> Any:
return urlparse(url).hostname

@staticmethod
def _determine_user(headers: Mapping[Any, Any]) -> Optional[Any]:
return headers.get(HEADER_USER)

@staticmethod
def _construct_cache_key(host: Optional[str], user: Optional[str]) -> Optional[str]:
if user is None:
return host
else:
return f"{host}@{user}"


class OAuth2Authentication(Authentication):
def __init__(self, redirect_auth_url_handler: CompositeRedirectHandler = CompositeRedirectHandler([
Expand Down

0 comments on commit d8b3b68

Please sign in to comment.