diff --git a/Cargo.lock b/Cargo.lock index 983282c40..3ae0611a8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1652,6 +1652,7 @@ version = "0.0.1" dependencies = [ "alloy-primitives", "anyhow", + "async-trait", "cfg-if", "kona-common", "tempfile", diff --git a/crates/derive/Cargo.toml b/crates/derive/Cargo.toml index 667233eb0..3abe1afc5 100644 --- a/crates/derive/Cargo.toml +++ b/crates/derive/Cargo.toml @@ -23,7 +23,7 @@ kona-primitives = { path = "../primitives", version = "0.0.1" } alloy-sol-types = { version = "0.7.1", default-features = false } op-alloy-consensus = { git = "https://github.com/clabby/op-alloy", branch = "refcell/consensus-port", default-features = false } alloy-eips = { git = "https://github.com/alloy-rs/alloy", rev = "e3f2f07", default-features = false } -async-trait = "0.1.77" +async-trait = "0.1.80" hashbrown = "0.14.3" unsigned-varint = "0.8.0" miniz_oxide = { version = "0.7.2" } diff --git a/crates/plasma/Cargo.toml b/crates/plasma/Cargo.toml index 8b5fd8cb7..3ce24311d 100644 --- a/crates/plasma/Cargo.toml +++ b/crates/plasma/Cargo.toml @@ -20,7 +20,7 @@ kona-derive = { path = "../derive" } # External alloy-consensus = { git = "https://github.com/alloy-rs/alloy", rev = "e3f2f07", default-features = false } alloy-primitives = { workspace = true, features = ["rlp"] } -async-trait = "0.1.77" +async-trait = "0.1.80" # `serde` feature dependencies serde = { version = "1.0.197", default-features = false, features = ["derive"], optional = true } diff --git a/crates/preimage/Cargo.toml b/crates/preimage/Cargo.toml index 975b2fe83..480474060 100644 --- a/crates/preimage/Cargo.toml +++ b/crates/preimage/Cargo.toml @@ -18,6 +18,9 @@ alloy-primitives.workspace = true # local kona-common = { path = "../common", version = "0.0.1" } +# external +async-trait = "0.1.80" + [dev-dependencies] tokio = { version = "1.36.0", features = ["full"] } tempfile = "3.10.0" diff --git a/crates/preimage/src/hint.rs b/crates/preimage/src/hint.rs index 4ed6159b5..35a2f07b8 100644 --- a/crates/preimage/src/hint.rs +++ b/crates/preimage/src/hint.rs @@ -1,6 +1,7 @@ use crate::{traits::HintWriterClient, HintReaderServer, PipeHandle}; -use alloc::{string::String, vec}; +use alloc::{boxed::Box, string::String, vec}; use anyhow::Result; +use core::future::Future; use tracing::{debug, error}; /// A [HintWriter] is a high-level interface to the hint pipe. It provides a way to write hints to @@ -58,8 +59,13 @@ impl HintReader { } } +#[async_trait::async_trait] impl HintReaderServer for HintReader { - fn next_hint(&self, mut route_hint: impl FnMut(String) -> Result<()>) -> Result<()> { + async fn next_hint(&mut self, mut route_hint: F) -> Result<()> + where + F: FnMut(String) -> Fut + Send, + Fut: Future> + Send, + { // Read the length of the raw hint payload. let mut len_buf = [0u8; 4]; self.pipe_handle.read_exact(&mut len_buf)?; @@ -74,7 +80,7 @@ impl HintReaderServer for HintReader { debug!(target: "hint_reader", "Successfully read hint: \"{payload}\""); // Route the hint - if let Err(e) = route_hint(payload) { + if let Err(e) = route_hint(payload).await { // Write back on error to prevent blocking the client. self.pipe_handle.write(&[0x00])?; @@ -90,15 +96,18 @@ impl HintReaderServer for HintReader { Ok(()) } } + #[cfg(test)] mod test { extern crate std; use super::*; - use alloc::vec::Vec; + use alloc::{sync::Arc, vec::Vec}; + use core::pin::Pin; use kona_common::FileDescriptor; use std::{fs::File, os::fd::AsRawFd}; use tempfile::tempfile; + use tokio::sync::Mutex; /// Test struct containing the [HintReader] and [HintWriter]. The [File]s are stored in this /// struct so that they are not dropped until the end of the test. @@ -132,20 +141,27 @@ mod test { const MOCK_DATA: &str = "test-hint 0xfacade"; let sys = client_and_host(); - let (hint_writer, hint_reader) = (sys.hint_writer, sys.hint_reader); + let (hint_writer, mut hint_reader) = (sys.hint_writer, sys.hint_reader); + let incoming_hints = Arc::new(Mutex::new(Vec::new())); let client = tokio::task::spawn(async move { hint_writer.write(MOCK_DATA) }); - let host = tokio::task::spawn(async move { - let mut v = Vec::new(); - let route_hint = |hint: String| { - v.push(hint.clone()); - Ok(()) - }; - hint_reader.next_hint(route_hint).unwrap(); - - assert_eq!(v.len(), 1); - - v.remove(0) + let host = tokio::task::spawn({ + let incoming_hints_ref = Arc::clone(&incoming_hints); + async move { + let route_hint = + move |hint: String| -> Pin> + Send>> { + let hints = Arc::clone(&incoming_hints_ref); + Box::pin(async move { + hints.lock().await.push(hint.clone()); + Ok(()) + }) + }; + hint_reader.next_hint(&route_hint).await.unwrap(); + + let mut hints = incoming_hints.lock().await; + assert_eq!(hints.len(), 1); + hints.remove(0) + } }); let (_, h) = tokio::join!(client, host); diff --git a/crates/preimage/src/oracle.rs b/crates/preimage/src/oracle.rs index 5d05bbc94..2fa6eb3c4 100644 --- a/crates/preimage/src/oracle.rs +++ b/crates/preimage/src/oracle.rs @@ -1,6 +1,7 @@ use crate::{PipeHandle, PreimageKey, PreimageOracleClient, PreimageOracleServer}; -use alloc::vec::Vec; +use alloc::{boxed::Box, sync::Arc, vec::Vec}; use anyhow::{bail, Result}; +use core::future::Future; use tracing::debug; /// An [OracleReader] is a high-level interface to the preimage oracle. @@ -85,11 +86,13 @@ impl OracleServer { } } +#[async_trait::async_trait] impl PreimageOracleServer for OracleServer { - fn next_preimage_request<'a>( - &self, - mut get_preimage: impl FnMut(PreimageKey) -> Result<&'a Vec>, - ) -> Result<()> { + async fn next_preimage_request(&mut self, mut get_preimage: F) -> Result<()> + where + F: FnMut(PreimageKey) -> Fut + Send, + Fut: Future>>> + Send, + { // Read the preimage request from the client, and throw early if there isn't is any. let mut buf = [0u8; 32]; self.pipe_handle.read_exact(&mut buf)?; @@ -98,7 +101,7 @@ impl PreimageOracleServer for OracleServer { debug!(target: "oracle_server", "Fetching preimage for key {preimage_key}"); // Fetch the preimage value from the preimage getter. - let value = get_preimage(preimage_key)?; + let value = get_preimage(preimage_key).await?; // Write the length as a big-endian u64 followed by the data. let data = [(value.len() as u64).to_be_bytes().as_ref(), value.as_ref()] @@ -121,9 +124,11 @@ mod test { use super::*; use crate::PreimageKeyType; use alloy_primitives::keccak256; + use core::pin::Pin; use kona_common::FileDescriptor; use std::{collections::HashMap, fs::File, os::fd::AsRawFd}; use tempfile::tempfile; + use tokio::sync::Mutex; /// Test struct containing the [OracleReader] and a [OracleServer] for the host, plus the open /// [File]s. The [File]s are stored in this struct so that they are not dropped until the @@ -167,12 +172,15 @@ mod test { let key_b: PreimageKey = PreimageKey::new(*keccak256(MOCK_DATA_B), PreimageKeyType::Keccak256); - let mut preimages = HashMap::new(); - preimages.insert(key_a, MOCK_DATA_A.to_vec()); - preimages.insert(key_b, MOCK_DATA_B.to_vec()); + let preimages = { + let mut preimages = HashMap::new(); + preimages.insert(key_a, Arc::new(MOCK_DATA_A.to_vec())); + preimages.insert(key_b, Arc::new(MOCK_DATA_B.to_vec())); + Arc::new(Mutex::new(preimages)) + }; let sys = client_and_host(); - let (oracle_reader, oracle_server) = (sys.oracle_reader, sys.oracle_server); + let (oracle_reader, mut oracle_server) = (sys.oracle_reader, sys.oracle_server); let client = tokio::task::spawn(async move { let contents_a = oracle_reader.get(key_a).unwrap(); @@ -185,11 +193,24 @@ mod test { (contents_a, contents_b) }); let host = tokio::task::spawn(async move { - let get_preimage = - |key| preimages.get(&key).ok_or(anyhow::anyhow!("Preimage not available")); + #[allow(clippy::type_complexity)] + let get_preimage = move |key: PreimageKey| -> Pin< + Box>>> + Send>, + > { + let preimages = Arc::clone(&preimages); + Box::pin(async move { + // Simulate fetching preimage data + preimages + .lock() + .await + .get(&key) + .ok_or(anyhow::anyhow!("Preimage not available")) + .cloned() + }) + }; loop { - if oracle_server.next_preimage_request(get_preimage).is_err() { + if oracle_server.next_preimage_request(&get_preimage).await.is_err() { break; } } diff --git a/crates/preimage/src/traits.rs b/crates/preimage/src/traits.rs index c2a20d4eb..88f79ef0c 100644 --- a/crates/preimage/src/traits.rs +++ b/crates/preimage/src/traits.rs @@ -1,6 +1,7 @@ use crate::PreimageKey; -use alloc::{string::String, vec::Vec}; +use alloc::{boxed::Box, string::String, sync::Arc, vec::Vec}; use anyhow::Result; +use core::future::Future; /// A [PreimageOracleClient] is a high-level interface to read data from the host, keyed by a /// [PreimageKey]. @@ -36,20 +37,22 @@ pub trait HintWriterClient { /// A [PreimageOracleServer] is a high-level interface to accept read requests from the client and /// write the preimage data to the client pipe. +#[async_trait::async_trait] pub trait PreimageOracleServer { /// Get the next preimage request and return the response to the client. /// /// # Returns /// - `Ok(())` if the data was successfully written into the client pipe. /// - `Err(_)` if the data could not be written to the client. - fn next_preimage_request<'a>( - &self, - get_preimage: impl FnMut(PreimageKey) -> Result<&'a Vec>, - ) -> Result<()>; + async fn next_preimage_request(&mut self, get_preimage: F) -> Result<()> + where + F: FnMut(PreimageKey) -> Fut + Send, + Fut: Future>>> + Send; } /// A [HintReaderServer] is a high-level interface to read preimage hints from the /// [HintWriterClient] and prepare them for consumption by the client program. +#[async_trait::async_trait] pub trait HintReaderServer { /// Get the next hint request and return the acknowledgement to the client. /// @@ -57,5 +60,8 @@ pub trait HintReaderServer { /// - `Ok(())` if the hint was received and the client was notified of the host's /// acknowledgement. /// - `Err(_)` if the hint was not received correctly. - fn next_hint(&self, route_hint: impl FnMut(String) -> Result<()>) -> Result<()>; + async fn next_hint(&mut self, route_hint: F) -> Result<()> + where + F: FnMut(String) -> Fut + Send, + Fut: Future> + Send; }