Skip to content

Commit

Permalink
Receive verifier shares and proofs in RPC batches
Browse files Browse the repository at this point in the history
  • Loading branch information
jimouris committed Nov 1, 2023
1 parent fee03a6 commit 7da2ba5
Show file tree
Hide file tree
Showing 7 changed files with 161 additions and 93 deletions.
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,9 @@ system. An example of one such file is in `src/bin/config.json`. The contents of
"threshold": 0.01,
"server_0": "0.0.0.0:8000",
"server_1": "0.0.0.0:8001",
"addkey_batch_size": 100,
"unique_buckets": 10,
"add_key_batch_size": 1000,
"flp_batch_size": 100000,
"unique_buckets": 1000,
"zipf_exponent": 1.03
}
```
Expand All @@ -72,9 +73,10 @@ The parameters are:
clients hold.
* `server0`, `server1`, and `server2`: The `IP:port` of tuple for the two servers. The servers can
run on different IP addresses, but these IPs must be publicly addressable.
* `addkey_batch_size`: The number of each type of RPC request to bundle together. The underlying RPC
* `add_key_batch_size`: The number of each type of RPC request to bundle together. The underlying RPC
library has an annoying limit on the size of each RPC request, so you cannot set these values too
large.
* `flp_batch_size`: Similar to `add_key_batch_size` but with a greater threshold.
* `unique_buckets` and `zipf_exponent`: Each simulated client samples its private string from a Zipf
distribution over strings with parameter `zipf_exponent` and support `unique_buckets`.

Expand Down
5 changes: 3 additions & 2 deletions src/bin/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
"threshold": 0.01,
"server_0": "0.0.0.0:8000",
"server_1": "0.0.0.0:8001",
"addkey_batch_size": 100,
"unique_buckets": 10,
"add_key_batch_size": 1000,
"flp_batch_size": 100000,
"unique_buckets": 1000,
"zipf_exponent": 1.03
}
154 changes: 90 additions & 64 deletions src/bin/leader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ use futures::try_join;
use mastic::{
collect, config, dpf,
rpc::{
AddFLPsRequest, AddKeysRequest, ApplyFLPResultsRequest, FinalSharesRequest, ResetRequest,
RunFlpQueriesRequest, TreeCrawlLastRequest, TreeCrawlRequest, TreeInitRequest,
TreePruneRequest,
AddFLPsRequest, AddKeysRequest, ApplyFLPResultsRequest, FinalSharesRequest,
GetProofsRequest, ResetRequest, RunFlpQueriesRequest, TreeCrawlLastRequest,
TreeCrawlRequest, TreeInitRequest, TreePruneRequest,
},
CollectorClient,
};
Expand Down Expand Up @@ -85,18 +85,18 @@ async fn reset_servers(
let req = ResetRequest {
verify_key: *verify_key,
};
let response_0 = client_0.reset(long_context(), req.clone());
let response_1 = client_1.reset(long_context(), req);
try_join!(response_0, response_1).unwrap();
let resp_0 = client_0.reset(long_context(), req.clone());
let resp_1 = client_1.reset(long_context(), req);
try_join!(resp_0, resp_1).unwrap();

Ok(())
}

async fn tree_init(client_0: &Client, client_1: &Client) -> io::Result<()> {
let req = TreeInitRequest {};
let response_0 = client_0.tree_init(long_context(), req.clone());
let response_1 = client_1.tree_init(long_context(), req);
try_join!(response_0, response_1).unwrap();
let resp_0 = client_0.tree_init(long_context(), req.clone());
let resp_1 = client_1.tree_init(long_context(), req);
try_join!(resp_0, resp_1).unwrap();

Ok(())
}
Expand Down Expand Up @@ -141,55 +141,70 @@ async fn add_keys(
flp_proof_shares_1.push(proofs_1[idx_3 % cfg.unique_buckets].clone());
}

let response_0 = client_0.add_keys(long_context(), AddKeysRequest { keys: add_keys_0 });
let response_1 = client_1.add_keys(long_context(), AddKeysRequest { keys: add_keys_1 });
try_join!(response_0, response_1).unwrap();
let resp_0 = client_0.add_keys(long_context(), AddKeysRequest { keys: add_keys_0 });
let resp_1 = client_1.add_keys(long_context(), AddKeysRequest { keys: add_keys_1 });
try_join!(resp_0, resp_1).unwrap();

let response_0 = client_0.add_all_flp_proof_shares(
let resp_0 = client_0.add_all_flp_proof_shares(
long_context(),
AddFLPsRequest {
flp_proof_shares: flp_proof_shares_0,
},
);
let response_1 = client_1.add_all_flp_proof_shares(
let resp_1 = client_1.add_all_flp_proof_shares(
long_context(),
AddFLPsRequest {
flp_proof_shares: flp_proof_shares_1,
},
);
try_join!(response_0, response_1).unwrap();
try_join!(resp_0, resp_1).unwrap();

Ok(())
}

async fn run_flp_queries(client_0: &Client, client_1: &Client) -> io::Result<()> {
let req = RunFlpQueriesRequest {};
let response_0 = client_0.run_flp_queries(long_context(), req.clone());
let response_1 = client_1.run_flp_queries(long_context(), req);
let (flp_verifier_shares_0, flp_verifier_shares_1) = try_join!(response_0, response_1).unwrap();

assert_eq!(flp_verifier_shares_0.len(), flp_verifier_shares_1.len());

async fn run_flp_queries(
cfg: &config::Config,
client_0: &Client,
client_1: &Client,
num_clients: usize,
) -> io::Result<()> {
// Receive FLP query responses in chunks of cfg.flp_batch_size to avoid having huge RPC messages.
let count = Count::new();
let keep = flp_verifier_shares_0
.par_iter()
.zip(flp_verifier_shares_1.par_iter())
.map(|(flp_verifier_share_0, flp_verifier_share_1)| {
let flp_verifier = flp_verifier_share_0
let mut keep = vec![];
let mut start = 0;
while start < num_clients {
let end = std::cmp::min(num_clients, start + cfg.flp_batch_size);

let req = RunFlpQueriesRequest { start, end };
let resp_0 = client_0.run_flp_queries(long_context(), req.clone());
let resp_1 = client_1.run_flp_queries(long_context(), req);
let (flp_verifier_shares_0, flp_verifier_shares_1) = try_join!(resp_0, resp_1).unwrap();
debug_assert_eq!(flp_verifier_shares_0.len(), flp_verifier_shares_1.len());

keep.extend(
flp_verifier_shares_0
.par_iter()
.zip(flp_verifier_share_1.par_iter())
.map(|(&v1, &v2)| v1 + v2)
.collect::<Vec<_>>();
.zip(flp_verifier_shares_1.par_iter())
.map(|(flp_verifier_share_0, flp_verifier_share_1)| {
let flp_verifier = flp_verifier_share_0
.par_iter()
.zip(flp_verifier_share_1.par_iter())
.map(|(&v1, &v2)| v1 + v2)
.collect::<Vec<_>>();

count.decide(&flp_verifier).unwrap()
})
.collect::<Vec<_>>(),
);

count.decide(&flp_verifier).unwrap()
})
.collect::<Vec<_>>();
start += cfg.flp_batch_size;
}

// Tree prune
let req = ApplyFLPResultsRequest { keep };
let response_0 = client_0.apply_flp_results(long_context(), req.clone());
let response_1 = client_1.apply_flp_results(long_context(), req);
try_join!(response_0, response_1).unwrap();
let resp_0 = client_0.apply_flp_results(long_context(), req.clone());
let resp_1 = client_1.apply_flp_results(long_context(), req);
try_join!(resp_0, resp_1).unwrap();

Ok(())
}
Expand All @@ -213,10 +228,10 @@ async fn run_level(
is_last,
};

let response_0 = client_0.tree_crawl(long_context(), req.clone());
let response_1 = client_1.tree_crawl(long_context(), req);
let resp_0 = client_0.tree_crawl(long_context(), req.clone());
let resp_1 = client_1.tree_crawl(long_context(), req);
let ((cnt_values_0, mt_root_0, indices_0), (cnt_values_1, mt_root_1, indices_1)) =
try_join!(response_0, response_1).unwrap();
try_join!(resp_0, resp_1).unwrap();

assert_eq!(cnt_values_0.len(), cnt_values_1.len());
keep =
Expand Down Expand Up @@ -254,9 +269,9 @@ async fn run_level(

// Tree prune
let req = TreePruneRequest { keep };
let response_0 = client_0.tree_prune(long_context(), req.clone());
let response_1 = client_1.tree_prune(long_context(), req);
try_join!(response_0, response_1).unwrap();
let resp_0 = client_0.tree_prune(long_context(), req.clone());
let resp_1 = client_1.tree_prune(long_context(), req);
try_join!(resp_0, resp_1).unwrap();

Ok(())
}
Expand All @@ -270,33 +285,44 @@ async fn run_level_last(
let threshold = core::cmp::max(1, (cfg.threshold * (num_clients as f64)) as u64);

let req = TreeCrawlLastRequest {};
let response_0 = client_0.tree_crawl_last(long_context(), req.clone());
let response_1 = client_1.tree_crawl_last(long_context(), req);
let ((cnt_values_0, hashes_0), (cnt_values_1, hashes_1)) =
try_join!(response_0, response_1).unwrap();

let resp_0 = client_0.tree_crawl_last(long_context(), req.clone());
let resp_1 = client_1.tree_crawl_last(long_context(), req);
let (cnt_values_0, cnt_values_1) = try_join!(resp_0, resp_1).unwrap();
assert_eq!(cnt_values_0.len(), cnt_values_1.len());
assert_eq!(hashes_0.len(), hashes_1.len());

let verified = hashes_0
.par_iter()
.zip(hashes_1.par_iter())
.all(|(&h0, &h1)| h0 == h1);
assert!(verified);

let keep =
collect::KeyCollection::<Field64>::keep_values(threshold, &cnt_values_0, &cnt_values_1);

// Receive counters in chunks to avoid having huge RPC messages.
let mut start = 0;
while start < num_clients {
let end = std::cmp::min(num_clients, start + cfg.flp_batch_size);

let req = GetProofsRequest { start, end };
let resp_0 = client_0.get_proofs(long_context(), req.clone());
let resp_1 = client_1.get_proofs(long_context(), req);
let (hashes_0, hashes_1) = try_join!(resp_0, resp_1).unwrap();

assert_eq!(hashes_0.len(), hashes_1.len());

let verified = hashes_0
.par_iter()
.zip(hashes_1.par_iter())
.all(|(&h0, &h1)| h0 == h1);
assert!(verified);

start += cfg.flp_batch_size;
}

// Tree prune
let req = TreePruneRequest { keep };
let response_0 = client_0.tree_prune(long_context(), req.clone());
let response_1 = client_1.tree_prune(long_context(), req);
try_join!(response_0, response_1).unwrap();
let resp_0 = client_0.tree_prune(long_context(), req.clone());
let resp_1 = client_1.tree_prune(long_context(), req);
try_join!(resp_0, resp_1).unwrap();

let req = FinalSharesRequest {};
let response_0 = client_0.final_shares(long_context(), req.clone());
let response_1 = client_1.final_shares(long_context(), req);
let (shares_0, shares_1) = try_join!(response_0, response_1).unwrap();
let resp_0 = client_0.final_shares(long_context(), req.clone());
let resp_1 = client_1.final_shares(long_context(), req);
let (shares_0, shares_1) = try_join!(resp_0, resp_1).unwrap();
for res in &collect::KeyCollection::<Field64>::final_values(&shares_0, &shares_1) {
let bits = mastic::bits_to_bitstring(&res.path);
if res.value > Field64::from(0) {
Expand Down Expand Up @@ -348,7 +374,7 @@ async fn main() -> io::Result<()> {
let mut responses = vec![];

for _ in 0..reqs_in_flight {
let this_batch = std::cmp::min(left_to_go, cfg.addkey_batch_size);
let this_batch = std::cmp::min(left_to_go, cfg.add_key_batch_size);
left_to_go -= this_batch;

if this_batch > 0 {
Expand All @@ -371,7 +397,7 @@ async fn main() -> io::Result<()> {
for level in 0..bit_len - 1 {
let start_level = Instant::now();
if level == 0 {
run_flp_queries(&client_0, &client_1).await?;
run_flp_queries(&cfg, &client_0, &client_1, num_clients).await?;
}
run_level(&cfg, &client_0, &client_1, num_clients).await?;
println!(
Expand Down
28 changes: 21 additions & 7 deletions src/bin/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ use mastic::{
collect, config, prg,
rpc::{
AddFLPsRequest, AddKeysRequest, ApplyFLPResultsRequest, Collector, FinalSharesRequest,
ResetRequest, RunFlpQueriesRequest, TreeCrawlLastRequest, TreeCrawlRequest,
TreeInitRequest, TreePruneRequest,
GetProofsRequest, ResetRequest, RunFlpQueriesRequest, TreeCrawlLastRequest,
TreeCrawlRequest, TreeInitRequest, TreePruneRequest,
},
};
use prio::field::Field64;
Expand Down Expand Up @@ -79,22 +79,27 @@ impl Collector for CollectorServer {
_: context::Context,
req: TreeCrawlRequest,
) -> (Vec<Field64>, Vec<Vec<u8>>, Vec<usize>) {
let start = Instant::now();
let split_by = req.split_by;
let malicious = req.malicious;
let is_last = req.is_last;
let mut coll = self.arc.lock().unwrap();

coll.tree_crawl(split_by, &malicious, is_last)
let res = coll.tree_crawl(split_by, &malicious, is_last);
println!("Tree crawl: {:?} sec.", start.elapsed().as_secs_f64());

res
}

async fn run_flp_queries(
self,
_: context::Context,
_req: RunFlpQueriesRequest,
req: RunFlpQueriesRequest,
) -> Vec<Vec<Field64>> {
let mut coll = self.arc.lock().unwrap();
debug_assert!(req.start < req.end);

coll.run_flp_queries()
coll.run_flp_queries(req.start, req.end)
}

async fn apply_flp_results(self, _: context::Context, req: ApplyFLPResultsRequest) -> String {
Expand All @@ -107,14 +112,23 @@ impl Collector for CollectorServer {
self,
_: context::Context,
_req: TreeCrawlLastRequest,
) -> (Vec<Field64>, Vec<[u8; 32]>) {
) -> Vec<Field64> {
let start = Instant::now();
let mut coll = self.arc.lock().unwrap();

let res = coll.tree_crawl_last();
println!("tree_crawl_last: {:?}", start.elapsed().as_secs_f64());
println!("Tree crawl last: {:?} sec.", start.elapsed().as_secs_f64());

res
}

async fn get_proofs(self, _: context::Context, req: GetProofsRequest) -> Vec<[u8; 32]> {
let coll = self.arc.lock().unwrap();
debug_assert!(req.start < req.end);

coll.get_proofs(req.start, req.end)
}

async fn tree_prune(self, _: context::Context, req: TreePruneRequest) -> String {
let mut coll = self.arc.lock().unwrap();
coll.tree_prune(&req.keep);
Expand Down
Loading

0 comments on commit 7da2ba5

Please sign in to comment.