diff --git a/Cargo.lock b/Cargo.lock index fd028a99c..a90b0506e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -13389,6 +13389,7 @@ name = "strata-prover-client-rpc-api" version = "0.1.0" dependencies = [ "jsonrpsee", + "strata-primitives", "strata-rpc-types", ] diff --git a/bin/prover-client/src/errors.rs b/bin/prover-client/src/errors.rs index 68cef01c9..c44414774 100644 --- a/bin/prover-client/src/errors.rs +++ b/bin/prover-client/src/errors.rs @@ -50,6 +50,14 @@ pub enum ProvingTaskError { #[error("Witness not found")] WitnessNotFound, + /// Occurs when a newly created proving task is expected but none is found. + #[error("No tasks found after creation; at least one was expected")] + NoTasksFound, + + /// Occurs when the witness data provided is invalid. + #[error("{0}")] + InvalidWitness(String), + /// Represents a generic database error. #[error("Database error: {0:?}")] DatabaseError(DbError), diff --git a/bin/prover-client/src/hosts/native.rs b/bin/prover-client/src/hosts/native.rs index 0d00af62c..86aa10a13 100644 --- a/bin/prover-client/src/hosts/native.rs +++ b/bin/prover-client/src/hosts/native.rs @@ -5,7 +5,7 @@ use strata_primitives::proof::ProofContext; use strata_proofimpl_btc_blockspace::logic::process_blockspace_proof_outer; use strata_proofimpl_checkpoint::process_checkpoint_proof_outer; use strata_proofimpl_cl_agg::process_cl_agg; -use strata_proofimpl_cl_stf::process_cl_stf; +use strata_proofimpl_cl_stf::batch_process_cl_stf; use strata_proofimpl_evm_ee_stf::process_block_transaction_outer; use strata_proofimpl_l1_batch::process_l1_batch_proof; @@ -22,37 +22,37 @@ const MOCK_VK: [u32; 8] = [0u32; 8]; /// allowing for efficient host selection for different proof types. pub fn get_host(id: &ProofContext) -> NativeHost { match id { - ProofContext::BtcBlockspace(_) => NativeHost { + ProofContext::BtcBlockspace(..) => NativeHost { process_proof: Arc::new(Box::new(move |zkvm: &NativeMachine| { process_blockspace_proof_outer(zkvm); Ok(()) })), }, - ProofContext::L1Batch(_, _) => NativeHost { + ProofContext::L1Batch(..) => NativeHost { process_proof: Arc::new(Box::new(move |zkvm: &NativeMachine| { process_l1_batch_proof(zkvm, &MOCK_VK); Ok(()) })), }, - ProofContext::EvmEeStf(_, _) => NativeHost { + ProofContext::EvmEeStf(..) => NativeHost { process_proof: Arc::new(Box::new(move |zkvm: &NativeMachine| { process_block_transaction_outer(zkvm); Ok(()) })), }, - ProofContext::ClStf(_) => NativeHost { + ProofContext::ClStf(..) => NativeHost { process_proof: Arc::new(Box::new(move |zkvm: &NativeMachine| { - process_cl_stf(zkvm, &MOCK_VK); + batch_process_cl_stf(zkvm, &MOCK_VK); Ok(()) })), }, - ProofContext::ClAgg(_, _) => NativeHost { + ProofContext::ClAgg(..) => NativeHost { process_proof: Arc::new(Box::new(move |zkvm: &NativeMachine| { process_cl_agg(zkvm, &MOCK_VK); Ok(()) })), }, - ProofContext::Checkpoint(_) => NativeHost { + ProofContext::Checkpoint(..) => NativeHost { process_proof: Arc::new(Box::new(move |zkvm: &NativeMachine| { process_checkpoint_proof_outer(zkvm, &MOCK_VK, &MOCK_VK); Ok(()) diff --git a/bin/prover-client/src/hosts/risc0.rs b/bin/prover-client/src/hosts/risc0.rs index 8d342f4ed..dfea06ab0 100644 --- a/bin/prover-client/src/hosts/risc0.rs +++ b/bin/prover-client/src/hosts/risc0.rs @@ -32,11 +32,11 @@ static CHECKPOINT_HOST: LazyLock = /// instance, allowing for efficient host selection for different proof types. pub fn get_host(id: &ProofContext) -> &'static Risc0Host { match id { - ProofContext::BtcBlockspace(_) => &BTC_BLOCKSPACE_HOST, - ProofContext::L1Batch(_, _) => &L1_BATCH_HOST, - ProofContext::EvmEeStf(_, _) => &EVM_EE_STF_HOST, - ProofContext::ClStf(_) => &CL_STF_HOST, - ProofContext::ClAgg(_, _) => &CL_AGG_HOST, - ProofContext::Checkpoint(_) => &CHECKPOINT_HOST, + ProofContext::BtcBlockspace(..) => &BTC_BLOCKSPACE_HOST, + ProofContext::L1Batch(..) => &L1_BATCH_HOST, + ProofContext::EvmEeStf(..) => &EVM_EE_STF_HOST, + ProofContext::ClStf(..) => &CL_STF_HOST, + ProofContext::ClAgg(..) => &CL_AGG_HOST, + ProofContext::Checkpoint(..) => &CHECKPOINT_HOST, } } diff --git a/bin/prover-client/src/hosts/sp1.rs b/bin/prover-client/src/hosts/sp1.rs index f4e808a01..63666ad16 100644 --- a/bin/prover-client/src/hosts/sp1.rs +++ b/bin/prover-client/src/hosts/sp1.rs @@ -58,11 +58,11 @@ pub static CHECKPOINT_HOST: LazyLock = std::sync::LazyLock::new(|| { /// instance, allowing for efficient host selection for different proof types. pub fn get_host(id: &ProofContext) -> &'static SP1Host { match id { - ProofContext::BtcBlockspace(_) => &BTC_BLOCKSPACE_HOST, - ProofContext::L1Batch(_, _) => &L1_BATCH_HOST, - ProofContext::EvmEeStf(_, _) => &EVM_EE_STF_HOST, - ProofContext::ClStf(_) => &CL_STF_HOST, - ProofContext::ClAgg(_, _) => &CL_AGG_HOST, - ProofContext::Checkpoint(_) => &CHECKPOINT_HOST, + ProofContext::BtcBlockspace(..) => &BTC_BLOCKSPACE_HOST, + ProofContext::L1Batch(..) => &L1_BATCH_HOST, + ProofContext::EvmEeStf(..) => &EVM_EE_STF_HOST, + ProofContext::ClStf(..) => &CL_STF_HOST, + ProofContext::ClAgg(..) => &CL_AGG_HOST, + ProofContext::Checkpoint(..) => &CHECKPOINT_HOST, } } diff --git a/bin/prover-client/src/operators/checkpoint.rs b/bin/prover-client/src/operators/checkpoint.rs index 2c199efc5..85a6b796a 100644 --- a/bin/prover-client/src/operators/checkpoint.rs +++ b/bin/prover-client/src/operators/checkpoint.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use jsonrpsee::http_client::HttpClient; use strata_db::traits::ProofDatabase; use strata_primitives::{ + buf::Buf32, params::RollupParams, proof::{ProofContext, ProofKey}, }; @@ -10,6 +11,7 @@ use strata_proofimpl_checkpoint::prover::{CheckpointProver, CheckpointProverInpu use strata_rocksdb::prover::db::ProofDb; use strata_rpc_api::StrataApiClient; use strata_rpc_types::RpcCheckpointInfo; +use strata_state::id::L2BlockId; use strata_zkvm::AggregationInput; use tokio::sync::Mutex; use tracing::error; @@ -59,6 +61,31 @@ impl CheckpointOperator { .ok_or(ProvingTaskError::WitnessNotFound) } + /// Retrieves the [`L2BlockId`] for the given `block_num` + pub async fn get_l2id(&self, block_num: u64) -> Result { + let l2_headers = self + .cl_client + .get_headers_at_idx(block_num) + .await + .inspect_err(|_| error!(%block_num, "Failed to fetch l2_headers")) + .map_err(|e| ProvingTaskError::RpcError(e.to_string()))?; + + let headers = l2_headers.ok_or_else(|| { + error!(%block_num, "Failed to fetch L2 block"); + ProvingTaskError::InvalidWitness(format!("Invalid L2 block height {}", block_num)) + })?; + + let first_header: Buf32 = headers + .first() + .ok_or_else(|| { + ProvingTaskError::InvalidWitness(format!("Invalid L2 block height {}", block_num)) + })? + .block_id + .into(); + + Ok(first_header.into()) + } + /// Retrieves the latest checkpoint index pub async fn fetch_latest_ckp_idx(&self) -> Result { self.cl_client @@ -88,13 +115,27 @@ impl ProvingOp for CheckpointOperator { .l1_batch_operator .create_task(checkpoint_info.l1_range, task_tracker.clone(), db) .await?; - let l1_batch_id = l1_batch_keys.first().expect("at least one").context(); + let l1_batch_id = l1_batch_keys + .first() + .ok_or_else(|| ProvingTaskError::NoTasksFound)? + .context(); + + // Doing the manual block idx to id transformation. Will be removed once checkpoint_info + // include the range in terms of block_id. + // https://alpenlabs.atlassian.net/browse/STR-756 + let start_l2_idx = self.get_l2id(checkpoint_info.l2_range.0).await?; + let end_l2_idx = self.get_l2id(checkpoint_info.l2_range.1).await?; + let l2_range = vec![(start_l2_idx, end_l2_idx)]; let l2_batch_keys = self .l2_batch_operator - .create_task(checkpoint_info.l2_range, task_tracker.clone(), db) + .create_task(l2_range, task_tracker.clone(), db) .await?; - let l2_batch_id = l2_batch_keys.first().expect("at least one").context(); + + let l2_batch_id = l2_batch_keys + .first() + .ok_or_else(|| ProvingTaskError::NoTasksFound)? + .context(); let deps = vec![*l1_batch_id, *l2_batch_id]; diff --git a/bin/prover-client/src/operators/cl_agg.rs b/bin/prover-client/src/operators/cl_agg.rs index b9a3b5025..4b0078e31 100644 --- a/bin/prover-client/src/operators/cl_agg.rs +++ b/bin/prover-client/src/operators/cl_agg.rs @@ -4,7 +4,9 @@ use strata_db::traits::ProofDatabase; use strata_primitives::proof::{ProofContext, ProofKey}; use strata_proofimpl_cl_agg::{ClAggInput, ClAggProver}; use strata_rocksdb::prover::db::ProofDb; +use strata_state::id::L2BlockId; use tokio::sync::Mutex; +use tracing::error; use super::{cl_stf::ClStfOperator, ProvingOp}; use crate::{errors::ProvingTaskError, hosts, task_tracker::TaskTracker}; @@ -29,28 +31,33 @@ impl ClAggOperator { impl ProvingOp for ClAggOperator { type Prover = ClAggProver; - type Params = (u64, u64); + type Params = Vec<(L2BlockId, L2BlockId)>; async fn create_task( &self, - params: (u64, u64), + batches: Self::Params, task_tracker: Arc>, db: &ProofDb, ) -> Result, ProvingTaskError> { - let (start_height, end_height) = params; + let mut cl_stf_deps = Vec::with_capacity(batches.len()); - let len = (end_height - start_height) as usize + 1; - let mut cl_stf_deps = Vec::with_capacity(len); + // Extract first and last block IDs from batches, error if empty + let (start_blkid, end_blkid) = match (batches.first(), batches.last()) { + (Some(first), Some(last)) => (first.0, last.1), + _ => { + error!("Aggregation task with empty batch"); + return Err(ProvingTaskError::InvalidInput( + "Aggregation task with empty batch".into(), + )); + } + }; - let start_blkid = self.cl_stf_operator.get_id(start_height).await?; - let end_blkid = self.cl_stf_operator.get_id(end_height).await?; let cl_agg_proof_id = ProofContext::ClAgg(start_blkid, end_blkid); - for height in start_height..=end_height { - let blkid = self.cl_stf_operator.get_id(height).await?; - let proof_id = ProofContext::ClStf(blkid); + for (start_blkid, end_blkid) in batches { + let proof_id = ProofContext::ClStf(start_blkid, end_blkid); self.cl_stf_operator - .create_task(height, task_tracker.clone(), db) + .create_task((start_blkid, end_blkid), task_tracker.clone(), db) .await?; cl_stf_deps.push(proof_id); } @@ -67,7 +74,7 @@ impl ProvingOp for ClAggOperator { task_id: &ProofKey, db: &ProofDb, ) -> Result { - let (start_blkid, _) = match task_id.context() { + let (start_blkid, end_blkid) = match task_id.context() { ProofContext::ClAgg(start, end) => (start, end), _ => return Err(ProvingTaskError::InvalidInput("ClAgg".to_string())), }; @@ -88,7 +95,7 @@ impl ProvingOp for ClAggOperator { } let cl_stf_vk = hosts::get_verification_key(&ProofKey::new( - ProofContext::ClStf(*start_blkid), + ProofContext::ClStf(*start_blkid, *end_blkid), *task_id.host(), )); Ok(ClAggInput { batch, cl_stf_vk }) diff --git a/bin/prover-client/src/operators/cl_stf.rs b/bin/prover-client/src/operators/cl_stf.rs index 04bd8f5e9..b02d41618 100644 --- a/bin/prover-client/src/operators/cl_stf.rs +++ b/bin/prover-client/src/operators/cl_stf.rs @@ -10,6 +10,7 @@ use strata_primitives::{ use strata_proofimpl_cl_stf::prover::{ClStfInput, ClStfProver}; use strata_rocksdb::prover::db::ProofDb; use strata_rpc_api::StrataApiClient; +use strata_rpc_types::RpcBlockHeader; use strata_state::id::L2BlockId; use tokio::sync::Mutex; use tracing::error; @@ -49,56 +50,97 @@ impl ClStfOperator { } } - /// Retrieves the [`L2BlockId`] for the given `block_num` - pub async fn get_id(&self, block_num: u64) -> Result { - let l2_headers = self + async fn get_l2_block_header( + &self, + blkid: L2BlockId, + ) -> Result { + let header = self .cl_client - .get_headers_at_idx(block_num) + .get_header_by_id(blkid) .await - .inspect_err(|_| error!(%block_num, "Failed to fetch l2_headers")) - .map_err(|e| ProvingTaskError::RpcError(e.to_string()))?; + .inspect_err(|_| error!(%blkid, "Failed to fetch corresponding ee data")) + .map_err(|e| ProvingTaskError::RpcError(e.to_string()))? + .ok_or_else(|| { + error!(%blkid, "L2 Block not found"); + ProvingTaskError::InvalidWitness(format!("L2 Block {} not found", blkid)) + })?; - let cl_stf_id_buf: Buf32 = l2_headers - .expect("invalid height") - .first() - .expect("at least one l2 blockid") - .block_id - .into(); - Ok(cl_stf_id_buf.into()) + Ok(header) } - /// Retrieves the slot num of the given [`L2BlockId`] - pub async fn get_slot(&self, id: L2BlockId) -> Result { - let header = self + /// Retrieves the evm_ee block hash corresponding to the given L2 block ID + pub async fn get_exec_id(&self, cl_block_id: L2BlockId) -> Result { + let header = self.get_l2_block_header(cl_block_id).await?; + let block = self.evm_ee_operator.get_block(header.block_idx).await?; + Ok(block.header.hash.into()) + } + + /// Retrieves the specified number of ancestor block IDs for the given block ID. + pub async fn get_block_ancestors( + &self, + blkid: L2BlockId, + n_ancestors: u64, + ) -> Result, ProvingTaskError> { + let mut ancestors = Vec::with_capacity(n_ancestors as usize); + let mut blkid = blkid; + for _ in 0..=n_ancestors { + blkid = self.get_prev_block_id(blkid).await?; + ancestors.push(blkid); + } + Ok(ancestors) + } + + /// Retrieves the previous [`L2BlockId`] for the given `L2BlockId` + pub async fn get_prev_block_id(&self, blkid: L2BlockId) -> Result { + let l2_block = self .cl_client - .get_header_by_id(id) + .get_header_by_id(blkid) .await - .map_err(|e| ProvingTaskError::RpcError(e.to_string()))? - .expect("invalid blkid"); - Ok(header.block_idx) + .inspect_err(|_| error!(%blkid, "Failed to fetch l2_header")) + .map_err(|e| ProvingTaskError::RpcError(e.to_string()))?; + + let prev_block: Buf32 = l2_block + .ok_or_else(|| { + error!(%blkid, "L2 Block not found"); + ProvingTaskError::InvalidWitness(format!("L2 Block {} not found", blkid)) + })? + .prev_block + .into(); + + Ok(prev_block.into()) } } impl ProvingOp for ClStfOperator { type Prover = ClStfProver; - type Params = u64; + type Params = (L2BlockId, L2BlockId); async fn create_task( &self, - block_num: u64, + block_range: Self::Params, task_tracker: Arc>, db: &ProofDb, ) -> Result, ProvingTaskError> { + let (start_block_id, end_block_id) = block_range; + + let el_start_block_id = self.get_exec_id(start_block_id).await?; + let el_end_block_id = self.get_exec_id(end_block_id).await?; + let evm_ee_tasks = self .evm_ee_operator - .create_task((block_num, block_num), task_tracker.clone(), db) + .create_task( + (el_start_block_id, el_end_block_id), + task_tracker.clone(), + db, + ) .await?; + let evm_ee_id = evm_ee_tasks .first() - .expect("creation of task should result on at least one key") + .ok_or_else(|| ProvingTaskError::NoTasksFound)? .context(); - let cl_stf_id = ProofContext::ClStf(self.get_id(block_num).await?); + let cl_stf_id = ProofContext::ClStf(start_block_id, end_block_id); db.put_proof_deps(cl_stf_id, vec![*evm_ee_id]) .map_err(ProvingTaskError::DatabaseError)?; @@ -112,18 +154,28 @@ impl ProvingOp for ClStfOperator { task_id: &ProofKey, db: &ProofDb, ) -> Result { - let block_id = match task_id.context() { - ProofContext::ClStf(id) => id, - _ => return Err(ProvingTaskError::InvalidInput("EvmEe".to_string())), + let (start_block_hash, end_block_hash) = match task_id.context() { + ProofContext::ClStf(start, end) => (*start, *end), + _ => return Err(ProvingTaskError::InvalidInput("CL_STF".to_string())), }; - let block_num = self.get_slot(*block_id).await?; - let raw_witness: Option> = self - .cl_client - .get_cl_block_witness_raw(block_num) - .await - .map_err(|e| ProvingTaskError::RpcError(e.to_string()))?; - let witness = raw_witness.ok_or(ProvingTaskError::WitnessNotFound)?; - let (pre_state, l2_block) = borsh::from_slice(&witness)?; + + let start_block = self.get_l2_block_header(start_block_hash).await?; + let end_block = self.get_l2_block_header(end_block_hash).await?; + let num_blocks = end_block.block_idx - start_block.block_idx; + + // Get ancestor blocks and reverse to oldest-first order + let mut l2_block_ids = self.get_block_ancestors(end_block_hash, num_blocks).await?; + l2_block_ids.reverse(); + + let mut stf_witness_payloads = Vec::new(); + for l2_block_id in l2_block_ids { + let raw_witness: Vec = self + .cl_client + .get_cl_block_witness_raw(l2_block_id) + .await + .map_err(|e| ProvingTaskError::RpcError(e.to_string()))?; + stf_witness_payloads.push(raw_witness); + } let evm_ee_ids = db .get_proof_deps(*task_id.context()) @@ -131,7 +183,7 @@ impl ProvingOp for ClStfOperator { .ok_or(ProvingTaskError::DependencyNotFound(*task_id))?; let evm_ee_id = evm_ee_ids .first() - .expect("should have at least a dependency"); + .ok_or_else(|| ProvingTaskError::NoTasksFound)?; let evm_ee_key = ProofKey::new(*evm_ee_id, *task_id.host()); let evm_ee_proof = db .get_proof(evm_ee_key) @@ -142,8 +194,7 @@ impl ProvingOp for ClStfOperator { let rollup_params = self.rollup_params.as_ref().clone(); Ok(ClStfInput { rollup_params, - pre_state, - l2_block, + stf_witness_payloads, evm_ee_proof, evm_ee_vk, }) diff --git a/bin/prover-client/src/operators/evm_ee.rs b/bin/prover-client/src/operators/evm_ee.rs index 13e7010ef..f9d7ce534 100644 --- a/bin/prover-client/src/operators/evm_ee.rs +++ b/bin/prover-client/src/operators/evm_ee.rs @@ -58,22 +58,15 @@ impl EvmEeOperator { impl ProvingOp for EvmEeOperator { type Prover = EvmEeProver; - type Params = (u64, u64); + type Params = (Buf32, Buf32); async fn create_task( &self, - block_range: (u64, u64), + block_range: Self::Params, task_tracker: Arc>, _db: &ProofDb, ) -> Result, ProvingTaskError> { - let (start_block_num, end_block_num) = block_range; - - let start_block = self.get_block(start_block_num).await?; - let start_blkid: Buf32 = start_block.header.hash.into(); - - let end_block = self.get_block(end_block_num).await?; - let end_blkid: Buf32 = end_block.header.hash.into(); - + let (start_blkid, end_blkid) = block_range; let context = ProofContext::EvmEeStf(start_blkid, end_blkid); let mut task_tracker = task_tracker.lock().await; diff --git a/bin/prover-client/src/operators/operator.rs b/bin/prover-client/src/operators/operator.rs index b5c42c460..d31b22f76 100644 --- a/bin/prover-client/src/operators/operator.rs +++ b/bin/prover-client/src/operators/operator.rs @@ -128,7 +128,9 @@ impl ProofOperator { ProofContext::EvmEeStf(_, _) => { Self::prove(&self.evm_ee_operator, proof_key, db, host).await } - ProofContext::ClStf(_) => Self::prove(&self.cl_stf_operator, proof_key, db, host).await, + ProofContext::ClStf(_, _) => { + Self::prove(&self.cl_stf_operator, proof_key, db, host).await + } ProofContext::ClAgg(_, _) => { Self::prove(&self.cl_agg_operator, proof_key, db, host).await } diff --git a/bin/prover-client/src/rpc_server.rs b/bin/prover-client/src/rpc_server.rs index 7602f1a0d..1939b42a2 100644 --- a/bin/prover-client/src/rpc_server.rs +++ b/bin/prover-client/src/rpc_server.rs @@ -5,9 +5,11 @@ use std::sync::Arc; use anyhow::Context; use async_trait::async_trait; use jsonrpsee::{core::RpcResult, RpcModule}; +use strata_primitives::buf::Buf32; use strata_prover_client_rpc_api::StrataProverClientApiServer; use strata_rocksdb::prover::db::ProofDb; use strata_rpc_types::ProofKey; +use strata_state::id::L2BlockId; use tokio::sync::{oneshot, Mutex}; use tracing::{info, warn}; @@ -87,7 +89,7 @@ impl StrataProverClientApiServer for ProverClientRpc { .expect("failed to create task")) } - async fn prove_el_block(&self, el_block_range: (u64, u64)) -> RpcResult> { + async fn prove_el_blocks(&self, el_block_range: (Buf32, Buf32)) -> RpcResult> { Ok(self .operator .evm_ee_operator() @@ -96,11 +98,14 @@ impl StrataProverClientApiServer for ProverClientRpc { .expect("failed to create task")) } - async fn prove_cl_block(&self, cl_block_num: u64) -> RpcResult> { + async fn prove_cl_blocks( + &self, + cl_block_range: (L2BlockId, L2BlockId), + ) -> RpcResult> { Ok(self .operator .cl_stf_operator() - .create_task(cl_block_num, self.task_tracker.clone(), &self.db) + .create_task(cl_block_range, self.task_tracker.clone(), &self.db) .await .expect("failed to create task")) } @@ -114,7 +119,10 @@ impl StrataProverClientApiServer for ProverClientRpc { .expect("failed to create task")) } - async fn prove_l2_batch(&self, l2_range: (u64, u64)) -> RpcResult> { + async fn prove_l2_batch( + &self, + l2_range: Vec<(L2BlockId, L2BlockId)>, + ) -> RpcResult> { Ok(self .operator .cl_agg_operator() diff --git a/bin/strata-client/src/rpc_server.rs b/bin/strata-client/src/rpc_server.rs index ccc045745..e0ae05ac2 100644 --- a/bin/strata-client/src/rpc_server.rs +++ b/bin/strata-client/src/rpc_server.rs @@ -299,23 +299,7 @@ impl StrataApiServer for StrataRpcImpl { } } - async fn get_cl_block_witness_raw(&self, idx: u64) -> RpcResult>> { - let blk_manifest_db = self.database.clone(); - let blk_ids: Vec = wait_blocking("l2_blockid", move || { - blk_manifest_db - .clone() - .l2_db() - .get_blocks_at_height(idx) - .map_err(Error::Db) - }) - .await?; - - // Check if blk_ids is empty - let blkid = match blk_ids.first() { - Some(id) => id.to_owned(), - None => return Ok(None), - }; - + async fn get_cl_block_witness_raw(&self, blkid: L2BlockId) -> RpcResult> { let l2_blk_db = self.database.clone(); let l2_blk_bundle = wait_blocking("l2_block", move || { let l2_db = l2_blk_db.l2_db(); @@ -323,14 +307,16 @@ impl StrataApiServer for StrataRpcImpl { }) .await?; + let prev_slot = l2_blk_bundle.block().header().header().blockidx() - 1; + let chain_state_db = self.database.clone(); let chain_state = wait_blocking("l2_chain_state", move || { let chs_db = chain_state_db.chain_state_db(); chs_db - .get_toplevel_state(idx - 1) + .get_toplevel_state(prev_slot) .map_err(Error::Db)? - .ok_or(Error::MissingChainstate(idx - 1)) + .ok_or(Error::MissingChainstate(prev_slot)) }) .await?; @@ -338,7 +324,7 @@ impl StrataApiServer for StrataRpcImpl { let raw_cl_block_witness = borsh::to_vec(&cl_block_witness) .map_err(|_| Error::Other("Failed to get raw cl block witness".to_string()))?; - Ok(Some(raw_cl_block_witness)) + Ok(raw_cl_block_witness) } async fn get_current_deposits(&self) -> RpcResult> { diff --git a/crates/primitives/src/proof.rs b/crates/primitives/src/proof.rs index 45f3c28ea..388bdfad4 100644 --- a/crates/primitives/src/proof.rs +++ b/crates/primitives/src/proof.rs @@ -59,9 +59,9 @@ pub enum ProofContext { /// Transition Function (STF) proof. EvmEeStf(Buf32, Buf32), - /// Identifier for the Consensus Layer (CL) block used in generating the State Transition + /// Identifier for the Consensus Layer (CL) blocks used in generating the State Transition /// Function (STF) proof. - ClStf(L2BlockId), + ClStf(L2BlockId, L2BlockId), /// Identifier for a batch of Consensus Layer (CL) blocks being proven. /// Includes the starting and ending block heights. diff --git a/crates/proof-impl/cl-stf/src/lib.rs b/crates/proof-impl/cl-stf/src/lib.rs index 6eee08d35..f6a822fa6 100644 --- a/crates/proof-impl/cl-stf/src/lib.rs +++ b/crates/proof-impl/cl-stf/src/lib.rs @@ -84,23 +84,16 @@ fn apply_state_transition( state_cache.state().to_owned() } -pub fn process_cl_stf(zkvm: &impl ZkVmEnv, el_vkey: &[u32; 8]) { - let rollup_params: RollupParams = zkvm.read_serde(); - let (prev_state, block): (Chainstate, L2Block) = zkvm.read_borsh(); - let el_pp_deserialized: Vec = zkvm.read_verified_borsh(el_vkey); - - // The CL block currently includes only a single ExecSegment - assert_eq!( - el_pp_deserialized.len(), - 1, - "execsegment: expected exactly one" - ); - - let exec_update = el_pp_deserialized - .first() - .expect("execsegment: failed to fetch the first"); - - let new_state = verify_and_transition(prev_state.clone(), block, exec_update, &rollup_params); +#[inline] +fn process_cl_stf( + prev_state: Chainstate, + new_block: L2Block, + exec_update: &ExecSegment, + rollup_params: &RollupParams, + rollup_params_commitment: &Buf32, +) -> L2BatchProofOutput { + let new_state = + verify_and_transition(prev_state.clone(), new_block, exec_update, rollup_params); let initial_snapshot = ChainStateSnapshot { hash: prev_state.compute_state_root(), @@ -114,13 +107,65 @@ pub fn process_cl_stf(zkvm: &impl ZkVmEnv, el_vkey: &[u32; 8]) { l2_blockid: new_state.chain_tip_blockid(), }; - let cl_stf_public_params = L2BatchProofOutput { + L2BatchProofOutput { // TODO: Accumulate the deposits deposits: Vec::new(), - final_snapshot, initial_snapshot, - rollup_params_commitment: rollup_params.compute_hash(), + final_snapshot, + rollup_params_commitment: *rollup_params_commitment, + } +} + +pub fn batch_process_cl_stf(zkvm: &impl ZkVmEnv, el_vkey: &[u32; 8]) { + let rollup_params: RollupParams = zkvm.read_serde(); + let exec_updates: Vec = zkvm.read_verified_borsh(el_vkey); + let num_blocks: u32 = zkvm.read_serde(); + + assert!(num_blocks > 0, "At least one block is required."); + assert_eq!( + num_blocks as usize, + exec_updates.len(), + "Number of blocks and execution updates differ." + ); + + let (prev_state, new_block): (Chainstate, L2Block) = zkvm.read_borsh(); + let rollup_params_commitment = rollup_params.compute_hash(); + let initial_cl_update = process_cl_stf( + prev_state, + new_block, + &exec_updates[0], + &rollup_params, + &rollup_params_commitment, + ); + + let mut deposits = initial_cl_update.deposits.clone(); + let mut cl_update_acc = initial_cl_update.clone(); + + for exec_update in &exec_updates[1..] { + let (prev_state, new_block): (Chainstate, L2Block) = zkvm.read_borsh(); + let cl_update = process_cl_stf( + prev_state, + new_block, + exec_update, + &rollup_params, + &rollup_params_commitment, + ); + + assert_eq!( + cl_update.initial_snapshot.hash, cl_update_acc.final_snapshot.hash, + "Snapshot hash mismatch between consecutive updates." + ); + + deposits.extend_from_slice(&cl_update.deposits); + cl_update_acc = cl_update; + } + + let output = L2BatchProofOutput { + deposits, + initial_snapshot: initial_cl_update.initial_snapshot, + final_snapshot: cl_update_acc.final_snapshot, + rollup_params_commitment: cl_update_acc.rollup_params_commitment, }; - zkvm.commit_borsh(&cl_stf_public_params); + zkvm.commit_borsh(&output); } diff --git a/crates/proof-impl/cl-stf/src/prover.rs b/crates/proof-impl/cl-stf/src/prover.rs index 2c318825e..79af2f3b3 100644 --- a/crates/proof-impl/cl-stf/src/prover.rs +++ b/crates/proof-impl/cl-stf/src/prover.rs @@ -1,5 +1,4 @@ use strata_primitives::params::RollupParams; -use strata_state::{block::L2Block, chain_state::Chainstate}; use strata_zkvm::{ AggregationInput, ProofReceipt, PublicValues, VerificationKey, ZkVmInputResult, ZkVmProver, ZkVmResult, @@ -9,8 +8,7 @@ use crate::L2BatchProofOutput; pub struct ClStfInput { pub rollup_params: RollupParams, - pub pre_state: Chainstate, - pub l2_block: L2Block, + pub stf_witness_payloads: Vec>, pub evm_ee_proof: ProofReceipt, pub evm_ee_vk: VerificationKey, } @@ -29,14 +27,19 @@ impl ZkVmProver for ClStfProver { where B: strata_zkvm::ZkVmInputBuilder<'a>, { - B::new() - .write_serde(&input.rollup_params)? - .write_borsh(&(&input.pre_state, &input.l2_block))? - .write_proof(&AggregationInput::new( - input.evm_ee_proof.clone(), - input.evm_ee_vk.clone(), - ))? - .build() + let mut input_builder = B::new(); + input_builder.write_serde(&input.rollup_params)?; + input_builder.write_proof(&AggregationInput::new( + input.evm_ee_proof.clone(), + input.evm_ee_vk.clone(), + ))?; + + input_builder.write_serde(&input.stf_witness_payloads.len())?; + for cl_stf_input in &input.stf_witness_payloads { + input_builder.write_buf(cl_stf_input)?; + } + + input_builder.build() } fn process_output(public_values: &PublicValues) -> ZkVmResult diff --git a/crates/rpc/api/src/lib.rs b/crates/rpc/api/src/lib.rs index 4a72cf7c1..e2f96ab00 100644 --- a/crates/rpc/api/src/lib.rs +++ b/crates/rpc/api/src/lib.rs @@ -45,7 +45,7 @@ pub trait StrataApi { async fn get_exec_update_by_id(&self, block_id: L2BlockId) -> RpcResult>; #[method(name = "getCLBlockWitness")] - async fn get_cl_block_witness_raw(&self, index: u64) -> RpcResult>>; + async fn get_cl_block_witness_raw(&self, block_id: L2BlockId) -> RpcResult>; #[method(name = "getCurrentDeposits")] async fn get_current_deposits(&self) -> RpcResult>; diff --git a/crates/rpc/prover-client-api/Cargo.toml b/crates/rpc/prover-client-api/Cargo.toml index da65aee34..c2a598019 100644 --- a/crates/rpc/prover-client-api/Cargo.toml +++ b/crates/rpc/prover-client-api/Cargo.toml @@ -13,7 +13,8 @@ rust.unused_must_use = "deny" rustdoc.all = "warn" [dependencies] -strata-rpc-types = { workspace = true } +strata-primitives.workspace = true +strata-rpc-types.workspace = true jsonrpsee = { workspace = true, features = ["server", "macros"] } diff --git a/crates/rpc/prover-client-api/src/lib.rs b/crates/rpc/prover-client-api/src/lib.rs index e3ad19c2b..99f546588 100644 --- a/crates/rpc/prover-client-api/src/lib.rs +++ b/crates/rpc/prover-client-api/src/lib.rs @@ -1,6 +1,7 @@ //! Provides prover-client related APIs for the RPC server. use jsonrpsee::{core::RpcResult, proc_macros::rpc}; +use strata_primitives::{buf::Buf32, l2::L2BlockId}; use strata_rpc_types::ProofKey; /// RPCs related to information about the client itself. @@ -12,12 +13,15 @@ pub trait StrataProverClientApi { async fn prove_btc_block(&self, el_block_num: u64) -> RpcResult>; /// Start proving the given el block - #[method(name = "proveELBlock")] - async fn prove_el_block(&self, el_block_range: (u64, u64)) -> RpcResult>; + #[method(name = "proveElBlocks")] + async fn prove_el_blocks(&self, el_block_range: (Buf32, Buf32)) -> RpcResult>; /// Start proving the given cl block - #[method(name = "proveCLBlock")] - async fn prove_cl_block(&self, cl_block_num: u64) -> RpcResult>; + #[method(name = "proveClBlocks")] + async fn prove_cl_blocks( + &self, + cl_block_range: (L2BlockId, L2BlockId), + ) -> RpcResult>; /// Start proving the given l1 Batch #[method(name = "proveL1Batch")] @@ -25,7 +29,10 @@ pub trait StrataProverClientApi { /// Start proving the given l2 batch #[method(name = "proveL2Batch")] - async fn prove_l2_batch(&self, l2_range: (u64, u64)) -> RpcResult>; + async fn prove_l2_batch( + &self, + l2_range: Vec<(L2BlockId, L2BlockId)>, + ) -> RpcResult>; /// Start proving the given checkpoint info #[method(name = "proveCheckpointRaw")] diff --git a/crates/test-utils/src/evm_ee.rs b/crates/test-utils/src/evm_ee.rs index fa25fb569..2e1a4a39e 100644 --- a/crates/test-utils/src/evm_ee.rs +++ b/crates/test-utils/src/evm_ee.rs @@ -1,11 +1,14 @@ -use std::{collections::HashMap, path::PathBuf}; +use std::path::PathBuf; use strata_consensus_logic::genesis::make_genesis_block; use strata_primitives::buf::{Buf32, Buf64}; use strata_proofimpl_cl_stf::{Chainstate, StateCache}; use strata_proofimpl_evm_ee_stf::{ - process_block_transaction, processor::EvmConfig, utils::generate_exec_update, EvmBlockStfInput, - EvmBlockStfOutput, + primitives::{EvmEeProofInput, EvmEeProofOutput}, + process_block_transaction, + processor::EvmConfig, + utils::generate_exec_update, + EvmBlockStfInput, }; use strata_state::{ block::{L1Segment, L2Block, L2BlockBody}, @@ -19,8 +22,8 @@ use crate::l2::{gen_params, get_genesis_chainstate}; /// generation and processing for testing STF proofs. #[derive(Debug, Clone)] pub struct EvmSegment { - inputs: HashMap, - outputs: HashMap, + inputs: EvmEeProofInput, + outputs: EvmEeProofOutput, } impl EvmSegment { @@ -39,8 +42,8 @@ impl EvmSegment { spec_id: SpecId::SHANGHAI, }; - let mut inputs = HashMap::new(); - let mut outputs = HashMap::new(); + let mut inputs = Vec::new(); + let mut outputs = Vec::new(); let dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("data/evm_ee/"); for height in start_height..=end_height { @@ -48,27 +51,24 @@ impl EvmSegment { let json_file = std::fs::read_to_string(witness_path).expect("Expected JSON file"); let el_proof_input: EvmBlockStfInput = serde_json::from_str(&json_file).expect("Invalid JSON file"); - inputs.insert(height, el_proof_input.clone()); + inputs.push(el_proof_input.clone()); - let output = process_block_transaction(el_proof_input, EVM_CONFIG); - outputs.insert(height, output); + let block_stf_output = process_block_transaction(el_proof_input, EVM_CONFIG); + let exec_output = generate_exec_update(&block_stf_output); + outputs.push(exec_output); } Self { inputs, outputs } } - /// Retrieves the [`EvmBlockStfInput`] associated with the given block height. - /// - /// Panics if no input is found for the specified height. - pub fn get_input(&self, height: &u64) -> &EvmBlockStfInput { - self.inputs.get(height).expect("No input found at height") + /// Retrieves the [`EvmEeProofInput`] + pub fn get_inputs(&self) -> &EvmEeProofInput { + &self.inputs } - /// Retrieves the [`EvmBlockStfOutput`] associated with the given block height. - /// - /// Panics if no output is found for the specified height. - pub fn get_output(&self, height: &u64) -> &EvmBlockStfOutput { - self.outputs.get(height).expect("No output found at height") + /// Retrieves the [`EvmEeProofOutput`] + pub fn get_outputs(&self) -> &EvmEeProofOutput { + &self.outputs } } @@ -76,9 +76,9 @@ impl EvmSegment { /// This struct stores L2 blocks, pre-state, and post-state data, simulating /// the block processing for testing STF proofs. pub struct L2Segment { - blocks: HashMap, - pre_states: HashMap, - post_states: HashMap, + pub blocks: Vec, + pub pre_states: Vec, + pub post_states: Vec, } impl L2Segment { @@ -88,23 +88,23 @@ impl L2Segment { /// /// This function ensures that valid L2 segments and blocks are generated for testing /// the STF proofs by simulating state transitions from a starting genesis state. - pub fn initialize_from_saved_evm_ee_data(end_height: u64) -> Self { - let evm_segment = EvmSegment::initialize_from_saved_ee_data(1, 4); + pub fn initialize_from_saved_evm_ee_data(start_block: u64, end_block: u64) -> Self { + let evm_segment = EvmSegment::initialize_from_saved_ee_data(start_block, end_block); let params = gen_params(); - let mut blocks = HashMap::new(); - let mut pre_states = HashMap::new(); - let mut post_states = HashMap::new(); + let mut blocks = Vec::new(); + let mut pre_states = Vec::new(); + let mut post_states = Vec::new(); let mut prev_block = make_genesis_block(¶ms).block().clone(); let mut prev_chainstate = get_genesis_chainstate(); - for height in 1..=end_height { - let el_proof_in = evm_segment.get_input(&height); - let el_proof_out = evm_segment.get_output(&height); - let evm_ee_segment = generate_exec_update(el_proof_out); + let el_proof_ins = evm_segment.get_inputs(); + let el_proof_outs = evm_segment.get_outputs(); + + for (el_proof_in, el_proof_out) in el_proof_ins.iter().zip(el_proof_outs.iter()) { let l1_segment = L1Segment::new_empty(); - let body = L2BlockBody::new(l1_segment, evm_ee_segment); + let body = L2BlockBody::new(l1_segment, el_proof_out.clone()); let slot = prev_block.header().blockidx() + 1; let ts = el_proof_in.timestamp; @@ -139,9 +139,9 @@ impl L2Segment { .unwrap(); let (post_state, _) = state_cache.finalize(); - blocks.insert(height, block.clone()); - pre_states.insert(height, pre_state); - post_states.insert(height, post_state.clone()); + blocks.push(block.clone()); + pre_states.push(pre_state); + post_states.push(post_state.clone()); prev_block = block; prev_chainstate = post_state; @@ -153,31 +153,6 @@ impl L2Segment { post_states, } } - - /// Retrieves the L2Block associated with the given block height. - /// - /// Panics if no block is found for the specified height. - pub fn get_block(&self, height: u64) -> &L2Block { - self.blocks.get(&height).expect("Not block found at height") - } - - /// Retrieves the pre-state Chainstate for the given block height. - /// - /// Panics if no pre-state is found for the specified height. - pub fn get_pre_state(&self, height: u64) -> &Chainstate { - self.pre_states - .get(&height) - .expect("Not chain state found at height") - } - - /// Retrieves the post-state Chainstate for the given block height. - /// - /// Panics if no post-state is found for the specified height. - pub fn get_post_state(&self, height: u64) -> &Chainstate { - self.post_states - .get(&height) - .expect("Not chain state found at height") - } } #[cfg(test)] @@ -186,12 +161,13 @@ mod tests { #[test] fn test_chaintsn() { + let start_height = 1; let end_height = 4; - let l2_segment = L2Segment::initialize_from_saved_evm_ee_data(end_height); + let l2_segment = L2Segment::initialize_from_saved_evm_ee_data(start_height, end_height); - for height in 1..end_height { - let pre_state = l2_segment.get_pre_state(height + 1); - let post_state = l2_segment.get_post_state(height); + for height in start_height..end_height - 1 { + let pre_state = &l2_segment.pre_states[height as usize + 1]; + let post_state = &l2_segment.post_states[height as usize]; assert_eq!(pre_state, post_state); } } diff --git a/crates/zkvm/hosts/src/native.rs b/crates/zkvm/hosts/src/native.rs index 424192034..155422741 100644 --- a/crates/zkvm/hosts/src/native.rs +++ b/crates/zkvm/hosts/src/native.rs @@ -4,7 +4,7 @@ use strata_native_zkvm_adapter::{NativeHost, NativeMachine}; use strata_proofimpl_btc_blockspace::logic::process_blockspace_proof_outer; use strata_proofimpl_checkpoint::process_checkpoint_proof_outer; use strata_proofimpl_cl_agg::process_cl_agg; -use strata_proofimpl_cl_stf::process_cl_stf; +use strata_proofimpl_cl_stf::batch_process_cl_stf; use strata_proofimpl_evm_ee_stf::process_block_transaction_outer; use strata_proofimpl_l1_batch::process_l1_batch_proof; @@ -43,7 +43,7 @@ static EVM_EE_STF_HOST: LazyLock = std::sync::LazyLock::new(|| Nativ /// A native host for [`ProofVm::CLProving`] prover. static CL_STF_HOST: LazyLock = std::sync::LazyLock::new(|| NativeHost { process_proof: Arc::new(Box::new(move |zkvm: &NativeMachine| { - process_cl_stf(zkvm, &MOCK_VK); + batch_process_cl_stf(zkvm, &MOCK_VK); Ok(()) })), }); diff --git a/functional-tests/fn_cl_block_witness.py b/functional-tests/fn_cl_block_witness.py index 4a3f7a37c..08c147406 100644 --- a/functional-tests/fn_cl_block_witness.py +++ b/functional-tests/fn_cl_block_witness.py @@ -23,11 +23,17 @@ def main(self, ctx: flexitest.RunContext): error_with="Sequencer did not start on time", ) - witness_1 = seqrpc.strata_getCLBlockWitness(1) + witness_1 = self.get_witness(seqrpc, 1) assert witness_1 is not None time.sleep(1) - witness_2 = seqrpc.strata_getCLBlockWitness(2) + witness_2 = self.get_witness(seqrpc, 2) assert witness_2 is not None return True + + def get_witness(self, seqrpc, idx): + block_ids = seqrpc.strata_getHeadersAtIdx(idx) + block_id = block_ids[0]["block_id"] + witness = seqrpc.strata_getCLBlockWitness(block_id) + return witness diff --git a/functional-tests/fn_prover_cl_batch_dispatch.py b/functional-tests/fn_prover_cl_batch_dispatch.py index 9e98a4d3c..0ae711e6b 100644 --- a/functional-tests/fn_prover_cl_batch_dispatch.py +++ b/functional-tests/fn_prover_cl_batch_dispatch.py @@ -5,6 +5,18 @@ import testenv from utils import wait_for_proof_with_time_out +# Parameters defining the aggeration of the CL blocks +CL_AGG_PARAMS = [ + { + "start_block": 1, + "end_block": 2, + }, + { + "start_block": 3, + "end_block": 4, + }, +] + @flexitest.register class ProverClientTest(testenv.StrataTester): @@ -13,12 +25,21 @@ def __init__(self, ctx: flexitest.InitContext): def main(self, ctx: flexitest.RunContext): prover_client = ctx.get_service("prover_client") + seq = ctx.get_service("sequencer") + prover_client_rpc = prover_client.create_rpc() + seqrpc = seq.create_rpc() # Wait for the Prover Manager setup time.sleep(5) - task_ids = prover_client_rpc.dev_strata_proveL2Batch((1, 2)) + batches = [] + for batch_info in CL_AGG_PARAMS: + start_block_id = self.blockidx_2_blockid(seqrpc, batch_info["start_block"]) + end_block_id = self.blockidx_2_blockid(seqrpc, batch_info["end_block"]) + batches.append((start_block_id, end_block_id)) + + task_ids = prover_client_rpc.dev_strata_proveL2Batch(batches) self.debug(f"got task ids: {task_ids}") task_id = task_ids[0] self.debug(f"using task id: {task_id}") @@ -26,3 +47,7 @@ def main(self, ctx: flexitest.RunContext): time_out = 10 * 60 wait_for_proof_with_time_out(prover_client_rpc, task_id, time_out=time_out) + + def blockidx_2_blockid(self, seqrpc, blockidx): + l2_blks = seqrpc.strata_getHeadersAtIdx(blockidx) + return l2_blks[0]["block_id"] diff --git a/functional-tests/fn_prover_cl_dispatch.py b/functional-tests/fn_prover_cl_dispatch.py index 4a01d5591..fb7e73fa6 100644 --- a/functional-tests/fn_prover_cl_dispatch.py +++ b/functional-tests/fn_prover_cl_dispatch.py @@ -3,7 +3,13 @@ import flexitest import testenv -from utils import wait_for_proof_with_time_out +from utils import cl_slot_to_block_id, wait_for_proof_with_time_out + +# Parameters defining the range of Execution Engine (EE) blocks to be proven. +CL_PROVER_PARAMS = { + "start_block": 1, + "end_block": 2, +} @flexitest.register @@ -13,15 +19,21 @@ def __init__(self, ctx: flexitest.InitContext): def main(self, ctx: flexitest.RunContext): prover_client = ctx.get_service("prover_client") + seq = ctx.get_service("sequencer") + prover_client_rpc = prover_client.create_rpc() + seqrpc = seq.create_rpc() # Wait for the Prover Manager setup time.sleep(5) # Dispatch the prover task - task_ids = prover_client_rpc.dev_strata_proveCLBlock(1) - self.debug(f"got task ids: {task_ids}") + start_block_id = cl_slot_to_block_id(seqrpc, CL_PROVER_PARAMS["start_block"]) + end_block_id = cl_slot_to_block_id(seqrpc, CL_PROVER_PARAMS["end_block"]) + + task_ids = prover_client_rpc.dev_strata_proveClBlocks((start_block_id, end_block_id)) task_id = task_ids[0] + self.debug(f"using task id: {task_id}") assert task_id is not None diff --git a/functional-tests/fn_prover_el_dispatch.py b/functional-tests/fn_prover_el_dispatch.py index c612e3997..4c247dda9 100644 --- a/functional-tests/fn_prover_el_dispatch.py +++ b/functional-tests/fn_prover_el_dispatch.py @@ -3,7 +3,7 @@ import flexitest import testenv -from utils import wait_for_proof_with_time_out +from utils import el_slot_to_block_id, wait_for_proof_with_time_out # Parameters defining the range of Execution Engine (EE) blocks to be proven. EE_PROVER_PARAMS = { @@ -20,12 +20,17 @@ def __init__(self, ctx: flexitest.InitContext): def main(self, ctx: flexitest.RunContext): prover_client = ctx.get_service("prover_client") prover_client_rpc = prover_client.create_rpc() + reth = ctx.get_service("reth") + rethrpc = reth.create_rpc() # Wait for the some block building time.sleep(5) # Dispatch the prover task - task_ids = prover_client_rpc.dev_strata_proveELBlock((1, 2)) + start_block_id = el_slot_to_block_id(rethrpc, EE_PROVER_PARAMS["start_block"]) + end_block_id = el_slot_to_block_id(rethrpc, EE_PROVER_PARAMS["end_block"]) + + task_ids = prover_client_rpc.dev_strata_proveElBlocks((start_block_id, end_block_id)) self.debug(f"got task ids: {task_ids}") task_id = task_ids[0] self.debug(f"using task id: {task_id}") diff --git a/functional-tests/utils.py b/functional-tests/utils.py index 56b79435f..4d61eb356 100644 --- a/functional-tests/utils.py +++ b/functional-tests/utils.py @@ -507,3 +507,14 @@ def submit_da_blob(btcrpc: BitcoindClient, seqrpc: JsonrpcClient, blobdata: str) timeout=10, ) return tx + + +def cl_slot_to_block_id(seqrpc, slot): + """Convert L2 slot number to block ID.""" + l2_blocks = seqrpc.strata_getHeadersAtIdx(slot) + return l2_blocks[0]["block_id"] + + +def el_slot_to_block_id(rethrpc, block_num): + """Get EL block hash from block number using Ethereum RPC.""" + return rethrpc.eth_getBlockByNumber(hex(block_num), False)["hash"] diff --git a/provers/perf/src/main.rs b/provers/perf/src/main.rs index 9a961f840..46c84fdab 100644 --- a/provers/perf/src/main.rs +++ b/provers/perf/src/main.rs @@ -97,7 +97,7 @@ fn run_generator_programs( let btc_block_id = 40321; let btc_chain = get_btc_chain(); let btc_block = btc_chain.get_block(btc_block_id); - let strata_block_id = 1; + let evmee_block_range = (1, 1); // btc_blockspace println!("Generating a report for BTC_BLOCKSPACE"); @@ -112,7 +112,7 @@ fn run_generator_programs( println!("Generating a report for EL_BLOCK"); let el_block = generator.el_block(); let el_block_report = el_block - .gen_proof_report(&strata_block_id, "EL_BLOCK".to_owned()) + .gen_proof_report(&evmee_block_range, "EL_BLOCK".to_owned()) .unwrap(); reports.push(el_block_report.into()); @@ -121,7 +121,7 @@ fn run_generator_programs( println!("Generating a report for CL_BLOCK"); let cl_block = generator.cl_block(); let cl_block_report = cl_block - .gen_proof_report(&strata_block_id, "CL_BLOCK".to_owned()) + .gen_proof_report(&evmee_block_range, "CL_BLOCK".to_owned()) .unwrap(); reports.push(cl_block_report.into()); @@ -138,8 +138,9 @@ fn run_generator_programs( // l2_block println!("Generating a report for L2_BATCH"); let l2_block = generator.l2_batch(); + let l2_mini_batches = vec![(l2_start_height, l2_end_height)]; let l2_block_report = l2_block - .gen_proof_report(&(l2_start_height, l2_end_height), "L2_BATCH".to_owned()) + .gen_proof_report(&l2_mini_batches, "L2_BATCH".to_owned()) .unwrap(); reports.push(l2_block_report.into()); diff --git a/provers/risc0/guest-cl-stf/src/main.rs b/provers/risc0/guest-cl-stf/src/main.rs index 47d9027b8..f12f62555 100644 --- a/provers/risc0/guest-cl-stf/src/main.rs +++ b/provers/risc0/guest-cl-stf/src/main.rs @@ -1,4 +1,4 @@ -use strata_proofimpl_cl_stf::process_cl_stf; +use strata_proofimpl_cl_stf::batch_process_cl_stf; use strata_risc0_adapter::Risc0ZkVmEnv; // TODO: replace this with vks file that'll generated by build.rs script similar to how things are @@ -6,5 +6,5 @@ use strata_risc0_adapter::Risc0ZkVmEnv; pub const GUEST_EVM_EE_STF_ELF_ID: &[u32; 8] = &[0, 0, 0, 0, 0, 0, 0, 0]; fn main() { - process_cl_stf(&Risc0ZkVmEnv, GUEST_EVM_EE_STF_ELF_ID); + batch_process_cl_stf(&Risc0ZkVmEnv, GUEST_EVM_EE_STF_ELF_ID); } diff --git a/provers/sp1/guest-cl-stf/src/main.rs b/provers/sp1/guest-cl-stf/src/main.rs index 1a7048c7b..3000d87fa 100644 --- a/provers/sp1/guest-cl-stf/src/main.rs +++ b/provers/sp1/guest-cl-stf/src/main.rs @@ -1,8 +1,8 @@ -use strata_proofimpl_cl_stf::process_cl_stf; +use strata_proofimpl_cl_stf::batch_process_cl_stf; use strata_sp1_adapter::Sp1ZkVmEnv; mod vks; fn main() { - process_cl_stf(&Sp1ZkVmEnv, vks::GUEST_EVM_EE_STF_ELF_ID); + batch_process_cl_stf(&Sp1ZkVmEnv, vks::GUEST_EVM_EE_STF_ELF_ID); } diff --git a/provers/tests/src/checkpoint.rs b/provers/tests/src/checkpoint.rs index 482368eaf..6f7271a12 100644 --- a/provers/tests/src/checkpoint.rs +++ b/provers/tests/src/checkpoint.rs @@ -42,6 +42,7 @@ impl ProofGenerator for CheckpointProofGenerator { let (l1_start_height, l1_end_height) = batch_info.l1_range; let (l2_start_height, l2_end_height) = batch_info.l2_range; + let cl_batches = vec![(l2_start_height, l2_end_height)]; let l1_batch_proof = self .l1_batch_prover @@ -50,10 +51,7 @@ impl ProofGenerator for CheckpointProofGenerator { let l1_batch_vk = self.l1_batch_prover.get_host().get_verification_key(); let l1_batch = AggregationInput::new(l1_batch_proof, l1_batch_vk); - let l2_batch_proof = self - .l2_batch_prover - .get_proof(&(l2_start_height, l2_end_height)) - .unwrap(); + let l2_batch_proof = self.l2_batch_prover.get_proof(&cl_batches).unwrap(); let l2_batch_vk = self.l2_batch_prover.get_host().get_verification_key(); let l2_batch = AggregationInput::new(l2_batch_proof, l2_batch_vk); diff --git a/provers/tests/src/cl.rs b/provers/tests/src/cl.rs index 2b7f0c319..db28c61d0 100644 --- a/provers/tests/src/cl.rs +++ b/provers/tests/src/cl.rs @@ -20,33 +20,38 @@ impl ClProofGenerator { } impl ProofGenerator for ClProofGenerator { - type Input = u64; + type Input = (u64, u64); type P = ClStfProver; type H = H; - fn get_input(&self, block_num: &u64) -> ZkVmResult { + fn get_input(&self, block_range: &(u64, u64)) -> ZkVmResult { // Generate EL proof required for aggregation - let el_proof = self.el_proof_generator.get_proof(block_num)?; + let el_proof = self.el_proof_generator.get_proof(block_range)?; // Read CL witness data let params = gen_params(); let rollup_params = params.rollup(); - let l2_segment = L2Segment::initialize_from_saved_evm_ee_data(*block_num); - let l2_block = l2_segment.get_block(*block_num); - let pre_state = l2_segment.get_pre_state(*block_num); + let l2_segment = L2Segment::initialize_from_saved_evm_ee_data(block_range.0, block_range.1); + let l2_blocks = l2_segment.blocks; + let pre_states = l2_segment.pre_states; + + let mut stf_witness_payloads = Vec::new(); + for (block, pre_state) in l2_blocks.iter().zip(pre_states.iter()) { + let witness = borsh::to_vec(&(pre_state, block)).unwrap(); + stf_witness_payloads.push(witness); + } Ok(ClStfInput { rollup_params: rollup_params.clone(), - pre_state: pre_state.clone(), - l2_block: l2_block.clone(), + stf_witness_payloads, evm_ee_proof: el_proof, evm_ee_vk: self.el_proof_generator.get_host().get_verification_key(), }) } - fn get_proof_id(&self, block_num: &u64) -> String { - format!("cl_block_{}", block_num) + fn get_proof_id(&self, block_range: &(u64, u64)) -> String { + format!("cl_block_{}_{}", block_range.0, block_range.1) } fn get_host(&self) -> H { @@ -59,9 +64,10 @@ mod tests { use super::*; fn test_proof(cl_prover: &ClProofGenerator) { - let height = 1; + let start_height = 1; + let end_height = 3; - let _ = cl_prover.get_proof(&height).unwrap(); + let _ = cl_prover.get_proof(&(start_height, end_height)).unwrap(); } #[test] diff --git a/provers/tests/src/el.rs b/provers/tests/src/el.rs index 08b1a264e..9acaa43a4 100644 --- a/provers/tests/src/el.rs +++ b/provers/tests/src/el.rs @@ -16,19 +16,19 @@ impl ElProofGenerator { } impl ProofGenerator for ElProofGenerator { - type Input = u64; + type Input = (u64, u64); type P = EvmEeProver; type H = H; - fn get_input(&self, block_num: &u64) -> ZkVmResult { - let input = EvmSegment::initialize_from_saved_ee_data(*block_num, *block_num) - .get_input(block_num) - .clone(); - Ok(vec![input]) + fn get_input(&self, block_range: &(u64, u64)) -> ZkVmResult { + let (start_block, end_block) = block_range; + let evm_segment = EvmSegment::initialize_from_saved_ee_data(*start_block, *end_block); + + Ok(evm_segment.get_inputs().clone()) } - fn get_proof_id(&self, block_num: &u64) -> String { - format!("el_{}", block_num) + fn get_proof_id(&self, block_range: &(u64, u64)) -> String { + format!("el_{}_{}", block_range.0, block_range.1) } fn get_host(&self) -> H { @@ -42,8 +42,9 @@ mod tests { use super::*; fn test_proof(el_prover: &ElProofGenerator) { - let height = 1; - let _ = el_prover.get_proof(&height).unwrap(); + let start_height = 1; + let end_height = 2; + let _ = el_prover.get_proof(&(start_height, end_height)).unwrap(); } #[test] diff --git a/provers/tests/src/l2_batch.rs b/provers/tests/src/l2_batch.rs index 39705aaa8..5cc13c3f0 100644 --- a/provers/tests/src/l2_batch.rs +++ b/provers/tests/src/l2_batch.rs @@ -19,16 +19,18 @@ impl L2BatchProofGenerator { } impl ProofGenerator for L2BatchProofGenerator { - type Input = (u64, u64); + type Input = Vec<(u64, u64)>; type P = ClAggProver; type H = H; - fn get_input(&self, heights: &(u64, u64)) -> ZkVmResult { - let (start_height, end_height) = *heights; + fn get_input(&self, batches: &Self::Input) -> ZkVmResult { let mut batch = Vec::new(); - for block_num in start_height..=end_height { - let cl_proof = self.cl_proof_generator.get_proof(&block_num)?; + for mini_batch_range in batches { + let (start_height, end_height) = *mini_batch_range; + let cl_proof = self + .cl_proof_generator + .get_proof(&(start_height, end_height))?; batch.push(cl_proof); } @@ -36,9 +38,12 @@ impl ProofGenerator for L2BatchProofGenerator { Ok(ClAggInput { batch, cl_stf_vk }) } - fn get_proof_id(&self, heights: &(u64, u64)) -> String { - let (start_height, end_height) = *heights; - format!("l2_batch_{}_{}", start_height, end_height) + fn get_proof_id(&self, batches: &Self::Input) -> String { + if let (Some(first), Some(last)) = (batches.first(), batches.last()) { + format!("cl_batch_{}_{}", first.0, last.1) + } else { + "cl_batch_empty".to_string() + } } fn get_host(&self) -> H { @@ -51,7 +56,7 @@ mod tests { use super::*; fn test_proof(cl_agg_prover: &L2BatchProofGenerator) { - let _ = cl_agg_prover.get_proof(&(1, 3)).unwrap(); + let _ = cl_agg_prover.get_proof(&vec![(1, 3)]).unwrap(); } #[test]