Skip to content

Commit

Permalink
implement expiration leeway in Bearer Token based auth handlers, #18 (#…
Browse files Browse the repository at this point in the history
…20)

* implement expiration leeway in Bearer Token based auth handlers, #18
  • Loading branch information
guillp authored Aug 14, 2023
1 parent 269df74 commit ad66ecf
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 99 deletions.
125 changes: 63 additions & 62 deletions requests_oauth2client/auth.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""This module contains requests-compatible Auth Handlers that implement OAuth 2.0."""
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Optional, Union

Expand All @@ -17,7 +18,7 @@ class BearerAuth(requests.auth.AuthBase):
"""An Auth Handler that includes a Bearer Token in API calls, as defined in [RFC6750$2.1].
As a prerequisite to using this `AuthBase`, you have to obtain an access token manually.
You most likely don't want do to that by yourself, but instead use an instance of
You most likely don't want to do that by yourself, but instead use an instance of
[OAuth2Client][requests_oauth2client.client.OAuth2Client] to do that for you.
See the others Auth Handlers in this module, which will automatically obtain access tokens from an OAuth 2.x server.
Expand All @@ -40,7 +41,7 @@ class BearerAuth(requests.auth.AuthBase):
token: a [BearerToken][requests_oauth2client.tokens.BearerToken] or a string to use as token for this Auth Handler. If `None`, this Auth Handler is a no op.
"""

def __init__(self, token: Optional[Union[str, BearerToken]] = None) -> None:
def __init__(self, token: Union[str, BearerToken, None] = None) -> None:
self.token = token # type: ignore[assignment] # until https://github.com/python/mypy/issues/3004 is fixed

@property
Expand All @@ -53,7 +54,7 @@ def token(self) -> Optional[BearerToken]:
return self._token

@token.setter
def token(self, token: Union[str, BearerToken]) -> None:
def token(self, token: Union[str, BearerToken, None]) -> None:
"""Change the access token used with this AuthHandler.
Accepts a [BearerToken][requests_oauth2client.tokens.BearerToken] or an access token as `str`.
Expand Down Expand Up @@ -88,7 +89,55 @@ def __call__(self, request: requests.PreparedRequest) -> requests.PreparedReques
return request


class OAuth2ClientCredentialsAuth(BearerAuth):
class BaseOAuth2RenewableTokenAuth(BearerAuth):
"""Base class for Bearer Token based Auth Handlers, with on obtainable or renewable token.
In addition to adding a properly formatted `Authorization` header, this will obtain a new token
once the current token is expired.
Expiration is detected based on the `expires_in` hint returned by the AS.
A configurable `leeway`, in number of seconds, will make sure that a new token is obtained some seconds before the
actual expiration is reached. This may help in situations where the client, AS and RS have slightly offset clocks.
Args:
client: an OAuth2Client
token: an initial Access Token, if you have one already. In most cases, leave `None`.
leeway: expiration leeway, in number of seconds
token_kwargs: additional kwargs to include in token requests
"""

def __init__(
self,
client: OAuth2Client,
token: Union[None, BearerToken, str] = None,
leeway: int = 20,
**token_kwargs: Any,
) -> None:
super().__init__(token)
self.client = client
self.leeway = leeway
self.token_kwargs = token_kwargs

def __call__(
self, request: requests.PreparedRequest
) -> requests.PreparedRequest: # noqa: D102
token = self.token
if token is None or token.is_expired(self.leeway):
self.renew_token()
return super().__call__(request)

def renew_token(self) -> None:
"""Obtain a new Bearer Token.
This should be implemented by subclasses.
"""
raise NotImplementedError

def forget_token(self) -> None:
"""Forget the current token, forcing a renewal on the next usage of this Auth Handler."""
self.token = None


class OAuth2ClientCredentialsAuth(BaseOAuth2RenewableTokenAuth):
"""An Auth Handler for the Client Credentials grant.
This [requests AuthBase][requests.auth.AuthBase] automatically gets Access
Expand All @@ -109,38 +158,16 @@ class OAuth2ClientCredentialsAuth(BearerAuth):
```
"""

def __init__(self, client: "OAuth2Client", **token_kwargs: Any):
super().__init__(None)
self.client = client
self.token_kwargs = token_kwargs

def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest:
"""Implement the Client Credentials grant as an Auth Handler.
This will obtain a token using the Client Credentials Grant, and include that token in requests.
Once the token is expired (detected using the 'expires_in' hint), it will obtain a new token.
Args:
request: a [PreparedRequest][requests.PreparedRequest]
Returns:
a [PreparedRequest][requests.PreparedRequest] with an Access Token added in Authorization Header
"""
token = self.token
if token is None or token.is_expired():
self.renew_token()
return super().__call__(request)

def renew_token(self) -> None:
"""Obtain a new token for use within this Auth Handler."""
self.token = self.client.client_credentials(**self.token_kwargs)


class OAuth2AccessTokenAuth(BearerAuth):
"""Authenticaton Handler for OAuth 2.0 Access Tokens and (optional) Refresh Tokens.
class OAuth2AccessTokenAuth(BaseOAuth2RenewableTokenAuth):
"""Authentication Handler for OAuth 2.0 Access Tokens and (optional) Refresh Tokens.
This [Requests Auth handler][requests.auth.AuthBase] implementation uses an access token as Bearer token, and can
automatically refreshes it when expired, if a refresh token is available.
automatically refresh it when expired, if a refresh token is available.
Token can be a simple `str` containing a raw access token value, or a [BearerToken][requests_oauth2client.tokens.BearerToken]
that can contain a refresh_token. If a refresh_token and an expiration date are available, this Auth Handler
Expand All @@ -163,34 +190,8 @@ class OAuth2AccessTokenAuth(BearerAuth):
````
"""

def __init__(
self,
client: "OAuth2Client",
token: Optional[Union[str, BearerToken]] = None,
**token_kwargs: Any,
) -> None:
super().__init__(token)
self.client = client
self.token_kwargs = token_kwargs

def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest:
"""Implement the usage of OAuth 2.0 access tokens as Bearer Tokens.
This adds access token in requests, and refreshes that token once it is expired.
Args:
request: a [PreparedRequest][requests.PreparedRequest]
Returns:
a [PreparedRequest][requests.PreparedRequest] with an Access Token added in Authorization Header
"""
token = self.token
if token is not None and token.is_expired():
self.renew_token()
return super().__call__(request)

def renew_token(self) -> None:
"""Obtain a new token, by using the Refresh Token, if available."""
"""Obtain a new token, using the Refresh Token, if available."""
if self.token and self.token.refresh_token and self.client is not None:
self.token = self.client.refresh_token(
refresh_token=self.token.refresh_token, **self.token_kwargs
Expand Down Expand Up @@ -218,13 +219,13 @@ class OAuth2AuthorizationCodeAuth(OAuth2AccessTokenAuth):

def __init__(
self,
client: "OAuth2Client",
client: OAuth2Client,
code: Union[str, AuthorizationResponse],
leeway: int = 20,
**token_kwargs: Any,
) -> None:
super().__init__(client, None)
super().__init__(client, token=None, leeway=leeway, **token_kwargs)
self.code: Union[str, AuthorizationResponse, None] = code
self.token_kwargs = token_kwargs

def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest:
"""Implement the Authorization Code grant as an Authentication Handler.
Expand Down Expand Up @@ -276,17 +277,17 @@ class OAuth2DeviceCodeAuth(OAuth2AccessTokenAuth):

def __init__(
self,
client: "OAuth2Client",
client: OAuth2Client,
device_code: Union[str, DeviceAuthorizationResponse],
leeway: int = 20,
interval: int = 5,
expires_in: int = 360,
**token_kwargs: Any,
) -> None:
super().__init__(client, None)
super().__init__(client=client, leeway=leeway, token=None, **token_kwargs)
self.device_code: Union[str, DeviceAuthorizationResponse, None] = device_code
self.interval = interval
self.expires_in = expires_in
self.token_kwargs = token_kwargs

def __call__(self, request: requests.PreparedRequest) -> requests.PreparedRequest:
"""Implement the Device Code grant as a request Authentication Handler.
Expand Down
41 changes: 20 additions & 21 deletions requests_oauth2client/flask/auth.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Helper classes for the [Flask](https://flask.palletsprojects.com) framework."""

from typing import Any, Optional
from typing import Any, Optional, Union

from flask import session

Expand All @@ -12,14 +12,22 @@
class FlaskSessionAuthMixin:
"""A Mixin for auth handlers to store their tokens in Flask session.
Storing tokens in Flask session does ensure that each user of a Flask application has a different access token, and that tokens will be persisted between multiple requests to the front-end Flask app.
Storing tokens in Flask session does ensure that each user of a Flask application has a different access token, and
that tokens used for backend API access will be persisted between multiple requests to the front-end Flask app.
Args:
session_key: the key that will be used to store the access token in session.
serializer: the serializer that will be used to store the access token in session.
"""

def __init__(self, session_key: str, serializer: Optional[BearerTokenSerializer] = None):
def __init__(
self,
session_key: str,
serializer: Optional[BearerTokenSerializer] = None,
*args: Any,
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)
self.serializer = serializer or BearerTokenSerializer()
self.session_key = session_key

Expand All @@ -28,48 +36,39 @@ def token(self) -> Optional[BearerToken]:
"""Return the Access Token stored in session.
Returns:
The current BearerToken for this session, if any.
The current `BearerToken` for this session, if any.
"""
serialized_token = session.get(self.session_key)
if serialized_token is None:
return None
return self.serializer.loads(serialized_token)

@token.setter
def token(self, token: Optional[BearerToken]) -> None:
def token(self, token: Union[BearerToken, str, None]) -> None:
"""Store an Access Token in session.
Args:
token: the token to store
"""
if isinstance(token, str):
token = BearerToken(token) # pragma: no cover
if token:
serialized_token = self.serializer.dumps(token)
session[self.session_key] = serialized_token
else:
elif session and self.session_key in session:
session.pop(self.session_key, None)


class FlaskOAuth2ClientCredentialsAuth(FlaskSessionAuthMixin, OAuth2ClientCredentialsAuth):
"""A `requests` Auth handler for CC that stores its token in Flask session.
"""A `requests` Auth handler for CC grant that stores its token in Flask session.
It will automatically gets access tokens from an OAuth 2.x Token Endpoint
with the Client Credentials grant (and can get a new one once it is expired),
and stores the retrieved token in Flask `session`, so that each user has a different access token.
It will automatically get Access Tokens from an OAuth 2.x AS
with the Client Credentials grant (and can get a new one once the first one is expired),
and stores the retrieved token, serialized in Flask `session`, so that each user has a different access token.
Args:
client: an OAuth2Client that will be used to retrieve tokens.
session_key: the key that will be used to store the access token in Flask session
serializer: a serializer that will be used to serialize the access token in Flask session
**token_kwargs: additional kwargs for the Token Request
"""

def __init__(
self,
client: OAuth2Client,
session_key: str,
serializer: Optional[BearerTokenSerializer] = None,
**token_kwargs: Any,
) -> None:
super().__init__(session_key, serializer)
self.client = client
self.token_kwargs = token_kwargs
54 changes: 38 additions & 16 deletions tests/unit_tests/test_flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from urllib.parse import parse_qs

import pytest
from flask import request

from requests_oauth2client import ApiClient, ClientSecretPost, OAuth2Client
from tests.conftest import RequestsMocker
Expand All @@ -23,21 +24,25 @@ def test_flask(
from requests_oauth2client.flask import FlaskOAuth2ClientCredentialsAuth
except ImportError:
pytest.skip("Flask is not available")
return

oauth_client = OAuth2Client(token_endpoint, ClientSecretPost(client_id, client_secret))
api_client = ApiClient(
auth=FlaskOAuth2ClientCredentialsAuth(
oauth_client, session_key=session_key, scope=scope
)
auth = FlaskOAuth2ClientCredentialsAuth(
session_key=session_key,
scope=scope,
client=oauth_client,
)
api_client = ApiClient(target_api, auth=auth)

assert isinstance(api_client.session.auth, FlaskOAuth2ClientCredentialsAuth)

app = Flask("testapp")
app.config["TESTING"] = True
app.config["SECRET_KEY"] = "thisissecret"

@app.route("/api")
def get() -> Any:
return api_client.get(target_api).json()
return api_client.get(params=request.args).json()

access_token = "access_token"
json_resp = {"status": "success"}
Expand All @@ -48,22 +53,22 @@ def get() -> Any:
requests_mock.get(target_api, json=json_resp)

with app.test_client() as client:
resp = client.get("/api")
resp = client.get("/api?call=1")
assert resp.json == json_resp
resp = client.get("/api")
resp = client.get("/api?call=2")
assert resp.json == json_resp
# api_client.auth.token = None # strangely this has no effect in a test session
with client.session_transaction() as sess: # does what 'api_client.auth.token = None' should do
sess.pop("session_key", None)
resp = client.get("/api")
api_client.session.auth.forget_token()
# assert api_client.session.auth.token is None
with client.session_transaction() as sess:
sess.pop(auth.session_key)
# this should trigger a new token request then the API request
resp = client.get("/api?call=3")
assert resp.json == json_resp

assert len(requests_mock.request_history) == 5
token_request = requests_mock.request_history[0]
api_request1 = requests_mock.request_history[1]
api_request2 = requests_mock.request_history[2]
# assert api_client.session.auth.token == access_token

token_params = parse_qs(token_request.text)
token_request1 = requests_mock.request_history[0]
token_params = parse_qs(token_request1.text)
assert token_params.get("client_id") == [client_id]
if not scope:
assert token_params.get("scope") is None
Expand All @@ -73,5 +78,22 @@ def get() -> Any:
assert token_params.get("scope") == [" ".join(scope)]
assert token_params.get("client_secret") == [client_secret]

api_request1 = requests_mock.request_history[1]
assert api_request1.headers.get("Authorization") == f"Bearer {access_token}"

api_request2 = requests_mock.request_history[2]
assert api_request2.headers.get("Authorization") == f"Bearer {access_token}"

token_request2 = requests_mock.request_history[3]
token_params = parse_qs(token_request2.text)
assert token_params.get("client_id") == [client_id]
if not scope:
assert token_params.get("scope") is None
elif isinstance(scope, str):
assert token_params.get("scope") == [scope]
elif isinstance(scope, Iterable):
assert token_params.get("scope") == [" ".join(scope)]
assert token_params.get("client_secret") == [client_secret]

api_request3 = requests_mock.request_history[4]
assert api_request3.headers.get("Authorization") == f"Bearer {access_token}"

0 comments on commit ad66ecf

Please sign in to comment.