Skip to content

Commit

Permalink
removing non determinism in linear constraints that cannot be removed…
Browse files Browse the repository at this point in the history
… after simplification
  • Loading branch information
clararod9 committed Jun 17, 2022
1 parent bd1a102 commit 2c56dc5
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 21 deletions.
33 changes: 32 additions & 1 deletion circom_algebra/src/algebra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use super::modular_arithmetic;
pub use super::modular_arithmetic::ArithmeticError;
use num_bigint::BigInt;
use num_traits::{ToPrimitive, Zero};
use std::collections::{HashMap, HashSet};
use std::collections::{HashMap, HashSet, BTreeSet};
use std::fmt::{Display, Formatter};
use std::hash::Hash;

Expand Down Expand Up @@ -984,6 +984,19 @@ impl<C: Default + Clone + Display + Hash + Eq> Substitution<C> {
}
}

impl<C: Default + Clone + Display + Hash + Eq + std::cmp::Ord> Substitution<C> {
pub fn take_cloned_signals_ordered(&self) -> BTreeSet<C> {
let cq: C = ArithmeticExpression::constant_coefficient();
let mut signals = BTreeSet::new();
for s in self.to.keys() {
if cq != *s {
signals.insert(s.clone());
}
}
signals
}
}

impl Substitution<usize> {
pub fn apply_offset(&self, offset: usize) -> Substitution<usize> {
let constant: usize = Substitution::constant_coefficient();
Expand Down Expand Up @@ -1172,6 +1185,24 @@ impl<C: Default + Clone + Display + Hash + Eq> Constraint<C> {

}

impl<C: Default + Clone + Display + Hash + Eq + std::cmp::Ord> Constraint<C> {
pub fn take_cloned_signals_ordered(&self) -> BTreeSet<C> {
let mut signals = BTreeSet::new();
for signal in self.a().keys() {
signals.insert(signal.clone());
}
for signal in self.b().keys() {
signals.insert(signal.clone());
}
for signal in self.c().keys() {
signals.insert(signal.clone());
}
signals.remove(&Constraint::constant_coefficient());
signals
}

}

impl Constraint<usize> {
pub fn apply_offset(&self, offset: usize) -> Constraint<usize> {
let a = apply_raw_offset(&self.a, offset);
Expand Down
50 changes: 30 additions & 20 deletions constraint_list/src/constraint_simplification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use super::{ConstraintStorage, EncodingIterator, SEncoded, Simplifier, A, C, S};
use crate::SignalMap;
use circom_algebra::num_bigint::BigInt;
use constraint_writers::json_writer::SubstitutionJSON;
use std::collections::{HashMap, HashSet, LinkedList};
use std::collections::{HashMap, HashSet, LinkedList, BTreeSet};
use std::sync::Arc;

const SUB_LOG: &str = "./log_substitution.json";
Expand Down Expand Up @@ -72,7 +72,7 @@ fn build_clusters(linear: LinkedList<C>, no_vars: usize) -> Vec<Cluster> {
let mut cluster_to_current = ClusterPath::with_capacity(no_linear);
let mut signal_to_cluster = vec![no_linear; no_vars];
for constraint in linear {
let signals = C::take_cloned_signals(&constraint);
let signals = C::take_cloned_signals_ordered(&constraint);
let dest = ClusterArena::len(&arena);
ClusterArena::push(&mut arena, Some(Cluster::new(constraint)));
Vec::push(&mut cluster_to_current, dest);
Expand Down Expand Up @@ -128,7 +128,7 @@ fn eq_cluster_simplification(
let mut substitutions = LinkedList::new();
let mut constraints = LinkedList::new();
let constraint = LinkedList::pop_back(&mut cluster.constraints).unwrap();
let signals: Vec<_> = C::take_cloned_signals(&constraint).iter().cloned().collect();
let signals: Vec<_> = C::take_cloned_signals_ordered(&constraint).iter().cloned().collect();
let s_0 = signals[0];
let s_1 = signals[1];
if HashSet::contains(forbidden, &s_0) && HashSet::contains(forbidden, &s_1) {
Expand All @@ -151,12 +151,12 @@ fn eq_cluster_simplification(
} else {
let mut cons = LinkedList::new();
let mut subs = LinkedList::new();
let (mut remains, mut min_remains) = (HashSet::new(), None);
let (mut remains, mut min_remains) = (BTreeSet::new(), None);
let (mut remove, mut min_remove) = (HashSet::new(), None);
for c in cluster.constraints {
for signal in C::take_cloned_signals(&c) {
for signal in C::take_cloned_signals_ordered(&c) {
if HashSet::contains(&forbidden, &signal) {
HashSet::insert(&mut remains, signal);
BTreeSet::insert(&mut remains, signal);
min_remains = Some(min_remains.map_or(signal, |s| std::cmp::min(s, signal)));
} else {
min_remove = Some(min_remove.map_or(signal, |s| std::cmp::min(s, signal)));
Expand All @@ -166,7 +166,7 @@ fn eq_cluster_simplification(
}

let rh_signal = if let Some(signal) = min_remains {
HashSet::remove(&mut remains, &signal);
BTreeSet::remove(&mut remains, &signal);
signal
} else {
let signal = min_remove.unwrap();
Expand Down Expand Up @@ -200,7 +200,6 @@ fn eq_simplification(
) -> (LinkedList<S>, LinkedList<C>) {
use std::sync::mpsc;
use threadpool::ThreadPool;

let field = Arc::new(field.clone());
let mut constraints = LinkedList::new();
let mut substitutions = LinkedList::new();
Expand All @@ -211,11 +210,12 @@ fn eq_simplification(
// println!("Clusters: {}", no_clusters);
let mut single_clusters = 0;
let mut id = 0;
let mut aux_constraints = vec![LinkedList::new(); clusters.len()];
for cluster in clusters {
if Cluster::size(&cluster) == 1 {
let (mut subs, mut cons) = eq_cluster_simplification(cluster, &forbidden, &field);
let (mut subs, cons) = eq_cluster_simplification(cluster, &forbidden, &field);
aux_constraints[id] = cons;
LinkedList::append(&mut substitutions, &mut subs);
LinkedList::append(&mut constraints, &mut cons);
single_clusters += 1;
} else {
let cluster_tx = cluster_tx.clone();
Expand All @@ -225,7 +225,7 @@ fn eq_simplification(
//println!("Cluster: {}", id);
let result = eq_cluster_simplification(cluster, &forbidden, &field);
//println!("End of cluster: {}", id);
cluster_tx.send(result).unwrap();
cluster_tx.send((id, result)).unwrap();
};
ThreadPool::execute(&pool, job);
}
Expand All @@ -235,9 +235,12 @@ fn eq_simplification(
// println!("{} clusters were of size 1", single_clusters);
ThreadPool::join(&pool);
for _ in 0..(no_clusters - single_clusters) {
let (mut subs, mut cons) = simplified_rx.recv().unwrap();
let (id, (mut subs, cons)) = simplified_rx.recv().unwrap();
aux_constraints[id] = cons;
LinkedList::append(&mut substitutions, &mut subs);
LinkedList::append(&mut constraints, &mut cons);
}
for id in 0..no_clusters {
LinkedList::append(&mut constraints, &mut aux_constraints[id]);
}
log_substitutions(&substitutions, substitution_log);
(substitutions, constraints)
Expand All @@ -252,7 +255,7 @@ fn constant_eq_simplification(
let mut cons = LinkedList::new();
let mut subs = LinkedList::new();
for constraint in c_eq {
let mut signals: Vec<_> = C::take_cloned_signals(&constraint).iter().cloned().collect();
let mut signals: Vec<_> = C::take_cloned_signals_ordered(&constraint).iter().cloned().collect();
let signal = signals.pop().unwrap();
if HashSet::contains(&forbidden, &signal) {
LinkedList::push_back(&mut cons, constraint);
Expand Down Expand Up @@ -517,7 +520,7 @@ pub fn simplification(smp: &mut Simplifier) -> (ConstraintStorage, SignalMap) {
relevant
};

let linear_substitutions = if apply_linear {
let linear_substitutions = if remove_unused {
let now = SystemTime::now();
let (subs, mut cons) = linear_simplification(
&mut substitution_log,
Expand Down Expand Up @@ -563,15 +566,17 @@ pub fn simplification(smp: &mut Simplifier) -> (ConstraintStorage, SignalMap) {
crate::state_utils::empty_encoding_constraints(&mut smp.dag_encoding);
let _dur = now.elapsed().unwrap().as_millis();
// println!("Storages built in {} ms", dur);
no_rounds -= 1;
if remove_unused {
no_rounds -= 1;
}
(with_linear, storage)
};

let mut round_id = 0;
let _ = round_id;
let mut linear = with_linear;
let mut apply_round = apply_linear && no_rounds > 0 && !linear.is_empty();
let mut non_linear_map = if apply_round || remove_unused{
let mut apply_round = remove_unused && no_rounds > 0 && !linear.is_empty();
let mut non_linear_map = if apply_round || remove_unused {
// println!("Building non-linear map");
let now = SystemTime::now();
let non_linear_map = build_non_linear_signal_map(&constraint_storage);
Expand Down Expand Up @@ -615,7 +620,7 @@ pub fn simplification(smp: &mut Simplifier) -> (ConstraintStorage, SignalMap) {
}

for constraint in linear {
if remove_unused{
if remove_unused {
let signals = C::take_cloned_signals(&constraint);
let c_id = constraint_storage.add_constraint(constraint);
for signal in signals {
Expand All @@ -632,8 +637,9 @@ pub fn simplification(smp: &mut Simplifier) -> (ConstraintStorage, SignalMap) {
constraint_storage.add_constraint(constraint);
}
}
for constraint in lconst {
for mut constraint in lconst {
if remove_unused{
C::fix_constraint(&mut constraint, &field);
let signals = C::take_cloned_signals(&constraint);
let c_id = constraint_storage.add_constraint(constraint);
for signal in signals {
Expand All @@ -647,6 +653,7 @@ pub fn simplification(smp: &mut Simplifier) -> (ConstraintStorage, SignalMap) {
}
}
else{
C::fix_constraint(&mut constraint, &field);
constraint_storage.add_constraint(constraint);
}
}
Expand Down Expand Up @@ -678,3 +685,6 @@ pub fn simplification(smp: &mut Simplifier) -> (ConstraintStorage, SignalMap) {
// println!("NO CONSTANTS: {}", constraint_storage.no_constants());
(constraint_storage, signal_map)
}



0 comments on commit 2c56dc5

Please sign in to comment.