diff --git a/src/compile.rs b/src/compile.rs index 1b27684..1c4af3d 100644 --- a/src/compile.rs +++ b/src/compile.rs @@ -7,8 +7,8 @@ use std::{ use crate::{ ast::{ - ConstExpr, ConstExprEnum, EnumDef, ExprEnum, Op, Pattern, PatternEnum, StmtEnum, StructDef, - Type, UnaryOp, VariantExprEnum, + ConstExpr, ConstExprEnum, EnumDef, Expr, ExprEnum, Op, Pattern, PatternEnum, StmtEnum, + StructDef, Type, UnaryOp, VariantExprEnum, }, circuit::{Circuit, CircuitBuilder, GateIndex, PanicReason, USIZE_BITS}, env::Env, @@ -782,6 +782,49 @@ impl TypedExpr { bits_unshifted } ExprEnum::Op(op, x, y) => { + if let Op::Mul = op { + for (x, y) in [(x, y), (y, x)] { + let (n, bits, is_neg) = match x.inner { + ExprEnum::NumUnsigned(n, size) => ( + n, + Type::Unsigned(size) + .size_in_bits_for_defs(prg, circuit.const_sizes()) + as u64, + false, + ), + ExprEnum::NumSigned(n, size) => ( + n.abs() as u64, + Type::Signed(size).size_in_bits_for_defs(prg, circuit.const_sizes()) + as u64, + n < 0, + ), + _ => continue, + }; + if n == 0 { + continue; + } + if n < bits { + let mut expr = y.clone(); + for _ in 0..n - 1 { + expr = Box::new(Expr { + inner: ExprEnum::Op(Op::Add, expr, y.clone()), + meta, + ty: ty.clone(), + }); + } + if is_neg { + return Expr { + inner: ExprEnum::UnaryOp(UnaryOp::Neg, expr), + meta, + ty: ty.clone(), + } + .compile(prg, env, circuit); + } else { + return expr.compile(prg, env, circuit); + } + } + } + } let ty_x = &x.ty; let ty_y = &y.ty; let mut x = x.compile(prg, env, circuit); diff --git a/tests/circuit.rs b/tests/circuit.rs index 59cb4ef..c2ffbbb 100644 --- a/tests/circuit.rs +++ b/tests/circuit.rs @@ -232,6 +232,48 @@ pub fn main(arr1: [(u16, u16, u32); 8]) -> [((u16, u16), u32); 8] { Ok(()) } +#[test] +fn optimize_constant_mul() -> Result<(), String> { + let unoptimized = " +pub fn main(x: i32) -> i32 { + 2 * x +} +"; + let optimized = " +pub fn main(x: i32) -> i32 { + x + x +} +"; + let unoptimized = compile(unoptimized).map_err(|e| e.prettify(unoptimized))?; + let optimized = compile(optimized).map_err(|e| e.prettify(optimized))?; + assert_eq!( + unoptimized.circuit.gates.len(), + optimized.circuit.gates.len() + ); + Ok(()) +} + +#[test] +fn optimize_constant_mul_signed() -> Result<(), String> { + let unoptimized = " +pub fn main(x: i32) -> i32 { + -2 * x +} +"; + let optimized = " +pub fn main(x: i32) -> i32 { + -(x + x) +} +"; + let unoptimized = compile(unoptimized).map_err(|e| e.prettify(unoptimized))?; + let optimized = compile(optimized).map_err(|e| e.prettify(optimized))?; + assert_eq!( + unoptimized.circuit.gates.len(), + optimized.circuit.gates.len() + ); + Ok(()) +} + // Run the following test using `cargo test plot --features=plot --release -- --nocapture` #[test]