Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Circuit optimizations #166

Merged
merged 8 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 186 additions & 25 deletions src/circuit.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! The [`Circuit`] representation used by the compiler.

use crate::{compile::wires_as_unsigned, env::Env, token::MetaInfo};
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};

#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -258,6 +258,7 @@ pub(crate) struct CircuitBuilder {
shift: usize,
input_gates: Vec<usize>,
gates: Vec<BuilderGate>,
used: HashSet<GateIndex>,
cache: HashMap<BuilderGate, GateIndex>,
negated: HashMap<GateIndex, GateIndex>,
gates_optimized: usize,
Expand Down Expand Up @@ -403,6 +404,7 @@ impl CircuitBuilder {
shift: gate_counter,
input_gates,
gates: vec![],
used: HashSet::new(),
cache: HashMap::new(),
negated: HashMap::new(),
gates_optimized: 0,
Expand All @@ -419,6 +421,16 @@ impl CircuitBuilder {
&self.consts
}

fn get_cached(&self, gate: &BuilderGate) -> Option<&usize> {
match self.cache.get(gate) {
Some(wire) => Some(wire),
None => match gate {
BuilderGate::Xor(x, y) => self.cache.get(&BuilderGate::Xor(*y, *x)),
BuilderGate::And(x, y) => self.cache.get(&BuilderGate::And(*y, *x)),
},
}
}

// Pruning of useless gates (gates that are not part of the output nor used by other gates):
fn remove_unused_gates(&mut self, output_gates: Vec<GateIndex>) -> Vec<GateIndex> {
// To find all unused gates, we start at the output gates and recursively mark all their
Expand Down Expand Up @@ -764,12 +776,128 @@ impl CircuitBuilder {
}
}
// Sub-expression sharing:
if let Some(&wire) = self.cache.get(&BuilderGate::Xor(x, y)) {
if let Some(&wire) = self.get_cached(&BuilderGate::Xor(x, y)) {
return Some(wire);
}
None
}

pub fn push_xor(&mut self, x: GateIndex, y: GateIndex) -> GateIndex {
if let Some(optimized) = self.optimize_xor(x, y) {
self.gates_optimized += 1;
optimized
} else {
if x >= self.shift && y >= self.shift {
let gate_x = self.gates[x - self.shift];
let gate_y = self.gates[y - self.shift];
if let (BuilderGate::Xor(x1, x2), BuilderGate::Xor(y1, y2)) = (gate_x, gate_y) {
if x1 == y1 {
return self.push_xor(x2, y2);
} else if x1 == y2 {
return self.push_xor(x2, y1);
} else if x2 == y1 {
return self.push_xor(x1, y2);
} else if x2 == y2 {
return self.push_xor(x1, y1);
}
} else if let (BuilderGate::And(x1, x2), BuilderGate::And(y1, y2)) =
(gate_x, gate_y)
{
for (a1, a2, b1, b2) in [
(x1, x2, y1, y2),
(x1, x2, y2, y1),
(x2, x1, y1, y2),
(x2, x1, y2, y1),
] {
if a1 == b1 {
if let Some(&a2_xor_b2) = self.get_cached(&BuilderGate::Xor(a2, b2)) {
if let Some(&wire) =
self.get_cached(&BuilderGate::And(a1, a2_xor_b2))
{
self.gates_optimized += 1;
return wire;
}
}
}
}
for (a1, a2, b1, b2) in [
(x1, x2, y1, y2),
(x1, x2, y2, y1),
(x2, x1, y1, y2),
(x2, x1, y2, y1),
] {
if a1 == b1 && !self.used.contains(&x) && !self.used.contains(&y) {
let a2_xor_b2_gate = BuilderGate::Xor(a2, b2);
self.gate_counter += 1;
self.gates.push(a2_xor_b2_gate);
let a2_xor_b2 = self.gate_counter - 1;
self.cache.insert(a2_xor_b2_gate, a2_xor_b2);

let gate = BuilderGate::And(a1, a2_xor_b2);
self.gate_counter += 1;
self.gates.push(gate);
self.cache.insert(gate, self.gate_counter - 1);
self.used.insert(a2_xor_b2);
return self.gate_counter - 1;
}
}
}
}
if x >= self.shift {
let gate_x = self.gates[x - self.shift];
if let BuilderGate::Xor(x1, x2) = gate_x {
if x1 == y {
self.gates_optimized += 1;
return x2;
} else if x2 == y {
self.gates_optimized += 1;
return x1;
} else if let Some(&y_negated) = self.negated.get(&y) {
if x1 == y_negated {
return self.push_xor(x2, 1);
} else if x2 == y_negated {
return self.push_xor(x1, 1);
}
}
}
}
if y >= self.shift {
let gate_y = self.gates[y - self.shift];
if let BuilderGate::Xor(y1, y2) = gate_y {
if x == y1 {
self.gates_optimized += 1;
return y2;
} else if x == y2 {
self.gates_optimized += 1;
return y1;
} else if let Some(&x_negated) = self.negated.get(&x) {
if x_negated == y1 {
return self.push_xor(1, y2);
} else if x_negated == y2 {
return self.push_xor(1, y1);
}
}
}
}
let gate = BuilderGate::Xor(x, y);
self.gate_counter += 1;
self.gates.push(gate);
let gate_index = self.gate_counter - 1;
self.cache.insert(gate, gate_index);
self.used.insert(x);
self.used.insert(y);
if x == 1 {
self.negated.insert(y, gate_index);
self.negated.insert(gate_index, y);
}
if y == 1 {
self.negated.insert(x, gate_index);
self.negated.insert(gate_index, x);
}
gate_index
}
}

// - Constant evaluation (e.g. x & x == x; x & 1 == x; x & 0 == 0)
// - Sub-expression sharing (wires are re-used if a gate with the same type and inputs exists)
fn optimize_and(&self, x: GateIndex, y: GateIndex) -> Option<GateIndex> {
Expand All @@ -789,43 +917,76 @@ impl CircuitBuilder {
}
}
// Sub-expression sharing:
if let Some(&wire) = self.cache.get(&BuilderGate::And(x, y)) {
if let Some(&wire) = self.get_cached(&BuilderGate::And(x, y)) {
return Some(wire);
}
None
}

pub fn push_xor(&mut self, x: GateIndex, y: GateIndex) -> GateIndex {
if let Some(optimized) = self.optimize_xor(x, y) {
self.gates_optimized += 1;
optimized
} else {
let gate = BuilderGate::Xor(x, y);
self.gate_counter += 1;
self.gates.push(gate);
let gate_index = self.gate_counter - 1;
self.cache.insert(gate, gate_index);
if x == 1 {
self.negated.insert(y, gate_index);
self.negated.insert(gate_index, y);
}
if y == 1 {
self.negated.insert(x, gate_index);
self.negated.insert(gate_index, x);
}
gate_index
}
}

pub fn push_and(&mut self, x: GateIndex, y: GateIndex) -> GateIndex {
if let Some(optimized) = self.optimize_and(x, y) {
self.gates_optimized += 1;
optimized
} else {
if x >= self.shift && y >= self.shift {
let gate_x = self.gates[x - self.shift];
let gate_y = self.gates[y - self.shift];
if let (BuilderGate::And(x1, x2), BuilderGate::And(y1, y2)) = (gate_x, gate_y) {
if x1 == y1 || x2 == y1 {
return self.push_and(x, y2);
} else if x1 == y2 || x2 == y2 {
return self.push_and(x, y1);
}
}
}
if x >= self.shift {
let gate_x = self.gates[x - self.shift];
if let BuilderGate::And(x1, x2) = gate_x {
if x1 == y || x2 == y {
self.gates_optimized += 1;
return x;
}
if let Some(&y_negated) = self.negated.get(&y) {
if x1 == y_negated || x2 == y_negated {
return 0;
}
}
} else if let BuilderGate::Xor(x1, x2) = gate_x {
if let (Some(&x1_and_y), Some(&x2_and_y)) = (
self.get_cached(&BuilderGate::And(x1, y)),
self.get_cached(&BuilderGate::And(x2, y)),
) {
return self.push_xor(x1_and_y, x2_and_y);
}
}
}
if y >= self.shift {
let gate_y = self.gates[y - self.shift];
if let BuilderGate::And(y1, y2) = gate_y {
if x == y1 || x == y2 {
self.gates_optimized += 1;
return y;
}
if let Some(&x_negated) = self.negated.get(&x) {
if x_negated == y1 || x_negated == y2 {
return 0;
}
}
} else if let BuilderGate::Xor(y1, y2) = gate_y {
if let (Some(&x_and_y1), Some(&x_and_y2)) = (
self.get_cached(&BuilderGate::And(x, y1)),
self.get_cached(&BuilderGate::And(x, y2)),
) {
return self.push_xor(x_and_y1, x_and_y2);
}
}
}
let gate = BuilderGate::And(x, y);
self.gate_counter += 1;
self.gates.push(gate);
self.cache.insert(gate, self.gate_counter - 1);
self.used.insert(x);
self.used.insert(y);
self.gate_counter - 1
}
}
Expand Down
47 changes: 45 additions & 2 deletions src/compile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ use std::{

use crate::{
ast::{
ConstExpr, ConstExprEnum, EnumDef, ExprEnum, Op, Pattern, PatternEnum, StmtEnum, StructDef,
Type, UnaryOp, VariantExprEnum,
ConstExpr, ConstExprEnum, EnumDef, Expr, ExprEnum, Op, Pattern, PatternEnum, StmtEnum,
StructDef, Type, UnaryOp, VariantExprEnum,
},
circuit::{Circuit, CircuitBuilder, GateIndex, PanicReason, USIZE_BITS},
env::Env,
Expand Down Expand Up @@ -782,6 +782,49 @@ impl TypedExpr {
bits_unshifted
}
ExprEnum::Op(op, x, y) => {
if let Op::Mul = op {
for (x, y) in [(x, y), (y, x)] {
let (n, bits, is_neg) = match x.inner {
ExprEnum::NumUnsigned(n, size) => (
n,
Type::Unsigned(size)
.size_in_bits_for_defs(prg, circuit.const_sizes())
as u64,
false,
),
ExprEnum::NumSigned(n, size) => (
n.unsigned_abs(),
Type::Signed(size).size_in_bits_for_defs(prg, circuit.const_sizes())
as u64,
n < 0,
),
_ => continue,
};
if n == 0 {
continue;
}
if n < bits {
let mut expr = y.clone();
for _ in 0..n - 1 {
expr = Box::new(Expr {
inner: ExprEnum::Op(Op::Add, expr, y.clone()),
meta,
ty: ty.clone(),
});
}
if is_neg {
return Expr {
inner: ExprEnum::UnaryOp(UnaryOp::Neg, expr),
meta,
ty: ty.clone(),
}
.compile(prg, env, circuit);
} else {
return expr.compile(prg, env, circuit);
}
}
}
}
let ty_x = &x.ty;
let ty_y = &y.ty;
let mut x = x.compile(prg, env, circuit);
Expand Down
Loading
Loading