From b066967967a85a6a3156066b2b33397888a74898 Mon Sep 17 00:00:00 2001 From: Rain Jiang <96632942+rainj-me@users.noreply.github.com> Date: Thu, 4 Jan 2024 01:33:53 +0000 Subject: [PATCH] adapt proxy-protocol implementation with async fn in trait (#53) * adapt proxy-protocol implementation with async fn in trait * fix clippy error --- monolake-core/src/server/mod.rs | 6 +- monolake-services/src/proxy_protocol/mod.rs | 206 ++++++++++---------- 2 files changed, 102 insertions(+), 110 deletions(-) diff --git a/monolake-core/src/server/mod.rs b/monolake-core/src/server/mod.rs index 4d93cd3..f1ae38c 100644 --- a/monolake-core/src/server/mod.rs +++ b/monolake-core/src/server/mod.rs @@ -242,10 +242,8 @@ where fn execute(self, controller: &WorkerController) -> Result<(), AnyError> { match self { Command::Update(name, factory) => { - match { - let sites = unsafe { &mut *controller.sites.get() }; - sites.get(&name).map(|sh| sh.handler_slot.clone()) - } { + let sites = unsafe { &mut *controller.sites.get() }; + match sites.get(&name).map(|sh| sh.handler_slot.clone()) { Some(svc_slot) => { let svc = factory .make_via_ref(Some(&svc_slot.get_svc())) diff --git a/monolake-services/src/proxy_protocol/mod.rs b/monolake-services/src/proxy_protocol/mod.rs index ee635c4..577e76e 100644 --- a/monolake-services/src/proxy_protocol/mod.rs +++ b/monolake-services/src/proxy_protocol/mod.rs @@ -1,4 +1,4 @@ -use std::{fmt::Display, future::Future, net::SocketAddr}; +use std::{fmt::Display, net::SocketAddr}; use monoio::{ buf::IoBufMut, @@ -34,123 +34,117 @@ where { type Response = T::Response; type Error = AnyError; - type Future<'cx> = impl Future> + 'cx - where - Self: 'cx, - Accept: 'cx; - fn call(&self, (mut stream, ctx): Accept) -> Self::Future<'_> { - async move { - const MAX_HEADER_SIZE: usize = 230; - let mut buffer = Vec::with_capacity(MAX_HEADER_SIZE); - let mut pos = 0; + async fn call(&self, (mut stream, ctx): Accept) -> Result { + const MAX_HEADER_SIZE: usize = 230; + let mut buffer = Vec::with_capacity(MAX_HEADER_SIZE); + let mut pos = 0; - // read at-least 1 byte - let (res, buf) = stream - .read(unsafe { buffer.slice_mut_unchecked(0..MAX_HEADER_SIZE) }) - .await; - buffer = buf.into_inner(); - pos += res.map_err(AnyError::from)?; - // match version magic header - let parsed = if let Some(target_header) = match buffer[0] { - b'P' => { - let end = pos.min(V1HEADER.len()); - if buffer[1..end] == V1HEADER[1..end] { - Some(&V1HEADER[..]) - } else { - tracing::warn!("proxy-protocol: v1 magic only partly matched"); - None - } + // read at-least 1 byte + let (res, buf) = stream + .read(unsafe { buffer.slice_mut_unchecked(0..MAX_HEADER_SIZE) }) + .await; + buffer = buf.into_inner(); + pos += res.map_err(AnyError::from)?; + // match version magic header + let parsed = if let Some(target_header) = match buffer[0] { + b'P' => { + let end = pos.min(V1HEADER.len()); + if buffer[1..end] == V1HEADER[1..end] { + Some(&V1HEADER[..]) + } else { + tracing::warn!("proxy-protocol: v1 magic only partly matched"); + None } - 0x0D => { - let end = pos.min(V2HEADER.len()); - if buffer[1..end] == V2HEADER[1..end] { - Some(&V2HEADER[..]) - } else { - tracing::warn!("proxy-protocol: v2 magic only partly matched"); - None - } + } + 0x0D => { + let end = pos.min(V2HEADER.len()); + if buffer[1..end] == V2HEADER[1..end] { + Some(&V2HEADER[..]) + } else { + tracing::warn!("proxy-protocol: v2 magic only partly matched"); + None } - _ => None, - } { - // loop {parse; read; check_full;} - let header = loop { - let mut cursor = std::io::Cursor::new(&buffer); - let e = match parse(&mut cursor) { - Ok(header) => break Ok((header, cursor.position())), - // data is not enough to parse version, we should read again - Err( - e @ ParseError::NotProxyHeader - | e @ ParseError::Version1 { - source: version1::ParseError::UnexpectedEof, - } - | e @ ParseError::Version2 { - source: version2::ParseError::UnexpectedEof, - }, - ) => e, - Err(e) => break Err(e), - }; - - let buf = unsafe { buffer.slice_mut_unchecked(pos..MAX_HEADER_SIZE) }; - let (res, buf) = stream.read(buf).await; - buffer = buf.into_inner(); - let read = res.map_err(AnyError::from)?; - // if we are reading magic header, we have to check if the magic header matches - // because ParseError::NotProxyHeader does not always mean data is not enough - if pos < target_header.len() { - let end = target_header.len().min(pos + read); - if buffer[pos..end] != target_header[pos..end] { - break Err(e); + } + _ => None, + } { + // loop {parse; read; check_full;} + let header = loop { + let mut cursor = std::io::Cursor::new(&buffer); + let e = match parse(&mut cursor) { + Ok(header) => break Ok((header, cursor.position())), + // data is not enough to parse version, we should read again + Err( + e @ ParseError::NotProxyHeader + | e @ ParseError::Version1 { + source: version1::ParseError::UnexpectedEof, } - } - pos += read; - if pos == MAX_HEADER_SIZE { - return Err(ParseError::NotProxyHeader.into()); - } + | e @ ParseError::Version2 { + source: version2::ParseError::UnexpectedEof, + }, + ) => e, + Err(e) => break Err(e), }; - Some(header) - } else { - tracing::debug!("proxy-protocol: not proxy protocol at first glance"); - None - }; - let mut cursor = std::io::Cursor::new(buffer); - let remote_addr = match parsed { - Some(Ok((header, idx))) => { - // advance proxy-protocol length on success parsing - cursor.set_position(idx); - match header { - ProxyHeader::Version1 { - addresses: version1::ProxyAddresses::Ipv4 { source, .. }, - } - | ProxyHeader::Version2 { - addresses: version2::ProxyAddresses::Ipv4 { source, .. }, - .. - } => Some(RemoteAddr(AcceptedAddr::from(SocketAddr::from(source)))), - ProxyHeader::Version1 { - addresses: version1::ProxyAddresses::Ipv6 { source, .. }, - } - | ProxyHeader::Version2 { - addresses: version2::ProxyAddresses::Ipv6 { source, .. }, - .. - } => Some(RemoteAddr(AcceptedAddr::from(SocketAddr::from(source)))), - _ => { - tracing::warn!("proxy protocol get source failed"); - None - } + let buf = unsafe { buffer.slice_mut_unchecked(pos..MAX_HEADER_SIZE) }; + let (res, buf) = stream.read(buf).await; + buffer = buf.into_inner(); + let read = res.map_err(AnyError::from)?; + // if we are reading magic header, we have to check if the magic header matches + // because ParseError::NotProxyHeader does not always mean data is not enough + if pos < target_header.len() { + let end = target_header.len().min(pos + read); + if buffer[pos..end] != target_header[pos..end] { + break Err(e); } } - _ => None, + pos += read; + if pos == MAX_HEADER_SIZE { + return Err(ParseError::NotProxyHeader.into()); + } }; + Some(header) + } else { + tracing::debug!("proxy-protocol: not proxy protocol at first glance"); + None + }; + + let mut cursor = std::io::Cursor::new(buffer); + let remote_addr = match parsed { + Some(Ok((header, idx))) => { + // advance proxy-protocol length on success parsing + cursor.set_position(idx); + match header { + ProxyHeader::Version1 { + addresses: version1::ProxyAddresses::Ipv4 { source, .. }, + } + | ProxyHeader::Version2 { + addresses: version2::ProxyAddresses::Ipv4 { source, .. }, + .. + } => Some(RemoteAddr(AcceptedAddr::from(SocketAddr::from(source)))), + ProxyHeader::Version1 { + addresses: version1::ProxyAddresses::Ipv6 { source, .. }, + } + | ProxyHeader::Version2 { + addresses: version2::ProxyAddresses::Ipv6 { source, .. }, + .. + } => Some(RemoteAddr(AcceptedAddr::from(SocketAddr::from(source)))), + _ => { + tracing::warn!("proxy protocol get source failed"); + None + } + } + } + _ => None, + }; - let ctx = ctx.param_set(remote_addr); - let prefix_io = PrefixedReadIo::new(stream, cursor); + let ctx = ctx.param_set(remote_addr); + let prefix_io = PrefixedReadIo::new(stream, cursor); - self.inner - .call((prefix_io, ctx)) - .await - .map_err(|e| e.into()) - } + self.inner + .call((prefix_io, ctx)) + .await + .map_err(|e| e.into()) } }