From 7cf3e3a9055af83226539509a7452263e14f5c22 Mon Sep 17 00:00:00 2001 From: Alex Gaynor Date: Mon, 19 Feb 2024 17:18:54 -0500 Subject: [PATCH] Convert PKCS#12 loading to Rust --- .../hazmat/backends/openssl/backend.py | 100 ------------ .../hazmat/bindings/_rust/openssl/keys.pyi | 4 - .../hazmat/bindings/_rust/pkcs12.pyi | 26 ++++ .../hazmat/bindings/_rust/pkcs7.pyi | 4 + .../hazmat/primitives/serialization/pkcs12.py | 25 +-- src/rust/src/backend/keys.rs | 16 +- src/rust/src/lib.rs | 2 + src/rust/src/pkcs12.rs | 146 ++++++++++++++++++ src/rust/src/types.rs | 9 ++ tests/hazmat/backends/test_openssl.py | 9 -- 10 files changed, 191 insertions(+), 150 deletions(-) create mode 100644 src/cryptography/hazmat/bindings/_rust/pkcs12.pyi create mode 100644 src/rust/src/pkcs12.rs diff --git a/src/cryptography/hazmat/backends/openssl/backend.py b/src/cryptography/hazmat/backends/openssl/backend.py index 6a4aeca7521f6..56d8206612e69 100644 --- a/src/cryptography/hazmat/backends/openssl/backend.py +++ b/src/cryptography/hazmat/backends/openssl/backend.py @@ -22,9 +22,6 @@ PSS, PKCS1v15, ) -from cryptography.hazmat.primitives.asymmetric.types import ( - PrivateKeyTypes, -) from cryptography.hazmat.primitives.ciphers import ( CipherAlgorithm, ) @@ -38,7 +35,6 @@ from cryptography.hazmat.primitives.serialization.pkcs12 import ( PBES, PKCS12Certificate, - PKCS12KeyAndCertificates, PKCS12PrivateKeyTypes, _PKCS12CATypes, ) @@ -278,12 +274,6 @@ def _cert2ossl(self, cert: x509.Certificate) -> typing.Any: x509 = self._ffi.gc(x509, self._lib.X509_free) return x509 - def _ossl2cert(self, x509_ptr: typing.Any) -> x509.Certificate: - bio = self._create_mem_bio_gc() - res = self._lib.i2d_X509_bio(bio, x509_ptr) - self.openssl_assert(res == 1) - return x509.load_der_x509_certificate(self._read_mem_bio(bio)) - def _key2ossl(self, key: PKCS12PrivateKeyTypes) -> typing.Any: data = key.private_bytes( serialization.Encoding.DER, @@ -398,96 +388,6 @@ def _zeroed_null_terminated_buf(self, data): # Cast to a uint8_t * so we can assign by integer self._zero_data(self._ffi.cast("uint8_t *", buf), data_len) - def load_key_and_certificates_from_pkcs12( - self, data: bytes, password: bytes | None - ) -> tuple[ - PrivateKeyTypes | None, - x509.Certificate | None, - list[x509.Certificate], - ]: - pkcs12 = self.load_pkcs12(data, password) - return ( - pkcs12.key, - pkcs12.cert.certificate if pkcs12.cert else None, - [cert.certificate for cert in pkcs12.additional_certs], - ) - - def load_pkcs12( - self, data: bytes, password: bytes | None - ) -> PKCS12KeyAndCertificates: - if password is not None: - utils._check_byteslike("password", password) - - bio = self._bytes_to_bio(data) - p12 = self._lib.d2i_PKCS12_bio(bio.bio, self._ffi.NULL) - if p12 == self._ffi.NULL: - self._consume_errors() - raise ValueError("Could not deserialize PKCS12 data") - - p12 = self._ffi.gc(p12, self._lib.PKCS12_free) - evp_pkey_ptr = self._ffi.new("EVP_PKEY **") - x509_ptr = self._ffi.new("X509 **") - sk_x509_ptr = self._ffi.new("Cryptography_STACK_OF_X509 **") - with self._zeroed_null_terminated_buf(password) as password_buf: - res = self._lib.PKCS12_parse( - p12, password_buf, evp_pkey_ptr, x509_ptr, sk_x509_ptr - ) - if res == 0: - self._consume_errors() - raise ValueError("Invalid password or PKCS12 data") - - cert = None - key = None - additional_certificates = [] - - if evp_pkey_ptr[0] != self._ffi.NULL: - evp_pkey = self._ffi.gc(evp_pkey_ptr[0], self._lib.EVP_PKEY_free) - # We don't support turning off RSA key validation when loading - # PKCS12 keys - key = rust_openssl.keys.private_key_from_ptr( - int(self._ffi.cast("uintptr_t", evp_pkey)), - unsafe_skip_rsa_key_validation=False, - ) - - if x509_ptr[0] != self._ffi.NULL: - x509 = self._ffi.gc(x509_ptr[0], self._lib.X509_free) - cert_obj = self._ossl2cert(x509) - name = None - maybe_name = self._lib.X509_alias_get0(x509, self._ffi.NULL) - if maybe_name != self._ffi.NULL: - name = self._ffi.string(maybe_name) - cert = PKCS12Certificate(cert_obj, name) - - if sk_x509_ptr[0] != self._ffi.NULL: - sk_x509 = self._ffi.gc(sk_x509_ptr[0], self._lib.sk_X509_free) - num = self._lib.sk_X509_num(sk_x509_ptr[0]) - - # In OpenSSL < 3.0.0 PKCS12 parsing reverses the order of the - # certificates. - indices: typing.Iterable[int] - if ( - rust_openssl.CRYPTOGRAPHY_OPENSSL_300_OR_GREATER - or rust_openssl.CRYPTOGRAPHY_IS_BORINGSSL - ): - indices = range(num) - else: - indices = reversed(range(num)) - - for i in indices: - x509 = self._lib.sk_X509_value(sk_x509, i) - self.openssl_assert(x509 != self._ffi.NULL) - x509 = self._ffi.gc(x509, self._lib.X509_free) - addl_cert = self._ossl2cert(x509) - addl_name = None - maybe_name = self._lib.X509_alias_get0(x509, self._ffi.NULL) - if maybe_name != self._ffi.NULL: - addl_name = self._ffi.string(maybe_name) - additional_certificates.append( - PKCS12Certificate(addl_cert, addl_name) - ) - - return PKCS12KeyAndCertificates(key, cert, additional_certificates) - def serialize_key_and_certificates_to_pkcs12( self, name: bytes | None, diff --git a/src/cryptography/hazmat/bindings/_rust/openssl/keys.pyi b/src/cryptography/hazmat/bindings/_rust/openssl/keys.pyi index e312d51dc58ba..6815b7d9154ba 100644 --- a/src/cryptography/hazmat/bindings/_rust/openssl/keys.pyi +++ b/src/cryptography/hazmat/bindings/_rust/openssl/keys.pyi @@ -9,10 +9,6 @@ from cryptography.hazmat.primitives.asymmetric.types import ( PublicKeyTypes, ) -def private_key_from_ptr( - ptr: int, - unsafe_skip_rsa_key_validation: bool, -) -> PrivateKeyTypes: ... def load_der_private_key( data: bytes, password: bytes | None, diff --git a/src/cryptography/hazmat/bindings/_rust/pkcs12.pyi b/src/cryptography/hazmat/bindings/_rust/pkcs12.pyi new file mode 100644 index 0000000000000..c82892f6debcc --- /dev/null +++ b/src/cryptography/hazmat/bindings/_rust/pkcs12.pyi @@ -0,0 +1,26 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +import typing + +from cryptography import x509 +from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes +from cryptography.hazmat.primitives.serialization.pkcs12 import ( + PKCS12KeyAndCertificates, +) + +def load_key_and_certificates( + data: bytes, + password: bytes | None, + backend: typing.Any = None, +) -> tuple[ + PrivateKeyTypes | None, + x509.Certificate | None, + list[x509.Certificate], +]: ... +def load_pkcs12( + data: bytes, + password: bytes | None, + backend: typing.Any = None, +) -> PKCS12KeyAndCertificates: ... diff --git a/src/cryptography/hazmat/bindings/_rust/pkcs7.pyi b/src/cryptography/hazmat/bindings/_rust/pkcs7.pyi index a849782465720..f7f9883eb3114 100644 --- a/src/cryptography/hazmat/bindings/_rust/pkcs7.pyi +++ b/src/cryptography/hazmat/bindings/_rust/pkcs7.pyi @@ -1,3 +1,7 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + import typing from cryptography import x509 diff --git a/src/cryptography/hazmat/primitives/serialization/pkcs12.py b/src/cryptography/hazmat/primitives/serialization/pkcs12.py index 006a248bd2444..b6d6a198a4f6b 100644 --- a/src/cryptography/hazmat/primitives/serialization/pkcs12.py +++ b/src/cryptography/hazmat/primitives/serialization/pkcs12.py @@ -7,6 +7,7 @@ import typing from cryptography import x509 +from cryptography.hazmat.bindings._rust import pkcs12 as rust_pkcs12 from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives._serialization import PBES as PBES from cryptography.hazmat.primitives.asymmetric import ( @@ -143,28 +144,8 @@ def __repr__(self) -> str: return fmt.format(self.key, self.cert, self.additional_certs) -def load_key_and_certificates( - data: bytes, - password: bytes | None, - backend: typing.Any = None, -) -> tuple[ - PrivateKeyTypes | None, - x509.Certificate | None, - list[x509.Certificate], -]: - from cryptography.hazmat.backends.openssl.backend import backend as ossl - - return ossl.load_key_and_certificates_from_pkcs12(data, password) - - -def load_pkcs12( - data: bytes, - password: bytes | None, - backend: typing.Any = None, -) -> PKCS12KeyAndCertificates: - from cryptography.hazmat.backends.openssl.backend import backend as ossl - - return ossl.load_pkcs12(data, password) +load_key_and_certificates = rust_pkcs12.load_key_and_certificates +load_pkcs12 = rust_pkcs12.load_pkcs12 _PKCS12CATypes = typing.Union[ diff --git a/src/rust/src/backend/keys.rs b/src/rust/src/backend/keys.rs index 6af0b923aebc4..edb99065c76db 100644 --- a/src/rust/src/backend/keys.rs +++ b/src/rust/src/backend/keys.rs @@ -2,7 +2,6 @@ // 2.0, and the BSD License. See the LICENSE file in the root of this repository // for complete details. -use foreign_types_shared::ForeignTypeRef; use pyo3::IntoPy; use crate::backend::utils; @@ -61,18 +60,7 @@ fn load_pem_private_key( private_key_from_pkey(py, &pkey, unsafe_skip_rsa_key_validation) } -#[pyo3::prelude::pyfunction] -fn private_key_from_ptr( - py: pyo3::Python<'_>, - ptr: usize, - unsafe_skip_rsa_key_validation: bool, -) -> CryptographyResult { - // SAFETY: Caller is responsible for passing a valid pointer. - let pkey = unsafe { openssl::pkey::PKeyRef::from_ptr(ptr as *mut _) }; - private_key_from_pkey(py, pkey, unsafe_skip_rsa_key_validation) -} - -fn private_key_from_pkey( +pub(crate) fn private_key_from_pkey( py: pyo3::Python<'_>, pkey: &openssl::pkey::PKeyRef, unsafe_skip_rsa_key_validation: bool, @@ -236,8 +224,6 @@ pub(crate) fn create_module(py: pyo3::Python<'_>) -> pyo3::PyResult<&pyo3::prelu m.add_function(pyo3::wrap_pyfunction!(load_der_public_key, m)?)?; m.add_function(pyo3::wrap_pyfunction!(load_pem_public_key, m)?)?; - m.add_function(pyo3::wrap_pyfunction!(private_key_from_ptr, m)?)?; - Ok(m) } diff --git a/src/rust/src/lib.rs b/src/rust/src/lib.rs index a92fdebe42df9..af9eb42a520b8 100644 --- a/src/rust/src/lib.rs +++ b/src/rust/src/lib.rs @@ -18,6 +18,7 @@ mod error; mod exceptions; pub(crate) mod oid; mod padding; +mod pkcs12; mod pkcs7; pub(crate) mod types; mod x509; @@ -82,6 +83,7 @@ fn _rust(py: pyo3::Python<'_>, m: &pyo3::types::PyModule) -> pyo3::PyResult<()> m.add_submodule(asn1::create_submodule(py)?)?; m.add_submodule(pkcs7::create_submodule(py)?)?; + m.add_submodule(pkcs12::create_submodule(py)?)?; m.add_submodule(exceptions::create_submodule(py)?)?; let x509_mod = pyo3::prelude::PyModule::new(py, "x509")?; diff --git a/src/rust/src/pkcs12.rs b/src/rust/src/pkcs12.rs new file mode 100644 index 0000000000000..178bdc5c868d3 --- /dev/null +++ b/src/rust/src/pkcs12.rs @@ -0,0 +1,146 @@ +// This file is dual licensed under the terms of the Apache License, Version +// 2.0, and the BSD License. See the LICENSE file in the root of this repository +// for complete details. + +use crate::backend::keys; +use crate::buf::CffiBuf; +use crate::error::CryptographyResult; +use crate::{types, x509}; +use pyo3::IntoPy; + +#[pyo3::prelude::pyfunction] +fn load_key_and_certificates<'p>( + py: pyo3::Python<'p>, + data: CffiBuf<'_>, + password: Option>, + backend: Option<&pyo3::PyAny>, +) -> CryptographyResult<( + pyo3::PyObject, + Option, + &'p pyo3::types::PyList, +)> { + let _ = backend; + + let p12 = openssl::pkcs12::Pkcs12::from_der(data.as_bytes()).map_err(|_| { + pyo3::exceptions::PyValueError::new_err("Could not deserialize PKCS12 data") + })?; + // XXX: Python version treated None as -> NULL, this treats it as "". + let parsed = p12 + .parse2( + password + .as_ref() + .map_or("", |p| std::str::from_utf8(p.as_bytes()).expect("XXX")), + ) + .map_err(|_| pyo3::exceptions::PyValueError::new_err("Invalid password or PKCS12 data"))?; + + let private_key = if let Some(pkey) = parsed.pkey { + keys::private_key_from_pkey(py, &pkey, false)? + } else { + py.None() + }; + let cert = if let Some(ossl_cert) = parsed.cert { + let cert_der = pyo3::types::PyBytes::new(py, &ossl_cert.to_der()?).into_py(py); + Some(x509::certificate::load_der_x509_certificate( + py, cert_der, None, + )?) + } else { + None + }; + let additional_certs = pyo3::types::PyList::empty(py); + if let Some(ossl_certs) = parsed.ca { + cfg_if::cfg_if! { + if #[cfg(any( + CRYPTOGRAPHY_OPENSSL_300_OR_GREATER, CRYPTOGRAPHY_IS_BORINGSSL + ))] { + let it = ossl_certs.iter(); + } else { + let it = ossl_certs.iter().rev(); + } + }; + + for ossl_cert in it { + let cert_der = pyo3::types::PyBytes::new(py, &ossl_cert.to_der()?).into_py(py); + let cert = x509::certificate::load_der_x509_certificate(py, cert_der, None)?; + additional_certs.append(cert.into_py(py))?; + } + } + + Ok((private_key, cert, additional_certs)) +} + +#[pyo3::prelude::pyfunction] +fn load_pkcs12<'p>( + py: pyo3::Python<'p>, + data: CffiBuf<'_>, + password: Option>, + backend: Option<&pyo3::PyAny>, +) -> CryptographyResult<&'p pyo3::PyAny> { + let _ = backend; + + let p12 = openssl::pkcs12::Pkcs12::from_der(data.as_bytes()).map_err(|_| { + pyo3::exceptions::PyValueError::new_err("Could not deserialize PKCS12 data") + })?; + // XXX: Python version treated None as -> NULL, this treats it as "". + let parsed = p12 + .parse2( + password + .as_ref() + .map_or("", |p| std::str::from_utf8(p.as_bytes()).expect("XXX")), + ) + .map_err(|_| pyo3::exceptions::PyValueError::new_err("Invalid password or PKCS12 data"))?; + + let private_key = if let Some(pkey) = parsed.pkey { + keys::private_key_from_pkey(py, &pkey, false)? + } else { + py.None() + }; + let cert = if let Some(ossl_cert) = parsed.cert { + let cert_der = pyo3::types::PyBytes::new(py, &ossl_cert.to_der()?).into_py(py); + let cert = x509::certificate::load_der_x509_certificate(py, cert_der, None)?; + let alias = ossl_cert.alias(); + + types::PKCS12CERTIFICATE + .get(py)? + .call1((cert, alias))? + .into_py(py) + } else { + py.None() + }; + let additional_certs = pyo3::types::PyList::empty(py); + if let Some(ossl_certs) = parsed.ca { + cfg_if::cfg_if! { + if #[cfg(any( + CRYPTOGRAPHY_OPENSSL_300_OR_GREATER, CRYPTOGRAPHY_IS_BORINGSSL + ))] { + let it = ossl_certs.iter(); + } else { + let it = ossl_certs.iter().rev(); + } + }; + + for ossl_cert in it { + let cert_der = pyo3::types::PyBytes::new(py, &ossl_cert.to_der()?).into_py(py); + let cert = x509::certificate::load_der_x509_certificate(py, cert_der, None)?; + let alias = ossl_cert.alias(); + + let p12_cert = types::PKCS12CERTIFICATE + .get(py)? + .call1((cert, alias))? + .into_py(py); + additional_certs.append(p12_cert)?; + } + } + + Ok(types::PKCS12KEYANDCERTIFICATES + .get(py)? + .call1((private_key, cert, additional_certs))?) +} + +pub(crate) fn create_submodule(py: pyo3::Python<'_>) -> pyo3::PyResult<&pyo3::prelude::PyModule> { + let submod = pyo3::prelude::PyModule::new(py, "pkcs12")?; + + submod.add_function(pyo3::wrap_pyfunction!(load_key_and_certificates, submod)?)?; + submod.add_function(pyo3::wrap_pyfunction!(load_pkcs12, submod)?)?; + + Ok(submod) +} diff --git a/src/rust/src/types.rs b/src/rust/src/types.rs index 98dd9ecbb2690..3afdbb9809149 100644 --- a/src/rust/src/types.rs +++ b/src/rust/src/types.rs @@ -327,6 +327,15 @@ pub static SMIME_ENCODE: LazyPyImport = LazyPyImport::new( &["_smime_encode"], ); +pub static PKCS12CERTIFICATE: LazyPyImport = LazyPyImport::new( + "cryptography.hazmat.primitives.serialization.pkcs12", + &["PKCS12Certificate"], +); +pub static PKCS12KEYANDCERTIFICATES: LazyPyImport = LazyPyImport::new( + "cryptography.hazmat.primitives.serialization.pkcs12", + &["PKCS12KeyAndCertificates"], +); + pub static HASHES_MODULE: LazyPyImport = LazyPyImport::new("cryptography.hazmat.primitives.hashes", &[]); pub static HASH_ALGORITHM: LazyPyImport = diff --git a/tests/hazmat/backends/test_openssl.py b/tests/hazmat/backends/test_openssl.py index 7cf98afe91d0c..901eec59776f7 100644 --- a/tests/hazmat/backends/test_openssl.py +++ b/tests/hazmat/backends/test_openssl.py @@ -201,15 +201,6 @@ def test_unsupported_mgf1_hash_algorithm_md5_decrypt(self, rsa_key_2048): class TestOpenSSLSerializationWithOpenSSL: - def test_unsupported_evp_pkey_type(self): - key = backend._lib.EVP_PKEY_new() - key = backend._ffi.gc(key, backend._lib.EVP_PKEY_free) - with raises_unsupported_algorithm(None): - rust_openssl.keys.private_key_from_ptr( - int(backend._ffi.cast("uintptr_t", key)), - unsafe_skip_rsa_key_validation=False, - ) - def test_very_long_pem_serialization_password(self): password = b"x" * 1025