From 14e897dac9e369aa9c9a0bb06180c60f024a5561 Mon Sep 17 00:00:00 2001 From: Dimitris Mouris Date: Thu, 7 Dec 2023 19:06:28 -0500 Subject: [PATCH] Make beta a vector --- src/bin/config.json | 1 + src/bin/leader.rs | 178 ++++++++++++++++++++++--------- src/bin/server.rs | 26 +++-- src/collect.rs | 242 ++++++++++++++++++++++++------------------ src/config.rs | 30 +++--- src/lib.rs | 29 ++++- src/prg.rs | 13 ++- src/rpc.rs | 11 +- src/vidpf.rs | 81 ++++++++------ tests/collect_test.rs | 39 ++++--- tests/dpf_test.rs | 12 +-- tests/flp_test.rs | 55 ++++++---- 12 files changed, 451 insertions(+), 266 deletions(-) diff --git a/src/bin/config.json b/src/bin/config.json index fd19ee7..afa6fcc 100644 --- a/src/bin/config.json +++ b/src/bin/config.json @@ -1,5 +1,6 @@ { "data_bytes": 1, + "range_bits": 2, "threshold": 0.01, "server_0": "0.0.0.0:8000", "server_1": "0.0.0.0:8001", diff --git a/src/bin/leader.rs b/src/bin/leader.rs index 9501b77..e2569b1 100644 --- a/src/bin/leader.rs +++ b/src/bin/leader.rs @@ -11,17 +11,19 @@ use mastic::{ GetProofsRequest, ResetRequest, RunFlpQueriesRequest, TreeCrawlLastRequest, TreeCrawlRequest, TreeInitRequest, TreePruneRequest, }, - vidpf, CollectorClient, + vidpf, BetaType, CollectorClient, }; use prio::{ field::{random_vector, Field64}, - flp::{types::Count, Type}, + flp::{types::Sum, Type}, + vdaf::xof::{IntoFieldVec, Xof, XofShake128}, }; use rand::{distributions::Alphanumeric, thread_rng, Rng}; +use rand_core::RngCore; use rayon::prelude::*; use tarpc::{client, context, serde_transport::tcp, tokio_serde::formats::Bincode}; -type Key = vidpf::VIDPFKey; +type Key = vidpf::VIDPFKey; type Client = CollectorClient; fn long_context() -> context::Context { @@ -42,23 +44,72 @@ fn sample_string(len: usize) -> String { fn generate_keys( cfg: &config::Config, -) -> ((Vec, Vec), (Vec>, Vec>)) { - let beta = 1u64; - let count = Count::new(); - let input_beta: Vec = count.encode_measurement(&beta).unwrap(); - - let (keys_0, keys_1): (Vec, Vec) = rayon::iter::repeat(0) + typ: &Sum, +) -> ((Vec, Vec), Vec>) { + let (keys, values): ((Vec, Vec), Vec>) = rayon::iter::repeat(0) .take(cfg.unique_buckets) .map(|_| { - vidpf::VIDPFKey::gen_from_str(&sample_string(cfg.data_bytes * 8), Field64::from(beta)) + // Generate a random number in the specified range + let beta = rand::thread_rng().gen_range(1..(1 << cfg.range_bits)); + let input_beta: BetaType = typ.encode_measurement(&beta).unwrap(); + + ( + Key::gen_from_str(&sample_string(cfg.data_bytes * 8), &input_beta), + input_beta, + ) }) .unzip(); - let (proofs_0, proofs_1): (Vec>, Vec>) = rayon::iter::repeat(0) - .take(cfg.unique_buckets) - .map(|_| { - let prove_rand = random_vector(count.prove_rand_len()).unwrap(); - let proof = count.prove(&input_beta, &prove_rand, &[]).unwrap(); + let encoded: Vec = bincode::serialize(&keys.0[0]).unwrap(); + println!("Key size: {:?} bytes", encoded.len()); + + (keys, values) +} + +fn generate_randomness(keys: (&Vec, &Vec)) -> (Vec<[u8; 16]>, Vec<[[u8; 16]; 2]>) { + keys.0 + .par_iter() + .zip(keys.1.par_iter()) + .map(|(key_0, key_1)| { + let nonce = rand::thread_rng().gen::().to_le_bytes(); + let vidpf_seeds = (key_0.get_root_seed().key, key_1.get_root_seed().key); + + let mut jr_parts = [[0u8; 16]; 2]; + let mut jr_part_0_xof = XofShake128::init(&vidpf_seeds.0, &[0u8; 16]); + jr_part_0_xof.update(&[0]); // Aggregator ID + jr_part_0_xof.update(&nonce); + jr_part_0_xof + .into_seed_stream() + .fill_bytes(&mut jr_parts[0]); + + let mut jr_part_1_xof = XofShake128::init(&vidpf_seeds.1, &[0u8; 16]); + jr_part_1_xof.update(&[1]); // Aggregator ID + jr_part_1_xof.update(&nonce); + jr_part_1_xof + .into_seed_stream() + .fill_bytes(&mut jr_parts[1]); + + (nonce, jr_parts) + }) + .unzip() +} + +fn generate_proofs( + typ: &Sum, + beta_values: &Vec>, + all_jr_parts: &Vec<[[u8; 16]; 2]>, +) -> (Vec>, Vec>) { + all_jr_parts + .par_iter() + .zip_eq(beta_values.par_iter()) + .map(|(jr_parts, input_beta)| { + let joint_rand_xof = XofShake128::init(&jr_parts[0], &jr_parts[1]); + let joint_rand: Vec = joint_rand_xof + .into_seed_stream() + .into_field_vec(typ.joint_rand_len()); + + let prove_rand = random_vector(typ.prove_rand_len()).unwrap(); + let proof = typ.prove(input_beta, &prove_rand, &joint_rand).unwrap(); let proof_0 = proof .iter() @@ -69,14 +120,10 @@ fn generate_keys( .zip(proof_0.par_iter()) .map(|(p_0, p_1)| p_0 - p_1) .collect::>(); + (proof_0, proof_1) }) - .unzip(); - - let encoded: Vec = bincode::serialize(&keys_0[0]).unwrap(); - println!("Key size: {:?} bytes", encoded.len()); - - ((keys_0, keys_1), (proofs_0, proofs_1)) + .unzip() } async fn reset_servers( @@ -105,12 +152,11 @@ async fn tree_init(client_0: &Client, client_1: &Client) -> io::Result<()> { async fn add_keys( cfg: &config::Config, - client_0: &Client, - client_1: &Client, - keys_0: &[vidpf::VIDPFKey], - keys_1: &[vidpf::VIDPFKey], - proofs_0: &[Vec], - proofs_1: &[Vec], + clients: (&Client, &Client), + keys: (&[vidpf::VIDPFKey], &[vidpf::VIDPFKey]), + proofs: (&[Vec], &[Vec]), + all_nonces: &[[u8; 16]], + all_jr_parts: &[[[u8; 16]; 2]], num_clients: usize, malicious_percentage: f32, ) -> io::Result<()> { @@ -123,6 +169,7 @@ async fn add_keys( let mut flp_proof_shares_0 = Vec::with_capacity(num_clients); let mut flp_proof_shares_1 = Vec::with_capacity(num_clients); let mut nonces = Vec::with_capacity(num_clients); + let mut jr_parts = Vec::with_capacity(num_clients); for r in 0..num_clients { let idx_1 = zipf.sample(&mut rng) - 1; let mut idx_2 = idx_1; @@ -137,30 +184,38 @@ async fn add_keys( } println!("Malicious {}", r); } - add_keys_0.push(keys_0[idx_1].clone()); - add_keys_1.push(keys_1[idx_2 % cfg.unique_buckets].clone()); + add_keys_0.push(keys.0[idx_1].clone()); + add_keys_1.push(keys.1[idx_2 % cfg.unique_buckets].clone()); + + flp_proof_shares_0.push(proofs.0[idx_1].clone()); + flp_proof_shares_1.push(proofs.1[idx_3 % cfg.unique_buckets].clone()); - flp_proof_shares_0.push(proofs_0[idx_1].clone()); - flp_proof_shares_1.push(proofs_1[idx_3 % cfg.unique_buckets].clone()); - nonces.push(rand::thread_rng().gen::().to_le_bytes()); + nonces.push(all_nonces[idx_1]); + jr_parts.push(all_jr_parts[idx_1]); } - let resp_0 = client_0.add_keys(long_context(), AddKeysRequest { keys: add_keys_0 }); - let resp_1 = client_1.add_keys(long_context(), AddKeysRequest { keys: add_keys_1 }); + let resp_0 = clients + .0 + .add_keys(long_context(), AddKeysRequest { keys: add_keys_0 }); + let resp_1 = clients + .1 + .add_keys(long_context(), AddKeysRequest { keys: add_keys_1 }); try_join!(resp_0, resp_1).unwrap(); - let resp_0 = client_0.add_all_flp_proof_shares( + let resp_0 = clients.0.add_all_flp_proof_shares( long_context(), AddFLPsRequest { flp_proof_shares: flp_proof_shares_0, nonces: nonces.clone(), + jr_parts: jr_parts.clone(), }, ); - let resp_1 = client_1.add_all_flp_proof_shares( + let resp_1 = clients.1.add_all_flp_proof_shares( long_context(), AddFLPsRequest { flp_proof_shares: flp_proof_shares_1, - nonces, + nonces: nonces.clone(), + jr_parts: jr_parts.clone(), }, ); try_join!(resp_0, resp_1).unwrap(); @@ -170,12 +225,12 @@ async fn add_keys( async fn run_flp_queries( cfg: &config::Config, + typ: &Sum, client_0: &Client, client_1: &Client, num_clients: usize, ) -> io::Result<()> { // Receive FLP query responses in chunks of cfg.flp_batch_size to avoid having huge RPC messages. - let count = Count::new(); let mut keep = vec![]; let mut start = 0; while start < num_clients { @@ -198,7 +253,7 @@ async fn run_flp_queries( .map(|(&v1, &v2)| v1 + v2) .collect::>(); - count.decide(&flp_verifier).unwrap() + typ.decide(&flp_verifier).unwrap() }) .collect::>(), ); @@ -217,6 +272,7 @@ async fn run_flp_queries( async fn run_level( cfg: &config::Config, + typ: &Sum, client_0: &Client, client_1: &Client, num_clients: usize, @@ -240,8 +296,12 @@ async fn run_level( try_join!(resp_0, resp_1).unwrap(); assert_eq!(cnt_values_0.len(), cnt_values_1.len()); - keep = - collect::KeyCollection::::keep_values(threshold, &cnt_values_0, &cnt_values_1); + keep = collect::KeyCollection::keep_values( + typ.input_len(), + threshold, + &cnt_values_0, + &cnt_values_1, + ); if mt_root_0.is_empty() { break; } @@ -284,6 +344,7 @@ async fn run_level( async fn run_level_last( cfg: &config::Config, + typ: &Sum, client_0: &Client, client_1: &Client, num_clients: usize, @@ -295,8 +356,12 @@ async fn run_level_last( let resp_1 = client_1.tree_crawl_last(long_context(), req); let (cnt_values_0, cnt_values_1) = try_join!(resp_0, resp_1).unwrap(); assert_eq!(cnt_values_0.len(), cnt_values_1.len()); - let keep = - collect::KeyCollection::::keep_values(threshold, &cnt_values_0, &cnt_values_1); + let keep = collect::KeyCollection::keep_values( + typ.input_len(), + threshold, + &cnt_values_0, + &cnt_values_1, + ); // Receive counters in chunks to avoid having huge RPC messages. let mut start = 0; @@ -329,9 +394,9 @@ async fn run_level_last( let resp_0 = client_0.final_shares(long_context(), req.clone()); let resp_1 = client_1.final_shares(long_context(), req); let (shares_0, shares_1) = try_join!(resp_0, resp_1).unwrap(); - for res in &collect::KeyCollection::::final_values(&shares_0, &shares_1) { + for res in &collect::KeyCollection::final_values(typ.input_len(), &shares_0, &shares_1) { let bits = mastic::bits_to_bitstring(&res.path); - if res.value > Field64::from(0) { + if res.value[typ.input_len() - 1] > Field64::from(0) { println!("Value ({}) \t Count: {:?}", bits, res.value); } } @@ -341,9 +406,6 @@ async fn run_level_last( #[tokio::main] async fn main() -> io::Result<()> { - // println!("Using only one thread!"); - // rayon::ThreadPoolBuilder::new().num_threads(1).build_global().unwrap(); - let (cfg, _, num_clients, malicious) = config::get_args("Leader", false, true, true); assert!((0.0..0.8).contains(&malicious)); println!("Running with {}% malicious clients", malicious * 100.0); @@ -358,10 +420,14 @@ async fn main() -> io::Result<()> { ) .spawn(); + let typ = Sum::::new(cfg.range_bits).unwrap(); + let start = Instant::now(); println!("Generating keys..."); - let ((keys_0, keys_1), (proofs_0, proofs_1)) = generate_keys(&cfg); + let ((keys_0, keys_1), beta_values) = generate_keys(&cfg, &typ); let delta = start.elapsed().as_secs_f64(); + let (nonces, jr_parts) = generate_randomness((&keys_0, &keys_1)); + let (proofs_0, proofs_1) = generate_proofs(&typ, &beta_values, &jr_parts); println!( "Generated {:?} keys in {:?} seconds ({:?} sec/key)", keys_0.len(), @@ -385,7 +451,13 @@ async fn main() -> io::Result<()> { if this_batch > 0 { responses.push(add_keys( - &cfg, &client_0, &client_1, &keys_0, &keys_1, &proofs_0, &proofs_1, this_batch, + &cfg, + (&client_0, &client_1), + (&keys_0, &keys_1), + (&proofs_0, &proofs_1), + &nonces, + &jr_parts, + this_batch, malicious, )); } @@ -403,9 +475,9 @@ async fn main() -> io::Result<()> { for level in 0..bit_len - 1 { let start_level = Instant::now(); if level == 0 { - run_flp_queries(&cfg, &client_0, &client_1, num_clients).await?; + run_flp_queries(&cfg, &typ, &client_0, &client_1, num_clients).await?; } - run_level(&cfg, &client_0, &client_1, num_clients).await?; + run_level(&cfg, &typ, &client_0, &client_1, num_clients).await?; println!( "Time for level {}: {:?}", level, @@ -419,7 +491,7 @@ async fn main() -> io::Result<()> { ); let start_last = Instant::now(); - run_level_last(&cfg, &client_0, &client_1, num_clients).await?; + run_level_last(&cfg, &typ, &client_0, &client_1, num_clients).await?; println!( "Time for last level: {:?}", start_last.elapsed().as_secs_f64() diff --git a/src/bin/server.rs b/src/bin/server.rs index ed6a1b4..2f3c856 100644 --- a/src/bin/server.rs +++ b/src/bin/server.rs @@ -12,9 +12,9 @@ use mastic::{ GetProofsRequest, ResetRequest, RunFlpQueriesRequest, TreeCrawlLastRequest, TreeCrawlRequest, TreeInitRequest, TreePruneRequest, }, - HASH_SIZE, + BetaType, HASH_SIZE, }; -use prio::field::Field64; +use prio::{field::Field64, flp::types::Sum}; use tarpc::{ context, serde_transport::tcp, @@ -29,7 +29,7 @@ struct CollectorServer { server_id: i8, seed: prg::PrgSeed, data_bytes: usize, - arc: Arc>>, + arc: Arc>, } #[tarpc::server] @@ -37,6 +37,7 @@ impl Collector for CollectorServer { async fn reset(self, _: context::Context, req: ResetRequest) -> String { let mut coll = self.arc.lock().unwrap(); *coll = collect::KeyCollection::new( + Sum::::new(2).unwrap(), self.server_id, &self.seed, self.data_bytes, @@ -58,8 +59,13 @@ impl Collector for CollectorServer { async fn add_all_flp_proof_shares(self, _: context::Context, req: AddFLPsRequest) -> String { let mut coll = self.arc.lock().unwrap(); - for (flp_proof_share, nonce) in req.flp_proof_shares.into_iter().zip(req.nonces) { - coll.add_flp_proof_share(flp_proof_share, nonce); + for ((flp_proof_share, nonce), jr_parts) in req + .flp_proof_shares + .into_iter() + .zip(req.nonces) + .zip(req.jr_parts) + { + coll.add_flp_proof_share(flp_proof_share, nonce, jr_parts); } if coll.keys.len() % 10000 == 0 { println!("Number of keys: {:?}", coll.keys.len()); @@ -79,7 +85,7 @@ impl Collector for CollectorServer { self, _: context::Context, req: TreeCrawlRequest, - ) -> (Vec, Vec>, Vec) { + ) -> (Vec, Vec>, Vec) { let start = Instant::now(); let split_by = req.split_by; let malicious = req.malicious; @@ -113,7 +119,7 @@ impl Collector for CollectorServer { self, _: context::Context, _req: TreeCrawlLastRequest, - ) -> Vec { + ) -> Vec { let start = Instant::now(); let mut coll = self.arc.lock().unwrap(); @@ -140,7 +146,7 @@ impl Collector for CollectorServer { self, _: context::Context, _req: FinalSharesRequest, - ) -> Vec> { + ) -> Vec { let coll = self.arc.lock().unwrap(); coll.final_shares() } @@ -156,8 +162,10 @@ async fn main() -> io::Result<()> { }; let seed = prg::PrgSeed { key: [1u8; 16] }; + let typ = Sum::::new(cfg.range_bits).unwrap(); - let coll = collect::KeyCollection::new(server_id, &seed, cfg.data_bytes * 8, [0u8; 16]); + let coll = + collect::KeyCollection::new(typ.clone(), server_id, &seed, cfg.data_bytes * 8, [0u8; 16]); let arc = Arc::new(Mutex::new(coll)); println!("Server {} running at {:?}", server_id, server_addr); diff --git a/src/collect.rs b/src/collect.rs index e9827b2..35e815a 100644 --- a/src/collect.rs +++ b/src/collect.rs @@ -1,13 +1,16 @@ use blake3::hash; use prio::{ - flp::{types::Count, Type}, + codec::Encode, + field::Field64, + flp::{types::Sum, Type}, vdaf::xof::{IntoFieldVec, Xof, XofShake128}, }; +use rand_core::RngCore; use rayon::prelude::*; use rs_merkle::{Hasher, MerkleTree}; use serde::{Deserialize, Serialize}; -use crate::{prg, vidpf, xor_in_place, xor_vec, HASH_SIZE}; +use crate::{prg, vec_add, vec_sub, vidpf, xor_in_place, xor_vec, BetaType, HASH_SIZE}; #[derive(Clone)] pub struct HashAlg {} @@ -21,113 +24,114 @@ impl Hasher for HashAlg { } #[derive(Clone)] -struct TreeNode { +struct TreeNode { path: Vec, - value: T, + value: Vec, key_states: Vec, - key_values: Vec, + key_values: Vec>, } -unsafe impl Send for TreeNode {} -unsafe impl Sync for TreeNode {} +unsafe impl Send for TreeNode {} +unsafe impl Sync for TreeNode {} #[derive(Clone)] -pub struct KeyCollection { +pub struct KeyCollection { + typ: Sum, server_id: i8, verify_key: [u8; 16], depth: usize, - pub keys: Vec<(bool, vidpf::VIDPFKey)>, + pub keys: Vec<(bool, vidpf::VIDPFKey)>, nonces: Vec<[u8; 16]>, - all_flp_proof_shares: Vec>, - frontier: Vec>, - prev_frontier: Vec>, - count: Count, + jr_parts: Vec<[[u8; 16]; 2]>, + all_flp_proof_shares: Vec>, + frontier: Vec>, + prev_frontier: Vec>, final_proofs: Vec<[u8; HASH_SIZE]>, } #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Result { +pub struct Result { pub path: Vec, - pub value: T, + pub value: BetaType, } -impl KeyCollection -where - T: prio::field::FieldElement - + prio::field::FftFriendlyFieldElement - + std::fmt::Debug - + std::cmp::PartialOrd - + Send - + Sync - + prg::FromRng - + 'static, - u64: From, -{ +impl KeyCollection { pub fn new( + typ: Sum, server_id: i8, _seed: &prg::PrgSeed, depth: usize, verify_key: [u8; 16], - ) -> KeyCollection { - KeyCollection:: { + ) -> KeyCollection { + KeyCollection { + typ, server_id, verify_key, depth, keys: vec![], nonces: vec![], + jr_parts: vec![], all_flp_proof_shares: vec![], frontier: vec![], prev_frontier: vec![], - count: Count::new(), final_proofs: vec![], } } - pub fn add_key(&mut self, key: vidpf::VIDPFKey) { + pub fn add_key(&mut self, key: vidpf::VIDPFKey) { self.keys.push((true, key)); } - pub fn add_flp_proof_share(&mut self, flp_proof_share: Vec, nonce: [u8; 16]) { + pub fn add_flp_proof_share( + &mut self, + flp_proof_share: Vec, + nonce: [u8; 16], + jr_parts: [[u8; 16]; 2], + ) { self.all_flp_proof_shares.push(flp_proof_share); self.nonces.push(nonce); + self.jr_parts.push(jr_parts); } pub fn tree_init(&mut self) { - let mut root = TreeNode:: { + let mut root = TreeNode:: { path: vec![], - value: T::zero(), + value: vec![Field64::from(0)], key_states: vec![], key_values: vec![], }; for k in &self.keys { root.key_states.push(k.1.eval_init()); - root.key_values.push(T::zero()); + root.key_values.push(vec![Field64::from(0)]); } self.frontier.clear(); self.frontier.push(root); } - fn make_tree_node(&self, parent: &TreeNode, dir: bool) -> TreeNode { + fn make_tree_node(&self, parent: &TreeNode, dir: bool) -> TreeNode { let mut bit_str = crate::bits_to_bitstring(parent.path.as_slice()); bit_str.push(if dir { '1' } else { '0' }); - let (key_states, key_values): (Vec, Vec) = self + let (key_states, key_values): (Vec, Vec>) = self .keys .par_iter() .enumerate() - .map(|(i, key)| key.1.eval_bit(&parent.key_states[i], dir, &bit_str)) + .map(|(i, key)| { + key.1 + .eval_bit(&parent.key_states[i], dir, &bit_str, self.typ.input_len()) + }) .unzip(); - let mut child_val = T::zero(); + let mut child_val = vec![Field64::from(0); &self.typ.input_len() + 1]; key_values .iter() .zip(&self.keys) .filter(|&(_, key)| key.0) - .for_each(|(&v, _)| child_val.add_assign(v)); + .for_each(|(v, _)| vec_add(&mut child_val, v)); - let mut child = TreeNode:: { + let mut child = TreeNode:: { path: parent.path.clone(), value: child_val, key_states, @@ -139,7 +143,7 @@ where child } - pub fn run_flp_queries(&mut self, start: usize, end: usize) -> Vec> { + pub fn run_flp_queries(&mut self, start: usize, end: usize) -> Vec> { let level = self.frontier[0].path.len(); assert_eq!(level, 0); @@ -160,25 +164,49 @@ where .enumerate() .filter(|(client_index, _)| *client_index >= start && *client_index < end) .map(|(client_index, _)| { - let y_p0 = node_left.key_values[client_index]; - let y_p1 = node_right.key_values[client_index]; + let y_p0 = &node_left.key_values[client_index]; + let y_p1 = &node_right.key_values[client_index]; - let mut beta_share = T::zero(); - beta_share.add_assign(y_p0); - beta_share.add_assign(y_p1); + let mut beta_share = vec![Field64::from(0); self.typ.input_len()]; + vec_add(&mut beta_share, y_p0); + vec_add(&mut beta_share, y_p1); let flp_proof_share = &self.all_flp_proof_shares[client_index]; let query_rand_xof = XofShake128::init(&self.verify_key, &self.nonces[client_index]); - let query_rand: Vec = query_rand_xof + let query_rand: Vec = query_rand_xof .clone() .into_seed_stream() - .into_field_vec(self.count.query_rand_len()); + .into_field_vec(self.typ.query_rand_len()); + + let mut jr_parts = self.jr_parts[client_index]; + if self.server_id == 0 { + let mut jr_part_xof = XofShake128::init( + &self.keys[client_index].1.get_root_seed().key, + &[0u8; 16], + ); + jr_part_xof.update(&[0]); // Aggregator ID + jr_part_xof.update(&self.nonces[client_index]); + jr_part_xof.into_seed_stream().fill_bytes(&mut jr_parts[0]); + } else { + let mut jr_part_xof = XofShake128::init( + &self.keys[client_index].1.get_root_seed().key, + &[0u8; 16], + ); + jr_part_xof.update(&[1]); // Aggregator ID + jr_part_xof.update(&self.nonces[client_index]); + jr_part_xof.into_seed_stream().fill_bytes(&mut jr_parts[1]); + } + + let joint_rand_xof = XofShake128::init(&jr_parts[0], &jr_parts[1]); + let joint_rand: Vec = joint_rand_xof + .into_seed_stream() + .into_field_vec(self.typ.joint_rand_len()); // Compute the flp_verifier_share. - self.count - .query(&[beta_share], flp_proof_share, &query_rand, &[], 2) + self.typ + .query(&beta_share, flp_proof_share, &query_rand, &joint_rand, 2) .unwrap() }) .collect::>() @@ -189,7 +217,7 @@ where mut split_by: usize, malicious: &Vec, is_last: bool, - ) -> (Vec, Vec>, Vec) { + ) -> (Vec>, Vec>, Vec) { if !malicious.is_empty() { println!("Malicious is not empty!!"); @@ -215,13 +243,13 @@ where vec![child_0, child_1] }) - .collect::>>(); + .collect::>>(); // These are summed evaluations y for different prefixes. let cnt_values = next_frontier .par_iter() - .map(|node| node.value) - .collect::>(); + .map(|node| node.value.clone()) + .collect::>>(); // For all prefixes, compute the checks for each client. let all_y_checks = self @@ -239,31 +267,34 @@ where node.key_values .par_iter() .enumerate() - .map(|(client_index, &y_p)| { - let y_p0 = node_left.key_values[client_index]; - let y_p1 = node_right.key_values[client_index]; + .map(|(client_index, y_p)| { + let y_p0 = &node_left.key_values[client_index]; + let y_p1 = &node_right.key_values[client_index]; - let mut value_check = T::zero(); + let mut value_check = vec![Field64::from(0); &self.typ.input_len() + 1]; if level == 0 { // (1 - server_id) + (-1)^server_id * (- y^{p||0} - y^{p||1}) if self.server_id == 0 { - value_check.add_assign(T::one()); - value_check.sub_assign(y_p0); - value_check.sub_assign(y_p1); + vec_add( + &mut value_check, + &vec![Field64::from(1); &self.typ.input_len() + 1], + ); + vec_sub(&mut value_check, y_p0); + vec_sub(&mut value_check, y_p1); } else { - value_check.add_assign(y_p0); - value_check.add_assign(y_p1); + vec_add(&mut value_check, y_p0); + vec_add(&mut value_check, y_p1); } } else { // (-1)^server_id * (y^{p} - y^{p||0} - y^{p||1}) if self.server_id == 0 { - value_check.add_assign(y_p); - value_check.sub_assign(y_p0); - value_check.sub_assign(y_p1); + vec_add(&mut value_check, y_p); + vec_sub(&mut value_check, y_p0); + vec_sub(&mut value_check, y_p1); } else { - value_check.add_assign(y_p0); - value_check.add_assign(y_p1); - value_check.sub_assign(y_p); + vec_add(&mut value_check, y_p0); + vec_add(&mut value_check, y_p1); + vec_sub(&mut value_check, y_p); } } @@ -290,7 +321,19 @@ where // for each client. let mut check = [0u8; 8]; all_y_checks.iter().for_each(|checks_for_prefix| { - xor_in_place(&mut check, &checks_for_prefix[client_index].get_encoded()); + if level == 0 { + xor_in_place( + &mut check, + &checks_for_prefix[client_index][self.typ.input_len()].get_encoded(), + ); + } else { + for i in 0..self.typ.input_len() { + xor_in_place( + &mut check, + &checks_for_prefix[client_index][i].get_encoded(), + ); + } + } }); xor_vec( @@ -346,7 +389,7 @@ where (cnt_values, mtree_roots, mtree_indices) } - pub fn tree_crawl_last(&mut self) -> Vec { + pub fn tree_crawl_last(&mut self) -> Vec> { let next_frontier = self .frontier .par_iter() @@ -357,7 +400,7 @@ where vec![child_0, child_1] }) - .collect::>>(); + .collect::>>(); self.final_proofs = self .keys @@ -378,8 +421,8 @@ where // These are summed evaluations y for different prefixes. self.frontier .par_iter() - .map(|node| node.value) - .collect::>() + .map(|node| node.value.clone()) + .collect::>>() } pub fn get_proofs(&self, start: usize, end: usize) -> Vec<[u8; HASH_SIZE]> { @@ -414,46 +457,41 @@ where } } - pub fn keep_values(threshold: u64, cnt_values_0: &[T], cnt_values_1: &[T]) -> Vec { + pub fn keep_values( + input_len: usize, + threshold: u64, + cnt_values_0: &[Vec], + cnt_values_1: &[Vec], + ) -> Vec { cnt_values_0 .par_iter() .zip(cnt_values_1.par_iter()) - .map(|(&value_0, &value_1)| { - let v = value_0 + value_1; + .map(|(value_0, value_1)| { + // Note, this assumes that the pruning happens based on the last element (i.e., + // the counter). + let v = value_0[input_len] + value_1[input_len]; u64::from(v) >= threshold }) .collect::>() } - pub fn final_shares(&self) -> Vec> { + pub fn final_shares(&self) -> Vec { self.frontier .par_iter() - .map(|n| Result:: { + .map(|n| Result { path: n.path.clone(), - value: n.value, + value: n.value.clone(), }) .collect::>() } // Reconstruct counters based on shares - pub fn reconstruct_shares(results_0: &[T], results_1: &[T]) -> Vec { - assert_eq!(results_0.len(), results_1.len()); - - results_0 - .par_iter() - .zip_eq(results_1) - .map(|(&v1, &v2)| { - let mut v = T::zero(); - v.add_assign(v1); - v.add_assign(v2); - v - }) - .collect() - } - - // Reconstruct counters based on shares - pub fn final_values(results_0: &[Result], results_1: &[Result]) -> Vec> { + pub fn final_values( + input_len: usize, + results_0: &[Result], + results_1: &[Result], + ) -> Vec { assert_eq!(results_0.len(), results_1.len()); results_0 @@ -462,16 +500,16 @@ where .map(|(r0, r1)| { assert_eq!(r0.path, r1.path); - let mut v = T::zero(); - v.add_assign(r0.value); - v.add_assign(r1.value); + let mut v = vec![Field64::from(0); input_len + 1]; + vec_add(&mut v, &r0.value); + vec_add(&mut v, &r1.value); Result { path: r0.path.clone(), value: v, } }) - .filter(|result| result.value > T::zero()) + .filter(|result| result.value[input_len] > Field64::from(0)) .collect::>() } } diff --git a/src/config.rs b/src/config.rs index 5e9440d..b886682 100644 --- a/src/config.rs +++ b/src/config.rs @@ -5,6 +5,7 @@ use serde_json::Value; pub struct Config { pub data_bytes: usize, + pub range_bits: usize, pub add_key_batch_size: usize, pub flp_batch_size: usize, pub unique_buckets: usize, @@ -22,6 +23,7 @@ pub fn get_config(filename: &str) -> Config { let json_data = &fs::read_to_string(filename).expect("Cannot open JSON file"); let v: Value = serde_json::from_str(json_data).expect("Cannot parse JSON config"); + let range_bits: usize = v["range_bits"].as_u64().expect("Can't parse range_bits") as usize; let data_bytes: usize = v["data_bytes"].as_u64().expect("Can't parse data_bytes") as usize; let add_key_batch_size: usize = v["add_key_batch_size"] .as_u64() @@ -41,6 +43,7 @@ pub fn get_config(filename: &str) -> Config { Config { data_bytes, + range_bits, add_key_batch_size, flp_batch_size, unique_buckets, @@ -57,18 +60,15 @@ pub fn get_args( get_n_reqs: bool, get_malicious: bool, ) -> (Config, i8, usize, f32) { - let mut flags = App::new(name) - .version("0.1") - .about("Privacy-preserving heavy-hitters for location data.") - .arg( - Arg::with_name("config") - .short("c") - .long("config") - .value_name("FILENAME") - .help("Location of JSON config file") - .required(true) - .takes_value(true), - ); + let mut flags = App::new(name).version("0.1").about("Mastic.").arg( + Arg::with_name("config") + .short("c") + .long("config") + .value_name("FILENAME") + .help("Location of JSON config file") + .required(true) + .takes_value(true), + ); if get_server_id { flags = flags.arg( Arg::with_name("server_id") @@ -82,9 +82,9 @@ pub fn get_args( } if get_n_reqs { flags = flags.arg( - Arg::with_name("num_requests") + Arg::with_name("num_clients") .short("n") - .long("num_requests") + .long("num_clients") .value_name("NUMBER") .help("Number of client requests to generate") .required(true) @@ -112,7 +112,7 @@ pub fn get_args( let mut n_reqs = 0; if get_n_reqs { - n_reqs = flags.value_of("num_requests").unwrap().parse().unwrap(); + n_reqs = flags.value_of("num_clients").unwrap().parse().unwrap(); } let mut malicious = 0.0; diff --git a/src/lib.rs b/src/lib.rs index 30eb072..c26c791 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,7 +10,9 @@ use prio::field::Field64; pub use crate::rpc::CollectorClient; -pub const HASH_SIZE: usize = 12; +pub type BetaType = Vec; + +pub const HASH_SIZE: usize = 16; impl crate::prg::FromRng for Field64 { fn from_rng(&mut self, rng: &mut impl rand::Rng) { @@ -81,6 +83,31 @@ pub fn xor_vec(v1: &[u8], v2: &[u8]) -> Vec { v1.iter().zip(v2.iter()).map(|(&x1, &x2)| x1 ^ x2).collect() } +pub fn vec_add(v1: &mut [T], v2: &[T]) +where + T: prg::FromRng + Clone + prio::field::FieldElement + std::fmt::Debug, +{ + v1.iter_mut() + .zip(v2.iter()) + .for_each(|(x1, &x2)| x1.add_assign(x2)); +} + +pub fn vec_sub(v1: &mut [T], v2: &[T]) +where + T: prg::FromRng + Clone + prio::field::FieldElement + std::fmt::Debug, +{ + v1.iter_mut() + .zip(v2.iter()) + .for_each(|(x1, &x2)| x1.sub_assign(x2)); +} + +pub fn vec_neg(v1: &mut [T]) +where + T: prg::FromRng + Clone + prio::field::FieldElement + std::fmt::Debug, +{ + v1.iter_mut().for_each(|x| *x = x.neg()); +} + pub fn xor_in_place(v1: &mut [u8], v2: &[u8]) { for (x1, &x2) in v1.iter_mut().zip(v2.iter()) { *x1 ^= x2; diff --git a/src/prg.rs b/src/prg.rs index 8cd0f53..dea4553 100644 --- a/src/prg.rs +++ b/src/prg.rs @@ -88,10 +88,13 @@ impl PrgSeed { self.expand_dir(true, true) } - pub fn convert(self: &PrgSeed) -> ConvertOutput { + pub fn convert( + self: &PrgSeed, + input_len: usize, + ) -> ConvertOutput { let mut out = ConvertOutput { seed: PrgSeed::zero(), - word: T::zero(), + word: vec![T::zero(); input_len + 1], }; FIXED_KEY_STREAM.with(|s_in| { @@ -100,7 +103,9 @@ impl PrgSeed { s.fill_bytes(&mut out.seed.key); unsafe { let sp = s_in.as_ptr(); - out.word.from_rng(&mut *sp); + for i in 0..input_len { + out.word[i].from_rng(&mut *sp); + } } }); @@ -191,7 +196,7 @@ pub struct PrgOutput { pub struct ConvertOutput { pub seed: PrgSeed, - pub word: T, + pub word: Vec, } impl FixedKeyPrgStream { diff --git a/src/rpc.rs b/src/rpc.rs index 702aabd..2e68202 100644 --- a/src/rpc.rs +++ b/src/rpc.rs @@ -1,7 +1,7 @@ use prio::field::Field64; use serde::{Deserialize, Serialize}; -use crate::{collect, vidpf, HASH_SIZE}; +use crate::{collect, vidpf, BetaType, HASH_SIZE}; #[derive(Clone, Debug, Serialize, Deserialize)] pub struct ResetRequest { @@ -10,13 +10,14 @@ pub struct ResetRequest { #[derive(Clone, Debug, Serialize, Deserialize)] pub struct AddKeysRequest { - pub keys: Vec>, + pub keys: Vec, } #[derive(Clone, Debug, Serialize, Deserialize)] pub struct AddFLPsRequest { pub flp_proof_shares: Vec>, pub nonces: Vec<[u8; 16]>, + pub jr_parts: Vec<[[u8; 16]; 2]>, } #[derive(Clone, Debug, Serialize, Deserialize)] @@ -64,10 +65,10 @@ pub trait Collector { async fn add_all_flp_proof_shares(req: AddFLPsRequest) -> String; async fn run_flp_queries(req: RunFlpQueriesRequest) -> Vec>; async fn apply_flp_results(req: ApplyFLPResultsRequest) -> String; - async fn tree_crawl(req: TreeCrawlRequest) -> (Vec, Vec>, Vec); - async fn tree_crawl_last(req: TreeCrawlLastRequest) -> Vec; + async fn tree_crawl(req: TreeCrawlRequest) -> (Vec, Vec>, Vec); + async fn tree_crawl_last(req: TreeCrawlLastRequest) -> Vec; async fn get_proofs(req: GetProofsRequest) -> Vec<[u8; HASH_SIZE]>; async fn tree_init(req: TreeInitRequest) -> String; async fn tree_prune(req: TreePruneRequest) -> String; - async fn final_shares(req: FinalSharesRequest) -> Vec>; + async fn final_shares(req: FinalSharesRequest) -> Vec; } diff --git a/src/vidpf.rs b/src/vidpf.rs index df8f10f..ebb4f1f 100644 --- a/src/vidpf.rs +++ b/src/vidpf.rs @@ -1,20 +1,21 @@ use blake3::Hasher; +use prio::field::Field64; use serde::{Deserialize, Serialize}; -use crate::{prg, xor_three_vecs, xor_vec, HASH_SIZE}; +use crate::{prg, vec_add, vec_neg, vec_sub, xor_three_vecs, xor_vec, BetaType, HASH_SIZE}; #[derive(Clone, Debug, Serialize, Deserialize)] -struct CorWord { +struct CorWord { seed: prg::PrgSeed, bits: (bool, bool), - word: T, + word: BetaType, } #[derive(Clone, Debug, Serialize, Deserialize)] -pub struct VIDPFKey { +pub struct VIDPFKey { pub key_idx: bool, root_seed: prg::PrgSeed, - cor_words: Vec>, + cor_words: Vec, pub cs: Vec<[u8; HASH_SIZE]>, } @@ -78,15 +79,12 @@ impl TupleExt for (T, T) { } } -fn gen_cor_word( +fn gen_cor_word( bit: bool, - beta: W, + beta: &BetaType, bits: &mut (bool, bool), seeds: &mut (prg::PrgSeed, prg::PrgSeed), -) -> CorWord -where - W: prg::FromRng + Clone + prio::field::FieldElement + std::fmt::Debug, -{ +) -> CorWord { let data = seeds.map(|s| s.expand()); // If alpha[i] = 0: @@ -102,7 +100,7 @@ where data.0.bits.0 ^ data.1.bits.0 ^ bit ^ true, data.0.bits.1 ^ data.1.bits.1 ^ bit, ), - word: W::zero(), + word: beta.clone(), }; for (b, seed) in seeds.iter_mut() { @@ -120,15 +118,16 @@ where *bits.get_mut(b) = newbit; } - let converted = seeds.map(|s| s.convert()); - cw.word = beta; - cw.word.sub_assign(converted.0.word); - cw.word.add_assign(converted.1.word); + let input_len = beta.len(); + let converted = seeds.map(|s| s.convert(input_len)); + // Counter is last + cw.word.push(Field64::from(1)); + vec_sub(&mut cw.word, &converted.0.word); + vec_add(&mut cw.word, &converted.1.word); if bits.1 { - cw.word = cw.word.neg(); + vec_neg(&mut cw.word); } - seeds.0 = converted.0.seed; seeds.1 = converted.1.seed; @@ -136,11 +135,12 @@ where } /// All-prefix DPF implementation. -impl VIDPFKey -where - T: prg::FromRng + Clone + prio::field::FieldElement + std::fmt::Debug, -{ - pub fn gen(alpha_bits: &[bool], beta: T) -> (VIDPFKey, VIDPFKey) { +impl VIDPFKey { + pub fn get_root_seed(&self) -> prg::PrgSeed { + self.root_seed.clone() + } + + pub fn gen(alpha_bits: &[bool], beta: &BetaType) -> (VIDPFKey, VIDPFKey) { let root_seeds = (prg::PrgSeed::random(), prg::PrgSeed::random()); let root_bits = (false, true); @@ -148,12 +148,12 @@ where let mut bits = root_bits; let mut hasher = Hasher::new(); - let mut cor_words: Vec> = Vec::new(); + let mut cor_words: Vec = Vec::new(); let mut cs: Vec<[u8; HASH_SIZE]> = Vec::new(); let mut bit_str: String = "".to_string(); for &bit in alpha_bits { bit_str.push_str(if bit { "1" } else { "0" }); - let cw = gen_cor_word::(bit, beta, &mut bits, &mut seeds); + let cw = gen_cor_word(bit, beta, &mut bits, &mut seeds); cor_words.push(cw); let pi_0 = { @@ -176,13 +176,13 @@ where } ( - VIDPFKey:: { + VIDPFKey { key_idx: false, root_seed: root_seeds.0, cor_words: cor_words.clone(), cs: cs.clone(), }, - VIDPFKey:: { + VIDPFKey { key_idx: true, root_seed: root_seeds.1, cor_words, @@ -191,7 +191,13 @@ where ) } - pub fn eval_bit(&self, state: &EvalState, dir: bool, bit_str: &String) -> (EvalState, T) { + pub fn eval_bit( + &self, + state: &EvalState, + dir: bool, + bit_str: &String, + input_len: usize, + ) -> (EvalState, Vec) { let tau = state.seed.expand_dir(!dir, dir); let mut seed = tau.seeds.get(dir).clone(); let mut new_bit = *tau.bits.get(dir); @@ -201,16 +207,16 @@ where new_bit ^= self.cor_words[state.level].bits.get(dir); } - let converted = seed.convert::(); + let converted = seed.convert(input_len); let new_seed = converted.seed; let mut word = converted.word; if new_bit { - word.add_assign(self.cor_words[state.level].word); + vec_add(&mut word, &self.cor_words[state.level].word); } if self.key_idx { - word = word.neg(); + vec_neg(&mut word); } // Compute proofs @@ -263,7 +269,12 @@ where } } - pub fn eval(&self, idx: &[bool], pi: &mut [u8; HASH_SIZE]) -> (Vec, T) { + pub fn eval( + &self, + idx: &[bool], + pi: &mut [u8; HASH_SIZE], + input_len: usize, + ) -> (Vec>, Vec) { debug_assert!(idx.len() <= self.domain_size()); debug_assert!(!idx.is_empty()); let mut out = vec![]; @@ -275,18 +286,18 @@ where for &bit in idx.iter().take(idx.len() - 1) { bit_str.push(if bit { '1' } else { '0' }); - let (state_new, word) = self.eval_bit(&state, bit, &bit_str); + let (state_new, word) = self.eval_bit(&state, bit, &bit_str, input_len); out.push(word); state = state_new; } - let (_, last) = self.eval_bit(&state, *idx.last().unwrap(), &bit_str); + let (_, last) = self.eval_bit(&state, *idx.last().unwrap(), &bit_str, input_len); *pi = state.proof; (out, last) } - pub fn gen_from_str(s: &str, beta: T) -> (Self, Self) { + pub fn gen_from_str(s: &str, beta: &BetaType) -> (Self, Self) { let bits = crate::string_to_bits(s); VIDPFKey::gen(&bits, beta) } diff --git a/tests/collect_test.rs b/tests/collect_test.rs index 3e5154b..63cb865 100644 --- a/tests/collect_test.rs +++ b/tests/collect_test.rs @@ -1,5 +1,8 @@ use mastic::{collect::*, prg, *}; -use prio::field::{Field64, FieldElement}; +use prio::{ + field::Field64, + flp::{types::Sum, Type}, +}; use rand::{thread_rng, Rng}; use rayon::prelude::*; @@ -15,11 +18,13 @@ fn collect_test_eval_groups() { let mut verify_key = [0; 16]; thread_rng().fill(&mut verify_key); - let mut col_0 = KeyCollection::new(0, &seed, strlen, verify_key); - let mut col_1 = KeyCollection::new(1, &seed, strlen, verify_key); + let typ = Sum::::new(2).unwrap(); + let mut col_0 = KeyCollection::new(typ.clone(), 0, &seed, strlen, verify_key); + let mut col_1 = KeyCollection::new(typ.clone(), 1, &seed, strlen, verify_key); for cstr in &client_strings { - let (keys_0, keys_1) = vidpf::VIDPFKey::::gen_from_str(&cstr, Field64::one()); + let input_beta = typ.encode_measurement(&3u64).unwrap(); + let (keys_0, keys_1) = vidpf::VIDPFKey::gen_from_str(&cstr, &input_beta); col_0.add_key(keys_0); col_1.add_key(keys_1); } @@ -35,7 +40,8 @@ fn collect_test_eval_groups() { let (cnt_values_1, _, _) = col_1.tree_crawl(1usize, &malicious, false); assert_eq!(cnt_values_0.len(), cnt_values_1.len()); - let keep = KeyCollection::::keep_values(threshold, &cnt_values_0, &cnt_values_1); + let keep = + KeyCollection::keep_values(typ.input_len(), threshold, &cnt_values_0, &cnt_values_1); col_0.tree_prune(&keep); col_1.tree_prune(&keep); @@ -56,7 +62,7 @@ fn collect_test_eval_groups() { .all(|(&h0, &h1)| h0 == h1); assert!(verified); - let keep = KeyCollection::::keep_values(threshold, &cnt_values_0, &cnt_values_1); + let keep = KeyCollection::keep_values(typ.input_len(), threshold, &cnt_values_0, &cnt_values_1); col_0.tree_prune(&keep); col_1.tree_prune(&keep); @@ -64,14 +70,14 @@ fn collect_test_eval_groups() { let shares_0 = col_0.final_shares(); let shares_1 = col_1.final_shares(); - for res in &KeyCollection::::final_values(&shares_0, &shares_1) { + for res in &KeyCollection::final_values(typ.input_len(), &shares_0, &shares_1) { println!("Path = {:?}", res.path); let s = crate::bits_to_string(&res.path); println!("fast: {:?} = {:?}", s, res.value); match &s[..] { - "abdef" => assert_eq!(res.value, 4u64), - "gZ???" => assert_eq!(res.value, 3u64), + "abdef" => assert_eq!(res.value, vec![4u64, 4u64, 4u64]), + "gZ???" => assert_eq!(res.value, vec![3u64, 3u64, 3u64]), _ => { println!("Unexpected string: '{:?}' = {:?}", s, res.value); assert!(false); @@ -95,13 +101,15 @@ fn collect_test_eval_full_groups() { let seed = prg::PrgSeed::random(); let mut verify_key = [0; 16]; thread_rng().fill(&mut verify_key); - let mut col_0 = KeyCollection::new(0, &seed, strlen, verify_key); - let mut col_1 = KeyCollection::new(1, &seed, strlen, verify_key); + let typ = Sum::::new(2).unwrap(); + let mut col_0 = KeyCollection::new(typ.clone(), 0, &seed, strlen, verify_key); + let mut col_1 = KeyCollection::new(typ.clone(), 1, &seed, strlen, verify_key); let mut keys = vec![]; println!("Starting to generate keys"); for s in &client_strings { - keys.push(vidpf::VIDPFKey::::gen_from_str(&s, Field64::one())); + let input_beta = typ.encode_measurement(&1u64).unwrap(); + keys.push(vidpf::VIDPFKey::gen_from_str(&s, &input_beta)); } println!("Done generating keys"); @@ -128,7 +136,8 @@ fn collect_test_eval_full_groups() { println!("At level {:?} (size: {:?})", level, cnt_values_0.len()); assert_eq!(cnt_values_0.len(), cnt_values_1.len()); - let keep = KeyCollection::::keep_values(threshold, &cnt_values_0, &cnt_values_1); + let keep = + KeyCollection::keep_values(typ.input_len(), threshold, &cnt_values_0, &cnt_values_1); col_0.tree_prune(&keep); col_1.tree_prune(&keep); @@ -149,7 +158,7 @@ fn collect_test_eval_full_groups() { .all(|(&h0, &h1)| h0 == h1); assert!(verified); - let keep = KeyCollection::::keep_values(threshold, &cnt_values_0, &cnt_values_1); + let keep = KeyCollection::keep_values(typ.input_len(), threshold, &cnt_values_0, &cnt_values_1); col_0.tree_prune(&keep); col_1.tree_prune(&keep); @@ -157,7 +166,7 @@ fn collect_test_eval_full_groups() { let s0 = col_0.final_shares(); let s1 = col_1.final_shares(); - for res in &KeyCollection::::final_values(&s0, &s1) { + for res in &KeyCollection::final_values(typ.input_len(), &s0, &s1) { println!("Path = {:?}", res.path); let s = crate::bits_to_string(&res.path); println!("Value: {:?} = {:?}", s, res.value); diff --git a/tests/dpf_test.rs b/tests/dpf_test.rs index f46f024..bb31b2c 100644 --- a/tests/dpf_test.rs +++ b/tests/dpf_test.rs @@ -8,8 +8,8 @@ use prio::field::Field64; fn dpf_complete() { let num_bits = 5; let alpha = u32_to_bits(num_bits, 21); - let beta = Field64::from(7u64); - let (key_0, key_1) = VIDPFKey::gen(&alpha, beta); + let beta = vec![Field64::from(7u64)]; + let (key_0, key_1) = VIDPFKey::gen(&alpha, &beta); let mut pi_0: [u8; HASH_SIZE] = hash(b"0").as_bytes()[0..HASH_SIZE].try_into().unwrap(); let mut pi_1: [u8; HASH_SIZE] = pi_0.clone(); @@ -19,14 +19,14 @@ fn dpf_complete() { println!("Alpha: {:?}", alpha); for j in 2..((num_bits - 1) as usize) { - let eval_0 = key_0.eval(&alpha_eval[0..j].to_vec(), &mut pi_0); - let eval_1 = key_1.eval(&alpha_eval[0..j].to_vec(), &mut pi_1); + let eval_0 = key_0.eval(&alpha_eval[0..j].to_vec(), &mut pi_0, 1); + let eval_1 = key_1.eval(&alpha_eval[0..j].to_vec(), &mut pi_1, 1); - let tmp = eval_0.0[j - 2].add(eval_1.0[j - 2]); + let tmp = eval_0.0[j - 2][0].add(eval_1.0[j - 2][0]); println!("[{:?}] Tmp {:?} = {:?}", alpha_eval, j, tmp); if alpha[0..j - 1] == alpha_eval[0..j - 1] { assert_eq!( - beta, tmp, + beta[0], tmp, "[Level {:?}] Value incorrect at {:?}", j, alpha_eval ); diff --git a/tests/flp_test.rs b/tests/flp_test.rs index 40cad80..21e4252 100644 --- a/tests/flp_test.rs +++ b/tests/flp_test.rs @@ -7,6 +7,7 @@ use prio::{ vdaf::xof::{IntoFieldVec, Xof, XofShake128}, }; use rand::{thread_rng, Rng}; +use rand_core::RngCore; use rayon::prelude::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator}; #[test] @@ -44,12 +45,12 @@ fn flp_random_beta_in_range() { assert!(run_flp_with_input(&verify_key, &sum, 100).is_err()); } -fn run_flp_with_input(verify_key: &[u8; 16], sum: &T, input: u64) -> Result +fn run_flp_with_input(verify_key: &[u8; 16], typ: &T, input: u64) -> Result where T: prio::flp::Type, { // 1. The Prover chooses a measurement and secret shares the input. - let input: Vec = sum.encode_measurement(&input)?; + let input: Vec = typ.encode_measurement(&input)?; let input_0 = input .iter() .map(|_| Field64::from(rand::thread_rng().gen::())) @@ -65,26 +66,38 @@ where // 2. The Prover generates prove_rand and query_rand (should be unique per proof). The Prover // uses prover_rand to generate the proof. Finally, the Prover secret shares the proof. - let prove_rand = random_vector(sum.prove_rand_len()).unwrap(); + let prove_rand = random_vector(typ.prove_rand_len()).unwrap(); let query_rand_xof = XofShake128::init(&verify_key, &nonce); let query_rand: Vec = query_rand_xof .clone() .into_seed_stream() - .into_field_vec(sum.query_rand_len()); - let joint_rand: Vec = { - // Assume that we have the two VIDPF seeds. - let mut vidpf_seeds = ([0u8; 16], [0u8; 16]); - thread_rng().fill(&mut vidpf_seeds.0); - thread_rng().fill(&mut vidpf_seeds.1); - let mut joint_rand_xof = XofShake128::init(&vidpf_seeds.0, &vidpf_seeds.1); - joint_rand_xof.update(&nonce); - joint_rand_xof - .clone() - .into_seed_stream() - .into_field_vec(sum.joint_rand_len()) - }; - - let proof = sum.prove(&input, &prove_rand, &joint_rand).unwrap(); + .into_field_vec(typ.query_rand_len()); + let mut vidpf_seeds = ([0u8; 16], [0u8; 16]); + thread_rng().fill(&mut vidpf_seeds.0); + thread_rng().fill(&mut vidpf_seeds.1); + + let mut jr_parts = [[0u8; 16]; 2]; + // Assume that we have the two VIDPF seeds. + let mut jr_part_0_xof = XofShake128::init(&vidpf_seeds.0, &[0u8; 16]); + jr_part_0_xof.update(&[0]); // Aggregator ID + jr_part_0_xof.update(&nonce); + jr_part_0_xof + .into_seed_stream() + .fill_bytes(&mut jr_parts[0]); + + let mut jr_part_1_xof = XofShake128::init(&vidpf_seeds.1, &[0u8; 16]); + jr_part_1_xof.update(&[1]); // Aggregator ID + jr_part_1_xof.update(&nonce); + jr_part_1_xof + .into_seed_stream() + .fill_bytes(&mut jr_parts[1]); + + let joint_rand_xof = XofShake128::init(&jr_parts[0], &jr_parts[1]); + let joint_rand: Vec = joint_rand_xof + .into_seed_stream() + .into_field_vec(typ.joint_rand_len()); + + let proof = typ.prove(&input, &prove_rand, &joint_rand).unwrap(); let proof_0 = proof .iter() .map(|_| Field64::from(rand::thread_rng().gen::())) @@ -98,10 +111,10 @@ where // 3. The Verifiers are provided with the nonce for each Client and can generate the query_rand // (should be the same between the Verifiers). Each Verifier queries the input and proof // shares and receives a verifier_share. - let verifier_0 = sum + let verifier_0 = typ .query(&input_0, &proof_0, &query_rand, &joint_rand, 2) .unwrap(); - let verifier_1 = sum + let verifier_1 = typ .query(&input_1, &proof_1, &query_rand, &joint_rand, 2) .unwrap(); @@ -112,5 +125,5 @@ where .map(|(v1, v2)| v1 + v2) .collect::>(); - sum.decide(&verifier) + typ.decide(&verifier) }