From aec268cefdc12662231ef73445c8eb29c0fb9d37 Mon Sep 17 00:00:00 2001 From: ihciah Date: Mon, 4 Nov 2024 07:30:05 +0000 Subject: [PATCH] feat: generic protocol detector --- monolake-services/src/common/detect.rs | 118 ++++++++++++++++++ monolake-services/src/common/mod.rs | 2 + monolake-services/src/http/core.rs | 10 +- monolake-services/src/http/detect.rs | 92 +++----------- .../http/handlers/connection_persistence.rs | 4 +- .../src/http/handlers/content_handler.rs | 4 +- monolake-services/src/http/handlers/mod.rs | 4 +- monolake-services/src/http/handlers/route.rs | 4 +- monolake-services/src/lib.rs | 10 +- monolake/src/factory.rs | 4 +- 10 files changed, 158 insertions(+), 94 deletions(-) create mode 100644 monolake-services/src/common/detect.rs diff --git a/monolake-services/src/common/detect.rs b/monolake-services/src/common/detect.rs new file mode 100644 index 0000000..e4763bc --- /dev/null +++ b/monolake-services/src/common/detect.rs @@ -0,0 +1,118 @@ +use std::{future::Future, io, io::Cursor}; + +use monoio::{ + buf::IoBufMut, + io::{AsyncReadRent, AsyncReadRentExt, PrefixedReadIo}, +}; +use service_async::Service; + +/// Detect is a trait for detecting a certain pattern in the input stream. +/// +/// It accepts an input stream and returns a tuple of the detected pattern and the wrapped input +/// stream which is usually a `PrefixedReadIo`. The implementation can choose to whether add the +/// prefix data. +/// If it fails to detect the pattern, it should represent the error inside the `DetOut`. +pub trait Detect { + type DetOut; + type IOOut; + + fn detect(&self, io: IO) -> impl Future>; +} + +/// DetectService is a service that detects a certain pattern in the input stream and forwards the +/// detected pattern and the wrapped input stream to the inner service. +pub struct DetectService { + pub detector: D, + pub inner: S, +} + +#[derive(thiserror::Error, Debug)] +pub enum DetectError { + #[error("service error: {0:?}")] + Svc(E), + #[error("io error: {0:?}")] + Io(std::io::Error), +} + +impl Service<(R, CX)> for DetectService +where + D: Detect, + S: Service<(D::DetOut, D::IOOut, CX)>, +{ + type Response = S::Response; + type Error = DetectError; + + async fn call(&self, (io, cx): (R, CX)) -> Result { + let (det, io) = self.detector.detect(io).await.map_err(DetectError::Io)?; + self.inner + .call((det, io, cx)) + .await + .map_err(DetectError::Svc) + } +} + +/// FixedLengthDetector detects a fixed length of bytes from the input stream. +pub struct FixedLengthDetector(pub F); + +impl Detect for FixedLengthDetector +where + F: Fn(&mut [u8]) -> DetOut, + IO: AsyncReadRent, +{ + type DetOut = DetOut; + type IOOut = PrefixedReadIo>>; + + async fn detect(&self, mut io: IO) -> io::Result<(Self::DetOut, Self::IOOut)> { + let buf = Vec::with_capacity(N).slice_mut(..N); + let (r, buf) = io.read_exact(buf).await; + r?; + + let mut buf = buf.into_inner(); + let r = (self.0)(&mut buf); + Ok((r, PrefixedReadIo::new(io, Cursor::new(buf)))) + } +} + +/// PrefixDetector detects a certain prefix from the input stream. +/// +/// If the prefix matches, it returns true and the wrapped input stream with the prefix data. +/// Otherwise, it returns false and the input stream with the prefix data(the prefix maybe less than +/// the static str's length). +pub struct PrefixDetector(pub &'static [u8]); + +impl Detect for PrefixDetector +where + IO: AsyncReadRent, +{ + type DetOut = bool; + type IOOut = PrefixedReadIo>>; + + async fn detect(&self, mut io: IO) -> io::Result<(Self::DetOut, Self::IOOut)> { + let l = self.0.len(); + let mut written = 0; + let mut buf: Vec = Vec::with_capacity(l); + let mut eq = true; + loop { + // # Safety + // The buf must have enough capacity to write the data. + let buf_slice = unsafe { buf.slice_mut_unchecked(written..l) }; + let (result, buf_slice) = io.read(buf_slice).await; + buf = buf_slice.into_inner(); + match result? { + 0 => { + break; + } + n => { + let curr = written; + written += n; + if self.0[curr..written] != buf[curr..written] { + eq = false; + break; + } + } + } + } + let io = PrefixedReadIo::new(io, Cursor::new(buf)); + Ok((eq && written == l, io)) + } +} diff --git a/monolake-services/src/common/mod.rs b/monolake-services/src/common/mod.rs index 0bc6fc7..516773a 100644 --- a/monolake-services/src/common/mod.rs +++ b/monolake-services/src/common/mod.rs @@ -2,6 +2,7 @@ mod cancel; mod context; mod delay; +mod detect; mod erase; mod map; mod panic; @@ -10,6 +11,7 @@ mod timeout; pub use cancel::{linked_list, Canceller, CancellerDropper, Waiter}; pub use context::ContextService; pub use delay::{Delay, DelayService}; +pub use detect::{Detect, DetectService, FixedLengthDetector, PrefixDetector}; pub use erase::EraseResp; pub use map::{FnSvc, Map, MapErr}; pub use panic::{CatchPanicError, CatchPanicService}; diff --git a/monolake-services/src/http/core.rs b/monolake-services/src/http/core.rs index e32d3d0..ad8236b 100644 --- a/monolake-services/src/http/core.rs +++ b/monolake-services/src/http/core.rs @@ -16,7 +16,7 @@ //! //! - Support for HTTP/1, HTTP/1.1, and HTTP/2 protocols //! - Composable design allowing a stack of `HttpHandler` implementations -//! - Automatic protocol detection when combined with `HttpVersionDetect` +//! - Automatic protocol detection when combined with `H2Detect` //! - Efficient handling of concurrent requests using asynchronous I/O //! - Configurable timeout settings for different stages of request processing //! - Integration with `service_async` for easy composition in service stacks @@ -25,17 +25,17 @@ //! # Usage //! //! `HttpCoreService` is typically used as part of a larger service stack, often in combination -//! with `HttpVersionDetect` for automatic protocol detection. Here's a basic example: +//! with `H2Detect` for automatic protocol detection. Here's a basic example: //! //! ```ignore //! use service_async::{layer::FactoryLayer, stack::FactoryStack}; //! -//! use crate::http::{HttpCoreService, HttpVersionDetect}; +//! use crate::http::{HttpCoreService, H2Detect}; //! //! let config = Config { /* ... */ }; //! let stack = FactoryStack::new(config) //! .push(HttpCoreService::layer()) -//! .push(HttpVersionDetect::layer()) +//! .push(H2Detect::layer()) //! // ... other handlers implementing HttpHandler ... //! ; //! @@ -52,7 +52,7 @@ //! //! # Automatic Protocol Detection //! -//! When used in conjunction with `HttpVersionDetect`, `HttpCoreService` can automatically +//! When used in conjunction with `H2Detect`, `HttpCoreService` can automatically //! detect whether an incoming connection is using HTTP/1, HTTP/1.1, or HTTP/2, and handle //! it appropriately. This allows for seamless support of multiple HTTP versions without //! the need for separate server configurations. diff --git a/monolake-services/src/http/detect.rs b/monolake-services/src/http/detect.rs index 14e14a0..8af5e20 100644 --- a/monolake-services/src/http/detect.rs +++ b/monolake-services/src/http/detect.rs @@ -6,8 +6,8 @@ //! //! # Key Components //! -//! - [`HttpVersionDetect`]: The main service component responsible for HTTP version detection. -//! - [`HttpVersionDetectError`]: Error type for version detection operations. +//! - [`H2Detect`]: The main service component responsible for HTTP version detection. +//! - [`H2DetectError`]: Error type for version detection operations. //! //! # Features //! @@ -27,7 +27,7 @@ //! let config = Config { /* ... */ }; //! let stack = FactoryStack::new(config) //! .push(HttpCoreService::layer()) -//! .push(HttpVersionDetect::layer()) +//! .push(H2Detect::layer()) //! // ... other layers ... //! ; //! @@ -39,122 +39,66 @@ //! //! - Uses efficient buffering to minimize I/O operations during version detection //! - Implements zero-copy techniques where possible to reduce memory overhead -use std::io::Cursor; -use monoio::{ - buf::IoBufMut, - io::{AsyncReadRent, AsyncWriteRent, PrefixedReadIo}, -}; -use monolake_core::http::HttpAccept; use service_async::{ layer::{layer_fn, FactoryLayer}, - AsyncMakeService, MakeService, Service, + AsyncMakeService, MakeService, }; -use crate::tcp::Accept; +use crate::common::{DetectService, PrefixDetector}; const PREFACE: &[u8; 24] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"; /// Service for detecting HTTP version and routing connections accordingly. /// -/// `HttpVersionDetect` examines the initial bytes of an incoming connection to +/// `H2Detect` examines the initial bytes of an incoming connection to /// determine whether it's an HTTP/2 connection (by checking for the HTTP/2 preface) /// or an HTTP/1.x connection. It then forwards the connection to the inner service /// with appropriate version information. /// For implementation details and example usage, see the /// [module level documentation](crate::http::detect). #[derive(Clone)] -pub struct HttpVersionDetect { +pub struct H2Detect { inner: T, } #[derive(thiserror::Error, Debug)] -pub enum HttpVersionDetectError { +pub enum H2DetectError { #[error("inner error: {0:?}")] Inner(E), #[error("io error: {0:?}")] Io(std::io::Error), } -impl MakeService for HttpVersionDetect { - type Service = HttpVersionDetect; +impl MakeService for H2Detect { + type Service = DetectService; type Error = F::Error; fn make_via_ref(&self, old: Option<&Self::Service>) -> Result { - Ok(HttpVersionDetect { + Ok(DetectService { inner: self.inner.make_via_ref(old.map(|o| &o.inner))?, + detector: PrefixDetector(PREFACE), }) } } -impl AsyncMakeService for HttpVersionDetect { - type Service = HttpVersionDetect; +impl AsyncMakeService for H2Detect { + type Service = DetectService; type Error = F::Error; async fn make_via_ref( &self, old: Option<&Self::Service>, ) -> Result { - Ok(HttpVersionDetect { + Ok(DetectService { inner: self.inner.make_via_ref(old.map(|o| &o.inner)).await?, + detector: PrefixDetector(PREFACE), }) } } -impl HttpVersionDetect { +impl H2Detect { pub fn layer() -> impl FactoryLayer { - layer_fn(|_: &C, inner| HttpVersionDetect { inner }) - } -} - -impl Service> for HttpVersionDetect -where - Stream: AsyncReadRent + AsyncWriteRent, - T: Service>>, CX>>, -{ - type Response = T::Response; - type Error = HttpVersionDetectError; - - async fn call( - &self, - incoming_stream: Accept, - ) -> Result { - let (mut stream, addr) = incoming_stream; - let mut buf = vec![0; PREFACE.len()]; - let mut pos = 0; - let mut h2_detect = false; - - loop { - let buf_slice = unsafe { buf.slice_mut_unchecked(pos..PREFACE.len()) }; - let (result, buf_slice) = stream.read(buf_slice).await; - buf = buf_slice.into_inner(); - match result { - Ok(0) => { - break; - } - Ok(n) => { - if PREFACE[pos..pos + n] != buf[pos..pos + n] { - break; - } - pos += n; - } - Err(e) => { - return Err(HttpVersionDetectError::Io(e)); - } - } - - if pos == PREFACE.len() { - h2_detect = true; - break; - } - } - - let preface_buf = std::io::Cursor::new(buf); - let rewind_io = monoio::io::PrefixedReadIo::new(stream, preface_buf); - - self.inner - .call((h2_detect, rewind_io, addr)) - .await - .map_err(HttpVersionDetectError::Inner) + layer_fn(|_: &C, inner| H2Detect { inner }) } } diff --git a/monolake-services/src/http/handlers/connection_persistence.rs b/monolake-services/src/http/handlers/connection_persistence.rs index dd4c075..a89c730 100644 --- a/monolake-services/src/http/handlers/connection_persistence.rs +++ b/monolake-services/src/http/handlers/connection_persistence.rs @@ -25,7 +25,7 @@ //! common::ContextService, //! http::{ //! core::HttpCoreService, -//! detect::HttpVersionDetect, +//! detect::H2Detect, //! handlers::{ //! route::RouteConfig, ConnectionReuseHandler, ContentHandler, RewriteAndRouteHandler, //! UpstreamHandler, @@ -60,7 +60,7 @@ //! .push(RewriteAndRouteHandler::layer()) //! .push(ConnectionReuseHandler::layer()) //! .push(HttpCoreService::layer()) -//! .push(HttpVersionDetect::layer()); +//! .push(H2Detect::layer()); //! //! // Use the service to handle HTTP requests //! ``` diff --git a/monolake-services/src/http/handlers/content_handler.rs b/monolake-services/src/http/handlers/content_handler.rs index 96f6f1e..919afa6 100644 --- a/monolake-services/src/http/handlers/content_handler.rs +++ b/monolake-services/src/http/handlers/content_handler.rs @@ -25,7 +25,7 @@ //! common::ContextService, //! http::{ //! core::HttpCoreService, -//! detect::HttpVersionDetect, +//! detect::H2Detect, //! handlers::{ //! route::RouteConfig, ConnectionReuseHandler, ContentHandler, RewriteAndRouteHandler, //! UpstreamHandler, @@ -60,7 +60,7 @@ //! .push(RewriteAndRouteHandler::layer()) //! .push(ConnectionReuseHandler::layer()) //! .push(HttpCoreService::layer()) -//! .push(HttpVersionDetect::layer()); +//! .push(H2Detect::layer()); //! //! // Use the service to handle HTTP requests //! ``` diff --git a/monolake-services/src/http/handlers/mod.rs b/monolake-services/src/http/handlers/mod.rs index 6a47b3c..9e6a6fd 100644 --- a/monolake-services/src/http/handlers/mod.rs +++ b/monolake-services/src/http/handlers/mod.rs @@ -61,7 +61,7 @@ //! common::ContextService, //! http::{ //! core::HttpCoreService, -//! detect::HttpVersionDetect, +//! detect::H2Detect, //! handlers::{ //! route::RouteConfig, ConnectionReuseHandler, ContentHandler, RewriteAndRouteHandler, //! UpstreamHandler, @@ -96,7 +96,7 @@ //! .push(RewriteAndRouteHandler::layer()) //! .push(ConnectionReuseHandler::layer()) //! .push(HttpCoreService::layer()) -//! .push(HttpVersionDetect::layer()); +//! .push(H2Detect::layer()); //! //! // Use the service to handle HTTP requests //! ``` diff --git a/monolake-services/src/http/handlers/route.rs b/monolake-services/src/http/handlers/route.rs index b945c40..bb38ac4 100644 --- a/monolake-services/src/http/handlers/route.rs +++ b/monolake-services/src/http/handlers/route.rs @@ -32,7 +32,7 @@ //! common::ContextService, //! http::{ //! core::HttpCoreService, -//! detect::HttpVersionDetect, +//! detect::H2Detect, //! handlers::{ //! route::RouteConfig, ConnectionReuseHandler, ContentHandler, RewriteAndRouteHandler, //! UpstreamHandler, @@ -67,7 +67,7 @@ //! .push(RewriteAndRouteHandler::layer()) //! .push(ConnectionReuseHandler::layer()) //! .push(HttpCoreService::layer()) -//! .push(HttpVersionDetect::layer()); +//! .push(H2Detect::layer()); //! //! // Use the service to handle HTTP requests //! ``` diff --git a/monolake-services/src/lib.rs b/monolake-services/src/lib.rs index b2a5031..2388ae6 100644 --- a/monolake-services/src/lib.rs +++ b/monolake-services/src/lib.rs @@ -15,9 +15,9 @@ //! //! - [`HttpCoreService`](http::core): The main service for handling HTTP/1.1 and HTTP/2 //! connections. -//! - [`HttpVersionDetect`](http::detect): Automatic detection of HTTP protocol versions. -//! #[cfg_attr(feature = "hyper", doc = "- [`HyperCoreService`](hyper::HyperCoreService): A -//! high-performance HTTP service built on top of the Hyper library.")] +//! - [`H2Detect`](http::detect): Automatic detection of HTTP protocol versions. #[cfg_attr(feature +//! = "hyper", doc = "- [`HyperCoreService`](hyper::HyperCoreService): A high-performance HTTP +//! service built on top of the Hyper library.")] //! //! #### Request Handlers //! @@ -108,7 +108,7 @@ //! //! ```ignore //! use monolake_services::{ -//! HttpCoreService, HttpVersionDetect, ConnectionReuseHandler, +//! HttpCoreService, H2Detect, ConnectionReuseHandler, //! ContentHandler, RewriteAndRouteHandler, UpstreamHandler, UnifiedTlsService, //! ProxyProtocolService, HyperCoreService //! }; @@ -124,7 +124,7 @@ //! .push(ContentHandler::layer()) //! .push(ConnectionReuseHandler::layer()) //! .push(HyperCoreService::layer()); -//! .push(HttpVersionDetect::layer()) +//! .push(H2Detect::layer()) //! .push(UnifiedTlsService::layer()) //! .push(ContextService::layer()); //! diff --git a/monolake/src/factory.rs b/monolake/src/factory.rs index 125bbd6..e29e556 100644 --- a/monolake/src/factory.rs +++ b/monolake/src/factory.rs @@ -12,7 +12,7 @@ use monolake_services::{ common::ContextService, http::{ core::HttpCoreService, - detect::HttpVersionDetect, + detect::H2Detect, handlers::{ upstream::HttpUpstreamTimeout, ConnectionReuseHandler, ContentHandler, RewriteAndRouteHandler, UpstreamHandler, @@ -55,7 +55,7 @@ pub fn l7_factory( let stacks = stacks .push(ConnectionReuseHandler::layer()) .push(HttpCoreService::layer()) - .push(HttpVersionDetect::layer()); + .push(H2Detect::layer()); #[cfg(feature = "tls")] let stacks = stacks.push(monolake_services::tls::UnifiedTlsFactory::layer());