Skip to content

Commit

Permalink
pyjwt: ES256 algorithm support for PyJWT.
Browse files Browse the repository at this point in the history
Add optional support for ES256 JWT signing/verifying to PyJWT using
@dmazzella's cryptography port.

Signed-off-by: Jonah Bron <[email protected]>
  • Loading branch information
jonahbron committed Mar 18, 2024
1 parent ffb07db commit 5dec263
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 17 deletions.
87 changes: 75 additions & 12 deletions python-ecosys/pyjwt/jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,17 @@
import json
from time import time

# Optionally depend on https://github.com/dmazzella/ucryptography
try:
# Try importing from ucryptography port.
import cryptography
from cryptography import hashes, ec, serialization, utils

_ec_supported = True
except ImportError:
# No cryptography library available, no EC256 support.
_ec_supported = False


def _to_b64url(data):
return (
Expand All @@ -19,6 +30,28 @@ def _from_b64url(data):
return binascii.a2b_base64(data.replace(b"-", b"+").replace(b"_", b"/") + b"===")


def _sig_der_to_jws(signed):
"""Accept a DER signature and convert to JSON Web Signature bytes.
`cryptography` produces signatures encoded in DER ASN.1 binary format.
JSON Web Algorithm instead encodes the signature as the point coordinates
as bigendian byte strings concatenated.
See https://datatracker.ietf.org/doc/html/rfc7518#section-3.4
"""
r, s = utils.decode_dss_signature(signed)
return r.to_bytes(32, "big") + s.to_bytes(32, "big")


def _sig_jws_to_der(signed):
"""Accept a JSON Web Signature and convert to a DER signature.
See `_sig_der_to_jws()`
"""
r, s = int.from_bytes(signed[0:32], "big"), int.from_bytes(signed[32:], "big")
return utils.encode_dss_signature(r, s)


class exceptions:
class PyJWTError(Exception):
pass
Expand All @@ -37,19 +70,32 @@ class ExpiredSignatureError(PyJWTError):


def encode(payload, key, algorithm="HS256"):
if algorithm != "HS256":
if algorithm != "HS256" and algorithm != "ES256":
raise exceptions.InvalidAlgorithmError

if isinstance(key, str):
key = key.encode()
header = _to_b64url(json.dumps({"typ": "JWT", "alg": algorithm}).encode())
payload = _to_b64url(json.dumps(payload).encode())
signature = _to_b64url(hmac.new(key, header + b"." + payload, hashlib.sha256).digest())

if algorithm == "HS256":
if isinstance(key, str):
key = key.encode()
signature = _to_b64url(hmac.new(key, header + b"." + payload, hashlib.sha256).digest())
elif algorithm == "ES256":
if not _ec_supported:
raise exceptions.InvalidAlgorithmError(
"Required dependencies for ES256 are not available"
)
if isinstance(key, int):
key = ec.derive_private_key(key, ec.SECP256R1())
signature = _to_b64url(
_sig_der_to_jws(key.sign(header + b"." + payload, ec.ECDSA(hashes.SHA256())))
)

return (header + b"." + payload + b"." + signature).decode()


def decode(token, key, algorithms=["HS256"]):
if "HS256" not in algorithms:
def decode(token, key, algorithms=["HS256", "ES256"]):
if "HS256" not in algorithms and "ES256" not in algorithms:
raise exceptions.InvalidAlgorithmError

parts = token.encode().split(b".")
Expand All @@ -63,14 +109,31 @@ def decode(token, key, algorithms=["HS256"]):
except Exception:
raise exceptions.InvalidTokenError

if header["alg"] not in algorithms or header["alg"] != "HS256":
if header["alg"] not in algorithms or (header["alg"] != "HS256" and header["alg"] != "ES256"):
raise exceptions.InvalidAlgorithmError

if isinstance(key, str):
key = key.encode()
calculated_signature = hmac.new(key, parts[0] + b"." + parts[1], hashlib.sha256).digest()
if signature != calculated_signature:
raise exceptions.InvalidSignatureError
if header["alg"] == "HS256":
if isinstance(key, str):
key = key.encode()
calculated_signature = hmac.new(key, parts[0] + b"." + parts[1], hashlib.sha256).digest()
if signature != calculated_signature:
raise exceptions.InvalidSignatureError
elif header["alg"] == "ES256":
if not _ec_supported:
raise exceptions.InvalidAlgorithmError(
"Required dependencies for ES256 are not available"
)

if isinstance(key, bytes):
key = ec.EllipticCurvePublicKey.from_encoded_point(key, ec.SECP256R1())
try:
key.verify(
_sig_jws_to_der(signature),
parts[0] + b"." + parts[1],
ec.ECDSA(hashes.SHA256()),
)
except cryptography.exceptions.InvalidSignature:
raise exceptions.InvalidSignatureError

if "exp" in payload:
if time() > payload["exp"]:
Expand Down
53 changes: 48 additions & 5 deletions python-ecosys/pyjwt/test_jwt.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,71 @@
import jwt
from time import time

"""
Run tests by executing:
```
mpremote fs cp jwt.py :lib/jwt.py + run test_jwt.py
```
Only the full test suite can be run if
[ucryptography](https://github.com/dmazzella/ucryptography) is present in the
firmware.
"""

# Indentation
I = " "

print("Testing HS256")
secret_key = "top-secret!"

token = jwt.encode({"user": "joe"}, secret_key, algorithm="HS256")
print(token)
decoded = jwt.decode(token, secret_key, algorithms=["HS256"])
if decoded != {"user": "joe"}:
raise Exception("Invalid decoded JWT")
else:
print("Encode/decode test: OK")
print(I, "Encode/decode test: OK")

try:
decoded = jwt.decode(token, "wrong-secret", algorithms=["HS256"])
except jwt.exceptions.InvalidSignatureError:
print("Invalid signature test: OK")
print(I, "Invalid signature test: OK")
else:
raise Exception("Invalid JWT should have failed decoding")

token = jwt.encode({"user": "joe", "exp": time() - 1}, secret_key)
print(token)
try:
decoded = jwt.decode(token, secret_key, algorithms=["HS256"])
except jwt.exceptions.ExpiredSignatureError:
print("Expired token test: OK")
print(I, "Expired token test: OK")
else:
raise Exception("Expired JWT should have failed decoding")


print("Testing ES256")
try:
from cryptography import ec
except ImportError:
raise Exception("No cryptography lib present, can't test ES256")

private_key = ec.derive_private_key(
0xEB6DFB26C7A3C23D33C60F7C7BA61B6893451F2643E0737B20759E457825EE75, ec.SECP256R1()
)
wrong_private_key = ec.derive_private_key(
0x25D91A0DA38F69283A0CE32B87D82817CA4E134A1693BE6083C2292BF562A451, ec.SECP256R1()
)

token = jwt.encode({"user": "joe"}, private_key, algorithm="ES256")
decoded = jwt.decode(token, private_key.public_key(), algorithms=["ES256"])
if decoded != {"user": "joe"}:
raise Exception("Invalid decoded JWT")
else:
print(I, "Encode/decode test: OK")

token = jwt.encode({"user": "joe"}, private_key, algorithm="ES256")
try:
decoded = jwt.decode(token + "a", wrong_private_key.public_key(), algorithms=["ES256"])
except jwt.exceptions.InvalidSignatureError:
print(I, "Invalid signature test: OK")
else:
raise Exception("Invalid JWT should have fialed decoding")

0 comments on commit 5dec263

Please sign in to comment.