diff --git a/Cargo.toml b/Cargo.toml index d8194f96..93b093b7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,7 @@ [workspace] members = [ - "host-program" + "host-program", + "sdk" ] resolver = "2" \ No newline at end of file diff --git a/rust-toolchain.toml b/rust-toolchain.toml new file mode 100644 index 00000000..e3d87d08 --- /dev/null +++ b/rust-toolchain.toml @@ -0,0 +1,2 @@ +[toolchain] +channel = "nightly-2024-07-20" diff --git a/sdk/Cargo.toml b/sdk/Cargo.toml new file mode 100644 index 00000000..4b5b8e8e --- /dev/null +++ b/sdk/Cargo.toml @@ -0,0 +1,64 @@ +[package] +name = "zkm-sdk" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +##zkm-emulator = { path = "../emulator" } +#plonky2 = { path = "../plonky2/plonky2" } +##starky = { path = "../plonky2/starky" } +#plonky2_util = { path = "../plonky2/util" } +#plonky2_maybe_rayon = { path = "../plonky2/maybe_rayon" } + +bincode = "1.3.3" + +async-trait = "0.1" + +stage-service = {package = "service", git = "https://github.com/zkMIPS/zkm-prover", branch = "main", default-features = false } +zkm-prover = { git = "https://github.com/zkMIPS/zkm", branch = "main", default-features = false } +zkm-emulator = { git = "https://github.com/zkMIPS/zkm", branch = "main", default-features = false } +common = { git = "https://github.com/zkMIPS/zkm-prover", branch = "main", default-features = false } +plonky2 = { git = "https://github.com/zkMIPS/plonky2.git", branch = "zkm_dev" } +#starky = { git = "https://github.com/zkMIPS/plonky2.git", branch = "zkm_dev" } +plonky2_util = { git = "https://github.com/zkMIPS/plonky2.git", branch = "zkm_dev" } +plonky2_maybe_rayon = { git = "https://github.com/zkMIPS/plonky2.git", branch = "zkm_dev" } + +tonic = "0.8.1" +prost = "0.11.0" +tokio = { version = "1.21.0", features = ["macros", "rt-multi-thread", "signal"] } +ethers = "2.0.14" + +itertools = "0.11.0" +log = { version = "0.4.14", default-features = false } +anyhow = "1.0.75" +num = "0.4.0" +num-bigint = "0.4.3" +serde = { version = "1.0.144", features = ["derive"] } +serde_json = "1.0" +tiny-keccak = "2.0.2" +rand = "0.8.5" +rand_chacha = "0.3.1" +once_cell = "1.13.0" +static_assertions = "1.1.0" +byteorder = "1.5.0" +hex = "0.4" +hashbrown = { version = "0.14.0", default-features = false, features = ["ahash", "serde"] } # NOTE: When upgrading, see `ahash` dependency. +lazy_static = "1.4.0" + +elf = { version = "0.7", default-features = false } +uuid = { version = "1.2", features = ["v4", "fast-rng", "macro-diagnostics"] } + +##[dev-dependencies] +env_logger = "0.10.0" +keccak-hash = "0.10.0" +plonky2x = { git = "https://github.com/zkMIPS/succinctx.git", package = "plonky2x", branch = "zkm" } +plonky2x-derive = { git = "https://github.com/zkMIPS/succinctx.git", package = "plonky2x-derive", branch = "zkm" } + +[build-dependencies] +tonic-build = "0.8.0" + +[features] +test = [] + diff --git a/sdk/build.rs b/sdk/build.rs new file mode 100644 index 00000000..168f8964 --- /dev/null +++ b/sdk/build.rs @@ -0,0 +1,4 @@ +fn main() -> Result<(), Box> { + tonic_build::compile_protos("src/proto/stage.proto")?; + Ok(()) +} \ No newline at end of file diff --git a/sdk/src/lib.rs b/sdk/src/lib.rs new file mode 100644 index 00000000..c5ed955f --- /dev/null +++ b/sdk/src/lib.rs @@ -0,0 +1,43 @@ + +pub mod prover; +pub mod local; +pub mod network; + +use std::env; +use prover::Prover; +use network::prover::NetworkProver; + +pub struct ProverClient { + pub prover: Box, +} + +impl ProverClient { + pub fn new() -> Self { + #[allow(unreachable_code)] + match env::var("ZKM_PROVER").unwrap_or("network".to_string()).to_lowercase().as_str() { + // "local" => Self { + // prover: Box::new(CpuProver::new()), + // }, + "network" => Self { + prover: Box::new(NetworkProver::default()), + }, + _ => panic!( + "invalid value for ZKM_PROVER enviroment variable: expected 'local', or 'network'" + ), + } + } + + // pub fn local() -> Self { + // Self { prover: Box::new(CpuProver::new()) } + // } + + pub fn network() -> Self { + Self { prover: Box::new(NetworkProver::default()) } + } +} + +impl Default for ProverClient { + fn default() -> Self { + Self::new() + } +} \ No newline at end of file diff --git a/sdk/src/local/cpu.rs b/sdk/src/local/cpu.rs new file mode 100644 index 00000000..c62a0c0e --- /dev/null +++ b/sdk/src/local/cpu.rs @@ -0,0 +1,455 @@ +// use serde::{Deserialize, Serialize}; +// use std::env; +// use std::fs::File; +// use std::io::prelude::*; +// use std::io::BufReader; +// use std::ops::Range; +// use std::time::Duration; + +// use plonky2::field::goldilocks_field::GoldilocksField; +// use plonky2::plonk::config::{GenericConfig, PoseidonGoldilocksConfig}; +// use plonky2::util::timing::TimingTree; +// use plonky2x::backend::circuit::Groth16WrapperParameters; +// use plonky2x::backend::wrapper::wrap::WrappedCircuit; +// use plonky2x::frontend::builder::CircuitBuilder as WrapperBuilder; +// use plonky2x::prelude::DefaultParameters; +// use zkm_emulator::utils::{ +// get_block_path, load_elf_with_patch, split_prog_into_segs, SEGMENT_STEPS, +// }; +// use zkm_prover::all_stark::AllStark; +// use zkm_prover::config::StarkConfig; +// use zkm_prover::cpu::kernel::assembler::segment_kernel; +// use zkm_prover::fixed_recursive_verifier::AllRecursiveCircuits; +// use zkm_prover::proof; +// use zkm_prover::proof::PublicValues; +// use zkm_prover::prover::prove; +// use zkm_prover::verifier::verify_proof; + +// const DEGREE_BITS_RANGE: [Range; 6] = [10..21, 12..22, 12..21, 8..21, 6..21, 13..23]; + +// fn split_segments() { +// // 1. split ELF into segs +// let basedir = env::var("BASEDIR").unwrap_or("/tmp/cannon".to_string()); +// let elf_path = env::var("ELF_PATH").expect("ELF file is missing"); +// let block_no = env::var("BLOCK_NO").unwrap_or("".to_string()); +// let seg_path = env::var("SEG_OUTPUT").expect("Segment output path is missing"); +// let seg_size = env::var("SEG_SIZE").unwrap_or(format!("{SEGMENT_STEPS}")); +// let seg_size = seg_size.parse::<_>().unwrap_or(SEGMENT_STEPS); +// let args = env::var("ARGS").unwrap_or("".to_string()); +// let args = args.split_whitespace().collect(); + +// let mut state = load_elf_with_patch(&elf_path, args); +// let block_path = get_block_path(&basedir, &block_no, ""); +// if !block_no.is_empty() { +// state.load_input(&block_path); +// } +// let _ = split_prog_into_segs(state, &seg_path, &block_path, seg_size); +// } + +// fn prove_single_seg_common( +// seg_file: &str, +// basedir: &str, +// block: &str, +// file: &str, +// seg_size: usize, +// ) { +// let seg_reader = BufReader::new(File::open(seg_file).unwrap()); +// let kernel = segment_kernel(basedir, block, file, seg_reader, seg_size); + +// const D: usize = 2; +// type C = PoseidonGoldilocksConfig; +// type F = >::F; + +// let allstark: AllStark = AllStark::default(); +// let config = StarkConfig::standard_fast_config(); +// let mut timing = TimingTree::new("prove", log::Level::Info); +// let allproof: proof::AllProof = +// prove(&allstark, &kernel, &config, &mut timing).unwrap(); +// let mut count_bytes = 0; +// for (row, proof) in allproof.stark_proofs.clone().iter().enumerate() { +// let proof_str = serde_json::to_string(&proof.proof).unwrap(); +// log::info!("row:{} proof bytes:{}", row, proof_str.len()); +// count_bytes += proof_str.len(); +// } +// timing.filter(Duration::from_millis(100)).print(); +// log::info!("total proof bytes:{}KB", count_bytes / 1024); +// verify_proof(&allstark, allproof, &config).unwrap(); +// log::info!("Prove done"); +// } + +// fn prove_multi_seg_common( +// seg_dir: &str, +// basedir: &str, +// block: &str, +// file: &str, +// seg_size: usize, +// seg_file_number: usize, +// seg_start_id: usize, +// ) -> anyhow::Result<()> { +// type InnerParameters = DefaultParameters; +// type OuterParameters = Groth16WrapperParameters; + +// type F = GoldilocksField; +// const D: usize = 2; +// type C = PoseidonGoldilocksConfig; + +// if seg_file_number < 2 { +// panic!("seg file number must >= 2\n"); +// } + +// let total_timing = TimingTree::new("prove total time", log::Level::Info); +// let all_stark = AllStark::::default(); +// let config = StarkConfig::standard_fast_config(); +// // Preprocess all circuits. +// let all_circuits = +// AllRecursiveCircuits::::new(&all_stark, &DEGREE_BITS_RANGE, &config); + +// let seg_file = format!("{}/{}", seg_dir, seg_start_id); +// log::info!("Process segment {}", seg_file); +// let seg_reader = BufReader::new(File::open(seg_file)?); +// let input_first = segment_kernel(basedir, block, file, seg_reader, seg_size); +// let mut timing = TimingTree::new("prove root first", log::Level::Info); +// let (mut agg_proof, mut updated_agg_public_values) = +// all_circuits.prove_root(&all_stark, &input_first, &config, &mut timing)?; + +// timing.filter(Duration::from_millis(100)).print(); +// all_circuits.verify_root(agg_proof.clone())?; + +// let mut base_seg = seg_start_id + 1; +// let mut seg_num = seg_file_number - 1; +// let mut is_agg = false; + +// if seg_file_number % 2 == 0 { +// let seg_file = format!("{}/{}", seg_dir, seg_start_id + 1); +// log::info!("Process segment {}", seg_file); +// let seg_reader = BufReader::new(File::open(seg_file)?); +// let input = segment_kernel(basedir, block, file, seg_reader, seg_size); +// timing = TimingTree::new("prove root second", log::Level::Info); +// let (root_proof, public_values) = +// all_circuits.prove_root(&all_stark, &input, &config, &mut timing)?; +// timing.filter(Duration::from_millis(100)).print(); + +// all_circuits.verify_root(root_proof.clone())?; + +// // Update public values for the aggregation. +// let agg_public_values = PublicValues { +// roots_before: updated_agg_public_values.roots_before, +// roots_after: public_values.roots_after, +// userdata: public_values.userdata, +// }; +// timing = TimingTree::new("prove aggression", log::Level::Info); +// // We can duplicate the proofs here because the state hasn't mutated. +// (agg_proof, updated_agg_public_values) = all_circuits.prove_aggregation( +// false, +// &agg_proof, +// false, +// &root_proof, +// agg_public_values.clone(), +// )?; +// timing.filter(Duration::from_millis(100)).print(); +// all_circuits.verify_aggregation(&agg_proof)?; + +// is_agg = true; +// base_seg = seg_start_id + 2; +// seg_num -= 1; +// } + +// for i in 0..seg_num / 2 { +// let seg_file = format!("{}/{}", seg_dir, base_seg + (i << 1)); +// log::info!("Process segment {}", seg_file); +// let seg_reader = BufReader::new(File::open(&seg_file)?); +// let input_first = segment_kernel(basedir, block, file, seg_reader, seg_size); +// let mut timing = TimingTree::new("prove root first", log::Level::Info); +// let (root_proof_first, first_public_values) = +// all_circuits.prove_root(&all_stark, &input_first, &config, &mut timing)?; + +// timing.filter(Duration::from_millis(100)).print(); +// all_circuits.verify_root(root_proof_first.clone())?; + +// let seg_file = format!("{}/{}", seg_dir, base_seg + (i << 1) + 1); +// log::info!("Process segment {}", seg_file); +// let seg_reader = BufReader::new(File::open(&seg_file)?); +// let input = segment_kernel(basedir, block, file, seg_reader, seg_size); +// let mut timing = TimingTree::new("prove root second", log::Level::Info); +// let (root_proof, public_values) = +// all_circuits.prove_root(&all_stark, &input, &config, &mut timing)?; +// timing.filter(Duration::from_millis(100)).print(); + +// all_circuits.verify_root(root_proof.clone())?; + +// // Update public values for the aggregation. +// let new_agg_public_values = PublicValues { +// roots_before: first_public_values.roots_before, +// roots_after: public_values.roots_after, +// userdata: public_values.userdata, +// }; +// timing = TimingTree::new("prove aggression", log::Level::Info); +// // We can duplicate the proofs here because the state hasn't mutated. +// let (new_agg_proof, new_updated_agg_public_values) = all_circuits.prove_aggregation( +// false, +// &root_proof_first, +// false, +// &root_proof, +// new_agg_public_values, +// )?; +// timing.filter(Duration::from_millis(100)).print(); +// all_circuits.verify_aggregation(&new_agg_proof)?; + +// // Update public values for the nested aggregation. +// let agg_public_values = PublicValues { +// roots_before: updated_agg_public_values.roots_before, +// roots_after: new_updated_agg_public_values.roots_after, +// userdata: new_updated_agg_public_values.userdata, +// }; +// timing = TimingTree::new("prove nested aggression", log::Level::Info); + +// // We can duplicate the proofs here because the state hasn't mutated. +// (agg_proof, updated_agg_public_values) = all_circuits.prove_aggregation( +// is_agg, +// &agg_proof, +// true, +// &new_agg_proof, +// agg_public_values.clone(), +// )?; +// is_agg = true; +// timing.filter(Duration::from_millis(100)).print(); + +// all_circuits.verify_aggregation(&agg_proof)?; +// } + +// let (block_proof, _block_public_values) = +// all_circuits.prove_block(None, &agg_proof, updated_agg_public_values)?; + +// log::info!( +// "proof size: {:?}", +// serde_json::to_string(&block_proof.proof).unwrap().len() +// ); +// let result = all_circuits.verify_block(&block_proof); + +// let build_path = "../verifier/data".to_string(); +// let path = format!("{}/test_circuit/", build_path); +// let builder = WrapperBuilder::::new(); +// let mut circuit = builder.build(); +// circuit.set_data(all_circuits.block.circuit); +// let mut bit_size = vec![32usize; 16]; +// bit_size.extend(vec![8; 32]); +// bit_size.extend(vec![64; 68]); +// let wrapped_circuit = WrappedCircuit::::build( +// circuit, +// Some((vec![], bit_size)), +// ); +// log::info!("build finish"); + +// let wrapped_proof = wrapped_circuit.prove(&block_proof).unwrap(); +// wrapped_proof.save(path).unwrap(); + +// total_timing.filter(Duration::from_millis(100)).print(); +// result +// } + +// fn prove_sha2_bench() { +// // 1. split ELF into segs +// let elf_path = env::var("ELF_PATH").expect("ELF file is missing"); +// let seg_path = env::var("SEG_OUTPUT").expect("Segment output path is missing"); + +// let mut state = load_elf_with_patch(&elf_path, vec![]); +// // load input +// let args = env::var("ARGS").unwrap_or("data-to-hash".to_string()); +// // assume the first arg is the hash output(which is a public input), and the second is the input. +// let args: Vec<&str> = args.split_whitespace().collect(); +// assert_eq!(args.len(), 2); + +// let public_input: Vec = hex::decode(args[0]).unwrap(); +// state.add_input_stream(&public_input); +// log::info!("expected public value in hex: {:X?}", args[0]); +// log::info!("expected public value: {:X?}", public_input); + +// let private_input = args[1].as_bytes().to_vec(); +// log::info!("private input value: {:X?}", private_input); +// state.add_input_stream(&private_input); + +// let (total_steps, mut state) = split_prog_into_segs(state, &seg_path, "", 0); + +// let value = state.read_public_values::<[u8; 32]>(); +// log::info!("public value: {:X?}", value); +// log::info!("public value: {} in hex", hex::encode(value)); + +// let seg_file = format!("{seg_path}/{}", 0); +// prove_single_seg_common(&seg_file, "", "", "", total_steps); +// } + +// fn prove_revm() { +// // 1. split ELF into segs +// let elf_path = env::var("ELF_PATH").expect("ELF file is missing"); +// let seg_path = env::var("SEG_OUTPUT").expect("Segment output path is missing"); +// let json_path = env::var("JSON_PATH").expect("JSON file is missing"); +// let seg_size = env::var("SEG_SIZE").unwrap_or("0".to_string()); +// let seg_size = seg_size.parse::<_>().unwrap_or(0); +// let mut f = File::open(json_path).unwrap(); +// let mut data = vec![]; +// f.read_to_end(&mut data).unwrap(); + +// let mut state = load_elf_with_patch(&elf_path, vec![]); +// // load input +// state.add_input_stream(&data); + +// let (total_steps, mut _state) = split_prog_into_segs(state, &seg_path, "", seg_size); + +// let mut seg_num = 1usize; +// if seg_size != 0 { +// seg_num = (total_steps + seg_size - 1) / seg_size; +// } + +// if seg_num == 1 { +// let seg_file = format!("{seg_path}/{}", 0); +// prove_single_seg_common(&seg_file, "", "", "", total_steps) +// } else { +// prove_multi_seg_common(&seg_path, "", "", "", seg_size, seg_num, 0).unwrap() +// } +// } + +// #[derive(Debug, Clone, Deserialize, Serialize)] +// pub enum DataId { +// TYPE1, +// TYPE2, +// TYPE3, +// } + +// #[derive(Debug, Clone, Deserialize, Serialize)] +// pub struct Data { +// pub input1: [u8; 10], +// pub input2: u8, +// pub input3: i8, +// pub input4: u16, +// pub input5: i16, +// pub input6: u32, +// pub input7: i32, +// pub input8: u64, +// pub input9: i64, +// pub input10: Vec, +// pub input11: DataId, +// pub input12: String, +// } + +// impl Default for Data { +// fn default() -> Self { +// Self::new() +// } +// } + +// impl Data { +// pub fn new() -> Self { +// let array = [1u8, 2u8, 3u8, 4u8, 5u8, 6u8, 7u8, 8u8, 9u8, 10u8]; +// Self { +// input1: array, +// input2: 0x11u8, +// input3: -1i8, +// input4: 0x1122u16, +// input5: -1i16, +// input6: 0x112233u32, +// input7: -1i32, +// input8: 0x1122334455u64, +// input9: -1i64, +// input10: array[1..3].to_vec(), +// input11: DataId::TYPE3, +// input12: "hello".to_string(), +// } +// } +// } + +// fn prove_add_example() { +// // 1. split ELF into segs +// let elf_path = env::var("ELF_PATH").expect("ELF file is missing"); +// let seg_path = env::var("SEG_OUTPUT").expect("Segment output path is missing"); +// let seg_size = env::var("SEG_SIZE").unwrap_or("0".to_string()); +// let seg_size = seg_size.parse::<_>().unwrap_or(0); + +// let mut state = load_elf_with_patch(&elf_path, vec![]); + +// let data = Data::new(); +// state.add_input_stream(&data); +// log::info!( +// "enum {} {} {}", +// DataId::TYPE1 as u8, +// DataId::TYPE2 as u8, +// DataId::TYPE3 as u8 +// ); +// log::info!("public input: {:X?}", data); + +// let (total_steps, mut state) = split_prog_into_segs(state, &seg_path, "", seg_size); + +// let value = state.read_public_values::(); +// log::info!("public value: {:X?}", value); + +// let mut seg_num = 1usize; +// if seg_size != 0 { +// seg_num = (total_steps + seg_size - 1) / seg_size; +// } + +// if seg_num == 1 { +// let seg_file = format!("{seg_path}/{}", 0); +// prove_single_seg_common(&seg_file, "", "", "", total_steps) +// } else { +// prove_multi_seg_common(&seg_path, "", "", "", seg_size, seg_num, 0).unwrap() +// } +// } + +// fn prove_host() { +// let host_program = env::var("HOST_PROGRAM").expect("host_program name is missing"); +// match host_program.as_str() { +// "sha2_bench" => prove_sha2_bench(), +// "revm" => prove_revm(), +// "add_example" => prove_add_example(), +// _ => log::error!("Host program {} is not supported!", host_program), +// }; +// } + +// fn prove_segments() { +// let basedir = env::var("BASEDIR").unwrap_or("/tmp/cannon".to_string()); +// let block = env::var("BLOCK_NO").unwrap_or("".to_string()); +// let file = env::var("BLOCK_FILE").unwrap_or(String::from("")); +// let seg_dir = env::var("SEG_FILE_DIR").expect("segment file dir is missing"); +// let seg_num = env::var("SEG_NUM").unwrap_or("1".to_string()); +// let seg_num = seg_num.parse::<_>().unwrap_or(1usize); +// let seg_start_id = env::var("SEG_START_ID").unwrap_or("0".to_string()); +// let seg_start_id = seg_start_id.parse::<_>().unwrap_or(0usize); +// let seg_size = env::var("SEG_SIZE").unwrap_or(format!("{SEGMENT_STEPS}")); +// let seg_size = seg_size.parse::<_>().unwrap_or(SEGMENT_STEPS); + +// if seg_num == 1 { +// let seg_file = format!("{seg_dir}/{}", seg_start_id); +// prove_single_seg_common(&seg_file, &basedir, &block, &file, seg_size) +// } else { +// prove_multi_seg_common( +// &seg_dir, +// &basedir, +// &block, +// &file, +// seg_size, +// seg_num, +// seg_start_id, +// ) +// .unwrap() +// } +// } + +// // fn main() { +// // env_logger::try_init().unwrap_or_default(); +// // let args: Vec = env::args().collect(); +// // let helper = || { +// // log::info!( +// // "Help: {} split | prove_segments | prove_host_program", +// // args[0] +// // ); +// // std::process::exit(-1); +// // }; +// // if args.len() < 2 { +// // helper(); +// // } +// // match args[1].as_str() { +// // "split" => split_segments(), +// // "prove_segments" => prove_segments(), +// // "prove_host_program" => prove_host(), +// // _ => helper(), +// // }; +// // } diff --git a/sdk/src/local/mod.rs b/sdk/src/local/mod.rs new file mode 100644 index 00000000..e69de29b diff --git a/sdk/src/network/mod.rs b/sdk/src/network/mod.rs new file mode 100644 index 00000000..4cd3d7d8 --- /dev/null +++ b/sdk/src/network/mod.rs @@ -0,0 +1,2 @@ + +pub mod prover; \ No newline at end of file diff --git a/sdk/src/network/prover.rs b/sdk/src/network/prover.rs new file mode 100644 index 00000000..7747c4ad --- /dev/null +++ b/sdk/src/network/prover.rs @@ -0,0 +1,145 @@ +use common::tls::Config; +use stage_service::stage_service_client::StageServiceClient; +use stage_service::{GenerateProofRequest, GetStatusRequest}; +use std::env; +use std::time::Instant; +use tonic::transport::{Channel, ClientTlsConfig}; +use tonic::transport::Endpoint; + +use ethers::signers::{LocalWallet, Signer}; +use crate::prover::{Prover, ProverInput, ProverResult}; +use tokio::time::Duration; +use tokio::time::sleep; + +use async_trait::async_trait; + +pub mod stage_service { + tonic::include_proto!("stage.v1"); +} + +use crate::network::prover::stage_service::Status; + +pub const DEFAULT_PROVER_NETWORK_RPC: &str = "https://152.32.186.45:20002"; +pub const DEFALUT_PROVER_NETWORK_DOMAIN: &str = "stage"; + +pub struct NetworkProver { + pub stage_client: StageServiceClient, + pub wallet: LocalWallet, +} + +impl NetworkProver { + pub async fn new() -> anyhow::Result { + let endpoint = env::var("ENDPOINT").unwrap_or(DEFAULT_PROVER_NETWORK_RPC.to_string()); + let ca_cert_path = env::var("CA_CERT_PATH").unwrap_or("".to_string()); + let cert_path = env::var("CERT_PATH").unwrap_or("".to_string()); + let key_path = env::var("KEY_PATH").unwrap_or("".to_string()); + let domain_name = env::var("DOMAIN_NAME").unwrap_or(DEFALUT_PROVER_NETWORK_DOMAIN.to_string()); + let private_key = env::var("PRIVATE_KEY").unwrap_or("".to_string()); + + let ssl_config = if ca_cert_path.is_empty() { + None + } else { + Some(Config::new(ca_cert_path, cert_path, key_path).await?) + }; + + let endpoint = match ssl_config { + Some(config) => { + let mut tls_config = ClientTlsConfig::new().domain_name(domain_name); + if let Some(ca_cert) = config.ca_cert { + tls_config = tls_config.ca_certificate(ca_cert); + } + if let Some(identity) = config.identity { + tls_config = tls_config.identity(identity); + } + Endpoint::new(endpoint)?.tls_config(tls_config)? + } + None => Endpoint::new(endpoint)?, + }; + let stage_client = StageServiceClient::connect(endpoint).await?; + let wallet = private_key.parse::().unwrap(); + Ok(NetworkProver { stage_client, wallet }) + } + + pub async fn sign_ecdsa(&self, request: &mut GenerateProofRequest) { + let sign_data = format!( + "{}&{}&{}&{}", + request.proof_id, request.block_no, request.seg_size, request.args + ); + let signature = self.wallet.sign_message(sign_data).await.unwrap(); + request.signature = signature.to_string(); + } +} + +impl Default for NetworkProver { + fn default() -> Self { + let rt = tokio::runtime::Runtime::new().unwrap(); + let result = rt.block_on(Self::new()); + result.unwrap() + } +} + +#[async_trait] +impl Prover for NetworkProver { + async fn request_proof<'a>(&self, input: &'a ProverInput) -> anyhow::Result { + let proof_id = uuid::Uuid::new_v4().to_string(); + let mut request = GenerateProofRequest { + proof_id: proof_id.clone(), + elf_data: input.elf.clone(), + seg_size: input.seg_size, + public_input_stream: input.public_inputstream.clone(), + private_input_stream: input.private_inputstream.clone(), + execute_only: input.execute_only, + ..Default::default() + }; + self.sign_ecdsa(&mut request).await; + let mut client = self.stage_client.clone(); + let response = client.generate_proof(request).await?.into_inner(); + Ok(response.proof_id) + } + + async fn wait_proof<'a>(&self, proof_id: &'a str, timeout: Option) -> anyhow::Result > { + let start_time = Instant::now(); + let mut client = self.stage_client.clone(); + loop { + if let Some(timeout) = timeout { + if start_time.elapsed() > timeout { + return Err(anyhow::anyhow!("Proof generation timed out.")); + } + } + + let get_status_request = GetStatusRequest { + proof_id: proof_id.to_string(), + }; + let get_status_response = client + .get_status(get_status_request) + .await? + .into_inner(); + + match Status::from_i32(get_status_response.status as i32) { + Some(Status::Computing) => { + sleep(Duration::from_secs(2)).await; + } + Some(Status::Success) => { + let proof_result = ProverResult { + output_stream: get_status_response.output_stream, + proof_with_public_inputs: get_status_response.proof_with_public_inputs, + ..Default::default() + }; + return Ok(Some(proof_result)); + } + _ => { + log::error!( + "generate_proof failed status: {}", + get_status_response.status + ); + return Ok(None); + } + } + } + } + + async fn prover<'a>(&self, input: &'a ProverInput, timeout: Option) -> anyhow::Result > { + let proof_id = self.request_proof(input).await?; + self.wait_proof(&proof_id, timeout).await + } +} \ No newline at end of file diff --git a/sdk/src/proto/stage.proto b/sdk/src/proto/stage.proto new file mode 100644 index 00000000..8003d106 --- /dev/null +++ b/sdk/src/proto/stage.proto @@ -0,0 +1,65 @@ +syntax = "proto3"; + +package stage.v1; + + +service StageService { + rpc GenerateProof(GenerateProofRequest) returns (GenerateProofResponse) {} + rpc GetStatus(GetStatusRequest) returns (GetStatusResponse) {} +} + +enum Status { + SUCCESS = 0; + UNSPECIFIED = 1; + COMPUTING = 2; + INVALID_PARAMETER = 3; + INTERNAL_ERROR = 4; + SPLIT_ERROR = 5; + PROVE_ERROR = 6; + AGG_ERROR = 7; + FINAL_ERROR = 8; +} + +message BlockFileItem { + string file_name = 1; + bytes file_content = 2; +} + +message GenerateProofRequest { + uint64 chain_id = 1; + uint64 timestamp = 2; + string proof_id = 3; + bytes elf_data = 4; + repeated BlockFileItem block_data = 5; + uint64 block_no = 6; + uint32 seg_size = 7; + string args = 8; + string signature = 9; + bytes public_input_stream = 10; + bytes private_input_stream = 11; + bool execute_only = 12; +} + +message GenerateProofResponse { + uint32 status = 1; + string error_message = 2; + string proof_id = 3; + string proof_url = 4; + string stark_proof_url = 5; + string solidity_verifier_url = 6; + bytes output_stream = 7; +} + +message GetStatusRequest { + string proof_id = 1; +} + +message GetStatusResponse { + string proof_id = 1; + uint32 status = 2; + bytes proof_with_public_inputs = 3; + string proof_url = 4; + string stark_proof_url = 5; + string solidity_verifier_url = 6; + bytes output_stream = 7; +} \ No newline at end of file diff --git a/sdk/src/prover.rs b/sdk/src/prover.rs new file mode 100644 index 00000000..60786b51 --- /dev/null +++ b/sdk/src/prover.rs @@ -0,0 +1,28 @@ + +use tokio::time::Duration; +use async_trait::async_trait; +use serde::Serialize; +use serde::Deserialize; + +#[derive(Debug, Default, Deserialize, Serialize)] +pub struct ProverInput { + pub elf: Vec, + pub public_inputstream: Vec, + pub private_inputstream: Vec, + pub seg_size: u32, + pub execute_only: bool, +} + +#[derive(Debug, Default, Deserialize, Serialize)] +pub struct ProverResult { + pub output_stream: Vec, + // pub stark_proof: Vec, + pub proof_with_public_inputs: Vec, +} + +#[async_trait] +pub trait Prover { + async fn request_proof<'a>(&self, input: &'a ProverInput) -> anyhow::Result; + async fn wait_proof<'a>(&self, proof_id: &'a str, timeout: Option) -> anyhow::Result>; + async fn prover<'a>(&self, input: &'a ProverInput, timeout: Option) -> anyhow::Result>; +} \ No newline at end of file