Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update zkm receipt #43

Merged
merged 2 commits into from
Dec 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions sdk/src/local/stark.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,6 @@ pub fn prove_stark(
return Ok(false);
}
log::info!("[The seg_num is:{} ]", &seg_num);
if seg_num == 1 {
let seg_file = format!("{seg_path}/{}", 0);
util::prove_single_seg_common(&seg_file, "", "", "")?;
Ok(false)
} else {
util::prove_multi_seg_common(&seg_path, "", "", "", storedir, seg_num, 0)?;
Ok(true)
}
util::prove_segments(&seg_path, "", storedir, "", "", seg_num, 0, vec![])?;
Ok(seg_num > 1)
}
198 changes: 67 additions & 131 deletions sdk/src/local/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,76 +3,38 @@ use std::io::BufReader;
use std::ops::Range;
use std::time::Duration;

use plonky2::field::goldilocks_field::GoldilocksField;
use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig};
use plonky2::util::timing::TimingTree;
use plonky2x::backend::circuit::Groth16WrapperParameters;
use plonky2x::backend::wrapper::wrap::WrappedCircuit;
use plonky2x::frontend::builder::CircuitBuilder as WrapperBuilder;
use plonky2x::prelude::DefaultParameters;

use zkm_prover::all_stark::AllStark;
use zkm_prover::config::StarkConfig;
use zkm_prover::cpu::kernel::assembler::segment_kernel;
use zkm_prover::fixed_recursive_verifier::AllRecursiveCircuits;
use zkm_prover::proof;
use zkm_prover::proof::PublicValues;
use zkm_prover::prover::prove;
use zkm_prover::verifier::verify_proof;
use zkm_prover::generation::state::{AssumptionReceipts, Receipt};

const DEGREE_BITS_RANGE: [Range<usize>; 6] = [10..21, 12..22, 10..21, 10..21, 6..21, 13..23];

pub fn prove_single_seg_common(
seg_file: &str,
basedir: &str,
block: &str,
file: &str,
) -> anyhow::Result<()> {
let seg_reader = BufReader::new(File::open(seg_file)?);
let kernel = segment_kernel(basedir, block, file, seg_reader);

const D: usize = 2;
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;

let allstark: AllStark<F, D> = AllStark::default();
let config = StarkConfig::standard_fast_config();
let mut timing = TimingTree::new("prove", log::Level::Info);
let allproof: proof::AllProof<GoldilocksField, C, D> =
prove(&allstark, &kernel, &config, &mut timing)?;
let mut count_bytes = 0;
for (row, proof) in allproof.stark_proofs.clone().iter().enumerate() {
let proof_str = serde_json::to_string(&proof.proof)?;
log::info!("row:{} proof bytes:{}", row, proof_str.len());
count_bytes += proof_str.len();
}
timing.filter(Duration::from_millis(100)).print();
log::info!("total proof bytes:{}KB", count_bytes / 1024);
verify_proof(&allstark, allproof, &config)?;
log::info!("Prove done");
Ok(())
}
const D: usize = 2;
type C = PoseidonGoldilocksConfig;
type F = <C as GenericConfig<D>>::F;

#[allow(clippy::too_many_arguments)]
pub fn prove_multi_seg_common(
pub fn prove_segments(
seg_dir: &str,
basedir: &str,
outdir: &str,
block: &str,
file: &str,
outdir: &str,
seg_file_number: usize,
seg_start_id: usize,
) -> anyhow::Result<()> {
assumptions: AssumptionReceipts<F, C, D>,
) -> anyhow::Result<Receipt<F, C, D>> {
type InnerParameters = DefaultParameters;
type OuterParameters = Groth16WrapperParameters;

type F = GoldilocksField;
const D: usize = 2;
type C = PoseidonGoldilocksConfig;

if seg_file_number < 2 {
panic!("seg file number must >= 2\n");
}

let total_timing = TimingTree::new("prove total time", log::Level::Info);
let all_stark = AllStark::<F, D>::default();
let config = StarkConfig::standard_fast_config();
Expand All @@ -85,45 +47,37 @@ pub fn prove_multi_seg_common(
let seg_reader = BufReader::new(File::open(seg_file)?);
let input_first = segment_kernel(basedir, block, file, seg_reader);
let mut timing = TimingTree::new("prove root first", log::Level::Info);
let (mut agg_proof, mut updated_agg_public_values) =
all_circuits.prove_root(&all_stark, &input_first, &config, &mut timing)?;
let mut agg_receipt = all_circuits.prove_root_with_assumption(
&all_stark,
&input_first,
&config,
&mut timing,
assumptions,
)?;

timing.filter(Duration::from_millis(100)).print();
all_circuits.verify_root(agg_proof.clone())?;
all_circuits.verify_root(agg_receipt.clone())?;

let mut base_seg = seg_start_id + 1;
let mut seg_num = seg_file_number - 1;
let mut is_agg: bool = false;
let mut is_agg = false;

if seg_file_number % 2 == 0 {
let seg_file = format!("{}/{}", seg_dir, seg_start_id + 1);
log::info!("Process segment {}", seg_file);
let seg_reader = BufReader::new(File::open(seg_file)?);
let input = segment_kernel(basedir, block, file, seg_reader);
timing = TimingTree::new("prove root second", log::Level::Info);
let (root_proof, public_values) =
all_circuits.prove_root(&all_stark, &input, &config, &mut timing)?;
let receipt = all_circuits.prove_root(&all_stark, &input, &config, &mut timing)?;
timing.filter(Duration::from_millis(100)).print();

all_circuits.verify_root(root_proof.clone())?;
all_circuits.verify_root(receipt.clone())?;

// Update public values for the aggregation.
let agg_public_values = PublicValues {
roots_before: updated_agg_public_values.roots_before,
roots_after: public_values.roots_after,
userdata: public_values.userdata,
};
timing = TimingTree::new("prove aggression", log::Level::Info);
// We can duplicate the proofs here because the state hasn't mutated.
(agg_proof, updated_agg_public_values) = all_circuits.prove_aggregation(
false,
&agg_proof,
false,
&root_proof,
agg_public_values.clone(),
)?;
agg_receipt = all_circuits.prove_aggregation(false, &agg_receipt, false, &receipt)?;
timing.filter(Duration::from_millis(100)).print();
all_circuits.verify_aggregation(&agg_proof)?;
all_circuits.verify_aggregation(&agg_receipt)?;

is_agg = true;
base_seg = seg_start_id + 2;
Expand All @@ -136,96 +90,78 @@ pub fn prove_multi_seg_common(
let seg_reader = BufReader::new(File::open(&seg_file)?);
let input_first = segment_kernel(basedir, block, file, seg_reader);
let mut timing = TimingTree::new("prove root first", log::Level::Info);
let (root_proof_first, first_public_values) =
let root_receipt_first =
all_circuits.prove_root(&all_stark, &input_first, &config, &mut timing)?;

timing.filter(Duration::from_millis(100)).print();
all_circuits.verify_root(root_proof_first.clone())?;
all_circuits.verify_root(root_receipt_first.clone())?;

let seg_file = format!("{}/{}", seg_dir, base_seg + (i << 1) + 1);
log::info!("Process segment {}", seg_file);
let seg_reader = BufReader::new(File::open(&seg_file)?);
let input = segment_kernel(basedir, block, file, seg_reader);
let mut timing = TimingTree::new("prove root second", log::Level::Info);
let (root_proof, public_values) =
all_circuits.prove_root(&all_stark, &input, &config, &mut timing)?;
let root_receipt = all_circuits.prove_root(&all_stark, &input, &config, &mut timing)?;
timing.filter(Duration::from_millis(100)).print();

all_circuits.verify_root(root_proof.clone())?;
all_circuits.verify_root(root_receipt.clone())?;

// Update public values for the aggregation.
let new_agg_public_values = PublicValues {
roots_before: first_public_values.roots_before,
roots_after: public_values.roots_after,
userdata: public_values.userdata,
};
timing = TimingTree::new("prove aggression", log::Level::Info);
// We can duplicate the proofs here because the state hasn't mutated.
let (new_agg_proof, new_updated_agg_public_values) = all_circuits.prove_aggregation(
false,
&root_proof_first,
false,
&root_proof,
new_agg_public_values,
)?;
let new_agg_receipt =
all_circuits.prove_aggregation(false, &root_receipt_first, false, &root_receipt)?;
timing.filter(Duration::from_millis(100)).print();
all_circuits.verify_aggregation(&new_agg_proof)?;

// Update public values for the nested aggregation.
let agg_public_values = PublicValues {
roots_before: updated_agg_public_values.roots_before,
roots_after: new_updated_agg_public_values.roots_after,
userdata: new_updated_agg_public_values.userdata,
};
all_circuits.verify_aggregation(&new_agg_receipt)?;

timing = TimingTree::new("prove nested aggression", log::Level::Info);

// We can duplicate the proofs here because the state hasn't mutated.
(agg_proof, updated_agg_public_values) = all_circuits.prove_aggregation(
is_agg,
&agg_proof,
true,
&new_agg_proof,
agg_public_values.clone(),
)?;
agg_receipt =
all_circuits.prove_aggregation(is_agg, &agg_receipt, true, &new_agg_receipt)?;
is_agg = true;
timing.filter(Duration::from_millis(100)).print();

all_circuits.verify_aggregation(&agg_proof)?;
all_circuits.verify_aggregation(&agg_receipt)?;
}

let (block_proof, _block_public_values) =
all_circuits.prove_block(None, &agg_proof, updated_agg_public_values.clone())?;

log::info!(
"proof size: {:?}",
serde_json::to_string(&block_proof.proof).unwrap().len()
);
let result = all_circuits.verify_block(&block_proof);

let builder = WrapperBuilder::<DefaultParameters, 2>::new();
let mut circuit = builder.build();
circuit.set_data(all_circuits.block.circuit);
let mut bit_size = vec![32usize; 16];
bit_size.extend(vec![8; 32]);
bit_size.extend(vec![64; 68]);
let wrapped_circuit = WrappedCircuit::<InnerParameters, OuterParameters, D>::build(
circuit,
Some((vec![], bit_size)),
serde_json::to_string(&agg_receipt.proof().proof)
.unwrap()
.len()
);
log::info!("build finish");
let final_receipt = if seg_file_number > 1 {
let block_receipt = all_circuits.prove_block(None, &agg_receipt)?;
all_circuits.verify_block(&block_receipt)?;
let builder = WrapperBuilder::<DefaultParameters, 2>::new();
let mut circuit = builder.build();
circuit.set_data(all_circuits.block.circuit);
let mut bit_size = vec![32usize; 16];
bit_size.extend(vec![8; 32]);
bit_size.extend(vec![64; 68]);
let wrapped_circuit = WrappedCircuit::<InnerParameters, OuterParameters, D>::build(
circuit,
Some((vec![], bit_size)),
);
let wrapped_proof = wrapped_circuit.prove(&block_receipt.proof()).unwrap();
wrapped_proof.save(outdir).unwrap();

let block_public_inputs = serde_json::json!({
"public_inputs": wrapped_proof.proof.public_inputs,
});
let outdir_path = std::path::Path::new(outdir);
let public_values_file = File::create(outdir_path.join("public_values.json"))?;
serde_json::to_writer(&public_values_file, &block_receipt.values())?;
let block_public_inputs_file = File::create(outdir_path.join("block_public_inputs.json"))?;
serde_json::to_writer(&block_public_inputs_file, &block_public_inputs)?;

block_receipt
} else {
agg_receipt
};

let wrapped_proof = wrapped_circuit.prove(&block_proof)?;
wrapped_proof.save(outdir)?;

let block_public_inputs = serde_json::json!({
"public_inputs": block_proof.public_inputs,
});
let outdir_path = std::path::Path::new(outdir);
let public_values_file = File::create(outdir_path.join("public_values.json"))?;
serde_json::to_writer(&public_values_file, &updated_agg_public_values)?;
let block_public_inputs_file = File::create(outdir_path.join("block_public_inputs.json"))?;
serde_json::to_writer(&block_public_inputs_file, &block_public_inputs)?;
log::info!("build finish");

total_timing.filter(Duration::from_millis(100)).print();
result
Ok(final_receipt)
}
Loading