diff --git a/examples/arithmetic.no b/examples/arithmetic.no index ea3073c60..8dc3bdca0 100644 --- a/examples/arithmetic.no +++ b/examples/arithmetic.no @@ -2,4 +2,24 @@ fn main(pub public_input: Field, private_input: Field) { let xx = private_input + public_input; let yy = private_input * public_input; assert_eq(xx, yy); + + // modulus constant + let mod_res = (xx + yy) % 3; + assert_eq(mod_res, 2); + + // modulus var + let mod_var_res = (xx + yy) % (private_input + 1); + assert_eq(mod_var_res, 2); + + // divide constant + let div_res = xx / 2; + assert_eq(div_res, 2); + + // divide var + let div_var_res = xx / private_input; + assert_eq(div_res, 2); + + // left shift + let lf_res = xx << 2; + assert_eq(lf_res, 16); } diff --git a/examples/comparator.no b/examples/comparator.no new file mode 100644 index 000000000..d508fbf53 --- /dev/null +++ b/examples/comparator.no @@ -0,0 +1,4 @@ +fn main(pub xx: Field, yy: Field) { + let res = xx < yy; + assert(res); +} diff --git a/src/backends/mod.rs b/src/backends/mod.rs index fabfa2b83..0fa3368e9 100644 --- a/src/backends/mod.rs +++ b/src/backends/mod.rs @@ -198,6 +198,70 @@ pub trait Backend: Clone { Ok(res) } + Value::VarDivVar(dividend, divisor) => { + let dividend = self.compute_var(env, dividend)?.to_biguint(); + let divisor = self.compute_var(env, divisor)?.to_biguint(); + let res = Self::Field::from(dividend / divisor); + + env.cached_values.insert(cache_key, res); + Ok(res) + } + Value::CstDivVar(dividend, divisor) => { + let divisor = self.compute_var(env, divisor)?; + let res = *dividend / divisor; + Ok(res) + } + Value::VarDivCst(dividend, divisor) => { + let dividend = self.compute_var(env, dividend)?; + // convert to bigint + let dividend = dividend.to_biguint(); + let divisor = divisor.to_biguint(); + + let res = Self::Field::from(dividend / divisor); + env.cached_values.insert(cache_key, res); + Ok(res) + } + Value::CstDivCst(dividend, divisor) => { + let res = *dividend / *divisor; + env.cached_values.insert(cache_key, res); + Ok(res) + } + Value::VarModVar(dividend, divisor) => { + let dividend = self.compute_var(env, dividend)?; + let divisor = self.compute_var(env, divisor)?; + // convert to bigint + let dividend = dividend.to_biguint(); + let divisor = divisor.to_biguint(); + let res = Self::Field::from(dividend % divisor); + + env.cached_values.insert(cache_key, res); + Ok(res) + } + Value::CstModVar(dividend, divisor) => { + let divisor = self.compute_var(env, divisor)?; + // convert to bigint + let dividend = dividend.to_biguint(); + let divisor = divisor.to_biguint(); + let res = Self::Field::from(dividend % divisor); + env.cached_values.insert(cache_key, res); + Ok(res) + } + Value::VarModCst(dividend, divisor) => { + let dividend = self.compute_var(env, dividend)?; + // convert to bigint + let dividend = dividend.to_biguint(); + let divisor = divisor.to_biguint(); + let res = Self::Field::from(dividend % divisor); + env.cached_values.insert(cache_key, res); + Ok(res) + } + Value::CstModCst(dividend, divisor) => { + let dividend = dividend.to_biguint(); + let divisor = divisor.to_biguint(); + let res = Self::Field::from(dividend % divisor); + env.cached_values.insert(cache_key, res); + Ok(res) + } } } diff --git a/src/circuit_writer/writer.rs b/src/circuit_writer/writer.rs index 924c57b6f..183b713a6 100644 --- a/src/circuit_writer/writer.rs +++ b/src/circuit_writer/writer.rs @@ -619,11 +619,16 @@ impl CircuitWriter { Op2::Addition => field::add(self, &lhs[0], &rhs[0], expr.span), Op2::Subtraction => field::sub(self, &lhs[0], &rhs[0], expr.span), Op2::Multiplication => field::mul(self, &lhs[0], &rhs[0], expr.span), + Op2::Modulus => field::modulus(self, &lhs[0], &rhs[0], expr.span), + Op2::Division => field::div(self, &lhs[0], &rhs[0], expr.span), Op2::Equality => field::equal(self, &lhs, &rhs, expr.span), Op2::Inequality => field::not_equal(self, &lhs, &rhs, expr.span), + // todo: refactor the input vars from Var to VarInfo, + // which contain the type to provide the info about the bit length + Op2::LessThan => field::less_than(self, None, &lhs[0], &rhs[0], expr.span), + Op2::LeftShift => field::left_shift(self, &lhs[0], &rhs[0], expr.span), Op2::BoolAnd => boolean::and(self, &lhs[0], &rhs[0], expr.span), Op2::BoolOr => boolean::or(self, &lhs[0], &rhs[0], expr.span), - Op2::Division => todo!(), }; Ok(Some(VarOrRef::Var(res))) diff --git a/src/constraints/field.rs b/src/constraints/field.rs index 38e120117..5ee221572 100644 --- a/src/constraints/field.rs +++ b/src/constraints/field.rs @@ -1,13 +1,16 @@ use crate::{ backends::Backend, - circuit_writer::CircuitWriter, + circuit_writer::{CircuitWriter, VarInfo}, constants::Span, + parser::types::{GenericParameters, TyKind}, + stdlib::bits::to_bits, var::{ConstOrCell, Value, Var}, }; use super::boolean; -use ark_ff::{One, Zero}; +use ark_ff::{Field, One, PrimeField, Zero}; +use kimchi::o1_utils::FieldHelpers; use std::ops::Neg; @@ -99,6 +102,180 @@ pub fn mul( } } +fn constrain_div_mod( + compiler: &mut CircuitWriter, + lhs: &ConstOrCell, + rhs: &ConstOrCell, + span: Span, +) -> (B::Var, B::Var) { + // to constrain lhs − q * rhs − rem = 0 + // where rhs is the modulus + // so 0 <= rem < rhs + + let one = B::Field::one(); + + // todo: to avoid duplicating a lot of code due the different combinations of the input types + // until we refactor the backend to handle ConstOrCell or some kind of wrapper that encapsulate the different variable types + // convert cst to var for easier handling + let lhs = match lhs { + ConstOrCell::Const(lhs) => { + compiler + .backend + .add_constant(Some("wrap a constant as var"), *lhs, span) + } + ConstOrCell::Cell(lhs) => lhs.clone(), + }; + + let rhs = match rhs { + ConstOrCell::Const(rhs) => { + compiler + .backend + .add_constant(Some("wrap a constant as var"), *rhs, span) + } + ConstOrCell::Cell(rhs) => rhs.clone(), + }; + + // witness var for quotient + let q = Value::VarDivVar(lhs.clone(), rhs.clone()); + let q_var = compiler.backend.new_internal_var(q, span); + + // witness var for remainder + let rem = Value::VarModVar(lhs.clone(), rhs.clone()); + let rem_var = compiler.backend.new_internal_var(rem, span); + + // rem < rhs + let lt_rem = &less_than( + compiler, + None, + &ConstOrCell::Cell(rem_var.clone()), + &ConstOrCell::Cell(rhs.clone()), + span, + )[0]; + let lt_rem = lt_rem.cvar().expect("expected a cell var"); + compiler.backend.assert_eq_const(lt_rem, one, span); + + // foundamental constraint: lhs - q * rhs - rem = 0 + let q_mul_rhs = compiler.backend.mul(&q_var, &rhs, span); + let lhs_sub_q_mul_rhs = compiler.backend.sub(&lhs, &q_mul_rhs, span); + + // cell representing the foundamental constraint + let fc_var = compiler.backend.sub(&lhs_sub_q_mul_rhs, &rem_var, span); + compiler + .backend + .assert_eq_const(&fc_var, B::Field::zero(), span); + + (rem_var, q_var) +} + +/// Divide operation +pub fn div( + compiler: &mut CircuitWriter, + lhs: &ConstOrCell, + rhs: &ConstOrCell, + span: Span, +) -> Var { + // to constrain lhs − q * rhs − rem = 0 + // rhs can't be zero + match rhs { + ConstOrCell::Const(rhs) => { + if rhs.is_zero() { + panic!("division by zero"); + } + } + _ => { + let is_zero = is_zero_cell(compiler, rhs, span); + let is_zero = is_zero[0].cvar().unwrap(); + compiler + .backend + .assert_eq_const(is_zero, B::Field::zero(), span); + } + }; + + match (lhs, rhs) { + // if rhs is a constant, we can just divide lhs by rhs + (ConstOrCell::Const(lhs), ConstOrCell::Const(rhs)) => { + // to bigint + let lhs = lhs.to_biguint(); + let rhs = rhs.to_biguint(); + let res = lhs / rhs; + + Var::new_constant(B::Field::from(res), span) + } + _ => { + let (_, q) = constrain_div_mod(compiler, lhs, rhs, span); + Var::new_var(q, span) + } + } +} + +/// Modulus operation +pub fn modulus( + compiler: &mut CircuitWriter, + lhs: &ConstOrCell, + rhs: &ConstOrCell, + span: Span, +) -> Var { + // to constrain lhs − q * rhs − rem = 0 + + let zero = B::Field::zero(); + + // rhs can't be zero + match &rhs { + ConstOrCell::Const(rhs) => { + if rhs.is_zero() { + panic!("modulus by zero"); + } + } + _ => { + let is_zero = is_zero_cell(compiler, rhs, span); + let is_zero = is_zero[0].cvar().unwrap(); + compiler.backend.assert_eq_const(is_zero, zero, span); + } + }; + + match (lhs, rhs) { + // if rhs is a constant, we can just divide lhs by rhs + (ConstOrCell::Const(lhs), ConstOrCell::Const(rhs)) => { + let lhs = lhs.to_biguint(); + let rhs = rhs.to_biguint(); + let res = lhs % rhs; + + Var::new_constant(res.into(), span) + } + _ => { + let (rem, _) = constrain_div_mod(compiler, lhs, rhs, span); + Var::new_var(rem, span) + } + } +} + +/// Left shift operation +pub fn left_shift( + compiler: &mut CircuitWriter, + lhs: &ConstOrCell, + rhs: &ConstOrCell, + span: Span, +) -> Var { + // to constrain lhs * (1 << rhs) = res + + match (lhs, rhs) { + (ConstOrCell::Const(lhs), ConstOrCell::Const(rhs)) => { + // convert to bigint + let pow2 = B::Field::from(2u32).pow(rhs.into_repr()); + let res = *lhs * pow2; + Var::new_constant(res, span) + } + (ConstOrCell::Cell(lhs_), ConstOrCell::Const(rhs_)) => { + let pow2 = B::Field::from(2u32).pow(rhs_.into_repr()); + let res = compiler.backend.mul_const(lhs_, &pow2, span); + Var::new_var(res, span) + } + // todo: wrap rhs in a symbolic value + (ConstOrCell::Const(_), ConstOrCell::Cell(_)) => todo!(), + (ConstOrCell::Cell(_), ConstOrCell::Cell(_)) => todo!(), + } +} + /// This takes variables that can be anything, and returns a boolean // TODO: so perhaps it's not really relevant in this file? pub fn equal( @@ -254,6 +431,119 @@ pub fn not_equal( acc } +/// Returns 1 if lhs < rhs, 0 otherwise +pub fn less_than( + compiler: &mut CircuitWriter, + bitlen: Option, + lhs: &ConstOrCell, + rhs: &ConstOrCell, + span: Span, +) -> Var { + let one = B::Field::one(); + let zero = B::Field::zero(); + + // Instead of comparing bit by bit, we check the carry bit: + // lhs + (1 << LEN) - rhs + // proof: + // lhs + (1 << LEN) will add a carry bit, valued 1, to the bit array representing lhs, + // resulted in a bit array of length LEN + 1, named as sum_bits. + // if `lhs < rhs``, then `lhs - rhs < 0`, thus `(1 << LEN) + lhs - rhs < (1 << LEN)` + // then, the carry bit of sum_bits is 0. + // otherwise, the carry bit of sum_bits is 1. + + /* + psuedo code: + let carry_bit_len = LEN + 1; + + # 1 << LEN + let mut pow2 = 1; + for ii in 0..LEN { + pow2 = pow2 + pow2; + } + + let sum = (pow2 + lhs) - rhs; + let sum_bit = bits::to_bits(carry_bit_len, sum); + + let b1 = false; + let b2 = true; + let res = if sum_bit[LEN] { b1 } else { b2 }; + + */ + + let modulus_bits: usize = B::Field::modulus_biguint() + .bits() + .try_into() + .expect("can't determine the number of bits in the modulus"); + + let bitlen_upper_bound = modulus_bits - 2; + let bit_len = bitlen.unwrap_or(bitlen_upper_bound); + + assert!(bit_len <= (bitlen_upper_bound)); + + let carry_bit_len = bit_len + 1; + + // let pow2 = (1 << bit_len) as u32; + // let pow2 = B::Field::from(pow2); + let two = B::Field::from(2u32); + let pow2 = two.pow([bit_len as u64]); + + // let pow2_lhs = compiler.backend.add_const(lhs, &pow2, span); + match (lhs, rhs) { + (ConstOrCell::Const(lhs), ConstOrCell::Const(rhs)) => { + let res = if lhs < rhs { one } else { zero }; + + Var::new_constant(res, span) + } + (_, _) => { + let pow2_lhs = match lhs { + // todo: we really should refactor the backend to handle ConstOrCell + ConstOrCell::Const(lhs) => { + compiler + .backend + .add_constant(Some("wrap a constant as var"), *lhs + pow2, span) + } + ConstOrCell::Cell(lhs) => compiler.backend.add_const(lhs, &pow2, span), + }; + + let rhs = match rhs { + ConstOrCell::Const(rhs) => { + compiler + .backend + .add_constant(Some("wrap a constant as var"), *rhs, span) + } + ConstOrCell::Cell(rhs) => rhs.clone(), + }; + + let sum = compiler.backend.sub(&pow2_lhs, &rhs, span); + + // todo: this api call is kind of weird here, maybe these bulitin shouldn't get inputs from the `GenericParameters` + let generic_var_name = "LEN".to_string(); + let mut gens = GenericParameters::default(); + gens.add(generic_var_name.clone()); + gens.assign(&generic_var_name, carry_bit_len as u32, span) + .unwrap(); + + // construct var info for sum + let cbl_var = Var::new_constant(B::Field::from(carry_bit_len as u32), span); + let cbl_var = VarInfo::new(cbl_var, false, Some(TyKind::Field { constant: true })); + + let sum_var = Var::new_var(sum, span); + let sum_var = VarInfo::new(sum_var, false, Some(TyKind::Field { constant: false })); + + let sum_bits = to_bits(compiler, &gens, &[cbl_var, sum_var], span) + .unwrap() + .unwrap(); + // convert to cell vars + let sum_bits: Vec<_> = sum_bits.cvars.into_iter().collect(); + + // if sum_bits[LEN] == 0, then lhs < rhs + let res = &is_zero_cell(compiler, &sum_bits[bit_len], span)[0]; + let res = res.cvar().unwrap(); + Var::new_var(res.clone(), span) + } + } +} + /// Returns 1 if var is zero, 0 otherwise fn is_zero_cell( compiler: &mut CircuitWriter, diff --git a/src/lexer/mod.rs b/src/lexer/mod.rs index 89784b213..9f3435700 100644 --- a/src/lexer/mod.rs +++ b/src/lexer/mod.rs @@ -143,10 +143,13 @@ pub enum TokenKind { Minus, // - RightArrow, // -> Star, // * + Percent, // % Ampersand, // & DoubleAmpersand, // && Pipe, // | DoublePipe, // || + LeftShift, // << + RightShift, // >> Exclamation, // ! Question, // ? // Literal, // "thing" @@ -184,12 +187,15 @@ impl Display for TokenKind { Minus => "`-`", RightArrow => "`->`", Star => "`*`", + Percent => "%", Ampersand => "`&`", DoubleAmpersand => "`&&`", Pipe => "`|`", DoublePipe => "`||`", Exclamation => "`!`", Question => "`?`", + LeftShift => "`<<`", + RightShift => "`>>`", // TokenType::Literal => "`\"something\"", }; @@ -363,10 +369,22 @@ impl Token { } } '>' => { - tokens.push(TokenKind::Greater.new_token(ctx, 1)); + let next_c = chars.peek(); + if matches!(next_c, Some(&'>')) { + tokens.push(TokenKind::RightShift.new_token(ctx, 2)); + chars.next(); + } else { + tokens.push(TokenKind::Greater.new_token(ctx, 1)); + } } '<' => { - tokens.push(TokenKind::Less.new_token(ctx, 1)); + let next_c = chars.peek(); + if matches!(next_c, Some(&'<')) { + tokens.push(TokenKind::LeftShift.new_token(ctx, 2)); + chars.next(); + } else { + tokens.push(TokenKind::Less.new_token(ctx, 1)); + } } '=' => { let next_c = chars.peek(); @@ -392,6 +410,9 @@ impl Token { '*' => { tokens.push(TokenKind::Star.new_token(ctx, 1)); } + '%' => { + tokens.push(TokenKind::Percent.new_token(ctx, 1)); + } '&' => { let next_c = chars.peek(); if matches!(next_c, Some(&'&')) { diff --git a/src/mast/mod.rs b/src/mast/mod.rs index 60039c45d..4933215e1 100644 --- a/src/mast/mod.rs +++ b/src/mast/mod.rs @@ -549,14 +549,15 @@ fn monomorphize_expr( let rhs_mono = monomorphize_expr(ctx, rhs, mono_fn_env)?; let typ = match op { - Op2::Equality => Some(TyKind::Bool), - Op2::Inequality => Some(TyKind::Bool), + Op2::Equality | Op2::Inequality | Op2::LessThan => Some(TyKind::Bool), Op2::Addition | Op2::Subtraction | Op2::Multiplication | Op2::Division + | Op2::Modulus | Op2::BoolAnd - | Op2::BoolOr => lhs_mono.typ, + | Op2::BoolOr + | Op2::LeftShift => lhs_mono.clone().typ, }; let ExprMonoInfo { expr: lhs_expr, .. } = lhs_mono; diff --git a/src/parser/expr.rs b/src/parser/expr.rs index 9fed71992..eca42c91a 100644 --- a/src/parser/expr.rs +++ b/src/parser/expr.rs @@ -129,8 +129,11 @@ pub enum Op2 { Subtraction, Multiplication, Division, + Modulus, Equality, Inequality, + LessThan, + LeftShift, BoolAnd, BoolOr, } @@ -436,11 +439,14 @@ impl Expr { | TokenKind::Minus | TokenKind::Star | TokenKind::Slash + | TokenKind::Percent | TokenKind::DoubleEqual | TokenKind::NotEqual + | TokenKind::Less | TokenKind::DoubleAmpersand | TokenKind::DoublePipe - | TokenKind::Exclamation, + | TokenKind::Exclamation + | TokenKind::LeftShift, .. }) => { // lhs + rhs @@ -450,8 +456,11 @@ impl Expr { TokenKind::Minus => Op2::Subtraction, TokenKind::Star => Op2::Multiplication, TokenKind::Slash => Op2::Division, + TokenKind::Percent => Op2::Modulus, TokenKind::DoubleEqual => Op2::Equality, TokenKind::NotEqual => Op2::Inequality, + TokenKind::Less => Op2::LessThan, + TokenKind::LeftShift => Op2::LeftShift, TokenKind::DoubleAmpersand => Op2::BoolAnd, TokenKind::DoublePipe => Op2::BoolOr, _ => unreachable!(), diff --git a/src/stdlib/bits.rs b/src/stdlib/bits.rs index 3fe256ee2..d29258a90 100644 --- a/src/stdlib/bits.rs +++ b/src/stdlib/bits.rs @@ -28,7 +28,7 @@ impl Module for BitsLib { } } -fn to_bits( +pub fn to_bits( compiler: &mut CircuitWriter, generics: &GenericParameters, vars: &[VarInfo], diff --git a/src/tests/examples.rs b/src/tests/examples.rs index 3aade9de0..95ce5fa11 100644 --- a/src/tests/examples.rs +++ b/src/tests/examples.rs @@ -309,6 +309,18 @@ fn test_not_equal(#[case] backend: BackendKind) -> miette::Result<()> { Ok(()) } +#[rstest] +#[case::kimchi_vesta(BackendKind::KimchiVesta(KimchiVesta::new(false)))] +#[case::r1cs(BackendKind::R1csBls12_381(R1CS::new()))] +fn test_comparator(#[case] backend: BackendKind) -> miette::Result<()> { + let public_inputs = r#"{"xx": "1"}"#; + let private_inputs = r#"{"yy": "2"}"#; + + test_file("comparator", public_inputs, private_inputs, vec![], backend)?; + + Ok(()) +} + #[rstest] #[case::kimchi_vesta(BackendKind::KimchiVesta(KimchiVesta::new(false)))] #[case::r1cs(BackendKind::R1csBls12_381(R1CS::new()))] diff --git a/src/type_checker/checker.rs b/src/type_checker/checker.rs index 7046d8141..281985ecd 100644 --- a/src/type_checker/checker.rs +++ b/src/type_checker/checker.rs @@ -302,14 +302,15 @@ impl TypeChecker { } let typ = match op { - Op2::Equality => TyKind::Bool, - Op2::Inequality => TyKind::Bool, + Op2::Equality | Op2::Inequality | Op2::LessThan => TyKind::Bool, Op2::Addition | Op2::Subtraction | Op2::Multiplication | Op2::Division + | Op2::Modulus | Op2::BoolAnd - | Op2::BoolOr => lhs_node.typ, + | Op2::BoolOr + | Op2::LeftShift => lhs_node.typ, }; Some(ExprTyInfo::new_anon(typ)) diff --git a/src/var.rs b/src/var.rs index e183a7085..fcee95943 100644 --- a/src/var.rs +++ b/src/var.rs @@ -57,6 +57,20 @@ where // but does this make sense to all different backends? is it possible that some backend doesn't allow certain out of circuit calculations like this? NthBit(B::Var, usize), + /// Divide + // todo: refactor to use a argument wrapper to encapsulate its own type, + // so that a variant can have an argument to be either B::Var or B::Field + CstDivVar(B::Field, B::Var), + VarDivCst(B::Var, B::Field), + VarDivVar(B::Var, B::Var), + CstDivCst(B::Field, B::Field), + + /// Modulo + VarModVar(B::Var, B::Var), + CstModVar(B::Field, B::Var), + VarModCst(B::Var, B::Field), + CstModCst(B::Field, B::Field), + /// A public or private input to the function /// There's an index associated to a variable name, as the variable could be composed of several field elements. External(String, usize), @@ -78,6 +92,14 @@ impl std::fmt::Debug for Value { Value::PublicOutput(..) => write!(f, "PublicOutput"), Value::Scale(..) => write!(f, "Scaling"), Value::NthBit(_, _) => write!(f, "NthBit"), + Value::CstDivVar(_, _) => write!(f, "CstDivVar"), + Value::VarDivCst(_, _) => write!(f, "VarDivCst"), + Value::VarDivVar(_, _) => write!(f, "VarDivVar"), + Value::CstDivCst(_, _) => write!(f, "CstDivCst"), + Value::VarModVar(_, _) => write!(f, "VarModVar"), + Value::CstModVar(_, _) => write!(f, "CstModVar"), + Value::VarModCst(_, _) => write!(f, "VarModCst"), + Value::CstModCst(_, _) => write!(f, "CstModCst"), } } }