Skip to content

Commit

Permalink
Optimize multiplication by const numbers
Browse files Browse the repository at this point in the history
  • Loading branch information
fkettelhoit committed Dec 11, 2024
1 parent 749af4b commit 4823a76
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 2 deletions.
47 changes: 45 additions & 2 deletions src/compile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ use std::{

use crate::{
ast::{
ConstExpr, ConstExprEnum, EnumDef, ExprEnum, Op, Pattern, PatternEnum, StmtEnum, StructDef,
Type, UnaryOp, VariantExprEnum,
ConstExpr, ConstExprEnum, EnumDef, Expr, ExprEnum, Op, Pattern, PatternEnum, StmtEnum,
StructDef, Type, UnaryOp, VariantExprEnum,
},
circuit::{Circuit, CircuitBuilder, GateIndex, PanicReason, USIZE_BITS},
env::Env,
Expand Down Expand Up @@ -782,6 +782,49 @@ impl TypedExpr {
bits_unshifted
}
ExprEnum::Op(op, x, y) => {
if let Op::Mul = op {
for (x, y) in [(x, y), (y, x)] {
let (n, bits, is_neg) = match x.inner {
ExprEnum::NumUnsigned(n, size) => (
n,
Type::Unsigned(size)
.size_in_bits_for_defs(prg, circuit.const_sizes())
as u64,
false,
),
ExprEnum::NumSigned(n, size) => (
n.abs() as u64,
Type::Signed(size).size_in_bits_for_defs(prg, circuit.const_sizes())
as u64,
n < 0,
),
_ => continue,
};
if n == 0 {
continue;
}
if n < bits {
let mut expr = y.clone();
for _ in 0..n - 1 {
expr = Box::new(Expr {
inner: ExprEnum::Op(Op::Add, expr, y.clone()),
meta,
ty: ty.clone(),
});
}
if is_neg {
return Expr {
inner: ExprEnum::UnaryOp(UnaryOp::Neg, expr),
meta,
ty: ty.clone(),
}
.compile(prg, env, circuit);
} else {
return expr.compile(prg, env, circuit);
}
}
}
}
let ty_x = &x.ty;
let ty_y = &y.ty;
let mut x = x.compile(prg, env, circuit);
Expand Down
42 changes: 42 additions & 0 deletions tests/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,48 @@ pub fn main(arr1: [(u16, u16, u32); 8]) -> [((u16, u16), u32); 8] {
Ok(())
}

#[test]
fn optimize_constant_mul() -> Result<(), String> {
let unoptimized = "
pub fn main(x: i32) -> i32 {
2 * x
}
";
let optimized = "
pub fn main(x: i32) -> i32 {
x + x
}
";
let unoptimized = compile(unoptimized).map_err(|e| e.prettify(unoptimized))?;
let optimized = compile(optimized).map_err(|e| e.prettify(optimized))?;
assert_eq!(
unoptimized.circuit.gates.len(),
optimized.circuit.gates.len()
);
Ok(())
}

#[test]
fn optimize_constant_mul_signed() -> Result<(), String> {
let unoptimized = "
pub fn main(x: i32) -> i32 {
-2 * x
}
";
let optimized = "
pub fn main(x: i32) -> i32 {
-(x + x)
}
";
let unoptimized = compile(unoptimized).map_err(|e| e.prettify(unoptimized))?;
let optimized = compile(optimized).map_err(|e| e.prettify(optimized))?;
assert_eq!(
unoptimized.circuit.gates.len(),
optimized.circuit.gates.len()
);
Ok(())
}

// Run the following test using `cargo test plot --features=plot --release -- --nocapture`

#[test]
Expand Down

0 comments on commit 4823a76

Please sign in to comment.