diff --git a/src/cryptography/hazmat/backends/openssl/aead.py b/src/cryptography/hazmat/backends/openssl/aead.py index f0162530b2f9b..4198164ba0092 100644 --- a/src/cryptography/hazmat/backends/openssl/aead.py +++ b/src/cryptography/hazmat/backends/openssl/aead.py @@ -14,36 +14,16 @@ AESCCM, AESGCM, AESOCB3, - ChaCha20Poly1305, ) - _AEADTypes = typing.Union[AESCCM, AESGCM, AESOCB3, ChaCha20Poly1305] - - -def _is_evp_aead_supported_cipher( - backend: Backend, cipher: _AEADTypes -) -> bool: - """ - Checks whether the given cipher is supported through - EVP_AEAD rather than the normal OpenSSL EVP_CIPHER API. - """ - from cryptography.hazmat.primitives.ciphers.aead import ChaCha20Poly1305 - - return backend._lib.Cryptography_HAS_EVP_AEAD and isinstance( - cipher, ChaCha20Poly1305 - ) + _AEADTypes = typing.Union[AESCCM, AESGCM, AESOCB3] def _aead_cipher_supported(backend: Backend, cipher: _AEADTypes) -> bool: - if _is_evp_aead_supported_cipher(backend, cipher): - return True - else: - cipher_name = _evp_cipher_cipher_name(cipher) - if backend._fips_enabled and cipher_name not in backend._fips_aead: - return False - return ( - backend._lib.EVP_get_cipherbyname(cipher_name) != backend._ffi.NULL - ) + cipher_name = _evp_cipher_cipher_name(cipher) + if backend._fips_enabled and cipher_name not in backend._fips_aead: + return False + return backend._lib.EVP_get_cipherbyname(cipher_name) != backend._ffi.NULL def _aead_create_ctx( @@ -51,10 +31,7 @@ def _aead_create_ctx( cipher: _AEADTypes, key: bytes, ): - if _is_evp_aead_supported_cipher(backend, cipher): - return _evp_aead_create_ctx(backend, cipher, key) - else: - return _evp_cipher_create_ctx(backend, cipher, key) + return _evp_cipher_create_ctx(backend, cipher, key) def _encrypt( @@ -66,14 +43,9 @@ def _encrypt( tag_length: int, ctx: typing.Any = None, ) -> bytes: - if _is_evp_aead_supported_cipher(backend, cipher): - return _evp_aead_encrypt( - backend, cipher, nonce, data, associated_data, tag_length, ctx - ) - else: - return _evp_cipher_encrypt( - backend, cipher, nonce, data, associated_data, tag_length, ctx - ) + return _evp_cipher_encrypt( + backend, cipher, nonce, data, associated_data, tag_length, ctx + ) def _decrypt( @@ -85,132 +57,10 @@ def _decrypt( tag_length: int, ctx: typing.Any = None, ) -> bytes: - if _is_evp_aead_supported_cipher(backend, cipher): - return _evp_aead_decrypt( - backend, cipher, nonce, data, associated_data, tag_length, ctx - ) - else: - return _evp_cipher_decrypt( - backend, cipher, nonce, data, associated_data, tag_length, ctx - ) - - -def _evp_aead_create_ctx( - backend: Backend, - cipher: _AEADTypes, - key: bytes, - tag_len: int | None = None, -): - aead_cipher = _evp_aead_get_cipher(backend, cipher) - assert aead_cipher is not None - key_ptr = backend._ffi.from_buffer(key) - tag_len = ( - backend._lib.EVP_AEAD_DEFAULT_TAG_LENGTH - if tag_len is None - else tag_len - ) - ctx = backend._lib.Cryptography_EVP_AEAD_CTX_new( - aead_cipher, key_ptr, len(key), tag_len - ) - backend.openssl_assert(ctx != backend._ffi.NULL) - ctx = backend._ffi.gc(ctx, backend._lib.EVP_AEAD_CTX_free) - return ctx - - -def _evp_aead_get_cipher(backend: Backend, cipher: _AEADTypes): - from cryptography.hazmat.primitives.ciphers.aead import ( - ChaCha20Poly1305, + return _evp_cipher_decrypt( + backend, cipher, nonce, data, associated_data, tag_length, ctx ) - # Currently only ChaCha20-Poly1305 is supported using this API - assert isinstance(cipher, ChaCha20Poly1305) - return backend._lib.EVP_aead_chacha20_poly1305() - - -def _evp_aead_encrypt( - backend: Backend, - cipher: _AEADTypes, - nonce: bytes, - data: bytes, - associated_data: list[bytes], - tag_length: int, - ctx: typing.Any, -) -> bytes: - assert ctx is not None - - aead_cipher = _evp_aead_get_cipher(backend, cipher) - assert aead_cipher is not None - - out_len = backend._ffi.new("size_t *") - # max_out_len should be in_len plus the result of - # EVP_AEAD_max_overhead. - max_out_len = len(data) + backend._lib.EVP_AEAD_max_overhead(aead_cipher) - out_buf = backend._ffi.new("uint8_t[]", max_out_len) - data_ptr = backend._ffi.from_buffer(data) - nonce_ptr = backend._ffi.from_buffer(nonce) - aad = b"".join(associated_data) - aad_ptr = backend._ffi.from_buffer(aad) - - res = backend._lib.EVP_AEAD_CTX_seal( - ctx, - out_buf, - out_len, - max_out_len, - nonce_ptr, - len(nonce), - data_ptr, - len(data), - aad_ptr, - len(aad), - ) - backend.openssl_assert(res == 1) - encrypted_data = backend._ffi.buffer(out_buf, out_len[0])[:] - return encrypted_data - - -def _evp_aead_decrypt( - backend: Backend, - cipher: _AEADTypes, - nonce: bytes, - data: bytes, - associated_data: list[bytes], - tag_length: int, - ctx: typing.Any, -) -> bytes: - if len(data) < tag_length: - raise InvalidTag - - assert ctx is not None - - out_len = backend._ffi.new("size_t *") - # max_out_len should at least in_len - max_out_len = len(data) - out_buf = backend._ffi.new("uint8_t[]", max_out_len) - data_ptr = backend._ffi.from_buffer(data) - nonce_ptr = backend._ffi.from_buffer(nonce) - aad = b"".join(associated_data) - aad_ptr = backend._ffi.from_buffer(aad) - - res = backend._lib.EVP_AEAD_CTX_open( - ctx, - out_buf, - out_len, - max_out_len, - nonce_ptr, - len(nonce), - data_ptr, - len(data), - aad_ptr, - len(aad), - ) - - if res == 0: - backend._consume_errors() - raise InvalidTag - - decrypted_data = backend._ffi.buffer(out_buf, out_len[0])[:] - return decrypted_data - _ENCRYPT = 1 _DECRYPT = 0 @@ -221,12 +71,9 @@ def _evp_cipher_cipher_name(cipher: _AEADTypes) -> bytes: AESCCM, AESGCM, AESOCB3, - ChaCha20Poly1305, ) - if isinstance(cipher, ChaCha20Poly1305): - return b"chacha20-poly1305" - elif isinstance(cipher, AESCCM): + if isinstance(cipher, AESCCM): return f"aes-{len(cipher._key) * 8}-ccm".encode("ascii") elif isinstance(cipher, AESOCB3): return f"aes-{len(cipher._key) * 8}-ocb".encode("ascii") diff --git a/src/cryptography/hazmat/bindings/_rust/openssl/aead.pyi b/src/cryptography/hazmat/bindings/_rust/openssl/aead.pyi index 08a9307127aca..a0a30433b88a5 100644 --- a/src/cryptography/hazmat/bindings/_rust/openssl/aead.pyi +++ b/src/cryptography/hazmat/bindings/_rust/openssl/aead.pyi @@ -2,6 +2,23 @@ # 2.0, and the BSD License. See the LICENSE file in the root of this repository # for complete details. +class ChaCha20Poly1305: + def __init__(self, key: bytes) -> None: ... + @staticmethod + def generate_key() -> bytes: ... + def encrypt( + self, + nonce: bytes, + data: bytes, + associated_data: bytes | None, + ) -> bytes: ... + def decrypt( + self, + nonce: bytes, + data: bytes, + associated_data: bytes | None, + ) -> bytes: ... + class AESSIV: def __init__(self, key: bytes) -> None: ... @staticmethod diff --git a/src/cryptography/hazmat/primitives/ciphers/aead.py b/src/cryptography/hazmat/primitives/ciphers/aead.py index 0feb921dc7bd5..404d24aad65ea 100644 --- a/src/cryptography/hazmat/primitives/ciphers/aead.py +++ b/src/cryptography/hazmat/primitives/ciphers/aead.py @@ -9,7 +9,6 @@ from cryptography import exceptions, utils from cryptography.hazmat.backends.openssl import aead from cryptography.hazmat.backends.openssl.backend import backend -from cryptography.hazmat.bindings._rust import FixedPool from cryptography.hazmat.bindings._rust import openssl as rust_openssl __all__ = [ @@ -20,82 +19,10 @@ "AESSIV", ] +ChaCha20Poly1305 = rust_openssl.aead.ChaCha20Poly1305 AESSIV = rust_openssl.aead.AESSIV -class ChaCha20Poly1305: - _MAX_SIZE = 2**31 - 1 - - def __init__(self, key: bytes): - if not backend.aead_cipher_supported(self): - raise exceptions.UnsupportedAlgorithm( - "ChaCha20Poly1305 is not supported by this version of OpenSSL", - exceptions._Reasons.UNSUPPORTED_CIPHER, - ) - utils._check_byteslike("key", key) - - if len(key) != 32: - raise ValueError("ChaCha20Poly1305 key must be 32 bytes.") - - self._key = key - self._pool = FixedPool(self._create_fn) - - @classmethod - def generate_key(cls) -> bytes: - return os.urandom(32) - - def _create_fn(self): - return aead._aead_create_ctx(backend, self, self._key) - - def encrypt( - self, - nonce: bytes, - data: bytes, - associated_data: bytes | None, - ) -> bytes: - if associated_data is None: - associated_data = b"" - - if len(data) > self._MAX_SIZE or len(associated_data) > self._MAX_SIZE: - # This is OverflowError to match what cffi would raise - raise OverflowError( - "Data or associated data too long. Max 2**31 - 1 bytes" - ) - - self._check_params(nonce, data, associated_data) - with self._pool.acquire() as ctx: - return aead._encrypt( - backend, self, nonce, data, [associated_data], 16, ctx - ) - - def decrypt( - self, - nonce: bytes, - data: bytes, - associated_data: bytes | None, - ) -> bytes: - if associated_data is None: - associated_data = b"" - - self._check_params(nonce, data, associated_data) - with self._pool.acquire() as ctx: - return aead._decrypt( - backend, self, nonce, data, [associated_data], 16, ctx - ) - - def _check_params( - self, - nonce: bytes, - data: bytes, - associated_data: bytes, - ) -> None: - utils._check_byteslike("nonce", nonce) - utils._check_byteslike("data", data) - utils._check_byteslike("associated_data", associated_data) - if len(nonce) != 12: - raise ValueError("Nonce must be 12 bytes") - - class AESCCM: _MAX_SIZE = 2**31 - 1 diff --git a/src/rust/cryptography-openssl/src/aead.rs b/src/rust/cryptography-openssl/src/aead.rs new file mode 100644 index 0000000000000..b592ce204519e --- /dev/null +++ b/src/rust/cryptography-openssl/src/aead.rs @@ -0,0 +1,88 @@ +// This file is dual licensed under the terms of the Apache License, Version +// 2.0, and the BSD License. See the LICENSE file in the root of this repository +// for complete details. + +use crate::{cvt, cvt_p, OpenSSLResult}; +use foreign_types_shared::{ForeignType, ForeignTypeRef}; + +pub enum AeadType { + ChaCha20Poly1305, +} + +foreign_types::foreign_type! { + type CType = ffi::EVP_AEAD_CTX; + fn drop = ffi::EVP_AEAD_CTX_free; + + pub struct AeadCtx; + pub struct AeadCtxRef; +} + +impl AeadCtx { + pub fn new(aead: AeadType, key: &[u8]) -> OpenSSLResult { + let aead = match aead { + AeadType::ChaCha20Poly1305 => unsafe { ffi::EVP_aead_chacha20_poly1305() }, + }; + + unsafe { + let ctx = cvt_p(ffi::EVP_AEAD_CTX_new( + aead, + key.as_ptr(), + key.len(), + ffi::EVP_AEAD_DEFAULT_TAG_LENGTH as usize, + ))?; + Ok(AeadCtx::from_ptr(ctx)) + } + } +} + +impl AeadCtxRef { + pub fn encrypt( + &self, + data: &[u8], + nonce: &[u8], + ad: &[u8], + out: &mut [u8], + ) -> OpenSSLResult<()> { + let mut out_len = out.len(); + unsafe { + cvt(ffi::EVP_AEAD_CTX_seal( + self.as_ptr(), + out.as_mut_ptr(), + &mut out_len, + out.len(), + nonce.as_ptr(), + nonce.len(), + data.as_ptr(), + data.len(), + ad.as_ptr(), + ad.len(), + ))?; + } + Ok(()) + } + + pub fn decrypt( + &self, + data: &[u8], + nonce: &[u8], + ad: &[u8], + out: &mut [u8], + ) -> OpenSSLResult<()> { + let mut out_len = out.len(); + unsafe { + cvt(ffi::EVP_AEAD_CTX_open( + self.as_ptr(), + out.as_mut_ptr(), + &mut out_len, + out.len(), + nonce.as_ptr(), + nonce.len(), + data.as_ptr(), + data.len(), + ad.as_ptr(), + ad.len(), + ))?; + } + Ok(()) + } +} diff --git a/src/rust/cryptography-openssl/src/lib.rs b/src/rust/cryptography-openssl/src/lib.rs index 0a2b48149e0f1..4e300f3dfb0bf 100644 --- a/src/rust/cryptography-openssl/src/lib.rs +++ b/src/rust/cryptography-openssl/src/lib.rs @@ -2,6 +2,8 @@ // 2.0, and the BSD License. See the LICENSE file in the root of this repository // for complete details. +#[cfg(CRYPTOGRAPHY_IS_BORINGSSL)] +pub mod aead; pub mod fips; pub mod hmac; diff --git a/src/rust/src/backend/aead.rs b/src/rust/src/backend/aead.rs index 94a9e949a53ae..630a5001672ff 100644 --- a/src/rust/src/backend/aead.rs +++ b/src/rust/src/backend/aead.rs @@ -20,6 +20,7 @@ fn check_length(data: &[u8]) -> CryptographyResult<()> { } enum Aad<'a> { + Single(CffiBuf<'a>), List(&'a pyo3::types::PyList), } @@ -27,12 +28,19 @@ fn process_aad( ctx: &mut openssl::cipher_ctx::CipherCtx, aad: Option>, ) -> CryptographyResult<()> { - if let Some(Aad::List(ads)) = aad { - for ad in ads.iter() { - let ad = ad.extract::>()?; + match aad { + Some(Aad::Single(ad)) => { check_length(ad.as_bytes())?; ctx.cipher_update(ad.as_bytes(), None)?; } + Some(Aad::List(ads)) => { + for ad in ads.iter() { + let ad = ad.extract::>()?; + check_length(ad.as_bytes())?; + ctx.cipher_update(ad.as_bytes(), None)?; + } + } + None => {} } Ok(()) @@ -55,9 +63,11 @@ fn encrypt_value<'p>( |b| { let ciphertext; let tag; - // TODO: remove once we have a second AEAD implemented here. - assert!(tag_first); - (tag, ciphertext) = b.split_at_mut(tag_len); + if tag_first { + (tag, ciphertext) = b.split_at_mut(tag_len); + } else { + (ciphertext, tag) = b.split_at_mut(plaintext.len()); + }; let n = ctx .cipher_update(plaintext, Some(ciphertext)) @@ -102,6 +112,102 @@ fn decrypt_value<'p>( })?) } +#[pyo3::prelude::pyclass(frozen, module = "cryptography.hazmat.bindings._rust.openssl.aead")] +struct ChaCha20Poly1305 { + key: pyo3::Py, + cipher: &'static openssl::cipher::CipherRef, +} + +#[pyo3::prelude::pymethods] +impl ChaCha20Poly1305 { + #[new] + fn new(py: pyo3::Python<'_>, key: pyo3::Py) -> CryptographyResult { + let key_buf = key.extract::>(py)?; + if key_buf.as_bytes().len() != 32 { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err("ChaCha20Poly1305 key must be 32 bytes."), + )); + } + + // TODO: Handle if ChaChaPoly1305 isn't supported by this OpenSSL + // TODO: FixedPool + + Ok(ChaCha20Poly1305 { + key, + cipher: openssl::cipher::Cipher::chacha20_poly1305(), + }) + } + + #[staticmethod] + fn generate_key(py: pyo3::Python<'_>) -> CryptographyResult<&pyo3::PyAny> { + Ok(py + .import(pyo3::intern!(py, "os"))? + .call_method1(pyo3::intern!(py, "urandom"), (32,))?) + } + + fn encrypt<'p>( + &self, + py: pyo3::Python<'p>, + nonce: CffiBuf<'_>, + data: CffiBuf<'_>, + associated_data: Option>, + ) -> CryptographyResult<&'p pyo3::types::PyBytes> { + let key_buf = self.key.extract::>(py)?; + let data_bytes = data.as_bytes(); + let nonce_bytes = nonce.as_bytes(); + let aad = associated_data.map(Aad::Single); + + if nonce_bytes.len() != 12 { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err("Nonce must be 12 bytes"), + )); + } + + let mut ctx = openssl::cipher_ctx::CipherCtx::new()?; + ctx.encrypt_init( + Some(self.cipher), + Some(key_buf.as_bytes()), + Some(nonce_bytes), + )?; + + encrypt_value(py, ctx, data_bytes, aad, 16, false) + } + + fn decrypt<'p>( + &self, + py: pyo3::Python<'p>, + nonce: CffiBuf<'_>, + data: CffiBuf<'_>, + associated_data: Option>, + ) -> CryptographyResult<&'p pyo3::types::PyBytes> { + let key_buf = self.key.extract::>(py)?; + let data_bytes = data.as_bytes(); + let nonce_bytes = nonce.as_bytes(); + let aad = associated_data.map(Aad::Single); + + if nonce_bytes.len() != 12 { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err("Nonce must be 12 bytes"), + )); + } + + let mut ctx = openssl::cipher_ctx::CipherCtx::new()?; + ctx.decrypt_init( + Some(self.cipher), + Some(key_buf.as_bytes()), + Some(nonce_bytes), + )?; + + if data_bytes.len() < 16 { + return Err(CryptographyError::from(exceptions::InvalidTag::new_err(()))); + } + let (ciphertext, tag) = data_bytes.split_at(data_bytes.len() - 16); + ctx.set_tag(tag)?; + + decrypt_value(py, ctx, ciphertext, aad) + } +} + #[pyo3::prelude::pyclass( frozen, module = "cryptography.hazmat.bindings._rust.openssl.aead", @@ -218,6 +324,7 @@ impl AesSiv { pub(crate) fn create_module(py: pyo3::Python<'_>) -> pyo3::PyResult<&pyo3::prelude::PyModule> { let m = pyo3::prelude::PyModule::new(py, "aead")?; + m.add_class::()?; m.add_class::()?; Ok(m)