Skip to content

Commit

Permalink
Add more XOR / AND optimizations
Browse files Browse the repository at this point in the history
  • Loading branch information
fkettelhoit committed Jan 6, 2025
1 parent decc761 commit 5943556
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 2 deletions.
22 changes: 22 additions & 0 deletions src/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -852,6 +852,12 @@ impl CircuitBuilder {
} else if x2 == y {
self.gates_optimized += 1;
return x1;
} else if let Some(&y_negated) = self.negated.get(&y) {
if x1 == y_negated {
return self.push_xor(x2, 1);
} else if x2 == y_negated {
return self.push_xor(x1, 1);
}
}
}
}
Expand All @@ -864,6 +870,12 @@ impl CircuitBuilder {
} else if x == y2 {
self.gates_optimized += 1;
return y1;
} else if let Some(&x_negated) = self.negated.get(&x) {
if x_negated == y1 {
return self.push_xor(1, y2);
} else if x_negated == y2 {
return self.push_xor(1, y1);
}
}
}
}
Expand Down Expand Up @@ -934,6 +946,11 @@ impl CircuitBuilder {
self.gates_optimized += 1;
return x;
}
if let Some(&y_negated) = self.negated.get(&y) {
if x1 == y_negated || x2 == y_negated {
return 0;
}
}
} else if let BuilderGate::Xor(x1, x2) = gate_x {
if let (Some(&x1_and_y), Some(&x2_and_y)) = (
self.get_cached(&BuilderGate::And(x1, y)),
Expand All @@ -950,6 +967,11 @@ impl CircuitBuilder {
self.gates_optimized += 1;
return y;
}
if let Some(&x_negated) = self.negated.get(&x) {
if x_negated == y1 || x_negated == y2 {
return 0;
}
}
} else if let BuilderGate::Xor(y1, y2) = gate_y {
if let (Some(&x_and_y1), Some(&x_and_y2)) = (
self.get_cached(&BuilderGate::And(x, y1)),
Expand Down
118 changes: 116 additions & 2 deletions tests/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,63 @@ pub fn main(a: u32, b: u32, c: u32) -> u32 {
Ok(())
}

#[test]
fn optimize_and_pattern4() -> Result<(), String> {
let naive = "
pub fn main(a: u32, b: u32) -> u32 {
(a & b) & !a
}
";
let optimized = "
pub fn main(a: u32, b: u32) -> u32 {
0
}
";
let naive = compile(naive).map_err(|e| e.prettify(naive))?;
let optimized = compile(optimized).map_err(|e| e.prettify(optimized))?;

assert_eq!(naive.circuit.gates.len(), optimized.circuit.gates.len());
Ok(())
}

#[test]
fn optimize_and_pattern5() -> Result<(), String> {
let naive = "
pub fn main(a: u32, b: u32) -> u32 {
a & (b & !a)
}
";
let optimized = "
pub fn main(a: u32, b: u32) -> u32 {
0
}
";
let naive = compile(naive).map_err(|e| e.prettify(naive))?;
let optimized = compile(optimized).map_err(|e| e.prettify(optimized))?;

assert_eq!(naive.circuit.gates.len(), optimized.circuit.gates.len());
Ok(())
}

#[test]
fn optimize_and_pattern6() -> Result<(), String> {
let naive = "
pub fn main(a: u32, b: u32) -> u32 {
(a & b) & (!a & b)
}
";
let optimized = "
pub fn main(a: u32, b: u32) -> u32 {
0
}
";
let naive = compile(naive).map_err(|e| e.prettify(naive))?;
let optimized = compile(optimized).map_err(|e| e.prettify(optimized))?;

assert_eq!(naive.circuit.gates.len(), optimized.circuit.gates.len());
Ok(())
}

#[test]
fn optimize_xor_pattern1() -> Result<(), String> {
let naive = "
Expand Down Expand Up @@ -192,7 +249,64 @@ pub fn main(a: u32, b: u32, c: u32) -> u32 {
}

#[test]
fn optimize_and_xor_pattern() -> Result<(), String> {
fn optimize_xor_pattern4() -> Result<(), String> {
let naive = "
pub fn main(a: u32, b: u32) -> u32 {
(a ^ b) ^ !a
}
";
let optimized = "
pub fn main(a: u32, b: u32) -> u32 {
b ^ !0
}
";
let naive = compile(naive).map_err(|e| e.prettify(naive))?;
let optimized = compile(optimized).map_err(|e| e.prettify(optimized))?;

assert_eq!(naive.circuit.gates.len(), optimized.circuit.gates.len());
Ok(())
}

#[test]
fn optimize_xor_pattern5() -> Result<(), String> {
let naive = "
pub fn main(a: u32, b: u32) -> u32 {
a ^ (b ^ !a)
}
";
let optimized = "
pub fn main(a: u32, b: u32) -> u32 {
b ^ !0
}
";
let naive = compile(naive).map_err(|e| e.prettify(naive))?;
let optimized = compile(optimized).map_err(|e| e.prettify(optimized))?;

assert_eq!(naive.circuit.gates.len(), optimized.circuit.gates.len());
Ok(())
}

#[test]
fn optimize_xor_pattern6() -> Result<(), String> {
let naive = "
pub fn main(a: u32, b: u32) -> u32 {
(a ^ b) ^ (!a ^ b)
}
";
let optimized = "
pub fn main(a: u32, b: u32) -> u32 {
!0
}
";
let naive = compile(naive).map_err(|e| e.prettify(naive))?;
let optimized = compile(optimized).map_err(|e| e.prettify(optimized))?;

assert_eq!(naive.circuit.gates.len(), optimized.circuit.gates.len());
Ok(())
}

#[test]
fn optimize_and_xor_pattern1() -> Result<(), String> {
let naive = "
pub fn main(a: u32, b: u32, c: u32) -> [u32; 2] {
[a & (b ^ c), (a & b) ^ (a & c)]
Expand All @@ -212,7 +326,7 @@ pub fn main(a: u32, b: u32, c: u32) -> [u32; 2] {
}

#[test]
fn optimize_xor_and_pattern() -> Result<(), String> {
fn optimize_and_xor_pattern2() -> Result<(), String> {
let naive = "
pub fn main(a: u32, b: u32, c: u32) -> [u32; 2] {
[(a & b) ^ (a & c), a & (b ^ c)]
Expand Down

0 comments on commit 5943556

Please sign in to comment.