Skip to content

Commit

Permalink
Refactor key conversion to be in rust
Browse files Browse the repository at this point in the history
removes a lot of unsafe
  • Loading branch information
alex committed Nov 3, 2023
1 parent 9d5682e commit b831ec5
Show file tree
Hide file tree
Showing 22 changed files with 213 additions and 249 deletions.
130 changes: 9 additions & 121 deletions src/cryptography/hazmat/backends/openssl/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,127 +392,19 @@ def _evp_pkey_to_private_key(
Return the appropriate type of PrivateKey given an evp_pkey cdata
pointer.
"""

key_type = self._lib.EVP_PKEY_id(evp_pkey)

if key_type == self._lib.EVP_PKEY_RSA:
return rust_openssl.rsa.private_key_from_ptr(
int(self._ffi.cast("uintptr_t", evp_pkey)),
unsafe_skip_rsa_key_validation=unsafe_skip_rsa_key_validation,
)
elif (
key_type == self._lib.EVP_PKEY_RSA_PSS
and not self._lib.CRYPTOGRAPHY_IS_LIBRESSL
and not self._lib.CRYPTOGRAPHY_IS_BORINGSSL
and not self._lib.CRYPTOGRAPHY_OPENSSL_LESS_THAN_111E
):
# At the moment the way we handle RSA PSS keys is to strip the
# PSS constraints from them and treat them as normal RSA keys
# Unfortunately the RSA * itself tracks this data so we need to
# extract, serialize, and reload it without the constraints.
rsa_cdata = self._lib.EVP_PKEY_get1_RSA(evp_pkey)
self.openssl_assert(rsa_cdata != self._ffi.NULL)
rsa_cdata = self._ffi.gc(rsa_cdata, self._lib.RSA_free)
bio = self._create_mem_bio_gc()
res = self._lib.i2d_RSAPrivateKey_bio(bio, rsa_cdata)
self.openssl_assert(res == 1)
return self.load_der_private_key(
self._read_mem_bio(bio),
password=None,
unsafe_skip_rsa_key_validation=unsafe_skip_rsa_key_validation,
)
elif key_type == self._lib.EVP_PKEY_DSA:
return rust_openssl.dsa.private_key_from_ptr(
int(self._ffi.cast("uintptr_t", evp_pkey))
)
elif key_type == self._lib.EVP_PKEY_EC:
return rust_openssl.ec.private_key_from_ptr(
int(self._ffi.cast("uintptr_t", evp_pkey))
)
elif key_type in self._dh_types:
return rust_openssl.dh.private_key_from_ptr(
int(self._ffi.cast("uintptr_t", evp_pkey))
)
elif key_type == getattr(self._lib, "EVP_PKEY_ED25519", None):
# EVP_PKEY_ED25519 is not present in CRYPTOGRAPHY_IS_LIBRESSL
return rust_openssl.ed25519.private_key_from_ptr(
int(self._ffi.cast("uintptr_t", evp_pkey))
)
elif key_type == getattr(self._lib, "EVP_PKEY_X448", None):
# EVP_PKEY_X448 is not present in CRYPTOGRAPHY_IS_LIBRESSL
return rust_openssl.x448.private_key_from_ptr(
int(self._ffi.cast("uintptr_t", evp_pkey))
)
elif key_type == self._lib.EVP_PKEY_X25519:
return rust_openssl.x25519.private_key_from_ptr(
int(self._ffi.cast("uintptr_t", evp_pkey))
)
elif key_type == getattr(self._lib, "EVP_PKEY_ED448", None):
# EVP_PKEY_ED448 is not present in CRYPTOGRAPHY_IS_LIBRESSL
return rust_openssl.ed448.private_key_from_ptr(
int(self._ffi.cast("uintptr_t", evp_pkey))
)
else:
raise UnsupportedAlgorithm("Unsupported key type.")
return rust_openssl.keys.private_key_from_ptr(
int(self._ffi.cast("uintptr_t", evp_pkey)),
unsafe_skip_rsa_key_validation=unsafe_skip_rsa_key_validation,
)

def _evp_pkey_to_public_key(self, evp_pkey) -> PublicKeyTypes:
"""
Return the appropriate type of PublicKey given an evp_pkey cdata
pointer.
"""

key_type = self._lib.EVP_PKEY_id(evp_pkey)

if key_type == self._lib.EVP_PKEY_RSA:
return rust_openssl.rsa.public_key_from_ptr(
int(self._ffi.cast("uintptr_t", evp_pkey))
)
elif (
key_type == self._lib.EVP_PKEY_RSA_PSS
and not self._lib.CRYPTOGRAPHY_IS_LIBRESSL
and not self._lib.CRYPTOGRAPHY_IS_BORINGSSL
and not self._lib.CRYPTOGRAPHY_OPENSSL_LESS_THAN_111E
):
rsa_cdata = self._lib.EVP_PKEY_get1_RSA(evp_pkey)
self.openssl_assert(rsa_cdata != self._ffi.NULL)
rsa_cdata = self._ffi.gc(rsa_cdata, self._lib.RSA_free)
bio = self._create_mem_bio_gc()
res = self._lib.i2d_RSAPublicKey_bio(bio, rsa_cdata)
self.openssl_assert(res == 1)
return self.load_der_public_key(self._read_mem_bio(bio))
elif key_type == self._lib.EVP_PKEY_DSA:
return rust_openssl.dsa.public_key_from_ptr(
int(self._ffi.cast("uintptr_t", evp_pkey))
)
elif key_type == self._lib.EVP_PKEY_EC:
return rust_openssl.ec.public_key_from_ptr(
int(self._ffi.cast("uintptr_t", evp_pkey))
)
elif key_type in self._dh_types:
return rust_openssl.dh.public_key_from_ptr(
int(self._ffi.cast("uintptr_t", evp_pkey))
)
elif key_type == getattr(self._lib, "EVP_PKEY_ED25519", None):
# EVP_PKEY_ED25519 is not present in CRYPTOGRAPHY_IS_LIBRESSL
return rust_openssl.ed25519.public_key_from_ptr(
int(self._ffi.cast("uintptr_t", evp_pkey))
)
elif key_type == getattr(self._lib, "EVP_PKEY_X448", None):
# EVP_PKEY_X448 is not present in CRYPTOGRAPHY_IS_LIBRESSL
return rust_openssl.x448.public_key_from_ptr(
int(self._ffi.cast("uintptr_t", evp_pkey))
)
elif key_type == self._lib.EVP_PKEY_X25519:
return rust_openssl.x25519.public_key_from_ptr(
int(self._ffi.cast("uintptr_t", evp_pkey))
)
elif key_type == getattr(self._lib, "EVP_PKEY_ED448", None):
# EVP_PKEY_ED448 is not present in CRYPTOGRAPHY_IS_LIBRESSL
return rust_openssl.ed448.public_key_from_ptr(
int(self._ffi.cast("uintptr_t", evp_pkey))
)
else:
raise UnsupportedAlgorithm("Unsupported key type.")
return rust_openssl.keys.public_key_from_ptr(
int(self._ffi.cast("uintptr_t", evp_pkey))
)

def _oaep_hash_supported(self, algorithm: hashes.HashAlgorithm) -> bool:
if self._fips_enabled and isinstance(algorithm, hashes.SHA1):
Expand Down Expand Up @@ -620,9 +512,7 @@ def load_pem_public_key(self, data: bytes) -> PublicKeyTypes:
if rsa_cdata != self._ffi.NULL:
rsa_cdata = self._ffi.gc(rsa_cdata, self._lib.RSA_free)
evp_pkey = self._rsa_cdata_to_evp_pkey(rsa_cdata)
return rust_openssl.rsa.public_key_from_ptr(
int(self._ffi.cast("uintptr_t", evp_pkey))
)
return self._evp_pkey_to_public_key(evp_pkey)
else:
self._handle_key_loading_error()

Expand Down Expand Up @@ -685,9 +575,7 @@ def load_der_public_key(self, data: bytes) -> PublicKeyTypes:
if rsa_cdata != self._ffi.NULL:
rsa_cdata = self._ffi.gc(rsa_cdata, self._lib.RSA_free)
evp_pkey = self._rsa_cdata_to_evp_pkey(rsa_cdata)
return rust_openssl.rsa.public_key_from_ptr(
int(self._ffi.cast("uintptr_t", evp_pkey))
)
return self._evp_pkey_to_public_key(evp_pkey)
else:
self._handle_key_loading_error()

Expand Down
2 changes: 2 additions & 0 deletions src/cryptography/hazmat/bindings/_rust/openssl/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ from cryptography.hazmat.bindings._rust.openssl import (
hashes,
hmac,
kdf,
keys,
poly1305,
rsa,
x448,
Expand All @@ -32,6 +33,7 @@ __all__ = [
"hashes",
"hmac",
"kdf",
"keys",
"ed448",
"ed25519",
"rsa",
Expand Down
2 changes: 0 additions & 2 deletions src/cryptography/hazmat/bindings/_rust/openssl/dh.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ class DHPublicKey: ...
class DHParameters: ...

def generate_parameters(generator: int, key_size: int) -> dh.DHParameters: ...
def private_key_from_ptr(ptr: int) -> dh.DHPrivateKey: ...
def public_key_from_ptr(ptr: int) -> dh.DHPublicKey: ...
def from_pem_parameters(data: bytes) -> dh.DHParameters: ...
def from_der_parameters(data: bytes) -> dh.DHParameters: ...
def from_private_numbers(numbers: dh.DHPrivateNumbers) -> dh.DHPrivateKey: ...
Expand Down
2 changes: 0 additions & 2 deletions src/cryptography/hazmat/bindings/_rust/openssl/dsa.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ class DSAPublicKey: ...
class DSAParameters: ...

def generate_parameters(key_size: int) -> dsa.DSAParameters: ...
def private_key_from_ptr(ptr: int) -> dsa.DSAPrivateKey: ...
def public_key_from_ptr(ptr: int) -> dsa.DSAPublicKey: ...
def from_private_numbers(
numbers: dsa.DSAPrivateNumbers,
) -> dsa.DSAPrivateKey: ...
Expand Down
2 changes: 0 additions & 2 deletions src/cryptography/hazmat/bindings/_rust/openssl/ec.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ class ECPrivateKey: ...
class ECPublicKey: ...

def curve_supported(curve: ec.EllipticCurve) -> bool: ...
def private_key_from_ptr(ptr: int) -> ec.EllipticCurvePrivateKey: ...
def public_key_from_ptr(ptr: int) -> ec.EllipticCurvePublicKey: ...
def generate_private_key(
curve: ec.EllipticCurve,
) -> ec.EllipticCurvePrivateKey: ...
Expand Down
2 changes: 0 additions & 2 deletions src/cryptography/hazmat/bindings/_rust/openssl/ed25519.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,5 @@ class Ed25519PrivateKey: ...
class Ed25519PublicKey: ...

def generate_key() -> ed25519.Ed25519PrivateKey: ...
def private_key_from_ptr(ptr: int) -> ed25519.Ed25519PrivateKey: ...
def public_key_from_ptr(ptr: int) -> ed25519.Ed25519PublicKey: ...
def from_private_bytes(data: bytes) -> ed25519.Ed25519PrivateKey: ...
def from_public_bytes(data: bytes) -> ed25519.Ed25519PublicKey: ...
2 changes: 0 additions & 2 deletions src/cryptography/hazmat/bindings/_rust/openssl/ed448.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,5 @@ class Ed448PrivateKey: ...
class Ed448PublicKey: ...

def generate_key() -> ed448.Ed448PrivateKey: ...
def private_key_from_ptr(ptr: int) -> ed448.Ed448PrivateKey: ...
def public_key_from_ptr(ptr: int) -> ed448.Ed448PublicKey: ...
def from_private_bytes(data: bytes) -> ed448.Ed448PrivateKey: ...
def from_public_bytes(data: bytes) -> ed448.Ed448PublicKey: ...
14 changes: 14 additions & 0 deletions src/cryptography/hazmat/bindings/_rust/openssl/keys.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# This file is dual licensed under the terms of the Apache License, Version
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
# for complete details.

from cryptography.hazmat.primitives.asymmetric.types import (
PrivateKeyTypes,
PublicKeyTypes,
)

def private_key_from_ptr(
ptr: int,
unsafe_skip_rsa_key_validation: bool,
) -> PrivateKeyTypes: ...
def public_key_from_ptr(ptr: int) -> PublicKeyTypes: ...
5 changes: 0 additions & 5 deletions src/cryptography/hazmat/bindings/_rust/openssl/rsa.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,6 @@ def generate_private_key(
public_exponent: int,
key_size: int,
) -> rsa.RSAPrivateKey: ...
def private_key_from_ptr(
ptr: int,
unsafe_skip_rsa_key_validation: bool,
) -> rsa.RSAPrivateKey: ...
def public_key_from_ptr(ptr: int) -> rsa.RSAPublicKey: ...
def from_private_numbers(
numbers: rsa.RSAPrivateNumbers,
unsafe_skip_rsa_key_validation: bool,
Expand Down
2 changes: 0 additions & 2 deletions src/cryptography/hazmat/bindings/_rust/openssl/x25519.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,5 @@ class X25519PrivateKey: ...
class X25519PublicKey: ...

def generate_key() -> x25519.X25519PrivateKey: ...
def private_key_from_ptr(ptr: int) -> x25519.X25519PrivateKey: ...
def public_key_from_ptr(ptr: int) -> x25519.X25519PublicKey: ...
def from_private_bytes(data: bytes) -> x25519.X25519PrivateKey: ...
def from_public_bytes(data: bytes) -> x25519.X25519PublicKey: ...
2 changes: 0 additions & 2 deletions src/cryptography/hazmat/bindings/_rust/openssl/x448.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,5 @@ class X448PrivateKey: ...
class X448PublicKey: ...

def generate_key() -> x448.X448PrivateKey: ...
def private_key_from_ptr(ptr: int) -> x448.X448PrivateKey: ...
def public_key_from_ptr(ptr: int) -> x448.X448PublicKey: ...
def from_private_bytes(data: bytes) -> x448.X448PrivateKey: ...
def from_public_bytes(data: bytes) -> x448.X448PublicKey: ...
21 changes: 8 additions & 13 deletions src/rust/src/backend/dh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,16 @@ use crate::backend::utils;
use crate::error::{CryptographyError, CryptographyResult};
use crate::{types, x509};
use cryptography_x509::common;
use foreign_types_shared::ForeignTypeRef;

const MIN_MODULUS_SIZE: u32 = 512;

#[pyo3::prelude::pyclass(frozen, module = "cryptography.hazmat.bindings._rust.openssl.dh")]
struct DHPrivateKey {
pub(crate) struct DHPrivateKey {
pkey: openssl::pkey::PKey<openssl::pkey::Private>,
}

#[pyo3::prelude::pyclass(frozen, module = "cryptography.hazmat.bindings._rust.openssl.dh")]
struct DHPublicKey {
pub(crate) struct DHPublicKey {
pkey: openssl::pkey::PKey<openssl::pkey::Public>,
}

Expand Down Expand Up @@ -47,19 +46,17 @@ fn generate_parameters(generator: u32, key_size: u32) -> CryptographyResult<DHPa
Ok(DHParameters { dh })
}

#[pyo3::prelude::pyfunction]
fn private_key_from_ptr(ptr: usize) -> DHPrivateKey {
// SAFETY: Caller is responsible for passing a valid pointer.
let pkey = unsafe { openssl::pkey::PKeyRef::from_ptr(ptr as *mut _) };
pub(crate) fn private_key_from_pkey(
pkey: &openssl::pkey::PKeyRef<openssl::pkey::Private>,
) -> DHPrivateKey {
DHPrivateKey {
pkey: pkey.to_owned(),
}
}

#[pyo3::prelude::pyfunction]
fn public_key_from_ptr(ptr: usize) -> DHPublicKey {
// SAFETY: Caller is responsible for passing a valid pointer.
let pkey = unsafe { openssl::pkey::PKeyRef::from_ptr(ptr as *mut _) };
pub(crate) fn public_key_from_pkey(
pkey: &openssl::pkey::PKeyRef<openssl::pkey::Public>,
) -> DHPublicKey {
DHPublicKey {
pkey: pkey.to_owned(),
}
Expand Down Expand Up @@ -390,8 +387,6 @@ impl DHParameters {
pub(crate) fn create_module(py: pyo3::Python<'_>) -> pyo3::PyResult<&pyo3::prelude::PyModule> {
let m = pyo3::prelude::PyModule::new(py, "dh")?;
m.add_function(pyo3::wrap_pyfunction!(generate_parameters, m)?)?;
m.add_function(pyo3::wrap_pyfunction!(private_key_from_ptr, m)?)?;
m.add_function(pyo3::wrap_pyfunction!(public_key_from_ptr, m)?)?;
m.add_function(pyo3::wrap_pyfunction!(from_der_parameters, m)?)?;
m.add_function(pyo3::wrap_pyfunction!(from_pem_parameters, m)?)?;
#[cfg(not(CRYPTOGRAPHY_IS_BORINGSSL))]
Expand Down
21 changes: 8 additions & 13 deletions src/rust/src/backend/dsa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,13 @@
use crate::backend::utils;
use crate::error::{CryptographyError, CryptographyResult};
use crate::{exceptions, types};
use foreign_types_shared::ForeignTypeRef;

#[pyo3::prelude::pyclass(
frozen,
module = "cryptography.hazmat.bindings._rust.openssl.dsa",
name = "DSAPrivateKey"
)]
struct DsaPrivateKey {
pub(crate) struct DsaPrivateKey {
pkey: openssl::pkey::PKey<openssl::pkey::Private>,
}

Expand All @@ -21,7 +20,7 @@ struct DsaPrivateKey {
module = "cryptography.hazmat.bindings._rust.openssl.dsa",
name = "DSAPublicKey"
)]
struct DsaPublicKey {
pub(crate) struct DsaPublicKey {
pkey: openssl::pkey::PKey<openssl::pkey::Public>,
}

Expand All @@ -34,19 +33,17 @@ struct DsaParameters {
dsa: openssl::dsa::Dsa<openssl::pkey::Params>,
}

#[pyo3::prelude::pyfunction]
fn private_key_from_ptr(ptr: usize) -> DsaPrivateKey {
// SAFETY: Caller is responsible for passing a valid pointer.
let pkey = unsafe { openssl::pkey::PKeyRef::from_ptr(ptr as *mut _) };
pub(crate) fn private_key_from_pkey(
pkey: &openssl::pkey::PKeyRef<openssl::pkey::Private>,
) -> DsaPrivateKey {
DsaPrivateKey {
pkey: pkey.to_owned(),
}
}

#[pyo3::prelude::pyfunction]
fn public_key_from_ptr(ptr: usize) -> DsaPublicKey {
// SAFETY: Caller is responsible for passing a valid pointer.
let pkey = unsafe { openssl::pkey::PKeyRef::from_ptr(ptr as *mut _) };
pub(crate) fn public_key_from_pkey(
pkey: &openssl::pkey::PKeyRef<openssl::pkey::Public>,
) -> DsaPublicKey {
DsaPublicKey {
pkey: pkey.to_owned(),
}
Expand Down Expand Up @@ -293,8 +290,6 @@ impl DsaParameters {

pub(crate) fn create_module(py: pyo3::Python<'_>) -> pyo3::PyResult<&pyo3::prelude::PyModule> {
let m = pyo3::prelude::PyModule::new(py, "dsa")?;
m.add_function(pyo3::wrap_pyfunction!(private_key_from_ptr, m)?)?;
m.add_function(pyo3::wrap_pyfunction!(public_key_from_ptr, m)?)?;
m.add_function(pyo3::wrap_pyfunction!(generate_parameters, m)?)?;
m.add_function(pyo3::wrap_pyfunction!(from_private_numbers, m)?)?;
m.add_function(pyo3::wrap_pyfunction!(from_public_numbers, m)?)?;
Expand Down
Loading

0 comments on commit b831ec5

Please sign in to comment.