Skip to content

Commit

Permalink
fix: remove useless tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mumarkhan999 committed Jan 8, 2025
1 parent 65ffe29 commit 6549014
Show file tree
Hide file tree
Showing 12 changed files with 143 additions and 172 deletions.
82 changes: 47 additions & 35 deletions lti_consumer/lti_1p3/key_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
import copy
import json
import math
import time
import sys
import time
import logging

import jwt
from Cryptodome.PublicKey import RSA
from edx_django_utils.monitoring import function_trace
from jwt.api_jwk import PyJWK

from . import exceptions

Expand Down Expand Up @@ -52,7 +52,9 @@ def __init__(self, public_key=None, keyset_url=None):
try:
# Import Key and save to internal state
algo_obj = jwt.get_algorithm_by_name('RS256')
self.public_key = algo_obj.prepare_key(public_key)
public_key = algo_obj.prepare_key(public_key)
public_jwk = json.loads(algo_obj.to_jwk(public_key))
self.public_key = PyJWK.from_dict(public_jwk)
except ValueError as err:
log.warning(
'An error was encountered while loading the LTI tool\'s key from the public key. '
Expand Down Expand Up @@ -82,15 +84,16 @@ def _get_keyset(self, kid=None):
'The RSA keys could not be loaded.'
)
raise exceptions.NoSuitableKeys() from err
keyset.extend(keys)
keyset.extend(keys.keys)

if self.public_key and kid:
# Fill in key id of stored key.
# This is needed because if the JWS is signed with a
# key with a kid, pyjwkest doesn't match them with
# keys without kid (kid=None) and fails verification
self.public_key.kid = kid

if self.public_key:
if kid:
# Fill in key id of stored key.
# This is needed because if the JWS is signed with a
# key with a kid, pyjwkest doesn't match them with
# keys without kid (kid=None) and fails verification
self.public_key.kid = kid
# Add to keyset
keyset.append(self.public_key)

Expand All @@ -105,25 +108,29 @@ def validate_and_decode(self, token):
The authorization server decodes the JWT and MUST validate the values for the
iss, sub, exp, aud and jti claims.
"""
try:
key_set = self._get_keyset()
if not key_set:
raise exceptions.NoSuitableKeys()
for i in range(len(key_set)):
try:
message = jwt.decode(
token,
key=key_set[i],
algorithms=['RS256', 'RS512',],
options={'verify_signature': True}
)
return message
except Exception:
if i == len(key_set) - 1:
raise
except Exception as token_error:
exc_info = sys.exc_info()
raise jwt.InvalidTokenError(exc_info[2]) from token_error
key_set = self._get_keyset()

for i, obj in enumerate(key_set):
try:
if hasattr(obj.key, 'public_key'):
key = obj.key.public_key()
else:
key = obj.key
message = jwt.decode(
token,
key,
algorithms=['RS256', 'RS512',],
options={
'verify_signature': True,
'verify_aud': False
}
)
return message
except Exception: # pylint: disable=broad-except
if i == len(key_set) - 1:
raise

raise exceptions.NoSuitableKeys()


class PlatformKeyHandler:
Expand All @@ -144,7 +151,10 @@ def __init__(self, key_pem, kid=None):
# Import JWK from RSA key
try:
algo = jwt.get_algorithm_by_name('RS256')
self.key = algo.prepare_key(key_pem)
private_key = algo.prepare_key(key_pem)
private_jwk = json.loads(algo.to_jwk(private_key))
private_jwk['kid'] = kid
self.key = PyJWK.from_dict(private_jwk)
except ValueError as err:
log.warning(
'An error was encountered while loading the LTI platform\'s key. '
Expand Down Expand Up @@ -175,7 +185,7 @@ def encode_and_sign(self, message, expiration=None):

# The class instance that sets up the signing operation
# An RS 256 key is required for LTI 1.3
return jwt.encode(_message, self.key, algorithm="RS256")
return jwt.encode(_message, self.key.key, algorithm="RS256")

def get_public_jwk(self):
"""
Expand All @@ -186,11 +196,11 @@ def get_public_jwk(self):
# Only append to keyset if a key exists
if self.key:
algo_obj = jwt.get_algorithm_by_name('RS256')
public_key = algo_obj.prepare_key(self.key).public_key()
public_key = algo_obj.prepare_key(self.key.key).public_key()
jwk['keys'].append(json.loads(algo_obj.to_jwk(public_key)))
return jwk

def validate_and_decode(self, token, iss=None, aud=None):
def validate_and_decode(self, token, iss=None, aud=None, exp=True):
"""
Check if a platform token is valid, and return allowed scopes.
Expand All @@ -202,13 +212,15 @@ def validate_and_decode(self, token, iss=None, aud=None):
try:
message = jwt.decode(
token,
key=self.key.public_key(),
key=self.key.key.public_key(),
audience=aud,
issuer=iss,
algorithms=['RS256', 'RS512'],
options={
'verify_signature': True,
'verify_aud': True if aud else False
'verify_exp': bool(exp),
'verify_iss': bool(iss),
'verify_aud': bool(aud)
}
)
return message
Expand Down
33 changes: 14 additions & 19 deletions lti_consumer/lti_1p3/tests/test_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import ddt
import jwt
import sys
from Cryptodome.PublicKey import RSA
from django.conf import settings
from django.test.testcases import TestCase
Expand Down Expand Up @@ -115,30 +114,26 @@ def _get_lti_message(

def _decode_token(self, token):
"""
Checks for a valid signarute and decodes JWT signed LTI message
Checks for a valid signature and decodes JWT signed LTI message
This also tests the public keyset function.
"""
public_keyset = self.lti_consumer.get_public_keyset()
keyset = PyJWKSet.from_dict(public_keyset).keys

for i in range(len(keyset)):
try:
message = jwt.decode(
token,
key=keyset[i].key,
algorithms=['RS256', 'RS512'],
options={
'verify_signature': True,
'verify_aud': False
}
)
return message
except Exception as token_error:
if i < len(keyset) - 1:
continue
exc_info = sys.exc_info()
raise jwt.InvalidTokenError(exc_info[2]) from token_error
for obj in keyset:
message = jwt.decode(
token,
key=obj.key,
algorithms=['RS256', 'RS512'],
options={
'verify_signature': True,
'verify_aud': False
}
)
return message

return exceptions.NoSuitableKeys()

@ddt.data(
({"client_id": CLIENT_ID, "redirect_uri": LAUNCH_URL, "nonce": STATE, "state": STATE}, True),
Expand Down
71 changes: 23 additions & 48 deletions lti_consumer/lti_1p3/tests/test_key_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,14 @@
import json
import math
import time
from datetime import datetime, timezone
from unittest.mock import patch

import ddt
import jwt
from Cryptodome.PublicKey import RSA
from django.test.testcases import TestCase
from jwkest import BadSignature
from jwkest.jwk import RSAKey, load_jwks
from jwkest.jws import JWS, NoSuitableSigningKeys, UnknownAlgorithm

from jwt.api_jwk import PyJWK

from lti_consumer.lti_1p3 import exceptions
from lti_consumer.lti_1p3.key_handlers import PlatformKeyHandler, ToolKeyHandler
Expand All @@ -39,16 +37,13 @@ def setUp(self):
kid=self.rsa_key_id
)

def _decode_token(self, token):
def _decode_token(self, token, exp=True):
"""
Checks for a valid signarute and decodes JWT signed LTI message
Checks for a valid signature and decodes JWT signed LTI message
This also touches the public keyset method.
"""
public_keyset = self.key_handler.get_public_jwk()
key_set = load_jwks(json.dumps(public_keyset))

return JWS().verify_compact(token, keys=key_set)
return self.key_handler.validate_and_decode(token, exp=exp)

def test_encode_and_sign(self):
"""
Expand All @@ -59,7 +54,7 @@ def test_encode_and_sign(self):
}
signed_token = self.key_handler.encode_and_sign(message)
self.assertEqual(
self._decode_token(signed_token),
self._decode_token(signed_token, exp=False),
message
)

Expand All @@ -72,45 +67,21 @@ def test_encode_and_sign_with_exp(self, mock_time):
message = {
"test": "test"
}

expiration = int(datetime.now(tz=timezone.utc).timestamp())
signed_token = self.key_handler.encode_and_sign(
message,
expiration=1000
expiration=expiration
)

self.assertEqual(
self._decode_token(signed_token),
{
"test": "test",
"iat": 1000,
"exp": 2000
"exp": expiration + 1000
}
)

def test_encode_and_sign_no_suitable_keys(self):
"""
Test if an exception is raised when there are no suitable keys when signing the JWT.
"""
message = {
"test": "test"
}

with patch('lti_consumer.lti_1p3.key_handlers.JWS.sign_compact', side_effect=NoSuitableSigningKeys):
with self.assertRaises(exceptions.NoSuitableKeys):
self.key_handler.encode_and_sign(message)

def test_encode_and_sign_unknown_algorithm(self):
"""
Test if an exception is raised when the signing algorithm is unknown when signing the JWT.
"""
message = {
"test": "test"
}

with patch('lti_consumer.lti_1p3.key_handlers.JWS.sign_compact', side_effect=UnknownAlgorithm):
with self.assertRaises(exceptions.MalformedJwtToken):
self.key_handler.encode_and_sign(message)

def test_invalid_rsa_key(self):
"""
Check that class raises when trying to import invalid RSA Key.
Expand Down Expand Up @@ -217,10 +188,14 @@ def setUp(self):
self.rsa_key_id = "1"

# Generate RSA and save exports
rsa_key = RSA.generate(2048).export_key('PEM')
rsa_key = RSA.generate(2048)
algo_obj = jwt.get_algorithm_by_name('RS256')
self.key = algo_obj.prepare_key(rsa_key)
self.public_key = self.key.public_key()
private_key = algo_obj.prepare_key(rsa_key.export_key())
private_jwk = json.loads(algo_obj.to_jwk(private_key))
private_jwk['kid'] = self.rsa_key_id
self.key = PyJWK.from_dict(private_jwk)

self.public_key = rsa_key.publickey().export_key()

# Key handler
self.key_handler = None
Expand Down Expand Up @@ -318,20 +293,20 @@ def test_validate_and_decode_no_keys(self):
signed = create_jwt(self.key, message)

# Decode and check results
with self.assertRaises(jwt.InvalidTokenError):
with self.assertRaises(exceptions.NoSuitableKeys):
key_handler.validate_and_decode(signed)

@patch("lti_consumer.lti_1p3.key_handlers.jwt.decode")
def test_validate_and_decode_bad_signature(self, mock_jwt_decode):
mock_jwt_decode.side_effect = Exception()
def test_validate_and_decode_bad_signature(self):
self._setup_key_handler()

message = {
"test": "test_message",
"iat": 1000,
"exp": 1200,
"exp": int(datetime.now(tz=timezone.utc).timestamp()) + 1000,
}
signed = create_jwt(self.key, message)
# Tamper with the token
bad_signed = signed[:-1] + "X"

with self.assertRaises(jwt.InvalidTokenError):
self.key_handler.validate_and_decode(signed)
with self.assertRaises(jwt.exceptions.InvalidSignatureError):
self.key_handler.validate_and_decode(bad_signed)
2 changes: 1 addition & 1 deletion lti_consumer/lti_1p3/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@ def create_jwt(key, message):
Uses private key to create a JWS from a dict.
"""
token = jwt.encode(
message, key, algorithm='RS256'
message, key.key, algorithm='RS256'
)
return token
Loading

0 comments on commit 6549014

Please sign in to comment.