From 826a6804ac1ce0f1f2990bfedef7fd59a165d573 Mon Sep 17 00:00:00 2001 From: M Umar Khan Date: Wed, 5 Apr 2023 23:51:03 +0500 Subject: [PATCH] chore: add pyjwt requirement --- lti_consumer/lti_1p3/key_handlers.py | 146 +++++++----------- lti_consumer/lti_1p3/tests/test_consumer.py | 46 ++++-- .../lti_1p3/tests/test_key_handlers.py | 70 ++++----- lti_consumer/lti_1p3/tests/utils.py | 8 +- lti_consumer/plugin/views.py | 11 +- .../tests/unit/plugin/test_proctoring.py | 4 +- requirements/base.in | 1 + requirements/base.txt | 2 + requirements/ci.txt | 2 + requirements/dev.txt | 2 + requirements/quality.txt | 2 + requirements/test.txt | 2 + 12 files changed, 144 insertions(+), 152 deletions(-) diff --git a/lti_consumer/lti_1p3/key_handlers.py b/lti_consumer/lti_1p3/key_handlers.py index 64048c5d..76f94779 100644 --- a/lti_consumer/lti_1p3/key_handlers.py +++ b/lti_consumer/lti_1p3/key_handlers.py @@ -4,16 +4,14 @@ This handles validating messages sent by the tool and generating access token with LTI scopes. """ -import codecs import copy -import time import json +import math +import time +import sys +import jwt from Cryptodome.PublicKey import RSA -from jwkest import BadSignature, BadSyntax, WrongNumberOfParts, jwk -from jwkest.jwk import RSAKey, load_jwks_from_url -from jwkest.jws import JWS, NoSuitableSigningKeys -from jwkest.jwt import JWT from . import exceptions @@ -47,14 +45,9 @@ def __init__(self, public_key=None, keyset_url=None): # Import from public key if public_key: try: - new_key = RSAKey(use='sig') - - # Unescape key before importing it - raw_key = codecs.decode(public_key, 'unicode_escape') - # Import Key and save to internal state - new_key.load_key(RSA.import_key(raw_key)) - self.public_key = new_key + algo_obj = jwt.get_algorithm_by_name('RS256') + self.public_key = algo_obj.prepare_key(public_key) except ValueError as err: raise exceptions.InvalidRsaKey() from err @@ -69,7 +62,7 @@ def _get_keyset(self, kid=None): if self.keyset_url: try: - keys = load_jwks_from_url(self.keyset_url) + keys = jwt.PyJWKClient(self.keyset_url).get_jwk_set() except Exception as err: # Broad Exception is required here because jwkest raises # an Exception object explicitly. @@ -78,13 +71,13 @@ def _get_keyset(self, kid=None): raise exceptions.NoSuitableKeys() from err keyset.extend(keys) - if self.public_key and kid: - # Fill in key id of stored key. - # This is needed because if the JWS is signed with a - # key with a kid, pyjwkest doesn't match them with - # keys without kid (kid=None) and fails verification - self.public_key.kid = kid - + if self.public_key: + if kid: + # Fill in key id of stored key. + # This is needed because if the JWS is signed with a + # key with a kid, pyjwkest doesn't match them with + # keys without kid (kid=None) and fails verification + self.public_key.kid = kid # Add to keyset keyset.append(self.public_key) @@ -100,32 +93,24 @@ def validate_and_decode(self, token): iss, sub, exp, aud and jti claims. """ try: - # Get KID from JWT header - jwt = JWT().unpack(token) - - # Verify message signature - message = JWS().verify_compact( - token, - keys=self._get_keyset( - jwt.headers.get('kid') - ) - ) - - # If message is valid, check expiration from JWT - if 'exp' in message and message['exp'] < time.time(): - raise exceptions.TokenSignatureExpired() - - # TODO: Validate other JWT claims - - # Else returns decoded message - return message - - except NoSuitableSigningKeys as err: - raise exceptions.NoSuitableKeys() from err - except (BadSyntax, WrongNumberOfParts) as err: - raise exceptions.MalformedJwtToken() from err - except BadSignature as err: - raise exceptions.BadJwtSignature() from err + key_set = self._get_keyset() + if not key_set: + raise exceptions.NoSuitableKeys() + for i in range(len(key_set)): + try: + message = jwt.decode( + token, + key=key_set[i], + algorithms=['RS256', 'RS512',], + options={'verify_signature': True} + ) + return message + except Exception: + if i == len(key_set) - 1: + raise + except Exception as token_error: + exc_info = sys.exc_info() + raise jwt.InvalidTokenError(exc_info[2]) from token_error class PlatformKeyHandler: @@ -144,14 +129,8 @@ def __init__(self, key_pem, kid=None): if key_pem: # Import JWK from RSA key try: - self.key = RSAKey( - # Using the same key ID as client id - # This way we can easily serve multiple public - # keys on teh same endpoint and keep all - # LTI 1.3 blocks working - kid=kid, - key=RSA.import_key(key_pem) - ) + algo = jwt.get_algorithm_by_name('RS256') + self.key = algo.prepare_key(key_pem) except ValueError as err: raise exceptions.InvalidRsaKey() from err @@ -167,28 +146,26 @@ def encode_and_sign(self, message, expiration=None): # Set iat and exp if expiration is set if expiration: _message.update({ - "iat": int(round(time.time())), - "exp": int(round(time.time()) + expiration), + "iat": int(math.floor(time.time())), + "exp": int(math.floor(time.time()) + expiration), }) # The class instance that sets up the signing operation # An RS 256 key is required for LTI 1.3 - _jws = JWS(_message, alg="RS256", cty="JWT") - - # Encode and sign LTI message - return _jws.sign_compact([self.key]) + return jwt.encode(_message, self.key, algorithm="RS256") def get_public_jwk(self): """ Export Public JWK """ - public_keys = jwk.KEYS() + jwk = {"keys": []} # Only append to keyset if a key exists if self.key: - public_keys.append(self.key) - - return json.loads(public_keys.dump_jwks()) + algo_obj = jwt.get_algorithm_by_name('RS256') + public_key = algo_obj.prepare_key(self.key).public_key() + jwk['keys'].append(json.loads(algo_obj.to_jwk(public_key))) + return jwk def validate_and_decode(self, token, iss=None, aud=None): """ @@ -197,29 +174,22 @@ def validate_and_decode(self, token, iss=None, aud=None): Validates a token sent by the tool using the platform's RSA Key. Optionally validate iss and aud claims if provided. """ + if not self.key: + raise exceptions.RsaKeyNotSet() try: - # Verify message signature - message = JWS().verify_compact(token, keys=[self.key]) - - # If message is valid, check expiration from JWT - if 'exp' in message and message['exp'] < time.time(): - raise exceptions.TokenSignatureExpired() - - # Validate issuer claim (if present) - if iss: - if 'iss' not in message or message['iss'] != iss: - raise exceptions.InvalidClaimValue('The required iss claim is either missing or does ' - 'not match the expected iss value.') - - # Validate audience claim (if present) - if aud: - if 'aud' not in message or aud not in message['aud']: - raise exceptions.InvalidClaimValue('The required aud claim is missing.') - - # Else return token contents + message = jwt.decode( + token, + key=self.key.public_key(), + audience=aud, + issuer=iss, + algorithms=['RS256', 'RS512'], + options={ + 'verify_signature': True, + 'verify_aud': True if aud else False + } + ) return message - except NoSuitableSigningKeys as err: - raise exceptions.NoSuitableKeys() from err - except BadSyntax as err: - raise exceptions.MalformedJwtToken() from err + except Exception as token_error: + exc_info = sys.exc_info() + raise jwt.InvalidTokenError(exc_info[2]) from token_error diff --git a/lti_consumer/lti_1p3/tests/test_consumer.py b/lti_consumer/lti_1p3/tests/test_consumer.py index b702a883..7a341e00 100644 --- a/lti_consumer/lti_1p3/tests/test_consumer.py +++ b/lti_consumer/lti_1p3/tests/test_consumer.py @@ -2,16 +2,16 @@ Unit tests for LTI 1.3 consumer implementation """ -import json from unittest.mock import patch from urllib.parse import parse_qs, urlparse import ddt +import jwt +import sys from Cryptodome.PublicKey import RSA from django.test.testcases import TestCase from edx_django_utils.cache import get_cache_key, TieredCache -from jwkest.jwk import load_jwks -from jwkest.jws import JWS +from jwt.api_jwk import PyJWKSet from lti_consumer.data import Lti1p3LaunchData from lti_consumer.lti_1p3 import exceptions @@ -34,7 +34,9 @@ STATE = "ABCD" # Consider storing a fixed key RSA_KEY_ID = "1" -RSA_KEY = RSA.generate(2048).export_key('PEM') +RSA_KEY = RSA.generate(2048) +RSA_PRIVATE_KEY = RSA_KEY.export_key('PEM') +RSA_PUBLIC_KEY = RSA_KEY.public_key().export_key('PEM') # Test classes @@ -53,11 +55,11 @@ def setUp(self): lti_launch_url=LAUNCH_URL, client_id=CLIENT_ID, deployment_id=DEPLOYMENT_ID, - rsa_key=RSA_KEY, + rsa_key=RSA_PRIVATE_KEY, rsa_key_id=RSA_KEY_ID, redirect_uris=REDIRECT_URIS, # Use the same key for testing purposes - tool_key=RSA_KEY + tool_key=RSA_PUBLIC_KEY ) def _setup_lti_launch_data(self): @@ -102,9 +104,25 @@ def _decode_token(self, token): This also tests the public keyset function. """ public_keyset = self.lti_consumer.get_public_keyset() - key_set = load_jwks(json.dumps(public_keyset)) - - return JWS().verify_compact(token, keys=key_set) + keyset = PyJWKSet.from_dict(public_keyset).keys + + for i in range(len(keyset)): + try: + message = jwt.decode( + token, + key=keyset[i].key, + algorithms=['RS256', 'RS512'], + options={ + 'verify_signature': True, + 'verify_aud': False + } + ) + return message + except Exception as token_error: + if i < len(keyset) - 1: + continue + exc_info = sys.exc_info() + raise jwt.InvalidTokenError(exc_info[2]) from token_error @ddt.data( ({"client_id": CLIENT_ID, "redirect_uri": LAUNCH_URL, "nonce": STATE, "state": STATE}, True), @@ -526,7 +544,7 @@ def test_access_token_invalid_jwt(self): "scope": "", } - with self.assertRaises(exceptions.MalformedJwtToken): + with self.assertRaises(jwt.exceptions.InvalidTokenError): self.lti_consumer.access_token(request_data) def test_access_token(self): @@ -641,11 +659,11 @@ def setUp(self): lti_launch_url=LAUNCH_URL, client_id=CLIENT_ID, deployment_id=DEPLOYMENT_ID, - rsa_key=RSA_KEY, + rsa_key=RSA_PRIVATE_KEY, rsa_key_id=RSA_KEY_ID, redirect_uris=REDIRECT_URIS, # Use the same key for testing purposes - tool_key=RSA_KEY + tool_key=RSA_PUBLIC_KEY ) self.preflight_response = {} @@ -884,11 +902,11 @@ def setUp(self): lti_launch_url=LAUNCH_URL, client_id=CLIENT_ID, deployment_id=DEPLOYMENT_ID, - rsa_key=RSA_KEY, + rsa_key=RSA_PRIVATE_KEY, rsa_key_id=RSA_KEY_ID, redirect_uris=REDIRECT_URIS, # Use the same key for testing purposes - tool_key=RSA_KEY + tool_key=RSA_PUBLIC_KEY ) self.preflight_response = {} diff --git a/lti_consumer/lti_1p3/tests/test_key_handlers.py b/lti_consumer/lti_1p3/tests/test_key_handlers.py index e087ad8d..5a0d053b 100644 --- a/lti_consumer/lti_1p3/tests/test_key_handlers.py +++ b/lti_consumer/lti_1p3/tests/test_key_handlers.py @@ -3,9 +3,12 @@ """ import json +import math +import time from unittest.mock import patch import ddt +import jwt from Cryptodome.PublicKey import RSA from django.test.testcases import TestCase from jwkest import BadSignature @@ -106,18 +109,17 @@ def test_empty_rsa_key(self): {'keys': []} ) - # pylint: disable=unused-argument - @patch('time.time', return_value=1000) - def test_validate_and_decode(self, mock_time): + def test_validate_and_decode(self): """ Test validate and decode with all parameters. """ + expiration = 1000 signed_token = self.key_handler.encode_and_sign( { "iss": "test-issuer", "aud": "test-aud", }, - expiration=1000 + expiration=expiration ) self.assertEqual( @@ -125,14 +127,12 @@ def test_validate_and_decode(self, mock_time): { "iss": "test-issuer", "aud": "test-aud", - "iat": 1000, - "exp": 2000 + "iat": int(math.floor(time.time())), + "exp": int(math.floor(time.time()) + expiration), } ) - # pylint: disable=unused-argument - @patch('time.time', return_value=1000) - def test_validate_and_decode_expired(self, mock_time): + def test_validate_and_decode_expired(self): """ Test validate and decode with all parameters. """ @@ -141,7 +141,7 @@ def test_validate_and_decode_expired(self, mock_time): expiration=-10 ) - with self.assertRaises(exceptions.TokenSignatureExpired): + with self.assertRaises(jwt.InvalidTokenError): self.key_handler.validate_and_decode(signed_token) def test_validate_and_decode_invalid_iss(self): @@ -150,7 +150,7 @@ def test_validate_and_decode_invalid_iss(self): """ signed_token = self.key_handler.encode_and_sign({"iss": "wrong"}) - with self.assertRaises(exceptions.InvalidClaimValue): + with self.assertRaises(jwt.InvalidTokenError): self.key_handler.validate_and_decode(signed_token, iss="right") def test_validate_and_decode_invalid_aud(self): @@ -159,14 +159,14 @@ def test_validate_and_decode_invalid_aud(self): """ signed_token = self.key_handler.encode_and_sign({"aud": "wrong"}) - with self.assertRaises(exceptions.InvalidClaimValue): + with self.assertRaises(jwt.InvalidTokenError): self.key_handler.validate_and_decode(signed_token, aud="right") def test_validate_and_decode_no_jwt(self): """ Test validate and decode with invalid JWT. """ - with self.assertRaises(exceptions.MalformedJwtToken): + with self.assertRaises(jwt.InvalidTokenError): self.key_handler.validate_and_decode("1.2.3") def test_validate_and_decode_no_keys(self): @@ -174,10 +174,10 @@ def test_validate_and_decode_no_keys(self): Test validate and decode when no keys are available. """ signed_token = self.key_handler.encode_and_sign({}) - # Changing the KID so it doesn't match - self.key_handler.key.kid = "invalid_kid" - with self.assertRaises(exceptions.NoSuitableKeys): + self.key_handler.key = None + + with self.assertRaises(exceptions.RsaKeyNotSet): self.key_handler.validate_and_decode(signed_token) @@ -192,12 +192,10 @@ def setUp(self): self.rsa_key_id = "1" # Generate RSA and save exports - rsa_key = RSA.generate(2048) - self.key = RSAKey( - key=rsa_key, - kid=self.rsa_key_id - ) - self.public_key = rsa_key.publickey().export_key() + rsa_key = RSA.generate(2048).export_key('PEM') + algo_obj = jwt.get_algorithm_by_name('RS256') + self.key = algo_obj.prepare_key(rsa_key) + self.public_key = self.key.public_key() # Key handler self.key_handler = None @@ -247,9 +245,7 @@ def test_get_keyset_with_pub_key(self): self.rsa_key_id ) - # pylint: disable=unused-argument - @patch('time.time', return_value=1000) - def test_validate_and_decode(self, mock_time): + def test_validate_and_decode(self): """ Check that the validate and decode works. """ @@ -258,7 +254,7 @@ def test_validate_and_decode(self, mock_time): message = { "test": "test_message", "iat": 1000, - "exp": 1200, + "exp": int(math.floor(time.time()) + 1000), } signed = create_jwt(self.key, message) @@ -266,9 +262,7 @@ def test_validate_and_decode(self, mock_time): decoded_message = self.key_handler.validate_and_decode(signed) self.assertEqual(decoded_message, message) - # pylint: disable=unused-argument - @patch('time.time', return_value=1000) - def test_validate_and_decode_expired(self, mock_time): + def test_validate_and_decode_expired(self): """ Check that the validate and decode raises when signature expires. """ @@ -282,7 +276,7 @@ def test_validate_and_decode_expired(self, mock_time): signed = create_jwt(self.key, message) # Decode and check results - with self.assertRaises(exceptions.TokenSignatureExpired): + with self.assertRaises(jwt.InvalidTokenError): self.key_handler.validate_and_decode(signed) def test_validate_and_decode_no_keys(self): @@ -299,14 +293,13 @@ def test_validate_and_decode_no_keys(self): signed = create_jwt(self.key, message) # Decode and check results - with self.assertRaises(exceptions.NoSuitableKeys): + with self.assertRaises(jwt.InvalidTokenError): key_handler.validate_and_decode(signed) - @patch("lti_consumer.lti_1p3.key_handlers.JWS.verify_compact") - def test_validate_and_decode_bad_signature(self, mock_verify_compact): - mock_verify_compact.side_effect = BadSignature() - - key_handler = ToolKeyHandler() + @patch("lti_consumer.lti_1p3.key_handlers.jwt.decode") + def test_validate_and_decode_bad_signature(self, mock_jwt_decode): + mock_jwt_decode.side_effect = Exception() + self._setup_key_handler() message = { "test": "test_message", @@ -315,6 +308,5 @@ def test_validate_and_decode_bad_signature(self, mock_verify_compact): } signed = create_jwt(self.key, message) - # Decode and check results - with self.assertRaises(exceptions.BadJwtSignature): - key_handler.validate_and_decode(signed) + with self.assertRaises(jwt.InvalidTokenError): + self.key_handler.validate_and_decode(signed) diff --git a/lti_consumer/lti_1p3/tests/utils.py b/lti_consumer/lti_1p3/tests/utils.py index 3a76d162..3aae56ca 100644 --- a/lti_consumer/lti_1p3/tests/utils.py +++ b/lti_consumer/lti_1p3/tests/utils.py @@ -1,12 +1,14 @@ """ Test utils """ -from jwkest.jws import JWS +import jwt def create_jwt(key, message): """ Uses private key to create a JWS from a dict. """ - jws = JWS(message, alg="RS256", cty="JWT") - return jws.sign_compact([key]) + token = jwt.encode( + message, key, algorithm='RS256' + ) + return token diff --git a/lti_consumer/plugin/views.py b/lti_consumer/plugin/views.py index c6503b28..f9a8a77a 100644 --- a/lti_consumer/plugin/views.py +++ b/lti_consumer/plugin/views.py @@ -4,6 +4,7 @@ import logging import urllib +import jwt from django.contrib.auth import get_user_model from django.core.exceptions import ObjectDoesNotExist, PermissionDenied, ValidationError from django.db import transaction @@ -15,7 +16,6 @@ from django.views.decorators.http import require_http_methods from django_filters.rest_framework import DjangoFilterBackend from edx_django_utils.cache import TieredCache, get_cache_key -from jwkest.jwt import JWT, BadSyntax from opaque_keys import InvalidKeyError from opaque_keys.edx.keys import UsageKey from rest_framework import status, viewsets @@ -740,13 +740,12 @@ def start_proctoring_assessment_endpoint(request): token = request.POST.get('JWT') try: - jwt = JWT().unpack(token) - except BadSyntax: + decoded_jwt = jwt.decode(token, options={'verify_signature': False}) + except Exception: return render(request, 'html/lti_proctoring_start_error.html', status=HTTP_400_BAD_REQUEST) - jwt_payload = jwt.payload() - iss = jwt_payload.get('iss') - resource_link_id = jwt_payload.get('https://purl.imsglobal.org/spec/lti/claim/resource_link', {}).get('id') + iss = decoded_jwt.get('iss') + resource_link_id = decoded_jwt.get('https://purl.imsglobal.org/spec/lti/claim/resource_link', {}).get('id') try: lti_config = LtiConfiguration.objects.get(lti_1p3_client_id=iss) diff --git a/lti_consumer/tests/unit/plugin/test_proctoring.py b/lti_consumer/tests/unit/plugin/test_proctoring.py index 25a62456..7d048df8 100644 --- a/lti_consumer/tests/unit/plugin/test_proctoring.py +++ b/lti_consumer/tests/unit/plugin/test_proctoring.py @@ -134,8 +134,8 @@ def test_valid_token(self): def test_unparsable_token(self): """Tests that a call to the start_assessment_endpoint with an unparsable token results in a 400 response.""" - with patch("lti_consumer.plugin.views.JWT.unpack") as mock_jwt_unpack_method: - mock_jwt_unpack_method.side_effect = BadSyntax(value="", msg="") + with patch("lti_consumer.plugin.views.jwt.decode") as mock_jwt_decode_method: + mock_jwt_decode_method.side_effect = Exception response = self.client.post( self.url, diff --git a/requirements/base.in b/requirements/base.in index 0be36300..74b62d6d 100644 --- a/requirements/base.in +++ b/requirements/base.in @@ -10,6 +10,7 @@ mako lazy XBlock xblock-utils +pyjwt pycryptodomex pyjwkest edx-opaque-keys[django] diff --git a/requirements/base.txt b/requirements/base.txt index ee977ac7..ac9e685b 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -85,6 +85,8 @@ pycryptodomex==3.17 # pyjwkest pyjwkest==1.4.2 # via -r requirements/base.in +pyjwt==2.6.0 + # via -r requirements/base.in pymongo==3.13.0 # via edx-opaque-keys pynacl==1.5.0 diff --git a/requirements/ci.txt b/requirements/ci.txt index a456b95a..5223567b 100644 --- a/requirements/ci.txt +++ b/requirements/ci.txt @@ -313,6 +313,8 @@ pygments==2.14.0 # rich pyjwkest==1.4.2 # via -r requirements/test.txt +pyjwt==2.6.0 + # via -r requirements/test.txt pylint==2.17.2 # via # -r requirements/test.txt diff --git a/requirements/dev.txt b/requirements/dev.txt index a41ec30f..61caa679 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -126,6 +126,8 @@ pycryptodomex==3.17 # pyjwkest pyjwkest==1.4.2 # via -r requirements/base.txt +pyjwt==2.6.0 + # via -r requirements/base.txt pymongo==3.13.0 # via # -r requirements/base.txt diff --git a/requirements/quality.txt b/requirements/quality.txt index 7a6d980e..515be10b 100644 --- a/requirements/quality.txt +++ b/requirements/quality.txt @@ -189,6 +189,8 @@ pycryptodomex==3.17 # pyjwkest pyjwkest==1.4.2 # via -r requirements/base.txt +pyjwt==2.6.0 + # via -r requirements/base.txt pylint==2.17.2 # via # -r requirements/quality.in diff --git a/requirements/test.txt b/requirements/test.txt index ad419332..ad45e0c5 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -230,6 +230,8 @@ pygments==2.14.0 # rich pyjwkest==1.4.2 # via -r requirements/base.txt +pyjwt==2.6.0 + # via -r requirements/base.txt pylint==2.17.2 # via # edx-lint