diff --git a/src/cryptography/hazmat/backends/openssl/aead.py b/src/cryptography/hazmat/backends/openssl/aead.py deleted file mode 100644 index dd2485481203..000000000000 --- a/src/cryptography/hazmat/backends/openssl/aead.py +++ /dev/null @@ -1,255 +0,0 @@ -# This file is dual licensed under the terms of the Apache License, Version -# 2.0, and the BSD License. See the LICENSE file in the root of this repository -# for complete details. - -from __future__ import annotations - -import typing - -from cryptography.exceptions import InvalidTag - -if typing.TYPE_CHECKING: - from cryptography.hazmat.backends.openssl.backend import Backend - from cryptography.hazmat.primitives.ciphers.aead import ( - AESCCM, - ) - - _AEADTypes = typing.Union[AESCCM] - - -def _aead_cipher_supported(backend: Backend, cipher: _AEADTypes) -> bool: - cipher_name = _evp_cipher_cipher_name(cipher) - - return backend._lib.EVP_get_cipherbyname(cipher_name) != backend._ffi.NULL - - -def _encrypt( - backend: Backend, - cipher: _AEADTypes, - nonce: bytes, - data: bytes, - associated_data: list[bytes], - tag_length: int, -) -> bytes: - return _evp_cipher_encrypt( - backend, cipher, nonce, data, associated_data, tag_length - ) - - -def _decrypt( - backend: Backend, - cipher: _AEADTypes, - nonce: bytes, - data: bytes, - associated_data: list[bytes], - tag_length: int, -) -> bytes: - return _evp_cipher_decrypt( - backend, cipher, nonce, data, associated_data, tag_length - ) - - -_ENCRYPT = 1 -_DECRYPT = 0 - - -def _evp_cipher_cipher_name(cipher: _AEADTypes) -> bytes: - from cryptography.hazmat.primitives.ciphers.aead import AESCCM - - assert isinstance(cipher, AESCCM) - return f"aes-{len(cipher._key) * 8}-ccm".encode("ascii") - - -def _evp_cipher(cipher_name: bytes, backend: Backend): - evp_cipher = backend._lib.EVP_get_cipherbyname(cipher_name) - backend.openssl_assert(evp_cipher != backend._ffi.NULL) - return evp_cipher - - -def _evp_cipher_aead_setup( - backend: Backend, - cipher_name: bytes, - key: bytes, - nonce: bytes, - tag: bytes | None, - tag_len: int, - operation: int, -): - evp_cipher = _evp_cipher(cipher_name, backend) - ctx = backend._lib.EVP_CIPHER_CTX_new() - ctx = backend._ffi.gc(ctx, backend._lib.EVP_CIPHER_CTX_free) - res = backend._lib.EVP_CipherInit_ex( - ctx, - evp_cipher, - backend._ffi.NULL, - backend._ffi.NULL, - backend._ffi.NULL, - int(operation == _ENCRYPT), - ) - backend.openssl_assert(res != 0) - # CCM requires the IVLEN to be set before calling SET_TAG on decrypt - res = backend._lib.EVP_CIPHER_CTX_ctrl( - ctx, - backend._lib.EVP_CTRL_AEAD_SET_IVLEN, - len(nonce), - backend._ffi.NULL, - ) - backend.openssl_assert(res != 0) - if operation == _DECRYPT: - assert tag is not None - _evp_cipher_set_tag(backend, ctx, tag) - else: - assert cipher_name.endswith(b"-ccm") - res = backend._lib.EVP_CIPHER_CTX_ctrl( - ctx, - backend._lib.EVP_CTRL_AEAD_SET_TAG, - tag_len, - backend._ffi.NULL, - ) - backend.openssl_assert(res != 0) - - nonce_ptr = backend._ffi.from_buffer(nonce) - key_ptr = backend._ffi.from_buffer(key) - res = backend._lib.EVP_CipherInit_ex( - ctx, - backend._ffi.NULL, - backend._ffi.NULL, - key_ptr, - nonce_ptr, - int(operation == _ENCRYPT), - ) - backend.openssl_assert(res != 0) - return ctx - - -def _evp_cipher_set_tag(backend, ctx, tag: bytes) -> None: - tag_ptr = backend._ffi.from_buffer(tag) - res = backend._lib.EVP_CIPHER_CTX_ctrl( - ctx, backend._lib.EVP_CTRL_AEAD_SET_TAG, len(tag), tag_ptr - ) - backend.openssl_assert(res != 0) - - -def _evp_cipher_set_length(backend: Backend, ctx, data_len: int) -> None: - intptr = backend._ffi.new("int *") - res = backend._lib.EVP_CipherUpdate( - ctx, backend._ffi.NULL, intptr, backend._ffi.NULL, data_len - ) - backend.openssl_assert(res != 0) - - -def _evp_cipher_process_aad( - backend: Backend, ctx, associated_data: bytes -) -> None: - outlen = backend._ffi.new("int *") - a_data_ptr = backend._ffi.from_buffer(associated_data) - res = backend._lib.EVP_CipherUpdate( - ctx, backend._ffi.NULL, outlen, a_data_ptr, len(associated_data) - ) - backend.openssl_assert(res != 0) - - -def _evp_cipher_process_data(backend: Backend, ctx, data: bytes) -> bytes: - outlen = backend._ffi.new("int *") - buf = backend._ffi.new("unsigned char[]", len(data)) - data_ptr = backend._ffi.from_buffer(data) - res = backend._lib.EVP_CipherUpdate(ctx, buf, outlen, data_ptr, len(data)) - backend.openssl_assert(res != 0) - return backend._ffi.buffer(buf, outlen[0])[:] - - -def _evp_cipher_encrypt( - backend: Backend, - cipher: _AEADTypes, - nonce: bytes, - data: bytes, - associated_data: list[bytes], - tag_length: int, -) -> bytes: - from cryptography.hazmat.primitives.ciphers.aead import AESCCM - - cipher_name = _evp_cipher_cipher_name(cipher) - ctx = _evp_cipher_aead_setup( - backend, - cipher_name, - cipher._key, - nonce, - None, - tag_length, - _ENCRYPT, - ) - - # CCM requires us to pass the length of the data before processing - # anything. - # However calling this with any other AEAD results in an error - assert isinstance(cipher, AESCCM) - _evp_cipher_set_length(backend, ctx, len(data)) - - for ad in associated_data: - _evp_cipher_process_aad(backend, ctx, ad) - processed_data = _evp_cipher_process_data(backend, ctx, data) - outlen = backend._ffi.new("int *") - # All AEADs we support besides OCB are streaming so they return nothing - # in finalization. OCB can return up to (16 byte block - 1) bytes so - # we need a buffer here too. - buf = backend._ffi.new("unsigned char[]", 16) - res = backend._lib.EVP_CipherFinal_ex(ctx, buf, outlen) - backend.openssl_assert(res != 0) - processed_data += backend._ffi.buffer(buf, outlen[0])[:] - tag_buf = backend._ffi.new("unsigned char[]", tag_length) - res = backend._lib.EVP_CIPHER_CTX_ctrl( - ctx, backend._lib.EVP_CTRL_AEAD_GET_TAG, tag_length, tag_buf - ) - backend.openssl_assert(res != 0) - tag = backend._ffi.buffer(tag_buf)[:] - - return processed_data + tag - - -def _evp_cipher_decrypt( - backend: Backend, - cipher: _AEADTypes, - nonce: bytes, - data: bytes, - associated_data: list[bytes], - tag_length: int, -) -> bytes: - from cryptography.hazmat.primitives.ciphers.aead import AESCCM - - if len(data) < tag_length: - raise InvalidTag - - tag = data[-tag_length:] - data = data[:-tag_length] - cipher_name = _evp_cipher_cipher_name(cipher) - ctx = _evp_cipher_aead_setup( - backend, - cipher_name, - cipher._key, - nonce, - tag, - tag_length, - _DECRYPT, - ) - - # CCM requires us to pass the length of the data before processing - # anything. - # However calling this with any other AEAD results in an error - assert isinstance(cipher, AESCCM) - _evp_cipher_set_length(backend, ctx, len(data)) - - for ad in associated_data: - _evp_cipher_process_aad(backend, ctx, ad) - # CCM has a different error path if the tag doesn't match. Errors are - # raised in Update and Final is irrelevant. - outlen = backend._ffi.new("int *") - buf = backend._ffi.new("unsigned char[]", len(data)) - d_ptr = backend._ffi.from_buffer(data) - res = backend._lib.EVP_CipherUpdate(ctx, buf, outlen, d_ptr, len(data)) - if res != 1: - backend._consume_errors() - raise InvalidTag - - processed_data = backend._ffi.buffer(buf, outlen[0])[:] - - return processed_data diff --git a/src/cryptography/hazmat/backends/openssl/backend.py b/src/cryptography/hazmat/backends/openssl/backend.py index f296303ced1f..1f76092625bc 100644 --- a/src/cryptography/hazmat/backends/openssl/backend.py +++ b/src/cryptography/hazmat/backends/openssl/backend.py @@ -11,7 +11,6 @@ from cryptography import utils, x509 from cryptography.exceptions import UnsupportedAlgorithm -from cryptography.hazmat.backends.openssl import aead from cryptography.hazmat.backends.openssl.ciphers import _CipherContext from cryptography.hazmat.bindings._rust import openssl as rust_openssl from cryptography.hazmat.bindings.openssl import binding @@ -559,9 +558,6 @@ def ed448_supported(self) -> bool: and not self._lib.CRYPTOGRAPHY_IS_BORINGSSL ) - def aead_cipher_supported(self, cipher) -> bool: - return aead._aead_cipher_supported(self, cipher) - def _zero_data(self, data, length: int) -> None: # We clear things this way because at the moment we're not # sure of a better way that can guarantee it overwrites the diff --git a/src/cryptography/hazmat/bindings/_rust/openssl/aead.pyi b/src/cryptography/hazmat/bindings/_rust/openssl/aead.pyi index e274073f201e..047f49d819c1 100644 --- a/src/cryptography/hazmat/bindings/_rust/openssl/aead.pyi +++ b/src/cryptography/hazmat/bindings/_rust/openssl/aead.pyi @@ -36,6 +36,23 @@ class ChaCha20Poly1305: associated_data: bytes | None, ) -> bytes: ... +class AESCCM: + def __init__(self, key: bytes, tag_length: int = 16) -> None: ... + @staticmethod + def generate_key(key_size: int) -> bytes: ... + def encrypt( + self, + nonce: bytes, + data: bytes, + associated_data: bytes | None, + ) -> bytes: ... + def decrypt( + self, + nonce: bytes, + data: bytes, + associated_data: bytes | None, + ) -> bytes: ... + class AESSIV: def __init__(self, key: bytes) -> None: ... @staticmethod diff --git a/src/cryptography/hazmat/primitives/ciphers/aead.py b/src/cryptography/hazmat/primitives/ciphers/aead.py index e96b735b18f9..f82a05685e02 100644 --- a/src/cryptography/hazmat/primitives/ciphers/aead.py +++ b/src/cryptography/hazmat/primitives/ciphers/aead.py @@ -4,11 +4,6 @@ from __future__ import annotations -import os - -from cryptography import exceptions, utils -from cryptography.hazmat.backends.openssl import aead -from cryptography.hazmat.backends.openssl.backend import backend from cryptography.hazmat.bindings._rust import openssl as rust_openssl __all__ = [ @@ -22,91 +17,7 @@ AESGCM = rust_openssl.aead.AESGCM ChaCha20Poly1305 = rust_openssl.aead.ChaCha20Poly1305 +AESCCM = rust_openssl.aead.AESCCM AESSIV = rust_openssl.aead.AESSIV AESOCB3 = rust_openssl.aead.AESOCB3 AESGCMSIV = rust_openssl.aead.AESGCMSIV - - -class AESCCM: - _MAX_SIZE = 2**31 - 1 - - def __init__(self, key: bytes, tag_length: int = 16): - utils._check_byteslike("key", key) - if len(key) not in (16, 24, 32): - raise ValueError("AESCCM key must be 128, 192, or 256 bits.") - - self._key = key - if not isinstance(tag_length, int): - raise TypeError("tag_length must be an integer") - - if tag_length not in (4, 6, 8, 10, 12, 14, 16): - raise ValueError("Invalid tag_length") - - self._tag_length = tag_length - - if not backend.aead_cipher_supported(self): - raise exceptions.UnsupportedAlgorithm( - "AESCCM is not supported by this version of OpenSSL", - exceptions._Reasons.UNSUPPORTED_CIPHER, - ) - - @classmethod - def generate_key(cls, bit_length: int) -> bytes: - if not isinstance(bit_length, int): - raise TypeError("bit_length must be an integer") - - if bit_length not in (128, 192, 256): - raise ValueError("bit_length must be 128, 192, or 256") - - return os.urandom(bit_length // 8) - - def encrypt( - self, - nonce: bytes, - data: bytes, - associated_data: bytes | None, - ) -> bytes: - if associated_data is None: - associated_data = b"" - - if len(data) > self._MAX_SIZE or len(associated_data) > self._MAX_SIZE: - # This is OverflowError to match what cffi would raise - raise OverflowError( - "Data or associated data too long. Max 2**31 - 1 bytes" - ) - - self._check_params(nonce, data, associated_data) - self._validate_lengths(nonce, len(data)) - return aead._encrypt( - backend, self, nonce, data, [associated_data], self._tag_length - ) - - def decrypt( - self, - nonce: bytes, - data: bytes, - associated_data: bytes | None, - ) -> bytes: - if associated_data is None: - associated_data = b"" - - self._check_params(nonce, data, associated_data) - return aead._decrypt( - backend, self, nonce, data, [associated_data], self._tag_length - ) - - def _validate_lengths(self, nonce: bytes, data_len: int) -> None: - # For information about computing this, see - # https://tools.ietf.org/html/rfc3610#section-2.1 - l_val = 15 - len(nonce) - if 2 ** (8 * l_val) < data_len: - raise ValueError("Data too long for nonce") - - def _check_params( - self, nonce: bytes, data: bytes, associated_data: bytes - ) -> None: - utils._check_byteslike("nonce", nonce) - utils._check_byteslike("data", data) - utils._check_byteslike("associated_data", associated_data) - if not 7 <= len(nonce) <= 13: - raise ValueError("Nonce must be between 7 and 13 bytes") diff --git a/src/rust/src/backend/aead.rs b/src/rust/src/backend/aead.rs index b13a420c7588..7afd7a172e94 100644 --- a/src/rust/src/backend/aead.rs +++ b/src/rust/src/backend/aead.rs @@ -77,6 +77,7 @@ impl EvpCipherAead { ctx: &mut openssl::cipher_ctx::CipherCtx, data: &[u8], out: &mut [u8], + is_ccm: bool, ) -> CryptographyResult<()> { let bs = ctx.block_size(); @@ -87,9 +88,11 @@ impl EvpCipherAead { let n = ctx.cipher_update(data, Some(out))?; assert_eq!(n, data.len()); - let mut final_block = [0]; - let n = ctx.cipher_final(&mut final_block)?; - assert_eq!(n, 0); + if !is_ccm { + let mut final_block = [0]; + let n = ctx.cipher_final(&mut final_block)?; + assert_eq!(n, 0); + } } else { // Our algorithm here is: split the data into the full chunks, and // the remaining partial chunk. Feed the full chunks into OpenSSL @@ -131,9 +134,19 @@ impl EvpCipherAead { ) -> CryptographyResult<&'p pyo3::types::PyBytes> { let mut ctx = openssl::cipher_ctx::CipherCtx::new()?; ctx.copy(&self.base_encryption_ctx)?; - Self::encrypt_with_context(py, ctx, plaintext, aad, nonce, self.tag_len, self.tag_first) + Self::encrypt_with_context( + py, + ctx, + plaintext, + aad, + nonce, + self.tag_len, + self.tag_first, + false, + ) } + #[allow(clippy::too_many_arguments)] fn encrypt_with_context<'p>( py: pyo3::Python<'p>, mut ctx: openssl::cipher_ctx::CipherCtx, @@ -142,13 +155,19 @@ impl EvpCipherAead { nonce: Option<&[u8]>, tag_len: usize, tag_first: bool, + is_ccm: bool, ) -> CryptographyResult<&'p pyo3::types::PyBytes> { check_length(plaintext)?; - if let Some(nonce) = nonce { - ctx.set_iv_length(nonce.len())?; + if !is_ccm { + if let Some(nonce) = nonce { + ctx.set_iv_length(nonce.len())?; + } + ctx.encrypt_init(None, None, nonce)?; + } + if is_ccm { + ctx.set_data_len(plaintext.len())?; } - ctx.encrypt_init(None, None, nonce)?; Self::process_aad(&mut ctx, aad)?; @@ -164,7 +183,7 @@ impl EvpCipherAead { (ciphertext, tag) = b.split_at_mut(plaintext.len()); } - Self::process_data(&mut ctx, plaintext, ciphertext)?; + Self::process_data(&mut ctx, plaintext, ciphertext, is_ccm)?; ctx.tag(tag).map_err(CryptographyError::from)?; @@ -190,9 +209,11 @@ impl EvpCipherAead { nonce, self.tag_len, self.tag_first, + false, ) } + #[allow(clippy::too_many_arguments)] fn decrypt_with_context<'p>( py: pyo3::Python<'p>, mut ctx: openssl::cipher_ctx::CipherCtx, @@ -201,16 +222,12 @@ impl EvpCipherAead { nonce: Option<&[u8]>, tag_len: usize, tag_first: bool, + is_ccm: bool, ) -> CryptographyResult<&'p pyo3::types::PyBytes> { if ciphertext.len() < tag_len { return Err(CryptographyError::from(exceptions::InvalidTag::new_err(()))); } - if let Some(nonce) = nonce { - ctx.set_iv_length(nonce.len())?; - } - ctx.decrypt_init(None, None, nonce)?; - let tag; let ciphertext_data; if tag_first { @@ -221,7 +238,18 @@ impl EvpCipherAead { } else { (ciphertext_data, tag) = ciphertext.split_at(ciphertext.len() - tag_len); } - ctx.set_tag(tag)?; + + if !is_ccm { + if let Some(nonce) = nonce { + ctx.set_iv_length(nonce.len())?; + } + + ctx.decrypt_init(None, None, nonce)?; + ctx.set_tag(tag)?; + } + if is_ccm { + ctx.set_data_len(ciphertext_data.len())?; + } Self::process_aad(&mut ctx, aad)?; @@ -229,7 +257,7 @@ impl EvpCipherAead { py, ciphertext_data.len(), |b| { - Self::process_data(&mut ctx, ciphertext_data, b) + Self::process_data(&mut ctx, ciphertext_data, b, is_ccm) .map_err(|_| exceptions::InvalidTag::new_err(()))?; Ok(()) @@ -238,38 +266,29 @@ impl EvpCipherAead { } } -#[cfg(not(any( - CRYPTOGRAPHY_IS_LIBRESSL, - CRYPTOGRAPHY_IS_BORINGSSL, - not(CRYPTOGRAPHY_OPENSSL_300_OR_GREATER), - CRYPTOGRAPHY_OPENSSL_320_OR_GREATER -)))] struct LazyEvpCipherAead { cipher: &'static openssl::cipher::CipherRef, key: pyo3::Py, tag_len: usize, tag_first: bool, + is_ccm: bool, } -#[cfg(not(any( - CRYPTOGRAPHY_IS_LIBRESSL, - CRYPTOGRAPHY_IS_BORINGSSL, - not(CRYPTOGRAPHY_OPENSSL_300_OR_GREATER), - CRYPTOGRAPHY_OPENSSL_320_OR_GREATER -)))] impl LazyEvpCipherAead { fn new( cipher: &'static openssl::cipher::CipherRef, key: pyo3::Py, tag_len: usize, tag_first: bool, + is_ccm: bool, ) -> LazyEvpCipherAead { LazyEvpCipherAead { cipher, key, tag_len, tag_first, + is_ccm, } } @@ -283,7 +302,15 @@ impl LazyEvpCipherAead { let key_buf = self.key.as_ref(py).extract::>()?; let mut encryption_ctx = openssl::cipher_ctx::CipherCtx::new()?; - encryption_ctx.encrypt_init(Some(self.cipher), Some(key_buf.as_bytes()), None)?; + if self.is_ccm { + encryption_ctx.encrypt_init(Some(self.cipher), None, None)?; + encryption_ctx.set_iv_length(nonce.as_ref().unwrap().len())?; + encryption_ctx.set_tag_length(self.tag_len)?; + encryption_ctx.encrypt_init(None, Some(key_buf.as_bytes()), nonce)?; + } else { + encryption_ctx.encrypt_init(Some(self.cipher), Some(key_buf.as_bytes()), None)?; + } + EvpCipherAead::encrypt_with_context( py, encryption_ctx, @@ -292,6 +319,7 @@ impl LazyEvpCipherAead { nonce, self.tag_len, self.tag_first, + self.is_ccm, ) } @@ -305,7 +333,22 @@ impl LazyEvpCipherAead { let key_buf = self.key.as_ref(py).extract::>()?; let mut decryption_ctx = openssl::cipher_ctx::CipherCtx::new()?; - decryption_ctx.decrypt_init(Some(self.cipher), Some(key_buf.as_bytes()), None)?; + if self.is_ccm { + decryption_ctx.decrypt_init(Some(self.cipher), None, None)?; + decryption_ctx.set_iv_length(nonce.as_ref().unwrap().len())?; + + if ciphertext.len() < self.tag_len { + return Err(CryptographyError::from(exceptions::InvalidTag::new_err(()))); + } + + let (_, tag) = ciphertext.split_at(ciphertext.len() - self.tag_len); + decryption_ctx.set_tag(tag)?; + + decryption_ctx.decrypt_init(None, Some(key_buf.as_bytes()), nonce)?; + } else { + decryption_ctx.decrypt_init(Some(self.cipher), Some(key_buf.as_bytes()), None)?; + } + EvpCipherAead::decrypt_with_context( py, decryption_ctx, @@ -314,6 +357,7 @@ impl LazyEvpCipherAead { nonce, self.tag_len, self.tag_first, + self.is_ccm, ) } } @@ -478,6 +522,7 @@ impl ChaCha20Poly1305 { key, 16, false, + false, ) }) } @@ -583,7 +628,7 @@ impl AesGcm { }) } else { Ok(AesGcm { - ctx: LazyEvpCipherAead::new(cipher, key, 16, false), + ctx: LazyEvpCipherAead::new(cipher, key, 16, false, false), }) } @@ -642,6 +687,135 @@ impl AesGcm { } } +#[pyo3::prelude::pyclass( + frozen, + module = "cryptography.hazmat.bindings._rust.openssl.aead", + name = "AESCCM" +)] +struct AesCcm { + ctx: LazyEvpCipherAead, +} + +#[pyo3::prelude::pymethods] +impl AesCcm { + #[new] + fn new( + py: pyo3::Python<'_>, + key: pyo3::Py, + tag_length: Option, + ) -> CryptographyResult { + cfg_if::cfg_if! { + if #[cfg(CRYPTOGRAPHY_IS_BORINGSSL)] { + return Err(CryptographyError::from( + exceptions::UnsupportedAlgorithm::new_err(( + "AES-CCM is not supported by this version of OpenSSL", + exceptions::Reasons::UNSUPPORTED_CIPHER, + )), + )); + } else { + let key_buf = key.extract::>(py)?; + let cipher = match key_buf.as_bytes().len() { + 16 => openssl::cipher::Cipher::aes_128_ccm(), + 24 => openssl::cipher::Cipher::aes_192_ccm(), + 32 => openssl::cipher::Cipher::aes_256_ccm(), + _ => { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err( + "AESCCM key must be 128, 192, or 256 bits.", + ), + )) + } + }; + let tag_length = tag_length.unwrap_or(16); + if ![4, 6, 8, 10, 12, 14, 16].contains(&tag_length) { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err("Invalid tag_length"), + )); + } + + Ok(AesCcm { + ctx: LazyEvpCipherAead::new(cipher, key, tag_length, false, true), + }) + } + } + } + + #[staticmethod] + fn generate_key(py: pyo3::Python<'_>, bit_length: usize) -> CryptographyResult<&pyo3::PyAny> { + if bit_length != 128 && bit_length != 192 && bit_length != 256 { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err("bit_length must be 128, 192, or 256"), + )); + } + + Ok(types::OS_URANDOM.get(py)?.call1((bit_length / 8,))?) + } + + fn encrypt<'p>( + &self, + py: pyo3::Python<'p>, + nonce: CffiBuf<'_>, + data: CffiBuf<'_>, + associated_data: Option>, + ) -> CryptographyResult<&'p pyo3::types::PyBytes> { + let nonce_bytes = nonce.as_bytes(); + let data_bytes = data.as_bytes(); + let aad = associated_data.map(Aad::Single); + + if nonce_bytes.len() < 7 || nonce_bytes.len() > 13 { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err("Nonce must be between 7 and 13 bytes"), + )); + } + + check_length(data_bytes)?; + // For information about computing this, see + // https://tools.ietf.org/html/rfc3610#section-2.1 + let l_val = 15 - nonce_bytes.len(); + let max_length = 1usize.checked_shl(8 * l_val as u32); + // If `max_length` overflowed, then it's not possible for data to be + // longer than it. + if max_length.map(|v| v < data_bytes.len()).unwrap_or(false) { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err("Data too long for nonce"), + )); + } + + self.ctx.encrypt(py, data_bytes, aad, Some(nonce_bytes)) + } + + fn decrypt<'p>( + &self, + py: pyo3::Python<'p>, + nonce: CffiBuf<'_>, + data: CffiBuf<'_>, + associated_data: Option>, + ) -> CryptographyResult<&'p pyo3::types::PyBytes> { + let nonce_bytes = nonce.as_bytes(); + let data_bytes = data.as_bytes(); + let aad = associated_data.map(Aad::Single); + + if nonce_bytes.len() < 7 || nonce_bytes.len() > 13 { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err("Nonce must be between 7 and 13 bytes"), + )); + } + // For information about computing this, see + // https://tools.ietf.org/html/rfc3610#section-2.1 + let l_val = 15 - nonce_bytes.len(); + let max_length = 1usize.checked_shl(8 * l_val as u32); + // If `max_length` overflowed, then it's not possible for data to be + // longer than it. + if max_length.map(|v| v < data_bytes.len()).unwrap_or(false) { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err("Data too long for nonce"), + )); + } + + self.ctx.decrypt(py, data_bytes, aad, Some(nonce_bytes)) + } +} + #[pyo3::prelude::pyclass( frozen, module = "cryptography.hazmat.bindings._rust.openssl.aead", @@ -957,6 +1131,7 @@ pub(crate) fn create_module(py: pyo3::Python<'_>) -> pyo3::PyResult<&pyo3::prelu m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; diff --git a/tests/hazmat/primitives/test_aead.py b/tests/hazmat/primitives/test_aead.py index 5228edbbd2d3..a1f99ab815ed 100644 --- a/tests/hazmat/primitives/test_aead.py +++ b/tests/hazmat/primitives/test_aead.py @@ -296,6 +296,9 @@ def test_nonce_too_long(self, backend): with pytest.raises(ValueError): aesccm.encrypt(nonce, pt, None) + with pytest.raises(ValueError): + aesccm.decrypt(nonce, pt, None) + @pytest.mark.parametrize( ("nonce", "data", "associated_data"), [