From d4ce3f1e7c12be51bf12f3c51416735666d82582 Mon Sep 17 00:00:00 2001 From: Jiaqi Gao Date: Tue, 8 Aug 2023 10:19:26 -0400 Subject: [PATCH 1/2] std-support: add `futures_io` This crate inherits from the `futures-io`crate and replaces its dependency on std with 'rust-std-stub' using the prelude, with some deletions made. Signed-off-by: Jiaqi Gao --- src/std-support/futures-io/Cargo.toml | 9 + src/std-support/futures-io/src/lib.rs | 550 ++++++++++++++++++++++++++ 2 files changed, 559 insertions(+) create mode 100644 src/std-support/futures-io/Cargo.toml create mode 100644 src/std-support/futures-io/src/lib.rs diff --git a/src/std-support/futures-io/Cargo.toml b/src/std-support/futures-io/Cargo.toml new file mode 100644 index 00000000..04d8bc95 --- /dev/null +++ b/src/std-support/futures-io/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "futures-io" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +rust_std_stub = { path = "../rust-std-stub" } diff --git a/src/std-support/futures-io/src/lib.rs b/src/std-support/futures-io/src/lib.rs new file mode 100644 index 00000000..25b69a84 --- /dev/null +++ b/src/std-support/futures-io/src/lib.rs @@ -0,0 +1,550 @@ +//! Asynchronous I/O +//! +//! This crate contains the `AsyncRead`, `AsyncWrite`, `AsyncSeek`, and +//! `AsyncBufRead` traits, the asynchronous analogs to +//! `std::io::{Read, Write, Seek, BufRead}`. The primary difference is +//! that these traits integrate with the asynchronous task system. +//! +//! All items of this library are only available when the `std` feature of this +//! library is activated, and it is activated by default. + +#![no_std] +#![warn( + missing_debug_implementations, + missing_docs, + rust_2018_idioms, + unreachable_pub +)] +// It cannot be included in the published code because this lints have false positives in the minimum required version. +#![cfg_attr(test, warn(single_use_lifetimes))] +#![doc(test( + no_crate_inject, + attr( + deny(warnings, rust_2018_idioms, single_use_lifetimes), + allow(dead_code, unused_assignments, unused_variables) + ) +))] +#![cfg_attr(docsrs, feature(doc_cfg))] + +extern crate alloc; + +use alloc::boxed::Box; +use alloc::vec::Vec; +use core::ops::DerefMut; +use core::pin::Pin; +use core::task::{Context, Poll}; +use rust_std_stub::io; + +// Re-export some types from `std::io` so that users don't have to deal +// with conflicts when `use`ing `futures::io` and `std::io`. +#[allow(unreachable_pub)] // https://github.com/rust-lang/rust/issues/57411 +#[doc(no_inline)] +pub use io::{Error, ErrorKind, IoSlice, IoSliceMut, Result}; + +/// Read bytes asynchronously. +/// +/// This trait is analogous to the `std::io::Read` trait, but integrates +/// with the asynchronous task system. In particular, the `poll_read` +/// method, unlike `Read::read`, will automatically queue the current task +/// for wakeup and return if data is not yet available, rather than blocking +/// the calling thread. +pub trait AsyncRead { + /// Attempt to read from the `AsyncRead` into `buf`. + /// + /// On success, returns `Poll::Ready(Ok(num_bytes_read))`. + /// + /// If no data is available for reading, the method returns + /// `Poll::Pending` and arranges for the current task (via + /// `cx.waker().wake_by_ref()`) to receive a notification when the object becomes + /// readable or is closed. + /// + /// # Implementation + /// + /// This function may not return errors of kind `WouldBlock` or + /// `Interrupted`. Implementations must convert `WouldBlock` into + /// `Poll::Pending` and either internally retry or convert + /// `Interrupted` into another error kind. + fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) + -> Poll>; + + /// Attempt to read from the `AsyncRead` into `bufs` using vectored + /// IO operations. + /// + /// This method is similar to `poll_read`, but allows data to be read + /// into multiple buffers using a single operation. + /// + /// On success, returns `Poll::Ready(Ok(num_bytes_read))`. + /// + /// If no data is available for reading, the method returns + /// `Poll::Pending` and arranges for the current task (via + /// `cx.waker().wake_by_ref()`) to receive a notification when the object becomes + /// readable or is closed. + /// By default, this method delegates to using `poll_read` on the first + /// nonempty buffer in `bufs`, or an empty one if none exists. Objects which + /// support vectored IO should override this method. + /// + /// # Implementation + /// + /// This function may not return errors of kind `WouldBlock` or + /// `Interrupted`. Implementations must convert `WouldBlock` into + /// `Poll::Pending` and either internally retry or convert + /// `Interrupted` into another error kind. + fn poll_read_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &mut [IoSliceMut<'_>], + ) -> Poll> { + for b in bufs { + if !b.is_empty() { + return self.poll_read(cx, b); + } + } + + self.poll_read(cx, &mut []) + } +} + +/// Write bytes asynchronously. +/// +/// This trait is analogous to the `std::io::Write` trait, but integrates +/// with the asynchronous task system. In particular, the `poll_write` +/// method, unlike `Write::write`, will automatically queue the current task +/// for wakeup and return if the writer cannot take more data, rather than blocking +/// the calling thread. +pub trait AsyncWrite { + /// Attempt to write bytes from `buf` into the object. + /// + /// On success, returns `Poll::Ready(Ok(num_bytes_written))`. + /// + /// If the object is not ready for writing, the method returns + /// `Poll::Pending` and arranges for the current task (via + /// `cx.waker().wake_by_ref()`) to receive a notification when the object becomes + /// writable or is closed. + /// + /// # Implementation + /// + /// This function may not return errors of kind `WouldBlock` or + /// `Interrupted`. Implementations must convert `WouldBlock` into + /// `Poll::Pending` and either internally retry or convert + /// `Interrupted` into another error kind. + /// + /// `poll_write` must try to make progress by flushing the underlying object if + /// that is the only way the underlying object can become writable again. + fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll>; + + /// Attempt to write bytes from `bufs` into the object using vectored + /// IO operations. + /// + /// This method is similar to `poll_write`, but allows data from multiple buffers to be written + /// using a single operation. + /// + /// On success, returns `Poll::Ready(Ok(num_bytes_written))`. + /// + /// If the object is not ready for writing, the method returns + /// `Poll::Pending` and arranges for the current task (via + /// `cx.waker().wake_by_ref()`) to receive a notification when the object becomes + /// writable or is closed. + /// + /// By default, this method delegates to using `poll_write` on the first + /// nonempty buffer in `bufs`, or an empty one if none exists. Objects which + /// support vectored IO should override this method. + /// + /// # Implementation + /// + /// This function may not return errors of kind `WouldBlock` or + /// `Interrupted`. Implementations must convert `WouldBlock` into + /// `Poll::Pending` and either internally retry or convert + /// `Interrupted` into another error kind. + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + for b in bufs { + if !b.is_empty() { + return self.poll_write(cx, b); + } + } + + self.poll_write(cx, &[]) + } + + /// Attempt to flush the object, ensuring that any buffered data reach + /// their destination. + /// + /// On success, returns `Poll::Ready(Ok(()))`. + /// + /// If flushing cannot immediately complete, this method returns + /// `Poll::Pending` and arranges for the current task (via + /// `cx.waker().wake_by_ref()`) to receive a notification when the object can make + /// progress towards flushing. + /// + /// # Implementation + /// + /// This function may not return errors of kind `WouldBlock` or + /// `Interrupted`. Implementations must convert `WouldBlock` into + /// `Poll::Pending` and either internally retry or convert + /// `Interrupted` into another error kind. + /// + /// It only makes sense to do anything here if you actually buffer data. + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>; + + /// Attempt to close the object. + /// + /// On success, returns `Poll::Ready(Ok(()))`. + /// + /// If closing cannot immediately complete, this function returns + /// `Poll::Pending` and arranges for the current task (via + /// `cx.waker().wake_by_ref()`) to receive a notification when the object can make + /// progress towards closing. + /// + /// # Implementation + /// + /// This function may not return errors of kind `WouldBlock` or + /// `Interrupted`. Implementations must convert `WouldBlock` into + /// `Poll::Pending` and either internally retry or convert + /// `Interrupted` into another error kind. + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>; +} + +// /// Seek bytes asynchronously. +// /// +// /// This trait is analogous to the `std::io::Seek` trait, but integrates +// /// with the asynchronous task system. In particular, the `poll_seek` +// /// method, unlike `Seek::seek`, will automatically queue the current task +// /// for wakeup and return if data is not yet available, rather than blocking +// /// the calling thread. +// pub trait AsyncSeek { +// /// Attempt to seek to an offset, in bytes, in a stream. +// /// +// /// A seek beyond the end of a stream is allowed, but behavior is defined +// /// by the implementation. +// /// +// /// If the seek operation completed successfully, +// /// this method returns the new position from the start of the stream. +// /// That position can be used later with [`SeekFrom::Start`]. +// /// +// /// # Errors +// /// +// /// Seeking to a negative offset is considered an error. +// /// +// /// # Implementation +// /// +// /// This function may not return errors of kind `WouldBlock` or +// /// `Interrupted`. Implementations must convert `WouldBlock` into +// /// `Poll::Pending` and either internally retry or convert +// /// `Interrupted` into another error kind. +// fn poll_seek( +// self: Pin<&mut Self>, +// cx: &mut Context<'_>, +// pos: SeekFrom, +// ) -> Poll>; +// } + +/// Read bytes asynchronously. +/// +/// This trait is analogous to the `std::io::BufRead` trait, but integrates +/// with the asynchronous task system. In particular, the `poll_fill_buf` +/// method, unlike `BufRead::fill_buf`, will automatically queue the current task +/// for wakeup and return if data is not yet available, rather than blocking +/// the calling thread. +pub trait AsyncBufRead: AsyncRead { + /// Attempt to return the contents of the internal buffer, filling it with more data + /// from the inner reader if it is empty. + /// + /// On success, returns `Poll::Ready(Ok(buf))`. + /// + /// If no data is available for reading, the method returns + /// `Poll::Pending` and arranges for the current task (via + /// `cx.waker().wake_by_ref()`) to receive a notification when the object becomes + /// readable or is closed. + /// + /// This function is a lower-level call. It needs to be paired with the + /// [`consume`] method to function properly. When calling this + /// method, none of the contents will be "read" in the sense that later + /// calling [`poll_read`] may return the same contents. As such, [`consume`] must + /// be called with the number of bytes that are consumed from this buffer to + /// ensure that the bytes are never returned twice. + /// + /// [`poll_read`]: AsyncRead::poll_read + /// [`consume`]: AsyncBufRead::consume + /// + /// An empty buffer returned indicates that the stream has reached EOF. + /// + /// # Implementation + /// + /// This function may not return errors of kind `WouldBlock` or + /// `Interrupted`. Implementations must convert `WouldBlock` into + /// `Poll::Pending` and either internally retry or convert + /// `Interrupted` into another error kind. + fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>; + + /// Tells this buffer that `amt` bytes have been consumed from the buffer, + /// so they should no longer be returned in calls to [`poll_read`]. + /// + /// This function is a lower-level call. It needs to be paired with the + /// [`poll_fill_buf`] method to function properly. This function does + /// not perform any I/O, it simply informs this object that some amount of + /// its buffer, returned from [`poll_fill_buf`], has been consumed and should + /// no longer be returned. As such, this function may do odd things if + /// [`poll_fill_buf`] isn't called before calling it. + /// + /// The `amt` must be `<=` the number of bytes in the buffer returned by + /// [`poll_fill_buf`]. + /// + /// [`poll_read`]: AsyncRead::poll_read + /// [`poll_fill_buf`]: AsyncBufRead::poll_fill_buf + fn consume(self: Pin<&mut Self>, amt: usize); +} + +macro_rules! deref_async_read { + () => { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + Pin::new(&mut **self).poll_read(cx, buf) + } + + fn poll_read_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &mut [IoSliceMut<'_>], + ) -> Poll> { + Pin::new(&mut **self).poll_read_vectored(cx, bufs) + } + }; +} + +impl AsyncRead for Box { + deref_async_read!(); +} + +impl AsyncRead for &mut T { + deref_async_read!(); +} + +impl

AsyncRead for Pin

+where + P: DerefMut + Unpin, + P::Target: AsyncRead, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + self.get_mut().as_mut().poll_read(cx, buf) + } + + fn poll_read_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &mut [IoSliceMut<'_>], + ) -> Poll> { + self.get_mut().as_mut().poll_read_vectored(cx, bufs) + } +} + +macro_rules! delegate_async_read_to_stdio { + () => { + fn poll_read( + mut self: Pin<&mut Self>, + _: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + Poll::Ready(io::Read::read(&mut *self, buf)) + } + + fn poll_read_vectored( + mut self: Pin<&mut Self>, + _: &mut Context<'_>, + bufs: &mut [IoSliceMut<'_>], + ) -> Poll> { + Poll::Ready(io::Read::read_vectored(&mut *self, bufs)) + } + }; +} + +impl AsyncRead for &[u8] { + delegate_async_read_to_stdio!(); +} + +macro_rules! deref_async_write { + () => { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut **self).poll_write(cx, buf) + } + + fn poll_write_vectored( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + Pin::new(&mut **self).poll_write_vectored(cx, bufs) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut **self).poll_flush(cx) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut **self).poll_close(cx) + } + }; +} + +impl AsyncWrite for Box { + deref_async_write!(); +} + +impl AsyncWrite for &mut T { + deref_async_write!(); +} + +impl

AsyncWrite for Pin

+where + P: DerefMut + Unpin, + P::Target: AsyncWrite, +{ + fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + self.get_mut().as_mut().poll_write(cx, buf) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + self.get_mut().as_mut().poll_write_vectored(cx, bufs) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.get_mut().as_mut().poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.get_mut().as_mut().poll_close(cx) + } +} + +macro_rules! delegate_async_write_to_stdio { + () => { + fn poll_write( + mut self: Pin<&mut Self>, + _: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Poll::Ready(io::Write::write(&mut *self, buf)) + } + + fn poll_write_vectored( + mut self: Pin<&mut Self>, + _: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + Poll::Ready(io::Write::write_vectored(&mut *self, bufs)) + } + + fn poll_flush(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + Poll::Ready(io::Write::flush(&mut *self)) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.poll_flush(cx) + } + }; +} + +impl AsyncWrite for Vec { + delegate_async_write_to_stdio!(); +} + +// macro_rules! deref_async_seek { +// () => { +// fn poll_seek( +// mut self: Pin<&mut Self>, +// cx: &mut Context<'_>, +// pos: SeekFrom, +// ) -> Poll> { +// Pin::new(&mut **self).poll_seek(cx, pos) +// } +// }; +// } + +// impl AsyncSeek for Box { +// deref_async_seek!(); +// } + +// impl AsyncSeek for &mut T { +// deref_async_seek!(); +// } + +// impl

AsyncSeek for Pin

+// where +// P: DerefMut + Unpin, +// P::Target: AsyncSeek, +// { +// fn poll_seek( +// self: Pin<&mut Self>, +// cx: &mut Context<'_>, +// pos: SeekFrom, +// ) -> Poll> { +// self.get_mut().as_mut().poll_seek(cx, pos) +// } +// } + +// macro_rules! deref_async_buf_read { +// () => { +// fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { +// Pin::new(&mut **self.get_mut()).poll_fill_buf(cx) +// } + +// fn consume(mut self: Pin<&mut Self>, amt: usize) { +// Pin::new(&mut **self).consume(amt) +// } +// }; +// } + +// impl AsyncBufRead for Box { +// deref_async_buf_read!(); +// } + +// impl AsyncBufRead for &mut T { +// deref_async_buf_read!(); +// } + +// impl

AsyncBufRead for Pin

+// where +// P: DerefMut + Unpin, +// P::Target: AsyncBufRead, +// { +// fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { +// self.get_mut().as_mut().poll_fill_buf(cx) +// } + +// fn consume(self: Pin<&mut Self>, amt: usize) { +// self.get_mut().as_mut().consume(amt) +// } +// } + +// macro_rules! delegate_async_buf_read_to_stdio { +// () => { +// fn poll_fill_buf(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { +// Poll::Ready(io::BufRead::fill_buf(self.get_mut())) +// } + +// fn consume(self: Pin<&mut Self>, amt: usize) { +// io::BufRead::consume(self.get_mut(), amt) +// } +// }; +// } + +// impl AsyncBufRead for &[u8] { +// delegate_async_buf_read_to_stdio!(); +// } From fd80aa0719f990bacb773726764f567c94330b03 Mon Sep 17 00:00:00 2001 From: Jiaqi Gao Date: Tue, 8 Aug 2023 10:22:37 -0400 Subject: [PATCH 2/2] add `async_rustls` This crate inherits from the `async-rustls` crate and replaces its dependency on std with 'rust-std-stub' using the prelude. Signed-off-by: Jiaqi Gao --- Cargo.lock | 16 + Cargo.toml | 1 + src/async_rustls/Cargo.toml | 15 + src/async_rustls/src/client.rs | 220 +++++++++++ src/async_rustls/src/common/handshake.rs | 70 ++++ src/async_rustls/src/common/mod.rs | 360 ++++++++++++++++++ src/async_rustls/src/lib.rs | 450 +++++++++++++++++++++++ src/async_rustls/src/server.rs | 122 ++++++ 8 files changed, 1254 insertions(+) create mode 100644 src/async_rustls/Cargo.toml create mode 100644 src/async_rustls/src/client.rs create mode 100644 src/async_rustls/src/common/handshake.rs create mode 100644 src/async_rustls/src/common/mod.rs create mode 100644 src/async_rustls/src/lib.rs create mode 100644 src/async_rustls/src/server.rs diff --git a/Cargo.lock b/Cargo.lock index f79efe08..5088883a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -68,6 +68,15 @@ version = "1.0.72" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3b13c32d80ecc7ab747b80c3784bce54ee8a7a0cc4fbda9bf4cda2cf6fe90854" +[[package]] +name = "async_rustls" +version = "0.1.0" +dependencies = [ + "futures-io", + "rust_std_stub", + "rustls", +] + [[package]] name = "atomic_refcell" version = "0.1.10" @@ -345,6 +354,13 @@ dependencies = [ "libc", ] +[[package]] +name = "futures-io" +version = "0.1.0" +dependencies = [ + "rust_std_stub", +] + [[package]] name = "generic-array" version = "0.14.7" diff --git a/Cargo.toml b/Cargo.toml index e1d08b9f..e71596dc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,7 @@ [workspace] members = [ + "src/async_rustls", "src/attestation", "src/crypto", "src/devices/pci", diff --git a/src/async_rustls/Cargo.toml b/src/async_rustls/Cargo.toml new file mode 100644 index 00000000..6e576896 --- /dev/null +++ b/src/async_rustls/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "async_rustls" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +rust_std_stub = { path = "../std-support/rust-std-stub" } +futures-io = { path = "../std-support/futures-io" } +rustls = { path = "../../deps/rustls/rustls", default-features = false, features = ["no_std", "alloc"] } + +[features] +dangerous_configuration = ["rustls/dangerous_configuration"] +early-data = [] diff --git a/src/async_rustls/src/client.rs b/src/async_rustls/src/client.rs new file mode 100644 index 00000000..fe6bfbd0 --- /dev/null +++ b/src/async_rustls/src/client.rs @@ -0,0 +1,220 @@ +use super::*; +use crate::common::IoSession; + +/// A wrapper around an underlying raw stream which implements the TLS or SSL +/// protocol. +#[derive(Debug)] +pub struct TlsStream { + pub(crate) io: IO, + pub(crate) session: ClientConnection, + pub(crate) state: TlsState, + + #[cfg(feature = "early-data")] + pub(crate) early_waker: Option, +} + +impl TlsStream { + #[inline] + pub fn get_ref(&self) -> (&IO, &ClientConnection) { + (&self.io, &self.session) + } + + #[inline] + pub fn get_mut(&mut self) -> (&mut IO, &mut ClientConnection) { + (&mut self.io, &mut self.session) + } + + #[inline] + pub fn into_inner(self) -> (IO, ClientConnection) { + (self.io, self.session) + } +} + +impl IoSession for TlsStream { + type Io = IO; + type Session = ClientConnection; + + #[inline] + fn skip_handshake(&self) -> bool { + self.state.is_early_data() + } + + #[inline] + fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session) { + (&mut self.state, &mut self.io, &mut self.session) + } + + #[inline] + fn into_io(self) -> Self::Io { + self.io + } +} + +impl AsyncRead for TlsStream +where + IO: AsyncRead + AsyncWrite + Unpin, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + match self.state { + #[cfg(feature = "early-data")] + TlsState::EarlyData(..) => { + let this = self.get_mut(); + + // In the EarlyData state, we have not really established a Tls connection. + // Before writing data through `AsyncWrite` and completing the tls handshake, + // we ignore read readiness and return to pending. + // + // In order to avoid event loss, + // we need to register a waker and wake it up after tls is connected. + if this + .early_waker + .as_ref() + .filter(|waker| cx.waker().will_wake(waker)) + .is_none() + { + this.early_waker = Some(cx.waker().clone()); + } + + Poll::Pending + } + TlsState::Stream | TlsState::WriteShutdown => { + let this = self.get_mut(); + let mut stream = + Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); + + match stream.as_mut_pin().poll_read(cx, buf) { + Poll::Ready(Ok(n)) => { + if n == 0 || stream.eof { + this.state.shutdown_read(); + } + + Poll::Ready(Ok(n)) + } + Poll::Ready(Err(err)) if err.kind() == io::ErrorKind::ConnectionAborted => { + this.state.shutdown_read(); + Poll::Ready(Err(err)) + } + output => output, + } + } + TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(0)), + } + } +} + +impl AsyncWrite for TlsStream +where + IO: AsyncRead + AsyncWrite + Unpin, +{ + /// Note: that it does not guarantee the final data to be sent. + /// To be cautious, you must manually call `flush`. + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let this = self.get_mut(); + let mut stream = + Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); + + #[allow(clippy::match_single_binding)] + match this.state { + #[cfg(feature = "early-data")] + TlsState::EarlyData(ref mut pos, ref mut data) => { + use rust_std_stub::io::Write; + + // write early data + if let Some(mut early_data) = stream.session.early_data() { + let len = match early_data.write(buf) { + Ok(n) => n, + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { + return Poll::Pending + } + Err(err) => return Poll::Ready(Err(err)), + }; + if len != 0 { + data.extend_from_slice(&buf[..len]); + return Poll::Ready(Ok(len)); + } + } + + // complete handshake + while stream.session.is_handshaking() { + ready!(stream.handshake(cx))?; + } + + // write early data (fallback) + if !stream.session.is_early_data_accepted() { + while *pos < data.len() { + let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?; + *pos += len; + } + } + + // end + this.state = TlsState::Stream; + + if let Some(waker) = this.early_waker.take() { + waker.wake(); + } + + stream.as_mut_pin().poll_write(cx, buf) + } + _ => stream.as_mut_pin().poll_write(cx, buf), + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + let mut stream = + Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); + + #[cfg(feature = "early-data")] + { + if let TlsState::EarlyData(ref mut pos, ref mut data) = this.state { + // complete handshake + while stream.session.is_handshaking() { + ready!(stream.handshake(cx))?; + } + + // write early data (fallback) + if !stream.session.is_early_data_accepted() { + while *pos < data.len() { + let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?; + *pos += len; + } + } + + this.state = TlsState::Stream; + + if let Some(waker) = this.early_waker.take() { + waker.wake(); + } + } + } + + stream.as_mut_pin().poll_flush(cx) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + // complete handshake + #[cfg(feature = "early-data")] + if matches!(self.state, TlsState::EarlyData(..)) { + ready!(self.as_mut().poll_flush(cx))?; + } + + if self.state.writeable() { + self.session.send_close_notify(); + self.state.shutdown_write(); + } + + let this = self.get_mut(); + let mut stream = + Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); + stream.as_mut_pin().poll_close(cx) + } +} diff --git a/src/async_rustls/src/common/handshake.rs b/src/async_rustls/src/common/handshake.rs new file mode 100644 index 00000000..fee2d3f5 --- /dev/null +++ b/src/async_rustls/src/common/handshake.rs @@ -0,0 +1,70 @@ +use crate::common::{Stream, TlsState}; +use core::future::Future; +use core::ops::{Deref, DerefMut}; +use core::pin::Pin; +use core::task::{Context, Poll}; +use futures_io::{AsyncRead, AsyncWrite}; +use rust_std_stub::{io, mem}; +use rustls::{ConnectionCommon, SideData}; + +pub(crate) trait IoSession { + type Io; + type Session; + + fn skip_handshake(&self) -> bool; + fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session); + fn into_io(self) -> Self::Io; +} + +pub(crate) enum MidHandshake { + Handshaking(IS), + End, + Error { io: IS::Io, error: io::Error }, +} + +impl Future for MidHandshake +where + IS: IoSession + Unpin, + IS::Io: AsyncRead + AsyncWrite + Unpin, + IS::Session: DerefMut + Deref> + Unpin, + SD: SideData, +{ + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + + let mut stream = match mem::replace(this, MidHandshake::End) { + MidHandshake::Handshaking(stream) => stream, + // Starting the handshake returned an error; fail the future immediately. + MidHandshake::Error { io, error } => return Poll::Ready(Err((error, io))), + _ => panic!("unexpected polling after handshake"), + }; + + if !stream.skip_handshake() { + let (state, io, session) = stream.get_mut(); + let mut tls_stream = Stream::new(io, session).set_eof(!state.readable()); + + macro_rules! try_poll { + ( $e:expr ) => { + match $e { + Poll::Ready(Ok(_)) => (), + Poll::Ready(Err(err)) => return Poll::Ready(Err((err, stream.into_io()))), + Poll::Pending => { + *this = MidHandshake::Handshaking(stream); + return Poll::Pending; + } + } + }; + } + + while tls_stream.session.is_handshaking() { + try_poll!(tls_stream.handshake(cx)); + } + + try_poll!(Pin::new(&mut tls_stream).poll_flush(cx)); + } + + Poll::Ready(Ok(stream)) + } +} diff --git a/src/async_rustls/src/common/mod.rs b/src/async_rustls/src/common/mod.rs new file mode 100644 index 00000000..860376e6 --- /dev/null +++ b/src/async_rustls/src/common/mod.rs @@ -0,0 +1,360 @@ +mod handshake; + +#[cfg(feature = "early-data")] +use alloc::vec::Vec; +use core::ops::{Deref, DerefMut}; +use core::pin::Pin; +use core::task::{Context, Poll}; +use futures_io::{AsyncRead, AsyncWrite}; +pub(crate) use handshake::{IoSession, MidHandshake}; +use rust_std_stub::io::{self, IoSlice, Read, Write}; +use rustls::{ConnectionCommon, SideData}; + +#[derive(Debug)] +pub enum TlsState { + #[cfg(feature = "early-data")] + EarlyData(usize, Vec), + Stream, + ReadShutdown, + WriteShutdown, + FullyShutdown, +} + +impl TlsState { + #[inline] + pub fn shutdown_read(&mut self) { + match *self { + TlsState::WriteShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown, + _ => *self = TlsState::ReadShutdown, + } + } + + #[inline] + pub fn shutdown_write(&mut self) { + match *self { + TlsState::ReadShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown, + _ => *self = TlsState::WriteShutdown, + } + } + + #[inline] + pub fn writeable(&self) -> bool { + !matches!(*self, TlsState::WriteShutdown | TlsState::FullyShutdown) + } + + #[inline] + pub fn readable(&self) -> bool { + !matches!(*self, TlsState::ReadShutdown | TlsState::FullyShutdown) + } + + #[inline] + #[cfg(feature = "early-data")] + pub fn is_early_data(&self) -> bool { + matches!(self, TlsState::EarlyData(..)) + } + + #[inline] + #[cfg(not(feature = "early-data"))] + pub fn is_early_data(&self) -> bool { + false + } +} + +pub struct Stream<'a, IO, S> { + pub io: &'a mut IO, + pub session: &'a mut S, + pub eof: bool, +} + +impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S, SD> Stream<'a, IO, S> +where + S: DerefMut + Deref>, + SD: SideData, +{ + pub fn new(io: &'a mut IO, session: &'a mut S) -> Self { + Stream { + io, + session, + // The state so far is only used to detect EOF, so either Stream + // or EarlyData state should both be all right. + eof: false, + } + } + + pub fn set_eof(mut self, eof: bool) -> Self { + self.eof = eof; + self + } + + pub fn as_mut_pin(&mut self) -> Pin<&mut Self> { + Pin::new(self) + } + + pub fn read_io(&mut self, cx: &mut Context) -> Poll> { + let mut reader = SyncReadAdapter { io: self.io, cx }; + + let n = match self.session.read_tls(&mut reader) { + Ok(n) => n, + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => return Poll::Pending, + Err(err) => return Poll::Ready(Err(err)), + }; + + let stats = self.session.process_new_packets().map_err(|err| { + // In case we have an alert to send describing this error, + // try a last-gasp write -- but don't predate the primary + // error. + let _ = self.write_io(cx); + + io::Error::new(io::ErrorKind::InvalidData, err) + })?; + + if stats.peer_has_closed() && self.session.is_handshaking() { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + "tls handshake alert", + ))); + } + + Poll::Ready(Ok(n)) + } + + pub fn write_io(&mut self, cx: &mut Context) -> Poll> { + struct Writer<'a, 'b, T> { + io: &'a mut T, + cx: &'a mut Context<'b>, + } + + impl<'a, 'b, T: Unpin> Writer<'a, 'b, T> { + #[inline] + fn poll_with( + &mut self, + f: impl FnOnce(Pin<&mut T>, &mut Context<'_>) -> Poll>, + ) -> io::Result { + match f(Pin::new(self.io), self.cx) { + Poll::Ready(result) => result, + Poll::Pending => Err(io::ErrorKind::WouldBlock.into()), + } + } + } + + impl<'a, 'b, T: AsyncWrite + Unpin> Write for Writer<'a, 'b, T> { + #[inline] + fn write(&mut self, buf: &[u8]) -> io::Result { + self.poll_with(|io, cx| io.poll_write(cx, buf)) + } + + #[inline] + fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result { + self.poll_with(|io, cx| io.poll_write_vectored(cx, bufs)) + } + + #[inline] + fn flush(&mut self) -> io::Result<()> { + self.poll_with(|io, cx| io.poll_flush(cx)) + } + } + + let mut writer = Writer { io: self.io, cx }; + + match self.session.write_tls(&mut writer) { + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending, + result => Poll::Ready(result), + } + } + + pub fn handshake(&mut self, cx: &mut Context) -> Poll> { + let mut wrlen = 0; + let mut rdlen = 0; + + loop { + let mut write_would_block = false; + let mut read_would_block = false; + let mut need_flush = false; + + while self.session.wants_write() { + match self.write_io(cx) { + Poll::Ready(Ok(n)) => { + wrlen += n; + need_flush = true; + } + Poll::Pending => { + write_would_block = true; + break; + } + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + } + } + + if need_flush { + match Pin::new(&mut self.io).poll_flush(cx) { + Poll::Ready(Ok(())) => (), + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + Poll::Pending => write_would_block = true, + } + } + + while !self.eof && self.session.wants_read() { + match self.read_io(cx) { + Poll::Ready(Ok(0)) => self.eof = true, + Poll::Ready(Ok(n)) => rdlen += n, + Poll::Pending => { + read_would_block = true; + break; + } + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + } + } + + return match (self.eof, self.session.is_handshaking()) { + (true, true) => { + let err = io::Error::new(io::ErrorKind::UnexpectedEof, "tls handshake eof"); + Poll::Ready(Err(err)) + } + (_, false) => Poll::Ready(Ok((rdlen, wrlen))), + (_, true) if write_would_block || read_would_block => { + if rdlen != 0 || wrlen != 0 { + Poll::Ready(Ok((rdlen, wrlen))) + } else { + Poll::Pending + } + } + (..) => continue, + }; + } + } +} + +impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S, SD> AsyncRead for Stream<'a, IO, S> +where + S: DerefMut + Deref>, + SD: SideData, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + let mut io_pending = false; + + // read a packet + while !self.eof && self.session.wants_read() { + match self.read_io(cx) { + Poll::Ready(Ok(0)) => { + self.eof = true; + break; + } + Poll::Ready(Ok(_)) => (), + Poll::Pending => { + io_pending = true; + break; + } + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + } + } + + match self.session.reader().read(buf) { + // If Rustls returns `Ok(0)` (while `buf` is non-empty), the peer closed the + // connection with a `CloseNotify` message and no more data will be forthcoming. + // + // Rustls yielded more data: advance the buffer, then see if more data is coming. + // + // We don't need to modify `self.eof` here, because it is only a temporary mark. + // rustls will only return 0 if is has received `CloseNotify`, + // in which case no additional processing is required. + Ok(n) => Poll::Ready(Ok(n)), + + // Rustls doesn't have more data to yield, but it believes the connection is open. + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { + if !io_pending { + // If `wants_read()` is satisfied, rustls will not return `WouldBlock`. + // but if it does, we can try again. + // + // If the rustls state is abnormal, it may cause a cyclic wakeup. + // but tokio's cooperative budget will prevent infinite wakeup. + cx.waker().wake_by_ref(); + } + + Poll::Pending + } + + Err(err) => Poll::Ready(Err(err)), + } + } +} + +impl<'a, IO: AsyncRead + AsyncWrite + Unpin, C, SD> AsyncWrite for Stream<'a, IO, C> +where + C: DerefMut + Deref>, + SD: SideData, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context, + buf: &[u8], + ) -> Poll> { + let mut pos = 0; + + while pos != buf.len() { + let mut would_block = false; + + match self.session.writer().write(&buf[pos..]) { + Ok(n) => pos += n, + Err(err) => return Poll::Ready(Err(err)), + }; + + while self.session.wants_write() { + match self.write_io(cx) { + Poll::Ready(Ok(0)) | Poll::Pending => { + would_block = true; + break; + } + Poll::Ready(Ok(_)) => (), + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + } + } + + return match (pos, would_block) { + (0, true) => Poll::Pending, + (n, true) => Poll::Ready(Ok(n)), + (_, false) => continue, + }; + } + + Poll::Ready(Ok(pos)) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.session.writer().flush()?; + while self.session.wants_write() { + ready!(self.write_io(cx))?; + } + Pin::new(&mut self.io).poll_flush(cx) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + while self.session.wants_write() { + ready!(self.write_io(cx))?; + } + Pin::new(&mut self.io).poll_close(cx) + } +} + +/// An adapter that implements a [`Read`] interface for [`AsyncRead`] types and an +/// associated [`Context`]. +/// +/// Turns `Poll::Pending` into `WouldBlock`. +pub struct SyncReadAdapter<'a, 'b, T> { + pub io: &'a mut T, + pub cx: &'a mut Context<'b>, +} + +impl<'a, 'b, T: AsyncRead + Unpin> Read for SyncReadAdapter<'a, 'b, T> { + #[inline] + fn read(&mut self, buf: &mut [u8]) -> io::Result { + match Pin::new(&mut self.io).poll_read(self.cx, buf) { + Poll::Ready(Ok(n)) => Ok(n), + Poll::Ready(Err(err)) => Err(err), + Poll::Pending => Err(io::ErrorKind::WouldBlock.into()), + } + } +} diff --git a/src/async_rustls/src/lib.rs b/src/async_rustls/src/lib.rs new file mode 100644 index 00000000..57ebd43c --- /dev/null +++ b/src/async_rustls/src/lib.rs @@ -0,0 +1,450 @@ +#![no_std] + +extern crate alloc; + +macro_rules! ready { + ( $e:expr ) => { + match $e { + core::task::Poll::Ready(t) => t, + core::task::Poll::Pending => return core::task::Poll::Pending, + } + }; +} + +pub mod client; +mod common; +pub mod server; + +use alloc::sync::Arc; +#[cfg(feature = "early-data")] +use alloc::vec::Vec; +use common::{MidHandshake, Stream, TlsState}; +use core::future::Future; +use core::pin::Pin; +use core::task::{Context, Poll}; +use futures_io::{AsyncRead, AsyncWrite}; +use rust_std_stub::io; +use rustls::crypto::ring::Ring; +use rustls::{ClientConfig, ClientConnection, CommonState, ServerConfig, ServerConnection}; + +pub use rustls; + +/// A wrapper around a `rustls::ClientConfig`, providing an async `connect` method. +#[derive(Clone)] +pub struct TlsConnector { + inner: Arc>, + #[cfg(feature = "early-data")] + early_data: bool, +} + +/// A wrapper around a `rustls::ServerConfig`, providing an async `accept` method. +#[derive(Clone)] +pub struct TlsAcceptor { + inner: Arc>, +} + +impl From>> for TlsConnector { + fn from(inner: Arc>) -> TlsConnector { + TlsConnector { + inner, + #[cfg(feature = "early-data")] + early_data: false, + } + } +} + +impl From>> for TlsAcceptor { + fn from(inner: Arc>) -> TlsAcceptor { + TlsAcceptor { inner } + } +} + +impl TlsConnector { + /// Enable 0-RTT. + /// + /// If you want to use 0-RTT, + /// You must also set `ClientConfig.enable_early_data` to `true`. + #[cfg(feature = "early-data")] + pub fn early_data(mut self, flag: bool) -> TlsConnector { + self.early_data = flag; + self + } + + #[inline] + pub fn connect(&self, domain: rustls::ServerName, stream: IO) -> Connect + where + IO: AsyncRead + AsyncWrite + Unpin, + { + self.connect_with(domain, stream, |_| ()) + } + + pub fn connect_with(&self, domain: rustls::ServerName, stream: IO, f: F) -> Connect + where + IO: AsyncRead + AsyncWrite + Unpin, + F: FnOnce(&mut ClientConnection), + { + let mut session = match ClientConnection::new(self.inner.clone(), domain) { + Ok(session) => session, + Err(error) => { + return Connect(MidHandshake::Error { + io: stream, + // TODO(eliza): should this really return an `io::Error`? + // Probably not... + error: io::Error::new(io::ErrorKind::Other, error), + }); + } + }; + f(&mut session); + + Connect(MidHandshake::Handshaking(client::TlsStream { + io: stream, + + #[cfg(not(feature = "early-data"))] + state: TlsState::Stream, + + #[cfg(feature = "early-data")] + state: if self.early_data && session.early_data().is_some() { + TlsState::EarlyData(0, Vec::new()) + } else { + TlsState::Stream + }, + + #[cfg(feature = "early-data")] + early_waker: None, + + session, + })) + } +} + +impl TlsAcceptor { + #[inline] + pub fn accept(&self, stream: IO) -> Accept + where + IO: AsyncRead + AsyncWrite + Unpin, + { + self.accept_with(stream, |_| ()) + } + + pub fn accept_with(&self, stream: IO, f: F) -> Accept + where + IO: AsyncRead + AsyncWrite + Unpin, + F: FnOnce(&mut ServerConnection), + { + let mut session = match ServerConnection::new(self.inner.clone()) { + Ok(session) => session, + Err(error) => { + return Accept(MidHandshake::Error { + io: stream, + // TODO(eliza): should this really return an `io::Error`? + // Probably not... + error: io::Error::new(io::ErrorKind::Other, error), + }); + } + }; + f(&mut session); + + Accept(MidHandshake::Handshaking(server::TlsStream { + session, + io: stream, + state: TlsState::Stream, + })) + } +} + +pub struct LazyConfigAcceptor { + acceptor: rustls::server::Acceptor, + io: Option, +} + +impl LazyConfigAcceptor +where + IO: AsyncRead + AsyncWrite + Unpin, +{ + #[inline] + pub fn new(acceptor: rustls::server::Acceptor, io: IO) -> Self { + Self { + acceptor, + io: Some(io), + } + } +} + +impl Future for LazyConfigAcceptor +where + IO: AsyncRead + AsyncWrite + Unpin, +{ + type Output = Result, io::Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + loop { + let io = match this.io.as_mut() { + Some(io) => io, + None => { + panic!("Acceptor cannot be polled after acceptance."); + } + }; + + let mut reader = common::SyncReadAdapter { io, cx }; + match this.acceptor.read_tls(&mut reader) { + Ok(0) => return Poll::Ready(Err(io::ErrorKind::UnexpectedEof.into())), + Ok(_) => {} + Err(e) if e.kind() == io::ErrorKind::WouldBlock => return Poll::Pending, + Err(e) => return Poll::Ready(Err(e)), + } + + match this.acceptor.accept() { + Ok(Some(accepted)) => { + let io = this.io.take().unwrap(); + return Poll::Ready(Ok(StartHandshake { accepted, io })); + } + Ok(None) => continue, + Err(err) => { + return Poll::Ready(Err(io::Error::new(io::ErrorKind::InvalidInput, err))) + } + } + } + } +} + +pub struct StartHandshake { + accepted: rustls::server::Accepted, + io: IO, +} + +impl StartHandshake +where + IO: AsyncRead + AsyncWrite + Unpin, +{ + pub fn client_hello(&self) -> rustls::server::ClientHello<'_> { + self.accepted.client_hello() + } + + pub fn into_stream(self, config: Arc>) -> Accept { + self.into_stream_with(config, |_| ()) + } + + pub fn into_stream_with(self, config: Arc>, f: F) -> Accept + where + F: FnOnce(&mut ServerConnection), + { + let mut conn = match self.accepted.into_connection(config) { + Ok(conn) => conn, + Err(error) => { + return Accept(MidHandshake::Error { + io: self.io, + // TODO(eliza): should this really return an `io::Error`? + // Probably not... + error: io::Error::new(io::ErrorKind::Other, error), + }); + } + }; + f(&mut conn); + + Accept(MidHandshake::Handshaking(server::TlsStream { + session: conn, + io: self.io, + state: TlsState::Stream, + })) + } +} + +/// Future returned from `TlsConnector::connect` which will resolve +/// once the connection handshake has finished. +pub struct Connect(MidHandshake>); + +/// Future returned from `TlsAcceptor::accept` which will resolve +/// once the accept handshake has finished. +pub struct Accept(MidHandshake>); + +/// Like [Connect], but returns `IO` on failure. +pub struct FallibleConnect(MidHandshake>); + +/// Like [Accept], but returns `IO` on failure. +pub struct FallibleAccept(MidHandshake>); + +impl Connect { + #[inline] + pub fn into_fallible(self) -> FallibleConnect { + FallibleConnect(self.0) + } + + pub fn get_ref(&self) -> Option<&IO> { + match &self.0 { + MidHandshake::Handshaking(sess) => Some(sess.get_ref().0), + MidHandshake::Error { io, .. } => Some(io), + MidHandshake::End => None, + } + } + + pub fn get_mut(&mut self) -> Option<&mut IO> { + match &mut self.0 { + MidHandshake::Handshaking(sess) => Some(sess.get_mut().0), + MidHandshake::Error { io, .. } => Some(io), + MidHandshake::End => None, + } + } +} + +impl Accept { + #[inline] + pub fn into_fallible(self) -> FallibleAccept { + FallibleAccept(self.0) + } + + pub fn get_ref(&self) -> Option<&IO> { + match &self.0 { + MidHandshake::Handshaking(sess) => Some(sess.get_ref().0), + MidHandshake::Error { io, .. } => Some(io), + MidHandshake::End => None, + } + } + + pub fn get_mut(&mut self) -> Option<&mut IO> { + match &mut self.0 { + MidHandshake::Handshaking(sess) => Some(sess.get_mut().0), + MidHandshake::Error { io, .. } => Some(io), + MidHandshake::End => None, + } + } +} + +impl Future for Connect { + type Output = io::Result>; + + #[inline] + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err) + } +} + +impl Future for Accept { + type Output = io::Result>; + + #[inline] + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err) + } +} + +impl Future for FallibleConnect { + type Output = Result, (io::Error, IO)>; + + #[inline] + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + Pin::new(&mut self.0).poll(cx) + } +} + +impl Future for FallibleAccept { + type Output = Result, (io::Error, IO)>; + + #[inline] + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + Pin::new(&mut self.0).poll(cx) + } +} + +/// Unified TLS stream type +/// +/// This abstracts over the inner `client::TlsStream` and `server::TlsStream`, so you can use +/// a single type to keep both client- and server-initiated TLS-encrypted connections. +#[allow(clippy::large_enum_variant)] // https://github.com/rust-lang/rust-clippy/issues/9798 +#[derive(Debug)] +pub enum TlsStream { + Client(client::TlsStream), + Server(server::TlsStream), +} + +impl TlsStream { + pub fn get_ref(&self) -> (&T, &CommonState) { + use TlsStream::*; + match self { + Client(io) => { + let (io, session) = io.get_ref(); + (io, session) + } + Server(io) => { + let (io, session) = io.get_ref(); + (io, session) + } + } + } + + pub fn get_mut(&mut self) -> (&mut T, &mut CommonState) { + use TlsStream::*; + match self { + Client(io) => { + let (io, session) = io.get_mut(); + (io, &mut *session) + } + Server(io) => { + let (io, session) = io.get_mut(); + (io, &mut *session) + } + } + } +} + +impl From> for TlsStream { + fn from(s: client::TlsStream) -> Self { + Self::Client(s) + } +} + +impl From> for TlsStream { + fn from(s: server::TlsStream) -> Self { + Self::Server(s) + } +} + +impl AsyncRead for TlsStream +where + T: AsyncRead + AsyncWrite + Unpin, +{ + #[inline] + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + match self.get_mut() { + TlsStream::Client(x) => Pin::new(x).poll_read(cx, buf), + TlsStream::Server(x) => Pin::new(x).poll_read(cx, buf), + } + } +} + +impl AsyncWrite for TlsStream +where + T: AsyncRead + AsyncWrite + Unpin, +{ + #[inline] + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + match self.get_mut() { + TlsStream::Client(x) => Pin::new(x).poll_write(cx, buf), + TlsStream::Server(x) => Pin::new(x).poll_write(cx, buf), + } + } + + #[inline] + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + TlsStream::Client(x) => Pin::new(x).poll_flush(cx), + TlsStream::Server(x) => Pin::new(x).poll_flush(cx), + } + } + + #[inline] + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + TlsStream::Client(x) => Pin::new(x).poll_close(cx), + TlsStream::Server(x) => Pin::new(x).poll_close(cx), + } + } +} diff --git a/src/async_rustls/src/server.rs b/src/async_rustls/src/server.rs new file mode 100644 index 00000000..b5476b45 --- /dev/null +++ b/src/async_rustls/src/server.rs @@ -0,0 +1,122 @@ +use super::*; +use crate::common::IoSession; + +/// A wrapper around an underlying raw stream which implements the TLS or SSL +/// protocol. +#[derive(Debug)] +pub struct TlsStream { + pub(crate) io: IO, + pub(crate) session: ServerConnection, + pub(crate) state: TlsState, +} + +impl TlsStream { + #[inline] + pub fn get_ref(&self) -> (&IO, &ServerConnection) { + (&self.io, &self.session) + } + + #[inline] + pub fn get_mut(&mut self) -> (&mut IO, &mut ServerConnection) { + (&mut self.io, &mut self.session) + } + + #[inline] + pub fn into_inner(self) -> (IO, ServerConnection) { + (self.io, self.session) + } +} + +impl IoSession for TlsStream { + type Io = IO; + type Session = ServerConnection; + + #[inline] + fn skip_handshake(&self) -> bool { + false + } + + #[inline] + fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session) { + (&mut self.state, &mut self.io, &mut self.session) + } + + #[inline] + fn into_io(self) -> Self::Io { + self.io + } +} + +impl AsyncRead for TlsStream +where + IO: AsyncRead + AsyncWrite + Unpin, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + let this = self.get_mut(); + let mut stream = + Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); + + match &this.state { + TlsState::Stream | TlsState::WriteShutdown => { + match stream.as_mut_pin().poll_read(cx, buf) { + Poll::Ready(Ok(n)) => { + if n == 0 || stream.eof { + this.state.shutdown_read(); + } + + Poll::Ready(Ok(n)) + } + Poll::Ready(Err(err)) if err.kind() == io::ErrorKind::UnexpectedEof => { + this.state.shutdown_read(); + Poll::Ready(Err(err)) + } + output => output, + } + } + TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(0)), + #[cfg(feature = "early-data")] + s => unreachable!("server TLS can not hit this state: {:?}", s), + } + } +} + +impl AsyncWrite for TlsStream +where + IO: AsyncRead + AsyncWrite + Unpin, +{ + /// Note: that it does not guarantee the final data to be sent. + /// To be cautious, you must manually call `flush`. + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let this = self.get_mut(); + let mut stream = + Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); + stream.as_mut_pin().poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + let mut stream = + Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); + stream.as_mut_pin().poll_flush(cx) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if self.state.writeable() { + self.session.send_close_notify(); + self.state.shutdown_write(); + } + + let this = self.get_mut(); + let mut stream = + Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); + stream.as_mut_pin().poll_close(cx) + } +}