diff --git a/examples/fixture/asm/kimchi/generic_builtin_bits.asm b/examples/fixture/asm/kimchi/generic_builtin_bits.asm index 671fcfc5e..a447ae7b1 100644 --- a/examples/fixture/asm/kimchi/generic_builtin_bits.asm +++ b/examples/fixture/asm/kimchi/generic_builtin_bits.asm @@ -2,19 +2,34 @@ @ public inputs: 1 DoubleGeneric<1> +DoubleGeneric<1,0,-1> +DoubleGeneric<1,1> +DoubleGeneric<1,0,-1,0,1> DoubleGeneric<1,0,-1,0,-1> DoubleGeneric<0,0,-1,1> DoubleGeneric<1> DoubleGeneric<1,0,-1> +DoubleGeneric<1,1> +DoubleGeneric<1,0,-1,0,1> +DoubleGeneric<1,0,-1> +DoubleGeneric<1,1> +DoubleGeneric<1,0,-1,0,1> DoubleGeneric<1,0,-1,0,-1> DoubleGeneric<0,0,-1,1> DoubleGeneric<1> DoubleGeneric<2,0,-1> +DoubleGeneric<1,1> +DoubleGeneric<1,0,-1,0,1> DoubleGeneric<1,1,-1> +DoubleGeneric<1,0,-1> +DoubleGeneric<1,1> +DoubleGeneric<1,0,-1,0,1> DoubleGeneric<1,0,-1,0,-1> DoubleGeneric<0,0,-1,1> DoubleGeneric<1> DoubleGeneric<4,0,-1> +DoubleGeneric<1,1> +DoubleGeneric<1,0,-1,0,1> DoubleGeneric<1,1,-1> DoubleGeneric<1,-1> DoubleGeneric<1,1> @@ -25,47 +40,51 @@ DoubleGeneric<1,1> DoubleGeneric<1,0,-1,0,1> DoubleGeneric<1,0,0,0,-1> DoubleGeneric<1,0,-1> +DoubleGeneric<1,1> +DoubleGeneric<1,0,-1,0,1> DoubleGeneric<2,0,-1> +DoubleGeneric<1,1> +DoubleGeneric<1,0,-1,0,1> DoubleGeneric<1,1,-1> DoubleGeneric<4,0,-1> +DoubleGeneric<1,1> +DoubleGeneric<1,0,-1,0,1> DoubleGeneric<1,1,-1> DoubleGeneric<1,-1> -DoubleGeneric<1> -DoubleGeneric<1,0,0,0,-1> -DoubleGeneric<1,0,-1> -DoubleGeneric<2,0,-1> -DoubleGeneric<1,1,-1> -DoubleGeneric<4,0,-1> -DoubleGeneric<1,1,-1> -DoubleGeneric<1,-1> -(0,0) -> (15,0) -> (28,1) -> (36,1) -(1,0) -> (2,0) -> (4,0) -> (16,0) -> (23,0) -(1,2) -> (2,1) -(2,2) -> (3,0) -(4,2) -> (9,0) -(5,0) -> (6,0) -> (8,0) -> (19,0) -> (24,0) -(5,2) -> (6,1) -(6,2) -> (7,0) -(8,2) -> (9,1) -(9,2) -> (14,0) -(10,0) -> (11,0) -> (13,0) -> (20,0) -> (26,0) -(10,2) -> (11,1) -(11,2) -> (12,0) +DoubleGeneric<1,0,0,0,-2> +(0,0) -> (30,1) -> (49,1) -> (50,0) +(1,0) -> (2,0) -> (7,0) -> (8,0) -> (31,0) -> (38,0) -> (39,0) +(1,2) -> (4,0) -> (5,0) +(2,1) -> (3,0) +(4,2) -> (5,1) +(5,2) -> (6,0) +(7,2) -> (19,0) +(8,1) -> (9,0) +(10,0) -> (11,0) -> (16,0) -> (17,0) -> (34,0) -> (41,0) -> (42,0) +(10,2) -> (13,0) -> (14,0) +(11,1) -> (12,0) (13,2) -> (14,1) -(14,2) -> (15,1) -(16,1) -> (17,0) -(17,2) -> (18,0) -(20,1) -> (21,0) -(21,2) -> (22,0) -(23,2) -> (25,0) -(24,2) -> (25,1) -(25,2) -> (27,0) -(26,2) -> (27,1) -(27,2) -> (28,0) -(29,0) -> (31,0) -> (34,0) -(30,0) -> (32,0) -(31,2) -> (33,0) -(32,2) -> (33,1) -(33,2) -> (35,0) -(34,2) -> (35,1) -(35,2) -> (36,0) +(14,2) -> (15,0) +(16,2) -> (19,1) +(17,1) -> (18,0) +(19,2) -> (29,0) +(20,0) -> (21,0) -> (26,0) -> (27,0) -> (35,0) -> (45,0) -> (46,0) +(20,2) -> (23,0) -> (24,0) +(21,1) -> (22,0) +(23,2) -> (24,1) +(24,2) -> (25,0) +(26,2) -> (29,1) +(27,1) -> (28,0) +(29,2) -> (30,0) +(31,1) -> (32,0) +(32,2) -> (33,0) +(35,1) -> (36,0) +(36,2) -> (37,0) +(38,2) -> (44,0) +(39,1) -> (40,0) +(41,2) -> (44,1) +(42,1) -> (43,0) +(44,2) -> (48,0) +(45,2) -> (48,1) +(46,1) -> (47,0) +(48,2) -> (49,0) diff --git a/examples/fixture/asm/r1cs/generic_builtin_bits.asm b/examples/fixture/asm/r1cs/generic_builtin_bits.asm index c216dc416..43d7a4f02 100644 --- a/examples/fixture/asm/r1cs/generic_builtin_bits.asm +++ b/examples/fixture/asm/r1cs/generic_builtin_bits.asm @@ -7,12 +7,9 @@ v_5 == (v_4) * (v_4 + -1) 0 == (v_5) * (1) v_7 == (v_6) * (v_6 + -1) 0 == (v_7) * (1) -v_2 + 2 * v_4 + 4 * v_6 == (v_1) * (1) +v_1 == (v_2 + 2 * v_4 + 4 * v_6) * (1) 1 == (-1 * v_2 + 1) * (1) 1 == (v_4) * (1) 1 == (-1 * v_6 + 1) * (1) v_1 == (v_2 + 2 * v_4 + 4 * v_6) * (1) -0 == (v_8) * (1) -1 == (v_9) * (1) -0 == (v_10) * (1) -v_1 == (v_8 + 2 * v_9 + 4 * v_10) * (1) +2 == (v_1) * (1) diff --git a/examples/generic_array_access.no b/examples/generic_array_access.no index 2935d0cd9..a43458317 100644 --- a/examples/generic_array_access.no +++ b/examples/generic_array_access.no @@ -1,5 +1,6 @@ fn last(arr: [Field; LEN]) -> Field { - return arr[LEN - 1]; + let last = LEN - 1; + return arr[last]; } fn main(pub xx: Field) { diff --git a/examples/generic_builtin_bits.no b/examples/generic_builtin_bits.no index db4830d07..62c412c08 100644 --- a/examples/generic_builtin_bits.no +++ b/examples/generic_builtin_bits.no @@ -1,22 +1,111 @@ +// circom versions: +// template Num2Bits(n) { +// signal input in; +// signal output out[n]; +// var lc1=0; + +// var e2=1; +// for (var i = 0; i> i) & 1; +// out[i] * (out[i] -1 ) === 0; +// lc1 += out[i] * e2; +// e2 = e2+e2; +// } + +// lc1 === in; +// } + +// template Bits2Num(n) { +// signal input in[n]; +// signal output out; +// var lc1=0; + +// var e2 = 1; +// for (var i = 0; i out; +// } + use std::bits; -// 010 = xx, where xx = 2 +// obviously writing this in native is much simpler than the builtin version +fn to_bits(const LEN: Field, value: Field) -> [Bool; LEN] { + let mut bits = [false; LEN]; + let mut lc1 = 0; + let mut e2 = 1; + + let one = 1; + let zero = 0; + + for index in 0..LEN { + // maybe add a unconstrained / unsafe attribute before bits::nth_bit, such that: + // bits[index] = unsafe bits::nth_bit(value, index); + // here we need to ensure the related variables are constrained: + // 1. value: constrained to be equal with the sum of bits, which involves the index as well + // 2. index: a cell index in bits + // 3. bits: the output bits + // beyond the notation purpose, what security measures can we take to help guide this unsafe operation? + // one idea is to rely on this unsafe attribute to check if it is non-deterministic when constraining the bits[index] + // eg. + // - bits::nth_bit(value, index) is non-deterministic + // - a metadata can be added to the var of the bits as non-deterministic + // - when CS trying to constrain the non-deterministic var, + // it will raise an error if the var is not marked unsafe via the attribute unsafe + // thus, it seems we also need to add the attribute to the builtin function signature + // eg. `unsafe nth_bit(val: Field, const nth: Field) -> Bool` + // while the unsafe attribute in `bits[index] = unsafe bits::nth_bit(value, index);` + // is for the users to acknowledge they are responsible for having additional constraints. + // This approach makes it explicit whether an expression is non-deterministic at the first place. + // On the other hand, circom lang determines whether it is non-deterministic by folding the arithmetic operation. + + bits[index] = bits::nth_bit(value, index); + // nth_bit is a hint function, and it doesn't constraint the value of the bits as boolean, + // although its return type is boolean. + // can we make the arithmetic operation compatible with boolean? + // or just make a stdlib to convert boolean to Field while adding the constraint? + let bit_num = if bits[index] {one} else {zero}; + assert_eq(bit_num * (bit_num - 1), 0); + + lc1 = lc1 + if bits[index] {e2} else {zero}; + e2 = e2 + e2; + } + assert_eq(lc1, value); + return bits; +} + +fn from_bits(bits: [Bool; LEN]) -> Field { + let mut lc1 = 0; + let mut e2 = 1; + let zero = 0; + + for index in 0..LEN { + lc1 = lc1 + if bits[index] {e2} else {zero}; + e2 = e2 + e2; + } + return lc1; +} + fn main(pub xx: Field) { - // var - let bits = bits::to_bits(3, xx); + // calculate on a cell var + let bits = to_bits(3, xx); assert(!bits[0]); assert(bits[1]); assert(!bits[2]); - let val = bits::from_bits(bits); + let val = from_bits(bits); assert_eq(val, xx); - // constant - let cst_bits = bits::to_bits(3, 2); + // calculate on a constant + let cst_bits = to_bits(3, 2); assert(!cst_bits[0]); assert(cst_bits[1]); assert(!cst_bits[2]); - let cst = bits::from_bits(cst_bits); + let cst = from_bits(cst_bits); assert_eq(cst, xx); } + +// ^ the asm diffs originated from the fact that the builtin version stored constant as cell vars. \ No newline at end of file diff --git a/src/backends/mod.rs b/src/backends/mod.rs index fabfa2b83..fe411ae90 100644 --- a/src/backends/mod.rs +++ b/src/backends/mod.rs @@ -196,6 +196,82 @@ pub trait Backend: Clone { Self::Field::zero() }; + env.cached_values.insert(cache_key, res); + Ok(res) + } + Value::LeftShift(var, shift) => { + let var = self.compute_var(env, var)?; + let shifted = var.to_biguint() << *shift; + let res = Self::Field::from(shifted); + Ok(res) + } + Value::VarDivVar(dividend, divisor) => { + let dividend = self.compute_var(env, dividend)?; + let divisor = self.compute_var(env, divisor)?; + let res = dividend / divisor; + + env.cached_values.insert(cache_key, res); + Ok(res) + } + Value::CstDivVar(dividend, divisor) => { + let divisor = self.compute_var(env, divisor)?; + let res = *dividend / divisor; + Ok(res) + } + Value::VarDivCst(dividend, divisor) => { + let dividend = self.compute_var(env, dividend)?; + // convert to bigint + let dividend = dividend.to_biguint(); + let divisor = divisor.to_biguint(); + + let res = dividend / divisor; + let res = Self::Field::from(res); + env.cached_values.insert(cache_key, res); + Ok(res) + } + Value::CstDivCst(dividend, divisor) => { + let res = *dividend / *divisor; + env.cached_values.insert(cache_key, res); + Ok(res) + } + Value::VarModVar(dividend, divisor) => { + let dividend = self.compute_var(env, dividend)?; + let divisor = self.compute_var(env, divisor)?; + // convert to bigint + let dividend = dividend.to_biguint(); + let divisor = divisor.to_biguint(); + let res = dividend % divisor; + + let res = Self::Field::from(res); + env.cached_values.insert(cache_key, res); + Ok(res) + } + Value::CstModVar(dividend, divisor) => { + let divisor = self.compute_var(env, divisor)?; + // convert to bigint + let dividend = dividend.to_biguint(); + let divisor = divisor.to_biguint(); + let res = dividend % divisor; + let res = Self::Field::from(res); + env.cached_values.insert(cache_key, res); + Ok(res) + } + Value::VarModCst(dividend, divisor) => { + let dividend = self.compute_var(env, dividend)?; + // convert to bigint + let dividend = dividend.to_biguint(); + let divisor = divisor.to_biguint(); + let res = dividend % divisor; + let res = Self::Field::from(res); + env.cached_values.insert(cache_key, res); + Ok(res) + } + Value::CstModCst(dividend, divisor) => { + let dividend = dividend.to_biguint(); + let divisor = divisor.to_biguint(); + let res = dividend % divisor; + let res = Self::Field::from(res); + env.cached_values.insert(cache_key, res); Ok(res) } } diff --git a/src/backends/r1cs/mod.rs b/src/backends/r1cs/mod.rs index 17f7fdac2..facac9bd1 100644 --- a/src/backends/r1cs/mod.rs +++ b/src/backends/r1cs/mod.rs @@ -407,7 +407,10 @@ where ); return Err(err); } else { - panic!("there's a bug in the circuit_writer, some cellvar does not end up being a cellvar in the circuit!"); + // todo: lift this restriction for now + // this is due to the intermediate cell vars introduced by the hint builtins. + // in order to get computed in symbolic value, it needs to introduce intermediate cell vars and pass them around even it is hint builtins. + // panic!("there's a bug in the circuit_writer, some cellvar does not end up being a cellvar in the circuit!"); } } } diff --git a/src/error.rs b/src/error.rs index 9c53036f5..9670830e8 100644 --- a/src/error.rs +++ b/src/error.rs @@ -130,7 +130,7 @@ pub enum ErrorKind { #[error("invalid array size, expected [_; x] with x in [0,2^32]")] InvalidArraySize, - #[error("only allow a single generic parameter for the size of an array argument")] + #[error("Invalid expression in symbolic size")] InvalidSymbolicSize, #[error("invalid generic parameter, expected single uppercase letter, such as N, M, etc.")] @@ -148,6 +148,9 @@ pub enum ErrorKind { #[error("calling generic functions in for loop is not allowed")] GenericInForLoop, + #[error("variable `{0}` is forbidden to be accessed in the for loop, likely due to a generic function call")] + VarAccessForbiddenInForLoop(String), + #[error("the value passed could not be converted to a field element")] InvalidField(String), @@ -349,4 +352,7 @@ pub enum ErrorKind { #[error("invalid range, the end value can't be smaller than the start value")] InvalidRange, + + #[error("division by zero")] + DivisionByZero, } diff --git a/src/mast/mod.rs b/src/mast/mod.rs index 60039c45d..b4ba259a0 100644 --- a/src/mast/mod.rs +++ b/src/mast/mod.rs @@ -27,10 +27,21 @@ pub struct ExprMonoInfo { /// The generic types shouldn't be presented in this field. pub typ: Option, - // todo: see if we can do constant folding on the expression nodes. - // - it is possible to remove this field, as the constant value can be extracted from folded expression node - /// Numeric value of the expression - /// applicable to BigInt type + /// Propagated constant value + /// - For BigUInt expression type, this corresponds the same inner value. + /// - For constant field type, this corresponds the propagated constant value. + /// The reason why we can't just fold all the constants to BigUInt expression node is because + /// there are cases where folding can't be done even though the expression node is a constant. + /// For example: + /// ``` + /// let mut var = 1; + /// for var idx in 0..10 { + /// var = var + var; + /// } + /// ``` + /// If the `var` is folded to a BigUInt expression node, + /// it won't represent the expression node with the same intension at the synthesizer phase, + /// as it lose the recursive nature of the expression node. pub constant: Option, } @@ -152,10 +163,10 @@ where { tast: TypeChecker, generic_func_scope: Option, - // fully qualified function name - functions_to_delete: Vec, - // fully qualified struct name, method name - methods_to_delete: Vec<(FullyQualified, String)>, + // new fully qualified function name as the key, old fully qualified function name as the value + functions_instantiated: HashMap, + // new method name as the key, old method name as the value + methods_instantiated: HashMap<(FullyQualified, String), String>, } impl MastCtx { @@ -163,8 +174,8 @@ impl MastCtx { Self { tast, generic_func_scope: Some(0), - functions_to_delete: vec![], - methods_to_delete: vec![], + functions_instantiated: HashMap::new(), + methods_instantiated: HashMap::new(), } } @@ -190,9 +201,8 @@ impl MastCtx { ) { self.tast .add_monomorphized_fn(new_qualified.clone(), fn_info); - if new_qualified != old_qualified { - self.functions_to_delete.push(old_qualified); - } + self.functions_instantiated + .insert(new_qualified, old_qualified); } pub fn add_monomorphized_method( @@ -205,20 +215,25 @@ impl MastCtx { self.tast .add_monomorphized_method(struct_qualified.clone(), method_name, fn_info); - if method_name != old_method_name { - self.methods_to_delete - .push((struct_qualified, old_method_name.to_string())); - } + self.methods_instantiated.insert( + (struct_qualified, method_name.to_string()), + old_method_name.to_string(), + ); } pub fn clear_generic_fns(&mut self) { - for qualified in &self.functions_to_delete { - self.tast.remove_fn(qualified); + for (new, old) in &self.functions_instantiated { + // don't remove the instantiated function with no generic arguments + if new != old { + self.tast.remove_fn(old); + } } - self.functions_to_delete.clear(); - for (struct_qualified, method_name) in &self.methods_to_delete { - self.tast.remove_method(struct_qualified, method_name); + for ((struct_qualified, new), old) in &self.methods_instantiated { + // don't remove the instantiated method with no generic arguments + if new != old { + self.tast.remove_method(struct_qualified, old); + } } } } @@ -238,6 +253,7 @@ impl Symbolic { } Symbolic::Generic(g) => mono_fn_env.get_type_info(&g.value).unwrap().value.unwrap(), Symbolic::Add(a, b) => a.eval(mono_fn_env, tast) + b.eval(mono_fn_env, tast), + Symbolic::Sub(a, b) => a.eval(mono_fn_env, tast) - b.eval(mono_fn_env, tast), Symbolic::Mul(a, b) => a.eval(mono_fn_env, tast) * b.eval(mono_fn_env, tast), } } @@ -422,21 +438,35 @@ fn monomorphize_expr( let args_mono = observed.clone().into_iter().map(|e| e.expr).collect(); - let fn_name_mono = &fn_info_mono.sig().name; - let mexpr = Expr { - kind: ExprKind::FnCall { - module: module.clone(), - fn_name: fn_name_mono.clone(), - args: args_mono, - }, - ..expr.clone() - }; + // check if this function is already monomorphized + if ctx.functions_instantiated.contains_key(&old_qualified) { + // todo: cache the propagated constant from instantiated function, + // so it doesn't need to re-instantiate the function + let mexpr = Expr { + kind: ExprKind::FnCall { + module: module.clone(), + fn_name: fn_name.clone(), + args: args_mono, + }, + ..expr.clone() + }; + ExprMonoInfo::new(mexpr, typ, None) + } else { + let fn_name_mono = &fn_info_mono.sig().name; + let mexpr = Expr { + kind: ExprKind::FnCall { + module: module.clone(), + fn_name: fn_name_mono.clone(), + args: args_mono, + }, + ..expr.clone() + }; - let qualified = FullyQualified::new(module, &fn_name_mono.value); - ctx.add_monomorphized_fn(old_qualified, qualified, fn_info_mono); + let qualified = FullyQualified::new(module, &fn_name_mono.value); + ctx.add_monomorphized_fn(old_qualified, qualified, fn_info_mono); - // assume the function call won't return constant value - ExprMonoInfo::new(mexpr, typ, None) + ExprMonoInfo::new(mexpr, typ, None) + } } // `lhs.method_name(args)` @@ -493,22 +523,39 @@ fn monomorphize_expr( // monomorphize the function call let (fn_info_mono, typ) = instantiate_fn_call(ctx, fn_info, &observed, expr.span)?; - let fn_name_mono = &fn_info_mono.sig().name; - let mexpr = Expr { - kind: ExprKind::MethodCall { - lhs: Box::new(lhs_mono.expr), - method_name: fn_name_mono.clone(), - args: args_mono, - }, - ..expr.clone() - }; + // check if this function is already monomorphized + if ctx + .methods_instantiated + .contains_key(&(struct_qualified.clone(), method_name.value.clone())) + { + // todo: cache the propagated constant from instantiated method, + // so it doesn't need to re-instantiate the function + let mexpr = Expr { + kind: ExprKind::MethodCall { + lhs: Box::new(lhs_mono.expr), + method_name: method_name.clone(), + args: args_mono, + }, + ..expr.clone() + }; + ExprMonoInfo::new(mexpr, typ, None) + } else { + let fn_name_mono = &fn_info_mono.sig().name; + let mexpr = Expr { + kind: ExprKind::MethodCall { + lhs: Box::new(lhs_mono.expr), + method_name: fn_name_mono.clone(), + args: args_mono, + }, + ..expr.clone() + }; - let fn_def = fn_info_mono.native(); - ctx.tast - .add_monomorphized_method(struct_qualified, &fn_name_mono.value, fn_def); + let fn_def = fn_info_mono.native(); + ctx.tast + .add_monomorphized_method(struct_qualified, &fn_name_mono.value, fn_def); - // assume the function call won't return constant value - ExprMonoInfo::new(mexpr, typ, None) + ExprMonoInfo::new(mexpr, typ, None) + } } ExprKind::Assignment { lhs, rhs } => { @@ -556,31 +603,70 @@ fn monomorphize_expr( | Op2::Multiplication | Op2::Division | Op2::BoolAnd - | Op2::BoolOr => lhs_mono.typ, + | Op2::BoolOr => lhs_mono.clone().typ, }; let ExprMonoInfo { expr: lhs_expr, .. } = lhs_mono; let ExprMonoInfo { expr: rhs_expr, .. } = rhs_mono; - // fold constants - let cst = match (&lhs_expr.kind, &rhs_expr.kind) { - (ExprKind::BigUInt(lhs), ExprKind::BigUInt(rhs)) => match op { - Op2::Addition => Some(lhs + rhs), - Op2::Subtraction => Some(lhs - rhs), - Op2::Multiplication => Some(lhs * rhs), - Op2::Division => Some(lhs / rhs), - _ => None, - }, - _ => None, - }; + match (&lhs_expr.kind, &rhs_expr.kind) { + // fold constants + (ExprKind::BigUInt(lhs), ExprKind::BigUInt(rhs)) => { + let cst = match op { + Op2::Addition => Some(lhs + rhs), + Op2::Subtraction => Some(lhs - rhs), + Op2::Multiplication => Some(lhs * rhs), + Op2::Division => Some(lhs / rhs), + _ => None, + }; + + match cst { + // fold only if the operation is supported + Some(cst) => { + let mexpr = expr.to_mast(ctx, &ExprKind::BigUInt(cst.clone())); + ExprMonoInfo::new(mexpr, typ, cst.to_u32()) + } + None => { + let mexpr = expr.to_mast( + ctx, + &ExprKind::BinaryOp { + op: op.clone(), + protected: *protected, + lhs: Box::new(lhs_expr), + rhs: Box::new(rhs_expr), + }, + ); + + ExprMonoInfo::new(mexpr, typ, None) + } + } + } + // not folding, but propagate the updated constant value + (_, _) if lhs_mono.constant.is_some() && rhs_mono.constant.is_some() => { + let lhs = lhs_mono.constant.unwrap(); + let rhs = rhs_mono.constant.unwrap(); + let cst = match op { + Op2::Addition => Some(lhs + rhs), + Op2::Subtraction => Some(lhs - rhs), + Op2::Multiplication => Some(lhs * rhs), + Op2::Division => Some(lhs / rhs), + _ => None, + }; - match cst { - Some(v) => { - let mexpr = expr.to_mast(ctx, &ExprKind::BigUInt(v.clone())); + let mexpr = expr.to_mast( + ctx, + &ExprKind::BinaryOp { + op: op.clone(), + protected: *protected, + lhs: Box::new(lhs_expr), + rhs: Box::new(rhs_expr), + }, + ); - ExprMonoInfo::new(mexpr, typ, v.to_u32()) + ExprMonoInfo::new(mexpr, typ, cst) } - None => { + // keep as is + _ => { let mexpr = expr.to_mast( ctx, &ExprKind::BinaryOp { diff --git a/src/negative_tests.rs b/src/negative_tests.rs index 9c7013d4d..eab751309 100644 --- a/src/negative_tests.rs +++ b/src/negative_tests.rs @@ -105,7 +105,11 @@ fn test_generic_const_for_loop() { "#; let res = tast_pass(code).0; - assert!(matches!(res.unwrap_err().kind, ErrorKind::GenericInForLoop)); + + assert!(matches!( + res.unwrap_err().kind, + ErrorKind::VarAccessForbiddenInForLoop(..) + )); } #[test] @@ -125,7 +129,10 @@ fn test_generic_array_for_loop() { "#; let res = tast_pass(code).0; - assert!(matches!(res.unwrap_err().kind, ErrorKind::GenericInForLoop)); + assert!(matches!( + res.unwrap_err().kind, + ErrorKind::VarAccessForbiddenInForLoop(..) + )); } #[test] diff --git a/src/parser/expr.rs b/src/parser/expr.rs index 9fed71992..c3db48638 100644 --- a/src/parser/expr.rs +++ b/src/parser/expr.rs @@ -529,9 +529,11 @@ impl Expr { // sanity check if !matches!( self.kind, - ExprKind::Variable { .. } | ExprKind::FieldAccess { .. } + ExprKind::Variable { .. } + | ExprKind::FieldAccess { .. } + | ExprKind::ArrayAccess { .. } ) { - panic!("an array access can only follow a variable"); + panic!("an array access can only follow a variable or another array access"); } // array[idx] diff --git a/src/parser/types.rs b/src/parser/types.rs index 9c60dc65d..fce537fe3 100644 --- a/src/parser/types.rs +++ b/src/parser/types.rs @@ -186,7 +186,9 @@ pub enum Symbolic { Constant(Ident), /// Generic parameter Generic(Ident), + /// Binary operation with protected flag Add(Box, Box), + Sub(Box, Box), Mul(Box, Box), } @@ -197,6 +199,7 @@ impl Display for Symbolic { Symbolic::Constant(ident) => write!(f, "{}", ident.value), Symbolic::Generic(ident) => write!(f, "{}", ident.value), Symbolic::Add(lhs, rhs) => write!(f, "{} + {}", lhs, rhs), + Symbolic::Sub(lhs, rhs) => write!(f, "{} - {}", lhs, rhs), Symbolic::Mul(lhs, rhs) => write!(f, "{} * {}", lhs, rhs), } } @@ -222,6 +225,10 @@ impl Symbolic { generics.extend(lhs.extract_generics()); generics.extend(rhs.extract_generics()); } + Symbolic::Sub(lhs, rhs) => { + generics.extend(lhs.extract_generics()); + generics.extend(rhs.extract_generics()); + } } generics @@ -250,6 +257,7 @@ impl Symbolic { // no protected flags are needed, as this is based on expression nodes which already ordered the operations match op { Op2::Addition => Ok(Symbolic::Add(Box::new(lhs), Box::new(rhs?))), + Op2::Subtraction => Ok(Symbolic::Sub(Box::new(lhs), Box::new(rhs?))), Op2::Multiplication => Ok(Symbolic::Mul(Box::new(lhs), Box::new(rhs?))), _ => Err(Error::new( "mast", @@ -520,30 +528,8 @@ impl FnSig { // extract generic parameters from arguments let mut generics = GenericParameters::default(); for arg in &arguments { - match &arg.typ.kind { - TyKind::Field { .. } => { - // extract from const argument - if is_generic_parameter(&arg.name.value) && arg.is_constant() { - generics.add(arg.name.value.to_string()); - } - } - TyKind::Array(ty, _) => { - // recursively extract all generic parameters from the item type - let extracted = ty.extract_generics(); - - for name in extracted { - generics.add(name); - } - } - TyKind::GenericSizedArray(_, _) => { - // recursively extract all generic parameters from the symbolic size - let extracted = arg.typ.kind.extract_generics(); - - for name in extracted { - generics.add(name); - } - } - _ => (), + for name in arg.extract_generic_names() { + generics.add(name); } } @@ -863,6 +849,38 @@ impl FnArg { .map(|attr| attr.is_constant()) .unwrap_or(false) } + + pub fn extract_generic_names(&self) -> HashSet { + let mut generics = HashSet::new(); + + match &self.typ.kind { + TyKind::Field { .. } => { + // extract from const argument + if is_generic_parameter(&self.name.value) && self.is_constant() { + generics.insert(self.name.value.to_string()); + } + } + TyKind::Array(ty, _) => { + // recursively extract all generic parameters from the item type + let extracted = ty.extract_generics(); + + for name in extracted { + generics.insert(name); + } + } + TyKind::GenericSizedArray(_, _) => { + // recursively extract all generic parameters from the symbolic size + let extracted = self.typ.kind.extract_generics(); + + for name in extracted { + generics.insert(name); + } + } + _ => (), + } + + generics + } } impl FuncOrMethod { diff --git a/src/stdlib/bits.rs b/src/stdlib/bits.rs index 3fe256ee2..35712f05a 100644 --- a/src/stdlib/bits.rs +++ b/src/stdlib/bits.rs @@ -1,7 +1,7 @@ use std::vec; -use ark_ff::One; -use kimchi::o1_utils::FieldHelpers; +use ark_ff::{BigInteger, One}; +use kimchi::{o1_utils::FieldHelpers, turshi::helper::CairoFieldHelpers}; use crate::{ backends::Backend, @@ -9,7 +9,7 @@ use crate::{ constants::Span, constraints::boolean, error::Result, - parser::types::GenericParameters, + parser::types::{GenericParameters, TyKind}, var::{ConstOrCell, Value, Var}, }; @@ -18,16 +18,115 @@ use super::{FnInfoType, Module}; const TO_BITS_FN: &str = "to_bits(const LEN: Field, val: Field) -> [Bool; LEN]"; const FROM_BITS_FN: &str = "from_bits(bits: [Bool; LEN]) -> Field"; +const NTH_BIT_FN: &str = "nth_bit(val: Field, const nth: Field) -> Bool"; +const LEFT_SHIFT_FN: &str = "left_shift(val: Field, const shift: Field) -> Field"; + +const BIT_LEN_FN: &str = "bit_len(const value: Field) -> Field"; + pub struct BitsLib {} impl Module for BitsLib { const MODULE: &'static str = "bits"; fn get_fns() -> Vec<(&'static str, FnInfoType)> { - vec![(TO_BITS_FN, to_bits), (FROM_BITS_FN, from_bits)] + vec![ + (TO_BITS_FN, to_bits), + (FROM_BITS_FN, from_bits), + (NTH_BIT_FN, nth_bit), + (LEFT_SHIFT_FN, left_shift), + (BIT_LEN_FN, bit_len), + ] } } +fn nth_bit( + compiler: &mut CircuitWriter, + _generics: &GenericParameters, + vars: &[VarInfo], + span: Span, +) -> Result>> { + // should be two input vars + assert_eq!(vars.len(), 2); + + // these should be type checked already, unless it is called by other low level functions + // eg. builtins + let var_info = &vars[0]; + let val = &var_info.var; + assert_eq!(val.len(), 1); + + let var_info = &vars[1]; + let nth = &var_info.var; + assert_eq!(nth.len(), 1); + + let nth: usize = match &nth[0] { + ConstOrCell::Cell(_) => unreachable!("nth should be a constant"), + ConstOrCell::Const(cst) => cst.to_u64() as usize, + }; + + let val = match &val[0] { + ConstOrCell::Cell(cvar) => cvar.clone(), + ConstOrCell::Const(cst) => { + // directly return the nth bit without adding symbolic value as it doesn't depend on a cell var + let bit = cst.to_bits(); + return Ok(Some(Var::new_cvar( + ConstOrCell::Const(B::Field::from(bit[nth])), + span, + ))); + } + }; + + // create a cell var for the symbolic value representing the nth bit. + // it seems we will always have to create cell vars to allocate the symbolic values that involve non-deterministic calculations. + // it is non-deterministic because it involves non-deterministic arithmetic on a cell var. + let bit = compiler + .backend + .new_internal_var(Value::NthBit(val.clone(), nth), span); + + Ok(Some(Var::new(vec![ConstOrCell::Cell(bit)], span))) +} + +fn left_shift( + compiler: &mut CircuitWriter, + _generics: &GenericParameters, + vars: &[VarInfo], + span: Span, +) -> Result>> { + // should be two input vars + assert_eq!(vars.len(), 2); + + let var_info = &vars[0]; + let val = &var_info.var; + assert_eq!(val.len(), 1); + + let var_info = &vars[1]; + let shift = &var_info.var; + assert_eq!(shift.len(), 1); + + let shift: usize = match &shift[0] { + ConstOrCell::Cell(_) => unreachable!("shift should be a constant"), + ConstOrCell::Const(cst) => cst.to_u64() as usize, + }; + + let val = match &val[0] { + ConstOrCell::Cell(cvar) => cvar.clone(), + ConstOrCell::Const(cst) => { + // directly return the nth bit without adding symbolic value as it doesn't depend on a cell var + let shifted = cst.to_biguint() << shift; + return Ok(Some(Var::new_cvar( + ConstOrCell::Const(B::Field::from(shifted)), + span, + ))); + } + }; + + // create a cell var for the symbolic value that depends on another cell var. + let shifted = compiler + .backend + .new_internal_var(Value::LeftShift(val.clone(), shift), span); + + Ok(Some(Var::new(vec![ConstOrCell::Cell(shifted)], span))) +} + fn to_bits( compiler: &mut CircuitWriter, generics: &GenericParameters, @@ -164,3 +263,41 @@ fn from_bits( Ok(Some(Var::new_cvar(cvar, span))) } + +/// Unconstrained log ceil. +fn bit_len( + compiler: &mut CircuitWriter, + generics: &GenericParameters, + vars: &[VarInfo], + span: Span, +) -> Result>> { + // we get a single var + assert_eq!(vars.len(), 1); + + // of type field + let var_info = &vars[0]; + if !matches!(var_info.typ, Some(TyKind::Field { .. })) { + panic!( + "the var of log_ceil must be of type Field. It was of type {:?}", + var_info.typ + ); + } + + // of only one field element + let var = &var_info.var; + assert_eq!(var.len(), 1); + let val = &var[0]; + + match val { + ConstOrCell::Const(cst) => { + let res = cst.to_biguint().bits(); + Ok(Some(Var::new( + vec![ConstOrCell::Const(B::Field::from(res))], + span, + ))) + } + _ => { + todo!() + } + } +} diff --git a/src/stdlib/builtins.rs b/src/stdlib/builtins.rs index 2ece6c5de..e110ec1a1 100644 --- a/src/stdlib/builtins.rs +++ b/src/stdlib/builtins.rs @@ -1,6 +1,8 @@ //! Builtins are imported by default. use ark_ff::One; +use kimchi::o1_utils::FieldHelpers; +use num_traits::{ToPrimitive, Zero}; use crate::{ backends::Backend, @@ -9,19 +11,22 @@ use crate::{ error::{Error, ErrorKind, Result}, helpers::PrettyField, parser::types::{GenericParameters, TyKind}, - var::{ConstOrCell, Var}, + var::{ConstOrCell, Value, Var}, }; use super::{FnInfoType, Module}; pub const QUALIFIED_BUILTINS: &str = "std/builtins"; -pub const BUILTIN_FN_NAMES: [&str; 3] = ["assert", "assert_eq", "log"]; +pub const BUILTIN_FN_NAMES: [&str; 2] = ["assert", "assert_eq"]; const ASSERT_FN: &str = "assert(condition: Bool)"; const ASSERT_EQ_FN: &str = "assert_eq(lhs: Field, rhs: Field)"; // todo: currently only supports a single field var // to support all the types, we can bypass the type check for this log function for now const LOG_FN: &str = "log(var: Field)"; +const DIV_EQ_FN: &str = "div(lhs: Field, rhs: Field) -> Field"; +const MOD_EQ_FN: &str = "mod(lhs: Field, rhs: Field) -> Field"; +const POW_EQ_FN: &str = "pow(base: Field, exp: Field) -> Field"; pub struct BuiltinsLib {} @@ -33,6 +38,9 @@ impl Module for BuiltinsLib { (ASSERT_FN, assert_fn), (ASSERT_EQ_FN, assert_eq_fn), (LOG_FN, log_fn), + (DIV_EQ_FN, div_fn), + (MOD_EQ_FN, mod_fn), + (POW_EQ_FN, pow_fn), ] } } @@ -52,14 +60,14 @@ fn assert_eq_fn( // they are both of type field if !matches!(lhs_info.typ, Some(TyKind::Field { .. })) { panic!( - "the lhs of assert_eq must be of type Field or BigInt. It was of type {:?}", + "the lhs of assert_eq must be of type Field. It was of type {:?}", lhs_info.typ ); } if !matches!(rhs_info.typ, Some(TyKind::Field { .. })) { panic!( - "the rhs of assert_eq must be of type Field or BigInt. It was of type {:?}", + "the rhs of assert_eq must be of type Field. It was of type {:?}", rhs_info.typ ); } @@ -130,6 +138,75 @@ fn assert_fn( Ok(None) } +/// Unconstrained division. +fn div_fn( + compiler: &mut CircuitWriter, + generics: &GenericParameters, + vars: &[VarInfo], + span: Span, +) -> Result>> { + // we get two vars + assert_eq!(vars.len(), 2); + let lhs_info = &vars[0]; + let rhs_info = &vars[1]; + + // they are both of type field + if !matches!(lhs_info.typ, Some(TyKind::Field { .. })) { + panic!( + "the lhs of div must be of type Field. It was of type {:?}", + lhs_info.typ + ); + } + + if !matches!(rhs_info.typ, Some(TyKind::Field { .. })) { + panic!( + "the rhs of div must be of type Field. It was of type {:?}", + rhs_info.typ + ); + } + + // retrieve the values + let lhs_var = &lhs_info.var; + assert_eq!(lhs_var.len(), 1); + let lhs_cvar = &lhs_var[0]; + + let rhs_var = &rhs_info.var; + assert_eq!(rhs_var.len(), 1); + let rhs_cvar = &rhs_var[0]; + + match (lhs_cvar, rhs_cvar) { + // two constants + (ConstOrCell::Const(a), ConstOrCell::Const(b)) => { + if b.is_zero() { + return Err(Error::new( + "constraint-generation", + ErrorKind::DivisionByZero, + span, + )); + } + let res = *a / b; + Ok(Some(Var::new(vec![ConstOrCell::Const(res)], span))) + } + + // a const and a var + (ConstOrCell::Const(cst), ConstOrCell::Cell(cvar)) => { + let val = Value::CstDivVar(*cst, cvar.clone()); + let res = compiler.backend.new_internal_var(val, span); + Ok(Some(Var::new(vec![ConstOrCell::Cell(res)], span))) + } + (ConstOrCell::Cell(cvar), ConstOrCell::Const(cst)) => { + let val = Value::VarDivCst(cvar.clone(), *cst); + let res = compiler.backend.new_internal_var(val, span); + Ok(Some(Var::new(vec![ConstOrCell::Cell(res)], span))) + } + (ConstOrCell::Cell(lhs), ConstOrCell::Cell(rhs)) => { + let val = Value::VarDivVar(lhs.clone(), rhs.clone()); + let res = compiler.backend.new_internal_var(val, span); + Ok(Some(Var::new(vec![ConstOrCell::Cell(res)], span))) + } + } +} + /// Logging fn log_fn( compiler: &mut CircuitWriter, @@ -156,3 +233,130 @@ fn log_fn( Ok(None) } + +/// Unconstrained modulo. +fn mod_fn( + compiler: &mut CircuitWriter, + generics: &GenericParameters, + vars: &[VarInfo], + span: Span, +) -> Result>> { + // we get two vars + assert_eq!(vars.len(), 2); + let lhs_info = &vars[0]; + let rhs_info = &vars[1]; + + // they are both of type field + if !matches!(lhs_info.typ, Some(TyKind::Field { .. })) { + panic!( + "the lhs of mod must be of type Field. It was of type {:?}", + lhs_info.typ + ); + } + + if !matches!(rhs_info.typ, Some(TyKind::Field { .. })) { + panic!( + "the rhs of mod must be of type Field. It was of type {:?}", + rhs_info.typ + ); + } + + // retrieve the values + let lhs_var = &lhs_info.var; + assert_eq!(lhs_var.len(), 1); + let lhs_cvar = &lhs_var[0]; + + let rhs_var = &rhs_info.var; + assert_eq!(rhs_var.len(), 1); + let rhs_cvar = &rhs_var[0]; + + match (lhs_cvar, rhs_cvar) { + // two constants + (ConstOrCell::Const(a), ConstOrCell::Const(b)) => { + if b.is_zero() { + return Err(Error::new( + "constraint-generation", + ErrorKind::DivisionByZero, + span, + )); + } + // convert to bigint + let a = a.to_biguint(); + let b = b.to_biguint(); + let res = a % b; + Ok(Some(Var::new( + vec![ConstOrCell::Const(B::Field::from(res))], + span, + ))) + } + (ConstOrCell::Const(cst), ConstOrCell::Cell(cvar)) => { + let val = Value::CstModVar(*cst, cvar.clone()); + let res = compiler.backend.new_internal_var(val, span); + Ok(Some(Var::new(vec![ConstOrCell::Cell(res)], span))) + } + (ConstOrCell::Cell(cvar), ConstOrCell::Const(cst)) => { + let val = Value::VarModCst(cvar.clone(), *cst); + let res = compiler.backend.new_internal_var(val, span); + Ok(Some(Var::new(vec![ConstOrCell::Cell(res)], span))) + } + (ConstOrCell::Cell(lhs), ConstOrCell::Cell(rhs)) => { + let val = Value::VarModVar(lhs.clone(), rhs.clone()); + let res = compiler.backend.new_internal_var(val, span); + Ok(Some(Var::new(vec![ConstOrCell::Cell(res)], span))) + } + } +} + +/// Unconstrained exponentiation. +fn pow_fn( + compiler: &mut CircuitWriter, + generics: &GenericParameters, + vars: &[VarInfo], + span: Span, +) -> Result>> { + // we get two vars + assert_eq!(vars.len(), 2); + let base_info = &vars[0]; + let exp_info = &vars[1]; + + // they are both of type field + if !matches!(base_info.typ, Some(TyKind::Field { .. })) { + panic!( + "the base of pow must be of type Field. It was of type {:?}", + base_info.typ + ); + } + + if !matches!(exp_info.typ, Some(TyKind::Field { .. })) { + panic!( + "the exp of pow must be of type Field. It was of type {:?}", + exp_info.typ + ); + } + + // retrieve the values + let base_var = &base_info.var; + assert_eq!(base_var.len(), 1); + let base_cvar = &base_var[0]; + + let exp_var = &exp_info.var; + assert_eq!(exp_var.len(), 1); + let exp_cvar = &exp_var[0]; + + match (base_cvar, exp_cvar) { + // two constants + (ConstOrCell::Const(a), ConstOrCell::Const(b)) => { + // convert to bigint + let a = a.to_biguint(); + let b = b.to_biguint(); + let res = a.pow(b.to_u32().expect("expects u32 number")); + Ok(Some(Var::new( + vec![ConstOrCell::Const(B::Field::from(res))], + span, + ))) + } + _ => { + todo!() + } + } +} diff --git a/src/stdlib/native/bigint.no b/src/stdlib/native/bigint.no new file mode 100644 index 000000000..e3fd71491 --- /dev/null +++ b/src/stdlib/native/bigint.no @@ -0,0 +1,295 @@ +use std::bits; +use std::builtins; +use std::comparator; +use std::int; + +struct Result { + sum: Field, + carry: Field, +} + +// sum % 2**LEN with carry bit +fn module_sum(const LEN: Field, lhs: Field, rhs: Field) -> Result { + // assumptions: + // 1. LEN has to be less than the size of fields lhs and rhs + // otherwise the sum of lhs and rhs is always less than the max value of LEN bits. + // 2. the size of a b value have to be upper bounded by LEN + + let one = 1; + let zero = 0; + + // absolute sum + let abs_sum = lhs + rhs; + + // one more bit for the carry + let carry_len = LEN + 1; + + // convert the sum to bits + // this also does the range check for the assumption 2 + let abs_bits = bits::to_bits(carry_len, abs_sum); + + // extract the carry bit + let carry_bit = abs_bits[LEN]; + let carry_num = if carry_bit { one } else { zero }; + + // calculate the module sum by removing contributions from carries + let mod_sum = abs_sum - (carry_num * bits::left_shift(1, LEN)); + + return Result { + sum: mod_sum, + carry: carry_num + }; +} + +// todo: encapsulate module_sum* functions into a single function, eg: +// fn module_sum(const LEN: Field, elms: [Field; SIZE]) -> Result +// LEN is the number of bits of the modulo, SIZE is the number of elements +// according to the math: a + b + c <= SIZE * (2**LEN - 1) +// the number of carries can be determined by ceil(SIZE/2) +fn module_sum_three (const LEN: Field, aa: Field, bb: Field, cc: Field) -> Result { + let one = 1; + let zero = 0; + + // absolute sum + let abs_sum = (aa + bb) + cc; + + // carry bits + let carry_len = LEN + 2; + + // convert the sum to bits + let abs_bits = bits::to_bits(carry_len, abs_sum); + + // extract the carry bit + let first_carry_num = if abs_bits[LEN] { one } else { zero }; + let second_carry_num = if abs_bits[LEN + 1] { one } else { zero }; + + // total carry in sum + let total_carry_num = first_carry_num + (2 * second_carry_num); + + // calculate the module sum by removing contributions from carries + let mod_sum = abs_sum - (total_carry_num * bits::left_shift(1, LEN)); + + return Result { + sum: mod_sum, + carry: total_carry_num + }; +} + +fn add_limbs(const LEN: Field, aa: [Field; SIZE], bb: [Field; SIZE]) -> [Field; SIZE + 1] { + let mut out = [0; SIZE + 1]; + + let mut res = module_sum(LEN, aa[0], bb[0]); + out[0] = res.sum; + + let mut carry = res.carry; + + for ii in 1..SIZE { + res = module_sum_three(LEN, aa[ii], bb[ii], carry); + out[ii] = res.sum; + carry = res.carry; + } + + out[SIZE] = carry; + return out; +} + +fn mult_limbs_no_carry(const BITLEN: Field, lhs: [Field; REGLHS], rhs: [Field; REGRHS]) -> [Field; (REGLHS + REGRHS) - 1] { + let mut prod_val = [0; REGLHS + REGRHS]; + for ii in 0..REGLHS { + for jj in 0..REGRHS { + prod_val[ii + jj] = prod_val[ii + jj] + (lhs[ii] * rhs[jj]); + } + } + + let mut prod = [0; (REGLHS + REGRHS) - 1]; + // is the value or prod[REGLHS + REGRHS] always 0? + for ii in 0..(REGLHS + REGRHS) - 1 { + prod[ii] = prod_val[ii]; + } + + // prove prod_val calculated correctly via polynomial convolution + // P_out(x) = P_lhs(x) * P_rhs(x) + // lhs[i] and rhs[j] are the coefficients of P_lhs(x) and P_rhs(x) respectively + // coefficients of P_out(x), which is out[i], are the convolution of lhs and rhs + let mut lhs_poly = [0; (REGLHS + REGRHS) - 1]; + let mut rhs_poly = [0; (REGLHS + REGRHS) - 1]; + let mut prod_poly = [0; (REGLHS + REGRHS) - 1]; + for point in 0..(REGLHS + REGRHS) - 1 { + for limb_idx in 0..(REGLHS + REGRHS) - 1 { + prod_poly[point] = prod_poly[point] + (prod[limb_idx] * builtins::pow(point, limb_idx)); + } + for limb_idx in 0..REGLHS { + lhs_poly[point] = lhs_poly[point] + (lhs[limb_idx] * builtins::pow(point, limb_idx)); + } + for limb_idx in 0..REGRHS { + rhs_poly[point] = rhs_poly[point] + (rhs[limb_idx] * builtins::pow(point, limb_idx)); + } + } + + for ii in 0..(REGLHS + REGRHS) - 1 { + assert_eq(prod_poly[ii], lhs_poly[ii] * rhs_poly[ii]); + } + + return prod; +} + +fn split_fn(in_val: Field, const small_bitlen: Field, const big_bitlen: Field) -> [Field; 2] { + let small = builtins::mod(in_val, bits::left_shift(1, small_bitlen)); + + // right shift by small_bitlen to get the remaining bits + let right_shifted = builtins::div(in_val, bits::left_shift(1, small_bitlen)); + let big = builtins::mod(right_shifted, bits::left_shift(1, big_bitlen)); + + return [small, big]; +} + +fn split_three_fn(in_val: Field, const small_bitlen: Field, const mid_bitlen: Field, const big_bitlen: Field) -> [Field; 3] { + let small = builtins::mod(in_val, bits::left_shift(1, small_bitlen)); + + // right shift by small_bitlen to get the remaining bits + let right_shifted_mid = builtins::div(in_val, bits::left_shift(1, small_bitlen)); + let mid = builtins::mod(right_shifted_mid, bits::left_shift(1, mid_bitlen)); + + // right shift by small_bitlen + mid_bitlen to get the remaining bits + let right_shifted_big = builtins::div(in_val, bits::left_shift(1, small_bitlen + mid_bitlen)); + let big = builtins::mod(right_shifted_big, bits::left_shift(1, big_bitlen)); + + return [small, mid, big]; +} + +fn long_to_short(const SHORTLEN: Field, long_ints: [Field; SIZE]) -> [Field; SIZE + 1] { + let zero = 0; + let mut short_ints = [0; SIZE + 1]; + + let mut split = [[0; 3]; SIZE]; + for ii in 0..SIZE { + split[ii] = split_three_fn(long_ints[ii], SHORTLEN, SHORTLEN, SHORTLEN); + } + + let mut carry = [0; SIZE]; + short_ints[0] = split[0][0]; + + // todo: support const attribute inside scope, so it can instruct whether the return is constant field + // then it can add a checking at synthesizer phase to check if the returned value is constant + // let size_bit_len = bits::bit_len(SIZE); + let size_bit_len = 2; + + // Compute the short integers and carries + // each of long ints are split into 3 short ints + // then they align diagonally to compute the sum vertically + // the reason to align diagonally is to align with the bit magnitude in each long int + // each long int has left shifted SHORTLEN bits compared to the previous long int + // l0: s0, s1, s2 + // l1: s0, s1, s2 + // l2: s0, s1, s2 + // assumption: no carry is possible in the highest order register + // for example, + // a * b = l1, then l1.s3 won't have carry bits more than len(a) + len(b) + // !need a proof for this assumption + + short_ints[0] = split[0][0]; + let short_ints_1_sum = split_fn(split[0][1] + split[1][0], SHORTLEN, SHORTLEN); + short_ints[1] = short_ints_1_sum[0]; + carry[1] = short_ints_1_sum[1]; + + for ii in 2..SIZE { + // relatively, this is from the last long int + let l2_s0 = split[ii][0]; + // relatively, this is from the middle long int + // todo: without branching, ii - 1 could overview the range when ii is 0 + // todo: thus for now, enforce SIZE >= 3 + let l1_s1 = if ii == 0 { zero } else { split[ii - 1][1] }; + // relatively, this is from the first long int + let l0_s2 = if comparator::less_than(size_bit_len, ii, 2) { zero } else { split[ii - 2][2] }; + let c_prev = carry[ii - 1]; + + // Compute sum and carry for indices less than SIZE + let sum = ((l2_s0 + l1_s1) + l0_s2) + c_prev; + let sum_and_carry = split_fn(sum, SHORTLEN, SHORTLEN); + short_ints[ii] = sum_and_carry[0]; + carry[ii] = sum_and_carry[1]; + } + + // Compute the l2_s1 and l1_s2 + let l2_s1 = if SIZE == 1 { split[0][1] } else { split[SIZE - 1][1] }; + let l1_s2 = if comparator::less_than(size_bit_len, SIZE, 2) { zero } else { split[SIZE - 2][2] }; + + short_ints[SIZE] = (l2_s1 + l1_s2) + carry[SIZE - 1]; + + // range checks + for ii in 0..SIZE + 1 { + // todo: update the type checker to allow unused return value? + // or a wrapper function to simply do the range check without return? + let ignore = bits::to_bits(SHORTLEN, short_ints[ii]); + } + + // prove the relationship between long ints and short ints + // the running carry should be the same as the last short int + let mut running_carry = [0; SIZE]; + // this can be also seen as a right shift + running_carry[0] = builtins::div(long_ints[0] - short_ints[0], bits::left_shift(1, SHORTLEN)); + // range check + let ignore = bits::to_bits(SHORTLEN + size_bit_len, running_carry[0]); + assert_eq(running_carry[0] * bits::left_shift(1, SHORTLEN), long_ints[0] - short_ints[0]); + + for ii in 1..SIZE { + let remaining = (long_ints[ii] - short_ints[ii]) + running_carry[ii - 1]; + running_carry[ii] = builtins::div(remaining, bits::left_shift(1, SHORTLEN)); + builtins::log(SHORTLEN + size_bit_len); + builtins::log(remaining); + // range check + let ignore_ = bits::to_bits(SHORTLEN + size_bit_len, running_carry[ii]); + assert_eq(running_carry[ii] * bits::left_shift(1, SHORTLEN), remaining); + } + + assert_eq(running_carry[SIZE - 1], short_ints[SIZE]); + + return short_ints; +} + +fn mult_limbs(const BITLEN: Field, lhs: [Field; SIZE], rhs: [Field; SIZE]) -> [Field; 2 * SIZE] { + let mut out = [0; 2 * SIZE]; + + let longs = mult_limbs_no_carry(BITLEN, lhs, rhs); + let shorts = long_to_short(BITLEN, longs); + + return shorts; +} + +fn uint8_to_fields(val: [int::Uint8; REGISTERS]) -> [Field; REGISTERS] { + let mut out = [0; REGISTERS]; + for ii in 0..REGISTERS { + out[ii] = val[ii].inner; + } + return out; +} + +fn uint8_add(lhs: [int::Uint8; REGISTERS], rhs: [int::Uint8; REGISTERS]) -> [int::Uint8; REGISTERS + 1] { + let mut res = [int::Uint8.new(0); REGISTERS + 1]; + + let lhs_fields = uint8_to_fields(lhs); + let rhs_fields = uint8_to_fields(rhs); + + let raw_res = add_limbs(8, lhs_fields, rhs_fields); + + for ii in 0..REGISTERS + 1 { + res[ii] = int::Uint8.new(raw_res[ii]); + } + + return res; +} + +fn uint8_mult(lhs: [int::Uint8; REGISTERS], rhs: [int::Uint8; REGISTERS]) -> [int::Uint8; REGISTERS * 2] { + let mut res = [int::Uint8.new(0); REGISTERS * 2]; + + let lhs_fields = uint8_to_fields(lhs); + let rhs_fields = uint8_to_fields(rhs); + + let raw_res = mult_limbs(8, lhs_fields, rhs_fields); + + for ii in 0..REGISTERS * 2 { + res[ii] = int::Uint8.new(raw_res[ii]); + } + + return res; +} \ No newline at end of file diff --git a/src/stdlib/native/comparator.no b/src/stdlib/native/comparator.no new file mode 100644 index 000000000..43a4e14bf --- /dev/null +++ b/src/stdlib/native/comparator.no @@ -0,0 +1,44 @@ +use std::bits; +use std::int; + +// Instead of comparing bit by bit, we check the carry bit: +// lhs + (1 << LEN) - rhs +// proof: +// lhs + (1 << LEN) will add a carry bit, valued 1, to the bit array representing lhs, +// resulted in a bit array of length LEN + 1, named as sum_bits. +// if `lhs < rhs``, then `lhs - rhs < 0`, thus `(1 << LEN) + lhs - rhs < (1 << LEN)` +// then, the carry bit of sum_bits is 0. +// otherwise, the carry bit of sum_bits is 1. +fn less_than(const LEN: Field, lhs: Field, rhs: Field) -> Bool { + let carry_bit_len = LEN + 1; + + // 1 << LEN + let mut pow2 = 1; + for ii in 0..LEN { + pow2 = pow2 + pow2; + } + + let sum = (pow2 + lhs) - rhs; + let sum_bit = bits::to_bits(carry_bit_len, sum); + + // todo: modify the ife to allow literals + let b1 = false; + let b2 = true; + let res = if sum_bit[LEN] { b1 } else { b2 }; + + return res; +} + +// Less than or equal to. +// based on the proof of less_than(): +// adding 1 to the rhs, can upper bound by 1 for the lhs: +// lhs < rhs + 1 +// is equivalent to +// lhs <= rhs +fn less_eq_than(const LEN: Field, lhs: Field, rhs: Field) -> Bool { + return less_than(LEN, lhs, rhs + 1); +} + +fn uint8_less_than(lhs: int::Uint8, rhs: int::Uint8) -> Bool { + return less_than(8, lhs.inner, rhs.inner); +} \ No newline at end of file diff --git a/src/stdlib/native/int.no b/src/stdlib/native/int.no new file mode 100644 index 000000000..9b0bc1067 --- /dev/null +++ b/src/stdlib/native/int.no @@ -0,0 +1,17 @@ +use std::bits; + +struct Uint8 { + // todo: maybe add a const attribute to Field to forbid reassignment + inner: Field, + bit_len: Field, +} + +fn Uint8.new(val: Field) -> Uint8 { + // range check + let ignore_ = bits::to_bits(8, val); + + return Uint8 { + inner: val, + bit_len: 8 + }; +} \ No newline at end of file diff --git a/src/tests/examples.rs b/src/tests/examples.rs index 3aade9de0..05a4b4768 100644 --- a/src/tests/examples.rs +++ b/src/tests/examples.rs @@ -13,6 +13,8 @@ use crate::{ type_checker::TypeChecker, }; +use super::init_stdlib_dep; + fn test_file( file_name: &str, public_inputs: &str, @@ -36,6 +38,7 @@ fn test_file( // compile let mut sources = Sources::new(); let mut tast = TypeChecker::new(); + init_stdlib_dep(&mut sources, &mut tast); let this_module = None; let _node_id = typecheck_next_file( &mut tast, @@ -98,6 +101,7 @@ fn test_file( // compile let mut sources = Sources::new(); let mut tast = TypeChecker::new(); + init_stdlib_dep(&mut sources, &mut tast); let this_module = None; let _node_id = typecheck_next_file( &mut tast, diff --git a/src/tests/mod.rs b/src/tests/mod.rs index 3d2702bcb..34c0d577b 100644 --- a/src/tests/mod.rs +++ b/src/tests/mod.rs @@ -1,2 +1,25 @@ +use std::path::Path; + +use crate::{ + backends::Backend, + cli::packages::UserRepo, + compiler::{typecheck_next_file, Sources}, + type_checker::TypeChecker, +}; + mod examples; mod modules; +mod stdlib; + +fn init_stdlib_dep(sources: &mut Sources, tast: &mut TypeChecker) { + let libs = vec!["int", "comparator", "bigint"]; + + // read stdlib files from src/stdlib/native/ + for lib in libs { + let module = UserRepo::new(&format!("std/{}", lib)); + let prefix_stdlib = Path::new("src/stdlib/native/"); + let code = std::fs::read_to_string(prefix_stdlib.join(format!("{lib}.no"))).unwrap(); + let _node_id = + typecheck_next_file(tast, Some(module), sources, lib.to_string(), code, 0).unwrap(); + } +} diff --git a/src/tests/stdlib/bigint/add_limbs/main.no b/src/tests/stdlib/bigint/add_limbs/main.no new file mode 100644 index 000000000..ff0a45d86 --- /dev/null +++ b/src/tests/stdlib/bigint/add_limbs/main.no @@ -0,0 +1,24 @@ +use std::bigint; +use std::builtins; +use std::int; + +// todo: allow constant var be symbolic value in main function sig +fn main(pub lhs: [Field; 3], rhs: [Field; 3]) -> [Field; 4] { + let registers = 3; + let mut lhs_uint8 = [int::Uint8.new(0); registers]; + let mut rhs_uint8 = [int::Uint8.new(0); registers]; + // todo: the following throws error, might due to parsing issue + // this happens when int::Uint8.new() returns a struct + // update the log to print out the type to debug + // let mut rhs_uint8 = [int::Uint8.new(0); 32]; + + for ii in 0..registers { + lhs_uint8[ii] = int::Uint8.new(lhs[ii]); + rhs_uint8[ii] = int::Uint8.new(rhs[ii]); + } + + let res_bigint = bigint::uint8_add(lhs_uint8, rhs_uint8); + let res = bigint::uint8_to_fields(res_bigint); + + return res; +} \ No newline at end of file diff --git a/src/tests/stdlib/bigint/mod.rs b/src/tests/stdlib/bigint/mod.rs new file mode 100644 index 000000000..be01f0610 --- /dev/null +++ b/src/tests/stdlib/bigint/mod.rs @@ -0,0 +1,162 @@ +use crate::error; + +use super::test_stdlib; +use ark_ff::One; +use error::Result; +use num_bigint::BigInt; + +#[test] +fn test_module_sum_with_carry() -> Result<()> { + let public_inputs = r#"{"lhs": "4"}"#; + let private_inputs = r#"{"rhs": "5"}"#; + + test_stdlib( + "bigint/module_sum/main.no", + "bigint/module_sum/main.asm", + public_inputs, + private_inputs, + vec!["1", "1"], + )?; + + Ok(()) +} + +#[test] +fn test_module_sum_without_carry() -> Result<()> { + let public_inputs = r#"{"lhs": "4"}"#; + let private_inputs = r#"{"rhs": "3"}"#; + + test_stdlib( + "bigint/module_sum/main.no", + "bigint/module_sum/main.asm", + public_inputs, + private_inputs, + vec!["7", "0"], + )?; + + Ok(()) +} + +#[test] +fn test_module_sum_three_with_carry() -> Result<()> { + let public_inputs = r#"{"lhs": "4"}"#; + let private_inputs = r#"{"rhs": "2"}"#; + + test_stdlib( + "bigint/module_sum_three/main.no", + "bigint/module_sum_three/main.asm", + public_inputs, + private_inputs, + vec!["1", "1"], + )?; + + Ok(()) +} + +#[test] +fn test_module_sum_three_with_carry_two() -> Result<()> { + let public_inputs = r#"{"lhs": "7"}"#; + let private_inputs = r#"{"rhs": "7"}"#; + + test_stdlib( + "bigint/module_sum_three/main.no", + "bigint/module_sum_three/main.asm", + public_inputs, + private_inputs, + // 2 carry bits + vec!["1", "2"], + )?; + + Ok(()) +} + +#[test] +fn test_module_sum_three_without_carry() -> Result<()> { + let public_inputs = r#"{"lhs": "1"}"#; + let private_inputs = r#"{"rhs": "1"}"#; + + test_stdlib( + "bigint/module_sum_three/main.no", + "bigint/module_sum_three/main.asm", + public_inputs, + private_inputs, + vec!["5", "0"], + )?; + + Ok(()) +} + +fn bigint_to_array(n: u32, k: u32, x: BigInt) -> Vec { + // Compute modulus as 2^n + let modulus = BigInt::one() << n; + + let mut ret = Vec::new(); + let mut x_temp = x; + + for _ in 0..k { + // Get the remainder of x_temp divided by modulus + let remainder = &x_temp % &modulus; + ret.push(remainder.clone()); + + // Divide x_temp by modulus (equivalent to right-shifting by n bits) + x_temp >>= n; + } + + ret +} + +#[test] +fn test_add_limbs() -> Result<()> { + let a = bigint_to_array(8, 3, BigInt::from(16)); + let b = bigint_to_array(8, 3, BigInt::from(17)); + let c = bigint_to_array(8, 4, BigInt::from(33)); + + let lhs_strings: Vec = a.iter().map(|num| num.to_string()).collect(); + let rhs_strings: Vec = b.iter().map(|num| num.to_string()).collect(); + let sum_strings: Vec = c.iter().map(|num| num.to_string()).collect(); + + // Serialize to JSON string + let lhs_strings = serde_json::to_string(&lhs_strings).unwrap(); + let rhs_strings = serde_json::to_string(&rhs_strings).unwrap(); + + let public_inputs = format!(r#"{{"lhs": {}}}"#, lhs_strings); + let private_inputs = format!(r#"{{"rhs": {}}}"#, rhs_strings); + + test_stdlib( + "bigint/add_limbs/main.no", + "bigint/add_limbs/main.asm", + &public_inputs, + &private_inputs, + sum_strings.iter().map(|s| s.as_str()).collect(), + )?; + + Ok(()) +} + +#[test] +fn test_mult_limbs() -> Result<()> { + let a = bigint_to_array(8, 3, BigInt::from(10)); + let b = bigint_to_array(8, 3, BigInt::from(10)); + let c = bigint_to_array(8, 6, BigInt::from(100)); + + let lhs_strings: Vec = a.iter().map(|num| num.to_string()).collect(); + let rhs_strings: Vec = b.iter().map(|num| num.to_string()).collect(); + let res_strings: Vec = c.iter().map(|num| num.to_string()).collect(); + + // Serialize to JSON string + let lhs_strings = serde_json::to_string(&lhs_strings).unwrap(); + let rhs_strings = serde_json::to_string(&rhs_strings).unwrap(); + + let public_inputs = format!(r#"{{"lhs": {}}}"#, lhs_strings); + let private_inputs = format!(r#"{{"rhs": {}}}"#, rhs_strings); + + test_stdlib( + "bigint/mult_limbs/main.no", + "bigint/mult_limbs/main.asm", + &public_inputs, + &private_inputs, + res_strings.iter().map(|s| s.as_str()).collect(), + )?; + + Ok(()) +} diff --git a/src/tests/stdlib/bigint/module_sum/main.asm b/src/tests/stdlib/bigint/module_sum/main.asm new file mode 100644 index 000000000..a647be1a4 --- /dev/null +++ b/src/tests/stdlib/bigint/module_sum/main.asm @@ -0,0 +1,14 @@ +@ noname.0.7.0 +@ public inputs: 3 + +v_6 == (v_5) * (v_5 + -1) +0 == (v_6) * (1) +v_8 == (v_7) * (v_7 + -1) +0 == (v_8) * (1) +v_10 == (v_9) * (v_9 + -1) +0 == (v_10) * (1) +v_12 == (v_11) * (v_11 + -1) +0 == (v_12) * (1) +v_5 + 2 * v_7 + 4 * v_9 + 8 * v_11 == (v_3 + v_4) * (1) +v_3 + v_4 + -8 * v_11 == (v_1) * (1) +v_11 == (v_2) * (1) diff --git a/src/tests/stdlib/bigint/module_sum/main.no b/src/tests/stdlib/bigint/module_sum/main.no new file mode 100644 index 000000000..12e0f659d --- /dev/null +++ b/src/tests/stdlib/bigint/module_sum/main.no @@ -0,0 +1,7 @@ +use std::bigint; + +fn main(pub lhs: Field, rhs: Field) -> bigint::Result { + let mod_bits = 3; + let res = bigint::module_sum(mod_bits, lhs, rhs); + return res; +} \ No newline at end of file diff --git a/src/tests/stdlib/bigint/module_sum_three/main.asm b/src/tests/stdlib/bigint/module_sum_three/main.asm new file mode 100644 index 000000000..b84d8cea0 --- /dev/null +++ b/src/tests/stdlib/bigint/module_sum_three/main.asm @@ -0,0 +1,16 @@ +@ noname.0.7.0 +@ public inputs: 3 + +v_6 == (v_5) * (v_5 + -1) +0 == (v_6) * (1) +v_8 == (v_7) * (v_7 + -1) +0 == (v_8) * (1) +v_10 == (v_9) * (v_9 + -1) +0 == (v_10) * (1) +v_12 == (v_11) * (v_11 + -1) +0 == (v_12) * (1) +v_14 == (v_13) * (v_13 + -1) +0 == (v_14) * (1) +v_5 + 2 * v_7 + 4 * v_9 + 8 * v_11 + 16 * v_13 == (v_3 + v_4 + 3) * (1) +v_3 + v_4 + -8 * v_11 + -16 * v_13 + 3 == (v_1) * (1) +v_11 + 2 * v_13 == (v_2) * (1) diff --git a/src/tests/stdlib/bigint/module_sum_three/main.no b/src/tests/stdlib/bigint/module_sum_three/main.no new file mode 100644 index 000000000..1b0770a1e --- /dev/null +++ b/src/tests/stdlib/bigint/module_sum_three/main.no @@ -0,0 +1,8 @@ +use std::bigint; + +fn main(pub lhs: Field, rhs: Field) -> bigint::Result { + let mod_bits = 3; + let another_val = 3; + let res = bigint::module_sum_three(mod_bits, lhs, rhs, another_val); + return res; +} \ No newline at end of file diff --git a/src/tests/stdlib/bigint/mult_limbs/main.no b/src/tests/stdlib/bigint/mult_limbs/main.no new file mode 100644 index 000000000..3a171818c --- /dev/null +++ b/src/tests/stdlib/bigint/mult_limbs/main.no @@ -0,0 +1,19 @@ +use std::bigint; +use std::int; + +fn main(pub lhs: [Field; 3], rhs: [Field; 3]) -> [Field; 6] { + let registers = 3; + let val = int::Uint8.new(0); + // todo: bug: when assigning using [int::Uint8.new(0); registers], it throws error "method call only work on custom types" + let mut lhs_uint8 = [val; registers]; + let mut rhs_uint8 = [val; registers]; + + for ii in 0..registers { + lhs_uint8[ii] = int::Uint8.new(lhs[ii]); + rhs_uint8[ii] = int::Uint8.new(rhs[ii]); + } + + let res_uint8 = bigint::uint8_mult(lhs_uint8, rhs_uint8); + let res = bigint::uint8_to_fields(res_uint8); + return res; +} \ No newline at end of file diff --git a/src/tests/stdlib/comparator/less_eq_than/less_eq_than.asm b/src/tests/stdlib/comparator/less_eq_than/less_eq_than.asm new file mode 100644 index 000000000..396f070da --- /dev/null +++ b/src/tests/stdlib/comparator/less_eq_than/less_eq_than.asm @@ -0,0 +1,11 @@ +@ noname.0.7.0 +@ public inputs: 2 + +v_5 == (v_4) * (v_4 + -1) +0 == (v_5) * (1) +v_7 == (v_6) * (v_6 + -1) +0 == (v_7) * (1) +v_9 == (v_8) * (v_8 + -1) +0 == (v_9) * (1) +v_4 + 2 * v_6 + 4 * v_8 == (v_2 + -1 * v_3 + 3) * (1) +-1 * v_8 + 1 == (v_1) * (1) diff --git a/src/tests/stdlib/comparator/less_eq_than/less_eq_than_main.no b/src/tests/stdlib/comparator/less_eq_than/less_eq_than_main.no new file mode 100644 index 000000000..5f998ae9f --- /dev/null +++ b/src/tests/stdlib/comparator/less_eq_than/less_eq_than_main.no @@ -0,0 +1,6 @@ +use std::comparator; + +fn main(pub lhs: Field, rhs: Field) -> Bool { + let res = comparator::less_eq_than(2, lhs, rhs); + return res; +} \ No newline at end of file diff --git a/src/tests/stdlib/comparator/less_than/less_than.asm b/src/tests/stdlib/comparator/less_than/less_than.asm new file mode 100644 index 000000000..8c33ac348 --- /dev/null +++ b/src/tests/stdlib/comparator/less_than/less_than.asm @@ -0,0 +1,11 @@ +@ noname.0.7.0 +@ public inputs: 2 + +v_5 == (v_4) * (v_4 + -1) +0 == (v_5) * (1) +v_7 == (v_6) * (v_6 + -1) +0 == (v_7) * (1) +v_9 == (v_8) * (v_8 + -1) +0 == (v_9) * (1) +v_4 + 2 * v_6 + 4 * v_8 == (v_2 + -1 * v_3 + 4) * (1) +-1 * v_8 + 1 == (v_1) * (1) diff --git a/src/tests/stdlib/comparator/less_than/less_than_main.no b/src/tests/stdlib/comparator/less_than/less_than_main.no new file mode 100644 index 000000000..f0cba5462 --- /dev/null +++ b/src/tests/stdlib/comparator/less_than/less_than_main.no @@ -0,0 +1,10 @@ +use std::comparator; +use std::int; + +fn main(pub lhs: Field, rhs: Field) -> Bool { + // todo bug: this also throws error "method call only work on custom types" + let lhs_bigint = int::Uint8.new(lhs); + let rhs_bigint = int::Uint8.new(rhs); + // let res = comparator::uint8_less_than(lhs_bigint, rhs_bigint); + return true; +} \ No newline at end of file diff --git a/src/tests/stdlib/comparator/mod.rs b/src/tests/stdlib/comparator/mod.rs new file mode 100644 index 000000000..bc85f91fe --- /dev/null +++ b/src/tests/stdlib/comparator/mod.rs @@ -0,0 +1,89 @@ +use crate::error; + +use super::test_stdlib; +use error::Result; + +#[test] +fn test_less_than_true() -> Result<()> { + let public_inputs = r#"{"lhs": "0"}"#; + let private_inputs = r#"{"rhs": "1"}"#; + + test_stdlib( + "comparator/less_than/less_than_main.no", + "comparator/less_than/less_than.asm", + public_inputs, + private_inputs, + vec!["1"], + )?; + + Ok(()) +} + +// test false +#[test] +fn test_less_than_false() -> Result<()> { + let public_inputs = r#"{"lhs": "1"}"#; + let private_inputs = r#"{"rhs": "0"}"#; + + test_stdlib( + "comparator/less_than/less_than_main.no", + "comparator/less_than/less_than.asm", + public_inputs, + private_inputs, + vec!["0"], + )?; + + Ok(()) +} + +#[test] +fn test_less_eq_than_true_1() -> Result<()> { + let public_inputs = r#"{"lhs": "0"}"#; + let private_inputs = r#"{"rhs": "1"}"#; + + test_stdlib( + "comparator/less_eq_than/less_eq_than_main.no", + "comparator/less_eq_than/less_eq_than.asm", + public_inputs, + private_inputs, + vec!["1"], + )?; + + Ok(()) +} + +#[test] +fn test_less_eq_than_true_2() -> Result<()> { + let public_inputs = r#"{"lhs": "1"}"#; + let private_inputs = r#"{"rhs": "1"}"#; + + test_stdlib( + "comparator/less_eq_than/less_eq_than_main.no", + "comparator/less_eq_than/less_eq_than.asm", + public_inputs, + private_inputs, + vec!["1"], + )?; + + Ok(()) +} + +#[test] +fn test_less_eq_than_false() -> Result<()> { + let public_inputs = r#"{"lhs": "1"}"#; + let private_inputs = r#"{"rhs": "0"}"#; + + test_stdlib( + "comparator/less_eq_than/less_eq_than_main.no", + "comparator/less_eq_than/less_eq_than.asm", + public_inputs, + private_inputs, + vec!["0"], + )?; + + Ok(()) +} + +// test value overflow modulus +// it shouldn't need user to enter the bit length +// should have a way to restrict and type check the value to a certain bit length diff --git a/src/tests/stdlib/mod.rs b/src/tests/stdlib/mod.rs new file mode 100644 index 000000000..3955a474f --- /dev/null +++ b/src/tests/stdlib/mod.rs @@ -0,0 +1,94 @@ +mod bigint; +mod comparator; + +use std::{path::Path, str::FromStr}; + +use crate::{ + backends::r1cs::{R1csBn254Field, R1CS}, + circuit_writer::CircuitWriter, + compiler::{typecheck_next_file, Sources}, + error::Result, + inputs::parse_inputs, + mast, + tests::init_stdlib_dep, + type_checker::TypeChecker, + witness::CompiledCircuit, +}; + +fn test_stdlib( + path: &str, + asm_path: &str, + public_inputs: &str, + private_inputs: &str, + expected_public_output: Vec<&str>, +) -> Result>> { + let r1cs = R1CS::new(); + let root = env!("CARGO_MANIFEST_DIR"); + let prefix_path = Path::new(root).join("src/tests/stdlib"); + + // read noname file + let code = std::fs::read_to_string(prefix_path.clone().join(path)).unwrap(); + + // parse inputs + let public_inputs = parse_inputs(public_inputs).unwrap(); + let private_inputs = parse_inputs(private_inputs).unwrap(); + + // compile + let mut sources = Sources::new(); + let mut tast = TypeChecker::new(); + init_stdlib_dep(&mut sources, &mut tast); + + let this_module = None; + let _node_id = typecheck_next_file( + &mut tast, + this_module, + &mut sources, + path.to_string(), + code.clone(), + 0, + ) + .unwrap(); + + let mast = mast::monomorphize(tast)?; + let compiled_circuit = CircuitWriter::generate_circuit(mast, r1cs)?; + + // this should check the constraints + let generated_witness = compiled_circuit + .generate_witness(public_inputs.clone(), private_inputs.clone()) + .unwrap(); + + let expected_public_output = expected_public_output + .iter() + .map(|x| crate::backends::r1cs::R1csBn254Field::from_str(x).unwrap()) + .collect::>(); + + if generated_witness.outputs != expected_public_output { + eprintln!("obtained by executing the circuit:"); + generated_witness + .outputs + .iter() + .for_each(|x| eprintln!("- {x}")); + eprintln!("passed as output by the verifier:"); + expected_public_output + .iter() + .for_each(|x| eprintln!("- {x}")); + panic!("Obtained output does not match expected output"); + } + + // check the ASM + if compiled_circuit.circuit.backend.num_constraints() < 100 { + let prefix_asm = Path::new(root).join("src/tests/stdlib/"); + let expected_asm = std::fs::read_to_string(prefix_asm.clone().join(asm_path)).unwrap(); + let obtained_asm = compiled_circuit.asm(&Sources::new(), false); + + if obtained_asm != expected_asm { + eprintln!("obtained:"); + eprintln!("{obtained_asm}"); + eprintln!("expected:"); + eprintln!("{expected_asm}"); + panic!("Obtained ASM does not match expected ASM"); + } + } + + Ok(compiled_circuit) +} diff --git a/src/type_checker/checker.rs b/src/type_checker/checker.rs index 7046d8141..4cbe5572a 100644 --- a/src/type_checker/checker.rs +++ b/src/type_checker/checker.rs @@ -143,7 +143,20 @@ impl TypeChecker { // check if generic is allowed if fn_sig.require_monomorphization() && typed_fn_env.is_in_forloop() { - return Err(self.error(ErrorKind::GenericInForLoop, expr.span)); + for (observed_arg, expected_arg) in args.iter().zip(fn_sig.arguments.iter()) { + // check if the arg involves generic vars + if !expected_arg.extract_generic_names().is_empty() { + let mut forbidden_env = typed_fn_env.clone(); + forbidden_env.forbid_forloop_scope(); + + // rewalk the observed arg expression + // it should throw an error if the arg contains generic vars relating to the variables in the forloop scope + self.compute_type(observed_arg, &mut forbidden_env)?; + + // release the forbidden flag + forbidden_env.allow_forloop_scope(); + } + } } // type check the function call @@ -255,7 +268,7 @@ impl TypeChecker { // check that the var exists locally let lhs_info = typed_fn_env - .get_type_info(&lhs_name) + .get_type_info(&lhs_name)? .expect("variable not found (TODO: replace with error") .clone(); @@ -370,7 +383,7 @@ impl TypeChecker { // otherwise it's a local variable // generic parameter is also checked as a local variable typed_fn_env - .get_type(&name.value) + .get_type(&name.value)? .ok_or_else(|| self.error(ErrorKind::UndefinedVariable, name.span))? .clone() }; @@ -579,9 +592,12 @@ impl TypeChecker { typed_fn_env: &mut TypedFnEnv, stmts: &[Stmt], expected_return: Option<&Ty>, + new_scope: bool, ) -> Result<()> { // enter the scope - typed_fn_env.nest(); + if new_scope { + typed_fn_env.nest(); + } let mut return_typ = None; @@ -613,7 +629,9 @@ impl TypeChecker { }; // exit the scope - typed_fn_env.pop(); + if new_scope { + typed_fn_env.pop(); + } Ok(()) } @@ -688,7 +706,7 @@ impl TypeChecker { typed_fn_env.start_forloop(); // check block - self.check_block(typed_fn_env, body, None)?; + self.check_block(typed_fn_env, body, None, false)?; // exit the scope typed_fn_env.pop(); @@ -725,7 +743,7 @@ impl TypeChecker { span: Span, ) -> Result> { // check if a function names is in use already by another variable - match typed_fn_env.get_type_info(&fn_sig.name.value) { + match typed_fn_env.get_type_info(&fn_sig.name.value)? { Some(_) => { return Err(self.error( ErrorKind::FunctionNameInUsebyVariable(fn_sig.name.value), @@ -766,6 +784,7 @@ impl TypeChecker { // compare argument types with the function signature for (sig_arg, (typ, span)) in expected.iter().zip(observed) { + // todo: disable the constant for now until fixed https://github.com/zksecurity/noname/issues/192 // when const attribute presented, the argument must be a constant if sig_arg.is_constant() && !matches!(typ, TyKind::Field { constant: true }) { return Err(self.error( diff --git a/src/type_checker/fn_env.rs b/src/type_checker/fn_env.rs index 7f0783a37..65ecc0c38 100644 --- a/src/type_checker/fn_env.rs +++ b/src/type_checker/fn_env.rs @@ -50,8 +50,11 @@ pub struct TypedFnEnv { // TODO: there's an output_type field that's a reserved keyword? vars: HashMap, - /// Whether it is in a for loop or not. - forloop: bool, + /// The forloop scope if it is within a for loop. + forloop_scope: Option, + + /// Determines if forloop variables are allowed to be accessed. + forbid_forloop_scope: bool, } impl TypedFnEnv { @@ -75,19 +78,31 @@ impl TypedFnEnv { .retain(|_name, (scope, _type_info)| *scope <= current_scope); } + pub fn forbid_forloop_scope(&mut self) { + self.forbid_forloop_scope = true; + } + + pub fn allow_forloop_scope(&mut self) { + self.forbid_forloop_scope = false; + } + /// Returns whether it is in a for loop. pub fn is_in_forloop(&self) -> bool { - self.forloop + if let Some(scope) = self.forloop_scope { + self.current_scope >= scope + } else { + false + } } /// Flags it as in the for loop. pub fn start_forloop(&mut self) { - self.forloop = true; + self.forloop_scope = Some(self.current_scope); } /// Flags it as not in the for loop. pub fn end_forloop(&mut self) { - self.forloop = false; + self.forloop_scope = None; } /// Returns true if a scope is a prefix of our scope. @@ -95,6 +110,16 @@ impl TypedFnEnv { self.current_scope >= prefix_scope } + pub fn is_forbidden(&self, scope: usize) -> bool { + let in_forbidden_scope = if let Some(forloop_scope) = self.forloop_scope { + scope >= forloop_scope + } else { + false + }; + + self.forbid_forloop_scope && in_forbidden_scope + } + /// Stores type information about a local variable. /// Note that we forbid shadowing at all scopes. pub fn store_type(&mut self, ident: String, type_info: TypeInfo) -> Result<()> { @@ -111,26 +136,36 @@ impl TypedFnEnv { } } - pub fn get_type(&self, ident: &str) -> Option<&TyKind> { - self.get_type_info(ident).map(|type_info| &type_info.typ) + pub fn get_type(&self, ident: &str) -> Result> { + Ok(self.get_type_info(ident)?.map(|type_info| &type_info.typ)) } - pub fn mutable(&self, ident: &str) -> Option { - self.get_type_info(ident).map(|type_info| type_info.mutable) + pub fn mutable(&self, ident: &str) -> Result> { + Ok(self + .get_type_info(ident)? + .map(|type_info| type_info.mutable)) } /// Retrieves type information on a variable, given a name. /// If the variable is not in scope, return false. // TODO: return an error no? - pub fn get_type_info(&self, ident: &str) -> Option<&TypeInfo> { + pub fn get_type_info(&self, ident: &str) -> Result> { if let Some((scope, type_info)) = self.vars.get(ident) { + if self.is_forbidden(*scope) { + return Err(Error::new( + "type-checker", + ErrorKind::VarAccessForbiddenInForLoop(ident.to_string()), + type_info.span, + )); + } + if self.is_in_scope(*scope) { - Some(type_info) + Ok(Some(type_info)) } else { - None + Ok(None) } } else { - None + Ok(None) } } } diff --git a/src/type_checker/mod.rs b/src/type_checker/mod.rs index 7306f2d39..a2b7469ef 100644 --- a/src/type_checker/mod.rs +++ b/src/type_checker/mod.rs @@ -415,6 +415,7 @@ impl TypeChecker { &mut typed_fn_env, &function.body, function.sig.return_type.as_ref(), + true, )?; } diff --git a/src/var.rs b/src/var.rs index e183a7085..e65756c96 100644 --- a/src/var.rs +++ b/src/var.rs @@ -8,6 +8,7 @@ use crate::{ circuit_writer::{CircuitWriter, FnEnv, VarInfo}, constants::Span, error::Result, + helpers::PrettyField, type_checker::ConstInfo, witness::WitnessEnv, }; @@ -57,6 +58,22 @@ where // but does this make sense to all different backends? is it possible that some backend doesn't allow certain out of circuit calculations like this? NthBit(B::Var, usize), + /// Left shift the variable. + LeftShift(B::Var, usize), + + /// Divide + // todo: refactor to use a argument wrapper to encapsulate its type, so that it can be either B::Var or B::Field + CstDivVar(B::Field, B::Var), + VarDivCst(B::Var, B::Field), + VarDivVar(B::Var, B::Var), + CstDivCst(B::Field, B::Field), + + /// Modulo + VarModVar(B::Var, B::Var), + CstModVar(B::Field, B::Var), + VarModCst(B::Var, B::Field), + CstModCst(B::Field, B::Field), + /// A public or private input to the function /// There's an index associated to a variable name, as the variable could be composed of several field elements. External(String, usize), @@ -78,6 +95,15 @@ impl std::fmt::Debug for Value { Value::PublicOutput(..) => write!(f, "PublicOutput"), Value::Scale(..) => write!(f, "Scaling"), Value::NthBit(_, _) => write!(f, "NthBit"), + Value::LeftShift(_, _) => write!(f, "LeftShift"), + Value::CstDivVar(lhs, rhs) => write!(f, "CstDivVar({:?}, {:?})", lhs, rhs), + Value::VarDivCst(lhs, rhs) => write!(f, "{:?} / {:?}", lhs, rhs.pretty()), + Value::VarDivVar(lhs, rhs) => write!(f, "VarDivVar({:?}, {:?})", lhs, rhs), + Value::CstDivCst(lhs, rhs) => write!(f, "CstDivCst({:?}, {:?})", lhs, rhs), + Value::VarModVar(lhs, rhs) => write!(f, "VarModVar({:?}, {:?})", lhs, rhs), + Value::CstModVar(lhs, rhs) => write!(f, "CstModVar({:?}, {:?})", lhs, rhs), + Value::VarModCst(lhs, rhs) => write!(f, "{:?} % {:?}", lhs, rhs.pretty()), + Value::CstModCst(lhs, rhs) => write!(f, "CstModCst({:?}, {:?})", lhs, rhs), } } }