diff --git a/requests_oauth2client/auth.py b/requests_oauth2client/auth.py index 2078429..4892109 100644 --- a/requests_oauth2client/auth.py +++ b/requests_oauth2client/auth.py @@ -1,4 +1,5 @@ """This module contains requests-compatible Auth Handlers that implement OAuth 2.0.""" +from __future__ import annotations from typing import TYPE_CHECKING, Any, Optional, Union @@ -17,7 +18,7 @@ class BearerAuth(requests.auth.AuthBase): """An Auth Handler that includes a Bearer Token in API calls, as defined in [RFC6750$2.1]. As a prerequisite to using this `AuthBase`, you have to obtain an access token manually. - You most likely don't want do to that by yourself, but instead use an instance of + You most likely don't want to do that by yourself, but instead use an instance of [OAuth2Client][requests_oauth2client.client.OAuth2Client] to do that for you. See the others Auth Handlers in this module, which will automatically obtain access tokens from an OAuth 2.x server. @@ -40,7 +41,7 @@ class BearerAuth(requests.auth.AuthBase): token: a [BearerToken][requests_oauth2client.tokens.BearerToken] or a string to use as token for this Auth Handler. If `None`, this Auth Handler is a no op. """ - def __init__(self, token: Optional[Union[str, BearerToken]] = None) -> None: + def __init__(self, token: Union[str, BearerToken, None] = None) -> None: self.token = token # type: ignore[assignment] # until https://github.com/python/mypy/issues/3004 is fixed @property @@ -53,7 +54,7 @@ def token(self) -> Optional[BearerToken]: return self._token @token.setter - def token(self, token: Union[str, BearerToken]) -> None: + def token(self, token: Union[str, BearerToken, None]) -> None: """Change the access token used with this AuthHandler. Accepts a [BearerToken][requests_oauth2client.tokens.BearerToken] or an access token as `str`. @@ -88,7 +89,55 @@ def __call__(self, request: requests.PreparedRequest) -> requests.PreparedReques return request -class OAuth2ClientCredentialsAuth(BearerAuth): +class BaseOAuth2RenewableTokenAuth(BearerAuth): + """Base class for Bearer Token based Auth Handlers, with on obtainable or renewable token. + + In addition to adding a properly formatted `Authorization` header, this will obtain a new token + once the current token is expired. + Expiration is detected based on the `expires_in` hint returned by the AS. + A configurable `leeway`, in number of seconds, will make sure that a new token is obtained some seconds before the + actual expiration is reached. This may help in situations where the client, AS and RS have slightly offset clocks. + + Args: + client: an OAuth2Client + token: an initial Access Token, if you have one already. In most cases, leave `None`. + leeway: expiration leeway, in number of seconds + token_kwargs: additional kwargs to include in token requests + """ + + def __init__( + self, + client: OAuth2Client, + token: Union[None, BearerToken, str] = None, + leeway: int = 20, + **token_kwargs: Any, + ) -> None: + super().__init__(token) + self.client = client + self.leeway = leeway + self.token_kwargs = token_kwargs + + def __call__( + self, request: requests.PreparedRequest + ) -> requests.PreparedRequest: # noqa: D102 + token = self.token + if token is None or token.is_expired(self.leeway): + self.renew_token() + return super().__call__(request) + + def renew_token(self) -> None: + """Obtain a new Bearer Token. + + This should be implemented by subclasses. + """ + raise NotImplementedError + + def forget_token(self) -> None: + """Forget the current token, forcing a renewal on the next usage of this Auth Handler.""" + self.token = None + + +class OAuth2ClientCredentialsAuth(BaseOAuth2RenewableTokenAuth): """An Auth Handler for the Client Credentials grant. This [requests AuthBase][requests.auth.AuthBase] automatically gets Access @@ -109,38 +158,16 @@ class OAuth2ClientCredentialsAuth(BearerAuth): ``` """ - def __init__(self, client: "OAuth2Client", **token_kwargs: Any): - super().__init__(None) - self.client = client - self.token_kwargs = token_kwargs - - def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest: - """Implement the Client Credentials grant as an Auth Handler. - - This will obtain a token using the Client Credentials Grant, and include that token in requests. - Once the token is expired (detected using the 'expires_in' hint), it will obtain a new token. - - Args: - request: a [PreparedRequest][requests.PreparedRequest] - - Returns: - a [PreparedRequest][requests.PreparedRequest] with an Access Token added in Authorization Header - """ - token = self.token - if token is None or token.is_expired(): - self.renew_token() - return super().__call__(request) - def renew_token(self) -> None: """Obtain a new token for use within this Auth Handler.""" self.token = self.client.client_credentials(**self.token_kwargs) -class OAuth2AccessTokenAuth(BearerAuth): - """Authenticaton Handler for OAuth 2.0 Access Tokens and (optional) Refresh Tokens. +class OAuth2AccessTokenAuth(BaseOAuth2RenewableTokenAuth): + """Authentication Handler for OAuth 2.0 Access Tokens and (optional) Refresh Tokens. This [Requests Auth handler][requests.auth.AuthBase] implementation uses an access token as Bearer token, and can - automatically refreshes it when expired, if a refresh token is available. + automatically refresh it when expired, if a refresh token is available. Token can be a simple `str` containing a raw access token value, or a [BearerToken][requests_oauth2client.tokens.BearerToken] that can contain a refresh_token. If a refresh_token and an expiration date are available, this Auth Handler @@ -163,34 +190,8 @@ class OAuth2AccessTokenAuth(BearerAuth): ```` """ - def __init__( - self, - client: "OAuth2Client", - token: Optional[Union[str, BearerToken]] = None, - **token_kwargs: Any, - ) -> None: - super().__init__(token) - self.client = client - self.token_kwargs = token_kwargs - - def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest: - """Implement the usage of OAuth 2.0 access tokens as Bearer Tokens. - - This adds access token in requests, and refreshes that token once it is expired. - - Args: - request: a [PreparedRequest][requests.PreparedRequest] - - Returns: - a [PreparedRequest][requests.PreparedRequest] with an Access Token added in Authorization Header - """ - token = self.token - if token is not None and token.is_expired(): - self.renew_token() - return super().__call__(request) - def renew_token(self) -> None: - """Obtain a new token, by using the Refresh Token, if available.""" + """Obtain a new token, using the Refresh Token, if available.""" if self.token and self.token.refresh_token and self.client is not None: self.token = self.client.refresh_token( refresh_token=self.token.refresh_token, **self.token_kwargs @@ -218,13 +219,13 @@ class OAuth2AuthorizationCodeAuth(OAuth2AccessTokenAuth): def __init__( self, - client: "OAuth2Client", + client: OAuth2Client, code: Union[str, AuthorizationResponse], + leeway: int = 20, **token_kwargs: Any, ) -> None: - super().__init__(client, None) + super().__init__(client, token=None, leeway=leeway, **token_kwargs) self.code: Union[str, AuthorizationResponse, None] = code - self.token_kwargs = token_kwargs def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest: """Implement the Authorization Code grant as an Authentication Handler. @@ -276,17 +277,17 @@ class OAuth2DeviceCodeAuth(OAuth2AccessTokenAuth): def __init__( self, - client: "OAuth2Client", + client: OAuth2Client, device_code: Union[str, DeviceAuthorizationResponse], + leeway: int = 20, interval: int = 5, expires_in: int = 360, **token_kwargs: Any, ) -> None: - super().__init__(client, None) + super().__init__(client=client, leeway=leeway, token=None, **token_kwargs) self.device_code: Union[str, DeviceAuthorizationResponse, None] = device_code self.interval = interval self.expires_in = expires_in - self.token_kwargs = token_kwargs def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest: """Implement the Device Code grant as a request Authentication Handler. diff --git a/requests_oauth2client/flask/auth.py b/requests_oauth2client/flask/auth.py index edce0ea..297b135 100644 --- a/requests_oauth2client/flask/auth.py +++ b/requests_oauth2client/flask/auth.py @@ -1,6 +1,6 @@ """Helper classes for the [Flask](https://flask.palletsprojects.com) framework.""" -from typing import Any, Optional +from typing import Any, Optional, Union from flask import session @@ -12,14 +12,22 @@ class FlaskSessionAuthMixin: """A Mixin for auth handlers to store their tokens in Flask session. - Storing tokens in Flask session does ensure that each user of a Flask application has a different access token, and that tokens will be persisted between multiple requests to the front-end Flask app. + Storing tokens in Flask session does ensure that each user of a Flask application has a different access token, and + that tokens used for backend API access will be persisted between multiple requests to the front-end Flask app. Args: session_key: the key that will be used to store the access token in session. serializer: the serializer that will be used to store the access token in session. """ - def __init__(self, session_key: str, serializer: Optional[BearerTokenSerializer] = None): + def __init__( + self, + session_key: str, + serializer: Optional[BearerTokenSerializer] = None, + *args: Any, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) self.serializer = serializer or BearerTokenSerializer() self.session_key = session_key @@ -28,7 +36,7 @@ def token(self) -> Optional[BearerToken]: """Return the Access Token stored in session. Returns: - The current BearerToken for this session, if any. + The current `BearerToken` for this session, if any. """ serialized_token = session.get(self.session_key) if serialized_token is None: @@ -36,25 +44,27 @@ def token(self) -> Optional[BearerToken]: return self.serializer.loads(serialized_token) @token.setter - def token(self, token: Optional[BearerToken]) -> None: + def token(self, token: Union[BearerToken, str, None]) -> None: """Store an Access Token in session. Args: token: the token to store """ + if isinstance(token, str): + token = BearerToken(token) # pragma: no cover if token: serialized_token = self.serializer.dumps(token) session[self.session_key] = serialized_token - else: + elif session and self.session_key in session: session.pop(self.session_key, None) class FlaskOAuth2ClientCredentialsAuth(FlaskSessionAuthMixin, OAuth2ClientCredentialsAuth): - """A `requests` Auth handler for CC that stores its token in Flask session. + """A `requests` Auth handler for CC grant that stores its token in Flask session. - It will automatically gets access tokens from an OAuth 2.x Token Endpoint - with the Client Credentials grant (and can get a new one once it is expired), - and stores the retrieved token in Flask `session`, so that each user has a different access token. + It will automatically get Access Tokens from an OAuth 2.x AS + with the Client Credentials grant (and can get a new one once the first one is expired), + and stores the retrieved token, serialized in Flask `session`, so that each user has a different access token. Args: client: an OAuth2Client that will be used to retrieve tokens. @@ -62,14 +72,3 @@ class FlaskOAuth2ClientCredentialsAuth(FlaskSessionAuthMixin, OAuth2ClientCreden serializer: a serializer that will be used to serialize the access token in Flask session **token_kwargs: additional kwargs for the Token Request """ - - def __init__( - self, - client: OAuth2Client, - session_key: str, - serializer: Optional[BearerTokenSerializer] = None, - **token_kwargs: Any, - ) -> None: - super().__init__(session_key, serializer) - self.client = client - self.token_kwargs = token_kwargs diff --git a/tests/unit_tests/test_flask.py b/tests/unit_tests/test_flask.py index c38d6a8..5dc66f0 100644 --- a/tests/unit_tests/test_flask.py +++ b/tests/unit_tests/test_flask.py @@ -2,6 +2,7 @@ from urllib.parse import parse_qs import pytest +from flask import request from requests_oauth2client import ApiClient, ClientSecretPost, OAuth2Client from tests.conftest import RequestsMocker @@ -23,13 +24,17 @@ def test_flask( from requests_oauth2client.flask import FlaskOAuth2ClientCredentialsAuth except ImportError: pytest.skip("Flask is not available") + return oauth_client = OAuth2Client(token_endpoint, ClientSecretPost(client_id, client_secret)) - api_client = ApiClient( - auth=FlaskOAuth2ClientCredentialsAuth( - oauth_client, session_key=session_key, scope=scope - ) + auth = FlaskOAuth2ClientCredentialsAuth( + session_key=session_key, + scope=scope, + client=oauth_client, ) + api_client = ApiClient(target_api, auth=auth) + + assert isinstance(api_client.session.auth, FlaskOAuth2ClientCredentialsAuth) app = Flask("testapp") app.config["TESTING"] = True @@ -37,7 +42,7 @@ def test_flask( @app.route("/api") def get() -> Any: - return api_client.get(target_api).json() + return api_client.get(params=request.args).json() access_token = "access_token" json_resp = {"status": "success"} @@ -48,22 +53,22 @@ def get() -> Any: requests_mock.get(target_api, json=json_resp) with app.test_client() as client: - resp = client.get("/api") + resp = client.get("/api?call=1") assert resp.json == json_resp - resp = client.get("/api") + resp = client.get("/api?call=2") assert resp.json == json_resp - # api_client.auth.token = None # strangely this has no effect in a test session - with client.session_transaction() as sess: # does what 'api_client.auth.token = None' should do - sess.pop("session_key", None) - resp = client.get("/api") + api_client.session.auth.forget_token() + # assert api_client.session.auth.token is None + with client.session_transaction() as sess: + sess.pop(auth.session_key) + # this should trigger a new token request then the API request + resp = client.get("/api?call=3") assert resp.json == json_resp - assert len(requests_mock.request_history) == 5 - token_request = requests_mock.request_history[0] - api_request1 = requests_mock.request_history[1] - api_request2 = requests_mock.request_history[2] + # assert api_client.session.auth.token == access_token - token_params = parse_qs(token_request.text) + token_request1 = requests_mock.request_history[0] + token_params = parse_qs(token_request1.text) assert token_params.get("client_id") == [client_id] if not scope: assert token_params.get("scope") is None @@ -73,5 +78,22 @@ def get() -> Any: assert token_params.get("scope") == [" ".join(scope)] assert token_params.get("client_secret") == [client_secret] + api_request1 = requests_mock.request_history[1] assert api_request1.headers.get("Authorization") == f"Bearer {access_token}" + + api_request2 = requests_mock.request_history[2] assert api_request2.headers.get("Authorization") == f"Bearer {access_token}" + + token_request2 = requests_mock.request_history[3] + token_params = parse_qs(token_request2.text) + assert token_params.get("client_id") == [client_id] + if not scope: + assert token_params.get("scope") is None + elif isinstance(scope, str): + assert token_params.get("scope") == [scope] + elif isinstance(scope, Iterable): + assert token_params.get("scope") == [" ".join(scope)] + assert token_params.get("client_secret") == [client_secret] + + api_request3 = requests_mock.request_history[4] + assert api_request3.headers.get("Authorization") == f"Bearer {access_token}"