diff --git a/src/server/conn/auto.rs b/src/server/conn/auto.rs index 47d88dc..5cd3694 100644 --- a/src/server/conn/auto.rs +++ b/src/server/conn/auto.rs @@ -58,6 +58,8 @@ pub struct Builder { http1: http1::Builder, #[cfg(feature = "http2")] http2: http2::Builder, + #[cfg(any(feature = "http1", feature = "http2"))] + version: Option, #[cfg(not(feature = "http2"))] _executor: E, } @@ -84,6 +86,8 @@ impl Builder { http1: http1::Builder::new(), #[cfg(feature = "http2")] http2: http2::Builder::new(executor), + #[cfg(any(feature = "http1", feature = "http2"))] + version: None, #[cfg(not(feature = "http2"))] _executor: executor, } @@ -101,6 +105,26 @@ impl Builder { Http2Builder { inner: self } } + /// Only accepts HTTP/2 + /// + /// Does not do anything if used with [`serve_connection_with_upgrades`] + #[cfg(feature = "http2")] + pub fn http2_only(mut self) -> Self { + assert!(self.version.is_none()); + self.version = Some(Version::H2); + self + } + + /// Only accepts HTTP/1 + /// + /// Does not do anything if used with [`serve_connection_with_upgrades`] + #[cfg(feature = "http1")] + pub fn http1_only(mut self) -> Self { + assert!(self.version.is_none()); + self.version = Some(Version::H1); + self + } + /// Bind a connection together with a [`Service`]. pub fn serve_connection(&self, io: I, service: S) -> Connection<'_, I, S, E> where @@ -112,13 +136,28 @@ impl Builder { I: Read + Write + Unpin + 'static, E: HttpServerConnExec, { - Connection { - state: ConnState::ReadVersion { + let state = match self.version { + #[cfg(feature = "http1")] + Some(Version::H1) => { + let io = Rewind::new_buffered(io, Bytes::new()); + let conn = self.http1.serve_connection(io, service); + ConnState::H1 { conn } + } + #[cfg(feature = "http2")] + Some(Version::H2) => { + let io = Rewind::new_buffered(io, Bytes::new()); + let conn = self.http2.serve_connection(io, service); + ConnState::H2 { conn } + } + #[cfg(any(feature = "http1", feature = "http2"))] + _ => ConnState::ReadVersion { read_version: read_version(io), builder: self, service: Some(service), }, - } + }; + + Connection { state } } /// Bind a connection together with a [`Service`], with the ability to @@ -139,16 +178,33 @@ impl Builder { E: HttpServerConnExec, { UpgradeableConnection { - state: UpgradeableConnState::ReadVersion { - read_version: read_version(io), - builder: self, - service: Some(service), + state: match self.version { + #[cfg(feature = "http1")] + Some(Version::H1) => { + let io = Rewind::new_buffered(io, Bytes::new()); + UpgradeableConnState::H1 { + conn: self.http1.serve_connection(io, service).with_upgrades(), + } + }, + #[cfg(feature = "http2")] + Some(Version::H2) => { + let io = Rewind::new_buffered(io, Bytes::new()); + UpgradeableConnState::H2 { + conn: self.http2.serve_connection(io, service), + } + }, + #[cfg(any(feature = "http1", feature = "http2"))] + _ => UpgradeableConnState::ReadVersion { + read_version: read_version(io), + builder: self, + service: Some(service), + } }, } } } -#[derive(Copy, Clone)] +#[derive(Copy, Clone, Debug)] enum Version { H1, H2, @@ -906,7 +962,7 @@ mod tests { #[cfg(not(miri))] #[tokio::test] async fn http1() { - let addr = start_server().await; + let addr = start_server(false, false).await; let mut sender = connect_h1(addr).await; let response = sender @@ -922,7 +978,7 @@ mod tests { #[cfg(not(miri))] #[tokio::test] async fn http2() { - let addr = start_server().await; + let addr = start_server(false, false).await; let mut sender = connect_h2(addr).await; let response = sender @@ -935,6 +991,62 @@ mod tests { assert_eq!(body, BODY); } + #[cfg(not(miri))] + #[tokio::test] + async fn http2_only() { + let addr = start_server(false, true).await; + let mut sender = connect_h2(addr).await; + + let response = sender + .send_request(Request::new(Empty::::new())) + .await + .unwrap(); + + let body = response.into_body().collect().await.unwrap().to_bytes(); + + assert_eq!(body, BODY); + } + + #[cfg(not(miri))] + #[tokio::test] + async fn http2_only_fail_if_client_is_http1() { + let addr = start_server(false, true).await; + let mut sender = connect_h1(addr).await; + + let _ = sender + .send_request(Request::new(Empty::::new())) + .await + .expect_err("should fail"); + } + + #[cfg(not(miri))] + #[tokio::test] + async fn http1_only() { + let addr = start_server(true, false).await; + let mut sender = connect_h1(addr).await; + + let response = sender + .send_request(Request::new(Empty::::new())) + .await + .unwrap(); + + let body = response.into_body().collect().await.unwrap().to_bytes(); + + assert_eq!(body, BODY); + } + + #[cfg(not(miri))] + #[tokio::test] + async fn http1_only_fail_if_client_is_http2() { + let addr = start_server(true, false).await; + let mut sender = connect_h2(addr).await; + + let _ = sender + .send_request(Request::new(Empty::::new())) + .await + .expect_err("should fail"); + } + #[cfg(not(miri))] #[tokio::test] async fn graceful_shutdown() { @@ -1000,7 +1112,7 @@ mod tests { sender } - async fn start_server() -> SocketAddr { + async fn start_server(h1_only: bool, h2_only: bool) -> SocketAddr { let addr: SocketAddr = ([127, 0, 0, 1], 0).into(); let listener = TcpListener::bind(addr).await.unwrap(); @@ -1011,9 +1123,19 @@ mod tests { let (stream, _) = listener.accept().await.unwrap(); let stream = TokioIo::new(stream); tokio::task::spawn(async move { - let _ = auto::Builder::new(TokioExecutor::new()) + let mut builder = auto::Builder::new(TokioExecutor::new()); + + builder .http2() - .max_header_list_size(4096) + .max_header_list_size(4096); + + if h1_only { + builder = builder.http1_only(); + } else if h2_only { + builder = builder.http2_only(); + } + + builder .serve_connection_with_upgrades(stream, service_fn(hello)) .await; });