diff --git a/ipa-core/src/helpers/gateway/send.rs b/ipa-core/src/helpers/gateway/send.rs index 07018fb14..c0ed20cb9 100644 --- a/ipa-core/src/helpers/gateway/send.rs +++ b/ipa-core/src/helpers/gateway/send.rs @@ -1,5 +1,6 @@ use std::{ borrow::Borrow, + cmp::min, fmt::Debug, marker::PhantomData, num::NonZeroUsize, @@ -248,28 +249,71 @@ impl Stream for GatewaySendStream { impl SendChannelConfig { fn new(gateway_config: GatewayConfig, total_records: TotalRecords) -> Self { - debug_assert!(M::Size::USIZE > 0, "Message size cannot be 0"); + Self::new_with(gateway_config, total_records, M::Size::USIZE) + } - let record_size = M::Size::USIZE; - let total_capacity = gateway_config.active.get() * record_size; - Self { - total_capacity: total_capacity.try_into().unwrap(), - record_size: record_size.try_into().unwrap(), - read_size: if total_records.is_indeterminate() - || gateway_config.read_size.get() <= record_size - { + fn new_with( + gateway_config: GatewayConfig, + total_records: TotalRecords, + record_size: usize, + ) -> Self { + debug_assert!(record_size > 0, "Message size cannot be 0"); + // The absolute minimum of capacity we reserve for this channel. We can't go + // below that number, otherwise a deadlock is almost guaranteed. + let min_capacity = gateway_config.active.get() * record_size; + + // first, compute the read size. It must be a multiple of `record_size` to prevent + // misaligned reads and deadlocks. For indeterminate channels, read size must be + // set to the size of one record, to trigger buffer flush on every write + let read_size = + if total_records.is_indeterminate() || gateway_config.read_size.get() <= record_size { record_size } else { - std::cmp::min( - total_capacity, - // closest multiple of record_size to read_size + // closest multiple of record_size to read_size + let proposed_read_size = min( gateway_config.read_size.get() / record_size * record_size, - ) - } - .try_into() - .unwrap(), + min_capacity, + ); + // if min capacity is not a multiple of read size. + // we must adjust read size. Adjusting total capacity is not possible due to + // possible deadlocks - it must be strictly aligned with active work. + // read size goes in `record_size` increments to keep it aligned. + // rem is aligned with both capacity and read_size, so subtracting + // it will keep read_size and capacity aligned + // Here is an example how it may work: + // lets say the active work is set to 10, record size is 512 bytes + // and read size in gateway config is set to 2048 bytes (default value). + // the math above will compute total_capacity to 5120 bytes and + // proposed_read_size to 2048 because it is aligned with 512 record size. + // Now, if we don't adjust then we have an issue as 5120 % 2048 = 1024 != 0. + // Keeping read size like this will cause a deadlock, so we adjust it to + // 1024. + proposed_read_size - min_capacity % proposed_read_size + }; + + // total capacity must be a multiple of both read size and record size. + // Record size is easy to justify: misalignment here leads to either waste of memory + // or deadlock on the last write. Aligning read size and total capacity + // has the same reasoning behind it: reading less than total capacity + // can leave the last chunk half-written and backpressure from active work + // preventing the protocols to make further progress. + let total_capacity = min_capacity / read_size * read_size; + + let this = Self { + total_capacity: total_capacity.try_into().unwrap(), + record_size: record_size.try_into().unwrap(), + read_size: read_size.try_into().unwrap(), total_records, - } + }; + + // make sure we've set these values correctly. + debug_assert_eq!(0, this.total_capacity.get() % this.read_size.get()); + debug_assert_eq!(0, this.total_capacity.get() % this.record_size.get()); + debug_assert!(this.total_capacity.get() >= this.read_size.get()); + debug_assert!(this.total_capacity.get() >= this.record_size.get()); + debug_assert!(this.read_size.get() >= this.record_size.get()); + + this } } @@ -277,6 +321,7 @@ impl SendChannelConfig { mod test { use std::num::NonZeroUsize; + use proptest::proptest; use typenum::Unsigned; use crate::{ @@ -285,7 +330,7 @@ mod test { Serializable, }, helpers::{gateway::send::SendChannelConfig, GatewayConfig, TotalRecords}, - secret_sharing::SharedValue, + secret_sharing::{Sendable, StdArray}, }; impl Default for SendChannelConfig { @@ -300,7 +345,7 @@ mod test { } #[allow(clippy::needless_update)] // to allow progress_check_interval to be conditionally compiled - fn send_config( + fn send_config( total_records: TotalRecords, ) -> SendChannelConfig { let gateway_config = GatewayConfig { @@ -390,4 +435,61 @@ mod test { .get() ); } + + /// This test reproduces ipa/#1300. PRF evaluation sent 32*16 = 512 (`record_size` * vectorization) + /// chunks through a channel with total capacity 5120 (active work = 10 records) and read size + /// of 2048 bytes. + /// The problem was that read size of 2048 does not divide 5120, so the last chunk was not sent. + #[test] + fn total_capacity_is_a_multiple_of_read_size() { + let config = + send_config::, 10, 2048>(TotalRecords::specified(43).unwrap()); + + assert_eq!(0, config.total_capacity.get() % config.read_size.get()); + assert_eq!(config.total_capacity.get(), 10 * config.record_size.get()); + } + + fn ensure_config( + total_records: Option, + active: usize, + read_size: usize, + record_size: usize, + ) { + let gateway_config = GatewayConfig { + active: active.try_into().unwrap(), + read_size: read_size.try_into().unwrap(), + ..Default::default() + }; + let config = SendChannelConfig::new_with( + gateway_config, + total_records + .map_or(TotalRecords::Indeterminate, |v| TotalRecords::specified(v).unwrap()), + record_size, + ); + + // total capacity checks + assert!(config.total_capacity.get() > 0); + assert!(config.total_capacity.get() >= record_size); + assert!(config.total_capacity.get() <= record_size * active); + assert!(config.total_capacity.get() >= config.read_size.get()); + assert_eq!(0, config.total_capacity.get() % config.record_size.get()); + + // read size checks + assert!(config.read_size.get() > 0); + assert!(config.read_size.get() >= config.record_size.get()); + assert_eq!(0, config.total_capacity.get() % config.read_size.get()); + assert_eq!(0, config.read_size.get() % config.record_size.get()); + } + + proptest! { + #[test] + fn config_prop( + total_records in proptest::option::of(1_usize..1 << 32), + active in 1_usize..100_000, + read_size in 1_usize..32768, + record_size in 1_usize..4096, + ) { + ensure_config(total_records, active, read_size, record_size); + } + } } diff --git a/ipa-core/src/protocol/context/batcher.rs b/ipa-core/src/protocol/context/batcher.rs index cdbfddbce..1c530a68f 100644 --- a/ipa-core/src/protocol/context/batcher.rs +++ b/ipa-core/src/protocol/context/batcher.rs @@ -1,5 +1,5 @@ use std::{cmp::min, collections::VecDeque, future::Future}; - +use std::fmt::Debug; use bitvec::{bitvec, prelude::BitVec}; use tokio::sync::watch; @@ -78,7 +78,7 @@ enum Ready { }, } -impl<'a, B> Batcher<'a, B> { +impl<'a, B: Debug> Batcher<'a, B> { pub fn new>( records_per_batch: usize, total_records: T, @@ -170,7 +170,7 @@ impl<'a, B> Batcher<'a, B> { "Expected batch of {total_count} records to be ready for validation, but only have {:?}.", &batch.pending_records[0..total_count], ); - tracing::info!("batch {batch_index} is ready for validation"); + tracing::info!("is_ready_for_validation: batch {batch_index} is ready for validation"); let batch; if batch_offset == 0 { batch = self.batches.pop_front().unwrap(); @@ -252,7 +252,12 @@ impl<'a, B> Batcher<'a, B> { /// If the batcher contains more than one batch. pub fn into_single_batch(mut self) -> B { assert!(self.first_batch == 0); - assert!(self.batches.len() <= 1); + assert!(self.batches.len() <= 1, "Number of batches must be 1, got {}. Total records: {:?}/records per batch: {}. debug: {:?}", + self.batches.len(), + self.total_records, + self.records_per_batch, + self.batches + ); let batch_index = 0; match self.batches.pop_back() { Some(state) => { diff --git a/ipa-core/src/protocol/context/dzkp_validator.rs b/ipa-core/src/protocol/context/dzkp_validator.rs index f40a7d805..93fdb4177 100644 --- a/ipa-core/src/protocol/context/dzkp_validator.rs +++ b/ipa-core/src/protocol/context/dzkp_validator.rs @@ -551,9 +551,9 @@ impl Batch { .generate_challenges(ctx.narrow(&Step::Challenge)) .await; + let m = self.get_number_of_multiplications(); let (sum_of_uv, p_r_right_prover, q_r_left_prover) = { // get number of multiplications - let m = self.get_number_of_multiplications(); tracing::info!("validating {m} multiplications"); debug_assert_eq!( m, diff --git a/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs b/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs index 682841c8b..426d945a7 100644 --- a/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs +++ b/ipa-core/src/protocol/ipa_prf/prf_sharding/mod.rs @@ -543,7 +543,9 @@ where protocol: &Step::Aggregate, validate: &Step::AggregateValidate, }, - aggregate_values_proof_chunk(B, usize::try_from(TV::BITS).unwrap()), + // aggregate_values_proof_chunk(B, usize::try_from(TV::BITS).unwrap()), + // 1B batch size, suboptimal. But only to test that it works for 3M + 1 << 30, ); let user_contributions = flattened_user_results.try_collect::>().await?; let result = breakdown_reveal_aggregation::<_, _, _, HV, B>( diff --git a/ipa-core/src/query/executor.rs b/ipa-core/src/query/executor.rs index b3e197e4d..f2e0b8793 100644 --- a/ipa-core/src/query/executor.rs +++ b/ipa-core/src/query/executor.rs @@ -4,7 +4,7 @@ use std::{ future::{ready, Future}, pin::Pin, }; - +use std::sync::OnceLock; use ::tokio::{ runtime::{Handle, RuntimeFlavor}, sync::oneshot, @@ -17,6 +17,7 @@ use rand::rngs::StdRng; use rand_core::SeedableRng; #[cfg(all(feature = "shuttle", test))] use shuttle::future as tokio; +use tokio::runtime::{Builder, Runtime}; use typenum::Unsigned; #[cfg(any( @@ -71,6 +72,13 @@ where } } +static QUERY_RUNTIME: OnceLock = OnceLock::new(); +fn get_query_runtime() -> &'static Runtime { + QUERY_RUNTIME.get_or_init(|| { + Builder::new_multi_thread().worker_threads(10).thread_name("query_runtime").enable_all().build().unwrap() + }) +} + /// Needless pass by value because IPA v3 does not make use of key registry yet. #[allow(clippy::too_many_lines, clippy::needless_pass_by_value)] pub fn execute( @@ -180,7 +188,7 @@ where { let (tx, rx) = oneshot::channel(); - let join_handle = tokio::spawn(async move { + let join_handle = get_query_runtime().spawn(async move { let gateway = gateway.borrow(); // TODO: make it a generic argument for this function let mut rng = StdRng::from_entropy(); diff --git a/ipa-core/tests/common/mod.rs b/ipa-core/tests/common/mod.rs index ca1d5e08a..6fa3139ac 100644 --- a/ipa-core/tests/common/mod.rs +++ b/ipa-core/tests/common/mod.rs @@ -217,8 +217,12 @@ pub fn test_network(https: bool) { T::execute(path, https); } -pub fn test_ipa(mode: IpaSecurityModel, https: bool, encrypted_inputs: bool) { - test_ipa_with_config( +pub fn test_ipa( + mode: IpaSecurityModel, + https: bool, + encrypted_inputs: bool, +) { + test_ipa_with_config::( mode, https, IpaQueryConfig { @@ -228,7 +232,7 @@ pub fn test_ipa(mode: IpaSecurityModel, https: bool, encrypted_inputs: bool) { ); } -pub fn test_ipa_with_config( +pub fn test_ipa_with_config( mode: IpaSecurityModel, https: bool, config: IpaQueryConfig, @@ -238,7 +242,6 @@ pub fn test_ipa_with_config( panic!("encrypted_input requires https") }; - const INPUT_SIZE: usize = 100; // set to true to always keep the temp dir after test finishes let dir = TempDir::new_delete_on_drop(); let path = dir.path(); diff --git a/ipa-core/tests/helper_networks.rs b/ipa-core/tests/helper_networks.rs index 7775ffba4..caa10335e 100644 --- a/ipa-core/tests/helper_networks.rs +++ b/ipa-core/tests/helper_networks.rs @@ -45,20 +45,20 @@ fn http_network_large_input() { #[test] #[cfg(all(test, web_test))] fn http_semi_honest_ipa() { - test_ipa(IpaSecurityModel::SemiHonest, false, false); + test_ipa::<100>(IpaSecurityModel::SemiHonest, false, false); } #[test] #[cfg(all(test, web_test))] fn https_semi_honest_ipa() { - test_ipa(IpaSecurityModel::SemiHonest, true, true); + test_ipa::<100>(IpaSecurityModel::SemiHonest, true, true); } #[test] #[cfg(all(test, web_test))] #[ignore] fn https_malicious_ipa() { - test_ipa(IpaSecurityModel::Malicious, true, true); + test_ipa::<100>(IpaSecurityModel::Malicious, true, true); } /// Similar to [`network`] tests, but it uses keygen + confgen CLIs to generate helper client config diff --git a/ipa-core/tests/ipa_with_relaxed_dp.rs b/ipa-core/tests/ipa_with_relaxed_dp.rs index 84c4c2a7b..f4d15e5e1 100644 --- a/ipa-core/tests/ipa_with_relaxed_dp.rs +++ b/ipa-core/tests/ipa_with_relaxed_dp.rs @@ -20,7 +20,7 @@ fn relaxed_dp_semi_honest() { let encrypted_input = false; let config = build_config(); - test_ipa_with_config( + test_ipa_with_config::<100>( IpaSecurityModel::SemiHonest, encrypted_input, config, @@ -33,7 +33,7 @@ fn relaxed_dp_malicious() { let encrypted_input = false; let config = build_config(); - test_ipa_with_config( + test_ipa_with_config::<100>( IpaSecurityModel::Malicious, encrypted_input, config, @@ -44,5 +44,11 @@ fn relaxed_dp_malicious() { #[test] #[cfg(all(test, web_test))] fn relaxed_dp_https_malicious_ipa() { - test_ipa(IpaSecurityModel::Malicious, true, true); + test_ipa::<100>(IpaSecurityModel::Malicious, true, true); +} + +#[test] +#[cfg(all(test, web_test))] +fn relaxed_dp_https_malicious_ipa_10_rows() { + test_ipa::<10>(IpaSecurityModel::Malicious, true, true); }