Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert PKCS#12 loading to Rust #10434

Merged
merged 1 commit into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 0 additions & 100 deletions src/cryptography/hazmat/backends/openssl/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,6 @@
PSS,
PKCS1v15,
)
from cryptography.hazmat.primitives.asymmetric.types import (
PrivateKeyTypes,
)
from cryptography.hazmat.primitives.ciphers import (
CipherAlgorithm,
)
Expand All @@ -38,7 +35,6 @@
from cryptography.hazmat.primitives.serialization.pkcs12 import (
PBES,
PKCS12Certificate,
PKCS12KeyAndCertificates,
PKCS12PrivateKeyTypes,
_PKCS12CATypes,
)
Expand Down Expand Up @@ -278,12 +274,6 @@ def _cert2ossl(self, cert: x509.Certificate) -> typing.Any:
x509 = self._ffi.gc(x509, self._lib.X509_free)
return x509

def _ossl2cert(self, x509_ptr: typing.Any) -> x509.Certificate:
bio = self._create_mem_bio_gc()
res = self._lib.i2d_X509_bio(bio, x509_ptr)
self.openssl_assert(res == 1)
return x509.load_der_x509_certificate(self._read_mem_bio(bio))

def _key2ossl(self, key: PKCS12PrivateKeyTypes) -> typing.Any:
data = key.private_bytes(
serialization.Encoding.DER,
Expand Down Expand Up @@ -398,96 +388,6 @@ def _zeroed_null_terminated_buf(self, data):
# Cast to a uint8_t * so we can assign by integer
self._zero_data(self._ffi.cast("uint8_t *", buf), data_len)

def load_key_and_certificates_from_pkcs12(
self, data: bytes, password: bytes | None
) -> tuple[
PrivateKeyTypes | None,
x509.Certificate | None,
list[x509.Certificate],
]:
pkcs12 = self.load_pkcs12(data, password)
return (
pkcs12.key,
pkcs12.cert.certificate if pkcs12.cert else None,
[cert.certificate for cert in pkcs12.additional_certs],
)

def load_pkcs12(
self, data: bytes, password: bytes | None
) -> PKCS12KeyAndCertificates:
if password is not None:
utils._check_byteslike("password", password)

bio = self._bytes_to_bio(data)
p12 = self._lib.d2i_PKCS12_bio(bio.bio, self._ffi.NULL)
if p12 == self._ffi.NULL:
self._consume_errors()
raise ValueError("Could not deserialize PKCS12 data")

p12 = self._ffi.gc(p12, self._lib.PKCS12_free)
evp_pkey_ptr = self._ffi.new("EVP_PKEY **")
x509_ptr = self._ffi.new("X509 **")
sk_x509_ptr = self._ffi.new("Cryptography_STACK_OF_X509 **")
with self._zeroed_null_terminated_buf(password) as password_buf:
res = self._lib.PKCS12_parse(
p12, password_buf, evp_pkey_ptr, x509_ptr, sk_x509_ptr
)
if res == 0:
self._consume_errors()
raise ValueError("Invalid password or PKCS12 data")

cert = None
key = None
additional_certificates = []

if evp_pkey_ptr[0] != self._ffi.NULL:
evp_pkey = self._ffi.gc(evp_pkey_ptr[0], self._lib.EVP_PKEY_free)
# We don't support turning off RSA key validation when loading
# PKCS12 keys
key = rust_openssl.keys.private_key_from_ptr(
int(self._ffi.cast("uintptr_t", evp_pkey)),
unsafe_skip_rsa_key_validation=False,
)

if x509_ptr[0] != self._ffi.NULL:
x509 = self._ffi.gc(x509_ptr[0], self._lib.X509_free)
cert_obj = self._ossl2cert(x509)
name = None
maybe_name = self._lib.X509_alias_get0(x509, self._ffi.NULL)
if maybe_name != self._ffi.NULL:
name = self._ffi.string(maybe_name)
cert = PKCS12Certificate(cert_obj, name)

if sk_x509_ptr[0] != self._ffi.NULL:
sk_x509 = self._ffi.gc(sk_x509_ptr[0], self._lib.sk_X509_free)
num = self._lib.sk_X509_num(sk_x509_ptr[0])

# In OpenSSL < 3.0.0 PKCS12 parsing reverses the order of the
# certificates.
indices: typing.Iterable[int]
if (
rust_openssl.CRYPTOGRAPHY_OPENSSL_300_OR_GREATER
or rust_openssl.CRYPTOGRAPHY_IS_BORINGSSL
):
indices = range(num)
else:
indices = reversed(range(num))

for i in indices:
x509 = self._lib.sk_X509_value(sk_x509, i)
self.openssl_assert(x509 != self._ffi.NULL)
x509 = self._ffi.gc(x509, self._lib.X509_free)
addl_cert = self._ossl2cert(x509)
addl_name = None
maybe_name = self._lib.X509_alias_get0(x509, self._ffi.NULL)
if maybe_name != self._ffi.NULL:
addl_name = self._ffi.string(maybe_name)
additional_certificates.append(
PKCS12Certificate(addl_cert, addl_name)
)

return PKCS12KeyAndCertificates(key, cert, additional_certificates)

def serialize_key_and_certificates_to_pkcs12(
self,
name: bytes | None,
Expand Down
4 changes: 0 additions & 4 deletions src/cryptography/hazmat/bindings/_rust/openssl/keys.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,6 @@ from cryptography.hazmat.primitives.asymmetric.types import (
PublicKeyTypes,
)

def private_key_from_ptr(
ptr: int,
unsafe_skip_rsa_key_validation: bool,
) -> PrivateKeyTypes: ...
def load_der_private_key(
data: bytes,
password: bytes | None,
Expand Down
26 changes: 26 additions & 0 deletions src/cryptography/hazmat/bindings/_rust/pkcs12.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# This file is dual licensed under the terms of the Apache License, Version
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
# for complete details.

import typing

from cryptography import x509
from cryptography.hazmat.primitives.asymmetric.types import PrivateKeyTypes
from cryptography.hazmat.primitives.serialization.pkcs12 import (
PKCS12KeyAndCertificates,
)

def load_key_and_certificates(
data: bytes,
password: bytes | None,
backend: typing.Any = None,
) -> tuple[
PrivateKeyTypes | None,
x509.Certificate | None,
list[x509.Certificate],
]: ...
def load_pkcs12(
data: bytes,
password: bytes | None,
backend: typing.Any = None,
) -> PKCS12KeyAndCertificates: ...
4 changes: 4 additions & 0 deletions src/cryptography/hazmat/bindings/_rust/pkcs7.pyi
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# This file is dual licensed under the terms of the Apache License, Version
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
# for complete details.

import typing

from cryptography import x509
Expand Down
25 changes: 3 additions & 22 deletions src/cryptography/hazmat/primitives/serialization/pkcs12.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import typing

from cryptography import x509
from cryptography.hazmat.bindings._rust import pkcs12 as rust_pkcs12
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives._serialization import PBES as PBES
from cryptography.hazmat.primitives.asymmetric import (
Expand Down Expand Up @@ -143,28 +144,8 @@ def __repr__(self) -> str:
return fmt.format(self.key, self.cert, self.additional_certs)


def load_key_and_certificates(
data: bytes,
password: bytes | None,
backend: typing.Any = None,
) -> tuple[
PrivateKeyTypes | None,
x509.Certificate | None,
list[x509.Certificate],
]:
from cryptography.hazmat.backends.openssl.backend import backend as ossl

return ossl.load_key_and_certificates_from_pkcs12(data, password)


def load_pkcs12(
data: bytes,
password: bytes | None,
backend: typing.Any = None,
) -> PKCS12KeyAndCertificates:
from cryptography.hazmat.backends.openssl.backend import backend as ossl

return ossl.load_pkcs12(data, password)
load_key_and_certificates = rust_pkcs12.load_key_and_certificates
load_pkcs12 = rust_pkcs12.load_pkcs12


_PKCS12CATypes = typing.Union[
Expand Down
29 changes: 13 additions & 16 deletions src/rust/src/backend/keys.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// 2.0, and the BSD License. See the LICENSE file in the root of this repository
// for complete details.

use foreign_types_shared::ForeignTypeRef;
use pyo3::IntoPy;

use crate::backend::utils;
Expand Down Expand Up @@ -61,18 +60,7 @@ fn load_pem_private_key(
private_key_from_pkey(py, &pkey, unsafe_skip_rsa_key_validation)
}

#[pyo3::prelude::pyfunction]
fn private_key_from_ptr(
py: pyo3::Python<'_>,
ptr: usize,
unsafe_skip_rsa_key_validation: bool,
) -> CryptographyResult<pyo3::PyObject> {
// SAFETY: Caller is responsible for passing a valid pointer.
let pkey = unsafe { openssl::pkey::PKeyRef::from_ptr(ptr as *mut _) };
private_key_from_pkey(py, pkey, unsafe_skip_rsa_key_validation)
}

fn private_key_from_pkey(
pub(crate) fn private_key_from_pkey(
py: pyo3::Python<'_>,
pkey: &openssl::pkey::PKeyRef<openssl::pkey::Private>,
unsafe_skip_rsa_key_validation: bool,
Expand Down Expand Up @@ -236,15 +224,13 @@ pub(crate) fn create_module(py: pyo3::Python<'_>) -> pyo3::PyResult<&pyo3::prelu
m.add_function(pyo3::wrap_pyfunction!(load_der_public_key, m)?)?;
m.add_function(pyo3::wrap_pyfunction!(load_pem_public_key, m)?)?;

m.add_function(pyo3::wrap_pyfunction!(private_key_from_ptr, m)?)?;

Ok(m)
}

#[cfg(test)]
mod tests {
#[cfg(not(CRYPTOGRAPHY_IS_BORINGSSL))]
use super::public_key_from_pkey;
use super::{private_key_from_pkey, public_key_from_pkey};

#[test]
#[cfg(not(CRYPTOGRAPHY_IS_BORINGSSL))]
Expand All @@ -260,4 +246,15 @@ mod tests {
assert!(public_key_from_pkey(py, &pkey, openssl::pkey::Id::CMAC).is_err());
});
}

#[test]
#[cfg(not(CRYPTOGRAPHY_IS_BORINGSSL))]
fn test_private_key_from_pkey_unknown_key() {
pyo3::prepare_freethreaded_python();

pyo3::Python::with_gil(|py| {
let pkey = openssl::pkey::PKey::hmac(&[0; 32]).unwrap();
assert!(private_key_from_pkey(py, &pkey, false).is_err());
});
}
}
2 changes: 2 additions & 0 deletions src/rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ mod error;
mod exceptions;
pub(crate) mod oid;
mod padding;
mod pkcs12;
mod pkcs7;
pub(crate) mod types;
mod x509;
Expand Down Expand Up @@ -82,6 +83,7 @@ fn _rust(py: pyo3::Python<'_>, m: &pyo3::types::PyModule) -> pyo3::PyResult<()>

m.add_submodule(asn1::create_submodule(py)?)?;
m.add_submodule(pkcs7::create_submodule(py)?)?;
m.add_submodule(pkcs12::create_submodule(py)?)?;
m.add_submodule(exceptions::create_submodule(py)?)?;

let x509_mod = pyo3::prelude::PyModule::new(py, "x509")?;
Expand Down
Loading