Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: optimize segment proof aggregation #410

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,22 @@ jobs:
CARGO_INCREMENTAL: 1
RUST_BACKTRACE: 1

simple_proof_witness_only:
name: Test witness for a small block.
runs-on: zero-ci

steps:
- name: Checkout code
uses: actions/checkout@v3

- name: Run the script
run: |
pushd zero_bin/tools
./prove_stdio.sh artifacts/witness_b19240705.json test_only


simple_proof_regular:
name: Execute bash script to generate and verify a proof for a small block.
name: Test real proof for a small block.
runs-on: zero-ci

steps:
Expand All @@ -189,8 +203,8 @@ jobs:
pushd zero_bin/tools
./prove_stdio.sh artifacts/witness_b19240705.json

simple_proof_witness_only:
name: Execute bash script to generate the proof witness for a small block.
simple_proof_using_continuations:
name: Test proof for a small block with custom batch and segment chunk size
runs-on: zero-ci

steps:
Expand All @@ -200,10 +214,11 @@ jobs:
- name: Run the script
run: |
pushd zero_bin/tools
./prove_stdio.sh artifacts/witness_b19240705.json test_only
PROVER_BATCH_SIZE=3 PROVER_SEGMENT_CHUNK_SIZE=4 PROVER_MAX_CPU_LEN_LOG=17 ./prove_stdio.sh ./artifacts/witness_b19240705.json


multi_blocks_proof_regular:
name: Execute bash script to generate and verify a proof for multiple blocks using parallel proving.
name: Test proof for multiple blocks using parallel proving.
runs-on: zero-ci

steps:
Expand Down
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

37 changes: 37 additions & 0 deletions evm_arithmetization/src/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,43 @@ impl<F: RichField> Iterator for SegmentDataIterator<F> {
}
}

pub struct SegmentDataChunkIterator<'a, F: RichField> {
segment_data_iter: &'a mut SegmentDataIterator<F>,
chunk_size: usize,
}

impl<'a, F: RichField> SegmentDataChunkIterator<'a, F> {
pub fn new(segment_data_iter: &'a mut SegmentDataIterator<F>, chunk_size: usize) -> Self {
Self {
segment_data_iter,
chunk_size,
}
}
}

impl<'a, F: RichField> Iterator for SegmentDataChunkIterator<'a, F> {
type Item = Vec<(TrimmedGenerationInputs, GenerationSegmentData)>;

fn next(&mut self) -> Option<Self::Item> {
let mut chunk_empty_space = self.chunk_size as isize;
let mut chunk = Vec::with_capacity(self.chunk_size);
while chunk_empty_space > 0 {
chunk_empty_space -= 1;
if let Some(it) = self.segment_data_iter.next() {
chunk.push(it);
} else {
break;
}
}

if chunk.is_empty() {
None
} else {
Some(chunk)
}
}
}

/// A utility module designed to test witness generation externally.
pub mod testing {
use super::*;
Expand Down
8 changes: 4 additions & 4 deletions proof_gen/src/proof_gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ use plonky2::{

use crate::{
proof_types::{
GeneratedBlockProof, GeneratedSegmentAggProof, GeneratedSegmentProof, GeneratedTxnAggProof,
SegmentAggregatableProof, TxnAggregatableProof,
BatchAggregatableProof, GeneratedBlockProof, GeneratedSegmentAggProof,
GeneratedSegmentProof, GeneratedTxnAggProof, SegmentAggregatableProof,
},
prover_state::ProverState,
types::{Field, PlonkyProofIntern, EXTENSION_DEGREE},
Expand Down Expand Up @@ -121,8 +121,8 @@ pub fn generate_segment_agg_proof(
/// Note that the child proofs may be either transaction or aggregation proofs.
pub fn generate_transaction_agg_proof(
p_state: &ProverState,
lhs_child: &TxnAggregatableProof,
rhs_child: &TxnAggregatableProof,
lhs_child: &BatchAggregatableProof,
rhs_child: &BatchAggregatableProof,
) -> ProofGenResult<GeneratedTxnAggProof> {
let (b_proof_intern, p_vals) = p_state
.state
Expand Down
32 changes: 16 additions & 16 deletions proof_gen/src/proof_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ pub enum SegmentAggregatableProof {
/// we can combine it into an agg proof. For these cases, we want to abstract
/// away whether or not the proof was a txn or agg proof.
#[derive(Clone, Debug, Deserialize, Serialize)]
pub enum TxnAggregatableProof {
pub enum BatchAggregatableProof {
/// The underlying proof is a segment proof. It first needs to be aggregated
/// with another segment proof, or a dummy one.
Segment(GeneratedSegmentProof),
Expand Down Expand Up @@ -100,28 +100,28 @@ impl SegmentAggregatableProof {
}
}

impl TxnAggregatableProof {
impl BatchAggregatableProof {
pub(crate) fn public_values(&self) -> PublicValues {
match self {
TxnAggregatableProof::Segment(info) => info.p_vals.clone(),
TxnAggregatableProof::Txn(info) => info.p_vals.clone(),
TxnAggregatableProof::Agg(info) => info.p_vals.clone(),
BatchAggregatableProof::Segment(info) => info.p_vals.clone(),
BatchAggregatableProof::Txn(info) => info.p_vals.clone(),
BatchAggregatableProof::Agg(info) => info.p_vals.clone(),
}
}

pub(crate) fn is_agg(&self) -> bool {
match self {
TxnAggregatableProof::Segment(_) => false,
TxnAggregatableProof::Txn(_) => false,
TxnAggregatableProof::Agg(_) => true,
BatchAggregatableProof::Segment(_) => false,
BatchAggregatableProof::Txn(_) => false,
BatchAggregatableProof::Agg(_) => true,
}
}

pub(crate) fn intern(&self) -> &PlonkyProofIntern {
match self {
TxnAggregatableProof::Segment(info) => &info.intern,
TxnAggregatableProof::Txn(info) => &info.intern,
TxnAggregatableProof::Agg(info) => &info.intern,
BatchAggregatableProof::Segment(info) => &info.intern,
BatchAggregatableProof::Txn(info) => &info.intern,
BatchAggregatableProof::Agg(info) => &info.intern,
}
}
}
Expand All @@ -138,23 +138,23 @@ impl From<GeneratedSegmentAggProof> for SegmentAggregatableProof {
}
}

impl From<GeneratedSegmentAggProof> for TxnAggregatableProof {
impl From<GeneratedSegmentAggProof> for BatchAggregatableProof {
fn from(v: GeneratedSegmentAggProof) -> Self {
Self::Txn(v)
}
}

impl From<GeneratedTxnAggProof> for TxnAggregatableProof {
impl From<GeneratedTxnAggProof> for BatchAggregatableProof {
fn from(v: GeneratedTxnAggProof) -> Self {
Self::Agg(v)
}
}

impl From<SegmentAggregatableProof> for TxnAggregatableProof {
impl From<SegmentAggregatableProof> for BatchAggregatableProof {
fn from(v: SegmentAggregatableProof) -> Self {
match v {
SegmentAggregatableProof::Agg(agg) => TxnAggregatableProof::Txn(agg),
SegmentAggregatableProof::Seg(seg) => TxnAggregatableProof::Segment(seg),
SegmentAggregatableProof::Agg(agg) => BatchAggregatableProof::Txn(agg),
SegmentAggregatableProof::Seg(seg) => BatchAggregatableProof::Segment(seg),
}
}
}
27 changes: 5 additions & 22 deletions zero_bin/leader/src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::path::PathBuf;

use alloy::transports::http::reqwest::Url;
use clap::{Parser, Subcommand, ValueHint};
use prover::cli::CliProverConfig;
use rpc::RpcType;
use zero_bin_common::prover_state::cli::CliProverStateConfig;

Expand All @@ -14,6 +15,9 @@ pub(crate) struct Cli {
#[clap(flatten)]
pub(crate) paladin: paladin::config::Config,

#[clap(flatten)]
pub(crate) prover_config: CliProverConfig,

// Note this is only relevant for the leader when running in in-memory
// mode.
#[clap(flatten)]
Expand All @@ -22,18 +26,12 @@ pub(crate) struct Cli {

#[derive(Subcommand)]
pub(crate) enum Command {
//TODO unify parameters for all use cases
/// Reads input from stdin and writes output to stdout.
Stdio {
/// The previous proof output.
#[arg(long, short = 'f', value_hint = ValueHint::FilePath)]
previous_proof: Option<PathBuf>,
#[arg(short, long, default_value_t = 20)]
max_cpu_len_log: usize,
#[arg(short, long, default_value_t = 1)]
batch_size: usize,
/// If true, save the public inputs to disk on error.
#[arg(short, long, default_value_t = false)]
save_inputs_on_error: bool,
},
/// Reads input from a node rpc and writes output to stdout.
Rpc {
Expand All @@ -56,14 +54,6 @@ pub(crate) enum Command {
/// stdout.
#[arg(long, short = 'o', value_hint = ValueHint::FilePath)]
proof_output_dir: Option<PathBuf>,
/// The log of the max number of CPU cycles per proof.
#[arg(short, long, default_value_t = 20)]
max_cpu_len_log: usize,
#[arg(short, long, default_value_t = 1)]
batch_size: usize,
/// If true, save the public inputs to disk on error.
#[arg(short, long, default_value_t = false)]
save_inputs_on_error: bool,
/// Network block time in milliseconds. This value is used
/// to determine the blockchain node polling interval.
#[arg(short, long, env = "ZERO_BIN_BLOCK_TIME", default_value_t = 2000)]
Expand Down Expand Up @@ -92,12 +82,5 @@ pub(crate) enum Command {
/// The directory to which output should be written.
#[arg(short, long, value_hint = ValueHint::DirPath)]
output_dir: PathBuf,
#[arg(short, long, default_value_t = 20)]
max_cpu_len_log: usize,
#[arg(short, long, default_value_t = 1)]
batch_size: usize,
/// If true, save the public inputs to disk on error.
#[arg(short, long, default_value_t = false)]
save_inputs_on_error: bool,
},
}
11 changes: 4 additions & 7 deletions zero_bin/leader/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use alloy::transports::http::reqwest::Url;
use anyhow::Result;
use paladin::runtime::Runtime;
use proof_gen::proof_types::GeneratedBlockProof;
use prover::ProverConfig;
use rpc::{retry::build_http_retry_provider, RpcType};
use tracing::{error, info, warn};
use zero_bin_common::block_interval::BlockInterval;
Expand All @@ -18,14 +19,12 @@ pub struct RpcParams {
pub max_retries: u32,
}

#[derive(Debug, Default)]
#[derive(Debug)]
pub struct ProofParams {
pub checkpoint_block_number: u64,
pub previous_proof: Option<GeneratedBlockProof>,
pub proof_output_dir: Option<PathBuf>,
pub max_cpu_len_log: usize,
pub batch_size: usize,
pub save_inputs_on_error: bool,
pub prover_config: ProverConfig,
pub keep_intermediate_proofs: bool,
}

Expand Down Expand Up @@ -56,10 +55,8 @@ pub(crate) async fn client_main(
let proved_blocks = prover_input
.prove(
&runtime,
params.max_cpu_len_log,
params.previous_proof.take(),
params.batch_size,
params.save_inputs_on_error,
params.prover_config,
params.proof_output_dir.clone(),
)
.await;
Expand Down
25 changes: 5 additions & 20 deletions zero_bin/leader/src/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use anyhow::{bail, Result};
use axum::{http::StatusCode, routing::post, Json, Router};
use paladin::runtime::Runtime;
use proof_gen::proof_types::GeneratedBlockProof;
use prover::BlockProverInput;
use prover::{BlockProverInput, ProverConfig};
use serde::{Deserialize, Serialize};
use serde_json::to_writer;
use tracing::{debug, error, info};
Expand All @@ -15,9 +15,7 @@ pub(crate) async fn http_main(
runtime: Runtime,
port: u16,
output_dir: PathBuf,
max_cpu_len_log: usize,
batch_size: usize,
save_inputs_on_error: bool,
prover_config: ProverConfig,
) -> Result<()> {
let addr = SocketAddr::from(([0, 0, 0, 0], port));
debug!("listening on {}", addr);
Expand All @@ -27,16 +25,7 @@ pub(crate) async fn http_main(
"/prove",
post({
let runtime = runtime.clone();
move |body| {
prove(
body,
runtime,
output_dir.clone(),
max_cpu_len_log,
batch_size,
save_inputs_on_error,
)
}
move |body| prove(body, runtime, output_dir.clone(), prover_config)
}),
);
let listener = tokio::net::TcpListener::bind(&addr).await?;
Expand Down Expand Up @@ -76,9 +65,7 @@ async fn prove(
Json(payload): Json<HttpProverInput>,
runtime: Arc<Runtime>,
output_dir: PathBuf,
max_cpu_len_log: usize,
batch_size: usize,
save_inputs_on_error: bool,
prover_config: ProverConfig,
) -> StatusCode {
debug!("Received payload: {:#?}", payload);

Expand All @@ -88,10 +75,8 @@ async fn prove(
.prover_input
.prove(
&runtime,
max_cpu_len_log,
payload.previous.map(futures::future::ok),
batch_size,
save_inputs_on_error,
prover_config,
)
.await
{
Expand Down
Loading
Loading