From 6e475ec3d5fdbf052da1594ff0fce3fde06b2197 Mon Sep 17 00:00:00 2001 From: clabby Date: Sun, 2 Jun 2024 00:52:42 -0400 Subject: [PATCH] feat(preimage): Async client handles ## Overview Makes the `HintWriterClient` + `PreimageOracleClient` traits asynchronous to prevent blocking of the host program when executing a client program natively. Previously, since the preimage oracle bindings for the client were entirely synchronous, the loops in `PipeHandle` could cause a deadlock. Now that oracle IO is asynchronous, the runtime can interrupt a future when it yields execution (i.e. `tokio::select` works.) In the client program, synchronous execution is still guaranteed. It can run async colored functions in a minimal runtime, such as the `block_on` runtime in `kona_common`. `simple-revm` had to be changed as a part of this PR, which has an example of this. --- bin/programs/simple-revm/src/main.rs | 39 ++++++----- crates/preimage/src/hint.rs | 20 +++--- crates/preimage/src/oracle.rs | 27 +++---- crates/preimage/src/pipe.rs | 101 +++++++++++++++++++++------ crates/preimage/src/traits.rs | 13 ++-- 5 files changed, 133 insertions(+), 67 deletions(-) diff --git a/bin/programs/simple-revm/src/main.rs b/bin/programs/simple-revm/src/main.rs index 0c24f56d4..8a2444f89 100644 --- a/bin/programs/simple-revm/src/main.rs +++ b/bin/programs/simple-revm/src/main.rs @@ -33,45 +33,48 @@ static CLIENT_HINT_PIPE: PipeHandle = #[client_entry(0xFFFFFFF)] fn main() { - let mut oracle = OracleReader::new(CLIENT_PREIMAGE_PIPE); - let hint_writer = HintWriter::new(CLIENT_HINT_PIPE); - - io::print("Booting EVM and checking hash...\n"); - let (digest, code) = boot(&mut oracle).expect("Failed to boot"); - - match run_evm(&mut oracle, &hint_writer, digest, code) { - Ok(_) => io::print("Success, hashes matched!\n"), - Err(e) => { - io::print_err(alloc::format!("Error: {}\n", e).as_ref()); - io::exit(1); + kona_common::block_on(async { + let mut oracle = OracleReader::new(CLIENT_PREIMAGE_PIPE); + let hint_writer = HintWriter::new(CLIENT_HINT_PIPE); + + io::print("Booting EVM and checking hash...\n"); + let (digest, code) = boot(&mut oracle).await.expect("Failed to boot"); + + match run_evm(&mut oracle, &hint_writer, digest, code).await { + Ok(_) => io::print("Success, hashes matched!\n"), + Err(e) => { + io::print_err(alloc::format!("Error: {}\n", e).as_ref()); + io::exit(1); + } } - } + }) } /// Boot the program and load bootstrap information. #[inline] -fn boot(oracle: &mut OracleReader) -> Result<([u8; 32], Vec)> { +async fn boot(oracle: &mut OracleReader) -> Result<([u8; 32], Vec)> { let digest = oracle - .get(PreimageKey::new_local(DIGEST_IDENT))? + .get(PreimageKey::new_local(DIGEST_IDENT)) + .await? .try_into() .map_err(|_| anyhow!("Failed to convert digest to [u8; 32]"))?; - let code = oracle.get(PreimageKey::new_local(CODE_IDENT))?; + let code = oracle.get(PreimageKey::new_local(CODE_IDENT)).await?; Ok((digest, code)) } /// Call the SHA-256 precompile and assert that the input and output match the expected values #[inline] -fn run_evm( +async fn run_evm( oracle: &mut OracleReader, hint_writer: &HintWriter, digest: [u8; 32], code: Vec, ) -> Result<()> { // Send a hint for the preimage of the digest to the host so that it can prepare the preimage. - hint_writer.write(&alloc::format!("sha2-preimage {}", hex::encode(digest)))?; + hint_writer.write(&alloc::format!("sha2-preimage {}", hex::encode(digest))).await?; // Get the preimage of `digest` from the host. - let input = oracle.get(PreimageKey::new_local(INPUT_IDENT))?; + let input = oracle.get(PreimageKey::new_local(INPUT_IDENT)).await?; let mut cache_db = CacheDB::new(EmptyDB::default()); diff --git a/crates/preimage/src/hint.rs b/crates/preimage/src/hint.rs index 49f871588..18ebf2d57 100644 --- a/crates/preimage/src/hint.rs +++ b/crates/preimage/src/hint.rs @@ -1,6 +1,7 @@ use crate::{traits::HintWriterClient, HintReaderServer, PipeHandle}; use alloc::{boxed::Box, string::String, vec}; use anyhow::Result; +use async_trait::async_trait; use core::future::Future; use tracing::{debug, error}; @@ -18,10 +19,11 @@ impl HintWriter { } } +#[async_trait] impl HintWriterClient for HintWriter { /// Write a hint to the host. This will overwrite any existing hint in the pipe, and block until /// all data has been written. - fn write(&self, hint: &str) -> Result<()> { + async fn write(&self, hint: &str) -> Result<()> { // Form the hint into a byte buffer. The format is a 4-byte big-endian length prefix // followed by the hint string. let mut hint_bytes = vec![0u8; hint.len() + 4]; @@ -31,13 +33,13 @@ impl HintWriterClient for HintWriter { debug!(target: "hint_writer", "Writing hint \"{hint}\""); // Write the hint to the host. - self.pipe_handle.write(&hint_bytes)?; + self.pipe_handle.write(&hint_bytes).await?; debug!(target: "hint_writer", "Successfully wrote hint"); // Read the hint acknowledgement from the host. let mut hint_ack = [0u8; 1]; - self.pipe_handle.read_exact(&mut hint_ack)?; + self.pipe_handle.read_exact(&mut hint_ack).await?; debug!(target: "hint_writer", "Received hint acknowledgement"); @@ -59,7 +61,7 @@ impl HintReader { } } -#[async_trait::async_trait] +#[async_trait] impl HintReaderServer for HintReader { async fn next_hint(&self, mut route_hint: F) -> Result<()> where @@ -68,12 +70,12 @@ impl HintReaderServer for HintReader { { // Read the length of the raw hint payload. let mut len_buf = [0u8; 4]; - self.pipe_handle.read_exact(&mut len_buf)?; + self.pipe_handle.read_exact(&mut len_buf).await?; let len = u32::from_be_bytes(len_buf); // Read the raw hint payload. let mut raw_payload = vec![0u8; len as usize]; - self.pipe_handle.read_exact(raw_payload.as_mut_slice())?; + self.pipe_handle.read_exact(raw_payload.as_mut_slice()).await?; let payload = String::from_utf8(raw_payload) .map_err(|e| anyhow::anyhow!("Failed to decode hint payload: {e}"))?; @@ -82,14 +84,14 @@ impl HintReaderServer for HintReader { // Route the hint if let Err(e) = route_hint(payload).await { // Write back on error to prevent blocking the client. - self.pipe_handle.write(&[0x00])?; + self.pipe_handle.write(&[0x00]).await?; error!("Failed to route hint: {e}"); anyhow::bail!("Failed to rout hint: {e}"); } // Write back an acknowledgement to the client to unblock their process. - self.pipe_handle.write(&[0x00])?; + self.pipe_handle.write(&[0x00]).await?; debug!(target: "hint_reader", "Successfully routed and acknowledged hint"); @@ -144,7 +146,7 @@ mod test { let (hint_writer, hint_reader) = (sys.hint_writer, sys.hint_reader); let incoming_hints = Arc::new(Mutex::new(Vec::new())); - let client = tokio::task::spawn(async move { hint_writer.write(MOCK_DATA) }); + let client = tokio::task::spawn(async move { hint_writer.write(MOCK_DATA).await }); let host = tokio::task::spawn({ let incoming_hints_ref = Arc::clone(&incoming_hints); async move { diff --git a/crates/preimage/src/oracle.rs b/crates/preimage/src/oracle.rs index 87494d231..4f6eea4fc 100644 --- a/crates/preimage/src/oracle.rs +++ b/crates/preimage/src/oracle.rs @@ -19,31 +19,32 @@ impl OracleReader { /// Set the preimage key for the global oracle reader. This will overwrite any existing key, and /// block until the host has prepared the preimage and responded with the length of the /// preimage. - fn write_key(&self, key: PreimageKey) -> Result { + async fn write_key(&self, key: PreimageKey) -> Result { // Write the key to the host so that it can prepare the preimage. let key_bytes: [u8; 32] = key.into(); - self.pipe_handle.write(&key_bytes)?; + self.pipe_handle.write(&key_bytes).await?; // Read the length prefix and reset the cursor. let mut length_buffer = [0u8; 8]; - self.pipe_handle.read_exact(&mut length_buffer)?; + self.pipe_handle.read_exact(&mut length_buffer).await?; Ok(u64::from_be_bytes(length_buffer) as usize) } } +#[async_trait::async_trait] impl PreimageOracleClient for OracleReader { /// Get the data corresponding to the currently set key from the host. Return the data in a new /// heap allocated `Vec` - fn get(&self, key: PreimageKey) -> Result> { + async fn get(&self, key: PreimageKey) -> Result> { debug!(target: "oracle_client", "Requesting data from preimage oracle. Key {key}"); - let length = self.write_key(key)?; + let length = self.write_key(key).await?; let mut data_buffer = alloc::vec![0; length]; debug!(target: "oracle_client", "Reading data from preimage oracle. Key {key}"); // Grab a read lock on the preimage pipe to read the data. - self.pipe_handle.read_exact(&mut data_buffer)?; + self.pipe_handle.read_exact(&mut data_buffer).await?; debug!(target: "oracle_client", "Successfully read data from preimage oracle. Key: {key}"); @@ -52,11 +53,11 @@ impl PreimageOracleClient for OracleReader { /// Get the data corresponding to the currently set key from the host. Write the data into the /// provided buffer - fn get_exact(&self, key: PreimageKey, buf: &mut [u8]) -> Result<()> { + async fn get_exact(&self, key: PreimageKey, buf: &mut [u8]) -> Result<()> { debug!(target: "oracle_client", "Requesting data from preimage oracle. Key {key}"); // Write the key to the host and read the length of the preimage. - let length = self.write_key(key)?; + let length = self.write_key(key).await?; debug!(target: "oracle_client", "Reading data from preimage oracle. Key {key}"); @@ -65,7 +66,7 @@ impl PreimageOracleClient for OracleReader { bail!("Buffer size {} does not match preimage size {}", buf.len(), length); } - self.pipe_handle.read_exact(buf)?; + self.pipe_handle.read_exact(buf).await?; debug!(target: "oracle_client", "Successfully read data from preimage oracle. Key: {key}"); @@ -95,7 +96,7 @@ impl PreimageOracleServer for OracleServer { { // Read the preimage request from the client, and throw early if there isn't is any. let mut buf = [0u8; 32]; - self.pipe_handle.read_exact(&mut buf)?; + self.pipe_handle.read_exact(&mut buf).await?; let preimage_key = PreimageKey::try_from(buf)?; debug!(target: "oracle_server", "Fetching preimage for key {preimage_key}"); @@ -109,7 +110,7 @@ impl PreimageOracleServer for OracleServer { .flatten() .copied() .collect::>(); - self.pipe_handle.write(data.as_slice())?; + self.pipe_handle.write(data.as_slice()).await?; debug!(target: "oracle_server", "Successfully wrote preimage data for key {preimage_key}"); @@ -184,8 +185,8 @@ mod test { let (oracle_reader, oracle_server) = (sys.oracle_reader, sys.oracle_server); let client = tokio::task::spawn(async move { - let contents_a = oracle_reader.get(key_a).unwrap(); - let contents_b = oracle_reader.get(key_b).unwrap(); + let contents_a = oracle_reader.get(key_a).await.unwrap(); + let contents_b = oracle_reader.get(key_b).await.unwrap(); // Drop the file descriptors to close the pipe, stopping the host's blocking loop on // waiting for client requests. diff --git a/crates/preimage/src/pipe.rs b/crates/preimage/src/pipe.rs index 4fa8ea8ba..6355aab8e 100644 --- a/crates/preimage/src/pipe.rs +++ b/crates/preimage/src/pipe.rs @@ -1,7 +1,14 @@ //! This module contains a rudamentary pipe between two file descriptors, using [kona_common::io] //! for reading and writing from the file descriptors. -use anyhow::{bail, Result}; +use anyhow::{anyhow, Result}; +use core::{ + cell::RefCell, + cmp::Ordering, + future::Future, + pin::Pin, + task::{Context, Poll}, +}; use kona_common::{io, FileDescriptor}; /// [PipeHandle] is a handle for one end of a bidirectional pipe. @@ -24,30 +31,14 @@ impl PipeHandle { io::read(self.read_handle, buf) } - /// Reads exactly `buf.len()` bytes into `buf`, blocking until all bytes are read. - pub fn read_exact(&self, buf: &mut [u8]) -> Result { - let mut read = 0; - while read < buf.len() { - let chunk_read = self.read(&mut buf[read..])?; - read += chunk_read; - } - Ok(read) + /// Reads exactly `buf.len()` bytes into `buf`. + pub fn read_exact<'a>(&self, buf: &'a mut [u8]) -> impl Future> + 'a { + ReadFuture { pipe_handle: *self, buf: RefCell::new(buf), read: 0 } } /// Write the given buffer to the pipe. - pub fn write(&self, buf: &[u8]) -> Result { - let mut written = 0; - loop { - match io::write(self.write_handle, &buf[written..]) { - Ok(0) => break, - Ok(n) => { - written += n; - continue; - } - Err(e) => bail!("Failed to write preimage key: {}", e), - } - } - Ok(written) + pub fn write<'a>(&self, buf: &'a [u8]) -> impl Future> + 'a { + WriteFuture { pipe_handle: *self, buf, written: 0 } } /// Returns the read handle for the pipe. @@ -60,3 +51,69 @@ impl PipeHandle { self.write_handle } } + +/// A future that reads from a pipe, returning [Poll::Ready] when the buffer is full. +struct ReadFuture<'a> { + /// The pipe handle to read from + pipe_handle: PipeHandle, + /// The buffer to read into + buf: RefCell<&'a mut [u8]>, + /// The number of bytes read so far + read: usize, +} + +impl Future for ReadFuture<'_> { + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll { + let mut buf = self.buf.borrow_mut(); + let buf_len = buf.len(); + let chunk_read = self.pipe_handle.read(&mut buf[self.read..])?; + + // Drop the borrow on self. + drop(buf); + + self.read += chunk_read; + + match self.read.cmp(&buf_len) { + Ordering::Equal => Poll::Ready(Ok(self.read)), + Ordering::Greater => Poll::Ready(Err(anyhow!("Read more bytes than buffer size"))), + Ordering::Less => { + // Register the current task to be woken up when it can make progress + ctx.waker().wake_by_ref(); + Poll::Pending + } + } + } +} + +/// A future that writes to a pipe, returning [Poll::Ready] when the full buffer has been written. +struct WriteFuture<'a> { + /// The pipe handle to write to + pipe_handle: PipeHandle, + /// The buffer to write + buf: &'a [u8], + /// The number of bytes written so far + written: usize, +} + +impl Future for WriteFuture<'_> { + type Output = Result; + + fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll { + loop { + match io::write(self.pipe_handle.write_handle(), &self.buf[self.written..]) { + Ok(0) => return Poll::Ready(Ok(self.written)), // Finished writing + Ok(n) => { + self.written += n; + continue; + } + Err(_) => { + // Register the current task to be woken up when it can make progress + ctx.waker().wake_by_ref(); + return Poll::Pending; + } + } + } + } +} diff --git a/crates/preimage/src/traits.rs b/crates/preimage/src/traits.rs index bede547e2..c62d1c5bc 100644 --- a/crates/preimage/src/traits.rs +++ b/crates/preimage/src/traits.rs @@ -1,10 +1,12 @@ use crate::PreimageKey; use alloc::{boxed::Box, string::String, vec::Vec}; use anyhow::Result; +use async_trait::async_trait; use core::future::Future; /// A [PreimageOracleClient] is a high-level interface to read data from the host, keyed by a /// [PreimageKey]. +#[async_trait] pub trait PreimageOracleClient { /// Get the data corresponding to the currently set key from the host. Return the data in a new /// heap allocated `Vec` @@ -12,7 +14,7 @@ pub trait PreimageOracleClient { /// # Returns /// - `Ok(Vec)` if the data was successfully fetched from the host. /// - `Err(_)` if the data could not be fetched from the host. - fn get(&self, key: PreimageKey) -> Result>; + async fn get(&self, key: PreimageKey) -> Result>; /// Get the data corresponding to the currently set key from the host. Writes the data into the /// provided buffer. @@ -20,11 +22,12 @@ pub trait PreimageOracleClient { /// # Returns /// - `Ok(())` if the data was successfully written into the buffer. /// - `Err(_)` if the data could not be written into the buffer. - fn get_exact(&self, key: PreimageKey, buf: &mut [u8]) -> Result<()>; + async fn get_exact(&self, key: PreimageKey, buf: &mut [u8]) -> Result<()>; } /// A [HintWriterClient] is a high-level interface to the hint pipe. It provides a way to write /// hints to the host. +#[async_trait] pub trait HintWriterClient { /// Write a hint to the host. This will overwrite any existing hint in the pipe, and block until /// all data has been written. @@ -32,12 +35,12 @@ pub trait HintWriterClient { /// # Returns /// - `Ok(())` if the hint was successfully written to the host. /// - `Err(_)` if the hint could not be written to the host. - fn write(&self, hint: &str) -> Result<()>; + async fn write(&self, hint: &str) -> Result<()>; } /// A [PreimageOracleServer] is a high-level interface to accept read requests from the client and /// write the preimage data to the client pipe. -#[async_trait::async_trait] +#[async_trait] pub trait PreimageOracleServer { /// Get the next preimage request and return the response to the client. /// @@ -52,7 +55,7 @@ pub trait PreimageOracleServer { /// A [HintReaderServer] is a high-level interface to read preimage hints from the /// [HintWriterClient] and prepare them for consumption by the client program. -#[async_trait::async_trait] +#[async_trait] pub trait HintReaderServer { /// Get the next hint request and return the acknowledgement to the client. ///