diff --git a/Cargo.toml b/Cargo.toml index a2856900b..a1bade447 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,7 @@ members = [ "capnp-rpc/examples/hello-world", "capnp-rpc/examples/calculator", "capnp-rpc/examples/pubsub", + "capnp-rpc/examples/streaming", "capnp-rpc/test", "example/addressbook", "example/addressbook_send", diff --git a/capnp-rpc/examples/streaming/Cargo.toml b/capnp-rpc/examples/streaming/Cargo.toml new file mode 100644 index 000000000..809fc2581 --- /dev/null +++ b/capnp-rpc/examples/streaming/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "streaming" +version = "0.1.0" +edition = "2021" + +build = "build.rs" + +[[bin]] +name = "streaming" +path = "main.rs" + +[build-dependencies] +capnpc = { path = "../../../capnpc" } + +[dependencies] +capnp = { path = "../../../capnp" } +futures = "0.3.0" +rand = "0.8.5" +sha2 = { version = "0.10.8" } +tokio = { version = "1.0.0", features = ["net", "rt", "macros"]} +tokio-util = { version = "0.7.4", features = ["compat"] } + +[dependencies.capnp-rpc] +path = "../.." diff --git a/capnp-rpc/examples/streaming/build.rs b/capnp-rpc/examples/streaming/build.rs new file mode 100644 index 000000000..a2005e039 --- /dev/null +++ b/capnp-rpc/examples/streaming/build.rs @@ -0,0 +1,6 @@ +fn main() -> Result<(), Box> { + capnpc::CompilerCommand::new() + .file("streaming.capnp") + .run()?; + Ok(()) +} diff --git a/capnp-rpc/examples/streaming/client.rs b/capnp-rpc/examples/streaming/client.rs new file mode 100644 index 000000000..006461940 --- /dev/null +++ b/capnp-rpc/examples/streaming/client.rs @@ -0,0 +1,70 @@ +use crate::streaming_capnp::receiver; +use capnp_rpc::{rpc_twoparty_capnp, twoparty, RpcSystem}; + +use futures::AsyncReadExt; +use rand::Rng; +use sha2::{Digest, Sha256}; + +pub async fn main() -> Result<(), Box> { + use std::net::ToSocketAddrs; + let args: Vec = ::std::env::args().collect(); + if args.len() != 5 { + println!( + "usage: {} client HOST:PORT STREAM_SIZE WINDOW_SIZE", + args[0] + ); + return Ok(()); + } + + let stream_size: usize = str::parse(&args[3]).unwrap(); + let window_size: usize = str::parse(&args[4]).unwrap(); + + let addr = args[2] + .to_socket_addrs()? + .next() + .expect("could not parse address"); + + tokio::task::LocalSet::new() + .run_until(async move { + let stream = tokio::net::TcpStream::connect(&addr).await?; + stream.set_nodelay(true)?; + let (reader, writer) = + tokio_util::compat::TokioAsyncReadCompatExt::compat(stream).split(); + let mut rpc_network = Box::new(twoparty::VatNetwork::new( + futures::io::BufReader::new(reader), + futures::io::BufWriter::new(writer), + rpc_twoparty_capnp::Side::Client, + Default::default(), + )); + rpc_network.set_window_size(window_size); + let mut rpc_system = RpcSystem::new(rpc_network, None); + let receiver: receiver::Client = rpc_system.bootstrap(rpc_twoparty_capnp::Side::Server); + tokio::task::spawn_local(rpc_system); + + let capnp::capability::RemotePromise { promise, pipeline } = + receiver.write_stream_request().send(); + + let mut rng = rand::thread_rng(); + let mut hasher = Sha256::new(); + let bytestream = pipeline.get_stream(); + let mut bytes_written: u32 = 0; + const CHUNK_SIZE: u32 = 4096; + while bytes_written < stream_size as u32 { + let mut request = bytestream.write_request(); + let body = request.get(); + let buf = body.init_bytes(CHUNK_SIZE); + rng.fill(buf); + hasher.update(buf); + request.send().await?; + bytes_written += CHUNK_SIZE; + } + bytestream.end_request().send().promise.await?; + let response = promise.await?; + + let sha256 = response.get()?.get_sha256()?; + let local_sha256 = hasher.finalize(); + assert_eq!(sha256, &local_sha256[..]); + Ok(()) + }) + .await +} diff --git a/capnp-rpc/examples/streaming/main.rs b/capnp-rpc/examples/streaming/main.rs new file mode 100644 index 000000000..9907ccf0e --- /dev/null +++ b/capnp-rpc/examples/streaming/main.rs @@ -0,0 +1,21 @@ +pub mod streaming_capnp { + include!(concat!(env!("OUT_DIR"), "/streaming_capnp.rs")); +} + +pub mod client; +pub mod server; + +#[tokio::main(flavor = "current_thread")] +async fn main() -> Result<(), Box> { + let args: Vec = ::std::env::args().collect(); + if args.len() >= 2 { + match &args[1][..] { + "client" => return client::main().await, + "server" => return server::main().await, + _ => (), + } + } + + println!("usage: {} [client | server] ADDRESS", args[0]); + Ok(()) +} diff --git a/capnp-rpc/examples/streaming/server.rs b/capnp-rpc/examples/streaming/server.rs new file mode 100644 index 000000000..2bc71b0fc --- /dev/null +++ b/capnp-rpc/examples/streaming/server.rs @@ -0,0 +1,112 @@ +use std::net::ToSocketAddrs; + +use crate::streaming_capnp::{byte_stream, receiver}; +use capnp_rpc::{pry, rpc_twoparty_capnp, twoparty, RpcSystem}; + +use capnp::capability::Promise; +use capnp::Error; + +use futures::channel::oneshot; +use futures::AsyncReadExt; +use sha2::{Digest, Sha256}; + +struct ByteStreamImpl { + hasher: Sha256, + hash_sender: Option>>, +} + +impl ByteStreamImpl { + fn new(hash_sender: oneshot::Sender>) -> Self { + Self { + hasher: Sha256::new(), + hash_sender: Some(hash_sender), + } + } +} + +impl byte_stream::Server for ByteStreamImpl { + fn write(&mut self, params: byte_stream::WriteParams) -> Promise<(), Error> { + let bytes = pry!(pry!(params.get()).get_bytes()); + self.hasher.update(bytes); + Promise::ok(()) + } + + fn end( + &mut self, + _params: byte_stream::EndParams, + _results: byte_stream::EndResults, + ) -> Promise<(), Error> { + let hasher = std::mem::take(&mut self.hasher); + if let Some(sender) = self.hash_sender.take() { + let _ = sender.send(hasher.finalize()[..].to_vec()); + } + Promise::ok(()) + } +} + +struct ReceiverImpl {} + +impl ReceiverImpl { + fn new() -> Self { + Self {} + } +} + +impl receiver::Server for ReceiverImpl { + fn write_stream( + &mut self, + _params: receiver::WriteStreamParams, + mut results: receiver::WriteStreamResults, + ) -> Promise<(), Error> { + let (snd, rcv) = oneshot::channel(); + let client: byte_stream::Client = capnp_rpc::new_client(ByteStreamImpl::new(snd)); + results.get().set_stream(client); + pry!(results.set_pipeline()); + Promise::from_future(async move { + match rcv.await { + Ok(v) => { + results.get().set_sha256(&v[..]); + Ok(()) + } + Err(_) => Err(Error::failed("failed to get hash".into())), + } + }) + } +} + +pub async fn main() -> Result<(), Box> { + let args: Vec = ::std::env::args().collect(); + if args.len() != 3 { + println!("usage: {} server ADDRESS[:PORT]", args[0]); + return Ok(()); + } + + let addr = args[2] + .to_socket_addrs()? + .next() + .expect("could not parse address"); + + tokio::task::LocalSet::new() + .run_until(async move { + let listener = tokio::net::TcpListener::bind(&addr).await?; + let client: receiver::Client = capnp_rpc::new_client(ReceiverImpl::new()); + + loop { + let (stream, _) = listener.accept().await?; + stream.set_nodelay(true)?; + let (reader, writer) = + tokio_util::compat::TokioAsyncReadCompatExt::compat(stream).split(); + let network = twoparty::VatNetwork::new( + futures::io::BufReader::new(reader), + futures::io::BufWriter::new(writer), + rpc_twoparty_capnp::Side::Server, + Default::default(), + ); + + let rpc_system = RpcSystem::new(Box::new(network), Some(client.clone().client)); + + tokio::task::spawn_local(rpc_system); + } + }) + .await +} diff --git a/capnp-rpc/examples/streaming/streaming.capnp b/capnp-rpc/examples/streaming/streaming.capnp new file mode 100644 index 000000000..a49e1ead2 --- /dev/null +++ b/capnp-rpc/examples/streaming/streaming.capnp @@ -0,0 +1,16 @@ +@0x9fedc87e438cde81; + +interface ByteStream { + write @0 (bytes :Data) -> stream; + # Writes a chunk. + + end @1 (); + # Ends the stream. +} + +interface Receiver { + writeStream @0 () -> (stream :ByteStream, sha256 :Data); + # Uses set_pipeline() to set up `stream` immediately. + # Actually returns when `end()` is called on that stream. + # `sha256` is the SHA256 checksum of the received data. +} diff --git a/capnp-rpc/src/broken.rs b/capnp-rpc/src/broken.rs index 980ad4868..d7575d881 100644 --- a/capnp-rpc/src/broken.rs +++ b/capnp-rpc/src/broken.rs @@ -81,6 +81,9 @@ impl RequestHook for Request { pipeline: any_pointer::Pipeline::new(Box::new(pipeline)), } } + fn send_streaming(self: Box) -> Promise<(), Error> { + Promise::err(self.error) + } fn tail_send(self: Box) -> Option<(u32, Promise<(), Error>, Box)> { None } diff --git a/capnp-rpc/src/flow_control.rs b/capnp-rpc/src/flow_control.rs new file mode 100644 index 000000000..afbe3cbf7 --- /dev/null +++ b/capnp-rpc/src/flow_control.rs @@ -0,0 +1,162 @@ +use capnp::capability::Promise; +use capnp::Error; + +use futures::channel::oneshot; +use futures::TryFutureExt; +use std::cell::RefCell; +use std::rc::Rc; + +use crate::task_set::{TaskReaper, TaskSet, TaskSetHandle}; + +pub const DEFAULT_WINDOW_SIZE: usize = 65536; + +enum State { + Running(Vec>>), + Failed(Error), +} + +struct FixedWindowFlowControllerInner { + window_size: usize, + in_flight: usize, + max_message_size: usize, + state: State, + empty_fulfiller: Option>>, +} + +impl FixedWindowFlowControllerInner { + fn is_ready(&self) -> bool { + // We extend the window by maxMessageSize to avoid a pathological situation when a message + // is larger than the window size. Otherwise, after sending that message, we would end up + // not sending any others until the ack was received, wasting a round trip's worth of + // bandwidth. + + self.in_flight < self.window_size + self.max_message_size + } +} + +pub struct FixedWindowFlowController { + inner: Rc>, + tasks: TaskSetHandle, +} + +struct Reaper { + inner: Rc>, +} + +impl TaskReaper for Reaper { + fn task_failed(&mut self, error: Error) { + let mut inner = self.inner.borrow_mut(); + if let State::Running(ref mut blocked_sends) = &mut inner.state { + for s in std::mem::take(blocked_sends) { + let _ = s.send(Err(error.clone())); + } + inner.state = State::Failed(error) + } + } +} + +impl FixedWindowFlowController { + pub fn new(window_size: usize) -> (Self, Promise<(), Error>) { + let inner = FixedWindowFlowControllerInner { + window_size, + in_flight: 0, + max_message_size: 0, + state: State::Running(vec![]), + empty_fulfiller: None, + }; + let inner = Rc::new(RefCell::new(inner)); + let (tasks, task_future) = TaskSet::new(Box::new(Reaper { + inner: inner.clone(), + })); + (Self { inner, tasks }, Promise::from_future(task_future)) + } +} + +impl crate::FlowController for FixedWindowFlowController { + fn send( + &mut self, + message: Box, + ack: Promise<(), Error>, + ) -> Promise<(), Error> { + let size = message.size_in_words() * 8; + { + let mut inner = self.inner.borrow_mut(); + let prev_max_size = inner.max_message_size; + inner.max_message_size = usize::max(size, prev_max_size); + + // We are REQUIRED to send the message NOW to maintain correct ordering. + let _ = message.send(); + + inner.in_flight += size; + } + let inner = self.inner.clone(); + let mut tasks = self.tasks.clone(); + self.tasks.add(async move { + ack.await?; + let mut inner = inner.borrow_mut(); + inner.in_flight -= size; + let is_ready = inner.is_ready(); + match inner.state { + State::Running(ref mut blocked_sends) => { + if is_ready { + for s in std::mem::take(blocked_sends) { + let _ = s.send(Ok(())); + } + } + + if inner.in_flight == 0 { + if let Some(f) = inner.empty_fulfiller.take() { + let _ = f.send(Promise::from_future( + tasks.on_empty().map_err(crate::canceled_to_error), + )); + } + } + } + State::Failed(_) => { + // A previous call failed, but this one -- which was already in-flight at the + // time -- ended up succeeding. That may indicate that the server side is not + // properly handling streaming error propagation. Nothing much we can do about + // it here though. + } + } + Ok(()) + }); + + let mut inner = self.inner.borrow_mut(); + let is_ready = inner.is_ready(); + match inner.state { + State::Running(ref mut blocked_sends) => { + if is_ready { + Promise::ok(()) + } else { + let (snd, rcv) = oneshot::channel(); + blocked_sends.push(snd); + Promise::from_future(async { + match rcv.await { + Ok(r) => r, + Err(e) => Err(crate::canceled_to_error(e)), + } + }) + } + } + State::Failed(ref e) => Promise::err(e.clone()), + } + } + + fn wait_all_acked(&mut self) -> Promise<(), Error> { + let mut inner = self.inner.borrow_mut(); + if let State::Running(ref blocked_sends) = inner.state { + if !blocked_sends.is_empty() { + let (snd, rcv) = oneshot::channel(); + inner.empty_fulfiller = Some(snd); + return Promise::from_future(async move { + match rcv.await { + Ok(r) => r.await, + Err(e) => Err(crate::canceled_to_error(e)), + } + }); + } + } + Promise::from_future(self.tasks.on_empty().map_err(crate::canceled_to_error)) + } +} diff --git a/capnp-rpc/src/lib.rs b/capnp-rpc/src/lib.rs index b6e23f29e..bbbe1bd12 100644 --- a/capnp-rpc/src/lib.rs +++ b/capnp-rpc/src/lib.rs @@ -99,6 +99,7 @@ macro_rules! pry { mod attach; mod broken; +mod flow_control; mod local; mod queued; mod reconnect; @@ -132,6 +133,12 @@ pub trait OutgoingMessage { /// Takes the inner message out of `self`. fn take(self: Box) -> ::capnp::message::Builder<::capnp::message::HeapAllocator>; + + /// Gets the total size of the message, for flow control purposes. Although the caller + /// could also call get_body().target_size(0, doing that would walk th emessage tree, + /// whereas typical implementations can compute the size more cheaply by summing + /// segment sizes. + fn size_in_words(&self) -> usize; } /// A message received from a [`VatNetwork`]. @@ -164,11 +171,32 @@ pub trait Connection { /// returns None. If any other problem occurs, returns an Error. fn receive_incoming_message(&mut self) -> Promise>, Error>; + /// Constructs a flow controller for a new stream on this connection. + /// + /// Returns (fc, p), where fc is the new flow controller and p is a promise + /// that must be polled in order to drive the flow controller. + fn new_stream(&mut self) -> (Box, Promise<(), Error>) { + let (fc, f) = crate::flow_control::FixedWindowFlowController::new( + crate::flow_control::DEFAULT_WINDOW_SIZE, + ); + (Box::new(fc), f) + } + /// Waits until all outgoing messages have been sent, then shuts down the outgoing stream. The /// returned promise resolves after shutdown is complete. fn shutdown(&mut self, result: ::capnp::Result<()>) -> Promise<(), Error>; } +/// Tracks a particular RPC stream in order to implement a flow control algorithm. +pub trait FlowController { + fn send( + &mut self, + message: Box, + ack: Promise<(), Error>, + ) -> Promise<(), Error>; + fn wait_all_acked(&mut self) -> Promise<(), Error>; +} + /// Network facility between vats, it determines how to form connections between /// vats. /// diff --git a/capnp-rpc/src/local.rs b/capnp-rpc/src/local.rs index e8aa77d8e..377749e7f 100644 --- a/capnp-rpc/src/local.rs +++ b/capnp-rpc/src/local.rs @@ -279,6 +279,13 @@ impl RequestHook for Request { pipeline, } } + fn send_streaming(self: Box) -> Promise<(), Error> { + // No special handling for streaming in this case. Calls are delivered one at a time. + Promise::from_future(async { + let _ = self.send().promise.await?; + Ok(()) + }) + } fn tail_send(self: Box) -> Option<(u32, Promise<(), Error>, Box)> { unimplemented!() } @@ -332,6 +339,10 @@ where S: capability::Server, { inner: Rc>, + + /// If a streaming call on this capability has returned an error, + /// this contains a copy of that error. + broken_error: Rc>>, } impl Client @@ -341,11 +352,15 @@ where pub fn new(server: S) -> Self { Self { inner: Rc::new(RefCell::new(server)), + broken_error: Rc::new(RefCell::new(None)), } } pub fn from_rc(inner: Rc>) -> Self { - Self { inner } + Self { + inner, + broken_error: Rc::new(RefCell::new(None)), + } } } @@ -356,6 +371,7 @@ where fn clone(&self) -> Self { Self { inner: self.inner.clone(), + broken_error: self.broken_error.clone(), } } } @@ -388,6 +404,12 @@ where params: Box, results: Box, ) -> Promise<(), Error> { + let streaming_error = self.broken_error.clone(); + if let Some(e) = &*streaming_error.borrow() { + // Previous streaming call threw, so everything fails from now on. + return Promise::err(e.clone()); + } + // We don't want to actually dispatch the call synchronously, because we don't want the callee // to have any side effects before the promise is returned to the caller. This helps avoid // race conditions. @@ -407,7 +429,11 @@ where ::capnp::capability::Results::new(results), ) }; - f.await + let result = f.promise.await; + if let (true, Err(e)) = (f.is_streaming, &result) { + *streaming_error.borrow_mut() = Some(e.clone()); + } + result }) } diff --git a/capnp-rpc/src/reconnect.rs b/capnp-rpc/src/reconnect.rs index 4ecbced32..cb2c8b159 100644 --- a/capnp-rpc/src/reconnect.rs +++ b/capnp-rpc/src/reconnect.rs @@ -203,6 +203,10 @@ where result } + fn send_streaming(self: Box) -> Promise<(), capnp::Error> { + todo!() + } + fn tail_send( self: Box, ) -> Option<( diff --git a/capnp-rpc/src/rpc.rs b/capnp-rpc/src/rpc.rs index fcc5a8b73..b1751a278 100644 --- a/capnp-rpc/src/rpc.rs +++ b/capnp-rpc/src/rpc.rs @@ -1797,6 +1797,71 @@ where (question_ref, promise2) } + + fn send_streaming_internal( + connection_state: &Rc>, + mut message: Box, + cap_table: &[Option>], + flow: Rc>>>, + ) -> Promise<(), Error> { + // Build the cap table. + let exports = ConnectionState::write_descriptors( + connection_state, + cap_table, + get_call(&mut message).unwrap().get_params().unwrap(), + ); + + // Init the question table. Do this after writing descriptors to avoid interference. + let mut question = Question::::new(); + question.is_awaiting_return = true; + question.param_exports = exports; + question.is_tail_call = false; + + let question_id = connection_state.questions.borrow_mut().push(question); + { + let mut call_builder: call::Builder = get_call(&mut message).unwrap(); + call_builder.reborrow().set_question_id(question_id); + } + + // Make the result promise. + let (fulfiller, promise) = oneshot::channel::, Error>>(); + let promise = promise.map_err(crate::canceled_to_error).and_then(|x| x); + let question_ref = Rc::new(RefCell::new(QuestionRef::new( + connection_state.clone(), + question_id, + fulfiller, + ))); + + match connection_state.questions.borrow_mut().slots[question_id as usize] { + Some(ref mut q) => { + q.self_ref = Some(Rc::downgrade(&question_ref)); + } + None => unreachable!(), + } + let promise = promise.attach(question_ref.clone()); + + let mut flow = flow.borrow_mut(); + if flow.is_none() { + match connection_state.connection.borrow_mut().as_mut() { + Err(_) => return Promise::err(Error::failed("no connection".into())), + Ok(connection) => { + let (s, p) = connection.new_stream(); + connection_state.add_task(p); + *flow = Some(s); + } + }; + } + let Some(ref mut flow) = *flow else { + unreachable!() + }; + flow.send( + message, + Promise::from_future(async move { + let _ = promise.await?; + Ok(()) + }), + ) + } } impl RequestHook for Request { @@ -1872,6 +1937,46 @@ impl RequestHook for Request { pipeline: any_pointer::Pipeline::new(Box::new(pipeline)), } } + fn send_streaming(self: Box) -> Promise<(), Error> { + let tmp = *self; + let Self { + connection_state, + target, + mut message, + cap_table, + } = tmp; + let write_target_result = { + let call_builder: call::Builder = get_call(&mut message).unwrap(); + target.write_target(call_builder.get_target().unwrap()) + }; + if let Some(redirect) = write_target_result { + // Whoops, this capability has been redirected while we were building the request! + // We'll have to make a new request and do a copy. Ick. + let mut call_builder: call::Builder = get_call(&mut message).unwrap(); + let mut replacement = redirect.new_call( + call_builder.reborrow().get_interface_id(), + call_builder.reborrow().get_method_id(), + None, + ); + + replacement + .set( + call_builder + .get_params() + .unwrap() + .get_content() + .into_reader(), + ) + .unwrap(); + return replacement.hook.send_streaming(); + } + Self::send_streaming_internal( + &connection_state, + message, + &cap_table, + target.flow_controller, + ) + } fn tail_send(self: Box) -> Option<(u32, Promise<(), Error>, Box)> { let tmp = *self; let Self { @@ -2520,6 +2625,7 @@ where { connection_state: Rc>, variant: ClientVariant, + flow_controller: Rc>>>, } enum WeakClientVariant @@ -2538,6 +2644,7 @@ where { connection_state: Weak>, variant: WeakClientVariant, + flow_controller: Weak>>>, } impl WeakClient @@ -2552,9 +2659,11 @@ where WeakClientVariant::__NoIntercept(()) => ClientVariant::__NoIntercept(()), }; let connection_state = self.connection_state.upgrade()?; + let flow_controller = self.flow_controller.upgrade()?; Some(Client { connection_state, variant, + flow_controller, }) } } @@ -2825,6 +2934,7 @@ impl Client { let client = Self { connection_state: connection_state.clone(), variant, + flow_controller: Rc::new(RefCell::new(None)), }; let weak = client.downgrade(); @@ -2853,6 +2963,7 @@ impl Client { WeakClient { connection_state: Rc::downgrade(&self.connection_state), variant, + flow_controller: Rc::downgrade(&self.flow_controller), } } @@ -2960,6 +3071,7 @@ impl Clone for Client { Self { connection_state: self.connection_state.clone(), variant, + flow_controller: self.flow_controller.clone(), } } } diff --git a/capnp-rpc/src/task_set.rs b/capnp-rpc/src/task_set.rs index 67107fd6a..1de4ade82 100644 --- a/capnp-rpc/src/task_set.rs +++ b/capnp-rpc/src/task_set.rs @@ -18,7 +18,7 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN // THE SOFTWARE. -use futures::channel::mpsc; +use futures::channel::{mpsc, oneshot}; use futures::stream::FuturesUnordered; use futures::{Future, FutureExt, Stream}; use std::pin::Pin; @@ -30,6 +30,7 @@ use std::rc::Rc; enum EnqueuedTask { Task(Pin>>>), Terminate(Result<(), E>), + OnEmpty(oneshot::Sender<()>), } enum TaskInProgress { @@ -62,6 +63,7 @@ impl Future for TaskInProgress { pub struct TaskSet { enqueued: Option>>, in_progress: FuturesUnordered>, + on_empty_fulfillers: Vec>, reaper: Rc>>>, } @@ -79,6 +81,7 @@ where let set = Self { enqueued: Some(receiver), in_progress: FuturesUnordered::new(), + on_empty_fulfillers: vec![], reaper: Rc::new(RefCell::new(reaper)), }; @@ -91,6 +94,15 @@ where (handle, set) } + + fn update_on_empty_fulfillers(&mut self) { + // There is always the one pending() future that we added in `new()`. + if self.in_progress.len() <= 1 { + for f in std::mem::take(&mut self.on_empty_fulfillers) { + let _ = f.send(()); + } + } + } } #[derive(Clone)] @@ -112,6 +124,14 @@ where pub fn terminate(&mut self, result: Result<(), E>) { let _ = self.sender.unbounded_send(EnqueuedTask::Terminate(result)); } + + /// Returns a future that finishes at the next time when the task set + /// is empty. If the task set is termined, the oneshot will be canceled. + pub fn on_empty(&mut self) -> oneshot::Receiver<()> { + let (s, r) = oneshot::channel(); + let _ = self.sender.unbounded_send(EnqueuedTask::OnEmpty(s)); + r + } } /// For a specific kind of task, `TaskReaper` defines the procedure that should @@ -136,7 +156,7 @@ where enqueued: Some(ref mut enqueued), ref mut in_progress, ref reaper, - .. + ref mut on_empty_fulfillers, } = self.as_mut().get_mut() { loop { @@ -161,6 +181,9 @@ where } })))); } + Poll::Ready(Some(EnqueuedTask::OnEmpty(f))) => { + on_empty_fulfillers.push(f); + } } } } @@ -173,9 +196,15 @@ where Poll::Pending => return Poll::Pending, Poll::Ready(v) => match v { None => return Poll::Ready(Ok(())), - Some(TaskDone::Continue) => (), - Some(TaskDone::Terminate(Ok(()))) => return Poll::Ready(Ok(())), - Some(TaskDone::Terminate(Err(e))) => return Poll::Ready(Err(e)), + Some(TaskDone::Continue) => self.update_on_empty_fulfillers(), + Some(TaskDone::Terminate(Ok(()))) => { + self.on_empty_fulfillers.clear(); + return Poll::Ready(Ok(())); + } + Some(TaskDone::Terminate(Err(e))) => { + self.on_empty_fulfillers.clear(); + return Poll::Ready(Err(e)); + } }, } } diff --git a/capnp-rpc/src/twoparty.rs b/capnp-rpc/src/twoparty.rs index 362363f17..daf754a51 100644 --- a/capnp-rpc/src/twoparty.rs +++ b/capnp-rpc/src/twoparty.rs @@ -80,6 +80,10 @@ impl crate::OutgoingMessage for OutgoingMessage { fn take(self: Box) -> ::capnp::message::Builder<::capnp::message::HeapAllocator> { self.message } + + fn size_in_words(&self) -> usize { + self.message.size_in_words() + } } struct ConnectionInner @@ -91,6 +95,7 @@ where side: crate::rpc_twoparty_capnp::Side, receive_options: ReaderOptions, on_disconnect_fulfiller: Option>, + window_size_in_bytes: usize, } struct Connection @@ -134,6 +139,7 @@ where side, receive_options, on_disconnect_fulfiller: Some(on_disconnect_fulfiller), + window_size_in_bytes: crate::flow_control::DEFAULT_WINDOW_SIZE, })), } } @@ -189,6 +195,13 @@ where } } + fn new_stream(&mut self) -> (Box, Promise<(), capnp::Error>) { + let (fc, f) = crate::flow_control::FixedWindowFlowController::new( + self.inner.borrow().window_size_in_bytes, + ); + (Box::new(fc), f) + } + fn shutdown(&mut self, result: ::capnp::Result<()>) -> Promise<(), ::capnp::Error> { Promise::from_future(self.inner.borrow_mut().sender.terminate(result)) } @@ -262,6 +275,14 @@ where side, } } + + /// Set the number of bytes in the flow control window for each stream created + /// on this connection. + pub fn set_window_size(&mut self, window_size: usize) { + if let Some(ref mut conn) = self.connection { + conn.inner.borrow_mut().window_size_in_bytes = window_size; + } + } } impl crate::VatNetwork for VatNetwork diff --git a/capnp-rpc/test/impls.rs b/capnp-rpc/test/impls.rs index 16046d9ad..e5758da28 100644 --- a/capnp-rpc/test/impls.rs +++ b/capnp-rpc/test/impls.rs @@ -21,7 +21,7 @@ use crate::test_capnp::{ bootstrap, test_call_order, test_capability_server_set, test_extends, test_handle, - test_interface, test_more_stuff, test_pipeline, + test_interface, test_more_stuff, test_pipeline, test_streaming, }; use capnp::capability::{FromClientHook, Promise}; @@ -541,6 +541,17 @@ impl test_more_stuff::Server for TestMoreStuff { Promise::from_future(::futures::future::try_join_all(results).map_ok(|_| ())) } + + fn get_test_streaming( + &mut self, + _params: test_more_stuff::GetTestStreamingParams, + mut results: test_more_stuff::GetTestStreamingResults, + ) -> Promise<(), Error> { + results + .get() + .set_cap(capnp_rpc::new_client(TestStreamingImpl::new())); + Promise::ok(()) + } } struct Handle { @@ -611,6 +622,47 @@ impl test_interface::Server for TestCapDestructor { } } +#[derive(Default)] +pub struct TestStreamingImpl { + i_sum: u32, + j_sum: u32, +} + +impl TestStreamingImpl { + pub fn new() -> Self { + Self::default() + } +} + +impl test_streaming::Server for TestStreamingImpl { + fn do_stream_i(&mut self, params: test_streaming::DoStreamIParams) -> Promise<(), Error> { + let params = pry!(params.get()); + if params.get_throw_error() { + return Promise::err(Error::failed("throw requested".to_string())); + } + self.i_sum += params.get_i(); + Promise::ok(()) + } + fn do_stream_j(&mut self, params: test_streaming::DoStreamJParams) -> Promise<(), Error> { + let params = pry!(params.get()); + if params.get_throw_error() { + return Promise::err(Error::failed("throw requested".to_string())); + } + self.j_sum += params.get_j(); + Promise::ok(()) + } + fn finish_stream( + &mut self, + _params: test_streaming::FinishStreamParams, + mut results: test_streaming::FinishStreamResults, + ) -> Promise<(), Error> { + let mut results = results.get(); + results.set_total_i(self.i_sum); + results.set_total_j(self.j_sum); + Promise::ok(()) + } +} + #[derive(Default)] pub struct CssHandle {} diff --git a/capnp-rpc/test/test.capnp b/capnp-rpc/test/test.capnp index 931ad3b5c..88da7d660 100644 --- a/capnp-rpc/test/test.capnp +++ b/capnp-rpc/test/test.capnp @@ -133,6 +133,13 @@ interface TestTailCaller { foo @0 (i :Int32, callee :TestTailCallee) -> TestTailCallee.TailResult; } +interface TestStreaming { + doStreamI @0 (i :UInt32, throwError :Bool) -> stream; + doStreamJ @1 (j :UInt32, throwError :Bool) -> stream; + finishStream @2 () -> (totalI :UInt32, totalJ :UInt32); + # Test streaming. finishStream() returns the totals of the values streamed to the other calls. +} + interface TestHandle {} interface TestMoreStuff extends(TestCallOrder) { @@ -178,6 +185,8 @@ interface TestMoreStuff extends(TestCallOrder) { callEachCapability @13 (caps :List(TestInterface)) -> (); # Calls TestInterface::foo(123, true) on each cap. + + getTestStreaming @14 () -> (cap :TestStreaming); } interface TestCapabilityServerSet { diff --git a/capnp-rpc/test/test.rs b/capnp-rpc/test/test.rs index e9de9db61..b782658c3 100644 --- a/capnp-rpc/test/test.rs +++ b/capnp-rpc/test/test.rs @@ -1162,3 +1162,71 @@ fn capability_server_set_rpc() { Ok(()) }) } + +#[test] +fn basic_streaming() { + rpc_and_local_top_level(|_spawner, client| async move { + let response = client.test_more_stuff_request().send().promise.await?; + let client = response.get()?.get_cap()?; + let response = client.get_test_streaming_request().send().promise.await?; + let client = response.get()?.get_cap()?; + + const EACH: u32 = 10; + const ITERS: u32 = 100; + for _ in 0..ITERS { + let mut request = client.do_stream_i_request(); + request.get().set_i(EACH); + request.send().await?; + } + + let r = client.finish_stream_request().send().promise.await?; + let results = r.get()?; + assert_eq!(results.get_total_i(), ITERS * EACH); + Ok(()) + }); +} + +#[test] +fn basic_streaming_on_pipeline() { + rpc_and_local_top_level(|_spawner, client| async move { + let response = client.test_more_stuff_request().send().pipeline; + let client = response.get_cap(); + let response = client.get_test_streaming_request().send().pipeline; + let client = response.get_cap(); + + const EACH: u32 = 3; + const ITERS: u32 = 1000; + for _ in 0..ITERS { + let mut request = client.do_stream_i_request(); + request.get().set_i(EACH); + request.send().await?; + } + + let r = client.finish_stream_request().send().promise.await?; + let results = r.get()?; + assert_eq!(results.get_total_i(), ITERS * EACH); + Ok(()) + }); +} + +#[test] +fn stream_error_gets_reported() { + rpc_and_local_top_level(|_spawner, client| async move { + let response = client.test_more_stuff_request().send().promise.await?; + let client = response.get()?.get_cap()?; + let response = client.get_test_streaming_request().send().promise.await?; + let client = response.get()?.get_cap()?; + + let mut request = client.do_stream_i_request(); + request.get().set_throw_error(true); + + let _ = request.send().await; + + let r = client.finish_stream_request().send().promise.await; + let Err(e) = r else { + panic!("expected error"); + }; + assert!(e.to_string().contains("throw requested")); + Ok(()) + }); +} diff --git a/capnp/src/capability.rs b/capnp/src/capability.rs index d87dc248b..1858385eb 100644 --- a/capnp/src/capability.rs +++ b/capnp/src/capability.rs @@ -214,6 +214,27 @@ where } } +/// A method call that has not been sent yet. +#[cfg(feature = "alloc")] +pub struct StreamingRequest { + pub marker: PhantomData, + pub hook: alloc::boxed::Box, +} + +#[cfg(feature = "alloc")] +impl StreamingRequest +where + Params: Owned, +{ + pub fn get(&mut self) -> Params::Builder<'_> { + self.hook.get().get_as().unwrap() + } + + pub fn send(self) -> Promise<(), Error> { + self.hook.send_streaming() + } +} + /// The values of the parameters passed to a method call, as seen by the server. #[cfg(feature = "alloc")] pub struct Params { @@ -324,6 +345,19 @@ impl Client { } } + pub fn new_streaming_call( + &self, + interface_id: u64, + method_id: u16, + size_hint: Option, + ) -> StreamingRequest { + let typeless = self.hook.new_call(interface_id, method_id, size_hint); + StreamingRequest { + hook: typeless.hook, + marker: PhantomData, + } + } + /// If the capability is actually only a promise, the returned promise resolves once the /// capability itself has resolved to its final destination (or propagates the exception if /// the capability promise is rejected). This is mainly useful for error-checking in the case @@ -334,6 +368,28 @@ impl Client { } } +/// The return value of Server::dispatch_call(). +#[cfg(feature = "alloc")] +pub struct DispatchCallResult { + /// Promise for completion of the call. + pub promise: Promise<(), Error>, + + /// If true, this method was declared as `-> stream;`. If this call throws + /// an exception, then all future calls on the capability with throw the + /// same exception. + pub is_streaming: bool, +} + +#[cfg(feature = "alloc")] +impl DispatchCallResult { + pub fn new(promise: Promise<(), Error>, is_streaming: bool) -> Self { + Self { + promise, + is_streaming, + } + } +} + /// An untyped server. #[cfg(feature = "alloc")] pub trait Server { @@ -343,7 +399,7 @@ pub trait Server { method_id: u16, params: Params, results: Results, - ) -> Promise<(), Error>; + ) -> DispatchCallResult; } /// Trait to track the relationship between generated Server traits and Client structs. diff --git a/capnp/src/message.rs b/capnp/src/message.rs index c381b3a56..1a38f51dd 100644 --- a/capnp/src/message.rs +++ b/capnp/src/message.rs @@ -294,6 +294,10 @@ where pub fn into_typed(self) -> TypedReader { TypedReader::new(self) } + + pub fn size_in_words(&self) -> usize { + self.arena.size_in_words() + } } /// A message reader whose value is known to be of type `T`. @@ -520,6 +524,10 @@ where pub fn into_allocator(self) -> A { self.arena.into_allocator() } + + pub fn size_in_words(&self) -> usize { + self.arena.size_in_words() + } } impl ReaderSegments for Builder diff --git a/capnp/src/private/arena.rs b/capnp/src/private/arena.rs index e7c246e42..76939f277 100644 --- a/capnp/src/private/arena.rs +++ b/capnp/src/private/arena.rs @@ -45,6 +45,8 @@ pub trait ReaderArena { fn nesting_limit(&self) -> i32; + fn size_in_words(&self) -> usize; + // TODO(apibump): Consider putting extract_cap(), inject_cap(), drop_cap() here // and on message::Reader. Then we could get rid of Imbue and ImbueMut, and // layout::StructReader, layout::ListReader, etc. could drop their `cap_table` fields. @@ -146,6 +148,16 @@ where fn nesting_limit(&self) -> i32 { self.nesting_limit } + + fn size_in_words(&self) -> usize { + let mut result = 0; + for ii in 0..self.segments.len() { + if let Some(seg) = self.segments.get_segment(ii as u32) { + result += seg.len(); + } + } + result + } } pub trait BuilderArena: ReaderArena { @@ -333,6 +345,14 @@ where fn nesting_limit(&self) -> i32 { 0x7fffffff } + + fn size_in_words(&self) -> usize { + let mut result = 0; + for ii in 0..self.inner.segments.len() { + result += self.inner.segments[ii].allocated as usize + } + result + } } impl BuilderArenaImplInner @@ -464,4 +484,8 @@ impl ReaderArena for NullArena { fn nesting_limit(&self) -> i32 { 0x7fffffff } + + fn size_in_words(&self) -> usize { + 0 + } } diff --git a/capnp/src/private/capability.rs b/capnp/src/private/capability.rs index 996f5678f..ce5dfdbfa 100644 --- a/capnp/src/private/capability.rs +++ b/capnp/src/private/capability.rs @@ -33,6 +33,7 @@ pub trait RequestHook { fn get(&mut self) -> any_pointer::Builder<'_>; fn get_brand(&self) -> usize; fn send(self: alloc::boxed::Box) -> RemotePromise; + fn send_streaming(self: alloc::boxed::Box) -> Promise<(), crate::Error>; fn tail_send( self: alloc::boxed::Box, ) -> Option<( diff --git a/capnpc/src/codegen.rs b/capnpc/src/codegen.rs index 2f55830db..c2102efc0 100644 --- a/capnpc/src/codegen.rs +++ b/capnpc/src/codegen.rs @@ -520,6 +520,9 @@ const NAME_ANNOTATION_ID: u64 = 0xc2fe4c6d100166d0; const PARENT_MODULE_ANNOTATION_ID: u64 = 0xabee386cd1450364; const OPTION_ANNOTATION_ID: u64 = 0xabfef22c4ee1964e; +// StreamResult type ID, as defined in stream.capnp. +const STREAM_RESULT_ID: u64 = 0x995f9a3377c0b16e; + fn name_annotation_value(annotation: schema_capnp::annotation::Reader) -> capnp::Result<&str> { if let schema_capnp::value::Text(t) = annotation.get_value()?.which()? { let name = t?.to_str()?; @@ -2551,32 +2554,6 @@ fn generate_node( ¶m_scopes.join("::"), )?; - let result_id = method.get_result_struct_type(); - let result_node = &ctx.node_map[&result_id]; - let (result_scopes, results_ty_params) = if result_node.get_scope_id() == 0 { - let mut names = names.clone(); - let local_name = module_name(&format!("{name}Results")); - nested_output.push(generate_node(ctx, result_id, &local_name)?); - names.push(local_name); - (names, params.params.clone()) - } else { - ( - ctx.scope_map[&result_node.get_id()].clone(), - get_ty_params_of_brand(ctx, method.get_result_brand()?)?, - ) - }; - let result_type = do_branding( - ctx, - result_id, - method.get_result_brand()?, - Leaf::Owned, - &result_scopes.join("::"), - )?; - - dispatch_arms.push( - Line(fmt!(ctx, - "{ordinal} => server.{}({capnp}::private::capability::internal_get_typed_params(params), {capnp}::private::capability::internal_get_typed_results(results)),", - module_name(name)))); mod_interior.push(Line(fmt!( ctx, "pub type {}Params<{}> = {capnp}::capability::Params<{}>;", @@ -2584,34 +2561,88 @@ fn generate_node( params_ty_params, param_type ))); - mod_interior.push(Line(fmt!( - ctx, - "pub type {}Results<{}> = {capnp}::capability::Results<{}>;", - capitalize_first_letter(name), - results_ty_params, - result_type - ))); - server_interior.push( - Line(fmt!(ctx, - "fn {}(&mut self, _: {}Params<{}>, _: {}Results<{}>) -> {capnp}::capability::Promise<(), {capnp}::Error> {{ {capnp}::capability::Promise::err({capnp}::Error::unimplemented(\"method {}::Server::{} not implemented\".to_string())) }}", - module_name(name), - capitalize_first_letter(name), params_ty_params, - capitalize_first_letter(name), results_ty_params, - node_name, module_name(name) + + let result_id = method.get_result_struct_type(); + if result_id != STREAM_RESULT_ID { + dispatch_arms.push( + Line(fmt!(ctx, + "{ordinal} => {capnp}::capability::DispatchCallResult::new(server.{}({capnp}::private::capability::internal_get_typed_params(params), {capnp}::private::capability::internal_get_typed_results(results)), false),", + module_name(name)))); + + let result_node = &ctx.node_map[&result_id]; + let (result_scopes, results_ty_params) = if result_node.get_scope_id() == 0 { + let mut names = names.clone(); + let local_name = module_name(&format!("{name}Results")); + nested_output.push(generate_node(ctx, result_id, &local_name)?); + names.push(local_name); + (names, params.params.clone()) + } else { + ( + ctx.scope_map[&result_node.get_id()].clone(), + get_ty_params_of_brand(ctx, method.get_result_brand()?)?, + ) + }; + let result_type = do_branding( + ctx, + result_id, + method.get_result_brand()?, + Leaf::Owned, + &result_scopes.join("::"), + )?; + mod_interior.push(Line(fmt!( + ctx, + "pub type {}Results<{}> = {capnp}::capability::Results<{}>;", + capitalize_first_letter(name), + results_ty_params, + result_type ))); + server_interior.push( + Line(fmt!(ctx, + "fn {}(&mut self, _: {}Params<{}>, _: {}Results<{}>) -> {capnp}::capability::Promise<(), {capnp}::Error> {{ {capnp}::capability::Promise::err({capnp}::Error::unimplemented(\"method {}::Server::{} not implemented\".to_string())) }}", + module_name(name), + capitalize_first_letter(name), params_ty_params, + capitalize_first_letter(name), results_ty_params, + node_name, module_name(name) + ))); - client_impl_interior.push(Line(fmt!( - ctx, - "pub fn {}_request(&self) -> {capnp}::capability::Request<{},{}> {{", - camel_to_snake_case(name), - param_type, - result_type - ))); + client_impl_interior.push(Line(fmt!( + ctx, + "pub fn {}_request(&self) -> {capnp}::capability::Request<{},{}> {{", + camel_to_snake_case(name), + param_type, + result_type + ))); + + client_impl_interior.push(indent(Line(format!( + "self.client.new_call(_private::TYPE_ID, {ordinal}, ::core::option::Option::None)" + )))); + client_impl_interior.push(line("}")); + } else { + // It's a streaming method. + dispatch_arms.push( + Line(fmt!(ctx, + "{ordinal} => {capnp}::capability::DispatchCallResult::new(server.{}({capnp}::private::capability::internal_get_typed_params(params)), true),", + module_name(name)))); + + server_interior.push( + Line(fmt!(ctx, + "fn {}(&mut self, _: {}Params<{}>) -> {capnp}::capability::Promise<(), {capnp}::Error> {{ {capnp}::capability::Promise::err({capnp}::Error::unimplemented(\"method {}::Server::{} not implemented\".to_string())) }}", + module_name(name), + capitalize_first_letter(name), params_ty_params, + node_name, module_name(name) + ))); + client_impl_interior.push(Line(fmt!( + ctx, + "pub fn {}_request(&self) -> {capnp}::capability::StreamingRequest<{}> {{", + camel_to_snake_case(name), + param_type + ))); + client_impl_interior.push(indent(Line(format!( + "self.client.new_streaming_call(_private::TYPE_ID, {ordinal}, ::core::option::Option::None)" + )))); - client_impl_interior.push(indent(Line(format!( - "self.client.new_call(_private::TYPE_ID, {ordinal}, ::core::option::Option::None)" - )))); - client_impl_interior.push(line("}")); + client_impl_interior.push(line("}")); + } method.get_annotations()?; } @@ -2835,11 +2866,11 @@ fn generate_node( } else { Line(fmt!(ctx,"impl <_T: Server> {capnp}::capability::Server for ServerDispatch<_T> {{")) }), - indent(Line(fmt!(ctx,"fn dispatch_call(&mut self, interface_id: u64, method_id: u16, params: {capnp}::capability::Params<{capnp}::any_pointer::Owned>, results: {capnp}::capability::Results<{capnp}::any_pointer::Owned>) -> {capnp}::capability::Promise<(), {capnp}::Error> {{"))), + indent(Line(fmt!(ctx,"fn dispatch_call(&mut self, interface_id: u64, method_id: u16, params: {capnp}::capability::Params<{capnp}::any_pointer::Owned>, results: {capnp}::capability::Results<{capnp}::any_pointer::Owned>) -> {capnp}::capability::DispatchCallResult {{"))), indent(indent(line("match interface_id {"))), indent(indent(indent(line("_private::TYPE_ID => Self::dispatch_call_internal(&mut self.server, method_id, params, results),")))), indent(indent(indent(base_dispatch_arms))), - indent(indent(indent(Line(fmt!(ctx,"_ => {{ {capnp}::capability::Promise::err({capnp}::Error::unimplemented(\"Method not implemented.\".to_string())) }}"))))), + indent(indent(indent(Line(fmt!(ctx,"_ => {{ {capnp}::capability::DispatchCallResult::new({capnp}::capability::Promise::err({capnp}::Error::unimplemented(\"Method not implemented.\".to_string())), false) }}"))))), indent(indent(line("}"))), indent(line("}")), line("}")])); @@ -2851,10 +2882,10 @@ fn generate_node( } else { line("impl <_T :Server> ServerDispatch<_T> {") }), - indent(Line(fmt!(ctx,"pub fn dispatch_call_internal(server: &mut _T, method_id: u16, params: {capnp}::capability::Params<{capnp}::any_pointer::Owned>, results: {capnp}::capability::Results<{capnp}::any_pointer::Owned>) -> {capnp}::capability::Promise<(), {capnp}::Error> {{"))), + indent(Line(fmt!(ctx,"pub fn dispatch_call_internal(server: &mut _T, method_id: u16, params: {capnp}::capability::Params<{capnp}::any_pointer::Owned>, results: {capnp}::capability::Results<{capnp}::any_pointer::Owned>) -> {capnp}::capability::DispatchCallResult {{"))), indent(indent(line("match method_id {"))), indent(indent(indent(dispatch_arms))), - indent(indent(indent(Line(fmt!(ctx,"_ => {{ ::capnp::capability::Promise::err({capnp}::Error::unimplemented(\"Method not implemented.\".to_string())) }}"))))), + indent(indent(indent(Line(fmt!(ctx,"_ => {{ {capnp}::capability::DispatchCallResult::new({capnp}::capability::Promise::err({capnp}::Error::unimplemented(\"Method not implemented.\".to_string())), false) }}"))))), indent(indent(line("}"))), indent(line("}")), line("}")]));