Skip to content

Commit

Permalink
feat: generic protocol detector
Browse files Browse the repository at this point in the history
  • Loading branch information
ihciah committed Oct 24, 2024
1 parent 5d7198f commit 5c24e28
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 65 deletions.
118 changes: 118 additions & 0 deletions monolake-services/src/common/detect.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
use std::{future::Future, io, io::Cursor};

use monoio::{
buf::IoBufMut,
io::{AsyncReadRent, AsyncReadRentExt, PrefixedReadIo},
};
use service_async::Service;

/// Detect is a trait for detecting a certain pattern in the input stream.
///
/// It accepts an input stream and returns a tuple of the detected pattern and the wrapped input
/// stream which is usually a `PrefixedReadIo`. The implementation can choose to whether add the
/// prefix data.
/// If it fails to detect the pattern, it should represent the error inside the `DetOut`.
pub trait Detect<IO> {
type DetOut;
type IOOut;

fn detect(&self, io: IO) -> impl Future<Output = io::Result<(Self::DetOut, Self::IOOut)>>;
}

/// DetectService is a service that detects a certain pattern in the input stream and forwards the
/// detected pattern and the wrapped input stream to the inner service.
pub struct DetectService<D, S> {
pub detector: D,
pub inner: S,
}

#[derive(thiserror::Error, Debug)]
pub enum DetectError<E> {
#[error("service error: {0:?}")]
Svc(E),
#[error("io error: {0:?}")]
Io(std::io::Error),
}

impl<R, S, D, CX> Service<(R, CX)> for DetectService<D, S>
where
D: Detect<R>,
S: Service<(D::DetOut, D::IOOut, CX)>,
{
type Response = S::Response;
type Error = DetectError<S::Error>;

async fn call(&self, (io, cx): (R, CX)) -> Result<Self::Response, Self::Error> {
let (det, io) = self.detector.detect(io).await.map_err(DetectError::Io)?;
self.inner
.call((det, io, cx))
.await
.map_err(DetectError::Svc)
}
}

/// FixedLengthDetector detects a fixed length of bytes from the input stream.
pub struct FixedLengthDetector<const N: usize, F>(pub F);

impl<const N: usize, F, IO, DetOut> Detect<IO> for FixedLengthDetector<N, F>
where
F: Fn(&mut [u8]) -> DetOut,
IO: AsyncReadRent,
{
type DetOut = DetOut;
type IOOut = PrefixedReadIo<IO, Cursor<Vec<u8>>>;

async fn detect(&self, mut io: IO) -> io::Result<(Self::DetOut, Self::IOOut)> {
let buf = Vec::with_capacity(N).slice_mut(..N);
let (r, buf) = io.read_exact(buf).await;
r?;

let mut buf = buf.into_inner();
let r = (self.0)(&mut buf);
Ok((r, PrefixedReadIo::new(io, Cursor::new(buf))))
}
}

/// PrefixDetector detects a certain prefix from the input stream.
///
/// If the prefix matches, it returns true and the wrapped input stream with the prefix data.
/// Otherwise, it returns false and the input stream with the prefix data(the prefix maybe less than
/// the static str's length).
pub struct PrefixDetector(pub &'static [u8]);

impl<IO> Detect<IO> for PrefixDetector
where
IO: AsyncReadRent,
{
type DetOut = bool;
type IOOut = PrefixedReadIo<IO, Cursor<Vec<u8>>>;

async fn detect(&self, mut io: IO) -> io::Result<(Self::DetOut, Self::IOOut)> {
let l = self.0.len();
let mut written = 0;
let mut buf: Vec<u8> = Vec::with_capacity(l);
let mut eq = true;
loop {
// # Safety
// The buf must have enough capacity to write the data.
let buf_slice = unsafe { buf.slice_mut_unchecked(written..l) };
let (result, buf_slice) = io.read(buf_slice).await;
buf = buf_slice.into_inner();
match result? {
0 => {
break;
}
n => {
let curr = written;
written += n;
if self.0[curr..written] != buf[curr..written] {
eq = false;
break;
}
}
}
}
let io = PrefixedReadIo::new(io, Cursor::new(buf));
Ok((eq && written == l, io))
}
}
2 changes: 2 additions & 0 deletions monolake-services/src/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
mod cancel;
mod context;
mod delay;
mod detect;
mod erase;
mod map;
mod panic;
Expand All @@ -10,6 +11,7 @@ mod timeout;
pub use cancel::{linked_list, Canceller, CancellerDropper, Waiter};
pub use context::ContextService;
pub use delay::{Delay, DelayService};
pub use detect::{Detect, DetectService, FixedLengthDetector, PrefixDetector};
pub use erase::EraseResp;
pub use map::{FnSvc, Map, MapErr};
pub use panic::{CatchPanicError, CatchPanicService};
Expand Down
72 changes: 8 additions & 64 deletions monolake-services/src/http/detect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,13 @@
//!
//! - Uses efficient buffering to minimize I/O operations during version detection
//! - Implements zero-copy techniques where possible to reduce memory overhead
use std::io::Cursor;
use monoio::{
buf::IoBufMut,
io::{AsyncReadRent, AsyncWriteRent, PrefixedReadIo},
};
use monolake_core::http::HttpAccept;
use service_async::{
layer::{layer_fn, FactoryLayer},
AsyncMakeService, MakeService, Service,
AsyncMakeService, MakeService,
};

use crate::tcp::Accept;
use crate::common::{DetectService, PrefixDetector};

const PREFACE: &[u8; 24] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n";

Expand All @@ -77,26 +71,28 @@ pub enum HttpVersionDetectError<E> {
}

impl<F: MakeService> MakeService for HttpVersionDetect<F> {
type Service = HttpVersionDetect<F::Service>;
type Service = DetectService<PrefixDetector, F::Service>;
type Error = F::Error;

fn make_via_ref(&self, old: Option<&Self::Service>) -> Result<Self::Service, Self::Error> {
Ok(HttpVersionDetect {
Ok(DetectService {
inner: self.inner.make_via_ref(old.map(|o| &o.inner))?,
detector: PrefixDetector(PREFACE),
})
}
}

impl<F: AsyncMakeService> AsyncMakeService for HttpVersionDetect<F> {
type Service = HttpVersionDetect<F::Service>;
type Service = DetectService<PrefixDetector, F::Service>;
type Error = F::Error;

async fn make_via_ref(
&self,
old: Option<&Self::Service>,
) -> Result<Self::Service, Self::Error> {
Ok(HttpVersionDetect {
Ok(DetectService {
inner: self.inner.make_via_ref(old.map(|o| &o.inner)).await?,
detector: PrefixDetector(PREFACE),
})
}
}
Expand All @@ -106,55 +102,3 @@ impl<F> HttpVersionDetect<F> {
layer_fn(|_: &C, inner| HttpVersionDetect { inner })
}
}

impl<T, Stream, CX> Service<Accept<Stream, CX>> for HttpVersionDetect<T>
where
Stream: AsyncReadRent + AsyncWriteRent,
T: Service<HttpAccept<PrefixedReadIo<Stream, Cursor<Vec<u8>>>, CX>>,
{
type Response = T::Response;
type Error = HttpVersionDetectError<T::Error>;

async fn call(
&self,
incoming_stream: Accept<Stream, CX>,
) -> Result<Self::Response, Self::Error> {
let (mut stream, addr) = incoming_stream;
let mut buf = vec![0; PREFACE.len()];
let mut pos = 0;
let mut h2_detect = false;

loop {
let buf_slice = unsafe { buf.slice_mut_unchecked(pos..PREFACE.len()) };
let (result, buf_slice) = stream.read(buf_slice).await;
buf = buf_slice.into_inner();
match result {
Ok(0) => {
break;
}
Ok(n) => {
if PREFACE[pos..pos + n] != buf[pos..pos + n] {
break;
}
pos += n;
}
Err(e) => {
return Err(HttpVersionDetectError::Io(e));
}
}

if pos == PREFACE.len() {
h2_detect = true;
break;
}
}

let preface_buf = std::io::Cursor::new(buf);
let rewind_io = monoio::io::PrefixedReadIo::new(stream, preface_buf);

self.inner
.call((h2_detect, rewind_io, addr))
.await
.map_err(HttpVersionDetectError::Inner)
}
}
1 change: 0 additions & 1 deletion monolake-services/src/thrift/handlers/proxy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ impl<C, E> ConnectorMapper<C, E> for ThriftConnectorMapper {
/// It manages connections to upstream servers using a connection pool for efficiency.
/// For implementation details and example usage, see the
/// [module level documentation](crate::thrift::handlers::proxy).
pub struct ProxyHandler {
connector: PoolThriftConnector,
routes: Vec<RouteConfig>,
Expand Down

0 comments on commit 5c24e28

Please sign in to comment.