Skip to content

Commit

Permalink
Add TokenVerificationMiddleware
Browse files Browse the repository at this point in the history
  • Loading branch information
timonegk committed Dec 19, 2023
1 parent d381461 commit 9f9d557
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 19 deletions.
43 changes: 43 additions & 0 deletions src/simple_openid_connect/integrations/django/middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import logging
from datetime import datetime, timezone
from typing import Callable

from django.http import HttpRequest, HttpResponse, HttpResponseRedirect
from django.urls import reverse

from simple_openid_connect.data import TokenSuccessResponse
from simple_openid_connect.integrations.django.apps import OpenidAppConfig
from simple_openid_connect.integrations.django.models import OpenidSession

logger = logging.getLogger(__name__)


class TokenVerificationMiddleware:
def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]) -> None:
self.get_response = get_response

def __call__(self, request: HttpRequest) -> HttpResponse:
response = self.get_response(request)
openid_session_id = request.session.get("openid_session")
if not openid_session_id:
return response

openid_session = OpenidSession.objects.get(id=openid_session_id)
refresh_token = openid_session.refresh_token
session_valid_until = openid_session.access_token_expiry
access_token_valid = (
session_valid_until is not None
and session_valid_until > datetime.now(timezone.utc)
)
if access_token_valid:
return response

logger.debug("access token expired, trying to refresh")
client = OpenidAppConfig.get_instance().get_client(request)
exchange_response = client.exchange_refresh_token(refresh_token)
if isinstance(exchange_response, TokenSuccessResponse):
openid_session.update_session(exchange_response)
openid_session.save()
return response
else:
return HttpResponseRedirect(reverse("logout"))
39 changes: 21 additions & 18 deletions src/simple_openid_connect/integrations/django/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@
from simple_openid_connect.integrations.django.apps import OpenidAppConfig


def _calc_expiry(t: Optional[int]) -> Optional[datetime]:
if t is not None:
return timezone.now() + timedelta(seconds=t)
return None


class OpenidUserManager(models.Manager["OpenidUser"]):
"""
Custom user manager for the :class:`OpenidUser` model.
Expand Down Expand Up @@ -53,44 +59,34 @@ class OpenidUser(models.Model):

def update_session(
self, token_response: TokenSuccessResponse, id_token: IdToken
) -> None:
) -> "OpenidSession":
"""
Update session information based on the given openid token response.
If the token contains a session id, that session is updated with newer information and if not, a new session
object is created.
"""

def calc_expiry(t: Optional[int]) -> Optional[datetime]:
if t is not None:
return timezone.now() + timedelta(seconds=t)
return None

# update the existing session if possible
if id_token.sid is not None:
query = OpenidSession.objects.filter(sid=id_token.sid)
if query.exists():
session = query.get() # type: OpenidSession
session.scope = str(token_response.scope)
session.access_token = token_response.access_token
session.access_token_expiry = calc_expiry(token_response.expires_in)
session.refresh_token = token_response.refresh_token or ""
session.refresh_token_expiry = calc_expiry(
token_response.refresh_expires_in
)
session.update_session(token_response)
session.id_token = id_token
return
session.save()
return session

# fall back to creating a new session
OpenidSession.objects.create(
return OpenidSession.objects.create(
user=self,
sid=id_token.sid or "",
scope=str(token_response.scope),
access_token=token_response.access_token,
access_token_expiry=calc_expiry(token_response.expires_in),
access_token_expiry=_calc_expiry(token_response.expires_in),
refresh_token=token_response.refresh_token or "",
refresh_token_expiry=calc_expiry(token_response.refresh_expires_in),
_id_token=id_token.json(),
refresh_token_expiry=_calc_expiry(token_response.refresh_expires_in),
_id_token=id_token.json(), # type: ignore
)


Expand Down Expand Up @@ -119,3 +115,10 @@ def id_token(self) -> IdToken:
@id_token.setter
def id_token(self, value: IdToken) -> None:
self._id_token = value.json()

def update_session(self, token_response: TokenSuccessResponse) -> None:
self.scope = str(token_response.scope)
self.access_token = token_response.access_token
self.access_token_expiry = _calc_expiry(token_response.expires_in)
self.refresh_token = token_response.refresh_token or ""
self.refresh_token_expiry = _calc_expiry(token_response.refresh_expires_in)
3 changes: 2 additions & 1 deletion src/simple_openid_connect/integrations/django/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ def get(self, request: HttpRequest) -> HttpResponse:
user = OpenidAppConfig.get_instance().user_mapper.handle_federated_userinfo(
id_token
)
user.openid.update_session(token_response, id_token)
openid_session = user.openid.update_session(token_response, id_token)
request.session["openid_session"] = openid_session.id
login(request, user, backend=settings.AUTHENTICATION_BACKENDS[0])

# redirect to the next get parameter if present, otherwise to the configured default
Expand Down

0 comments on commit 9f9d557

Please sign in to comment.