From 490ec6330a8f502f84c9d76a223fcca640a5299f Mon Sep 17 00:00:00 2001 From: "xander.z" <162873981+xander42280@users.noreply.github.com> Date: Sat, 14 Sep 2024 13:57:20 +0800 Subject: [PATCH] add local prover (#10) * add local prover * fmt & clippy * Add compile_and_install.sh * CI support libsnark * Simplify code * fmt * Abstract util --- .github/workflows/ci.yml | 8 + host-program/src/bin/add-go-prove.rs | 232 +------------ host-program/src/bin/revme-prove.rs | 230 +------------ sdk/Cargo.toml | 8 +- sdk/build.rs | 2 + sdk/src/lib.rs | 15 +- sdk/src/local/cpu.rs | 455 ------------------------- sdk/src/local/libsnark/compile.sh | 23 ++ sdk/src/local/libsnark/contract.go | 215 ++++++++++++ sdk/src/local/libsnark/go.mod | 35 ++ sdk/src/local/libsnark/go.sum | 76 +++++ sdk/src/local/libsnark/libsnark.go | 22 ++ sdk/src/local/libsnark/snark_prover.go | 223 ++++++++++++ sdk/src/local/mod.rs | 6 +- sdk/src/local/prover.rs | 120 +++++++ sdk/src/local/snark.rs | 14 + sdk/src/local/stark.rs | 37 ++ sdk/src/local/util.rs | 223 ++++++++++++ sdk/src/prover.rs | 4 +- 19 files changed, 1027 insertions(+), 921 deletions(-) delete mode 100644 sdk/src/local/cpu.rs create mode 100755 sdk/src/local/libsnark/compile.sh create mode 100644 sdk/src/local/libsnark/contract.go create mode 100644 sdk/src/local/libsnark/go.mod create mode 100644 sdk/src/local/libsnark/go.sum create mode 100644 sdk/src/local/libsnark/libsnark.go create mode 100644 sdk/src/local/libsnark/snark_prover.go create mode 100644 sdk/src/local/prover.rs create mode 100644 sdk/src/local/snark.rs create mode 100644 sdk/src/local/stark.rs create mode 100644 sdk/src/local/util.rs diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b5791f8e..e31bd76e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -24,6 +24,14 @@ jobs: - run: rustup update ${{ matrix.toolchain }} && rustup default ${{ matrix.toolchain }} - name: Install Dependencies run: sudo apt install protobuf-compiler + - name: Install Go + uses: actions/setup-go@v3 + with: + go-version: '1.22' + - name: Build Go library + run: | + chmod +x ./sdk/src/local/libsnark/compile.sh + ./sdk/src/local/libsnark/compile.sh - name: Cargo build run: cargo build --verbose --release - name: Cargo test diff --git a/host-program/src/bin/add-go-prove.rs b/host-program/src/bin/add-go-prove.rs index ce1dfd84..a6f8970d 100644 --- a/host-program/src/bin/add-go-prove.rs +++ b/host-program/src/bin/add-go-prove.rs @@ -1,230 +1,6 @@ use serde::{Deserialize, Serialize}; use std::env; -use std::fs::File; - -use std::io::BufReader; -use std::ops::Range; -use std::time::Duration; - -use plonky2::field::goldilocks_field::GoldilocksField; -use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; -use plonky2::util::timing::TimingTree; -use plonky2x::backend::circuit::Groth16WrapperParameters; -use plonky2x::backend::wrapper::wrap::WrappedCircuit; -use plonky2x::frontend::builder::CircuitBuilder as WrapperBuilder; -use plonky2x::prelude::DefaultParameters; use zkm_emulator::utils::{load_elf_with_patch, split_prog_into_segs}; -use zkm_prover::all_stark::AllStark; -use zkm_prover::config::StarkConfig; -use zkm_prover::cpu::kernel::assembler::segment_kernel; -use zkm_prover::fixed_recursive_verifier::AllRecursiveCircuits; -use zkm_prover::proof; -use zkm_prover::proof::PublicValues; -use zkm_prover::prover::prove; -use zkm_prover::verifier::verify_proof; - -const DEGREE_BITS_RANGE: [Range; 6] = [10..21, 12..22, 12..21, 8..21, 6..21, 13..23]; - -fn prove_single_seg_common( - seg_file: &str, - basedir: &str, - block: &str, - file: &str, - seg_size: usize, -) { - let seg_reader = BufReader::new(File::open(seg_file).unwrap()); - let kernel = segment_kernel(basedir, block, file, seg_reader, seg_size); - - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - let allstark: AllStark = AllStark::default(); - let config = StarkConfig::standard_fast_config(); - let mut timing = TimingTree::new("prove", log::Level::Info); - let allproof: proof::AllProof = - prove(&allstark, &kernel, &config, &mut timing).unwrap(); - let mut count_bytes = 0; - for (row, proof) in allproof.stark_proofs.clone().iter().enumerate() { - let proof_str = serde_json::to_string(&proof.proof).unwrap(); - log::info!("row:{} proof bytes:{}", row, proof_str.len()); - count_bytes += proof_str.len(); - } - timing.filter(Duration::from_millis(100)).print(); - log::info!("total proof bytes:{}KB", count_bytes / 1024); - verify_proof(&allstark, allproof, &config).unwrap(); - log::info!("Prove done"); -} - -fn prove_multi_seg_common( - seg_dir: &str, - basedir: &str, - block: &str, - file: &str, - seg_size: usize, - seg_file_number: usize, - seg_start_id: usize, -) -> anyhow::Result<()> { - type InnerParameters = DefaultParameters; - type OuterParameters = Groth16WrapperParameters; - - type F = GoldilocksField; - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - - if seg_file_number < 2 { - panic!("seg file number must >= 2\n"); - } - - let total_timing = TimingTree::new("prove total time", log::Level::Info); - let all_stark = AllStark::::default(); - let config = StarkConfig::standard_fast_config(); - // Preprocess all circuits. - let all_circuits = - AllRecursiveCircuits::::new(&all_stark, &DEGREE_BITS_RANGE, &config); - - let seg_file = format!("{}/{}", seg_dir, seg_start_id); - log::info!("Process segment {}", seg_file); - let seg_reader = BufReader::new(File::open(seg_file)?); - let input_first = segment_kernel(basedir, block, file, seg_reader, seg_size); - let mut timing = TimingTree::new("prove root first", log::Level::Info); - let (mut agg_proof, mut updated_agg_public_values) = - all_circuits.prove_root(&all_stark, &input_first, &config, &mut timing)?; - - timing.filter(Duration::from_millis(100)).print(); - all_circuits.verify_root(agg_proof.clone())?; - - let mut base_seg = seg_start_id + 1; - let mut seg_num = seg_file_number - 1; - let mut is_agg = false; - - if seg_file_number % 2 == 0 { - let seg_file = format!("{}/{}", seg_dir, seg_start_id + 1); - log::info!("Process segment {}", seg_file); - let seg_reader = BufReader::new(File::open(seg_file)?); - let input = segment_kernel(basedir, block, file, seg_reader, seg_size); - timing = TimingTree::new("prove root second", log::Level::Info); - let (root_proof, public_values) = - all_circuits.prove_root(&all_stark, &input, &config, &mut timing)?; - timing.filter(Duration::from_millis(100)).print(); - - all_circuits.verify_root(root_proof.clone())?; - - // Update public values for the aggregation. - let agg_public_values = PublicValues { - roots_before: updated_agg_public_values.roots_before, - roots_after: public_values.roots_after, - userdata: public_values.userdata, - }; - timing = TimingTree::new("prove aggression", log::Level::Info); - // We can duplicate the proofs here because the state hasn't mutated. - (agg_proof, updated_agg_public_values) = all_circuits.prove_aggregation( - false, - &agg_proof, - false, - &root_proof, - agg_public_values.clone(), - )?; - timing.filter(Duration::from_millis(100)).print(); - all_circuits.verify_aggregation(&agg_proof)?; - - is_agg = true; - base_seg = seg_start_id + 2; - seg_num -= 1; - } - - for i in 0..seg_num / 2 { - let seg_file = format!("{}/{}", seg_dir, base_seg + (i << 1)); - log::info!("Process segment {}", seg_file); - let seg_reader = BufReader::new(File::open(&seg_file)?); - let input_first = segment_kernel(basedir, block, file, seg_reader, seg_size); - let mut timing = TimingTree::new("prove root first", log::Level::Info); - let (root_proof_first, first_public_values) = - all_circuits.prove_root(&all_stark, &input_first, &config, &mut timing)?; - - timing.filter(Duration::from_millis(100)).print(); - all_circuits.verify_root(root_proof_first.clone())?; - - let seg_file = format!("{}/{}", seg_dir, base_seg + (i << 1) + 1); - log::info!("Process segment {}", seg_file); - let seg_reader = BufReader::new(File::open(&seg_file)?); - let input = segment_kernel(basedir, block, file, seg_reader, seg_size); - let mut timing = TimingTree::new("prove root second", log::Level::Info); - let (root_proof, public_values) = - all_circuits.prove_root(&all_stark, &input, &config, &mut timing)?; - timing.filter(Duration::from_millis(100)).print(); - - all_circuits.verify_root(root_proof.clone())?; - - // Update public values for the aggregation. - let new_agg_public_values = PublicValues { - roots_before: first_public_values.roots_before, - roots_after: public_values.roots_after, - userdata: public_values.userdata, - }; - timing = TimingTree::new("prove aggression", log::Level::Info); - // We can duplicate the proofs here because the state hasn't mutated. - let (new_agg_proof, new_updated_agg_public_values) = all_circuits.prove_aggregation( - false, - &root_proof_first, - false, - &root_proof, - new_agg_public_values, - )?; - timing.filter(Duration::from_millis(100)).print(); - all_circuits.verify_aggregation(&new_agg_proof)?; - - // Update public values for the nested aggregation. - let agg_public_values = PublicValues { - roots_before: updated_agg_public_values.roots_before, - roots_after: new_updated_agg_public_values.roots_after, - userdata: new_updated_agg_public_values.userdata, - }; - timing = TimingTree::new("prove nested aggression", log::Level::Info); - - // We can duplicate the proofs here because the state hasn't mutated. - (agg_proof, updated_agg_public_values) = all_circuits.prove_aggregation( - is_agg, - &agg_proof, - true, - &new_agg_proof, - agg_public_values.clone(), - )?; - is_agg = true; - timing.filter(Duration::from_millis(100)).print(); - - all_circuits.verify_aggregation(&agg_proof)?; - } - - let (block_proof, _block_public_values) = - all_circuits.prove_block(None, &agg_proof, updated_agg_public_values)?; - - log::info!( - "proof size: {:?}", - serde_json::to_string(&block_proof.proof).unwrap().len() - ); - let result = all_circuits.verify_block(&block_proof); - - let build_path = "verifier/data".to_string(); - let path = format!("{}/test_circuit/", build_path); - let builder = WrapperBuilder::::new(); - let mut circuit = builder.build(); - circuit.set_data(all_circuits.block.circuit); - let mut bit_size = vec![32usize; 16]; - bit_size.extend(vec![8; 32]); - bit_size.extend(vec![64; 68]); - let wrapped_circuit = WrappedCircuit::::build( - circuit, - Some((vec![], bit_size)), - ); - log::info!("build finish"); - - let wrapped_proof = wrapped_circuit.prove(&block_proof).unwrap(); - wrapped_proof.save(path).unwrap(); - - total_timing.filter(Duration::from_millis(100)).print(); - result -} #[derive(Debug, Clone, Deserialize, Serialize)] pub enum DataId { @@ -302,8 +78,12 @@ fn main() { } if seg_num == 1 { let seg_file = format!("{seg_path}/{}", 0); - prove_single_seg_common(&seg_file, "", "", "", total_steps) + zkm_sdk::local::util::prove_single_seg_common(&seg_file, "", "", "", total_steps) } else { - prove_multi_seg_common(&seg_path, "", "", "", seg_size, seg_num, 0).unwrap() + let outdir = "verifier/data/test_circuit/".to_string(); + zkm_sdk::local::util::prove_multi_seg_common( + &seg_path, "", "", "", &outdir, seg_size, seg_num, 0, + ) + .unwrap() } } diff --git a/host-program/src/bin/revme-prove.rs b/host-program/src/bin/revme-prove.rs index 1e7ec208..30835525 100644 --- a/host-program/src/bin/revme-prove.rs +++ b/host-program/src/bin/revme-prove.rs @@ -2,229 +2,7 @@ use serde::{Deserialize, Serialize}; use std::env; use std::fs::File; use std::io::prelude::*; -use std::io::BufReader; -use std::ops::Range; -use std::time::Duration; - -use plonky2::field::goldilocks_field::GoldilocksField; -use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; -use plonky2::util::timing::TimingTree; -use plonky2x::backend::circuit::Groth16WrapperParameters; -use plonky2x::backend::wrapper::wrap::WrappedCircuit; -use plonky2x::frontend::builder::CircuitBuilder as WrapperBuilder; -use plonky2x::prelude::DefaultParameters; use zkm_emulator::utils::{load_elf_with_patch, split_prog_into_segs}; -use zkm_prover::all_stark::AllStark; -use zkm_prover::config::StarkConfig; -use zkm_prover::cpu::kernel::assembler::segment_kernel; -use zkm_prover::fixed_recursive_verifier::AllRecursiveCircuits; -use zkm_prover::proof; -use zkm_prover::proof::PublicValues; -use zkm_prover::prover::prove; -use zkm_prover::verifier::verify_proof; - -const DEGREE_BITS_RANGE: [Range; 6] = [10..21, 12..22, 12..21, 8..21, 6..21, 13..23]; - -fn prove_single_seg_common( - seg_file: &str, - basedir: &str, - block: &str, - file: &str, - seg_size: usize, -) { - let seg_reader = BufReader::new(File::open(seg_file).unwrap()); - let kernel = segment_kernel(basedir, block, file, seg_reader, seg_size); - - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - type F = >::F; - - let allstark: AllStark = AllStark::default(); - let config = StarkConfig::standard_fast_config(); - let mut timing = TimingTree::new("prove", log::Level::Info); - let allproof: proof::AllProof = - prove(&allstark, &kernel, &config, &mut timing).unwrap(); - let mut count_bytes = 0; - for (row, proof) in allproof.stark_proofs.clone().iter().enumerate() { - let proof_str = serde_json::to_string(&proof.proof).unwrap(); - log::info!("row:{} proof bytes:{}", row, proof_str.len()); - count_bytes += proof_str.len(); - } - timing.filter(Duration::from_millis(100)).print(); - log::info!("total proof bytes:{}KB", count_bytes / 1024); - verify_proof(&allstark, allproof, &config).unwrap(); - log::info!("Prove done"); -} - -fn prove_multi_seg_common( - seg_dir: &str, - basedir: &str, - block: &str, - file: &str, - seg_size: usize, - seg_file_number: usize, - seg_start_id: usize, -) -> anyhow::Result<()> { - type InnerParameters = DefaultParameters; - type OuterParameters = Groth16WrapperParameters; - - type F = GoldilocksField; - const D: usize = 2; - type C = PoseidonGoldilocksConfig; - - if seg_file_number < 2 { - panic!("seg file number must >= 2\n"); - } - - let total_timing = TimingTree::new("prove total time", log::Level::Info); - let all_stark = AllStark::::default(); - let config = StarkConfig::standard_fast_config(); - // Preprocess all circuits. - let all_circuits = - AllRecursiveCircuits::::new(&all_stark, &DEGREE_BITS_RANGE, &config); - - let seg_file = format!("{}/{}", seg_dir, seg_start_id); - log::info!("Process segment {}", seg_file); - let seg_reader = BufReader::new(File::open(seg_file)?); - let input_first = segment_kernel(basedir, block, file, seg_reader, seg_size); - let mut timing = TimingTree::new("prove root first", log::Level::Info); - let (mut agg_proof, mut updated_agg_public_values) = - all_circuits.prove_root(&all_stark, &input_first, &config, &mut timing)?; - - timing.filter(Duration::from_millis(100)).print(); - all_circuits.verify_root(agg_proof.clone())?; - - let mut base_seg = seg_start_id + 1; - let mut seg_num = seg_file_number - 1; - let mut is_agg = false; - - if seg_file_number % 2 == 0 { - let seg_file = format!("{}/{}", seg_dir, seg_start_id + 1); - log::info!("Process segment {}", seg_file); - let seg_reader = BufReader::new(File::open(seg_file)?); - let input = segment_kernel(basedir, block, file, seg_reader, seg_size); - timing = TimingTree::new("prove root second", log::Level::Info); - let (root_proof, public_values) = - all_circuits.prove_root(&all_stark, &input, &config, &mut timing)?; - timing.filter(Duration::from_millis(100)).print(); - - all_circuits.verify_root(root_proof.clone())?; - - // Update public values for the aggregation. - let agg_public_values = PublicValues { - roots_before: updated_agg_public_values.roots_before, - roots_after: public_values.roots_after, - userdata: public_values.userdata, - }; - timing = TimingTree::new("prove aggression", log::Level::Info); - // We can duplicate the proofs here because the state hasn't mutated. - (agg_proof, updated_agg_public_values) = all_circuits.prove_aggregation( - false, - &agg_proof, - false, - &root_proof, - agg_public_values.clone(), - )?; - timing.filter(Duration::from_millis(100)).print(); - all_circuits.verify_aggregation(&agg_proof)?; - - is_agg = true; - base_seg = seg_start_id + 2; - seg_num -= 1; - } - - for i in 0..seg_num / 2 { - let seg_file = format!("{}/{}", seg_dir, base_seg + (i << 1)); - log::info!("Process segment {}", seg_file); - let seg_reader = BufReader::new(File::open(&seg_file)?); - let input_first = segment_kernel(basedir, block, file, seg_reader, seg_size); - let mut timing = TimingTree::new("prove root first", log::Level::Info); - let (root_proof_first, first_public_values) = - all_circuits.prove_root(&all_stark, &input_first, &config, &mut timing)?; - - timing.filter(Duration::from_millis(100)).print(); - all_circuits.verify_root(root_proof_first.clone())?; - - let seg_file = format!("{}/{}", seg_dir, base_seg + (i << 1) + 1); - log::info!("Process segment {}", seg_file); - let seg_reader = BufReader::new(File::open(&seg_file)?); - let input = segment_kernel(basedir, block, file, seg_reader, seg_size); - let mut timing = TimingTree::new("prove root second", log::Level::Info); - let (root_proof, public_values) = - all_circuits.prove_root(&all_stark, &input, &config, &mut timing)?; - timing.filter(Duration::from_millis(100)).print(); - - all_circuits.verify_root(root_proof.clone())?; - - // Update public values for the aggregation. - let new_agg_public_values = PublicValues { - roots_before: first_public_values.roots_before, - roots_after: public_values.roots_after, - userdata: public_values.userdata, - }; - timing = TimingTree::new("prove aggression", log::Level::Info); - // We can duplicate the proofs here because the state hasn't mutated. - let (new_agg_proof, new_updated_agg_public_values) = all_circuits.prove_aggregation( - false, - &root_proof_first, - false, - &root_proof, - new_agg_public_values, - )?; - timing.filter(Duration::from_millis(100)).print(); - all_circuits.verify_aggregation(&new_agg_proof)?; - - // Update public values for the nested aggregation. - let agg_public_values = PublicValues { - roots_before: updated_agg_public_values.roots_before, - roots_after: new_updated_agg_public_values.roots_after, - userdata: new_updated_agg_public_values.userdata, - }; - timing = TimingTree::new("prove nested aggression", log::Level::Info); - - // We can duplicate the proofs here because the state hasn't mutated. - (agg_proof, updated_agg_public_values) = all_circuits.prove_aggregation( - is_agg, - &agg_proof, - true, - &new_agg_proof, - agg_public_values.clone(), - )?; - is_agg = true; - timing.filter(Duration::from_millis(100)).print(); - - all_circuits.verify_aggregation(&agg_proof)?; - } - - let (block_proof, _block_public_values) = - all_circuits.prove_block(None, &agg_proof, updated_agg_public_values)?; - - log::info!( - "proof size: {:?}", - serde_json::to_string(&block_proof.proof).unwrap().len() - ); - let result = all_circuits.verify_block(&block_proof); - - let build_path = "verifier/data".to_string(); - let path = format!("{}/test_circuit/", build_path); - let builder = WrapperBuilder::::new(); - let mut circuit = builder.build(); - circuit.set_data(all_circuits.block.circuit); - let mut bit_size = vec![32usize; 16]; - bit_size.extend(vec![8; 32]); - bit_size.extend(vec![64; 68]); - let wrapped_circuit = WrappedCircuit::::build( - circuit, - Some((vec![], bit_size)), - ); - log::info!("build finish"); - - let wrapped_proof = wrapped_circuit.prove(&block_proof).unwrap(); - wrapped_proof.save(path).unwrap(); - - total_timing.filter(Duration::from_millis(100)).print(); - result -} #[derive(Debug, Clone, Deserialize, Serialize)] pub enum DataId { @@ -301,8 +79,12 @@ fn main() { if seg_num == 1 { let seg_file = format!("{seg_path}/{}", 0); - prove_single_seg_common(&seg_file, "", "", "", total_steps) + zkm_sdk::local::util::prove_single_seg_common(&seg_file, "", "", "", total_steps) } else { - prove_multi_seg_common(&seg_path, "", "", "", seg_size, seg_num, 0).unwrap() + let outdir = "verifier/data/test_circuit/".to_string(); + zkm_sdk::local::util::prove_multi_seg_common( + &seg_path, "", "", "", &outdir, seg_size, seg_num, 0, + ) + .unwrap() } } diff --git a/sdk/Cargo.toml b/sdk/Cargo.toml index b9e29337..454c39bc 100644 --- a/sdk/Cargo.toml +++ b/sdk/Cargo.toml @@ -6,17 +6,11 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -##zkm-emulator = { path = "../emulator" } -#plonky2 = { path = "../plonky2/plonky2" } -##starky = { path = "../plonky2/starky" } -#plonky2_util = { path = "../plonky2/util" } -#plonky2_maybe_rayon = { path = "../plonky2/maybe_rayon" } - +libc = "0.2" bincode = "1.3.3" async-trait = "0.1" -stage-service = {package = "service", git = "https://github.com/zkMIPS/zkm-prover", branch = "main", default-features = false } zkm-prover = { git = "https://github.com/zkMIPS/zkm", branch = "main", default-features = false } zkm-emulator = { git = "https://github.com/zkMIPS/zkm", branch = "main", default-features = false } common = { git = "https://github.com/zkMIPS/zkm-prover", branch = "main", default-features = false } diff --git a/sdk/build.rs b/sdk/build.rs index 2c94645c..5150fdaa 100644 --- a/sdk/build.rs +++ b/sdk/build.rs @@ -1,4 +1,6 @@ fn main() -> Result<(), Box> { + println!("cargo:rustc-link-search=native=./sdk/src/local/libsnark"); + println!("cargo:rustc-link-lib=dylib=snark"); tonic_build::configure() .protoc_arg("--experimental_allow_proto3_optional") .compile(&["src/proto/stage.proto"], &["src/proto"])?; diff --git a/sdk/src/lib.rs b/sdk/src/lib.rs index a5166e71..ace583dd 100644 --- a/sdk/src/lib.rs +++ b/sdk/src/lib.rs @@ -2,6 +2,7 @@ pub mod local; pub mod network; pub mod prover; +use local::prover::LocalProver; use network::prover::NetworkProver; use prover::Prover; use std::env; @@ -18,9 +19,9 @@ impl ProverClient { .to_lowercase() .as_str() { - // "local" => Self { - // prover: Box::new(CpuProver::new()), - // }, + "local" => Self { + prover: Box::new(LocalProver::new()), + }, "network" => Self { prover: Box::new(NetworkProver::new().await.unwrap()), }, @@ -30,9 +31,11 @@ impl ProverClient { } } - // pub fn local() -> Self { - // Self { prover: Box::new(CpuProver::new()) } - // } + pub fn local() -> Self { + Self { + prover: Box::new(LocalProver::new()), + } + } pub async fn network() -> Self { Self { diff --git a/sdk/src/local/cpu.rs b/sdk/src/local/cpu.rs deleted file mode 100644 index c62a0c0e..00000000 --- a/sdk/src/local/cpu.rs +++ /dev/null @@ -1,455 +0,0 @@ -// use serde::{Deserialize, Serialize}; -// use std::env; -// use std::fs::File; -// use std::io::prelude::*; -// use std::io::BufReader; -// use std::ops::Range; -// use std::time::Duration; - -// use plonky2::field::goldilocks_field::GoldilocksField; -// use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; -// use plonky2::util::timing::TimingTree; -// use plonky2x::backend::circuit::Groth16WrapperParameters; -// use plonky2x::backend::wrapper::wrap::WrappedCircuit; -// use plonky2x::frontend::builder::CircuitBuilder as WrapperBuilder; -// use plonky2x::prelude::DefaultParameters; -// use zkm_emulator::utils::{ -// get_block_path, load_elf_with_patch, split_prog_into_segs, SEGMENT_STEPS, -// }; -// use zkm_prover::all_stark::AllStark; -// use zkm_prover::config::StarkConfig; -// use zkm_prover::cpu::kernel::assembler::segment_kernel; -// use zkm_prover::fixed_recursive_verifier::AllRecursiveCircuits; -// use zkm_prover::proof; -// use zkm_prover::proof::PublicValues; -// use zkm_prover::prover::prove; -// use zkm_prover::verifier::verify_proof; - -// const DEGREE_BITS_RANGE: [Range; 6] = [10..21, 12..22, 12..21, 8..21, 6..21, 13..23]; - -// fn split_segments() { -// // 1. split ELF into segs -// let basedir = env::var("BASEDIR").unwrap_or("/tmp/cannon".to_string()); -// let elf_path = env::var("ELF_PATH").expect("ELF file is missing"); -// let block_no = env::var("BLOCK_NO").unwrap_or("".to_string()); -// let seg_path = env::var("SEG_OUTPUT").expect("Segment output path is missing"); -// let seg_size = env::var("SEG_SIZE").unwrap_or(format!("{SEGMENT_STEPS}")); -// let seg_size = seg_size.parse::<_>().unwrap_or(SEGMENT_STEPS); -// let args = env::var("ARGS").unwrap_or("".to_string()); -// let args = args.split_whitespace().collect(); - -// let mut state = load_elf_with_patch(&elf_path, args); -// let block_path = get_block_path(&basedir, &block_no, ""); -// if !block_no.is_empty() { -// state.load_input(&block_path); -// } -// let _ = split_prog_into_segs(state, &seg_path, &block_path, seg_size); -// } - -// fn prove_single_seg_common( -// seg_file: &str, -// basedir: &str, -// block: &str, -// file: &str, -// seg_size: usize, -// ) { -// let seg_reader = BufReader::new(File::open(seg_file).unwrap()); -// let kernel = segment_kernel(basedir, block, file, seg_reader, seg_size); - -// const D: usize = 2; -// type C = PoseidonGoldilocksConfig; -// type F = >::F; - -// let allstark: AllStark = AllStark::default(); -// let config = StarkConfig::standard_fast_config(); -// let mut timing = TimingTree::new("prove", log::Level::Info); -// let allproof: proof::AllProof = -// prove(&allstark, &kernel, &config, &mut timing).unwrap(); -// let mut count_bytes = 0; -// for (row, proof) in allproof.stark_proofs.clone().iter().enumerate() { -// let proof_str = serde_json::to_string(&proof.proof).unwrap(); -// log::info!("row:{} proof bytes:{}", row, proof_str.len()); -// count_bytes += proof_str.len(); -// } -// timing.filter(Duration::from_millis(100)).print(); -// log::info!("total proof bytes:{}KB", count_bytes / 1024); -// verify_proof(&allstark, allproof, &config).unwrap(); -// log::info!("Prove done"); -// } - -// fn prove_multi_seg_common( -// seg_dir: &str, -// basedir: &str, -// block: &str, -// file: &str, -// seg_size: usize, -// seg_file_number: usize, -// seg_start_id: usize, -// ) -> anyhow::Result<()> { -// type InnerParameters = DefaultParameters; -// type OuterParameters = Groth16WrapperParameters; - -// type F = GoldilocksField; -// const D: usize = 2; -// type C = PoseidonGoldilocksConfig; - -// if seg_file_number < 2 { -// panic!("seg file number must >= 2\n"); -// } - -// let total_timing = TimingTree::new("prove total time", log::Level::Info); -// let all_stark = AllStark::::default(); -// let config = StarkConfig::standard_fast_config(); -// // Preprocess all circuits. -// let all_circuits = -// AllRecursiveCircuits::::new(&all_stark, &DEGREE_BITS_RANGE, &config); - -// let seg_file = format!("{}/{}", seg_dir, seg_start_id); -// log::info!("Process segment {}", seg_file); -// let seg_reader = BufReader::new(File::open(seg_file)?); -// let input_first = segment_kernel(basedir, block, file, seg_reader, seg_size); -// let mut timing = TimingTree::new("prove root first", log::Level::Info); -// let (mut agg_proof, mut updated_agg_public_values) = -// all_circuits.prove_root(&all_stark, &input_first, &config, &mut timing)?; - -// timing.filter(Duration::from_millis(100)).print(); -// all_circuits.verify_root(agg_proof.clone())?; - -// let mut base_seg = seg_start_id + 1; -// let mut seg_num = seg_file_number - 1; -// let mut is_agg = false; - -// if seg_file_number % 2 == 0 { -// let seg_file = format!("{}/{}", seg_dir, seg_start_id + 1); -// log::info!("Process segment {}", seg_file); -// let seg_reader = BufReader::new(File::open(seg_file)?); -// let input = segment_kernel(basedir, block, file, seg_reader, seg_size); -// timing = TimingTree::new("prove root second", log::Level::Info); -// let (root_proof, public_values) = -// all_circuits.prove_root(&all_stark, &input, &config, &mut timing)?; -// timing.filter(Duration::from_millis(100)).print(); - -// all_circuits.verify_root(root_proof.clone())?; - -// // Update public values for the aggregation. -// let agg_public_values = PublicValues { -// roots_before: updated_agg_public_values.roots_before, -// roots_after: public_values.roots_after, -// userdata: public_values.userdata, -// }; -// timing = TimingTree::new("prove aggression", log::Level::Info); -// // We can duplicate the proofs here because the state hasn't mutated. -// (agg_proof, updated_agg_public_values) = all_circuits.prove_aggregation( -// false, -// &agg_proof, -// false, -// &root_proof, -// agg_public_values.clone(), -// )?; -// timing.filter(Duration::from_millis(100)).print(); -// all_circuits.verify_aggregation(&agg_proof)?; - -// is_agg = true; -// base_seg = seg_start_id + 2; -// seg_num -= 1; -// } - -// for i in 0..seg_num / 2 { -// let seg_file = format!("{}/{}", seg_dir, base_seg + (i << 1)); -// log::info!("Process segment {}", seg_file); -// let seg_reader = BufReader::new(File::open(&seg_file)?); -// let input_first = segment_kernel(basedir, block, file, seg_reader, seg_size); -// let mut timing = TimingTree::new("prove root first", log::Level::Info); -// let (root_proof_first, first_public_values) = -// all_circuits.prove_root(&all_stark, &input_first, &config, &mut timing)?; - -// timing.filter(Duration::from_millis(100)).print(); -// all_circuits.verify_root(root_proof_first.clone())?; - -// let seg_file = format!("{}/{}", seg_dir, base_seg + (i << 1) + 1); -// log::info!("Process segment {}", seg_file); -// let seg_reader = BufReader::new(File::open(&seg_file)?); -// let input = segment_kernel(basedir, block, file, seg_reader, seg_size); -// let mut timing = TimingTree::new("prove root second", log::Level::Info); -// let (root_proof, public_values) = -// all_circuits.prove_root(&all_stark, &input, &config, &mut timing)?; -// timing.filter(Duration::from_millis(100)).print(); - -// all_circuits.verify_root(root_proof.clone())?; - -// // Update public values for the aggregation. -// let new_agg_public_values = PublicValues { -// roots_before: first_public_values.roots_before, -// roots_after: public_values.roots_after, -// userdata: public_values.userdata, -// }; -// timing = TimingTree::new("prove aggression", log::Level::Info); -// // We can duplicate the proofs here because the state hasn't mutated. -// let (new_agg_proof, new_updated_agg_public_values) = all_circuits.prove_aggregation( -// false, -// &root_proof_first, -// false, -// &root_proof, -// new_agg_public_values, -// )?; -// timing.filter(Duration::from_millis(100)).print(); -// all_circuits.verify_aggregation(&new_agg_proof)?; - -// // Update public values for the nested aggregation. -// let agg_public_values = PublicValues { -// roots_before: updated_agg_public_values.roots_before, -// roots_after: new_updated_agg_public_values.roots_after, -// userdata: new_updated_agg_public_values.userdata, -// }; -// timing = TimingTree::new("prove nested aggression", log::Level::Info); - -// // We can duplicate the proofs here because the state hasn't mutated. -// (agg_proof, updated_agg_public_values) = all_circuits.prove_aggregation( -// is_agg, -// &agg_proof, -// true, -// &new_agg_proof, -// agg_public_values.clone(), -// )?; -// is_agg = true; -// timing.filter(Duration::from_millis(100)).print(); - -// all_circuits.verify_aggregation(&agg_proof)?; -// } - -// let (block_proof, _block_public_values) = -// all_circuits.prove_block(None, &agg_proof, updated_agg_public_values)?; - -// log::info!( -// "proof size: {:?}", -// serde_json::to_string(&block_proof.proof).unwrap().len() -// ); -// let result = all_circuits.verify_block(&block_proof); - -// let build_path = "../verifier/data".to_string(); -// let path = format!("{}/test_circuit/", build_path); -// let builder = WrapperBuilder::::new(); -// let mut circuit = builder.build(); -// circuit.set_data(all_circuits.block.circuit); -// let mut bit_size = vec![32usize; 16]; -// bit_size.extend(vec![8; 32]); -// bit_size.extend(vec![64; 68]); -// let wrapped_circuit = WrappedCircuit::::build( -// circuit, -// Some((vec![], bit_size)), -// ); -// log::info!("build finish"); - -// let wrapped_proof = wrapped_circuit.prove(&block_proof).unwrap(); -// wrapped_proof.save(path).unwrap(); - -// total_timing.filter(Duration::from_millis(100)).print(); -// result -// } - -// fn prove_sha2_bench() { -// // 1. split ELF into segs -// let elf_path = env::var("ELF_PATH").expect("ELF file is missing"); -// let seg_path = env::var("SEG_OUTPUT").expect("Segment output path is missing"); - -// let mut state = load_elf_with_patch(&elf_path, vec![]); -// // load input -// let args = env::var("ARGS").unwrap_or("data-to-hash".to_string()); -// // assume the first arg is the hash output(which is a public input), and the second is the input. -// let args: Vec<&str> = args.split_whitespace().collect(); -// assert_eq!(args.len(), 2); - -// let public_input: Vec = hex::decode(args[0]).unwrap(); -// state.add_input_stream(&public_input); -// log::info!("expected public value in hex: {:X?}", args[0]); -// log::info!("expected public value: {:X?}", public_input); - -// let private_input = args[1].as_bytes().to_vec(); -// log::info!("private input value: {:X?}", private_input); -// state.add_input_stream(&private_input); - -// let (total_steps, mut state) = split_prog_into_segs(state, &seg_path, "", 0); - -// let value = state.read_public_values::<[u8; 32]>(); -// log::info!("public value: {:X?}", value); -// log::info!("public value: {} in hex", hex::encode(value)); - -// let seg_file = format!("{seg_path}/{}", 0); -// prove_single_seg_common(&seg_file, "", "", "", total_steps); -// } - -// fn prove_revm() { -// // 1. split ELF into segs -// let elf_path = env::var("ELF_PATH").expect("ELF file is missing"); -// let seg_path = env::var("SEG_OUTPUT").expect("Segment output path is missing"); -// let json_path = env::var("JSON_PATH").expect("JSON file is missing"); -// let seg_size = env::var("SEG_SIZE").unwrap_or("0".to_string()); -// let seg_size = seg_size.parse::<_>().unwrap_or(0); -// let mut f = File::open(json_path).unwrap(); -// let mut data = vec![]; -// f.read_to_end(&mut data).unwrap(); - -// let mut state = load_elf_with_patch(&elf_path, vec![]); -// // load input -// state.add_input_stream(&data); - -// let (total_steps, mut _state) = split_prog_into_segs(state, &seg_path, "", seg_size); - -// let mut seg_num = 1usize; -// if seg_size != 0 { -// seg_num = (total_steps + seg_size - 1) / seg_size; -// } - -// if seg_num == 1 { -// let seg_file = format!("{seg_path}/{}", 0); -// prove_single_seg_common(&seg_file, "", "", "", total_steps) -// } else { -// prove_multi_seg_common(&seg_path, "", "", "", seg_size, seg_num, 0).unwrap() -// } -// } - -// #[derive(Debug, Clone, Deserialize, Serialize)] -// pub enum DataId { -// TYPE1, -// TYPE2, -// TYPE3, -// } - -// #[derive(Debug, Clone, Deserialize, Serialize)] -// pub struct Data { -// pub input1: [u8; 10], -// pub input2: u8, -// pub input3: i8, -// pub input4: u16, -// pub input5: i16, -// pub input6: u32, -// pub input7: i32, -// pub input8: u64, -// pub input9: i64, -// pub input10: Vec, -// pub input11: DataId, -// pub input12: String, -// } - -// impl Default for Data { -// fn default() -> Self { -// Self::new() -// } -// } - -// impl Data { -// pub fn new() -> Self { -// let array = [1u8, 2u8, 3u8, 4u8, 5u8, 6u8, 7u8, 8u8, 9u8, 10u8]; -// Self { -// input1: array, -// input2: 0x11u8, -// input3: -1i8, -// input4: 0x1122u16, -// input5: -1i16, -// input6: 0x112233u32, -// input7: -1i32, -// input8: 0x1122334455u64, -// input9: -1i64, -// input10: array[1..3].to_vec(), -// input11: DataId::TYPE3, -// input12: "hello".to_string(), -// } -// } -// } - -// fn prove_add_example() { -// // 1. split ELF into segs -// let elf_path = env::var("ELF_PATH").expect("ELF file is missing"); -// let seg_path = env::var("SEG_OUTPUT").expect("Segment output path is missing"); -// let seg_size = env::var("SEG_SIZE").unwrap_or("0".to_string()); -// let seg_size = seg_size.parse::<_>().unwrap_or(0); - -// let mut state = load_elf_with_patch(&elf_path, vec![]); - -// let data = Data::new(); -// state.add_input_stream(&data); -// log::info!( -// "enum {} {} {}", -// DataId::TYPE1 as u8, -// DataId::TYPE2 as u8, -// DataId::TYPE3 as u8 -// ); -// log::info!("public input: {:X?}", data); - -// let (total_steps, mut state) = split_prog_into_segs(state, &seg_path, "", seg_size); - -// let value = state.read_public_values::(); -// log::info!("public value: {:X?}", value); - -// let mut seg_num = 1usize; -// if seg_size != 0 { -// seg_num = (total_steps + seg_size - 1) / seg_size; -// } - -// if seg_num == 1 { -// let seg_file = format!("{seg_path}/{}", 0); -// prove_single_seg_common(&seg_file, "", "", "", total_steps) -// } else { -// prove_multi_seg_common(&seg_path, "", "", "", seg_size, seg_num, 0).unwrap() -// } -// } - -// fn prove_host() { -// let host_program = env::var("HOST_PROGRAM").expect("host_program name is missing"); -// match host_program.as_str() { -// "sha2_bench" => prove_sha2_bench(), -// "revm" => prove_revm(), -// "add_example" => prove_add_example(), -// _ => log::error!("Host program {} is not supported!", host_program), -// }; -// } - -// fn prove_segments() { -// let basedir = env::var("BASEDIR").unwrap_or("/tmp/cannon".to_string()); -// let block = env::var("BLOCK_NO").unwrap_or("".to_string()); -// let file = env::var("BLOCK_FILE").unwrap_or(String::from("")); -// let seg_dir = env::var("SEG_FILE_DIR").expect("segment file dir is missing"); -// let seg_num = env::var("SEG_NUM").unwrap_or("1".to_string()); -// let seg_num = seg_num.parse::<_>().unwrap_or(1usize); -// let seg_start_id = env::var("SEG_START_ID").unwrap_or("0".to_string()); -// let seg_start_id = seg_start_id.parse::<_>().unwrap_or(0usize); -// let seg_size = env::var("SEG_SIZE").unwrap_or(format!("{SEGMENT_STEPS}")); -// let seg_size = seg_size.parse::<_>().unwrap_or(SEGMENT_STEPS); - -// if seg_num == 1 { -// let seg_file = format!("{seg_dir}/{}", seg_start_id); -// prove_single_seg_common(&seg_file, &basedir, &block, &file, seg_size) -// } else { -// prove_multi_seg_common( -// &seg_dir, -// &basedir, -// &block, -// &file, -// seg_size, -// seg_num, -// seg_start_id, -// ) -// .unwrap() -// } -// } - -// // fn main() { -// // env_logger::try_init().unwrap_or_default(); -// // let args: Vec = env::args().collect(); -// // let helper = || { -// // log::info!( -// // "Help: {} split | prove_segments | prove_host_program", -// // args[0] -// // ); -// // std::process::exit(-1); -// // }; -// // if args.len() < 2 { -// // helper(); -// // } -// // match args[1].as_str() { -// // "split" => split_segments(), -// // "prove_segments" => prove_segments(), -// // "prove_host_program" => prove_host(), -// // _ => helper(), -// // }; -// // } diff --git a/sdk/src/local/libsnark/compile.sh b/sdk/src/local/libsnark/compile.sh new file mode 100755 index 00000000..0d7ff334 --- /dev/null +++ b/sdk/src/local/libsnark/compile.sh @@ -0,0 +1,23 @@ +#!/bin/bash + +cd "$(dirname "$0")" + +# Determine the operating system +OS="$(uname)" + +case "$OS" in + Linux) + echo "Running on Linux" + # Compile for Linux + go build -o libsnark.so -buildmode=c-shared *.go + ;; + Darwin) + echo "Running on macOS" + # Compile for macOS + go build -o libsnark.dylib -buildmode=c-shared *.go + ;; + *) + echo "Unsupported OS: $OS" + exit 1 + ;; +esac diff --git a/sdk/src/local/libsnark/contract.go b/sdk/src/local/libsnark/contract.go new file mode 100644 index 00000000..6ffe93a4 --- /dev/null +++ b/sdk/src/local/libsnark/contract.go @@ -0,0 +1,215 @@ +package main + +var Gtemplate = `// This file is MIT Licensed. +// +// Copyright 2017 Christian Reitwiessner +// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: +// The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +pragma solidity ^0.8.0; +library Pairing { + struct G1Point { + uint X; + uint Y; + } + // Encoding of field elements is: X[0] * z + X[1] + struct G2Point { + uint[2] X; + uint[2] Y; + } + /// @return the generator of G1 + function P1() pure internal returns (G1Point memory) { + return G1Point(1, 2); + } + /// @return the generator of G2 + function P2() pure internal returns (G2Point memory) { + return G2Point( + [10857046999023057135944570762232829481370756359578518086990519993285655852781, + 11559732032986387107991004021392285783925812861821192530917403151452391805634], + [8495653923123431417604973247489272438418190587263600148770280649306958101930, + 4082367875863433681332203403145435568316851327593401208105741076214120093531] + ); + } + /// @return the negation of p, i.e. p.addition(p.negate()) should be zero. + function negate(G1Point memory p) pure internal returns (G1Point memory) { + // The prime q in the base field F_q for G1 + uint q = 21888242871839275222246405745257275088696311157297823662689037894645226208583; + if (p.X == 0 && p.Y == 0) + return G1Point(0, 0); + return G1Point(p.X, q - (p.Y % q)); + } + /// @return r the sum of two points of G1 + function addition(G1Point memory p1, G1Point memory p2) internal view returns (G1Point memory r) { + uint[4] memory input; + input[0] = p1.X; + input[1] = p1.Y; + input[2] = p2.X; + input[3] = p2.Y; + bool success; + assembly { + success := staticcall(sub(gas(), 2000), 6, input, 0xc0, r, 0x60) + // Use "invalid" to make gas estimation work + switch success case 0 { invalid() } + } + require(success); + } + + + /// @return r the product of a point on G1 and a scalar, i.e. + /// p == p.scalar_mul(1) and p.addition(p) == p.scalar_mul(2) for all points p. + function scalar_mul(G1Point memory p, uint s) internal view returns (G1Point memory r) { + uint[3] memory input; + input[0] = p.X; + input[1] = p.Y; + input[2] = s; + bool success; + + assembly { + success := staticcall(sub(gas(), 2000), 7, input, 0x80, r, 0x60) + // Use "invalid" to make gas estimation work + switch success case 0 { invalid() } + } + + require (success); + } + /// @return the result of computing the pairing check + /// e(p1[0], p2[0]) * .... * e(p1[n], p2[n]) == 1 + /// For example pairing([P1(), P1().negate()], [P2(), P2()]) should + /// return true. + function pairing(G1Point[] memory p1, G2Point[] memory p2) internal view returns (bool) { + require(p1.length == p2.length); + uint elements = p1.length; + uint inputSize = elements * 6; + uint[] memory input = new uint[](inputSize); + for (uint i = 0; i < elements; i++) + { + input[i * 6 + 0] = p1[i].X; + input[i * 6 + 1] = p1[i].Y; + input[i * 6 + 2] = p2[i].X[1]; + input[i * 6 + 3] = p2[i].X[0]; + input[i * 6 + 4] = p2[i].Y[1]; + input[i * 6 + 5] = p2[i].Y[0]; + } + uint[1] memory out; + bool success; + + assembly { + success := staticcall(sub(gas(), 2000), 8, add(input, 0x20), mul(inputSize, 0x20), out, 0x20) + // Use "invalid" to make gas estimation work + // switch success case 0 { invalid() } + } + + require(success,"no"); + return out[0] != 0; + } + /// Convenience method for a pairing check for two pairs. + function pairingProd2(G1Point memory a1, G2Point memory a2, G1Point memory b1, G2Point memory b2) internal view returns (bool) { + G1Point[] memory p1 = new G1Point[](2); + G2Point[] memory p2 = new G2Point[](2); + p1[0] = a1; + p1[1] = b1; + p2[0] = a2; + p2[1] = b2; + return pairing(p1, p2); + } + /// Convenience method for a pairing check for three pairs. + function pairingProd3( + G1Point memory a1, G2Point memory a2, + G1Point memory b1, G2Point memory b2, + G1Point memory c1, G2Point memory c2 + ) internal view returns (bool) { + G1Point[] memory p1 = new G1Point[](3); + G2Point[] memory p2 = new G2Point[](3); + p1[0] = a1; + p1[1] = b1; + p1[2] = c1; + p2[0] = a2; + p2[1] = b2; + p2[2] = c2; + return pairing(p1, p2); + } + /// Convenience method for a pairing check for four pairs. + function pairingProd4( + G1Point memory a1, G2Point memory a2, + G1Point memory b1, G2Point memory b2, + G1Point memory c1, G2Point memory c2, + G1Point memory d1, G2Point memory d2 + ) internal view returns (bool) { + G1Point[] memory p1 = new G1Point[](4); + G2Point[] memory p2 = new G2Point[](4); + p1[0] = a1; + p1[1] = b1; + p1[2] = c1; + p1[3] = d1; + p2[0] = a2; + p2[1] = b2; + p2[2] = c2; + p2[3] = d2; + return pairing(p1, p2); + } +} + +contract Verifier { + event VerifyEvent(address user); + event Value(uint x, uint y); + + using Pairing for *; + struct VerifyingKey { + Pairing.G1Point alpha; + Pairing.G2Point beta; + Pairing.G2Point gamma; + Pairing.G2Point delta; + Pairing.G1Point[] gamma_abc; + } + struct Proof { + Pairing.G1Point a; + Pairing.G2Point b; + Pairing.G1Point c; + } + function verifyingKey() pure internal returns (VerifyingKey memory vk) { + vk.alpha = {{.Alpha}}; + vk.beta = {{.Beta}}; + vk.gamma = {{.Gamma}}; + vk.delta = {{.Delta}}; + {{.Gamma_abc}} + } + function verify(uint[65] memory input, Proof memory proof, uint[2] memory proof_commitment) public view returns (uint) { + uint256 snark_scalar_field = 21888242871839275222246405745257275088548364400416034343698204186575808495617; + + VerifyingKey memory vk = verifyingKey(); + require(input.length + 1 == vk.gamma_abc.length); + // Compute the linear combination vk_x + Pairing.G1Point memory vk_x = Pairing.G1Point(0, 0); + for (uint i = 0; i < input.length; i++) { + require(input[i] < snark_scalar_field); + vk_x = Pairing.addition(vk_x, Pairing.scalar_mul(vk.gamma_abc[i + 1], input[i])); + } + Pairing.G1Point memory p_c = Pairing.G1Point(proof_commitment[0], proof_commitment[1]); + + vk_x = Pairing.addition(vk_x, vk.gamma_abc[0]); + vk_x = Pairing.addition(vk_x, p_c); + + if(!Pairing.pairingProd4( + proof.a, proof.b, + Pairing.negate(vk_x), vk.gamma, + Pairing.negate(proof.c), vk.delta, + Pairing.negate(vk.alpha), vk.beta)) { + return 1; + } + + return 0; + } + function verifyTx( + Proof memory proof, uint[65] memory input + ,uint[2] memory proof_commitment) public returns (bool r) { + + if (verify(input, proof , proof_commitment) == 0) { + emit VerifyEvent(msg.sender); + return true; + } else { + return false; + } + + } +} +` diff --git a/sdk/src/local/libsnark/go.mod b/sdk/src/local/libsnark/go.mod new file mode 100644 index 00000000..c9d0d436 --- /dev/null +++ b/sdk/src/local/libsnark/go.mod @@ -0,0 +1,35 @@ +module zkm-project-template/sdk/libsnark + +go 1.21 + +require ( + github.com/consensys/gnark v0.9.1 + github.com/consensys/gnark-crypto v0.12.2-0.20231023220848-538dff926c15 + github.com/succinctlabs/gnark-plonky2-verifier v0.1.0 +) + +replace github.com/consensys/gnark v0.9.1 => github.com/zkMIPS/gnark v0.9.2-0.20240114074717-11112539ed1e + +require ( + github.com/bits-and-blooms/bitset v1.13.0 // indirect + github.com/blang/semver/v4 v4.0.0 // indirect + github.com/consensys/bavard v0.1.13 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/fxamacker/cbor/v2 v2.5.0 // indirect + github.com/google/pprof v0.0.0-20230817174616-7a8ec2ada47b // indirect + github.com/ingonyama-zk/icicle v0.0.0-20230928131117-97f0079e5c71 // indirect + github.com/ingonyama-zk/iciclegnark v0.1.0 // indirect + github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-isatty v0.0.19 // indirect + github.com/mmcloughlin/addchain v0.4.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/rs/zerolog v1.30.0 // indirect + github.com/stretchr/testify v1.8.4 // indirect + github.com/x448/float16 v0.8.4 // indirect + golang.org/x/crypto v0.22.0 // indirect + golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa // indirect + golang.org/x/sync v0.7.0 // indirect + golang.org/x/sys v0.19.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect + rsc.io/tmplfunc v0.0.3 // indirect +) diff --git a/sdk/src/local/libsnark/go.sum b/sdk/src/local/libsnark/go.sum new file mode 100644 index 00000000..97722b23 --- /dev/null +++ b/sdk/src/local/libsnark/go.sum @@ -0,0 +1,76 @@ +github.com/bits-and-blooms/bitset v1.13.0 h1:bAQ9OPNFYbGHV6Nez0tmNI0RiEu7/hxlYJRUA0wFAVE= +github.com/bits-and-blooms/bitset v1.13.0/go.mod h1:7hO7Gc7Pp1vODcmWvKMRA9BNmbv6a/7QIWpPxHddWR8= +github.com/blang/semver/v4 v4.0.0 h1:1PFHFE6yCCTv8C1TeyNNarDzntLi7wMI5i/pzqYIsAM= +github.com/blang/semver/v4 v4.0.0/go.mod h1:IbckMUScFkM3pff0VJDNKRiT6TG/YpiHIM2yvyW5YoQ= +github.com/consensys/bavard v0.1.13 h1:oLhMLOFGTLdlda/kma4VOJazblc7IM5y5QPd2A/YjhQ= +github.com/consensys/bavard v0.1.13/go.mod h1:9ItSMtA/dXMAiL7BG6bqW2m3NdSEObYWoH223nGHukI= +github.com/consensys/gnark-crypto v0.12.2-0.20231023220848-538dff926c15 h1:fu5ienFKWWqrfMPbWnhw4zfIFZW3pzVIbv3KtASymbU= +github.com/consensys/gnark-crypto v0.12.2-0.20231023220848-538dff926c15/go.mod h1:v2Gy7L/4ZRosZ7Ivs+9SfUDr0f5UlG+EM5t7MPHiLuY= +github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/fxamacker/cbor/v2 v2.5.0 h1:oHsG0V/Q6E/wqTS2O1Cozzsy69nqCiguo5Q1a1ADivE= +github.com/fxamacker/cbor/v2 v2.5.0/go.mod h1:TA1xS00nchWmaBnEIxPSE5oHLuJBAVvqrtAnWBwBCVo= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/pprof v0.0.0-20230817174616-7a8ec2ada47b h1:h9U78+dx9a4BKdQkBBos92HalKpaGKHrp+3Uo6yTodo= +github.com/google/pprof v0.0.0-20230817174616-7a8ec2ada47b/go.mod h1:czg5+yv1E0ZGTi6S6vVK1mke0fV+FaUhNGcd6VRS9Ik= +github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= +github.com/ingonyama-zk/icicle v0.0.0-20230928131117-97f0079e5c71 h1:YxI1RTPzpFJ3MBmxPl3Bo0F7ume7CmQEC1M9jL6CT94= +github.com/ingonyama-zk/icicle v0.0.0-20230928131117-97f0079e5c71/go.mod h1:kAK8/EoN7fUEmakzgZIYdWy1a2rBnpCaZLqSHwZWxEk= +github.com/ingonyama-zk/iciclegnark v0.1.0 h1:88MkEghzjQBMjrYRJFxZ9oR9CTIpB8NG2zLeCJSvXKQ= +github.com/ingonyama-zk/iciclegnark v0.1.0/go.mod h1:wz6+IpyHKs6UhMMoQpNqz1VY+ddfKqC/gRwR/64W6WU= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/leanovate/gopter v0.2.9 h1:fQjYxZaynp97ozCzfOyOuAGOU4aU/z37zf/tOujFk7c= +github.com/leanovate/gopter v0.2.9/go.mod h1:U2L/78B+KVFIx2VmW6onHJQzXtFb+p5y3y2Sh+Jxxv8= +github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mmcloughlin/addchain v0.4.0 h1:SobOdjm2xLj1KkXN5/n0xTIWyZA2+s99UCY1iPfkHRY= +github.com/mmcloughlin/addchain v0.4.0/go.mod h1:A86O+tHqZLMNO4w6ZZ4FlVQEadcoqkyU72HC5wJ4RlU= +github.com/mmcloughlin/profile v0.1.1/go.mod h1:IhHD7q1ooxgwTgjxQYkACGA77oFTDdFVejUS1/tS/qU= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= +github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= +github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= +github.com/rs/zerolog v1.30.0 h1:SymVODrcRsaRaSInD9yQtKbtWqwsfoPcRff/oRXLj4c= +github.com/rs/zerolog v1.30.0/go.mod h1:/tk+P47gFdPXq4QYjvCmT5/Gsug2nagsFWBWhAiSi1w= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/succinctlabs/gnark-plonky2-verifier v0.1.0 h1:5mohIEl5iZj1CIFgX4fOwqc18AQFPiOQTypoJb+OPyk= +github.com/succinctlabs/gnark-plonky2-verifier v0.1.0/go.mod h1:c144MdRU1b0w/khA+lTrTFGcRHiKp1obwv8VGv/LQzI= +github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= +github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= +github.com/zkMIPS/gnark v0.9.2-0.20240114074717-11112539ed1e h1:T2nl8q9iCxNe5ay+D0OAC+HrQaIj5PmzkbJVuJbn7lk= +github.com/zkMIPS/gnark v0.9.2-0.20240114074717-11112539ed1e/go.mod h1:s681Fp+KgDAK1Ix0FQgTMkmGOr9yjJLF9M8DOt5TKtA= +golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30= +golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M= +golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa h1:FRnLl4eNAQl8hwxVVC17teOw8kdjVDVAiFMtgUdTSRQ= +golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa/go.mod h1:zk2irFbV9DP96SEBUUAy67IdHUaZuSnrz1n472HUCLE= +golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= +golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o= +golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +rsc.io/tmplfunc v0.0.3 h1:53XFQh69AfOa8Tw0Jm7t+GV7KZhOi6jzsCzTtKbMvzU= +rsc.io/tmplfunc v0.0.3/go.mod h1:AG3sTPzElb1Io3Yg4voV9AGZJuleGAwaVRxL9M49PhA= diff --git a/sdk/src/local/libsnark/libsnark.go b/sdk/src/local/libsnark/libsnark.go new file mode 100644 index 00000000..71277657 --- /dev/null +++ b/sdk/src/local/libsnark/libsnark.go @@ -0,0 +1,22 @@ +package main + +import ( + "C" +) +import "fmt" + +//export Stark2Snark +func Stark2Snark(inputdir *C.char, outputdir *C.char) C.int { + // Convert C strings to Go strings + inputDir := C.GoString(inputdir) + outputDir := C.GoString(outputdir) + var prover SnarkProver + err := prover.Prove(inputDir, outputDir) + if err != nil { + fmt.Printf("Stark2Snark error: %v\n", err) + return -1 + } + return 0 +} + +func main() {} diff --git a/sdk/src/local/libsnark/snark_prover.go b/sdk/src/local/libsnark/snark_prover.go new file mode 100644 index 00000000..0e8a7f58 --- /dev/null +++ b/sdk/src/local/libsnark/snark_prover.go @@ -0,0 +1,223 @@ +package main + +import ( + "bytes" + "encoding/json" + "fmt" + "path/filepath" + "text/template" + + "math/big" + "os" + "time" + + "github.com/succinctlabs/gnark-plonky2-verifier/types" + "github.com/succinctlabs/gnark-plonky2-verifier/variables" + "github.com/succinctlabs/gnark-plonky2-verifier/verifier" + + "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend/groth16" + groth16_bn254 "github.com/consensys/gnark/backend/groth16/bn254" + "github.com/consensys/gnark/constraint" + "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" +) + +type SnarkProver struct { + r1cs_circuit constraint.ConstraintSystem + pk groth16.ProvingKey + vk groth16.VerifyingKey +} + +func (obj *SnarkProver) init_circuit_keys(inputdir string) error { + if obj.r1cs_circuit != nil { + return nil + } + + circuitPath := inputdir + "/circuit" + pkPath := inputdir + "/proving.key" + vkPath := inputdir + "/verifying.key" + _, err := os.Stat(circuitPath) + + if os.IsNotExist(err) { + commonCircuitData := types.ReadCommonCircuitData(inputdir + "/common_circuit_data.json") + proofWithPisData := types.ReadProofWithPublicInputs(inputdir + "/proof_with_public_inputs.json") + proofWithPis := variables.DeserializeProofWithPublicInputs(proofWithPisData) + + verifierOnlyCircuitRawData := types.ReadVerifierOnlyCircuitData(inputdir + "/verifier_only_circuit_data.json") + verifierOnlyCircuitData := variables.DeserializeVerifierOnlyCircuitData(verifierOnlyCircuitRawData) + + circuit := verifier.ExampleVerifierCircuit{ + Proof: proofWithPis.Proof, + PublicInputs: proofWithPis.PublicInputs, + VerifierOnlyCircuitData: verifierOnlyCircuitData, + CommonCircuitData: commonCircuitData, + } + + var builder frontend.NewBuilder = r1cs.NewBuilder + obj.r1cs_circuit, _ = frontend.Compile(ecc.BN254.ScalarField(), builder, &circuit) + fR1CS, _ := os.Create(circuitPath) + obj.r1cs_circuit.WriteTo(fR1CS) + fR1CS.Close() + } else { + fCircuit, err := os.Open(circuitPath) + if err != nil { + fmt.Println(err) + os.Exit(1) + } + + obj.r1cs_circuit = groth16.NewCS(ecc.BN254) + obj.r1cs_circuit.ReadFrom(fCircuit) + fCircuit.Close() + } + + _, err = os.Stat(pkPath) + if os.IsNotExist(err) { + obj.pk, obj.vk, err = groth16.Setup(obj.r1cs_circuit) + if err != nil { + fmt.Println(err) + os.Exit(1) + } + + fPK, _ := os.Create(pkPath) + obj.pk.WriteTo(fPK) + fPK.Close() + + if obj.vk != nil { + fVK, _ := os.Create(vkPath) + obj.vk.WriteTo(fVK) + fVK.Close() + } + } else { + obj.pk = groth16.NewProvingKey(ecc.BN254) + obj.vk = groth16.NewVerifyingKey(ecc.BN254) + fPk, err := os.Open(pkPath) + if err != nil { + fmt.Println(err) + os.Exit(1) + } + obj.pk.ReadFrom(fPk) + + fVk, err := os.Open(vkPath) + if err != nil { + fmt.Println(err) + os.Exit(1) + } + obj.vk.ReadFrom(fVk) + defer fVk.Close() + } + return nil +} + +func (obj *SnarkProver) groth16ProofWithCache(r1cs constraint.ConstraintSystem, inputdir, outputdir string) error { + proofWithPisData := types.ReadProofWithPublicInputs(inputdir + "/proof_with_public_inputs.json") + proofWithPis := variables.DeserializeProofWithPublicInputs(proofWithPisData) + + verifierOnlyCircuitRawData := types.ReadVerifierOnlyCircuitData(inputdir + "/verifier_only_circuit_data.json") + verifierOnlyCircuitData := variables.DeserializeVerifierOnlyCircuitData(verifierOnlyCircuitRawData) + + assignment := verifier.ExampleVerifierCircuit{ + Proof: proofWithPis.Proof, + PublicInputs: proofWithPis.PublicInputs, + VerifierOnlyCircuitData: verifierOnlyCircuitData, + } + + start := time.Now() + fmt.Println("Generating witness", start) + witness, _ := frontend.NewWitness(&assignment, ecc.BN254.ScalarField()) + fmt.Printf("frontend.NewWitness cost time: %v ms\n", time.Since(start).Milliseconds()) + publicWitness, _ := witness.Public() + + start = time.Now() + fmt.Println("Creating proof", start) + proof, err := groth16.Prove(r1cs, obj.pk, witness) + fmt.Printf("groth16.Prove cost time: %v ms\n", time.Since(start).Milliseconds()) + if err != nil { + return err + } + + if obj.vk == nil { + return fmt.Errorf("vk is nil, means you're using dummy setup and we skip verification of proof") + } + + start = time.Now() + fmt.Println("Verifying proof", start) + err = groth16.Verify(proof, obj.vk, publicWitness) + fmt.Printf("groth16.Verify cost time: %v ms\n", time.Since(start).Milliseconds()) + if err != nil { + return err + } + + fContractProof, _ := os.Create(outputdir + "/snark_proof_with_public_inputs.json") + _, bPublicWitness, _, _ := groth16.GetBn254Witness(proof, obj.vk, publicWitness) + nbInputs := len(bPublicWitness) + + type ProofPublicData struct { + Proof groth16.Proof + PublicWitness []string + } + proofPublicData := ProofPublicData{ + Proof: proof, + PublicWitness: make([]string, nbInputs), + } + for i := 0; i < nbInputs; i++ { + input := new(big.Int) + bPublicWitness[i].BigInt(input) + proofPublicData.PublicWitness[i] = input.String() + } + proofData, _ := json.Marshal(proofPublicData) + fContractProof.Write(proofData) + fContractProof.Close() + return nil +} + +func (obj *SnarkProver) generateVerifySol(outputDir string) error { + tmpl, err := template.New("contract").Parse(Gtemplate) + if err != nil { + return err + } + + type VerifyingKeyConfig struct { + Alpha string + Beta string + Gamma string + Delta string + Gamma_abc string + } + + var config VerifyingKeyConfig + vk := obj.vk.(*groth16_bn254.VerifyingKey) + + config.Alpha = fmt.Sprint("Pairing.G1Point(uint256(", vk.G1.Alpha.X.String(), "), uint256(", vk.G1.Alpha.Y.String(), "))") + config.Beta = fmt.Sprint("Pairing.G2Point([uint256(", vk.G2.Beta.X.A0.String(), "), uint256(", vk.G2.Beta.X.A1.String(), ")], [uint256(", vk.G2.Beta.Y.A0.String(), "), uint256(", vk.G2.Beta.Y.A1.String(), ")])") + config.Gamma = fmt.Sprint("Pairing.G2Point([uint256(", vk.G2.Gamma.X.A0.String(), "), uint256(", vk.G2.Gamma.X.A1.String(), ")], [uint256(", vk.G2.Gamma.Y.A0.String(), "), uint256(", vk.G2.Gamma.Y.A1.String(), ")])") + config.Delta = fmt.Sprint("Pairing.G2Point([uint256(", vk.G2.Delta.X.A0.String(), "), uint256(", vk.G2.Delta.X.A1.String(), ")], [uint256(", vk.G2.Delta.Y.A0.String(), "), uint256(", vk.G2.Delta.Y.A1.String(), ")])") + config.Gamma_abc = fmt.Sprint("vk.gamma_abc = new Pairing.G1Point[](", len(vk.G1.K), ");\n") + for k, v := range vk.G1.K { + config.Gamma_abc += fmt.Sprint(" vk.gamma_abc[", k, "] = Pairing.G1Point(uint256(", v.X.String(), "), uint256(", v.Y.String(), "));\n") + } + var buf bytes.Buffer + err = tmpl.Execute(&buf, config) + if err != nil { + return err + } + fSol, _ := os.Create(filepath.Join(outputDir, "verifier.sol")) + _, err = fSol.Write(buf.Bytes()) + if err != nil { + return err + } + fSol.Close() + return nil +} + +func (obj *SnarkProver) Prove(inputdir string, outputdir string) error { + if err := obj.init_circuit_keys(inputdir); err != nil { + return err + } + + if err := obj.generateVerifySol(outputdir); err != nil { + return err + } + + return obj.groth16ProofWithCache(obj.r1cs_circuit, inputdir, outputdir) +} diff --git a/sdk/src/local/mod.rs b/sdk/src/local/mod.rs index 8b137891..003fe2bf 100644 --- a/sdk/src/local/mod.rs +++ b/sdk/src/local/mod.rs @@ -1 +1,5 @@ - +#[allow(clippy::module_inception)] +pub mod prover; +pub mod snark; +pub mod stark; +pub mod util; diff --git a/sdk/src/local/prover.rs b/sdk/src/local/prover.rs new file mode 100644 index 00000000..ae4926d8 --- /dev/null +++ b/sdk/src/local/prover.rs @@ -0,0 +1,120 @@ +use std::collections::HashMap; +use std::sync::{Arc, Mutex}; +use std::thread; +use std::time::Duration; + +use crate::prover::{Prover, ProverInput, ProverResult}; +use async_trait::async_trait; +use std::fs; +use std::time::Instant; + +pub struct ProverTask { + proof_id: String, + input: ProverInput, + result: Option, + is_done: bool, +} + +impl ProverTask { + fn new(proof_id: &str, input: &ProverInput) -> ProverTask { + ProverTask { + proof_id: proof_id.to_string(), + input: input.clone(), + result: None, + is_done: false, + } + } + + fn run(&mut self) { + let mut result = ProverResult::default(); + let inputdir = format!("/tmp/{}/input", self.proof_id); + let outputdir = format!("/tmp/{}/output", self.proof_id); + fs::create_dir_all(&inputdir).unwrap(); + fs::create_dir_all(&outputdir).unwrap(); + crate::local::stark::prove_stark(&self.input, &inputdir, &mut result); + if self.input.execute_only { + result.proof_with_public_inputs = vec![]; + result.stark_proof = vec![]; + result.solidity_verifier = vec![]; + } else if crate::local::snark::prove_snark(&inputdir, &outputdir) { + result.stark_proof = + std::fs::read(format!("{}/proof_with_public_inputs.json", inputdir)).unwrap(); + result.proof_with_public_inputs = + std::fs::read(format!("{}/snark_proof_with_public_inputs.json", outputdir)) + .unwrap(); + result.solidity_verifier = + std::fs::read(format!("{}/verifier.sol", outputdir)).unwrap(); + } else { + log::error!("Failed to generate snark proof."); + } + self.result = Some(result); + self.is_done = true; + } + + fn is_done(&self) -> bool { + self.is_done + } +} + +pub struct LocalProver { + tasks: Arc>>>>, +} + +impl Default for LocalProver { + fn default() -> Self { + Self::new() + } +} + +impl LocalProver { + pub fn new() -> LocalProver { + LocalProver { + tasks: Arc::new(Mutex::new(HashMap::new())), + } + } +} + +#[async_trait] +impl Prover for LocalProver { + async fn request_proof<'a>(&self, input: &'a ProverInput) -> anyhow::Result { + let proof_id: String = uuid::Uuid::new_v4().to_string(); + let task: Arc> = Arc::new(Mutex::new(ProverTask::new(&proof_id, input))); + self.tasks + .lock() + .unwrap() + .insert(proof_id.clone(), task.clone()); + thread::spawn(move || { + task.lock().unwrap().run(); + }); + Ok(proof_id) + } + + async fn wait_proof<'a>( + &self, + proof_id: &'a str, + timeout: Option, + ) -> anyhow::Result> { + let task = self.tasks.lock().unwrap().get(proof_id).unwrap().clone(); + let start_time = Instant::now(); + loop { + if let Some(timeout) = timeout { + if start_time.elapsed() > timeout { + return Err(anyhow::anyhow!("Proof generation timed out.")); + } + } + if task.lock().unwrap().is_done() { + self.tasks.lock().unwrap().remove(proof_id); + return Ok(task.lock().unwrap().result.clone()); + } + } + } + + async fn prove<'a>( + &self, + input: &'a ProverInput, + timeout: Option, + ) -> anyhow::Result> { + let proof_id = self.request_proof(input).await?; + self.wait_proof(&proof_id, timeout).await + } +} diff --git a/sdk/src/local/snark.rs b/sdk/src/local/snark.rs new file mode 100644 index 00000000..43c1ac4f --- /dev/null +++ b/sdk/src/local/snark.rs @@ -0,0 +1,14 @@ +extern crate libc; +use libc::c_int; +use std::os::raw::c_char; + +extern "C" { + fn Stark2Snark(inputdir: *const c_char, outputdir: *const c_char) -> c_int; +} + +pub fn prove_snark(inputdir: &str, outputdir: &str) -> bool { + let inputdir = std::ffi::CString::new(inputdir).unwrap(); + let outputdir = std::ffi::CString::new(outputdir).unwrap(); + let ret = unsafe { Stark2Snark(inputdir.as_ptr(), outputdir.as_ptr()) }; + ret == 0 +} diff --git a/sdk/src/local/stark.rs b/sdk/src/local/stark.rs new file mode 100644 index 00000000..ed46a109 --- /dev/null +++ b/sdk/src/local/stark.rs @@ -0,0 +1,37 @@ +use super::util; +use crate::prover::{ProverInput, ProverResult}; +use elf::{endian::AnyEndian, ElfBytes}; +use zkm_emulator::state::State; +use zkm_emulator::utils::split_prog_into_segs; + +pub fn prove_stark(input: &ProverInput, storedir: &str, result: &mut ProverResult) { + let seg_path = format!("{}/segments", storedir); + let seg_size = input.seg_size as usize; + let file = ElfBytes::::minimal_parse(input.elf.as_slice()) + .expect("opening elf file failed"); + let mut state = State::load_elf(&file); + state.patch_elf(&file); + state.patch_stack(vec![]); + + state.add_input_stream(&input.public_inputstream); + state.add_input_stream(&input.private_inputstream); + + let (total_steps, state) = split_prog_into_segs(state, &seg_path, "", seg_size); + result + .output_stream + .copy_from_slice(&state.public_values_stream); + if input.execute_only { + return; + } + + let mut seg_num = 1usize; + if seg_size != 0 { + seg_num = (total_steps + seg_size - 1) / seg_size; + } + if seg_num == 1 { + let seg_file = format!("{seg_path}/{}", 0); + util::prove_single_seg_common(&seg_file, "", "", "", total_steps) + } else { + util::prove_multi_seg_common(&seg_path, "", "", "", storedir, seg_size, seg_num, 0).unwrap() + } +} diff --git a/sdk/src/local/util.rs b/sdk/src/local/util.rs new file mode 100644 index 00000000..1ae4601a --- /dev/null +++ b/sdk/src/local/util.rs @@ -0,0 +1,223 @@ +use std::fs::File; +use std::io::BufReader; +use std::ops::Range; +use std::time::Duration; + +use plonky2::field::goldilocks_field::GoldilocksField; +use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; +use plonky2::util::timing::TimingTree; +use plonky2x::backend::circuit::Groth16WrapperParameters; +use plonky2x::backend::wrapper::wrap::WrappedCircuit; +use plonky2x::frontend::builder::CircuitBuilder as WrapperBuilder; +use plonky2x::prelude::DefaultParameters; +use zkm_prover::all_stark::AllStark; +use zkm_prover::config::StarkConfig; +use zkm_prover::cpu::kernel::assembler::segment_kernel; +use zkm_prover::fixed_recursive_verifier::AllRecursiveCircuits; +use zkm_prover::proof; +use zkm_prover::proof::PublicValues; +use zkm_prover::prover::prove; +use zkm_prover::verifier::verify_proof; + +const DEGREE_BITS_RANGE: [Range; 6] = [10..21, 12..22, 12..21, 8..21, 6..21, 13..23]; + +pub fn prove_single_seg_common( + seg_file: &str, + basedir: &str, + block: &str, + file: &str, + seg_size: usize, +) { + let seg_reader = BufReader::new(File::open(seg_file).unwrap()); + let kernel = segment_kernel(basedir, block, file, seg_reader, seg_size); + + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + type F = >::F; + + let allstark: AllStark = AllStark::default(); + let config = StarkConfig::standard_fast_config(); + let mut timing = TimingTree::new("prove", log::Level::Info); + let allproof: proof::AllProof = + prove(&allstark, &kernel, &config, &mut timing).unwrap(); + let mut count_bytes = 0; + for (row, proof) in allproof.stark_proofs.clone().iter().enumerate() { + let proof_str = serde_json::to_string(&proof.proof).unwrap(); + log::info!("row:{} proof bytes:{}", row, proof_str.len()); + count_bytes += proof_str.len(); + } + timing.filter(Duration::from_millis(100)).print(); + log::info!("total proof bytes:{}KB", count_bytes / 1024); + verify_proof(&allstark, allproof, &config).unwrap(); + log::info!("Prove done"); +} + +#[allow(clippy::too_many_arguments)] +pub fn prove_multi_seg_common( + seg_dir: &str, + basedir: &str, + block: &str, + file: &str, + outdir: &str, + seg_size: usize, + seg_file_number: usize, + seg_start_id: usize, +) -> anyhow::Result<()> { + type InnerParameters = DefaultParameters; + type OuterParameters = Groth16WrapperParameters; + + type F = GoldilocksField; + const D: usize = 2; + type C = PoseidonGoldilocksConfig; + + if seg_file_number < 2 { + panic!("seg file number must >= 2\n"); + } + + let total_timing = TimingTree::new("prove total time", log::Level::Info); + let all_stark = AllStark::::default(); + let config = StarkConfig::standard_fast_config(); + // Preprocess all circuits. + let all_circuits = + AllRecursiveCircuits::::new(&all_stark, &DEGREE_BITS_RANGE, &config); + + let seg_file = format!("{}/{}", seg_dir, seg_start_id); + log::info!("Process segment {}", seg_file); + let seg_reader = BufReader::new(File::open(seg_file)?); + let input_first = segment_kernel(basedir, block, file, seg_reader, seg_size); + let mut timing = TimingTree::new("prove root first", log::Level::Info); + let (mut agg_proof, mut updated_agg_public_values) = + all_circuits.prove_root(&all_stark, &input_first, &config, &mut timing)?; + + timing.filter(Duration::from_millis(100)).print(); + all_circuits.verify_root(agg_proof.clone())?; + + let mut base_seg = seg_start_id + 1; + let mut seg_num = seg_file_number - 1; + let mut is_agg: bool = false; + + if seg_file_number % 2 == 0 { + let seg_file = format!("{}/{}", seg_dir, seg_start_id + 1); + log::info!("Process segment {}", seg_file); + let seg_reader = BufReader::new(File::open(seg_file)?); + let input = segment_kernel(basedir, block, file, seg_reader, seg_size); + timing = TimingTree::new("prove root second", log::Level::Info); + let (root_proof, public_values) = + all_circuits.prove_root(&all_stark, &input, &config, &mut timing)?; + timing.filter(Duration::from_millis(100)).print(); + + all_circuits.verify_root(root_proof.clone())?; + + // Update public values for the aggregation. + let agg_public_values = PublicValues { + roots_before: updated_agg_public_values.roots_before, + roots_after: public_values.roots_after, + userdata: public_values.userdata, + }; + timing = TimingTree::new("prove aggression", log::Level::Info); + // We can duplicate the proofs here because the state hasn't mutated. + (agg_proof, updated_agg_public_values) = all_circuits.prove_aggregation( + false, + &agg_proof, + false, + &root_proof, + agg_public_values.clone(), + )?; + timing.filter(Duration::from_millis(100)).print(); + all_circuits.verify_aggregation(&agg_proof)?; + + is_agg = true; + base_seg = seg_start_id + 2; + seg_num -= 1; + } + + for i in 0..seg_num / 2 { + let seg_file = format!("{}/{}", seg_dir, base_seg + (i << 1)); + log::info!("Process segment {}", seg_file); + let seg_reader = BufReader::new(File::open(&seg_file)?); + let input_first = segment_kernel(basedir, block, file, seg_reader, seg_size); + let mut timing = TimingTree::new("prove root first", log::Level::Info); + let (root_proof_first, first_public_values) = + all_circuits.prove_root(&all_stark, &input_first, &config, &mut timing)?; + + timing.filter(Duration::from_millis(100)).print(); + all_circuits.verify_root(root_proof_first.clone())?; + + let seg_file = format!("{}/{}", seg_dir, base_seg + (i << 1) + 1); + log::info!("Process segment {}", seg_file); + let seg_reader = BufReader::new(File::open(&seg_file)?); + let input = segment_kernel(basedir, block, file, seg_reader, seg_size); + let mut timing = TimingTree::new("prove root second", log::Level::Info); + let (root_proof, public_values) = + all_circuits.prove_root(&all_stark, &input, &config, &mut timing)?; + timing.filter(Duration::from_millis(100)).print(); + + all_circuits.verify_root(root_proof.clone())?; + + // Update public values for the aggregation. + let new_agg_public_values = PublicValues { + roots_before: first_public_values.roots_before, + roots_after: public_values.roots_after, + userdata: public_values.userdata, + }; + timing = TimingTree::new("prove aggression", log::Level::Info); + // We can duplicate the proofs here because the state hasn't mutated. + let (new_agg_proof, new_updated_agg_public_values) = all_circuits.prove_aggregation( + false, + &root_proof_first, + false, + &root_proof, + new_agg_public_values, + )?; + timing.filter(Duration::from_millis(100)).print(); + all_circuits.verify_aggregation(&new_agg_proof)?; + + // Update public values for the nested aggregation. + let agg_public_values = PublicValues { + roots_before: updated_agg_public_values.roots_before, + roots_after: new_updated_agg_public_values.roots_after, + userdata: new_updated_agg_public_values.userdata, + }; + timing = TimingTree::new("prove nested aggression", log::Level::Info); + + // We can duplicate the proofs here because the state hasn't mutated. + (agg_proof, updated_agg_public_values) = all_circuits.prove_aggregation( + is_agg, + &agg_proof, + true, + &new_agg_proof, + agg_public_values.clone(), + )?; + is_agg = true; + timing.filter(Duration::from_millis(100)).print(); + + all_circuits.verify_aggregation(&agg_proof)?; + } + + let (block_proof, _block_public_values) = + all_circuits.prove_block(None, &agg_proof, updated_agg_public_values)?; + + log::info!( + "proof size: {:?}", + serde_json::to_string(&block_proof.proof).unwrap().len() + ); + let result = all_circuits.verify_block(&block_proof); + + let builder = WrapperBuilder::::new(); + let mut circuit = builder.build(); + circuit.set_data(all_circuits.block.circuit); + let mut bit_size = vec![32usize; 16]; + bit_size.extend(vec![8; 32]); + bit_size.extend(vec![64; 68]); + let wrapped_circuit = WrappedCircuit::::build( + circuit, + Some((vec![], bit_size)), + ); + log::info!("build finish"); + + let wrapped_proof = wrapped_circuit.prove(&block_proof).unwrap(); + wrapped_proof.save(outdir).unwrap(); + + total_timing.filter(Duration::from_millis(100)).print(); + result +} diff --git a/sdk/src/prover.rs b/sdk/src/prover.rs index 1bcf5fe6..09dff3ae 100644 --- a/sdk/src/prover.rs +++ b/sdk/src/prover.rs @@ -3,7 +3,7 @@ use serde::Deserialize; use serde::Serialize; use tokio::time::Duration; -#[derive(Debug, Default, Deserialize, Serialize)] +#[derive(Debug, Default, Deserialize, Serialize, Clone)] pub struct ProverInput { pub elf: Vec, pub public_inputstream: Vec, @@ -12,7 +12,7 @@ pub struct ProverInput { pub execute_only: bool, } -#[derive(Debug, Default, Deserialize, Serialize)] +#[derive(Debug, Default, Deserialize, Serialize, Clone)] pub struct ProverResult { pub output_stream: Vec, pub proof_with_public_inputs: Vec,