Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] [DO NOT MERGE] evaluating attribution fix in draft #1306

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 121 additions & 19 deletions ipa-core/src/helpers/gateway/send.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::{
borrow::Borrow,
cmp::min,
fmt::Debug,
marker::PhantomData,
num::NonZeroUsize,
Expand Down Expand Up @@ -248,35 +249,79 @@ impl<I: Debug> Stream for GatewaySendStream<I> {

impl SendChannelConfig {
fn new<M: Message>(gateway_config: GatewayConfig, total_records: TotalRecords) -> Self {
debug_assert!(M::Size::USIZE > 0, "Message size cannot be 0");
Self::new_with(gateway_config, total_records, M::Size::USIZE)
}

let record_size = M::Size::USIZE;
let total_capacity = gateway_config.active.get() * record_size;
Self {
total_capacity: total_capacity.try_into().unwrap(),
record_size: record_size.try_into().unwrap(),
read_size: if total_records.is_indeterminate()
|| gateway_config.read_size.get() <= record_size
{
fn new_with(
gateway_config: GatewayConfig,
total_records: TotalRecords,
record_size: usize,
) -> Self {
debug_assert!(record_size > 0, "Message size cannot be 0");
// The absolute minimum of capacity we reserve for this channel. We can't go
// below that number, otherwise a deadlock is almost guaranteed.
let min_capacity = gateway_config.active.get() * record_size;

// first, compute the read size. It must be a multiple of `record_size` to prevent
// misaligned reads and deadlocks. For indeterminate channels, read size must be
// set to the size of one record, to trigger buffer flush on every write
let read_size =
if total_records.is_indeterminate() || gateway_config.read_size.get() <= record_size {
record_size
} else {
std::cmp::min(
total_capacity,
// closest multiple of record_size to read_size
// closest multiple of record_size to read_size
let proposed_read_size = min(
gateway_config.read_size.get() / record_size * record_size,
)
}
.try_into()
.unwrap(),
min_capacity,
);
// if min capacity is not a multiple of read size.
// we must adjust read size. Adjusting total capacity is not possible due to
// possible deadlocks - it must be strictly aligned with active work.
// read size goes in `record_size` increments to keep it aligned.
// rem is aligned with both capacity and read_size, so subtracting
// it will keep read_size and capacity aligned
// Here is an example how it may work:
// lets say the active work is set to 10, record size is 512 bytes
// and read size in gateway config is set to 2048 bytes (default value).
// the math above will compute total_capacity to 5120 bytes and
// proposed_read_size to 2048 because it is aligned with 512 record size.
// Now, if we don't adjust then we have an issue as 5120 % 2048 = 1024 != 0.
// Keeping read size like this will cause a deadlock, so we adjust it to
// 1024.
proposed_read_size - min_capacity % proposed_read_size
};

// total capacity must be a multiple of both read size and record size.
// Record size is easy to justify: misalignment here leads to either waste of memory
// or deadlock on the last write. Aligning read size and total capacity
// has the same reasoning behind it: reading less than total capacity
// can leave the last chunk half-written and backpressure from active work
// preventing the protocols to make further progress.
let total_capacity = min_capacity / read_size * read_size;

let this = Self {
total_capacity: total_capacity.try_into().unwrap(),
record_size: record_size.try_into().unwrap(),
read_size: read_size.try_into().unwrap(),
total_records,
}
};

// make sure we've set these values correctly.
debug_assert_eq!(0, this.total_capacity.get() % this.read_size.get());
debug_assert_eq!(0, this.total_capacity.get() % this.record_size.get());
debug_assert!(this.total_capacity.get() >= this.read_size.get());
debug_assert!(this.total_capacity.get() >= this.record_size.get());
debug_assert!(this.read_size.get() >= this.record_size.get());

this
}
}

#[cfg(test)]
mod test {
use std::num::NonZeroUsize;

use proptest::proptest;
use typenum::Unsigned;

use crate::{
Expand All @@ -285,7 +330,7 @@ mod test {
Serializable,
},
helpers::{gateway::send::SendChannelConfig, GatewayConfig, TotalRecords},
secret_sharing::SharedValue,
secret_sharing::{Sendable, StdArray},
};

impl Default for SendChannelConfig {
Expand All @@ -300,7 +345,7 @@ mod test {
}

#[allow(clippy::needless_update)] // to allow progress_check_interval to be conditionally compiled
fn send_config<V: SharedValue, const A: usize, const R: usize>(
fn send_config<V: Sendable, const A: usize, const R: usize>(
total_records: TotalRecords,
) -> SendChannelConfig {
let gateway_config = GatewayConfig {
Expand Down Expand Up @@ -390,4 +435,61 @@ mod test {
.get()
);
}

/// This test reproduces ipa/#1300. PRF evaluation sent 32*16 = 512 (`record_size` * vectorization)
/// chunks through a channel with total capacity 5120 (active work = 10 records) and read size
/// of 2048 bytes.
/// The problem was that read size of 2048 does not divide 5120, so the last chunk was not sent.
#[test]
fn total_capacity_is_a_multiple_of_read_size() {
let config =
send_config::<StdArray<BA256, 16>, 10, 2048>(TotalRecords::specified(43).unwrap());

assert_eq!(0, config.total_capacity.get() % config.read_size.get());
assert_eq!(config.total_capacity.get(), 10 * config.record_size.get());
}

fn ensure_config(
total_records: Option<usize>,
active: usize,
read_size: usize,
record_size: usize,
) {
let gateway_config = GatewayConfig {
active: active.try_into().unwrap(),
read_size: read_size.try_into().unwrap(),
..Default::default()
};
let config = SendChannelConfig::new_with(
gateway_config,
total_records
.map_or(TotalRecords::Indeterminate, |v| TotalRecords::specified(v).unwrap()),
record_size,
);

// total capacity checks
assert!(config.total_capacity.get() > 0);
assert!(config.total_capacity.get() >= record_size);
assert!(config.total_capacity.get() <= record_size * active);
assert!(config.total_capacity.get() >= config.read_size.get());
assert_eq!(0, config.total_capacity.get() % config.record_size.get());

// read size checks
assert!(config.read_size.get() > 0);
assert!(config.read_size.get() >= config.record_size.get());
assert_eq!(0, config.total_capacity.get() % config.read_size.get());
assert_eq!(0, config.read_size.get() % config.record_size.get());
}

proptest! {
#[test]
fn config_prop(
total_records in proptest::option::of(1_usize..1 << 32),
active in 1_usize..100_000,
read_size in 1_usize..32768,
record_size in 1_usize..4096,
) {
ensure_config(total_records, active, read_size, record_size);
}
}
}
13 changes: 9 additions & 4 deletions ipa-core/src/protocol/context/batcher.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use std::{cmp::min, collections::VecDeque, future::Future};

use std::fmt::Debug;
use bitvec::{bitvec, prelude::BitVec};
use tokio::sync::watch;

Expand Down Expand Up @@ -78,7 +78,7 @@ enum Ready<B> {
},
}

impl<'a, B> Batcher<'a, B> {
impl<'a, B: Debug> Batcher<'a, B> {
pub fn new<T: Into<TotalRecords>>(
records_per_batch: usize,
total_records: T,
Expand Down Expand Up @@ -170,7 +170,7 @@ impl<'a, B> Batcher<'a, B> {
"Expected batch of {total_count} records to be ready for validation, but only have {:?}.",
&batch.pending_records[0..total_count],
);
tracing::info!("batch {batch_index} is ready for validation");
tracing::info!("is_ready_for_validation: batch {batch_index} is ready for validation");
let batch;
if batch_offset == 0 {
batch = self.batches.pop_front().unwrap();
Expand Down Expand Up @@ -252,7 +252,12 @@ impl<'a, B> Batcher<'a, B> {
/// If the batcher contains more than one batch.
pub fn into_single_batch(mut self) -> B {
assert!(self.first_batch == 0);
assert!(self.batches.len() <= 1);
assert!(self.batches.len() <= 1, "Number of batches must be 1, got {}. Total records: {:?}/records per batch: {}. debug: {:?}",
self.batches.len(),
self.total_records,
self.records_per_batch,
self.batches
);
let batch_index = 0;
match self.batches.pop_back() {
Some(state) => {
Expand Down
2 changes: 1 addition & 1 deletion ipa-core/src/protocol/context/dzkp_validator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -551,9 +551,9 @@ impl Batch {
.generate_challenges(ctx.narrow(&Step::Challenge))
.await;

let m = self.get_number_of_multiplications();
let (sum_of_uv, p_r_right_prover, q_r_left_prover) = {
// get number of multiplications
let m = self.get_number_of_multiplications();
tracing::info!("validating {m} multiplications");
debug_assert_eq!(
m,
Expand Down
4 changes: 3 additions & 1 deletion ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,9 @@ where
protocol: &Step::Aggregate,
validate: &Step::AggregateValidate,
},
aggregate_values_proof_chunk(B, usize::try_from(TV::BITS).unwrap()),
// aggregate_values_proof_chunk(B, usize::try_from(TV::BITS).unwrap()),
// 1B batch size, suboptimal. But only to test that it works for 3M
1 << 30,
);
let user_contributions = flattened_user_results.try_collect::<Vec<_>>().await?;
let result = breakdown_reveal_aggregation::<_, _, _, HV, B>(
Expand Down
12 changes: 10 additions & 2 deletions ipa-core/src/query/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::{
future::{ready, Future},
pin::Pin,
};

use std::sync::OnceLock;
use ::tokio::{
runtime::{Handle, RuntimeFlavor},
sync::oneshot,
Expand All @@ -17,6 +17,7 @@ use rand::rngs::StdRng;
use rand_core::SeedableRng;
#[cfg(all(feature = "shuttle", test))]
use shuttle::future as tokio;
use tokio::runtime::{Builder, Runtime};
use typenum::Unsigned;

#[cfg(any(
Expand Down Expand Up @@ -71,6 +72,13 @@ where
}
}

static QUERY_RUNTIME: OnceLock<Runtime> = OnceLock::new();
fn get_query_runtime() -> &'static Runtime {
QUERY_RUNTIME.get_or_init(|| {
Builder::new_multi_thread().worker_threads(10).thread_name("query_runtime").enable_all().build().unwrap()
})
}

/// Needless pass by value because IPA v3 does not make use of key registry yet.
#[allow(clippy::too_many_lines, clippy::needless_pass_by_value)]
pub fn execute<R: PrivateKeyRegistry>(
Expand Down Expand Up @@ -180,7 +188,7 @@ where
{
let (tx, rx) = oneshot::channel();

let join_handle = tokio::spawn(async move {
let join_handle = get_query_runtime().spawn(async move {
let gateway = gateway.borrow();
// TODO: make it a generic argument for this function
let mut rng = StdRng::from_entropy();
Expand Down
11 changes: 7 additions & 4 deletions ipa-core/tests/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,12 @@ pub fn test_network<T: NetworkTest>(https: bool) {
T::execute(path, https);
}

pub fn test_ipa(mode: IpaSecurityModel, https: bool, encrypted_inputs: bool) {
test_ipa_with_config(
pub fn test_ipa<const INPUT_SIZE: usize>(
mode: IpaSecurityModel,
https: bool,
encrypted_inputs: bool,
) {
test_ipa_with_config::<INPUT_SIZE>(
mode,
https,
IpaQueryConfig {
Expand All @@ -228,7 +232,7 @@ pub fn test_ipa(mode: IpaSecurityModel, https: bool, encrypted_inputs: bool) {
);
}

pub fn test_ipa_with_config(
pub fn test_ipa_with_config<const INPUT_SIZE: usize>(
mode: IpaSecurityModel,
https: bool,
config: IpaQueryConfig,
Expand All @@ -238,7 +242,6 @@ pub fn test_ipa_with_config(
panic!("encrypted_input requires https")
};

const INPUT_SIZE: usize = 100;
// set to true to always keep the temp dir after test finishes
let dir = TempDir::new_delete_on_drop();
let path = dir.path();
Expand Down
6 changes: 3 additions & 3 deletions ipa-core/tests/helper_networks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,20 +45,20 @@ fn http_network_large_input() {
#[test]
#[cfg(all(test, web_test))]
fn http_semi_honest_ipa() {
test_ipa(IpaSecurityModel::SemiHonest, false, false);
test_ipa::<100>(IpaSecurityModel::SemiHonest, false, false);
}

#[test]
#[cfg(all(test, web_test))]
fn https_semi_honest_ipa() {
test_ipa(IpaSecurityModel::SemiHonest, true, true);
test_ipa::<100>(IpaSecurityModel::SemiHonest, true, true);
}

#[test]
#[cfg(all(test, web_test))]
#[ignore]
fn https_malicious_ipa() {
test_ipa(IpaSecurityModel::Malicious, true, true);
test_ipa::<100>(IpaSecurityModel::Malicious, true, true);
}

/// Similar to [`network`] tests, but it uses keygen + confgen CLIs to generate helper client config
Expand Down
12 changes: 9 additions & 3 deletions ipa-core/tests/ipa_with_relaxed_dp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ fn relaxed_dp_semi_honest() {
let encrypted_input = false;
let config = build_config();

test_ipa_with_config(
test_ipa_with_config::<100>(
IpaSecurityModel::SemiHonest,
encrypted_input,
config,
Expand All @@ -33,7 +33,7 @@ fn relaxed_dp_malicious() {
let encrypted_input = false;
let config = build_config();

test_ipa_with_config(
test_ipa_with_config::<100>(
IpaSecurityModel::Malicious,
encrypted_input,
config,
Expand All @@ -44,5 +44,11 @@ fn relaxed_dp_malicious() {
#[test]
#[cfg(all(test, web_test))]
fn relaxed_dp_https_malicious_ipa() {
test_ipa(IpaSecurityModel::Malicious, true, true);
test_ipa::<100>(IpaSecurityModel::Malicious, true, true);
}

#[test]
#[cfg(all(test, web_test))]
fn relaxed_dp_https_malicious_ipa_10_rows() {
test_ipa::<10>(IpaSecurityModel::Malicious, true, true);
}
Loading