From 51ea13ed3a6101d732ed805787a5b991f6b39276 Mon Sep 17 00:00:00 2001 From: Alex Gaynor Date: Wed, 13 Sep 2023 20:02:23 -0400 Subject: [PATCH] Migrate OCB3 to Rust (#9569) --- .../hazmat/backends/openssl/aead.py | 6 +- .../hazmat/bindings/_rust/openssl/aead.pyi | 17 ++ .../hazmat/primitives/ciphers/aead.py | 70 +----- src/rust/Cargo.lock | 1 + src/rust/Cargo.toml | 1 + src/rust/src/backend/aead.rs | 214 ++++++++++++++++-- 6 files changed, 212 insertions(+), 97 deletions(-) diff --git a/src/cryptography/hazmat/backends/openssl/aead.py b/src/cryptography/hazmat/backends/openssl/aead.py index f0162530b2f9..95c5133c1dc9 100644 --- a/src/cryptography/hazmat/backends/openssl/aead.py +++ b/src/cryptography/hazmat/backends/openssl/aead.py @@ -13,11 +13,10 @@ from cryptography.hazmat.primitives.ciphers.aead import ( AESCCM, AESGCM, - AESOCB3, ChaCha20Poly1305, ) - _AEADTypes = typing.Union[AESCCM, AESGCM, AESOCB3, ChaCha20Poly1305] + _AEADTypes = typing.Union[AESCCM, AESGCM, ChaCha20Poly1305] def _is_evp_aead_supported_cipher( @@ -220,7 +219,6 @@ def _evp_cipher_cipher_name(cipher: _AEADTypes) -> bytes: from cryptography.hazmat.primitives.ciphers.aead import ( AESCCM, AESGCM, - AESOCB3, ChaCha20Poly1305, ) @@ -228,8 +226,6 @@ def _evp_cipher_cipher_name(cipher: _AEADTypes) -> bytes: return b"chacha20-poly1305" elif isinstance(cipher, AESCCM): return f"aes-{len(cipher._key) * 8}-ccm".encode("ascii") - elif isinstance(cipher, AESOCB3): - return f"aes-{len(cipher._key) * 8}-ocb".encode("ascii") else: assert isinstance(cipher, AESGCM) return f"aes-{len(cipher._key) * 8}-gcm".encode("ascii") diff --git a/src/cryptography/hazmat/bindings/_rust/openssl/aead.pyi b/src/cryptography/hazmat/bindings/_rust/openssl/aead.pyi index 08a9307127ac..981d69d13219 100644 --- a/src/cryptography/hazmat/bindings/_rust/openssl/aead.pyi +++ b/src/cryptography/hazmat/bindings/_rust/openssl/aead.pyi @@ -16,3 +16,20 @@ class AESSIV: data: bytes, associated_data: list[bytes] | None, ) -> bytes: ... + +class AESOCB3: + def __init__(self, key: bytes) -> None: ... + @staticmethod + def generate_key(key_size: int) -> bytes: ... + def encrypt( + self, + nonce: bytes, + data: bytes, + associated_data: bytes | None, + ) -> bytes: ... + def decrypt( + self, + nonce: bytes, + data: bytes, + associated_data: bytes | None, + ) -> bytes: ... diff --git a/src/cryptography/hazmat/primitives/ciphers/aead.py b/src/cryptography/hazmat/primitives/ciphers/aead.py index 0feb921dc7bd..291513d75f04 100644 --- a/src/cryptography/hazmat/primitives/ciphers/aead.py +++ b/src/cryptography/hazmat/primitives/ciphers/aead.py @@ -21,6 +21,7 @@ ] AESSIV = rust_openssl.aead.AESSIV +AESOCB3 = rust_openssl.aead.AESOCB3 class ChaCha20Poly1305: @@ -242,72 +243,3 @@ def _check_params( utils._check_byteslike("associated_data", associated_data) if len(nonce) < 8 or len(nonce) > 128: raise ValueError("Nonce must be between 8 and 128 bytes") - - -class AESOCB3: - _MAX_SIZE = 2**31 - 1 - - def __init__(self, key: bytes): - utils._check_byteslike("key", key) - if len(key) not in (16, 24, 32): - raise ValueError("AESOCB3 key must be 128, 192, or 256 bits.") - - self._key = key - - if not backend.aead_cipher_supported(self): - raise exceptions.UnsupportedAlgorithm( - "OCB3 is not supported by this version of OpenSSL", - exceptions._Reasons.UNSUPPORTED_CIPHER, - ) - - @classmethod - def generate_key(cls, bit_length: int) -> bytes: - if not isinstance(bit_length, int): - raise TypeError("bit_length must be an integer") - - if bit_length not in (128, 192, 256): - raise ValueError("bit_length must be 128, 192, or 256") - - return os.urandom(bit_length // 8) - - def encrypt( - self, - nonce: bytes, - data: bytes, - associated_data: bytes | None, - ) -> bytes: - if associated_data is None: - associated_data = b"" - - if len(data) > self._MAX_SIZE or len(associated_data) > self._MAX_SIZE: - # This is OverflowError to match what cffi would raise - raise OverflowError( - "Data or associated data too long. Max 2**31 - 1 bytes" - ) - - self._check_params(nonce, data, associated_data) - return aead._encrypt(backend, self, nonce, data, [associated_data], 16) - - def decrypt( - self, - nonce: bytes, - data: bytes, - associated_data: bytes | None, - ) -> bytes: - if associated_data is None: - associated_data = b"" - - self._check_params(nonce, data, associated_data) - return aead._decrypt(backend, self, nonce, data, [associated_data], 16) - - def _check_params( - self, - nonce: bytes, - data: bytes, - associated_data: bytes, - ) -> None: - utils._check_byteslike("nonce", nonce) - utils._check_byteslike("data", data) - utils._check_byteslike("associated_data", associated_data) - if len(nonce) < 12 or len(nonce) > 15: - raise ValueError("Nonce must be between 12 and 15 bytes") diff --git a/src/rust/Cargo.lock b/src/rust/Cargo.lock index 18d790e9c8f1..6134c6a02b72 100644 --- a/src/rust/Cargo.lock +++ b/src/rust/Cargo.lock @@ -86,6 +86,7 @@ version = "0.1.0" dependencies = [ "asn1", "cc", + "cfg-if", "cryptography-cffi", "cryptography-openssl", "cryptography-x509", diff --git a/src/rust/Cargo.toml b/src/rust/Cargo.toml index 6e408e9b4355..9d41d805fc16 100644 --- a/src/rust/Cargo.toml +++ b/src/rust/Cargo.toml @@ -9,6 +9,7 @@ rust-version = "1.63.0" [dependencies] once_cell = "1" +cfg-if = "1" pyo3 = { version = "0.19", features = ["abi3-py37"] } asn1 = { version = "0.15.5", default-features = false } cryptography-cffi = { path = "cryptography-cffi" } diff --git a/src/rust/src/backend/aead.rs b/src/rust/src/backend/aead.rs index de330448b9e9..0965b71a7005 100644 --- a/src/rust/src/backend/aead.rs +++ b/src/rust/src/backend/aead.rs @@ -20,6 +20,7 @@ fn check_length(data: &[u8]) -> CryptographyResult<()> { } enum Aad<'a> { + Single(CffiBuf<'a>), List(&'a pyo3::types::PyList), } @@ -55,12 +56,19 @@ impl EvpCipherAead { ctx: &mut openssl::cipher_ctx::CipherCtx, aad: Option>, ) -> CryptographyResult<()> { - if let Some(Aad::List(ads)) = aad { - for ad in ads.iter() { - let ad = ad.extract::>()?; + match aad { + Some(Aad::Single(ad)) => { check_length(ad.as_bytes())?; ctx.cipher_update(ad.as_bytes(), None)?; } + Some(Aad::List(ads)) => { + for ad in ads.iter() { + let ad = ad.extract::>()?; + check_length(ad.as_bytes())?; + ctx.cipher_update(ad.as_bytes(), None)?; + } + } + None => {} } Ok(()) @@ -72,12 +80,46 @@ impl EvpCipherAead { data: &[u8], out: &mut [u8], ) -> CryptographyResult<()> { - let n = ctx.cipher_update(data, Some(out))?; - assert_eq!(n, data.len()); - - let mut final_block = [0]; - let n = ctx.cipher_final(&mut final_block)?; - assert_eq!(n, 0); + let bs = ctx.block_size(); + + // For AEADs that operate as if they are streaming there's an easy + // path. For AEADs that are more like block ciphers (notably, OCB), + // this is a bit more complicated. + if bs == 1 { + let n = ctx.cipher_update(data, Some(out))?; + assert_eq!(n, data.len()); + + let mut final_block = [0]; + let n = ctx.cipher_final(&mut final_block)?; + assert_eq!(n, 0); + } else { + // Our algorithm here is: split the data into the full chunks, and + // the remaining partial chunk. Feed the full chunks into OpenSSL + // and let it write the results to `out`. Then feed the trailer + // in, allowing it to write the results to a buffer on the + // stack -- this never writes anything. Finally, finalize the AEAD + // and let it write the results to the stack buffer, then copy + // from the stack buffer over to `out`. The indirection via the + // stack buffer is required because OpenSSL uses it as scratch + // space, and `out` wouldn't be long enough. + let (initial, trailer) = data.split_at((data.len() / bs) * bs); + + let n = + // SAFETY: `initial.len()` is a precise multiple of the block + // size, which means the space required in the output is + // exactly `initial.len()`. + unsafe { ctx.cipher_update_unchecked(initial, Some(&mut out[..initial.len()]))? }; + assert_eq!(n, initial.len()); + + assert!(bs <= 16); + let mut buf = [0; 32]; + let n = ctx.cipher_update(trailer, Some(&mut buf))?; + assert_eq!(n, 0); + + let n = ctx.cipher_final(&mut buf)?; + assert_eq!(n, trailer.len()); + out[initial.len()..].copy_from_slice(&buf[..n]); + } Ok(()) } @@ -93,6 +135,9 @@ impl EvpCipherAead { let mut ctx = openssl::cipher_ctx::CipherCtx::new()?; ctx.copy(&self.base_encryption_ctx)?; + if let Some(nonce) = nonce { + ctx.set_iv_length(nonce.len())?; + } ctx.encrypt_init(None, None, nonce)?; self.process_aad(&mut ctx, aad)?; @@ -103,9 +148,11 @@ impl EvpCipherAead { |b| { let ciphertext; let tag; - // TODO: remove once we have a second AEAD implemented here. - assert!(self.tag_first); - (tag, ciphertext) = b.split_at_mut(self.tag_len); + if self.tag_first { + (tag, ciphertext) = b.split_at_mut(self.tag_len); + } else { + (ciphertext, tag) = b.split_at_mut(plaintext.len()); + } self.process_data(&mut ctx, plaintext, ciphertext)?; @@ -129,24 +176,35 @@ impl EvpCipherAead { let mut ctx = openssl::cipher_ctx::CipherCtx::new()?; ctx.copy(&self.base_decryption_ctx)?; + if let Some(nonce) = nonce { + ctx.set_iv_length(nonce.len())?; + } ctx.decrypt_init(None, None, nonce)?; - assert!(self.tag_first); - // RFC 5297 defines the output as IV || C, where the tag we generate - // is the "IV" and C is the ciphertext. This is the opposite of our - // other AEADs, which are Ciphertext || Tag. - let (tag, ciphertext) = ciphertext.split_at(self.tag_len); + let tag; + let ciphertext_data; + if self.tag_first { + // RFC 5297 defines the output as IV || C, where the tag we generate + // is the "IV" and C is the ciphertext. This is the opposite of our + // other AEADs, which are Ciphertext || Tag. + (tag, ciphertext_data) = ciphertext.split_at(self.tag_len); + } else { + (ciphertext_data, tag) = ciphertext.split_at(ciphertext.len() - self.tag_len); + } ctx.set_tag(tag)?; self.process_aad(&mut ctx, aad)?; - Ok(pyo3::types::PyBytes::new_with(py, ciphertext.len(), |b| { - // AES SIV can error here if the data is invalid on decrypt - self.process_data(&mut ctx, ciphertext, b) - .map_err(|_| exceptions::InvalidTag::new_err(()))?; + Ok(pyo3::types::PyBytes::new_with( + py, + ciphertext_data.len(), + |b| { + self.process_data(&mut ctx, ciphertext_data, b) + .map_err(|_| exceptions::InvalidTag::new_err(()))?; - Ok(()) - })?) + Ok(()) + }, + )?) } } @@ -215,6 +273,7 @@ impl AesSiv { Ok(types::OS_URANDOM.get(py)?.call1((bit_length / 8,))?) } + #[pyo3(signature = (data, associated_data))] fn encrypt<'p>( &self, py: pyo3::Python<'p>, @@ -232,6 +291,7 @@ impl AesSiv { self.ctx.encrypt(py, data_bytes, aad, None) } + #[pyo3(signature = (data, associated_data))] fn decrypt<'p>( &self, py: pyo3::Python<'p>, @@ -243,10 +303,118 @@ impl AesSiv { } } +#[pyo3::prelude::pyclass( + frozen, + module = "cryptography.hazmat.bindings._rust.openssl.aead", + name = "AESOCB3" +)] +struct AesOcb3 { + ctx: EvpCipherAead, +} + +#[pyo3::prelude::pymethods] +impl AesOcb3 { + #[new] + fn new(py: pyo3::Python<'_>, key: pyo3::Py) -> CryptographyResult { + let key_buf = key.extract::>(py)?; + + cfg_if::cfg_if! { + if #[cfg(any(CRYPTOGRAPHY_IS_LIBRESSL, CRYPTOGRAPHY_IS_BORINGSSL))] { + return Err(CryptographyError::from( + exceptions::UnsupportedAlgorithm::new_err(( + "AES-OCB3 is not supported by this version of OpenSSL", + exceptions::Reasons::UNSUPPORTED_CIPHER, + )), + )); + } else { + if cryptography_openssl::fips::is_enabled() { + return Err(CryptographyError::from( + exceptions::UnsupportedAlgorithm::new_err(( + "AES-OCB3 is not supported by this version of OpenSSL", + exceptions::Reasons::UNSUPPORTED_CIPHER, + )), + )); + } + + let cipher = match key_buf.as_bytes().len() { + 16 => openssl::cipher::Cipher::aes_128_ocb(), + 24 => openssl::cipher::Cipher::aes_192_ocb(), + 32 => openssl::cipher::Cipher::aes_256_ocb(), + _ => { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err( + "AESOCB3 key must be 128, 192, or 256 bits.", + ), + )) + } + }; + + Ok(AesOcb3 { + ctx: EvpCipherAead::new(cipher, key_buf.as_bytes(), 16, false)?, + }) + } + } + } + + #[staticmethod] + fn generate_key(py: pyo3::Python<'_>, bit_length: usize) -> CryptographyResult<&pyo3::PyAny> { + if bit_length != 128 && bit_length != 192 && bit_length != 256 { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err("bit_length must be 128, 192, or 256"), + )); + } + + Ok(types::OS_URANDOM.get(py)?.call1((bit_length / 8,))?) + } + + #[pyo3(signature = (nonce, data, associated_data))] + fn encrypt<'p>( + &self, + py: pyo3::Python<'p>, + nonce: CffiBuf<'_>, + data: CffiBuf<'_>, + associated_data: Option>, + ) -> CryptographyResult<&'p pyo3::types::PyBytes> { + let nonce_bytes = nonce.as_bytes(); + let aad = associated_data.map(Aad::Single); + + if nonce_bytes.len() < 12 || nonce_bytes.len() > 15 { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err("Nonce must be between 12 and 15 bytes"), + )); + } + + self.ctx + .encrypt(py, data.as_bytes(), aad, Some(nonce_bytes)) + } + + #[pyo3(signature = (nonce, data, associated_data))] + fn decrypt<'p>( + &self, + py: pyo3::Python<'p>, + nonce: CffiBuf<'_>, + data: CffiBuf<'_>, + associated_data: Option>, + ) -> CryptographyResult<&'p pyo3::types::PyBytes> { + let nonce_bytes = nonce.as_bytes(); + let aad = associated_data.map(Aad::Single); + + if nonce_bytes.len() < 12 || nonce_bytes.len() > 15 { + return Err(CryptographyError::from( + pyo3::exceptions::PyValueError::new_err("Nonce must be between 12 and 15 bytes"), + )); + } + + self.ctx + .decrypt(py, data.as_bytes(), aad, Some(nonce_bytes)) + } +} + pub(crate) fn create_module(py: pyo3::Python<'_>) -> pyo3::PyResult<&pyo3::prelude::PyModule> { let m = pyo3::prelude::PyModule::new(py, "aead")?; m.add_class::()?; + m.add_class::()?; Ok(m) }