diff --git a/circom_algebra/src/algebra.rs b/circom_algebra/src/algebra.rs index 854b6c2e0..2451a5b73 100644 --- a/circom_algebra/src/algebra.rs +++ b/circom_algebra/src/algebra.rs @@ -2,7 +2,7 @@ use super::modular_arithmetic; pub use super::modular_arithmetic::ArithmeticError; use num_bigint::BigInt; use num_traits::{ToPrimitive, Zero}; -use std::collections::{HashMap, HashSet}; +use std::collections::{HashMap, HashSet, BTreeSet}; use std::fmt::{Display, Formatter}; use std::hash::Hash; @@ -984,6 +984,19 @@ impl Substitution { } } +impl Substitution { + pub fn take_cloned_signals_ordered(&self) -> BTreeSet { + let cq: C = ArithmeticExpression::constant_coefficient(); + let mut signals = BTreeSet::new(); + for s in self.to.keys() { + if cq != *s { + signals.insert(s.clone()); + } + } + signals + } +} + impl Substitution { pub fn apply_offset(&self, offset: usize) -> Substitution { let constant: usize = Substitution::constant_coefficient(); @@ -1172,6 +1185,24 @@ impl Constraint { } +impl Constraint { + pub fn take_cloned_signals_ordered(&self) -> BTreeSet { + let mut signals = BTreeSet::new(); + for signal in self.a().keys() { + signals.insert(signal.clone()); + } + for signal in self.b().keys() { + signals.insert(signal.clone()); + } + for signal in self.c().keys() { + signals.insert(signal.clone()); + } + signals.remove(&Constraint::constant_coefficient()); + signals + } + +} + impl Constraint { pub fn apply_offset(&self, offset: usize) -> Constraint { let a = apply_raw_offset(&self.a, offset); diff --git a/constraint_list/src/constraint_simplification.rs b/constraint_list/src/constraint_simplification.rs index 3b4c5c60d..106894a14 100644 --- a/constraint_list/src/constraint_simplification.rs +++ b/constraint_list/src/constraint_simplification.rs @@ -3,7 +3,7 @@ use super::{ConstraintStorage, EncodingIterator, SEncoded, Simplifier, A, C, S}; use crate::SignalMap; use circom_algebra::num_bigint::BigInt; use constraint_writers::json_writer::SubstitutionJSON; -use std::collections::{HashMap, HashSet, LinkedList}; +use std::collections::{HashMap, HashSet, LinkedList, BTreeSet}; use std::sync::Arc; const SUB_LOG: &str = "./log_substitution.json"; @@ -72,7 +72,7 @@ fn build_clusters(linear: LinkedList, no_vars: usize) -> Vec { let mut cluster_to_current = ClusterPath::with_capacity(no_linear); let mut signal_to_cluster = vec![no_linear; no_vars]; for constraint in linear { - let signals = C::take_cloned_signals(&constraint); + let signals = C::take_cloned_signals_ordered(&constraint); let dest = ClusterArena::len(&arena); ClusterArena::push(&mut arena, Some(Cluster::new(constraint))); Vec::push(&mut cluster_to_current, dest); @@ -128,7 +128,7 @@ fn eq_cluster_simplification( let mut substitutions = LinkedList::new(); let mut constraints = LinkedList::new(); let constraint = LinkedList::pop_back(&mut cluster.constraints).unwrap(); - let signals: Vec<_> = C::take_cloned_signals(&constraint).iter().cloned().collect(); + let signals: Vec<_> = C::take_cloned_signals_ordered(&constraint).iter().cloned().collect(); let s_0 = signals[0]; let s_1 = signals[1]; if HashSet::contains(forbidden, &s_0) && HashSet::contains(forbidden, &s_1) { @@ -151,12 +151,12 @@ fn eq_cluster_simplification( } else { let mut cons = LinkedList::new(); let mut subs = LinkedList::new(); - let (mut remains, mut min_remains) = (HashSet::new(), None); + let (mut remains, mut min_remains) = (BTreeSet::new(), None); let (mut remove, mut min_remove) = (HashSet::new(), None); for c in cluster.constraints { - for signal in C::take_cloned_signals(&c) { + for signal in C::take_cloned_signals_ordered(&c) { if HashSet::contains(&forbidden, &signal) { - HashSet::insert(&mut remains, signal); + BTreeSet::insert(&mut remains, signal); min_remains = Some(min_remains.map_or(signal, |s| std::cmp::min(s, signal))); } else { min_remove = Some(min_remove.map_or(signal, |s| std::cmp::min(s, signal))); @@ -166,7 +166,7 @@ fn eq_cluster_simplification( } let rh_signal = if let Some(signal) = min_remains { - HashSet::remove(&mut remains, &signal); + BTreeSet::remove(&mut remains, &signal); signal } else { let signal = min_remove.unwrap(); @@ -200,7 +200,6 @@ fn eq_simplification( ) -> (LinkedList, LinkedList) { use std::sync::mpsc; use threadpool::ThreadPool; - let field = Arc::new(field.clone()); let mut constraints = LinkedList::new(); let mut substitutions = LinkedList::new(); @@ -211,11 +210,12 @@ fn eq_simplification( // println!("Clusters: {}", no_clusters); let mut single_clusters = 0; let mut id = 0; + let mut aux_constraints = vec![LinkedList::new(); clusters.len()]; for cluster in clusters { if Cluster::size(&cluster) == 1 { - let (mut subs, mut cons) = eq_cluster_simplification(cluster, &forbidden, &field); + let (mut subs, cons) = eq_cluster_simplification(cluster, &forbidden, &field); + aux_constraints[id] = cons; LinkedList::append(&mut substitutions, &mut subs); - LinkedList::append(&mut constraints, &mut cons); single_clusters += 1; } else { let cluster_tx = cluster_tx.clone(); @@ -225,7 +225,7 @@ fn eq_simplification( //println!("Cluster: {}", id); let result = eq_cluster_simplification(cluster, &forbidden, &field); //println!("End of cluster: {}", id); - cluster_tx.send(result).unwrap(); + cluster_tx.send((id, result)).unwrap(); }; ThreadPool::execute(&pool, job); } @@ -235,9 +235,12 @@ fn eq_simplification( // println!("{} clusters were of size 1", single_clusters); ThreadPool::join(&pool); for _ in 0..(no_clusters - single_clusters) { - let (mut subs, mut cons) = simplified_rx.recv().unwrap(); + let (id, (mut subs, cons)) = simplified_rx.recv().unwrap(); + aux_constraints[id] = cons; LinkedList::append(&mut substitutions, &mut subs); - LinkedList::append(&mut constraints, &mut cons); + } + for id in 0..no_clusters { + LinkedList::append(&mut constraints, &mut aux_constraints[id]); } log_substitutions(&substitutions, substitution_log); (substitutions, constraints) @@ -252,7 +255,7 @@ fn constant_eq_simplification( let mut cons = LinkedList::new(); let mut subs = LinkedList::new(); for constraint in c_eq { - let mut signals: Vec<_> = C::take_cloned_signals(&constraint).iter().cloned().collect(); + let mut signals: Vec<_> = C::take_cloned_signals_ordered(&constraint).iter().cloned().collect(); let signal = signals.pop().unwrap(); if HashSet::contains(&forbidden, &signal) { LinkedList::push_back(&mut cons, constraint); @@ -517,7 +520,7 @@ pub fn simplification(smp: &mut Simplifier) -> (ConstraintStorage, SignalMap) { relevant }; - let linear_substitutions = if apply_linear { + let linear_substitutions = if remove_unused { let now = SystemTime::now(); let (subs, mut cons) = linear_simplification( &mut substitution_log, @@ -563,15 +566,17 @@ pub fn simplification(smp: &mut Simplifier) -> (ConstraintStorage, SignalMap) { crate::state_utils::empty_encoding_constraints(&mut smp.dag_encoding); let _dur = now.elapsed().unwrap().as_millis(); // println!("Storages built in {} ms", dur); - no_rounds -= 1; + if remove_unused { + no_rounds -= 1; + } (with_linear, storage) }; let mut round_id = 0; let _ = round_id; let mut linear = with_linear; - let mut apply_round = apply_linear && no_rounds > 0 && !linear.is_empty(); - let mut non_linear_map = if apply_round || remove_unused{ + let mut apply_round = remove_unused && no_rounds > 0 && !linear.is_empty(); + let mut non_linear_map = if apply_round || remove_unused { // println!("Building non-linear map"); let now = SystemTime::now(); let non_linear_map = build_non_linear_signal_map(&constraint_storage); @@ -615,7 +620,7 @@ pub fn simplification(smp: &mut Simplifier) -> (ConstraintStorage, SignalMap) { } for constraint in linear { - if remove_unused{ + if remove_unused { let signals = C::take_cloned_signals(&constraint); let c_id = constraint_storage.add_constraint(constraint); for signal in signals { @@ -632,8 +637,9 @@ pub fn simplification(smp: &mut Simplifier) -> (ConstraintStorage, SignalMap) { constraint_storage.add_constraint(constraint); } } - for constraint in lconst { + for mut constraint in lconst { if remove_unused{ + C::fix_constraint(&mut constraint, &field); let signals = C::take_cloned_signals(&constraint); let c_id = constraint_storage.add_constraint(constraint); for signal in signals { @@ -647,6 +653,7 @@ pub fn simplification(smp: &mut Simplifier) -> (ConstraintStorage, SignalMap) { } } else{ + C::fix_constraint(&mut constraint, &field); constraint_storage.add_constraint(constraint); } } @@ -678,3 +685,6 @@ pub fn simplification(smp: &mut Simplifier) -> (ConstraintStorage, SignalMap) { // println!("NO CONSTANTS: {}", constraint_storage.no_constants()); (constraint_storage, signal_map) } + + +