From b86bf90cf83a741233821d6426e7e13f4f065880 Mon Sep 17 00:00:00 2001 From: Andy Leiserson Date: Mon, 14 Oct 2024 13:19:08 -0700 Subject: [PATCH] Review follow-up from #1323 --- .../src/protocol/context/dzkp_malicious.rs | 4 ++ .../ipa_prf/aggregation/breakdown_reveal.rs | 13 +++-- ipa-core/src/utils/mod.rs | 1 - ipa-core/src/utils/vec_chunks.rs | 58 ------------------- 4 files changed, 11 insertions(+), 65 deletions(-) delete mode 100644 ipa-core/src/utils/vec_chunks.rs diff --git a/ipa-core/src/protocol/context/dzkp_malicious.rs b/ipa-core/src/protocol/context/dzkp_malicious.rs index 6511663d7..671dfa08d 100644 --- a/ipa-core/src/protocol/context/dzkp_malicious.rs +++ b/ipa-core/src/protocol/context/dzkp_malicious.rs @@ -43,6 +43,10 @@ impl<'a, B: ShardBinding> DZKPUpgraded<'a, B> { // in tests; there shouldn't be a risk of deadlocks with one record per // batch; and UnorderedReceiver capacity (which is set from active_work) // must be at least two. + // + // Also rely on the protocol to ensure an appropriate active_work if + // records_per_batch is `usize::MAX` (unlimited batch size). Allocating + // storage for `usize::MAX` active records won't work. base_ctx.active_work() } else { // Adjust active_work to match records_per_batch. If it is less, we will diff --git a/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs b/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs index 6ff1b19c9..65ec97230 100644 --- a/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs +++ b/ipa-core/src/protocol/ipa_prf/aggregation/breakdown_reveal.rs @@ -1,4 +1,4 @@ -use std::{convert::Infallible, mem, pin::pin}; +use std::{convert::Infallible, pin::pin}; use futures::stream; use futures_util::{StreamExt, TryStreamExt}; @@ -34,7 +34,6 @@ use crate::{ TransposeFrom, Vectorizable, }, seq_join::seq_join, - utils::vec_chunks::vec_chunks, }; /// Improved Aggregation a.k.a Aggregation revealing breakdown. @@ -100,7 +99,8 @@ where while intermediate_results.len() > 1 { let mut record_ids = [RecordId::FIRST; AGGREGATE_DEPTH]; - for chunk in vec_chunks(mem::take(&mut intermediate_results), agg_proof_chunk) { + let mut next_intermediate_results = Vec::new(); + for chunk in intermediate_results.chunks(agg_proof_chunk) { let chunk_len = chunk.len(); let validator = ctx.clone().dzkp_validator( MaliciousProtocolSteps { @@ -109,21 +109,22 @@ where }, // We have to specify usize::MAX here because the procession through // record IDs is different at each step of the reduction. The batch - // size is limited by `vec_chunks`, above. + // size is limited by `intermediate_results.chunks()`, above. usize::MAX, ); let result = aggregate_values::<_, HV, B>( validator.context(), - stream::iter(chunk).map(Ok).boxed(), + stream::iter(chunk).map(|v| Ok(v.clone())).boxed(), chunk_len, Some(&mut record_ids), ) .await?; validator.validate().await?; chunk_counter += 1; - intermediate_results.push(result); + next_intermediate_results.push(result); } depth += 1; + intermediate_results = next_intermediate_results; } Ok(intermediate_results diff --git a/ipa-core/src/utils/mod.rs b/ipa-core/src/utils/mod.rs index 5bbfd8c87..e8dfd95ae 100644 --- a/ipa-core/src/utils/mod.rs +++ b/ipa-core/src/utils/mod.rs @@ -2,7 +2,6 @@ pub mod array; pub mod arraychunks; #[cfg(target_pointer_width = "64")] mod power_of_two; -pub mod vec_chunks; #[cfg(target_pointer_width = "64")] pub use power_of_two::NonZeroU32PowerOfTwo; diff --git a/ipa-core/src/utils/vec_chunks.rs b/ipa-core/src/utils/vec_chunks.rs deleted file mode 100644 index 9732bf184..000000000 --- a/ipa-core/src/utils/vec_chunks.rs +++ /dev/null @@ -1,58 +0,0 @@ -use std::cmp::min; - -pub struct VecChunks { - vec: Vec, - pos: usize, - chunk_size: usize, -} - -impl Iterator for VecChunks { - type Item = Vec; - - fn next(&mut self) -> Option { - let start = self.pos; - let len = min(self.vec.len() - start, self.chunk_size); - (len != 0).then(|| { - self.pos += len; - self.vec[start..start + len].to_vec() - }) - } -} - -pub fn vec_chunks(vec: Vec, chunk_size: usize) -> impl Iterator> { - assert!(chunk_size != 0); - VecChunks { - vec, - pos: 0, - chunk_size, - } -} - -#[cfg(all(test, unit_test))] -mod tests { - use super::vec_chunks; - use crate::ff::{Field, Fp61BitPrime}; - - #[test] - fn vec_chunk_iter() { - let elements = vec![Fp61BitPrime::ONE; 4]; - - let mut vec_chunk_iterator = vec_chunks(elements, 3); - - assert_eq!( - vec_chunk_iterator.next().unwrap(), - vec![Fp61BitPrime::ONE; 3] - ); - assert_eq!( - vec_chunk_iterator.next().unwrap(), - vec![Fp61BitPrime::ONE; 1] - ); - assert!(vec_chunk_iterator.next().is_none()); - } - - #[test] - fn vec_chunk_empty() { - let vec = Vec::::new(); - assert!(vec_chunks(vec, 1).next().is_none()); - } -}