Skip to content

Commit

Permalink
Convert PKCS#12 loading to Rust (#10434)
Browse files Browse the repository at this point in the history
  • Loading branch information
alex authored Feb 20, 2024
1 parent fb2d6ec commit 8224447
Show file tree
Hide file tree
Showing 11 changed files with 222 additions and 153 deletions.
100 changes: 0 additions & 100 deletions src/cryptography/hazmat/backends/openssl/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@
PSS,
PKCS1v15,
)
from cryptography.hazmat.primitives.asymmetric.types import (
PrivateKeyTypes,
)
from cryptography.hazmat.primitives.ciphers import (
CipherAlgorithm,
)
Expand All @@ -38,7 +35,6 @@
from cryptography.hazmat.primitives.serialization.pkcs12 import (
PBES,
PKCS12Certificate,
PKCS12KeyAndCertificates,
PKCS12PrivateKeyTypes,
_PKCS12CATypes,
)
Expand Down Expand Up @@ -278,12 +274,6 @@ def _cert2ossl(self, cert: x509.Certificate) -> typing.Any:
x509 = self._ffi.gc(x509, self._lib.X509_free)
return x509

def _ossl2cert(self, x509_ptr: typing.Any) -> x509.Certificate:
bio = self._create_mem_bio_gc()
res = self._lib.i2d_X509_bio(bio, x509_ptr)
self.openssl_assert(res == 1)
return x509.load_der_x509_certificate(self._read_mem_bio(bio))

def _key2ossl(self, key: PKCS12PrivateKeyTypes) -> typing.Any:
data = key.private_bytes(
serialization.Encoding.DER,
Expand Down Expand Up @@ -398,96 +388,6 @@ def _zeroed_null_terminated_buf(self, data):
# Cast to a uint8_t * so we can assign by integer
self._zero_data(self._ffi.cast("uint8_t *", buf), data_len)

def load_key_and_certificates_from_pkcs12(
self, data: bytes, password: bytes | None
) -> tuple[
PrivateKeyTypes | None,
x509.Certificate | None,
list[x509.Certificate],
]:
pkcs12 = self.load_pkcs12(data, password)
return (
pkcs12.key,
pkcs12.cert.certificate if pkcs12.cert else None,
[cert.certificate for cert in pkcs12.additional_certs],
)

def load_pkcs12(
self, data: bytes, password: bytes | None
) -> PKCS12KeyAndCertificates:
if password is not None:
utils._check_byteslike("password", password)

bio = self._bytes_to_bio(data)
p12 = self._lib.d2i_PKCS12_bio(bio.bio, self._ffi.NULL)
if p12 == self._ffi.NULL:
self._consume_errors()
raise ValueError("Could not deserialize PKCS12 data")

p12 = self._ffi.gc(p12, self._lib.PKCS12_free)
evp_pkey_ptr = self._ffi.new("EVP_PKEY **")
x509_ptr = self._ffi.new("X509 **")
sk_x509_ptr = self._ffi.new("Cryptography_STACK_OF_X509 **")
with self._zeroed_null_terminated_buf(password) as password_buf:
res = self._lib.PKCS12_parse(
p12, password_buf, evp_pkey_ptr, x509_ptr, sk_x509_ptr
)
if res == 0:
self._consume_errors()
raise ValueError("Invalid password or PKCS12 data")

cert = None
key = None
additional_certificates = []

if evp_pkey_ptr[0] != self._ffi.NULL:
evp_pkey = self._ffi.gc(evp_pkey_ptr[0], self._lib.EVP_PKEY_free)
# We don't support turning off RSA key validation when loading
# PKCS12 keys
key = rust_openssl.keys.private_key_from_ptr(
int(self._ffi.cast("uintptr_t", evp_pkey)),
unsafe_skip_rsa_key_validation=False,
)

if x509_ptr[0] != self._ffi.NULL:
x509 = self._ffi.gc(x509_ptr[0], self._lib.X509_free)
cert_obj = self._ossl2cert(x509)
name = None
maybe_name = self._lib.X509_alias_get0(x509, self._ffi.NULL)
if maybe_name != self._ffi.NULL:
name = self._ffi.string(maybe_name)
cert = PKCS12Certificate(cert_obj, name)

if sk_x509_ptr[0] != self._ffi.NULL:
sk_x509 = self._ffi.gc(sk_x509_ptr[0], self._lib.sk_X509_free)
num = self._lib.sk_X509_num(sk_x509_ptr[0])

# In OpenSSL < 3.0.0 PKCS12 parsing reverses the order of the
# certificates.
indices: typing.Iterable[int]
if (
rust_openssl.CRYPTOGRAPHY_OPENSSL_300_OR_GREATER
or rust_openssl.CRYPTOGRAPHY_IS_BORINGSSL
):
indices = range(num)
else:
indices = reversed(range(num))

for i in indices:
x509 = self._lib.sk_X509_value(sk_x509, i)
self.openssl_assert(x509 != self._ffi.NULL)
x509 = self._ffi.gc(x509, self._lib.X509_free)
addl_cert = self._ossl2cert(x509)
addl_name = None
maybe_name = self._lib.X509_alias_get0(x509, self._ffi.NULL)
if maybe_name != self._ffi.NULL:
addl_name = self._ffi.string(maybe_name)
additional_certificates.append(
PKCS12Certificate(addl_cert, addl_name)
)

return PKCS12KeyAndCertificates(key, cert, additional_certificates)

def serialize_key_and_certificates_to_pkcs12(
self,
name: bytes | None,
Expand Down
4 changes: 0 additions & 4 deletions src/cryptography/hazmat/bindings/_rust/openssl/keys.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,6 @@ from cryptography.hazmat.primitives.asymmetric.types import (
PublicKeyTypes,
)

def private_key_from_ptr(
ptr: int,
unsafe_skip_rsa_key_validation: bool,
) -> PrivateKeyTypes: ...
def load_der_private_key(
data: bytes,
password: bytes | None,
Expand Down
26 changes: 26 additions & 0 deletions src/cryptography/hazmat/bindings/_rust/pkcs12.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# This file is dual licensed under the terms of the Apache License, Version
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
# for complete details.

import typing

from cryptography import x509
from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes
from cryptography.hazmat.primitives.serialization.pkcs12 import (
PKCS12KeyAndCertificates,
)

def load_key_and_certificates(
data: bytes,
password: bytes | None,
backend: typing.Any = None,
) -> tuple[
PrivateKeyTypes | None,
x509.Certificate | None,
list[x509.Certificate],
]: ...
def load_pkcs12(
data: bytes,
password: bytes | None,
backend: typing.Any = None,
) -> PKCS12KeyAndCertificates: ...
4 changes: 4 additions & 0 deletions src/cryptography/hazmat/bindings/_rust/pkcs7.pyi
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# This file is dual licensed under the terms of the Apache License, Version
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
# for complete details.

import typing

from cryptography import x509
Expand Down
25 changes: 3 additions & 22 deletions src/cryptography/hazmat/primitives/serialization/pkcs12.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import typing

from cryptography import x509
from cryptography.hazmat.bindings._rust import pkcs12 as rust_pkcs12
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives._serialization import PBES as PBES
from cryptography.hazmat.primitives.asymmetric import (
Expand Down Expand Up @@ -143,28 +144,8 @@ def __repr__(self) -> str:
return fmt.format(self.key, self.cert, self.additional_certs)


def load_key_and_certificates(
data: bytes,
password: bytes | None,
backend: typing.Any = None,
) -> tuple[
PrivateKeyTypes | None,
x509.Certificate | None,
list[x509.Certificate],
]:
from cryptography.hazmat.backends.openssl.backend import backend as ossl

return ossl.load_key_and_certificates_from_pkcs12(data, password)


def load_pkcs12(
data: bytes,
password: bytes | None,
backend: typing.Any = None,
) -> PKCS12KeyAndCertificates:
from cryptography.hazmat.backends.openssl.backend import backend as ossl

return ossl.load_pkcs12(data, password)
load_key_and_certificates = rust_pkcs12.load_key_and_certificates
load_pkcs12 = rust_pkcs12.load_pkcs12


_PKCS12CATypes = typing.Union[
Expand Down
29 changes: 13 additions & 16 deletions src/rust/src/backend/keys.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// 2.0, and the BSD License. See the LICENSE file in the root of this repository
// for complete details.

use foreign_types_shared::ForeignTypeRef;
use pyo3::IntoPy;

use crate::backend::utils;
Expand Down Expand Up @@ -61,18 +60,7 @@ fn load_pem_private_key(
private_key_from_pkey(py, &pkey, unsafe_skip_rsa_key_validation)
}

#[pyo3::prelude::pyfunction]
fn private_key_from_ptr(
py: pyo3::Python<'_>,
ptr: usize,
unsafe_skip_rsa_key_validation: bool,
) -> CryptographyResult<pyo3::PyObject> {
// SAFETY: Caller is responsible for passing a valid pointer.
let pkey = unsafe { openssl::pkey::PKeyRef::from_ptr(ptr as *mut _) };
private_key_from_pkey(py, pkey, unsafe_skip_rsa_key_validation)
}

fn private_key_from_pkey(
pub(crate) fn private_key_from_pkey(
py: pyo3::Python<'_>,
pkey: &openssl::pkey::PKeyRef<openssl::pkey::Private>,
unsafe_skip_rsa_key_validation: bool,
Expand Down Expand Up @@ -236,15 +224,13 @@ pub(crate) fn create_module(py: pyo3::Python<'_>) -> pyo3::PyResult<&pyo3::prelu
m.add_function(pyo3::wrap_pyfunction!(load_der_public_key, m)?)?;
m.add_function(pyo3::wrap_pyfunction!(load_pem_public_key, m)?)?;

m.add_function(pyo3::wrap_pyfunction!(private_key_from_ptr, m)?)?;

Ok(m)
}

#[cfg(test)]
mod tests {
#[cfg(not(CRYPTOGRAPHY_IS_BORINGSSL))]
use super::public_key_from_pkey;
use super::{private_key_from_pkey, public_key_from_pkey};

#[test]
#[cfg(not(CRYPTOGRAPHY_IS_BORINGSSL))]
Expand All @@ -260,4 +246,15 @@ mod tests {
assert!(public_key_from_pkey(py, &pkey, openssl::pkey::Id::CMAC).is_err());
});
}

#[test]
#[cfg(not(CRYPTOGRAPHY_IS_BORINGSSL))]
fn test_private_key_from_pkey_unknown_key() {
pyo3::prepare_freethreaded_python();

pyo3::Python::with_gil(|py| {
let pkey = openssl::pkey::PKey::hmac(&[0; 32]).unwrap();
assert!(private_key_from_pkey(py, &pkey, false).is_err());
});
}
}
2 changes: 2 additions & 0 deletions src/rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ mod error;
mod exceptions;
pub(crate) mod oid;
mod padding;
mod pkcs12;
mod pkcs7;
pub(crate) mod types;
mod x509;
Expand Down Expand Up @@ -82,6 +83,7 @@ fn _rust(py: pyo3::Python<'_>, m: &pyo3::types::PyModule) -> pyo3::PyResult<()>

m.add_submodule(asn1::create_submodule(py)?)?;
m.add_submodule(pkcs7::create_submodule(py)?)?;
m.add_submodule(pkcs12::create_submodule(py)?)?;
m.add_submodule(exceptions::create_submodule(py)?)?;

let x509_mod = pyo3::prelude::PyModule::new(py, "x509")?;
Expand Down
Loading

0 comments on commit 8224447

Please sign in to comment.