diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 0cc9deec..d4116c17 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -539,6 +539,8 @@ def test_oauth2_authentication_missing_headers(header, error): 'Bearer x_token_server="{token_server}", x_redirect_server="{redirect_server}"', 'Basic realm="Trino", Bearer x_redirect_server="{redirect_server}", x_token_server="{token_server}"', 'Bearer x_redirect_server="{redirect_server}", x_token_server="{token_server}", Basic realm="Trino"', + 'realm="Trino", Bearer realm="Trino", token_type="JWT", Bearer x_redirect_server="{redirect_server}", ' + 'x_token_server="{token_server}"' ]) @httprettified def test_oauth2_header_parsing(header, sample_post_response_data): diff --git a/trino/auth.py b/trino/auth.py index 6262f95a..5b7b20ff 100644 --- a/trino/auth.py +++ b/trino/auth.py @@ -22,7 +22,6 @@ from requests import PreparedRequest, Request, Response, Session from requests.auth import AuthBase, extract_cookies_to_jar -from requests.utils import parse_dict_header import trino.logging from trino.client import exceptions @@ -421,10 +420,13 @@ def _attempt_oauth(self, response: Response, **kwargs: Any) -> None: if not _OAuth2TokenBearer._BEARER_PREFIX.search(auth_info): raise exceptions.TrinoAuthError(f"Error: header info didn't match {auth_info}") - auth_info_headers = parse_dict_header( - _OAuth2TokenBearer._BEARER_PREFIX.sub("", auth_info, count=1)) # type: ignore + # Example www-authenticate header value: + # 'Basic realm="Trino", Bearer realm="Trino", token_type="JWT", + # Bearer x_redirect_server="https://trino.com/oauth2/token/uuid4", + # x_token_server="https://trino.com/oauth2/token/uuid4"' + auth_info_headers = self._parse_authenticate_header(auth_info) - auth_server = auth_info_headers.get('x_redirect_server') + auth_server = auth_info_headers.get('bearer x_redirect_server', auth_info_headers.get('x_redirect_server')) token_server = auth_info_headers.get('x_token_server') if token_server is None: raise exceptions.TrinoAuthError("Error: header info didn't have x_token_server") @@ -510,6 +512,21 @@ def _construct_cache_key(host: Optional[str], user: Optional[str]) -> Optional[s else: return f"{host}@{user}" + @staticmethod + def _parse_authenticate_header(header: str) -> Dict[str, str]: + split_challenge = header.split(" ", 1) + trimmed_challenge = split_challenge[1] if len(split_challenge) > 1 else "" + auth_info_headers = {} + + for item in trimmed_challenge.split(","): + comps = item.split("=") + if len(comps) == 2: + key = comps[0].strip(' "') + value = comps[1].strip(' "') + if key: + auth_info_headers[key.lower()] = value + return auth_info_headers + class OAuth2Authentication(Authentication): def __init__(self, redirect_auth_url_handler: CompositeRedirectHandler = CompositeRedirectHandler([