Skip to content

Commit

Permalink
Make beta a vector
Browse files Browse the repository at this point in the history
  • Loading branch information
jimouris committed Dec 11, 2023
1 parent 8174f79 commit 14e897d
Show file tree
Hide file tree
Showing 12 changed files with 451 additions and 266 deletions.
1 change: 1 addition & 0 deletions src/bin/config.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{
"data_bytes": 1,
"range_bits": 2,
"threshold": 0.01,
"server_0": "0.0.0.0:8000",
"server_1": "0.0.0.0:8001",
Expand Down
178 changes: 125 additions & 53 deletions src/bin/leader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,19 @@ use mastic::{
GetProofsRequest, ResetRequest, RunFlpQueriesRequest, TreeCrawlLastRequest,
TreeCrawlRequest, TreeInitRequest, TreePruneRequest,
},
vidpf, CollectorClient,
vidpf, BetaType, CollectorClient,
};
use prio::{
field::{random_vector, Field64},
flp::{types::Count, Type},
flp::{types::Sum, Type},
vdaf::xof::{IntoFieldVec, Xof, XofShake128},
};
use rand::{distributions::Alphanumeric, thread_rng, Rng};
use rand_core::RngCore;
use rayon::prelude::*;
use tarpc::{client, context, serde_transport::tcp, tokio_serde::formats::Bincode};

type Key = vidpf::VIDPFKey<Field64>;
type Key = vidpf::VIDPFKey;
type Client = CollectorClient;

fn long_context() -> context::Context {
Expand All @@ -42,23 +44,72 @@ fn sample_string(len: usize) -> String {

fn generate_keys(
cfg: &config::Config,
) -> ((Vec<Key>, Vec<Key>), (Vec<Vec<Field64>>, Vec<Vec<Field64>>)) {
let beta = 1u64;
let count = Count::new();
let input_beta: Vec<Field64> = count.encode_measurement(&beta).unwrap();

let (keys_0, keys_1): (Vec<Key>, Vec<Key>) = rayon::iter::repeat(0)
typ: &Sum<Field64>,
) -> ((Vec<Key>, Vec<Key>), Vec<Vec<Field64>>) {
let (keys, values): ((Vec<Key>, Vec<Key>), Vec<Vec<Field64>>) = rayon::iter::repeat(0)
.take(cfg.unique_buckets)
.map(|_| {
vidpf::VIDPFKey::gen_from_str(&sample_string(cfg.data_bytes * 8), Field64::from(beta))
// Generate a random number in the specified range
let beta = rand::thread_rng().gen_range(1..(1 << cfg.range_bits));
let input_beta: BetaType = typ.encode_measurement(&beta).unwrap();

(
Key::gen_from_str(&sample_string(cfg.data_bytes * 8), &input_beta),
input_beta,
)
})
.unzip();

let (proofs_0, proofs_1): (Vec<Vec<Field64>>, Vec<Vec<Field64>>) = rayon::iter::repeat(0)
.take(cfg.unique_buckets)
.map(|_| {
let prove_rand = random_vector(count.prove_rand_len()).unwrap();
let proof = count.prove(&input_beta, &prove_rand, &[]).unwrap();
let encoded: Vec<u8> = bincode::serialize(&keys.0[0]).unwrap();
println!("Key size: {:?} bytes", encoded.len());

(keys, values)
}

fn generate_randomness(keys: (&Vec<Key>, &Vec<Key>)) -> (Vec<[u8; 16]>, Vec<[[u8; 16]; 2]>) {
keys.0
.par_iter()
.zip(keys.1.par_iter())
.map(|(key_0, key_1)| {
let nonce = rand::thread_rng().gen::<u128>().to_le_bytes();
let vidpf_seeds = (key_0.get_root_seed().key, key_1.get_root_seed().key);

let mut jr_parts = [[0u8; 16]; 2];
let mut jr_part_0_xof = XofShake128::init(&vidpf_seeds.0, &[0u8; 16]);
jr_part_0_xof.update(&[0]); // Aggregator ID
jr_part_0_xof.update(&nonce);
jr_part_0_xof
.into_seed_stream()
.fill_bytes(&mut jr_parts[0]);

let mut jr_part_1_xof = XofShake128::init(&vidpf_seeds.1, &[0u8; 16]);
jr_part_1_xof.update(&[1]); // Aggregator ID
jr_part_1_xof.update(&nonce);
jr_part_1_xof
.into_seed_stream()
.fill_bytes(&mut jr_parts[1]);

(nonce, jr_parts)
})
.unzip()
}

fn generate_proofs(
typ: &Sum<Field64>,
beta_values: &Vec<Vec<Field64>>,
all_jr_parts: &Vec<[[u8; 16]; 2]>,
) -> (Vec<Vec<Field64>>, Vec<Vec<Field64>>) {
all_jr_parts
.par_iter()
.zip_eq(beta_values.par_iter())
.map(|(jr_parts, input_beta)| {
let joint_rand_xof = XofShake128::init(&jr_parts[0], &jr_parts[1]);
let joint_rand: Vec<Field64> = joint_rand_xof
.into_seed_stream()
.into_field_vec(typ.joint_rand_len());

let prove_rand = random_vector(typ.prove_rand_len()).unwrap();
let proof = typ.prove(input_beta, &prove_rand, &joint_rand).unwrap();

let proof_0 = proof
.iter()
Expand All @@ -69,14 +120,10 @@ fn generate_keys(
.zip(proof_0.par_iter())
.map(|(p_0, p_1)| p_0 - p_1)
.collect::<Vec<_>>();

(proof_0, proof_1)
})
.unzip();

let encoded: Vec<u8> = bincode::serialize(&keys_0[0]).unwrap();
println!("Key size: {:?} bytes", encoded.len());

((keys_0, keys_1), (proofs_0, proofs_1))
.unzip()
}

async fn reset_servers(
Expand Down Expand Up @@ -105,12 +152,11 @@ async fn tree_init(client_0: &Client, client_1: &Client) -> io::Result<()> {

async fn add_keys(
cfg: &config::Config,
client_0: &Client,
client_1: &Client,
keys_0: &[vidpf::VIDPFKey<Field64>],
keys_1: &[vidpf::VIDPFKey<Field64>],
proofs_0: &[Vec<Field64>],
proofs_1: &[Vec<Field64>],
clients: (&Client, &Client),
keys: (&[vidpf::VIDPFKey], &[vidpf::VIDPFKey]),
proofs: (&[Vec<Field64>], &[Vec<Field64>]),
all_nonces: &[[u8; 16]],
all_jr_parts: &[[[u8; 16]; 2]],
num_clients: usize,
malicious_percentage: f32,
) -> io::Result<()> {
Expand All @@ -123,6 +169,7 @@ async fn add_keys(
let mut flp_proof_shares_0 = Vec::with_capacity(num_clients);
let mut flp_proof_shares_1 = Vec::with_capacity(num_clients);
let mut nonces = Vec::with_capacity(num_clients);
let mut jr_parts = Vec::with_capacity(num_clients);
for r in 0..num_clients {
let idx_1 = zipf.sample(&mut rng) - 1;
let mut idx_2 = idx_1;
Expand All @@ -137,30 +184,38 @@ async fn add_keys(
}
println!("Malicious {}", r);
}
add_keys_0.push(keys_0[idx_1].clone());
add_keys_1.push(keys_1[idx_2 % cfg.unique_buckets].clone());
add_keys_0.push(keys.0[idx_1].clone());
add_keys_1.push(keys.1[idx_2 % cfg.unique_buckets].clone());

flp_proof_shares_0.push(proofs.0[idx_1].clone());
flp_proof_shares_1.push(proofs.1[idx_3 % cfg.unique_buckets].clone());

flp_proof_shares_0.push(proofs_0[idx_1].clone());
flp_proof_shares_1.push(proofs_1[idx_3 % cfg.unique_buckets].clone());
nonces.push(rand::thread_rng().gen::<u128>().to_le_bytes());
nonces.push(all_nonces[idx_1]);
jr_parts.push(all_jr_parts[idx_1]);
}

let resp_0 = client_0.add_keys(long_context(), AddKeysRequest { keys: add_keys_0 });
let resp_1 = client_1.add_keys(long_context(), AddKeysRequest { keys: add_keys_1 });
let resp_0 = clients
.0
.add_keys(long_context(), AddKeysRequest { keys: add_keys_0 });
let resp_1 = clients
.1
.add_keys(long_context(), AddKeysRequest { keys: add_keys_1 });
try_join!(resp_0, resp_1).unwrap();

let resp_0 = client_0.add_all_flp_proof_shares(
let resp_0 = clients.0.add_all_flp_proof_shares(
long_context(),
AddFLPsRequest {
flp_proof_shares: flp_proof_shares_0,
nonces: nonces.clone(),
jr_parts: jr_parts.clone(),
},
);
let resp_1 = client_1.add_all_flp_proof_shares(
let resp_1 = clients.1.add_all_flp_proof_shares(
long_context(),
AddFLPsRequest {
flp_proof_shares: flp_proof_shares_1,
nonces,
nonces: nonces.clone(),
jr_parts: jr_parts.clone(),
},
);
try_join!(resp_0, resp_1).unwrap();
Expand All @@ -170,12 +225,12 @@ async fn add_keys(

async fn run_flp_queries(
cfg: &config::Config,
typ: &Sum<Field64>,
client_0: &Client,
client_1: &Client,
num_clients: usize,
) -> io::Result<()> {
// Receive FLP query responses in chunks of cfg.flp_batch_size to avoid having huge RPC messages.
let count = Count::new();
let mut keep = vec![];
let mut start = 0;
while start < num_clients {
Expand All @@ -198,7 +253,7 @@ async fn run_flp_queries(
.map(|(&v1, &v2)| v1 + v2)
.collect::<Vec<_>>();

count.decide(&flp_verifier).unwrap()
typ.decide(&flp_verifier).unwrap()
})
.collect::<Vec<_>>(),
);
Expand All @@ -217,6 +272,7 @@ async fn run_flp_queries(

async fn run_level(
cfg: &config::Config,
typ: &Sum<Field64>,
client_0: &Client,
client_1: &Client,
num_clients: usize,
Expand All @@ -240,8 +296,12 @@ async fn run_level(
try_join!(resp_0, resp_1).unwrap();

assert_eq!(cnt_values_0.len(), cnt_values_1.len());
keep =
collect::KeyCollection::<Field64>::keep_values(threshold, &cnt_values_0, &cnt_values_1);
keep = collect::KeyCollection::keep_values(
typ.input_len(),
threshold,
&cnt_values_0,
&cnt_values_1,
);
if mt_root_0.is_empty() {
break;
}
Expand Down Expand Up @@ -284,6 +344,7 @@ async fn run_level(

async fn run_level_last(
cfg: &config::Config,
typ: &Sum<Field64>,
client_0: &Client,
client_1: &Client,
num_clients: usize,
Expand All @@ -295,8 +356,12 @@ async fn run_level_last(
let resp_1 = client_1.tree_crawl_last(long_context(), req);
let (cnt_values_0, cnt_values_1) = try_join!(resp_0, resp_1).unwrap();
assert_eq!(cnt_values_0.len(), cnt_values_1.len());
let keep =
collect::KeyCollection::<Field64>::keep_values(threshold, &cnt_values_0, &cnt_values_1);
let keep = collect::KeyCollection::keep_values(
typ.input_len(),
threshold,
&cnt_values_0,
&cnt_values_1,
);

// Receive counters in chunks to avoid having huge RPC messages.
let mut start = 0;
Expand Down Expand Up @@ -329,9 +394,9 @@ async fn run_level_last(
let resp_0 = client_0.final_shares(long_context(), req.clone());
let resp_1 = client_1.final_shares(long_context(), req);
let (shares_0, shares_1) = try_join!(resp_0, resp_1).unwrap();
for res in &collect::KeyCollection::<Field64>::final_values(&shares_0, &shares_1) {
for res in &collect::KeyCollection::final_values(typ.input_len(), &shares_0, &shares_1) {
let bits = mastic::bits_to_bitstring(&res.path);
if res.value > Field64::from(0) {
if res.value[typ.input_len() - 1] > Field64::from(0) {
println!("Value ({}) \t Count: {:?}", bits, res.value);
}
}
Expand All @@ -341,9 +406,6 @@ async fn run_level_last(

#[tokio::main]
async fn main() -> io::Result<()> {
// println!("Using only one thread!");
// rayon::ThreadPoolBuilder::new().num_threads(1).build_global().unwrap();

let (cfg, _, num_clients, malicious) = config::get_args("Leader", false, true, true);
assert!((0.0..0.8).contains(&malicious));
println!("Running with {}% malicious clients", malicious * 100.0);
Expand All @@ -358,10 +420,14 @@ async fn main() -> io::Result<()> {
)
.spawn();

let typ = Sum::<Field64>::new(cfg.range_bits).unwrap();

let start = Instant::now();
println!("Generating keys...");
let ((keys_0, keys_1), (proofs_0, proofs_1)) = generate_keys(&cfg);
let ((keys_0, keys_1), beta_values) = generate_keys(&cfg, &typ);
let delta = start.elapsed().as_secs_f64();
let (nonces, jr_parts) = generate_randomness((&keys_0, &keys_1));
let (proofs_0, proofs_1) = generate_proofs(&typ, &beta_values, &jr_parts);
println!(
"Generated {:?} keys in {:?} seconds ({:?} sec/key)",
keys_0.len(),
Expand All @@ -385,7 +451,13 @@ async fn main() -> io::Result<()> {

if this_batch > 0 {
responses.push(add_keys(
&cfg, &client_0, &client_1, &keys_0, &keys_1, &proofs_0, &proofs_1, this_batch,
&cfg,
(&client_0, &client_1),
(&keys_0, &keys_1),
(&proofs_0, &proofs_1),
&nonces,
&jr_parts,
this_batch,
malicious,
));
}
Expand All @@ -403,9 +475,9 @@ async fn main() -> io::Result<()> {
for level in 0..bit_len - 1 {
let start_level = Instant::now();
if level == 0 {
run_flp_queries(&cfg, &client_0, &client_1, num_clients).await?;
run_flp_queries(&cfg, &typ, &client_0, &client_1, num_clients).await?;
}
run_level(&cfg, &client_0, &client_1, num_clients).await?;
run_level(&cfg, &typ, &client_0, &client_1, num_clients).await?;
println!(
"Time for level {}: {:?}",
level,
Expand All @@ -419,7 +491,7 @@ async fn main() -> io::Result<()> {
);

let start_last = Instant::now();
run_level_last(&cfg, &client_0, &client_1, num_clients).await?;
run_level_last(&cfg, &typ, &client_0, &client_1, num_clients).await?;
println!(
"Time for last level: {:?}",
start_last.elapsed().as_secs_f64()
Expand Down
Loading

0 comments on commit 14e897d

Please sign in to comment.