diff --git a/capnp-rpc/src/local.rs b/capnp-rpc/src/local.rs index 50ac0d128..bd8dbe8b3 100644 --- a/capnp-rpc/src/local.rs +++ b/capnp-rpc/src/local.rs @@ -85,14 +85,19 @@ struct Results { message: Option>, cap_table: Vec>>, results_done_fulfiller: Option>>, + pipeline_sender: Option, } impl Results { - fn new(fulfiller: oneshot::Sender>) -> Self { + fn new( + fulfiller: oneshot::Sender>, + pipeline_sender: crate::queued::PipelineInnerSender, + ) -> Self { Self { message: Some(::capnp::message::Builder::new_default()), cap_table: Vec::new(), results_done_fulfiller: Some(fulfiller), + pipeline_sender: Some(pipeline_sender), } } } @@ -126,6 +131,25 @@ impl ResultsHook for Results { } } + fn set_pipeline(&mut self) -> capnp::Result<()> { + use ::capnp::traits::ImbueMut; + let root = self.get()?; + let size = root.target_size()?; + let mut message2 = capnp::message::Builder::new( + capnp::message::HeapAllocator::new().first_segment_words(size.word_count as u32 + 1), + ); + let mut root2: capnp::any_pointer::Builder = message2.init_root(); + let mut cap_table2 = vec![]; + root2.imbue_mut(&mut cap_table2); + root2.set_as(root.into_reader())?; + let hook = Box::new(ResultsDone::new(message2, cap_table2)) as Box; + let Some(sender) = self.pipeline_sender.take() else { + return Err(Error::failed("set_pipeline() called twice".into())); + }; + sender.complete(Box::new(Pipeline::new(hook))); + Ok(()) + } + fn tail_call(self: Box, _request: Box) -> Promise<(), Error> { unimplemented!() } @@ -147,12 +171,12 @@ struct ResultsDoneInner { cap_table: Vec>>, } -struct ResultsDone { +pub(crate) struct ResultsDone { inner: Rc, } impl ResultsDone { - fn new( + pub(crate) fn new( message: message::Builder, cap_table: Vec>>, ) -> Self { @@ -181,6 +205,8 @@ pub struct Request { interface_id: u64, method_id: u16, client: Box, + pipeline: crate::queued::Pipeline, + pipeline_sender: crate::queued::PipelineInnerSender, } impl Request { @@ -190,12 +216,15 @@ impl Request { _size_hint: Option<::capnp::MessageSize>, client: Box, ) -> Self { + let (pipeline_sender, pipeline) = crate::queued::Pipeline::new(); Self { message: message::Builder::new_default(), cap_table: Vec::new(), interface_id, method_id, client, + pipeline, + pipeline_sender, } } } @@ -217,17 +246,17 @@ impl RequestHook for Request { interface_id, method_id, client, + mut pipeline, + pipeline_sender, } = tmp; let params = Params::new(message, cap_table); let (results_done_fulfiller, results_done_promise) = oneshot::channel::>(); let results_done_promise = results_done_promise.map_err(crate::canceled_to_error); - let results = Results::new(results_done_fulfiller); + let results = Results::new(results_done_fulfiller, pipeline_sender.weak_clone()); let promise = client.call(interface_id, method_id, Box::new(params), Box::new(results)); - let (pipeline_sender, mut pipeline) = crate::queued::Pipeline::new(); - let p = futures::future::try_join(promise, results_done_promise).and_then( move |((), results_done_hook)| { pipeline_sender diff --git a/capnp-rpc/src/queued.rs b/capnp-rpc/src/queued.rs index 40322f352..701c0b3bc 100644 --- a/capnp-rpc/src/queued.rs +++ b/capnp-rpc/src/queued.rs @@ -44,7 +44,11 @@ pub struct PipelineInner { impl PipelineInner { fn resolve(this: &Rc>, result: Result, Error>) { - assert!(this.borrow().redirect.is_none()); + if this.borrow().redirect.is_some() { + // Already resolved, probably by set_pipeline(). + return; + } + let pipeline = match result { Ok(pipeline_hook) => pipeline_hook, Err(e) => Box::new(broken::Pipeline::new(e)), @@ -66,18 +70,30 @@ impl PipelineInner { pub struct PipelineInnerSender { inner: Option>>, + resolve_on_drop: bool, +} + +impl PipelineInnerSender { + pub(crate) fn weak_clone(&self) -> Self { + Self { + inner: self.inner.clone(), + resolve_on_drop: false, + } + } } impl Drop for PipelineInnerSender { fn drop(&mut self) { - if let Some(weak_queued) = self.inner.take() { - if let Some(pipeline_inner) = weak_queued.upgrade() { - PipelineInner::resolve( - &pipeline_inner, - Ok(Box::new(crate::broken::Pipeline::new(Error::failed( - "PipelineInnerSender was canceled".into(), - )))), - ); + if self.resolve_on_drop { + if let Some(weak_queued) = self.inner.take() { + if let Some(pipeline_inner) = weak_queued.upgrade() { + PipelineInner::resolve( + &pipeline_inner, + Ok(Box::new(crate::broken::Pipeline::new(Error::failed( + "PipelineInnerSender was canceled".into(), + )))), + ); + } } } } @@ -108,6 +124,7 @@ impl Pipeline { ( PipelineInnerSender { inner: Some(Rc::downgrade(&inner)), + resolve_on_drop: true, }, Self { inner }, ) @@ -271,9 +288,22 @@ impl ClientHook for Client { .attach(inner_clone) .and_then(|x| x); + // We need to drive `promise_to_drive` until we have a result. match self.inner.borrow().promise_to_drive { Some(ref p) => { - Promise::from_future(futures::future::try_join(p.clone(), promise).map_ok(|v| v.1)) + let p1 = p.clone(); + Promise::from_future(async move { + match futures::future::select(p1, promise).await { + futures::future::Either::Left((Ok(()), promise)) => promise.await, + futures::future::Either::Left((Err(e), _)) => Err(e), + futures::future::Either::Right((r, _)) => { + // Don't bother waiting for `promise_to_drive` to resolve. + // If we're here because set_pipeline() was called, then + // `promise_to_drive` might in fact never resolve. + r + } + } + }) } None => Promise::from_future(promise), } diff --git a/capnp-rpc/src/rpc.rs b/capnp-rpc/src/rpc.rs index 86d39af67..6d8081076 100644 --- a/capnp-rpc/src/rpc.rs +++ b/capnp-rpc/src/rpc.rs @@ -938,12 +938,15 @@ impl ConnectionState { let (results_inner_fulfiller, results_inner_promise) = oneshot::channel(); let results_inner_promise = results_inner_promise.map_err(crate::canceled_to_error); + + let (pipeline_sender, mut pipeline) = queued::Pipeline::new(); let results = Results::new( &connection_state, question_id, redirect_results, results_inner_fulfiller, answer.received_finish.clone(), + Some(pipeline_sender.weak_clone()), ); let (redirected_results_done_promise, redirected_results_done_fulfiller) = @@ -965,7 +968,6 @@ impl ConnectionState { let call_promise = capability.call(interface_id, method_id, Box::new(params), Box::new(results)); - let (pipeline_sender, mut pipeline) = queued::Pipeline::new(); let promise = call_promise .then(move |call_result| { @@ -2141,6 +2143,7 @@ where redirect_results: bool, answer_id: AnswerId, finish_received: Rc>, + pipeline_sender: Option, } impl ResultsInner @@ -2195,6 +2198,7 @@ where redirect_results: bool, fulfiller: oneshot::Sender>, finish_received: Rc>, + pipeline_sender: Option, ) -> Self { Self { inner: Some(ResultsInner { @@ -2203,6 +2207,7 @@ where redirect_results, answer_id, finish_received, + pipeline_sender, }), results_done_fulfiller: Some(fulfiller), } @@ -2250,6 +2255,29 @@ impl ResultsHook for Results { } } + fn set_pipeline(&mut self) -> ::capnp::Result<()> { + use ::capnp::traits::ImbueMut; + let root = self.get()?; + let size = root.target_size()?; + let mut message2 = capnp::message::Builder::new( + capnp::message::HeapAllocator::new().first_segment_words(size.word_count as u32 + 1), + ); + let mut root2: capnp::any_pointer::Builder = message2.init_root(); + let mut cap_table2 = vec![]; + root2.imbue_mut(&mut cap_table2); + root2.set_as(root.into_reader())?; + let hook = + Box::new(local::ResultsDone::new(message2, cap_table2)) as Box; + let Some(ref mut inner) = self.inner else { + unreachable!(); + }; + let Some(sender) = inner.pipeline_sender.take() else { + return Err(Error::failed("set_pipeline() called twice".into())); + }; + sender.complete(Box::new(local::Pipeline::new(hook))); + Ok(()) + } + fn tail_call(self: Box, _request: Box) -> Promise<(), Error> { unimplemented!() } diff --git a/capnp-rpc/test/impls.rs b/capnp-rpc/test/impls.rs index 59fa4c5e8..16046d9ad 100644 --- a/capnp-rpc/test/impls.rs +++ b/capnp-rpc/test/impls.rs @@ -276,6 +276,19 @@ impl test_pipeline::Server for TestPipeline { ) -> Promise<(), Error> { Promise::ok(()) } + + fn get_cap_pipeline_only( + &mut self, + _params: test_pipeline::GetCapPipelineOnlyParams, + mut results: test_pipeline::GetCapPipelineOnlyResults, + ) -> Promise<(), Error> { + results + .get() + .init_out_box() + .set_cap(capnp_rpc::new_client::(TestExtends).cast_to()); + pry!(results.set_pipeline()); + Promise::from_future(::futures::future::pending()) + } } #[derive(Default)] diff --git a/capnp-rpc/test/test.capnp b/capnp-rpc/test/test.capnp index cbd34a932..931ad3b5c 100644 --- a/capnp-rpc/test/test.capnp +++ b/capnp-rpc/test/test.capnp @@ -100,6 +100,9 @@ interface TestPipeline { getNullCap @1 () -> (cap :TestInterface); testPointers @2 (cap :TestInterface, obj :AnyPointer, list :List(TestInterface)) -> (); + getCapPipelineOnly @3 () -> (outBox :Box); + # Never returns, but uses setPipeline() to make the pipeline work. + struct Box { cap @0 :TestInterface; diff --git a/capnp-rpc/test/test.rs b/capnp-rpc/test/test.rs index 58bf4a581..e9de9db61 100644 --- a/capnp-rpc/test/test.rs +++ b/capnp-rpc/test/test.rs @@ -235,6 +235,9 @@ fn disconnector_disconnects_2() { }); } +/// Sets up a test_capnp::bootstrap capability on the remote side of a +/// two-party RPC connection and provides a local reference to it, along +/// with a handle that supports spawning of tasks. fn rpc_top_level(main: F) where F: FnOnce(futures::executor::LocalSpawner, test_capnp::bootstrap::Client) -> G, @@ -281,6 +284,30 @@ where join_handle.join().unwrap(); } +/// Like rpc_top_level(), but sets up the bootstrap::Client as a local capability, +/// i.e. not on the other side of an RPC connection. +fn local_top_level(main: F) +where + F: FnOnce(futures::executor::LocalSpawner, test_capnp::bootstrap::Client) -> G, + G: Future> + 'static, +{ + let mut pool = futures::executor::LocalPool::new(); + let spawner = pool.spawner(); + let client: test_capnp::bootstrap::Client = capnp_rpc::new_client(impls::Bootstrap); + pool.run_until(main(spawner, client)).unwrap(); +} + +/// Runs both rpc_top_level() and local_top_level() on the function `main`. +fn rpc_and_local_top_level(main: F) +where + F: FnOnce(futures::executor::LocalSpawner, test_capnp::bootstrap::Client) -> G, + F: Send + 'static + Clone, + G: Future> + 'static, +{ + rpc_top_level(main.clone()); + local_top_level(main); +} + #[test] fn do_nothing() { rpc_top_level(|_spawner, _client| async { Ok(()) }); @@ -325,7 +352,7 @@ fn basic_rpc_calls() { #[test] fn basic_pipelining() { - rpc_top_level(|_spawner, client| async move { + rpc_and_local_top_level(|_spawner, client| async move { let response = client.test_pipeline_request().send().promise.await?; let client = response.get()?.get_cap()?; @@ -412,6 +439,78 @@ fn null_capability() { assert!(root.get_interface_field().is_err()); } +struct WaitNTicks { + remaining: u32, +} + +impl WaitNTicks { + fn new(n: u32) -> Self { + Self { remaining: n } + } +} + +impl Future for WaitNTicks { + type Output = (); + + fn poll( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + if self.remaining > 0 { + self.remaining -= 1; + cx.waker().wake_by_ref(); // Wake up the task again + std::task::Poll::Pending + } else { + std::task::Poll::Ready(()) + } + } +} + +#[test] +fn set_pipeline() { + use std::cell::Cell; + use std::rc::Rc; + rpc_and_local_top_level(|mut spawner, client| async move { + let response = client.test_pipeline_request().send().promise.await?; + let client = response.get()?.get_cap()?; + let capnp::capability::RemotePromise { promise, pipeline } = + client.get_cap_pipeline_only_request().send(); + let promise_completed = Rc::new(Cell::new(false)); + let promise_completed2 = promise_completed.clone(); + spawn( + &mut spawner, + promise.map(move |_| { + promise_completed2.set(true); + Ok(()) + }), + ); + + let mut pipeline_request = pipeline.get_out_box().get_cap().foo_request(); + pipeline_request.get().set_i(321); + let pipeline_promise = pipeline_request.send().promise; + + let pipeline_request2 = pipeline + .get_out_box() + .get_cap() + .cast_to::() + .grault_request(); + let pipeline_promise2 = pipeline_request2.send().promise; + + let response = pipeline_promise.await?; + assert_eq!(response.get()?.get_x()?, "bar"); + + let response2 = pipeline_promise2.await?; + crate::test_util::CheckTestMessage::check_test_message(response2.get()?); + + // Give the original promise an opportunity to complete. + WaitNTicks::new(5).await; + + // The original promise never completed. + assert!(!promise_completed.get()); + Ok(()) + }); +} + #[test] fn release_simple() { rpc_top_level(|_spawner, client| async move { diff --git a/capnp/src/capability.rs b/capnp/src/capability.rs index b692b76d4..d87dc248b 100644 --- a/capnp/src/capability.rs +++ b/capnp/src/capability.rs @@ -263,6 +263,13 @@ where pub fn set(&mut self, other: T::Reader<'_>) -> crate::Result<()> { self.hook.get().unwrap().set_as(other) } + + /// Call this method to signal that all of the capabilities have been filled in for this + /// `Results` and that pipelined calls should be allowed to start using those capabilities. + /// (Usually pipelined calls are enqueued until the initial call completes.) + pub fn set_pipeline(&mut self) -> crate::Result<()> { + self.hook.set_pipeline() + } } pub trait FromTypelessPipeline { diff --git a/capnp/src/private/capability.rs b/capnp/src/private/capability.rs index f5f393e7e..996f5678f 100644 --- a/capnp/src/private/capability.rs +++ b/capnp/src/private/capability.rs @@ -92,6 +92,12 @@ impl Clone for alloc::boxed::Box { pub trait ResultsHook { fn get(&mut self) -> crate::Result>; + + // TODO(version bump): remove this default impl. + fn set_pipeline(&mut self) -> crate::Result<()> { + unimplemented!() + } + fn allow_cancellation(&self); fn tail_call( self: alloc::boxed::Box,