diff --git a/src/bin/driver.rs b/src/bin/driver.rs index 0fb4524..17cc3c2 100644 --- a/src/bin/driver.rs +++ b/src/bin/driver.rs @@ -11,7 +11,8 @@ use mastic::{ GetProofsRequest, ResetRequest, RunFlpQueriesRequest, TreeCrawlLastRequest, TreeCrawlRequest, TreeInitRequest, TreePruneRequest, }, - vidpf, BetaType, CollectorClient, + vidpf::{self, VidpfKey}, + CollectorClient, }; use prio::{ field::{random_vector, Field64}, @@ -23,9 +24,6 @@ use rand_core::RngCore; use rayon::prelude::*; use tarpc::{client, context, serde_transport::tcp, tokio_serde::formats::Bincode}; -type Key = vidpf::VIDPFKey; -type Client = CollectorClient; - fn long_context() -> context::Context { let mut ctx = context::current(); @@ -45,28 +43,31 @@ fn sample_string(len: usize) -> String { fn generate_keys( cfg: &config::Config, typ: &Sum, -) -> ((Vec, Vec), Vec>) { - let (keys, values): ((Vec, Vec), Vec>) = rayon::iter::repeat(0) - .take(cfg.unique_buckets) - .map(|_| { - // Generate a random number in the specified range - let beta = rand::thread_rng().gen_range(1..(1 << cfg.range_bits)); - let input_beta: BetaType = typ.encode_measurement(&beta).unwrap(); - - ( - Key::gen_from_str(&sample_string(cfg.data_bytes * 8), &input_beta), - input_beta, - ) - }) - .unzip(); +) -> ((Vec, Vec), Vec>) { + let (keys, values): ((Vec, Vec), Vec>) = + rayon::iter::repeat(0) + .take(cfg.unique_buckets) + .map(|_| { + // Generate a random number in the specified range + let beta = rand::thread_rng().gen_range(1..(1 << cfg.range_bits)); + let input_beta: Vec = typ.encode_measurement(&beta).unwrap(); + + ( + VidpfKey::gen_from_str(&sample_string(cfg.data_bytes * 8), &input_beta), + input_beta, + ) + }) + .unzip(); let encoded: Vec = bincode::serialize(&keys.0[0]).unwrap(); - println!("Key size: {:?} bytes", encoded.len()); + println!("VIDPFKey size: {:?} bytes", encoded.len()); (keys, values) } -fn generate_randomness(keys: (&Vec, &Vec)) -> (Vec<[u8; 16]>, Vec<[[u8; 16]; 2]>) { +fn generate_randomness( + keys: (&Vec, &Vec), +) -> (Vec<[u8; 16]>, Vec<[[u8; 16]; 2]>) { keys.0 .par_iter() .zip(keys.1.par_iter()) @@ -127,8 +128,8 @@ fn generate_proofs( } async fn reset_servers( - client_0: &Client, - client_1: &Client, + client_0: &CollectorClient, + client_1: &CollectorClient, verify_key: &[u8; 16], ) -> io::Result<()> { let req = ResetRequest { @@ -141,7 +142,7 @@ async fn reset_servers( Ok(()) } -async fn tree_init(client_0: &Client, client_1: &Client) -> io::Result<()> { +async fn tree_init(client_0: &CollectorClient, client_1: &CollectorClient) -> io::Result<()> { let req = TreeInitRequest {}; let resp_0 = client_0.tree_init(long_context(), req.clone()); let resp_1 = client_1.tree_init(long_context(), req); @@ -152,9 +153,9 @@ async fn tree_init(client_0: &Client, client_1: &Client) -> io::Result<()> { async fn add_keys( cfg: &config::Config, - clients: (&Client, &Client), - keys: (&[vidpf::VIDPFKey], &[vidpf::VIDPFKey]), - proofs: (&[Vec], &[Vec]), + clients: (&CollectorClient, &CollectorClient), + all_keys: (&[vidpf::VidpfKey], &[vidpf::VidpfKey]), + all_proofs: (&[Vec], &[Vec]), all_nonces: &[[u8; 16]], all_jr_parts: &[[[u8; 16]; 2]], num_clients: usize, @@ -184,11 +185,11 @@ async fn add_keys( } println!("Malicious {}", r); } - add_keys_0.push(keys.0[idx_1].clone()); - add_keys_1.push(keys.1[idx_2 % cfg.unique_buckets].clone()); + add_keys_0.push(all_keys.0[idx_1].clone()); + add_keys_1.push(all_keys.1[idx_2 % cfg.unique_buckets].clone()); - flp_proof_shares_0.push(proofs.0[idx_1].clone()); - flp_proof_shares_1.push(proofs.1[idx_3 % cfg.unique_buckets].clone()); + flp_proof_shares_0.push(all_proofs.0[idx_1].clone()); + flp_proof_shares_1.push(all_proofs.1[idx_3 % cfg.unique_buckets].clone()); nonces.push(all_nonces[idx_1]); jr_parts.push(all_jr_parts[idx_1]); @@ -226,8 +227,8 @@ async fn add_keys( async fn run_flp_queries( cfg: &config::Config, typ: &Sum, - client_0: &Client, - client_1: &Client, + client_0: &CollectorClient, + client_1: &CollectorClient, num_clients: usize, ) -> io::Result<()> { // Receive FLP query responses in chunks of cfg.flp_batch_size to avoid having huge RPC messages. @@ -273,8 +274,8 @@ async fn run_flp_queries( async fn run_level( cfg: &config::Config, typ: &Sum, - client_0: &Client, - client_1: &Client, + client_0: &CollectorClient, + client_1: &CollectorClient, num_clients: usize, ) -> io::Result<()> { let threshold = core::cmp::max(1, (cfg.threshold * (num_clients as f64)) as u64); @@ -345,8 +346,8 @@ async fn run_level( async fn run_level_last( cfg: &config::Config, typ: &Sum, - client_0: &Client, - client_1: &Client, + client_0: &CollectorClient, + client_1: &CollectorClient, num_clients: usize, ) -> io::Result<()> { let threshold = core::cmp::max(1, (cfg.threshold * (num_clients as f64)) as u64); @@ -409,12 +410,12 @@ async fn main() -> io::Result<()> { let (cfg, _, num_clients, malicious) = config::get_args("Driver", false, true, true); assert!((0.0..0.8).contains(&malicious)); println!("Running with {}% malicious clients", malicious * 100.0); - let client_0 = Client::new( + let client_0 = CollectorClient::new( client::Config::default(), tcp::connect(cfg.server_0, Bincode::default).await?, ) .spawn(); - let client_1 = Client::new( + let client_1 = CollectorClient::new( client::Config::default(), tcp::connect(cfg.server_1, Bincode::default).await?, ) diff --git a/src/bin/server.rs b/src/bin/server.rs index 2f3c856..2780af4 100644 --- a/src/bin/server.rs +++ b/src/bin/server.rs @@ -12,7 +12,7 @@ use mastic::{ GetProofsRequest, ResetRequest, RunFlpQueriesRequest, TreeCrawlLastRequest, TreeCrawlRequest, TreeInitRequest, TreePruneRequest, }, - BetaType, HASH_SIZE, + HASH_SIZE, }; use prio::{field::Field64, flp::types::Sum}; use tarpc::{ @@ -85,7 +85,7 @@ impl Collector for CollectorServer { self, _: context::Context, req: TreeCrawlRequest, - ) -> (Vec, Vec>, Vec) { + ) -> (Vec>, Vec>, Vec) { let start = Instant::now(); let split_by = req.split_by; let malicious = req.malicious; @@ -119,7 +119,7 @@ impl Collector for CollectorServer { self, _: context::Context, _req: TreeCrawlLastRequest, - ) -> Vec { + ) -> Vec> { let start = Instant::now(); let mut coll = self.arc.lock().unwrap(); diff --git a/src/collect.rs b/src/collect.rs index 35e815a..2630d7c 100644 --- a/src/collect.rs +++ b/src/collect.rs @@ -10,7 +10,7 @@ use rayon::prelude::*; use rs_merkle::{Hasher, MerkleTree}; use serde::{Deserialize, Serialize}; -use crate::{prg, vec_add, vec_sub, vidpf, xor_in_place, xor_vec, BetaType, HASH_SIZE}; +use crate::{prg, vec_add, vec_sub, vidpf, xor_in_place, xor_vec, HASH_SIZE}; #[derive(Clone)] pub struct HashAlg {} @@ -40,7 +40,7 @@ pub struct KeyCollection { server_id: i8, verify_key: [u8; 16], depth: usize, - pub keys: Vec<(bool, vidpf::VIDPFKey)>, + pub keys: Vec<(bool, vidpf::VidpfKey)>, nonces: Vec<[u8; 16]>, jr_parts: Vec<[[u8; 16]; 2]>, all_flp_proof_shares: Vec>, @@ -52,7 +52,7 @@ pub struct KeyCollection { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Result { pub path: Vec, - pub value: BetaType, + pub value: Vec, } impl KeyCollection { @@ -78,7 +78,7 @@ impl KeyCollection { } } - pub fn add_key(&mut self, key: vidpf::VIDPFKey) { + pub fn add_key(&mut self, key: vidpf::VidpfKey) { self.keys.push((true, key)); } diff --git a/src/lib.rs b/src/lib.rs index c26c791..a88702c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,8 +10,6 @@ use prio::field::Field64; pub use crate::rpc::CollectorClient; -pub type BetaType = Vec; - pub const HASH_SIZE: usize = 16; impl crate::prg::FromRng for Field64 { diff --git a/src/rpc.rs b/src/rpc.rs index 2e68202..3a6dae1 100644 --- a/src/rpc.rs +++ b/src/rpc.rs @@ -1,7 +1,7 @@ use prio::field::Field64; use serde::{Deserialize, Serialize}; -use crate::{collect, vidpf, BetaType, HASH_SIZE}; +use crate::{collect, vidpf, HASH_SIZE}; #[derive(Clone, Debug, Serialize, Deserialize)] pub struct ResetRequest { @@ -10,7 +10,7 @@ pub struct ResetRequest { #[derive(Clone, Debug, Serialize, Deserialize)] pub struct AddKeysRequest { - pub keys: Vec, + pub keys: Vec, } #[derive(Clone, Debug, Serialize, Deserialize)] @@ -65,8 +65,8 @@ pub trait Collector { async fn add_all_flp_proof_shares(req: AddFLPsRequest) -> String; async fn run_flp_queries(req: RunFlpQueriesRequest) -> Vec>; async fn apply_flp_results(req: ApplyFLPResultsRequest) -> String; - async fn tree_crawl(req: TreeCrawlRequest) -> (Vec, Vec>, Vec); - async fn tree_crawl_last(req: TreeCrawlLastRequest) -> Vec; + async fn tree_crawl(req: TreeCrawlRequest) -> (Vec>, Vec>, Vec); + async fn tree_crawl_last(req: TreeCrawlLastRequest) -> Vec>; async fn get_proofs(req: GetProofsRequest) -> Vec<[u8; HASH_SIZE]>; async fn tree_init(req: TreeInitRequest) -> String; async fn tree_prune(req: TreePruneRequest) -> String; diff --git a/src/vidpf.rs b/src/vidpf.rs index ebb4f1f..3288764 100644 --- a/src/vidpf.rs +++ b/src/vidpf.rs @@ -2,17 +2,17 @@ use blake3::Hasher; use prio::field::Field64; use serde::{Deserialize, Serialize}; -use crate::{prg, vec_add, vec_neg, vec_sub, xor_three_vecs, xor_vec, BetaType, HASH_SIZE}; +use crate::{prg, vec_add, vec_neg, vec_sub, xor_three_vecs, xor_vec, HASH_SIZE}; #[derive(Clone, Debug, Serialize, Deserialize)] struct CorWord { seed: prg::PrgSeed, bits: (bool, bool), - word: BetaType, + word: Vec, } #[derive(Clone, Debug, Serialize, Deserialize)] -pub struct VIDPFKey { +pub struct VidpfKey { pub key_idx: bool, root_seed: prg::PrgSeed, cor_words: Vec, @@ -81,7 +81,7 @@ impl TupleExt for (T, T) { fn gen_cor_word( bit: bool, - beta: &BetaType, + beta: &Vec, bits: &mut (bool, bool), seeds: &mut (prg::PrgSeed, prg::PrgSeed), ) -> CorWord { @@ -135,12 +135,12 @@ fn gen_cor_word( } /// All-prefix DPF implementation. -impl VIDPFKey { +impl VidpfKey { pub fn get_root_seed(&self) -> prg::PrgSeed { self.root_seed.clone() } - pub fn gen(alpha_bits: &[bool], beta: &BetaType) -> (VIDPFKey, VIDPFKey) { + pub fn gen(alpha_bits: &[bool], beta: &Vec) -> (VidpfKey, VidpfKey) { let root_seeds = (prg::PrgSeed::random(), prg::PrgSeed::random()); let root_bits = (false, true); @@ -176,13 +176,13 @@ impl VIDPFKey { } ( - VIDPFKey { + VidpfKey { key_idx: false, root_seed: root_seeds.0, cor_words: cor_words.clone(), cs: cs.clone(), }, - VIDPFKey { + VidpfKey { key_idx: true, root_seed: root_seeds.1, cor_words, @@ -297,9 +297,9 @@ impl VIDPFKey { (out, last) } - pub fn gen_from_str(s: &str, beta: &BetaType) -> (Self, Self) { + pub fn gen_from_str(s: &str, beta: &Vec) -> (Self, Self) { let bits = crate::string_to_bits(s); - VIDPFKey::gen(&bits, beta) + VidpfKey::gen(&bits, beta) } pub fn domain_size(&self) -> usize { diff --git a/tests/collect_test.rs b/tests/collect_test.rs index 63cb865..f3b4423 100644 --- a/tests/collect_test.rs +++ b/tests/collect_test.rs @@ -24,7 +24,7 @@ fn collect_test_eval_groups() { for cstr in &client_strings { let input_beta = typ.encode_measurement(&3u64).unwrap(); - let (keys_0, keys_1) = vidpf::VIDPFKey::gen_from_str(&cstr, &input_beta); + let (keys_0, keys_1) = vidpf::VidpfKey::gen_from_str(&cstr, &input_beta); col_0.add_key(keys_0); col_1.add_key(keys_1); } @@ -109,7 +109,7 @@ fn collect_test_eval_full_groups() { println!("Starting to generate keys"); for s in &client_strings { let input_beta = typ.encode_measurement(&1u64).unwrap(); - keys.push(vidpf::VIDPFKey::gen_from_str(&s, &input_beta)); + keys.push(vidpf::VidpfKey::gen_from_str(&s, &input_beta)); } println!("Done generating keys"); @@ -119,7 +119,7 @@ fn collect_test_eval_full_groups() { col_0.add_key(copy_0); col_1.add_key(copy_1); if i % 50 == 0 { - println!(" Key {:?}", i); + println!(" VIDPFKey {:?}", i); } } diff --git a/tests/dpf_test.rs b/tests/dpf_test.rs index bb31b2c..ff7d393 100644 --- a/tests/dpf_test.rs +++ b/tests/dpf_test.rs @@ -9,7 +9,7 @@ fn dpf_complete() { let num_bits = 5; let alpha = u32_to_bits(num_bits, 21); let beta = vec![Field64::from(7u64)]; - let (key_0, key_1) = VIDPFKey::gen(&alpha, &beta); + let (key_0, key_1) = VidpfKey::gen(&alpha, &beta); let mut pi_0: [u8; HASH_SIZE] = hash(b"0").as_bytes()[0..HASH_SIZE].try_into().unwrap(); let mut pi_1: [u8; HASH_SIZE] = pi_0.clone();