diff --git a/capnp-rpc/src/rpc.rs b/capnp-rpc/src/rpc.rs index 62b5c01c1..785d844c6 100644 --- a/capnp-rpc/src/rpc.rs +++ b/capnp-rpc/src/rpc.rs @@ -278,6 +278,11 @@ impl Answer { pub struct Export { refcount: u32, + + /// If true, this is the canonical export entry for this clientHook, that is, + /// `exports_by_cap[clientHook]` points to this entry. + canonical: bool, + client_hook: Box, // If this export is a promise (not a settled capability), the `resolve_op` represents the @@ -289,6 +294,7 @@ impl Export { fn new(client_hook: Box) -> Self { Self { refcount: 1, + canonical: false, client_hook, resolve_op: Promise::err(Error::failed("no resolve op".to_string())), } @@ -1184,8 +1190,10 @@ impl ConnectionState { e.refcount -= refcount; if e.refcount == 0 { let client_ptr = e.client_hook.get_ptr(); + if e.canonical { + self.exports_by_cap.borrow_mut().remove(&client_ptr); + } exports.erase(id); - self.exports_by_cap.borrow_mut().remove(&client_ptr); } Ok(()) } @@ -1286,7 +1294,8 @@ impl ConnectionState { promise: Promise, Error>, ) -> Promise<(), Error> { let weak_connection_state = Rc::downgrade(state); - state.eagerly_evaluate(promise.map(move |resolution_result| { + state.eagerly_evaluate(Promise::from_future(async move { + let resolution_result = promise.await; let connection_state = weak_connection_state .upgrade() .expect("dangling connection state?"); @@ -1300,29 +1309,68 @@ impl ConnectionState { // Update the export table to point at this object instead. We know that our // entry in the export table is still live because when it is destroyed the // asynchronous resolution task (i.e. this code) is canceled. - if let Some(exp) = connection_state.exports.borrow_mut().find(export_id) { + let mut exports = connection_state.exports.borrow_mut(); + let Some(exp) = exports.find(export_id) else { + return Err(Error::failed("export table entry not found".to_string())); + }; + + if exp.canonical { connection_state .exports_by_cap .borrow_mut() .remove(&exp.client_hook.get_ptr()); - exp.client_hook = resolution.clone(); - } else { - return Err(Error::failed("export table entry not found".to_string())); } + exp.client_hook = resolution.clone(); + + // The export now points to `resolution`, but it is not necessarily the + // canonical export for `resolution`. The export itself still represents + // the promise that ended up resolving to `resolution`, but `resolution` + // itself also needs to be exported under a separate export ID to + // distinguish from the promise. (Unless it's also a promise, see the next + // bit...) + exp.canonical = false; if brand != connection_state.get_brand() { // We're resolving to a local capability. If we're resolving to a promise, // we might be able to reuse our export table entry and avoid sending a // message. - if let Some(_promise) = resolution.when_more_resolved() { + if let Some(promise) = resolution.when_more_resolved() { // We're replacing a promise with another local promise. In this case, // we might actually be able to just reuse the existing export table // entry to represent the new promise -- unless it already has an entry. // Let's check. - unimplemented!() + let mut exports_by_cap = connection_state.exports_by_cap.borrow_mut(); + + let replacement_export_id = + match exports_by_cap.entry(exp.client_hook.get_ptr()) { + hash_map::Entry::Occupied(occ) => *occ.get(), + hash_map::Entry::Vacant(vac) => { + // The replacement capability isn't previously exported, + // so assign it to the existing table entry. + vac.insert(export_id); + export_id + } + }; + if replacement_export_id == export_id { + // The new promise was not already in the table, therefore the existing + // export table entry has now been repurposed to represent it. There is + // no need to send a resolve message at all. We do, however, have to + // start resolving the next promise. + exp.canonical = true; + drop(exports); + drop(exports_by_cap); + return Self::resolve_exported_promise( + &connection_state, + export_id, + promise, + ) + .await; + } } } + // Prevent a double borrow in write_descriptor() below. + drop(exports); // OK, we have to send a `Resolve` message. let mut message = connection_state.new_outgoing_message(15)?; @@ -1383,7 +1431,8 @@ impl ConnectionState { } else { // This is the first time we've seen this capability. - let exp = Export::new(inner.clone()); + let mut exp = Export::new(inner.clone()); + exp.canonical = true; let export_id = state.exports.borrow_mut().push(exp); state.exports_by_cap.borrow_mut().insert(ptr, export_id); match inner.when_more_resolved() { diff --git a/capnp-rpc/test/impls.rs b/capnp-rpc/test/impls.rs index e5758da28..88174a7d1 100644 --- a/capnp-rpc/test/impls.rs +++ b/capnp-rpc/test/impls.rs @@ -21,13 +21,14 @@ use crate::test_capnp::{ bootstrap, test_call_order, test_capability_server_set, test_extends, test_handle, - test_interface, test_more_stuff, test_pipeline, test_streaming, + test_interface, test_more_stuff, test_pipeline, test_promise_resolve, test_streaming, }; use capnp::capability::{FromClientHook, Promise}; use capnp::Error; use capnp_rpc::pry; +use futures::channel::oneshot; use futures::{FutureExt, TryFutureExt}; use std::cell::{Cell, RefCell}; @@ -113,6 +114,17 @@ impl bootstrap::Server for Bootstrap { .set_cap(capnp_rpc::new_client(TestCapabilityServerSet::new())); Promise::ok(()) } + + fn test_promise_resolve( + &mut self, + _params: bootstrap::TestPromiseResolveParams, + mut results: bootstrap::TestPromiseResolveResults, + ) -> Promise<(), Error> { + results + .get() + .set_cap(capnp_rpc::new_client(TestPromiseResolveImpl {})); + Promise::ok(()) + } } #[derive(Default)] @@ -575,12 +587,12 @@ impl Drop for Handle { impl test_handle::Server for Handle {} pub struct TestCapDestructor { - fulfiller: Option<::futures::channel::oneshot::Sender<()>>, + fulfiller: Option>, imp: TestInterface, } impl TestCapDestructor { - pub fn new(fulfiller: ::futures::channel::oneshot::Sender<()>) -> Self { + pub fn new(fulfiller: oneshot::Sender<()>) -> Self { Self { fulfiller: Some(fulfiller), imp: TestInterface::new(), @@ -718,3 +730,58 @@ impl test_capability_server_set::Server for TestCapabilityServerSet { }) } } + +pub struct ResolverImpl { + sender: Option>, +} + +impl test_promise_resolve::resolver::Server for ResolverImpl { + fn resolve_to_another_promise( + &mut self, + _params: test_promise_resolve::resolver::ResolveToAnotherPromiseParams, + _results: test_promise_resolve::resolver::ResolveToAnotherPromiseResults, + ) -> Promise<(), Error> { + let Some(sender) = self.sender.take() else { + return Promise::err(Error::failed("no sender".into())); + }; + let (snd, rcv) = oneshot::channel(); + let _ = sender.send(capnp_rpc::new_promise_client( + rcv.map_err(|_| Error::failed("oneshot was canceled".to_string())) + .map_ok(|x: test_interface::Client| x.client), + )); + self.sender = Some(snd); + Promise::ok(()) + } + + fn resolve_to_cap( + &mut self, + _params: test_promise_resolve::resolver::ResolveToCapParams, + _results: test_promise_resolve::resolver::ResolveToCapResults, + ) -> Promise<(), Error> { + let Some(sender) = self.sender.take() else { + return Promise::err(Error::failed("no sender".into())); + }; + let _ = sender.send(capnp_rpc::new_client(TestInterface::new())); + Promise::ok(()) + } +} + +pub struct TestPromiseResolveImpl {} + +impl test_promise_resolve::Server for TestPromiseResolveImpl { + fn foo( + &mut self, + _params: test_promise_resolve::FooParams, + mut results: test_promise_resolve::FooResults, + ) -> Promise<(), Error> { + let (snd, rcv) = oneshot::channel(); + let resolver = ResolverImpl { sender: Some(snd) }; + let mut results_root = results.get(); + results_root.set_cap(capnp_rpc::new_promise_client( + rcv.map_err(|_| Error::failed("oneshot was canceled".to_string())) + .map_ok(|x| x.client), + )); + results_root.set_resolver(capnp_rpc::new_client(resolver)); + Promise::ok(()) + } +} diff --git a/capnp-rpc/test/test.capnp b/capnp-rpc/test/test.capnp index 88da7d660..a61e4ebf7 100644 --- a/capnp-rpc/test/test.capnp +++ b/capnp-rpc/test/test.capnp @@ -79,6 +79,7 @@ interface Bootstrap { testCallOrder @4 () -> (cap: TestCallOrder); testMoreStuff @5 () -> (cap: TestMoreStuff); testCapabilityServerSet @6 () -> (cap: TestCapabilityServerSet); + testPromiseResolve @7 () -> (cap: TestPromiseResolve); } interface TestInterface { @@ -195,3 +196,14 @@ interface TestCapabilityServerSet { createHandle @0 () -> (handle :Handle); checkHandle @1 (handle: Handle) -> (isOurs :Bool); } + +interface TestPromiseResolve { + interface Resolver { + resolveToAnotherPromise @0 (); + resolveToCap @1 (); + } + + foo @0 () -> (cap: TestInterface, resolver: Resolver); + # Teturns a promise capability whose resolution can + # be triggered by a `resolver` capability. +} diff --git a/capnp-rpc/test/test.rs b/capnp-rpc/test/test.rs index b782658c3..4fb0d360f 100644 --- a/capnp-rpc/test/test.rs +++ b/capnp-rpc/test/test.rs @@ -1230,3 +1230,31 @@ fn stream_error_gets_reported() { Ok(()) }); } + +#[test] +fn promise_resolve_twice() { + rpc_top_level(|_spawner, client| async move { + let response1 = client.test_promise_resolve_request().send().promise.await?; + let client1 = response1.get()?.get_cap()?; + + let response = client1.foo_request().send().promise.await?; + let resolver = response.get()?.get_resolver()?; + + resolver + .resolve_to_another_promise_request() + .send() + .promise + .await?; + + resolver.resolve_to_cap_request().send().promise.await?; + + let cap = response.get()?.get_cap()?; + let mut request = cap.foo_request(); + request.get().set_i(123); + request.get().set_j(true); + let response2 = request.send().promise.await?; + let x = response2.get()?.get_x()?.to_str()?; + assert_eq!(x, "foo"); + Ok(()) + }); +}