diff --git a/.gitmodules b/.gitmodules index 36634ee2..288ea77a 100644 --- a/.gitmodules +++ b/.gitmodules @@ -4,6 +4,3 @@ [submodule "deps/td-shim"] path = deps/td-shim url = https://github.com/confidential-containers/td-shim -[submodule "deps/rustls"] - path = deps/rustls - url = https://github.com/rustls/rustls.git diff --git a/Cargo.lock b/Cargo.lock index 190a2ee8..9f5bb9d4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -283,6 +283,7 @@ dependencies = [ "rustls", "rustls-pemfile", "rustls-pki-types", + "sys_time", "zeroize", ] @@ -813,10 +814,12 @@ dependencies = [ [[package]] name = "rustls" -version = "0.22.4" +version = "0.23.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c58f8c84392efc0a126acce10fa59ff7b3d2ac06ab451a33f2741989b806b044" dependencies = [ + "once_cell", "ring", - "rust_std_stub", "rustls-pki-types", "rustls-webpki", "subtle", @@ -835,15 +838,15 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.0.0" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eb0a1f9b9efec70d32e6d6aa3e58ebd88c3754ec98dfe9145c63cf54cc829b83" +checksum = "fc0a2ce646f8655401bb81e7927b812614bd5d91dbc968696be50603510fcaf0" [[package]] name = "rustls-webpki" -version = "0.102.0" +version = "0.102.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "de2635c8bc2b88d367767c5de8ea1d8db9af3f6219eba28442242d9ab81d1b89" +checksum = "64ca1bc8749bd4cf37b5ce386cc146580777b4e8572c7b97baf22c83f444bee9" dependencies = [ "ring", "rustls-pki-types", @@ -1540,6 +1543,6 @@ dependencies = [ [[package]] name = "zeroize" -version = "1.6.0" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a0956f1ba7c7909bfb66c2e9e4124ab6f6482560f6628b5aaeba39207c9aad9" +checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" diff --git a/Cargo.toml b/Cargo.toml index 377fb10b..b91e4c77 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,6 +31,7 @@ resolver = "2" panic = "abort" # disable stack unwinding on panic opt-level = "z" lto = true +strip = true # the profile used for `cargo build --release` [profile.release] @@ -40,4 +41,3 @@ lto = true [patch.crates-io] ring = { path = "deps/td-shim/library/ring" } -rustls = { path = "deps/rustls/rustls" } diff --git a/deps/patches/rustls.diff b/deps/patches/rustls.diff deleted file mode 100644 index 6766f60c..00000000 --- a/deps/patches/rustls.diff +++ /dev/null @@ -1,268 +0,0 @@ -diff --git a/rustls/Cargo.toml b/rustls/Cargo.toml -index 4ec52f86..1962fb28 100644 ---- a/rustls/Cargo.toml -+++ b/rustls/Cargo.toml -@@ -18,14 +18,18 @@ rustversion = { version = "1.0.6", optional = true } - [dependencies] - aws-lc-rs = { version = "1.5", optional = true } - log = { version = "0.4.4", optional = true } --ring = { version = "0.17", optional = true } -+ring = { version = "0.17", features = ["alloc", "less-safe-getrandom-custom-or-rdrand"], default-features = false, optional = true } - subtle = { version = "2.5.0", default-features = false } --webpki = { package = "rustls-webpki", version = "0.102.1", features = ["std"], default-features = false } --pki-types = { package = "rustls-pki-types", version = "1", features = ["std"] } -+webpki = { package = "rustls-webpki", version = "0.102", features = ["alloc", "ring"], default-features = false } -+pki-types = { package = "rustls-pki-types", version = "1" } -+rust_std_stub = { path = "../../../src/std-support/rust-std-stub", optional = true } - zeroize = "1.6.0" - - [features] - default = ["logging", "ring", "tls12"] -+alloc = ["ring/alloc", "webpki/alloc"] -+std = ["alloc", "ring/std", "webpki/std"] -+no_std = ["rust_std_stub", "alloc"] - logging = ["log"] - aws_lc_rs = ["dep:aws-lc-rs", "webpki/aws_lc_rs"] - ring = ["dep:ring", "webpki/ring"] -diff --git a/rustls/src/client/hs.rs b/rustls/src/client/hs.rs -index 26ce6383..3ce9d5dc 100644 ---- a/rustls/src/client/hs.rs -+++ b/rustls/src/client/hs.rs -@@ -68,7 +68,15 @@ fn find_session( - None - }) - .and_then(|resuming| { -+ #[cfg(feature = "std")] - let retrieved = persist::Retrieved::new(resuming, UnixTime::now()); -+ #[cfg(not(feature = "std"))] -+ let retrieved = persist::Retrieved::new( -+ resuming, -+ UnixTime::since_unix_epoch(core::time::Duration::from_secs( -+ std::time::now().as_secs(), -+ )), -+ ); - match retrieved.has_expired() { - false => Some(retrieved), - true => None, -diff --git a/rustls/src/client/tls13.rs b/rustls/src/client/tls13.rs -index fdd53b95..e8926b47 100644 ---- a/rustls/src/client/tls13.rs -+++ b/rustls/src/client/tls13.rs -@@ -673,7 +673,12 @@ impl State for ExpectCertificateVerify { - intermediates, - &self.server_name, - &self.server_cert.ocsp_response, -+ #[cfg(feature = "std")] - UnixTime::now(), -+ #[cfg(not(feature = "std"))] -+ UnixTime::since_unix_epoch(core::time::Duration::from_secs( -+ std::time::now().as_secs(), -+ )), - ) - .map_err(|err| { - cx.common -@@ -956,7 +961,10 @@ impl ExpectTraffic { - .peer_certificates - .clone() - .unwrap_or_default(), -+ #[cfg(feature = "std")] - UnixTime::now(), -+ #[cfg(not(feature = "std"))] -+ UnixTime::since_unix_epoch(core::time::Duration::from_secs(std::time::now().as_secs())), - nst.lifetime, - nst.age_add, - nst.get_max_early_data_size() -diff --git a/rustls/src/error.rs b/rustls/src/error.rs -index 7d692b7f..4c7c3eb2 100644 ---- a/rustls/src/error.rs -+++ b/rustls/src/error.rs -@@ -550,7 +550,7 @@ impl From for Error { - /// - /// Enums holding this type will never compare equal to each other. - #[derive(Debug, Clone)] --pub struct OtherError(pub Arc); -+pub struct OtherError(pub Arc); - - impl PartialEq for OtherError { - fn eq(&self, _other: &Self) -> bool { -@@ -570,12 +570,6 @@ impl fmt::Display for OtherError { - } - } - --impl StdError for OtherError { -- fn source(&self) -> Option<&(dyn StdError + 'static)> { -- Some(self.0.as_ref()) -- } --} -- - #[cfg(test)] - mod tests { - use super::{Error, InvalidMessage}; -diff --git a/rustls/src/lib.rs b/rustls/src/lib.rs -index 8988f31c..2ba40be4 100644 ---- a/rustls/src/lib.rs -+++ b/rustls/src/lib.rs -@@ -273,7 +273,9 @@ - - // Require docs for public APIs, deny unsafe code, etc. - #![forbid(unsafe_code, unused_must_use)] --#![cfg_attr(not(any(read_buf, bench)), forbid(unstable_features))] -+// If std feature enabled, forbit unstable_features -+#![cfg_attr(feature = "std", forbid(unstable_features))] -+#![cfg_attr(feature = "std", deny(unused_qualifications))] - #![deny( - clippy::alloc_instead_of_core, - clippy::clone_on_ref_ptr, -@@ -285,8 +287,7 @@ - missing_docs, - unreachable_pub, - unused_import_braces, -- unused_extern_crates, -- unused_qualifications -+ unused_extern_crates - )] - // Relax these clippy lints: - // - ptr_arg: this triggers on references to type aliases that are Vec -@@ -305,6 +306,8 @@ - clippy::single_component_path_imports, - clippy::new_without_default - )] -+#![allow(internal_features)] -+#![feature(prelude_import)] - // Enable documentation for all features on docs.rs - #![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))] - // XXX: Because of https://github.com/rust-lang/rust/issues/54726, we cannot -@@ -326,8 +329,20 @@ extern crate alloc; - // is in `std::prelude` but not in `core::prelude`. This helps maintain no-std support as even - // developers that are not interested in, or aware of, no-std support and / or that never run - // `cargo build --no-default-features` locally will get errors when they rely on `std::prelude` API. -+#[cfg(all(not(test), feature = "std"))] - extern crate std; - -+#[cfg(not(feature = "std"))] -+extern crate rust_std_stub as std; -+ -+// prelude internal_std for calling Vec, String, Mutex, HashMap, etc. -+#[cfg(not(feature = "std"))] -+#[prelude_import] -+#[allow(unused_imports)] -+#[allow(unused_attributes)] -+#[macro_use] -+use std::prelude::*; -+ - // Import `test` sysroot crate for `Bencher` definitions. - #[cfg(bench)] - #[allow(unused_extern_crates)] -@@ -374,6 +389,7 @@ mod bs_debug; - mod builder; - mod enums; - mod key_log; -+#[cfg(feature = "std")] - mod key_log_file; - mod suites; - mod versions; -@@ -441,7 +457,10 @@ pub use crate::error::{ - CertRevocationListError, CertificateError, Error, InvalidMessage, OtherError, PeerIncompatible, - PeerMisbehaved, - }; -+#[cfg(not(feature = "std"))] -+pub use crate::key_log::NoKeyLog as KeyLogFile; - pub use crate::key_log::{KeyLog, NoKeyLog}; -+#[cfg(feature = "std")] - pub use crate::key_log_file::KeyLogFile; - pub use crate::msgs::enums::NamedGroup; - pub use crate::msgs::handshake::DistinguishedName; -diff --git a/rustls/src/server/tls13.rs b/rustls/src/server/tls13.rs -index 290fb3db..45121a2d 100644 ---- a/rustls/src/server/tls13.rs -+++ b/rustls/src/server/tls13.rs -@@ -312,10 +312,16 @@ mod client_hello { - } - - for (i, psk_id) in psk_offer.identities.iter().enumerate() { -+ #[cfg(feature = "std")] -+ let now = UnixTime::now(); -+ #[cfg(not(feature = "std"))] -+ let now = UnixTime::since_unix_epoch(core::time::Duration::from_secs( -+ std::time::now().as_secs(), -+ )); - let resume = match self - .attempt_tls13_ticket_decryption(&psk_id.identity.0) - .map(|resumedata| { -- resumedata.set_freshness(psk_id.obfuscated_ticket_age, UnixTime::now()) -+ resumedata.set_freshness(psk_id.obfuscated_ticket_age, now) - }) - .filter(|resumedata| { - hs::can_resume(self.suite.into(), &cx.data.sni, false, resumedata) -@@ -921,9 +927,14 @@ impl State for ExpectCertificate { - Some(chain) => chain, - }; - -+ #[cfg(feature = "std")] -+ let now = UnixTime::now(); -+ #[cfg(not(feature = "std"))] -+ let now = -+ UnixTime::since_unix_epoch(core::time::Duration::from_secs(std::time::now().as_secs())); - self.config - .verifier -- .verify_client_cert(end_entity, intermediates, UnixTime::now()) -+ .verify_client_cert(end_entity, intermediates, now) - .map_err(|err| { - cx.common - .send_cert_verify_error_alert(err) -@@ -1096,7 +1107,10 @@ impl ExpectFinished { - key_schedule, - cx, - &nonce, -+ #[cfg(feature = "std")] - UnixTime::now(), -+ #[cfg(not(feature = "std"))] -+ UnixTime::since_unix_epoch(core::time::Duration::from_secs(std::time::now().as_secs())), - age_add, - ) - .get_encoding(); -diff --git a/rustls/src/ticketer.rs b/rustls/src/ticketer.rs -index ddadb0ef..8a43f5bc 100644 ---- a/rustls/src/ticketer.rs -+++ b/rustls/src/ticketer.rs -@@ -46,9 +46,12 @@ impl TicketSwitcher { - next: Some(generator()?), - current: generator()?, - previous: None, -+ #[cfg(feature = "std")] - next_switch_time: UnixTime::now() - .as_secs() - .saturating_add(u64::from(lifetime)), -+ #[cfg(not(feature = "std"))] -+ next_switch_time: std::time::now().as_secs(), - }), - }) - } -@@ -144,13 +147,25 @@ impl ProducesTickets for TicketSwitcher { - } - - fn encrypt(&self, message: &[u8]) -> Option> { -- let state = self.maybe_roll(UnixTime::now())?; -+ #[cfg(feature = "std")] -+ let now = UnixTime::now(); -+ #[cfg(not(feature = "std"))] -+ let now = -+ UnixTime::since_unix_epoch(core::time::Duration::from_secs(std::time::now().as_secs())); -+ -+ let state = self.maybe_roll(now)?; - - state.current.encrypt(message) - } - - fn decrypt(&self, ciphertext: &[u8]) -> Option> { -- let state = self.maybe_roll(UnixTime::now())?; -+ #[cfg(feature = "std")] -+ let now = UnixTime::now(); -+ #[cfg(not(feature = "std"))] -+ let now = -+ UnixTime::since_unix_epoch(core::time::Duration::from_secs(std::time::now().as_secs())); -+ -+ let state = self.maybe_roll(now)?; - - // Decrypt with the current key; if that fails, try with the previous. - state diff --git a/deps/rustls b/deps/rustls deleted file mode 160000 index ae277bef..00000000 --- a/deps/rustls +++ /dev/null @@ -1 +0,0 @@ -Subproject commit ae277befb5061bbd4c44fea1c2697f2da5b2f6fa diff --git a/sh_script/preparation.sh b/sh_script/preparation.sh index 74c1965a..18126c77 100755 --- a/sh_script/preparation.sh +++ b/sh_script/preparation.sh @@ -4,12 +4,6 @@ preparation() { pushd deps/td-shim bash sh_script/preparation.sh popd - - pushd deps/rustls - git reset --hard ae277befb5061bbd4c44fea1c2697f2da5b2f6fa - git clean -f -d - patch -p 1 -i ../patches/rustls.diff - popd } preparation diff --git a/src/crypto/Cargo.toml b/src/crypto/Cargo.toml index d4c0b226..14f2c21e 100644 --- a/src/crypto/Cargo.toml +++ b/src/crypto/Cargo.toml @@ -9,9 +9,10 @@ cfg-if = "1.0" der = {version = "0.7.9", features = ["oid", "alloc", "derive"]} pki-types = { package = "rustls-pki-types", version = "1" } rust_std_stub = { path = "../std-support/rust-std-stub" } -rustls = { path = "../../deps/rustls/rustls", default-features = false, features = ["no_std"], optional = true } +rustls = { version = "=0.23.12", default-features = false, features = ["ring" ], optional = true } rustls-pemfile = { version = "2.0.0", default-features = false } -ring = { path = "../../deps/td-shim/library/ring", default-features = false, features = ["alloc"], optional = true } +ring = { path = "../../deps/td-shim/library/ring", default-features = false, features = ["alloc", "less-safe-getrandom-custom-or-rdrand"], optional = true } +sys_time = { path = "../std-support/sys_time" } zeroize = "1.5.7" [features] diff --git a/src/crypto/src/lib.rs b/src/crypto/src/lib.rs index 908e6273..e9d89b51 100644 --- a/src/crypto/src/lib.rs +++ b/src/crypto/src/lib.rs @@ -64,6 +64,9 @@ pub enum Error { /// Unable to verify the TLS peer's certificates TlsVerifyPeerCert(String), + /// Error occurs during processing the tls connection + TlsConnection, + /// Pem certificate parsing error DecodePemCert, diff --git a/src/crypto/src/rustls_impl/tls.rs b/src/crypto/src/rustls_impl/tls.rs index d44edfd3..f2ffa3c6 100644 --- a/src/crypto/src/rustls_impl/tls.rs +++ b/src/crypto/src/rustls_impl/tls.rs @@ -2,11 +2,14 @@ // // SPDX-License-Identifier: BSD-2-Clause-Patent +use core::time::Duration; + use alloc::string::ToString; +use alloc::sync::Arc; use alloc::vec::Vec; +use connection::{TlsClientConnection, TlsConnectionError, TlsServerConnection}; use pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer, ServerName, UnixTime}; -use rust_std_stub::io::{self, Read, Write}; -use rust_std_stub::sync::Arc; +use rust_std_stub::io::{Read, Write}; use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}; use rustls::client::ResolvesClientCert; use rustls::crypto::ring::cipher_suite::TLS13_AES_256_GCM_SHA384; @@ -17,8 +20,10 @@ use rustls::crypto::{verify_tls12_signature, verify_tls13_signature, CryptoProvi use rustls::server::danger::{ClientCertVerified, ClientCertVerifier}; use rustls::server::{ClientHello, ResolvesServerCert}; use rustls::sign::{CertifiedKey, SigningKey}; +use rustls::time_provider::TimeProvider; use rustls::version::TLS13; -use rustls::{ClientConfig, ClientConnection, ServerConfig, ServerConnection}; +use rustls::{ClientConfig, ServerConfig}; +extern crate alloc; use crate::{Error, Result}; @@ -28,96 +33,57 @@ pub type TlsLibError = rustls::Error; const TLS_CUSTOM_CALLBACK_ERROR: &str = "TlsCustomCallbackError"; pub struct SecureChannel { - conn: TlsConnection, - stream: T, + conn: TlsConnection, } impl SecureChannel where T: Read + Write, { - fn new(conn: TlsConnection, stream: T) -> Self { - SecureChannel { conn, stream } + fn new(conn: TlsConnection) -> Self { + SecureChannel { conn } } pub fn write(&mut self, data: &[u8]) -> Result { - self.conn.write(&mut self.stream, data) + self.conn.write(data) } pub fn read(&mut self, data: &mut [u8]) -> Result { - self.conn.read(&mut self.stream, data) - } - - pub fn peer_cert(&mut self) -> Option> { - self.conn.peer_cert() + self.conn.read(data) } } -enum TlsConnection { - Server(ServerConnection), - Client(ClientConnection), +enum TlsConnection { + Server(TlsServerConnection), + Client(TlsClientConnection), } -impl TlsConnection { - fn read(&mut self, stream: &mut T, data: &mut [u8]) -> Result { - match self { - Self::Server(conn) => { - let mut tls_stream = rustls::Stream::new(conn, stream); - tls_stream.read(data).map_err(Self::handle_stream_error) - } - Self::Client(conn) => { - let mut tls_stream = rustls::Stream::new(conn, stream); - tls_stream.read(data).map_err(Self::handle_stream_error) - } - } - } - - fn write(&mut self, stream: &mut T, data: &[u8]) -> Result { +impl TlsConnection { + fn read(&mut self, data: &mut [u8]) -> Result { match self { - Self::Server(conn) => { - let mut tls_stream = rustls::Stream::new(conn, stream); - tls_stream.write(data).map_err(Self::handle_stream_error) - } - Self::Client(conn) => { - let mut tls_stream = rustls::Stream::new(conn, stream); - tls_stream.write(data).map_err(Self::handle_stream_error) - } + Self::Server(conn) => conn.read(data).map_err(Self::handle_stream_error), + Self::Client(conn) => conn.read(data).map_err(Self::handle_stream_error), } } - fn peer_cert(&self) -> Option> { - let mut list = Vec::new(); + fn write(&mut self, data: &[u8]) -> Result { match self { - Self::Server(conn) => conn.peer_certificates().map(|certs| { - for cert in certs { - list.push(cert.as_ref()) - } - list - }), - Self::Client(conn) => conn.peer_certificates().map(|certs| { - for cert in certs { - list.push(cert.as_ref()) - } - list - }), + Self::Server(conn) => conn.write(data).map_err(Self::handle_stream_error), + Self::Client(conn) => conn.write(data).map_err(Self::handle_stream_error), } } - fn handle_stream_error(e: io::Error) -> Error { - match e.kind() { - io::ErrorKind::InvalidData => { - let desc = e.to_string(); - + fn handle_stream_error(e: TlsConnectionError) -> Error { + match e { + TlsConnectionError::TlsLib(rustls::Error::General(desc)) => { if let Some(index) = desc.find(TLS_CUSTOM_CALLBACK_ERROR) { let start = index + TLS_CUSTOM_CALLBACK_ERROR.len() + 1; let end = match desc[start..].find(')') { Some(index) => start + index, - None => return Error::Unexpected, + None => return Error::TlsStream, }; - return Error::TlsVerifyPeerCert(desc[start..end].to_string()); } - Error::TlsStream } _ => Error::TlsStream, @@ -134,7 +100,7 @@ impl TlsConfig { pub fn new( certs_der: Vec>, signing_key: EcdsaPk, - verify_callback: fn(&[u8], &[u8]) -> Result<()>, + verify_callback: fn(&[u8], &[u8]) -> core::result::Result<(), Error>, verify_callback_data: Vec, ) -> Result { let mut certs = Vec::new(); @@ -163,7 +129,7 @@ impl TlsConfig { pub fn set_verify_callback( &mut self, - cb: fn(&[u8], &[u8]) -> Result<()>, + cb: fn(&[u8], &[u8]) -> core::result::Result<(), Error>, data: Vec, ) -> Result<()> { self.verifier = Verifier::new(cb, data); @@ -172,41 +138,38 @@ impl TlsConfig { } pub fn tls_client(self, stream: T) -> Result> { - let client_config = ClientConfig::builder_with_provider(Arc::new(crypto_provider())) - .with_protocol_versions(&[&TLS13]) - .map_err(Error::SetupTlsContext)? - // `dangerous()` method of `ClientConfig` allows setting inadvisable options, such as replacing the - // certificate verification process. - .dangerous() - .with_custom_certificate_verifier(Arc::new(self.verifier)) - .with_client_cert_resolver(Arc::new(self.resolver)); - - let connection = rustls::ClientConnection::new( - alloc::sync::Arc::new(client_config), - ServerName::try_from("localhost").map_err(|_| Error::InvalidDnsName)?, + let client_config = ClientConfig::builder_with_details( + Arc::new(crypto_provider()), + Arc::new(TlsTimeProvider {}), ) - .map_err(Error::SetupTlsContext)?; - - Ok(SecureChannel::new( - TlsConnection::Client(connection), - stream, - )) + .with_protocol_versions(&[&TLS13]) + .map_err(Error::SetupTlsContext)? + // `dangerous()` method of `ClientConfig` allows setting inadvisable options, such as replacing the + // certificate verification process. + .dangerous() + .with_custom_certificate_verifier(Arc::new(self.verifier)) + .with_client_cert_resolver(Arc::new(self.resolver)); + + let connection = TlsClientConnection::new(Arc::new(client_config), stream) + .map_err(|_| Error::TlsConnection)?; + + Ok(SecureChannel::new(TlsConnection::Client(connection))) } pub fn tls_server(self, stream: T) -> Result> { - let server_config = ServerConfig::builder_with_provider(Arc::new(crypto_provider())) - .with_protocol_versions(&[&TLS13]) - .map_err(Error::SetupTlsContext)? - .with_client_cert_verifier(Arc::new(self.verifier)) - .with_cert_resolver(Arc::new(self.resolver)); + let server_config = ServerConfig::builder_with_details( + Arc::new(crypto_provider()), + Arc::new(TlsTimeProvider {}), + ) + .with_protocol_versions(&[&TLS13]) + .map_err(Error::SetupTlsContext)? + .with_client_cert_verifier(Arc::new(self.verifier)) + .with_cert_resolver(Arc::new(self.resolver)); - let connection = rustls::ServerConnection::new(alloc::sync::Arc::new(server_config)) - .map_err(Error::SetupTlsContext)?; + let connection = TlsServerConnection::new(Arc::new(server_config), stream) + .map_err(|_| Error::TlsConnection)?; - Ok(SecureChannel::new( - TlsConnection::Server(connection), - stream, - )) + Ok(SecureChannel::new(TlsConnection::Server(connection))) } } @@ -391,3 +354,584 @@ impl ClientCertVerifier for Verifier { &[] } } + +pub(crate) mod connection { + use alloc::{collections::VecDeque, sync::Arc, vec::Vec}; + use rust_std_stub::io::{self, Read, Write}; + use rustls::{ + client::UnbufferedClientConnection, + server::UnbufferedServerConnection, + unbuffered::{ + AppDataRecord, ConnectionState, EncodeError, EncryptError, InsufficientSizeError, + UnbufferedStatus, + }, + ClientConfig, ServerConfig, + }; + + const PAGE_SIZE: usize = 0x1000; + const TLS_BUFFER_SIZE: usize = 16 * PAGE_SIZE; + const APP_DATA_BUFFER_LIMIT: usize = PAGE_SIZE; + + #[derive(Debug)] + pub enum TlsConnectionError { + /// Error occurs during encoding tls data + Encode, + + /// Error occurs during encrypt tls data + Encrypt, + + /// Unexpected tls state + UnexpectedState, + + /// Tls lib error + TlsLib(rustls::Error), + + /// Failed to read/write transport + Transport, + } + + impl From for TlsConnectionError { + fn from(_: io::Error) -> Self { + Self::Transport + } + } + + impl From for TlsConnectionError { + fn from(value: rustls::Error) -> Self { + Self::TlsLib(value) + } + } + + impl From for TlsConnectionError { + fn from(_: rustls::unbuffered::EncodeError) -> Self { + Self::Encode + } + } + + impl From for TlsConnectionError { + fn from(_: rustls::unbuffered::EncryptError) -> Self { + Self::Encrypt + } + } + + struct TlsBuffer { + inner: Vec, + used: usize, + } + + impl TlsBuffer { + fn new() -> Self { + TlsBuffer { + inner: vec![0u8; TLS_BUFFER_SIZE], + used: 0, + } + } + + // Try to run `f` and resize the buffer and try again if we got `InsufficientSizeError` + fn try_or_resize_and_retry( + &mut self, + mut f: impl FnMut(&mut [u8]) -> Result, + map_err: impl FnOnce(E) -> Result, + ) -> Result { + let written = match f(self.unused_mut()) { + Ok(written) => written, + + Err(e) => { + let InsufficientSizeError { required_size } = map_err(e)?; + let new_len = self.used + required_size; + self.inner.resize(new_len, 0); + + f(self.unused_mut()).map_err(|_| TlsConnectionError::Encode)? + } + }; + + self.used += written; + + Ok(written) + } + + // Get the immutable reference of used buffer + fn used(&self) -> &[u8] { + &self.inner[..self.used] + } + + // Get the mutable reference of used buffer + fn used_mut(&mut self) -> &mut [u8] { + &mut self.inner[..self.used] + } + + // Get the mutable reference of unused buffer + fn unused_mut(&mut self) -> &mut [u8] { + &mut self.inner[self.used..] + } + + // Reset the used + fn reset(&mut self) { + self.used = 0; + } + + // Accumulate number of used bytes + fn consume(&mut self, size: usize) { + self.used += size; + } + + // Discard the first `size` bytes + fn discard(&mut self, size: usize) { + self.inner.copy_within(size..self.used, 0); + self.used -= size; + } + } + + pub struct TlsServerConnection { + conn: UnbufferedServerConnection, + input: TlsBuffer, + output: TlsBuffer, + transport: T, + is_handshaking: bool, + received_app_data: ChunkVecBuffer, + } + + impl TlsServerConnection { + pub fn new(config: Arc, transport: T) -> Result { + Ok(Self { + conn: UnbufferedServerConnection::new(config)?, + transport, + input: TlsBuffer::new(), + output: TlsBuffer::new(), + is_handshaking: true, + received_app_data: ChunkVecBuffer::new(Some(APP_DATA_BUFFER_LIMIT)), + }) + } + + pub fn read(&mut self, data: &mut [u8]) -> Result { + if self.is_handshaking { + self.process_tls_status()?; + } + + if !self.received_app_data.is_empty() { + return Ok(self.received_app_data.read(data)); + } + + loop { + let UnbufferedStatus { mut discard, state } = + self.conn.process_tls_records(self.input.used_mut()); + match state? { + ConnectionState::ReadTraffic(mut state) => { + while let Some(res) = state.next_record() { + let AppDataRecord { + discard: new_discard, + payload, + } = res?; + discard += new_discard; + self.received_app_data.append(payload.to_vec()); + } + let read = self.received_app_data.read(data); + self.input.discard(discard); + return Ok(read); + } + ConnectionState::WriteTraffic(..) => { + let size = self.transport.read(self.input.unused_mut())?; + self.input.consume(size); + } + _ => return Err(TlsConnectionError::UnexpectedState), + } + self.input.discard(discard); + } + } + + pub fn write(&mut self, data: &[u8]) -> Result { + if self.is_handshaking { + self.process_tls_status()?; + } + + loop { + let UnbufferedStatus { mut discard, state } = + self.conn.process_tls_records(self.input.used_mut()); + + match state? { + ConnectionState::ReadTraffic(mut state) => { + while let Some(res) = state.next_record() { + let AppDataRecord { + discard: new_discard, + payload, + } = res?; + discard += new_discard; + self.received_app_data.append(payload.to_vec()); + } + } + ConnectionState::WriteTraffic(mut state) => { + let map_err = |e| { + if let EncryptError::InsufficientSize(is) = &e { + Ok(*is) + } else { + Err(e.into()) + } + }; + self.output.try_or_resize_and_retry( + |out_buffer| state.encrypt(data, out_buffer), + map_err, + )?; + self.transport.write(self.output.used())?; + self.output.reset(); + break; + } + _ => return Err(TlsConnectionError::UnexpectedState), + } + self.input.discard(discard); + } + Ok(data.len()) + } + + fn process_tls_status(&mut self) -> Result<(), TlsConnectionError> { + loop { + let UnbufferedStatus { mut discard, state } = + self.conn.process_tls_records(self.input.used_mut()); + + match state { + Ok(state) => match state { + ConnectionState::EncodeTlsData(mut state) => { + let _ = self.output.try_or_resize_and_retry( + |out_buffer| state.encode(out_buffer), + |e| { + if let EncodeError::InsufficientSize(is) = &e { + Ok(*is) + } else { + Err(e.into()) + } + }, + )?; + } + ConnectionState::TransmitTlsData(state) => { + self.transport.write(self.output.used())?; + self.output.reset(); + state.done(); + } + ConnectionState::BlockedHandshake { .. } => { + let size = self.transport.read(self.input.unused_mut())?; + self.input.consume(size); + } + ConnectionState::ReadTraffic(mut state) => { + while let Some(res) = state.next_record() { + let AppDataRecord { + discard: new_discard, + payload, + } = res?; + discard += new_discard; + self.received_app_data.append(payload.to_vec()); + } + self.is_handshaking = false; + self.input.discard(discard); + break; + } + ConnectionState::WriteTraffic { .. } => { + self.is_handshaking = false; + self.input.discard(discard); + break; + } + _ => return Err(TlsConnectionError::UnexpectedState), + }, + Err(e) => { + self.input.discard(discard); + self.handle_tls_error()?; + return Err(TlsConnectionError::TlsLib(e)); + } + } + self.input.discard(discard); + } + Ok(()) + } + + fn handle_tls_error(&mut self) -> Result<(), TlsConnectionError> { + let status = self.conn.process_tls_records(self.input.used_mut()); + match status.state? { + ConnectionState::EncodeTlsData(mut state) => { + let _ = self.output.try_or_resize_and_retry( + |out_buffer| state.encode(out_buffer), + |e| { + if let EncodeError::InsufficientSize(is) = &e { + Ok(*is) + } else { + Err(e.into()) + } + }, + )?; + self.transport.write(self.output.used())?; + self.output.reset(); + Ok(()) + } + _ => Ok(()), + } + } + } + + // Derived from `rustls::vecbuf` + struct ChunkVecBuffer { + chunks: VecDeque>, + limit: Option, + } + + impl ChunkVecBuffer { + fn new(limit: Option) -> Self { + Self { + chunks: VecDeque::new(), + limit, + } + } + + fn is_full(&self) -> bool { + self.limit + .map(|limit| self.len() > limit) + .unwrap_or_default() + } + + fn is_empty(&self) -> bool { + self.chunks.is_empty() + } + + fn len(&self) -> usize { + let mut len = 0; + for ch in &self.chunks { + len += ch.len(); + } + len + } + + fn append(&mut self, bytes: Vec) -> usize { + let len = bytes.len(); + + if !bytes.is_empty() { + self.chunks.push_back(bytes); + } + + len + } + + fn read(&mut self, buf: &mut [u8]) -> usize { + let mut offs = 0; + + while offs < buf.len() && !self.is_empty() { + let used; + if buf.len() - offs >= self.chunks[0].len() { + used = self.chunks[0].len(); + buf[offs..offs + used].copy_from_slice(&self.chunks[0]); + } else { + used = buf.len() - offs; + buf[offs..].copy_from_slice(&self.chunks[0][..used]); + } + + self.consume(used); + offs += used; + } + + offs + } + + fn consume(&mut self, mut used: usize) { + while let Some(mut buf) = self.chunks.pop_front() { + if used < buf.len() { + buf.drain(..used); + self.chunks.push_front(buf); + break; + } else { + used -= buf.len(); + } + } + } + } + + pub struct TlsClientConnection { + conn: UnbufferedClientConnection, + input: TlsBuffer, + output: TlsBuffer, + transport: T, + is_handshaking: bool, + received_app_data: ChunkVecBuffer, + } + + impl TlsClientConnection { + pub fn new(config: Arc, transport: T) -> Result { + Ok(Self { + conn: UnbufferedClientConnection::new(config, "localhost".try_into().unwrap())?, + transport, + input: TlsBuffer::new(), + output: TlsBuffer::new(), + is_handshaking: true, + received_app_data: ChunkVecBuffer::new(Some(APP_DATA_BUFFER_LIMIT)), + }) + } + + pub fn read(&mut self, data: &mut [u8]) -> Result { + if self.is_handshaking { + self.process_tls_status()?; + } + + if !self.received_app_data.is_empty() { + return Ok(self.received_app_data.read(data)); + } + + loop { + let UnbufferedStatus { mut discard, state } = + self.conn.process_tls_records(self.input.used_mut()); + match state? { + ConnectionState::ReadTraffic(mut state) => { + while let Some(res) = state.next_record() { + let AppDataRecord { + discard: new_discard, + payload, + } = res?; + if !self.received_app_data.is_full() { + discard += new_discard; + self.received_app_data.append(payload.to_vec()); + } + } + let read = self.received_app_data.read(data); + self.input.discard(discard); + return Ok(read); + } + ConnectionState::WriteTraffic(..) => { + let size = self.transport.read(self.input.unused_mut())?; + self.input.consume(size); + } + _ => return Err(TlsConnectionError::UnexpectedState), + } + self.input.discard(discard); + } + } + + pub fn write(&mut self, data: &[u8]) -> Result { + if self.is_handshaking { + self.process_tls_status()?; + } + + loop { + let UnbufferedStatus { mut discard, state } = + self.conn.process_tls_records(self.input.used_mut()); + + match state? { + ConnectionState::ReadTraffic(mut state) => { + while let Some(res) = state.next_record() { + let AppDataRecord { + discard: new_discard, + payload, + } = res?; + discard += new_discard; + self.received_app_data.append(payload.to_vec()); + } + } + ConnectionState::WriteTraffic(mut state) => { + let map_err = |e| { + if let EncryptError::InsufficientSize(is) = &e { + Ok(*is) + } else { + Err(e.into()) + } + }; + self.output.try_or_resize_and_retry( + |out_buffer| state.encrypt(data, out_buffer), + map_err, + )?; + self.transport.write(self.output.used())?; + self.output.reset(); + break; + } + _ => return Err(TlsConnectionError::UnexpectedState), + } + self.input.discard(discard); + } + Ok(data.len()) + } + + fn process_tls_status(&mut self) -> Result<(), TlsConnectionError> { + loop { + let UnbufferedStatus { mut discard, state } = + self.conn.process_tls_records(self.input.used_mut()); + + match state { + Ok(state) => match state { + ConnectionState::EncodeTlsData(mut state) => { + let _ = self.output.try_or_resize_and_retry( + |out_buffer| state.encode(out_buffer), + |e| { + if let EncodeError::InsufficientSize(is) = &e { + Ok(*is) + } else { + Err(e.into()) + } + }, + )?; + } + ConnectionState::TransmitTlsData(state) => { + self.transport.write(self.output.used())?; + self.output.reset(); + state.done(); + } + ConnectionState::BlockedHandshake { .. } => { + let size = self.transport.read(self.input.unused_mut())?; + self.input.consume(size); + } + ConnectionState::ReadTraffic(mut state) => { + while let Some(res) = state.next_record() { + let AppDataRecord { + discard: new_discard, + payload, + } = res?; + discard += new_discard; + self.received_app_data.append(payload.to_vec()); + } + self.is_handshaking = false; + self.input.discard(discard); + break; + } + ConnectionState::WriteTraffic { .. } => { + self.is_handshaking = false; + self.input.discard(discard); + break; + } + _ => return Err(TlsConnectionError::UnexpectedState), + }, + Err(e) => { + self.input.discard(discard); + self.handle_tls_error()?; + return Err(TlsConnectionError::TlsLib(e)); + } + } + self.input.discard(discard); + } + Ok(()) + } + + fn handle_tls_error(&mut self) -> Result<(), TlsConnectionError> { + let status = self.conn.process_tls_records(self.input.used_mut()); + match status.state? { + ConnectionState::EncodeTlsData(mut state) => { + let _ = self.output.try_or_resize_and_retry( + |out_buffer| state.encode(out_buffer), + |e| { + if let EncodeError::InsufficientSize(is) = &e { + Ok(*is) + } else { + Err(e.into()) + } + }, + )?; + self.transport.write(self.output.used())?; + self.output.reset(); + Ok(()) + } + _ => Ok(()), + } + } + } +} + +#[derive(Debug)] +struct TlsTimeProvider; + +impl TimeProvider for TlsTimeProvider { + fn current_time(&self) -> Option { + Some(UnixTime::since_unix_epoch(Duration::new( + sys_time::get_sys_time()? as u64, + 0, + ))) + } +}