diff --git a/Cargo.lock b/Cargo.lock index f79efe08..5088883a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -68,6 +68,15 @@ version = "1.0.72" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3b13c32d80ecc7ab747b80c3784bce54ee8a7a0cc4fbda9bf4cda2cf6fe90854" +[[package]] +name = "async_rustls" +version = "0.1.0" +dependencies = [ + "futures-io", + "rust_std_stub", + "rustls", +] + [[package]] name = "atomic_refcell" version = "0.1.10" @@ -345,6 +354,13 @@ dependencies = [ "libc", ] +[[package]] +name = "futures-io" +version = "0.1.0" +dependencies = [ + "rust_std_stub", +] + [[package]] name = "generic-array" version = "0.14.7" diff --git a/Cargo.toml b/Cargo.toml index e1d08b9f..e71596dc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,7 @@ [workspace] members = [ + "src/async_rustls", "src/attestation", "src/crypto", "src/devices/pci", diff --git a/src/async_rustls/Cargo.toml b/src/async_rustls/Cargo.toml new file mode 100644 index 00000000..6e576896 --- /dev/null +++ b/src/async_rustls/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "async_rustls" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +rust_std_stub = { path = "../std-support/rust-std-stub" } +futures-io = { path = "../std-support/futures-io" } +rustls = { path = "../../deps/rustls/rustls", default-features = false, features = ["no_std", "alloc"] } + +[features] +dangerous_configuration = ["rustls/dangerous_configuration"] +early-data = [] diff --git a/src/async_rustls/src/client.rs b/src/async_rustls/src/client.rs new file mode 100644 index 00000000..fe6bfbd0 --- /dev/null +++ b/src/async_rustls/src/client.rs @@ -0,0 +1,220 @@ +use super::*; +use crate::common::IoSession; + +/// A wrapper around an underlying raw stream which implements the TLS or SSL +/// protocol. +#[derive(Debug)] +pub struct TlsStream { + pub(crate) io: IO, + pub(crate) session: ClientConnection, + pub(crate) state: TlsState, + + #[cfg(feature = "early-data")] + pub(crate) early_waker: Option, +} + +impl TlsStream { + #[inline] + pub fn get_ref(&self) -> (&IO, &ClientConnection) { + (&self.io, &self.session) + } + + #[inline] + pub fn get_mut(&mut self) -> (&mut IO, &mut ClientConnection) { + (&mut self.io, &mut self.session) + } + + #[inline] + pub fn into_inner(self) -> (IO, ClientConnection) { + (self.io, self.session) + } +} + +impl IoSession for TlsStream { + type Io = IO; + type Session = ClientConnection; + + #[inline] + fn skip_handshake(&self) -> bool { + self.state.is_early_data() + } + + #[inline] + fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session) { + (&mut self.state, &mut self.io, &mut self.session) + } + + #[inline] + fn into_io(self) -> Self::Io { + self.io + } +} + +impl AsyncRead for TlsStream +where + IO: AsyncRead + AsyncWrite + Unpin, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + match self.state { + #[cfg(feature = "early-data")] + TlsState::EarlyData(..) => { + let this = self.get_mut(); + + // In the EarlyData state, we have not really established a Tls connection. + // Before writing data through `AsyncWrite` and completing the tls handshake, + // we ignore read readiness and return to pending. + // + // In order to avoid event loss, + // we need to register a waker and wake it up after tls is connected. + if this + .early_waker + .as_ref() + .filter(|waker| cx.waker().will_wake(waker)) + .is_none() + { + this.early_waker = Some(cx.waker().clone()); + } + + Poll::Pending + } + TlsState::Stream | TlsState::WriteShutdown => { + let this = self.get_mut(); + let mut stream = + Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); + + match stream.as_mut_pin().poll_read(cx, buf) { + Poll::Ready(Ok(n)) => { + if n == 0 || stream.eof { + this.state.shutdown_read(); + } + + Poll::Ready(Ok(n)) + } + Poll::Ready(Err(err)) if err.kind() == io::ErrorKind::ConnectionAborted => { + this.state.shutdown_read(); + Poll::Ready(Err(err)) + } + output => output, + } + } + TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(0)), + } + } +} + +impl AsyncWrite for TlsStream +where + IO: AsyncRead + AsyncWrite + Unpin, +{ + /// Note: that it does not guarantee the final data to be sent. + /// To be cautious, you must manually call `flush`. + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let this = self.get_mut(); + let mut stream = + Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); + + #[allow(clippy::match_single_binding)] + match this.state { + #[cfg(feature = "early-data")] + TlsState::EarlyData(ref mut pos, ref mut data) => { + use rust_std_stub::io::Write; + + // write early data + if let Some(mut early_data) = stream.session.early_data() { + let len = match early_data.write(buf) { + Ok(n) => n, + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { + return Poll::Pending + } + Err(err) => return Poll::Ready(Err(err)), + }; + if len != 0 { + data.extend_from_slice(&buf[..len]); + return Poll::Ready(Ok(len)); + } + } + + // complete handshake + while stream.session.is_handshaking() { + ready!(stream.handshake(cx))?; + } + + // write early data (fallback) + if !stream.session.is_early_data_accepted() { + while *pos < data.len() { + let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?; + *pos += len; + } + } + + // end + this.state = TlsState::Stream; + + if let Some(waker) = this.early_waker.take() { + waker.wake(); + } + + stream.as_mut_pin().poll_write(cx, buf) + } + _ => stream.as_mut_pin().poll_write(cx, buf), + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + let mut stream = + Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); + + #[cfg(feature = "early-data")] + { + if let TlsState::EarlyData(ref mut pos, ref mut data) = this.state { + // complete handshake + while stream.session.is_handshaking() { + ready!(stream.handshake(cx))?; + } + + // write early data (fallback) + if !stream.session.is_early_data_accepted() { + while *pos < data.len() { + let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?; + *pos += len; + } + } + + this.state = TlsState::Stream; + + if let Some(waker) = this.early_waker.take() { + waker.wake(); + } + } + } + + stream.as_mut_pin().poll_flush(cx) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + // complete handshake + #[cfg(feature = "early-data")] + if matches!(self.state, TlsState::EarlyData(..)) { + ready!(self.as_mut().poll_flush(cx))?; + } + + if self.state.writeable() { + self.session.send_close_notify(); + self.state.shutdown_write(); + } + + let this = self.get_mut(); + let mut stream = + Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); + stream.as_mut_pin().poll_close(cx) + } +} diff --git a/src/async_rustls/src/common/handshake.rs b/src/async_rustls/src/common/handshake.rs new file mode 100644 index 00000000..fee2d3f5 --- /dev/null +++ b/src/async_rustls/src/common/handshake.rs @@ -0,0 +1,70 @@ +use crate::common::{Stream, TlsState}; +use core::future::Future; +use core::ops::{Deref, DerefMut}; +use core::pin::Pin; +use core::task::{Context, Poll}; +use futures_io::{AsyncRead, AsyncWrite}; +use rust_std_stub::{io, mem}; +use rustls::{ConnectionCommon, SideData}; + +pub(crate) trait IoSession { + type Io; + type Session; + + fn skip_handshake(&self) -> bool; + fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session); + fn into_io(self) -> Self::Io; +} + +pub(crate) enum MidHandshake { + Handshaking(IS), + End, + Error { io: IS::Io, error: io::Error }, +} + +impl Future for MidHandshake +where + IS: IoSession + Unpin, + IS::Io: AsyncRead + AsyncWrite + Unpin, + IS::Session: DerefMut + Deref> + Unpin, + SD: SideData, +{ + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + + let mut stream = match mem::replace(this, MidHandshake::End) { + MidHandshake::Handshaking(stream) => stream, + // Starting the handshake returned an error; fail the future immediately. + MidHandshake::Error { io, error } => return Poll::Ready(Err((error, io))), + _ => panic!("unexpected polling after handshake"), + }; + + if !stream.skip_handshake() { + let (state, io, session) = stream.get_mut(); + let mut tls_stream = Stream::new(io, session).set_eof(!state.readable()); + + macro_rules! try_poll { + ( $e:expr ) => { + match $e { + Poll::Ready(Ok(_)) => (), + Poll::Ready(Err(err)) => return Poll::Ready(Err((err, stream.into_io()))), + Poll::Pending => { + *this = MidHandshake::Handshaking(stream); + return Poll::Pending; + } + } + }; + } + + while tls_stream.session.is_handshaking() { + try_poll!(tls_stream.handshake(cx)); + } + + try_poll!(Pin::new(&mut tls_stream).poll_flush(cx)); + } + + Poll::Ready(Ok(stream)) + } +} diff --git a/src/async_rustls/src/common/mod.rs b/src/async_rustls/src/common/mod.rs new file mode 100644 index 00000000..860376e6 --- /dev/null +++ b/src/async_rustls/src/common/mod.rs @@ -0,0 +1,360 @@ +mod handshake; + +#[cfg(feature = "early-data")] +use alloc::vec::Vec; +use core::ops::{Deref, DerefMut}; +use core::pin::Pin; +use core::task::{Context, Poll}; +use futures_io::{AsyncRead, AsyncWrite}; +pub(crate) use handshake::{IoSession, MidHandshake}; +use rust_std_stub::io::{self, IoSlice, Read, Write}; +use rustls::{ConnectionCommon, SideData}; + +#[derive(Debug)] +pub enum TlsState { + #[cfg(feature = "early-data")] + EarlyData(usize, Vec), + Stream, + ReadShutdown, + WriteShutdown, + FullyShutdown, +} + +impl TlsState { + #[inline] + pub fn shutdown_read(&mut self) { + match *self { + TlsState::WriteShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown, + _ => *self = TlsState::ReadShutdown, + } + } + + #[inline] + pub fn shutdown_write(&mut self) { + match *self { + TlsState::ReadShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown, + _ => *self = TlsState::WriteShutdown, + } + } + + #[inline] + pub fn writeable(&self) -> bool { + !matches!(*self, TlsState::WriteShutdown | TlsState::FullyShutdown) + } + + #[inline] + pub fn readable(&self) -> bool { + !matches!(*self, TlsState::ReadShutdown | TlsState::FullyShutdown) + } + + #[inline] + #[cfg(feature = "early-data")] + pub fn is_early_data(&self) -> bool { + matches!(self, TlsState::EarlyData(..)) + } + + #[inline] + #[cfg(not(feature = "early-data"))] + pub fn is_early_data(&self) -> bool { + false + } +} + +pub struct Stream<'a, IO, S> { + pub io: &'a mut IO, + pub session: &'a mut S, + pub eof: bool, +} + +impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S, SD> Stream<'a, IO, S> +where + S: DerefMut + Deref>, + SD: SideData, +{ + pub fn new(io: &'a mut IO, session: &'a mut S) -> Self { + Stream { + io, + session, + // The state so far is only used to detect EOF, so either Stream + // or EarlyData state should both be all right. + eof: false, + } + } + + pub fn set_eof(mut self, eof: bool) -> Self { + self.eof = eof; + self + } + + pub fn as_mut_pin(&mut self) -> Pin<&mut Self> { + Pin::new(self) + } + + pub fn read_io(&mut self, cx: &mut Context) -> Poll> { + let mut reader = SyncReadAdapter { io: self.io, cx }; + + let n = match self.session.read_tls(&mut reader) { + Ok(n) => n, + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => return Poll::Pending, + Err(err) => return Poll::Ready(Err(err)), + }; + + let stats = self.session.process_new_packets().map_err(|err| { + // In case we have an alert to send describing this error, + // try a last-gasp write -- but don't predate the primary + // error. + let _ = self.write_io(cx); + + io::Error::new(io::ErrorKind::InvalidData, err) + })?; + + if stats.peer_has_closed() && self.session.is_handshaking() { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::UnexpectedEof, + "tls handshake alert", + ))); + } + + Poll::Ready(Ok(n)) + } + + pub fn write_io(&mut self, cx: &mut Context) -> Poll> { + struct Writer<'a, 'b, T> { + io: &'a mut T, + cx: &'a mut Context<'b>, + } + + impl<'a, 'b, T: Unpin> Writer<'a, 'b, T> { + #[inline] + fn poll_with( + &mut self, + f: impl FnOnce(Pin<&mut T>, &mut Context<'_>) -> Poll>, + ) -> io::Result { + match f(Pin::new(self.io), self.cx) { + Poll::Ready(result) => result, + Poll::Pending => Err(io::ErrorKind::WouldBlock.into()), + } + } + } + + impl<'a, 'b, T: AsyncWrite + Unpin> Write for Writer<'a, 'b, T> { + #[inline] + fn write(&mut self, buf: &[u8]) -> io::Result { + self.poll_with(|io, cx| io.poll_write(cx, buf)) + } + + #[inline] + fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result { + self.poll_with(|io, cx| io.poll_write_vectored(cx, bufs)) + } + + #[inline] + fn flush(&mut self) -> io::Result<()> { + self.poll_with(|io, cx| io.poll_flush(cx)) + } + } + + let mut writer = Writer { io: self.io, cx }; + + match self.session.write_tls(&mut writer) { + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending, + result => Poll::Ready(result), + } + } + + pub fn handshake(&mut self, cx: &mut Context) -> Poll> { + let mut wrlen = 0; + let mut rdlen = 0; + + loop { + let mut write_would_block = false; + let mut read_would_block = false; + let mut need_flush = false; + + while self.session.wants_write() { + match self.write_io(cx) { + Poll::Ready(Ok(n)) => { + wrlen += n; + need_flush = true; + } + Poll::Pending => { + write_would_block = true; + break; + } + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + } + } + + if need_flush { + match Pin::new(&mut self.io).poll_flush(cx) { + Poll::Ready(Ok(())) => (), + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + Poll::Pending => write_would_block = true, + } + } + + while !self.eof && self.session.wants_read() { + match self.read_io(cx) { + Poll::Ready(Ok(0)) => self.eof = true, + Poll::Ready(Ok(n)) => rdlen += n, + Poll::Pending => { + read_would_block = true; + break; + } + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + } + } + + return match (self.eof, self.session.is_handshaking()) { + (true, true) => { + let err = io::Error::new(io::ErrorKind::UnexpectedEof, "tls handshake eof"); + Poll::Ready(Err(err)) + } + (_, false) => Poll::Ready(Ok((rdlen, wrlen))), + (_, true) if write_would_block || read_would_block => { + if rdlen != 0 || wrlen != 0 { + Poll::Ready(Ok((rdlen, wrlen))) + } else { + Poll::Pending + } + } + (..) => continue, + }; + } + } +} + +impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S, SD> AsyncRead for Stream<'a, IO, S> +where + S: DerefMut + Deref>, + SD: SideData, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + let mut io_pending = false; + + // read a packet + while !self.eof && self.session.wants_read() { + match self.read_io(cx) { + Poll::Ready(Ok(0)) => { + self.eof = true; + break; + } + Poll::Ready(Ok(_)) => (), + Poll::Pending => { + io_pending = true; + break; + } + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + } + } + + match self.session.reader().read(buf) { + // If Rustls returns `Ok(0)` (while `buf` is non-empty), the peer closed the + // connection with a `CloseNotify` message and no more data will be forthcoming. + // + // Rustls yielded more data: advance the buffer, then see if more data is coming. + // + // We don't need to modify `self.eof` here, because it is only a temporary mark. + // rustls will only return 0 if is has received `CloseNotify`, + // in which case no additional processing is required. + Ok(n) => Poll::Ready(Ok(n)), + + // Rustls doesn't have more data to yield, but it believes the connection is open. + Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { + if !io_pending { + // If `wants_read()` is satisfied, rustls will not return `WouldBlock`. + // but if it does, we can try again. + // + // If the rustls state is abnormal, it may cause a cyclic wakeup. + // but tokio's cooperative budget will prevent infinite wakeup. + cx.waker().wake_by_ref(); + } + + Poll::Pending + } + + Err(err) => Poll::Ready(Err(err)), + } + } +} + +impl<'a, IO: AsyncRead + AsyncWrite + Unpin, C, SD> AsyncWrite for Stream<'a, IO, C> +where + C: DerefMut + Deref>, + SD: SideData, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context, + buf: &[u8], + ) -> Poll> { + let mut pos = 0; + + while pos != buf.len() { + let mut would_block = false; + + match self.session.writer().write(&buf[pos..]) { + Ok(n) => pos += n, + Err(err) => return Poll::Ready(Err(err)), + }; + + while self.session.wants_write() { + match self.write_io(cx) { + Poll::Ready(Ok(0)) | Poll::Pending => { + would_block = true; + break; + } + Poll::Ready(Ok(_)) => (), + Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), + } + } + + return match (pos, would_block) { + (0, true) => Poll::Pending, + (n, true) => Poll::Ready(Ok(n)), + (_, false) => continue, + }; + } + + Poll::Ready(Ok(pos)) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + self.session.writer().flush()?; + while self.session.wants_write() { + ready!(self.write_io(cx))?; + } + Pin::new(&mut self.io).poll_flush(cx) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + while self.session.wants_write() { + ready!(self.write_io(cx))?; + } + Pin::new(&mut self.io).poll_close(cx) + } +} + +/// An adapter that implements a [`Read`] interface for [`AsyncRead`] types and an +/// associated [`Context`]. +/// +/// Turns `Poll::Pending` into `WouldBlock`. +pub struct SyncReadAdapter<'a, 'b, T> { + pub io: &'a mut T, + pub cx: &'a mut Context<'b>, +} + +impl<'a, 'b, T: AsyncRead + Unpin> Read for SyncReadAdapter<'a, 'b, T> { + #[inline] + fn read(&mut self, buf: &mut [u8]) -> io::Result { + match Pin::new(&mut self.io).poll_read(self.cx, buf) { + Poll::Ready(Ok(n)) => Ok(n), + Poll::Ready(Err(err)) => Err(err), + Poll::Pending => Err(io::ErrorKind::WouldBlock.into()), + } + } +} diff --git a/src/async_rustls/src/lib.rs b/src/async_rustls/src/lib.rs new file mode 100644 index 00000000..57ebd43c --- /dev/null +++ b/src/async_rustls/src/lib.rs @@ -0,0 +1,450 @@ +#![no_std] + +extern crate alloc; + +macro_rules! ready { + ( $e:expr ) => { + match $e { + core::task::Poll::Ready(t) => t, + core::task::Poll::Pending => return core::task::Poll::Pending, + } + }; +} + +pub mod client; +mod common; +pub mod server; + +use alloc::sync::Arc; +#[cfg(feature = "early-data")] +use alloc::vec::Vec; +use common::{MidHandshake, Stream, TlsState}; +use core::future::Future; +use core::pin::Pin; +use core::task::{Context, Poll}; +use futures_io::{AsyncRead, AsyncWrite}; +use rust_std_stub::io; +use rustls::crypto::ring::Ring; +use rustls::{ClientConfig, ClientConnection, CommonState, ServerConfig, ServerConnection}; + +pub use rustls; + +/// A wrapper around a `rustls::ClientConfig`, providing an async `connect` method. +#[derive(Clone)] +pub struct TlsConnector { + inner: Arc>, + #[cfg(feature = "early-data")] + early_data: bool, +} + +/// A wrapper around a `rustls::ServerConfig`, providing an async `accept` method. +#[derive(Clone)] +pub struct TlsAcceptor { + inner: Arc>, +} + +impl From>> for TlsConnector { + fn from(inner: Arc>) -> TlsConnector { + TlsConnector { + inner, + #[cfg(feature = "early-data")] + early_data: false, + } + } +} + +impl From>> for TlsAcceptor { + fn from(inner: Arc>) -> TlsAcceptor { + TlsAcceptor { inner } + } +} + +impl TlsConnector { + /// Enable 0-RTT. + /// + /// If you want to use 0-RTT, + /// You must also set `ClientConfig.enable_early_data` to `true`. + #[cfg(feature = "early-data")] + pub fn early_data(mut self, flag: bool) -> TlsConnector { + self.early_data = flag; + self + } + + #[inline] + pub fn connect(&self, domain: rustls::ServerName, stream: IO) -> Connect + where + IO: AsyncRead + AsyncWrite + Unpin, + { + self.connect_with(domain, stream, |_| ()) + } + + pub fn connect_with(&self, domain: rustls::ServerName, stream: IO, f: F) -> Connect + where + IO: AsyncRead + AsyncWrite + Unpin, + F: FnOnce(&mut ClientConnection), + { + let mut session = match ClientConnection::new(self.inner.clone(), domain) { + Ok(session) => session, + Err(error) => { + return Connect(MidHandshake::Error { + io: stream, + // TODO(eliza): should this really return an `io::Error`? + // Probably not... + error: io::Error::new(io::ErrorKind::Other, error), + }); + } + }; + f(&mut session); + + Connect(MidHandshake::Handshaking(client::TlsStream { + io: stream, + + #[cfg(not(feature = "early-data"))] + state: TlsState::Stream, + + #[cfg(feature = "early-data")] + state: if self.early_data && session.early_data().is_some() { + TlsState::EarlyData(0, Vec::new()) + } else { + TlsState::Stream + }, + + #[cfg(feature = "early-data")] + early_waker: None, + + session, + })) + } +} + +impl TlsAcceptor { + #[inline] + pub fn accept(&self, stream: IO) -> Accept + where + IO: AsyncRead + AsyncWrite + Unpin, + { + self.accept_with(stream, |_| ()) + } + + pub fn accept_with(&self, stream: IO, f: F) -> Accept + where + IO: AsyncRead + AsyncWrite + Unpin, + F: FnOnce(&mut ServerConnection), + { + let mut session = match ServerConnection::new(self.inner.clone()) { + Ok(session) => session, + Err(error) => { + return Accept(MidHandshake::Error { + io: stream, + // TODO(eliza): should this really return an `io::Error`? + // Probably not... + error: io::Error::new(io::ErrorKind::Other, error), + }); + } + }; + f(&mut session); + + Accept(MidHandshake::Handshaking(server::TlsStream { + session, + io: stream, + state: TlsState::Stream, + })) + } +} + +pub struct LazyConfigAcceptor { + acceptor: rustls::server::Acceptor, + io: Option, +} + +impl LazyConfigAcceptor +where + IO: AsyncRead + AsyncWrite + Unpin, +{ + #[inline] + pub fn new(acceptor: rustls::server::Acceptor, io: IO) -> Self { + Self { + acceptor, + io: Some(io), + } + } +} + +impl Future for LazyConfigAcceptor +where + IO: AsyncRead + AsyncWrite + Unpin, +{ + type Output = Result, io::Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + loop { + let io = match this.io.as_mut() { + Some(io) => io, + None => { + panic!("Acceptor cannot be polled after acceptance."); + } + }; + + let mut reader = common::SyncReadAdapter { io, cx }; + match this.acceptor.read_tls(&mut reader) { + Ok(0) => return Poll::Ready(Err(io::ErrorKind::UnexpectedEof.into())), + Ok(_) => {} + Err(e) if e.kind() == io::ErrorKind::WouldBlock => return Poll::Pending, + Err(e) => return Poll::Ready(Err(e)), + } + + match this.acceptor.accept() { + Ok(Some(accepted)) => { + let io = this.io.take().unwrap(); + return Poll::Ready(Ok(StartHandshake { accepted, io })); + } + Ok(None) => continue, + Err(err) => { + return Poll::Ready(Err(io::Error::new(io::ErrorKind::InvalidInput, err))) + } + } + } + } +} + +pub struct StartHandshake { + accepted: rustls::server::Accepted, + io: IO, +} + +impl StartHandshake +where + IO: AsyncRead + AsyncWrite + Unpin, +{ + pub fn client_hello(&self) -> rustls::server::ClientHello<'_> { + self.accepted.client_hello() + } + + pub fn into_stream(self, config: Arc>) -> Accept { + self.into_stream_with(config, |_| ()) + } + + pub fn into_stream_with(self, config: Arc>, f: F) -> Accept + where + F: FnOnce(&mut ServerConnection), + { + let mut conn = match self.accepted.into_connection(config) { + Ok(conn) => conn, + Err(error) => { + return Accept(MidHandshake::Error { + io: self.io, + // TODO(eliza): should this really return an `io::Error`? + // Probably not... + error: io::Error::new(io::ErrorKind::Other, error), + }); + } + }; + f(&mut conn); + + Accept(MidHandshake::Handshaking(server::TlsStream { + session: conn, + io: self.io, + state: TlsState::Stream, + })) + } +} + +/// Future returned from `TlsConnector::connect` which will resolve +/// once the connection handshake has finished. +pub struct Connect(MidHandshake>); + +/// Future returned from `TlsAcceptor::accept` which will resolve +/// once the accept handshake has finished. +pub struct Accept(MidHandshake>); + +/// Like [Connect], but returns `IO` on failure. +pub struct FallibleConnect(MidHandshake>); + +/// Like [Accept], but returns `IO` on failure. +pub struct FallibleAccept(MidHandshake>); + +impl Connect { + #[inline] + pub fn into_fallible(self) -> FallibleConnect { + FallibleConnect(self.0) + } + + pub fn get_ref(&self) -> Option<&IO> { + match &self.0 { + MidHandshake::Handshaking(sess) => Some(sess.get_ref().0), + MidHandshake::Error { io, .. } => Some(io), + MidHandshake::End => None, + } + } + + pub fn get_mut(&mut self) -> Option<&mut IO> { + match &mut self.0 { + MidHandshake::Handshaking(sess) => Some(sess.get_mut().0), + MidHandshake::Error { io, .. } => Some(io), + MidHandshake::End => None, + } + } +} + +impl Accept { + #[inline] + pub fn into_fallible(self) -> FallibleAccept { + FallibleAccept(self.0) + } + + pub fn get_ref(&self) -> Option<&IO> { + match &self.0 { + MidHandshake::Handshaking(sess) => Some(sess.get_ref().0), + MidHandshake::Error { io, .. } => Some(io), + MidHandshake::End => None, + } + } + + pub fn get_mut(&mut self) -> Option<&mut IO> { + match &mut self.0 { + MidHandshake::Handshaking(sess) => Some(sess.get_mut().0), + MidHandshake::Error { io, .. } => Some(io), + MidHandshake::End => None, + } + } +} + +impl Future for Connect { + type Output = io::Result>; + + #[inline] + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err) + } +} + +impl Future for Accept { + type Output = io::Result>; + + #[inline] + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err) + } +} + +impl Future for FallibleConnect { + type Output = Result, (io::Error, IO)>; + + #[inline] + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + Pin::new(&mut self.0).poll(cx) + } +} + +impl Future for FallibleAccept { + type Output = Result, (io::Error, IO)>; + + #[inline] + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + Pin::new(&mut self.0).poll(cx) + } +} + +/// Unified TLS stream type +/// +/// This abstracts over the inner `client::TlsStream` and `server::TlsStream`, so you can use +/// a single type to keep both client- and server-initiated TLS-encrypted connections. +#[allow(clippy::large_enum_variant)] // https://github.com/rust-lang/rust-clippy/issues/9798 +#[derive(Debug)] +pub enum TlsStream { + Client(client::TlsStream), + Server(server::TlsStream), +} + +impl TlsStream { + pub fn get_ref(&self) -> (&T, &CommonState) { + use TlsStream::*; + match self { + Client(io) => { + let (io, session) = io.get_ref(); + (io, session) + } + Server(io) => { + let (io, session) = io.get_ref(); + (io, session) + } + } + } + + pub fn get_mut(&mut self) -> (&mut T, &mut CommonState) { + use TlsStream::*; + match self { + Client(io) => { + let (io, session) = io.get_mut(); + (io, &mut *session) + } + Server(io) => { + let (io, session) = io.get_mut(); + (io, &mut *session) + } + } + } +} + +impl From> for TlsStream { + fn from(s: client::TlsStream) -> Self { + Self::Client(s) + } +} + +impl From> for TlsStream { + fn from(s: server::TlsStream) -> Self { + Self::Server(s) + } +} + +impl AsyncRead for TlsStream +where + T: AsyncRead + AsyncWrite + Unpin, +{ + #[inline] + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + match self.get_mut() { + TlsStream::Client(x) => Pin::new(x).poll_read(cx, buf), + TlsStream::Server(x) => Pin::new(x).poll_read(cx, buf), + } + } +} + +impl AsyncWrite for TlsStream +where + T: AsyncRead + AsyncWrite + Unpin, +{ + #[inline] + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + match self.get_mut() { + TlsStream::Client(x) => Pin::new(x).poll_write(cx, buf), + TlsStream::Server(x) => Pin::new(x).poll_write(cx, buf), + } + } + + #[inline] + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + TlsStream::Client(x) => Pin::new(x).poll_flush(cx), + TlsStream::Server(x) => Pin::new(x).poll_flush(cx), + } + } + + #[inline] + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match self.get_mut() { + TlsStream::Client(x) => Pin::new(x).poll_close(cx), + TlsStream::Server(x) => Pin::new(x).poll_close(cx), + } + } +} diff --git a/src/async_rustls/src/server.rs b/src/async_rustls/src/server.rs new file mode 100644 index 00000000..b5476b45 --- /dev/null +++ b/src/async_rustls/src/server.rs @@ -0,0 +1,122 @@ +use super::*; +use crate::common::IoSession; + +/// A wrapper around an underlying raw stream which implements the TLS or SSL +/// protocol. +#[derive(Debug)] +pub struct TlsStream { + pub(crate) io: IO, + pub(crate) session: ServerConnection, + pub(crate) state: TlsState, +} + +impl TlsStream { + #[inline] + pub fn get_ref(&self) -> (&IO, &ServerConnection) { + (&self.io, &self.session) + } + + #[inline] + pub fn get_mut(&mut self) -> (&mut IO, &mut ServerConnection) { + (&mut self.io, &mut self.session) + } + + #[inline] + pub fn into_inner(self) -> (IO, ServerConnection) { + (self.io, self.session) + } +} + +impl IoSession for TlsStream { + type Io = IO; + type Session = ServerConnection; + + #[inline] + fn skip_handshake(&self) -> bool { + false + } + + #[inline] + fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session) { + (&mut self.state, &mut self.io, &mut self.session) + } + + #[inline] + fn into_io(self) -> Self::Io { + self.io + } +} + +impl AsyncRead for TlsStream +where + IO: AsyncRead + AsyncWrite + Unpin, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + let this = self.get_mut(); + let mut stream = + Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); + + match &this.state { + TlsState::Stream | TlsState::WriteShutdown => { + match stream.as_mut_pin().poll_read(cx, buf) { + Poll::Ready(Ok(n)) => { + if n == 0 || stream.eof { + this.state.shutdown_read(); + } + + Poll::Ready(Ok(n)) + } + Poll::Ready(Err(err)) if err.kind() == io::ErrorKind::UnexpectedEof => { + this.state.shutdown_read(); + Poll::Ready(Err(err)) + } + output => output, + } + } + TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(0)), + #[cfg(feature = "early-data")] + s => unreachable!("server TLS can not hit this state: {:?}", s), + } + } +} + +impl AsyncWrite for TlsStream +where + IO: AsyncRead + AsyncWrite + Unpin, +{ + /// Note: that it does not guarantee the final data to be sent. + /// To be cautious, you must manually call `flush`. + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let this = self.get_mut(); + let mut stream = + Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); + stream.as_mut_pin().poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + let mut stream = + Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); + stream.as_mut_pin().poll_flush(cx) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if self.state.writeable() { + self.session.send_close_notify(); + self.state.shutdown_write(); + } + + let this = self.get_mut(); + let mut stream = + Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable()); + stream.as_mut_pin().poll_close(cx) + } +}