Skip to content

Commit

Permalink
adapt proxy-protocol implementation with async fn in trait (#53)
Browse files Browse the repository at this point in the history
* adapt proxy-protocol implementation with async fn in trait

* fix clippy error
  • Loading branch information
rainj-me authored Jan 4, 2024
1 parent 5a770d9 commit b066967
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 110 deletions.
6 changes: 2 additions & 4 deletions monolake-core/src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,10 +242,8 @@ where
fn execute(self, controller: &WorkerController<S>) -> Result<(), AnyError> {
match self {
Command::Update(name, factory) => {
match {
let sites = unsafe { &mut *controller.sites.get() };
sites.get(&name).map(|sh| sh.handler_slot.clone())
} {
let sites = unsafe { &mut *controller.sites.get() };
match sites.get(&name).map(|sh| sh.handler_slot.clone()) {
Some(svc_slot) => {
let svc = factory
.make_via_ref(Some(&svc_slot.get_svc()))
Expand Down
206 changes: 100 additions & 106 deletions monolake-services/src/proxy_protocol/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{fmt::Display, future::Future, net::SocketAddr};
use std::{fmt::Display, net::SocketAddr};

use monoio::{
buf::IoBufMut,
Expand Down Expand Up @@ -34,123 +34,117 @@ where
{
type Response = T::Response;
type Error = AnyError;
type Future<'cx> = impl Future<Output = Result<Self::Response, Self::Error>> + 'cx
where
Self: 'cx,
Accept<S, CX>: 'cx;

fn call(&self, (mut stream, ctx): Accept<S, CX>) -> Self::Future<'_> {
async move {
const MAX_HEADER_SIZE: usize = 230;
let mut buffer = Vec::with_capacity(MAX_HEADER_SIZE);
let mut pos = 0;
async fn call(&self, (mut stream, ctx): Accept<S, CX>) -> Result<Self::Response, Self::Error> {
const MAX_HEADER_SIZE: usize = 230;
let mut buffer = Vec::with_capacity(MAX_HEADER_SIZE);
let mut pos = 0;

// read at-least 1 byte
let (res, buf) = stream
.read(unsafe { buffer.slice_mut_unchecked(0..MAX_HEADER_SIZE) })
.await;
buffer = buf.into_inner();
pos += res.map_err(AnyError::from)?;
// match version magic header
let parsed = if let Some(target_header) = match buffer[0] {
b'P' => {
let end = pos.min(V1HEADER.len());
if buffer[1..end] == V1HEADER[1..end] {
Some(&V1HEADER[..])
} else {
tracing::warn!("proxy-protocol: v1 magic only partly matched");
None
}
// read at-least 1 byte
let (res, buf) = stream
.read(unsafe { buffer.slice_mut_unchecked(0..MAX_HEADER_SIZE) })
.await;
buffer = buf.into_inner();
pos += res.map_err(AnyError::from)?;
// match version magic header
let parsed = if let Some(target_header) = match buffer[0] {
b'P' => {
let end = pos.min(V1HEADER.len());
if buffer[1..end] == V1HEADER[1..end] {
Some(&V1HEADER[..])
} else {
tracing::warn!("proxy-protocol: v1 magic only partly matched");
None
}
0x0D => {
let end = pos.min(V2HEADER.len());
if buffer[1..end] == V2HEADER[1..end] {
Some(&V2HEADER[..])
} else {
tracing::warn!("proxy-protocol: v2 magic only partly matched");
None
}
}
0x0D => {
let end = pos.min(V2HEADER.len());
if buffer[1..end] == V2HEADER[1..end] {
Some(&V2HEADER[..])
} else {
tracing::warn!("proxy-protocol: v2 magic only partly matched");
None
}
_ => None,
} {
// loop {parse; read; check_full;}
let header = loop {
let mut cursor = std::io::Cursor::new(&buffer);
let e = match parse(&mut cursor) {
Ok(header) => break Ok((header, cursor.position())),
// data is not enough to parse version, we should read again
Err(
e @ ParseError::NotProxyHeader
| e @ ParseError::Version1 {
source: version1::ParseError::UnexpectedEof,
}
| e @ ParseError::Version2 {
source: version2::ParseError::UnexpectedEof,
},
) => e,
Err(e) => break Err(e),
};

let buf = unsafe { buffer.slice_mut_unchecked(pos..MAX_HEADER_SIZE) };
let (res, buf) = stream.read(buf).await;
buffer = buf.into_inner();
let read = res.map_err(AnyError::from)?;
// if we are reading magic header, we have to check if the magic header matches
// because ParseError::NotProxyHeader does not always mean data is not enough
if pos < target_header.len() {
let end = target_header.len().min(pos + read);
if buffer[pos..end] != target_header[pos..end] {
break Err(e);
}
_ => None,
} {
// loop {parse; read; check_full;}
let header = loop {
let mut cursor = std::io::Cursor::new(&buffer);
let e = match parse(&mut cursor) {
Ok(header) => break Ok((header, cursor.position())),
// data is not enough to parse version, we should read again
Err(
e @ ParseError::NotProxyHeader
| e @ ParseError::Version1 {
source: version1::ParseError::UnexpectedEof,
}
}
pos += read;
if pos == MAX_HEADER_SIZE {
return Err(ParseError::NotProxyHeader.into());
}
| e @ ParseError::Version2 {
source: version2::ParseError::UnexpectedEof,
},
) => e,
Err(e) => break Err(e),
};
Some(header)
} else {
tracing::debug!("proxy-protocol: not proxy protocol at first glance");
None
};

let mut cursor = std::io::Cursor::new(buffer);
let remote_addr = match parsed {
Some(Ok((header, idx))) => {
// advance proxy-protocol length on success parsing
cursor.set_position(idx);
match header {
ProxyHeader::Version1 {
addresses: version1::ProxyAddresses::Ipv4 { source, .. },
}
| ProxyHeader::Version2 {
addresses: version2::ProxyAddresses::Ipv4 { source, .. },
..
} => Some(RemoteAddr(AcceptedAddr::from(SocketAddr::from(source)))),
ProxyHeader::Version1 {
addresses: version1::ProxyAddresses::Ipv6 { source, .. },
}
| ProxyHeader::Version2 {
addresses: version2::ProxyAddresses::Ipv6 { source, .. },
..
} => Some(RemoteAddr(AcceptedAddr::from(SocketAddr::from(source)))),
_ => {
tracing::warn!("proxy protocol get source failed");
None
}
let buf = unsafe { buffer.slice_mut_unchecked(pos..MAX_HEADER_SIZE) };
let (res, buf) = stream.read(buf).await;
buffer = buf.into_inner();
let read = res.map_err(AnyError::from)?;
// if we are reading magic header, we have to check if the magic header matches
// because ParseError::NotProxyHeader does not always mean data is not enough
if pos < target_header.len() {
let end = target_header.len().min(pos + read);
if buffer[pos..end] != target_header[pos..end] {
break Err(e);
}
}
_ => None,
pos += read;
if pos == MAX_HEADER_SIZE {
return Err(ParseError::NotProxyHeader.into());
}
};
Some(header)
} else {
tracing::debug!("proxy-protocol: not proxy protocol at first glance");
None
};

let mut cursor = std::io::Cursor::new(buffer);
let remote_addr = match parsed {
Some(Ok((header, idx))) => {
// advance proxy-protocol length on success parsing
cursor.set_position(idx);
match header {
ProxyHeader::Version1 {
addresses: version1::ProxyAddresses::Ipv4 { source, .. },
}
| ProxyHeader::Version2 {
addresses: version2::ProxyAddresses::Ipv4 { source, .. },
..
} => Some(RemoteAddr(AcceptedAddr::from(SocketAddr::from(source)))),
ProxyHeader::Version1 {
addresses: version1::ProxyAddresses::Ipv6 { source, .. },
}
| ProxyHeader::Version2 {
addresses: version2::ProxyAddresses::Ipv6 { source, .. },
..
} => Some(RemoteAddr(AcceptedAddr::from(SocketAddr::from(source)))),
_ => {
tracing::warn!("proxy protocol get source failed");
None
}
}
}
_ => None,
};

let ctx = ctx.param_set(remote_addr);
let prefix_io = PrefixedReadIo::new(stream, cursor);
let ctx = ctx.param_set(remote_addr);
let prefix_io = PrefixedReadIo::new(stream, cursor);

self.inner
.call((prefix_io, ctx))
.await
.map_err(|e| e.into())
}
self.inner
.call((prefix_io, ctx))
.await
.map_err(|e| e.into())
}
}

Expand Down

0 comments on commit b066967

Please sign in to comment.