Skip to content

Commit

Permalink
Migrate ChaCha20Poly1305 AEAD to Rust
Browse files Browse the repository at this point in the history
  • Loading branch information
alex committed Jan 12, 2024
1 parent a1ed534 commit 272dd6b
Show file tree
Hide file tree
Showing 6 changed files with 461 additions and 317 deletions.
254 changes: 29 additions & 225 deletions src/cryptography/hazmat/backends/openssl/aead.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,47 +13,15 @@
from cryptography.hazmat.primitives.ciphers.aead import (
AESCCM,
AESGCM,
ChaCha20Poly1305,
)

_AEADTypes = typing.Union[AESCCM, AESGCM, ChaCha20Poly1305]


def _is_evp_aead_supported_cipher(
backend: Backend, cipher: _AEADTypes
) -> bool:
"""
Checks whether the given cipher is supported through
EVP_AEAD rather than the normal OpenSSL EVP_CIPHER API.
"""
from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305

return backend._lib.Cryptography_HAS_EVP_AEAD and isinstance(
cipher, ChaCha20Poly1305
)
_AEADTypes = typing.Union[AESCCM, AESGCM]


def _aead_cipher_supported(backend: Backend, cipher: _AEADTypes) -> bool:
if _is_evp_aead_supported_cipher(backend, cipher):
return True
else:
cipher_name = _evp_cipher_cipher_name(cipher)
if backend._fips_enabled and cipher_name not in backend._fips_aead:
return False
return (
backend._lib.EVP_get_cipherbyname(cipher_name) != backend._ffi.NULL
)

cipher_name = _evp_cipher_cipher_name(cipher)

def _aead_create_ctx(
backend: Backend,
cipher: _AEADTypes,
key: bytes,
):
if _is_evp_aead_supported_cipher(backend, cipher):
return _evp_aead_create_ctx(backend, cipher, key)
else:
return _evp_cipher_create_ctx(backend, cipher, key)
return backend._lib.EVP_get_cipherbyname(cipher_name) != backend._ffi.NULL


def _encrypt(
Expand All @@ -63,153 +31,24 @@ def _encrypt(
data: bytes,
associated_data: list[bytes],
tag_length: int,
ctx: typing.Any = None,
) -> bytes:
if _is_evp_aead_supported_cipher(backend, cipher):
return _evp_aead_encrypt(
backend, cipher, nonce, data, associated_data, tag_length, ctx
)
else:
return _evp_cipher_encrypt(
backend, cipher, nonce, data, associated_data, tag_length, ctx
)


def _decrypt(
backend: Backend,
cipher: _AEADTypes,
nonce: bytes,
data: bytes,
associated_data: list[bytes],
tag_length: int,
ctx: typing.Any = None,
) -> bytes:
if _is_evp_aead_supported_cipher(backend, cipher):
return _evp_aead_decrypt(
backend, cipher, nonce, data, associated_data, tag_length, ctx
)
else:
return _evp_cipher_decrypt(
backend, cipher, nonce, data, associated_data, tag_length, ctx
)


def _evp_aead_create_ctx(
backend: Backend,
cipher: _AEADTypes,
key: bytes,
tag_len: int | None = None,
):
aead_cipher = _evp_aead_get_cipher(backend, cipher)
assert aead_cipher is not None
key_ptr = backend._ffi.from_buffer(key)
tag_len = (
backend._lib.EVP_AEAD_DEFAULT_TAG_LENGTH
if tag_len is None
else tag_len
)
ctx = backend._lib.Cryptography_EVP_AEAD_CTX_new(
aead_cipher, key_ptr, len(key), tag_len
)
backend.openssl_assert(ctx != backend._ffi.NULL)
ctx = backend._ffi.gc(ctx, backend._lib.EVP_AEAD_CTX_free)
return ctx


def _evp_aead_get_cipher(backend: Backend, cipher: _AEADTypes):
from cryptography.hazmat.primitives.ciphers.aead import (
ChaCha20Poly1305,
return _evp_cipher_encrypt(
backend, cipher, nonce, data, associated_data, tag_length
)

# Currently only ChaCha20-Poly1305 is supported using this API
assert isinstance(cipher, ChaCha20Poly1305)
return backend._lib.EVP_aead_chacha20_poly1305()


def _evp_aead_encrypt(
backend: Backend,
cipher: _AEADTypes,
nonce: bytes,
data: bytes,
associated_data: list[bytes],
tag_length: int,
ctx: typing.Any,
) -> bytes:
assert ctx is not None

aead_cipher = _evp_aead_get_cipher(backend, cipher)
assert aead_cipher is not None

out_len = backend._ffi.new("size_t *")
# max_out_len should be in_len plus the result of
# EVP_AEAD_max_overhead.
max_out_len = len(data) + backend._lib.EVP_AEAD_max_overhead(aead_cipher)
out_buf = backend._ffi.new("uint8_t[]", max_out_len)
data_ptr = backend._ffi.from_buffer(data)
nonce_ptr = backend._ffi.from_buffer(nonce)
aad = b"".join(associated_data)
aad_ptr = backend._ffi.from_buffer(aad)

res = backend._lib.EVP_AEAD_CTX_seal(
ctx,
out_buf,
out_len,
max_out_len,
nonce_ptr,
len(nonce),
data_ptr,
len(data),
aad_ptr,
len(aad),
)
backend.openssl_assert(res == 1)
encrypted_data = backend._ffi.buffer(out_buf, out_len[0])[:]
return encrypted_data


def _evp_aead_decrypt(
def _decrypt(
backend: Backend,
cipher: _AEADTypes,
nonce: bytes,
data: bytes,
associated_data: list[bytes],
tag_length: int,
ctx: typing.Any,
) -> bytes:
if len(data) < tag_length:
raise InvalidTag

assert ctx is not None

out_len = backend._ffi.new("size_t *")
# max_out_len should at least in_len
max_out_len = len(data)
out_buf = backend._ffi.new("uint8_t[]", max_out_len)
data_ptr = backend._ffi.from_buffer(data)
nonce_ptr = backend._ffi.from_buffer(nonce)
aad = b"".join(associated_data)
aad_ptr = backend._ffi.from_buffer(aad)

res = backend._lib.EVP_AEAD_CTX_open(
ctx,
out_buf,
out_len,
max_out_len,
nonce_ptr,
len(nonce),
data_ptr,
len(data),
aad_ptr,
len(aad),
return _evp_cipher_decrypt(
backend, cipher, nonce, data, associated_data, tag_length
)

if res == 0:
backend._consume_errors()
raise InvalidTag

decrypted_data = backend._ffi.buffer(out_buf, out_len[0])[:]
return decrypted_data


_ENCRYPT = 1
_DECRYPT = 0
Expand All @@ -219,12 +58,9 @@ def _evp_cipher_cipher_name(cipher: _AEADTypes) -> bytes:
from cryptography.hazmat.primitives.ciphers.aead import (
AESCCM,
AESGCM,
ChaCha20Poly1305,
)

if isinstance(cipher, ChaCha20Poly1305):
return b"chacha20-poly1305"
elif isinstance(cipher, AESCCM):
if isinstance(cipher, AESCCM):
return f"aes-{len(cipher._key) * 8}-ccm".encode("ascii")
else:
assert isinstance(cipher, AESGCM)
Expand All @@ -237,29 +73,6 @@ def _evp_cipher(cipher_name: bytes, backend: Backend):
return evp_cipher


def _evp_cipher_create_ctx(
backend: Backend,
cipher: _AEADTypes,
key: bytes,
):
ctx = backend._lib.EVP_CIPHER_CTX_new()
backend.openssl_assert(ctx != backend._ffi.NULL)
ctx = backend._ffi.gc(ctx, backend._lib.EVP_CIPHER_CTX_free)
cipher_name = _evp_cipher_cipher_name(cipher)
evp_cipher = _evp_cipher(cipher_name, backend)
key_ptr = backend._ffi.from_buffer(key)
res = backend._lib.EVP_CipherInit_ex(
ctx,
evp_cipher,
backend._ffi.NULL,
key_ptr,
backend._ffi.NULL,
0,
)
backend.openssl_assert(res != 0)
return ctx


def _evp_cipher_aead_setup(
backend: Backend,
cipher_name: bytes,
Expand Down Expand Up @@ -373,23 +186,19 @@ def _evp_cipher_encrypt(
data: bytes,
associated_data: list[bytes],
tag_length: int,
ctx: typing.Any = None,
) -> bytes:
from cryptography.hazmat.primitives.ciphers.aead import AESCCM

if ctx is None:
cipher_name = _evp_cipher_cipher_name(cipher)
ctx = _evp_cipher_aead_setup(
backend,
cipher_name,
cipher._key,
nonce,
None,
tag_length,
_ENCRYPT,
)
else:
_evp_cipher_set_nonce_operation(backend, ctx, nonce, _ENCRYPT)
cipher_name = _evp_cipher_cipher_name(cipher)
ctx = _evp_cipher_aead_setup(
backend,
cipher_name,
cipher._key,
nonce,
None,
tag_length,
_ENCRYPT,
)

# CCM requires us to pass the length of the data before processing
# anything.
Expand Down Expand Up @@ -425,7 +234,6 @@ def _evp_cipher_decrypt(
data: bytes,
associated_data: list[bytes],
tag_length: int,
ctx: typing.Any = None,
) -> bytes:
from cryptography.hazmat.primitives.ciphers.aead import AESCCM

Expand All @@ -434,20 +242,16 @@ def _evp_cipher_decrypt(

tag = data[-tag_length:]
data = data[:-tag_length]
if ctx is None:
cipher_name = _evp_cipher_cipher_name(cipher)
ctx = _evp_cipher_aead_setup(
backend,
cipher_name,
cipher._key,
nonce,
tag,
tag_length,
_DECRYPT,
)
else:
_evp_cipher_set_nonce_operation(backend, ctx, nonce, _DECRYPT)
_evp_cipher_set_tag(backend, ctx, tag)
cipher_name = _evp_cipher_cipher_name(cipher)
ctx = _evp_cipher_aead_setup(
backend,
cipher_name,
cipher._key,
nonce,
tag,
tag_length,
_DECRYPT,
)

# CCM requires us to pass the length of the data before processing
# anything.
Expand Down
17 changes: 17 additions & 0 deletions src/cryptography/hazmat/bindings/_rust/openssl/aead.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,23 @@
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
# for complete details.

class ChaCha20Poly1305:
def __init__(self, key: bytes) -> None: ...
@staticmethod
def generate_key() -> bytes: ...
def encrypt(
self,
nonce: bytes,
data: bytes,
associated_data: bytes | None,
) -> bytes: ...
def decrypt(
self,
nonce: bytes,
data: bytes,
associated_data: bytes | None,
) -> bytes: ...

class AESSIV:
def __init__(self, key: bytes) -> None: ...
@staticmethod
Expand Down
Loading

0 comments on commit 272dd6b

Please sign in to comment.