diff --git a/mozilla_django_oidc/auth.py b/mozilla_django_oidc/auth.py index f8243fe3..1290a1e2 100644 --- a/mozilla_django_oidc/auth.py +++ b/mozilla_django_oidc/auth.py @@ -1,19 +1,16 @@ import base64 import hashlib -import json import logging import inspect +import jwt import requests from django.contrib.auth import get_user_model from django.contrib.auth.backends import ModelBackend from django.core.exceptions import ImproperlyConfigured, SuspiciousOperation from django.urls import reverse -from django.utils.encoding import force_bytes, smart_bytes, smart_str +from django.utils.encoding import force_bytes, smart_str from django.utils.module_loading import import_string -from josepy.b64 import b64decode -from josepy.jwk import JWK -from josepy.jws import JWS, Header from requests.auth import HTTPBasicAuth from requests.exceptions import HTTPError @@ -127,10 +124,10 @@ def update_user(self, user, claims): def _verify_jws(self, payload, key): """Verify the given JWS payload with the given key and return the payload""" - jws = JWS.from_compact(payload) + jws = jwt.get_unverified_header(payload) try: - alg = jws.signature.combined.alg.name + alg = jws["alg"] except KeyError: msg = "No alg value found in header" raise SuspiciousOperation(msg) @@ -142,21 +139,19 @@ def _verify_jws(self, payload, key): ) raise SuspiciousOperation(msg) - if isinstance(key, str): - # Use smart_bytes here since the key string comes from settings. - jwk = JWK.load(smart_bytes(key)) - else: - # The key is a json returned from the IDP JWKS endpoint. - jwk = JWK.from_json(key) - - if not jws.verify(jwk): + try: + # Maybe add a settings to enforce audiance validation + return jwt.decode(payload, key, algorithms=alg, options={"verify_aud": False}) + except jwt.DecodeError: msg = "JWS token verification failed." raise SuspiciousOperation(msg) - return jws.payload - def retrieve_matching_jwk(self, token): - """Get the signing key by exploring the JWKS endpoint of the OP.""" + """Get the signing key by exploring the JWKS endpoint of the OP. + + Don't use jwt.PyJWKClient()get_signing_key_from_jwt() because it doesn't check + the algorithm in case of multiple jwk with the same kid. + """ response_jwks = requests.get( self.OIDC_OP_JWKS_ENDPOINT, verify=self.get_settings("OIDC_VERIFY_SSL", True), @@ -167,32 +162,29 @@ def retrieve_matching_jwk(self, token): jwks = response_jwks.json() # Compute the current header from the given token to find a match - jws = JWS.from_compact(token) - json_header = jws.signature.protected - header = Header.json_loads(json_header) + jws = jwt.get_unverified_header(token) key = None for jwk in jwks["keys"]: if import_from_settings("OIDC_VERIFY_KID", True) and jwk[ "kid" - ] != smart_str(header.kid): + ] != smart_str(jws["kid"]): continue - if "alg" in jwk and jwk["alg"] != smart_str(header.alg): + if "alg" in jwk and jwk["alg"] != smart_str(jws["alg"]): continue key = jwk if key is None: raise SuspiciousOperation("Could not find a valid JWKS.") - return key + return jwt.PyJWK(key) def get_payload_data(self, token, key): """Helper method to get the payload of the JWT token.""" if self.get_settings("OIDC_ALLOW_UNSECURED_JWT", False): - header, payload_data, signature = token.split(b".") - header = json.loads(smart_str(b64decode(header))) + header = jwt.get_unverified_header(token) # If config allows unsecured JWTs check the header and return the decoded payload if "alg" in header and header["alg"] == "none": - return b64decode(payload_data) + return jwt.decode(token, options={"verify_signature": False}) # By default fallback to verify JWT signatures return self._verify_jws(token, key) @@ -201,7 +193,6 @@ def verify_token(self, token, **kwargs): """Validate the token signature.""" nonce = kwargs.get("nonce") - token = force_bytes(token) if self.OIDC_RP_SIGN_ALGO.startswith("RS") or self.OIDC_RP_SIGN_ALGO.startswith( "ES" ): @@ -212,16 +203,7 @@ def verify_token(self, token, **kwargs): else: key = self.OIDC_RP_CLIENT_SECRET - payload_data = self.get_payload_data(token, key) - - # The 'token' will always be a byte string since it's - # the result of base64.urlsafe_b64decode(). - # The payload is always the result of base64.urlsafe_b64decode(). - # In Python 3 and 2, that's always a byte string. - # In Python3.6, the json.loads() function can accept a byte string - # as it will automagically decode it to a unicode string before - # deserializing https://bugs.python.org/issue17909 - payload = json.loads(payload_data.decode("utf-8")) + payload = self.get_payload_data(token, key) token_nonce = payload.get("nonce") if self.get_settings("OIDC_USE_NONCE", True) and nonce != token_nonce: diff --git a/mozilla_django_oidc/utils.py b/mozilla_django_oidc/utils.py index a09e4ce1..0f571fa7 100644 --- a/mozilla_django_oidc/utils.py +++ b/mozilla_django_oidc/utils.py @@ -1,13 +1,13 @@ import logging import time import warnings +from base64 import urlsafe_b64decode, urlsafe_b64encode from hashlib import sha256 from urllib.request import parse_http_list, parse_keqv_list -# Make it obvious that these aren't the usual base64 functions -import josepy.b64 from django.conf import settings from django.core.exceptions import ImproperlyConfigured +from django.utils.encoding import force_bytes LOGGER = logging.getLogger(__name__) @@ -57,16 +57,12 @@ def is_authenticated(user): def base64_url_encode(bytes_like_obj): """Return a URL-Safe, base64 encoded version of bytes_like_obj - Implements base64urlencode as described in https://datatracker.ietf.org/doc/html/rfc7636#appendix-A + This function is not used by the OpenID client; it's just for testing PKCE related functions. """ - - s = josepy.b64.b64encode(bytes_like_obj).decode("ascii") # base64 encode - # the josepy base64 encoder (strips '='s padding) automatically - s = s.replace("+", "-") # 62nd char of encoding - s = s.replace("/", "_") # 63rd char of encoding - + s = urlsafe_b64encode(force_bytes(bytes_like_obj)).decode('utf-8') + s = s.rstrip("=") return s @@ -78,11 +74,14 @@ def base64_url_decode(string_like_obj): """ s = string_like_obj - s = s.replace("_", "/") # 63rd char of encoding - s = s.replace("-", "+") # 62nd char of encoding - b = josepy.b64.b64decode(s) # josepy base64 encoder (decodes without '='s padding) - - return b + size = len(s) % 4 + if size == 2: + s += '==' + elif size == 3: + s += '=' + elif size != 0: + raise ValueError('Invalid base64 string') + return urlsafe_b64decode(s.encode('utf-8')) def generate_code_challenge(code_verifier, method): diff --git a/setup.py b/setup.py index 781b11ea..071cb96a 100755 --- a/setup.py +++ b/setup.py @@ -33,7 +33,7 @@ install_requirements = [ "Django >= 3.2", - "josepy", + "pyjwt", "requests", "cryptography", ] diff --git a/tests/test_auth.py b/tests/test_auth.py index 5bdda157..345de645 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -1,18 +1,18 @@ import json from unittest.mock import Mock, call, patch +import jwt from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes, hmac, serialization -from cryptography.hazmat.primitives.asymmetric import ec +from cryptography.hazmat.primitives.asymmetric import ec, rsa from django.conf import settings from django.contrib.auth import get_user_model from django.core.exceptions import ImproperlyConfigured, SuspiciousOperation from django.test import RequestFactory, TestCase, override_settings from django.utils.encoding import force_bytes, smart_str -from josepy.b64 import b64encode -from josepy.jwa import ES256 from mozilla_django_oidc.auth import OIDCAuthenticationBackend, default_username_algo +from mozilla_django_oidc.utils import base64_url_encode User = get_user_model() @@ -69,13 +69,12 @@ def test_invalid_token(self, request_mock, token_mock): @override_settings(OIDC_ALLOW_UNSECURED_JWT=True) def test_allowed_unsecured_token(self): """Test payload data from unsecured token (allowed).""" - header = force_bytes(json.dumps({"alg": "none"})) - payload = force_bytes(json.dumps({"foo": "bar"})) + header = json.dumps({"alg": "none"}) + payload = {"foo": "bar"} + payload_data = json.dumps(payload) signature = "" - token = force_bytes( - "{}.{}.{}".format( - smart_str(b64encode(header)), smart_str(b64encode(payload)), signature - ) + token = "{}.{}.{}".format( + base64_url_encode(header), base64_url_encode(payload_data), signature ) extracted_payload = self.backend.get_payload_data(token, None) @@ -84,122 +83,105 @@ def test_allowed_unsecured_token(self): @override_settings(OIDC_ALLOW_UNSECURED_JWT=False) def test_disallowed_unsecured_token(self): """Test payload data from unsecured token (disallowed).""" - header = force_bytes(json.dumps({"alg": "none"})) - payload = force_bytes(json.dumps({"foo": "bar"})) + header = json.dumps({"alg": "none"}) + payload = {"foo": "bar"} + payload_data = json.dumps(payload) signature = "" - token = force_bytes( - "{}.{}.{}".format( - smart_str(b64encode(header)), smart_str(b64encode(payload)), signature - ) + token = "{}.{}.{}".format( + base64_url_encode(header), base64_url_encode(payload_data), signature ) - with self.assertRaises(KeyError): + with self.assertRaises(SuspiciousOperation): self.backend.get_payload_data(token, None) @override_settings(OIDC_ALLOW_UNSECURED_JWT=True) def test_allowed_unsecured_valid_token(self): """Test payload data from valid secured token (unsecured allowed).""" - header = force_bytes(json.dumps({"alg": "HS256", "typ": "JWT"})) - payload = force_bytes(json.dumps({"foo": "bar"})) + header = json.dumps({"alg": "HS256", "typ": "JWT"}) + payload = {"foo": "bar"} + payload_data = json.dumps(payload) # Compute signature key = b"mysupersecuretestkey" h = hmac.HMAC(key, hashes.SHA256(), backend=default_backend()) - msg = "{}.{}".format( - smart_str(b64encode(header)), smart_str(b64encode(payload)) - ) + msg = "{}.{}".format(base64_url_encode(header), base64_url_encode(payload_data)) h.update(force_bytes(msg)) - signature = b64encode(h.finalize()) + signature = base64_url_encode(h.finalize()) token = "{}.{}.{}".format( - smart_str(b64encode(header)), - smart_str(b64encode(payload)), - smart_str(signature), + base64_url_encode(header), base64_url_encode(payload_data), signature ) - token_bytes = force_bytes(token) key_text = smart_str(key) - output = self.backend.get_payload_data(token_bytes, key_text) + output = self.backend.get_payload_data(token, key_text) self.assertEqual(output, payload) @override_settings(OIDC_ALLOW_UNSECURED_JWT=False) def test_disallowed_unsecured_valid_token(self): """Test payload data from valid secure token (unsecured disallowed).""" - header = force_bytes(json.dumps({"alg": "HS256", "typ": "JWT"})) - payload = force_bytes(json.dumps({"foo": "bar"})) + header = json.dumps({"alg": "HS256", "typ": "JWT"}) + payload = {"foo": "bar"} + payload_data = json.dumps(payload) # Compute signature key = b"mysupersecuretestkey" h = hmac.HMAC(key, hashes.SHA256(), backend=default_backend()) - msg = "{}.{}".format( - smart_str(b64encode(header)), smart_str(b64encode(payload)) - ) + msg = "{}.{}".format(base64_url_encode(header), base64_url_encode(payload_data)) h.update(force_bytes(msg)) - signature = b64encode(h.finalize()) + signature = base64_url_encode(h.finalize()) token = "{}.{}.{}".format( - smart_str(b64encode(header)), - smart_str(b64encode(payload)), - smart_str(signature), + base64_url_encode(header), base64_url_encode(payload_data), signature ) - token_bytes = force_bytes(token) key_text = smart_str(key) - output = self.backend.get_payload_data(token_bytes, key_text) + output = self.backend.get_payload_data(token, key_text) self.assertEqual(output, payload) @override_settings(OIDC_ALLOW_UNSECURED_JWT=True) def test_allowed_unsecured_invalid_token(self): """Test payload data from invalid secure token (unsecured allowed).""" - header = force_bytes(json.dumps({"alg": "HS256", "typ": "JWT"})) - payload = force_bytes(json.dumps({"foo": "bar"})) + header = json.dumps({"alg": "HS256", "typ": "JWT"}) + payload = {"foo": "bar"} + payload_data = json.dumps(payload) # Compute signature key = b"mysupersecuretestkey" fake_key = b"mysupersecurefaketestkey" h = hmac.HMAC(key, hashes.SHA256(), backend=default_backend()) - msg = "{}.{}".format( - smart_str(b64encode(header)), smart_str(b64encode(payload)) - ) + msg = "{}.{}".format(base64_url_encode(header), base64_url_encode(payload_data)) h.update(force_bytes(msg)) - signature = b64encode(h.finalize()) + signature = base64_url_encode(h.finalize()) token = "{}.{}.{}".format( - smart_str(b64encode(header)), - smart_str(b64encode(payload)), - smart_str(signature), + base64_url_encode(header), base64_url_encode(payload_data), signature ) - token_bytes = force_bytes(token) key_text = smart_str(fake_key) with self.assertRaises(SuspiciousOperation) as ctx: - self.backend.get_payload_data(token_bytes, key_text) + self.backend.get_payload_data(token, key_text) self.assertEqual(ctx.exception.args[0], "JWS token verification failed.") @override_settings(OIDC_ALLOW_UNSECURED_JWT=False) def test_disallowed_unsecured_invalid_token(self): """Test payload data from invalid secure token (unsecured disallowed).""" - header = force_bytes(json.dumps({"alg": "HS256", "typ": "JWT"})) - payload = force_bytes(json.dumps({"foo": "bar"})) + header = json.dumps({"alg": "HS256", "typ": "JWT"}) + payload = {"foo": "bar"} + payload_data = json.dumps(payload) # Compute signature key = b"mysupersecuretestkey" fake_key = b"mysupersecurefaketestkey" h = hmac.HMAC(key, hashes.SHA256(), backend=default_backend()) - msg = "{}.{}".format( - smart_str(b64encode(header)), smart_str(b64encode(payload)) - ) + msg = "{}.{}".format(base64_url_encode(header), base64_url_encode(payload_data)) h.update(force_bytes(msg)) - signature = b64encode(h.finalize()) + signature = base64_url_encode(h.finalize()) token = "{}.{}.{}".format( - smart_str(b64encode(header)), - smart_str(b64encode(payload)), - smart_str(signature), + base64_url_encode(header), base64_url_encode(payload_data), signature ) - token_bytes = force_bytes(token) key_text = smart_str(fake_key) with self.assertRaises(SuspiciousOperation) as ctx: - self.backend.get_payload_data(token_bytes, key_text) + self.backend.get_payload_data(token, key_text) self.assertEqual(ctx.exception.args[0], "JWS token verification failed.") def test_get_user(self): @@ -566,7 +548,7 @@ def test_jwt_decode_params(self, request_mock, jws_mock): auth_request = RequestFactory().get("/foo", {"code": "foo", "state": "bar"}) auth_request.session = {} - jws_mock.return_value = json.dumps({"aud": "audience"}).encode("utf-8") + jws_mock.return_value = {"aud": "audience"} get_json_mock = Mock() get_json_mock.json.return_value = { "nickname": "username", @@ -580,7 +562,7 @@ def test_jwt_decode_params(self, request_mock, jws_mock): } request_mock.post.return_value = post_json_mock self.backend.authenticate(request=auth_request) - calls = [call(force_bytes("token"), "client_secret")] + calls = [call("token", "client_secret")] jws_mock.assert_has_calls(calls) @override_settings(OIDC_VERIFY_JWT=False) @@ -592,7 +574,7 @@ def test_jwt_decode_params_verify_false(self, request_mock, jws_mock): auth_request = RequestFactory().get("/foo", {"code": "foo", "state": "bar"}) auth_request.session = {} - jws_mock.return_value = json.dumps({"aud": "audience"}).encode("utf-8") + jws_mock.return_value = {"aud": "audience"} get_json_mock = Mock() get_json_mock.json.return_value = { "nickname": "username", @@ -605,7 +587,7 @@ def test_jwt_decode_params_verify_false(self, request_mock, jws_mock): "access_token": "access_token", } request_mock.post.return_value = post_json_mock - calls = [call(force_bytes("token"), "client_secret")] + calls = [call("token", "client_secret")] self.backend.authenticate(request=auth_request) jws_mock.assert_has_calls(calls) @@ -614,9 +596,7 @@ def test_jwt_decode_params_verify_false(self, request_mock, jws_mock): def test_jwt_failed_nonce(self, jws_mock): """Test Nonce verification.""" - jws_mock.return_value = json.dumps({"nonce": "foobar", "aud": "aud"}).encode( - "utf-8" - ) + jws_mock.return_value = {"nonce": "foobar", "aud": "aud"} id_token = "my_token" with self.assertRaisesMessage( SuspiciousOperation, "JWT Nonce verification failed." @@ -632,7 +612,7 @@ def test_create_user_disabled(self, request_mock, jws_mock): auth_request = RequestFactory().get("/foo", {"code": "foo", "state": "bar"}) auth_request.session = {} - jws_mock.return_value = json.dumps({"nonce": "nonce"}).encode("utf-8") + jws_mock.return_value = {"nonce": "nonce"} get_json_mock = Mock() get_json_mock.json.return_value = { "nickname": "a_username", @@ -656,7 +636,7 @@ def test_create_user_enabled(self, request_mock, jws_mock): auth_request.session = {} self.assertEqual(User.objects.filter(email="email@example.com").exists(), False) - jws_mock.return_value = json.dumps({"nonce": "nonce"}).encode("utf-8") + jws_mock.return_value = {"nonce": "nonce"} get_json_mock = Mock() get_json_mock.json.return_value = { "nickname": "a_username", @@ -685,7 +665,7 @@ def test_custom_username_algo(self, request_mock, jws_mock, algo_mock): self.assertEqual(User.objects.filter(email="email@example.com").exists(), False) algo_mock.return_value = "username_algo" - jws_mock.return_value = json.dumps({"nonce": "nonce"}).encode("utf-8") + jws_mock.return_value = {"nonce": "nonce"} get_json_mock = Mock() get_json_mock.json.return_value = { "nickname": "a_username", @@ -715,7 +695,7 @@ def test_custom_username_algo_dotted_path(self, request_mock, jws_mock): auth_request.session = {} self.assertEqual(User.objects.filter(email="email@example.com").exists(), False) - jws_mock.return_value = json.dumps({"nonce": "nonce"}).encode("utf-8") + jws_mock.return_value = {"nonce": "nonce"} get_json_mock = Mock() get_json_mock.json.return_value = { "nickname": "a_username", @@ -745,7 +725,7 @@ def test_dotted_username_algo_callback_with_claims(self, request_mock, jws_mock) auth_request.session = {} self.assertEqual(User.objects.filter(email="email@example.com").exists(), False) - jws_mock.return_value = json.dumps({"nonce": "nonce"}).encode("utf-8") + jws_mock.return_value = {"nonce": "nonce"} domain = "django.con" get_json_mock = Mock() get_json_mock.json.return_value = { @@ -775,7 +755,7 @@ def test_duplicate_emails_exact(self, request_mock, jws_mock): User.objects.create(username="user1", email="email@example.com") User.objects.create(username="user2", email="email@example.com") - jws_mock.return_value = json.dumps({"nonce": "nonce"}).encode("utf-8") + jws_mock.return_value = {"nonce": "nonce"} get_json_mock = Mock() get_json_mock.json.return_value = { "nickname": "a_username", @@ -800,7 +780,7 @@ def test_duplicate_emails_case_mismatch(self, request_mock, jws_mock): User.objects.create(username="user1", email="email@example.com") User.objects.create(username="user2", email="eMaIl@ExAmPlE.cOm") - jws_mock.return_value = json.dumps({"nonce": "nonce"}).encode("utf-8") + jws_mock.return_value = {"nonce": "nonce"} get_json_mock = Mock() get_json_mock.json.return_value = { "nickname": "a_username", @@ -834,7 +814,7 @@ def update_user(user, claims): update_user_mock.side_effect = update_user - jws_mock.return_value = json.dumps({"nonce": "nonce"}).encode("utf-8") + jws_mock.return_value = {"nonce": "nonce"} get_json_mock = Mock() get_json_mock.json.return_value = { "nickname": "a_username", @@ -860,19 +840,34 @@ class OIDCAuthenticationBackendRS256WithKeyTestCase(TestCase): @override_settings(OIDC_RP_CLIENT_ID="example_id") @override_settings(OIDC_RP_CLIENT_SECRET="client_secret") @override_settings(OIDC_RP_SIGN_ALGO="RS256") - @override_settings(OIDC_RP_IDP_SIGN_KEY="sign_key") - def setUp(self): - self.backend = OIDCAuthenticationBackend() - @override_settings(OIDC_USE_NONCE=False) - @patch("mozilla_django_oidc.auth.OIDCAuthenticationBackend._verify_jws") @patch("mozilla_django_oidc.auth.requests") - def test_jwt_verify_sign_key(self, request_mock, jws_mock): + def test_jwt_verify_sign_key(self, request_mock): """Test jwt verification signature.""" + + # Generate a private key to create a test token with + key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + + # Make the public key available through the JWKS response + public_key = smart_str(key.public_key().public_bytes( + serialization.Encoding.PEM, + serialization.PublicFormat.PKCS1, + )) + + with override_settings(OIDC_RP_IDP_SIGN_KEY=public_key): + backend = OIDCAuthenticationBackend() + + # Generate id_token + header = { + "typ": "JWT", + "alg": "RS256", + } + data = {"name": "John Doe", "test": "test_jwt_verify_sign_key"} + id_token = jwt.encode(payload=data, key=key, algorithm="RS256", headers=header) + auth_request = RequestFactory().get("/foo", {"code": "foo", "state": "bar"}) auth_request.session = {} - jws_mock.return_value = json.dumps({"aud": "audience"}).encode("utf-8") get_json_mock = Mock() get_json_mock.json.return_value = { "nickname": "username", @@ -881,13 +876,11 @@ def test_jwt_verify_sign_key(self, request_mock, jws_mock): request_mock.get.return_value = get_json_mock post_json_mock = Mock(status_code=200) post_json_mock.json.return_value = { - "id_token": "token", + "id_token": id_token, "access_token": "access_token", } request_mock.post.return_value = post_json_mock - self.backend.authenticate(request=auth_request) - calls = [call(force_bytes("token"), "sign_key")] - jws_mock.assert_has_calls(calls) + self.assertIsNotNone(backend.authenticate(request=auth_request)) class OIDCAuthenticationBackendRS256WithJwksEndpointTestCase(TestCase): @@ -921,7 +914,7 @@ def test_jwt_verify_sign_key_calls(self, request_mock, jwk_mock, jws_mock): } jwk_mock.return_value = jwk_mock_ret - jws_mock.return_value = json.dumps({"aud": "audience"}).encode("utf-8") + jws_mock.return_value = {"aud": "audience"} get_json_mock = Mock() get_json_mock.json.return_value = { "nickname": "username", @@ -935,7 +928,7 @@ def test_jwt_verify_sign_key_calls(self, request_mock, jwk_mock, jws_mock): } request_mock.post.return_value = post_json_mock self.backend.authenticate(request=auth_request) - calls = [call(force_bytes("token"), jwk_mock_ret)] + calls = [call("token", jwk_mock_ret)] jws_mock.assert_has_calls(calls) @patch("mozilla_django_oidc.auth.requests") @@ -947,38 +940,36 @@ def test_retrieve_matching_jwk(self, mock_requests): "keys": [ { "alg": "RS256", + "e": "AQAB", "kid": "foobar", + "kty": "RSA", + "n": "radom_value", }, { "alg": "RS512", + "e": "AQAB", "kid": "foobar512", + "kty": "RSA", + "n": "radom_value", }, ] } mock_requests.get.return_value = get_json_mock - header = force_bytes( - json.dumps({"alg": "RS256", "typ": "JWT", "kid": "foobar"}) - ) - payload = force_bytes(json.dumps({"foo": "bar"})) + header = json.dumps({"alg": "RS256", "typ": "JWT", "kid": "foobar"}) + payload = json.dumps({"foo": "bar"}) # Compute signature key = b"mysupersecuretestkey" h = hmac.HMAC(key, hashes.SHA256(), backend=default_backend()) - msg = "{}.{}".format( - smart_str(b64encode(header)), smart_str(b64encode(payload)) - ) + msg = "{}.{}".format(base64_url_encode(header), base64_url_encode(payload)) h.update(force_bytes(msg)) - signature = b64encode(h.finalize()) + signature = base64_url_encode(h.finalize()) - token = "{}.{}.{}".format( - smart_str(b64encode(header)), - smart_str(b64encode(payload)), - smart_str(signature), - ) + token = "{}.{}.{}".format(base64_url_encode(header), base64_url_encode(payload), signature) - jwk_key = self.backend.retrieve_matching_jwk(force_bytes(token)) - self.assertEqual(jwk_key, get_json_mock.json.return_value["keys"][0]) + jwk_key = self.backend.retrieve_matching_jwk(token) + self.assertEqual(jwk_key._jwk_data, get_json_mock.json.return_value["keys"][0]) @patch("mozilla_django_oidc.auth.requests") def test_retrieve_matching_jwk_same_kid(self, mock_requests): @@ -989,42 +980,43 @@ def test_retrieve_matching_jwk_same_kid(self, mock_requests): "keys": [ { "alg": "RS512", + "e": "AQAB", "kid": "foobar", + "kty": "RSA", + "n": "radom_value", }, { "alg": "RS384", + "e": "AQAB", "kid": "foobar", + "kty": "RSA", + "n": "radom_value", }, { "alg": "RS256", + "e": "AQAB", "kid": "foobar", + "kty": "RSA", + "n": "radom_value", }, ] } mock_requests.get.return_value = get_json_mock - header = force_bytes( - json.dumps({"alg": "RS256", "typ": "JWT", "kid": "foobar"}) - ) - payload = force_bytes(json.dumps({"foo": "bar"})) + header = json.dumps({"alg": "RS256", "typ": "JWT", "kid": "foobar"}) + payload = json.dumps({"foo": "bar"}) # Compute signature key = b"mysupersecuretestkey" h = hmac.HMAC(key, hashes.SHA256(), backend=default_backend()) - msg = "{}.{}".format( - smart_str(b64encode(header)), smart_str(b64encode(payload)) - ) + msg = "{}.{}".format(base64_url_encode(header), base64_url_encode(payload)) h.update(force_bytes(msg)) - signature = b64encode(h.finalize()) + signature = base64_url_encode(h.finalize()) - token = "{}.{}.{}".format( - smart_str(b64encode(header)), - smart_str(b64encode(payload)), - smart_str(signature), - ) + token = "{}.{}.{}".format(base64_url_encode(header), base64_url_encode(payload), signature) - jwk_key = self.backend.retrieve_matching_jwk(force_bytes(token)) - self.assertEqual(jwk_key, get_json_mock.json.return_value["keys"][2]) + jwk_key = self.backend.retrieve_matching_jwk(token) + self.assertEqual(jwk_key._jwk_data, get_json_mock.json.return_value["keys"][2]) @patch("mozilla_django_oidc.auth.requests") def test_retrieve_mismatcing_jwk_alg(self, mock_requests): @@ -1034,33 +1026,30 @@ def test_retrieve_mismatcing_jwk_alg(self, mock_requests): get_json_mock.json.return_value = { "keys": [ { - "alg": "foo", - "kid": "bar", - } + "alg": "RS256", + "e": "AQAB", + "kid": "foobar", + "kty": "RSA", + "n": "radom_value", + }, ] } mock_requests.get.return_value = get_json_mock - header = force_bytes(json.dumps({"alg": "HS256", "typ": "JWT", "kid": "bar"})) - payload = force_bytes(json.dumps({"foo": "bar"})) + header = json.dumps({"alg": "HS256", "typ": "JWT", "kid": "foobar"}) + payload = json.dumps({"foo": "bar"}) # Compute signature key = b"mysupersecuretestkey" h = hmac.HMAC(key, hashes.SHA256(), backend=default_backend()) - msg = "{}.{}".format( - smart_str(b64encode(header)), smart_str(b64encode(payload)) - ) + msg = "{}.{}".format(base64_url_encode(header), base64_url_encode(payload)) h.update(force_bytes(msg)) - signature = b64encode(h.finalize()) + signature = base64_url_encode(h.finalize()) - token = "{}.{}.{}".format( - smart_str(b64encode(header)), - smart_str(b64encode(payload)), - smart_str(signature), - ) + token = "{}.{}.{}".format(base64_url_encode(header), base64_url_encode(payload), signature) with self.assertRaises(SuspiciousOperation) as ctx: - self.backend.retrieve_matching_jwk(force_bytes(token)) + self.backend.retrieve_matching_jwk(token) self.assertEqual(ctx.exception.args[0], "Could not find a valid JWKS.") @@ -1072,33 +1061,30 @@ def test_retrieve_mismatcing_jwk_kid(self, mock_requests): get_json_mock.json.return_value = { "keys": [ { - "alg": "HS256", + "alg": "RS256", + "e": "AQAB", "kid": "foobar", - } + "kty": "RSA", + "n": "radom_value", + }, ] } mock_requests.get.return_value = get_json_mock - header = force_bytes(json.dumps({"alg": "HS256", "typ": "JWT", "kid": "bar"})) - payload = force_bytes(json.dumps({"foo": "bar"})) + header = json.dumps({"alg": "HS256", "typ": "JWT", "kid": "bar"}) + payload = json.dumps({"foo": "bar"}) # Compute signature key = b"mysupersecuretestkey" h = hmac.HMAC(key, hashes.SHA256(), backend=default_backend()) - msg = "{}.{}".format( - smart_str(b64encode(header)), smart_str(b64encode(payload)) - ) + msg = "{}.{}".format(base64_url_encode(header), base64_url_encode(payload)) h.update(force_bytes(msg)) - signature = b64encode(h.finalize()) + signature = base64_url_encode(h.finalize()) - token = "{}.{}.{}".format( - smart_str(b64encode(header)), - smart_str(b64encode(payload)), - smart_str(signature), - ) + token = "{}.{}.{}".format(base64_url_encode(header), base64_url_encode(payload), signature) with self.assertRaises(SuspiciousOperation) as ctx: - self.backend.retrieve_matching_jwk(force_bytes(token)) + self.backend.retrieve_matching_jwk(token) self.assertEqual(ctx.exception.args[0], "Could not find a valid JWKS.") @@ -1110,63 +1096,62 @@ def test_retrieve_jwk_optional_alg(self, mock_requests): get_json_mock.json.return_value = { "keys": [ { + "e": "AQAB", "kid": "kid", + "kty": "RSA", + "n": "radom_value", } ] } mock_requests.get.return_value = get_json_mock - header = force_bytes(json.dumps({"alg": "HS256", "typ": "JWT", "kid": "kid"})) - payload = force_bytes(json.dumps({"foo": "bar"})) + header = json.dumps({"alg": "HS256", "typ": "JWT", "kid": "kid"}) + payload = json.dumps({"foo": "bar"}) # Compute signature key = b"mysupersecuretestkey" h = hmac.HMAC(key, hashes.SHA256(), backend=default_backend()) - msg = "{}.{}".format( - smart_str(b64encode(header)), smart_str(b64encode(payload)) - ) + msg = "{}.{}".format(base64_url_encode(header), base64_url_encode(payload)) h.update(force_bytes(msg)) - signature = b64encode(h.finalize()) + signature = base64_url_encode(h.finalize()) - token = "{}.{}.{}".format( - smart_str(b64encode(header)), - smart_str(b64encode(payload)), - smart_str(signature), - ) + token = "{}.{}.{}".format(base64_url_encode(header), base64_url_encode(payload), signature) - jwk_key = self.backend.retrieve_matching_jwk(force_bytes(token)) - self.assertEqual(jwk_key, get_json_mock.json.return_value["keys"][0]) + jwk_key = self.backend.retrieve_matching_jwk(token) + self.assertEqual(jwk_key._jwk_data, get_json_mock.json.return_value["keys"][0]) @patch("mozilla_django_oidc.auth.requests") def test_retrieve_not_existing_jwk(self, mock_requests): """Test retrieving jwk that doesn't exist.""" get_json_mock = Mock() - get_json_mock.json.return_value = {"keys": [{"alg": "RS256", "kid": "kid"}]} + get_json_mock.json.return_value = { + "keys": [ + { + "alg": "RS256", + "e": "AQAB", + "kid": "kid", + "kty": "RSA", + "n": "radom_value", + }, + ] + } mock_requests.get.return_value = get_json_mock - header = force_bytes( - json.dumps({"alg": "RS256", "typ": "JWT", "kid": "differentkid"}) - ) - payload = force_bytes(json.dumps({"foo": "bar"})) + header = json.dumps({"alg": "RS256", "typ": "JWT", "kid": "differentkid"}) + payload = json.dumps({"foo": "bar"}) # Compute signature key = b"mysupersecuretestkey" h = hmac.HMAC(key, hashes.SHA256(), backend=default_backend()) - msg = "{}.{}".format( - smart_str(b64encode(header)), smart_str(b64encode(payload)) - ) + msg = "{}.{}".format(base64_url_encode(header), base64_url_encode(payload)) h.update(force_bytes(msg)) - signature = b64encode(h.finalize()) + signature = base64_url_encode(h.finalize()) - token = "{}.{}.{}".format( - smart_str(b64encode(header)), - smart_str(b64encode(payload)), - smart_str(signature), - ) + token = "{}.{}.{}".format(base64_url_encode(header), base64_url_encode(payload), signature) with self.assertRaises(SuspiciousOperation) as ctx: - self.backend.retrieve_matching_jwk(force_bytes(token)) + self.backend.retrieve_matching_jwk(token) self.assertEqual(ctx.exception.args[0], "Could not find a valid JWKS.") @@ -1235,11 +1220,6 @@ def test_es256_alg_verification(self, mock_requests): # Generate a private key to create a test token with private_key = ec.generate_private_key(ec.SECP256R1, default_backend()) - private_key_pem = private_key.private_bytes( - serialization.Encoding.PEM, - serialization.PrivateFormat.PKCS8, - serialization.NoEncryption(), - ) # Make the public key available through the JWKS response public_numbers = private_key.public_key().public_numbers() @@ -1251,37 +1231,21 @@ def test_es256_alg_verification(self, mock_requests): "kty": "EC", "alg": "ES256", "use": "sig", - "x": smart_str(b64encode(public_numbers.x.to_bytes(32, "big"))), - "y": smart_str(b64encode(public_numbers.y.to_bytes(32, "big"))), + "x": base64_url_encode(public_numbers.x.to_bytes(32, "big")), + "y": base64_url_encode(public_numbers.y.to_bytes(32, "big")), "crv": "P-256", } ] } mock_requests.get.return_value = get_json_mock - header = force_bytes( - json.dumps( - { - "typ": "JWT", - "alg": "ES256", - "kid": "eckid", - }, - ) - ) + header = { + "typ": "JWT", + "alg": "ES256", + "kid": "eckid", + } data = {"name": "John Doe", "test": "test_es256_alg_verification"} - - h = hmac.HMAC(private_key_pem, hashes.SHA256(), backend=default_backend()) - msg = "{}.{}".format( - smart_str(b64encode(header)), - smart_str(b64encode(force_bytes(json.dumps(data)))), - ) - h.update(force_bytes(msg)) - - signature = b64encode(ES256.sign(private_key, force_bytes(msg))) - token = "{}.{}".format( - msg, - smart_str(signature), - ) + token = jwt.encode(payload=data, key=private_key, algorithm="ES256", headers=header) # Verify the token created with the private key by using the JWKS endpoint, # where the public numbers are.