Skip to content

Commit

Permalink
Merge pull request #166 from sine-fdn/circuit-optimizations
Browse files Browse the repository at this point in the history
Circuit optimizations
  • Loading branch information
fkettelhoit authored Jan 8, 2025
2 parents ba556d3 + 5943556 commit 7a1f850
Show file tree
Hide file tree
Showing 4 changed files with 625 additions and 75 deletions.
211 changes: 186 additions & 25 deletions src/circuit.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! The [`Circuit`] representation used by the compiler.
use crate::{compile::wires_as_unsigned, env::Env, token::MetaInfo};
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};

#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -258,6 +258,7 @@ pub(crate) struct CircuitBuilder {
shift: usize,
input_gates: Vec<usize>,
gates: Vec<BuilderGate>,
used: HashSet<GateIndex>,
cache: HashMap<BuilderGate, GateIndex>,
negated: HashMap<GateIndex, GateIndex>,
gates_optimized: usize,
Expand Down Expand Up @@ -403,6 +404,7 @@ impl CircuitBuilder {
shift: gate_counter,
input_gates,
gates: vec![],
used: HashSet::new(),
cache: HashMap::new(),
negated: HashMap::new(),
gates_optimized: 0,
Expand All @@ -419,6 +421,16 @@ impl CircuitBuilder {
&self.consts
}

fn get_cached(&self, gate: &BuilderGate) -> Option<&usize> {
match self.cache.get(gate) {
Some(wire) => Some(wire),
None => match gate {
BuilderGate::Xor(x, y) => self.cache.get(&BuilderGate::Xor(*y, *x)),
BuilderGate::And(x, y) => self.cache.get(&BuilderGate::And(*y, *x)),
},
}
}

// Pruning of useless gates (gates that are not part of the output nor used by other gates):
fn remove_unused_gates(&mut self, output_gates: Vec<GateIndex>) -> Vec<GateIndex> {
// To find all unused gates, we start at the output gates and recursively mark all their
Expand Down Expand Up @@ -764,12 +776,128 @@ impl CircuitBuilder {
}
}
// Sub-expression sharing:
if let Some(&wire) = self.cache.get(&BuilderGate::Xor(x, y)) {
if let Some(&wire) = self.get_cached(&BuilderGate::Xor(x, y)) {
return Some(wire);
}
None
}

pub fn push_xor(&mut self, x: GateIndex, y: GateIndex) -> GateIndex {
if let Some(optimized) = self.optimize_xor(x, y) {
self.gates_optimized += 1;
optimized
} else {
if x >= self.shift && y >= self.shift {
let gate_x = self.gates[x - self.shift];
let gate_y = self.gates[y - self.shift];
if let (BuilderGate::Xor(x1, x2), BuilderGate::Xor(y1, y2)) = (gate_x, gate_y) {
if x1 == y1 {
return self.push_xor(x2, y2);
} else if x1 == y2 {
return self.push_xor(x2, y1);
} else if x2 == y1 {
return self.push_xor(x1, y2);
} else if x2 == y2 {
return self.push_xor(x1, y1);
}
} else if let (BuilderGate::And(x1, x2), BuilderGate::And(y1, y2)) =
(gate_x, gate_y)
{
for (a1, a2, b1, b2) in [
(x1, x2, y1, y2),
(x1, x2, y2, y1),
(x2, x1, y1, y2),
(x2, x1, y2, y1),
] {
if a1 == b1 {
if let Some(&a2_xor_b2) = self.get_cached(&BuilderGate::Xor(a2, b2)) {
if let Some(&wire) =
self.get_cached(&BuilderGate::And(a1, a2_xor_b2))
{
self.gates_optimized += 1;
return wire;
}
}
}
}
for (a1, a2, b1, b2) in [
(x1, x2, y1, y2),
(x1, x2, y2, y1),
(x2, x1, y1, y2),
(x2, x1, y2, y1),
] {
if a1 == b1 && !self.used.contains(&x) && !self.used.contains(&y) {
let a2_xor_b2_gate = BuilderGate::Xor(a2, b2);
self.gate_counter += 1;
self.gates.push(a2_xor_b2_gate);
let a2_xor_b2 = self.gate_counter - 1;
self.cache.insert(a2_xor_b2_gate, a2_xor_b2);

let gate = BuilderGate::And(a1, a2_xor_b2);
self.gate_counter += 1;
self.gates.push(gate);
self.cache.insert(gate, self.gate_counter - 1);
self.used.insert(a2_xor_b2);
return self.gate_counter - 1;
}
}
}
}
if x >= self.shift {
let gate_x = self.gates[x - self.shift];
if let BuilderGate::Xor(x1, x2) = gate_x {
if x1 == y {
self.gates_optimized += 1;
return x2;
} else if x2 == y {
self.gates_optimized += 1;
return x1;
} else if let Some(&y_negated) = self.negated.get(&y) {
if x1 == y_negated {
return self.push_xor(x2, 1);
} else if x2 == y_negated {
return self.push_xor(x1, 1);
}
}
}
}
if y >= self.shift {
let gate_y = self.gates[y - self.shift];
if let BuilderGate::Xor(y1, y2) = gate_y {
if x == y1 {
self.gates_optimized += 1;
return y2;
} else if x == y2 {
self.gates_optimized += 1;
return y1;
} else if let Some(&x_negated) = self.negated.get(&x) {
if x_negated == y1 {
return self.push_xor(1, y2);
} else if x_negated == y2 {
return self.push_xor(1, y1);
}
}
}
}
let gate = BuilderGate::Xor(x, y);
self.gate_counter += 1;
self.gates.push(gate);
let gate_index = self.gate_counter - 1;
self.cache.insert(gate, gate_index);
self.used.insert(x);
self.used.insert(y);
if x == 1 {
self.negated.insert(y, gate_index);
self.negated.insert(gate_index, y);
}
if y == 1 {
self.negated.insert(x, gate_index);
self.negated.insert(gate_index, x);
}
gate_index
}
}

// - Constant evaluation (e.g. x & x == x; x & 1 == x; x & 0 == 0)
// - Sub-expression sharing (wires are re-used if a gate with the same type and inputs exists)
fn optimize_and(&self, x: GateIndex, y: GateIndex) -> Option<GateIndex> {
Expand All @@ -789,43 +917,76 @@ impl CircuitBuilder {
}
}
// Sub-expression sharing:
if let Some(&wire) = self.cache.get(&BuilderGate::And(x, y)) {
if let Some(&wire) = self.get_cached(&BuilderGate::And(x, y)) {
return Some(wire);
}
None
}

pub fn push_xor(&mut self, x: GateIndex, y: GateIndex) -> GateIndex {
if let Some(optimized) = self.optimize_xor(x, y) {
self.gates_optimized += 1;
optimized
} else {
let gate = BuilderGate::Xor(x, y);
self.gate_counter += 1;
self.gates.push(gate);
let gate_index = self.gate_counter - 1;
self.cache.insert(gate, gate_index);
if x == 1 {
self.negated.insert(y, gate_index);
self.negated.insert(gate_index, y);
}
if y == 1 {
self.negated.insert(x, gate_index);
self.negated.insert(gate_index, x);
}
gate_index
}
}

pub fn push_and(&mut self, x: GateIndex, y: GateIndex) -> GateIndex {
if let Some(optimized) = self.optimize_and(x, y) {
self.gates_optimized += 1;
optimized
} else {
if x >= self.shift && y >= self.shift {
let gate_x = self.gates[x - self.shift];
let gate_y = self.gates[y - self.shift];
if let (BuilderGate::And(x1, x2), BuilderGate::And(y1, y2)) = (gate_x, gate_y) {
if x1 == y1 || x2 == y1 {
return self.push_and(x, y2);
} else if x1 == y2 || x2 == y2 {
return self.push_and(x, y1);
}
}
}
if x >= self.shift {
let gate_x = self.gates[x - self.shift];
if let BuilderGate::And(x1, x2) = gate_x {
if x1 == y || x2 == y {
self.gates_optimized += 1;
return x;
}
if let Some(&y_negated) = self.negated.get(&y) {
if x1 == y_negated || x2 == y_negated {
return 0;
}
}
} else if let BuilderGate::Xor(x1, x2) = gate_x {
if let (Some(&x1_and_y), Some(&x2_and_y)) = (
self.get_cached(&BuilderGate::And(x1, y)),
self.get_cached(&BuilderGate::And(x2, y)),
) {
return self.push_xor(x1_and_y, x2_and_y);
}
}
}
if y >= self.shift {
let gate_y = self.gates[y - self.shift];
if let BuilderGate::And(y1, y2) = gate_y {
if x == y1 || x == y2 {
self.gates_optimized += 1;
return y;
}
if let Some(&x_negated) = self.negated.get(&x) {
if x_negated == y1 || x_negated == y2 {
return 0;
}
}
} else if let BuilderGate::Xor(y1, y2) = gate_y {
if let (Some(&x_and_y1), Some(&x_and_y2)) = (
self.get_cached(&BuilderGate::And(x, y1)),
self.get_cached(&BuilderGate::And(x, y2)),
) {
return self.push_xor(x_and_y1, x_and_y2);
}
}
}
let gate = BuilderGate::And(x, y);
self.gate_counter += 1;
self.gates.push(gate);
self.cache.insert(gate, self.gate_counter - 1);
self.used.insert(x);
self.used.insert(y);
self.gate_counter - 1
}
}
Expand Down
47 changes: 45 additions & 2 deletions src/compile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ use std::{

use crate::{
ast::{
ConstExpr, ConstExprEnum, EnumDef, ExprEnum, Op, Pattern, PatternEnum, StmtEnum, StructDef,
Type, UnaryOp, VariantExprEnum,
ConstExpr, ConstExprEnum, EnumDef, Expr, ExprEnum, Op, Pattern, PatternEnum, StmtEnum,
StructDef, Type, UnaryOp, VariantExprEnum,
},
circuit::{Circuit, CircuitBuilder, GateIndex, PanicReason, USIZE_BITS},
env::Env,
Expand Down Expand Up @@ -782,6 +782,49 @@ impl TypedExpr {
bits_unshifted
}
ExprEnum::Op(op, x, y) => {
if let Op::Mul = op {
for (x, y) in [(x, y), (y, x)] {
let (n, bits, is_neg) = match x.inner {
ExprEnum::NumUnsigned(n, size) => (
n,
Type::Unsigned(size)
.size_in_bits_for_defs(prg, circuit.const_sizes())
as u64,
false,
),
ExprEnum::NumSigned(n, size) => (
n.unsigned_abs(),
Type::Signed(size).size_in_bits_for_defs(prg, circuit.const_sizes())
as u64,
n < 0,
),
_ => continue,
};
if n == 0 {
continue;
}
if n < bits {
let mut expr = y.clone();
for _ in 0..n - 1 {
expr = Box::new(Expr {
inner: ExprEnum::Op(Op::Add, expr, y.clone()),
meta,
ty: ty.clone(),
});
}
if is_neg {
return Expr {
inner: ExprEnum::UnaryOp(UnaryOp::Neg, expr),
meta,
ty: ty.clone(),
}
.compile(prg, env, circuit);
} else {
return expr.compile(prg, env, circuit);
}
}
}
}
let ty_x = &x.ty;
let ty_y = &y.ty;
let mut x = x.compile(prg, env, circuit);
Expand Down
Loading

0 comments on commit 7a1f850

Please sign in to comment.