Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace pyjwkest with pyjwt package #349

Merged
merged 4 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 64 additions & 131 deletions lti_consumer/lti_1p3/key_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,16 @@
This handles validating messages sent by the tool and generating
access token with LTI scopes.
"""
import codecs
import copy
import time
import json
import math
import sys
import time
import logging

from Cryptodome.PublicKey import RSA
import jwt
from edx_django_utils.monitoring import function_trace
from jwkest import BadSignature, BadSyntax, WrongNumberOfParts, jwk
from jwkest.jwk import RSAKey, load_jwks_from_url
from jwkest.jws import JWS, NoSuitableSigningKeys, UnknownAlgorithm
from jwkest.jwt import JWT
from jwt.api_jwk import PyJWK

from . import exceptions

Expand Down Expand Up @@ -52,14 +50,11 @@ def __init__(self, public_key=None, keyset_url=None):
# Import from public key
if public_key:
try:
new_key = RSAKey(use='sig')

# Unescape key before importing it
raw_key = codecs.decode(public_key, 'unicode_escape')

# Import Key and save to internal state
new_key.load_key(RSA.import_key(raw_key))
self.public_key = new_key
algo_obj = jwt.get_algorithm_by_name('RS256')
public_key = algo_obj.prepare_key(public_key)
public_jwk = json.loads(algo_obj.to_jwk(public_key))
self.public_key = PyJWK.from_dict(public_jwk)
except ValueError as err:
log.warning(
'An error was encountered while loading the LTI tool\'s key from the public key. '
Expand All @@ -78,7 +73,7 @@ def _get_keyset(self, kid=None):

if self.keyset_url:
try:
keys = load_jwks_from_url(self.keyset_url)
keys = jwt.PyJWKClient(self.keyset_url).get_jwk_set()
except Exception as err:
# Broad Exception is required here because jwkest raises
# an Exception object explicitly.
Expand All @@ -89,7 +84,7 @@ def _get_keyset(self, kid=None):
'The RSA keys could not be loaded.'
)
raise exceptions.NoSuitableKeys() from err
keyset.extend(keys)
keyset.extend(keys.keys)

if self.public_key and kid:
# Fill in key id of stored key.
Expand All @@ -98,6 +93,7 @@ def _get_keyset(self, kid=None):
# keys without kid (kid=None) and fails verification
self.public_key.kid = kid

if self.public_key:
# Add to keyset
keyset.append(self.public_key)

Expand All @@ -112,49 +108,29 @@ def validate_and_decode(self, token):
The authorization server decodes the JWT and MUST validate the values for the
iss, sub, exp, aud and jti claims.
"""
try:
# Get KID from JWT header
jwt = JWT().unpack(token)
key_set = self._get_keyset()

# Verify message signature
message = JWS().verify_compact(
token,
keys=self._get_keyset(
jwt.headers.get('kid')
)
)

# If message is valid, check expiration from JWT
if 'exp' in message and message['exp'] < time.time():
log.warning(
'An error was encountered while verifying the OAuth 2.0 Client-Credentials Grant JWT. '
'The JWT has expired.'
for i, obj in enumerate(key_set):
try:
if hasattr(obj.key, 'public_key'):
key = obj.key.public_key()
else:
key = obj.key
message = jwt.decode(
token,
key,
algorithms=['RS256', 'RS512',],
options={
'verify_signature': True,
'verify_aud': False
}
)
raise exceptions.TokenSignatureExpired()
return message
except Exception: # pylint: disable=broad-except
if i == len(key_set) - 1:
raise

# TODO: Validate other JWT claims

# Else returns decoded message
return message

except NoSuitableSigningKeys as err:
log.warning(
'An error was encountered while verifying the OAuth 2.0 Client-Credentials Grant JWT. '
'There is no suitable signing key.'
)
raise exceptions.NoSuitableKeys() from err
except (BadSyntax, WrongNumberOfParts) as err:
log.warning(
'An error was encountered while verifying the OAuth 2.0 Client-Credentials Grant JWT. '
'The JWT is malformed.'
)
raise exceptions.MalformedJwtToken() from err
except BadSignature as err:
log.warning(
'An error was encountered while verifying the OAuth 2.0 Client-Credentials Grant JWT. '
'The JWT signature is incorrect.'
)
raise exceptions.BadJwtSignature() from err
raise exceptions.NoSuitableKeys()


class PlatformKeyHandler:
Expand All @@ -174,14 +150,11 @@ def __init__(self, key_pem, kid=None):
if key_pem:
# Import JWK from RSA key
try:
self.key = RSAKey(
# Using the same key ID as client id
# This way we can easily serve multiple public
# keys on the same endpoint and keep all
# LTI 1.3 blocks working
kid=kid,
key=RSA.import_key(key_pem)
)
algo = jwt.get_algorithm_by_name('RS256')
private_key = algo.prepare_key(key_pem)
private_jwk = json.loads(algo.to_jwk(private_key))
private_jwk['kid'] = kid
self.key = PyJWK.from_dict(private_jwk)
except ValueError as err:
log.warning(
'An error was encountered while loading the LTI platform\'s key. '
Expand All @@ -206,92 +179,52 @@ def encode_and_sign(self, message, expiration=None):
# Set iat and exp if expiration is set
if expiration:
_message.update({
"iat": int(round(time.time())),
"exp": int(round(time.time()) + expiration),
"iat": int(math.floor(time.time())),
"exp": int(math.floor(time.time()) + expiration),
})

# The class instance that sets up the signing operation
# An RS 256 key is required for LTI 1.3
_jws = JWS(_message, alg="RS256", cty="JWT")

try:
# Encode and sign LTI message
return _jws.sign_compact([self.key])
except NoSuitableSigningKeys as err:
log.warning(
'An error was encountered while signing the OAuth 2.0 access token JWT. '
'There is no suitable signing key.'
)
raise exceptions.NoSuitableKeys() from err
except UnknownAlgorithm as err:
log.warning(
'An error was encountered while signing the OAuth 2.0 access token JWT. '
'There algorithm is unknown.'
)
raise exceptions.MalformedJwtToken() from err
return jwt.encode(_message, self.key.key, algorithm="RS256")

def get_public_jwk(self):
"""
Export Public JWK
"""
public_keys = jwk.KEYS()
jwk = {"keys": []}

# Only append to keyset if a key exists
if self.key:
public_keys.append(self.key)

return json.loads(public_keys.dump_jwks())
algo_obj = jwt.get_algorithm_by_name('RS256')
public_key = algo_obj.prepare_key(self.key.key).public_key()
jwk['keys'].append(json.loads(algo_obj.to_jwk(public_key)))
return jwk

def validate_and_decode(self, token, iss=None, aud=None):
def validate_and_decode(self, token, iss=None, aud=None, exp=True):
"""
Check if a platform token is valid, and return allowed scopes.

Validates a token sent by the tool using the platform's RSA Key.
Optionally validate iss and aud claims if provided.
"""
if not self.key:
raise exceptions.RsaKeyNotSet()
try:
# Verify message signature
message = JWS().verify_compact(token, keys=[self.key])

# If message is valid, check expiration from JWT
if 'exp' in message and message['exp'] < time.time():
log.warning(
'An error was encountered while verifying the OAuth 2.0 access token. '
'The JWT has expired.'
)
raise exceptions.TokenSignatureExpired()

# Validate issuer claim (if present)
log_message_base = 'An error was encountered while verifying the OAuth 2.0 access token. '
if iss:
if 'iss' not in message or message['iss'] != iss:
error_message = 'The required iss claim is missing or does not match the expected iss value. '
log_message = log_message_base + error_message

log.warning(log_message)
raise exceptions.InvalidClaimValue(error_message)

# Validate audience claim (if present)
if aud:
if 'aud' not in message or aud not in message['aud']:
error_message = 'The required aud claim is missing.'
log_message = log_message_base + error_message

log.warning(log_message)
raise exceptions.InvalidClaimValue(error_message)

# Else return token contents
message = jwt.decode(
token,
key=self.key.key.public_key(),
audience=aud,
issuer=iss,
algorithms=['RS256', 'RS512'],
options={
'verify_signature': True,
'verify_exp': bool(exp),
'verify_iss': bool(iss),
'verify_aud': bool(aud)
}
)
return message

except NoSuitableSigningKeys as err:
log.warning(
'An error was encountered while verifying the OAuth 2.0 access token. '
'There is no suitable signing key.'
)
raise exceptions.NoSuitableKeys() from err
except BadSyntax as err:
log.warning(
'An error was encountered while verifying the OAuth 2.0 access token. '
'The JWT is malformed.'
)
raise exceptions.MalformedJwtToken() from err
except Exception as token_error:
exc_info = sys.exc_info()
raise jwt.InvalidTokenError(exc_info[2]) from token_error
41 changes: 27 additions & 14 deletions lti_consumer/lti_1p3/tests/test_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,17 @@
Unit tests for LTI 1.3 consumer implementation
"""

import json
from unittest.mock import patch
from urllib.parse import parse_qs, urlparse
import uuid

import ddt
import jwt
from Cryptodome.PublicKey import RSA
from django.conf import settings
from django.test.testcases import TestCase
from edx_django_utils.cache import get_cache_key, TieredCache
from jwkest.jwk import load_jwks
from jwkest.jws import JWS
from jwt.api_jwk import PyJWKSet

from lti_consumer.data import Lti1p3LaunchData
from lti_consumer.lti_1p3 import exceptions
Expand All @@ -36,7 +35,9 @@
STATE = "ABCD"
# Consider storing a fixed key
RSA_KEY_ID = "1"
RSA_KEY = RSA.generate(2048).export_key('PEM')
RSA_KEY = RSA.generate(2048)
RSA_PRIVATE_KEY = RSA_KEY.export_key('PEM')
RSA_PUBLIC_KEY = RSA_KEY.public_key().export_key('PEM')


def _generate_token_request_data(token, scope):
Expand Down Expand Up @@ -69,11 +70,11 @@ def setUp(self):
lti_launch_url=LAUNCH_URL,
client_id=CLIENT_ID,
deployment_id=DEPLOYMENT_ID,
rsa_key=RSA_KEY,
rsa_key=RSA_PRIVATE_KEY,
rsa_key_id=RSA_KEY_ID,
redirect_uris=REDIRECT_URIS,
# Use the same key for testing purposes
tool_key=RSA_KEY
tool_key=RSA_PUBLIC_KEY
)

def _setup_lti_launch_data(self):
Expand Down Expand Up @@ -113,14 +114,26 @@ def _get_lti_message(

def _decode_token(self, token):
"""
Checks for a valid signarute and decodes JWT signed LTI message
Checks for a valid signature and decodes JWT signed LTI message

This also tests the public keyset function.
"""
public_keyset = self.lti_consumer.get_public_keyset()
key_set = load_jwks(json.dumps(public_keyset))
keyset = PyJWKSet.from_dict(public_keyset).keys

for obj in keyset:
message = jwt.decode(
token,
key=obj.key,
algorithms=['RS256', 'RS512'],
options={
'verify_signature': True,
'verify_aud': False
}
)
return message

return JWS().verify_compact(token, keys=key_set)
return exceptions.NoSuitableKeys()

@ddt.data(
({"client_id": CLIENT_ID, "redirect_uri": LAUNCH_URL, "nonce": STATE, "state": STATE}, True),
Expand Down Expand Up @@ -558,7 +571,7 @@ def test_access_token_invalid_jwt(self):
"""
request_data = _generate_token_request_data("invalid_jwt", "")

with self.assertRaises(exceptions.MalformedJwtToken):
with self.assertRaises(jwt.exceptions.InvalidTokenError):
self.lti_consumer.access_token(request_data)

def test_access_token_no_acs(self):
Expand Down Expand Up @@ -686,11 +699,11 @@ def setUp(self):
lti_launch_url=LAUNCH_URL,
client_id=CLIENT_ID,
deployment_id=DEPLOYMENT_ID,
rsa_key=RSA_KEY,
rsa_key=RSA_PRIVATE_KEY,
rsa_key_id=RSA_KEY_ID,
redirect_uris=REDIRECT_URIS,
# Use the same key for testing purposes
tool_key=RSA_KEY
tool_key=RSA_PUBLIC_KEY
)

self.preflight_response = {}
Expand Down Expand Up @@ -930,11 +943,11 @@ def setUp(self):
lti_launch_url=LAUNCH_URL,
client_id=CLIENT_ID,
deployment_id=DEPLOYMENT_ID,
rsa_key=RSA_KEY,
rsa_key=RSA_PRIVATE_KEY,
rsa_key_id=RSA_KEY_ID,
redirect_uris=REDIRECT_URIS,
# Use the same key for testing purposes
tool_key=RSA_KEY
tool_key=RSA_PUBLIC_KEY
)

self.preflight_response = {}
Expand Down
Loading
Loading