From 3eb11dc3998aa12211c5b445550c11b614c98bb1 Mon Sep 17 00:00:00 2001 From: juanifioren Date: Wed, 11 Dec 2024 23:08:24 -0300 Subject: [PATCH] Tox improvement + id token better serializer --- oidc_provider/models.py | 215 +++--- .../tests/cases/test_token_endpoint.py | 635 +++++++++--------- oidc_provider/tests/cases/test_utils.py | 25 + tox.ini | 7 +- 4 files changed, 456 insertions(+), 426 deletions(-) diff --git a/oidc_provider/models.py b/oidc_provider/models.py index d2111086..2edf78b7 100644 --- a/oidc_provider/models.py +++ b/oidc_provider/models.py @@ -1,31 +1,32 @@ import base64 import binascii -from hashlib import md5, sha256 import json +from hashlib import md5 +from hashlib import sha256 +from django.conf import settings +from django.core.serializers.json import DjangoJSONEncoder from django.db import models from django.utils import timezone from django.utils.translation import gettext_lazy as _ -from django.conf import settings - CLIENT_TYPE_CHOICES = [ - ('confidential', 'Confidential'), - ('public', 'Public'), + ("confidential", "Confidential"), + ("public", "Public"), ] RESPONSE_TYPE_CHOICES = [ - ('code', 'code (Authorization Code Flow)'), - ('id_token', 'id_token (Implicit Flow)'), - ('id_token token', 'id_token token (Implicit Flow)'), - ('code token', 'code token (Hybrid Flow)'), - ('code id_token', 'code id_token (Hybrid Flow)'), - ('code id_token token', 'code id_token token (Hybrid Flow)'), + ("code", "code (Authorization Code Flow)"), + ("id_token", "id_token (Implicit Flow)"), + ("id_token token", "id_token token (Implicit Flow)"), + ("code token", "code token (Hybrid Flow)"), + ("code id_token", "code id_token (Hybrid Flow)"), + ("code id_token token", "code id_token token (Hybrid Flow)"), ] JWT_ALGS = [ - ('HS256', 'HS256'), - ('RS256', 'RS256'), + ("HS256", "HS256"), + ("RS256", "RS256"), ] @@ -41,82 +42,102 @@ class ResponseType(models.Model): max_length=30, choices=RESPONSE_TYPE_CHOICES, unique=True, - verbose_name=_(u'Response Type Value')) + verbose_name=_("Response Type Value"), + ) description = models.CharField( max_length=50, ) def natural_key(self): - return self.value, # natural_key must return tuple + return (self.value,) # natural_key must return tuple def __str__(self): - return u'{0}'.format(self.description) + return "{0}".format(self.description) class Client(models.Model): - - name = models.CharField(max_length=100, default='', verbose_name=_(u'Name')) + name = models.CharField(max_length=100, default="", verbose_name=_("Name")) owner = models.ForeignKey( - settings.AUTH_USER_MODEL, verbose_name=_(u'Owner'), blank=True, - null=True, default=None, on_delete=models.SET_NULL, related_name='oidc_clients_set') + settings.AUTH_USER_MODEL, + verbose_name=_("Owner"), + blank=True, + null=True, + default=None, + on_delete=models.SET_NULL, + related_name="oidc_clients_set", + ) client_type = models.CharField( max_length=30, choices=CLIENT_TYPE_CHOICES, - default='confidential', - verbose_name=_(u'Client Type'), - help_text=_(u'Confidential clients are capable of maintaining the confidentiality' - u' of their credentials. Public clients are incapable.')) - client_id = models.CharField(max_length=255, unique=True, verbose_name=_(u'Client ID')) - client_secret = models.CharField(max_length=255, blank=True, verbose_name=_(u'Client SECRET')) + default="confidential", + verbose_name=_("Client Type"), + help_text=_( + "Confidential clients are capable of maintaining the confidentiality" + " of their credentials. Public clients are incapable." + ), + ) + client_id = models.CharField(max_length=255, unique=True, verbose_name=_("Client ID")) + client_secret = models.CharField(max_length=255, blank=True, verbose_name=_("Client SECRET")) response_types = models.ManyToManyField(ResponseType) jwt_alg = models.CharField( max_length=10, choices=JWT_ALGS, - default='RS256', - verbose_name=_(u'JWT Algorithm'), - help_text=_(u'Algorithm used to encode ID Tokens.')) - date_created = models.DateField(auto_now_add=True, verbose_name=_(u'Date Created')) + default="RS256", + verbose_name=_("JWT Algorithm"), + help_text=_("Algorithm used to encode ID Tokens."), + ) + date_created = models.DateField(auto_now_add=True, verbose_name=_("Date Created")) website_url = models.CharField( - max_length=255, blank=True, default='', verbose_name=_(u'Website URL')) + max_length=255, blank=True, default="", verbose_name=_("Website URL") + ) terms_url = models.CharField( max_length=255, blank=True, - default='', - verbose_name=_(u'Terms URL'), - help_text=_(u'External reference to the privacy policy of the client.')) + default="", + verbose_name=_("Terms URL"), + help_text=_("External reference to the privacy policy of the client."), + ) contact_email = models.CharField( - max_length=255, blank=True, default='', verbose_name=_(u'Contact Email')) + max_length=255, blank=True, default="", verbose_name=_("Contact Email") + ) logo = models.FileField( - blank=True, default='', upload_to='oidc_provider/clients', verbose_name=_(u'Logo Image')) + blank=True, default="", upload_to="oidc_provider/clients", verbose_name=_("Logo Image") + ) reuse_consent = models.BooleanField( default=True, - verbose_name=_('Reuse Consent?'), - help_text=_('If enabled, server will save the user consent given to a specific client, ' - 'so that user won\'t be prompted for the same authorization multiple times.')) + verbose_name=_("Reuse Consent?"), + help_text=_( + "If enabled, server will save the user consent given to a specific client, " + "so that user won't be prompted for the same authorization multiple times." + ), + ) require_consent = models.BooleanField( default=True, - verbose_name=_('Require Consent?'), - help_text=_('If disabled, the Server will NEVER ask the user for consent.')) + verbose_name=_("Require Consent?"), + help_text=_("If disabled, the Server will NEVER ask the user for consent."), + ) _redirect_uris = models.TextField( - default='', verbose_name=_(u'Redirect URIs'), - help_text=_(u'Enter each URI on a new line.')) + default="", verbose_name=_("Redirect URIs"), help_text=_("Enter each URI on a new line.") + ) _post_logout_redirect_uris = models.TextField( blank=True, - default='', - verbose_name=_(u'Post Logout Redirect URIs'), - help_text=_(u'Enter each URI on a new line.')) + default="", + verbose_name=_("Post Logout Redirect URIs"), + help_text=_("Enter each URI on a new line."), + ) _scope = models.TextField( blank=True, - default='', - verbose_name=_(u'Scopes'), - help_text=_('Specifies the authorized scope values for the client app.')) + default="", + verbose_name=_("Scopes"), + help_text=_("Specifies the authorized scope values for the client app."), + ) class Meta: - verbose_name = _(u'Client') - verbose_name_plural = _(u'Clients') + verbose_name = _("Client") + verbose_name_plural = _("Clients") def __str__(self): - return u'{0}'.format(self.name) + return "{0}".format(self.name) def __unicode__(self): return self.__str__() @@ -134,7 +155,7 @@ def redirect_uris(self): @redirect_uris.setter def redirect_uris(self, value): - self._redirect_uris = '\n'.join(value) + self._redirect_uris = "\n".join(value) @property def post_logout_redirect_uris(self): @@ -142,7 +163,7 @@ def post_logout_redirect_uris(self): @post_logout_redirect_uris.setter def post_logout_redirect_uris(self, value): - self._post_logout_redirect_uris = '\n'.join(value) + self._post_logout_redirect_uris = "\n".join(value) @property def scope(self): @@ -150,18 +171,17 @@ def scope(self): @scope.setter def scope(self, value): - self._scope = ' '.join(value) + self._scope = " ".join(value) @property def default_redirect_uri(self): - return self.redirect_uris[0] if self.redirect_uris else '' + return self.redirect_uris[0] if self.redirect_uris else "" class BaseCodeTokenModel(models.Model): - - client = models.ForeignKey(Client, verbose_name=_(u'Client'), on_delete=models.CASCADE) - expires_at = models.DateTimeField(verbose_name=_(u'Expiration Date')) - _scope = models.TextField(default='', verbose_name=_(u'Scopes')) + client = models.ForeignKey(Client, verbose_name=_("Client"), on_delete=models.CASCADE) + expires_at = models.DateTimeField(verbose_name=_("Expiration Date")) + _scope = models.TextField(default="", verbose_name=_("Scopes")) class Meta: abstract = True @@ -172,7 +192,7 @@ def scope(self): @scope.setter def scope(self, value): - self._scope = ' '.join(value) + self._scope = " ".join(value) def __unicode__(self): return self.__str__() @@ -182,35 +202,36 @@ def has_expired(self): class Code(BaseCodeTokenModel): - user = models.ForeignKey( - settings.AUTH_USER_MODEL, verbose_name=_(u'User'), on_delete=models.CASCADE) - code = models.CharField(max_length=255, unique=True, verbose_name=_(u'Code')) - nonce = models.CharField(max_length=255, blank=True, default='', verbose_name=_(u'Nonce')) - is_authentication = models.BooleanField(default=False, verbose_name=_(u'Is Authentication?')) - code_challenge = models.CharField(max_length=255, null=True, verbose_name=_(u'Code Challenge')) + settings.AUTH_USER_MODEL, verbose_name=_("User"), on_delete=models.CASCADE + ) + code = models.CharField(max_length=255, unique=True, verbose_name=_("Code")) + nonce = models.CharField(max_length=255, blank=True, default="", verbose_name=_("Nonce")) + is_authentication = models.BooleanField(default=False, verbose_name=_("Is Authentication?")) + code_challenge = models.CharField(max_length=255, null=True, verbose_name=_("Code Challenge")) code_challenge_method = models.CharField( - max_length=255, null=True, verbose_name=_(u'Code Challenge Method')) + max_length=255, null=True, verbose_name=_("Code Challenge Method") + ) class Meta: - verbose_name = _(u'Authorization Code') - verbose_name_plural = _(u'Authorization Codes') + verbose_name = _("Authorization Code") + verbose_name_plural = _("Authorization Codes") def __str__(self): - return u'{0} - {1}'.format(self.client, self.code) + return "{0} - {1}".format(self.client, self.code) class Token(BaseCodeTokenModel): - user = models.ForeignKey( - settings.AUTH_USER_MODEL, null=True, verbose_name=_(u'User'), on_delete=models.CASCADE) - access_token = models.CharField(max_length=255, unique=True, verbose_name=_(u'Access Token')) - refresh_token = models.CharField(max_length=255, unique=True, verbose_name=_(u'Refresh Token')) - _id_token = models.TextField(verbose_name=_(u'ID Token')) + settings.AUTH_USER_MODEL, null=True, verbose_name=_("User"), on_delete=models.CASCADE + ) + access_token = models.CharField(max_length=255, unique=True, verbose_name=_("Access Token")) + refresh_token = models.CharField(max_length=255, unique=True, verbose_name=_("Refresh Token")) + _id_token = models.TextField(verbose_name=_("ID Token")) class Meta: - verbose_name = _(u'Token') - verbose_name_plural = _(u'Tokens') + verbose_name = _("Token") + verbose_name_plural = _("Tokens") @property def id_token(self): @@ -218,50 +239,48 @@ def id_token(self): @id_token.setter def id_token(self, value): - self._id_token = json.dumps(value) + self._id_token = json.dumps(value, cls=DjangoJSONEncoder, skipkeys=True, default=str) def __str__(self): - return u'{0} - {1}'.format(self.client, self.access_token) + return "{0} - {1}".format(self.client, self.access_token) @property def at_hash(self): # @@@ d-o-p only supports 256 bits (change this if that changes) - hashed_access_token = sha256( - self.access_token.encode('ascii') - ).hexdigest().encode('ascii') - return base64.urlsafe_b64encode( - binascii.unhexlify( - hashed_access_token[:len(hashed_access_token) // 2] + hashed_access_token = sha256(self.access_token.encode("ascii")).hexdigest().encode("ascii") + return ( + base64.urlsafe_b64encode( + binascii.unhexlify(hashed_access_token[: len(hashed_access_token) // 2]) ) - ).rstrip(b'=').decode('ascii') + .rstrip(b"=") + .decode("ascii") + ) class UserConsent(BaseCodeTokenModel): - user = models.ForeignKey( - settings.AUTH_USER_MODEL, verbose_name=_(u'User'), on_delete=models.CASCADE) - date_given = models.DateTimeField(verbose_name=_(u'Date Given')) + settings.AUTH_USER_MODEL, verbose_name=_("User"), on_delete=models.CASCADE + ) + date_given = models.DateTimeField(verbose_name=_("Date Given")) class Meta: - unique_together = ('user', 'client') + unique_together = ("user", "client") class RSAKey(models.Model): - - key = models.TextField( - verbose_name=_(u'Key'), help_text=_(u'Paste your private RSA Key here.')) + key = models.TextField(verbose_name=_("Key"), help_text=_("Paste your private RSA Key here.")) class Meta: ordering = ["id"] - verbose_name = _(u'RSA Key') - verbose_name_plural = _(u'RSA Keys') + verbose_name = _("RSA Key") + verbose_name_plural = _("RSA Keys") def __str__(self): - return u'{0}'.format(self.kid) + return "{0}".format(self.kid) def __unicode__(self): return self.__str__() @property def kid(self): - return u'{0}'.format(md5(self.key.encode('utf-8')).hexdigest() if self.key else '') + return "{0}".format(md5(self.key.encode("utf-8")).hexdigest() if self.key else "") diff --git a/oidc_provider/tests/cases/test_token_endpoint.py b/oidc_provider/tests/cases/test_token_endpoint.py index 8990d3d2..df0d6ad1 100644 --- a/oidc_provider/tests/cases/test_token_endpoint.py +++ b/oidc_provider/tests/cases/test_token_endpoint.py @@ -1,7 +1,6 @@ import json import time import uuid - from base64 import b64encode from django.db import DatabaseError @@ -18,11 +17,9 @@ from django.urls import reverse except ImportError: from django.core.urlresolvers import reverse -from django.test import ( - RequestFactory, - override_settings, -) +from django.test import RequestFactory from django.test import TestCase +from django.test import override_settings from django.views.decorators.http import require_http_methods from jwkest.jwk import KEYS from jwkest.jws import JWS @@ -33,19 +30,15 @@ from oidc_provider.lib.utils.oauth2 import protected_resource_view from oidc_provider.lib.utils.token import create_code from oidc_provider.models import Token -from oidc_provider.tests.app.utils import ( - create_fake_user, - create_fake_client, - FAKE_CODE_CHALLENGE, - FAKE_CODE_VERIFIER, - FAKE_NONCE, - FAKE_RANDOM_STRING, -) -from oidc_provider.views import ( - JwksView, - TokenView, - userinfo, -) +from oidc_provider.tests.app.utils import FAKE_CODE_CHALLENGE +from oidc_provider.tests.app.utils import FAKE_CODE_VERIFIER +from oidc_provider.tests.app.utils import FAKE_NONCE +from oidc_provider.tests.app.utils import FAKE_RANDOM_STRING +from oidc_provider.tests.app.utils import create_fake_client +from oidc_provider.tests.app.utils import create_fake_user +from oidc_provider.views import JwksView +from oidc_provider.views import TokenView +from oidc_provider.views import userinfo class TokenTestCase(TestCase): @@ -54,25 +47,26 @@ class TokenTestCase(TestCase): Token Request to the Token Endpoint to obtain a Token Response when using the Authorization Code Flow. """ - SCOPE = 'openid email' - SCOPE_LIST = SCOPE.split(' ') + + SCOPE = "openid email" + SCOPE_LIST = SCOPE.split(" ") def setUp(self): - call_command('creatersakey') + call_command("creatersakey") self.factory = RequestFactory() self.user = create_fake_user() self.request_client = self.client - self.client = create_fake_client(response_type='code') + self.client = create_fake_client(response_type="code") def _password_grant_post_data(self, scope=None): result = { - 'username': 'johndoe', - 'password': '1234', - 'grant_type': 'password', - 'scope': TokenTestCase.SCOPE, + "username": "johndoe", + "password": "1234", + "grant_type": "password", + "scope": TokenTestCase.SCOPE, } if scope is not None: - result['scope'] = ' '.join(scope) + result["scope"] = " ".join(scope) return result def _auth_code_post_data(self, code, scope=None): @@ -80,15 +74,15 @@ def _auth_code_post_data(self, code, scope=None): All the data that will be POSTed to the Token Endpoint. """ post_data = { - 'client_id': self.client.client_id, - 'client_secret': self.client.client_secret, - 'redirect_uri': self.client.default_redirect_uri, - 'grant_type': 'authorization_code', - 'code': code, - 'state': uuid.uuid4().hex, + "client_id": self.client.client_id, + "client_secret": self.client.client_secret, + "redirect_uri": self.client.default_redirect_uri, + "grant_type": "authorization_code", + "code": code, + "state": uuid.uuid4().hex, } if scope is not None: - post_data['scope'] = ' '.join(scope) + post_data["scope"] = " ".join(scope) return post_data @@ -97,24 +91,24 @@ def _refresh_token_post_data(self, refresh_token, scope=None): All the data that will be POSTed to the Token Endpoint. """ post_data = { - 'client_id': self.client.client_id, - 'client_secret': self.client.client_secret, - 'grant_type': 'refresh_token', - 'refresh_token': refresh_token, + "client_id": self.client.client_id, + "client_secret": self.client.client_secret, + "grant_type": "refresh_token", + "refresh_token": refresh_token, } if scope is not None: - post_data['scope'] = ' '.join(scope) + post_data["scope"] = " ".join(scope) return post_data def _client_credentials_post_data(self, scope=None): post_data = { - 'client_id': self.client.client_id, - 'client_secret': self.client.client_secret, - 'grant_type': 'client_credentials', + "client_id": self.client.client_id, + "client_secret": self.client.client_secret, + "grant_type": "client_credentials", } if scope is not None: - post_data['scope'] = ' '.join(scope) + post_data["scope"] = " ".join(scope) return post_data def _post_request(self, post_data, extras={}): @@ -123,13 +117,14 @@ def _post_request(self, post_data, extras={}): `post_data` parameters using the 'application/x-www-form-urlencoded' format. """ - url = reverse('oidc_provider:token') + url = reverse("oidc_provider:token") request = self.factory.post( url, data=urlencode(post_data), - content_type='application/x-www-form-urlencoded', - **extras) + content_type="application/x-www-form-urlencoded", + **extras, + ) response = TokenView.as_view()(request) @@ -144,7 +139,8 @@ def _create_code(self, scope=None): client=self.client, scope=(scope if scope else TokenTestCase.SCOPE_LIST), nonce=FAKE_NONCE, - is_authentication=True) + is_authentication=True, + ) code.save() return code @@ -153,141 +149,135 @@ def _get_keys(self): """ Get public key from discovery. """ - request = self.factory.get(reverse('oidc_provider:jwks')) + request = self.factory.get(reverse("oidc_provider:jwks")) response = JwksView.as_view()(request) - jwks_dic = json.loads(response.content.decode('utf-8')) + jwks_dic = json.loads(response.content.decode("utf-8")) SIGKEYS = KEYS() SIGKEYS.load_dict(jwks_dic) return SIGKEYS def _get_userinfo(self, access_token): - url = reverse('oidc_provider:userinfo') + url = reverse("oidc_provider:userinfo") request = self.factory.get(url) - request.META['HTTP_AUTHORIZATION'] = 'Bearer ' + access_token + request.META["HTTP_AUTHORIZATION"] = "Bearer " + access_token return userinfo(request) def _password_grant_auth_header(self): - user_pass = self.client.client_id + ':' + self.client.client_secret - auth = b'Basic ' + b64encode(user_pass.encode('utf-8')) - auth_header = {'HTTP_AUTHORIZATION': auth.decode('utf-8')} + user_pass = self.client.client_id + ":" + self.client.client_secret + auth = b"Basic " + b64encode(user_pass.encode("utf-8")) + auth_header = {"HTTP_AUTHORIZATION": auth.decode("utf-8")} return auth_header def test_default_setting_does_not_allow_grant_type_password(self): post_data = self._password_grant_post_data() response = self._post_request( - post_data=post_data, - extras=self._password_grant_auth_header() + post_data=post_data, extras=self._password_grant_auth_header() ) - response_dict = json.loads(response.content.decode('utf-8')) + response_dict = json.loads(response.content.decode("utf-8")) self.assertEqual(400, response.status_code) - self.assertEqual('unsupported_grant_type', response_dict['error']) + self.assertEqual("unsupported_grant_type", response_dict["error"]) @override_settings(OIDC_GRANT_TYPE_PASSWORD_ENABLE=True) def test_password_grant_get_access_token_without_scope(self): post_data = self._password_grant_post_data() - del (post_data['scope']) + del post_data["scope"] response = self._post_request( - post_data=post_data, - extras=self._password_grant_auth_header() + post_data=post_data, extras=self._password_grant_auth_header() ) - response_dict = json.loads(response.content.decode('utf-8')) - self.assertIn('access_token', response_dict) + response_dict = json.loads(response.content.decode("utf-8")) + self.assertIn("access_token", response_dict) @override_settings(OIDC_GRANT_TYPE_PASSWORD_ENABLE=True) def test_password_grant_get_access_token_with_scope(self): response = self._post_request( - post_data=self._password_grant_post_data(), - extras=self._password_grant_auth_header() + post_data=self._password_grant_post_data(), extras=self._password_grant_auth_header() ) - response_dict = json.loads(response.content.decode('utf-8')) - self.assertIn('access_token', response_dict) + response_dict = json.loads(response.content.decode("utf-8")) + self.assertIn("access_token", response_dict) @override_settings(OIDC_GRANT_TYPE_PASSWORD_ENABLE=True) def test_password_grant_get_access_token_invalid_user_credentials(self): invalid_post = self._password_grant_post_data() - invalid_post['password'] = 'wrong!' + invalid_post["password"] = "wrong!" response = self._post_request( - post_data=invalid_post, - extras=self._password_grant_auth_header() + post_data=invalid_post, extras=self._password_grant_auth_header() ) - response_dict = json.loads(response.content.decode('utf-8')) + response_dict = json.loads(response.content.decode("utf-8")) self.assertEqual(403, response.status_code) - self.assertEqual('access_denied', response_dict['error']) + self.assertEqual("access_denied", response_dict["error"]) def test_password_grant_get_access_token_invalid_client_credentials(self): - self.client.client_id = 'foo' - self.client.client_secret = 'bar' + self.client.client_id = "foo" + self.client.client_secret = "bar" response = self._post_request( - post_data=self._password_grant_post_data(), - extras=self._password_grant_auth_header() + post_data=self._password_grant_post_data(), extras=self._password_grant_auth_header() ) - response_dict = json.loads(response.content.decode('utf-8')) + response_dict = json.loads(response.content.decode("utf-8")) self.assertEqual(400, response.status_code) - self.assertEqual('invalid_client', response_dict['error']) + self.assertEqual("invalid_client", response_dict["error"]) def test_password_grant_full_response(self): - self.check_password_grant(scope=['openid', 'email']) + self.check_password_grant(scope=["openid", "email"]) def test_password_grant_scope(self): - scopes_list = ['openid', 'profile'] + scopes_list = ["openid", "profile"] self.client.scope = scopes_list self.client.save() self.check_password_grant(scope=scopes_list) - @override_settings(OIDC_TOKEN_EXPIRE=120, - OIDC_GRANT_TYPE_PASSWORD_ENABLE=True) + @override_settings(OIDC_TOKEN_EXPIRE=120, OIDC_GRANT_TYPE_PASSWORD_ENABLE=True) def check_password_grant(self, scope): response = self._post_request( post_data=self._password_grant_post_data(scope), - extras=self._password_grant_auth_header() + extras=self._password_grant_auth_header(), ) - response_dict = json.loads(response.content.decode('utf-8')) - id_token = JWS().verify_compact( - response_dict['id_token'].encode('utf-8'), self._get_keys()) + response_dict = json.loads(response.content.decode("utf-8")) + id_token = JWS().verify_compact(response_dict["id_token"].encode("utf-8"), self._get_keys()) token = Token.objects.get(user=self.user) - self.assertEqual(response_dict['access_token'], token.access_token) - self.assertEqual(response_dict['refresh_token'], token.refresh_token) - self.assertEqual(response_dict['expires_in'], 120) - self.assertEqual(response_dict['token_type'], 'bearer') - self.assertEqual(id_token['sub'], str(self.user.id)) - self.assertEqual(id_token['aud'], self.client.client_id) + self.assertEqual(response_dict["access_token"], token.access_token) + self.assertEqual(response_dict["refresh_token"], token.refresh_token) + self.assertEqual(response_dict["expires_in"], 120) + self.assertEqual(response_dict["token_type"], "bearer") + self.assertEqual(id_token["sub"], str(self.user.id)) + self.assertEqual(id_token["aud"], self.client.client_id) # Check the scope is honored by checking the claims in the userinfo - userinfo_response = self._get_userinfo(response_dict['access_token']) - userinfo = json.loads(userinfo_response.content.decode('utf-8')) + userinfo_response = self._get_userinfo(response_dict["access_token"]) + userinfo = json.loads(userinfo_response.content.decode("utf-8")) - for (scope_param, claim) in [('email', 'email'), ('profile', 'name')]: + for scope_param, claim in [("email", "email"), ("profile", "name")]: if scope_param in scope: self.assertIn(claim, userinfo) else: self.assertNotIn(claim, userinfo) - @override_settings(OIDC_GRANT_TYPE_PASSWORD_ENABLE=True, - AUTHENTICATION_BACKENDS=("oidc_provider.tests.app.utils.TestAuthBackend",)) + @override_settings( + OIDC_GRANT_TYPE_PASSWORD_ENABLE=True, + AUTHENTICATION_BACKENDS=("oidc_provider.tests.app.utils.TestAuthBackend",), + ) def test_password_grant_passes_request_to_backend(self): response = self._post_request( - post_data=self._password_grant_post_data(), - extras=self._password_grant_auth_header() + post_data=self._password_grant_post_data(), extras=self._password_grant_auth_header() ) - response_dict = json.loads(response.content.decode('utf-8')) - self.assertIn('access_token', response_dict) + response_dict = json.loads(response.content.decode("utf-8")) + self.assertIn("access_token", response_dict) @override_settings(OIDC_TOKEN_EXPIRE=720) def test_authorization_code(self): @@ -302,17 +292,17 @@ def test_authorization_code(self): post_data = self._auth_code_post_data(code=code.code) response = self._post_request(post_data) - response_dic = json.loads(response.content.decode('utf-8')) + response_dic = json.loads(response.content.decode("utf-8")) - id_token = JWS().verify_compact(response_dic['id_token'].encode('utf-8'), SIGKEYS) + id_token = JWS().verify_compact(response_dic["id_token"].encode("utf-8"), SIGKEYS) token = Token.objects.get(user=self.user) - self.assertEqual(response_dic['access_token'], token.access_token) - self.assertEqual(response_dic['refresh_token'], token.refresh_token) - self.assertEqual(response_dic['token_type'], 'bearer') - self.assertEqual(response_dic['expires_in'], 720) - self.assertEqual(id_token['sub'], str(self.user.id)) - self.assertEqual(id_token['aud'], self.client.client_id) + self.assertEqual(response_dic["access_token"], token.access_token) + self.assertEqual(response_dic["refresh_token"], token.refresh_token) + self.assertEqual(response_dic["token_type"], "bearer") + self.assertEqual(response_dic["expires_in"], 720) + self.assertEqual(id_token["sub"], str(self.user.id)) + self.assertEqual(id_token["aud"], self.client.client_id) @override_settings(OIDC_TOKEN_EXPIRE=720) def test_authorization_code_cant_be_reused(self): @@ -323,46 +313,44 @@ def test_authorization_code_cant_be_reused(self): code = self._create_code() post_data = self._auth_code_post_data(code=code.code) - with patch('django.db.models.query.QuerySet.select_for_update') as select_for_update_func: + with patch("django.db.models.query.QuerySet.select_for_update") as select_for_update_func: select_for_update_func.side_effect = DatabaseError() response = self._post_request(post_data) select_for_update_func.assert_called_once() self.assertEqual(response.status_code, 400) - response_dic = json.loads(response.content.decode('utf-8')) - self.assertEqual(response_dic['error'], 'invalid_grant') + response_dic = json.loads(response.content.decode("utf-8")) + self.assertEqual(response_dic["error"], "invalid_grant") - @override_settings(OIDC_TOKEN_EXPIRE=720, - OIDC_IDTOKEN_INCLUDE_CLAIMS=True) + @override_settings(OIDC_TOKEN_EXPIRE=720, OIDC_IDTOKEN_INCLUDE_CLAIMS=True) def test_scope_is_ignored_for_auth_code(self): """ Scope is ignored for token respones to auth code grant type. This comes down to that the scopes requested in authorize are returned. """ SIGKEYS = self._get_keys() - for code_scope in [['openid'], ['openid', 'email'], ['openid', 'profile']]: + for code_scope in [["openid"], ["openid", "email"], ["openid", "profile"]]: code = self._create_code(code_scope) - post_data = self._auth_code_post_data( - code=code.code, scope=code_scope) + post_data = self._auth_code_post_data(code=code.code, scope=code_scope) response = self._post_request(post_data) - response_dic = json.loads(response.content.decode('utf-8')) + response_dic = json.loads(response.content.decode("utf-8")) self.assertEqual(response.status_code, 200) - id_token = JWS().verify_compact(response_dic['id_token'].encode('utf-8'), SIGKEYS) + id_token = JWS().verify_compact(response_dic["id_token"].encode("utf-8"), SIGKEYS) - if 'email' in code_scope: - self.assertIn('email', id_token) - self.assertIn('email_verified', id_token) + if "email" in code_scope: + self.assertIn("email", id_token) + self.assertIn("email_verified", id_token) else: - self.assertNotIn('email', id_token) + self.assertNotIn("email", id_token) - if 'profile' in code_scope: - self.assertIn('given_name', id_token) + if "profile" in code_scope: + self.assertIn("given_name", id_token) else: - self.assertNotIn('given_name', id_token) + self.assertNotIn("given_name", id_token) def test_refresh_token(self): """ @@ -380,7 +368,7 @@ def test_refresh_token_invalid_scope(self): though the original authorized scope in the authorization code request is only ['openid', 'email']. """ - self.do_refresh_token_check(scope=['openid', 'profile']) + self.do_refresh_token_check(scope=["openid", "profile"]) def test_refresh_token_narrowed_scope(self): """ @@ -390,7 +378,7 @@ def test_refresh_token_narrowed_scope(self): though the original authorized scope in the authorization code request is ['openid', 'email']. """ - self.do_refresh_token_check(scope=['openid']) + self.do_refresh_token_check(scope=["openid"]) @override_settings(OIDC_IDTOKEN_INCLUDE_CLAIMS=True) def do_refresh_token_check(self, scope=None): @@ -401,70 +389,69 @@ def do_refresh_token_check(self, scope=None): self.assertEqual(code.scope, TokenTestCase.SCOPE_LIST) post_data = self._auth_code_post_data(code=code.code) start_time = time.time() - with patch('oidc_provider.lib.utils.token.time.time') as time_func: + with patch("oidc_provider.lib.utils.token.time.time") as time_func: time_func.return_value = start_time response = self._post_request(post_data) - response_dic1 = json.loads(response.content.decode('utf-8')) - id_token1 = JWS().verify_compact(response_dic1['id_token'].encode('utf-8'), SIGKEYS) + response_dic1 = json.loads(response.content.decode("utf-8")) + id_token1 = JWS().verify_compact(response_dic1["id_token"].encode("utf-8"), SIGKEYS) # Use refresh token to obtain new token - post_data = self._refresh_token_post_data( - response_dic1['refresh_token'], scope) - with patch('oidc_provider.lib.utils.token.time.time') as time_func: + post_data = self._refresh_token_post_data(response_dic1["refresh_token"], scope) + with patch("oidc_provider.lib.utils.token.time.time") as time_func: time_func.return_value = start_time + 600 response = self._post_request(post_data) - response_dic2 = json.loads(response.content.decode('utf-8')) + response_dic2 = json.loads(response.content.decode("utf-8")) if scope and set(scope) - set(code.scope): # too broad scope self.assertEqual(response.status_code, 400) # Bad Request - self.assertIn('error', response_dic2) - self.assertEqual(response_dic2['error'], 'invalid_scope') + self.assertIn("error", response_dic2) + self.assertEqual(response_dic2["error"], "invalid_scope") return # No more checks - id_token2 = JWS().verify_compact(response_dic2['id_token'].encode('utf-8'), SIGKEYS) + id_token2 = JWS().verify_compact(response_dic2["id_token"].encode("utf-8"), SIGKEYS) - if scope and 'email' not in scope: # narrowed scope The auth + if scope and "email" not in scope: # narrowed scope The auth # The auth code request had email in scope, so it should be # in the first id token - self.assertIn('email', id_token1) + self.assertIn("email", id_token1) # but the refresh request had no email in scope - self.assertNotIn('email', id_token2, 'email was not requested') + self.assertNotIn("email", id_token2, "email was not requested") - self.assertNotEqual(response_dic1['id_token'], response_dic2['id_token']) - self.assertNotEqual(response_dic1['access_token'], response_dic2['access_token']) - self.assertNotEqual(response_dic1['refresh_token'], response_dic2['refresh_token']) + self.assertNotEqual(response_dic1["id_token"], response_dic2["id_token"]) + self.assertNotEqual(response_dic1["access_token"], response_dic2["access_token"]) + self.assertNotEqual(response_dic1["refresh_token"], response_dic2["refresh_token"]) # http://openid.net/specs/openid-connect-core-1_0.html#rfc.section.12.2 - self.assertEqual(id_token1['iss'], id_token2['iss']) - self.assertEqual(id_token1['sub'], id_token2['sub']) - self.assertNotEqual(id_token1['iat'], id_token2['iat']) - self.assertEqual(id_token1['iat'], int(start_time)) - self.assertEqual(id_token2['iat'], int(start_time + 600)) - self.assertEqual(id_token1['aud'], id_token2['aud']) - self.assertEqual(id_token1['auth_time'], id_token2['auth_time']) - self.assertEqual(id_token1.get('azp'), id_token2.get('azp')) + self.assertEqual(id_token1["iss"], id_token2["iss"]) + self.assertEqual(id_token1["sub"], id_token2["sub"]) + self.assertNotEqual(id_token1["iat"], id_token2["iat"]) + self.assertEqual(id_token1["iat"], int(start_time)) + self.assertEqual(id_token2["iat"], int(start_time + 600)) + self.assertEqual(id_token1["aud"], id_token2["aud"]) + self.assertEqual(id_token1["auth_time"], id_token2["auth_time"]) + self.assertEqual(id_token1.get("azp"), id_token2.get("azp")) # Refresh token can't be reused - post_data = self._refresh_token_post_data(response_dic1['refresh_token']) + post_data = self._refresh_token_post_data(response_dic1["refresh_token"]) response = self._post_request(post_data) - self.assertIn('invalid_grant', response.content.decode('utf-8')) + self.assertIn("invalid_grant", response.content.decode("utf-8")) # Old access token is invalidated - self.assertEqual(self._get_userinfo(response_dic1['access_token']).status_code, 401) - self.assertEqual(self._get_userinfo(response_dic2['access_token']).status_code, 200) + self.assertEqual(self._get_userinfo(response_dic1["access_token"]).status_code, 401) + self.assertEqual(self._get_userinfo(response_dic2["access_token"]).status_code, 200) # Empty refresh token is invalid - post_data = self._refresh_token_post_data('') + post_data = self._refresh_token_post_data("") response = self._post_request(post_data) - self.assertIn('invalid_grant', response.content.decode('utf-8')) + self.assertIn("invalid_grant", response.content.decode("utf-8")) # No refresh token is invalid - post_data = self._refresh_token_post_data('') - del post_data['refresh_token'] + post_data = self._refresh_token_post_data("") + del post_data["refresh_token"] response = self._post_request(post_data) - self.assertIn('invalid_grant', response.content.decode('utf-8')) + self.assertIn("invalid_grant", response.content.decode("utf-8")) def test_client_redirect_uri(self): """ @@ -477,29 +464,29 @@ def test_client_redirect_uri(self): post_data = self._auth_code_post_data(code=code.code) # Unregistered URI - post_data['redirect_uri'] = 'http://invalid.example.org' + post_data["redirect_uri"] = "http://invalid.example.org" response = self._post_request(post_data) - self.assertIn('invalid_client', response.content.decode('utf-8')) + self.assertIn("invalid_client", response.content.decode("utf-8")) # Registered URI, but with query string appended - post_data['redirect_uri'] = self.client.default_redirect_uri + '?foo=bar' + post_data["redirect_uri"] = self.client.default_redirect_uri + "?foo=bar" response = self._post_request(post_data) - self.assertIn('invalid_client', response.content.decode('utf-8')) + self.assertIn("invalid_client", response.content.decode("utf-8")) # Registered URI - post_data['redirect_uri'] = self.client.default_redirect_uri + post_data["redirect_uri"] = self.client.default_redirect_uri response = self._post_request(post_data) - self.assertNotIn('invalid_client', response.content.decode('utf-8')) + self.assertNotIn("invalid_client", response.content.decode("utf-8")) def test_request_methods(self): """ Client sends an HTTP POST request to the Token Endpoint. Other request methods MUST NOT be allowed. """ - url = reverse('oidc_provider:token') + url = reverse("oidc_provider:token") requests = [ self.factory.get(url), @@ -511,16 +498,18 @@ def test_request_methods(self): response = TokenView.as_view()(request) self.assertEqual( - response.status_code, 405, - msg=request.method + ' request does not return a 405 status.') + response.status_code, + 405, + msg=request.method + " request does not return a 405 status.", + ) request = self.factory.post(url) response = TokenView.as_view()(request) self.assertEqual( - response.status_code, 400, - msg=request.method + ' request does not return a 400 status.') + response.status_code, 400, msg=request.method + " request does not return a 400 status." + ) def test_client_authentication(self): """ @@ -538,42 +527,45 @@ def test_client_authentication(self): response = self._post_request(post_data) self.assertNotIn( - 'invalid_client', - response.content.decode('utf-8'), - msg='Client authentication fails using request-body credentials.') + "invalid_client", + response.content.decode("utf-8"), + msg="Client authentication fails using request-body credentials.", + ) # Now, test with an invalid client_id. invalid_data = post_data.copy() - invalid_data['client_id'] = self.client.client_id * 2 # Fake id. + invalid_data["client_id"] = self.client.client_id * 2 # Fake id. # Create another grant code. code = self._create_code() - invalid_data['code'] = code.code + invalid_data["code"] = code.code response = self._post_request(invalid_data) self.assertIn( - 'invalid_client', - response.content.decode('utf-8'), - msg='Client authentication success with an invalid "client_id".') + "invalid_client", + response.content.decode("utf-8"), + msg='Client authentication success with an invalid "client_id".', + ) # Now, test using HTTP Basic Authentication method. basicauth_data = post_data.copy() # Create another grant code. code = self._create_code() - basicauth_data['code'] = code.code + basicauth_data["code"] = code.code - del basicauth_data['client_id'] - del basicauth_data['client_secret'] + del basicauth_data["client_id"] + del basicauth_data["client_secret"] response = self._post_request(basicauth_data, self._password_grant_auth_header()) - response.content.decode('utf-8') + response.content.decode("utf-8") self.assertNotIn( - 'invalid_client', - response.content.decode('utf-8'), - msg='Client authentication fails using HTTP Basic Auth.') + "invalid_client", + response.content.decode("utf-8"), + msg="Client authentication fails using HTTP Basic Auth.", + ) def test_access_token_contains_nonce(self): """ @@ -591,21 +583,21 @@ def test_access_token_contains_nonce(self): response = self._post_request(post_data) - response_dic = json.loads(response.content.decode('utf-8')) - id_token = JWT().unpack(response_dic['id_token'].encode('utf-8')).payload() + response_dic = json.loads(response.content.decode("utf-8")) + id_token = JWT().unpack(response_dic["id_token"].encode("utf-8")).payload() - self.assertEqual(id_token.get('nonce'), FAKE_NONCE) + self.assertEqual(id_token.get("nonce"), FAKE_NONCE) # Client does not supply a nonce parameter. - code.nonce = '' + code.nonce = "" code.save() response = self._post_request(post_data) - response_dic = json.loads(response.content.decode('utf-8')) + response_dic = json.loads(response.content.decode("utf-8")) - id_token = JWT().unpack(response_dic['id_token'].encode('utf-8')).payload() + id_token = JWT().unpack(response_dic["id_token"].encode("utf-8")).payload() - self.assertEqual(id_token.get('nonce'), None) + self.assertEqual(id_token.get("nonce"), None) def test_id_token_contains_at_hash(self): """ @@ -617,10 +609,10 @@ def test_id_token_contains_at_hash(self): response = self._post_request(post_data) - response_dic = json.loads(response.content.decode('utf-8')) - id_token = JWT().unpack(response_dic['id_token'].encode('utf-8')).payload() + response_dic = json.loads(response.content.decode("utf-8")) + id_token = JWT().unpack(response_dic["id_token"].encode("utf-8")).payload() - self.assertTrue(id_token.get('at_hash')) + self.assertTrue(id_token.get("at_hash")) def test_idtoken_sign_validation(self): """ @@ -629,19 +621,20 @@ def test_idtoken_sign_validation(self): the JOSE Header. """ SIGKEYS = self._get_keys() - RSAKEYS = [k for k in SIGKEYS if k.kty == 'RSA'] + RSAKEYS = [k for k in SIGKEYS if k.kty == "RSA"] code = self._create_code() post_data = self._auth_code_post_data(code=code.code) response = self._post_request(post_data) - response_dic = json.loads(response.content.decode('utf-8')) + response_dic = json.loads(response.content.decode("utf-8")) - JWS().verify_compact(response_dic['id_token'].encode('utf-8'), RSAKEYS) + JWS().verify_compact(response_dic["id_token"].encode("utf-8"), RSAKEYS) @override_settings( - OIDC_IDTOKEN_SUB_GENERATOR='oidc_provider.tests.app.utils.fake_sub_generator') + OIDC_IDTOKEN_SUB_GENERATOR="oidc_provider.tests.app.utils.fake_sub_generator" + ) def test_custom_sub_generator(self): """ Test custom function for setting OIDC_IDTOKEN_SUB_GENERATOR. @@ -652,35 +645,15 @@ def test_custom_sub_generator(self): response = self._post_request(post_data) - response_dic = json.loads(response.content.decode('utf-8')) - id_token = JWT().unpack(response_dic['id_token'].encode('utf-8')).payload() + response_dic = json.loads(response.content.decode("utf-8")) + id_token = JWT().unpack(response_dic["id_token"].encode("utf-8")).payload() - self.assertEqual(id_token.get('sub'), self.user.email) + self.assertEqual(id_token.get("sub"), self.user.email) @override_settings( - OIDC_IDTOKEN_PROCESSING_HOOK='oidc_provider.tests.app.utils.fake_idtoken_processing_hook') - def test_additional_idtoken_processing_hook(self): - """ - Test custom function for setting OIDC_IDTOKEN_PROCESSING_HOOK. - """ - code = self._create_code() - - post_data = self._auth_code_post_data(code=code.code) - - response = self._post_request(post_data) - - response_dic = json.loads(response.content.decode('utf-8')) - id_token = JWT().unpack(response_dic['id_token'].encode('utf-8')).payload() - - self.assertEqual(id_token.get('test_idtoken_processing_hook'), FAKE_RANDOM_STRING) - self.assertEqual(id_token.get('test_idtoken_processing_hook_user_email'), self.user.email) - - @override_settings( - OIDC_IDTOKEN_PROCESSING_HOOK=( - 'oidc_provider.tests.app.utils.fake_idtoken_processing_hook', - ) + OIDC_IDTOKEN_PROCESSING_HOOK=("oidc_provider.tests.app.utils.fake_idtoken_processing_hook",) ) - def test_additional_idtoken_processing_hook_one_element_in_tuple(self): + def test_additional_idtoken_processing_hook(self): """ Test custom function for setting OIDC_IDTOKEN_PROCESSING_HOOK. """ @@ -690,15 +663,15 @@ def test_additional_idtoken_processing_hook_one_element_in_tuple(self): response = self._post_request(post_data) - response_dic = json.loads(response.content.decode('utf-8')) - id_token = JWT().unpack(response_dic['id_token'].encode('utf-8')).payload() + response_dic = json.loads(response.content.decode("utf-8")) + id_token = JWT().unpack(response_dic["id_token"].encode("utf-8")).payload() - self.assertEqual(id_token.get('test_idtoken_processing_hook'), FAKE_RANDOM_STRING) - self.assertEqual(id_token.get('test_idtoken_processing_hook_user_email'), self.user.email) + self.assertEqual(id_token.get("test_idtoken_processing_hook"), FAKE_RANDOM_STRING) + self.assertEqual(id_token.get("test_idtoken_processing_hook_user_email"), self.user.email) @override_settings( OIDC_IDTOKEN_PROCESSING_HOOK=[ - 'oidc_provider.tests.app.utils.fake_idtoken_processing_hook', + "oidc_provider.tests.app.utils.fake_idtoken_processing_hook", ] ) def test_additional_idtoken_processing_hook_one_element_in_list(self): @@ -711,16 +684,16 @@ def test_additional_idtoken_processing_hook_one_element_in_list(self): response = self._post_request(post_data) - response_dic = json.loads(response.content.decode('utf-8')) - id_token = JWT().unpack(response_dic['id_token'].encode('utf-8')).payload() + response_dic = json.loads(response.content.decode("utf-8")) + id_token = JWT().unpack(response_dic["id_token"].encode("utf-8")).payload() - self.assertEqual(id_token.get('test_idtoken_processing_hook'), FAKE_RANDOM_STRING) - self.assertEqual(id_token.get('test_idtoken_processing_hook_user_email'), self.user.email) + self.assertEqual(id_token.get("test_idtoken_processing_hook"), FAKE_RANDOM_STRING) + self.assertEqual(id_token.get("test_idtoken_processing_hook_user_email"), self.user.email) @override_settings( OIDC_IDTOKEN_PROCESSING_HOOK=[ - 'oidc_provider.tests.app.utils.fake_idtoken_processing_hook', - 'oidc_provider.tests.app.utils.fake_idtoken_processing_hook2', + "oidc_provider.tests.app.utils.fake_idtoken_processing_hook", + "oidc_provider.tests.app.utils.fake_idtoken_processing_hook2", ] ) def test_additional_idtoken_processing_hook_two_elements_in_list(self): @@ -733,19 +706,19 @@ def test_additional_idtoken_processing_hook_two_elements_in_list(self): response = self._post_request(post_data) - response_dic = json.loads(response.content.decode('utf-8')) - id_token = JWT().unpack(response_dic['id_token'].encode('utf-8')).payload() + response_dic = json.loads(response.content.decode("utf-8")) + id_token = JWT().unpack(response_dic["id_token"].encode("utf-8")).payload() - self.assertEqual(id_token.get('test_idtoken_processing_hook'), FAKE_RANDOM_STRING) - self.assertEqual(id_token.get('test_idtoken_processing_hook_user_email'), self.user.email) + self.assertEqual(id_token.get("test_idtoken_processing_hook"), FAKE_RANDOM_STRING) + self.assertEqual(id_token.get("test_idtoken_processing_hook_user_email"), self.user.email) - self.assertEqual(id_token.get('test_idtoken_processing_hook2'), FAKE_RANDOM_STRING) - self.assertEqual(id_token.get('test_idtoken_processing_hook_user_email2'), self.user.email) + self.assertEqual(id_token.get("test_idtoken_processing_hook2"), FAKE_RANDOM_STRING) + self.assertEqual(id_token.get("test_idtoken_processing_hook_user_email2"), self.user.email) @override_settings( OIDC_IDTOKEN_PROCESSING_HOOK=( - 'oidc_provider.tests.app.utils.fake_idtoken_processing_hook', - 'oidc_provider.tests.app.utils.fake_idtoken_processing_hook2', + "oidc_provider.tests.app.utils.fake_idtoken_processing_hook", + "oidc_provider.tests.app.utils.fake_idtoken_processing_hook2", ) ) def test_additional_idtoken_processing_hook_two_elements_in_tuple(self): @@ -758,43 +731,41 @@ def test_additional_idtoken_processing_hook_two_elements_in_tuple(self): response = self._post_request(post_data) - response_dic = json.loads(response.content.decode('utf-8')) - id_token = JWT().unpack(response_dic['id_token'].encode('utf-8')).payload() + response_dic = json.loads(response.content.decode("utf-8")) + id_token = JWT().unpack(response_dic["id_token"].encode("utf-8")).payload() - self.assertEqual(id_token.get('test_idtoken_processing_hook'), FAKE_RANDOM_STRING) - self.assertEqual(id_token.get('test_idtoken_processing_hook_user_email'), self.user.email) + self.assertEqual(id_token.get("test_idtoken_processing_hook"), FAKE_RANDOM_STRING) + self.assertEqual(id_token.get("test_idtoken_processing_hook_user_email"), self.user.email) - self.assertEqual(id_token.get('test_idtoken_processing_hook2'), FAKE_RANDOM_STRING) - self.assertEqual(id_token.get('test_idtoken_processing_hook_user_email2'), self.user.email) + self.assertEqual(id_token.get("test_idtoken_processing_hook2"), FAKE_RANDOM_STRING) + self.assertEqual(id_token.get("test_idtoken_processing_hook_user_email2"), self.user.email) @override_settings( - OIDC_IDTOKEN_PROCESSING_HOOK=( - 'oidc_provider.tests.app.utils.fake_idtoken_processing_hook3')) + OIDC_IDTOKEN_PROCESSING_HOOK=("oidc_provider.tests.app.utils.fake_idtoken_processing_hook3") + ) def test_additional_idtoken_processing_hook_scope_available(self): """ Test scope is available in OIDC_IDTOKEN_PROCESSING_HOOK. """ - id_token = self._request_id_token_with_scope( - ['openid', 'email', 'profile', 'dummy']) + id_token = self._request_id_token_with_scope(["openid", "email", "profile", "dummy"]) self.assertEqual( - id_token.get('scope_of_token_passed_to_processing_hook'), - ['openid', 'email', 'profile', 'dummy']) + id_token.get("scope_of_token_passed_to_processing_hook"), + ["openid", "email", "profile", "dummy"], + ) @override_settings( - OIDC_IDTOKEN_PROCESSING_HOOK=( - 'oidc_provider.tests.app.utils.fake_idtoken_processing_hook4')) + OIDC_IDTOKEN_PROCESSING_HOOK=("oidc_provider.tests.app.utils.fake_idtoken_processing_hook4") + ) def test_additional_idtoken_processing_hook_kwargs(self): """ Test correct kwargs are passed to OIDC_IDTOKEN_PROCESSING_HOOK. """ - id_token = self._request_id_token_with_scope(['openid', 'profile']) - kwargs_passed = id_token.get('kwargs_passed_to_processing_hook') + id_token = self._request_id_token_with_scope(["openid", "profile"]) + kwargs_passed = id_token.get("kwargs_passed_to_processing_hook") assert kwargs_passed - self.assertTrue(kwargs_passed.get('token').startswith( - '") - self.assertEqual(set(kwargs_passed.keys()), {'token', 'request'}) + self.assertTrue(kwargs_passed.get("token").startswith("") + self.assertEqual(set(kwargs_passed.keys()), {"token", "request"}) def _request_id_token_with_scope(self, scope): code = self._create_code(scope) @@ -803,8 +774,8 @@ def _request_id_token_with_scope(self, scope): response = self._post_request(post_data) - response_dic = json.loads(response.content.decode('utf-8')) - id_token = JWT().unpack(response_dic['id_token'].encode('utf-8')).payload() + response_dic = json.loads(response.content.decode("utf-8")) + id_token = JWT().unpack(response_dic["id_token"].encode("utf-8")).payload() return id_token def test_pkce_parameters(self): @@ -812,19 +783,25 @@ def test_pkce_parameters(self): Test Proof Key for Code Exchange by OAuth Public Clients. https://tools.ietf.org/html/rfc7636 """ - code = create_code(user=self.user, client=self.client, - scope=['openid', 'email'], nonce=FAKE_NONCE, is_authentication=True, - code_challenge=FAKE_CODE_CHALLENGE, code_challenge_method='S256') + code = create_code( + user=self.user, + client=self.client, + scope=["openid", "email"], + nonce=FAKE_NONCE, + is_authentication=True, + code_challenge=FAKE_CODE_CHALLENGE, + code_challenge_method="S256", + ) code.save() post_data = self._auth_code_post_data(code=code.code) # Add parameters. - post_data['code_verifier'] = FAKE_CODE_VERIFIER + post_data["code_verifier"] = FAKE_CODE_VERIFIER response = self._post_request(post_data) - self.assertIn('access_token', json.loads(response.content.decode('utf-8'))) + self.assertIn("access_token", json.loads(response.content.decode("utf-8"))) def test_pkce_missing_code_verifier(self): """ @@ -832,22 +809,28 @@ def test_pkce_missing_code_verifier(self): fails when PKCE was used during the authorization request. """ - code = create_code(user=self.user, client=self.client, - scope=['openid', 'email'], nonce=FAKE_NONCE, is_authentication=True, - code_challenge=FAKE_CODE_CHALLENGE, code_challenge_method='S256') + code = create_code( + user=self.user, + client=self.client, + scope=["openid", "email"], + nonce=FAKE_NONCE, + is_authentication=True, + code_challenge=FAKE_CODE_CHALLENGE, + code_challenge_method="S256", + ) code.save() post_data = self._auth_code_post_data(code=code.code) - assert 'code_verifier' not in post_data + assert "code_verifier" not in post_data response = self._post_request(post_data) - assert json.loads(response.content.decode('utf-8')).get('error') == 'invalid_grant' + assert json.loads(response.content.decode("utf-8")).get("error") == "invalid_grant" @override_settings(OIDC_INTROSPECTION_VALIDATE_AUDIENCE_SCOPE=False) def test_client_credentials_grant_type(self): - fake_scopes_list = ['scopeone', 'scopetwo', INTROSPECTION_SCOPE] + fake_scopes_list = ["scopeone", "scopetwo", INTROSPECTION_SCOPE] # Add scope for this client. self.client.scope = fake_scopes_list @@ -855,139 +838,139 @@ def test_client_credentials_grant_type(self): post_data = self._client_credentials_post_data() response = self._post_request(post_data) - response_dict = json.loads(response.content.decode('utf-8')) + response_dict = json.loads(response.content.decode("utf-8")) # Ensure access token exists in the response, also check if scopes are # the ones we registered previously. - self.assertTrue('access_token' in response_dict) - self.assertEqual(' '.join(fake_scopes_list), response_dict['scope']) + self.assertTrue("access_token" in response_dict) + self.assertEqual(" ".join(fake_scopes_list), response_dict["scope"]) - access_token = response_dict['access_token'] + access_token = response_dict["access_token"] # Create a protected resource and test the access_token. - @require_http_methods(['GET']) + @require_http_methods(["GET"]) @protected_resource_view(fake_scopes_list) def protected_api(request, *args, **kwargs): - return JsonResponse({'protected': 'information'}, status=200) + return JsonResponse({"protected": "information"}, status=200) # Deploy view on some url. So, base url could be anything. - request = self.factory.get( - '/api/protected/?access_token={0}'.format(access_token)) + request = self.factory.get("/api/protected/?access_token={0}".format(access_token)) response = protected_api(request) - response_dict = json.loads(response.content.decode('utf-8')) + response_dict = json.loads(response.content.decode("utf-8")) self.assertEqual(response.status_code, 200) - self.assertTrue('protected' in response_dict) + self.assertTrue("protected" in response_dict) # Protected resource test ends here. # Verify access_token can be validated with token introspection response = self.request_client.post( - reverse('oidc_provider:token-introspection'), data={'token': access_token}, - **self._password_grant_auth_header()) + reverse("oidc_provider:token-introspection"), + data={"token": access_token}, + **self._password_grant_auth_header(), + ) self.assertEqual(response.status_code, 200) - response_dict = json.loads(response.content.decode('utf-8')) - self.assertTrue(response_dict.get('active')) + response_dict = json.loads(response.content.decode("utf-8")) + self.assertTrue(response_dict.get("active")) # End token introspection test # Clean scopes for this client. - self.client.scope = '' + self.client.scope = "" self.client.save() response = self._post_request(post_data) - response_dict = json.loads(response.content.decode('utf-8')) + response_dict = json.loads(response.content.decode("utf-8")) # It should fail when client does not have any scope added. self.assertEqual(400, response.status_code) - self.assertEqual('invalid_scope', response_dict['error']) + self.assertEqual("invalid_scope", response_dict["error"]) def test_printing_token_used_by_client_credentials_grant_type(self): # Add scope for this client. - self.client.scope = ['something'] + self.client.scope = ["something"] self.client.save() response = self._post_request(self._client_credentials_post_data()) - response_dict = json.loads(response.content.decode('utf-8')) - token = Token.objects.get(access_token=response_dict['access_token']) + response_dict = json.loads(response.content.decode("utf-8")) + token = Token.objects.get(access_token=response_dict["access_token"]) self.assertTrue(str(token)) @override_settings(OIDC_GRANT_TYPE_PASSWORD_ENABLE=True) def test_requested_scope(self): # GRANT_TYPE=PASSWORD response = self._post_request( - post_data=self._password_grant_post_data(['openid', 'invalid_scope']), - extras=self._password_grant_auth_header() + post_data=self._password_grant_post_data(["openid", "invalid_scope"]), + extras=self._password_grant_auth_header(), ) - response_dict = json.loads(response.content.decode('utf-8')) + response_dict = json.loads(response.content.decode("utf-8")) # It should fail when client requested an invalid scope. self.assertEqual(400, response.status_code) - self.assertEqual('invalid_scope', response_dict['error']) + self.assertEqual("invalid_scope", response_dict["error"]) # happy path: no scope response = self._post_request( - post_data=self._password_grant_post_data([]), - extras=self._password_grant_auth_header() + post_data=self._password_grant_post_data([]), extras=self._password_grant_auth_header() ) - response_dict = json.loads(response.content.decode('utf-8')) + response_dict = json.loads(response.content.decode("utf-8")) self.assertEqual(200, response.status_code) - self.assertEqual(TokenTestCase.SCOPE, response_dict['scope']) + self.assertEqual(TokenTestCase.SCOPE, response_dict["scope"]) # happy path: single scope response = self._post_request( - post_data=self._password_grant_post_data(['email']), - extras=self._password_grant_auth_header() + post_data=self._password_grant_post_data(["email"]), + extras=self._password_grant_auth_header(), ) - response_dict = json.loads(response.content.decode('utf-8')) + response_dict = json.loads(response.content.decode("utf-8")) self.assertEqual(200, response.status_code) - self.assertEqual('email', response_dict['scope']) + self.assertEqual("email", response_dict["scope"]) # happy path: multiple scopes response = self._post_request( - post_data=self._password_grant_post_data(['email', 'openid']), - extras=self._password_grant_auth_header() + post_data=self._password_grant_post_data(["email", "openid"]), + extras=self._password_grant_auth_header(), ) # GRANT_TYPE=CLIENT_CREDENTIALS - response_dict = json.loads(response.content.decode('utf-8')) + response_dict = json.loads(response.content.decode("utf-8")) self.assertEqual(200, response.status_code) - self.assertEqual('email openid', response_dict['scope']) + self.assertEqual("email openid", response_dict["scope"]) response = self._post_request( - post_data=self._client_credentials_post_data(['openid', 'invalid_scope']) + post_data=self._client_credentials_post_data(["openid", "invalid_scope"]) ) - response_dict = json.loads(response.content.decode('utf-8')) + response_dict = json.loads(response.content.decode("utf-8")) # It should fail when client requested an invalid scope. self.assertEqual(400, response.status_code) - self.assertEqual('invalid_scope', response_dict['error']) + self.assertEqual("invalid_scope", response_dict["error"]) # happy path: no scope response = self._post_request(post_data=self._client_credentials_post_data()) - response_dict = json.loads(response.content.decode('utf-8')) + response_dict = json.loads(response.content.decode("utf-8")) self.assertEqual(200, response.status_code) - self.assertEqual(TokenTestCase.SCOPE, response_dict['scope']) + self.assertEqual(TokenTestCase.SCOPE, response_dict["scope"]) # happy path: single scope - response = self._post_request(post_data=self._client_credentials_post_data(['email'])) + response = self._post_request(post_data=self._client_credentials_post_data(["email"])) - response_dict = json.loads(response.content.decode('utf-8')) + response_dict = json.loads(response.content.decode("utf-8")) self.assertEqual(200, response.status_code) - self.assertEqual('email', response_dict['scope']) + self.assertEqual("email", response_dict["scope"]) # happy path: multiple scopes response = self._post_request( - post_data=self._client_credentials_post_data(['email', 'openid']) + post_data=self._client_credentials_post_data(["email", "openid"]) ) - response_dict = json.loads(response.content.decode('utf-8')) + response_dict = json.loads(response.content.decode("utf-8")) self.assertEqual(200, response.status_code) - self.assertEqual('email openid', response_dict['scope']) + self.assertEqual("email openid", response_dict["scope"]) diff --git a/oidc_provider/tests/cases/test_utils.py b/oidc_provider/tests/cases/test_utils.py index 24c9ae65..c0e64ee7 100644 --- a/oidc_provider/tests/cases/test_utils.py +++ b/oidc_provider/tests/cases/test_utils.py @@ -1,4 +1,5 @@ import time +from datetime import date from datetime import datetime from hashlib import sha224 from unittest import mock @@ -116,6 +117,30 @@ def test_create_id_token_with_include_claims_setting_and_extra(self): self.assertIn("pizza", id_token_data) self.assertEqual(id_token_data["pizza"], "Margherita") + def test_token_saving_id_token_with_non_serialized_objects(self): + client = create_fake_client("code") + token = create_token(self.user, client, scope=["openid", "email", "pizza"]) + token.id_token = { + "iss": "http://localhost:8000/openid", + "sub": "1", + "aud": "test-aud", + "exp": 1733946683, + "iat": 1733946083, + "auth_time": 1733946082, + "email": "johndoe@example.com", + "email_verified": True, + "_extra_datetime": datetime(2002, 10, 15, 9), + "_extra_date": date(2000, 12, 25), + "_extra_object": object, + } + token.save() + + # A raw datetime/date object should be serialized. + self.assertEqual(token.id_token["_extra_datetime"], "2002-10-15 09:00:00") + self.assertEqual(token.id_token["_extra_date"], "2000-12-25") + # Even a raw object should be serialized wit str() at least. + self.assertEqual(token.id_token["_extra_object"], "") + class BrowserStateTest(TestCase): @override_settings(OIDC_UNAUTHENTICATED_SESSION_MANAGEMENT_KEY="my_static_key") diff --git a/tox.ini b/tox.ini index 088c1390..42e252d0 100644 --- a/tox.ini +++ b/tox.ini @@ -3,8 +3,9 @@ envlist= docs, py38-django{32,40,41,42}, py39-django{32,40,41,42}, - py310-django{32,40,41,42}, - py311-django{32,40,41,42}, + py310-django{32,40,41,42,50,51}, + py311-django{41,42,50,51}, + py312-django{42,50,51}, flake8 [testenv] @@ -21,6 +22,8 @@ deps = django40: django>=4.0,<4.1 django41: django>=4.1,<4.2 django42: django>=4.2,<4.3 + django50: django>=5.0,<5.1 + django51: django>=5.1,<5.2 commands = pytest --cov=oidc_provider {posargs}