Skip to content

Commit

Permalink
Style changes
Browse files Browse the repository at this point in the history
  • Loading branch information
jimouris committed Dec 18, 2023
1 parent 2f63a4e commit 7150112
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 65 deletions.
77 changes: 39 additions & 38 deletions src/bin/driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ use mastic::{
GetProofsRequest, ResetRequest, RunFlpQueriesRequest, TreeCrawlLastRequest,
TreeCrawlRequest, TreeInitRequest, TreePruneRequest,
},
vidpf, BetaType, CollectorClient,
vidpf::{self, VidpfKey},
CollectorClient,
};
use prio::{
field::{random_vector, Field64},
Expand All @@ -23,9 +24,6 @@ use rand_core::RngCore;
use rayon::prelude::*;
use tarpc::{client, context, serde_transport::tcp, tokio_serde::formats::Bincode};

type Key = vidpf::VIDPFKey;
type Client = CollectorClient;

fn long_context() -> context::Context {
let mut ctx = context::current();

Expand All @@ -45,28 +43,31 @@ fn sample_string(len: usize) -> String {
fn generate_keys(
cfg: &config::Config,
typ: &Sum<Field64>,
) -> ((Vec<Key>, Vec<Key>), Vec<Vec<Field64>>) {
let (keys, values): ((Vec<Key>, Vec<Key>), Vec<Vec<Field64>>) = rayon::iter::repeat(0)
.take(cfg.unique_buckets)
.map(|_| {
// Generate a random number in the specified range
let beta = rand::thread_rng().gen_range(1..(1 << cfg.range_bits));
let input_beta: BetaType = typ.encode_measurement(&beta).unwrap();

(
Key::gen_from_str(&sample_string(cfg.data_bytes * 8), &input_beta),
input_beta,
)
})
.unzip();
) -> ((Vec<VidpfKey>, Vec<VidpfKey>), Vec<Vec<Field64>>) {
let (keys, values): ((Vec<VidpfKey>, Vec<VidpfKey>), Vec<Vec<Field64>>) =
rayon::iter::repeat(0)
.take(cfg.unique_buckets)
.map(|_| {
// Generate a random number in the specified range
let beta = rand::thread_rng().gen_range(1..(1 << cfg.range_bits));
let input_beta: Vec<Field64> = typ.encode_measurement(&beta).unwrap();

(
VidpfKey::gen_from_str(&sample_string(cfg.data_bytes * 8), &input_beta),
input_beta,
)
})
.unzip();

let encoded: Vec<u8> = bincode::serialize(&keys.0[0]).unwrap();
println!("Key size: {:?} bytes", encoded.len());
println!("VIDPFKey size: {:?} bytes", encoded.len());

(keys, values)
}

fn generate_randomness(keys: (&Vec<Key>, &Vec<Key>)) -> (Vec<[u8; 16]>, Vec<[[u8; 16]; 2]>) {
fn generate_randomness(
keys: (&Vec<VidpfKey>, &Vec<VidpfKey>),
) -> (Vec<[u8; 16]>, Vec<[[u8; 16]; 2]>) {
keys.0
.par_iter()
.zip(keys.1.par_iter())
Expand Down Expand Up @@ -127,8 +128,8 @@ fn generate_proofs(
}

async fn reset_servers(
client_0: &Client,
client_1: &Client,
client_0: &CollectorClient,
client_1: &CollectorClient,
verify_key: &[u8; 16],
) -> io::Result<()> {
let req = ResetRequest {
Expand All @@ -141,7 +142,7 @@ async fn reset_servers(
Ok(())
}

async fn tree_init(client_0: &Client, client_1: &Client) -> io::Result<()> {
async fn tree_init(client_0: &CollectorClient, client_1: &CollectorClient) -> io::Result<()> {
let req = TreeInitRequest {};
let resp_0 = client_0.tree_init(long_context(), req.clone());
let resp_1 = client_1.tree_init(long_context(), req);
Expand All @@ -152,9 +153,9 @@ async fn tree_init(client_0: &Client, client_1: &Client) -> io::Result<()> {

async fn add_keys(
cfg: &config::Config,
clients: (&Client, &Client),
keys: (&[vidpf::VIDPFKey], &[vidpf::VIDPFKey]),
proofs: (&[Vec<Field64>], &[Vec<Field64>]),
clients: (&CollectorClient, &CollectorClient),
all_keys: (&[vidpf::VidpfKey], &[vidpf::VidpfKey]),
all_proofs: (&[Vec<Field64>], &[Vec<Field64>]),
all_nonces: &[[u8; 16]],
all_jr_parts: &[[[u8; 16]; 2]],
num_clients: usize,
Expand Down Expand Up @@ -184,11 +185,11 @@ async fn add_keys(
}
println!("Malicious {}", r);
}
add_keys_0.push(keys.0[idx_1].clone());
add_keys_1.push(keys.1[idx_2 % cfg.unique_buckets].clone());
add_keys_0.push(all_keys.0[idx_1].clone());
add_keys_1.push(all_keys.1[idx_2 % cfg.unique_buckets].clone());

flp_proof_shares_0.push(proofs.0[idx_1].clone());
flp_proof_shares_1.push(proofs.1[idx_3 % cfg.unique_buckets].clone());
flp_proof_shares_0.push(all_proofs.0[idx_1].clone());
flp_proof_shares_1.push(all_proofs.1[idx_3 % cfg.unique_buckets].clone());

nonces.push(all_nonces[idx_1]);
jr_parts.push(all_jr_parts[idx_1]);
Expand Down Expand Up @@ -226,8 +227,8 @@ async fn add_keys(
async fn run_flp_queries(
cfg: &config::Config,
typ: &Sum<Field64>,
client_0: &Client,
client_1: &Client,
client_0: &CollectorClient,
client_1: &CollectorClient,
num_clients: usize,
) -> io::Result<()> {
// Receive FLP query responses in chunks of cfg.flp_batch_size to avoid having huge RPC messages.
Expand Down Expand Up @@ -273,8 +274,8 @@ async fn run_flp_queries(
async fn run_level(
cfg: &config::Config,
typ: &Sum<Field64>,
client_0: &Client,
client_1: &Client,
client_0: &CollectorClient,
client_1: &CollectorClient,
num_clients: usize,
) -> io::Result<()> {
let threshold = core::cmp::max(1, (cfg.threshold * (num_clients as f64)) as u64);
Expand Down Expand Up @@ -345,8 +346,8 @@ async fn run_level(
async fn run_level_last(
cfg: &config::Config,
typ: &Sum<Field64>,
client_0: &Client,
client_1: &Client,
client_0: &CollectorClient,
client_1: &CollectorClient,
num_clients: usize,
) -> io::Result<()> {
let threshold = core::cmp::max(1, (cfg.threshold * (num_clients as f64)) as u64);
Expand Down Expand Up @@ -409,12 +410,12 @@ async fn main() -> io::Result<()> {
let (cfg, _, num_clients, malicious) = config::get_args("Driver", false, true, true);
assert!((0.0..0.8).contains(&malicious));
println!("Running with {}% malicious clients", malicious * 100.0);
let client_0 = Client::new(
let client_0 = CollectorClient::new(
client::Config::default(),
tcp::connect(cfg.server_0, Bincode::default).await?,
)
.spawn();
let client_1 = Client::new(
let client_1 = CollectorClient::new(
client::Config::default(),
tcp::connect(cfg.server_1, Bincode::default).await?,
)
Expand Down
6 changes: 3 additions & 3 deletions src/bin/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use mastic::{
GetProofsRequest, ResetRequest, RunFlpQueriesRequest, TreeCrawlLastRequest,
TreeCrawlRequest, TreeInitRequest, TreePruneRequest,
},
BetaType, HASH_SIZE,
HASH_SIZE,
};
use prio::{field::Field64, flp::types::Sum};
use tarpc::{
Expand Down Expand Up @@ -85,7 +85,7 @@ impl Collector for CollectorServer {
self,
_: context::Context,
req: TreeCrawlRequest,
) -> (Vec<BetaType>, Vec<Vec<u8>>, Vec<usize>) {
) -> (Vec<Vec<Field64>>, Vec<Vec<u8>>, Vec<usize>) {
let start = Instant::now();
let split_by = req.split_by;
let malicious = req.malicious;
Expand Down Expand Up @@ -119,7 +119,7 @@ impl Collector for CollectorServer {
self,
_: context::Context,
_req: TreeCrawlLastRequest,
) -> Vec<BetaType> {
) -> Vec<Vec<Field64>> {
let start = Instant::now();
let mut coll = self.arc.lock().unwrap();

Expand Down
8 changes: 4 additions & 4 deletions src/collect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use rayon::prelude::*;
use rs_merkle::{Hasher, MerkleTree};
use serde::{Deserialize, Serialize};

use crate::{prg, vec_add, vec_sub, vidpf, xor_in_place, xor_vec, BetaType, HASH_SIZE};
use crate::{prg, vec_add, vec_sub, vidpf, xor_in_place, xor_vec, HASH_SIZE};

#[derive(Clone)]
pub struct HashAlg {}
Expand Down Expand Up @@ -40,7 +40,7 @@ pub struct KeyCollection {
server_id: i8,
verify_key: [u8; 16],
depth: usize,
pub keys: Vec<(bool, vidpf::VIDPFKey)>,
pub keys: Vec<(bool, vidpf::VidpfKey)>,
nonces: Vec<[u8; 16]>,
jr_parts: Vec<[[u8; 16]; 2]>,
all_flp_proof_shares: Vec<Vec<Field64>>,
Expand All @@ -52,7 +52,7 @@ pub struct KeyCollection {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Result {
pub path: Vec<bool>,
pub value: BetaType,
pub value: Vec<Field64>,
}

impl KeyCollection {
Expand All @@ -78,7 +78,7 @@ impl KeyCollection {
}
}

pub fn add_key(&mut self, key: vidpf::VIDPFKey) {
pub fn add_key(&mut self, key: vidpf::VidpfKey) {
self.keys.push((true, key));
}

Expand Down
2 changes: 0 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ use prio::field::Field64;

pub use crate::rpc::CollectorClient;

pub type BetaType = Vec<Field64>;

pub const HASH_SIZE: usize = 16;

impl crate::prg::FromRng for Field64 {
Expand Down
8 changes: 4 additions & 4 deletions src/rpc.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use prio::field::Field64;
use serde::{Deserialize, Serialize};

use crate::{collect, vidpf, BetaType, HASH_SIZE};
use crate::{collect, vidpf, HASH_SIZE};

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ResetRequest {
Expand All @@ -10,7 +10,7 @@ pub struct ResetRequest {

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct AddKeysRequest {
pub keys: Vec<vidpf::VIDPFKey>,
pub keys: Vec<vidpf::VidpfKey>,
}

#[derive(Clone, Debug, Serialize, Deserialize)]
Expand Down Expand Up @@ -65,8 +65,8 @@ pub trait Collector {
async fn add_all_flp_proof_shares(req: AddFLPsRequest) -> String;
async fn run_flp_queries(req: RunFlpQueriesRequest) -> Vec<Vec<Field64>>;
async fn apply_flp_results(req: ApplyFLPResultsRequest) -> String;
async fn tree_crawl(req: TreeCrawlRequest) -> (Vec<BetaType>, Vec<Vec<u8>>, Vec<usize>);
async fn tree_crawl_last(req: TreeCrawlLastRequest) -> Vec<BetaType>;
async fn tree_crawl(req: TreeCrawlRequest) -> (Vec<Vec<Field64>>, Vec<Vec<u8>>, Vec<usize>);
async fn tree_crawl_last(req: TreeCrawlLastRequest) -> Vec<Vec<Field64>>;
async fn get_proofs(req: GetProofsRequest) -> Vec<[u8; HASH_SIZE]>;
async fn tree_init(req: TreeInitRequest) -> String;
async fn tree_prune(req: TreePruneRequest) -> String;
Expand Down
20 changes: 10 additions & 10 deletions src/vidpf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,17 @@ use blake3::Hasher;
use prio::field::Field64;
use serde::{Deserialize, Serialize};

use crate::{prg, vec_add, vec_neg, vec_sub, xor_three_vecs, xor_vec, BetaType, HASH_SIZE};
use crate::{prg, vec_add, vec_neg, vec_sub, xor_three_vecs, xor_vec, HASH_SIZE};

#[derive(Clone, Debug, Serialize, Deserialize)]
struct CorWord {
seed: prg::PrgSeed,
bits: (bool, bool),
word: BetaType,
word: Vec<Field64>,
}

#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct VIDPFKey {
pub struct VidpfKey {
pub key_idx: bool,
root_seed: prg::PrgSeed,
cor_words: Vec<CorWord>,
Expand Down Expand Up @@ -81,7 +81,7 @@ impl<T> TupleExt<T> for (T, T) {

fn gen_cor_word(
bit: bool,
beta: &BetaType,
beta: &Vec<Field64>,
bits: &mut (bool, bool),
seeds: &mut (prg::PrgSeed, prg::PrgSeed),
) -> CorWord {
Expand Down Expand Up @@ -135,12 +135,12 @@ fn gen_cor_word(
}

/// All-prefix DPF implementation.
impl VIDPFKey {
impl VidpfKey {
pub fn get_root_seed(&self) -> prg::PrgSeed {
self.root_seed.clone()
}

pub fn gen(alpha_bits: &[bool], beta: &BetaType) -> (VIDPFKey, VIDPFKey) {
pub fn gen(alpha_bits: &[bool], beta: &Vec<Field64>) -> (VidpfKey, VidpfKey) {
let root_seeds = (prg::PrgSeed::random(), prg::PrgSeed::random());
let root_bits = (false, true);

Expand Down Expand Up @@ -176,13 +176,13 @@ impl VIDPFKey {
}

(
VIDPFKey {
VidpfKey {
key_idx: false,
root_seed: root_seeds.0,
cor_words: cor_words.clone(),
cs: cs.clone(),
},
VIDPFKey {
VidpfKey {
key_idx: true,
root_seed: root_seeds.1,
cor_words,
Expand Down Expand Up @@ -297,9 +297,9 @@ impl VIDPFKey {
(out, last)
}

pub fn gen_from_str(s: &str, beta: &BetaType) -> (Self, Self) {
pub fn gen_from_str(s: &str, beta: &Vec<Field64>) -> (Self, Self) {
let bits = crate::string_to_bits(s);
VIDPFKey::gen(&bits, beta)
VidpfKey::gen(&bits, beta)
}

pub fn domain_size(&self) -> usize {
Expand Down
6 changes: 3 additions & 3 deletions tests/collect_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ fn collect_test_eval_groups() {

for cstr in &client_strings {
let input_beta = typ.encode_measurement(&3u64).unwrap();
let (keys_0, keys_1) = vidpf::VIDPFKey::gen_from_str(&cstr, &input_beta);
let (keys_0, keys_1) = vidpf::VidpfKey::gen_from_str(&cstr, &input_beta);
col_0.add_key(keys_0);
col_1.add_key(keys_1);
}
Expand Down Expand Up @@ -109,7 +109,7 @@ fn collect_test_eval_full_groups() {
println!("Starting to generate keys");
for s in &client_strings {
let input_beta = typ.encode_measurement(&1u64).unwrap();
keys.push(vidpf::VIDPFKey::gen_from_str(&s, &input_beta));
keys.push(vidpf::VidpfKey::gen_from_str(&s, &input_beta));
}
println!("Done generating keys");

Expand All @@ -119,7 +119,7 @@ fn collect_test_eval_full_groups() {
col_0.add_key(copy_0);
col_1.add_key(copy_1);
if i % 50 == 0 {
println!(" Key {:?}", i);
println!(" VIDPFKey {:?}", i);
}
}

Expand Down
2 changes: 1 addition & 1 deletion tests/dpf_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ fn dpf_complete() {
let num_bits = 5;
let alpha = u32_to_bits(num_bits, 21);
let beta = vec![Field64::from(7u64)];
let (key_0, key_1) = VIDPFKey::gen(&alpha, &beta);
let (key_0, key_1) = VidpfKey::gen(&alpha, &beta);

let mut pi_0: [u8; HASH_SIZE] = hash(b"0").as_bytes()[0..HASH_SIZE].try_into().unwrap();
let mut pi_1: [u8; HASH_SIZE] = pi_0.clone();
Expand Down

0 comments on commit 7150112

Please sign in to comment.