diff --git a/capnp-rpc/examples/streaming/Cargo.toml b/capnp-rpc/examples/streaming/Cargo.toml index 809fc2581..bc73685e8 100644 --- a/capnp-rpc/examples/streaming/Cargo.toml +++ b/capnp-rpc/examples/streaming/Cargo.toml @@ -17,6 +17,7 @@ capnp = { path = "../../../capnp" } futures = "0.3.0" rand = "0.8.5" sha2 = { version = "0.10.8" } +base16 = { version = "0.2" } tokio = { version = "1.0.0", features = ["net", "rt", "macros"]} tokio-util = { version = "0.7.4", features = ["compat"] } diff --git a/capnp-rpc/examples/streaming/client.rs b/capnp-rpc/examples/streaming/client.rs index 006461940..b8c84181d 100644 --- a/capnp-rpc/examples/streaming/client.rs +++ b/capnp-rpc/examples/streaming/client.rs @@ -16,8 +16,8 @@ pub async fn main() -> Result<(), Box> { return Ok(()); } - let stream_size: usize = str::parse(&args[3]).unwrap(); - let window_size: usize = str::parse(&args[4]).unwrap(); + let stream_size: u64 = str::parse(&args[3]).unwrap(); + let window_size: u64 = str::parse(&args[4]).unwrap(); let addr = args[2] .to_socket_addrs()? @@ -36,7 +36,7 @@ pub async fn main() -> Result<(), Box> { rpc_twoparty_capnp::Side::Client, Default::default(), )); - rpc_network.set_window_size(window_size); + rpc_network.set_window_size(window_size as usize); let mut rpc_system = RpcSystem::new(rpc_network, None); let receiver: receiver::Client = rpc_system.bootstrap(rpc_twoparty_capnp::Side::Server); tokio::task::spawn_local(rpc_system); @@ -47,22 +47,27 @@ pub async fn main() -> Result<(), Box> { let mut rng = rand::thread_rng(); let mut hasher = Sha256::new(); let bytestream = pipeline.get_stream(); - let mut bytes_written: u32 = 0; + let mut bytes_written: u64 = 0; const CHUNK_SIZE: u32 = 4096; - while bytes_written < stream_size as u32 { + while bytes_written < stream_size { let mut request = bytestream.write_request(); let body = request.get(); - let buf = body.init_bytes(CHUNK_SIZE); + let this_chunk_size = u64::min(CHUNK_SIZE as u64, stream_size - bytes_written); + let buf = body.init_bytes(this_chunk_size as u32); rng.fill(buf); hasher.update(buf); request.send().await?; - bytes_written += CHUNK_SIZE; + bytes_written += this_chunk_size; } + let local_sha256 = hasher.finalize(); + println!( + "wrote {bytes_written} bytes with hash {}", + base16::encode_lower(&local_sha256[..]) + ); bytestream.end_request().send().promise.await?; let response = promise.await?; let sha256 = response.get()?.get_sha256()?; - let local_sha256 = hasher.finalize(); assert_eq!(sha256, &local_sha256[..]); Ok(()) }) diff --git a/capnp-rpc/examples/streaming/server.rs b/capnp-rpc/examples/streaming/server.rs index 2bc71b0fc..6bd312fe8 100644 --- a/capnp-rpc/examples/streaming/server.rs +++ b/capnp-rpc/examples/streaming/server.rs @@ -12,6 +12,7 @@ use sha2::{Digest, Sha256}; struct ByteStreamImpl { hasher: Sha256, + bytes_received: u32, hash_sender: Option>>, } @@ -19,6 +20,7 @@ impl ByteStreamImpl { fn new(hash_sender: oneshot::Sender>) -> Self { Self { hasher: Sha256::new(), + bytes_received: 0, hash_sender: Some(hash_sender), } } @@ -28,6 +30,7 @@ impl byte_stream::Server for ByteStreamImpl { fn write(&mut self, params: byte_stream::WriteParams) -> Promise<(), Error> { let bytes = pry!(pry!(params.get()).get_bytes()); self.hasher.update(bytes); + self.bytes_received += bytes.len() as u32; Promise::ok(()) } @@ -37,8 +40,14 @@ impl byte_stream::Server for ByteStreamImpl { _results: byte_stream::EndResults, ) -> Promise<(), Error> { let hasher = std::mem::take(&mut self.hasher); + let hash = hasher.finalize()[..].to_vec(); + println!( + "received {} bytes with hash {}", + self.bytes_received, + base16::encode_lower(&hash[..]) + ); if let Some(sender) = self.hash_sender.take() { - let _ = sender.send(hasher.finalize()[..].to_vec()); + let _ = sender.send(hash); } Promise::ok(()) }