diff --git a/.gitignore b/.gitignore index c1e5bc64..973db0e1 100755 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,9 @@ target/ # macOS .DS_Store ._* + +# Example Certificates +localhost-key.pem +localhost.crt +localhost.key +localhost.pem diff --git a/Cargo.toml b/Cargo.toml index 89a45c59..93763f5b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,6 +2,7 @@ members = [ "h3", "h3-quinn", + "h3-webtransport", # Internal "examples", diff --git a/examples/Cargo.toml b/examples/Cargo.toml index 7196d051..e1900860 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -7,19 +7,37 @@ edition = "2018" # If you copy one of the examples into a new project, you should be using # [dependencies] instead. [dev-dependencies] +anyhow = "1.0" bytes = "1" futures = "0.3" h3 = { path = "../h3" } h3-quinn = { path = "../h3-quinn" } +h3-webtransport = { path = "../h3-webtransport" } http = "0.2" -quinn = { version = "0.10", default-features = false, features = ["runtime-tokio", "tls-rustls", "ring"] } +quinn = { version = "0.10", default-features = false, features = [ + "runtime-tokio", + "tls-rustls", + "ring", +] } rcgen = { version = "0.10" } rustls = { version = "0.21", features = ["dangerous_configuration"] } rustls-native-certs = "0.6" structopt = "0.3" tokio = { version = "1.27", features = ["full"] } tracing = "0.1.37" -tracing-subscriber = { version = "0.3", default-features = false, features = ["fmt", "ansi", "env-filter", "time", "tracing-log"] } +tracing-subscriber = { version = "0.3", default-features = false, features = [ + "fmt", + "ansi", + "env-filter", + "time", + "tracing-log", +] } +octets = "0.2.0" + +tracing-tree = { version = "0.2" } + +[features] +tree = [] [[example]] name = "client" @@ -28,3 +46,7 @@ path = "client.rs" [[example]] name = "server" path = "server.rs" + +[[example]] +name = "webtransport_server" +path = "webtransport_server.rs" diff --git a/examples/launch_chrome.sh b/examples/launch_chrome.sh new file mode 100755 index 00000000..97acd7ee --- /dev/null +++ b/examples/launch_chrome.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +set -e + +SPKI=`openssl x509 -inform der -in localhost.crt -pubkey -noout | openssl pkey -pubin -outform der | openssl dgst -sha256 -binary | openssl enc -base64` + +echo "Got cert key $SPKI" + +echo "Opening google chrome" + +case `uname` in + (*Linux*) google-chrome --origin-to-force-quic-on=127.0.0.1:4433 --ignore-certificate-errors-spki-list=$SPKI --enable-logging --v=1 ;; + (*Darwin*) open -a "Google Chrome" --args --origin-to-force-quic-on=127.0.0.1:4433 --ignore-certificate-errors-spki-list=$SPKI --enable-logging --v=1 ;; +esac + +## Logs are stored to ~/Library/Application Support/Google/Chrome/chrome_debug.log diff --git a/examples/webtransport_server.rs b/examples/webtransport_server.rs new file mode 100644 index 00000000..58d4ba43 --- /dev/null +++ b/examples/webtransport_server.rs @@ -0,0 +1,341 @@ +use anyhow::{Context, Result}; +use bytes::{BufMut, Bytes, BytesMut}; +use h3::{ + error::ErrorLevel, + ext::Protocol, + quic::{self, RecvDatagramExt, SendDatagramExt, SendStreamUnframed}, + server::Connection, +}; +use h3_quinn::quinn; +use h3_webtransport::{ + server::{self, WebTransportSession}, + stream, +}; +use http::Method; +use rustls::{Certificate, PrivateKey}; +use std::{net::SocketAddr, path::PathBuf, sync::Arc, time::Duration}; +use structopt::StructOpt; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; +use tokio::pin; +use tracing::{error, info, trace_span}; + +#[derive(StructOpt, Debug)] +#[structopt(name = "server")] +struct Opt { + #[structopt( + short, + long, + default_value = "127.0.0.1:4433", + help = "What address:port to listen for new connections" + )] + pub listen: SocketAddr, + + #[structopt(flatten)] + pub certs: Certs, +} + +#[derive(StructOpt, Debug)] +pub struct Certs { + #[structopt( + long, + short, + default_value = "examples/localhost.crt", + help = "Certificate for TLS. If present, `--key` is mandatory." + )] + pub cert: PathBuf, + + #[structopt( + long, + short, + default_value = "examples/localhost.key", + help = "Private key for the certificate." + )] + pub key: PathBuf, +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + // 0. Setup tracing + #[cfg(not(feature = "tree"))] + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .with_span_events(tracing_subscriber::fmt::format::FmtSpan::FULL) + .with_writer(std::io::stderr) + .init(); + + #[cfg(feature = "tree")] + use tracing_subscriber::prelude::*; + #[cfg(feature = "tree")] + tracing_subscriber::registry() + .with(tracing_subscriber::EnvFilter::from_default_env()) + .with(tracing_tree::HierarchicalLayer::new(4).with_bracketed_fields(true)) + .init(); + + // process cli arguments + + let opt = Opt::from_args(); + + tracing::info!("Opt: {opt:#?}"); + let Certs { cert, key } = opt.certs; + + // create quinn server endpoint and bind UDP socket + + // both cert and key must be DER-encoded + let cert = Certificate(std::fs::read(cert)?); + let key = PrivateKey(std::fs::read(key)?); + + let mut tls_config = rustls::ServerConfig::builder() + .with_safe_default_cipher_suites() + .with_safe_default_kx_groups() + .with_protocol_versions(&[&rustls::version::TLS13]) + .unwrap() + .with_no_client_auth() + .with_single_cert(vec![cert], key)?; + + tls_config.max_early_data_size = u32::MAX; + let alpn: Vec> = vec![ + b"h3".to_vec(), + b"h3-32".to_vec(), + b"h3-31".to_vec(), + b"h3-30".to_vec(), + b"h3-29".to_vec(), + ]; + tls_config.alpn_protocols = alpn; + + let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(tls_config)); + let mut transport_config = quinn::TransportConfig::default(); + transport_config.keep_alive_interval(Some(Duration::from_secs(2))); + server_config.transport = Arc::new(transport_config); + let endpoint = quinn::Endpoint::server(server_config, opt.listen)?; + + info!("listening on {}", opt.listen); + + // 2. Accept new quic connections and spawn a new task to handle them + while let Some(new_conn) = endpoint.accept().await { + trace_span!("New connection being attempted"); + + tokio::spawn(async move { + match new_conn.await { + Ok(conn) => { + info!("new http3 established"); + let h3_conn = h3::server::builder() + .enable_webtransport(true) + .enable_connect(true) + .enable_datagram(true) + .max_webtransport_sessions(1) + .send_grease(true) + .build(h3_quinn::Connection::new(conn)) + .await + .unwrap(); + + // tracing::info!("Establishing WebTransport session"); + // // 3. TODO: Conditionally, if the client indicated that this is a webtransport session, we should accept it here, else use regular h3. + // // if this is a webtransport session, then h3 needs to stop handing the datagrams, bidirectional streams, and unidirectional streams and give them + // // to the webtransport session. + + tokio::spawn(async move { + if let Err(err) = handle_connection(h3_conn).await { + tracing::error!("Failed to handle connection: {err:?}"); + } + }); + // let mut session: WebTransportSession<_, Bytes> = + // WebTransportSession::accept(h3_conn).await.unwrap(); + // tracing::info!("Finished establishing webtransport session"); + // // 4. Get datagrams, bidirectional streams, and unidirectional streams and wait for client requests here. + // // h3_conn needs to handover the datagrams, bidirectional streams, and unidirectional streams to the webtransport session. + // let result = handle.await; + } + Err(err) => { + error!("accepting connection failed: {:?}", err); + } + } + }); + } + + // shut down gracefully + // wait for connections to be closed before exiting + endpoint.wait_idle().await; + + Ok(()) +} + +async fn handle_connection(mut conn: Connection) -> Result<()> { + // 3. TODO: Conditionally, if the client indicated that this is a webtransport session, we should accept it here, else use regular h3. + // if this is a webtransport session, then h3 needs to stop handing the datagrams, bidirectional streams, and unidirectional streams and give them + // to the webtransport session. + + loop { + match conn.accept().await { + Ok(Some((req, stream))) => { + info!("new request: {:#?}", req); + + let ext = req.extensions(); + match req.method() { + &Method::CONNECT if ext.get::() == Some(&Protocol::WEB_TRANSPORT) => { + tracing::info!("Peer wants to initiate a webtransport session"); + + tracing::info!("Handing over connection to WebTransport"); + let session = WebTransportSession::accept(req, stream, conn).await?; + tracing::info!("Established webtransport session"); + // 4. Get datagrams, bidirectional streams, and unidirectional streams and wait for client requests here. + // h3_conn needs to handover the datagrams, bidirectional streams, and unidirectional streams to the webtransport session. + handle_session_and_echo_all_inbound_messages(session).await?; + + return Ok(()); + } + _ => { + tracing::info!(?req, "Received request"); + } + } + } + + // indicating no more streams to be received + Ok(None) => { + break; + } + + Err(err) => { + error!("Error on accept {}", err); + match err.get_error_level() { + ErrorLevel::ConnectionError => break, + ErrorLevel::StreamError => continue, + } + } + } + } + Ok(()) +} + +macro_rules! log_result { + ($expr:expr) => { + if let Err(err) = $expr { + tracing::error!("{err:?}"); + } + }; +} + +async fn echo_stream(send: T, recv: R) -> anyhow::Result<()> +where + T: AsyncWrite, + R: AsyncRead, +{ + pin!(send); + pin!(recv); + + tracing::info!("Got stream"); + let mut buf = Vec::new(); + recv.read_to_end(&mut buf).await?; + + let message = Bytes::from(buf); + + send_chunked(send, message).await?; + + Ok(()) +} + +// Used to test that all chunks arrive properly as it is easy to write an impl which only reads and +// writes the first chunk. +async fn send_chunked(mut send: impl AsyncWrite + Unpin, data: Bytes) -> anyhow::Result<()> { + for chunk in data.chunks(4) { + tokio::time::sleep(Duration::from_millis(100)).await; + tracing::info!("Sending {chunk:?}"); + send.write_all(chunk).await?; + } + + Ok(()) +} + +async fn open_bidi_test(mut stream: S) -> anyhow::Result<()> +where + S: Unpin + AsyncRead + AsyncWrite, +{ + tracing::info!("Opening bidirectional stream"); + + stream + .write_all(b"Hello from a server initiated bidi stream") + .await + .context("Failed to respond")?; + + let mut resp = Vec::new(); + stream.shutdown().await?; + stream.read_to_end(&mut resp).await?; + + tracing::info!("Got response from client: {resp:?}"); + + Ok(()) +} + +/// This method will echo all inbound datagrams, unidirectional and bidirectional streams. +#[tracing::instrument(level = "info", skip(session))] +async fn handle_session_and_echo_all_inbound_messages( + session: WebTransportSession, +) -> anyhow::Result<()> +where + // Use trait bounds to ensure we only happen to use implementation that are only for the quinn + // backend. + C: 'static + + Send + + h3::quic::Connection + + RecvDatagramExt + + SendDatagramExt, + >::Error: + 'static + std::error::Error + Send + Sync + Into, + ::Error: + 'static + std::error::Error + Send + Sync + Into, + stream::BidiStream: + quic::BidiStream + Unpin + AsyncWrite + AsyncRead, + as quic::BidiStream>::SendStream: + Unpin + AsyncWrite + Send + Sync, + as quic::BidiStream>::RecvStream: + Unpin + AsyncRead + Send + Sync, + C::SendStream: Send + Unpin, + C::RecvStream: Send + Unpin, + C::BidiStream: Send + Unpin, + stream::SendStream: AsyncWrite, + C::BidiStream: SendStreamUnframed, + C::SendStream: SendStreamUnframed, +{ + let session_id = session.session_id(); + + // This will open a bidirectional stream and send a message to the client right after connecting! + let stream = session.open_bi(session_id).await?; + + tokio::spawn(async move { log_result!(open_bidi_test(stream).await) }); + + loop { + tokio::select! { + datagram = session.accept_datagram() => { + let datagram = datagram?; + if let Some((_, datagram)) = datagram { + tracing::info!("Responding with {datagram:?}"); + // Put something before to make sure encoding and decoding works and don't just + // pass through + let mut resp = BytesMut::from(&b"Response: "[..]); + resp.put(datagram); + + session.send_datagram(resp.freeze())?; + tracing::info!("Finished sending datagram"); + } + } + uni_stream = session.accept_uni() => { + let (id, stream) = uni_stream?.unwrap(); + + let send = session.open_uni(id).await?; + tokio::spawn( async move { log_result!(echo_stream(send, stream).await); }); + } + stream = session.accept_bi() => { + if let Some(server::AcceptedBi::BidiStream(_, stream)) = stream? { + let (send, recv) = quic::BidiStream::split(stream); + tokio::spawn( async move { log_result!(echo_stream(send, recv).await); }); + } + } + else => { + break + } + } + } + + tracing::info!("Finished handling session"); + + Ok(()) +} diff --git a/h3-quinn/Cargo.toml b/h3-quinn/Cargo.toml index 7dce6c76..c0b2291f 100644 --- a/h3-quinn/Cargo.toml +++ b/h3-quinn/Cargo.toml @@ -15,7 +15,10 @@ license = "MIT" [dependencies] h3 = { version = "0.0.2", path = "../h3" } bytes = "1" -quinn = { version = "0.10", default-features = false } +quinn = { version = "0.10", default-features = false, features = [ + "futures-io", +] } quinn-proto = { version = "0.10", default-features = false } tokio-util = { version = "0.7.7" } futures = { version = "0.3.27" } +tokio = { version = "1.28", features = ["io-util"], default-features = false } diff --git a/h3-quinn/src/lib.rs b/h3-quinn/src/lib.rs index 62d84d81..78696dec 100644 --- a/h3-quinn/src/lib.rs +++ b/h3-quinn/src/lib.rs @@ -7,22 +7,27 @@ use std::{ convert::TryInto, fmt::{self, Display}, future::Future, + pin::Pin, sync::Arc, task::{self, Poll}, }; -use bytes::{Buf, Bytes}; +use bytes::{Buf, Bytes, BytesMut}; use futures::{ ready, stream::{self, BoxStream}, StreamExt, }; +use quinn::ReadDatagram; pub use quinn::{ self, crypto::Session, AcceptBi, AcceptUni, Endpoint, OpenBi, OpenUni, VarInt, WriteError, }; -use h3::quic::{self, Error, StreamId, WriteBuf}; +use h3::{ + ext::Datagram, + quic::{self, Error, StreamId, WriteBuf}, +}; use tokio_util::sync::ReusableBoxFuture; /// A QUIC connection backed by Quinn @@ -34,6 +39,7 @@ pub struct Connection { opening_bi: Option as Future>::Output>>, incoming_uni: BoxStream<'static, as Future>::Output>, opening_uni: Option as Future>::Output>>, + datagrams: BoxStream<'static, as Future>::Output>, } impl Connection { @@ -45,10 +51,13 @@ impl Connection { Some((conn.accept_bi().await, conn)) })), opening_bi: None, - incoming_uni: Box::pin(stream::unfold(conn, |conn| async { + incoming_uni: Box::pin(stream::unfold(conn.clone(), |conn| async { Some((conn.accept_uni().await, conn)) })), opening_uni: None, + datagrams: Box::pin(stream::unfold(conn, |conn| async { + Some((conn.read_datagram().await, conn)) + })), } } } @@ -89,6 +98,58 @@ impl From for ConnectionError { } } +/// Types of errors when sending a datagram. +#[derive(Debug)] +pub enum SendDatagramError { + /// Datagrams are not supported by the peer + UnsupportedByPeer, + /// Datagrams are locally disabled + Disabled, + /// The datagram was too large to be sent. + TooLarge, + /// Network error + ConnectionLost(Box), +} + +impl fmt::Display for SendDatagramError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + SendDatagramError::UnsupportedByPeer => write!(f, "datagrams not supported by peer"), + SendDatagramError::Disabled => write!(f, "datagram support disabled"), + SendDatagramError::TooLarge => write!(f, "datagram too large"), + SendDatagramError::ConnectionLost(_) => write!(f, "connection lost"), + } + } +} + +impl std::error::Error for SendDatagramError {} + +impl Error for SendDatagramError { + fn is_timeout(&self) -> bool { + false + } + + fn err_code(&self) -> Option { + match self { + Self::ConnectionLost(err) => err.err_code(), + _ => None, + } + } +} + +impl From for SendDatagramError { + fn from(value: quinn::SendDatagramError) -> Self { + match value { + quinn::SendDatagramError::UnsupportedByPeer => Self::UnsupportedByPeer, + quinn::SendDatagramError::Disabled => Self::Disabled, + quinn::SendDatagramError::TooLarge => Self::TooLarge, + quinn::SendDatagramError::ConnectionLost(err) => { + Self::ConnectionLost(ConnectionError::from(err).into()) + } + } + } +} + impl quic::Connection for Connection where B: Buf, @@ -172,6 +233,40 @@ where } } +impl quic::SendDatagramExt for Connection +where + B: Buf, +{ + type Error = SendDatagramError; + + fn send_datagram(&mut self, data: Datagram) -> Result<(), SendDatagramError> { + // TODO investigate static buffer from known max datagram size + let mut buf = BytesMut::new(); + data.encode(&mut buf); + self.conn.send_datagram(buf.freeze())?; + + Ok(()) + } +} + +impl quic::RecvDatagramExt for Connection { + type Buf = Bytes; + + type Error = ConnectionError; + + #[inline] + fn poll_accept_datagram( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll, Self::Error>> { + match ready!(self.datagrams.poll_next_unpin(cx)) { + Some(Ok(x)) => Poll::Ready(Ok(Some(x))), + Some(Err(e)) => Poll::Ready(Err(e.into())), + None => Poll::Ready(Ok(None)), + } + } +} + /// Stream opener backed by a Quinn connection /// /// Implements [`quic::OpenStreams`] using [`quinn::Connection`], @@ -265,10 +360,7 @@ where } } -impl quic::RecvStream for BidiStream -where - B: Buf, -{ +impl quic::RecvStream for BidiStream { type Buf = Bytes; type Error = ReadError; @@ -282,6 +374,10 @@ where fn stop_sending(&mut self, error_code: u64) { self.recv.stop_sending(error_code) } + + fn recv_id(&self) -> StreamId { + self.recv.recv_id() + } } impl quic::SendStream for BidiStream @@ -306,8 +402,20 @@ where self.send.send_data(data) } - fn id(&self) -> StreamId { - self.send.id() + fn send_id(&self) -> StreamId { + self.send.send_id() + } +} +impl quic::SendStreamUnframed for BidiStream +where + B: Buf, +{ + fn poll_send( + &mut self, + cx: &mut task::Context<'_>, + buf: &mut D, + ) -> Poll> { + self.send.poll_send(cx, buf) } } @@ -364,6 +472,16 @@ impl quic::RecvStream for RecvStream { .stop(VarInt::from_u64(error_code).expect("invalid error_code")) .ok(); } + + fn recv_id(&self) -> StreamId { + self.stream + .as_ref() + .unwrap() + .id() + .0 + .try_into() + .expect("invalid stream id") + } } /// The error type for [`RecvStream`] @@ -372,7 +490,17 @@ impl quic::RecvStream for RecvStream { #[derive(Debug)] pub struct ReadError(quinn::ReadError); -impl std::error::Error for ReadError {} +impl From for std::io::Error { + fn from(value: ReadError) -> Self { + value.0.into() + } +} + +impl std::error::Error for ReadError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + self.0.source() + } +} impl fmt::Display for ReadError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -491,7 +619,7 @@ where Ok(()) } - fn id(&self) -> StreamId { + fn send_id(&self) -> StreamId { self.stream .as_ref() .unwrap() @@ -502,6 +630,48 @@ where } } +impl quic::SendStreamUnframed for SendStream +where + B: Buf, +{ + fn poll_send( + &mut self, + cx: &mut task::Context<'_>, + buf: &mut D, + ) -> Poll> { + if self.writing.is_some() { + // This signifies a bug in implementation + panic!("poll_send called while send stream is not ready") + } + + let s = Pin::new(self.stream.as_mut().unwrap()); + + let res = ready!(futures::io::AsyncWrite::poll_write(s, cx, buf.chunk())); + match res { + Ok(written) => { + buf.advance(written); + Poll::Ready(Ok(written)) + } + Err(err) => { + // We are forced to use AsyncWrite for now because we cannot store + // the result of a call to: + // quinn::send_stream::write<'a>(&'a mut self, buf: &'a [u8]) -> Result. + // + // This is why we have to unpack the error from io::Error instead of having it + // returned directly. This should not panic as long as quinn's AsyncWrite impl + // doesn't change. + let err = err + .into_inner() + .expect("write stream returned an empty error") + .downcast::() + .expect("write stream returned an error which type is not WriteError"); + + Poll::Ready(Err(SendStreamError::Write(*err))) + } + } + } +} + /// The error type for [`SendStream`] /// /// Wraps errors that can happen writing to or polling a send stream. @@ -514,6 +684,17 @@ pub enum SendStreamError { NotReady, } +impl From for std::io::Error { + fn from(value: SendStreamError) -> Self { + match value { + SendStreamError::Write(err) => err.into(), + SendStreamError::NotReady => { + std::io::Error::new(std::io::ErrorKind::Other, "send stream is not ready") + } + } + } +} + impl std::error::Error for SendStreamError {} impl Display for SendStreamError { diff --git a/h3-webtransport/Cargo.toml b/h3-webtransport/Cargo.toml new file mode 100644 index 00000000..c4901844 --- /dev/null +++ b/h3-webtransport/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "h3-webtransport" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +bytes = "1" +futures-util = { version = "0.3", default-features = false } +http = "0.2.9" +pin-project-lite = { version = "0.2", default_features = false } +tracing = "0.1.37" +tokio = { version = "1.28", default_features = false } + +[dependencies.h3] +version = "0.0.2" +path = "../h3" +features = ["i-implement-a-third-party-backend-and-opt-into-breaking-changes"] diff --git a/h3-webtransport/src/lib.rs b/h3-webtransport/src/lib.rs new file mode 100644 index 00000000..9900311a --- /dev/null +++ b/h3-webtransport/src/lib.rs @@ -0,0 +1,13 @@ +//! Provides the client and server support for WebTransport sessions. +//! +//! # Relevant Links +//! WebTransport: https://www.w3.org/TR/webtransport/#biblio-web-transport-http3 +//! WebTransport over HTTP/3: https://datatracker.ietf.org/doc/html/draft-ietf-webtrans-http3/ +#![deny(missing_docs)] + +/// Server side WebTransport session support +pub mod server; +/// Webtransport stream types +pub mod stream; + +pub use h3::webtransport::SessionId; diff --git a/h3-webtransport/src/server.rs b/h3-webtransport/src/server.rs new file mode 100644 index 00000000..3a212dd1 --- /dev/null +++ b/h3-webtransport/src/server.rs @@ -0,0 +1,427 @@ +//! Provides the server side WebTransport session + +use std::{ + marker::PhantomData, + pin::Pin, + sync::Mutex, + task::{Context, Poll}, +}; + +use bytes::Buf; +use futures_util::{future::poll_fn, ready, Future}; +use h3::{ + connection::ConnectionState, + error::{Code, ErrorLevel}, + ext::{Datagram, Protocol}, + frame::FrameStream, + proto::frame::Frame, + quic::{self, OpenStreams, RecvDatagramExt, SendDatagramExt, WriteBuf}, + server::{self, Connection, RequestStream}, + Error, +}; +use h3::{ + quic::SendStreamUnframed, + stream::{BidiStreamHeader, BufRecvStream, UniStreamHeader}, +}; +use http::{Method, Request, Response, StatusCode}; + +use h3::webtransport::SessionId; +use pin_project_lite::pin_project; + +use crate::stream::{BidiStream, RecvStream, SendStream}; + +/// WebTransport session driver. +/// +/// Maintains the session using the underlying HTTP/3 connection. +/// +/// Similar to [`crate::Connection`] it is generic over the QUIC implementation and Buffer. +pub struct WebTransportSession +where + C: quic::Connection, + B: Buf, +{ + // See: https://datatracker.ietf.org/doc/html/draft-ietf-webtrans-http3/#section-2-3 + session_id: SessionId, + /// The underlying HTTP/3 connection + server_conn: Mutex>, + connect_stream: RequestStream, + opener: Mutex, +} + +impl WebTransportSession +where + C: quic::Connection, + B: Buf, +{ + /// Accepts a *CONNECT* request for establishing a WebTransport session. + /// + /// TODO: is the API or the user responsible for validating the CONNECT request? + pub async fn accept( + request: Request<()>, + mut stream: RequestStream, + mut conn: Connection, + ) -> Result { + let shared = conn.shared_state().clone(); + { + let config = shared.write("Read WebTransport support").peer_config; + + if !config.enable_webtransport() { + return Err(conn.close( + Code::H3_SETTINGS_ERROR, + "webtransport is not supported by client", + )); + } + + if !config.enable_datagram() { + return Err(conn.close( + Code::H3_SETTINGS_ERROR, + "datagrams are not supported by client", + )); + } + } + + // The peer is responsible for validating our side of the webtransport support. + // + // However, it is still advantageous to show a log on the server as (attempting) to + // establish a WebTransportSession without the proper h3 config is usually a mistake. + if !conn.inner.config.enable_webtransport() { + tracing::warn!("Server does not support webtransport"); + } + + if !conn.inner.config.enable_datagram() { + tracing::warn!("Server does not support datagrams"); + } + + if !conn.inner.config.enable_extended_connect() { + tracing::warn!("Server does not support CONNECT"); + } + + // Respond to the CONNECT request. + + //= https://datatracker.ietf.org/doc/html/draft-ietf-webtrans-http3/#section-3.3 + let response = if validate_wt_connect(&request) { + Response::builder() + // This is the only header that chrome cares about. + .header("sec-webtransport-http3-draft", "draft02") + .status(StatusCode::OK) + .body(()) + .unwrap() + } else { + Response::builder() + .status(StatusCode::BAD_REQUEST) + .body(()) + .unwrap() + }; + + stream.send_response(response).await?; + + let session_id = stream.send_id().into(); + let conn_inner = &mut conn.inner.conn; + let opener = Mutex::new(conn_inner.opener()); + + Ok(Self { + session_id, + opener, + server_conn: Mutex::new(conn), + connect_stream: stream, + }) + } + + /// Receive a datagram from the client + pub fn accept_datagram(&self) -> ReadDatagram { + ReadDatagram { + conn: &self.server_conn, + _marker: PhantomData, + } + } + + /// Sends a datagram + /// + /// TODO: maybe make async. `quinn` does not require an async send + pub fn send_datagram(&self, data: B) -> Result<(), Error> + where + C: SendDatagramExt, + { + self.server_conn + .lock() + .unwrap() + .send_datagram(self.connect_stream.id(), data)?; + + Ok(()) + } + + /// Accept an incoming unidirectional stream from the client, it reads the stream until EOF. + pub fn accept_uni(&self) -> AcceptUni { + AcceptUni { + conn: &self.server_conn, + } + } + + /// Accepts an incoming bidirectional stream or request + pub async fn accept_bi(&self) -> Result>, Error> { + // Get the next stream + // Accept the incoming stream + let stream = poll_fn(|cx| { + let mut conn = self.server_conn.lock().unwrap(); + conn.poll_accept_request(cx) + }) + .await; + + let mut stream = match stream { + Ok(Some(s)) => FrameStream::new(BufRecvStream::new(s)), + Ok(None) => { + // FIXME: is proper HTTP GoAway shutdown required? + return Ok(None); + } + Err(err) => { + match err.kind() { + h3::error::Kind::Closed => return Ok(None), + h3::error::Kind::Application { + code, + reason, + level: ErrorLevel::ConnectionError, + .. + } => { + return Err(self.server_conn.lock().unwrap().close( + code, + reason.unwrap_or_else(|| String::into_boxed_str(String::from(""))), + )) + } + _ => return Err(err), + }; + } + }; + + // Read the first frame. + // + // This will determine if it is a webtransport bi-stream or a request stream + let frame = poll_fn(|cx| stream.poll_next(cx)).await; + + match frame { + Ok(None) => Ok(None), + Ok(Some(Frame::WebTransportStream(session_id))) => { + // Take the stream out of the framed reader and split it in half like Paul Allen + let stream = stream.into_inner(); + + Ok(Some(AcceptedBi::BidiStream( + session_id, + BidiStream::new(stream), + ))) + } + // Make the underlying HTTP/3 connection handle the rest + frame => { + let req = { + let mut conn = self.server_conn.lock().unwrap(); + conn.accept_with_frame(stream, frame)? + }; + if let Some(req) = req { + let (req, resp) = req.resolve().await?; + Ok(Some(AcceptedBi::Request(req, resp))) + } else { + Ok(None) + } + } + } + } + + /// Open a new bidirectional stream + pub fn open_bi(&self, session_id: SessionId) -> OpenBi { + OpenBi { + opener: &self.opener, + stream: None, + session_id, + } + } + + /// Open a new unidirectional stream + pub fn open_uni(&self, session_id: SessionId) -> OpenUni { + OpenUni { + opener: &self.opener, + stream: None, + session_id, + } + } + + /// Returns the session id + pub fn session_id(&self) -> SessionId { + self.session_id + } +} + +/// Streams are opened, but the initial webtransport header has not been sent +type PendingStreams = ( + BidiStream<>::BidiStream, B>, + WriteBuf<&'static [u8]>, +); + +/// Streams are opened, but the initial webtransport header has not been sent +type PendingUniStreams = ( + SendStream<>::SendStream, B>, + WriteBuf<&'static [u8]>, +); + +pin_project! { + /// Future for opening a bidi stream + pub struct OpenBi<'a, C:quic::Connection, B:Buf> { + opener: &'a Mutex, + stream: Option>, + session_id: SessionId, + } +} + +impl<'a, B, C> Future for OpenBi<'a, C, B> +where + C: quic::Connection, + B: Buf, + C::BidiStream: SendStreamUnframed, +{ + type Output = Result, Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut p = self.project(); + loop { + match &mut p.stream { + Some((stream, buf)) => { + while buf.has_remaining() { + ready!(stream.poll_send(cx, buf))?; + } + + let (stream, _) = p.stream.take().unwrap(); + return Poll::Ready(Ok(stream)); + } + None => { + let mut opener = (*p.opener).lock().unwrap(); + // Open the stream first + let res = ready!(opener.poll_open_bidi(cx))?; + let stream = BidiStream::new(BufRecvStream::new(res)); + + let buf = WriteBuf::from(BidiStreamHeader::WebTransportBidi(*p.session_id)); + *p.stream = Some((stream, buf)); + } + } + } + } +} + +pin_project! { + /// Opens a unidirectional stream + pub struct OpenUni<'a, C: quic::Connection, B:Buf> { + opener: &'a Mutex, + stream: Option>, + // Future for opening a uni stream + session_id: SessionId, + } +} + +impl<'a, C, B> Future for OpenUni<'a, C, B> +where + C: quic::Connection, + B: Buf, + C::SendStream: SendStreamUnframed, +{ + type Output = Result, Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut p = self.project(); + loop { + match &mut p.stream { + Some((send, buf)) => { + while buf.has_remaining() { + ready!(send.poll_send(cx, buf))?; + } + let (send, buf) = p.stream.take().unwrap(); + assert!(!buf.has_remaining()); + return Poll::Ready(Ok(send)); + } + None => { + let mut opener = (*p.opener).lock().unwrap(); + let send = ready!(opener.poll_open_send(cx))?; + let send = BufRecvStream::new(send); + let send = SendStream::new(send); + + let buf = WriteBuf::from(UniStreamHeader::WebTransportUni(*p.session_id)); + *p.stream = Some((send, buf)); + } + } + } + } +} + +/// An accepted incoming bidirectional stream. +/// +/// Since +pub enum AcceptedBi, B: Buf> { + /// An incoming bidirectional stream + BidiStream(SessionId, BidiStream), + /// An incoming HTTP/3 request, passed through a webtransport session. + /// + /// This makes it possible to respond to multiple CONNECT requests + Request(Request<()>, RequestStream), +} + +/// Future for [`Connection::read_datagram`] +pub struct ReadDatagram<'a, C, B> +where + C: quic::Connection, + B: Buf, +{ + conn: &'a Mutex>, + _marker: PhantomData, +} + +impl<'a, C, B> Future for ReadDatagram<'a, C, B> +where + C: quic::Connection + RecvDatagramExt, + B: Buf, +{ + type Output = Result, Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut conn = self.conn.lock().unwrap(); + match ready!(conn.inner.conn.poll_accept_datagram(cx))? { + Some(v) => { + let datagram = Datagram::decode(v)?; + Poll::Ready(Ok(Some(( + datagram.stream_id().into(), + datagram.into_payload(), + )))) + } + None => Poll::Ready(Ok(None)), + } + } +} + +/// Future for [`WebTransportSession::accept_uni`] +pub struct AcceptUni<'a, C, B> +where + C: quic::Connection, + B: Buf, +{ + conn: &'a Mutex>, +} + +impl<'a, C, B> Future for AcceptUni<'a, C, B> +where + C: quic::Connection, + B: Buf, +{ + type Output = Result)>, Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut conn = self.conn.lock().unwrap(); + conn.inner.poll_accept_recv(cx)?; + + // Get the currently available streams + let streams = conn.inner.accepted_streams_mut(); + if let Some((id, stream)) = streams.wt_uni_streams.pop() { + return Poll::Ready(Ok(Some((id, RecvStream::new(stream))))); + } + + Poll::Pending + } +} + +fn validate_wt_connect(request: &Request<()>) -> bool { + let protocol = request.extensions().get::(); + matches!((request.method(), protocol), (&Method::CONNECT, Some(p)) if p == &Protocol::WEB_TRANSPORT) +} diff --git a/h3-webtransport/src/stream.rs b/h3-webtransport/src/stream.rs new file mode 100644 index 00000000..4f29f6f5 --- /dev/null +++ b/h3-webtransport/src/stream.rs @@ -0,0 +1,382 @@ +use std::task::Poll; + +use bytes::{Buf, Bytes}; +use h3::{quic, stream::BufRecvStream}; +use pin_project_lite::pin_project; +use tokio::io::ReadBuf; + +pin_project! { + /// WebTransport receive stream + pub struct RecvStream { + #[pin] + stream: BufRecvStream, + } +} + +impl RecvStream { + #[allow(missing_docs)] + pub fn new(stream: BufRecvStream) -> Self { + Self { stream } + } +} + +impl quic::RecvStream for RecvStream +where + S: quic::RecvStream, + B: Buf, +{ + type Buf = Bytes; + + type Error = S::Error; + + fn poll_data( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll, Self::Error>> { + self.stream.poll_data(cx) + } + + fn stop_sending(&mut self, error_code: u64) { + self.stream.stop_sending(error_code) + } + + fn recv_id(&self) -> quic::StreamId { + self.stream.recv_id() + } +} + +impl futures_util::io::AsyncRead for RecvStream +where + BufRecvStream: futures_util::io::AsyncRead, +{ + fn poll_read( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut [u8], + ) -> Poll> { + let p = self.project(); + p.stream.poll_read(cx, buf) + } +} + +impl tokio::io::AsyncRead for RecvStream +where + BufRecvStream: tokio::io::AsyncRead, +{ + fn poll_read( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let p = self.project(); + p.stream.poll_read(cx, buf) + } +} + +pin_project! { + /// WebTransport send stream + pub struct SendStream { + #[pin] + stream: BufRecvStream, + } +} + +impl std::fmt::Debug for SendStream { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SendStream") + .field("stream", &self.stream) + .finish() + } +} + +impl SendStream { + #[allow(missing_docs)] + pub(crate) fn new(stream: BufRecvStream) -> Self { + Self { stream } + } +} + +impl quic::SendStreamUnframed for SendStream +where + S: quic::SendStreamUnframed, + B: Buf, +{ + fn poll_send( + &mut self, + cx: &mut std::task::Context<'_>, + buf: &mut D, + ) -> Poll> { + self.stream.poll_send(cx, buf) + } +} + +impl quic::SendStream for SendStream +where + S: quic::SendStream, + B: Buf, +{ + type Error = S::Error; + + fn poll_finish(&mut self, cx: &mut std::task::Context<'_>) -> Poll> { + self.stream.poll_finish(cx) + } + + fn reset(&mut self, reset_code: u64) { + self.stream.reset(reset_code) + } + + fn send_id(&self) -> quic::StreamId { + self.stream.send_id() + } + + fn send_data>>(&mut self, data: T) -> Result<(), Self::Error> { + self.stream.send_data(data) + } + + fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll> { + self.stream.poll_ready(cx) + } +} + +impl futures_util::io::AsyncWrite for SendStream +where + BufRecvStream: futures_util::io::AsyncWrite, +{ + fn poll_write( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> Poll> { + let p = self.project(); + p.stream.poll_write(cx, buf) + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let p = self.project(); + p.stream.poll_flush(cx) + } + + fn poll_close( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let p = self.project(); + p.stream.poll_close(cx) + } +} + +impl tokio::io::AsyncWrite for SendStream +where + BufRecvStream: tokio::io::AsyncWrite, +{ + fn poll_write( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> Poll> { + let p = self.project(); + p.stream.poll_write(cx, buf) + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let p = self.project(); + p.stream.poll_flush(cx) + } + + fn poll_shutdown( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let p = self.project(); + p.stream.poll_shutdown(cx) + } +} + +pin_project! { + /// Combined send and receive stream. + /// + /// Can be split into a [`RecvStream`] and [`SendStream`] if the underlying QUIC implementation + /// supports it. + pub struct BidiStream { + #[pin] + stream: BufRecvStream, + } +} + +impl BidiStream { + pub(crate) fn new(stream: BufRecvStream) -> Self { + Self { stream } + } +} + +impl quic::SendStream for BidiStream +where + S: quic::SendStream, + B: Buf, +{ + type Error = S::Error; + + fn poll_finish(&mut self, cx: &mut std::task::Context<'_>) -> Poll> { + self.stream.poll_finish(cx) + } + + fn reset(&mut self, reset_code: u64) { + self.stream.reset(reset_code) + } + + fn send_id(&self) -> quic::StreamId { + self.stream.send_id() + } + + fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll> { + self.stream.poll_ready(cx) + } + + fn send_data>>(&mut self, data: T) -> Result<(), Self::Error> { + self.stream.send_data(data) + } +} + +impl quic::SendStreamUnframed for BidiStream +where + S: quic::SendStreamUnframed, + B: Buf, +{ + fn poll_send( + &mut self, + cx: &mut std::task::Context<'_>, + buf: &mut D, + ) -> Poll> { + self.stream.poll_send(cx, buf) + } +} + +impl quic::RecvStream for BidiStream { + type Buf = Bytes; + + type Error = S::Error; + + fn poll_data( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll, Self::Error>> { + self.stream.poll_data(cx) + } + + fn stop_sending(&mut self, error_code: u64) { + self.stream.stop_sending(error_code) + } + + fn recv_id(&self) -> quic::StreamId { + self.stream.recv_id() + } +} + +impl quic::BidiStream for BidiStream +where + S: quic::BidiStream, + B: Buf, +{ + type SendStream = SendStream; + + type RecvStream = RecvStream; + + fn split(self) -> (Self::SendStream, Self::RecvStream) { + let (send, recv) = self.stream.split(); + (SendStream::new(send), RecvStream::new(recv)) + } +} + +impl futures_util::io::AsyncRead for BidiStream +where + BufRecvStream: futures_util::io::AsyncRead, +{ + fn poll_read( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut [u8], + ) -> Poll> { + let p = self.project(); + p.stream.poll_read(cx, buf) + } +} + +impl futures_util::io::AsyncWrite for BidiStream +where + BufRecvStream: futures_util::io::AsyncWrite, +{ + fn poll_write( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> Poll> { + let p = self.project(); + p.stream.poll_write(cx, buf) + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let p = self.project(); + p.stream.poll_flush(cx) + } + + fn poll_close( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let p = self.project(); + p.stream.poll_close(cx) + } +} + +impl tokio::io::AsyncRead for BidiStream +where + BufRecvStream: tokio::io::AsyncRead, +{ + fn poll_read( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let p = self.project(); + p.stream.poll_read(cx, buf) + } +} + +impl tokio::io::AsyncWrite for BidiStream +where + BufRecvStream: tokio::io::AsyncWrite, +{ + fn poll_write( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> Poll> { + let p = self.project(); + p.stream.poll_write(cx, buf) + } + + fn poll_flush( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let p = self.project(); + p.stream.poll_flush(cx) + } + + fn poll_shutdown( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> Poll> { + let p = self.project(); + p.stream.poll_shutdown(cx) + } +} diff --git a/h3/Cargo.toml b/h3/Cargo.toml index eade4f44..bf6a48e7 100644 --- a/h3/Cargo.toml +++ b/h3/Cargo.toml @@ -19,11 +19,15 @@ categories = [ "web-programming::http-server", ] +[features] +i-implement-a-third-party-backend-and-opt-into-breaking-changes = [] + [dependencies] bytes = "1" -futures-util = { version = "0.3", default-features = false } +futures-util = { version = "0.3", default-features = false, features = ["io"] } http = "0.2.9" tokio = { version = "1", features = ["sync"] } +pin-project-lite = { version = "0.2", default_features = false } tracing = "0.1.37" fastrand = "1.9.0" diff --git a/h3/src/buf.rs b/h3/src/buf.rs index 99d12d56..c6c5617e 100644 --- a/h3/src/buf.rs +++ b/h3/src/buf.rs @@ -3,6 +3,7 @@ use std::io::IoSlice; use bytes::{Buf, Bytes}; +#[derive(Debug)] pub(crate) struct BufList { bufs: VecDeque, } @@ -32,11 +33,16 @@ impl BufList { } impl BufList { + pub fn take_first_chunk(&mut self) -> Option { + self.bufs.pop_front() + } + pub fn take_chunk(&mut self, max_len: usize) -> Option { let chunk = self .bufs .front_mut() .map(|chunk| chunk.split_to(usize::min(max_len, chunk.remaining()))); + if let Some(front) = self.bufs.front() { if front.remaining() == 0 { let _ = self.bufs.pop_front(); diff --git a/h3/src/client.rs b/h3/src/client.rs index e2643aaa..fb444664 100644 --- a/h3/src/client.rs +++ b/h3/src/client.rs @@ -13,13 +13,14 @@ use http::{request, HeaderMap, Response}; use tracing::{info, trace}; use crate::{ + config::Config, connection::{self, ConnectionInner, ConnectionState, SharedStateRef}, error::{Code, Error, ErrorLevel}, frame::FrameStream, - proto::{frame::Frame, headers::Header, push::PushId, varint::VarInt}, + proto::{frame::Frame, headers::Header, push::PushId}, qpack, quic::{self, StreamId}, - stream, + stream::{self, BufRecvStream}, }; /// Start building a new HTTP/3 client @@ -146,7 +147,7 @@ where ) -> Result, Error> { let (peer_max_field_section_size, closing) = { let state = self.conn_state.read("send request lock state"); - (state.peer_max_field_section_size, state.closing) + (state.peer_config.max_field_section_size, state.closing) }; if closing { @@ -160,7 +161,7 @@ where headers, .. } = parts; - let headers = Header::request(method, uri, headers)?; + let headers = Header::request(method, uri, headers, Default::default())?; //= https://www.rfc-editor.org/rfc/rfc9114#section-4.1 //= type=implication @@ -199,7 +200,7 @@ where let request_stream = RequestStream { inner: connection::RequestStream::new( - FrameStream::new(stream), + FrameStream::new(BufRecvStream::new(stream)), self.max_field_section_size, self.conn_state.clone(), self.send_grease_frame, @@ -253,7 +254,7 @@ where .fetch_sub(1, std::sync::atomic::Ordering::AcqRel) == 1 { - if let Some(w) = self.conn_waker.take() { + if let Some(w) = Option::take(&mut self.conn_waker) { w.wake() } self.shared_state().write("SendRequest drop").error = Some(Error::closed()); @@ -482,15 +483,13 @@ where /// # } /// ``` pub struct Builder { - max_field_section_size: u64, - send_grease: bool, + config: Config, } impl Builder { pub(super) fn new() -> Self { Builder { - max_field_section_size: VarInt::MAX.0, - send_grease: true, + config: Default::default(), } } @@ -500,7 +499,7 @@ impl Builder { /// /// [header size constraints]: https://www.rfc-editor.org/rfc/rfc9114.html#name-header-size-constraints pub fn max_field_section_size(&mut self, value: u64) -> &mut Self { - self.max_field_section_size = value; + self.config.max_field_section_size = value; self } @@ -521,13 +520,7 @@ impl Builder { Ok(( Connection { - inner: ConnectionInner::new( - quic, - self.max_field_section_size, - conn_state.clone(), - self.send_grease, - ) - .await?, + inner: ConnectionInner::new(quic, conn_state.clone(), self.config).await?, sent_closing: None, recv_closing: None, }, @@ -535,10 +528,10 @@ impl Builder { open, conn_state, conn_waker, - max_field_section_size: self.max_field_section_size, + max_field_section_size: self.config.max_field_section_size, sender_count: Arc::new(AtomicUsize::new(1)), + send_grease_frame: self.config.send_grease, _buf: PhantomData, - send_grease_frame: self.send_grease, }, )) } @@ -571,7 +564,6 @@ impl Builder { /// # async fn doc(mut req_stream: RequestStream) -> Result<(), Box> /// # where /// # T: quic::RecvStream, -/// # B: Buf, /// # { /// // Prepare the HTTP request to send to the server /// let request = Request::get("https://www.example.com/").body(())?; diff --git a/h3/src/config.rs b/h3/src/config.rs new file mode 100644 index 00000000..a1dbe89b --- /dev/null +++ b/h3/src/config.rs @@ -0,0 +1,65 @@ +use crate::proto::varint::VarInt; + +/// Configures the HTTP/3 connection +#[derive(Debug, Clone, Copy)] +#[non_exhaustive] +pub struct Config { + /// Just like in HTTP/2, HTTP/3 also uses the concept of "grease" + /// to prevent potential interoperability issues in the future. + /// In HTTP/3, the concept of grease is used to ensure that the protocol can evolve + /// and accommodate future changes without breaking existing implementations. + pub(crate) send_grease: bool, + /// The MAX_FIELD_SECTION_SIZE in HTTP/3 refers to the maximum size of the dynamic table used in HPACK compression. + /// HPACK is the compression algorithm used in HTTP/3 to reduce the size of the header fields in HTTP requests and responses. + + /// In HTTP/3, the MAX_FIELD_SECTION_SIZE is set to 12. + /// This means that the dynamic table used for HPACK compression can have a maximum size of 2^12 bytes, which is 4KB. + pub(crate) max_field_section_size: u64, + + /// https://datatracker.ietf.org/doc/html/draft-ietf-webtrans-http3/#section-3.1 + /// Sets `SETTINGS_ENABLE_WEBTRANSPORT` if enabled + pub(crate) enable_webtransport: bool, + /// https://www.rfc-editor.org/info/rfc8441 defines an extended CONNECT method in Section 4, + /// enabled by the SETTINGS_ENABLE_CONNECT_PROTOCOL parameter. + /// That parameter is only defined for HTTP/2. + /// for extended CONNECT in HTTP/3; instead, the SETTINGS_ENABLE_WEBTRANSPORT setting implies that an endpoint supports extended CONNECT. + pub(crate) enable_extended_connect: bool, + /// Enable HTTP Datagrams, see https://datatracker.ietf.org/doc/rfc9297/ for details + pub(crate) enable_datagram: bool, + /// The maximum number of concurrent streams that can be opened by the peer. + pub(crate) max_webtransport_sessions: u64, +} + +impl Config { + /// https://datatracker.ietf.org/doc/html/draft-ietf-webtrans-http3/#section-3.1 + /// Sets `SETTINGS_ENABLE_WEBTRANSPORT` if enabled + pub fn enable_webtransport(&self) -> bool { + self.enable_webtransport + } + + /// Enable HTTP Datagrams, see https://datatracker.ietf.org/doc/rfc9297/ for details + pub fn enable_datagram(&self) -> bool { + self.enable_datagram + } + + /// https://www.rfc-editor.org/info/rfc8441 defines an extended CONNECT method in Section 4, + /// enabled by the SETTINGS_ENABLE_CONNECT_PROTOCOL parameter. + /// That parameter is only defined for HTTP/2. + /// for extended CONNECT in HTTP/3; instead, the SETTINGS_ENABLE_WEBTRANSPORT setting implies that an endpoint supports extended CONNECT. + pub fn enable_extended_connect(&self) -> bool { + self.enable_extended_connect + } +} + +impl Default for Config { + fn default() -> Self { + Self { + max_field_section_size: VarInt::MAX.0, + send_grease: true, + enable_webtransport: false, + enable_extended_connect: false, + enable_datagram: false, + max_webtransport_sessions: 0, + } + } +} diff --git a/h3/src/connection.rs b/h3/src/connection.rs index abf7cd92..8c455657 100644 --- a/h3/src/connection.rs +++ b/h3/src/connection.rs @@ -7,9 +7,11 @@ use std::{ use bytes::{Buf, Bytes, BytesMut}; use futures_util::{future, ready}; use http::HeaderMap; -use tracing::warn; +use stream::WriteBuf; +use tracing::{trace, warn}; use crate::{ + config::Config, error::{Code, Error}, frame::FrameStream, proto::{ @@ -20,13 +22,15 @@ use crate::{ }, qpack, quic::{self, SendStream as _}, - stream::{self, AcceptRecvStream, AcceptedRecvStream}, + stream::{self, AcceptRecvStream, AcceptedRecvStream, BufRecvStream, UniStreamHeader}, + webtransport::SessionId, }; #[doc(hidden)] +#[non_exhaustive] pub struct SharedState { - // maximum size for a header we send - pub peer_max_field_section_size: u64, + // Peer settings + pub peer_config: Config, // connection-wide error, concerns all RequestStreams and drivers pub error: Option, // Has a GOAWAY frame been sent or received? @@ -50,13 +54,14 @@ impl SharedStateRef { impl Default for SharedStateRef { fn default() -> Self { Self(Arc::new(RwLock::new(SharedState { - peer_max_field_section_size: VarInt::MAX.0, + peer_config: Config::default(), error: None, closing: false, }))) } } +#[allow(missing_docs)] pub trait ConnectionState { fn shared_state(&self) -> &SharedStateRef; @@ -69,33 +74,76 @@ pub trait ConnectionState { } } +#[allow(missing_docs)] +pub struct AcceptedStreams +where + C: quic::Connection, + B: Buf, +{ + #[allow(missing_docs)] + pub wt_uni_streams: Vec<(SessionId, BufRecvStream)>, +} + +impl Default for AcceptedStreams +where + C: quic::Connection, + B: Buf, +{ + fn default() -> Self { + Self { + wt_uni_streams: Default::default(), + } + } +} + +#[allow(missing_docs)] pub struct ConnectionInner where C: quic::Connection, B: Buf, { pub(super) shared: SharedStateRef, - conn: C, + /// TODO: breaking encapsulation just to see if we can get this to work, will fix before merging + pub conn: C, control_send: C::SendStream, control_recv: Option>, decoder_recv: Option>, encoder_recv: Option>, - pending_recv_streams: Vec>, + /// Buffers incoming uni/recv streams which have yet to be claimed. + /// + /// This is opposed to discarding them by returning in `poll_accept_recv`, which may cause them to be missed by something else polling. + /// + /// See: https://datatracker.ietf.org/doc/html/draft-ietf-webtrans-http3/#section-4.5 + /// + /// In WebTransport over HTTP/3, the client MAY send its SETTINGS frame, as well as + /// multiple WebTransport CONNECT requests, WebTransport data streams and WebTransport + /// datagrams, all within a single flight. As those can arrive out of order, a WebTransport + /// server could be put into a situation where it receives a stream or a datagram without a + /// corresponding session. Similarly, a client may receive a server-initiated stream or a + /// datagram before receiving the CONNECT response headers from the server.To handle this + /// case, WebTransport endpoints SHOULD buffer streams and datagrams until those can be + /// associated with an established session. To avoid resource exhaustion, the endpoints + /// MUST limit the number of buffered streams and datagrams. When the number of buffered + /// streams is exceeded, a stream SHALL be closed by sending a RESET_STREAM and/or + /// STOP_SENDING with the H3_WEBTRANSPORT_BUFFERED_STREAM_REJECTED error code. When the + /// number of buffered datagrams is exceeded, a datagram SHALL be dropped. It is up to an + /// implementation to choose what stream or datagram to discard. + accepted_streams: AcceptedStreams, + + pending_recv_streams: Vec>, + got_peer_settings: bool, - pub(super) send_grease_frame: bool, + pub send_grease_frame: bool, + pub config: Config, } -impl ConnectionInner +impl ConnectionInner where C: quic::Connection, B: Buf, { - pub async fn new( - mut conn: C, - max_field_section_size: u64, - shared: SharedStateRef, - grease: bool, - ) -> Result { + /// Initiates the connection and opens a control stream + pub async fn new(mut conn: C, shared: SharedStateRef, config: Config) -> Result { //= https://www.rfc-editor.org/rfc/rfc9114#section-6.2 //# Endpoints SHOULD create the HTTP control stream as well as the //# unidirectional streams required by mandatory extensions (such as the @@ -106,11 +154,33 @@ where .map_err(|e| Code::H3_STREAM_CREATION_ERROR.with_transport(e))?; let mut settings = Settings::default(); + + settings + .insert( + SettingId::MAX_HEADER_LIST_SIZE, + config.max_field_section_size, + ) + .map_err(|e| Code::H3_INTERNAL_ERROR.with_cause(e))?; + + settings + .insert( + SettingId::ENABLE_CONNECT_PROTOCOL, + config.enable_extended_connect as u64, + ) + .map_err(|e| Code::H3_INTERNAL_ERROR.with_cause(e))?; + settings + .insert( + SettingId::ENABLE_WEBTRANSPORT, + config.enable_webtransport as u64, + ) + .map_err(|e| Code::H3_INTERNAL_ERROR.with_cause(e))?; settings - .insert(SettingId::MAX_HEADER_LIST_SIZE, max_field_section_size) + .insert(SettingId::H3_DATAGRAM, config.enable_datagram as u64) .map_err(|e| Code::H3_INTERNAL_ERROR.with_cause(e))?; - if grease { + tracing::debug!("Sending server settings: {:#x?}", settings); + + if config.send_grease { // Grease Settings (https://www.rfc-editor.org/rfc/rfc9114.html#name-defined-settings-parameters) //= https://www.rfc-editor.org/rfc/rfc9114#section-7.2.4.1 //# Setting identifiers of the format 0x1f * N + 0x21 for non-negative @@ -156,9 +226,10 @@ where //# Endpoints MUST NOT require any data to be received from //# the peer prior to sending the SETTINGS frame; settings MUST be sent //# as soon as the transport is ready to send data. + trace!("Sending Settings frame: {:#x?}", settings); stream::write( &mut control_send, - (StreamType::CONTROL, Frame::Settings(settings)), + WriteBuf::from(UniStreamHeader::Control(settings)), ) .await?; @@ -176,10 +247,12 @@ where encoder_recv: None, pending_recv_streams: Vec::with_capacity(3), got_peer_settings: false, - send_grease_frame: grease, + send_grease_frame: config.send_grease, + config, + accepted_streams: Default::default(), }; // start a grease stream - if grease { + if config.send_grease { //= https://www.rfc-editor.org/rfc/rfc9114#section-7.2.8 //= type=implication //# Frame types of the format 0x1f * N + 0x21 for non-negative integer @@ -193,6 +266,7 @@ where } /// Send GOAWAY with specified max_id, iff max_id is smaller than the previous one. + pub async fn shutdown( &mut self, sent_closing: &mut Option, @@ -220,6 +294,7 @@ where stream::write(&mut self.control_send, Frame::Goaway(max_id.into())).await } + #[allow(missing_docs)] pub fn poll_accept_request( &mut self, cx: &mut Context<'_>, @@ -231,26 +306,31 @@ where } } + // Accept the request by accepting the next bidirectional stream // .into().into() converts the impl QuicError into crate::error::Error. // The `?` operator doesn't work here for some reason. self.conn.poll_accept_bidi(cx).map_err(|e| e.into().into()) } - pub fn poll_accept_recv(&mut self, cx: &mut Context<'_>) -> Poll> { + /// Polls incoming streams + /// + /// Accepted streams which are not control, decoder, or encoder streams are buffer in `accepted_recv_streams` + pub fn poll_accept_recv(&mut self, cx: &mut Context<'_>) -> Result<(), Error> { if let Some(ref e) = self.shared.read("poll_accept_request").error { - return Poll::Ready(Err(e.clone())); + return Err(e.clone()); } + // Get all currently pending streams loop { match self.conn.poll_accept_recv(cx)? { Poll::Ready(Some(stream)) => self .pending_recv_streams .push(AcceptRecvStream::new(stream)), Poll::Ready(None) => { - return Poll::Ready(Err(Code::H3_GENERAL_PROTOCOL_ERROR.with_reason( + return Err(Code::H3_GENERAL_PROTOCOL_ERROR.with_reason( "Connection closed unexpected", crate::error::ErrorLevel::ConnectionError, - ))) + )) } Poll::Pending => break, } @@ -275,6 +355,7 @@ where .pending_recv_streams .remove(index - removed) .into_stream()?; + match stream { //= https://www.rfc-editor.org/rfc/rfc9114#section-6.2.1 //# Only one control stream per peer is permitted; @@ -282,26 +363,31 @@ where //# treated as a connection error of type H3_STREAM_CREATION_ERROR. AcceptedRecvStream::Control(s) => { if self.control_recv.is_some() { - return Poll::Ready(Err( + return Err( self.close(Code::H3_STREAM_CREATION_ERROR, "got two control streams") - )); + ); } self.control_recv = Some(s); } enc @ AcceptedRecvStream::Encoder(_) => { if let Some(_prev) = self.encoder_recv.replace(enc) { - return Poll::Ready(Err( + return Err( self.close(Code::H3_STREAM_CREATION_ERROR, "got two encoder streams") - )); + ); } } dec @ AcceptedRecvStream::Decoder(_) => { if let Some(_prev) = self.decoder_recv.replace(dec) { - return Poll::Ready(Err( + return Err( self.close(Code::H3_STREAM_CREATION_ERROR, "got two decoder streams") - )); + ); } } + AcceptedRecvStream::WebTransportUni(id, s) if self.config.enable_webtransport => { + // Store until someone else picks it up, like a webtransport session which is + // not yet established. + self.accepted_streams.wt_uni_streams.push((id, s)) + } //= https://www.rfc-editor.org/rfc/rfc9114#section-6.2.3 //= type=implication @@ -311,28 +397,27 @@ where } } - Poll::Pending + Ok(()) } + /// Waits for the control stream to be received and reads subsequent frames. pub fn poll_control(&mut self, cx: &mut Context<'_>) -> Poll, Error>> { if let Some(ref e) = self.shared.read("poll_accept_request").error { return Poll::Ready(Err(e.clone())); } - loop { - match self.poll_accept_recv(cx) { - Poll::Ready(Ok(_)) => continue, - Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), - Poll::Pending if self.control_recv.is_none() => return Poll::Pending, - _ => break, + let recv = { + // TODO + self.poll_accept_recv(cx)?; + if let Some(v) = &mut self.control_recv { + v + } else { + // Try later + return Poll::Pending; } - } + }; - let recvd = ready!(self - .control_recv - .as_mut() - .expect("control_recv") - .poll_next(cx))?; + let recvd = ready!(recv.poll_next(cx))?; let res = match recvd { //= https://www.rfc-editor.org/rfc/rfc9114#section-6.2.1 @@ -367,11 +452,26 @@ where //= type=implication //# Endpoints MUST NOT consider such settings to have //# any meaning upon receipt. - self.shared - .write("connection settings write") - .peer_max_field_section_size = settings + let mut shared = self.shared.write("connection settings write"); + shared.peer_config.max_field_section_size = settings .get(SettingId::MAX_HEADER_LIST_SIZE) .unwrap_or(VarInt::MAX.0); + + shared.peer_config.enable_webtransport = + settings.get(SettingId::ENABLE_WEBTRANSPORT).unwrap_or(0) != 0; + + shared.peer_config.max_webtransport_sessions = settings + .get(SettingId::WEBTRANSPORT_MAX_SESSIONS) + .unwrap_or(0); + + shared.peer_config.enable_datagram = + settings.get(SettingId::H3_DATAGRAM).unwrap_or(0) != 0; + + shared.peer_config.enable_extended_connect = settings + .get(SettingId::ENABLE_CONNECT_PROTOCOL) + .unwrap_or(0) + != 0; + Ok(Frame::Settings(settings)) } f @ Frame::Goaway(_) => Ok(f), @@ -524,8 +624,14 @@ where Err(err) => warn!("grease stream error on close {}", err), }; } + + #[allow(missing_docs)] + pub fn accepted_streams_mut(&mut self) -> &mut AcceptedStreams { + &mut self.accepted_streams + } } +#[allow(missing_docs)] pub struct RequestStream { pub(super) stream: FrameStream, pub(super) trailers: Option, @@ -535,6 +641,7 @@ pub struct RequestStream { } impl RequestStream { + #[allow(missing_docs)] pub fn new( stream: FrameStream, max_field_section_size: u64, @@ -675,6 +782,7 @@ where Ok(Some(Header::try_from(fields)?.into_fields())) } + #[allow(missing_docs)] pub fn stop_sending(&mut self, err_code: Code) { self.stream.stop_sending(err_code); } @@ -707,7 +815,8 @@ where let max_mem_size = self .conn_state .read("send_trailers shared state read") - .peer_max_field_section_size; + .peer_config + .max_field_section_size; //= https://www.rfc-editor.org/rfc/rfc9114#section-4.2.2 //# An implementation that @@ -729,6 +838,7 @@ where self.stream.reset(code.into()); } + #[allow(missing_docs)] pub async fn finish(&mut self) -> Result<(), Error> { if self.send_grease_frame { // send a grease frame once per Connection @@ -737,9 +847,7 @@ where .map_err(|e| self.maybe_conn_err(e))?; self.send_grease_frame = false; } - future::poll_fn(|cx| self.stream.poll_ready(cx)) - .await - .map_err(|e| self.maybe_conn_err(e))?; + future::poll_fn(|cx| self.stream.poll_finish(cx)) .await .map_err(|e| self.maybe_conn_err(e)) diff --git a/h3/src/error.rs b/h3/src/error.rs index 5c693cd6..02663f7d 100644 --- a/h3/src/error.rs +++ b/h3/src/error.rs @@ -12,6 +12,7 @@ pub(crate) type TransportError = Box; /// A general error that can occur when handling the HTTP/3 protocol. #[derive(Clone)] pub struct Error { + /// The error kind. pub(crate) inner: Box, } @@ -37,6 +38,7 @@ impl PartialEq for Code { } } +/// The error kind. #[derive(Clone)] pub(crate) struct ErrorImpl { pub(crate) kind: Kind, @@ -110,6 +112,9 @@ macro_rules! codes { } codes! { + /// Datagram or capsule parse error + /// See: https://www.rfc-editor.org/rfc/rfc9297#section-5.2 + (0x33, H3_DATAGRAM_ERROR); /// No error. This is used when the connection or stream needs to be /// closed, but there is no error to signal. (0x100, H3_NO_ERROR); @@ -272,7 +277,6 @@ impl Error { matches!(&self.inner.kind, Kind::HeaderTooBig { .. }) } - #[cfg(test)] #[doc(hidden)] pub fn kind(&self) -> Kind { self.inner.kind.clone() diff --git a/h3/src/ext.rs b/h3/src/ext.rs new file mode 100644 index 00000000..1ef8a14e --- /dev/null +++ b/h3/src/ext.rs @@ -0,0 +1,109 @@ +//! Extensions for the HTTP/3 protocol. + +use std::convert::TryFrom; +use std::str::FromStr; + +use bytes::{Buf, Bytes}; + +use crate::{ + error::Code, + proto::{stream::StreamId, varint::VarInt}, + Error, +}; + +/// Describes the `:protocol` pseudo-header for extended connect +/// +/// See: [https://www.rfc-editor.org/rfc/rfc8441#section-4] +#[derive(Copy, PartialEq, Debug, Clone)] +pub struct Protocol(ProtocolInner); + +impl Protocol { + /// WebTransport protocol + pub const WEB_TRANSPORT: Protocol = Protocol(ProtocolInner::WebTransport); +} + +#[derive(Copy, PartialEq, Debug, Clone)] +enum ProtocolInner { + WebTransport, +} + +/// Error when parsing the protocol +pub struct InvalidProtocol; + +impl FromStr for Protocol { + type Err = InvalidProtocol; + + fn from_str(s: &str) -> Result { + match s { + "webtransport" => Ok(Self(ProtocolInner::WebTransport)), + _ => Err(InvalidProtocol), + } + } +} + +/// HTTP datagram frames +/// See: https://www.rfc-editor.org/rfc/rfc9297#section-2.1 +pub struct Datagram { + /// Stream id divided by 4 + stream_id: StreamId, + /// The data contained in the datagram + payload: B, +} + +impl Datagram +where + B: Buf, +{ + /// Creates a new datagram frame + pub fn new(stream_id: StreamId, payload: B) -> Self { + assert!( + stream_id.into_inner() % 4 == 0, + "StreamId is not divisible by 4" + ); + Self { stream_id, payload } + } + + /// Decodes a datagram frame from the QUIC datagram + pub fn decode(mut buf: B) -> Result { + let q_stream_id = VarInt::decode(&mut buf) + .map_err(|_| Code::H3_DATAGRAM_ERROR.with_cause("Malformed datagram frame"))?; + + //= https://www.rfc-editor.org/rfc/rfc9297#section-2.1 + // Quarter Stream ID: A variable-length integer that contains the value of the client-initiated bidirectional + // stream that this datagram is associated with divided by four (the division by four stems + // from the fact that HTTP requests are sent on client-initiated bidirectional streams, + // which have stream IDs that are divisible by four). The largest legal QUIC stream ID + // value is 262-1, so the largest legal value of the Quarter Stream ID field is 260-1. + // Receipt of an HTTP/3 Datagram that includes a larger value MUST be treated as an HTTP/3 + // connection error of type H3_DATAGRAM_ERROR (0x33). + let stream_id = StreamId::try_from(u64::from(q_stream_id) * 4) + .map_err(|_| Code::H3_DATAGRAM_ERROR.with_cause("Invalid stream id"))?; + + let payload = buf; + + Ok(Self { stream_id, payload }) + } + + #[inline] + /// Returns the associated stream id of the datagram + pub fn stream_id(&self) -> StreamId { + self.stream_id + } + + #[inline] + /// Returns the datagram payload + pub fn payload(&self) -> &B { + &self.payload + } + + /// Encode the datagram to wire format + pub fn encode(self, buf: &mut D) { + (VarInt::from(self.stream_id) / 4).encode(buf); + buf.put(self.payload); + } + + /// Returns the datagram payload + pub fn into_payload(self) -> B { + self.payload + } +} diff --git a/h3/src/frame.rs b/h3/src/frame.rs index 15d3f634..3c8a58dc 100644 --- a/h3/src/frame.rs +++ b/h3/src/frame.rs @@ -1,11 +1,11 @@ -use std::marker::PhantomData; use std::task::{Context, Poll}; -use bytes::{Buf, Bytes}; +use bytes::Buf; use futures_util::ready; use tracing::trace; +use crate::stream::{BufRecvStream, WriteBuf}; use crate::{ buf::BufList, error::TransportError, @@ -14,34 +14,30 @@ use crate::{ stream::StreamId, }, quic::{BidiStream, RecvStream, SendStream}, - stream::WriteBuf, }; +/// Decodes Frames from the underlying QUIC stream pub struct FrameStream { - stream: S, - bufs: BufList, + pub stream: BufRecvStream, + // Already read data from the stream decoder: FrameDecoder, remaining_data: usize, - /// Set to true when `stream` reaches the end. - is_eos: bool, - _phantom_buffer: PhantomData, } impl FrameStream { - pub fn new(stream: S) -> Self { - Self::with_bufs(stream, BufList::new()) - } - - pub(crate) fn with_bufs(stream: S, bufs: BufList) -> Self { + pub fn new(stream: BufRecvStream) -> Self { Self { stream, - bufs, decoder: FrameDecoder::default(), remaining_data: 0, - is_eos: false, - _phantom_buffer: PhantomData, } } + + /// Unwraps the Framed streamer and returns the underlying stream **without** data loss for + /// partially received/read frames. + pub fn into_inner(self) -> BufRecvStream { + self.stream + } } impl FrameStream @@ -60,18 +56,22 @@ where loop { let end = self.try_recv(cx)?; - return match self.decoder.decode(&mut self.bufs)? { + return match self.decoder.decode(self.stream.buf_mut())? { Some(Frame::Data(PayloadLen(len))) => { self.remaining_data = len; Poll::Ready(Ok(Some(Frame::Data(PayloadLen(len))))) } + frame @ Some(Frame::WebTransportStream(_)) => { + self.remaining_data = usize::MAX; + Poll::Ready(Ok(frame)) + } Some(frame) => Poll::Ready(Ok(Some(frame))), None => match end { // Received a chunk but frame is incomplete, poll until we get `Pending`. Poll::Ready(false) => continue, Poll::Pending => Poll::Pending, Poll::Ready(true) => { - if self.bufs.has_remaining() { + if self.stream.buf_mut().has_remaining() { // Reached the end of receive stream, but there is still some data: // The frame is incomplete. Poll::Ready(Err(FrameStreamError::UnexpectedEnd)) @@ -84,6 +84,10 @@ where } } + /// Retrieves the next piece of data in an incoming data packet or webtransport stream + /// + /// + /// WebTransport bidirectional payload has no finite length and is processed until the end of the stream. pub fn poll_data( &mut self, cx: &mut Context<'_>, @@ -93,13 +97,14 @@ where }; let end = ready!(self.try_recv(cx))?; - let data = self.bufs.take_chunk(self.remaining_data); + let data = self.stream.buf_mut().take_chunk(self.remaining_data); match (data, end) { (None, true) => Poll::Ready(Ok(None)), (None, false) => Poll::Pending, (Some(d), true) - if d.remaining() < self.remaining_data && !self.bufs.has_remaining() => + if d.remaining() < self.remaining_data + && !self.stream.buf_mut().has_remaining() => { Poll::Ready(Err(FrameStreamError::UnexpectedEnd)) } @@ -110,6 +115,7 @@ where } } + /// Stops the underlying stream with the provided error code pub(crate) fn stop_sending(&mut self, error_code: crate::error::Code) { self.stream.stop_sending(error_code.into()); } @@ -119,26 +125,23 @@ where } pub(crate) fn is_eos(&self) -> bool { - self.is_eos && !self.bufs.has_remaining() + self.stream.is_eos() && !self.stream.buf().has_remaining() } fn try_recv(&mut self, cx: &mut Context<'_>) -> Poll> { - if self.is_eos { + if self.stream.is_eos() { return Poll::Ready(Ok(true)); } - match self.stream.poll_data(cx) { + match self.stream.poll_read(cx) { Poll::Ready(Err(e)) => Poll::Ready(Err(FrameStreamError::Quic(e.into()))), Poll::Pending => Poll::Pending, - Poll::Ready(Ok(None)) => { - self.is_eos = true; - Poll::Ready(Ok(true)) - } - Poll::Ready(Ok(Some(mut d))) => { - self.bufs.push_bytes(&mut d); - Poll::Ready(Ok(false)) - } + Poll::Ready(Ok(eos)) => Poll::Ready(Ok(eos)), } } + + pub fn id(&self) -> StreamId { + self.stream.recv_id() + } } impl SendStream for FrameStream @@ -164,8 +167,8 @@ where self.stream.reset(reset_code) } - fn id(&self) -> StreamId { - self.stream.id() + fn send_id(&self) -> StreamId { + self.stream.send_id() } } @@ -179,19 +182,13 @@ where ( FrameStream { stream: send, - bufs: BufList::new(), decoder: FrameDecoder::default(), remaining_data: 0, - is_eos: false, - _phantom_buffer: PhantomData, }, FrameStream { stream: recv, - bufs: self.bufs, decoder: self.decoder, remaining_data: self.remaining_data, - is_eos: self.is_eos, - _phantom_buffer: PhantomData, }, ) } @@ -266,7 +263,7 @@ mod tests { use super::*; use assert_matches::assert_matches; - use bytes::{BufMut, BytesMut}; + use bytes::{BufMut, Bytes, BytesMut}; use futures_util::future::poll_fn; use std::{collections::VecDeque, fmt, sync::Arc}; @@ -373,7 +370,7 @@ mod tests { Frame::headers(&b"trailer"[..]).encode_with_payload(&mut buf); recv.chunk(buf.freeze()); - let mut stream: FrameStream<_, ()> = FrameStream::new(recv); + let mut stream: FrameStream<_, ()> = FrameStream::new(BufRecvStream::new(recv)); assert_poll_matches!(|cx| stream.poll_next(cx), Ok(Some(Frame::Headers(_)))); assert_poll_matches!( @@ -395,7 +392,7 @@ mod tests { Frame::headers(&b"header"[..]).encode_with_payload(&mut buf); let mut buf = buf.freeze(); recv.chunk(buf.split_to(buf.len() - 1)); - let mut stream: FrameStream<_, ()> = FrameStream::new(recv); + let mut stream: FrameStream<_, ()> = FrameStream::new(BufRecvStream::new(recv)); assert_poll_matches!( |cx| stream.poll_next(cx), @@ -414,7 +411,7 @@ mod tests { FrameType::DATA.encode(&mut buf); VarInt::from(4u32).encode(&mut buf); recv.chunk(buf.freeze()); - let mut stream: FrameStream<_, ()> = FrameStream::new(recv); + let mut stream: FrameStream<_, ()> = FrameStream::new(BufRecvStream::new(recv)); assert_poll_matches!( |cx| stream.poll_next(cx), @@ -436,7 +433,7 @@ mod tests { let mut buf = buf.freeze(); recv.chunk(buf.split_to(buf.len() - 2)); recv.chunk(buf); - let mut stream: FrameStream<_, ()> = FrameStream::new(recv); + let mut stream: FrameStream<_, ()> = FrameStream::new(BufRecvStream::new(recv)); // We get the total size of data about to be received assert_poll_matches!( @@ -465,7 +462,7 @@ mod tests { VarInt::from(4u32).encode(&mut buf); buf.put_slice(&b"b"[..]); recv.chunk(buf.freeze()); - let mut stream: FrameStream<_, ()> = FrameStream::new(recv); + let mut stream: FrameStream<_, ()> = FrameStream::new(BufRecvStream::new(recv)); assert_poll_matches!( |cx| stream.poll_next(cx), @@ -497,7 +494,7 @@ mod tests { Frame::Data(Bytes::from("body")).encode_with_payload(&mut buf); recv.chunk(buf.freeze()); - let mut stream: FrameStream<_, ()> = FrameStream::new(recv); + let mut stream: FrameStream<_, ()> = FrameStream::new(BufRecvStream::new(recv)); assert_poll_matches!( |cx| stream.poll_next(cx), @@ -519,7 +516,7 @@ mod tests { buf.put_slice(&b"bo"[..]); recv.chunk(buf.clone().freeze()); - let mut stream: FrameStream<_, ()> = FrameStream::new(recv); + let mut stream: FrameStream<_, ()> = FrameStream::new(BufRecvStream::new(recv)); assert_poll_matches!( |cx| stream.poll_next(cx), @@ -528,7 +525,7 @@ mod tests { buf.truncate(0); buf.put_slice(&b"dy"[..]); - stream.bufs.push_bytes(&mut buf.freeze()); + stream.stream.buf_mut().push_bytes(&mut buf.freeze()); assert_poll_matches!( |cx| to_bytes(stream.poll_data(cx)), @@ -569,6 +566,10 @@ mod tests { fn stop_sending(&mut self, _: u64) { unimplemented!() } + + fn recv_id(&self) -> StreamId { + unimplemented!() + } } #[derive(Debug)] diff --git a/h3/src/lib.rs b/h3/src/lib.rs index 1cb59cd8..7fb6496a 100644 --- a/h3/src/lib.rs +++ b/h3/src/lib.rs @@ -3,22 +3,47 @@ #![allow(clippy::derive_partial_eq_without_eq)] pub mod client; +mod config; pub mod error; +pub mod ext; pub mod quic; +pub(crate) mod request; pub mod server; pub use error::Error; mod buf; + +#[cfg(feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes")] +#[allow(missing_docs)] +pub mod connection; +#[cfg(feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes")] +#[allow(missing_docs)] +pub mod frame; +#[cfg(feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes")] +#[allow(missing_docs)] +pub mod proto; +#[cfg(feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes")] +#[allow(missing_docs)] +pub mod stream; +#[cfg(feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes")] +#[allow(missing_docs)] +pub mod webtransport; + +#[cfg(not(feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes"))] mod connection; +#[cfg(not(feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes"))] mod frame; +#[cfg(not(feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes"))] mod proto; -#[allow(dead_code)] -mod qpack; +#[cfg(not(feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes"))] mod stream; +#[cfg(not(feature = "i-implement-a-third-party-backend-and-opt-into-breaking-changes"))] +mod webtransport; +#[allow(dead_code)] +mod qpack; #[cfg(test)] mod tests; - #[cfg(test)] extern crate self as h3; diff --git a/h3/src/proto/frame.rs b/h3/src/proto/frame.rs index 340a9598..f587d01a 100644 --- a/h3/src/proto/frame.rs +++ b/h3/src/proto/frame.rs @@ -1,9 +1,14 @@ use bytes::{Buf, BufMut, Bytes}; -use std::{convert::TryInto, fmt}; +use std::{ + convert::TryInto, + fmt::{self, Debug}, +}; use tracing::trace; +use crate::webtransport::SessionId; + use super::{ - coding::Encode, + coding::{Decode, Encode}, push::{InvalidPushId, PushId}, stream::InvalidStreamId, varint::{BufExt, BufMutExt, UnexpectedEnd, VarInt}, @@ -46,13 +51,21 @@ pub enum Frame { PushPromise(PushPromise), Goaway(VarInt), MaxPushId(PushId), + /// Describes the header for a webtransport stream. + /// + /// The payload is sent streaming until the stream is closed + /// + /// Unwrap the framed streamer and read the inner stream until the end. + /// + /// Conversely, when sending, send this frame and unwrap the stream + WebTransportStream(SessionId), Grease, } /// Represents the available data len for a `Data` frame on a RecvStream /// /// Decoding received frames does not handle `Data` frames payload. Instead, receiving it -/// and passing it to the user is left under the responsability of `RequestStream`s. +/// and passing it to the user is left under the responsibility of `RequestStream`s. pub struct PayloadLen(pub usize); impl From for PayloadLen { @@ -62,11 +75,21 @@ impl From for PayloadLen { } impl Frame { - pub const MAX_ENCODED_SIZE: usize = VarInt::MAX_SIZE * 3; + pub const MAX_ENCODED_SIZE: usize = VarInt::MAX_SIZE * 7; + /// Decodes a Frame from the stream according to https://www.rfc-editor.org/rfc/rfc9114#section-7.1 pub fn decode(buf: &mut T) -> Result { let remaining = buf.remaining(); let ty = FrameType::decode(buf).map_err(|_| FrameError::Incomplete(remaining + 1))?; + + // Webtransport streams need special handling as they have no length. + // + // See: https://datatracker.ietf.org/doc/html/draft-ietf-webtrans-http3/#section-4.2 + if ty == FrameType::WEBTRANSPORT_BI_STREAM { + tracing::trace!("webtransport frame"); + return Ok(Frame::WebTransportStream(SessionId::decode(buf)?)); + } + let len = buf .get_var() .map_err(|_| FrameError::Incomplete(remaining + 1))?; @@ -80,6 +103,7 @@ impl Frame { } let mut payload = buf.take(len as usize); + trace!("frame ty: {:?}", ty); let frame = match ty { FrameType::HEADERS => Ok(Frame::Headers(payload.copy_to_bytes(len as usize))), FrameType::SETTINGS => Ok(Frame::Settings(Settings::decode(&mut payload)?)), @@ -91,11 +115,13 @@ impl Frame { | FrameType::H2_PING | FrameType::H2_WINDOW_UPDATE | FrameType::H2_CONTINUATION => Err(FrameError::UnsupportedFrame(ty.0)), + FrameType::WEBTRANSPORT_BI_STREAM | FrameType::DATA => unreachable!(), _ => { buf.advance(len as usize); Err(FrameError::UnknownFrame(ty.0)) } }; + if let Ok(frame) = &frame { trace!( "got frame {:?}, len: {}, remaining: {}", @@ -132,6 +158,11 @@ where buf.write_var(6); buf.put_slice(b"grease"); } + Frame::WebTransportStream(id) => { + FrameType::WEBTRANSPORT_BI_STREAM.encode(buf); + id.encode(buf); + // rest of the data is sent streaming + } } } } @@ -189,6 +220,7 @@ impl fmt::Debug for Frame { Frame::Goaway(id) => write!(f, "GoAway({})", id), Frame::MaxPushId(id) => write!(f, "MaxPushId({})", id), Frame::Grease => write!(f, "Grease()"), + Frame::WebTransportStream(session) => write!(f, "WebTransportStream({:?})", session), } } } @@ -207,6 +239,7 @@ where Frame::Goaway(id) => write!(f, "GoAway({})", id), Frame::MaxPushId(id) => write!(f, "MaxPushId({})", id), Frame::Grease => write!(f, "Grease()"), + Frame::WebTransportStream(_) => write!(f, "WebTransportStream()"), } } } @@ -226,6 +259,9 @@ impl PartialEq> for Frame { Frame::Goaway(x) => matches!(other, Frame::Goaway(y) if x == y), Frame::MaxPushId(x) => matches!(other, Frame::MaxPushId(y) if x == y), Frame::Grease => matches!(other, Frame::Grease), + Frame::WebTransportStream(x) => { + matches!(other, Frame::WebTransportStream(y) if x == y) + } } } } @@ -257,6 +293,8 @@ frame_types! { H2_WINDOW_UPDATE = 0x8, H2_CONTINUATION = 0x9, MAX_PUSH_ID = 0xD, + // Reserved frame types + WEBTRANSPORT_BI_STREAM = 0x41, } impl FrameType { @@ -350,7 +388,11 @@ impl SettingId { self, SettingId::MAX_HEADER_LIST_SIZE | SettingId::QPACK_MAX_TABLE_CAPACITY - | SettingId::QPACK_MAX_BLOCKED_STREAMS, + | SettingId::QPACK_MAX_BLOCKED_STREAMS + | SettingId::ENABLE_CONNECT_PROTOCOL + | SettingId::ENABLE_WEBTRANSPORT + | SettingId::WEBTRANSPORT_MAX_SESSIONS + | SettingId::H3_DATAGRAM, ) } @@ -389,9 +431,19 @@ setting_identifiers! { QPACK_MAX_TABLE_CAPACITY = 0x1, QPACK_MAX_BLOCKED_STREAMS = 0x7, MAX_HEADER_LIST_SIZE = 0x6, + // https://datatracker.ietf.org/doc/html/rfc9220#section-5 + ENABLE_CONNECT_PROTOCOL = 0x8, + // https://datatracker.ietf.org/doc/html/draft-ietf-masque-h3-datagram-05#section-9.1 + H3_DATAGRAM = 0xFFD277, + // https://datatracker.ietf.org/doc/html/draft-ietf-webtrans-http3/#section-8.2 + ENABLE_WEBTRANSPORT = 0x2B603742, + // https://datatracker.ietf.org/doc/html/draft-ietf-webtrans-http3/#section-8.2 + H3_SETTING_ENABLE_DATAGRAM_CHROME_SPECIFIC= 0xFFD277, + + WEBTRANSPORT_MAX_SESSIONS = 0x2b603743, } -const SETTINGS_LEN: usize = 4; +const SETTINGS_LEN: usize = 8; #[derive(Debug, PartialEq)] pub struct Settings { @@ -446,7 +498,7 @@ impl Settings { None } - pub(super) fn encode(&self, buf: &mut T) { + pub(crate) fn encode(&self, buf: &mut T) { self.encode_header(buf); for (id, val) in self.entries[..self.len].iter() { id.encode(buf); @@ -483,6 +535,8 @@ impl Settings { //# their receipt MUST be treated as a connection error of type //# H3_SETTINGS_ERROR. settings.insert(identifier, value)?; + } else { + tracing::warn!("Unsupported setting: {:#x?}", identifier); } } Ok(settings) @@ -594,6 +648,10 @@ mod tests { (SettingId::QPACK_MAX_TABLE_CAPACITY, 0xfad2), (SettingId::QPACK_MAX_BLOCKED_STREAMS, 0xfad3), (SettingId(95), 0), + (SettingId::NONE, 0), + (SettingId::NONE, 0), + (SettingId::NONE, 0), + (SettingId::NONE, 0), ], len: 4, }), @@ -607,6 +665,10 @@ mod tests { (SettingId::QPACK_MAX_BLOCKED_STREAMS, 0xfad3), // check without the Grease setting because this is ignored (SettingId(0), 0), + (SettingId::NONE, 0), + (SettingId::NONE, 0), + (SettingId::NONE, 0), + (SettingId::NONE, 0), ], len: 3, }), diff --git a/h3/src/proto/headers.rs b/h3/src/proto/headers.rs index 86f8de74..5a4bda13 100644 --- a/h3/src/proto/headers.rs +++ b/h3/src/proto/headers.rs @@ -8,10 +8,10 @@ use std::{ use http::{ header::{self, HeaderName, HeaderValue}, uri::{self, Authority, Parts, PathAndQuery, Scheme, Uri}, - HeaderMap, Method, StatusCode, + Extensions, HeaderMap, Method, StatusCode, }; -use crate::qpack::HeaderField; +use crate::{ext::Protocol, qpack::HeaderField}; #[derive(Debug)] #[cfg_attr(test, derive(PartialEq, Clone))] @@ -22,12 +22,18 @@ pub struct Header { #[allow(clippy::len_without_is_empty)] impl Header { - pub fn request(method: Method, uri: Uri, fields: HeaderMap) -> Result { + /// Creates a new `Header` frame data suitable for sending a request + pub fn request( + method: Method, + uri: Uri, + fields: HeaderMap, + ext: Extensions, + ) -> Result { match (uri.authority(), fields.get("host")) { (None, None) => Err(HeaderError::MissingAuthority), (Some(a), Some(h)) if a.as_str() != h => Err(HeaderError::ContradictedAuthority), _ => Ok(Self { - pseudo: Pseudo::request(method, uri), + pseudo: Pseudo::request(method, uri, ext), fields, }), } @@ -50,7 +56,9 @@ impl Header { } } - pub fn into_request_parts(self) -> Result<(Method, Uri, HeaderMap), HeaderError> { + pub fn into_request_parts( + self, + ) -> Result<(Method, Uri, Option, HeaderMap), HeaderError> { let mut uri = Uri::builder(); if let Some(path) = self.pseudo.path { @@ -92,6 +100,7 @@ impl Header { Ok(( self.pseudo.method.ok_or(HeaderError::MissingMethod)?, uri.build().map_err(HeaderError::InvalidRequest)?, + self.pseudo.protocol, self.fields, )) } @@ -221,6 +230,10 @@ impl TryFrom> for Header { Field::Header((n, v)) => { fields.append(n, v); } + Field::Protocol(p) => { + pseudo.protocol = Some(p); + pseudo.len += 1; + } } } @@ -234,6 +247,7 @@ enum Field { Authority(Authority), Path(PathAndQuery), Status(StatusCode), + Protocol(Protocol), Header((HeaderName, HeaderValue)), } @@ -277,6 +291,7 @@ impl Field { StatusCode::from_bytes(value.as_ref()) .map_err(|_| HeaderError::invalid_value(name, value))?, ), + b":protocol" => Field::Protocol(try_value(name, value)?), _ => return Err(HeaderError::invalid_name(name)), }) } @@ -316,12 +331,14 @@ struct Pseudo { // Response status: Option, + protocol: Option, + len: usize, } #[allow(clippy::len_without_is_empty)] impl Pseudo { - fn request(method: Method, uri: Uri) -> Self { + fn request(method: Method, uri: Uri, ext: Extensions) -> Self { let Parts { scheme, authority, @@ -345,7 +362,16 @@ impl Pseudo { }, ); - let len = 3 + if authority.is_some() { 1 } else { 0 }; + // If the method is connect, the `:protocol` pseudo-header MAY be defined + // + // See: [https://www.rfc-editor.org/rfc/rfc8441#section-4] + let protocol = if method == Method::CONNECT { + ext.get::().copied() + } else { + None + }; + + let len = 3 + authority.is_some() as usize + protocol.is_some() as usize; //= https://www.rfc-editor.org/rfc/rfc9114#section-4.3 //= type=implication @@ -364,6 +390,7 @@ impl Pseudo { authority, path: Some(path), status: None, + protocol, len, } } @@ -381,6 +408,7 @@ impl Pseudo { path: None, status: Some(status), len: 1, + protocol: None, } } diff --git a/h3/src/proto/stream.rs b/h3/src/proto/stream.rs index ad53a3d6..2d525167 100644 --- a/h3/src/proto/stream.rs +++ b/h3/src/proto/stream.rs @@ -5,6 +5,8 @@ use std::{ ops::Add, }; +use crate::webtransport::SessionId; + use super::{ coding::{BufExt, BufMutExt, Decode, Encode, UnexpectedEnd}, varint::VarInt, @@ -26,6 +28,8 @@ stream_types! { PUSH = 0x01, ENCODER = 0x02, DECODER = 0x03, + WEBTRANSPORT_BIDI = 0x41, + WEBTRANSPORT_UNI = 0x54, } impl StreamType { @@ -59,6 +63,7 @@ impl fmt::Display for StreamType { &StreamType::CONTROL => write!(f, "Control"), &StreamType::ENCODER => write!(f, "Encoder"), &StreamType::DECODER => write!(f, "Decoder"), + &StreamType::WEBTRANSPORT_UNI => write!(f, "WebTransportUni"), x => write!(f, "StreamType({})", x.0), } } @@ -116,7 +121,7 @@ impl StreamId { } /// Distinguishes streams of the same initiator and directionality - fn index(self) -> u64 { + pub fn index(self) -> u64 { self.0 >> 2 } @@ -128,6 +133,10 @@ impl StreamId { Dir::Uni } } + + pub(crate) fn into_inner(self) -> u64 { + self.0 + } } impl TryFrom for StreamId { @@ -154,7 +163,7 @@ impl From for VarInt { /// Invalid StreamId, for example because it's too large #[derive(Debug, PartialEq)] -pub struct InvalidStreamId(u64); +pub struct InvalidStreamId(pub(crate) u64); impl Display for InvalidStreamId { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -181,6 +190,12 @@ impl Add for StreamId { } } +impl From for StreamId { + fn from(value: SessionId) -> Self { + Self(value.into_inner()) + } +} + #[derive(Debug, Copy, Clone, Eq, PartialEq)] pub enum Side { /// The initiator of a connection diff --git a/h3/src/proto/varint.rs b/h3/src/proto/varint.rs index 9af8e28f..60cdacf5 100644 --- a/h3/src/proto/varint.rs +++ b/h3/src/proto/varint.rs @@ -1,4 +1,4 @@ -use std::{convert::TryInto, fmt}; +use std::{convert::TryInto, fmt, ops::Div}; use bytes::{Buf, BufMut}; @@ -12,6 +12,14 @@ pub use super::coding::UnexpectedEnd; #[derive(Default, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] pub struct VarInt(pub(crate) u64); +impl Div for VarInt { + type Output = Self; + + fn div(self, rhs: u64) -> Self::Output { + Self(self.0 / rhs) + } +} + impl VarInt { /// The largest representable value pub const MAX: VarInt = VarInt((1 << 62) - 1); diff --git a/h3/src/quic.rs b/h3/src/quic.rs index e4f0aaf0..de1ecf4c 100644 --- a/h3/src/quic.rs +++ b/h3/src/quic.rs @@ -7,6 +7,7 @@ use std::task::{self, Poll}; use bytes::Buf; +use crate::ext::Datagram; pub use crate::proto::stream::{InvalidStreamId, StreamId}; pub use crate::stream::WriteBuf; @@ -38,7 +39,12 @@ pub trait Connection { /// The type produced by `poll_accept_recv()` type RecvStream: RecvStream; /// A producer of outgoing Unidirectional and Bidirectional streams. - type OpenStreams: OpenStreams; + type OpenStreams: OpenStreams< + B, + SendStream = Self::SendStream, + RecvStream = Self::RecvStream, + BidiStream = Self::BidiStream, + >; /// Error type yielded by this trait methods type Error: Into>; @@ -77,6 +83,33 @@ pub trait Connection { fn close(&mut self, code: crate::error::Code, reason: &[u8]); } +/// Extends the `Connection` trait for sending datagrams +/// +/// See: https://www.rfc-editor.org/rfc/rfc9297 +pub trait SendDatagramExt { + /// The error type that can occur when sending a datagram + type Error: Into>; + + /// Send a datagram + fn send_datagram(&mut self, data: Datagram) -> Result<(), Self::Error>; +} + +/// Extends the `Connection` trait for receiving datagrams +/// +/// See: https://www.rfc-editor.org/rfc/rfc9297 +pub trait RecvDatagramExt { + /// The type of `Buf` for *raw* datagrams (without the stream_id decoded) + type Buf: Buf; + /// The error type that can occur when receiving a datagram + type Error: Into>; + + /// Poll the connection for incoming datagrams. + fn poll_accept_datagram( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll, Self::Error>>; +} + /// Trait for opening outgoing streams pub trait OpenStreams { /// The type produced by `poll_open_bidi()` @@ -122,7 +155,21 @@ pub trait SendStream { fn reset(&mut self, reset_code: u64); /// Get QUIC send stream id - fn id(&self) -> StreamId; + fn send_id(&self) -> StreamId; +} + +/// Allows sending unframed pure bytes to a stream. Similar to [`AsyncWrite`](https://docs.rs/tokio/latest/tokio/io/trait.AsyncWrite.html) +pub trait SendStreamUnframed: SendStream { + /// Attempts write data into the stream. + /// + /// Returns the number of bytes written. + /// + /// `buf` is advanced by the number of bytes written. + fn poll_send( + &mut self, + cx: &mut task::Context<'_>, + buf: &mut D, + ) -> Poll>; } /// A trait describing the "receive" actions of a QUIC stream. @@ -143,6 +190,9 @@ pub trait RecvStream { /// Send a `STOP_SENDING` QUIC code. fn stop_sending(&mut self, error_code: u64); + + /// Get QUIC send stream id + fn recv_id(&self) -> StreamId; } /// Optional trait to allow "splitting" a bidirectional stream into two sides. diff --git a/h3/src/request.rs b/h3/src/request.rs new file mode 100644 index 00000000..f705efbc --- /dev/null +++ b/h3/src/request.rs @@ -0,0 +1,92 @@ +use std::convert::TryFrom; + +use bytes::Buf; +use http::{Request, StatusCode}; + +use crate::{error::Code, proto::headers::Header, qpack, quic, server::RequestStream, Error}; + +pub struct ResolveRequest, B: Buf> { + request_stream: RequestStream, + // Ok or `REQUEST_HEADER_FIELDS_TO_LARGE` which neeeds to be sent + decoded: Result, + max_field_section_size: u64, +} + +impl> ResolveRequest { + pub fn new( + request_stream: RequestStream, + decoded: Result, + max_field_section_size: u64, + ) -> Self { + Self { + request_stream, + decoded, + max_field_section_size, + } + } + + /// Finishes the resolution of the request + pub async fn resolve( + mut self, + ) -> Result<(Request<()>, RequestStream), Error> { + let fields = match self.decoded { + Ok(v) => v.fields, + Err(cancel_size) => { + // Send and await the error response + self.request_stream + .send_response( + http::Response::builder() + .status(StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE) + .body(()) + .expect("header too big response"), + ) + .await?; + + return Err(Error::header_too_big( + cancel_size, + self.max_field_section_size, + )); + } + }; + + // Parse the request headers + let (method, uri, protocol, headers) = match Header::try_from(fields) { + Ok(header) => match header.into_request_parts() { + Ok(parts) => parts, + Err(err) => { + //= https://www.rfc-editor.org/rfc/rfc9114#section-4.1.2 + //# Malformed requests or responses that are + //# detected MUST be treated as a stream error of type H3_MESSAGE_ERROR. + let error: Error = err.into(); + self.request_stream + .stop_stream(error.try_get_code().unwrap_or(Code::H3_MESSAGE_ERROR)); + return Err(error); + } + }, + Err(err) => { + //= https://www.rfc-editor.org/rfc/rfc9114#section-4.1.2 + //# Malformed requests or responses that are + //# detected MUST be treated as a stream error of type H3_MESSAGE_ERROR. + let error: Error = err.into(); + self.request_stream + .stop_stream(error.try_get_code().unwrap_or(Code::H3_MESSAGE_ERROR)); + return Err(error); + } + }; + + // request_stream.stop_stream(Code::H3_MESSAGE_ERROR).await; + let mut req = http::Request::new(()); + *req.method_mut() = method; + *req.uri_mut() = uri; + *req.headers_mut() = headers; + // NOTE: insert `Protocol` and not `Option` + if let Some(protocol) = protocol { + req.extensions_mut().insert(protocol); + } + *req.version_mut() = http::Version::HTTP_3; + // send the grease frame only once + // self.inner.send_grease_frame = false; + tracing::trace!("replying with: {:?}", req); + Ok((req, self.request_stream)) + } +} diff --git a/h3/src/server.rs b/h3/src/server.rs index 596766ee..7416fabb 100644 --- a/h3/src/server.rs +++ b/h3/src/server.rs @@ -52,25 +52,39 @@ use std::{ collections::HashSet, - convert::TryFrom, + marker::PhantomData, + option::Option, + result::Result, sync::Arc, task::{Context, Poll}, }; use bytes::{Buf, BytesMut}; -use futures_util::future; -use http::{response, HeaderMap, Request, Response, StatusCode}; +use futures_util::{ + future::{self, Future}, + ready, +}; +use http::{response, HeaderMap, Request, Response}; +use pin_project_lite::pin_project; +use quic::RecvStream; use quic::StreamId; use tokio::sync::mpsc; use crate::{ + config::Config, connection::{self, ConnectionInner, ConnectionState, SharedStateRef}, error::{Code, Error, ErrorLevel}, - frame::FrameStream, - proto::{frame::Frame, headers::Header, push::PushId, varint::VarInt}, + ext::Datagram, + frame::{FrameStream, FrameStreamError}, + proto::{ + frame::{Frame, PayloadLen}, + headers::Header, + push::PushId, + }, qpack, - quic::{self, RecvStream as _, SendStream as _}, - stream, + quic::{self, RecvDatagramExt, SendDatagramExt, SendStream as _}, + request::ResolveRequest, + stream::{self, BufRecvStream}, }; use tracing::{error, trace, warn}; @@ -94,7 +108,8 @@ where C: quic::Connection, B: Buf, { - inner: ConnectionInner, + /// TODO: temporarily break encapsulation for `WebTransportSession` + pub inner: ConnectionInner, max_field_section_size: u64, // List of all incoming streams that are currently running. ongoing_streams: HashSet, @@ -132,6 +147,11 @@ where pub async fn new(conn: C) -> Result { builder().build(conn).await } + + /// Closes the connection with a code and a reason. + pub fn close>(&mut self, code: Code, reason: T) -> Error { + self.inner.close(code, reason) + } } impl Connection @@ -149,7 +169,7 @@ where ) -> Result, RequestStream)>, Error> { // Accept the incoming stream let mut stream = match future::poll_fn(|cx| self.poll_accept_request(cx)).await { - Ok(Some(s)) => FrameStream::new(s), + Ok(Some(s)) => FrameStream::new(BufRecvStream::new(s)), Ok(None) => { // We always send a last GoAway frame to the client, so it knows which was the last // non-rejected request. @@ -175,7 +195,25 @@ where }; let frame = future::poll_fn(|cx| stream.poll_next(cx)).await; + let req = self.accept_with_frame(stream, frame)?; + if let Some(req) = req { + Ok(Some(req.resolve().await?)) + } else { + Ok(None) + } + } + /// Accepts an http request where the first frame has already been read and decoded. + /// + /// + /// This is needed as a bidirectional stream may be read as part of incoming webtransport + /// bi-streams. If it turns out that the stream is *not* a `WEBTRANSPORT_STREAM` the request + /// may still want to be handled and passed to the user. + pub fn accept_with_frame( + &mut self, + mut stream: FrameStream, + frame: Result>, FrameStreamError>, + ) -> Result>, Error> { let mut encoded = match frame { Ok(Some(Frame::Headers(h))) => h, @@ -188,7 +226,7 @@ where return Err(self.inner.close( Code::H3_REQUEST_INCOMPLETE, "request stream closed before headers", - )) + )); } //= https://www.rfc-editor.org/rfc/rfc9114#section-4.1 @@ -242,7 +280,7 @@ where let mut request_stream = RequestStream { request_end: Arc::new(RequestEnd { request_end: self.request_end_send.clone(), - stream_id: stream.id(), + stream_id: stream.send_id(), }), inner: connection::RequestStream::new( stream, @@ -252,92 +290,54 @@ where ), }; - let qpack::Decoded { fields, .. } = - match qpack::decode_stateless(&mut encoded, self.max_field_section_size) { - //= https://www.rfc-editor.org/rfc/rfc9114#section-4.2.2 - //# An HTTP/3 implementation MAY impose a limit on the maximum size of - //# the message header it will accept on an individual HTTP message. - Err(qpack::DecoderError::HeaderTooLong(cancel_size)) => { - request_stream - .send_response( - http::Response::builder() - .status(StatusCode::REQUEST_HEADER_FIELDS_TOO_LARGE) - .body(()) - .expect("header too big response"), - ) - .await?; - return Err(Error::header_too_big( - cancel_size, - self.max_field_section_size, - )); + let decoded = match qpack::decode_stateless(&mut encoded, self.max_field_section_size) { + //= https://www.rfc-editor.org/rfc/rfc9114#section-4.2.2 + //# An HTTP/3 implementation MAY impose a limit on the maximum size of + //# the message header it will accept on an individual HTTP message. + Err(qpack::DecoderError::HeaderTooLong(cancel_size)) => Err(cancel_size), + Ok(decoded) => { + // send the grease frame only once + self.inner.send_grease_frame = false; + Ok(decoded) + } + Err(e) => { + let err: Error = e.into(); + if err.is_closed() { + return Ok(None); } - Ok(decoded) => decoded, - Err(e) => { - let err: Error = e.into(); - if err.is_closed() { - return Ok(None); - } - match err.inner.kind { - crate::error::Kind::Closed => return Ok(None), - crate::error::Kind::Application { - code, - reason, - level: ErrorLevel::ConnectionError, - } => { - return Err(self.inner.close( - code, - reason.unwrap_or_else(|| String::into_boxed_str(String::from(""))), - )) - } - crate::error::Kind::Application { + match err.inner.kind { + crate::error::Kind::Closed => return Ok(None), + crate::error::Kind::Application { + code, + reason, + level: ErrorLevel::ConnectionError, + } => { + return Err(self.inner.close( code, - reason: _, - level: ErrorLevel::StreamError, - } => { - request_stream.stop_stream(code); - return Err(err); - } - _ => return Err(err), - }; - } - }; - - // Parse the request headers - let (method, uri, headers) = match Header::try_from(fields) { - Ok(header) => match header.into_request_parts() { - Ok(parts) => parts, - Err(err) => { - //= https://www.rfc-editor.org/rfc/rfc9114#section-4.1.2 - //# Malformed requests or responses that are - //# detected MUST be treated as a stream error of type H3_MESSAGE_ERROR. - let error: Error = err.into(); - request_stream - .stop_stream(error.try_get_code().unwrap_or(Code::H3_MESSAGE_ERROR)); - return Err(error); - } - }, - Err(err) => { - //= https://www.rfc-editor.org/rfc/rfc9114#section-4.1.2 - //# Malformed requests or responses that are - //# detected MUST be treated as a stream error of type H3_MESSAGE_ERROR. - let error: Error = err.into(); - request_stream.stop_stream(error.try_get_code().unwrap_or(Code::H3_MESSAGE_ERROR)); - return Err(error); + reason.unwrap_or_else(|| String::into_boxed_str(String::from(""))), + )) + } + crate::error::Kind::Application { + code, + reason: _, + level: ErrorLevel::StreamError, + } => { + request_stream.stop_stream(code); + return Err(err); + } + _ => return Err(err), + }; } }; - // request_stream.stop_stream(Code::H3_MESSAGE_ERROR).await; - let mut req = http::Request::new(()); - *req.method_mut() = method; - *req.uri_mut() = uri; - *req.headers_mut() = headers; - *req.version_mut() = http::Version::HTTP_3; - // send the grease frame only once - self.inner.send_grease_frame = false; - Ok(Some((req, request_stream))) + Ok(Some(ResolveRequest::new( + request_stream, + decoded, + self.max_field_section_size, + ))) } - /// Itiniate a graceful shutdown, accepting `max_request` potentially still in-flight + /// Initiate a graceful shutdown, accepting `max_request` potentially still in-flight /// /// See [connection shutdown](https://www.rfc-editor.org/rfc/rfc9114.html#connection-shutdown) for more information. pub async fn shutdown(&mut self, max_requests: usize) -> Result<(), Error> { @@ -349,7 +349,11 @@ where self.inner.shutdown(&mut self.sent_closing, max_id).await } - fn poll_accept_request( + /// Accepts an incoming bidirectional stream. + /// + /// This could be either a *Request* or a *WebTransportBiStream*, the first frame's type + /// decides. + pub fn poll_accept_request( &mut self, cx: &mut Context<'_>, ) -> Poll, Error>> { @@ -380,7 +384,7 @@ where // incoming requests not belonging to the grace interval. It's possible that // some acceptable request streams arrive after rejected requests. if let Some(max_id) = self.sent_closing { - if s.id() > max_id { + if s.send_id() > max_id { s.stop_sending(Code::H3_REQUEST_REJECTED.value()); s.reset(Code::H3_REQUEST_REJECTED.value()); if self.poll_requests_completion(cx).is_ready() { @@ -389,49 +393,57 @@ where continue; } } - self.last_accepted_stream = Some(s.id()); - self.ongoing_streams.insert(s.id()); + self.last_accepted_stream = Some(s.send_id()); + self.ongoing_streams.insert(s.send_id()); break Poll::Ready(Ok(Some(s))); } }; } } - fn poll_control(&mut self, cx: &mut Context<'_>) -> Poll> { - while let Poll::Ready(frame) = self.inner.poll_control(cx)? { - match frame { - Frame::Settings(_) => trace!("Got settings"), - Frame::Goaway(id) => self.inner.process_goaway(&mut self.recv_closing, id)?, - f @ Frame::MaxPushId(_) | f @ Frame::CancelPush(_) => { - warn!("Control frame ignored {:?}", f); - - //= https://www.rfc-editor.org/rfc/rfc9114#section-7.2.3 - //= type=TODO - //# If a server receives a CANCEL_PUSH frame for a push - //# ID that has not yet been mentioned by a PUSH_PROMISE frame, this MUST - //# be treated as a connection error of type H3_ID_ERROR. - - //= https://www.rfc-editor.org/rfc/rfc9114#section-7.2.7 - //= type=TODO - //# A MAX_PUSH_ID frame cannot reduce the maximum push - //# ID; receipt of a MAX_PUSH_ID frame that contains a smaller value than - //# previously received MUST be treated as a connection error of type - //# H3_ID_ERROR. - } + pub(crate) fn poll_control(&mut self, cx: &mut Context<'_>) -> Poll> { + while (self.poll_next_control(cx)?).is_ready() {} + Poll::Pending + } - //= https://www.rfc-editor.org/rfc/rfc9114#section-7.2.5 - //# A server MUST treat the - //# receipt of a PUSH_PROMISE frame as a connection error of type - //# H3_FRAME_UNEXPECTED. - frame => { - return Poll::Ready(Err(Code::H3_FRAME_UNEXPECTED.with_reason( - format!("on server control stream: {:?}", frame), - ErrorLevel::ConnectionError, - ))) - } + pub(crate) fn poll_next_control( + &mut self, + cx: &mut Context<'_>, + ) -> Poll, Error>> { + let frame = ready!(self.inner.poll_control(cx))?; + + match &frame { + Frame::Settings(w) => trace!("Got settings > {:?}", w), + &Frame::Goaway(id) => self.inner.process_goaway(&mut self.recv_closing, id)?, + f @ Frame::MaxPushId(_) | f @ Frame::CancelPush(_) => { + warn!("Control frame ignored {:?}", f); + + //= https://www.rfc-editor.org/rfc/rfc9114#section-7.2.3 + //= type=TODO + //# If a server receives a CANCEL_PUSH frame for a push + //# ID that has not yet been mentioned by a PUSH_PROMISE frame, this MUST + //# be treated as a connection error of type H3_ID_ERROR. + + //= https://www.rfc-editor.org/rfc/rfc9114#section-7.2.7 + //= type=TODO + //# A MAX_PUSH_ID frame cannot reduce the maximum push + //# ID; receipt of a MAX_PUSH_ID frame that contains a smaller value than + //# previously received MUST be treated as a connection error of type + //# H3_ID_ERROR. + } + + //= https://www.rfc-editor.org/rfc/rfc9114#section-7.2.5 + //# A server MUST treat the + //# receipt of a PUSH_PROMISE frame as a connection error of type + //# H3_FRAME_UNEXPECTED. + frame => { + return Poll::Ready(Err(Code::H3_FRAME_UNEXPECTED.with_reason( + format!("on server control stream: {:?}", frame), + ErrorLevel::ConnectionError, + ))) } } - Poll::Pending + Poll::Ready(Ok(frame)) } fn poll_requests_completion(&mut self, cx: &mut Context<'_>) -> Poll<()> { @@ -457,6 +469,36 @@ where } } +impl Connection +where + C: quic::Connection + SendDatagramExt, + B: Buf, +{ + /// Sends a datagram + pub fn send_datagram(&mut self, stream_id: StreamId, data: B) -> Result<(), Error> { + self.inner + .conn + .send_datagram(Datagram::new(stream_id, data))?; + tracing::info!("Sent datagram"); + + Ok(()) + } +} + +impl Connection +where + C: quic::Connection + RecvDatagramExt, + B: Buf, +{ + /// Reads an incoming datagram + pub fn read_datagram(&mut self) -> ReadDatagram { + ReadDatagram { + conn: self, + _marker: PhantomData, + } + } +} + impl Drop for Connection where C: quic::Connection, @@ -503,32 +545,66 @@ where /// } /// ``` pub struct Builder { - pub(super) max_field_section_size: u64, - pub(super) send_grease: bool, + pub(crate) config: Config, } impl Builder { /// Creates a new [`Builder`] with default settings. pub(super) fn new() -> Self { Builder { - max_field_section_size: VarInt::MAX.0, - send_grease: true, + config: Default::default(), } } + /// Set the maximum header size this client is willing to accept /// /// See [header size constraints] section of the specification for details. /// /// [header size constraints]: https://www.rfc-editor.org/rfc/rfc9114.html#name-header-size-constraints pub fn max_field_section_size(&mut self, value: u64) -> &mut Self { - self.max_field_section_size = value; + self.config.max_field_section_size = value; self } /// Send grease values to the Client. /// See [setting](https://www.rfc-editor.org/rfc/rfc9114.html#settings-parameters), [frame](https://www.rfc-editor.org/rfc/rfc9114.html#frame-reserved) and [stream](https://www.rfc-editor.org/rfc/rfc9114.html#stream-grease) for more information. + #[inline] pub fn send_grease(&mut self, value: bool) -> &mut Self { - self.send_grease = value; + self.config.send_grease = value; + self + } + + /// Indicates to the peer that WebTransport is supported. + /// + /// See: [establishing a webtransport session](https://datatracker.ietf.org/doc/html/draft-ietf-webtrans-http3/#section-3.1) + /// + /// + /// **Server**: + /// Supporting for webtransport also requires setting `enable_connect` `enable_datagram` + /// and `max_webtransport_sessions`. + #[inline] + pub fn enable_webtransport(&mut self, value: bool) -> &mut Self { + self.config.enable_webtransport = value; + self + } + + /// Enables the CONNECT protocol + pub fn enable_connect(&mut self, value: bool) -> &mut Self { + self.config.enable_extended_connect = value; + self + } + + /// Limits the maximum number of WebTransport sessions + pub fn max_webtransport_sessions(&mut self, value: u64) -> &mut Self { + self.config.max_webtransport_sessions = value; + self + } + + /// Indicates that the client or server supports HTTP/3 datagrams + /// + /// See: https://www.rfc-editor.org/rfc/rfc9297#section-2.1.1 + pub fn enable_datagram(&mut self, value: bool) -> &mut Self { + self.config.enable_datagram = value; self } } @@ -544,14 +620,8 @@ impl Builder { { let (sender, receiver) = mpsc::unbounded_channel(); Ok(Connection { - inner: ConnectionInner::new( - conn, - self.max_field_section_size, - SharedStateRef::default(), - self.send_grease, - ) - .await?, - max_field_section_size: self.max_field_section_size, + inner: ConnectionInner::new(conn, SharedStateRef::default(), self.config).await?, + max_field_section_size: self.config.max_field_section_size, request_end_send: sender, request_end_recv: receiver, ongoing_streams: HashSet::new(), @@ -591,6 +661,7 @@ impl ConnectionState for RequestStream { impl RequestStream where S: quic::RecvStream, + B: Buf, { /// Receive data sent from the client pub async fn recv_data(&mut self) -> Result, Error> { @@ -606,6 +677,11 @@ where pub fn stop_sending(&mut self, error_code: crate::error::Code) { self.inner.stream.stop_sending(error_code) } + + /// Returns the underlying stream id + pub fn id(&self) -> StreamId { + self.inner.stream.id() + } } impl RequestStream @@ -631,7 +707,8 @@ where .inner .conn_state .read("send_response") - .peer_max_field_section_size; + .peer_config + .max_field_section_size; //= https://www.rfc-editor.org/rfc/rfc9114#section-4.2.2 //# An implementation that @@ -686,6 +763,11 @@ where //# implementation resets the sending parts of streams and aborts reading //# on the receiving parts of streams; see Section 2.4 of //# [QUIC-TRANSPORT]. + + /// Returns the underlying stream id + pub fn send_id(&self) -> StreamId { + self.inner.stream.send_id() + } } impl RequestStream @@ -725,3 +807,31 @@ impl Drop for RequestEnd { } } } + +pin_project! { + /// Future for [`Connection::read_datagram`] + pub struct ReadDatagram<'a, C, B> + where + C: quic::Connection, + B: Buf, + { + conn: &'a mut Connection, + _marker: PhantomData, + } +} + +impl<'a, C, B> Future for ReadDatagram<'a, C, B> +where + C: quic::Connection + RecvDatagramExt, + B: Buf, +{ + type Output = Result>, Error>; + + fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + tracing::trace!("poll: read_datagram"); + match ready!(self.conn.inner.conn.poll_accept_datagram(cx))? { + Some(v) => Poll::Ready(Ok(Some(Datagram::decode(v)?))), + None => Poll::Ready(Ok(None)), + } + } +} diff --git a/h3/src/stream.rs b/h3/src/stream.rs index 0314867d..514a0390 100644 --- a/h3/src/stream.rs +++ b/h3/src/stream.rs @@ -1,24 +1,31 @@ -use std::task::{Context, Poll}; +use std::{ + marker::PhantomData, + pin::Pin, + task::{Context, Poll}, +}; -use bytes::{Buf, BufMut as _, Bytes}; +use bytes::{Buf, BufMut, Bytes}; use futures_util::{future, ready}; -use quic::RecvStream; +use pin_project_lite::pin_project; +use tokio::io::ReadBuf; use crate::{ buf::BufList, error::{Code, ErrorLevel}, frame::FrameStream, proto::{ - coding::{BufExt, Decode as _, Encode}, - frame::Frame, + coding::{Decode as _, Encode}, + frame::{Frame, Settings}, stream::StreamType, varint::VarInt, }, - quic::{self, SendStream}, + quic::{self, BidiStream, RecvStream, SendStream, SendStreamUnframed}, + webtransport::SessionId, Error, }; #[inline] +/// Transmits data by encoding in wire format. pub(crate) async fn write(stream: &mut S, data: D) -> Result<(), Error> where S: SendStream, @@ -43,10 +50,7 @@ const WRITE_BUF_ENCODE_SIZE: usize = StreamType::MAX_ENCODED_SIZE + Frame::MAX_E /// data is necessary (say, in `quic::SendStream::send_data`). It also has a public API ergonomy /// advantage: `WriteBuf` doesn't have to appear in public associated types. On the other hand, /// QUIC implementers have to call `into()`, which will encode the header in `Self::buf`. -pub struct WriteBuf -where - B: Buf, -{ +pub struct WriteBuf { buf: [u8; WRITE_BUF_ENCODE_SIZE], len: usize, pos: usize, @@ -59,10 +63,17 @@ where { fn encode_stream_type(&mut self, ty: StreamType) { let mut buf_mut = &mut self.buf[self.len..]; + ty.encode(&mut buf_mut); self.len = WRITE_BUF_ENCODE_SIZE - buf_mut.remaining_mut(); } + fn encode_value(&mut self, value: impl Encode) { + let mut buf_mut = &mut self.buf[self.len..]; + value.encode(&mut buf_mut); + self.len = WRITE_BUF_ENCODE_SIZE - buf_mut.remaining_mut(); + } + fn encode_frame_header(&mut self) { if let Some(frame) = self.frame.as_ref() { let mut buf_mut = &mut self.buf[self.len..]; @@ -88,6 +99,80 @@ where } } +impl From for WriteBuf +where + B: Buf, +{ + fn from(header: UniStreamHeader) -> Self { + let mut this = Self { + buf: [0; WRITE_BUF_ENCODE_SIZE], + len: 0, + pos: 0, + frame: None, + }; + + this.encode_value(header); + this + } +} + +pub enum UniStreamHeader { + Control(Settings), + WebTransportUni(SessionId), +} + +impl Encode for UniStreamHeader { + fn encode(&self, buf: &mut B) { + match self { + Self::Control(settings) => { + StreamType::CONTROL.encode(buf); + settings.encode(buf); + } + Self::WebTransportUni(session_id) => { + StreamType::WEBTRANSPORT_UNI.encode(buf); + session_id.encode(buf); + } + } + } +} + +impl From for WriteBuf +where + B: Buf, +{ + fn from(header: BidiStreamHeader) -> Self { + let mut this = Self { + buf: [0; WRITE_BUF_ENCODE_SIZE], + len: 0, + pos: 0, + frame: None, + }; + + this.encode_value(header); + this + } +} + +pub enum BidiStreamHeader { + Control(Settings), + WebTransportBidi(SessionId), +} + +impl Encode for BidiStreamHeader { + fn encode(&self, buf: &mut B) { + match self { + Self::Control(settings) => { + StreamType::CONTROL.encode(buf); + settings.encode(buf); + } + Self::WebTransportBidi(session_id) => { + StreamType::WEBTRANSPORT_BIDI.encode(buf); + session_id.encode(buf); + } + } + } +} + impl From> for WriteBuf where B: Buf, @@ -116,7 +201,7 @@ where pos: 0, frame: Some(frame), }; - me.encode_stream_type(ty); + me.encode_value(ty); me.encode_frame_header(); me } @@ -162,50 +247,52 @@ where pub(super) enum AcceptedRecvStream where S: quic::RecvStream, + B: Buf, { Control(FrameStream), Push(u64, FrameStream), - Encoder(S), - Decoder(S), + Encoder(BufRecvStream), + Decoder(BufRecvStream), + WebTransportUni(SessionId, BufRecvStream), Reserved, } -pub(super) struct AcceptRecvStream -where - S: quic::RecvStream, -{ - stream: S, +/// Resolves an incoming streams type as well as `PUSH_ID`s and `SESSION_ID`s +pub(super) struct AcceptRecvStream { + stream: BufRecvStream, ty: Option, - push_id: Option, - buf: BufList, + /// push_id or session_id + id: Option, expected: Option, } -impl AcceptRecvStream +impl AcceptRecvStream where S: RecvStream, + B: Buf, { pub fn new(stream: S) -> Self { Self { - stream, + stream: BufRecvStream::new(stream), ty: None, - push_id: None, - buf: BufList::new(), + id: None, expected: None, } } - pub fn into_stream(self) -> Result, Error> { + pub fn into_stream(self) -> Result, Error> { Ok(match self.ty.expect("Stream type not resolved yet") { - StreamType::CONTROL => { - AcceptedRecvStream::Control(FrameStream::with_bufs(self.stream, self.buf)) - } + StreamType::CONTROL => AcceptedRecvStream::Control(FrameStream::new(self.stream)), StreamType::PUSH => AcceptedRecvStream::Push( - self.push_id.expect("Push ID not resolved yet"), - FrameStream::with_bufs(self.stream, self.buf), + self.id.expect("Push ID not resolved yet").into_inner(), + FrameStream::new(self.stream), ), StreamType::ENCODER => AcceptedRecvStream::Encoder(self.stream), StreamType::DECODER => AcceptedRecvStream::Decoder(self.stream), + StreamType::WEBTRANSPORT_UNI => AcceptedRecvStream::WebTransportUni( + SessionId::from_varint(self.id.expect("Session ID not resolved yet")), + self.stream, + ), t if t.value() > 0x21 && (t.value() - 0x21) % 0x1f == 0 => AcceptedRecvStream::Reserved, //= https://www.rfc-editor.org/rfc/rfc9114#section-6.2 @@ -233,37 +320,42 @@ where pub fn poll_type(&mut self, cx: &mut Context) -> Poll> { loop { - match (self.ty.as_ref(), self.push_id) { - // When accepting a Push stream, we want to parse two VarInts: [StreamType, PUSH_ID] - (Some(&StreamType::PUSH), Some(_)) | (Some(_), _) => return Poll::Ready(Ok(())), - _ => (), - } - - match ready!(self.stream.poll_data(cx))? { - Some(mut b) => self.buf.push_bytes(&mut b), - None => { - return Poll::Ready(Err(Code::H3_STREAM_CREATION_ERROR.with_reason( - "Stream closed before type received", - ErrorLevel::ConnectionError, - ))); + // Return if all identification data is met + match self.ty { + Some(StreamType::PUSH | StreamType::WEBTRANSPORT_UNI) => { + if self.id.is_some() { + return Poll::Ready(Ok(())); + } } + Some(_) => return Poll::Ready(Ok(())), + None => (), }; - if self.expected.is_none() && self.buf.remaining() >= 1 { - self.expected = Some(VarInt::encoded_size(self.buf.chunk()[0])); + if ready!(self.stream.poll_read(cx))? { + return Poll::Ready(Err(Code::H3_STREAM_CREATION_ERROR.with_reason( + "Stream closed before type received", + ErrorLevel::ConnectionError, + ))); + }; + + let mut buf = self.stream.buf_mut(); + if self.expected.is_none() && buf.remaining() >= 1 { + self.expected = Some(VarInt::encoded_size(buf.chunk()[0])); } if let Some(expected) = self.expected { - if self.buf.remaining() < expected { + // Poll for more data + if buf.remaining() < expected { continue; } } else { continue; } + // Parse ty and then id if self.ty.is_none() { // Parse StreamType - self.ty = Some(StreamType::decode(&mut self.buf).map_err(|_| { + self.ty = Some(StreamType::decode(&mut buf).map_err(|_| { Code::H3_INTERNAL_ERROR.with_reason( "Unexpected end parsing stream type", ErrorLevel::ConnectionError, @@ -273,9 +365,9 @@ where self.expected = None; } else { // Parse PUSH_ID - self.push_id = Some(self.buf.get_var().map_err(|_| { + self.id = Some(VarInt::decode(&mut buf).map_err(|_| { Code::H3_INTERNAL_ERROR.with_reason( - "Unexpected end parsing stream type", + "Unexpected end parsing push or session id", ErrorLevel::ConnectionError, ) })?); @@ -284,10 +376,336 @@ where } } +pin_project! { + /// A stream which allows partial reading of the data without data loss. + /// + /// This fixes the problem where `poll_data` returns more than the needed amount of bytes, + /// requiring correct implementations to hold on to that extra data and return it later. + /// + /// # Usage + /// + /// Implements `quic::RecvStream` which will first return buffered data, and then read from the + /// stream + pub struct BufRecvStream { + buf: BufList, + // Indicates that the end of the stream has been reached + // + // Data may still be available as buffered + eos: bool, + stream: S, + _marker: PhantomData, + } +} + +impl std::fmt::Debug for BufRecvStream { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("BufRecvStream") + .field("buf", &self.buf) + .field("eos", &self.eos) + .field("stream", &"...") + .finish() + } +} + +impl BufRecvStream { + pub fn new(stream: S) -> Self { + Self { + buf: BufList::new(), + eos: false, + stream, + _marker: PhantomData, + } + } +} + +impl BufRecvStream { + /// Reads more data into the buffer, returning the number of bytes read. + /// + /// Returns `true` if the end of the stream is reached. + pub fn poll_read(&mut self, cx: &mut Context<'_>) -> Poll> { + let data = ready!(self.stream.poll_data(cx))?; + + if let Some(mut data) = data { + self.buf.push_bytes(&mut data); + Poll::Ready(Ok(false)) + } else { + self.eos = true; + Poll::Ready(Ok(true)) + } + } + + /// Returns the currently buffered data, allowing it to be partially read + #[inline] + pub(crate) fn buf_mut(&mut self) -> &mut BufList { + &mut self.buf + } + + /// Returns the next chunk of data from the stream + /// + /// Return `None` when there is no more buffered data; use [`Self::poll_read`]. + pub fn take_chunk(&mut self, limit: usize) -> Option { + self.buf.take_chunk(limit) + } + + /// Returns true if there is remaining buffered data + pub fn has_remaining(&mut self) -> bool { + self.buf.has_remaining() + } + + #[inline] + pub(crate) fn buf(&self) -> &BufList { + &self.buf + } + + pub fn is_eos(&self) -> bool { + self.eos + } +} + +impl RecvStream for BufRecvStream { + type Buf = Bytes; + + type Error = S::Error; + + fn poll_data( + &mut self, + cx: &mut std::task::Context<'_>, + ) -> Poll, Self::Error>> { + // There is data buffered, return that immediately + if let Some(chunk) = self.buf.take_first_chunk() { + return Poll::Ready(Ok(Some(chunk))); + } + + if let Some(mut data) = ready!(self.stream.poll_data(cx))? { + Poll::Ready(Ok(Some(data.copy_to_bytes(data.remaining())))) + } else { + self.eos = true; + Poll::Ready(Ok(None)) + } + } + + fn stop_sending(&mut self, error_code: u64) { + self.stream.stop_sending(error_code) + } + + fn recv_id(&self) -> quic::StreamId { + self.stream.recv_id() + } +} + +impl SendStream for BufRecvStream +where + B: Buf, + S: SendStream, +{ + type Error = S::Error; + + fn poll_finish(&mut self, cx: &mut std::task::Context<'_>) -> Poll> { + self.stream.poll_finish(cx) + } + + fn reset(&mut self, reset_code: u64) { + self.stream.reset(reset_code) + } + + fn send_id(&self) -> quic::StreamId { + self.stream.send_id() + } + + fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll> { + self.stream.poll_ready(cx) + } + + fn send_data>>(&mut self, data: T) -> Result<(), Self::Error> { + self.stream.send_data(data) + } +} + +impl SendStreamUnframed for BufRecvStream +where + B: Buf, + S: SendStreamUnframed, +{ + #[inline] + fn poll_send( + &mut self, + cx: &mut std::task::Context<'_>, + buf: &mut D, + ) -> Poll> { + self.stream.poll_send(cx, buf) + } +} + +impl BidiStream for BufRecvStream +where + B: Buf, + S: BidiStream, +{ + type SendStream = BufRecvStream; + + type RecvStream = BufRecvStream; + + fn split(self) -> (Self::SendStream, Self::RecvStream) { + let (send, recv) = self.stream.split(); + ( + BufRecvStream { + // Sending is not buffered + buf: BufList::new(), + eos: self.eos, + stream: send, + _marker: PhantomData, + }, + BufRecvStream { + buf: self.buf, + eos: self.eos, + stream: recv, + _marker: PhantomData, + }, + ) + } +} + +impl futures_util::io::AsyncRead for BufRecvStream +where + B: Buf, + S: RecvStream, + S::Error: Into, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + let p = &mut *self; + // Poll for data if the buffer is empty + // + // If there is data available *do not* poll for more data, as that may suspend indefinitely + // if no more data is sent, causing data loss. + if !p.has_remaining() { + let eos = ready!(p.poll_read(cx).map_err(Into::into))?; + if eos { + return Poll::Ready(Ok(0)); + } + } + + let chunk = p.buf_mut().take_chunk(buf.len()); + if let Some(chunk) = chunk { + assert!(chunk.len() <= buf.len()); + let len = chunk.len().min(buf.len()); + // Write the subset into the destination + buf[..len].copy_from_slice(&chunk); + Poll::Ready(Ok(len)) + } else { + Poll::Ready(Ok(0)) + } + } +} + +impl tokio::io::AsyncRead for BufRecvStream +where + B: Buf, + S: RecvStream, + S::Error: Into, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let p = &mut *self; + // Poll for data if the buffer is empty + // + // If there is data available *do not* poll for more data, as that may suspend indefinitely + // if no more data is sent, causing data loss. + if !p.has_remaining() { + let eos = ready!(p.poll_read(cx).map_err(Into::into))?; + if eos { + return Poll::Ready(Ok(())); + } + } + + let chunk = p.buf_mut().take_chunk(buf.remaining()); + if let Some(chunk) = chunk { + assert!(chunk.len() <= buf.remaining()); + // Write the subset into the destination + buf.put_slice(&chunk); + Poll::Ready(Ok(())) + } else { + Poll::Ready(Ok(())) + } + } +} + +impl futures_util::io::AsyncWrite for BufRecvStream +where + B: Buf, + S: SendStreamUnframed, + S::Error: Into, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + mut buf: &[u8], + ) -> Poll> { + let p = &mut *self; + p.poll_send(cx, &mut buf).map_err(Into::into) + } + + fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let p = &mut *self; + p.poll_finish(cx).map_err(Into::into) + } +} + +impl tokio::io::AsyncWrite for BufRecvStream +where + B: Buf, + S: SendStreamUnframed, + S::Error: Into, +{ + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + mut buf: &[u8], + ) -> Poll> { + let p = &mut *self; + p.poll_send(cx, &mut buf).map_err(Into::into) + } + + fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let p = &mut *self; + p.poll_finish(cx).map_err(Into::into) + } +} + #[cfg(test)] mod tests { + use quinn_proto::coding::BufExt; + use super::*; + #[test] + fn write_wt_uni_header() { + let mut w = WriteBuf::::from(UniStreamHeader::WebTransportUni( + SessionId::from_varint(VarInt(5)), + )); + + let ty = w.get_var().unwrap(); + println!("Got type: {ty} {ty:#x}"); + assert_eq!(ty, 0x54); + + let id = w.get_var().unwrap(); + println!("Got id: {id}"); + } + #[test] fn write_buf_encode_streamtype() { let wbuf = WriteBuf::::from(StreamType::ENCODER); diff --git a/h3/src/tests/connection.rs b/h3/src/tests/connection.rs index f03a7c50..458db214 100644 --- a/h3/src/tests/connection.rs +++ b/h3/src/tests/connection.rs @@ -144,7 +144,8 @@ async fn settings_exchange_client() { if client .shared_state() .read("client") - .peer_max_field_section_size + .peer_config + .max_field_section_size == 12 { return; @@ -202,7 +203,12 @@ async fn settings_exchange_server() { let settings_change = async { for _ in 0..10 { - if state.read("setting_change").peer_max_field_section_size == 12 { + if state + .read("setting_change") + .peer_config + .max_field_section_size + == 12 + { return; } tokio::time::sleep(Duration::from_millis(2)).await; diff --git a/h3/src/tests/request.rs b/h3/src/tests/request.rs index 69429559..cb2296ca 100644 --- a/h3/src/tests/request.rs +++ b/h3/src/tests/request.rs @@ -381,7 +381,8 @@ async fn header_too_big_client_error() { client .shared_state() .write("client") - .peer_max_field_section_size = 12; + .peer_config + .max_field_section_size = 12; let req = Request::get("http://localhost/salut").body(()).unwrap(); let err_kind = client @@ -432,7 +433,8 @@ async fn header_too_big_client_error_trailer() { client .shared_state() .write("client") - .peer_max_field_section_size = 200; + .peer_config + .max_field_section_size = 200; let mut request_stream = client .send_request(Request::get("http://localhost/salut").body(()).unwrap()) @@ -541,7 +543,8 @@ async fn header_too_big_discard_from_client() { incoming_req .shared_state() .write("client") - .peer_max_field_section_size = u64::MAX; + .peer_config + .max_field_section_size = u64::MAX; request_stream .send_response( Response::builder() @@ -594,6 +597,7 @@ async fn header_too_big_discard_from_client_trailers() { .build::<_, _, Bytes>(pair.client().await) .await .expect("client init"); + let drive_fut = async { future::poll_fn(|cx| driver.poll_close(cx)).await }; let req_fut = async { let mut request_stream = client @@ -627,7 +631,8 @@ async fn header_too_big_discard_from_client_trailers() { incoming_req .shared_state() .write("server") - .peer_max_field_section_size = u64::MAX; + .peer_config + .max_field_section_size = u64::MAX; request_stream .send_response( @@ -698,7 +703,8 @@ async fn header_too_big_server_error() { incoming_req .shared_state() .write("server") - .peer_max_field_section_size = 12; + .peer_config + .max_field_section_size = 12; let err_kind = request_stream .send_response( @@ -778,7 +784,8 @@ async fn header_too_big_server_error_trailers() { incoming_req .shared_state() .write("write") - .peer_max_field_section_size = 200; + .peer_config + .max_field_section_size = 200; let mut trailers = HeaderMap::new(); trailers.insert("trailer", "value".repeat(100).parse().unwrap()); @@ -1332,7 +1339,7 @@ fn request_encode(buf: &mut B, req: http::Request<()>) { headers, .. } = parts; - let headers = Header::request(method, uri, headers).unwrap(); + let headers = Header::request(method, uri, headers, Default::default()).unwrap(); let mut block = BytesMut::new(); qpack::encode_stateless(&mut block, headers).unwrap(); Frame::headers(block).encode_with_payload(buf); diff --git a/h3/src/webtransport/mod.rs b/h3/src/webtransport/mod.rs new file mode 100644 index 00000000..74ddc906 --- /dev/null +++ b/h3/src/webtransport/mod.rs @@ -0,0 +1,2 @@ +mod session_id; +pub use session_id::SessionId; diff --git a/h3/src/webtransport/session_id.rs b/h3/src/webtransport/session_id.rs new file mode 100644 index 00000000..b6f4424d --- /dev/null +++ b/h3/src/webtransport/session_id.rs @@ -0,0 +1,50 @@ +use std::convert::TryFrom; + +use crate::proto::{ + coding::{Decode, Encode}, + stream::{InvalidStreamId, StreamId}, + varint::VarInt, +}; + +/// Identifies a WebTransport session +/// +/// The session id is the same as the stream id of the CONNECT request. +#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)] +pub struct SessionId(u64); +impl SessionId { + pub(crate) fn from_varint(id: VarInt) -> SessionId { + Self(id.0) + } + + pub(crate) fn into_inner(self) -> u64 { + self.0 + } +} + +impl TryFrom for SessionId { + type Error = InvalidStreamId; + fn try_from(v: u64) -> Result { + if v > VarInt::MAX.0 { + return Err(InvalidStreamId(v)); + } + Ok(Self(v)) + } +} + +impl Encode for SessionId { + fn encode(&self, buf: &mut B) { + VarInt::from_u64(self.0).unwrap().encode(buf); + } +} + +impl Decode for SessionId { + fn decode(buf: &mut B) -> crate::proto::coding::Result { + Ok(Self(VarInt::decode(buf)?.into_inner())) + } +} + +impl From for SessionId { + fn from(value: StreamId) -> Self { + Self(value.index()) + } +}