Skip to content

Commit

Permalink
add async_rustls
Browse files Browse the repository at this point in the history
This crate inherits from the `async-rustls` crate and replaces its
dependency on std with 'rust-std-stub' using the prelude.

Signed-off-by: Jiaqi Gao <[email protected]>
  • Loading branch information
gaojiaqi7 committed Aug 8, 2023
1 parent d4ce3f1 commit fd80aa0
Show file tree
Hide file tree
Showing 8 changed files with 1,254 additions and 0 deletions.
16 changes: 16 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[workspace]

members = [
"src/async_rustls",
"src/attestation",
"src/crypto",
"src/devices/pci",
Expand Down
15 changes: 15 additions & 0 deletions src/async_rustls/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
[package]
name = "async_rustls"
version = "0.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
rust_std_stub = { path = "../std-support/rust-std-stub" }
futures-io = { path = "../std-support/futures-io" }
rustls = { path = "../../deps/rustls/rustls", default-features = false, features = ["no_std", "alloc"] }

[features]
dangerous_configuration = ["rustls/dangerous_configuration"]
early-data = []
220 changes: 220 additions & 0 deletions src/async_rustls/src/client.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
use super::*;
use crate::common::IoSession;

/// A wrapper around an underlying raw stream which implements the TLS or SSL
/// protocol.
#[derive(Debug)]
pub struct TlsStream<IO> {
pub(crate) io: IO,
pub(crate) session: ClientConnection,
pub(crate) state: TlsState,

#[cfg(feature = "early-data")]
pub(crate) early_waker: Option<core::task::Waker>,
}

impl<IO> TlsStream<IO> {
#[inline]
pub fn get_ref(&self) -> (&IO, &ClientConnection) {
(&self.io, &self.session)
}

#[inline]
pub fn get_mut(&mut self) -> (&mut IO, &mut ClientConnection) {
(&mut self.io, &mut self.session)
}

#[inline]
pub fn into_inner(self) -> (IO, ClientConnection) {
(self.io, self.session)
}
}

impl<IO> IoSession for TlsStream<IO> {
type Io = IO;
type Session = ClientConnection;

#[inline]
fn skip_handshake(&self) -> bool {
self.state.is_early_data()
}

#[inline]
fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session) {
(&mut self.state, &mut self.io, &mut self.session)
}

#[inline]
fn into_io(self) -> Self::Io {
self.io
}
}

impl<IO> AsyncRead for TlsStream<IO>
where
IO: AsyncRead + AsyncWrite + Unpin,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
match self.state {
#[cfg(feature = "early-data")]
TlsState::EarlyData(..) => {
let this = self.get_mut();

// In the EarlyData state, we have not really established a Tls connection.
// Before writing data through `AsyncWrite` and completing the tls handshake,
// we ignore read readiness and return to pending.
//
// In order to avoid event loss,
// we need to register a waker and wake it up after tls is connected.
if this
.early_waker
.as_ref()
.filter(|waker| cx.waker().will_wake(waker))
.is_none()
{
this.early_waker = Some(cx.waker().clone());
}

Poll::Pending
}
TlsState::Stream | TlsState::WriteShutdown => {
let this = self.get_mut();
let mut stream =
Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());

match stream.as_mut_pin().poll_read(cx, buf) {
Poll::Ready(Ok(n)) => {
if n == 0 || stream.eof {
this.state.shutdown_read();
}

Poll::Ready(Ok(n))
}
Poll::Ready(Err(err)) if err.kind() == io::ErrorKind::ConnectionAborted => {
this.state.shutdown_read();
Poll::Ready(Err(err))
}
output => output,
}
}
TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(0)),
}
}
}

impl<IO> AsyncWrite for TlsStream<IO>
where
IO: AsyncRead + AsyncWrite + Unpin,
{
/// Note: that it does not guarantee the final data to be sent.
/// To be cautious, you must manually call `flush`.
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
let mut stream =
Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());

#[allow(clippy::match_single_binding)]
match this.state {
#[cfg(feature = "early-data")]
TlsState::EarlyData(ref mut pos, ref mut data) => {
use rust_std_stub::io::Write;

// write early data
if let Some(mut early_data) = stream.session.early_data() {
let len = match early_data.write(buf) {
Ok(n) => n,
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
return Poll::Pending
}
Err(err) => return Poll::Ready(Err(err)),
};
if len != 0 {
data.extend_from_slice(&buf[..len]);
return Poll::Ready(Ok(len));
}
}

// complete handshake
while stream.session.is_handshaking() {
ready!(stream.handshake(cx))?;
}

// write early data (fallback)
if !stream.session.is_early_data_accepted() {
while *pos < data.len() {
let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
*pos += len;
}
}

// end
this.state = TlsState::Stream;

if let Some(waker) = this.early_waker.take() {
waker.wake();
}

stream.as_mut_pin().poll_write(cx, buf)
}
_ => stream.as_mut_pin().poll_write(cx, buf),
}
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.get_mut();
let mut stream =
Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());

#[cfg(feature = "early-data")]
{
if let TlsState::EarlyData(ref mut pos, ref mut data) = this.state {
// complete handshake
while stream.session.is_handshaking() {
ready!(stream.handshake(cx))?;
}

// write early data (fallback)
if !stream.session.is_early_data_accepted() {
while *pos < data.len() {
let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
*pos += len;
}
}

this.state = TlsState::Stream;

if let Some(waker) = this.early_waker.take() {
waker.wake();
}
}
}

stream.as_mut_pin().poll_flush(cx)
}

fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
// complete handshake
#[cfg(feature = "early-data")]
if matches!(self.state, TlsState::EarlyData(..)) {
ready!(self.as_mut().poll_flush(cx))?;
}

if self.state.writeable() {
self.session.send_close_notify();
self.state.shutdown_write();
}

let this = self.get_mut();
let mut stream =
Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
stream.as_mut_pin().poll_close(cx)
}
}
70 changes: 70 additions & 0 deletions src/async_rustls/src/common/handshake.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
use crate::common::{Stream, TlsState};
use core::future::Future;
use core::ops::{Deref, DerefMut};
use core::pin::Pin;
use core::task::{Context, Poll};
use futures_io::{AsyncRead, AsyncWrite};
use rust_std_stub::{io, mem};
use rustls::{ConnectionCommon, SideData};

pub(crate) trait IoSession {
type Io;
type Session;

fn skip_handshake(&self) -> bool;
fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session);
fn into_io(self) -> Self::Io;
}

pub(crate) enum MidHandshake<IS: IoSession> {
Handshaking(IS),
End,
Error { io: IS::Io, error: io::Error },
}

impl<IS, SD> Future for MidHandshake<IS>
where
IS: IoSession + Unpin,
IS::Io: AsyncRead + AsyncWrite + Unpin,
IS::Session: DerefMut + Deref<Target = ConnectionCommon<SD>> + Unpin,
SD: SideData,
{
type Output = Result<IS, (io::Error, IS::Io)>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();

let mut stream = match mem::replace(this, MidHandshake::End) {
MidHandshake::Handshaking(stream) => stream,
// Starting the handshake returned an error; fail the future immediately.
MidHandshake::Error { io, error } => return Poll::Ready(Err((error, io))),
_ => panic!("unexpected polling after handshake"),
};

if !stream.skip_handshake() {
let (state, io, session) = stream.get_mut();
let mut tls_stream = Stream::new(io, session).set_eof(!state.readable());

macro_rules! try_poll {
( $e:expr ) => {
match $e {
Poll::Ready(Ok(_)) => (),
Poll::Ready(Err(err)) => return Poll::Ready(Err((err, stream.into_io()))),
Poll::Pending => {
*this = MidHandshake::Handshaking(stream);
return Poll::Pending;
}
}
};
}

while tls_stream.session.is_handshaking() {
try_poll!(tls_stream.handshake(cx));
}

try_poll!(Pin::new(&mut tls_stream).poll_flush(cx));
}

Poll::Ready(Ok(stream))
}
}
Loading

0 comments on commit fd80aa0

Please sign in to comment.