Skip to content

Commit

Permalink
improve: use preallocated context to avoid stack copying (#119)
Browse files Browse the repository at this point in the history
Signed-off-by: ihciah <[email protected]>
  • Loading branch information
ihciah authored Sep 24, 2024
1 parent b4b068a commit 1d26079
Show file tree
Hide file tree
Showing 9 changed files with 130 additions and 77 deletions.
9 changes: 5 additions & 4 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ monoio-native-tls = "0.3.0"
monoio-rustls = "0.3.0"
native-tls = "0.2"
service-async = "0.2.3"
certain-map = "0.2.4"
certain-map = "0.3.1"
local-sync = "0.1"
http = "1.0"
anyhow = "1"
Expand Down
1 change: 1 addition & 0 deletions monolake-services/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ anyhow = { workspace = true }
thiserror = { workspace = true }
serde = { workspace = true }
tracing = { workspace = true }
certain-map = { workspace = true }

# tls
monoio-rustls = { workspace = true, optional = true }
Expand Down
71 changes: 50 additions & 21 deletions monolake-services/src/common/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,52 +38,81 @@
//!
//! In this example, `ContextService` is used to transform an `EmptyContext` into a `FullContext`
//! by setting the `peer_addr` field.
use std::marker::PhantomData;

use certain_map::Handler;
use monolake_core::{context::PeerAddr, listener::AcceptedAddr};
use service_async::{
layer::{layer_fn, FactoryLayer},
AsyncMakeService, MakeService, ParamSet, Service,
};

/// A service to insert Context into the request processing pipeline, compatible with `certain_map`.
#[derive(Debug, Clone, Copy)]
pub struct ContextService<CX, T> {
#[derive(Debug)]
pub struct ContextService<CXStore, T> {
pub inner: T,
pub ctx: CX,
pub ctx: PhantomData<CXStore>,
}

unsafe impl<CXStore, T: Send> Send for ContextService<CXStore, T> {}
unsafe impl<CXStore, T: Sync> Sync for ContextService<CXStore, T> {}

// Manually impl Clone because CXStore does not have to impl Clone.
impl<CXStore, T> Clone for ContextService<CXStore, T>
where
T: Clone,
{
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
ctx: PhantomData,
}
}
}

impl<R, T, CX> Service<(R, AcceptedAddr)> for ContextService<CX, T>
// Manually impl Copy because CXStore does not have to impl Copy.
impl<CXStore, T> Copy for ContextService<CXStore, T> where T: Copy {}

impl<R, T, CXStore, Resp, Err> Service<(R, AcceptedAddr)> for ContextService<CXStore, T>
where
T: Service<(R, CX::Transformed)>,
CX: ParamSet<PeerAddr> + Clone,
CXStore: Default + Handler,
// HRTB is your friend!
// Please pay attention to when to use bound associated types and when to use associated types
// directly(here `Transformed` is not bound but `Response` and `Error` are).
for<'a> CXStore::Hdr<'a>: ParamSet<PeerAddr>,
for<'a> T: Service<
(R, <CXStore::Hdr<'a> as ParamSet<PeerAddr>>::Transformed),
Response = Resp,
Error = Err,
>,
{
type Response = T::Response;
type Error = T::Error;
type Response = Resp;
type Error = Err;

async fn call(&self, (req, addr): (R, AcceptedAddr)) -> Result<Self::Response, Self::Error> {
let ctx = self.ctx.clone().param_set(PeerAddr(addr));
self.inner.call((req, ctx)).await
let mut store = CXStore::default();
let hdr = store.handler();
let hdr = hdr.param_set(PeerAddr(addr));
self.inner.call((req, hdr)).await
}
}

impl<CX, F> ContextService<CX, F> {
pub fn layer<C>() -> impl FactoryLayer<C, F, Factory = Self>
where
CX: Default,
{
pub fn layer<C>() -> impl FactoryLayer<C, F, Factory = Self> {
layer_fn(|_: &C, inner| ContextService {
inner,
ctx: Default::default(),
ctx: PhantomData,
})
}
}

impl<CX: Clone, F: MakeService> MakeService for ContextService<CX, F> {
type Service = ContextService<CX, F::Service>;
impl<CXStore, F: MakeService> MakeService for ContextService<CXStore, F> {
type Service = ContextService<CXStore, F::Service>;
type Error = F::Error;

fn make_via_ref(&self, old: Option<&Self::Service>) -> Result<Self::Service, Self::Error> {
Ok(ContextService {
ctx: self.ctx.clone(),
ctx: PhantomData,
inner: self
.inner
.make_via_ref(old.map(|o| &o.inner))
Expand All @@ -92,16 +121,16 @@ impl<CX: Clone, F: MakeService> MakeService for ContextService<CX, F> {
}
}

impl<CX: Clone, F: AsyncMakeService> AsyncMakeService for ContextService<CX, F> {
type Service = ContextService<CX, F::Service>;
impl<CXStore, F: AsyncMakeService> AsyncMakeService for ContextService<CXStore, F> {
type Service = ContextService<CXStore, F::Service>;
type Error = F::Error;

async fn make_via_ref(
&self,
old: Option<&Self::Service>,
) -> Result<Self::Service, Self::Error> {
Ok(ContextService {
ctx: self.ctx.clone(),
ctx: PhantomData,
inner: self
.inner
.make_via_ref(old.map(|o| &o.inner))
Expand Down
63 changes: 44 additions & 19 deletions monolake-services/src/http/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
use std::{convert::Infallible, fmt::Debug, pin::Pin, time::Duration};

use bytes::Bytes;
use certain_map::{Attach, Fork};
use futures::{stream::FuturesUnordered, StreamExt};
use http::StatusCode;
use monoio::io::{sink::SinkExt, stream::Stream, AsyncReadRent, AsyncWriteRent, Split, Splitable};
Expand Down Expand Up @@ -114,12 +115,19 @@ impl<H> HttpCoreService<H> {
}
}

async fn h1_svc<S, CX>(&self, stream: S, ctx: CX)
async fn h1_svc<S, CXIn, CXStore, CXState, Err>(&self, stream: S, ctx: CXIn)
where
CXIn: ParamRef<PeerAddr> + Fork<Store = CXStore, State = CXState>,
CXStore: 'static,
for<'a> CXState: Attach<CXStore>,
for<'a> H: HttpHandler<
<CXState as Attach<CXStore>>::Hdr<'a>,
HttpBody,
Body = HttpBody,
Error = Err,
>,
Err: Into<AnyError> + Debug,
S: Split + AsyncReadRent + AsyncWriteRent,
H: HttpHandler<CX, HttpBody, Body = HttpBody>,
H::Error: Into<AnyError> + Debug,
CX: ParamRef<PeerAddr> + Clone,
{
let (reader, writer) = stream.into_split();
let mut decoder = RequestDecoder::new(reader);
Expand Down Expand Up @@ -161,10 +169,14 @@ impl<H> HttpCoreService<H> {
}
};

// fork ctx
let (mut store, state) = ctx.fork();
let forked_ctx = unsafe { state.attach(&mut store) };

// handle request and reply response
// 1. do these things simultaneously: read body and send + handle request
let mut acc_fut = AccompanyPair::new(
self.handler_chain.handle(req, ctx.clone()),
self.handler_chain.handle(req, forked_ctx),
decoder.fill_payload(),
);
let res = unsafe { Pin::new_unchecked(&mut acc_fut) }.await;
Expand Down Expand Up @@ -272,12 +284,19 @@ impl<H> HttpCoreService<H> {
}
}

async fn h2_svc<S, CX>(&self, stream: S, ctx: CX)
async fn h2_svc<S, CXIn, CXStore, CXState, Err>(&self, stream: S, ctx: CXIn)
where
CXIn: ParamRef<PeerAddr> + Fork<Store = CXStore, State = CXState>,
CXStore: 'static,
for<'a> CXState: Attach<CXStore>,
for<'a> H: HttpHandler<
<CXState as Attach<CXStore>>::Hdr<'a>,
HttpBody,
Body = HttpBody,
Error = Err,
>,
Err: Into<AnyError> + Debug,
S: Split + AsyncReadRent + AsyncWriteRent + Unpin + 'static,
H: HttpHandler<CX, HttpBody, Body = HttpBody>,
H::Error: Into<AnyError> + Debug,
CX: ParamRef<PeerAddr> + Clone,
{
let mut connection = match monoio_http::h2::server::Builder::new()
.initial_window_size(1_000_000)
Expand Down Expand Up @@ -316,13 +335,15 @@ impl<H> HttpCoreService<H> {
});

loop {
let ctx = ctx.clone();
monoio::select! {
Some(Ok((request, response_handle))) = rx.recv() => {
let request = HttpBody::request(request);
backend_resp_stream.push( async move {
(self.handler_chain.handle(request, ctx).await, response_handle)
});
let request = HttpBody::request(request);
// fork ctx
let (mut store, state) = ctx.fork();
backend_resp_stream.push(async move {
let forked_ctx = unsafe { state.attach(&mut store) };
(self.handler_chain.handle(request, forked_ctx).await, response_handle)
});
}
Some(result) = backend_resp_stream.next() => {
match result {
Expand Down Expand Up @@ -354,19 +375,23 @@ impl<H> HttpCoreService<H> {
}
}

impl<H, Stream, CX> Service<HttpAccept<Stream, CX>> for HttpCoreService<H>
impl<H, Stream, CXIn, CXStore, CXState, Err> Service<HttpAccept<Stream, CXIn>>
for HttpCoreService<H>
where
CXIn: ParamRef<PeerAddr> + Fork<Store = CXStore, State = CXState>,
CXStore: 'static,
for<'a> CXState: Attach<CXStore>,
for<'a> H:
HttpHandler<<CXState as Attach<CXStore>>::Hdr<'a>, HttpBody, Body = HttpBody, Error = Err>,
Stream: Split + AsyncReadRent + AsyncWriteRent + Unpin + 'static,
H: HttpHandler<CX, HttpBody, Body = HttpBody>,
H::Error: Into<AnyError> + Debug,
CX: ParamRef<PeerAddr> + Clone,
Err: Into<AnyError> + Debug,
{
type Response = ();
type Error = Infallible;

async fn call(
&self,
incoming_stream: HttpAccept<Stream, CX>,
incoming_stream: HttpAccept<Stream, CXIn>,
) -> Result<Self::Response, Self::Error> {
let (use_h2, stream, ctx) = incoming_stream;
if use_h2 {
Expand Down
2 changes: 1 addition & 1 deletion monolake-services/src/proxy_protocol/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ use crate::tcp::Accept;

// Ref: https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt
// V1 max length is 107-byte.
const V1HEADER: &[u8; 6] = &[b'P', b'R', b'O', b'X', b'Y', b' '];
const V1HEADER: &[u8; 6] = b"PROXY ";
// V2 max length is 14+216 = 230 bytes.
const V2HEADER: &[u8; 12] = &[
0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
Expand Down
43 changes: 20 additions & 23 deletions monolake-services/src/thrift/ttheader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
use std::{convert::Infallible, fmt::Debug, time::Duration};

use certain_map::{Attach, Fork};
use monoio::io::{sink::SinkExt, stream::Stream, AsyncReadRent, AsyncWriteRent};
use monoio_codec::Framed;
use monoio_thrift::codec::ttheader::{RawPayloadCodec, TTHeaderPayloadCodec};
Expand Down Expand Up @@ -84,14 +85,21 @@ impl<H> TtheaderCoreService<H> {
thrift_timeout,
}
}
}

async fn svc<S, CX>(&self, stream: S, ctx: CX)
where
S: AsyncReadRent + AsyncWriteRent,
H: ThriftHandler<CX>,
H::Error: Into<AnyError> + Debug,
CX: ParamRef<PeerAddr> + Clone,
{
impl<H, Stream, CXIn, CXStore, CXState, ERR> Service<(Stream, CXIn)> for TtheaderCoreService<H>
where
CXIn: ParamRef<PeerAddr> + Fork<Store = CXStore, State = CXState>,
CXStore: 'static,
for<'a> CXState: Attach<CXStore>,
for<'a> H: ThriftHandler<<CXState as Attach<CXStore>>::Hdr<'a>, Error = ERR>,
ERR: Into<AnyError> + Debug,
Stream: AsyncReadRent + AsyncWriteRent + Unpin + 'static,
{
type Response = ();
type Error = Infallible;

async fn call(&self, (stream, ctx): (Stream, CXIn)) -> Result<Self::Response, Self::Error> {
let mut codec = Framed::new(stream, TTHeaderPayloadCodec::new(RawPayloadCodec::new()));
loop {
if let Some(keepalive_timeout) = self.thrift_timeout.keepalive_timeout {
Expand Down Expand Up @@ -150,8 +158,12 @@ impl<H> TtheaderCoreService<H> {
}
};

// fork ctx
let (mut store, state) = ctx.fork();
let forked_ctx = unsafe { state.attach(&mut store) };

// handle request and reply response
match self.handler_chain.handle(req, ctx.clone()).await {
match self.handler_chain.handle(req, forked_ctx).await {
Ok(resp) => {
if let Err(e) = codec.send_and_flush(resp).await {
warn!("error when reply client: {e}");
Expand All @@ -173,21 +185,6 @@ impl<H> TtheaderCoreService<H> {
}
}
}
}
}

impl<H, Stream, CX> Service<(Stream, CX)> for TtheaderCoreService<H>
where
Stream: AsyncReadRent + AsyncWriteRent + Unpin + 'static,
H: ThriftHandler<CX>,
H::Error: Into<AnyError> + Debug,
CX: ParamRef<PeerAddr> + Clone,
{
type Response = ();
type Error = Infallible;

async fn call(&self, incoming_stream: (Stream, CX)) -> Result<Self::Response, Self::Error> {
self.svc(incoming_stream.0, incoming_stream.1).await;
Ok(())
}
}
Expand Down
10 changes: 5 additions & 5 deletions monolake/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@ use monolake_core::context::{PeerAddr, RemoteAddr};
// This struct should be a app-defined struct.
// Framework should not bind it.
certain_map::certain_map! {
#[derive(Debug, Clone)]
#[empty(EmptyContext)]
#[derive(Clone)]
#[full(FullContext)]
pub struct Context {
// Set by ContextService
Expand All @@ -26,11 +25,12 @@ mod test {

#[test]
pub fn test_add_entries_to_context() {
let ctx = Context::new();
let mut ctx = Context::new();
let handler = ctx.handler();
let addr: SocketAddr = "127.0.0.1:8080".parse().unwrap();
let peer_addr = PeerAddr::from(AcceptedAddr::from(addr));
let ctx = ctx.param_set(peer_addr);
match ParamRef::<PeerAddr>::param_ref(&ctx).0 {
let handler = handler.param_set(peer_addr);
match ParamRef::<PeerAddr>::param_ref(&handler).0 {
AcceptedAddr::Tcp(socket_addr) => assert_eq!(addr, socket_addr),
_ => unreachable!(),
}
Expand Down
Loading

0 comments on commit 1d26079

Please sign in to comment.