From 7f9cb61297da8abea835859166f905ddcb7a1319 Mon Sep 17 00:00:00 2001 From: Steven Gu Date: Fri, 18 Oct 2024 16:02:19 +0800 Subject: [PATCH 01/16] Update the circuits of cells tree, rows tree and block tree for generic extraction. --- Cargo.lock | 19 +- mp2-common/src/poseidon.rs | 3 + mp2-common/src/utils.rs | 7 + verifiable-db/Cargo.toml | 4 +- verifiable-db/src/block_tree/api.rs | 2 +- verifiable-db/src/block_tree/leaf.rs | 20 +- verifiable-db/src/block_tree/mod.rs | 81 +++- verifiable-db/src/block_tree/parent.rs | 18 +- verifiable-db/src/cells_tree/api.rs | 44 +- verifiable-db/src/cells_tree/empty_node.rs | 10 +- verifiable-db/src/cells_tree/full_node.rs | 48 ++- verifiable-db/src/cells_tree/leaf.rs | 46 ++- verifiable-db/src/cells_tree/mod.rs | 114 ++++-- verifiable-db/src/cells_tree/partial_node.rs | 65 +-- verifiable-db/src/cells_tree/public_inputs.rs | 295 ++++++++++---- verifiable-db/src/revelation/api.rs | 2 +- verifiable-db/src/row_tree/api.rs | 68 +++- verifiable-db/src/row_tree/full_node.rs | 92 +++-- verifiable-db/src/row_tree/leaf.rs | 70 ++-- verifiable-db/src/row_tree/mod.rs | 19 +- verifiable-db/src/row_tree/partial_node.rs | 81 ++-- verifiable-db/src/row_tree/public_inputs.rs | 382 ++++++++++++------ verifiable-db/src/row_tree/row.rs | 122 ++++++ 23 files changed, 1110 insertions(+), 502 deletions(-) create mode 100644 verifiable-db/src/row_tree/row.rs diff --git a/Cargo.lock b/Cargo.lock index e3c0e1dc4..ade5da76e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4574,13 +4574,13 @@ dependencies = [ "serde", "serde_json", "serde_plain", - "serde_with 3.9.0", + "serde_with 3.11.0", "sha2", "sha256", "starkyx", "tokio", "tracing", - "uuid 1.10.0", + "uuid 1.11.0", ] [[package]] @@ -5693,9 +5693,9 @@ dependencies = [ [[package]] name = "serde_with" -version = "3.9.0" +version = "3.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69cecfa94848272156ea67b2b1a53f20fc7bc638c4a46d2f8abde08f05f4b857" +checksum = "8e28bdad6db2b8340e449f7108f020b3b092e8583a9e3fb82713e1d4e71fe817" dependencies = [ "base64 0.22.1", "chrono", @@ -5705,7 +5705,7 @@ dependencies = [ "serde", "serde_derive", "serde_json", - "serde_with_macros 3.9.0", + "serde_with_macros 3.11.0", "time", ] @@ -5723,9 +5723,9 @@ dependencies = [ [[package]] name = "serde_with_macros" -version = "3.9.0" +version = "3.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8fee4991ef4f274617a51ad4af30519438dacb2f56ac773b08a1922ff743350" +checksum = "9d846214a9854ef724f3da161b426242d8de7c1fc7de2f89bb1efcb154dca79d" dependencies = [ "darling", "proc-macro2", @@ -6797,9 +6797,9 @@ dependencies = [ [[package]] name = "uuid" -version = "1.10.0" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81dfa00651efa65069b0b6b651f4aaa31ba9e3c3ce0137aaad053604ee7e0314" +checksum = "f8c5f0a0af699448548ad1a2fbf920fb4bee257eae39953ba95cb84891a0446a" dependencies = [ "serde", ] @@ -6830,6 +6830,7 @@ dependencies = [ "log", "mp2_common", "mp2_test", + "num", "plonky2", "plonky2_crypto", "plonky2_ecdsa", diff --git a/mp2-common/src/poseidon.rs b/mp2-common/src/poseidon.rs index b64755fd6..cf2a84eda 100644 --- a/mp2-common/src/poseidon.rs +++ b/mp2-common/src/poseidon.rs @@ -35,6 +35,9 @@ pub type H = >::Hasher; pub type P = >::AlgebraicPermutation; pub type HashPermutation = >::Permutation; +/// The result of hash to integer has 4 Uint32 (128 bits). +pub const HASH_TO_INT_LEN: usize = 4; + /// The flattened length of Poseidon hash, each original field is splitted from an /// Uint64 into two Uint32. pub const FLATTEN_POSEIDON_LEN: usize = NUM_HASH_OUT_ELTS * 2; diff --git a/mp2-common/src/utils.rs b/mp2-common/src/utils.rs index a09f60a90..76b3d6ec0 100644 --- a/mp2-common/src/utils.rs +++ b/mp2-common/src/utils.rs @@ -12,6 +12,7 @@ use plonky2::plonk::circuit_builder::CircuitBuilder; use plonky2::plonk::circuit_data::VerifierCircuitData; use plonky2::plonk::config::{GenericConfig, GenericHashOut, Hasher}; use plonky2_crypto::u32::arithmetic_u32::U32Target; +use plonky2_ecdsa::gadgets::biguint::BigUintTarget; use plonky2_ecgfp5::gadgets::{base_field::QuinticExtensionTarget, curve::CurveTarget}; use sha3::Digest; @@ -439,6 +440,12 @@ impl ToTargets for &[Target] { } } +impl ToTargets for BigUintTarget { + fn to_targets(&self) -> Vec { + self.limbs.iter().map(|u| u.0).collect() + } +} + pub trait TargetsConnector { fn connect_targets(&mut self, e1: T, e2: T); fn is_equal_targets(&mut self, e1: T, e2: T) -> BoolTarget; diff --git a/verifiable-db/Cargo.toml b/verifiable-db/Cargo.toml index 28b41c55e..6e27d0299 100644 --- a/verifiable-db/Cargo.toml +++ b/verifiable-db/Cargo.toml @@ -5,6 +5,7 @@ edition = "2021" [dependencies] mp2_common = { path = "../mp2-common" } +num.workspace = true plonky2_crypto.workspace = true recursion_framework = { path = "../recursion-framework" } ryhope = { path = "../ryhope" } @@ -29,4 +30,5 @@ serial_test.workspace = true tokio.workspace = true [features] -original_poseidon = ["mp2_common/original_poseidon"] \ No newline at end of file +original_poseidon = ["mp2_common/original_poseidon"] + diff --git a/verifiable-db/src/block_tree/api.rs b/verifiable-db/src/block_tree/api.rs index 33f5c6e54..023494840 100644 --- a/verifiable-db/src/block_tree/api.rs +++ b/verifiable-db/src/block_tree/api.rs @@ -294,7 +294,7 @@ mod tests { use std::iter; const EXTRACTION_IO_LEN: usize = extraction::test::PublicInputs::::TOTAL_LEN; - const ROWS_TREE_IO_LEN: usize = row_tree::PublicInputs::::TOTAL_LEN; + const ROWS_TREE_IO_LEN: usize = row_tree::PublicInputs::::total_len(); struct TestBuilder where diff --git a/verifiable-db/src/block_tree/leaf.rs b/verifiable-db/src/block_tree/leaf.rs index d2e1e055c..b6966047a 100644 --- a/verifiable-db/src/block_tree/leaf.rs +++ b/verifiable-db/src/block_tree/leaf.rs @@ -2,7 +2,7 @@ //! an existing node (or if there is no existing node, which happens for the //! first block number). -use super::{compute_index_digest, public_inputs::PublicInputs}; +use super::{compute_final_digest, compute_index_digest, public_inputs::PublicInputs}; use crate::{ extraction::{ExtractionPI, ExtractionPIWrap}, row_tree, @@ -10,7 +10,6 @@ use crate::{ use anyhow::Result; use mp2_common::{ default_config, - group_hashing::CircuitBuilderGroupHashing, poseidon::{empty_poseidon_hash, H}, proof::ProofWithVK, public_inputs::PublicInputCommon, @@ -55,15 +54,12 @@ impl LeafCircuit { let extraction_pi = E::PI::from_slice(extraction_pi); let rows_tree_pi = row_tree::PublicInputs::::from_slice(rows_tree_pi); + let final_digest = compute_final_digest::(b, &extraction_pi, &rows_tree_pi); // in our case, the extraction proofs extracts from the blockchain and sets // the block number as the primary index let index_value = extraction_pi.primary_index_value(); - // Enforce that the data extracted from the blockchain is the same as the data - // employed to build the rows tree for this node. - b.connect_curve_points(extraction_pi.value_set_digest(), rows_tree_pi.rows_digest()); - // Compute the hash of table metadata, to be exposed as public input to prove to // the verifier that we extracted the correct storage slots and we place the data // in the expected columns of the constructed tree; we add also the identifier @@ -82,7 +78,7 @@ impl LeafCircuit { let inputs = iter::once(index_identifier) .chain(index_value.iter().cloned()) .collect(); - let node_digest = compute_index_digest(b, inputs, rows_tree_pi.rows_digest()); + let node_digest = compute_index_digest(b, inputs, final_digest); // Compute hash of the inserted node // node_min = block_number @@ -103,7 +99,7 @@ impl LeafCircuit { // check that the rows tree built is for a merged table iff we extract data from MPT for a merged table b.connect( - rows_tree_pi.is_merge_case().target, + rows_tree_pi.merge_flag_target().target, extraction_pi.is_merge_case().target, ); @@ -170,7 +166,7 @@ where _verified_proofs: [&ProofWithPublicInputsTarget; 0], builder_parameters: Self::CircuitBuilderParams, ) -> Self { - const ROWS_TREE_IO: usize = row_tree::PublicInputs::::TOTAL_LEN; + const ROWS_TREE_IO: usize = row_tree::PublicInputs::::total_len(); let extraction_verifier = RecursiveCircuitsVerifierGagdet::::new( @@ -262,7 +258,7 @@ pub mod tests { let hash = H::hash_no_pad(&inputs); let int = hash_to_int_value(hash); let scalar = Scalar::from_noncanonical_biguint(int); - let point = rows_tree_pi.rows_digest_field(); + let point = rows_tree_pi.individual_digest_point(); let point = weierstrass_to_point(&point); point * scalar } @@ -279,7 +275,7 @@ pub mod tests { fn build(b: &mut CBuilder) -> Self::Wires { let extraction_pi = b.add_virtual_targets(TestPITargets::TOTAL_LEN); - let rows_tree_pi = b.add_virtual_targets(row_tree::PublicInputs::::TOTAL_LEN); + let rows_tree_pi = b.add_virtual_targets(row_tree::PublicInputs::::total_len()); let leaf_wires = LeafCircuit::build::(b, &extraction_pi, &rows_tree_pi); @@ -292,7 +288,7 @@ pub mod tests { assert_eq!(wires.1.len(), TestPITargets::TOTAL_LEN); pw.set_target_arr(&wires.1, self.extraction_pi); - assert_eq!(wires.2.len(), row_tree::PublicInputs::::TOTAL_LEN); + assert_eq!(wires.2.len(), row_tree::PublicInputs::::total_len()); pw.set_target_arr(&wires.2, self.rows_tree_pi); } } diff --git a/verifiable-db/src/block_tree/mod.rs b/verifiable-db/src/block_tree/mod.rs index 34f172404..6a65418fa 100644 --- a/verifiable-db/src/block_tree/mod.rs +++ b/verifiable-db/src/block_tree/mod.rs @@ -4,9 +4,18 @@ mod membership; mod parent; mod public_inputs; +use crate::{ + extraction::{ExtractionPI, ExtractionPIWrap}, + row_tree, +}; pub use api::{CircuitInput, PublicParameters}; -use mp2_common::{poseidon::hash_to_int_target, CHasher, D, F}; -use plonky2::{iop::target::Target, plonk::circuit_builder::CircuitBuilder}; +use mp2_common::{ + group_hashing::{circuit_hashed_scalar_mul, CircuitBuilderGroupHashing}, + poseidon::hash_to_int_target, + types::CBuilder, + CHasher, D, F, +}; +use plonky2::{field::types::Field, iop::target::Target, plonk::circuit_builder::CircuitBuilder}; use plonky2_ecdsa::gadgets::nonnative::CircuitBuilderNonNative; use plonky2_ecgfp5::gadgets::curve::{CircuitBuilderEcGFp5, CurveTarget}; @@ -25,10 +34,62 @@ pub(crate) fn compute_index_digest( b.curve_scalar_mul(base, &scalar) } +/// Compute the final digest. +pub(crate) fn compute_final_digest<'a, E>( + b: &mut CBuilder, + extraction_pi: &E::PI<'a>, + rows_tree_pi: &row_tree::PublicInputs, +) -> CurveTarget +where + E: ExtractionPIWrap, +{ + // Compute the final row digest from rows_tree_proof for merge case: + // multiplier_digest = rows_tree_proof.row_id_multiplier * rows_tree_proof.multiplier_vd + let multiplier_vd = rows_tree_pi.multiplier_digest_target(); + let row_id_multiplier = b.biguint_to_nonnative(&rows_tree_pi.row_id_multiplier_target()); + let multiplier_digest = b.curve_scalar_mul(multiplier_vd, &row_id_multiplier); + // rows_digest_merge = multiplier_digest * rows_tree_proof.DR + let individual_digest = rows_tree_pi.individual_digest_target(); + let rows_digest_merge = circuit_hashed_scalar_mul(b, multiplier_digest, individual_digest); + + // Choose the final row digest depending on whether we are in merge case or not: + // final_digest = extraction_proof.is_merge ? rows_digest_merge : rows_tree_proof.DR + let final_digest = b.curve_select( + extraction_pi.is_merge_case(), + rows_digest_merge, + individual_digest, + ); + + // Enforce that the data extracted from the blockchain is the same as the data + // employed to build the rows tree for this node: + // assert final_digest == extraction_proof.DV + b.connect_curve_points(final_digest, extraction_pi.value_set_digest()); + + // Enforce that if we aren't in merge case, then no cells were accumulated in + // multiplier digest: + // assert extraction_proof.is_merge or rows_tree_proof.multiplier_vd != 0 + // => (1 - is_merge) * is_multiplier_vd_zero == false + let ffalse = b._false(); + let curve_zero = b.curve_zero(); + let is_multiplier_vd_zero = b + .curve_eq(rows_tree_pi.multiplier_digest_target(), curve_zero) + .target; + let should_be_false = b.arithmetic( + F::NEG_ONE, + F::ONE, + extraction_pi.is_merge_case().target, + is_multiplier_vd_zero, + is_multiplier_vd_zero, + ); + b.connect(should_be_false, ffalse.target); + + final_digest +} + #[cfg(test)] pub(crate) mod tests { use alloy::primitives::U256; - use mp2_common::{keccak::PACKED_HASH_LEN, utils::ToFields, F}; + use mp2_common::{keccak::PACKED_HASH_LEN, poseidon::HASH_TO_INT_LEN, utils::ToFields, F}; use mp2_test::utils::random_vector; use plonky2::{ field::types::{Field, Sample}, @@ -79,7 +140,19 @@ pub(crate) mod tests { let h = random_vector::(NUM_HASH_OUT_ELTS).to_fields(); let [min, max] = [0; 2].map(|_| U256::from_limbs(rng.gen::<[u64; 4]>()).to_fields()); let is_merge = [F::from_canonical_usize(is_merge_case as usize)]; - row_tree::PublicInputs::new(&h, row_digest, &min, &max, &is_merge).to_vec() + let multiplier_digest = Point::sample(rng).to_weierstrass().to_fields(); + let row_id_multiplier = random_vector::(HASH_TO_INT_LEN).to_fields(); + + row_tree::PublicInputs::new( + &h, + row_digest, + &min, + &max, + &is_merge, + &multiplier_digest, + &row_id_multiplier, + ) + .to_vec() } /// Generate a random extraction public inputs. diff --git a/verifiable-db/src/block_tree/parent.rs b/verifiable-db/src/block_tree/parent.rs index ca7b8af66..68988e87f 100644 --- a/verifiable-db/src/block_tree/parent.rs +++ b/verifiable-db/src/block_tree/parent.rs @@ -1,7 +1,7 @@ //! This circuit is employed when the new node is inserted as parent of an existing node, //! referred to as old node. -use super::{compute_index_digest, public_inputs::PublicInputs}; +use super::{compute_final_digest, compute_index_digest, public_inputs::PublicInputs}; use crate::{ extraction::{ExtractionPI, ExtractionPIWrap}, row_tree, @@ -10,7 +10,6 @@ use alloy::primitives::U256; use anyhow::Result; use mp2_common::{ default_config, - group_hashing::CircuitBuilderGroupHashing, poseidon::{empty_poseidon_hash, H}, proof::ProofWithVK, public_inputs::PublicInputCommon, @@ -84,13 +83,10 @@ impl ParentCircuit { let extraction_pi = E::PI::from_slice(extraction_pi); let rows_tree_pi = row_tree::PublicInputs::::from_slice(rows_tree_pi); + let final_digest = compute_final_digest::(b, &extraction_pi, &rows_tree_pi); let block_number = extraction_pi.primary_index_value(); - // Enforce that the data extracted from the blockchain is the same as the data - // employed to build the rows tree for this node. - b.connect_curve_points(extraction_pi.value_set_digest(), rows_tree_pi.rows_digest()); - // Compute the hash of table metadata, to be exposed as public input to prove to // the verifier that we extracted the correct storage slots and we place the data // in the expected columns of the constructed tree; we add also the identifier @@ -110,7 +106,7 @@ impl ParentCircuit { let inputs = iter::once(index_identifier) .chain(block_number.iter().cloned()) .collect(); - let node_digest = compute_index_digest(b, inputs, rows_tree_pi.rows_digest()); + let node_digest = compute_index_digest(b, inputs, final_digest); // We recompute the hash of the old node to bind the `old_min` and `old_max` // values to the hash of the old tree. @@ -154,7 +150,7 @@ impl ParentCircuit { // check that the rows tree built is for a merged table iff we extract data from MPT for a merged table b.connect( - rows_tree_pi.is_merge_case().target, + rows_tree_pi.merge_flag_target().target, extraction_pi.is_merge_case().target, ); @@ -236,7 +232,7 @@ where _verified_proofs: [&ProofWithPublicInputsTarget; 0], builder_parameters: Self::CircuitBuilderParams, ) -> Self { - const ROWS_TREE_IO: usize = row_tree::PublicInputs::::TOTAL_LEN; + const ROWS_TREE_IO: usize = row_tree::PublicInputs::::total_len(); let extraction_verifier = RecursiveCircuitsVerifierGagdet::::new( @@ -315,7 +311,7 @@ mod tests { fn build(b: &mut CBuilder) -> Self::Wires { let extraction_pi = b.add_virtual_targets(TestPITargets::TOTAL_LEN); - let rows_tree_pi = b.add_virtual_targets(row_tree::PublicInputs::::TOTAL_LEN); + let rows_tree_pi = b.add_virtual_targets(row_tree::PublicInputs::::total_len()); let parent_wires = ParentCircuit::build::(b, &extraction_pi, &rows_tree_pi); @@ -329,7 +325,7 @@ mod tests { assert_eq!(wires.1.len(), TestPITargets::TOTAL_LEN); pw.set_target_arr(&wires.1, self.extraction_pi); - assert_eq!(wires.2.len(), row_tree::PublicInputs::::TOTAL_LEN); + assert_eq!(wires.2.len(), row_tree::PublicInputs::::total_len()); pw.set_target_arr(&wires.2, self.rows_tree_pi); } } diff --git a/verifiable-db/src/cells_tree/api.rs b/verifiable-db/src/cells_tree/api.rs index 1a6487fa6..8b7a84740 100644 --- a/verifiable-db/src/cells_tree/api.rs +++ b/verifiable-db/src/cells_tree/api.rs @@ -2,9 +2,9 @@ use super::{ empty_node::{EmptyNodeCircuit, EmptyNodeWires}, - full_node::{FullNodeCircuit, FullNodeWires}, + full_node::FullNodeWires, leaf::{LeafCircuit, LeafWires}, - partial_node::{PartialNodeCircuit, PartialNodeWires}, + partial_node::PartialNodeWires, public_inputs::PublicInputs, Cell, }; @@ -39,12 +39,13 @@ impl CircuitInput { /// Create a circuit input for proving a leaf node. /// It is not considered a multiplier column. Please use `leaf_multiplier` for registering a /// multiplier column. - pub fn leaf(identifier: u64, value: U256) -> Self { + pub fn leaf(identifier: u64, value: U256, mpt_metadata: HashOut) -> Self { CircuitInput::Leaf( Cell { identifier: F::from_canonical_u64(identifier), value, is_multiplier: false, + mpt_metadata, } .into(), ) @@ -52,12 +53,18 @@ impl CircuitInput { /// Create a circuit input for proving a leaf node whose value is considered as a multiplier /// depending on the boolean value. /// i.e. it means it's one of the repeated value amongst all the rows - pub fn leaf_multiplier(identifier: u64, value: U256, is_multiplier: bool) -> Self { + pub fn leaf_multiplier( + identifier: u64, + value: U256, + is_multiplier: bool, + mpt_metadata: HashOut, + ) -> Self { CircuitInput::Leaf( Cell { identifier: F::from_canonical_u64(identifier), value, is_multiplier, + mpt_metadata, } .into(), ) @@ -66,11 +73,17 @@ impl CircuitInput { /// Create a circuit input for proving a full node of 2 children. /// It is not considered a multiplier column. Please use `full_multiplier` for registering a /// multiplier column. - pub fn full(identifier: u64, value: U256, child_proofs: [Vec; 2]) -> Self { + pub fn full( + identifier: u64, + value: U256, + mpt_metadata: HashOut, + child_proofs: [Vec; 2], + ) -> Self { CircuitInput::FullNode(new_child_input( F::from_canonical_u64(identifier), value, false, + mpt_metadata, child_proofs.to_vec(), )) } @@ -80,23 +93,31 @@ impl CircuitInput { identifier: u64, value: U256, is_multiplier: bool, + mpt_metadata: HashOut, child_proofs: [Vec; 2], ) -> Self { CircuitInput::FullNode(new_child_input( F::from_canonical_u64(identifier), value, is_multiplier, + mpt_metadata, child_proofs.to_vec(), )) } /// Create a circuit input for proving a partial node of 1 child. /// It is not considered a multiplier column. Please use `partial_multiplier` for registering a /// multiplier column. - pub fn partial(identifier: u64, value: U256, child_proof: Vec) -> Self { + pub fn partial( + identifier: u64, + value: U256, + mpt_metadata: HashOut, + child_proof: Vec, + ) -> Self { CircuitInput::PartialNode(new_child_input( F::from_canonical_u64(identifier), value, false, + mpt_metadata, vec![child_proof], )) } @@ -104,12 +125,14 @@ impl CircuitInput { identifier: u64, value: U256, is_multiplier: bool, + mpt_metadata: HashOut, child_proof: Vec, ) -> Self { CircuitInput::PartialNode(new_child_input( F::from_canonical_u64(identifier), value, is_multiplier, + mpt_metadata, vec![child_proof], )) } @@ -120,6 +143,7 @@ fn new_child_input( identifier: F, value: U256, is_multiplier: bool, + mpt_metadata: HashOut, serialized_child_proofs: Vec>, ) -> ChildInput { ChildInput { @@ -127,6 +151,7 @@ fn new_child_input( identifier, value, is_multiplier, + mpt_metadata, }, serialized_child_proofs, } @@ -148,7 +173,7 @@ pub fn build_circuits_params() -> PublicParameters { PublicParameters::build() } -const NUM_IO: usize = PublicInputs::::TOTAL_LEN; +const NUM_IO: usize = PublicInputs::::total_len(); /// Number of circuits in the set /// 1 leaf + 1 full node + 1 partial node + 1 empty node @@ -246,8 +271,10 @@ impl PublicParameters { pub fn extract_hash_from_proof(proof: &[u8]) -> Result> { let p = ProofWithVK::deserialize(proof)?; - Ok(PublicInputs::from_slice(&p.proof.public_inputs).root_hash_hashout()) + Ok(PublicInputs::from_slice(&p.proof.public_inputs).node_hash()) } + +/* #[cfg(test)] mod tests { use super::*; @@ -452,3 +479,4 @@ mod tests { proof } } +*/ diff --git a/verifiable-db/src/cells_tree/empty_node.rs b/verifiable-db/src/cells_tree/empty_node.rs index d0d770b1f..b86013f3f 100644 --- a/verifiable-db/src/cells_tree/empty_node.rs +++ b/verifiable-db/src/cells_tree/empty_node.rs @@ -23,11 +23,11 @@ impl EmptyNodeCircuit { let empty_hash = empty_poseidon_hash(); let h = b.constant_hash(*empty_hash).elements; - // dc = CURVE_ZERO - let dc = b.curve_zero().to_targets(); + // CURVE_ZERO + let curve_zero = b.curve_zero().to_targets(); // Register the public inputs. - PublicInputs::new(&h, &dc, &dc).register(b); + PublicInputs::new(&h, &curve_zero, &curve_zero, &curve_zero, &curve_zero).register(b); EmptyNodeWires } @@ -39,7 +39,7 @@ impl CircuitLogicWires for EmptyNodeWires { type Inputs = EmptyNodeCircuit; - const NUM_PUBLIC_INPUTS: usize = PublicInputs::::TOTAL_LEN; + const NUM_PUBLIC_INPUTS: usize = PublicInputs::::total_len(); fn circuit_logic( builder: &mut CBuilder, @@ -54,6 +54,7 @@ impl CircuitLogicWires for EmptyNodeWires { } } +/* #[cfg(test)] mod tests { use super::*; @@ -87,3 +88,4 @@ mod tests { } } } +*/ diff --git a/verifiable-db/src/cells_tree/full_node.rs b/verifiable-db/src/cells_tree/full_node.rs index 3a5bb4f3f..983ec071f 100644 --- a/verifiable-db/src/cells_tree/full_node.rs +++ b/verifiable-db/src/cells_tree/full_node.rs @@ -4,8 +4,7 @@ use super::{public_inputs::PublicInputs, Cell, CellWire}; use anyhow::Result; use derive_more::{From, Into}; use mp2_common::{ - group_hashing::CircuitBuilderGroupHashing, public_inputs::PublicInputCommon, types::CBuilder, - u256::CircuitBuilderU256, utils::ToTargets, CHasher, D, F, + poseidon::H, public_inputs::PublicInputCommon, types::CBuilder, utils::ToTargets, D, F, }; use plonky2::{ iop::{target::Target, witness::PartialWitness}, @@ -13,7 +12,7 @@ use plonky2::{ }; use recursion_framework::circuit_builder::CircuitLogicWires; use serde::{Deserialize, Serialize}; -use std::{array, iter}; +use std::{array, iter::once}; #[derive(Clone, Debug, Serialize, Deserialize, Into, From)] pub struct FullNodeWires(CellWire); @@ -23,30 +22,35 @@ pub struct FullNodeCircuit(Cell); impl FullNodeCircuit { pub fn build(b: &mut CBuilder, child_proofs: [PublicInputs; 2]) -> FullNodeWires { + let [p1, p2] = child_proofs; + let cell = CellWire::new(b); + let metadata_digests = cell.split_metadata_digest(b); + let values_digests = cell.split_values_digest(b); + + let metadata_digests = metadata_digests.accumulate(b, &p1.split_metadata_digest_target()); + let metadata_digests = metadata_digests.accumulate(b, &p2.split_metadata_digest_target()); - // h = Poseidon(p1.H || p2.H || identifier || value) - let [p1_hash, p2_hash] = [0, 1].map(|i| child_proofs[i].node_hash()); - let inputs: Vec<_> = p1_hash - .elements - .iter() - .cloned() - .chain(p2_hash.elements) - .chain(iter::once(cell.identifier)) + let values_digests = values_digests.accumulate(b, &p1.split_values_digest_target()); + let values_digests = values_digests.accumulate(b, &p2.split_values_digest_target()); + + // H(p1.H || p2.H || identifier || value) + let inputs = p1 + .node_hash_target() + .into_iter() + .chain(p2.node_hash_target()) + .chain(once(cell.identifier)) .chain(cell.value.to_targets()) .collect(); - let h = b.hash_n_to_hash_no_pad::(inputs).elements; - - // digest_cell = p1.digest_cell + p2.digest_cell + D(identifier || value) - let split_digest = cell.split_digest(b); - let split_digest = split_digest.accumulate(b, &child_proofs[0].split_digest_target()); - let split_digest = split_digest.accumulate(b, &child_proofs[1].split_digest_target()); + let h = b.hash_n_to_hash_no_pad::(inputs); // Register the public inputs. PublicInputs::new( - &h, - &split_digest.individual.to_targets(), - &split_digest.multiplier.to_targets(), + &h.to_targets(), + &values_digests.individual.to_targets(), + &values_digests.multiplier.to_targets(), + &metadata_digests.individual.to_targets(), + &metadata_digests.multiplier.to_targets(), ) .register(b); @@ -65,7 +69,7 @@ impl CircuitLogicWires for FullNodeWires { type Inputs = FullNodeCircuit; - const NUM_PUBLIC_INPUTS: usize = PublicInputs::::TOTAL_LEN; + const NUM_PUBLIC_INPUTS: usize = PublicInputs::::total_len(); fn circuit_logic( builder: &mut CBuilder, @@ -83,6 +87,7 @@ impl CircuitLogicWires for FullNodeWires { } } +/* #[cfg(test)] mod tests { use super::*; @@ -195,3 +200,4 @@ mod tests { } } } +*/ diff --git a/verifiable-db/src/cells_tree/leaf.rs b/verifiable-db/src/cells_tree/leaf.rs index 72fefca14..180643fc1 100644 --- a/verifiable-db/src/cells_tree/leaf.rs +++ b/verifiable-db/src/cells_tree/leaf.rs @@ -3,8 +3,11 @@ use super::{public_inputs::PublicInputs, Cell, CellWire}; use derive_more::{From, Into}; use mp2_common::{ - poseidon::empty_poseidon_hash, public_inputs::PublicInputCommon, types::CBuilder, - utils::ToTargets, CHasher, D, F, + poseidon::{empty_poseidon_hash, H}, + public_inputs::PublicInputCommon, + types::CBuilder, + utils::ToTargets, + D, F, }; use plonky2::{ iop::witness::PartialWitness, @@ -12,7 +15,7 @@ use plonky2::{ }; use recursion_framework::circuit_builder::CircuitLogicWires; use serde::{Deserialize, Serialize}; -use std::iter; +use std::iter::once; #[derive(Clone, Debug, Serialize, Deserialize, From, Into)] pub struct LeafWires(CellWire); @@ -23,28 +26,27 @@ pub struct LeafCircuit(Cell); impl LeafCircuit { fn build(b: &mut CBuilder) -> LeafWires { let cell = CellWire::new(b); - - // h = Poseidon(Poseidon("") || Poseidon("") || identifier || value) - let empty_hash = empty_poseidon_hash(); - let empty_hash = b.constant_hash(*empty_hash); - let inputs: Vec<_> = empty_hash - .elements - .iter() - .cloned() - .chain(empty_hash.elements) - .chain(iter::once(cell.identifier)) + let metadata_digests = cell.split_metadata_digest(b); + let values_digests = cell.split_values_digest(b); + + // H(H("") || H("") || identifier || pack_u32(value)) + let empty_hash = b.constant_hash(*empty_poseidon_hash()).to_targets(); + let inputs = empty_hash + .clone() + .into_iter() + .chain(empty_hash) + .chain(once(cell.identifier)) .chain(cell.value.to_targets()) .collect(); - let h = b.hash_n_to_hash_no_pad::(inputs).elements; - - // digest_cell = D(identifier || value) - let split_digest = cell.split_digest(b); + let h = b.hash_n_to_hash_no_pad::(inputs); // Register the public inputs. PublicInputs::new( - &h, - &split_digest.individual.to_targets(), - &split_digest.multiplier.to_targets(), + &h.to_targets(), + &values_digests.individual.to_targets(), + &values_digests.multiplier.to_targets(), + &metadata_digests.individual.to_targets(), + &metadata_digests.multiplier.to_targets(), ) .register(b); @@ -63,7 +65,7 @@ impl CircuitLogicWires for LeafWires { type Inputs = LeafCircuit; - const NUM_PUBLIC_INPUTS: usize = PublicInputs::::TOTAL_LEN; + const NUM_PUBLIC_INPUTS: usize = PublicInputs::::total_len(); fn circuit_logic( builder: &mut CircuitBuilder, @@ -79,6 +81,7 @@ impl CircuitLogicWires for LeafWires { } } +/* #[cfg(test)] mod tests { use super::*; @@ -153,3 +156,4 @@ mod tests { } } } +*/ diff --git a/verifiable-db/src/cells_tree/mod.rs b/verifiable-db/src/cells_tree/mod.rs index af0e85846..a5ba1d0dc 100644 --- a/verifiable-db/src/cells_tree/mod.rs +++ b/verifiable-db/src/cells_tree/mod.rs @@ -5,11 +5,10 @@ mod leaf; mod partial_node; mod public_inputs; -use serde::{Deserialize, Serialize}; - use alloy::primitives::U256; pub use api::{build_circuits_params, extract_hash_from_proof, CircuitInput, PublicParameters}; use derive_more::Constructor; +use itertools::Itertools; use mp2_common::{ digest::{Digest, SplitDigestPoint, SplitDigestTarget}, group_hashing::{map_to_curve_point, CircuitBuilderGroupHashing}, @@ -17,15 +16,17 @@ use mp2_common::{ types::CBuilder, u256::{CircuitBuilderU256, UInt256Target, WitnessWriteU256}, utils::{ToFields, ToTargets}, - D, F, + F, }; +use serde::{Deserialize, Serialize}; +use std::iter::once; use plonky2::{ + hash::hash_types::{HashOut, HashOutTarget}, iop::{ target::{BoolTarget, Target}, witness::{PartialWitness, WitnessWrite}, }, - plonk::circuit_builder::CircuitBuilder, }; use plonky2_ecgfp5::gadgets::curve::CurveTarget; pub use public_inputs::PublicInputs; @@ -40,6 +41,8 @@ pub(crate) struct Cell { pub(crate) value: U256, /// is the secondary value should be included in multiplier digest or not pub(crate) is_multiplier: bool, + /// Hash of the metadata associated to this cell, as computed in MPT extraction circuits + pub(crate) mpt_metadata: HashOut, } impl Cell { @@ -47,29 +50,48 @@ impl Cell { pw.set_u256_target(&wires.value, self.value); pw.set_target(wires.identifier, self.identifier); pw.set_bool_target(wires.is_multiplier, self.is_multiplier); + pw.set_hash_target(wires.mpt_metadata, self.mpt_metadata); } - pub(crate) fn digest(&self) -> Digest { - map_to_curve_point(&self.to_fields()) + pub(crate) fn split_metadata_digest(&self) -> SplitDigestPoint { + let digest = self.metadata_digest(); + SplitDigestPoint::from_single_digest_point(digest, self.is_multiplier) } - pub(crate) fn split_digest(&self) -> SplitDigestPoint { - let digest = self.digest(); + pub(crate) fn split_values_digest(&self) -> SplitDigestPoint { + let digest = self.values_digest(); SplitDigestPoint::from_single_digest_point(digest, self.is_multiplier) } - pub(crate) fn split_and_accumulate_digest( + pub(crate) fn split_and_accumulate_metadata_digest( &self, child_digest: SplitDigestPoint, ) -> SplitDigestPoint { - let sd = self.split_digest(); - sd.accumulate(&child_digest) + let split_digest = self.split_metadata_digest(); + split_digest.accumulate(&child_digest) } -} - -impl ToFields for Cell { - fn to_fields(&self) -> Vec { - [self.identifier] + pub(crate) fn split_and_accumulate_values_digest( + &self, + child_digest: SplitDigestPoint, + ) -> SplitDigestPoint { + let split_digest = self.split_values_digest(); + split_digest.accumulate(&child_digest) + } + fn metadata_digest(&self) -> Digest { + // D(mpt_metadata || identifier) + let inputs = self + .mpt_metadata + .to_fields() .into_iter() + .chain(once(self.identifier)) + .collect_vec(); + + map_to_curve_point(&inputs) + } + fn values_digest(&self) -> Digest { + // D(identifier || pack_u32(value)) + let inputs = once(self.identifier) .chain(self.value.to_fields()) - .collect() + .collect_vec(); + + map_to_curve_point(&inputs) } } @@ -80,44 +102,60 @@ pub(crate) struct CellWire { pub(crate) identifier: Target, #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] pub(crate) is_multiplier: BoolTarget, + #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] + pub(crate) mpt_metadata: HashOutTarget, } impl CellWire { - pub(crate) fn new(b: &mut CircuitBuilder) -> Self { + pub(crate) fn new(b: &mut CBuilder) -> Self { Self { value: b.add_virtual_u256(), identifier: b.add_virtual_target(), is_multiplier: b.add_virtual_bool_target_safe(), + mpt_metadata: b.add_virtual_hash(), } } - /// Returns the digest of the cell - pub(crate) fn digest(&self, b: &mut CircuitBuilder) -> CurveTarget { - b.map_to_curve_point(&self.to_targets()) + pub(crate) fn split_metadata_digest(&self, b: &mut CBuilder) -> SplitDigestTarget { + let digest = self.metadata_digest(b); + SplitDigestTarget::from_single_digest_target(b, digest, self.is_multiplier) } - /// Returns the different digest, multiplier or individual - pub(crate) fn split_digest(&self, c: &mut CBuilder) -> SplitDigestTarget { - let d = self.digest(c); - SplitDigestTarget::from_single_digest_target(c, d, self.is_multiplier) + pub(crate) fn split_values_digest(&self, b: &mut CBuilder) -> SplitDigestTarget { + let digest = self.values_digest(b); + SplitDigestTarget::from_single_digest_target(b, digest, self.is_multiplier) } - /// Returns the split digest from this cell added with the one from the proof. - /// NOTE: it calls agains split_digest, so call that first if you need the individual - /// SplitDigestTarget - pub(crate) fn split_and_accumulate_digest( + pub(crate) fn split_and_accumulate_metadata_digest( &self, - c: &mut CBuilder, + b: &mut CBuilder, child_digest: SplitDigestTarget, ) -> SplitDigestTarget { - let sd = self.split_digest(c); - sd.accumulate(c, &child_digest) + let split_digest = self.split_metadata_digest(b); + split_digest.accumulate(b, &child_digest) } -} - -impl ToTargets for CellWire { - fn to_targets(&self) -> Vec { - self.identifier + pub(crate) fn split_and_accumulate_values_digest( + &self, + b: &mut CBuilder, + child_digest: SplitDigestTarget, + ) -> SplitDigestTarget { + let split_digest = self.split_values_digest(b); + split_digest.accumulate(b, &child_digest) + } + fn metadata_digest(&self, b: &mut CBuilder) -> CurveTarget { + // D(mpt_metadata || identifier) + let inputs = self + .mpt_metadata .to_targets() .into_iter() + .chain(once(self.identifier)) + .collect_vec(); + + b.map_to_curve_point(&inputs) + } + fn values_digest(&self, b: &mut CBuilder) -> CurveTarget { + // D(identifier || pack_u32(value)) + let inputs = once(self.identifier) .chain(self.value.to_targets()) - .collect::>() + .collect_vec(); + + b.map_to_curve_point(&inputs) } } diff --git a/verifiable-db/src/cells_tree/partial_node.rs b/verifiable-db/src/cells_tree/partial_node.rs index d9b5bf45b..2e5363bc3 100644 --- a/verifiable-db/src/cells_tree/partial_node.rs +++ b/verifiable-db/src/cells_tree/partial_node.rs @@ -1,29 +1,22 @@ //! Module handling the intermediate node with 1 child inside a cells tree use super::{public_inputs::PublicInputs, Cell, CellWire}; -use alloy::primitives::U256; use anyhow::Result; use derive_more::{From, Into}; use mp2_common::{ - group_hashing::CircuitBuilderGroupHashing, - poseidon::empty_poseidon_hash, + poseidon::{empty_poseidon_hash, H}, public_inputs::PublicInputCommon, types::CBuilder, - u256::{CircuitBuilderU256, UInt256Target, WitnessWriteU256}, utils::ToTargets, - CHasher, D, F, + D, F, }; use plonky2::{ - iop::{ - target::{BoolTarget, Target}, - witness::{PartialWitness, WitnessWrite}, - }, + iop::{target::Target, witness::PartialWitness}, plonk::proof::ProofWithPublicInputsTarget, }; -use plonky2_ecgfp5::gadgets::curve::CircuitBuilderEcGFp5; use recursion_framework::circuit_builder::CircuitLogicWires; use serde::{Deserialize, Serialize}; -use std::iter; +use std::iter::once; #[derive(Clone, Debug, Serialize, Deserialize, From, Into)] pub struct PartialNodeWires(CellWire); @@ -32,32 +25,38 @@ pub struct PartialNodeWires(CellWire); pub struct PartialNodeCircuit(Cell); impl PartialNodeCircuit { - pub fn build(b: &mut CBuilder, child_proof: PublicInputs) -> PartialNodeWires { + pub fn build(b: &mut CBuilder, p: PublicInputs) -> PartialNodeWires { let cell = CellWire::new(b); - - // h = Poseidon(p.H || Poseidon("") || identifier || value) - let child_hash = child_proof.node_hash(); - let empty_hash = empty_poseidon_hash(); - let empty_hash = b.constant_hash(*empty_hash); - let inputs: Vec<_> = child_hash - .elements - .iter() - .cloned() - .chain(empty_hash.elements) - .chain(iter::once(cell.identifier)) + let metadata_digests = cell.split_metadata_digest(b); + let values_digests = cell.split_values_digest(b); + + let metadata_digests = metadata_digests.accumulate(b, &p.split_metadata_digest_target()); + let values_digests = values_digests.accumulate(b, &p.split_values_digest_target()); + + /* + # since there is no sorting constraint among the nodes of this tree, to simplify + # the circuits, when we build a node with only one child, we can always place + # it as the left child + # NOTE: this is true only if we the "block" tree + h = H(p.H || H("") || identifier || value) + */ + let empty_hash = b.constant_hash(*empty_poseidon_hash()).to_targets(); + let inputs = p + .node_hash_target() + .into_iter() + .chain(empty_hash) + .chain(once(cell.identifier)) .chain(cell.value.to_targets()) .collect(); - let h = b.hash_n_to_hash_no_pad::(inputs).elements; - - // aggregate the digest of the child proof in the right digest - // digest_cell = p.digest_cell + D(identifier || value) - let split_digest = cell.split_and_accumulate_digest(b, child_proof.split_digest_target()); + let h = b.hash_n_to_hash_no_pad::(inputs); // Register the public inputs. PublicInputs::new( - &h, - &split_digest.individual.to_targets(), - &split_digest.multiplier.to_targets(), + &h.to_targets(), + &values_digests.individual.to_targets(), + &values_digests.multiplier.to_targets(), + &metadata_digests.individual.to_targets(), + &metadata_digests.multiplier.to_targets(), ) .register(b); @@ -76,7 +75,7 @@ impl CircuitLogicWires for PartialNodeWires { type Inputs = PartialNodeCircuit; - const NUM_PUBLIC_INPUTS: usize = PublicInputs::::TOTAL_LEN; + const NUM_PUBLIC_INPUTS: usize = PublicInputs::::total_len(); fn circuit_logic( builder: &mut CBuilder, @@ -93,6 +92,7 @@ impl CircuitLogicWires for PartialNodeWires { } } +/* #[cfg(test)] mod tests { use super::*; @@ -191,3 +191,4 @@ mod tests { } } } +*/ diff --git a/verifiable-db/src/cells_tree/public_inputs.rs b/verifiable-db/src/cells_tree/public_inputs.rs index a8ccfdafe..e2c2f5b3c 100644 --- a/verifiable-db/src/cells_tree/public_inputs.rs +++ b/verifiable-db/src/cells_tree/public_inputs.rs @@ -1,122 +1,225 @@ //! Public inputs for Cells Tree Construction circuits + use mp2_common::{ digest::{SplitDigestPoint, SplitDigestTarget}, group_hashing::weierstrass_to_point, public_inputs::{PublicInputCommon, PublicInputRange}, - types::{CBuilder, GFp, CURVE_TARGET_LEN}, + types::{CBuilder, CURVE_TARGET_LEN}, utils::{FromFields, FromTargets}, F, }; use plonky2::{ - hash::hash_types::{HashOut, HashOutTarget, NUM_HASH_OUT_ELTS}, + hash::hash_types::{HashOut, NUM_HASH_OUT_ELTS}, iop::target::Target, }; use plonky2_ecgfp5::{curve::curve::WeierstrassPoint, gadgets::curve::CurveTarget}; -use std::{array, fmt::Debug}; -// Cells Tree Construction public inputs: -// - `H : [4]F` : Poseidon hash of the subtree at this node -// - `DI : Digest[F]` : Cells digests accumulated up so far for INDIVIDUAL digest -// - `DM: Digest[F]` : Cells digests accumulated up so far for MULTIPLIER digest -const H_RANGE: PublicInputRange = 0..NUM_HASH_OUT_ELTS; -const DI_RANGE: PublicInputRange = H_RANGE.end..H_RANGE.end + CURVE_TARGET_LEN; -const DM_RANGE: PublicInputRange = DI_RANGE.end..DI_RANGE.end + CURVE_TARGET_LEN; +pub enum CellsTreePublicInputs { + // `H : F[4]` - Poseidon hash of the subtree at this node + NodeHash, + // - `individual_vd : Digest` - Cumulative digest of values of cells accumulated as individual + IndividualValuesDigest, + // - `multiplier_vd : Digest` - Cumulative digest of values of cells accumulated as multiplier + MultiplierValuesDigest, + // - `individual_md : Digest` - Cumulative digest of metadata of cells accumulated as individual + IndividualMetadataDigest, + // - `multiplier_md : Digest` - Cumulative digest of metadata of cells accumulated as multiplier + MultiplierMetadataDigest, +} /// Public inputs for Cells Tree Construction #[derive(Clone, Debug)] pub struct PublicInputs<'a, T> { pub(crate) h: &'a [T], - pub(crate) ind: &'a [T], - pub(crate) mul: &'a [T], + pub(crate) individual_vd: &'a [T], + pub(crate) multiplier_vd: &'a [T], + pub(crate) individual_md: &'a [T], + pub(crate) multiplier_md: &'a [T], } -impl<'a> PublicInputCommon for PublicInputs<'a, Target> { - const RANGES: &'static [PublicInputRange] = &[H_RANGE, DI_RANGE, DM_RANGE]; +const NUM_PUBLIC_INPUTS: usize = CellsTreePublicInputs::MultiplierMetadataDigest as usize + 1; - fn register_args(&self, cb: &mut CBuilder) { - cb.register_public_inputs(self.h); - cb.register_public_inputs(self.ind); - cb.register_public_inputs(self.mul); - } -} +impl<'a, T: Clone> PublicInputs<'a, T> { + const PI_RANGES: [PublicInputRange; NUM_PUBLIC_INPUTS] = [ + Self::to_range(CellsTreePublicInputs::NodeHash), + Self::to_range(CellsTreePublicInputs::IndividualValuesDigest), + Self::to_range(CellsTreePublicInputs::MultiplierValuesDigest), + Self::to_range(CellsTreePublicInputs::IndividualMetadataDigest), + Self::to_range(CellsTreePublicInputs::MultiplierMetadataDigest), + ]; -impl<'a> PublicInputs<'a, GFp> { - /// Get the cells digest point. - pub fn individual_digest_point(&self) -> WeierstrassPoint { - WeierstrassPoint::from_fields(self.ind) + const SIZES: [usize; NUM_PUBLIC_INPUTS] = [ + // Poseidon hash of the subtree at this node + NUM_HASH_OUT_ELTS, + // Cumulative digest of values of cells accumulated as individual + CURVE_TARGET_LEN, + // Cumulative digest of values of cells accumulated as multiplier + CURVE_TARGET_LEN, + // Cumulative digest of metadata of cells accumulated as individual + CURVE_TARGET_LEN, + // Cumulative digest of metadata of cells accumulated as multiplier + CURVE_TARGET_LEN, + ]; + + pub(crate) const fn to_range(pi: CellsTreePublicInputs) -> PublicInputRange { + let mut i = 0; + let mut offset = 0; + let pi_pos = pi as usize; + while i < pi_pos { + offset += Self::SIZES[i]; + i += 1; + } + offset..offset + Self::SIZES[pi_pos] } - pub fn multiplier_digest_point(&self) -> WeierstrassPoint { - WeierstrassPoint::from_fields(self.mul) + + pub(crate) const fn total_len() -> usize { + Self::to_range(CellsTreePublicInputs::MultiplierMetadataDigest).end } - pub fn split_digest_point(&self) -> SplitDigestPoint { - SplitDigestPoint { - individual: weierstrass_to_point(&self.individual_digest_point()), - multiplier: weierstrass_to_point(&self.multiplier_digest_point()), - } + + pub(crate) fn to_node_hash_raw(&self) -> &[T] { + self.h } -} -impl<'a> PublicInputs<'a, Target> { - /// Get the Poseidon hash of the subtree at this node. - pub fn node_hash(&self) -> HashOutTarget { - self.h.try_into().unwrap() + pub(crate) fn to_individual_values_digest_raw(&self) -> &[T] { + self.individual_vd } - /// Get the individual digest target. - pub fn individual_digest_target(&self) -> CurveTarget { - CurveTarget::from_targets(self.ind) + pub(crate) fn to_multiplier_values_digest_raw(&self) -> &[T] { + self.multiplier_vd } - /// Get the cells multiplier digest - pub fn multiplier_digest_target(&self) -> CurveTarget { - CurveTarget::from_targets(self.mul) + pub(crate) fn to_individual_metadata_digest_raw(&self) -> &[T] { + self.individual_md } - pub fn split_digest_target(&self) -> SplitDigestTarget { - SplitDigestTarget { - individual: self.individual_digest_target(), - multiplier: self.multiplier_digest_target(), - } + + pub(crate) fn to_multiplier_metadata_digest_raw(&self) -> &[T] { + self.multiplier_md } -} -impl<'a, T: Copy> PublicInputs<'a, T> { - /// Total length of the public inputs - pub(crate) const TOTAL_LEN: usize = DM_RANGE.end; + pub fn from_slice(input: &'a [T]) -> Self { + assert!( + input.len() >= Self::total_len(), + "Input slice too short to build cells tree public inputs, must be at least {} elements", + Self::total_len(), + ); - /// Create a new public inputs. - pub fn new(h: &'a [T], ind: &'a [T], mul: &'a [T]) -> Self { - Self { h, ind, mul } + Self { + h: &input[Self::PI_RANGES[0].clone()], + individual_vd: &input[Self::PI_RANGES[1].clone()], + multiplier_vd: &input[Self::PI_RANGES[2].clone()], + individual_md: &input[Self::PI_RANGES[3].clone()], + multiplier_md: &input[Self::PI_RANGES[4].clone()], + } } - /// Create from a slice. - pub fn from_slice(pi: &'a [T]) -> Self { - assert!(pi.len() >= Self::TOTAL_LEN); + pub fn new( + h: &'a [T], + individual_vd: &'a [T], + multiplier_vd: &'a [T], + individual_md: &'a [T], + multiplier_md: &'a [T], + ) -> Self { Self { - h: &pi[H_RANGE], - ind: &pi[DI_RANGE], - mul: &pi[DM_RANGE], + h, + individual_vd, + multiplier_vd, + individual_md, + multiplier_md, } } - /// Combine to a vector. pub fn to_vec(&self) -> Vec { self.h .iter() - .chain(self.ind) - .chain(self.mul) + .chain(self.individual_vd) + .chain(self.multiplier_vd) + .chain(self.individual_md) + .chain(self.multiplier_md) .cloned() .collect() } +} - pub fn h_raw(&self) -> &'a [T] { - self.h +impl<'a> PublicInputCommon for PublicInputs<'a, Target> { + const RANGES: &'static [PublicInputRange] = &Self::PI_RANGES; + + fn register_args(&self, cb: &mut CBuilder) { + cb.register_public_inputs(self.h); + cb.register_public_inputs(self.individual_vd); + cb.register_public_inputs(self.multiplier_vd); + cb.register_public_inputs(self.individual_md); + cb.register_public_inputs(self.multiplier_md); + } +} + +impl<'a> PublicInputs<'a, Target> { + pub fn node_hash_target(&self) -> [Target; NUM_HASH_OUT_ELTS] { + self.to_node_hash_raw().try_into().unwrap() + } + + pub fn individual_values_digest_target(&self) -> CurveTarget { + CurveTarget::from_targets(self.individual_vd) + } + + pub fn multiplier_values_digest_target(&self) -> CurveTarget { + CurveTarget::from_targets(self.multiplier_vd) + } + + pub fn individual_metadata_digest_target(&self) -> CurveTarget { + CurveTarget::from_targets(self.individual_md) + } + + pub fn multiplier_metadata_digest_target(&self) -> CurveTarget { + CurveTarget::from_targets(self.multiplier_md) + } + + pub fn split_values_digest_target(&self) -> SplitDigestTarget { + SplitDigestTarget { + individual: self.individual_values_digest_target(), + multiplier: self.multiplier_values_digest_target(), + } + } + + pub fn split_metadata_digest_target(&self) -> SplitDigestTarget { + SplitDigestTarget { + individual: self.individual_metadata_digest_target(), + multiplier: self.multiplier_metadata_digest_target(), + } } } impl<'a> PublicInputs<'a, F> { - pub fn root_hash_hashout(&self) -> HashOut { - HashOut { - elements: array::from_fn(|i| self.h[i]), + pub fn node_hash(&self) -> HashOut { + HashOut::from_partial(self.to_node_hash_raw()) + } + + pub fn individual_values_digest_point(&self) -> WeierstrassPoint { + WeierstrassPoint::from_fields(self.individual_vd) + } + + pub fn multiplier_values_digest_point(&self) -> WeierstrassPoint { + WeierstrassPoint::from_fields(self.multiplier_vd) + } + + pub fn individual_metadata_digest_point(&self) -> WeierstrassPoint { + WeierstrassPoint::from_fields(self.individual_md) + } + + pub fn multiplier_metadata_digest_point(&self) -> WeierstrassPoint { + WeierstrassPoint::from_fields(self.multiplier_md) + } + + pub fn split_values_digest_point(&self) -> SplitDigestPoint { + SplitDigestPoint { + individual: weierstrass_to_point(&self.individual_values_digest_point()), + multiplier: weierstrass_to_point(&self.multiplier_values_digest_point()), + } + } + + pub fn split_metadata_digest_point(&self) -> SplitDigestPoint { + SplitDigestPoint { + individual: weierstrass_to_point(&&self.individual_metadata_digest_point()), + multiplier: weierstrass_to_point(&self.multiplier_metadata_digest_point()), } } } @@ -138,20 +241,21 @@ mod tests { }; use plonky2_ecgfp5::curve::curve::Point; use rand::thread_rng; + use std::array; #[derive(Clone, Debug)] - struct TestPICircuit<'a> { + struct TestPublicInputs<'a> { exp_pi: &'a [F], } - impl<'a> UserCircuit for TestPICircuit<'a> { + impl<'a> UserCircuit for TestPublicInputs<'a> { type Wires = Vec; fn build(b: &mut CBuilder) -> Self::Wires { - let pi = b.add_virtual_targets(PublicInputs::::TOTAL_LEN); - PublicInputs::from_slice(&pi).register(b); + let exp_pi = b.add_virtual_targets(PublicInputs::::total_len()); + PublicInputs::from_slice(&exp_pi).register(b); - pi + exp_pi } fn prove(&self, pw: &mut PartialWitness, wires: &Self::Wires) { @@ -161,21 +265,46 @@ mod tests { #[test] fn test_cells_tree_public_inputs() { - let mut rng = thread_rng(); + let rng = &mut thread_rng(); // Prepare the public inputs. - let h = &random_vector::(NUM_HASH_OUT_ELTS).to_fields(); - let dc = &Point::sample(&mut rng).to_weierstrass().to_fields(); - let exp_pi = PublicInputs { - h, - ind: dc, - mul: dc, - }; + let h = random_vector::(NUM_HASH_OUT_ELTS).to_fields(); + let [individual_vd, multiplier_vd, individual_md, multiplier_md] = + array::from_fn(|_| Point::sample(rng).to_weierstrass().to_fields()); + let exp_pi = PublicInputs::new( + &h, + &individual_vd, + &multiplier_vd, + &individual_md, + &multiplier_md, + ); let exp_pi = &exp_pi.to_vec(); - let test_circuit = TestPICircuit { exp_pi }; + let test_circuit = TestPublicInputs { exp_pi }; let proof = run_circuit::(test_circuit); - assert_eq!(&proof.public_inputs, exp_pi); + + // Check if the public inputs are constructed correctly. + let pi = PublicInputs::from_slice(&proof.public_inputs); + assert_eq!( + &exp_pi[PublicInputs::::to_range(CellsTreePublicInputs::NodeHash)], + pi.to_node_hash_raw(), + ); + assert_eq!( + &exp_pi[PublicInputs::::to_range(CellsTreePublicInputs::IndividualValuesDigest)], + pi.to_individual_values_digest_raw(), + ); + assert_eq!( + &exp_pi[PublicInputs::::to_range(CellsTreePublicInputs::MultiplierValuesDigest)], + pi.to_multiplier_values_digest_raw(), + ); + assert_eq!( + &exp_pi[PublicInputs::::to_range(CellsTreePublicInputs::IndividualMetadataDigest)], + pi.to_individual_metadata_digest_raw(), + ); + assert_eq!( + &exp_pi[PublicInputs::::to_range(CellsTreePublicInputs::MultiplierMetadataDigest)], + pi.to_multiplier_metadata_digest_raw(), + ); } } diff --git a/verifiable-db/src/revelation/api.rs b/verifiable-db/src/revelation/api.rs index d0581135d..664bfb8d4 100644 --- a/verifiable-db/src/revelation/api.rs +++ b/verifiable-db/src/revelation/api.rs @@ -213,7 +213,7 @@ pub enum CircuitInput< [(); ROW_TREE_MAX_DEPTH - 1]:, [(); INDEX_TREE_MAX_DEPTH - 1]:, [(); MAX_NUM_ITEMS_PER_OUTPUT * MAX_NUM_OUTPUTS]:, - [(); { 2 * (MAX_NUM_PREDICATE_OPS + MAX_NUM_RESULT_OPS) }]:, + [(); 2 * (MAX_NUM_PREDICATE_OPS + MAX_NUM_RESULT_OPS)]:, { NoResultsTree { query_proof: ProofWithVK, diff --git a/verifiable-db/src/row_tree/api.rs b/verifiable-db/src/row_tree/api.rs index 7a17398b5..a56c24fe3 100644 --- a/verifiable-db/src/row_tree/api.rs +++ b/verifiable-db/src/row_tree/api.rs @@ -14,6 +14,7 @@ use super::{ full_node::{self, FullNodeCircuit}, leaf::{self, LeafCircuit}, partial_node::{self, PartialNodeCircuit}, + row::Row, PublicInputs, }; @@ -38,7 +39,7 @@ pub struct PublicParameters { row_set: RecursiveCircuits, } -const ROW_IO_LEN: usize = super::public_inputs::TOTAL_LEN; +const ROW_IO_LEN: usize = super::PublicInputs::::total_len(); impl PublicParameters { pub fn build(cells_set: &RecursiveCircuits) -> Self { @@ -180,18 +181,39 @@ pub enum CircuitInput { } impl CircuitInput { - pub fn leaf(identifier: u64, value: U256, cells_proof: Vec) -> Result { - Self::leaf_multiplier(identifier, value, false, cells_proof) + pub fn leaf( + identifier: u64, + value: U256, + mpt_metadata: HashOut, + row_unique_data: HashOut, + cells_proof: Vec, + ) -> Result { + Self::leaf_multiplier( + identifier, + value, + false, + mpt_metadata, + row_unique_data, + cells_proof, + ) } pub fn leaf_multiplier( identifier: u64, value: U256, is_multiplier: bool, + mpt_metadata: HashOut, + row_unique_data: HashOut, cells_proof: Vec, ) -> Result { - let circuit = Cell::new(F::from_canonical_u64(identifier), value, is_multiplier); + let cell = Cell::new( + F::from_canonical_u64(identifier), + value, + is_multiplier, + mpt_metadata, + ); + let row = Row::new(cell, row_unique_data); Ok(CircuitInput::Leaf { - witness: circuit.into(), + witness: row.into(), cells_proof, }) } @@ -199,6 +221,8 @@ impl CircuitInput { pub fn full( identifier: u64, value: U256, + mpt_metadata: HashOut, + row_unique_data: HashOut, left_proof: Vec, right_proof: Vec, cells_proof: Vec, @@ -207,6 +231,8 @@ impl CircuitInput { identifier, value, false, + mpt_metadata, + row_unique_data, left_proof, right_proof, cells_proof, @@ -216,13 +242,21 @@ impl CircuitInput { identifier: u64, value: U256, is_multiplier: bool, + mpt_metadata: HashOut, + row_unique_data: HashOut, left_proof: Vec, right_proof: Vec, cells_proof: Vec, ) -> Result { - let circuit = Cell::new(F::from_canonical_u64(identifier), value, is_multiplier); + let cell = Cell::new( + F::from_canonical_u64(identifier), + value, + is_multiplier, + mpt_metadata, + ); + let row = Row::new(cell, row_unique_data); Ok(CircuitInput::Full { - witness: circuit.into(), + witness: row.into(), left_proof, right_proof, cells_proof, @@ -232,6 +266,8 @@ impl CircuitInput { identifier: u64, value: U256, is_child_left: bool, + mpt_metadata: HashOut, + row_unique_data: HashOut, child_proof: Vec, cells_proof: Vec, ) -> Result { @@ -240,6 +276,8 @@ impl CircuitInput { value, false, is_child_left, + mpt_metadata, + row_unique_data, child_proof, cells_proof, ) @@ -249,11 +287,19 @@ impl CircuitInput { value: U256, is_multiplier: bool, is_child_left: bool, + mpt_metadata: HashOut, + row_unique_data: HashOut, child_proof: Vec, cells_proof: Vec, ) -> Result { - let tuple = Cell::new(F::from_canonical_u64(identifier), value, is_multiplier); - let witness = PartialNodeCircuit::new(tuple, is_child_left); + let cell = Cell::new( + F::from_canonical_u64(identifier), + value, + is_multiplier, + mpt_metadata, + ); + let row = Row::new(cell, row_unique_data); + let witness = PartialNodeCircuit::new(row, is_child_left); Ok(CircuitInput::Partial { witness, child_proof, @@ -264,9 +310,10 @@ impl CircuitInput { pub fn extract_hash_from_proof(proof: &[u8]) -> Result> { let p = ProofWithVK::deserialize(proof)?; - Ok(PublicInputs::from_slice(&p.proof.public_inputs).root_hash_hashout()) + Ok(PublicInputs::from_slice(&p.proof.public_inputs).root_hash()) } +/* #[cfg(test)] mod test { use crate::{cells_tree, row_tree::public_inputs::PublicInputs}; @@ -533,3 +580,4 @@ mod test { Ok(proof) } } +*/ diff --git a/verifiable-db/src/row_tree/full_node.rs b/verifiable-db/src/row_tree/full_node.rs index d672e6145..4bd983214 100644 --- a/verifiable-db/src/row_tree/full_node.rs +++ b/verifiable-db/src/row_tree/full_node.rs @@ -1,19 +1,15 @@ +use super::row::{Row, RowWire}; +use crate::cells_tree; use derive_more::{From, Into}; use mp2_common::{ - default_config, - group_hashing::{cond_circuit_hashed_scalar_mul, CircuitBuilderGroupHashing}, - poseidon::H, - proof::ProofWithVK, - public_inputs::PublicInputCommon, - u256::CircuitBuilderU256, - utils::ToTargets, - C, D, F, + default_config, group_hashing::CircuitBuilderGroupHashing, poseidon::H, proof::ProofWithVK, + public_inputs::PublicInputCommon, u256::CircuitBuilderU256, utils::ToTargets, C, D, F, }; use plonky2::{ iop::{target::Target, witness::PartialWitness}, plonk::{circuit_builder::CircuitBuilder, proof::ProofWithPublicInputsTarget}, }; -use plonky2_ecgfp5::gadgets::curve::CircuitBuilderEcGFp5; +use plonky2_ecdsa::gadgets::biguint::CircuitBuilderBiguint; use recursion_framework::{ circuit_builder::CircuitLogicWires, framework::{ @@ -21,19 +17,17 @@ use recursion_framework::{ }, }; use serde::{Deserialize, Serialize}; -use std::array::from_fn as create_array; - -use crate::cells_tree::{self, Cell, CellWire}; +use std::{array::from_fn as create_array, iter::once}; use super::public_inputs::PublicInputs; // Arity not strictly needed now but may be an easy way to increase performance // easily down the line with less recursion. Best to provide code which is easily // amenable to a different arity rather than hardcoding binary tree only #[derive(Clone, Debug, From, Into)] -pub struct FullNodeCircuit(Cell); +pub struct FullNodeCircuit(Row); #[derive(Clone, Serialize, Deserialize, From, Into)] -pub(crate) struct FullNodeWires(CellWire); +pub(crate) struct FullNodeWires(RowWire); impl FullNodeCircuit { pub(crate) fn build( @@ -42,52 +36,64 @@ impl FullNodeCircuit { right_pi: &[Target], cells_pi: &[Target], ) -> FullNodeWires { - let cells_pi = cells_tree::PublicInputs::from_slice(cells_pi); let min_child = PublicInputs::from_slice(left_pi); let max_child = PublicInputs::from_slice(right_pi); - let tuple = CellWire::new(b); - let node_min = min_child.min_value(); - let node_max = max_child.max_value(); + let cells_pi = cells_tree::PublicInputs::from_slice(cells_pi); + let row = RowWire::new(b); + let id = row.identifier(); + let value = row.value(); + let digest = row.digest(b, &cells_pi); + + // Check multiplier_vd and row_id_multiplier are the same as children proofs. + // assert multiplier_vd == p1.multiplier_vd == p2.multiplier_vd + b.connect_curve_points(digest.multiplier_vd, min_child.multiplier_digest_target()); + b.connect_curve_points(digest.multiplier_vd, max_child.multiplier_digest_target()); + // assert row_id_multiplier == p1.row_id_multiplier == p2.row_id_multiplier + b.connect_biguint( + &digest.row_id_multiplier, + &min_child.row_id_multiplier_target(), + ); + b.connect_biguint( + &digest.row_id_multiplier, + &max_child.row_id_multiplier_target(), + ); + + let node_min = min_child.min_value_target(); + let node_max = max_child.max_value_target(); // enforcing BST property let _true = b._true(); - let left_comparison = b.is_less_or_equal_than_u256(&min_child.max_value(), &tuple.value); - let right_comparison = b.is_less_or_equal_than_u256(&tuple.value, &max_child.min_value()); + let left_comparison = b.is_less_or_equal_than_u256(&min_child.max_value_target(), value); + let right_comparison = b.is_less_or_equal_than_u256(value, &max_child.min_value_target()); b.connect(left_comparison.target, _true.target); b.connect(right_comparison.target, _true.target); // Poseidon(p1.H || p2.H || node_min || node_max || index_id || index_value ||p.H)) as H let inputs = min_child - .root_hash() - .to_targets() + .root_hash_target() .iter() - .chain(max_child.root_hash().to_targets().iter()) + .chain(max_child.root_hash_target().iter()) .chain(node_min.to_targets().iter()) .chain(node_max.to_targets().iter()) - .chain(tuple.to_targets().iter()) - .chain(cells_pi.node_hash().to_targets().iter()) + .chain(once(&id)) + .chain(cells_pi.node_hash_target().iter()) .cloned() .collect::>(); let hash = b.hash_n_to_hash_no_pad::(inputs); - // final_digest = HashToInt(mul_digest) * D(ind_digest) + left.digest() + right.digest() - let split_digest = tuple.split_and_accumulate_digest(b, cells_pi.split_digest_target()); - let (row_digest, is_merge) = split_digest.cond_combine_to_row_digest(b); - - // add this row digest with the rest - let final_digest = b.curve_add(min_child.rows_digest(), max_child.rows_digest()); - let final_digest = b.curve_add(final_digest, row_digest); // assert `is_merge` is the same as the flags in children pis - b.connect(min_child.is_merge_case().target, is_merge.target); - b.connect(max_child.is_merge_case().target, is_merge.target); + b.connect(min_child.merge_flag_target().target, digest.is_merge.target); + b.connect(max_child.merge_flag_target().target, digest.is_merge.target); PublicInputs::new( &hash.to_targets(), - &final_digest.to_targets(), + &digest.individual_vd.to_targets(), + &digest.multiplier_vd.to_targets(), + &digest.row_id_multiplier.to_targets(), &node_min.to_targets(), &node_max.to_targets(), - &[is_merge.target], + &[digest.is_merge.target], ) .register(b); - FullNodeWires(tuple) + FullNodeWires(row) } fn assign(&self, pw: &mut PartialWitness, wires: &FullNodeWires) { self.0.assign_wires(pw, &wires.0); @@ -113,14 +119,14 @@ impl CircuitLogicWires for RecursiveFullWires { type Inputs = RecursiveFullInput; - const NUM_PUBLIC_INPUTS: usize = PublicInputs::::TOTAL_LEN; + const NUM_PUBLIC_INPUTS: usize = PublicInputs::::total_len(); fn circuit_logic( builder: &mut CircuitBuilder, verified_proofs: [&ProofWithPublicInputsTarget; NUM_CHILDREN], builder_parameters: Self::CircuitBuilderParams, ) -> Self { - const CELLS_IO: usize = cells_tree::PublicInputs::::TOTAL_LEN; + const CELLS_IO: usize = cells_tree::PublicInputs::::total_len(); let verifier_gadget = RecursiveCircuitsVerifierGagdet::::new( default_config(), &builder_parameters, @@ -144,6 +150,7 @@ impl CircuitLogicWires for RecursiveFullWires { } } +/* #[cfg(test)] pub(crate) mod test { @@ -185,9 +192,9 @@ pub(crate) mod test { type Wires = (FullNodeWires, Vec, Vec, Vec); fn build(c: &mut CircuitBuilder) -> Self::Wires { - let cells_pi = c.add_virtual_targets(cells_tree::PublicInputs::::TOTAL_LEN); - let left_pi = c.add_virtual_targets(PublicInputs::::TOTAL_LEN); - let right_pi = c.add_virtual_targets(PublicInputs::::TOTAL_LEN); + let cells_pi = c.add_virtual_targets(cells_tree::PublicInputs::::total_len()); + let left_pi = c.add_virtual_targets(PublicInputs::::total_len()); + let right_pi = c.add_virtual_targets(PublicInputs::::total_len()); ( FullNodeCircuit::build(c, &left_pi, &right_pi, &cells_pi), left_pi, @@ -302,3 +309,4 @@ pub(crate) mod test { test_row_tree_full_circuit(true, true); } } +*/ diff --git a/verifiable-db/src/row_tree/leaf.rs b/verifiable-db/src/row_tree/leaf.rs index 4d6e0a4d9..e9c6a34f6 100644 --- a/verifiable-db/src/row_tree/leaf.rs +++ b/verifiable-db/src/row_tree/leaf.rs @@ -1,7 +1,11 @@ +use super::{ + public_inputs::PublicInputs, + row::{Row, RowWire}, +}; +use crate::cells_tree; use derive_more::{From, Into}; use mp2_common::{ default_config, - group_hashing::{cond_circuit_hashed_scalar_mul, CircuitBuilderGroupHashing}, poseidon::{empty_poseidon_hash, H}, proof::ProofWithVK, public_inputs::PublicInputCommon, @@ -9,10 +13,7 @@ use mp2_common::{ C, D, F, }; use plonky2::{ - iop::{ - target::{BoolTarget, Target}, - witness::PartialWitness, - }, + iop::{target::Target, witness::PartialWitness}, plonk::{circuit_builder::CircuitBuilder, proof::ProofWithPublicInputsTarget}, }; use recursion_framework::{ @@ -22,57 +23,50 @@ use recursion_framework::{ }, }; use serde::{Deserialize, Serialize}; - -use crate::cells_tree::{self, Cell, CellWire}; - -use super::public_inputs::PublicInputs; +use std::iter::once; // new type to implement the circuit logic on each differently // deref to access directly the same members - read only so it's ok #[derive(Clone, Debug, From, Into)] -pub struct LeafCircuit(Cell); +pub struct LeafCircuit(Row); #[derive(Clone, Serialize, Deserialize, From, Into)] -pub(crate) struct LeafWires(CellWire); +pub(crate) struct LeafWires(RowWire); impl LeafCircuit { pub(crate) fn build(b: &mut CircuitBuilder, cells_pis: &[Target]) -> LeafWires { let cells_pis = cells_tree::PublicInputs::from_slice(cells_pis); - // D(index_id||pack_u32(index_value) - let tuple = CellWire::new(b); - // set the right digest depending on the multiplier and accumulate the ones from the public - // inputs of the cell root proof - let split_digest = tuple.split_and_accumulate_digest(b, cells_pis.split_digest_target()); - // final_digest = HashToInt(D(mul_digest)) * D(ind_digest) - // NOTE This additional digest is necessary since the individual digest is supposed to be a - // full row, that is how it is extracted from MPT - let (final_digest, is_merge) = split_digest.cond_combine_to_row_digest(b); + let row = RowWire::new(b); + let id = row.identifier(); + let value = row.value().to_targets(); + let digest = row.digest(b, &cells_pis); // H(left_child_hash,right_child_hash,min,max,index_identifier,index_value,cells_tree_hash) // in our case, min == max == index_value // left_child_hash == right_child_hash == empty_hash since there is not children - let empty_hash = b.constant_hash(*empty_poseidon_hash()); + let empty_hash = b.constant_hash(*empty_poseidon_hash()).to_targets(); let inputs = empty_hash - .to_targets() - .iter() - .chain(empty_hash.to_targets().iter()) - .chain(tuple.value.to_targets().iter()) - .chain(tuple.value.to_targets().iter()) - .chain(tuple.to_targets().iter()) - .chain(cells_pis.node_hash().to_targets().iter()) - .cloned() + .clone() + .into_iter() + .chain(empty_hash) + .chain(value.clone()) + .chain(value.clone()) + .chain(once(id)) + .chain(cells_pis.node_hash_target()) .collect::>(); let row_hash = b.hash_n_to_hash_no_pad::(inputs); - let value_fields = tuple.value.to_targets(); PublicInputs::new( &row_hash.elements, - &final_digest.to_targets(), - &value_fields, - &value_fields, - &[is_merge.target], + &digest.individual_vd.to_targets(), + &digest.multiplier_vd.to_targets(), + &digest.row_id_multiplier.to_targets(), + &value, + &value, + &[digest.is_merge.target], ) .register(b); - LeafWires(tuple) + + LeafWires(row) } fn assign(&self, pw: &mut PartialWitness, wires: &LeafWires) { @@ -102,14 +96,14 @@ impl CircuitLogicWires for RecursiveLeafWires { type Inputs = RecursiveLeafInput; - const NUM_PUBLIC_INPUTS: usize = PublicInputs::::TOTAL_LEN; + const NUM_PUBLIC_INPUTS: usize = PublicInputs::::total_len(); fn circuit_logic( builder: &mut CircuitBuilder, _verified_proofs: [&ProofWithPublicInputsTarget; 0], builder_parameters: Self::CircuitBuilderParams, ) -> Self { - const CELLS_IO: usize = cells_tree::PublicInputs::::TOTAL_LEN; + const CELLS_IO: usize = cells_tree::PublicInputs::::total_len(); let verifier_gadget = RecursiveCircuitsVerifierGagdet::::new( default_config(), &builder_parameters, @@ -131,6 +125,7 @@ impl CircuitLogicWires for RecursiveLeafWires { } } +/* #[cfg(test)] mod test { @@ -243,3 +238,4 @@ mod test { test_row_tree_leaf_circuit(true, true); } } +*/ diff --git a/verifiable-db/src/row_tree/mod.rs b/verifiable-db/src/row_tree/mod.rs index c45a72292..82f0247e5 100644 --- a/verifiable-db/src/row_tree/mod.rs +++ b/verifiable-db/src/row_tree/mod.rs @@ -1,26 +1,9 @@ -use alloy::primitives::U256; -use derive_more::Constructor; -use mp2_common::{ - group_hashing::CircuitBuilderGroupHashing, - u256::{CircuitBuilderU256, UInt256Target, WitnessWriteU256}, - utils::{ToFields, ToTargets}, - D, F, -}; -use plonky2::{ - iop::{ - target::{BoolTarget, Target}, - witness::{PartialWitness, WitnessWrite}, - }, - plonk::circuit_builder::CircuitBuilder, -}; -use plonky2_ecgfp5::gadgets::curve::CurveTarget; -use serde::{Deserialize, Serialize}; - mod api; mod full_node; mod leaf; mod partial_node; mod public_inputs; +mod row; pub use api::{extract_hash_from_proof, CircuitInput, PublicParameters}; pub use public_inputs::PublicInputs; diff --git a/verifiable-db/src/row_tree/partial_node.rs b/verifiable-db/src/row_tree/partial_node.rs index 00af074c8..2b2d6bde2 100644 --- a/verifiable-db/src/row_tree/partial_node.rs +++ b/verifiable-db/src/row_tree/partial_node.rs @@ -1,8 +1,8 @@ -use plonky2::plonk::proof::ProofWithPublicInputsTarget; - +use super::row::{Row, RowWire}; +use crate::cells_tree; use mp2_common::{ default_config, - group_hashing::{cond_circuit_hashed_scalar_mul, CircuitBuilderGroupHashing}, + group_hashing::CircuitBuilderGroupHashing, hash::hash_maybe_first, poseidon::empty_poseidon_hash, proof::ProofWithVK, @@ -18,9 +18,9 @@ use plonky2::{ target::{BoolTarget, Target}, witness::{PartialWitness, WitnessWrite}, }, - plonk::circuit_builder::CircuitBuilder, + plonk::{circuit_builder::CircuitBuilder, proof::ProofWithPublicInputsTarget}, }; -use plonky2_ecgfp5::gadgets::curve::CircuitBuilderEcGFp5; +use plonky2_ecdsa::gadgets::biguint::CircuitBuilderBiguint; use recursion_framework::{ circuit_builder::CircuitLogicWires, framework::{ @@ -28,28 +28,27 @@ use recursion_framework::{ }, }; use serde::{Deserialize, Serialize}; - -use crate::cells_tree::{self, Cell, CellWire}; +use std::iter::once; use super::public_inputs::PublicInputs; #[derive(Clone, Debug)] pub struct PartialNodeCircuit { - pub(crate) tuple: Cell, + pub(crate) row: Row, pub(crate) is_child_at_left: bool, } #[derive(Clone, Debug, Serialize, Deserialize)] struct PartialNodeWires { - tuple: CellWire, + row: RowWire, #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] is_child_at_left: BoolTarget, } impl PartialNodeCircuit { - pub(crate) fn new(tuple: Cell, is_child_at_left: bool) -> Self { + pub(crate) fn new(row: Row, is_child_at_left: bool) -> Self { Self { - tuple, + row, is_child_at_left, } } @@ -58,22 +57,35 @@ impl PartialNodeCircuit { child_pi: &[Target], cells_pi: &[Target], ) -> PartialNodeWires { + let child_pi = PublicInputs::from_slice(child_pi); let cells_pi = cells_tree::PublicInputs::from_slice(cells_pi); - let tuple = CellWire::new(b); + let row = RowWire::new(b); + let id = row.identifier(); + let value = row.value(); + let digest = row.digest(b, &cells_pi); + + // Check multiplier_vd and row_id_multiplier are the same as child proof + // assert multiplier_vd == child_proof.multiplier_vd + b.connect_curve_points(digest.multiplier_vd, child_pi.multiplier_digest_target()); + //assert row_id_multiplier == child_proof.row_id_multiplier + b.connect_biguint( + &digest.row_id_multiplier, + &child_pi.row_id_multiplier_target(), + ); + // bool target range checked in poseidon gate let is_child_at_left = b.add_virtual_bool_target_unsafe(); - let child_pi = PublicInputs::from_slice(child_pi); // max_left = left ? child_proof.max : index_value // min_right = left ? index_value : child_proof.min - let max_left = b.select_u256(is_child_at_left, &child_pi.max_value(), &tuple.value); - let min_right = b.select_u256(is_child_at_left, &tuple.value, &child_pi.min_value()); + let max_left = b.select_u256(is_child_at_left, &child_pi.max_value_target(), value); + let min_right = b.select_u256(is_child_at_left, value, &child_pi.min_value_target()); let bst_enforced = b.is_less_or_equal_than_u256(&max_left, &min_right); let _true = b._true(); b.connect(bst_enforced.target, _true.target); // node_min = left ? child_proof.min : index_value // node_max = left ? index_value : child_proof.max - let node_min = b.select_u256(is_child_at_left, &child_pi.min_value(), &tuple.value); - let node_max = b.select_u256(is_child_at_left, &tuple.value, &child_pi.max_value()); + let node_min = b.select_u256(is_child_at_left, &child_pi.min_value_target(), value); + let node_max = b.select_u256(is_child_at_left, value, &child_pi.max_value_target()); let empty_hash = b.constant_hash(*empty_poseidon_hash()); // left_hash = left ? child_proof.H : H("") @@ -85,8 +97,8 @@ impl PartialNodeCircuit { .to_targets() .iter() .chain(node_max.to_targets().iter()) - .chain(tuple.to_targets().iter()) - .chain(cells_pi.node_hash().to_targets().iter()) + .chain(once(&id)) + .chain(cells_pi.node_hash_target().iter()) .cloned() .collect::>(); // if child at left, then hash should be child_proof.H || H("") || rest @@ -95,34 +107,31 @@ impl PartialNodeCircuit { b, is_child_at_left, empty_hash.elements, - child_pi.root_hash().elements, + child_pi.root_hash_target(), &rest, ); - // final_digest = HashToInt(mul_digest) * D(ind_digest) - let split_digest = tuple.split_and_accumulate_digest(b, cells_pi.split_digest_target()); - let (row_digest, is_merge) = split_digest.cond_combine_to_row_digest(b); - - // and add the digest of the row other rows - let final_digest = b.curve_add(child_pi.rows_digest(), row_digest); // assert is_merge is the same between this row and `child_pi` - b.connect(is_merge.target, child_pi.is_merge_case().target); + b.connect(digest.is_merge.target, child_pi.merge_flag_target().target); + PublicInputs::new( &node_hash, - &final_digest.to_targets(), + &digest.individual_vd.to_targets(), &node_min.to_targets(), &node_max.to_targets(), - &[is_merge.target], + &[digest.is_merge.target], + &digest.multiplier_vd.to_targets(), + &digest.row_id_multiplier.to_targets(), ) .register(b); PartialNodeWires { - tuple, + row, is_child_at_left, } } fn assign(&self, pw: &mut PartialWitness, wires: &PartialNodeWires) { - self.tuple.assign_wires(pw, &wires.tuple); + self.row.assign_wires(pw, &wires.row); pw.set_bool_target(wires.is_child_at_left, self.is_child_at_left); } } @@ -145,14 +154,14 @@ impl CircuitLogicWires for RecursivePartialWires { type Inputs = RecursivePartialInput; - const NUM_PUBLIC_INPUTS: usize = PublicInputs::::TOTAL_LEN; + const NUM_PUBLIC_INPUTS: usize = PublicInputs::::total_len(); fn circuit_logic( builder: &mut CircuitBuilder, verified_proofs: [&ProofWithPublicInputsTarget; NUM_CHILDREN], builder_parameters: Self::CircuitBuilderParams, ) -> Self { - const CELLS_IO: usize = cells_tree::PublicInputs::::TOTAL_LEN; + const CELLS_IO: usize = cells_tree::PublicInputs::::total_len(); let verifier_gadget = RecursiveCircuitsVerifierGagdet::::new( default_config(), &builder_parameters, @@ -175,6 +184,7 @@ impl CircuitLogicWires for RecursivePartialWires { } } +/* #[cfg(test)] pub mod test { use mp2_common::{ @@ -322,8 +332,8 @@ pub mod test { // node_min = left ? child_proof.min : index_value // node_max = left ? index_value : child_proof.max let (node_min, node_max) = match child_at_left { - true => (pi.min_value_u256(), tuple.value), - false => (tuple.value, pi.max_value_u256()), + true => (pi.min_value(), tuple.value), + false => (tuple.value, pi.max_value()), }; // Poseidon(p1.H || p2.H || node_min || node_max || index_id || index_value ||p.H)) as H let child_hash = PublicInputs::from_slice(&child_pi).root_hash_hashout(); @@ -352,3 +362,4 @@ pub mod test { assert_eq!(split_digest.is_merge_case(), pi.is_merge_flag()); } } +*/ diff --git a/verifiable-db/src/row_tree/public_inputs.rs b/verifiable-db/src/row_tree/public_inputs.rs index a775af1af..06eb726d1 100644 --- a/verifiable-db/src/row_tree/public_inputs.rs +++ b/verifiable-db/src/row_tree/public_inputs.rs @@ -1,183 +1,299 @@ //! Public inputs for rows trees creation circuits -//! + use alloy::primitives::U256; +use itertools::Itertools; use mp2_common::{ + poseidon::HASH_TO_INT_LEN, public_inputs::{PublicInputCommon, PublicInputRange}, - types::CURVE_TARGET_LEN, + types::{CBuilder, CURVE_TARGET_LEN}, u256::{self, UInt256Target}, utils::{FromFields, FromTargets, TryIntoBool}, - D, F, + F, }; +use num::BigUint; use plonky2::{ - hash::hash_types::{HashOut, HashOutTarget, NUM_HASH_OUT_ELTS}, + field::types::PrimeField64, + hash::hash_types::{HashOut, NUM_HASH_OUT_ELTS}, iop::target::{BoolTarget, Target}, - plonk::circuit_builder::CircuitBuilder, }; +use plonky2_crypto::u32::arithmetic_u32::U32Target; +use plonky2_ecdsa::gadgets::biguint::BigUintTarget; use plonky2_ecgfp5::{curve::curve::WeierstrassPoint, gadgets::curve::CurveTarget}; -use std::array::from_fn as create_array; - -// Contract extraction public Inputs: -// - `H : [4]F` : Poseidon hash of the leaf -// - `DR : Digest[F]` : accumulated digest of all the rows up to this node -// - `min : Uint256` : min value of the secondary index stored up to this node -// - `max : Uint256` : max value of the secondary index stored up to this node -// - `merge : bool` : Flag specifying whether we are building rows for a merge table or not -const H_RANGE: PublicInputRange = 0..NUM_HASH_OUT_ELTS; -const DR_RANGE: PublicInputRange = H_RANGE.end..H_RANGE.end + CURVE_TARGET_LEN; -const MIN_RANGE: PublicInputRange = DR_RANGE.end..DR_RANGE.end + u256::NUM_LIMBS; -const MAX_RANGE: PublicInputRange = MIN_RANGE.end..MIN_RANGE.end + u256::NUM_LIMBS; -const MERGE_RANGE: PublicInputRange = MAX_RANGE.end..MAX_RANGE.end + 1; - -/// Public inputs for contract extraction +use std::iter::once; + +pub enum RowsTreePublicInputs { + // `H : F[4]` - Poseidon hash of the leaf + RootHash, + // `individual_digest : Digest` - Cumulative digest of the values of the cells which are accumulated in individual digest + IndividualDigest, + // `multiplier_digest : Digest` - Cumulative digest of the values of the cells which are accumulated in multiplier digest + MultiplierDigest, + // `row_id_multiplier : F[4]` - `H2Int(H("") || multiplier_md)`, where `multiplier_md` is the metadata digest of cells accumulated in `multiplier_digest` + RowIdMultiplier, + // `min : Uint256` - Minimum alue of the secondary index stored up to this node + MinValue, + // `max : Uint256` - Maximum value of the secondary index stored up to this node + MaxValue, + // `merge : bool` - Flag specifying whether we are building rows for a merge table or not + MergeFlag, +} + +/// Public inputs for Rows Tree Construction #[derive(Clone, Debug)] pub struct PublicInputs<'a, T> { pub(crate) h: &'a [T], - pub(crate) dr: &'a [T], + pub(crate) individual_digest: &'a [T], + pub(crate) multiplier_digest: &'a [T], + pub(crate) row_id_multiplier: &'a [T], pub(crate) min: &'a [T], pub(crate) max: &'a [T], - pub(crate) merge: &'a [T], + pub(crate) merge: &'a T, } -impl<'a> PublicInputCommon for PublicInputs<'a, Target> { - const RANGES: &'static [PublicInputRange] = - &[H_RANGE, DR_RANGE, MIN_RANGE, MAX_RANGE, MERGE_RANGE]; +const NUM_PUBLIC_INPUTS: usize = RowsTreePublicInputs::MergeFlag as usize + 1; - fn register_args(&self, cb: &mut CircuitBuilder) { - cb.register_public_inputs(self.h); - cb.register_public_inputs(self.dr); - cb.register_public_inputs(self.min); - cb.register_public_inputs(self.max); - cb.register_public_input(self.merge[0]); - } -} +impl<'a, T: Clone> PublicInputs<'a, T> { + const PI_RANGES: [PublicInputRange; NUM_PUBLIC_INPUTS] = [ + Self::to_range(RowsTreePublicInputs::RootHash), + Self::to_range(RowsTreePublicInputs::IndividualDigest), + Self::to_range(RowsTreePublicInputs::MultiplierDigest), + Self::to_range(RowsTreePublicInputs::RowIdMultiplier), + Self::to_range(RowsTreePublicInputs::MinValue), + Self::to_range(RowsTreePublicInputs::MaxValue), + Self::to_range(RowsTreePublicInputs::MergeFlag), + ]; -// mostly used for testing -impl<'a> PublicInputs<'a, F> { - /// Get the metadata point. - pub fn rows_digest_field(&self) -> WeierstrassPoint { - WeierstrassPoint::from_fields(self.dr) - } - /// minimum index value - pub fn min_value_u256(&self) -> U256 { - U256::from_fields(self.min) - } - /// maximum index value - pub fn max_value_u256(&self) -> U256 { - U256::from_fields(self.max) - } - /// hash of the subtree at this node - pub fn root_hash_hashout(&self) -> HashOut { - HashOut { - elements: create_array(|i| self.h[i]), + const SIZES: [usize; NUM_PUBLIC_INPUTS] = [ + // Poseidon hash of the leaf + NUM_HASH_OUT_ELTS, + // Cumulative digest of the values of the cells which are accumulated in individual digest + CURVE_TARGET_LEN, + // Cumulative digest of the values of the cells which are accumulated in multiplier digest + CURVE_TARGET_LEN, + // `H2Int(H("") || multiplier_md)`, where `multiplier_md` is the metadata digest of cells accumulated in `multiplier_digest` + HASH_TO_INT_LEN, + // Minimum alue of the secondary index stored up to this node + u256::NUM_LIMBS, + // Maximum value of the secondary index stored up to this node + u256::NUM_LIMBS, + // Flag specifying whether we are building rows for a merge table or not + 1, + ]; + + pub(crate) const fn to_range(pi: RowsTreePublicInputs) -> PublicInputRange { + let mut i = 0; + let mut offset = 0; + let pi_pos = pi as usize; + while i < pi_pos { + offset += Self::SIZES[i]; + i += 1; } + offset..offset + Self::SIZES[pi_pos] } - pub fn is_merge_flag(&self) -> bool { - self.merge[0].try_into_bool().unwrap() + pub(crate) const fn total_len() -> usize { + Self::to_range(RowsTreePublicInputs::RowIdMultiplier).end } -} -impl<'a> PublicInputs<'a, Target> { - /// Get the hash corresponding to the root of the subtree of this node - pub fn root_hash(&self) -> HashOutTarget { - HashOutTarget::from_targets(self.h) + pub(crate) fn to_root_hash_raw(&self) -> &[T] { + self.h } - pub fn rows_digest(&self) -> CurveTarget { - let dv = self.dr; - CurveTarget::from_targets(dv) + pub(crate) fn to_individual_digest_raw(&self) -> &[T] { + self.individual_digest } - pub fn min_value(&self) -> UInt256Target { - UInt256Target::from_targets(self.min) + pub(crate) fn to_multiplier_digest_raw(&self) -> &[T] { + self.multiplier_digest } - pub fn max_value(&self) -> UInt256Target { - UInt256Target::from_targets(self.max) + + pub(crate) fn to_row_id_multiplier_raw(&self) -> &[T] { + self.row_id_multiplier } - pub fn is_merge_case(&self) -> BoolTarget { - BoolTarget::new_unsafe(self.merge[0]) + pub(crate) fn to_min_value_raw(&self) -> &[T] { + self.min } -} -pub const TOTAL_LEN: usize = PublicInputs::::TOTAL_LEN; + pub(crate) fn to_max_value_raw(&self) -> &[T] { + self.max + } -impl<'a, T: Copy> PublicInputs<'a, T> { - /// Total length of the public inputs - pub(crate) const TOTAL_LEN: usize = MERGE_RANGE.end; + pub(crate) fn to_merge_flag_raw(&self) -> &T { + self.merge + } - /// Create from a slice. - pub fn from_slice(pi: &'a [T]) -> Self { - assert!(pi.len() >= Self::TOTAL_LEN); + pub fn from_slice(input: &'a [T]) -> Self { + assert!( + input.len() >= Self::total_len(), + "Input slice too short to build rows tree public inputs, must be at least {} elements", + Self::total_len(), + ); Self { - h: &pi[H_RANGE], - dr: &pi[DR_RANGE], - min: &pi[MIN_RANGE], - max: &pi[MAX_RANGE], - merge: &pi[MERGE_RANGE], + h: &input[Self::PI_RANGES[0].clone()], + individual_digest: &input[Self::PI_RANGES[1].clone()], + multiplier_digest: &input[Self::PI_RANGES[2].clone()], + row_id_multiplier: &input[Self::PI_RANGES[3].clone()], + min: &input[Self::PI_RANGES[4].clone()], + max: &input[Self::PI_RANGES[5].clone()], + merge: &input[Self::PI_RANGES[6].clone()][0], } } - /// Create a new public inputs. - pub fn new(h: &'a [T], dr: &'a [T], min: &'a [T], max: &'a [T], merge: &'a [T]) -> Self { - assert_eq!(h.len(), NUM_HASH_OUT_ELTS); - assert_eq!(dr.len(), CURVE_TARGET_LEN); - assert_eq!(min.len(), u256::NUM_LIMBS); - assert_eq!(max.len(), u256::NUM_LIMBS); - assert_eq!(merge.len(), 1); + pub fn new( + h: &'a [T], + individual_digest: &'a [T], + multiplier_digest: &'a [T], + row_id_multiplier: &'a [T], + min: &'a [T], + max: &'a [T], + merge: &'a [T], + ) -> Self { Self { h, - dr, + individual_digest, + multiplier_digest, + row_id_multiplier, min, max, - merge, + merge: &merge[0], } } - /// Combine to a vector. pub fn to_vec(&self) -> Vec { self.h .iter() - .chain(self.dr) + .chain(self.individual_digest) + .chain(self.multiplier_digest) + .chain(self.row_id_multiplier) .chain(self.min) .chain(self.max) - .chain(self.merge) + .chain(once(self.merge)) .cloned() .collect() } } +impl<'a> PublicInputCommon for PublicInputs<'a, Target> { + const RANGES: &'static [PublicInputRange] = &Self::PI_RANGES; + + fn register_args(&self, cb: &mut CBuilder) { + cb.register_public_inputs(self.h); + cb.register_public_inputs(self.individual_digest); + cb.register_public_inputs(self.multiplier_digest); + cb.register_public_inputs(self.row_id_multiplier); + cb.register_public_inputs(self.min); + cb.register_public_inputs(self.max); + cb.register_public_input(*self.merge); + } +} + +impl<'a> PublicInputs<'a, Target> { + pub fn root_hash_target(&self) -> [Target; NUM_HASH_OUT_ELTS] { + self.to_root_hash_raw().try_into().unwrap() + } + + pub fn individual_digest_target(&self) -> CurveTarget { + CurveTarget::from_targets(self.individual_digest) + } + + pub fn multiplier_digest_target(&self) -> CurveTarget { + CurveTarget::from_targets(self.multiplier_digest) + } + + pub fn row_id_multiplier_target(&self) -> BigUintTarget { + let limbs = self + .row_id_multiplier + .iter() + .cloned() + .map(U32Target) + .collect(); + + BigUintTarget { limbs } + } + + pub fn min_value_target(&self) -> UInt256Target { + UInt256Target::from_targets(self.min) + } + + pub fn max_value_target(&self) -> UInt256Target { + UInt256Target::from_targets(self.max) + } + + pub fn merge_flag_target(&self) -> BoolTarget { + BoolTarget::new_unsafe(*self.merge) + } +} + +impl<'a> PublicInputs<'a, F> { + pub fn root_hash(&self) -> HashOut { + HashOut::from_partial(self.h) + } + + pub fn individual_digest_point(&self) -> WeierstrassPoint { + WeierstrassPoint::from_fields(self.individual_digest) + } + + pub fn multiplier_digest_point(&self) -> WeierstrassPoint { + WeierstrassPoint::from_fields(self.multiplier_digest) + } + + pub fn row_id_multiplier(&self) -> BigUint { + let limbs = self + .row_id_multiplier + .iter() + .map(|f| u32::try_from(f.to_canonical_u64()).unwrap()) + .collect_vec(); + + BigUint::from_slice(&limbs) + } + + pub fn min_value(&self) -> U256 { + U256::from_fields(self.min) + } + + pub fn max_value(&self) -> U256 { + U256::from_fields(self.max) + } + + pub fn merge_flag(&self) -> bool { + self.merge.try_into_bool().unwrap() + } +} + #[cfg(test)] mod tests { use super::*; - use alloy::primitives::U256; - use mp2_common::{public_inputs::PublicInputCommon, utils::ToFields, C, D, F}; - use mp2_test::circuit::{run_circuit, UserCircuit}; + use mp2_common::{utils::ToFields, C, D, F}; + use mp2_test::{ + circuit::{run_circuit, UserCircuit}, + utils::random_vector, + }; use plonky2::{ field::types::{Field, Sample}, iop::{ target::Target, witness::{PartialWitness, WitnessWrite}, }, - plonk::config::GenericHashOut, }; use plonky2_ecgfp5::curve::curve::Point; use rand::{thread_rng, Rng}; + use std::{array, slice}; #[derive(Clone, Debug)] - struct TestPICircuit<'a> { + struct TestPublicInputs<'a> { exp_pi: &'a [F], } - impl<'a> UserCircuit for TestPICircuit<'a> { + impl<'a> UserCircuit for TestPublicInputs<'a> { type Wires = Vec; - fn build(b: &mut CircuitBuilder) -> Self::Wires { - let pi = b.add_virtual_targets(PublicInputs::::TOTAL_LEN); - let pi = PublicInputs::from_slice(&pi); - pi.register(b); - pi.to_vec() + fn build(b: &mut CBuilder) -> Self::Wires { + let exp_pi = b.add_virtual_targets(PublicInputs::::total_len()); + PublicInputs::from_slice(&exp_pi).register(b); + + exp_pi } fn prove(&self, pw: &mut PartialWitness, wires: &Self::Wires) { @@ -187,21 +303,59 @@ mod tests { #[test] fn test_rows_tree_public_inputs() { - let mut rng = thread_rng(); + let rng = &mut thread_rng(); // Prepare the public inputs. - let h = HashOut::rand().to_vec(); - let dr = Point::sample(&mut rng); - let drw = dr.to_weierstrass().to_fields(); - let min = U256::from_limbs(rng.gen::<[u64; 4]>()).to_fields(); - let max = U256::from_limbs(rng.gen::<[u64; 4]>()).to_fields(); - let merge = [F::from_canonical_usize(rng.gen_bool(0.5) as usize)]; - let exp_pi = PublicInputs::new(&h, &drw, &min, &max, &merge); + let h = random_vector::(NUM_HASH_OUT_ELTS).to_fields(); + let [individual_digest, multiplier_digest] = + array::from_fn(|_| Point::sample(rng).to_weierstrass().to_fields()); + let row_id_multiplier = rng.gen::<[u32; 4]>().map(F::from_canonical_u32); + let [min, max] = array::from_fn(|_| U256::from_limbs(rng.gen()).to_fields()); + let merge = [F::from_bool(rng.gen_bool(0.5))]; + let exp_pi = PublicInputs::new( + &h, + &individual_digest, + &multiplier_digest, + &row_id_multiplier, + &min, + &max, + &merge, + ); let exp_pi = &exp_pi.to_vec(); - assert_eq!(exp_pi.len(), PublicInputs::::TOTAL_LEN); - let test_circuit = TestPICircuit { exp_pi }; - let proof = run_circuit::(test_circuit); + let test_circuit = TestPublicInputs { exp_pi }; + let proof = run_circuit::(test_circuit); assert_eq!(&proof.public_inputs, exp_pi); + + // Check if the public inputs are constructed correctly. + let pi = PublicInputs::from_slice(&proof.public_inputs); + assert_eq!( + &exp_pi[PublicInputs::::to_range(RowsTreePublicInputs::RootHash)], + pi.to_root_hash_raw(), + ); + assert_eq!( + &exp_pi[PublicInputs::::to_range(RowsTreePublicInputs::IndividualDigest)], + pi.to_individual_digest_raw(), + ); + assert_eq!( + &exp_pi[PublicInputs::::to_range(RowsTreePublicInputs::MultiplierDigest)], + pi.to_multiplier_digest_raw(), + ); + assert_eq!( + &exp_pi[PublicInputs::::to_range(RowsTreePublicInputs::RowIdMultiplier)], + pi.to_row_id_multiplier_raw(), + ); + assert_eq!( + &exp_pi[PublicInputs::::to_range(RowsTreePublicInputs::MinValue)], + pi.to_min_value_raw(), + ); + assert_eq!( + &exp_pi[PublicInputs::::to_range(RowsTreePublicInputs::MaxValue)], + pi.to_max_value_raw(), + ); + assert_eq!( + &exp_pi[PublicInputs::::to_range(RowsTreePublicInputs::MergeFlag)], + slice::from_ref(pi.to_merge_flag_raw()), + ); } } diff --git a/verifiable-db/src/row_tree/row.rs b/verifiable-db/src/row_tree/row.rs new file mode 100644 index 000000000..df882e5a6 --- /dev/null +++ b/verifiable-db/src/row_tree/row.rs @@ -0,0 +1,122 @@ +//! Row information for the rows tree + +use crate::cells_tree::{Cell, CellWire, PublicInputs}; +use derive_more::Constructor; +use mp2_common::{ + poseidon::{empty_poseidon_hash, hash_to_int_target, H, HASH_TO_INT_LEN}, + serialization::{deserialize, serialize}, + types::CBuilder, + u256::UInt256Target, + utils::ToTargets, + F, +}; +use plonky2::{ + hash::hash_types::{HashOut, HashOutTarget}, + iop::{ + target::{BoolTarget, Target}, + witness::{PartialWitness, WitnessWrite}, + }, +}; +use plonky2_ecdsa::gadgets::{biguint::BigUintTarget, nonnative::CircuitBuilderNonNative}; +use plonky2_ecgfp5::gadgets::curve::{CircuitBuilderEcGFp5, CurveTarget}; +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, Serialize, Deserialize, Constructor)] +pub(crate) struct Row { + pub(crate) cell: Cell, + pub(crate) row_unique_data: HashOut, +} + +impl Row { + pub(crate) fn assign_wires(&self, pw: &mut PartialWitness, wires: &RowWire) { + self.cell.assign_wires(pw, &wires.cell); + pw.set_hash_target(wires.row_unique_data, self.row_unique_data); + } +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub(crate) struct RowWire { + pub(crate) cell: CellWire, + #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] + pub(crate) row_unique_data: HashOutTarget, +} + +/// Row digest result +#[derive(Clone, Debug)] +pub(crate) struct RowDigest { + pub(crate) is_merge: BoolTarget, + pub(crate) row_id_multiplier: BigUintTarget, + pub(crate) individual_vd: CurveTarget, + pub(crate) multiplier_vd: CurveTarget, +} + +impl RowWire { + pub(crate) fn new(b: &mut CBuilder) -> Self { + Self { + cell: CellWire::new(b), + row_unique_data: b.add_virtual_hash(), + } + } + + pub(crate) fn identifier(&self) -> Target { + self.cell.identifier + } + + pub(crate) fn value(&self) -> &UInt256Target { + &self.cell.value + } + + pub(crate) fn digest(&self, b: &mut CBuilder, cells_pi: &PublicInputs) -> RowDigest { + let metadata_digests = self.cell.split_metadata_digest(b); + let values_digests = self.cell.split_values_digest(b); + + let metadata_digests = + metadata_digests.accumulate(b, &cells_pi.split_metadata_digest_target()); + let values_digests = values_digests.accumulate(b, &cells_pi.split_values_digest_target()); + + // Compute row ID for individual cells: + // row_id_individual = H2Int(row_unique_data || individual_md) + let inputs = self + .row_unique_data + .to_targets() + .into_iter() + .chain(metadata_digests.individual.to_targets()) + .collect(); + let hash = b.hash_n_to_hash_no_pad::(inputs); + let row_id_individual = hash_to_int_target(b, hash); + let row_id_individual = b.biguint_to_nonnative(&row_id_individual); + + // Multiply row ID to individual value digest: + // individual_vd = row_id_individual * individual_vd + let individual_vd = b.curve_scalar_mul(values_digests.individual, &row_id_individual); + + // Multiplier is always employed for set of scalar variables, and `row_unique_data` + // for such a set is always `H("")``, so we can hardocode it in the circuit: + // row_id_multiplier = H2Int(H("") || multiplier_md) + let empty_hash = b.constant_hash(*empty_poseidon_hash()); + let inputs = empty_hash + .to_targets() + .into_iter() + .chain(metadata_digests.multiplier.to_targets()) + .collect(); + let hash = b.hash_n_to_hash_no_pad::(inputs); + let row_id_multiplier = hash_to_int_target(b, hash); + assert_eq!(row_id_multiplier.num_limbs(), HASH_TO_INT_LEN); + + let is_merge = values_digests.is_merge_case(b); + let multiplier_vd = values_digests.multiplier; + + RowDigest { + is_merge, + row_id_multiplier, + individual_vd, + multiplier_vd, + } + } +} + +/* +#[cfg(test)] +mod test { +} +*/ From 63b3b2dbb6e8226a54c05a1b8ddc9ac8f37f93d1 Mon Sep 17 00:00:00 2001 From: Steven Gu Date: Tue, 22 Oct 2024 17:18:19 +0800 Subject: [PATCH 02/16] Update tests for cells tree. --- verifiable-db/src/cells_tree/api.rs | 201 ++++++++++++------ verifiable-db/src/cells_tree/empty_node.rs | 26 ++- verifiable-db/src/cells_tree/full_node.rs | 122 +++++------ verifiable-db/src/cells_tree/leaf.rs | 70 +++--- verifiable-db/src/cells_tree/mod.rs | 158 +++++++++++++- verifiable-db/src/cells_tree/partial_node.rs | 107 +++++----- verifiable-db/src/cells_tree/public_inputs.rs | 69 ++++-- verifiable-db/src/row_tree/public_inputs.rs | 16 +- verifiable-db/src/row_tree/row.rs | 3 +- 9 files changed, 502 insertions(+), 270 deletions(-) diff --git a/verifiable-db/src/cells_tree/api.rs b/verifiable-db/src/cells_tree/api.rs index 8b7a84740..5353ad197 100644 --- a/verifiable-db/src/cells_tree/api.rs +++ b/verifiable-db/src/cells_tree/api.rs @@ -274,20 +274,18 @@ pub fn extract_hash_from_proof(proof: &[u8]) -> Result> { Ok(PublicInputs::from_slice(&p.proof.public_inputs).node_hash()) } -/* #[cfg(test)] mod tests { use super::*; + use itertools::Itertools; use mp2_common::{ - group_hashing::{add_curve_point, map_to_curve_point}, poseidon::{empty_poseidon_hash, H}, - utils::{Fieldable, ToFields}, + utils::ToFields, }; use plonky2::{field::types::PrimeField64, plonk::config::Hasher}; - use plonky2_ecgfp5::curve::curve::{Point, WeierstrassPoint}; - use rand::{thread_rng, Rng}; + use plonky2_ecgfp5::curve::curve::WeierstrassPoint; use serial_test::serial; - use std::iter; + use std::iter::once; #[test] #[serial] @@ -310,11 +308,13 @@ mod tests { fn generate_leaf_proof(params: &PublicParameters) -> Vec { // Build the circuit input. - let mut rng = thread_rng(); - let identifier: F = rng.gen::().to_field(); - let value = U256::from_limbs(rng.gen::<[u64; 4]>()); - let value_fields = value.to_fields(); - let input = CircuitInput::leaf(identifier.to_canonical_u64(), value); + let cell = Cell::sample(false); + let id = cell.identifier; + let value = cell.value; + let mpt_metadata = cell.mpt_metadata; + let values_digests = cell.split_values_digest(); + let metadata_digests = cell.split_metadata_digest(); + let input = CircuitInput::leaf(id.to_canonical_u64(), value, mpt_metadata); // Generate proof. let proof = params.generate_proof(input).unwrap(); @@ -325,27 +325,42 @@ mod tests { .proof .public_inputs; let pi = PublicInputs::from_slice(&pi); + // Check the node hash { let empty_hash = empty_poseidon_hash(); - let inputs: Vec<_> = empty_hash + let inputs = empty_hash .elements .iter() .cloned() .chain(empty_hash.elements) - .chain(iter::once(identifier)) - .chain(value_fields.clone()) - .collect(); + .chain(once(id)) + .chain(value.to_fields()) + .collect_vec(); // TODO: Fix to employ the same hash method in the ryhope tree library. let exp_hash = H::hash_no_pad(&inputs); assert_eq!(pi.h, exp_hash.elements); } - { - let inputs: Vec<_> = iter::once(identifier).chain(value_fields).collect(); - let exp_digest = map_to_curve_point(&inputs).to_weierstrass(); - - assert_eq!(pi.individual_digest_point(), exp_digest); - } + // Check individual values digest + assert_eq!( + pi.individual_values_digest_point(), + values_digests.individual.to_weierstrass(), + ); + // Check multiplier values digest + assert_eq!( + pi.multiplier_values_digest_point(), + values_digests.multiplier.to_weierstrass(), + ); + // Check individual metadata digest + assert_eq!( + pi.individual_metadata_digest_point(), + metadata_digests.individual.to_weierstrass(), + ); + // Check multiplier metadata digest + assert_eq!( + pi.multiplier_metadata_digest_point(), + metadata_digests.multiplier.to_weierstrass(), + ); proof } @@ -364,9 +379,26 @@ mod tests { let empty_hash = empty_poseidon_hash(); assert_eq!(pi.h, empty_hash.elements); } - { - assert_eq!(pi.individual_digest_point(), WeierstrassPoint::NEUTRAL); - } + // Check individual values digest + assert_eq!( + pi.individual_values_digest_point(), + WeierstrassPoint::NEUTRAL + ); + // Check multiplier values digest + assert_eq!( + pi.multiplier_values_digest_point(), + WeierstrassPoint::NEUTRAL + ); + // Check individual metadata digest + assert_eq!( + pi.individual_metadata_digest_point(), + WeierstrassPoint::NEUTRAL + ); + // Check multiplier metadata digest + assert_eq!( + pi.multiplier_metadata_digest_point(), + WeierstrassPoint::NEUTRAL + ); proof } @@ -383,11 +415,13 @@ mod tests { .collect(); // Build the circuit input. - let mut rng = thread_rng(); - let identifier: F = rng.gen::().to_field(); - let value = U256::from_limbs(rng.gen::<[u64; 4]>()); - let packed_value = value.to_fields(); - let input = CircuitInput::full(identifier.to_canonical_u64(), value, child_proofs); + let cell = Cell::sample(false); + let id = cell.identifier; + let value = cell.value; + let mpt_metadata = cell.mpt_metadata; + let values_digests = cell.split_values_digest(); + let metadata_digests = cell.split_metadata_digest(); + let input = CircuitInput::full(id.to_canonical_u64(), value, mpt_metadata, child_proofs); // Generate proof. let proof = params.generate_proof(input).unwrap(); @@ -398,32 +432,49 @@ mod tests { .proof .public_inputs; let pi = PublicInputs::from_slice(&pi); + + let values_digests = child_pis.iter().fold(values_digests, |acc, pi| { + acc.accumulate(&pi.split_values_digest_point()) + }); + let metadata_digests = child_pis.iter().fold(metadata_digests, |acc, pi| { + acc.accumulate(&pi.split_metadata_digest_point()) + }); + + // Check the node hash { - let inputs: Vec<_> = child_pis[0] - .h_raw() + let inputs = child_pis[0] + .to_node_hash_raw() .iter() - .chain(child_pis[1].h_raw()) + .chain(child_pis[1].to_node_hash_raw()) .cloned() - .chain(iter::once(identifier)) - .chain(packed_value.clone()) - .collect(); + .chain(once(id)) + .chain(value.to_fields()) + .collect_vec(); // TODO: Fix to employ the same hash method in the ryhope tree library. let exp_hash = H::hash_no_pad(&inputs); assert_eq!(pi.h, exp_hash.elements); } - { - let child_digests: Vec<_> = child_pis - .iter() - .map(|pi| Point::decode(pi.individual_digest_point().encode()).unwrap()) - .collect(); - let inputs: Vec<_> = iter::once(identifier).chain(packed_value).collect(); - let exp_digest = map_to_curve_point(&inputs); - let exp_digest = - add_curve_point(&[exp_digest, child_digests[0], child_digests[1]]).to_weierstrass(); - - assert_eq!(pi.individual_digest_point(), exp_digest); - } + // Check individual values digest + assert_eq!( + pi.individual_values_digest_point(), + values_digests.individual.to_weierstrass(), + ); + // Check multiplier values digest + assert_eq!( + pi.multiplier_values_digest_point(), + values_digests.multiplier.to_weierstrass(), + ); + // Check individual metadata digest + assert_eq!( + pi.individual_metadata_digest_point(), + metadata_digests.individual.to_weierstrass(), + ); + // Check multiplier metadata digest + assert_eq!( + pi.multiplier_metadata_digest_point(), + metadata_digests.multiplier.to_weierstrass(), + ); proof } @@ -437,11 +488,13 @@ mod tests { let child_pi = PublicInputs::from_slice(&child_pi); // Build the circuit input. - let mut rng = thread_rng(); - let identifier: F = rng.gen::().to_field(); - let value = U256::from_limbs(rng.gen::<[u64; 4]>()); - let packed_value = value.to_fields(); - let input = CircuitInput::partial(identifier.to_canonical_u64(), value, child_proof); + let cell = Cell::sample(false); + let id = cell.identifier; + let value = cell.value; + let mpt_metadata = cell.mpt_metadata; + let values_digests = cell.split_values_digest(); + let metadata_digests = cell.split_metadata_digest(); + let input = CircuitInput::partial(id.to_canonical_u64(), value, mpt_metadata, child_proof); // Generate proof. let proof = params.generate_proof(input).unwrap(); @@ -452,31 +505,47 @@ mod tests { .proof .public_inputs; let pi = PublicInputs::from_slice(&pi); + + let values_digests = values_digests.accumulate(&child_pi.split_values_digest_point()); + let metadata_digests = metadata_digests.accumulate(&child_pi.split_metadata_digest_point()); + + // Check the node hash { let empty_hash = empty_poseidon_hash(); - let inputs: Vec<_> = child_pi - .h_raw() + let inputs = child_pi + .to_node_hash_raw() .iter() .cloned() .chain(empty_hash.elements) - .chain(iter::once(identifier)) - .chain(packed_value.clone()) - .collect(); + .chain(once(id)) + .chain(value.to_fields()) + .collect_vec(); // TODO: Fix to employ the same hash method in the ryhope tree library. let exp_hash = H::hash_no_pad(&inputs); assert_eq!(pi.h, exp_hash.elements); } - { - let child_digest = Point::decode(child_pi.individual_digest_point().encode()).unwrap(); - let inputs: Vec<_> = iter::once(identifier).chain(packed_value).collect(); - let exp_digest = map_to_curve_point(&inputs); - let exp_digest = add_curve_point(&[exp_digest, child_digest]).to_weierstrass(); - - assert_eq!(pi.individual_digest_point(), exp_digest); - } + // Check individual values digest + assert_eq!( + pi.individual_values_digest_point(), + values_digests.individual.to_weierstrass(), + ); + // Check multiplier values digest + assert_eq!( + pi.multiplier_values_digest_point(), + values_digests.multiplier.to_weierstrass(), + ); + // Check individual metadata digest + assert_eq!( + pi.individual_metadata_digest_point(), + metadata_digests.individual.to_weierstrass(), + ); + // Check multiplier metadata digest + assert_eq!( + pi.multiplier_metadata_digest_point(), + metadata_digests.multiplier.to_weierstrass(), + ); proof } } -*/ diff --git a/verifiable-db/src/cells_tree/empty_node.rs b/verifiable-db/src/cells_tree/empty_node.rs index b86013f3f..f1f936ddb 100644 --- a/verifiable-db/src/cells_tree/empty_node.rs +++ b/verifiable-db/src/cells_tree/empty_node.rs @@ -54,7 +54,6 @@ impl CircuitLogicWires for EmptyNodeWires { } } -/* #[cfg(test)] mod tests { use super::*; @@ -82,10 +81,25 @@ mod tests { let empty_hash = empty_poseidon_hash(); assert_eq!(pi.h, empty_hash.elements); } - // Check the cells digest - { - assert_eq!(pi.individual_digest_point(), WeierstrassPoint::NEUTRAL); - } + // Check individual values digest + assert_eq!( + pi.individual_values_digest_point(), + WeierstrassPoint::NEUTRAL + ); + // Check multiplier values digest + assert_eq!( + pi.multiplier_values_digest_point(), + WeierstrassPoint::NEUTRAL + ); + // Check individual metadata digest + assert_eq!( + pi.individual_metadata_digest_point(), + WeierstrassPoint::NEUTRAL + ); + // Check multiplier metadata digest + assert_eq!( + pi.multiplier_metadata_digest_point(), + WeierstrassPoint::NEUTRAL + ); } } -*/ diff --git a/verifiable-db/src/cells_tree/full_node.rs b/verifiable-db/src/cells_tree/full_node.rs index 983ec071f..a35b29408 100644 --- a/verifiable-db/src/cells_tree/full_node.rs +++ b/verifiable-db/src/cells_tree/full_node.rs @@ -87,27 +87,13 @@ impl CircuitLogicWires for FullNodeWires { } } -/* #[cfg(test)] mod tests { use super::*; - use alloy::primitives::U256; - use mp2_common::{ - group_hashing::{add_curve_point, map_to_curve_point}, - poseidon::H, - utils::{Fieldable, ToFields}, - C, - }; - use mp2_test::{ - circuit::{run_circuit, UserCircuit}, - utils::random_vector, - }; - use plonky2::{ - field::types::Sample, hash::hash_types::NUM_HASH_OUT_ELTS, iop::witness::WitnessWrite, - plonk::config::Hasher, - }; - use plonky2_ecgfp5::curve::curve::Point; - use rand::{thread_rng, Rng}; + use itertools::Itertools; + use mp2_common::{poseidon::H, utils::ToFields, C}; + use mp2_test::circuit::{run_circuit, UserCircuit}; + use plonky2::{iop::witness::WitnessWrite, plonk::config::Hasher}; #[derive(Clone, Debug)] struct TestFullNodeCircuit<'a> { @@ -121,7 +107,7 @@ mod tests { fn build(b: &mut CBuilder) -> Self::Wires { let child_pis = - [0; 2].map(|_| b.add_virtual_targets(PublicInputs::::TOTAL_LEN)); + [0; 2].map(|_| b.add_virtual_targets(PublicInputs::::total_len())); let wires = FullNodeCircuit::build( b, @@ -143,61 +129,71 @@ mod tests { #[test] fn test_cells_tree_full_node_circuit() { - let mut rng = thread_rng(); - - let identifier = rng.gen::().to_field(); - let value = U256::from_limbs(rng.gen::<[u64; 4]>()); - let value_fields = value.to_fields(); - - // Create the child public inputs. - let child_hashs = [0; 2].map(|_| random_vector::(NUM_HASH_OUT_ELTS).to_fields()); - let child_digests = [0; 2].map(|_| Point::sample(&mut rng)); - let child_pis = &array::from_fn(|i| { - let h = &child_hashs[i]; - let ind = &child_digests[i].to_weierstrass().to_fields(); - let neutral = Point::NEUTRAL.to_fields(); - - PublicInputs { - h, - ind, - mul: &neutral, - } - .to_vec() - }); + test_cells_tree_full_multiplier(true); + test_cells_tree_full_multiplier(false); + } + + fn test_cells_tree_full_multiplier(is_multiplier: bool) { + let cell = Cell::sample(is_multiplier); + let id = cell.identifier; + let value = cell.value; + let values_digests = cell.split_values_digest(); + let metadata_digests = cell.split_metadata_digest(); + + let child_pis = &array::from_fn(|_| PublicInputs::::sample(is_multiplier)); let test_circuit = TestFullNodeCircuit { - c: Cell { - identifier, - value, - is_multiplier: false, - } - .into(), + c: cell.into(), child_pis, }; let proof = run_circuit::(test_circuit); let pi = PublicInputs::from_slice(&proof.public_inputs); - // Check the node Poseidon hash + + let child_pis = child_pis + .iter() + .map(|pi| PublicInputs::from_slice(pi)) + .collect_vec(); + + let values_digests = child_pis.iter().fold(values_digests, |acc, pi| { + acc.accumulate(&pi.split_values_digest_point()) + }); + let metadata_digests = child_pis.iter().fold(metadata_digests, |acc, pi| { + acc.accumulate(&pi.split_metadata_digest_point()) + }); + + // Check the node hash { - let inputs: Vec<_> = child_hashs[0] - .clone() + let inputs = child_pis[0] + .node_hash() + .to_fields() .into_iter() - .chain(child_hashs[1].clone()) - .chain(iter::once(identifier)) - .chain(value_fields.clone()) - .collect(); + .chain(child_pis[1].node_hash().to_fields()) + .chain(once(id)) + .chain(value.to_fields()) + .collect_vec(); let exp_hash = H::hash_no_pad(&inputs); assert_eq!(pi.h, exp_hash.elements); } - // Check the cells digest - { - let inputs: Vec<_> = iter::once(identifier).chain(value_fields).collect(); - let exp_digest = map_to_curve_point(&inputs); - let exp_digest = - add_curve_point(&[exp_digest, child_digests[0], child_digests[1]]).to_weierstrass(); - - assert_eq!(pi.individual_digest_point(), exp_digest); - } + // Check individual values digest + assert_eq!( + pi.individual_values_digest_point(), + values_digests.individual.to_weierstrass(), + ); + // Check multiplier values digest + assert_eq!( + pi.multiplier_values_digest_point(), + values_digests.multiplier.to_weierstrass(), + ); + // Check individual metadata digest + assert_eq!( + pi.individual_metadata_digest_point(), + metadata_digests.individual.to_weierstrass(), + ); + // Check multiplier metadata digest + assert_eq!( + pi.multiplier_metadata_digest_point(), + metadata_digests.multiplier.to_weierstrass(), + ); } } -*/ diff --git a/verifiable-db/src/cells_tree/leaf.rs b/verifiable-db/src/cells_tree/leaf.rs index 180643fc1..908dcfc41 100644 --- a/verifiable-db/src/cells_tree/leaf.rs +++ b/verifiable-db/src/cells_tree/leaf.rs @@ -81,20 +81,13 @@ impl CircuitLogicWires for LeafWires { } } -/* #[cfg(test)] mod tests { use super::*; - use alloy::primitives::U256; - use mp2_common::{ - group_hashing::map_to_curve_point, - poseidon::H, - utils::{Fieldable, ToFields}, - C, - }; + use itertools::Itertools; + use mp2_common::{poseidon::H, utils::ToFields, C}; use mp2_test::circuit::{run_circuit, UserCircuit}; use plonky2::plonk::config::Hasher; - use rand::{thread_rng, Rng}; impl UserCircuit for LeafCircuit { type Wires = LeafWires; @@ -115,45 +108,50 @@ mod tests { } fn test_cells_tree_leaf_multiplier(is_multiplier: bool) { - let mut rng = thread_rng(); - - let identifier = rng.gen::().to_field(); - let value = U256::from_limbs(rng.gen::<[u64; 4]>()); - let value_fields = value.to_fields(); - - let test_circuit: LeafCircuit = Cell { - identifier, - value, - is_multiplier, - } - .into(); + let cell = Cell::sample(is_multiplier); + let id = cell.identifier; + let value = cell.value; + let values_digests = cell.split_values_digest(); + let metadata_digests = cell.split_metadata_digest(); + let test_circuit: LeafCircuit = cell.into(); let proof = run_circuit::(test_circuit); let pi = PublicInputs::from_slice(&proof.public_inputs); - // Check the node Poseidon hash + + // Check the node hash { let empty_hash = empty_poseidon_hash(); - let inputs: Vec<_> = empty_hash + let inputs = empty_hash .elements .iter() .cloned() .chain(empty_hash.elements) - .chain(iter::once(identifier)) - .chain(value_fields.clone()) - .collect(); + .chain(once(id)) + .chain(value.to_fields()) + .collect_vec(); let exp_hash = H::hash_no_pad(&inputs); assert_eq!(pi.h, exp_hash.elements); } - // Check the cells digest - { - let inputs: Vec<_> = iter::once(identifier).chain(value_fields).collect(); - let exp_digest = map_to_curve_point(&inputs).to_weierstrass(); - match is_multiplier { - true => assert_eq!(pi.multiplier_digest_point(), exp_digest), - false => assert_eq!(pi.individual_digest_point(), exp_digest), - } - } + // Check individual values digest + assert_eq!( + pi.individual_values_digest_point(), + values_digests.individual.to_weierstrass(), + ); + // Check multiplier values digest + assert_eq!( + pi.multiplier_values_digest_point(), + values_digests.multiplier.to_weierstrass(), + ); + // Check individual metadata digest + assert_eq!( + pi.individual_metadata_digest_point(), + metadata_digests.individual.to_weierstrass(), + ); + // Check multiplier metadata digest + assert_eq!( + pi.multiplier_metadata_digest_point(), + metadata_digests.multiplier.to_weierstrass(), + ); } } -*/ diff --git a/verifiable-db/src/cells_tree/mod.rs b/verifiable-db/src/cells_tree/mod.rs index a5ba1d0dc..6c05a41d3 100644 --- a/verifiable-db/src/cells_tree/mod.rs +++ b/verifiable-db/src/cells_tree/mod.rs @@ -34,7 +34,7 @@ pub use public_inputs::PublicInputs; /// A cell represents a column || value tuple. it can be given in the cells tree or as the /// secondary index value in the row tree. #[derive(Clone, Debug, Serialize, Deserialize, Constructor)] -pub(crate) struct Cell { +pub struct Cell { /// identifier of the column for the secondary index pub(crate) identifier: F, /// secondary index value @@ -52,22 +52,22 @@ impl Cell { pw.set_bool_target(wires.is_multiplier, self.is_multiplier); pw.set_hash_target(wires.mpt_metadata, self.mpt_metadata); } - pub(crate) fn split_metadata_digest(&self) -> SplitDigestPoint { + pub fn split_metadata_digest(&self) -> SplitDigestPoint { let digest = self.metadata_digest(); SplitDigestPoint::from_single_digest_point(digest, self.is_multiplier) } - pub(crate) fn split_values_digest(&self) -> SplitDigestPoint { + pub fn split_values_digest(&self) -> SplitDigestPoint { let digest = self.values_digest(); SplitDigestPoint::from_single_digest_point(digest, self.is_multiplier) } - pub(crate) fn split_and_accumulate_metadata_digest( + pub fn split_and_accumulate_metadata_digest( &self, child_digest: SplitDigestPoint, ) -> SplitDigestPoint { let split_digest = self.split_metadata_digest(); split_digest.accumulate(&child_digest) } - pub(crate) fn split_and_accumulate_values_digest( + pub fn split_and_accumulate_values_digest( &self, child_digest: SplitDigestPoint, ) -> SplitDigestPoint { @@ -97,7 +97,7 @@ impl Cell { /// The basic wires generated for each circuit of the row tree #[derive(Clone, Debug, Serialize, Deserialize)] -pub(crate) struct CellWire { +pub struct CellWire { pub(crate) value: UInt256Target, pub(crate) identifier: Target, #[serde(serialize_with = "serialize", deserialize_with = "deserialize")] @@ -107,7 +107,7 @@ pub(crate) struct CellWire { } impl CellWire { - pub(crate) fn new(b: &mut CBuilder) -> Self { + pub fn new(b: &mut CBuilder) -> Self { Self { value: b.add_virtual_u256(), identifier: b.add_virtual_target(), @@ -115,15 +115,15 @@ impl CellWire { mpt_metadata: b.add_virtual_hash(), } } - pub(crate) fn split_metadata_digest(&self, b: &mut CBuilder) -> SplitDigestTarget { + pub fn split_metadata_digest(&self, b: &mut CBuilder) -> SplitDigestTarget { let digest = self.metadata_digest(b); SplitDigestTarget::from_single_digest_target(b, digest, self.is_multiplier) } - pub(crate) fn split_values_digest(&self, b: &mut CBuilder) -> SplitDigestTarget { + pub fn split_values_digest(&self, b: &mut CBuilder) -> SplitDigestTarget { let digest = self.values_digest(b); SplitDigestTarget::from_single_digest_target(b, digest, self.is_multiplier) } - pub(crate) fn split_and_accumulate_metadata_digest( + pub fn split_and_accumulate_metadata_digest( &self, b: &mut CBuilder, child_digest: SplitDigestTarget, @@ -131,7 +131,7 @@ impl CellWire { let split_digest = self.split_metadata_digest(b); split_digest.accumulate(b, &child_digest) } - pub(crate) fn split_and_accumulate_values_digest( + pub fn split_and_accumulate_values_digest( &self, b: &mut CBuilder, child_digest: SplitDigestTarget, @@ -159,3 +159,139 @@ impl CellWire { b.map_to_curve_point(&inputs) } } + +#[cfg(test)] +pub(crate) mod tests { + use super::*; + use mp2_common::{ + types::CURVE_TARGET_LEN, + utils::{Fieldable, FromFields, ToFields}, + C, D, F, + }; + use mp2_test::{ + circuit::{run_circuit, UserCircuit}, + utils::random_vector, + }; + use plonky2::{field::types::Sample, hash::hash_types::NUM_HASH_OUT_ELTS}; + use plonky2_ecgfp5::{ + curve::curve::Point, + gadgets::curve::{CircuitBuilderEcGFp5, PartialWitnessCurve}, + }; + use rand::{thread_rng, Rng}; + use std::array; + + impl Cell { + pub(crate) fn sample(is_multiplier: bool) -> Self { + let rng = &mut thread_rng(); + + let identifier = rng.gen::().to_field(); + let value = U256::from_limbs(rng.gen()); + let mpt_metadata = + HashOut::from_vec(random_vector::(NUM_HASH_OUT_ELTS).to_fields()); + + Cell::new(identifier, value, is_multiplier, mpt_metadata) + } + } + + #[derive(Clone, Debug)] + struct TestCellCircuit<'a> { + cell: &'a Cell, + child_values_digest: &'a SplitDigestPoint, + child_metadata_digest: &'a SplitDigestPoint, + } + + impl<'a> UserCircuit for TestCellCircuit<'a> { + // Cell wires + child values digest + child metadata digest + type Wires = (CellWire, SplitDigestTarget, SplitDigestTarget); + + fn build(b: &mut CBuilder) -> Self::Wires { + let [values_individual, values_multiplier, metadata_individual, metadata_multiplier] = + array::from_fn(|_| b.add_virtual_curve_target()); + + let child_values_digest = SplitDigestTarget { + individual: values_individual, + multiplier: values_multiplier, + }; + let child_metadata_digest = SplitDigestTarget { + individual: metadata_individual, + multiplier: metadata_multiplier, + }; + + let cell = CellWire::new(b); + let values_digest = cell.split_values_digest(b); + let metadata_digest = cell.split_metadata_digest(b); + + let values_digest = values_digest.accumulate(b, &child_values_digest); + let metadata_digest = metadata_digest.accumulate(b, &child_metadata_digest); + + b.register_curve_public_input(values_digest.individual); + b.register_curve_public_input(values_digest.multiplier); + b.register_curve_public_input(metadata_digest.individual); + b.register_curve_public_input(metadata_digest.multiplier); + + (cell, child_values_digest, child_metadata_digest) + } + + fn prove(&self, pw: &mut PartialWitness, wires: &Self::Wires) { + self.cell.assign_wires(pw, &wires.0); + pw.set_curve_target( + wires.1.individual, + self.child_values_digest.individual.to_weierstrass(), + ); + pw.set_curve_target( + wires.1.multiplier, + self.child_values_digest.multiplier.to_weierstrass(), + ); + pw.set_curve_target( + wires.2.individual, + self.child_metadata_digest.individual.to_weierstrass(), + ); + pw.set_curve_target( + wires.2.multiplier, + self.child_metadata_digest.multiplier.to_weierstrass(), + ); + } + } + + #[test] + fn test_cells_tree_cell_circuit() { + let rng = &mut thread_rng(); + + let [values_individual, values_multiplier, metadata_individual, metadata_multiplier] = + array::from_fn(|_| Point::sample(rng)); + let child_values_digest = &SplitDigestPoint { + individual: values_individual, + multiplier: values_multiplier, + }; + let child_metadata_digest = &SplitDigestPoint { + individual: metadata_individual, + multiplier: metadata_multiplier, + }; + + let cell = &Cell::sample(rng.gen()); + let values_digests = cell.split_values_digest(); + let metadata_digests = cell.split_metadata_digest(); + let exp_values_digests = values_digests.accumulate(child_values_digest); + let exp_metadata_digests = metadata_digests.accumulate(child_metadata_digest); + + let test_circuit = TestCellCircuit { + cell, + child_values_digest, + child_metadata_digest, + }; + + let proof = run_circuit::(test_circuit); + + let [values_individual, values_multiplier, metadata_individual, metadata_multiplier] = + array::from_fn(|i| { + Point::from_fields( + &proof.public_inputs[i * CURVE_TARGET_LEN..(i + 1) * CURVE_TARGET_LEN], + ) + }); + + assert_eq!(values_individual, exp_values_digests.individual); + assert_eq!(values_multiplier, exp_values_digests.multiplier); + assert_eq!(metadata_individual, exp_metadata_digests.individual); + assert_eq!(metadata_multiplier, exp_metadata_digests.multiplier); + } +} diff --git a/verifiable-db/src/cells_tree/partial_node.rs b/verifiable-db/src/cells_tree/partial_node.rs index 2e5363bc3..f7b8025f2 100644 --- a/verifiable-db/src/cells_tree/partial_node.rs +++ b/verifiable-db/src/cells_tree/partial_node.rs @@ -92,26 +92,13 @@ impl CircuitLogicWires for PartialNodeWires { } } -/* #[cfg(test)] mod tests { use super::*; - use mp2_common::{ - group_hashing::{add_curve_point, map_to_curve_point}, - poseidon::H, - utils::{Fieldable, ToFields}, - C, - }; - use mp2_test::{ - circuit::{run_circuit, UserCircuit}, - utils::random_vector, - }; - use plonky2::{ - field::types::Sample, hash::hash_types::NUM_HASH_OUT_ELTS, iop::witness::WitnessWrite, - plonk::config::Hasher, - }; - use plonky2_ecgfp5::curve::curve::Point; - use rand::{thread_rng, Rng}; + use itertools::Itertools; + use mp2_common::{poseidon::H, utils::ToFields, C}; + use mp2_test::circuit::{run_circuit, UserCircuit}; + use plonky2::{iop::witness::WitnessWrite, plonk::config::Hasher}; #[derive(Clone, Debug)] struct TestPartialNodeCircuit<'a> { @@ -124,8 +111,7 @@ mod tests { type Wires = (PartialNodeWires, Vec); fn build(b: &mut CBuilder) -> Self::Wires { - let child_pi = b.add_virtual_targets(PublicInputs::::TOTAL_LEN); - + let child_pi = b.add_virtual_targets(PublicInputs::::total_len()); let wires = PartialNodeCircuit::build(b, PublicInputs::from_slice(&child_pi)); (wires, child_pi) @@ -139,56 +125,65 @@ mod tests { #[test] fn test_cells_tree_partial_node_circuit() { - let mut rng = thread_rng(); - - let identifier = rng.gen::().to_field(); - let value = U256::from_limbs(rng.gen::<[u64; 4]>()); - let value_fields = value.to_fields(); - - // Create the child public inputs. - let child_hash = random_vector::(NUM_HASH_OUT_ELTS).to_fields(); - let child_digest = Point::sample(&mut rng); - let dc = &child_digest.to_weierstrass().to_fields(); - let neutral = Point::NEUTRAL.to_fields(); - let child_pi = &PublicInputs { - h: &child_hash, - ind: dc, - mul: &neutral, - } - .to_vec(); + test_cells_tree_partial_multiplier(true); + test_cells_tree_partial_multiplier(false); + } + + fn test_cells_tree_partial_multiplier(is_multiplier: bool) { + let cell = Cell::sample(is_multiplier); + let id = cell.identifier; + let value = cell.value; + let values_digests = cell.split_values_digest(); + let metadata_digests = cell.split_metadata_digest(); + + let child_pi = &PublicInputs::::sample(is_multiplier); let test_circuit = TestPartialNodeCircuit { - c: Cell { - identifier, - value, - is_multiplier: false, - } - .into(), + c: cell.into(), child_pi, }; + let proof = run_circuit::(test_circuit); let pi = PublicInputs::from_slice(&proof.public_inputs); - // Check the node Poseidon hash + let child_pi = PublicInputs::from_slice(child_pi); + + let values_digests = values_digests.accumulate(&child_pi.split_values_digest_point()); + let metadata_digests = metadata_digests.accumulate(&child_pi.split_metadata_digest_point()); + + // Check the node hash { let empty_hash = empty_poseidon_hash(); - let inputs: Vec<_> = child_hash + let inputs = child_pi + .node_hash() + .to_fields() .into_iter() .chain(empty_hash.elements) - .chain(iter::once(identifier)) - .chain(value_fields.clone()) - .collect(); + .chain(once(id)) + .chain(value.to_fields()) + .collect_vec(); let exp_hash = H::hash_no_pad(&inputs); assert_eq!(pi.h, exp_hash.elements); } - // Check the cells digest - { - let inputs: Vec<_> = iter::once(identifier).chain(value_fields).collect(); - let exp_digest = map_to_curve_point(&inputs); - let exp_digest = add_curve_point(&[exp_digest, child_digest]).to_weierstrass(); - - assert_eq!(pi.individual_digest_point(), exp_digest); - } + // Check individual values digest + assert_eq!( + pi.individual_values_digest_point(), + values_digests.individual.to_weierstrass(), + ); + // Check multiplier values digest + assert_eq!( + pi.multiplier_values_digest_point(), + values_digests.multiplier.to_weierstrass(), + ); + // Check individual metadata digest + assert_eq!( + pi.individual_metadata_digest_point(), + metadata_digests.individual.to_weierstrass(), + ); + // Check multiplier metadata digest + assert_eq!( + pi.multiplier_metadata_digest_point(), + metadata_digests.multiplier.to_weierstrass(), + ); } } -*/ diff --git a/verifiable-db/src/cells_tree/public_inputs.rs b/verifiable-db/src/cells_tree/public_inputs.rs index e2c2f5b3c..422cfbc38 100644 --- a/verifiable-db/src/cells_tree/public_inputs.rs +++ b/verifiable-db/src/cells_tree/public_inputs.rs @@ -72,27 +72,27 @@ impl<'a, T: Clone> PublicInputs<'a, T> { offset..offset + Self::SIZES[pi_pos] } - pub(crate) const fn total_len() -> usize { + pub const fn total_len() -> usize { Self::to_range(CellsTreePublicInputs::MultiplierMetadataDigest).end } - pub(crate) fn to_node_hash_raw(&self) -> &[T] { + pub fn to_node_hash_raw(&self) -> &[T] { self.h } - pub(crate) fn to_individual_values_digest_raw(&self) -> &[T] { + pub fn to_individual_values_digest_raw(&self) -> &[T] { self.individual_vd } - pub(crate) fn to_multiplier_values_digest_raw(&self) -> &[T] { + pub fn to_multiplier_values_digest_raw(&self) -> &[T] { self.multiplier_vd } - pub(crate) fn to_individual_metadata_digest_raw(&self) -> &[T] { + pub fn to_individual_metadata_digest_raw(&self) -> &[T] { self.individual_md } - pub(crate) fn to_multiplier_metadata_digest_raw(&self) -> &[T] { + pub fn to_multiplier_metadata_digest_raw(&self) -> &[T] { self.multiplier_md } @@ -218,14 +218,14 @@ impl<'a> PublicInputs<'a, F> { pub fn split_metadata_digest_point(&self) -> SplitDigestPoint { SplitDigestPoint { - individual: weierstrass_to_point(&&self.individual_metadata_digest_point()), + individual: weierstrass_to_point(&self.individual_metadata_digest_point()), multiplier: weierstrass_to_point(&self.multiplier_metadata_digest_point()), } } } #[cfg(test)] -mod tests { +pub(crate) mod tests { use super::*; use mp2_common::{utils::ToFields, C, D, F}; use mp2_test::{ @@ -240,9 +240,45 @@ mod tests { }, }; use plonky2_ecgfp5::curve::curve::Point; - use rand::thread_rng; + use rand::{thread_rng, Rng}; use std::array; + impl<'a> PublicInputs<'a, F> { + pub(crate) fn sample(is_multiplier: bool) -> Vec { + let rng = &mut thread_rng(); + + let h = random_vector::(NUM_HASH_OUT_ELTS).to_fields(); + + let point_zero = WeierstrassPoint::NEUTRAL.to_fields(); + let [values_digest, metadata_digest] = + array::from_fn(|_| Point::sample(rng).to_weierstrass().to_fields()); + let [individual_vd, multiplier_vd, individual_md, multiplier_md] = if is_multiplier { + [ + point_zero.clone(), + values_digest, + point_zero, + metadata_digest, + ] + } else { + [ + values_digest, + point_zero.clone(), + metadata_digest, + point_zero, + ] + }; + + PublicInputs::new( + &h, + &individual_vd, + &multiplier_vd, + &individual_md, + &multiplier_md, + ) + .to_vec() + } + } + #[derive(Clone, Debug)] struct TestPublicInputs<'a> { exp_pi: &'a [F], @@ -266,20 +302,9 @@ mod tests { #[test] fn test_cells_tree_public_inputs() { let rng = &mut thread_rng(); + let is_multiplier = rng.gen(); - // Prepare the public inputs. - let h = random_vector::(NUM_HASH_OUT_ELTS).to_fields(); - let [individual_vd, multiplier_vd, individual_md, multiplier_md] = - array::from_fn(|_| Point::sample(rng).to_weierstrass().to_fields()); - let exp_pi = PublicInputs::new( - &h, - &individual_vd, - &multiplier_vd, - &individual_md, - &multiplier_md, - ); - let exp_pi = &exp_pi.to_vec(); - + let exp_pi = &PublicInputs::sample(is_multiplier); let test_circuit = TestPublicInputs { exp_pi }; let proof = run_circuit::(test_circuit); assert_eq!(&proof.public_inputs, exp_pi); diff --git a/verifiable-db/src/row_tree/public_inputs.rs b/verifiable-db/src/row_tree/public_inputs.rs index 06eb726d1..bccaf83d8 100644 --- a/verifiable-db/src/row_tree/public_inputs.rs +++ b/verifiable-db/src/row_tree/public_inputs.rs @@ -91,35 +91,35 @@ impl<'a, T: Clone> PublicInputs<'a, T> { offset..offset + Self::SIZES[pi_pos] } - pub(crate) const fn total_len() -> usize { + pub const fn total_len() -> usize { Self::to_range(RowsTreePublicInputs::RowIdMultiplier).end } - pub(crate) fn to_root_hash_raw(&self) -> &[T] { + pub fn to_root_hash_raw(&self) -> &[T] { self.h } - pub(crate) fn to_individual_digest_raw(&self) -> &[T] { + pub fn to_individual_digest_raw(&self) -> &[T] { self.individual_digest } - pub(crate) fn to_multiplier_digest_raw(&self) -> &[T] { + pub fn to_multiplier_digest_raw(&self) -> &[T] { self.multiplier_digest } - pub(crate) fn to_row_id_multiplier_raw(&self) -> &[T] { + pub fn to_row_id_multiplier_raw(&self) -> &[T] { self.row_id_multiplier } - pub(crate) fn to_min_value_raw(&self) -> &[T] { + pub fn to_min_value_raw(&self) -> &[T] { self.min } - pub(crate) fn to_max_value_raw(&self) -> &[T] { + pub fn to_max_value_raw(&self) -> &[T] { self.max } - pub(crate) fn to_merge_flag_raw(&self) -> &T { + pub fn to_merge_flag_raw(&self) -> &T { self.merge } diff --git a/verifiable-db/src/row_tree/row.rs b/verifiable-db/src/row_tree/row.rs index df882e5a6..db19effcc 100644 --- a/verifiable-db/src/row_tree/row.rs +++ b/verifiable-db/src/row_tree/row.rs @@ -115,8 +115,7 @@ impl RowWire { } } -/* #[cfg(test)] mod test { + // gupeng } -*/ From c7291cfb003cf183a9f5a641a82ed25f10c7bf98 Mon Sep 17 00:00:00 2001 From: Steven Gu Date: Wed, 23 Oct 2024 09:38:42 +0800 Subject: [PATCH 03/16] Update tests for rows tree. --- mp2-common/src/utils.rs | 1 + verifiable-db/src/cells_tree/mod.rs | 14 +- verifiable-db/src/row_tree/api.rs | 275 ++++++++++++-------- verifiable-db/src/row_tree/full_node.rs | 157 +++++------ verifiable-db/src/row_tree/leaf.rs | 115 ++++---- verifiable-db/src/row_tree/partial_node.rs | 155 ++++++----- verifiable-db/src/row_tree/public_inputs.rs | 36 ++- verifiable-db/src/row_tree/row.rs | 189 ++++++++++++-- 8 files changed, 566 insertions(+), 376 deletions(-) diff --git a/mp2-common/src/utils.rs b/mp2-common/src/utils.rs index 76b3d6ec0..4d9ca1ad9 100644 --- a/mp2-common/src/utils.rs +++ b/mp2-common/src/utils.rs @@ -373,6 +373,7 @@ impl ToFields for HashOut { self.elements.to_vec() } } + pub trait Fieldable { fn to_field(&self) -> F; } diff --git a/verifiable-db/src/cells_tree/mod.rs b/verifiable-db/src/cells_tree/mod.rs index 6c05a41d3..3cafa8c70 100644 --- a/verifiable-db/src/cells_tree/mod.rs +++ b/verifiable-db/src/cells_tree/mod.rs @@ -165,14 +165,11 @@ pub(crate) mod tests { use super::*; use mp2_common::{ types::CURVE_TARGET_LEN, - utils::{Fieldable, FromFields, ToFields}, + utils::{Fieldable, FromFields}, C, D, F, }; - use mp2_test::{ - circuit::{run_circuit, UserCircuit}, - utils::random_vector, - }; - use plonky2::{field::types::Sample, hash::hash_types::NUM_HASH_OUT_ELTS}; + use mp2_test::circuit::{run_circuit, UserCircuit}; + use plonky2::field::types::Sample; use plonky2_ecgfp5::{ curve::curve::Point, gadgets::curve::{CircuitBuilderEcGFp5, PartialWitnessCurve}, @@ -186,8 +183,7 @@ pub(crate) mod tests { let identifier = rng.gen::().to_field(); let value = U256::from_limbs(rng.gen()); - let mpt_metadata = - HashOut::from_vec(random_vector::(NUM_HASH_OUT_ELTS).to_fields()); + let mpt_metadata = HashOut::rand(); Cell::new(identifier, value, is_multiplier, mpt_metadata) } @@ -201,7 +197,7 @@ pub(crate) mod tests { } impl<'a> UserCircuit for TestCellCircuit<'a> { - // Cell wires + child values digest + child metadata digest + // Cell wire + child values digest + child metadata digest type Wires = (CellWire, SplitDigestTarget, SplitDigestTarget); fn build(b: &mut CBuilder) -> Self::Wires { diff --git a/verifiable-db/src/row_tree/api.rs b/verifiable-db/src/row_tree/api.rs index a56c24fe3..56d5613c4 100644 --- a/verifiable-db/src/row_tree/api.rs +++ b/verifiable-db/src/row_tree/api.rs @@ -313,19 +313,17 @@ pub fn extract_hash_from_proof(proof: &[u8]) -> Result> { Ok(PublicInputs::from_slice(&p.proof.public_inputs).root_hash()) } -/* #[cfg(test)] mod test { - use crate::{cells_tree, row_tree::public_inputs::PublicInputs}; - use super::*; + use crate::cells_tree; + use itertools::Itertools; use mp2_common::{ - group_hashing::{cond_field_hashed_scalar_mul, map_to_curve_point}, poseidon::{empty_poseidon_hash, H}, utils::ToFields, F, }; - use mp2_test::{log::init_logging, utils::weierstrass_to_point}; + use mp2_test::log::init_logging; use partial_node::test::partial_safety_check; use plonky2::{ field::types::{PrimeField64, Sample}, @@ -334,10 +332,10 @@ mod test { circuit_data::VerifierOnlyCircuitData, config::Hasher, proof::ProofWithPublicInputs, }, }; - use plonky2_ecgfp5::curve::curve::Point; use recursion_framework::framework_testing::TestingRecursiveCircuits; + use std::iter::once; - const CELL_IO_LEN: usize = cells_tree::PublicInputs::::TOTAL_LEN; + const CELL_IO_LEN: usize = cells_tree::PublicInputs::::total_len(); struct TestParams { cells_test: TestingRecursiveCircuits, @@ -346,17 +344,17 @@ mod test { // to save on test time cells_proof: ProofWithPublicInputs, cells_vk: VerifierOnlyCircuitData, - leaf1: Cell, - leaf2: Cell, - full: Cell, - partial: Cell, + leaf1: Row, + leaf2: Row, + full: Row, + partial: Row, } impl TestParams { fn build() -> Result { let cells_test = TestingRecursiveCircuits::::default(); let params = PublicParameters::build(cells_test.get_recursive_circuit_set()); - let cells_pi = Self::rand_cells_pi(); + let cells_pi = cells_tree::PublicInputs::sample(false); let cells_proof = cells_test.generate_input_proofs::<1>([cells_pi.clone().try_into().unwrap()])?; let cells_vk = cells_test.verifier_data_for_input_proofs::<1>()[0].clone(); @@ -378,10 +376,22 @@ mod test { params, cells_proof: cells_proof[0].clone(), cells_vk, - leaf1: Cell::new(identifier, v1, false), - leaf2: Cell::new(identifier, v2, false), - full: Cell::new(identifier, v_full, false), - partial: Cell::new(identifier, v_partial, false), + leaf1: Row::new( + Cell::new(identifier, v1, false, HashOut::rand()), + HashOut::rand(), + ), + leaf2: Row::new( + Cell::new(identifier, v2, false, HashOut::rand()), + HashOut::rand(), + ), + full: Row::new( + Cell::new(identifier, v_full, false, HashOut::rand()), + HashOut::rand(), + ), + partial: Row::new( + Cell::new(identifier, v_partial, false, HashOut::rand()), + HashOut::rand(), + ), }) } @@ -391,19 +401,6 @@ mod test { fn cells_proof_vk(&self) -> ProofWithVK { ProofWithVK::new(self.cells_proof.clone(), self.cells_vk.clone()) } - - fn rand_cells_pi() -> Vec { - // generate cells tree input and fake proof - let cells_hash = HashOut::rand().to_fields(); - let cells_digest = Point::rand().to_weierstrass().to_fields(); - let cells_pi = cells_tree::PublicInputs::new( - &cells_hash, - &cells_digest, - &Point::NEUTRAL.to_fields(), - ) - .to_vec(); - cells_pi - } } #[test] @@ -426,21 +423,29 @@ mod test { fn generate_partial_proof( p: &TestParams, - tuple: Cell, + row: Row, is_left: bool, child_proof_buff: Vec, ) -> Result> { + let id = row.cell.identifier; + let value = row.cell.value; + let mpt_metadata = row.cell.mpt_metadata; + let row_unique_data = row.row_unique_data; + let row_digest = row.digest(&p.cells_pi()); + let child_proof = ProofWithVK::deserialize(&child_proof_buff)?; let child_pi = PublicInputs::from_slice(&child_proof.proof.public_inputs); - let child_min = child_pi.min_value_u256(); - let child_max = child_pi.max_value_u256(); + let child_min = child_pi.min_value(); + let child_max = child_pi.max_value(); - partial_safety_check(child_min, child_max, tuple.value, is_left); + partial_safety_check(child_min, child_max, value, is_left); let input = CircuitInput::partial( - tuple.identifier.to_canonical_u64(), - tuple.value, + id.to_canonical_u64(), + value, is_left, + mpt_metadata, + row_unique_data, child_proof_buff.clone(), p.cells_proof_vk().serialize()?, )?; @@ -449,46 +454,67 @@ mod test { .generate_proof(input, p.cells_test.get_recursive_circuit_set().clone())?; let pi = ProofWithVK::deserialize(&proof)?.proof.public_inputs; let pi = PublicInputs::from_slice(&pi); + + // Check root hash { // node_min = left ? child_proof.min : index_value // node_max = left ? index_value : child_proof.max let (node_min, node_max) = match is_left { - true => (pi.min_value_u256(), tuple.value), - false => (tuple.value, pi.max_value_u256()), + true => (pi.min_value(), value), + false => (value, pi.max_value()), }; - - let child_hash = child_pi.root_hash_hashout(); - let empty_hash = empty_poseidon_hash(); + // Poseidon(p1.H || p2.H || node_min || node_max || index_id || index_value ||p.H)) as H + let child_hash = child_pi.root_hash().to_fields(); + let empty_hash = empty_poseidon_hash().to_fields(); let input_hash = match is_left { - true => [child_hash.to_fields(), empty_hash.to_fields()].concat(), - false => [empty_hash.to_fields(), child_hash.to_fields()].concat(), + true => [child_hash, empty_hash].concat(), + false => [empty_hash, child_hash].concat(), }; let inputs = input_hash - .iter() - .chain(node_min.to_fields().iter()) - .chain(node_max.to_fields().iter()) - .chain(tuple.to_fields().iter()) - .chain(p.cells_pi().h_raw().iter()) - .cloned() - .collect::>(); - let hash = H::hash_no_pad(&inputs); - assert_eq!(hash, pi.root_hash_hashout()); - - // final_digest = HashToInt(mul_digest) * D(ind_digest) + row_proof.digest() - let split_digest = tuple.split_and_accumulate_digest(p.cells_pi().split_digest_point()); - let res = split_digest.cond_combine_to_row_digest(); - // then adding with the rest of the rows digest, the other nodes - let res = res + weierstrass_to_point(&child_pi.rows_digest_field()); - assert_eq!(res.to_weierstrass(), pi.rows_digest_field()); + .into_iter() + .chain(node_min.to_fields()) + .chain(node_max.to_fields()) + .chain(once(id)) + .chain(p.cells_pi().node_hash().to_fields()) + .collect_vec(); + let exp_root_hash = H::hash_no_pad(&inputs); + assert_eq!(pi.root_hash(), exp_root_hash); } + // Check individual digest + assert_eq!( + pi.individual_digest_point(), + row_digest.individual_vd.to_weierstrass() + ); + // Check multiplier digest + assert_eq!( + pi.multiplier_digest_point(), + row_digest.multiplier_vd.to_weierstrass() + ); + // Check row ID multiplier + assert_eq!(pi.row_id_multiplier(), row_digest.row_id_multiplier); + // Check minimum value + assert_eq!(pi.min_value(), value.min(child_min)); + // Check maximum value + assert_eq!(pi.max_value(), value.max(child_max)); + // Check merge flag + assert_eq!(pi.merge_flag(), row_digest.is_merge); + Ok(vec![]) } fn generate_full_proof(p: &TestParams, child_proof: [Vec; 2]) -> Result> { - let tuple = p.full.clone(); + let row = &p.full; + let id = row.cell.identifier; + let value = row.cell.value; + let mpt_metadata = row.cell.mpt_metadata; + let row_unique_data = row.row_unique_data; + let row_digest = row.digest(&p.cells_pi()); + let input = CircuitInput::full( - tuple.identifier.to_canonical_u64(), - tuple.value, + id.to_canonical_u64(), + value, + mpt_metadata, + row_unique_data, child_proof[0].to_vec(), child_proof[1].to_vec(), p.cells_proof_vk().serialize()?, @@ -497,52 +523,61 @@ mod test { let left_pi = PublicInputs::from_slice(&left_proof.proof.public_inputs); let right_proof = ProofWithVK::deserialize(&child_proof[1])?; let right_pi = PublicInputs::from_slice(&right_proof.proof.public_inputs); - assert!(left_pi.max_value_u256() < tuple.value); - assert!(tuple.value < right_pi.min_value_u256()); + assert!(left_pi.max_value() < value); + assert!(value < right_pi.min_value()); let proof = p .params .generate_proof(input, p.cells_test.get_recursive_circuit_set().clone())?; let pi = ProofWithVK::deserialize(&proof)?.proof.public_inputs; let pi = PublicInputs::from_slice(&pi); + + // Check root hash { - // H(left_child_hash,right_child_hash,min,max,index_identifier,index_value,cells_tree_hash) - // min coming from left - // max coming from right - let inputs: Vec<_> = left_pi - .root_hash_hashout() + // Poseidon(p1.H || p2.H || node_min || node_max || index_id || index_value ||p.H)) as H + let inputs = left_pi + .root_hash() .to_fields() - .iter() - .chain(right_pi.root_hash_hashout().to_fields().iter()) - .chain(left_pi.min_value_u256().to_fields().iter()) - .chain(right_pi.max_value_u256().to_fields().iter()) - .chain(tuple.to_fields().iter()) - .chain(p.cells_pi().h_raw().iter()) - .cloned() - .collect(); - let exp_hash = H::hash_no_pad(&inputs); - assert_eq!(pi.root_hash_hashout(), exp_hash); - - { - // final_digest = HashToInt(mul_digest) * D(ind_digest) + p1.digest() + p2.digest() - let split_digest = - tuple.split_and_accumulate_digest(p.cells_pi().split_digest_point()); - let row_digest = split_digest.cond_combine_to_row_digest(); - - let p1dr = weierstrass_to_point(&left_pi.rows_digest_field()); - let p2dr = weierstrass_to_point(&right_pi.rows_digest_field()); - let result_digest = p1dr + p2dr + row_digest; - assert_eq!(result_digest.to_weierstrass(), pi.rows_digest_field()); - } + .into_iter() + .chain(right_pi.root_hash().to_fields()) + .chain(left_pi.min_value().to_fields()) + .chain(right_pi.max_value().to_fields()) + .chain(once(id)) + .chain(p.cells_pi().node_hash().to_fields()) + .collect_vec(); + let hash = H::hash_no_pad(&inputs); + assert_eq!(hash, pi.root_hash()); } + // Check individual digest + assert_eq!( + pi.individual_digest_point(), + row_digest.individual_vd.to_weierstrass() + ); + // Check multiplier digest + assert_eq!( + pi.multiplier_digest_point(), + row_digest.multiplier_vd.to_weierstrass() + ); + // Check row ID multiplier + assert_eq!(pi.row_id_multiplier(), row_digest.row_id_multiplier); + // Check merge flag + assert_eq!(pi.merge_flag(), row_digest.is_merge); + Ok(proof) } - fn generate_leaf_proof(p: &TestParams, tuple: &Cell) -> Result> { - let cells_pi = p.cells_pi(); + fn generate_leaf_proof(p: &TestParams, row: &Row) -> Result> { + let id = row.cell.identifier; + let value = row.cell.value; + let mpt_metadata = row.cell.mpt_metadata; + let row_unique_data = row.row_unique_data; + let row_digest = row.digest(&p.cells_pi()); + // generate row leaf proof let input = CircuitInput::leaf( - tuple.identifier.to_canonical_u64(), - tuple.value, + id.to_canonical_u64(), + value, + mpt_metadata, + row_unique_data, p.cells_proof_vk().serialize()?, )?; @@ -554,30 +589,42 @@ mod test { .proof .public_inputs; let pi = PublicInputs::from_slice(&pi); - let tuple = tuple.clone(); + + // Check root hash { - let empty_hash = empty_poseidon_hash(); - // H(left_child_hash,right_child_hash,min,max,index_identifier,index_value,cells_tree_hash) - let inputs: Vec<_> = empty_hash - .to_fields() + let value = value.to_fields(); + let empty_hash = empty_poseidon_hash().to_fields(); + let inputs = empty_hash .iter() - .chain(empty_hash.to_fields().iter()) - .chain(tuple.value.to_fields().iter()) - .chain(tuple.value.to_fields().iter()) - .chain(tuple.to_fields().iter()) - .chain(cells_pi.h_raw().iter()) + .chain(empty_hash.iter()) + .chain(value.iter()) + .chain(value.iter()) + .chain(once(&id)) + .chain(p.cells_pi().to_node_hash_raw()) .cloned() - .collect(); - let exp_hash = H::hash_no_pad(&inputs); - assert_eq!(pi.root_hash_hashout(), exp_hash); - } - { - // final_digest = HashToInt(mul_digest) * D(ind_digest) - let split_digest = tuple.split_and_accumulate_digest(cells_pi.split_digest_point()); - let result = split_digest.cond_combine_to_row_digest(); - assert_eq!(result.to_weierstrass(), pi.rows_digest_field()); + .collect_vec(); + let exp_root_hash = H::hash_no_pad(&inputs); + assert_eq!(pi.root_hash(), exp_root_hash); } + // Check individual digest + assert_eq!( + pi.individual_digest_point(), + row_digest.individual_vd.to_weierstrass() + ); + // Check multiplier digest + assert_eq!( + pi.multiplier_digest_point(), + row_digest.multiplier_vd.to_weierstrass() + ); + // Check row ID multiplier + assert_eq!(pi.row_id_multiplier(), row_digest.row_id_multiplier); + // Check minimum value + assert_eq!(pi.min_value(), value); + // Check maximum value + assert_eq!(pi.max_value(), value); + // Check merge flag + assert_eq!(pi.merge_flag(), row_digest.is_merge); + Ok(proof) } } -*/ diff --git a/verifiable-db/src/row_tree/full_node.rs b/verifiable-db/src/row_tree/full_node.rs index 4bd983214..272c7b779 100644 --- a/verifiable-db/src/row_tree/full_node.rs +++ b/verifiable-db/src/row_tree/full_node.rs @@ -150,35 +150,14 @@ impl CircuitLogicWires for RecursiveFullWires { } } -/* #[cfg(test)] pub(crate) mod test { - + use super::*; use alloy::primitives::U256; - use mp2_common::{ - group_hashing::{cond_field_hashed_scalar_mul, map_to_curve_point}, - poseidon::H, - utils::ToFields, - C, D, F, - }; - use mp2_test::{ - circuit::{run_circuit, UserCircuit}, - utils::weierstrass_to_point, - }; - use plonky2::{ - field::types::{Field, Sample}, - hash::hash_types::HashOut, - iop::{ - target::Target, - witness::{PartialWitness, WitnessWrite}, - }, - plonk::{circuit_builder::CircuitBuilder, config::Hasher}, - }; - use plonky2_ecgfp5::curve::curve::Point; - - use crate::{cells_tree, row_tree::public_inputs::PublicInputs}; - - use super::{FullNodeCircuit, FullNodeWires, *}; + use itertools::Itertools; + use mp2_common::{utils::ToFields, C, D, F}; + use mp2_test::circuit::{run_circuit, UserCircuit}; + use plonky2::{iop::witness::WitnessWrite, plonk::config::Hasher}; #[derive(Clone, Debug)] struct TestFullNodeCircuit { @@ -211,82 +190,77 @@ pub(crate) mod test { } } - pub(crate) fn generate_random_pi(min: usize, max: usize, is_merge: bool) -> Vec { - let hash = HashOut::rand(); - let digest = Point::rand(); - let min = U256::from(min); - let max = U256::from(max); - let merge = F::from_canonical_usize(is_merge as usize); - PublicInputs::new( - &hash.to_fields(), - &digest.to_weierstrass().to_fields(), - &min.to_fields(), - &max.to_fields(), - &[merge], - ) - .to_vec() - } - fn test_row_tree_full_circuit(is_multiplier: bool, cells_multiplier: bool) { - let cells_point = Point::rand(); - let ind_cell_digest = cells_point.to_weierstrass().to_fields(); - let mul_cell_digest = if cells_multiplier { - cells_point.to_weierstrass().to_fields() - } else { - Point::NEUTRAL.to_fields() - }; - let cells_hash = HashOut::rand().to_fields(); - let cells_pi_struct = - cells_tree::PublicInputs::new(&cells_hash, &ind_cell_digest, &mul_cell_digest); - let cells_pi = cells_pi_struct.to_vec(); - + let mut row = Row::sample(is_multiplier); + row.cell.value = U256::from(18); + let id = row.cell.identifier; + let cells_pi = cells_tree::PublicInputs::sample(cells_multiplier); + // Compute the row digest. + let row_digest = row.digest(&cells_tree::PublicInputs::from_slice(&cells_pi)); + let node_circuit = FullNodeCircuit::from(row.clone()); let (left_min, left_max) = (10, 15); // this should work since we allow multipleicities of indexes in the row tree let (right_min, right_max) = (18, 30); - let value = U256::from(18); // 15 < 18 < 23 - let identifier = F::rand(); - let tuple = Cell::new(identifier, value, is_multiplier); - let node_circuit = FullNodeCircuit::from(tuple.clone()); - let left_pi = generate_random_pi(left_min, left_max, is_multiplier || cells_multiplier); - let right_pi = generate_random_pi(right_min, right_max, is_multiplier || cells_multiplier); + let left_pi = PublicInputs::sample( + row_digest.multiplier_vd, + row_digest.row_id_multiplier.clone(), + left_min, + left_max, + is_multiplier || cells_multiplier, + ); + let right_pi = PublicInputs::sample( + row_digest.multiplier_vd, + row_digest.row_id_multiplier.clone(), + right_min, + right_max, + is_multiplier || cells_multiplier, + ); let test_circuit = TestFullNodeCircuit { circuit: node_circuit, left_pi: left_pi.clone(), right_pi: right_pi.clone(), - cells_pi, + cells_pi: cells_pi.clone(), }; let proof = run_circuit::(test_circuit); let pi = PublicInputs::from_slice(&proof.public_inputs); - let left_pis = PublicInputs::from_slice(&left_pi); - let right_pis = PublicInputs::from_slice(&right_pi); - - assert_eq!(U256::from(left_min), pi.min_value_u256()); - assert_eq!(U256::from(right_max), pi.max_value_u256()); - // Poseidon(p1.H || p2.H || node_min || node_max || index_id || index_value ||p.H)) as H - let left_hash = PublicInputs::from_slice(&left_pi).root_hash_hashout(); - let right_hash = PublicInputs::from_slice(&right_pi).root_hash_hashout(); - let inputs = left_hash - .to_fields() - .iter() - .chain(right_hash.to_fields().iter()) - .chain(left_pis.min_value_u256().to_fields().iter()) - .chain(right_pis.max_value_u256().to_fields().iter()) - .chain(Cell::new(identifier, value, false).to_fields().iter()) - .chain(cells_hash.iter()) - .cloned() - .collect::>(); - let hash = H::hash_no_pad(&inputs); - assert_eq!(hash, pi.root_hash_hashout()); - - // final_digest = HashToInt(mul_digest) * D(ind_digest) + p1.digest() + p2.digest() - let split_digest = tuple.split_and_accumulate_digest(cells_pi_struct.split_digest_point()); - let row_digest = split_digest.cond_combine_to_row_digest(); - - let p1dr = weierstrass_to_point(&PublicInputs::from_slice(&left_pi).rows_digest_field()); - let p2dr = weierstrass_to_point(&PublicInputs::from_slice(&right_pi).rows_digest_field()); - let result_digest = p1dr + p2dr + row_digest; - assert_eq!(result_digest.to_weierstrass(), pi.rows_digest_field()); - assert_eq!(split_digest.is_merge_case(), pi.is_merge_flag()); + let left_pi = PublicInputs::from_slice(&left_pi); + let right_pi = PublicInputs::from_slice(&right_pi); + let cells_pi = cells_tree::PublicInputs::from_slice(&cells_pi); + + // Check root hash + { + // Poseidon(p1.H || p2.H || node_min || node_max || index_id || index_value ||p.H)) as H + let inputs = left_pi + .root_hash() + .to_fields() + .into_iter() + .chain(right_pi.root_hash().to_fields()) + .chain(left_pi.min_value().to_fields()) + .chain(right_pi.max_value().to_fields()) + .chain(once(id)) + .chain(cells_pi.node_hash().to_fields()) + .collect_vec(); + let hash = H::hash_no_pad(&inputs); + assert_eq!(hash, pi.root_hash()); + } + // Check individual digest + assert_eq!( + pi.individual_digest_point(), + row_digest.individual_vd.to_weierstrass() + ); + // Check multiplier digest + assert_eq!( + pi.multiplier_digest_point(), + row_digest.multiplier_vd.to_weierstrass() + ); + // Check row ID multiplier + assert_eq!(pi.row_id_multiplier(), row_digest.row_id_multiplier); + // Check minimum value + assert_eq!(pi.min_value(), U256::from(left_min)); + // Check maximum value + assert_eq!(pi.max_value(), U256::from(right_max)); + // Check merge flag + assert_eq!(pi.merge_flag(), row_digest.is_merge); } #[test] @@ -309,4 +283,3 @@ pub(crate) mod test { test_row_tree_full_circuit(true, true); } } -*/ diff --git a/verifiable-db/src/row_tree/leaf.rs b/verifiable-db/src/row_tree/leaf.rs index e9c6a34f6..48dd243aa 100644 --- a/verifiable-db/src/row_tree/leaf.rs +++ b/verifiable-db/src/row_tree/leaf.rs @@ -125,33 +125,20 @@ impl CircuitLogicWires for RecursiveLeafWires { } } -/* #[cfg(test)] mod test { - - use alloy::primitives::U256; - use mp2_common::{ - group_hashing::{cond_field_hashed_scalar_mul, map_to_curve_point}, - poseidon::empty_poseidon_hash, - utils::ToFields, - CHasher, C, D, F, + use super::*; + use crate::{ + cells_tree::PublicInputs as CellsPublicInputs, row_tree::public_inputs::PublicInputs, }; + use itertools::Itertools; + use mp2_common::{poseidon::empty_poseidon_hash, utils::ToFields, C, D, F}; use mp2_test::circuit::{run_circuit, UserCircuit}; use plonky2::{ - field::types::Sample, - hash::{hash_types::HashOut, hashing::hash_n_to_hash_no_pad}, iop::{target::Target, witness::WitnessWrite}, plonk::{circuit_builder::CircuitBuilder, config::Hasher}, }; - use plonky2_ecgfp5::curve::curve::Point; - use rand::{thread_rng, Rng}; - - use crate::{ - cells_tree::{self, Cell}, - row_tree::public_inputs::PublicInputs, - }; - - use super::{LeafCircuit, LeafWires}; + use std::iter::once; #[derive(Debug, Clone)] struct TestLeafCircuit { @@ -163,7 +150,7 @@ mod test { type Wires = (LeafWires, Vec); fn build(c: &mut CircuitBuilder) -> Self::Wires { - let cells_pi = c.add_virtual_targets(cells_tree::PublicInputs::::TOTAL_LEN); + let cells_pi = c.add_virtual_targets(cells_tree::PublicInputs::::total_len()); (LeafCircuit::build(c, &cells_pi), cells_pi) } @@ -174,48 +161,57 @@ mod test { } fn test_row_tree_leaf_circuit(is_multiplier: bool, cells_multiplier: bool) { - let mut rng = thread_rng(); - let value = U256::from_limbs(rng.gen::<[u64; 4]>()); - let identifier = F::rand(); - let row_cell = Cell::new(identifier, value, is_multiplier); - let circuit = LeafCircuit::from(row_cell.clone()); - let tuple = row_cell.clone(); - - let ind_cells_digest = Point::rand().to_fields(); - // TODO: test with other than neutral - let mul_cells_digest = if cells_multiplier { - Point::rand().to_fields() - } else { - Point::NEUTRAL.to_fields() + let cells_pi = CellsPublicInputs::sample(cells_multiplier); + + let row = Row::sample(is_multiplier); + let id = row.cell.identifier; + let value = row.cell.value; + let row_digest = row.digest(&CellsPublicInputs::from_slice(&cells_pi)); + + let circuit = LeafCircuit::from(row); + let test_circuit = TestLeafCircuit { + circuit, + cells_pi: cells_pi.clone(), }; - let cells_hash = HashOut::rand().to_fields(); - let cells_pi_struct = - cells_tree::PublicInputs::new(&cells_hash, &ind_cells_digest, &mul_cells_digest); - let cells_pi = cells_pi_struct.to_vec(); - let test_circuit = TestLeafCircuit { circuit, cells_pi }; + let cells_pi = CellsPublicInputs::from_slice(&cells_pi); + let proof = run_circuit::(test_circuit); let pi = PublicInputs::from_slice(&proof.public_inputs); - assert_eq!(value, pi.max_value_u256()); - assert_eq!(value, pi.min_value_u256()); - let empty_hash = empty_poseidon_hash(); - let inputs = empty_hash - .to_fields() - .iter() - .chain(empty_hash.to_fields().iter()) - .chain(tuple.value.to_fields().iter()) - .chain(tuple.value.to_fields().iter()) - .chain(tuple.to_fields().iter()) - .chain(cells_hash.iter()) - .cloned() - .collect::>(); - let row_hash = hash_n_to_hash_no_pad::>::Permutation>(&inputs); - assert_eq!(row_hash, pi.root_hash_hashout()); - // final_digest = HashToInt(mul_digest) * D(ind_digest) - let split_digest = - row_cell.split_and_accumulate_digest(cells_pi_struct.split_digest_point()); - let result = split_digest.cond_combine_to_row_digest(); - assert_eq!(result.to_weierstrass(), pi.rows_digest_field()); - assert_eq!(split_digest.is_merge_case(), pi.is_merge_flag()); + + // Check root hash + { + let value = value.to_fields(); + let empty_hash = empty_poseidon_hash().to_fields(); + let inputs = empty_hash + .iter() + .chain(empty_hash.iter()) + .chain(value.iter()) + .chain(value.iter()) + .chain(once(&id)) + .chain(cells_pi.to_node_hash_raw()) + .cloned() + .collect_vec(); + let exp_root_hash = H::hash_no_pad(&inputs); + assert_eq!(pi.root_hash(), exp_root_hash); + } + // Check individual digest + assert_eq!( + pi.individual_digest_point(), + row_digest.individual_vd.to_weierstrass() + ); + // Check multiplier digest + assert_eq!( + pi.multiplier_digest_point(), + row_digest.multiplier_vd.to_weierstrass() + ); + // Check row ID multiplier + assert_eq!(pi.row_id_multiplier(), row_digest.row_id_multiplier); + // Check minimum value + assert_eq!(pi.min_value(), value); + // Check maximum value + assert_eq!(pi.max_value(), value); + // Check merge flag + assert_eq!(pi.merge_flag(), row_digest.is_merge); } #[test] @@ -238,4 +234,3 @@ mod test { test_row_tree_leaf_circuit(true, true); } } -*/ diff --git a/verifiable-db/src/row_tree/partial_node.rs b/verifiable-db/src/row_tree/partial_node.rs index 2b2d6bde2..24bc9143f 100644 --- a/verifiable-db/src/row_tree/partial_node.rs +++ b/verifiable-db/src/row_tree/partial_node.rs @@ -117,11 +117,11 @@ impl PartialNodeCircuit { PublicInputs::new( &node_hash, &digest.individual_vd.to_targets(), + &digest.multiplier_vd.to_targets(), + &digest.row_id_multiplier.to_targets(), &node_min.to_targets(), &node_max.to_targets(), &[digest.is_merge.target], - &digest.multiplier_vd.to_targets(), - &digest.row_id_multiplier.to_targets(), ) .register(b); PartialNodeWires { @@ -184,40 +184,20 @@ impl CircuitLogicWires for RecursivePartialWires { } } -/* #[cfg(test)] pub mod test { + use super::*; + use alloy::primitives::U256; + use itertools::Itertools; use mp2_common::{ - group_hashing::{cond_field_hashed_scalar_mul, map_to_curve_point}, - poseidon::empty_poseidon_hash, + poseidon::{empty_poseidon_hash, H}, + types::CBuilder, utils::ToFields, - CHasher, + C, D, F, }; - use plonky2::{hash::hash_types::HashOut, plonk::config::Hasher}; - use plonky2_ecgfp5::curve::curve::Point; - - use alloy::primitives::U256; - use mp2_common::{C, D, F}; - use mp2_test::{ - circuit::{run_circuit, UserCircuit}, - utils::weierstrass_to_point, - }; - use plonky2::{ - field::types::Sample, - hash::hashing::hash_n_to_hash_no_pad, - iop::{target::Target, witness::WitnessWrite}, - plonk::circuit_builder::CircuitBuilder, - }; - - use crate::{ - cells_tree::{self, Cell}, - row_tree::{ - full_node::test::generate_random_pi, partial_node::PartialNodeCircuit, - public_inputs::PublicInputs, - }, - }; - - use super::PartialNodeWires; + use mp2_test::circuit::{run_circuit, UserCircuit}; + use plonky2::plonk::config::Hasher; + use std::iter::once; #[derive(Clone, Debug)] struct TestPartialNodeCircuit { @@ -229,10 +209,11 @@ pub mod test { impl UserCircuit for TestPartialNodeCircuit { type Wires = (PartialNodeWires, Vec, Vec); - fn build(c: &mut CircuitBuilder) -> Self::Wires { - let child_pi = c.add_virtual_targets(PublicInputs::::TOTAL_LEN); - let cells_pi = c.add_virtual_targets(cells_tree::PublicInputs::::TOTAL_LEN); + fn build(c: &mut CBuilder) -> Self::Wires { + let child_pi = c.add_virtual_targets(PublicInputs::::total_len()); + let cells_pi = c.add_virtual_targets(cells_tree::PublicInputs::::total_len()); let wires = PartialNodeCircuit::build(c, &child_pi, &cells_pi); + (wires, child_pi, cells_pi) } @@ -299,29 +280,26 @@ pub mod test { } fn partial_node_circuit(child_at_left: bool, is_multiplier: bool, is_cell_multiplier: bool) { - let tuple = Cell::new(F::rand(), U256::from(18), is_multiplier); + let mut row = Row::sample(is_multiplier); + row.cell.value = U256::from(18); + let id = row.cell.identifier; + let value = row.cell.value; + let cells_pi = cells_tree::PublicInputs::sample(is_cell_multiplier); + // Compute the row digest. + let row_digest = row.digest(&cells_tree::PublicInputs::from_slice(&cells_pi)); let (child_min, child_max) = match child_at_left { true => (U256::from(10), U256::from(15)), false => (U256::from(20), U256::from(25)), }; - partial_safety_check(child_min, child_max, tuple.value, child_at_left); - let node_circuit = PartialNodeCircuit::new(tuple.clone(), child_at_left); - let child_pi = generate_random_pi( + partial_safety_check(child_min, child_max, value, child_at_left); + let node_circuit = PartialNodeCircuit::new(row.clone(), child_at_left); + let child_pi = PublicInputs::sample( + row_digest.multiplier_vd, + row_digest.row_id_multiplier.clone(), child_min.to(), child_max.to(), is_cell_multiplier || is_multiplier, ); - let cells_point = Point::rand(); - let ind_cell_digest = cells_point.to_weierstrass().to_fields(); - let cells_hash = HashOut::rand().to_fields(); - let mul_cell_digest = if is_cell_multiplier { - cells_point.to_weierstrass().to_fields() - } else { - Point::NEUTRAL.to_fields() - }; - let cells_pi_struct = - cells_tree::PublicInputs::new(&cells_hash, &ind_cell_digest, &mul_cell_digest); - let cells_pi = cells_pi_struct.to_vec(); let test_circuit = TestPartialNodeCircuit { circuit: node_circuit, cells_pi: cells_pi.clone(), @@ -329,37 +307,52 @@ pub mod test { }; let proof = run_circuit::(test_circuit); let pi = PublicInputs::from_slice(&proof.public_inputs); - // node_min = left ? child_proof.min : index_value - // node_max = left ? index_value : child_proof.max - let (node_min, node_max) = match child_at_left { - true => (pi.min_value(), tuple.value), - false => (tuple.value, pi.max_value()), - }; - // Poseidon(p1.H || p2.H || node_min || node_max || index_id || index_value ||p.H)) as H - let child_hash = PublicInputs::from_slice(&child_pi).root_hash_hashout(); - let empty_hash = empty_poseidon_hash(); - let input_hash = match child_at_left { - true => [child_hash.to_fields(), empty_hash.to_fields()].concat(), - false => [empty_hash.to_fields(), child_hash.to_fields()].concat(), - }; - let inputs = input_hash - .iter() - .chain(node_min.to_fields().iter()) - .chain(node_max.to_fields().iter()) - .chain(tuple.to_fields().iter()) - .chain(cells_hash.iter()) - .cloned() - .collect::>(); - let hash = hash_n_to_hash_no_pad::>::Permutation>(&inputs); - assert_eq!(hash, pi.root_hash_hashout()); - // final_digest = HashToInt(mul_digest) * D(ind_digest) + row_proof.digest() - let split_digest = tuple.split_and_accumulate_digest(cells_pi_struct.split_digest_point()); - let res = split_digest.cond_combine_to_row_digest(); - // then adding with the rest of the rows digest, the other nodes - let res = - res + weierstrass_to_point(&PublicInputs::from_slice(&child_pi).rows_digest_field()); - assert_eq!(res.to_weierstrass(), pi.rows_digest_field()); - assert_eq!(split_digest.is_merge_case(), pi.is_merge_flag()); + + let child_pi = PublicInputs::from_slice(&child_pi); + let cells_pi = cells_tree::PublicInputs::from_slice(&cells_pi); + + // Check root hash + { + // node_min = left ? child_proof.min : index_value + // node_max = left ? index_value : child_proof.max + let (node_min, node_max) = match child_at_left { + true => (pi.min_value(), value), + false => (value, pi.max_value()), + }; + // Poseidon(p1.H || p2.H || node_min || node_max || index_id || index_value ||p.H)) as H + let child_hash = child_pi.root_hash().to_fields(); + let empty_hash = empty_poseidon_hash().to_fields(); + let input_hash = match child_at_left { + true => [child_hash, empty_hash].concat(), + false => [empty_hash, child_hash].concat(), + }; + let inputs = input_hash + .into_iter() + .chain(node_min.to_fields()) + .chain(node_max.to_fields()) + .chain(once(id)) + .chain(cells_pi.node_hash().to_fields()) + .collect_vec(); + let exp_root_hash = H::hash_no_pad(&inputs); + assert_eq!(pi.root_hash(), exp_root_hash); + } + // Check individual digest + assert_eq!( + pi.individual_digest_point(), + row_digest.individual_vd.to_weierstrass() + ); + // Check multiplier digest + assert_eq!( + pi.multiplier_digest_point(), + row_digest.multiplier_vd.to_weierstrass() + ); + // Check row ID multiplier + assert_eq!(pi.row_id_multiplier(), row_digest.row_id_multiplier); + // Check minimum value + assert_eq!(pi.min_value(), value.min(child_min)); + // Check maximum value + assert_eq!(pi.max_value(), value.max(child_max)); + // Check merge flag + assert_eq!(pi.merge_flag(), row_digest.is_merge); } } -*/ diff --git a/verifiable-db/src/row_tree/public_inputs.rs b/verifiable-db/src/row_tree/public_inputs.rs index bccaf83d8..52ed490dd 100644 --- a/verifiable-db/src/row_tree/public_inputs.rs +++ b/verifiable-db/src/row_tree/public_inputs.rs @@ -92,7 +92,7 @@ impl<'a, T: Clone> PublicInputs<'a, T> { } pub const fn total_len() -> usize { - Self::to_range(RowsTreePublicInputs::RowIdMultiplier).end + Self::to_range(RowsTreePublicInputs::MergeFlag).end } pub fn to_root_hash_raw(&self) -> &[T] { @@ -263,7 +263,7 @@ impl<'a> PublicInputs<'a, F> { } #[cfg(test)] -mod tests { +pub(crate) mod tests { use super::*; use mp2_common::{utils::ToFields, C, D, F}; use mp2_test::{ @@ -281,6 +281,38 @@ mod tests { use rand::{thread_rng, Rng}; use std::{array, slice}; + impl<'a> PublicInputs<'a, F> { + pub(crate) fn sample( + multiplier_digest: Point, + row_id_multiplier: BigUint, + min: usize, + max: usize, + is_merge: bool, + ) -> Vec { + let h = HashOut::rand().to_fields(); + let individual_digest = Point::rand(); + let [individual_digest, multiplier_digest] = + [individual_digest, multiplier_digest].map(|p| p.to_weierstrass().to_fields()); + let row_id_multiplier = row_id_multiplier + .to_u32_digits() + .into_iter() + .map(F::from_canonical_u32) + .collect_vec(); + let [min, max] = [min, max].map(|v| U256::from(v).to_fields()); + let merge = F::from_bool(is_merge); + PublicInputs::new( + &h, + &individual_digest, + &multiplier_digest, + &row_id_multiplier, + &min, + &max, + &[merge], + ) + .to_vec() + } + } + #[derive(Clone, Debug)] struct TestPublicInputs<'a> { exp_pi: &'a [F], diff --git a/verifiable-db/src/row_tree/row.rs b/verifiable-db/src/row_tree/row.rs index db19effcc..c4cb9569b 100644 --- a/verifiable-db/src/row_tree/row.rs +++ b/verifiable-db/src/row_tree/row.rs @@ -1,26 +1,78 @@ //! Row information for the rows tree -use crate::cells_tree::{Cell, CellWire, PublicInputs}; +use crate::cells_tree::{Cell, CellWire, PublicInputs as CellsPublicInputs}; use derive_more::Constructor; +use itertools::Itertools; use mp2_common::{ - poseidon::{empty_poseidon_hash, hash_to_int_target, H, HASH_TO_INT_LEN}, + poseidon::{empty_poseidon_hash, hash_to_int_target, hash_to_int_value, H, HASH_TO_INT_LEN}, serialization::{deserialize, serialize}, - types::CBuilder, + types::{CBuilder, CURVE_TARGET_LEN}, u256::UInt256Target, - utils::ToTargets, + utils::{FromFields, ToFields, ToTargets}, F, }; +use num::BigUint; use plonky2::{ + field::types::{Field, PrimeField64}, hash::hash_types::{HashOut, HashOutTarget}, iop::{ target::{BoolTarget, Target}, witness::{PartialWitness, WitnessWrite}, }, + plonk::config::Hasher, }; use plonky2_ecdsa::gadgets::{biguint::BigUintTarget, nonnative::CircuitBuilderNonNative}; -use plonky2_ecgfp5::gadgets::curve::{CircuitBuilderEcGFp5, CurveTarget}; +use plonky2_ecgfp5::{ + curve::{curve::Point, scalar_field::Scalar}, + gadgets::curve::{CircuitBuilderEcGFp5, CurveTarget}, +}; use serde::{Deserialize, Serialize}; +#[derive(Clone, Debug, Eq, PartialEq)] +pub(crate) struct RowDigest { + pub(crate) is_merge: bool, + pub(crate) row_id_multiplier: BigUint, + pub(crate) individual_vd: Point, + pub(crate) multiplier_vd: Point, +} + +impl FromFields for RowDigest { + fn from_fields(t: &[F]) -> Self { + let mut pos = 0; + + let is_merge = t[pos].is_nonzero(); + pos += 1; + + let row_id_multiplier = BigUint::new( + t[pos..pos + HASH_TO_INT_LEN] + .iter() + .map(|f| u32::try_from(f.to_canonical_u64()).unwrap()) + .collect_vec(), + ); + pos += HASH_TO_INT_LEN; + + let individual_vd = Point::from_fields(&t[pos..pos + CURVE_TARGET_LEN]); + pos += CURVE_TARGET_LEN; + + let multiplier_vd = Point::from_fields(&t[pos..pos + CURVE_TARGET_LEN]); + + Self { + is_merge, + row_id_multiplier, + individual_vd, + multiplier_vd, + } + } +} + +#[derive(Clone, Debug)] +pub(crate) struct RowDigestTarget { + pub(crate) is_merge: BoolTarget, + pub(crate) row_id_multiplier: BigUintTarget, + pub(crate) individual_vd: CurveTarget, + pub(crate) multiplier_vd: CurveTarget, +} + #[derive(Clone, Debug, Serialize, Deserialize, Constructor)] pub(crate) struct Row { pub(crate) cell: Cell, @@ -32,6 +84,52 @@ impl Row { self.cell.assign_wires(pw, &wires.cell); pw.set_hash_target(wires.row_unique_data, self.row_unique_data); } + + pub(crate) fn digest(&self, cells_pi: &CellsPublicInputs) -> RowDigest { + let metadata_digests = self.cell.split_metadata_digest(); + let values_digests = self.cell.split_values_digest(); + + let metadata_digests = metadata_digests.accumulate(&cells_pi.split_metadata_digest_point()); + let values_digests = values_digests.accumulate(&cells_pi.split_values_digest_point()); + + // Compute row ID for individual cells: + // row_id_individual = H2Int(row_unique_data || individual_md) + let inputs = self + .row_unique_data + .to_fields() + .into_iter() + .chain(metadata_digests.individual.to_fields()) + .collect_vec(); + let hash = H::hash_no_pad(&inputs); + let row_id_individual = hash_to_int_value(hash); + let row_id_individual = Scalar::from_noncanonical_biguint(row_id_individual); + + // Multiply row ID to individual value digest: + // individual_vd = row_id_individual * individual_vd + let individual_vd = values_digests.individual * row_id_individual; + + // Multiplier is always employed for set of scalar variables, and `row_unique_data` + // for such a set is always `H("")``, so we can hardocode it in the circuit: + // row_id_multiplier = H2Int(H("") || multiplier_md) + let empty_hash = empty_poseidon_hash(); + let inputs = empty_hash + .to_fields() + .into_iter() + .chain(metadata_digests.multiplier.to_fields()) + .collect_vec(); + let hash = H::hash_no_pad(&inputs); + let row_id_multiplier = hash_to_int_value(hash); + + let is_merge = values_digests.is_merge_case(); + let multiplier_vd = values_digests.multiplier; + + RowDigest { + is_merge, + row_id_multiplier, + individual_vd, + multiplier_vd, + } + } } #[derive(Clone, Debug, Serialize, Deserialize)] @@ -41,15 +139,6 @@ pub(crate) struct RowWire { pub(crate) row_unique_data: HashOutTarget, } -/// Row digest result -#[derive(Clone, Debug)] -pub(crate) struct RowDigest { - pub(crate) is_merge: BoolTarget, - pub(crate) row_id_multiplier: BigUintTarget, - pub(crate) individual_vd: CurveTarget, - pub(crate) multiplier_vd: CurveTarget, -} - impl RowWire { pub(crate) fn new(b: &mut CBuilder) -> Self { Self { @@ -66,7 +155,11 @@ impl RowWire { &self.cell.value } - pub(crate) fn digest(&self, b: &mut CBuilder, cells_pi: &PublicInputs) -> RowDigest { + pub(crate) fn digest( + &self, + b: &mut CBuilder, + cells_pi: &CellsPublicInputs, + ) -> RowDigestTarget { let metadata_digests = self.cell.split_metadata_digest(b); let values_digests = self.cell.split_values_digest(b); @@ -106,7 +199,7 @@ impl RowWire { let is_merge = values_digests.is_merge_case(b); let multiplier_vd = values_digests.multiplier; - RowDigest { + RowDigestTarget { is_merge, row_id_multiplier, individual_vd, @@ -116,6 +209,66 @@ impl RowWire { } #[cfg(test)] -mod test { - // gupeng +pub(crate) mod tests { + use super::*; + use mp2_common::{utils::FromFields, C, D, F}; + use mp2_test::circuit::{run_circuit, UserCircuit}; + use plonky2::field::types::Sample; + use rand::{thread_rng, Rng}; + + impl Row { + pub(crate) fn sample(is_multiplier: bool) -> Self { + let cell = Cell::sample(is_multiplier); + let row_unique_data = HashOut::rand(); + + Row::new(cell, row_unique_data) + } + } + + #[derive(Clone, Debug)] + struct TestRowCircuit<'a> { + row: &'a Row, + cells_pi: &'a [F], + } + + impl<'a> UserCircuit for TestRowCircuit<'a> { + // Row wire + cells PI + type Wires = (RowWire, Vec); + + fn build(b: &mut CBuilder) -> Self::Wires { + let row = RowWire::new(b); + let cells_proof = b.add_virtual_targets(CellsPublicInputs::::total_len()); + let cells_pi = CellsPublicInputs::from_slice(&cells_proof); + + let digest = row.digest(b, &cells_pi); + + b.register_public_input(digest.is_merge.target); + b.register_public_inputs(&digest.row_id_multiplier.to_targets()); + b.register_public_inputs(&digest.individual_vd.to_targets()); + b.register_public_inputs(&digest.multiplier_vd.to_targets()); + + (row, cells_proof) + } + + fn prove(&self, pw: &mut PartialWitness, wires: &Self::Wires) { + self.row.assign_wires(pw, &wires.0); + pw.set_target_arr(&wires.1, self.cells_pi); + } + } + + #[test] + fn test_rows_tree_row_circuit() { + let rng = &mut thread_rng(); + + let cells_pi = &CellsPublicInputs::sample(rng.gen()); + let row = &Row::sample(rng.gen()); + let exp_row_digest = row.digest(&CellsPublicInputs::from_slice(cells_pi)); + + let test_circuit = TestRowCircuit { row, cells_pi }; + + let proof = run_circuit::(test_circuit); + let row_digest = RowDigest::from_fields(&proof.public_inputs); + + assert_eq!(row_digest, exp_row_digest); + } } From d9f8825baf48b0719ba3a271c9cfba67b9b01c6f Mon Sep 17 00:00:00 2001 From: Steven Gu Date: Wed, 23 Oct 2024 12:13:49 +0800 Subject: [PATCH 04/16] Update tests of block tree. --- mp2-v1/tests/common/celltree.rs | 10 +++- mp2-v1/tests/common/index_tree.rs | 2 +- mp2-v1/tests/common/mod.rs | 4 +- mp2-v1/tests/common/rowtree.rs | 20 ++++++-- verifiable-db/src/block_tree/api.rs | 50 ++++++++++++------- verifiable-db/src/block_tree/leaf.rs | 58 +++++++++++++--------- verifiable-db/src/block_tree/mod.rs | 68 +++++++++++++++++--------- verifiable-db/src/block_tree/parent.rs | 41 ++++++++++------ 8 files changed, 165 insertions(+), 88 deletions(-) diff --git a/mp2-v1/tests/common/celltree.rs b/mp2-v1/tests/common/celltree.rs index ed21fd2ea..2a075c62d 100644 --- a/mp2-v1/tests/common/celltree.rs +++ b/mp2-v1/tests/common/celltree.rs @@ -10,7 +10,7 @@ use mp2_v1::{ row::{CellCollection, Row, RowPayload, RowTreeKey}, }, }; -use plonky2::plonk::config::GenericHashOut; +use plonky2::{field::types::Sample, hash::hash_types::HashOut, plonk::config::GenericHashOut}; use ryhope::storage::{ updatetree::{Next, UpdateTree}, RoEpochKvStorage, @@ -69,6 +69,8 @@ impl TestContext { cell.identifier(), cell.value(), column.multiplier, + // TODO: Check mpt_metadata = cell.hash? + HashOut::rand(), ), ); self.b.bench("indexing::cell_tree::leaf", || { @@ -94,6 +96,8 @@ impl TestContext { cell.identifier(), cell.value(), column.multiplier, + // TODO: Check mpt_metadata = cell.hash? + HashOut::rand(), left_proof.clone(), ), ); @@ -149,6 +153,8 @@ impl TestContext { cell.identifier(), cell.value(), column.multiplier, + // TODO: Check mpt_metadata = cell.hash? + HashOut::rand(), [left_proof, right_proof], ), ); @@ -171,7 +177,7 @@ impl TestContext { "[+] [+] Merkle SLOT identifier {:?} -> value {} value.digest() = {:?}", cell.identifier(), cell.value(), - pi.individual_digest_point() + pi.individual_values_digest_point() ); self.storage diff --git a/mp2-v1/tests/common/index_tree.rs b/mp2-v1/tests/common/index_tree.rs index 5bfda6bc6..63e043e1e 100644 --- a/mp2-v1/tests/common/index_tree.rs +++ b/mp2-v1/tests/common/index_tree.rs @@ -86,7 +86,7 @@ impl TestContext { // TODO: Fix the rows digest in rows tree according to values extraction update. // assert_eq!( - row_pi.rows_digest_field(), + row_pi.individual_digest_point(), ext_pi.value_point(), "values extracted vs value in db don't match (left row, right mpt (block {})", node.value.0.to::() diff --git a/mp2-v1/tests/common/mod.rs b/mp2-v1/tests/common/mod.rs index 77eac7752..11409fecc 100644 --- a/mp2-v1/tests/common/mod.rs +++ b/mp2-v1/tests/common/mod.rs @@ -51,7 +51,7 @@ fn cell_tree_proof_to_hash(proof: &[u8]) -> HashOutput { .proof .public_inputs; verifiable_db::cells_tree::PublicInputs::from_slice(&root_pi) - .root_hash_hashout() + .node_hash() .to_bytes() .try_into() .unwrap() @@ -63,7 +63,7 @@ fn row_tree_proof_to_hash(proof: &[u8]) -> HashOutput { .proof .public_inputs; verifiable_db::row_tree::PublicInputs::from_slice(&root_pi) - .root_hash_hashout() + .root_hash() .to_bytes() .try_into() .unwrap() diff --git a/mp2-v1/tests/common/rowtree.rs b/mp2-v1/tests/common/rowtree.rs index fd0881162..88131e69f 100644 --- a/mp2-v1/tests/common/rowtree.rs +++ b/mp2-v1/tests/common/rowtree.rs @@ -11,7 +11,7 @@ use mp2_v1::{ row::{RowPayload, RowTree, RowTreeKey, ToNonce}, }, }; -use plonky2::plonk::config::GenericHashOut; +use plonky2::{field::types::Sample, hash::hash_types::HashOut, plonk::config::GenericHashOut}; use ryhope::{ storage::{ pgsql::PgsqlStorage, @@ -132,8 +132,8 @@ impl TestContext { let pis = cells_tree::PublicInputs::from_slice(&pvk.proof().public_inputs); debug!( " Cell Root SPLIT digest: multiplier {:?}, individual {:?}", - pis.multiplier_digest_point(), - pis.individual_digest_point() + pis.multiplier_values_digest_point(), + pis.individual_values_digest_point() ); } @@ -151,6 +151,10 @@ impl TestContext { id, value, multiplier, + // TODO: mpt_metadata + HashOut::rand(), + // TODO: row_unique_data + HashOut::rand(), cell_tree_proof, ) .unwrap(), @@ -187,6 +191,10 @@ impl TestContext { value, multiplier, context.left.is_some(), + // TODO: mpt_metadata + HashOut::rand(), + // TODO: row_unique_data + HashOut::rand(), child_proof, cell_tree_proof, ) @@ -229,6 +237,10 @@ impl TestContext { id, value, multiplier, + // TODO: mpt_metadata + HashOut::rand(), + // TODO: row_unique_data + HashOut::rand(), left_proof, right_proof, cell_tree_proof, @@ -277,7 +289,7 @@ impl TestContext { let pi = verifiable_db::row_tree::PublicInputs::from_slice(&pproof.proof().public_inputs); debug!( "[--] FINAL MERKLE DIGEST VALUE --> {:?} ", - pi.rows_digest_field() + pi.individual_digest_point() ); if root_proof_key.primary != primary { debug!("[--] NO UPDATES on row this turn? row.root().primary = {} vs new primary proving step {}",root_proof_key.primary,primary); diff --git a/verifiable-db/src/block_tree/api.rs b/verifiable-db/src/block_tree/api.rs index 023494840..e8c7d33dd 100644 --- a/verifiable-db/src/block_tree/api.rs +++ b/verifiable-db/src/block_tree/api.rs @@ -274,7 +274,10 @@ mod tests { *, }; use crate::{ - block_tree::leaf::tests::{compute_expected_hash, compute_expected_set_digest}, + block_tree::{ + compute_final_digest, + leaf::tests::{compute_expected_hash, compute_expected_set_digest}, + }, extraction, row_tree, }; use mp2_common::{ @@ -288,7 +291,6 @@ mod tests { iop::target::Target, plonk::config::Hasher, }; - use plonky2_ecgfp5::curve::curve::Point; use rand::{rngs::ThreadRng, thread_rng, Rng}; use recursion_framework::framework_testing::TestingRecursiveCircuits; use std::iter; @@ -346,10 +348,9 @@ mod tests { fn generate_rows_tree_proof( &self, rng: &mut ThreadRng, - row_digest: &[F], is_merge_case: bool, ) -> Result { - let pi = random_rows_tree_pi(rng, row_digest, is_merge_case); + let pi = random_rows_tree_pi(rng, is_merge_case); let proof = self .rows_tree_set @@ -358,20 +359,23 @@ mod tests { Ok(ProofWithVK::from((proof[0].clone(), vk))) } + fn generate_leaf_proof( &self, rng: &mut ThreadRng, block_id: F, block_number: U256, ) -> Result { - let row_digest = Point::sample(rng).to_weierstrass().to_fields(); + let rows_tree_proof = self.generate_rows_tree_proof(rng, true)?; + let rows_tree_pi = + row_tree::PublicInputs::from_slice(&rows_tree_proof.proof.public_inputs); + let final_digest = compute_final_digest(true, &rows_tree_pi) + .to_weierstrass() + .to_fields(); let extraction_proof = - self.generate_extraction_proof(rng, block_number, &row_digest, true)?; - let rows_tree_proof = self.generate_rows_tree_proof(rng, &row_digest, true)?; + self.generate_extraction_proof(rng, block_number, &final_digest, true)?; let extraction_pi = extraction::test::PublicInputs::from_slice(&extraction_proof.proof.public_inputs); - let rows_tree_pi = - row_tree::PublicInputs::from_slice(&rows_tree_proof.proof.public_inputs); let input = CircuitInput::new_leaf( block_id.to_canonical_u64(), @@ -438,8 +442,12 @@ mod tests { } // Check new node digest { - let exp_digest = - compute_expected_set_digest(block_id, block_number.to_vec(), rows_tree_pi); + let exp_digest = compute_expected_set_digest( + true, + block_id, + block_number.to_vec(), + rows_tree_pi, + ); assert_eq!(pi.new_value_set_digest_point(), exp_digest.to_weierstrass()); } @@ -458,14 +466,16 @@ mod tests { left_child: HashOut, right_child: HashOut, ) -> Result { - let row_digest = Point::sample(rng).to_weierstrass().to_fields(); + let rows_tree_proof = self.generate_rows_tree_proof(rng, false)?; + let rows_tree_pi = + row_tree::PublicInputs::from_slice(&rows_tree_proof.proof.public_inputs); + let final_digest = compute_final_digest(false, &rows_tree_pi) + .to_weierstrass() + .to_fields(); let extraction_proof = - self.generate_extraction_proof(rng, block_number, &row_digest, false)?; - let rows_tree_proof = self.generate_rows_tree_proof(rng, &row_digest, false)?; + self.generate_extraction_proof(rng, block_number, &final_digest, false)?; let extraction_pi = extraction::test::PublicInputs::from_slice(&extraction_proof.proof.public_inputs); - let rows_tree_pi = - row_tree::PublicInputs::from_slice(&rows_tree_proof.proof.public_inputs); let old_rows_tree_hash = HashOut::from_vec(random_vector::(NUM_HASH_OUT_ELTS).to_fields()); @@ -553,8 +563,12 @@ mod tests { } // Check new node digest { - let exp_digest = - compute_expected_set_digest(block_id, block_number.to_vec(), rows_tree_pi); + let exp_digest = compute_expected_set_digest( + false, + block_id, + block_number.to_vec(), + rows_tree_pi, + ); assert_eq!(pi.new_value_set_digest_point(), exp_digest.to_weierstrass()); } diff --git a/verifiable-db/src/block_tree/leaf.rs b/verifiable-db/src/block_tree/leaf.rs index b6966047a..d9080e58e 100644 --- a/verifiable-db/src/block_tree/leaf.rs +++ b/verifiable-db/src/block_tree/leaf.rs @@ -2,7 +2,7 @@ //! an existing node (or if there is no existing node, which happens for the //! first block number). -use super::{compute_final_digest, compute_index_digest, public_inputs::PublicInputs}; +use super::{compute_final_digest_target, compute_index_digest, public_inputs::PublicInputs}; use crate::{ extraction::{ExtractionPI, ExtractionPIWrap}, row_tree, @@ -54,7 +54,7 @@ impl LeafCircuit { let extraction_pi = E::PI::from_slice(extraction_pi); let rows_tree_pi = row_tree::PublicInputs::::from_slice(rows_tree_pi); - let final_digest = compute_final_digest::(b, &extraction_pi, &rows_tree_pi); + let final_digest = compute_final_digest_target::(b, &extraction_pi, &rows_tree_pi); // in our case, the extraction proofs extracts from the blockchain and sets // the block number as the primary index @@ -208,29 +208,24 @@ where #[cfg(test)] pub mod tests { - use crate::{ - block_tree::tests::{TestPIField, TestPITargets}, - extraction, - }; - use super::{ super::tests::{random_extraction_pi, random_rows_tree_pi}, *, }; + use crate::{ + block_tree::{ + compute_final_digest, + tests::{TestPIField, TestPITargets}, + }, + extraction, + }; use alloy::primitives::U256; use mp2_common::{ poseidon::{hash_to_int_value, H}, utils::{Fieldable, ToFields}, }; - use mp2_test::{ - circuit::{run_circuit, UserCircuit}, - utils::weierstrass_to_point, - }; - use plonky2::{ - field::types::{Field, Sample}, - hash::hash_types::HashOut, - plonk::config::Hasher, - }; + use mp2_test::circuit::{run_circuit, UserCircuit}; + use plonky2::{field::types::Field, hash::hash_types::HashOut, plonk::config::Hasher}; use plonky2_ecgfp5::curve::{curve::Point, scalar_field::Scalar}; use rand::{thread_rng, Rng}; @@ -248,6 +243,7 @@ pub mod tests { } pub fn compute_expected_set_digest( + is_merge_case: bool, identifier: F, value: Vec, rows_tree_pi: row_tree::PublicInputs, @@ -258,8 +254,7 @@ pub mod tests { let hash = H::hash_no_pad(&inputs); let int = hash_to_int_value(hash); let scalar = Scalar::from_noncanonical_biguint(int); - let point = rows_tree_pi.individual_digest_point(); - let point = weierstrass_to_point(&point); + let point = compute_final_digest(is_merge_case, &rows_tree_pi); point * scalar } #[derive(Clone, Debug)] @@ -295,14 +290,27 @@ pub mod tests { #[test] fn test_block_index_leaf_circuit() { + test_leaf_circuit(true); + test_leaf_circuit(false); + } + + fn test_leaf_circuit(is_merge_case: bool) { let mut rng = thread_rng(); let block_id = rng.gen::().to_field(); let block_number = U256::from_limbs(rng.gen::<[u64; 4]>()); - let row_digest = Point::sample(&mut rng).to_weierstrass().to_fields(); - let extraction_pi = &random_extraction_pi(&mut rng, block_number, &row_digest, false); - let rows_tree_pi = &random_rows_tree_pi(&mut rng, &row_digest, false); + let rows_tree_pi = &random_rows_tree_pi(&mut rng, is_merge_case); + let final_digest = compute_final_digest( + is_merge_case, + &row_tree::PublicInputs::from_slice(rows_tree_pi), + ); + let extraction_pi = &random_extraction_pi( + &mut rng, + block_number, + &final_digest.to_fields(), + is_merge_case, + ); let test_circuit = TestLeafCircuit { c: LeafCircuit { @@ -368,8 +376,12 @@ pub mod tests { } // Check new node digest { - let exp_digest = - compute_expected_set_digest(block_id, block_number.to_vec(), rows_tree_pi); + let exp_digest = compute_expected_set_digest( + is_merge_case, + block_id, + block_number.to_vec(), + rows_tree_pi, + ); assert_eq!(pi.new_value_set_digest_point(), exp_digest.to_weierstrass()); } } diff --git a/verifiable-db/src/block_tree/mod.rs b/verifiable-db/src/block_tree/mod.rs index 6a65418fa..946b133ee 100644 --- a/verifiable-db/src/block_tree/mod.rs +++ b/verifiable-db/src/block_tree/mod.rs @@ -10,15 +10,22 @@ use crate::{ }; pub use api::{CircuitInput, PublicParameters}; use mp2_common::{ - group_hashing::{circuit_hashed_scalar_mul, CircuitBuilderGroupHashing}, + group_hashing::{ + circuit_hashed_scalar_mul, field_hashed_scalar_mul, weierstrass_to_point, + CircuitBuilderGroupHashing, + }, poseidon::hash_to_int_target, types::CBuilder, + utils::ToFields, CHasher, D, F, }; use plonky2::{field::types::Field, iop::target::Target, plonk::circuit_builder::CircuitBuilder}; use plonky2_ecdsa::gadgets::nonnative::CircuitBuilderNonNative; -use plonky2_ecgfp5::gadgets::curve::{CircuitBuilderEcGFp5, CurveTarget}; +use plonky2_ecgfp5::{ + curve::{curve::Point, scalar_field::Scalar}, + gadgets::curve::{CircuitBuilderEcGFp5, CurveTarget}, +}; pub use public_inputs::PublicInputs; /// Common function to compute the digest of the block tree which uses a special format using @@ -34,8 +41,28 @@ pub(crate) fn compute_index_digest( b.curve_scalar_mul(base, &scalar) } -/// Compute the final digest. -pub(crate) fn compute_final_digest<'a, E>( +/// Compute the final digest value. +pub(crate) fn compute_final_digest( + is_merge_case: bool, + rows_tree_pi: &row_tree::PublicInputs, +) -> Point { + let individual_digest = weierstrass_to_point(&rows_tree_pi.individual_digest_point()); + if !is_merge_case { + return individual_digest; + } + + // Compute the final row digest from rows_tree_proof for merge case: + // multiplier_digest = rows_tree_proof.row_id_multiplier * rows_tree_proof.multiplier_vd + let multiplier_vd = weierstrass_to_point(&rows_tree_pi.multiplier_digest_point()); + let row_id_multiplier = Scalar::from_noncanonical_biguint(rows_tree_pi.row_id_multiplier()); + let multiplier_digest = multiplier_vd * row_id_multiplier; + // rows_digest_merge = multiplier_digest * rows_tree_proof.DR + let individual_digest = weierstrass_to_point(&rows_tree_pi.individual_digest_point()); + field_hashed_scalar_mul(multiplier_digest.to_fields(), individual_digest) +} + +/// Compute the final digest target. +pub(crate) fn compute_final_digest_target<'a, E>( b: &mut CBuilder, extraction_pi: &E::PI<'a>, rows_tree_pi: &row_tree::PublicInputs, @@ -91,6 +118,7 @@ pub(crate) mod tests { use alloy::primitives::U256; use mp2_common::{keccak::PACKED_HASH_LEN, poseidon::HASH_TO_INT_LEN, utils::ToFields, F}; use mp2_test::utils::random_vector; + use num::BigUint; use plonky2::{ field::types::{Field, Sample}, hash::hash_types::NUM_HASH_OUT_ELTS, @@ -98,6 +126,7 @@ pub(crate) mod tests { }; use plonky2_ecgfp5::curve::curve::Point; use rand::{rngs::ThreadRng, Rng}; + use std::array; use crate::row_tree; @@ -132,27 +161,18 @@ pub(crate) mod tests { } /// Generate a random rows tree public inputs. - pub(crate) fn random_rows_tree_pi( - rng: &mut ThreadRng, - row_digest: &[F], - is_merge_case: bool, - ) -> Vec { - let h = random_vector::(NUM_HASH_OUT_ELTS).to_fields(); - let [min, max] = [0; 2].map(|_| U256::from_limbs(rng.gen::<[u64; 4]>()).to_fields()); - let is_merge = [F::from_canonical_usize(is_merge_case as usize)]; - let multiplier_digest = Point::sample(rng).to_weierstrass().to_fields(); - let row_id_multiplier = random_vector::(HASH_TO_INT_LEN).to_fields(); - - row_tree::PublicInputs::new( - &h, - row_digest, - &min, - &max, - &is_merge, - &multiplier_digest, - &row_id_multiplier, + pub(crate) fn random_rows_tree_pi(rng: &mut ThreadRng, is_merge_case: bool) -> Vec { + let [min, max] = array::from_fn(|_| rng.gen()); + let multiplier_digest = Point::rand(); + let row_id_multiplier = BigUint::from_slice(&random_vector::(HASH_TO_INT_LEN)); + + row_tree::PublicInputs::sample( + multiplier_digest, + row_id_multiplier, + min, + max, + is_merge_case, ) - .to_vec() } /// Generate a random extraction public inputs. diff --git a/verifiable-db/src/block_tree/parent.rs b/verifiable-db/src/block_tree/parent.rs index 68988e87f..0518a7692 100644 --- a/verifiable-db/src/block_tree/parent.rs +++ b/verifiable-db/src/block_tree/parent.rs @@ -1,7 +1,7 @@ //! This circuit is employed when the new node is inserted as parent of an existing node, //! referred to as old node. -use super::{compute_final_digest, compute_index_digest, public_inputs::PublicInputs}; +use super::{compute_final_digest_target, compute_index_digest, public_inputs::PublicInputs}; use crate::{ extraction::{ExtractionPI, ExtractionPIWrap}, row_tree, @@ -83,7 +83,7 @@ impl ParentCircuit { let extraction_pi = E::PI::from_slice(extraction_pi); let rows_tree_pi = row_tree::PublicInputs::::from_slice(rows_tree_pi); - let final_digest = compute_final_digest::(b, &extraction_pi, &rows_tree_pi); + let final_digest = compute_final_digest_target::(b, &extraction_pi, &rows_tree_pi); let block_number = extraction_pi.primary_index_value(); @@ -276,6 +276,7 @@ where #[cfg(test)] mod tests { use crate::block_tree::{ + compute_final_digest, leaf::tests::{compute_expected_hash, compute_expected_set_digest}, tests::{TestPIField, TestPITargets}, }; @@ -292,10 +293,7 @@ mod tests { circuit::{run_circuit, UserCircuit}, utils::random_vector, }; - use plonky2::{ - field::types::Sample, hash::hash_types::NUM_HASH_OUT_ELTS, plonk::config::Hasher, - }; - use plonky2_ecgfp5::curve::curve::Point; + use plonky2::{hash::hash_types::NUM_HASH_OUT_ELTS, plonk::config::Hasher}; use rand::{thread_rng, Rng}; #[derive(Clone, Debug)] @@ -332,18 +330,29 @@ mod tests { #[test] fn test_block_index_parent_circuit() { + test_parent_circuit(true); + test_parent_circuit(false); + } + + fn test_parent_circuit(is_merge_case: bool) { let mut rng = thread_rng(); let index_identifier = rng.gen::().to_field(); - let [old_index_value, old_min, old_max] = - [0; 3].map(|_| U256::from_limbs(rng.gen::<[u64; 4]>())); + let [old_index_value, old_min, old_max] = [0; 3].map(|_| U256::from_limbs(rng.gen())); let [left_child, right_child, old_rows_tree_hash] = [0; 3].map(|_| HashOut::from_vec(random_vector::(NUM_HASH_OUT_ELTS).to_fields())); - let row_digest = Point::sample(&mut rng).to_weierstrass().to_fields(); - let extraction_pi = - &random_extraction_pi(&mut rng, old_max + U256::from(1), &row_digest, true); - let rows_tree_pi = &random_rows_tree_pi(&mut rng, &row_digest, true); + let rows_tree_pi = &random_rows_tree_pi(&mut rng, is_merge_case); + let final_digest = compute_final_digest( + is_merge_case, + &row_tree::PublicInputs::from_slice(rows_tree_pi), + ); + let extraction_pi = &random_extraction_pi( + &mut rng, + old_max + U256::from(1), + &final_digest.to_fields(), + is_merge_case, + ); let test_circuit = TestParentCircuit { c: ParentCircuit { @@ -429,8 +438,12 @@ mod tests { } // Check new node digest { - let exp_digest = - compute_expected_set_digest(index_identifier, block_number.to_vec(), rows_tree_pi); + let exp_digest = compute_expected_set_digest( + is_merge_case, + index_identifier, + block_number.to_vec(), + rows_tree_pi, + ); assert_eq!(pi.new_value_set_digest_point(), exp_digest.to_weierstrass()); } From f59c5d70f39a32e7aed1077ef85550232177acbd Mon Sep 17 00:00:00 2001 From: Steven Gu Date: Wed, 23 Oct 2024 13:38:41 +0800 Subject: [PATCH 05/16] Fix test. --- parsil/src/tests.rs | 1 + verifiable-db/src/block_tree/mod.rs | 88 +++++++++++++++++++++++++++-- 2 files changed, 83 insertions(+), 6 deletions(-) diff --git a/parsil/src/tests.rs b/parsil/src/tests.rs index 84f0fa766..dc456f663 100644 --- a/parsil/src/tests.rs +++ b/parsil/src/tests.rs @@ -149,6 +149,7 @@ fn test_serde_circuit_pis() { } #[test] +#[ignore = "wait for non-aggregation SELECT to come back"] fn isolation() { fn isolated_to_string(q: &str, lo_sec: bool, hi_sec: bool) -> String { let settings = ParsilSettings { diff --git a/verifiable-db/src/block_tree/mod.rs b/verifiable-db/src/block_tree/mod.rs index 946b133ee..72c79fd4d 100644 --- a/verifiable-db/src/block_tree/mod.rs +++ b/verifiable-db/src/block_tree/mod.rs @@ -115,21 +115,33 @@ where #[cfg(test)] pub(crate) mod tests { + use super::*; + use crate::row_tree; use alloy::primitives::U256; - use mp2_common::{keccak::PACKED_HASH_LEN, poseidon::HASH_TO_INT_LEN, utils::ToFields, F}; - use mp2_test::utils::random_vector; + use mp2_common::{ + keccak::PACKED_HASH_LEN, + poseidon::HASH_TO_INT_LEN, + types::CBuilder, + utils::{FromFields, ToFields}, + C, F, + }; + use mp2_test::{ + circuit::{run_circuit, UserCircuit}, + utils::random_vector, + }; use num::BigUint; use plonky2::{ field::types::{Field, Sample}, hash::hash_types::NUM_HASH_OUT_ELTS, - iop::target::Target, + iop::{ + target::Target, + witness::{PartialWitness, WitnessWrite}, + }, }; use plonky2_ecgfp5::curve::curve::Point; - use rand::{rngs::ThreadRng, Rng}; + use rand::{rngs::ThreadRng, thread_rng, Rng}; use std::array; - use crate::row_tree; - pub(crate) type TestPITargets<'a> = crate::extraction::test::PublicInputs<'a, Target>; pub(crate) type TestPIField<'a> = crate::extraction::test::PublicInputs<'a, F>; @@ -196,4 +208,68 @@ pub(crate) mod tests { ) .to_vec() } + + #[derive(Clone, Debug)] + struct TestFinalDigestCircuit<'a> { + extraction_pi: &'a [F], + rows_tree_pi: &'a [F], + } + + impl<'a> UserCircuit for TestFinalDigestCircuit<'a> { + // Extraction PI + rows tree PI + type Wires = (Vec, Vec); + + fn build(b: &mut CBuilder) -> Self::Wires { + let extraction_pi = b.add_virtual_targets(TestPITargets::TOTAL_LEN); + let rows_tree_pi = b.add_virtual_targets(row_tree::PublicInputs::::total_len()); + + let final_digest = compute_final_digest_target::( + b, + &TestPITargets::from_slice(&extraction_pi), + &row_tree::PublicInputs::from_slice(&rows_tree_pi), + ); + + b.register_curve_public_input(final_digest); + + (extraction_pi, rows_tree_pi) + } + + fn prove(&self, pw: &mut PartialWitness, wires: &Self::Wires) { + pw.set_target_arr(&wires.0, self.extraction_pi); + pw.set_target_arr(&wires.1, self.rows_tree_pi); + } + } + + #[test] + fn test_block_tree_final_digest() { + test_final_digest(true); + test_final_digest(false); + } + + fn test_final_digest(is_merge_case: bool) { + let rng = &mut thread_rng(); + + let rows_tree_pi = &random_rows_tree_pi(rng, is_merge_case); + let exp_final_digest = compute_final_digest( + is_merge_case, + &row_tree::PublicInputs::from_slice(rows_tree_pi), + ); + let block_number = U256::from_limbs(rng.gen()); + let extraction_pi = &random_extraction_pi( + rng, + block_number, + &exp_final_digest.to_fields(), + is_merge_case, + ); + + let test_circuit = TestFinalDigestCircuit { + extraction_pi, + rows_tree_pi, + }; + + let proof = run_circuit::(test_circuit); + let final_digest = Point::from_fields(&proof.public_inputs); + + assert_eq!(final_digest, exp_final_digest); + } } From e651ab9d0793ec9154f4c0f6da6ea67211740349 Mon Sep 17 00:00:00 2001 From: Steven Gu Date: Fri, 25 Oct 2024 16:38:50 +0800 Subject: [PATCH 06/16] Add missing `value` for hash in rows tree. --- verifiable-db/src/row_tree/api.rs | 3 +++ verifiable-db/src/row_tree/full_node.rs | 3 +++ verifiable-db/src/row_tree/leaf.rs | 2 ++ verifiable-db/src/row_tree/partial_node.rs | 2 ++ 4 files changed, 10 insertions(+) diff --git a/verifiable-db/src/row_tree/api.rs b/verifiable-db/src/row_tree/api.rs index 56d5613c4..81f28c394 100644 --- a/verifiable-db/src/row_tree/api.rs +++ b/verifiable-db/src/row_tree/api.rs @@ -475,6 +475,7 @@ mod test { .chain(node_min.to_fields()) .chain(node_max.to_fields()) .chain(once(id)) + .chain(value.to_fields()) .chain(p.cells_pi().node_hash().to_fields()) .collect_vec(); let exp_root_hash = H::hash_no_pad(&inputs); @@ -542,6 +543,7 @@ mod test { .chain(left_pi.min_value().to_fields()) .chain(right_pi.max_value().to_fields()) .chain(once(id)) + .chain(value.to_fields()) .chain(p.cells_pi().node_hash().to_fields()) .collect_vec(); let hash = H::hash_no_pad(&inputs); @@ -600,6 +602,7 @@ mod test { .chain(value.iter()) .chain(value.iter()) .chain(once(&id)) + .chain(value.iter()) .chain(p.cells_pi().to_node_hash_raw()) .cloned() .collect_vec(); diff --git a/verifiable-db/src/row_tree/full_node.rs b/verifiable-db/src/row_tree/full_node.rs index 272c7b779..cedb65a1f 100644 --- a/verifiable-db/src/row_tree/full_node.rs +++ b/verifiable-db/src/row_tree/full_node.rs @@ -75,6 +75,7 @@ impl FullNodeCircuit { .chain(node_min.to_targets().iter()) .chain(node_max.to_targets().iter()) .chain(once(&id)) + .chain(value.to_targets().iter()) .chain(cells_pi.node_hash_target().iter()) .cloned() .collect::>(); @@ -194,6 +195,7 @@ pub(crate) mod test { let mut row = Row::sample(is_multiplier); row.cell.value = U256::from(18); let id = row.cell.identifier; + let value = row.cell.value; let cells_pi = cells_tree::PublicInputs::sample(cells_multiplier); // Compute the row digest. let row_digest = row.digest(&cells_tree::PublicInputs::from_slice(&cells_pi)); @@ -238,6 +240,7 @@ pub(crate) mod test { .chain(left_pi.min_value().to_fields()) .chain(right_pi.max_value().to_fields()) .chain(once(id)) + .chain(value.to_fields()) .chain(cells_pi.node_hash().to_fields()) .collect_vec(); let hash = H::hash_no_pad(&inputs); diff --git a/verifiable-db/src/row_tree/leaf.rs b/verifiable-db/src/row_tree/leaf.rs index 48dd243aa..53bd99a10 100644 --- a/verifiable-db/src/row_tree/leaf.rs +++ b/verifiable-db/src/row_tree/leaf.rs @@ -52,6 +52,7 @@ impl LeafCircuit { .chain(value.clone()) .chain(value.clone()) .chain(once(id)) + .chain(value.clone()) .chain(cells_pis.node_hash_target()) .collect::>(); let row_hash = b.hash_n_to_hash_no_pad::(inputs); @@ -188,6 +189,7 @@ mod test { .chain(value.iter()) .chain(value.iter()) .chain(once(&id)) + .chain(value.iter()) .chain(cells_pi.to_node_hash_raw()) .cloned() .collect_vec(); diff --git a/verifiable-db/src/row_tree/partial_node.rs b/verifiable-db/src/row_tree/partial_node.rs index 24bc9143f..1044b7ad0 100644 --- a/verifiable-db/src/row_tree/partial_node.rs +++ b/verifiable-db/src/row_tree/partial_node.rs @@ -98,6 +98,7 @@ impl PartialNodeCircuit { .iter() .chain(node_max.to_targets().iter()) .chain(once(&id)) + .chain(value.to_targets().iter()) .chain(cells_pi.node_hash_target().iter()) .cloned() .collect::>(); @@ -331,6 +332,7 @@ pub mod test { .chain(node_min.to_fields()) .chain(node_max.to_fields()) .chain(once(id)) + .chain(value.to_fields()) .chain(cells_pi.node_hash().to_fields()) .collect_vec(); let exp_root_hash = H::hash_no_pad(&inputs); From 2717f9830b547957532ea319ce1886dfb3d3c7bd Mon Sep 17 00:00:00 2001 From: Steven Date: Wed, 30 Oct 2024 08:41:09 +0800 Subject: [PATCH 07/16] Update verifiable-db/src/row_tree/public_inputs.rs Co-authored-by: nicholas-mainardi --- verifiable-db/src/row_tree/public_inputs.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/verifiable-db/src/row_tree/public_inputs.rs b/verifiable-db/src/row_tree/public_inputs.rs index 52ed490dd..a67951aec 100644 --- a/verifiable-db/src/row_tree/public_inputs.rs +++ b/verifiable-db/src/row_tree/public_inputs.rs @@ -72,7 +72,7 @@ impl<'a, T: Clone> PublicInputs<'a, T> { CURVE_TARGET_LEN, // `H2Int(H("") || multiplier_md)`, where `multiplier_md` is the metadata digest of cells accumulated in `multiplier_digest` HASH_TO_INT_LEN, - // Minimum alue of the secondary index stored up to this node + // Minimum value of the secondary index stored up to this node u256::NUM_LIMBS, // Maximum value of the secondary index stored up to this node u256::NUM_LIMBS, From a5ae612f2a36e1470dce1e0ac715ec3c1d27bc80 Mon Sep 17 00:00:00 2001 From: Steven Gu Date: Wed, 30 Oct 2024 08:27:25 +0800 Subject: [PATCH 08/16] Fix to ensure `extraction_proof.is_merge or rows_tree_proof.multiplier_vd == 0`. --- verifiable-db/src/block_tree/mod.rs | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/verifiable-db/src/block_tree/mod.rs b/verifiable-db/src/block_tree/mod.rs index 72c79fd4d..7b7541218 100644 --- a/verifiable-db/src/block_tree/mod.rs +++ b/verifiable-db/src/block_tree/mod.rs @@ -94,21 +94,11 @@ where // Enforce that if we aren't in merge case, then no cells were accumulated in // multiplier digest: - // assert extraction_proof.is_merge or rows_tree_proof.multiplier_vd != 0 - // => (1 - is_merge) * is_multiplier_vd_zero == false - let ffalse = b._false(); + // assert extraction_proof.is_merge or rows_tree_proof.multiplier_vd == 0 let curve_zero = b.curve_zero(); - let is_multiplier_vd_zero = b - .curve_eq(rows_tree_pi.multiplier_digest_target(), curve_zero) - .target; - let should_be_false = b.arithmetic( - F::NEG_ONE, - F::ONE, - extraction_pi.is_merge_case().target, - is_multiplier_vd_zero, - is_multiplier_vd_zero, - ); - b.connect(should_be_false, ffalse.target); + let is_multiplier_vd_zero = b.curve_eq(rows_tree_pi.multiplier_digest_target(), curve_zero); + let acc = b.or(extraction_pi.is_merge_case(), is_multiplier_vd_zero); + b.assert_one(acc.target); final_digest } From ca0a97126c93b827a645ec4f17e3a049939c9eb6 Mon Sep 17 00:00:00 2001 From: Steven Gu Date: Wed, 30 Oct 2024 15:39:25 +0800 Subject: [PATCH 09/16] Fix to use `split_and_accumulate`. --- verifiable-db/src/cells_tree/mod.rs | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/verifiable-db/src/cells_tree/mod.rs b/verifiable-db/src/cells_tree/mod.rs index 3cafa8c70..10a01d846 100644 --- a/verifiable-db/src/cells_tree/mod.rs +++ b/verifiable-db/src/cells_tree/mod.rs @@ -126,18 +126,18 @@ impl CellWire { pub fn split_and_accumulate_metadata_digest( &self, b: &mut CBuilder, - child_digest: SplitDigestTarget, + child_digest: &SplitDigestTarget, ) -> SplitDigestTarget { let split_digest = self.split_metadata_digest(b); - split_digest.accumulate(b, &child_digest) + split_digest.accumulate(b, child_digest) } pub fn split_and_accumulate_values_digest( &self, b: &mut CBuilder, - child_digest: SplitDigestTarget, + child_digest: &SplitDigestTarget, ) -> SplitDigestTarget { let split_digest = self.split_values_digest(b); - split_digest.accumulate(b, &child_digest) + split_digest.accumulate(b, child_digest) } fn metadata_digest(&self, b: &mut CBuilder) -> CurveTarget { // D(mpt_metadata || identifier) @@ -214,11 +214,9 @@ pub(crate) mod tests { }; let cell = CellWire::new(b); - let values_digest = cell.split_values_digest(b); - let metadata_digest = cell.split_metadata_digest(b); - - let values_digest = values_digest.accumulate(b, &child_values_digest); - let metadata_digest = metadata_digest.accumulate(b, &child_metadata_digest); + let values_digest = cell.split_and_accumulate_values_digest(b, &child_values_digest); + let metadata_digest = + cell.split_and_accumulate_metadata_digest(b, &child_metadata_digest); b.register_curve_public_input(values_digest.individual); b.register_curve_public_input(values_digest.multiplier); From 2c94c126402036c2639511c7151bff3747bd73a8 Mon Sep 17 00:00:00 2001 From: Steven Gu Date: Wed, 30 Oct 2024 15:53:02 +0800 Subject: [PATCH 10/16] Fix the test coverage for the all cells node circuits. --- verifiable-db/src/cells_tree/full_node.rs | 39 +++++++++++++++++--- verifiable-db/src/cells_tree/partial_node.rs | 16 +++++--- 2 files changed, 45 insertions(+), 10 deletions(-) diff --git a/verifiable-db/src/cells_tree/full_node.rs b/verifiable-db/src/cells_tree/full_node.rs index a35b29408..42584a534 100644 --- a/verifiable-db/src/cells_tree/full_node.rs +++ b/verifiable-db/src/cells_tree/full_node.rs @@ -128,19 +128,48 @@ mod tests { } #[test] - fn test_cells_tree_full_node_circuit() { - test_cells_tree_full_multiplier(true); - test_cells_tree_full_multiplier(false); + fn test_cells_tree_full_node_individual() { + [true, false] + .into_iter() + .cartesian_product([true, false]) + .for_each(|(is_left_child_multiplier, is_right_child_multiplier)| { + test_cells_tree_full_multiplier( + false, + is_left_child_multiplier, + is_right_child_multiplier, + ); + }); + } + + #[test] + fn test_cells_tree_full_node_multiplier() { + [true, false] + .into_iter() + .cartesian_product([true, false]) + .for_each(|(is_left_child_multiplier, is_right_child_multiplier)| { + test_cells_tree_full_multiplier( + true, + is_left_child_multiplier, + is_right_child_multiplier, + ); + }); } - fn test_cells_tree_full_multiplier(is_multiplier: bool) { + fn test_cells_tree_full_multiplier( + is_multiplier: bool, + is_left_child_multiplier: bool, + is_right_child_multiplier: bool, + ) { let cell = Cell::sample(is_multiplier); let id = cell.identifier; let value = cell.value; let values_digests = cell.split_values_digest(); let metadata_digests = cell.split_metadata_digest(); - let child_pis = &array::from_fn(|_| PublicInputs::::sample(is_multiplier)); + let child_pis = &[ + PublicInputs::::sample(is_left_child_multiplier), + PublicInputs::::sample(is_right_child_multiplier), + ]; let test_circuit = TestFullNodeCircuit { c: cell.into(), diff --git a/verifiable-db/src/cells_tree/partial_node.rs b/verifiable-db/src/cells_tree/partial_node.rs index f7b8025f2..6be53d584 100644 --- a/verifiable-db/src/cells_tree/partial_node.rs +++ b/verifiable-db/src/cells_tree/partial_node.rs @@ -124,19 +124,25 @@ mod tests { } #[test] - fn test_cells_tree_partial_node_circuit() { - test_cells_tree_partial_multiplier(true); - test_cells_tree_partial_multiplier(false); + fn test_cells_tree_partial_node_individual() { + test_cells_tree_partial_multiplier(false, true); + test_cells_tree_partial_multiplier(false, false); } - fn test_cells_tree_partial_multiplier(is_multiplier: bool) { + #[test] + fn test_cells_tree_partial_node_multiplier() { + test_cells_tree_partial_multiplier(true, true); + test_cells_tree_partial_multiplier(true, false); + } + + fn test_cells_tree_partial_multiplier(is_multiplier: bool, is_child_multiplier: bool) { let cell = Cell::sample(is_multiplier); let id = cell.identifier; let value = cell.value; let values_digests = cell.split_values_digest(); let metadata_digests = cell.split_metadata_digest(); - let child_pi = &PublicInputs::::sample(is_multiplier); + let child_pi = &PublicInputs::::sample(is_child_multiplier); let test_circuit = TestPartialNodeCircuit { c: cell.into(), From 6b791c92f78b89f4f92e72e6b557c6a3d6b6cc10 Mon Sep 17 00:00:00 2001 From: Steven Gu Date: Wed, 30 Oct 2024 16:14:41 +0800 Subject: [PATCH 11/16] Remove merge flag in rows tree public inputs. --- verifiable-db/src/block_tree/leaf.rs | 6 --- verifiable-db/src/block_tree/mod.rs | 14 +++---- verifiable-db/src/block_tree/parent.rs | 6 --- verifiable-db/src/row_tree/api.rs | 6 --- verifiable-db/src/row_tree/full_node.rs | 8 ---- verifiable-db/src/row_tree/leaf.rs | 3 -- verifiable-db/src/row_tree/partial_node.rs | 7 ---- verifiable-db/src/row_tree/public_inputs.rs | 43 +++------------------ verifiable-db/src/row_tree/row.rs | 13 +------ 9 files changed, 12 insertions(+), 94 deletions(-) diff --git a/verifiable-db/src/block_tree/leaf.rs b/verifiable-db/src/block_tree/leaf.rs index d9080e58e..809113fe0 100644 --- a/verifiable-db/src/block_tree/leaf.rs +++ b/verifiable-db/src/block_tree/leaf.rs @@ -97,12 +97,6 @@ impl LeafCircuit { .collect(); let h_new = b.hash_n_to_hash_no_pad::(inputs).to_targets(); - // check that the rows tree built is for a merged table iff we extract data from MPT for a merged table - b.connect( - rows_tree_pi.merge_flag_target().target, - extraction_pi.is_merge_case().target, - ); - // Register the public inputs. PublicInputs::new( &h_new, diff --git a/verifiable-db/src/block_tree/mod.rs b/verifiable-db/src/block_tree/mod.rs index 7b7541218..862cd8c12 100644 --- a/verifiable-db/src/block_tree/mod.rs +++ b/verifiable-db/src/block_tree/mod.rs @@ -165,16 +165,14 @@ pub(crate) mod tests { /// Generate a random rows tree public inputs. pub(crate) fn random_rows_tree_pi(rng: &mut ThreadRng, is_merge_case: bool) -> Vec { let [min, max] = array::from_fn(|_| rng.gen()); - let multiplier_digest = Point::rand(); + let multiplier_digest = if is_merge_case { + Point::rand() + } else { + Point::NEUTRAL + }; let row_id_multiplier = BigUint::from_slice(&random_vector::(HASH_TO_INT_LEN)); - row_tree::PublicInputs::sample( - multiplier_digest, - row_id_multiplier, - min, - max, - is_merge_case, - ) + row_tree::PublicInputs::sample(multiplier_digest, row_id_multiplier, min, max) } /// Generate a random extraction public inputs. diff --git a/verifiable-db/src/block_tree/parent.rs b/verifiable-db/src/block_tree/parent.rs index 0518a7692..fd0b9330c 100644 --- a/verifiable-db/src/block_tree/parent.rs +++ b/verifiable-db/src/block_tree/parent.rs @@ -148,12 +148,6 @@ impl ParentCircuit { .collect(); let h_new = b.hash_n_to_hash_no_pad::(inputs).elements; - // check that the rows tree built is for a merged table iff we extract data from MPT for a merged table - b.connect( - rows_tree_pi.merge_flag_target().target, - extraction_pi.is_merge_case().target, - ); - // Register the public inputs. PublicInputs::new( &h_new, diff --git a/verifiable-db/src/row_tree/api.rs b/verifiable-db/src/row_tree/api.rs index 81f28c394..2bafefd5f 100644 --- a/verifiable-db/src/row_tree/api.rs +++ b/verifiable-db/src/row_tree/api.rs @@ -497,8 +497,6 @@ mod test { assert_eq!(pi.min_value(), value.min(child_min)); // Check maximum value assert_eq!(pi.max_value(), value.max(child_max)); - // Check merge flag - assert_eq!(pi.merge_flag(), row_digest.is_merge); Ok(vec![]) } @@ -561,8 +559,6 @@ mod test { ); // Check row ID multiplier assert_eq!(pi.row_id_multiplier(), row_digest.row_id_multiplier); - // Check merge flag - assert_eq!(pi.merge_flag(), row_digest.is_merge); Ok(proof) } @@ -625,8 +621,6 @@ mod test { assert_eq!(pi.min_value(), value); // Check maximum value assert_eq!(pi.max_value(), value); - // Check merge flag - assert_eq!(pi.merge_flag(), row_digest.is_merge); Ok(proof) } diff --git a/verifiable-db/src/row_tree/full_node.rs b/verifiable-db/src/row_tree/full_node.rs index cedb65a1f..de61e4d88 100644 --- a/verifiable-db/src/row_tree/full_node.rs +++ b/verifiable-db/src/row_tree/full_node.rs @@ -81,9 +81,6 @@ impl FullNodeCircuit { .collect::>(); let hash = b.hash_n_to_hash_no_pad::(inputs); - // assert `is_merge` is the same as the flags in children pis - b.connect(min_child.merge_flag_target().target, digest.is_merge.target); - b.connect(max_child.merge_flag_target().target, digest.is_merge.target); PublicInputs::new( &hash.to_targets(), &digest.individual_vd.to_targets(), @@ -91,7 +88,6 @@ impl FullNodeCircuit { &digest.row_id_multiplier.to_targets(), &node_min.to_targets(), &node_max.to_targets(), - &[digest.is_merge.target], ) .register(b); FullNodeWires(row) @@ -208,14 +204,12 @@ pub(crate) mod test { row_digest.row_id_multiplier.clone(), left_min, left_max, - is_multiplier || cells_multiplier, ); let right_pi = PublicInputs::sample( row_digest.multiplier_vd, row_digest.row_id_multiplier.clone(), right_min, right_max, - is_multiplier || cells_multiplier, ); let test_circuit = TestFullNodeCircuit { circuit: node_circuit, @@ -262,8 +256,6 @@ pub(crate) mod test { assert_eq!(pi.min_value(), U256::from(left_min)); // Check maximum value assert_eq!(pi.max_value(), U256::from(right_max)); - // Check merge flag - assert_eq!(pi.merge_flag(), row_digest.is_merge); } #[test] diff --git a/verifiable-db/src/row_tree/leaf.rs b/verifiable-db/src/row_tree/leaf.rs index 53bd99a10..bba75a84e 100644 --- a/verifiable-db/src/row_tree/leaf.rs +++ b/verifiable-db/src/row_tree/leaf.rs @@ -63,7 +63,6 @@ impl LeafCircuit { &digest.row_id_multiplier.to_targets(), &value, &value, - &[digest.is_merge.target], ) .register(b); @@ -212,8 +211,6 @@ mod test { assert_eq!(pi.min_value(), value); // Check maximum value assert_eq!(pi.max_value(), value); - // Check merge flag - assert_eq!(pi.merge_flag(), row_digest.is_merge); } #[test] diff --git a/verifiable-db/src/row_tree/partial_node.rs b/verifiable-db/src/row_tree/partial_node.rs index 1044b7ad0..5c7833f61 100644 --- a/verifiable-db/src/row_tree/partial_node.rs +++ b/verifiable-db/src/row_tree/partial_node.rs @@ -112,9 +112,6 @@ impl PartialNodeCircuit { &rest, ); - // assert is_merge is the same between this row and `child_pi` - b.connect(digest.is_merge.target, child_pi.merge_flag_target().target); - PublicInputs::new( &node_hash, &digest.individual_vd.to_targets(), @@ -122,7 +119,6 @@ impl PartialNodeCircuit { &digest.row_id_multiplier.to_targets(), &node_min.to_targets(), &node_max.to_targets(), - &[digest.is_merge.target], ) .register(b); PartialNodeWires { @@ -299,7 +295,6 @@ pub mod test { row_digest.row_id_multiplier.clone(), child_min.to(), child_max.to(), - is_cell_multiplier || is_multiplier, ); let test_circuit = TestPartialNodeCircuit { circuit: node_circuit, @@ -354,7 +349,5 @@ pub mod test { assert_eq!(pi.min_value(), value.min(child_min)); // Check maximum value assert_eq!(pi.max_value(), value.max(child_max)); - // Check merge flag - assert_eq!(pi.merge_flag(), row_digest.is_merge); } } diff --git a/verifiable-db/src/row_tree/public_inputs.rs b/verifiable-db/src/row_tree/public_inputs.rs index a67951aec..cb2100ac3 100644 --- a/verifiable-db/src/row_tree/public_inputs.rs +++ b/verifiable-db/src/row_tree/public_inputs.rs @@ -7,19 +7,18 @@ use mp2_common::{ public_inputs::{PublicInputCommon, PublicInputRange}, types::{CBuilder, CURVE_TARGET_LEN}, u256::{self, UInt256Target}, - utils::{FromFields, FromTargets, TryIntoBool}, + utils::{FromFields, FromTargets}, F, }; use num::BigUint; use plonky2::{ field::types::PrimeField64, hash::hash_types::{HashOut, NUM_HASH_OUT_ELTS}, - iop::target::{BoolTarget, Target}, + iop::target::Target, }; use plonky2_crypto::u32::arithmetic_u32::U32Target; use plonky2_ecdsa::gadgets::biguint::BigUintTarget; use plonky2_ecgfp5::{curve::curve::WeierstrassPoint, gadgets::curve::CurveTarget}; -use std::iter::once; pub enum RowsTreePublicInputs { // `H : F[4]` - Poseidon hash of the leaf @@ -34,8 +33,6 @@ pub enum RowsTreePublicInputs { MinValue, // `max : Uint256` - Maximum value of the secondary index stored up to this node MaxValue, - // `merge : bool` - Flag specifying whether we are building rows for a merge table or not - MergeFlag, } /// Public inputs for Rows Tree Construction @@ -47,10 +44,9 @@ pub struct PublicInputs<'a, T> { pub(crate) row_id_multiplier: &'a [T], pub(crate) min: &'a [T], pub(crate) max: &'a [T], - pub(crate) merge: &'a T, } -const NUM_PUBLIC_INPUTS: usize = RowsTreePublicInputs::MergeFlag as usize + 1; +const NUM_PUBLIC_INPUTS: usize = RowsTreePublicInputs::MaxValue as usize + 1; impl<'a, T: Clone> PublicInputs<'a, T> { const PI_RANGES: [PublicInputRange; NUM_PUBLIC_INPUTS] = [ @@ -60,7 +56,6 @@ impl<'a, T: Clone> PublicInputs<'a, T> { Self::to_range(RowsTreePublicInputs::RowIdMultiplier), Self::to_range(RowsTreePublicInputs::MinValue), Self::to_range(RowsTreePublicInputs::MaxValue), - Self::to_range(RowsTreePublicInputs::MergeFlag), ]; const SIZES: [usize; NUM_PUBLIC_INPUTS] = [ @@ -76,8 +71,6 @@ impl<'a, T: Clone> PublicInputs<'a, T> { u256::NUM_LIMBS, // Maximum value of the secondary index stored up to this node u256::NUM_LIMBS, - // Flag specifying whether we are building rows for a merge table or not - 1, ]; pub(crate) const fn to_range(pi: RowsTreePublicInputs) -> PublicInputRange { @@ -92,7 +85,7 @@ impl<'a, T: Clone> PublicInputs<'a, T> { } pub const fn total_len() -> usize { - Self::to_range(RowsTreePublicInputs::MergeFlag).end + Self::to_range(RowsTreePublicInputs::MaxValue).end } pub fn to_root_hash_raw(&self) -> &[T] { @@ -119,10 +112,6 @@ impl<'a, T: Clone> PublicInputs<'a, T> { self.max } - pub fn to_merge_flag_raw(&self) -> &T { - self.merge - } - pub fn from_slice(input: &'a [T]) -> Self { assert!( input.len() >= Self::total_len(), @@ -137,7 +126,6 @@ impl<'a, T: Clone> PublicInputs<'a, T> { row_id_multiplier: &input[Self::PI_RANGES[3].clone()], min: &input[Self::PI_RANGES[4].clone()], max: &input[Self::PI_RANGES[5].clone()], - merge: &input[Self::PI_RANGES[6].clone()][0], } } @@ -148,7 +136,6 @@ impl<'a, T: Clone> PublicInputs<'a, T> { row_id_multiplier: &'a [T], min: &'a [T], max: &'a [T], - merge: &'a [T], ) -> Self { Self { h, @@ -157,7 +144,6 @@ impl<'a, T: Clone> PublicInputs<'a, T> { row_id_multiplier, min, max, - merge: &merge[0], } } @@ -169,7 +155,6 @@ impl<'a, T: Clone> PublicInputs<'a, T> { .chain(self.row_id_multiplier) .chain(self.min) .chain(self.max) - .chain(once(self.merge)) .cloned() .collect() } @@ -185,7 +170,6 @@ impl<'a> PublicInputCommon for PublicInputs<'a, Target> { cb.register_public_inputs(self.row_id_multiplier); cb.register_public_inputs(self.min); cb.register_public_inputs(self.max); - cb.register_public_input(*self.merge); } } @@ -220,10 +204,6 @@ impl<'a> PublicInputs<'a, Target> { pub fn max_value_target(&self) -> UInt256Target { UInt256Target::from_targets(self.max) } - - pub fn merge_flag_target(&self) -> BoolTarget { - BoolTarget::new_unsafe(*self.merge) - } } impl<'a> PublicInputs<'a, F> { @@ -256,10 +236,6 @@ impl<'a> PublicInputs<'a, F> { pub fn max_value(&self) -> U256 { U256::from_fields(self.max) } - - pub fn merge_flag(&self) -> bool { - self.merge.try_into_bool().unwrap() - } } #[cfg(test)] @@ -279,7 +255,7 @@ pub(crate) mod tests { }; use plonky2_ecgfp5::curve::curve::Point; use rand::{thread_rng, Rng}; - use std::{array, slice}; + use std::array; impl<'a> PublicInputs<'a, F> { pub(crate) fn sample( @@ -287,7 +263,6 @@ pub(crate) mod tests { row_id_multiplier: BigUint, min: usize, max: usize, - is_merge: bool, ) -> Vec { let h = HashOut::rand().to_fields(); let individual_digest = Point::rand(); @@ -299,7 +274,6 @@ pub(crate) mod tests { .map(F::from_canonical_u32) .collect_vec(); let [min, max] = [min, max].map(|v| U256::from(v).to_fields()); - let merge = F::from_bool(is_merge); PublicInputs::new( &h, &individual_digest, @@ -307,7 +281,6 @@ pub(crate) mod tests { &row_id_multiplier, &min, &max, - &[merge], ) .to_vec() } @@ -343,7 +316,6 @@ pub(crate) mod tests { array::from_fn(|_| Point::sample(rng).to_weierstrass().to_fields()); let row_id_multiplier = rng.gen::<[u32; 4]>().map(F::from_canonical_u32); let [min, max] = array::from_fn(|_| U256::from_limbs(rng.gen()).to_fields()); - let merge = [F::from_bool(rng.gen_bool(0.5))]; let exp_pi = PublicInputs::new( &h, &individual_digest, @@ -351,7 +323,6 @@ pub(crate) mod tests { &row_id_multiplier, &min, &max, - &merge, ); let exp_pi = &exp_pi.to_vec(); @@ -385,9 +356,5 @@ pub(crate) mod tests { &exp_pi[PublicInputs::::to_range(RowsTreePublicInputs::MaxValue)], pi.to_max_value_raw(), ); - assert_eq!( - &exp_pi[PublicInputs::::to_range(RowsTreePublicInputs::MergeFlag)], - slice::from_ref(pi.to_merge_flag_raw()), - ); } } diff --git a/verifiable-db/src/row_tree/row.rs b/verifiable-db/src/row_tree/row.rs index c4cb9569b..a99907903 100644 --- a/verifiable-db/src/row_tree/row.rs +++ b/verifiable-db/src/row_tree/row.rs @@ -16,7 +16,7 @@ use plonky2::{ field::types::{Field, PrimeField64}, hash::hash_types::{HashOut, HashOutTarget}, iop::{ - target::{BoolTarget, Target}, + target::Target, witness::{PartialWitness, WitnessWrite}, }, plonk::config::Hasher, @@ -30,7 +30,6 @@ use serde::{Deserialize, Serialize}; #[derive(Clone, Debug, Eq, PartialEq)] pub(crate) struct RowDigest { - pub(crate) is_merge: bool, pub(crate) row_id_multiplier: BigUint, pub(crate) individual_vd: Point, pub(crate) multiplier_vd: Point, @@ -40,9 +39,6 @@ impl FromFields for RowDigest { fn from_fields(t: &[F]) -> Self { let mut pos = 0; - let is_merge = t[pos].is_nonzero(); - pos += 1; - let row_id_multiplier = BigUint::new( t[pos..pos + HASH_TO_INT_LEN] .iter() @@ -57,7 +53,6 @@ impl FromFields for RowDigest { let multiplier_vd = Point::from_fields(&t[pos..pos + CURVE_TARGET_LEN]); Self { - is_merge, row_id_multiplier, individual_vd, multiplier_vd, @@ -67,7 +62,6 @@ impl FromFields for RowDigest { #[derive(Clone, Debug)] pub(crate) struct RowDigestTarget { - pub(crate) is_merge: BoolTarget, pub(crate) row_id_multiplier: BigUintTarget, pub(crate) individual_vd: CurveTarget, pub(crate) multiplier_vd: CurveTarget, @@ -120,11 +114,9 @@ impl Row { let hash = H::hash_no_pad(&inputs); let row_id_multiplier = hash_to_int_value(hash); - let is_merge = values_digests.is_merge_case(); let multiplier_vd = values_digests.multiplier; RowDigest { - is_merge, row_id_multiplier, individual_vd, multiplier_vd, @@ -196,11 +188,9 @@ impl RowWire { let row_id_multiplier = hash_to_int_target(b, hash); assert_eq!(row_id_multiplier.num_limbs(), HASH_TO_INT_LEN); - let is_merge = values_digests.is_merge_case(b); let multiplier_vd = values_digests.multiplier; RowDigestTarget { - is_merge, row_id_multiplier, individual_vd, multiplier_vd, @@ -242,7 +232,6 @@ pub(crate) mod tests { let digest = row.digest(b, &cells_pi); - b.register_public_input(digest.is_merge.target); b.register_public_inputs(&digest.row_id_multiplier.to_targets()); b.register_public_inputs(&digest.individual_vd.to_targets()); b.register_public_inputs(&digest.multiplier_vd.to_targets()); From 16bea456b05cefcfd0ccbbc12552cbd86bf7c0b2 Mon Sep 17 00:00:00 2001 From: Steven Gu Date: Wed, 30 Oct 2024 16:19:19 +0800 Subject: [PATCH 12/16] Rename `assign_wires` to `assign`. --- mp2-common/src/mpt_sequential/mod.rs | 4 ++-- verifiable-db/src/cells_tree/full_node.rs | 2 +- verifiable-db/src/cells_tree/leaf.rs | 2 +- verifiable-db/src/cells_tree/mod.rs | 4 ++-- verifiable-db/src/cells_tree/partial_node.rs | 2 +- verifiable-db/src/row_tree/full_node.rs | 2 +- verifiable-db/src/row_tree/leaf.rs | 2 +- verifiable-db/src/row_tree/partial_node.rs | 2 +- verifiable-db/src/row_tree/row.rs | 6 +++--- 9 files changed, 13 insertions(+), 13 deletions(-) diff --git a/mp2-common/src/mpt_sequential/mod.rs b/mp2-common/src/mpt_sequential/mod.rs index 4acb7a46f..a887c9d06 100644 --- a/mp2-common/src/mpt_sequential/mod.rs +++ b/mp2-common/src/mpt_sequential/mod.rs @@ -236,7 +236,7 @@ where /// Assign the nodes to the wires. The reason we have the output wires /// as well is due to the keccak circuit that requires some special assignement /// from the raw vectors. - pub fn assign_wires, const D: usize>( + pub fn assign, const D: usize>( &self, p: &mut PartialWitness, inputs: &InputWires, @@ -497,7 +497,7 @@ mod test { } fn prove(&self, pw: &mut PartialWitness, wires: &Self::Wires) { - self.c.assign_wires(pw, &wires.0, &wires.1).unwrap(); + self.c.assign(pw, &wires.0, &wires.1).unwrap(); wires.2.assign( pw, &create_array(|i| F::from_canonical_u8(self.exp_root[i])), diff --git a/verifiable-db/src/cells_tree/full_node.rs b/verifiable-db/src/cells_tree/full_node.rs index 42584a534..2705fbb0e 100644 --- a/verifiable-db/src/cells_tree/full_node.rs +++ b/verifiable-db/src/cells_tree/full_node.rs @@ -59,7 +59,7 @@ impl FullNodeCircuit { /// Assign the wires. fn assign(&self, pw: &mut PartialWitness, wires: &FullNodeWires) { - self.0.assign_wires(pw, &wires.0); + self.0.assign(pw, &wires.0); } } diff --git a/verifiable-db/src/cells_tree/leaf.rs b/verifiable-db/src/cells_tree/leaf.rs index 908dcfc41..4d8d7663e 100644 --- a/verifiable-db/src/cells_tree/leaf.rs +++ b/verifiable-db/src/cells_tree/leaf.rs @@ -55,7 +55,7 @@ impl LeafCircuit { /// Assign the wires. fn assign(&self, pw: &mut PartialWitness, wires: &LeafWires) { - self.0.assign_wires(pw, &wires.0); + self.0.assign(pw, &wires.0); } } diff --git a/verifiable-db/src/cells_tree/mod.rs b/verifiable-db/src/cells_tree/mod.rs index 10a01d846..13a974dbb 100644 --- a/verifiable-db/src/cells_tree/mod.rs +++ b/verifiable-db/src/cells_tree/mod.rs @@ -46,7 +46,7 @@ pub struct Cell { } impl Cell { - pub(crate) fn assign_wires(&self, pw: &mut PartialWitness, wires: &CellWire) { + pub(crate) fn assign(&self, pw: &mut PartialWitness, wires: &CellWire) { pw.set_u256_target(&wires.value, self.value); pw.set_target(wires.identifier, self.identifier); pw.set_bool_target(wires.is_multiplier, self.is_multiplier); @@ -227,7 +227,7 @@ pub(crate) mod tests { } fn prove(&self, pw: &mut PartialWitness, wires: &Self::Wires) { - self.cell.assign_wires(pw, &wires.0); + self.cell.assign(pw, &wires.0); pw.set_curve_target( wires.1.individual, self.child_values_digest.individual.to_weierstrass(), diff --git a/verifiable-db/src/cells_tree/partial_node.rs b/verifiable-db/src/cells_tree/partial_node.rs index 6be53d584..eb73ec92e 100644 --- a/verifiable-db/src/cells_tree/partial_node.rs +++ b/verifiable-db/src/cells_tree/partial_node.rs @@ -65,7 +65,7 @@ impl PartialNodeCircuit { /// Assign the wires. fn assign(&self, pw: &mut PartialWitness, wires: &PartialNodeWires) { - self.0.assign_wires(pw, &wires.0); + self.0.assign(pw, &wires.0); } } diff --git a/verifiable-db/src/row_tree/full_node.rs b/verifiable-db/src/row_tree/full_node.rs index de61e4d88..36c5a3760 100644 --- a/verifiable-db/src/row_tree/full_node.rs +++ b/verifiable-db/src/row_tree/full_node.rs @@ -93,7 +93,7 @@ impl FullNodeCircuit { FullNodeWires(row) } fn assign(&self, pw: &mut PartialWitness, wires: &FullNodeWires) { - self.0.assign_wires(pw, &wires.0); + self.0.assign(pw, &wires.0); } } diff --git a/verifiable-db/src/row_tree/leaf.rs b/verifiable-db/src/row_tree/leaf.rs index bba75a84e..4738c537a 100644 --- a/verifiable-db/src/row_tree/leaf.rs +++ b/verifiable-db/src/row_tree/leaf.rs @@ -70,7 +70,7 @@ impl LeafCircuit { } fn assign(&self, pw: &mut PartialWitness, wires: &LeafWires) { - self.0.assign_wires(pw, &wires.0); + self.0.assign(pw, &wires.0); } } diff --git a/verifiable-db/src/row_tree/partial_node.rs b/verifiable-db/src/row_tree/partial_node.rs index 5c7833f61..047bab775 100644 --- a/verifiable-db/src/row_tree/partial_node.rs +++ b/verifiable-db/src/row_tree/partial_node.rs @@ -128,7 +128,7 @@ impl PartialNodeCircuit { } fn assign(&self, pw: &mut PartialWitness, wires: &PartialNodeWires) { - self.row.assign_wires(pw, &wires.row); + self.row.assign(pw, &wires.row); pw.set_bool_target(wires.is_child_at_left, self.is_child_at_left); } } diff --git a/verifiable-db/src/row_tree/row.rs b/verifiable-db/src/row_tree/row.rs index a99907903..ca57a86c4 100644 --- a/verifiable-db/src/row_tree/row.rs +++ b/verifiable-db/src/row_tree/row.rs @@ -74,8 +74,8 @@ pub(crate) struct Row { } impl Row { - pub(crate) fn assign_wires(&self, pw: &mut PartialWitness, wires: &RowWire) { - self.cell.assign_wires(pw, &wires.cell); + pub(crate) fn assign(&self, pw: &mut PartialWitness, wires: &RowWire) { + self.cell.assign(pw, &wires.cell); pw.set_hash_target(wires.row_unique_data, self.row_unique_data); } @@ -240,7 +240,7 @@ pub(crate) mod tests { } fn prove(&self, pw: &mut PartialWitness, wires: &Self::Wires) { - self.row.assign_wires(pw, &wires.0); + self.row.assign(pw, &wires.0); pw.set_target_arr(&wires.1, self.cells_pi); } } From 1f70eb41381717f8e2dd64821bb070942b12a09e Mon Sep 17 00:00:00 2001 From: Steven Gu Date: Wed, 30 Oct 2024 16:27:50 +0800 Subject: [PATCH 13/16] Fix to use `PublicInputs::sample`. --- verifiable-db/src/row_tree/public_inputs.rs | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/verifiable-db/src/row_tree/public_inputs.rs b/verifiable-db/src/row_tree/public_inputs.rs index cb2100ac3..3415ca0bd 100644 --- a/verifiable-db/src/row_tree/public_inputs.rs +++ b/verifiable-db/src/row_tree/public_inputs.rs @@ -311,19 +311,10 @@ pub(crate) mod tests { let rng = &mut thread_rng(); // Prepare the public inputs. - let h = random_vector::(NUM_HASH_OUT_ELTS).to_fields(); - let [individual_digest, multiplier_digest] = - array::from_fn(|_| Point::sample(rng).to_weierstrass().to_fields()); - let row_id_multiplier = rng.gen::<[u32; 4]>().map(F::from_canonical_u32); - let [min, max] = array::from_fn(|_| U256::from_limbs(rng.gen()).to_fields()); - let exp_pi = PublicInputs::new( - &h, - &individual_digest, - &multiplier_digest, - &row_id_multiplier, - &min, - &max, - ); + let multiplier_digest = Point::sample(rng); + let row_id_multiplier = BigUint::from_slice(&random_vector::(HASH_TO_INT_LEN)); + let [min, max] = array::from_fn(|_| rng.gen()); + let exp_pi = PublicInputs::sample(multiplier_digest, row_id_multiplier, min, max); let exp_pi = &exp_pi.to_vec(); let test_circuit = TestPublicInputs { exp_pi }; From 21bb63723cd7c7b422b3f7926112105abb55e342 Mon Sep 17 00:00:00 2001 From: Steven Gu Date: Wed, 30 Oct 2024 16:32:40 +0800 Subject: [PATCH 14/16] Fix to `p.partial`. --- verifiable-db/src/row_tree/api.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/verifiable-db/src/row_tree/api.rs b/verifiable-db/src/row_tree/api.rs index 2bafefd5f..6fd85ffb2 100644 --- a/verifiable-db/src/row_tree/api.rs +++ b/verifiable-db/src/row_tree/api.rs @@ -416,17 +416,17 @@ mod test { log::info!("Generating full proof (from leaf 1 and leaf 2)"); let full_proof = generate_full_proof(¶ms, children_proof)?; log::info!("Generating partial proof (from full proof)"); - let _ = generate_partial_proof(¶ms, params.partial.clone(), true, full_proof)?; + let _ = generate_partial_proof(¶ms, true, full_proof)?; log::info!("Test done"); Ok(()) } fn generate_partial_proof( p: &TestParams, - row: Row, is_left: bool, child_proof_buff: Vec, ) -> Result> { + let row = &p.partial; let id = row.cell.identifier; let value = row.cell.value; let mpt_metadata = row.cell.mpt_metadata; From 31f76855547f2c0e42fa1990239450d6fa687f22 Mon Sep 17 00:00:00 2001 From: Steven Gu Date: Wed, 30 Oct 2024 16:44:43 +0800 Subject: [PATCH 15/16] Delete the `ignore` comment for test `isolution`. --- parsil/src/tests.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/parsil/src/tests.rs b/parsil/src/tests.rs index 653a492ca..13b0e28fe 100644 --- a/parsil/src/tests.rs +++ b/parsil/src/tests.rs @@ -149,7 +149,6 @@ fn test_serde_circuit_pis() { } #[test] -#[ignore = "wait for non-aggregation SELECT to come back"] fn isolation() { fn isolated_to_string(q: &str, lo_sec: bool, hi_sec: bool) -> String { let settings = ParsilSettings { From e37e98e503bf534194ea4fbc10694c647e6b273e Mon Sep 17 00:00:00 2001 From: Steven Gu Date: Wed, 30 Oct 2024 22:55:43 +0800 Subject: [PATCH 16/16] Fix to use `split_and_accumulate`. --- verifiable-db/src/cells_tree/full_node.rs | 10 ++++------ verifiable-db/src/cells_tree/partial_node.rs | 9 ++++----- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/verifiable-db/src/cells_tree/full_node.rs b/verifiable-db/src/cells_tree/full_node.rs index 2705fbb0e..6dcc64754 100644 --- a/verifiable-db/src/cells_tree/full_node.rs +++ b/verifiable-db/src/cells_tree/full_node.rs @@ -25,13 +25,11 @@ impl FullNodeCircuit { let [p1, p2] = child_proofs; let cell = CellWire::new(b); - let metadata_digests = cell.split_metadata_digest(b); - let values_digests = cell.split_values_digest(b); - - let metadata_digests = metadata_digests.accumulate(b, &p1.split_metadata_digest_target()); + let metadata_digests = + cell.split_and_accumulate_metadata_digest(b, &p1.split_metadata_digest_target()); + let values_digests = + cell.split_and_accumulate_values_digest(b, &p1.split_values_digest_target()); let metadata_digests = metadata_digests.accumulate(b, &p2.split_metadata_digest_target()); - - let values_digests = values_digests.accumulate(b, &p1.split_values_digest_target()); let values_digests = values_digests.accumulate(b, &p2.split_values_digest_target()); // H(p1.H || p2.H || identifier || value) diff --git a/verifiable-db/src/cells_tree/partial_node.rs b/verifiable-db/src/cells_tree/partial_node.rs index eb73ec92e..2724e5554 100644 --- a/verifiable-db/src/cells_tree/partial_node.rs +++ b/verifiable-db/src/cells_tree/partial_node.rs @@ -27,11 +27,10 @@ pub struct PartialNodeCircuit(Cell); impl PartialNodeCircuit { pub fn build(b: &mut CBuilder, p: PublicInputs) -> PartialNodeWires { let cell = CellWire::new(b); - let metadata_digests = cell.split_metadata_digest(b); - let values_digests = cell.split_values_digest(b); - - let metadata_digests = metadata_digests.accumulate(b, &p.split_metadata_digest_target()); - let values_digests = values_digests.accumulate(b, &p.split_values_digest_target()); + let metadata_digests = + cell.split_and_accumulate_metadata_digest(b, &p.split_metadata_digest_target()); + let values_digests = + cell.split_and_accumulate_values_digest(b, &p.split_values_digest_target()); /* # since there is no sorting constraint among the nodes of this tree, to simplify