From 9690b083aab3b9bc4ffb4b5c849735cdff544fbe Mon Sep 17 00:00:00 2001 From: Marko Atanasievski Date: Thu, 7 Nov 2024 17:02:25 +0100 Subject: [PATCH 1/8] feat: revamp zero prove function --- zero/src/ops.rs | 6 +- zero/src/prover.rs | 157 ++++++++++++++++++++++++++++++++++++++------- 2 files changed, 137 insertions(+), 26 deletions(-) diff --git a/zero/src/ops.rs b/zero/src/ops.rs index 6d6dd8403..4be8ffe83 100644 --- a/zero/src/ops.rs +++ b/zero/src/ops.rs @@ -21,7 +21,7 @@ use crate::{debug_utils::save_inputs_to_disk, prover_state::p_state}; registry!(); -#[derive(Deserialize, Serialize, RemoteExecute)] +#[derive(Deserialize, Serialize, RemoteExecute, Clone)] pub struct SegmentProof { pub save_inputs_on_error: bool, } @@ -207,7 +207,7 @@ impl Drop for SegmentProofSpan { } } -#[derive(Deserialize, Serialize, RemoteExecute)] +#[derive(Deserialize, Serialize, RemoteExecute, Clone)] pub struct SegmentAggProof { pub save_inputs_on_error: bool, } @@ -289,7 +289,7 @@ impl Monoid for SegmentAggProof { } } -#[derive(Deserialize, Serialize, RemoteExecute)] +#[derive(Deserialize, Serialize, RemoteExecute, Clone)] pub struct BatchAggProof { pub save_inputs_on_error: bool, } diff --git a/zero/src/prover.rs b/zero/src/prover.rs index 7cc840f02..02fd12f0c 100644 --- a/zero/src/prover.rs +++ b/zero/src/prover.rs @@ -11,7 +11,8 @@ use anyhow::{Context, Result}; use evm_arithmetization::Field; use evm_arithmetization::SegmentDataIterator; use futures::{ - future, future::BoxFuture, stream::FuturesUnordered, FutureExt, TryFutureExt, TryStreamExt, + future, future::BoxFuture, stream::FuturesUnordered, FutureExt, StreamExt, TryFutureExt, + TryStreamExt, }; use hashbrown::HashMap; use num_traits::ToPrimitive as _; @@ -23,10 +24,10 @@ use plonky2::plonk::circuit_data::CircuitConfig; use serde::{Deserialize, Serialize}; use tokio::io::AsyncWriteExt; use tokio::sync::mpsc::Receiver; -use tokio::sync::{oneshot, Semaphore}; +use tokio::sync::{mpsc, oneshot, Semaphore}; use trace_decoder::observer::DummyObserver; use trace_decoder::{BlockTrace, OtherBlockData, WireDisposition}; -use tracing::{error, info}; +use tracing::{debug, error, info}; use crate::fs::generate_block_proof_file_name; use crate::ops; @@ -116,6 +117,8 @@ impl BlockProverInput { WIRE_DISPOSITION, )?; + let batch_count = block_generation_inputs.len(); + // Create segment proof. let seg_prove_ops = ops::SegmentProof { save_inputs_on_error, @@ -131,29 +134,137 @@ impl BlockProverInput { save_inputs_on_error, }; - // Segment the batches, prove segments and aggregate them to resulting batch - // proofs. - let batch_proof_futs: FuturesUnordered<_> = block_generation_inputs - .iter() - .enumerate() - .map(|(idx, txn_batch)| { - let segment_data_iterator = - SegmentDataIterator::::new(txn_batch, Some(max_cpu_len_log)); - - Directive::map(IndexedStream::from(segment_data_iterator), &seg_prove_ops) - .fold(&seg_agg_ops) - .run(&proof_runtime.heavy_proof) - .map(move |e| { - e.map(|p| (idx, crate::proof_types::BatchAggregatableProof::from(p))) - }) + // Generate channels to communicate segments of each batch to proving task + let (segment_senders, segment_receivers): (Vec<_>, Vec<_>) = (0..batch_count) + .map(|_idx| { + let (segment_tx, segment_rx) = + mpsc::channel::>(1); + (segment_tx, segment_rx) }) - .collect(); + .unzip(); + + let (batch_proof_tx, mut batch_proof_rx) = + mpsc::channel::<(usize, crate::proof_types::BatchAggregatableProof)>(32); + + // Span a task for each batch to generate segments for that batch + // and send them to the proving task. + let _segment_generation_task = tokio::spawn(async move { + let mut batch_segment_futures: FuturesUnordered<_> = FuturesUnordered::new(); + + for (batch_idx, (txn_batch, segment_tx)) in block_generation_inputs + .into_iter() + .zip(segment_senders) + .enumerate() + { + batch_segment_futures.push(async move { + let segment_data_iterator = + SegmentDataIterator::::new(&txn_batch, Some(max_cpu_len_log)); + for (segment_idx, segment_data) in segment_data_iterator.enumerate() { + segment_tx + .send(Some(segment_data)) + .await + .context(format!("failed to send segment data for batch {batch_idx} segment {segment_idx}"))?; + } + // Mark the end of the batch segments by sending None + segment_tx + .send(None) + .await + .context(format!("failed to send end segment data indicator for batch {batch_idx}"))?; + Ok::<(), anyhow::Error>(()) + }); + } + while let Some(it) = batch_segment_futures.next().await { + // In case of error, propagate the error to the main task + it?; + } + Ok::<(), anyhow::Error>(()) + }); + + let proof_runtime_ = proof_runtime.clone(); + let _proving_task = tokio::spawn(async move { + let mut batch_proving_futures: FuturesUnordered<_> = FuturesUnordered::new(); + // Span a proving task for each batch to generate segment proofs + // and aggregate them to batch proof. + for (batch_idx, mut segment_rx) in segment_receivers.into_iter().enumerate() { + let batch_proof_tx = batch_proof_tx.clone(); + let seg_prove_ops = seg_prove_ops.clone(); + let seg_agg_ops = seg_agg_ops.clone(); + let proof_runtime = proof_runtime_.clone(); + // Tasks to dispatch and aggregate one batch + batch_proving_futures.push(async move { + // let mut pair_segment_data = Vec::new(); + let mut seg_aggregatable_proofs = Vec::new(); + // Wait for segments and dispatch them to the segment proof worker task + // There will always be pair number of segments, so we dispatch two segments + // and aggregate them as one chained directive to save + // a bit on local and transported data size. + let mut segment_counter = 0; + while let Some(segment_data) = segment_rx.recv().await { + // Prove the segment + if let Some(segment_data) = segment_data { + debug!("proving the batch {batch_idx} segment data {segment_counter}"); + let seg_aggregatable_proof = Directive::map( + IndexedStream::from([segment_data]), + &seg_prove_ops, + ) + // .fold(&seg_agg_ops) + .run(&proof_runtime.heavy_proof) + .await?; + let seg_aggregatable_proof = seg_aggregatable_proof.into_values_sorted().await?; + seg_aggregatable_proofs.extend(seg_aggregatable_proof); + } + segment_counter += 1; + } + debug!(block_number=%block_number, batch=%batch_idx, "finished proving all segments"); + // We have received and proved all the segments, + // now we need to aggregate to the batch proof. + // Fold the batch aggregated proof stream into a single proof. + let batch_proof = if seg_aggregatable_proofs.len() > 1 { + Directive::fold(IndexedStream::from(seg_aggregatable_proofs), &seg_agg_ops) + .run(&proof_runtime.light_proof) + .map(move |e| { + e.map(|p| { + ( + batch_idx, + crate::proof_types::BatchAggregatableProof::from(p), + ) + }) + }) + .await? + } else { + // If there is only one segment aggregated proof, just transform it to batch proof + (batch_idx, crate::proof_types::BatchAggregatableProof::from( + seg_aggregatable_proofs.pop().unwrap(), + )) + }; + debug!(block_number=%block_number, batch=%batch_idx, "generated batch proof for block"); + batch_proof_tx.send(batch_proof).await.context(format!( + "unable to send batch proof, block: {block_number}, batch: {batch_idx}" + ))?; + Ok::<(), anyhow::Error>(()) + }); + } + // Wait for all the batch proving tasks to finish + while let Some(it) = batch_proving_futures.next().await { + it.context("Unable to send batch proof {}")?; + } + Ok::<(), anyhow::Error>(()) + }); + + // Collect all the batch proofs for proving tasks + let mut batch_proofs: Vec<(usize, crate::proof_types::BatchAggregatableProof)> = Vec::new(); + while let Some((batch_idx, batch_proof)) = batch_proof_rx.recv().await { + batch_proofs.push((batch_idx, batch_proof)); + } + batch_proofs.sort_by(|(a, _), (b, _)| a.cmp(b)); // Fold the batch aggregated proof stream into a single proof. - let final_batch_proof = - Directive::fold(IndexedStream::new(batch_proof_futs), &batch_agg_ops) - .run(&proof_runtime.light_proof) - .await?; + let final_batch_proof = Directive::fold( + IndexedStream::from(batch_proofs.into_iter().map(|(_, it)| it)), + &batch_agg_ops, + ) + .run(&proof_runtime.light_proof) + .await?; if let crate::proof_types::BatchAggregatableProof::BatchAgg(proof) = final_batch_proof { let block_number = block_number From 827b905e99246bd5b79abd6515f70e355168a126 Mon Sep 17 00:00:00 2001 From: Marko Atanasievski Date: Tue, 12 Nov 2024 14:38:10 +0100 Subject: [PATCH 2/8] fix: improvements --- zero/src/prover.rs | 69 +++++++++++++++++++++++++--------------------- 1 file changed, 38 insertions(+), 31 deletions(-) diff --git a/zero/src/prover.rs b/zero/src/prover.rs index 02fd12f0c..cb2f2bfa0 100644 --- a/zero/src/prover.rs +++ b/zero/src/prover.rs @@ -10,6 +10,7 @@ use alloy::primitives::U256; use anyhow::{Context, Result}; use evm_arithmetization::Field; use evm_arithmetization::SegmentDataIterator; +use futures::future::try_join_all; use futures::{ future, future::BoxFuture, stream::FuturesUnordered, FutureExt, StreamExt, TryFutureExt, TryStreamExt, @@ -134,7 +135,10 @@ impl BlockProverInput { save_inputs_on_error, }; - // Generate channels to communicate segments of each batch to proving task + // Generate channels to communicate segments of each batch to a batch proving + // task. We generate a segment and send it to the proving task, then + // wait for it to be proved before we send another segment. This is done + // to avoid memory exhaustion. let (segment_senders, segment_receivers): (Vec<_>, Vec<_>) = (0..batch_count) .map(|_idx| { let (segment_tx, segment_rx) = @@ -143,12 +147,14 @@ impl BlockProverInput { }) .unzip(); + // Size of this channel does not matter mach, as it is used to collect batch + // proofs. let (batch_proof_tx, mut batch_proof_rx) = mpsc::channel::<(usize, crate::proof_types::BatchAggregatableProof)>(32); - // Span a task for each batch to generate segments for that batch + // Spin up a task for each batch to generate segments for that batch // and send them to the proving task. - let _segment_generation_task = tokio::spawn(async move { + let segment_generation_task = tokio::spawn(async move { let mut batch_segment_futures: FuturesUnordered<_> = FuturesUnordered::new(); for (batch_idx, (txn_batch, segment_tx)) in block_generation_inputs @@ -174,53 +180,54 @@ impl BlockProverInput { }); } while let Some(it) = batch_segment_futures.next().await { - // In case of error, propagate the error to the main task + // In case of an error, propagate the error to the main task it?; } Ok::<(), anyhow::Error>(()) }); let proof_runtime_ = proof_runtime.clone(); - let _proving_task = tokio::spawn(async move { + let proving_task = tokio::spawn(async move { let mut batch_proving_futures: FuturesUnordered<_> = FuturesUnordered::new(); - // Span a proving task for each batch to generate segment proofs + // Span a proving subtask for each batch where we generate segment proofs // and aggregate them to batch proof. for (batch_idx, mut segment_rx) in segment_receivers.into_iter().enumerate() { let batch_proof_tx = batch_proof_tx.clone(); let seg_prove_ops = seg_prove_ops.clone(); let seg_agg_ops = seg_agg_ops.clone(); let proof_runtime = proof_runtime_.clone(); - // Tasks to dispatch and aggregate one batch + // Tasks to dispatch proving jobs and aggregate segment proofs of one batch batch_proving_futures.push(async move { - // let mut pair_segment_data = Vec::new(); - let mut seg_aggregatable_proofs = Vec::new(); - // Wait for segments and dispatch them to the segment proof worker task - // There will always be pair number of segments, so we dispatch two segments - // and aggregate them as one chained directive to save - // a bit on local and transported data size. + let mut batch_segment_aggregatable_proofs = Vec::new(); + // Wait for segments and dispatch them to the segment proof worker task. + // The segment proof worker task will prove the segment and send it back. let mut segment_counter = 0; while let Some(segment_data) = segment_rx.recv().await { - // Prove the segment + // Prove the segment. if let Some(segment_data) = segment_data { debug!("proving the batch {batch_idx} segment data {segment_counter}"); let seg_aggregatable_proof = Directive::map( IndexedStream::from([segment_data]), &seg_prove_ops, ) - // .fold(&seg_agg_ops) - .run(&proof_runtime.heavy_proof) - .await?; + .run(&proof_runtime.heavy_proof) + .await?; let seg_aggregatable_proof = seg_aggregatable_proof.into_values_sorted().await?; - seg_aggregatable_proofs.extend(seg_aggregatable_proof); + batch_segment_aggregatable_proofs.extend(seg_aggregatable_proof); } segment_counter += 1; } debug!(block_number=%block_number, batch=%batch_idx, "finished proving all segments"); - // We have received and proved all the segments, - // now we need to aggregate to the batch proof. - // Fold the batch aggregated proof stream into a single proof. - let batch_proof = if seg_aggregatable_proofs.len() > 1 { - Directive::fold(IndexedStream::from(seg_aggregatable_proofs), &seg_agg_ops) + // We have proved all the segments in a batch, + // now we need to aggregate them to the batch proof. + // Fold the segment aggregated proof stream into a single batch proof. + let batch_proof = if batch_segment_aggregatable_proofs.len() == 1 { + // If there is only one segment aggregated proof, just transform it to batch proof. + (batch_idx, crate::proof_types::BatchAggregatableProof::from( + batch_segment_aggregatable_proofs.pop().unwrap(), + )) + } else { + Directive::fold(IndexedStream::from(batch_segment_aggregatable_proofs), &seg_agg_ops) .run(&proof_runtime.light_proof) .map(move |e| { e.map(|p| { @@ -231,11 +238,6 @@ impl BlockProverInput { }) }) .await? - } else { - // If there is only one segment aggregated proof, just transform it to batch proof - (batch_idx, crate::proof_types::BatchAggregatableProof::from( - seg_aggregatable_proofs.pop().unwrap(), - )) }; debug!(block_number=%block_number, batch=%batch_idx, "generated batch proof for block"); batch_proof_tx.send(batch_proof).await.context(format!( @@ -244,18 +246,23 @@ impl BlockProverInput { Ok::<(), anyhow::Error>(()) }); } - // Wait for all the batch proving tasks to finish + // Wait for all the batch proving tasks to finish. Exit early on error. while let Some(it) = batch_proving_futures.next().await { - it.context("Unable to send batch proof {}")?; + it?; } Ok::<(), anyhow::Error>(()) }); - // Collect all the batch proofs for proving tasks + // Collect all the batch proofs. let mut batch_proofs: Vec<(usize, crate::proof_types::BatchAggregatableProof)> = Vec::new(); while let Some((batch_idx, batch_proof)) = batch_proof_rx.recv().await { batch_proofs.push((batch_idx, batch_proof)); } + debug!(block_number=%block_number, "all the batch proofs are collected"); + + // Wait for the segment generation and proving tasks to finish. + try_join_all([segment_generation_task, proving_task]).await?; + batch_proofs.sort_by(|(a, _), (b, _)| a.cmp(b)); // Fold the batch aggregated proof stream into a single proof. From 0116e978a703ff60754a1e2803f0706922bdb685 Mon Sep 17 00:00:00 2001 From: Marko Atanasievski Date: Tue, 12 Nov 2024 16:06:50 +0100 Subject: [PATCH 3/8] fix: formatting --- zero/src/prover.rs | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/zero/src/prover.rs b/zero/src/prover.rs index cb2f2bfa0..f404442e3 100644 --- a/zero/src/prover.rs +++ b/zero/src/prover.rs @@ -206,13 +206,12 @@ impl BlockProverInput { // Prove the segment. if let Some(segment_data) = segment_data { debug!("proving the batch {batch_idx} segment data {segment_counter}"); - let seg_aggregatable_proof = Directive::map( - IndexedStream::from([segment_data]), - &seg_prove_ops, - ) - .run(&proof_runtime.heavy_proof) - .await?; - let seg_aggregatable_proof = seg_aggregatable_proof.into_values_sorted().await?; + let seg_aggregatable_proof = + Directive::map(IndexedStream::from([segment_data]), &seg_prove_ops) + .run(&proof_runtime.heavy_proof) + .await? + .into_values_sorted() + .await?; batch_segment_aggregatable_proofs.extend(seg_aggregatable_proof); } segment_counter += 1; From 3075ec87993c22d986b5888dd8a629c49fa3fb71 Mon Sep 17 00:00:00 2001 From: Marko Atanasievski Date: Tue, 12 Nov 2024 16:58:00 +0100 Subject: [PATCH 4/8] fix: further paralelize segment proofs --- zero/src/prover.rs | 67 +++++++++++++++++++++++++++++++++++----------- 1 file changed, 51 insertions(+), 16 deletions(-) diff --git a/zero/src/prover.rs b/zero/src/prover.rs index f404442e3..e47d16e56 100644 --- a/zero/src/prover.rs +++ b/zero/src/prover.rs @@ -147,8 +147,8 @@ impl BlockProverInput { }) .unzip(); - // Size of this channel does not matter mach, as it is used to collect batch - // proofs. + // The size of this channel does not matter much, as it is only used to collect + // batch proofs. let (batch_proof_tx, mut batch_proof_rx) = mpsc::channel::<(usize, crate::proof_types::BatchAggregatableProof)>(32); @@ -171,7 +171,7 @@ impl BlockProverInput { .await .context(format!("failed to send segment data for batch {batch_idx} segment {segment_idx}"))?; } - // Mark the end of the batch segments by sending None + // Mark the end of the batch segments by sending `None` segment_tx .send(None) .await @@ -187,7 +187,7 @@ impl BlockProverInput { }); let proof_runtime_ = proof_runtime.clone(); - let proving_task = tokio::spawn(async move { + let batches_proving_task = tokio::spawn(async move { let mut batch_proving_futures: FuturesUnordered<_> = FuturesUnordered::new(); // Span a proving subtask for each batch where we generate segment proofs // and aggregate them to batch proof. @@ -199,23 +199,58 @@ impl BlockProverInput { // Tasks to dispatch proving jobs and aggregate segment proofs of one batch batch_proving_futures.push(async move { let mut batch_segment_aggregatable_proofs = Vec::new(); + + // This channel collects segment proofs from the one batch + // proven in parallel. The size of this channel does not matter much, + // as it is only used to collect segment aggregatable proofs. + let (segment_proof_tx, mut segment_proof_rx) = + mpsc::channel::<(usize, crate::proof_types::SegmentAggregatableProof)>(32); + // Wait for segments and dispatch them to the segment proof worker task. // The segment proof worker task will prove the segment and send it back. let mut segment_counter = 0; + let mut segment_proving_tasks = Vec::new(); while let Some(segment_data) = segment_rx.recv().await { - // Prove the segment. - if let Some(segment_data) = segment_data { - debug!("proving the batch {batch_idx} segment data {segment_counter}"); - let seg_aggregatable_proof = - Directive::map(IndexedStream::from([segment_data]), &seg_prove_ops) + let seg_prove_ops = seg_prove_ops.clone(); + let proof_runtime = proof_runtime.clone(); + let segment_proof_tx = segment_proof_tx.clone(); + // Prove one segment in a dedicated async task. + let segment_proving_task = tokio::spawn(async move { + if let Some(segment_data) = segment_data { + debug!("proving the batch {batch_idx} segment data {segment_counter}"); + let seg_aggregatable_proof= Directive::map( + IndexedStream::from([segment_data]), + &seg_prove_ops, + ) .run(&proof_runtime.heavy_proof) .await? .into_values_sorted() - .await?; - batch_segment_aggregatable_proofs.extend(seg_aggregatable_proof); - } + .await? + .into_iter() + .next() + .context(format!( + "failed to get segment proof, batch: {batch_idx}, segment: {segment_counter}" + ))?; + + segment_proof_tx + .send((segment_counter, seg_aggregatable_proof)) + .await + .context(format!( + "unable to send segment proof, batch: {batch_idx}, segment: {segment_counter}" + ))?; + }; + Ok::<(), anyhow::Error>(()) + }); + + segment_proving_tasks.push(segment_proving_task); segment_counter += 1; } + // Wait for all the segment proving tasks of one batch to finish. + while let Some((segment_idx, segment_aggregatable_proof)) = segment_proof_rx.recv().await { + batch_segment_aggregatable_proofs.push((segment_idx, segment_aggregatable_proof)); + } + try_join_all(segment_proving_tasks).await?; + batch_segment_aggregatable_proofs.sort_by(|(a, _), (b, _)| a.cmp(b)); debug!(block_number=%block_number, batch=%batch_idx, "finished proving all segments"); // We have proved all the segments in a batch, // now we need to aggregate them to the batch proof. @@ -223,10 +258,10 @@ impl BlockProverInput { let batch_proof = if batch_segment_aggregatable_proofs.len() == 1 { // If there is only one segment aggregated proof, just transform it to batch proof. (batch_idx, crate::proof_types::BatchAggregatableProof::from( - batch_segment_aggregatable_proofs.pop().unwrap(), + batch_segment_aggregatable_proofs.pop().map(|(_, it)| it).unwrap(), )) } else { - Directive::fold(IndexedStream::from(batch_segment_aggregatable_proofs), &seg_agg_ops) + Directive::fold(IndexedStream::from(batch_segment_aggregatable_proofs.into_iter().map(|(_, it)| it)), &seg_agg_ops) .run(&proof_runtime.light_proof) .map(move |e| { e.map(|p| { @@ -257,10 +292,10 @@ impl BlockProverInput { while let Some((batch_idx, batch_proof)) = batch_proof_rx.recv().await { batch_proofs.push((batch_idx, batch_proof)); } - debug!(block_number=%block_number, "all the batch proofs are collected"); + debug!(block_number=%block_number, "collected all batch proofs"); // Wait for the segment generation and proving tasks to finish. - try_join_all([segment_generation_task, proving_task]).await?; + try_join_all([segment_generation_task, batches_proving_task]).await?; batch_proofs.sort_by(|(a, _), (b, _)| a.cmp(b)); From f428d11a8c40e56debfd89549fc39585eaa3b613 Mon Sep 17 00:00:00 2001 From: Marko Atanasievski Date: Tue, 12 Nov 2024 19:25:18 +0100 Subject: [PATCH 5/8] fix: remove worker number limit --- scripts/prove_stdio.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/prove_stdio.sh b/scripts/prove_stdio.sh index 440f4692b..04bbbbe6b 100755 --- a/scripts/prove_stdio.sh +++ b/scripts/prove_stdio.sh @@ -121,7 +121,7 @@ cargo build --release --jobs "$num_procs" start_time=$(date +%s%N) cmd=("${REPO_ROOT}/target/release/leader" --runtime in-memory \ - --load-strategy on-demand -n 1 \ + --load-strategy on-demand \ --block-batch-size "$BLOCK_BATCH_SIZE") if [[ "$USE_TEST_CONFIG" == "use_test_config" ]]; then From a8d5b62e2fd7b3db17e94781d51bf39d80d7df8b Mon Sep 17 00:00:00 2001 From: Marko Atanasievski Date: Tue, 12 Nov 2024 20:42:35 +0100 Subject: [PATCH 6/8] fix: deadlock --- zero/src/prover.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/zero/src/prover.rs b/zero/src/prover.rs index e47d16e56..d458d30f4 100644 --- a/zero/src/prover.rs +++ b/zero/src/prover.rs @@ -211,6 +211,9 @@ impl BlockProverInput { let mut segment_counter = 0; let mut segment_proving_tasks = Vec::new(); while let Some(segment_data) = segment_rx.recv().await { + if segment_data.is_none() { + break; + } let seg_prove_ops = seg_prove_ops.clone(); let proof_runtime = proof_runtime.clone(); let segment_proof_tx = segment_proof_tx.clone(); @@ -245,6 +248,7 @@ impl BlockProverInput { segment_proving_tasks.push(segment_proving_task); segment_counter += 1; } + drop(segment_proof_tx); // Wait for all the segment proving tasks of one batch to finish. while let Some((segment_idx, segment_aggregatable_proof)) = segment_proof_rx.recv().await { batch_segment_aggregatable_proofs.push((segment_idx, segment_aggregatable_proof)); From 86d619a365abfa3acd206536262bafae852fe950 Mon Sep 17 00:00:00 2001 From: Marko Atanasievski Date: Wed, 13 Nov 2024 04:14:27 +0100 Subject: [PATCH 7/8] fix: comment --- zero/src/prover.rs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/zero/src/prover.rs b/zero/src/prover.rs index d458d30f4..4962adce5 100644 --- a/zero/src/prover.rs +++ b/zero/src/prover.rs @@ -136,9 +136,8 @@ impl BlockProverInput { }; // Generate channels to communicate segments of each batch to a batch proving - // task. We generate a segment and send it to the proving task, then - // wait for it to be proved before we send another segment. This is done - // to avoid memory exhaustion. + // task. We generate segments and send them to the proving task, where they + // are proven in parallel. let (segment_senders, segment_receivers): (Vec<_>, Vec<_>) = (0..batch_count) .map(|_idx| { let (segment_tx, segment_rx) = @@ -220,7 +219,7 @@ impl BlockProverInput { // Prove one segment in a dedicated async task. let segment_proving_task = tokio::spawn(async move { if let Some(segment_data) = segment_data { - debug!("proving the batch {batch_idx} segment data {segment_counter}"); + debug!("proving the batch {batch_idx} segment nr. {segment_counter}"); let seg_aggregatable_proof= Directive::map( IndexedStream::from([segment_data]), &seg_prove_ops, From 02e0b8d325f64e7817fac9309329fc7cf1d53147 Mon Sep 17 00:00:00 2001 From: Marko Atanasievski Date: Thu, 14 Nov 2024 13:36:20 +0100 Subject: [PATCH 8/8] fix: review --- zero/src/prover.rs | 77 ++++++++++++++++++++++------------------------ 1 file changed, 37 insertions(+), 40 deletions(-) diff --git a/zero/src/prover.rs b/zero/src/prover.rs index 4962adce5..5b6975180 100644 --- a/zero/src/prover.rs +++ b/zero/src/prover.rs @@ -10,10 +10,11 @@ use alloy::primitives::U256; use anyhow::{Context, Result}; use evm_arithmetization::Field; use evm_arithmetization::SegmentDataIterator; -use futures::future::try_join_all; use futures::{ - future, future::BoxFuture, stream::FuturesUnordered, FutureExt, StreamExt, TryFutureExt, - TryStreamExt, + future::BoxFuture, + future::{self, try_join, try_join_all}, + stream::FuturesUnordered, + FutureExt as _, StreamExt as _, TryFutureExt as _, TryStreamExt as _, }; use hashbrown::HashMap; use num_traits::ToPrimitive as _; @@ -175,19 +176,20 @@ impl BlockProverInput { .send(None) .await .context(format!("failed to send end segment data indicator for batch {batch_idx}"))?; - Ok::<(), anyhow::Error>(()) + anyhow::Ok(()) }); } while let Some(it) = batch_segment_futures.next().await { // In case of an error, propagate the error to the main task it?; } - Ok::<(), anyhow::Error>(()) + let () = batch_segment_futures.try_collect().await?; + anyhow::Ok(()) }); let proof_runtime_ = proof_runtime.clone(); let batches_proving_task = tokio::spawn(async move { - let mut batch_proving_futures: FuturesUnordered<_> = FuturesUnordered::new(); + let mut batch_proving_futures = FuturesUnordered::new(); // Span a proving subtask for each batch where we generate segment proofs // and aggregate them to batch proof. for (batch_idx, mut segment_rx) in segment_receivers.into_iter().enumerate() { @@ -209,39 +211,34 @@ impl BlockProverInput { // The segment proof worker task will prove the segment and send it back. let mut segment_counter = 0; let mut segment_proving_tasks = Vec::new(); - while let Some(segment_data) = segment_rx.recv().await { - if segment_data.is_none() { - break; - } + while let Some(Some(segment_data)) = segment_rx.recv().await { let seg_prove_ops = seg_prove_ops.clone(); let proof_runtime = proof_runtime.clone(); let segment_proof_tx = segment_proof_tx.clone(); // Prove one segment in a dedicated async task. let segment_proving_task = tokio::spawn(async move { - if let Some(segment_data) = segment_data { - debug!("proving the batch {batch_idx} segment nr. {segment_counter}"); - let seg_aggregatable_proof= Directive::map( - IndexedStream::from([segment_data]), - &seg_prove_ops, - ) - .run(&proof_runtime.heavy_proof) - .await? - .into_values_sorted() - .await? - .into_iter() - .next() - .context(format!( - "failed to get segment proof, batch: {batch_idx}, segment: {segment_counter}" - ))?; - - segment_proof_tx - .send((segment_counter, seg_aggregatable_proof)) - .await - .context(format!( - "unable to send segment proof, batch: {batch_idx}, segment: {segment_counter}" - ))?; - }; - Ok::<(), anyhow::Error>(()) + debug!(%batch_idx, %segment_counter, "proving batch segment"); + let seg_aggregatable_proof= Directive::map( + IndexedStream::from([segment_data]), + &seg_prove_ops, + ) + .run(&proof_runtime.heavy_proof) + .await? + .into_values_sorted() + .await? + .into_iter() + .next() + .context(format!( + "failed to get segment proof, batch: {batch_idx}, segment: {segment_counter}" + ))?; + + segment_proof_tx + .send((segment_counter, seg_aggregatable_proof)) + .await + .context(format!( + "unable to send segment proof, batch: {batch_idx}, segment: {segment_counter}" + ))?; + anyhow::Ok(()) }); segment_proving_tasks.push(segment_proving_task); @@ -254,7 +251,7 @@ impl BlockProverInput { } try_join_all(segment_proving_tasks).await?; batch_segment_aggregatable_proofs.sort_by(|(a, _), (b, _)| a.cmp(b)); - debug!(block_number=%block_number, batch=%batch_idx, "finished proving all segments"); + debug!(%block_number, batch=%batch_idx, "finished proving all segments"); // We have proved all the segments in a batch, // now we need to aggregate them to the batch proof. // Fold the segment aggregated proof stream into a single batch proof. @@ -276,18 +273,18 @@ impl BlockProverInput { }) .await? }; - debug!(block_number=%block_number, batch=%batch_idx, "generated batch proof for block"); + debug!(%block_number, batch=%batch_idx, "generated batch proof for block"); batch_proof_tx.send(batch_proof).await.context(format!( "unable to send batch proof, block: {block_number}, batch: {batch_idx}" ))?; - Ok::<(), anyhow::Error>(()) + anyhow::Ok(()) }); } // Wait for all the batch proving tasks to finish. Exit early on error. while let Some(it) = batch_proving_futures.next().await { it?; } - Ok::<(), anyhow::Error>(()) + anyhow::Ok(()) }); // Collect all the batch proofs. @@ -295,10 +292,10 @@ impl BlockProverInput { while let Some((batch_idx, batch_proof)) = batch_proof_rx.recv().await { batch_proofs.push((batch_idx, batch_proof)); } - debug!(block_number=%block_number, "collected all batch proofs"); + debug!(%block_number, "collected all batch proofs"); // Wait for the segment generation and proving tasks to finish. - try_join_all([segment_generation_task, batches_proving_task]).await?; + let _ = try_join(segment_generation_task, batches_proving_task).await?; batch_proofs.sort_by(|(a, _), (b, _)| a.cmp(b));