From be3e073b673a7a266b583701acb8f5345060cf46 Mon Sep 17 00:00:00 2001 From: kata Date: Wed, 2 Oct 2024 15:43:47 +0800 Subject: [PATCH 1/4] support builtin operator < --- examples/comparator.no | 4 ++ src/circuit_writer/writer.rs | 3 + src/constraints/field.rs | 115 +++++++++++++++++++++++++++++++++++ src/mast/mod.rs | 5 +- src/parser/expr.rs | 3 + src/tests/examples.rs | 12 ++++ src/type_checker/checker.rs | 5 +- 7 files changed, 143 insertions(+), 4 deletions(-) create mode 100644 examples/comparator.no diff --git a/examples/comparator.no b/examples/comparator.no new file mode 100644 index 000000000..d508fbf53 --- /dev/null +++ b/examples/comparator.no @@ -0,0 +1,4 @@ +fn main(pub xx: Field, yy: Field) { + let res = xx < yy; + assert(res); +} diff --git a/src/circuit_writer/writer.rs b/src/circuit_writer/writer.rs index 924c57b6f..83b493150 100644 --- a/src/circuit_writer/writer.rs +++ b/src/circuit_writer/writer.rs @@ -621,6 +621,9 @@ impl CircuitWriter { Op2::Multiplication => field::mul(self, &lhs[0], &rhs[0], expr.span), Op2::Equality => field::equal(self, &lhs, &rhs, expr.span), Op2::Inequality => field::not_equal(self, &lhs, &rhs, expr.span), + // todo: refactor the input vars from Var to VarInfo, + // which contain the type to provide the info about the bit length + Op2::LessThan => field::less_than(self, None, &lhs[0], &rhs[0], expr.span), Op2::BoolAnd => boolean::and(self, &lhs[0], &rhs[0], expr.span), Op2::BoolOr => boolean::or(self, &lhs[0], &rhs[0], expr.span), Op2::Division => todo!(), diff --git a/src/constraints/field.rs b/src/constraints/field.rs index 38e120117..1d2f582ed 100644 --- a/src/constraints/field.rs +++ b/src/constraints/field.rs @@ -254,6 +254,121 @@ pub fn not_equal( acc } +/// Returns 1 if lhs < rhs, 0 otherwise +pub fn less_than( + compiler: &mut CircuitWriter, + bitlen: Option, + lhs: &ConstOrCell, + rhs: &ConstOrCell, + span: Span, +) -> Var { + let one = B::Field::one(); + let zero = B::Field::zero(); + + // Instead of comparing bit by bit, we check the carry bit: + // lhs + (1 << LEN) - rhs + // proof: + // lhs + (1 << LEN) will add a carry bit, valued 1, to the bit array representing lhs, + // resulted in a bit array of length LEN + 1, named as sum_bits. + // if `lhs < rhs``, then `lhs - rhs < 0`, thus `(1 << LEN) + lhs - rhs < (1 << LEN)` + // then, the carry bit of sum_bits is 0. + // otherwise, the carry bit of sum_bits is 1. + + /* + psuedo code: + let carry_bit_len = LEN + 1; + + # 1 << LEN + let mut pow2 = 1; + for ii in 0..LEN { + pow2 = pow2 + pow2; + } + + let sum = (pow2 + lhs) - rhs; + let sum_bit = bits::to_bits(carry_bit_len, sum); + + let b1 = false; + let b2 = true; + let res = if sum_bit[LEN] { b1 } else { b2 }; + + */ + + let modulus_bits: usize = B::Field::modulus_biguint() + .bits() + .try_into() + .expect("can't determine the number of bits in the modulus"); + + let bitlen_upper_bound = modulus_bits - 2; + let bit_len = bitlen.unwrap_or(bitlen_upper_bound); + + assert!(bit_len <= (bitlen_upper_bound)); + + + let carry_bit_len = bit_len + 1; + + + // let pow2 = (1 << bit_len) as u32; + // let pow2 = B::Field::from(pow2); + let two = B::Field::from(2u32); + let pow2 = two.pow([bit_len as u64]); + + // let pow2_lhs = compiler.backend.add_const(lhs, &pow2, span); + match (lhs, rhs) { + (ConstOrCell::Const(lhs), ConstOrCell::Const(rhs)) => { + let res = if lhs < rhs { one } else { zero }; + + Var::new_constant(res, span) + } + (_, _) => { + let pow2_lhs = match lhs { + // todo: we really should refactor the backend to handle ConstOrCell + ConstOrCell::Const(lhs) => compiler.backend.add_constant( + Some("wrap a constant as var"), + *lhs + pow2, + span, + ), + ConstOrCell::Cell(lhs) => compiler.backend.add_const(lhs, &pow2, span), + }; + + let rhs = match rhs { + ConstOrCell::Const(rhs) => compiler.backend.add_constant( + Some("wrap a constant as var"), + *rhs, + span, + ), + ConstOrCell::Cell(rhs) => rhs.clone(), + }; + + let sum = compiler.backend.sub(&pow2_lhs, &rhs, span); + + // todo: this api call is kind of weird here, maybe these bulitin shouldn't get inputs from the `GenericParameters` + let generic_var_name = "LEN".to_string(); + let mut gens = GenericParameters::default(); + gens.add(generic_var_name.clone()); + gens.assign(&generic_var_name, carry_bit_len as u32, span) + .unwrap(); + + // construct var info for sum + let cbl_var = Var::new_constant(B::Field::from(carry_bit_len as u32), span); + let cbl_var = VarInfo::new(cbl_var, false, Some(TyKind::Field { constant: true })); + + let sum_var = Var::new_var(sum, span); + let sum_var = VarInfo::new(sum_var, false, Some(TyKind::Field { constant: false })); + + let sum_bits = to_bits(compiler, &gens, &[cbl_var, sum_var], span).unwrap().unwrap(); + // convert to cell vars + let sum_bits: Vec<_> = sum_bits.cvars.into_iter().collect(); + + // if sum_bits[LEN] == 0, then lhs < rhs + let res = &is_zero_cell(compiler, &sum_bits[bit_len], span)[0]; + let res = res + .cvar() + .unwrap(); + Var::new_var(res.clone(), span) + } + } +} + /// Returns 1 if var is zero, 0 otherwise fn is_zero_cell( compiler: &mut CircuitWriter, diff --git a/src/mast/mod.rs b/src/mast/mod.rs index 60039c45d..3159fd452 100644 --- a/src/mast/mod.rs +++ b/src/mast/mod.rs @@ -549,8 +549,9 @@ fn monomorphize_expr( let rhs_mono = monomorphize_expr(ctx, rhs, mono_fn_env)?; let typ = match op { - Op2::Equality => Some(TyKind::Bool), - Op2::Inequality => Some(TyKind::Bool), + Op2::Equality + | Op2::Inequality + | Op2::LessThan => Some(TyKind::Bool), Op2::Addition | Op2::Subtraction | Op2::Multiplication diff --git a/src/parser/expr.rs b/src/parser/expr.rs index 9fed71992..74144a4f7 100644 --- a/src/parser/expr.rs +++ b/src/parser/expr.rs @@ -131,6 +131,7 @@ pub enum Op2 { Division, Equality, Inequality, + LessThan, BoolAnd, BoolOr, } @@ -438,6 +439,7 @@ impl Expr { | TokenKind::Slash | TokenKind::DoubleEqual | TokenKind::NotEqual + | TokenKind::Less | TokenKind::DoubleAmpersand | TokenKind::DoublePipe | TokenKind::Exclamation, @@ -452,6 +454,7 @@ impl Expr { TokenKind::Slash => Op2::Division, TokenKind::DoubleEqual => Op2::Equality, TokenKind::NotEqual => Op2::Inequality, + TokenKind::Less => Op2::LessThan, TokenKind::DoubleAmpersand => Op2::BoolAnd, TokenKind::DoublePipe => Op2::BoolOr, _ => unreachable!(), diff --git a/src/tests/examples.rs b/src/tests/examples.rs index 3aade9de0..95ce5fa11 100644 --- a/src/tests/examples.rs +++ b/src/tests/examples.rs @@ -309,6 +309,18 @@ fn test_not_equal(#[case] backend: BackendKind) -> miette::Result<()> { Ok(()) } +#[rstest] +#[case::kimchi_vesta(BackendKind::KimchiVesta(KimchiVesta::new(false)))] +#[case::r1cs(BackendKind::R1csBls12_381(R1CS::new()))] +fn test_comparator(#[case] backend: BackendKind) -> miette::Result<()> { + let public_inputs = r#"{"xx": "1"}"#; + let private_inputs = r#"{"yy": "2"}"#; + + test_file("comparator", public_inputs, private_inputs, vec![], backend)?; + + Ok(()) +} + #[rstest] #[case::kimchi_vesta(BackendKind::KimchiVesta(KimchiVesta::new(false)))] #[case::r1cs(BackendKind::R1csBls12_381(R1CS::new()))] diff --git a/src/type_checker/checker.rs b/src/type_checker/checker.rs index 7046d8141..e7f437ee0 100644 --- a/src/type_checker/checker.rs +++ b/src/type_checker/checker.rs @@ -302,8 +302,9 @@ impl TypeChecker { } let typ = match op { - Op2::Equality => TyKind::Bool, - Op2::Inequality => TyKind::Bool, + Op2::Equality + | Op2::Inequality + | Op2::LessThan => TyKind::Bool, Op2::Addition | Op2::Subtraction | Op2::Multiplication From d2870e40acee3ccf7a5636b23897900bfbf44371 Mon Sep 17 00:00:00 2001 From: kata Date: Wed, 2 Oct 2024 16:26:28 +0800 Subject: [PATCH 2/4] support builtin operator / % --- examples/arithmetic.no | 17 +++++ src/backends/mod.rs | 64 ++++++++++++++++ src/circuit_writer/writer.rs | 3 +- src/constraints/field.rs | 144 ++++++++++++++++++++++++++++++++++- src/lexer/mod.rs | 5 ++ src/mast/mod.rs | 1 + src/parser/expr.rs | 3 + src/type_checker/checker.rs | 1 + src/var.rs | 22 ++++++ 9 files changed, 257 insertions(+), 3 deletions(-) diff --git a/examples/arithmetic.no b/examples/arithmetic.no index ea3073c60..6b0c673f3 100644 --- a/examples/arithmetic.no +++ b/examples/arithmetic.no @@ -2,4 +2,21 @@ fn main(pub public_input: Field, private_input: Field) { let xx = private_input + public_input; let yy = private_input * public_input; assert_eq(xx, yy); + + // modulus constant + let mod_res = (xx + yy) % 3; + assert_eq(mod_res, 2); + + // modulus var + let mod_var_res = (xx + yy) % (private_input + 1); + assert_eq(mod_var_res, 2); + + // divide constant + let div_res = xx / 2; + assert_eq(div_res, 2); + + // divide var + let div_var_res = xx / private_input; + assert_eq(div_res, 2); + } diff --git a/src/backends/mod.rs b/src/backends/mod.rs index fabfa2b83..0fa3368e9 100644 --- a/src/backends/mod.rs +++ b/src/backends/mod.rs @@ -198,6 +198,70 @@ pub trait Backend: Clone { Ok(res) } + Value::VarDivVar(dividend, divisor) => { + let dividend = self.compute_var(env, dividend)?.to_biguint(); + let divisor = self.compute_var(env, divisor)?.to_biguint(); + let res = Self::Field::from(dividend / divisor); + + env.cached_values.insert(cache_key, res); + Ok(res) + } + Value::CstDivVar(dividend, divisor) => { + let divisor = self.compute_var(env, divisor)?; + let res = *dividend / divisor; + Ok(res) + } + Value::VarDivCst(dividend, divisor) => { + let dividend = self.compute_var(env, dividend)?; + // convert to bigint + let dividend = dividend.to_biguint(); + let divisor = divisor.to_biguint(); + + let res = Self::Field::from(dividend / divisor); + env.cached_values.insert(cache_key, res); + Ok(res) + } + Value::CstDivCst(dividend, divisor) => { + let res = *dividend / *divisor; + env.cached_values.insert(cache_key, res); + Ok(res) + } + Value::VarModVar(dividend, divisor) => { + let dividend = self.compute_var(env, dividend)?; + let divisor = self.compute_var(env, divisor)?; + // convert to bigint + let dividend = dividend.to_biguint(); + let divisor = divisor.to_biguint(); + let res = Self::Field::from(dividend % divisor); + + env.cached_values.insert(cache_key, res); + Ok(res) + } + Value::CstModVar(dividend, divisor) => { + let divisor = self.compute_var(env, divisor)?; + // convert to bigint + let dividend = dividend.to_biguint(); + let divisor = divisor.to_biguint(); + let res = Self::Field::from(dividend % divisor); + env.cached_values.insert(cache_key, res); + Ok(res) + } + Value::VarModCst(dividend, divisor) => { + let dividend = self.compute_var(env, dividend)?; + // convert to bigint + let dividend = dividend.to_biguint(); + let divisor = divisor.to_biguint(); + let res = Self::Field::from(dividend % divisor); + env.cached_values.insert(cache_key, res); + Ok(res) + } + Value::CstModCst(dividend, divisor) => { + let dividend = dividend.to_biguint(); + let divisor = divisor.to_biguint(); + let res = Self::Field::from(dividend % divisor); + env.cached_values.insert(cache_key, res); + Ok(res) + } } } diff --git a/src/circuit_writer/writer.rs b/src/circuit_writer/writer.rs index 83b493150..c8b69565d 100644 --- a/src/circuit_writer/writer.rs +++ b/src/circuit_writer/writer.rs @@ -619,6 +619,8 @@ impl CircuitWriter { Op2::Addition => field::add(self, &lhs[0], &rhs[0], expr.span), Op2::Subtraction => field::sub(self, &lhs[0], &rhs[0], expr.span), Op2::Multiplication => field::mul(self, &lhs[0], &rhs[0], expr.span), + Op2::Modulus => field::modulus(self, &lhs[0], &rhs[0], expr.span), + Op2::Division => field::div(self, &lhs[0], &rhs[0], expr.span), Op2::Equality => field::equal(self, &lhs, &rhs, expr.span), Op2::Inequality => field::not_equal(self, &lhs, &rhs, expr.span), // todo: refactor the input vars from Var to VarInfo, @@ -626,7 +628,6 @@ impl CircuitWriter { Op2::LessThan => field::less_than(self, None, &lhs[0], &rhs[0], expr.span), Op2::BoolAnd => boolean::and(self, &lhs[0], &rhs[0], expr.span), Op2::BoolOr => boolean::or(self, &lhs[0], &rhs[0], expr.span), - Op2::Division => todo!(), }; Ok(Some(VarOrRef::Var(res))) diff --git a/src/constraints/field.rs b/src/constraints/field.rs index 1d2f582ed..8807114fa 100644 --- a/src/constraints/field.rs +++ b/src/constraints/field.rs @@ -1,13 +1,16 @@ use crate::{ backends::Backend, - circuit_writer::CircuitWriter, + circuit_writer::{CircuitWriter, VarInfo}, constants::Span, + parser::types::{GenericParameters, TyKind}, + stdlib::bits::to_bits, var::{ConstOrCell, Value, Var}, }; use super::boolean; -use ark_ff::{One, Zero}; +use ark_ff::{Field, One, PrimeField, Zero}; +use kimchi::o1_utils::FieldHelpers; use std::ops::Neg; @@ -99,6 +102,143 @@ pub fn mul( } } +fn constrain_div_mod( + compiler: &mut CircuitWriter, + lhs: &ConstOrCell, + rhs: &ConstOrCell, + span: Span, +) -> (B::Var, B::Var) { + // to constrain lhs − q * rhs − rem = 0 + // where rhs is the modulus + // so 0 <= rem < rhs + + let one = B::Field::one(); + + // todo: to avoid duplicating a lot of code due the different combinations of the input types + // until we refactor the backend to handle ConstOrCell or some kind of wrapper that encapsulate the different variable types + // convert cst to var for easier handling + let lhs = match lhs { + ConstOrCell::Const(lhs) => compiler.backend.add_constant( + Some("wrap a constant as var"), + *lhs, + span, + ), + ConstOrCell::Cell(lhs) => lhs.clone(), + }; + + let rhs = match rhs { + ConstOrCell::Const(rhs) => compiler.backend.add_constant( + Some("wrap a constant as var"), + *rhs, + span, + ), + ConstOrCell::Cell(rhs) => rhs.clone(), + }; + + // witness var for quotient + let q = Value::VarDivVar(lhs.clone(), rhs.clone()); + let q_var = compiler.backend.new_internal_var(q, span); + + // witness var for remainder + let rem = Value::VarModVar(lhs.clone(), rhs.clone()); + let rem_var = compiler.backend.new_internal_var(rem, span); + + // rem < rhs + let lt_rem = &less_than(compiler, None, &ConstOrCell::Cell(rem_var.clone()), &ConstOrCell::Cell(rhs.clone()), span)[0]; + let lt_rem = lt_rem.cvar().expect("expected a cell var"); + compiler.backend.assert_eq_const(lt_rem, one, span); + + // foundamental constraint: lhs - q * rhs - rem = 0 + let q_mul_rhs = compiler.backend.mul(&q_var, &rhs, span); + let lhs_sub_q_mul_rhs = compiler.backend.sub(&lhs, &q_mul_rhs, span); + + // cell representing the foundamental constraint + let fc_var = compiler.backend.sub(&lhs_sub_q_mul_rhs, &rem_var, span); + compiler.backend.assert_eq_const(&fc_var, B::Field::zero(), span); + + (rem_var, q_var) +} + +/// Divide operation +pub fn div( + compiler: &mut CircuitWriter, + lhs: &ConstOrCell, + rhs: &ConstOrCell, + span: Span, +) -> Var { + // to constrain lhs − q * rhs − rem = 0 + // rhs can't be zero + match rhs { + ConstOrCell::Const(rhs) => { + if rhs.is_zero() { + panic!("division by zero"); + } + } + _ => { + let is_zero = is_zero_cell(compiler, rhs, span); + let is_zero = is_zero[0].cvar().unwrap(); + compiler.backend.assert_eq_const(is_zero, B::Field::zero(), span); + } + }; + + match (lhs, rhs) { + // if rhs is a constant, we can just divide lhs by rhs + (ConstOrCell::Const(lhs), ConstOrCell::Const(rhs)) => { + // to bigint + let lhs = lhs.to_biguint(); + let rhs = rhs.to_biguint(); + let res = lhs / rhs; + + Var::new_constant(B::Field::from(res), span) + } + _ => { + let (_, q) = constrain_div_mod(compiler, lhs, rhs, span); + Var::new_var(q, span) + }, + } +} + +/// Modulus operation +pub fn modulus( + compiler: &mut CircuitWriter, + lhs: &ConstOrCell, + rhs: &ConstOrCell, + span: Span, +) -> Var { + // to constrain lhs − q * rhs − rem = 0 + + let zero = B::Field::zero(); + + // rhs can't be zero + match &rhs { + ConstOrCell::Const(rhs) => { + if rhs.is_zero() { + panic!("modulus by zero"); + } + } + _ => { + let is_zero = is_zero_cell(compiler, rhs, span); + let is_zero = is_zero[0].cvar().unwrap(); + compiler.backend.assert_eq_const(is_zero, zero, span); + } + }; + + match (lhs, rhs) { + // if rhs is a constant, we can just divide lhs by rhs + (ConstOrCell::Const(lhs), ConstOrCell::Const(rhs)) => { + let lhs = lhs.to_biguint(); + let rhs = rhs.to_biguint(); + let res = lhs % rhs; + + Var::new_constant(res.into(), span) + } + _ => { + let (rem, _) = constrain_div_mod(compiler, lhs, rhs, span); + Var::new_var(rem, span) + } + } +} + /// This takes variables that can be anything, and returns a boolean // TODO: so perhaps it's not really relevant in this file? pub fn equal( diff --git a/src/lexer/mod.rs b/src/lexer/mod.rs index 89784b213..75aae065b 100644 --- a/src/lexer/mod.rs +++ b/src/lexer/mod.rs @@ -143,6 +143,7 @@ pub enum TokenKind { Minus, // - RightArrow, // -> Star, // * + Percent, // % Ampersand, // & DoubleAmpersand, // && Pipe, // | @@ -184,6 +185,7 @@ impl Display for TokenKind { Minus => "`-`", RightArrow => "`->`", Star => "`*`", + Percent => "%", Ampersand => "`&`", DoubleAmpersand => "`&&`", Pipe => "`|`", @@ -392,6 +394,9 @@ impl Token { '*' => { tokens.push(TokenKind::Star.new_token(ctx, 1)); } + '%' => { + tokens.push(TokenKind::Percent.new_token(ctx, 1)); + } '&' => { let next_c = chars.peek(); if matches!(next_c, Some(&'&')) { diff --git a/src/mast/mod.rs b/src/mast/mod.rs index 3159fd452..3d658b5fd 100644 --- a/src/mast/mod.rs +++ b/src/mast/mod.rs @@ -556,6 +556,7 @@ fn monomorphize_expr( | Op2::Subtraction | Op2::Multiplication | Op2::Division + | Op2::Modulus | Op2::BoolAnd | Op2::BoolOr => lhs_mono.typ, }; diff --git a/src/parser/expr.rs b/src/parser/expr.rs index 74144a4f7..63a3c853e 100644 --- a/src/parser/expr.rs +++ b/src/parser/expr.rs @@ -129,6 +129,7 @@ pub enum Op2 { Subtraction, Multiplication, Division, + Modulus, Equality, Inequality, LessThan, @@ -437,6 +438,7 @@ impl Expr { | TokenKind::Minus | TokenKind::Star | TokenKind::Slash + | TokenKind::Percent | TokenKind::DoubleEqual | TokenKind::NotEqual | TokenKind::Less @@ -452,6 +454,7 @@ impl Expr { TokenKind::Minus => Op2::Subtraction, TokenKind::Star => Op2::Multiplication, TokenKind::Slash => Op2::Division, + TokenKind::Percent => Op2::Modulus, TokenKind::DoubleEqual => Op2::Equality, TokenKind::NotEqual => Op2::Inequality, TokenKind::Less => Op2::LessThan, diff --git a/src/type_checker/checker.rs b/src/type_checker/checker.rs index e7f437ee0..8e46e34b4 100644 --- a/src/type_checker/checker.rs +++ b/src/type_checker/checker.rs @@ -309,6 +309,7 @@ impl TypeChecker { | Op2::Subtraction | Op2::Multiplication | Op2::Division + | Op2::Modulus | Op2::BoolAnd | Op2::BoolOr => lhs_node.typ, }; diff --git a/src/var.rs b/src/var.rs index e183a7085..e1ea00076 100644 --- a/src/var.rs +++ b/src/var.rs @@ -57,6 +57,20 @@ where // but does this make sense to all different backends? is it possible that some backend doesn't allow certain out of circuit calculations like this? NthBit(B::Var, usize), + /// Divide + // todo: refactor to use a argument wrapper to encapsulate its own type, + // so that a variant can have an argument to be either B::Var or B::Field + CstDivVar(B::Field, B::Var), + VarDivCst(B::Var, B::Field), + VarDivVar(B::Var, B::Var), + CstDivCst(B::Field, B::Field), + + /// Modulo + VarModVar(B::Var, B::Var), + CstModVar(B::Field, B::Var), + VarModCst(B::Var, B::Field), + CstModCst(B::Field, B::Field), + /// A public or private input to the function /// There's an index associated to a variable name, as the variable could be composed of several field elements. External(String, usize), @@ -78,6 +92,14 @@ impl std::fmt::Debug for Value { Value::PublicOutput(..) => write!(f, "PublicOutput"), Value::Scale(..) => write!(f, "Scaling"), Value::NthBit(_, _) => write!(f, "NthBit"), + Value::CstDivVar(_, _) => write!(f, "CstDivVar"), + Value::VarDivCst(_, _) => write!(f, "VarDivCst"), + Value::VarDivVar(_, _) => write!(f, "VarDivVar"), + Value::CstDivCst(_, _) => write!(f, "CstDivCst"), + Value::VarModVar(_, _) => write!(f, "VarModVar"), + Value::CstModVar(_, _) => write!(f, "CstModVar"), + Value::VarModCst(_, _) => write!(f, "VarModCst"), + Value::CstModCst(_, _) => write!(f, "CstModCst"), } } } From 5b955003d384ee7e1e2b34315dbe70081b7efb67 Mon Sep 17 00:00:00 2001 From: kata Date: Wed, 2 Oct 2024 16:26:41 +0800 Subject: [PATCH 3/4] support builtin operator << --- examples/arithmetic.no | 3 +++ src/circuit_writer/writer.rs | 1 + src/constraints/field.rs | 27 +++++++++++++++++++++++++++ src/lexer/mod.rs | 20 ++++++++++++++++++-- src/mast/mod.rs | 3 ++- src/parser/expr.rs | 5 ++++- src/stdlib/bits.rs | 2 +- src/type_checker/checker.rs | 3 ++- 8 files changed, 58 insertions(+), 6 deletions(-) diff --git a/examples/arithmetic.no b/examples/arithmetic.no index 6b0c673f3..8dc3bdca0 100644 --- a/examples/arithmetic.no +++ b/examples/arithmetic.no @@ -19,4 +19,7 @@ fn main(pub public_input: Field, private_input: Field) { let div_var_res = xx / private_input; assert_eq(div_res, 2); + // left shift + let lf_res = xx << 2; + assert_eq(lf_res, 16); } diff --git a/src/circuit_writer/writer.rs b/src/circuit_writer/writer.rs index c8b69565d..116698a25 100644 --- a/src/circuit_writer/writer.rs +++ b/src/circuit_writer/writer.rs @@ -626,6 +626,7 @@ impl CircuitWriter { // todo: refactor the input vars from Var to VarInfo, // which contain the type to provide the info about the bit length Op2::LessThan => field::less_than(self, None, &lhs[0], &rhs[0], expr.span), + Op2::LeftShift => field::left_shift(self, &lhs[0], &rhs[0], expr.span), Op2::BoolAnd => boolean::and(self, &lhs[0], &rhs[0], expr.span), Op2::BoolOr => boolean::or(self, &lhs[0], &rhs[0], expr.span), }; diff --git a/src/constraints/field.rs b/src/constraints/field.rs index 8807114fa..05017d9dd 100644 --- a/src/constraints/field.rs +++ b/src/constraints/field.rs @@ -239,6 +239,33 @@ pub fn modulus( } } +/// Left shift operation +pub fn left_shift( + compiler: &mut CircuitWriter, + lhs: &ConstOrCell, + rhs: &ConstOrCell, + span: Span, +) -> Var { + // to constrain lhs * (1 << rhs) = res + + match (lhs, rhs) { + (ConstOrCell::Const(lhs), ConstOrCell::Const(rhs)) => { + // convert to bigint + let pow2 = B::Field::from(2u32).pow(rhs.into_repr()); + let res = *lhs * pow2; + Var::new_constant(res, span) + } + (ConstOrCell::Cell(lhs_), ConstOrCell::Const(rhs_)) => { + let pow2 = B::Field::from(2u32).pow(rhs_.into_repr()); + let res = compiler.backend.mul_const(lhs_, &pow2, span); + Var::new_var(res, span) + } + // todo: wrap rhs in a symbolic value + (ConstOrCell::Const(_), ConstOrCell::Cell(_)) => todo!(), + (ConstOrCell::Cell(_), ConstOrCell::Cell(_)) => todo!(), + } +} + /// This takes variables that can be anything, and returns a boolean // TODO: so perhaps it's not really relevant in this file? pub fn equal( diff --git a/src/lexer/mod.rs b/src/lexer/mod.rs index 75aae065b..9f3435700 100644 --- a/src/lexer/mod.rs +++ b/src/lexer/mod.rs @@ -148,6 +148,8 @@ pub enum TokenKind { DoubleAmpersand, // && Pipe, // | DoublePipe, // || + LeftShift, // << + RightShift, // >> Exclamation, // ! Question, // ? // Literal, // "thing" @@ -192,6 +194,8 @@ impl Display for TokenKind { DoublePipe => "`||`", Exclamation => "`!`", Question => "`?`", + LeftShift => "`<<`", + RightShift => "`>>`", // TokenType::Literal => "`\"something\"", }; @@ -365,10 +369,22 @@ impl Token { } } '>' => { - tokens.push(TokenKind::Greater.new_token(ctx, 1)); + let next_c = chars.peek(); + if matches!(next_c, Some(&'>')) { + tokens.push(TokenKind::RightShift.new_token(ctx, 2)); + chars.next(); + } else { + tokens.push(TokenKind::Greater.new_token(ctx, 1)); + } } '<' => { - tokens.push(TokenKind::Less.new_token(ctx, 1)); + let next_c = chars.peek(); + if matches!(next_c, Some(&'<')) { + tokens.push(TokenKind::LeftShift.new_token(ctx, 2)); + chars.next(); + } else { + tokens.push(TokenKind::Less.new_token(ctx, 1)); + } } '=' => { let next_c = chars.peek(); diff --git a/src/mast/mod.rs b/src/mast/mod.rs index 3d658b5fd..9f568ee02 100644 --- a/src/mast/mod.rs +++ b/src/mast/mod.rs @@ -558,7 +558,8 @@ fn monomorphize_expr( | Op2::Division | Op2::Modulus | Op2::BoolAnd - | Op2::BoolOr => lhs_mono.typ, + | Op2::BoolOr + | Op2::LeftShift => lhs_mono.clone().typ, }; let ExprMonoInfo { expr: lhs_expr, .. } = lhs_mono; diff --git a/src/parser/expr.rs b/src/parser/expr.rs index 63a3c853e..eca42c91a 100644 --- a/src/parser/expr.rs +++ b/src/parser/expr.rs @@ -133,6 +133,7 @@ pub enum Op2 { Equality, Inequality, LessThan, + LeftShift, BoolAnd, BoolOr, } @@ -444,7 +445,8 @@ impl Expr { | TokenKind::Less | TokenKind::DoubleAmpersand | TokenKind::DoublePipe - | TokenKind::Exclamation, + | TokenKind::Exclamation + | TokenKind::LeftShift, .. }) => { // lhs + rhs @@ -458,6 +460,7 @@ impl Expr { TokenKind::DoubleEqual => Op2::Equality, TokenKind::NotEqual => Op2::Inequality, TokenKind::Less => Op2::LessThan, + TokenKind::LeftShift => Op2::LeftShift, TokenKind::DoubleAmpersand => Op2::BoolAnd, TokenKind::DoublePipe => Op2::BoolOr, _ => unreachable!(), diff --git a/src/stdlib/bits.rs b/src/stdlib/bits.rs index 3fe256ee2..d29258a90 100644 --- a/src/stdlib/bits.rs +++ b/src/stdlib/bits.rs @@ -28,7 +28,7 @@ impl Module for BitsLib { } } -fn to_bits( +pub fn to_bits( compiler: &mut CircuitWriter, generics: &GenericParameters, vars: &[VarInfo], diff --git a/src/type_checker/checker.rs b/src/type_checker/checker.rs index 8e46e34b4..ab355f4a3 100644 --- a/src/type_checker/checker.rs +++ b/src/type_checker/checker.rs @@ -311,7 +311,8 @@ impl TypeChecker { | Op2::Division | Op2::Modulus | Op2::BoolAnd - | Op2::BoolOr => lhs_node.typ, + | Op2::BoolOr + | Op2::LeftShift => lhs_node.typ, }; Some(ExprTyInfo::new_anon(typ)) From 726e31cf3001d66c75727c1ca0134ea12e189e97 Mon Sep 17 00:00:00 2001 From: kata Date: Wed, 2 Oct 2024 16:26:51 +0800 Subject: [PATCH 4/4] fmt --- src/circuit_writer/writer.rs | 2 +- src/constraints/field.rs | 68 ++++++++++++++++++++---------------- src/mast/mod.rs | 6 ++-- src/type_checker/checker.rs | 6 ++-- src/var.rs | 4 +-- 5 files changed, 45 insertions(+), 41 deletions(-) diff --git a/src/circuit_writer/writer.rs b/src/circuit_writer/writer.rs index 116698a25..183b713a6 100644 --- a/src/circuit_writer/writer.rs +++ b/src/circuit_writer/writer.rs @@ -623,7 +623,7 @@ impl CircuitWriter { Op2::Division => field::div(self, &lhs[0], &rhs[0], expr.span), Op2::Equality => field::equal(self, &lhs, &rhs, expr.span), Op2::Inequality => field::not_equal(self, &lhs, &rhs, expr.span), - // todo: refactor the input vars from Var to VarInfo, + // todo: refactor the input vars from Var to VarInfo, // which contain the type to provide the info about the bit length Op2::LessThan => field::less_than(self, None, &lhs[0], &rhs[0], expr.span), Op2::LeftShift => field::left_shift(self, &lhs[0], &rhs[0], expr.span), diff --git a/src/constraints/field.rs b/src/constraints/field.rs index 05017d9dd..5ee221572 100644 --- a/src/constraints/field.rs +++ b/src/constraints/field.rs @@ -118,20 +118,20 @@ fn constrain_div_mod( // until we refactor the backend to handle ConstOrCell or some kind of wrapper that encapsulate the different variable types // convert cst to var for easier handling let lhs = match lhs { - ConstOrCell::Const(lhs) => compiler.backend.add_constant( - Some("wrap a constant as var"), - *lhs, - span, - ), + ConstOrCell::Const(lhs) => { + compiler + .backend + .add_constant(Some("wrap a constant as var"), *lhs, span) + } ConstOrCell::Cell(lhs) => lhs.clone(), }; let rhs = match rhs { - ConstOrCell::Const(rhs) => compiler.backend.add_constant( - Some("wrap a constant as var"), - *rhs, - span, - ), + ConstOrCell::Const(rhs) => { + compiler + .backend + .add_constant(Some("wrap a constant as var"), *rhs, span) + } ConstOrCell::Cell(rhs) => rhs.clone(), }; @@ -144,7 +144,13 @@ fn constrain_div_mod( let rem_var = compiler.backend.new_internal_var(rem, span); // rem < rhs - let lt_rem = &less_than(compiler, None, &ConstOrCell::Cell(rem_var.clone()), &ConstOrCell::Cell(rhs.clone()), span)[0]; + let lt_rem = &less_than( + compiler, + None, + &ConstOrCell::Cell(rem_var.clone()), + &ConstOrCell::Cell(rhs.clone()), + span, + )[0]; let lt_rem = lt_rem.cvar().expect("expected a cell var"); compiler.backend.assert_eq_const(lt_rem, one, span); @@ -154,7 +160,9 @@ fn constrain_div_mod( // cell representing the foundamental constraint let fc_var = compiler.backend.sub(&lhs_sub_q_mul_rhs, &rem_var, span); - compiler.backend.assert_eq_const(&fc_var, B::Field::zero(), span); + compiler + .backend + .assert_eq_const(&fc_var, B::Field::zero(), span); (rem_var, q_var) } @@ -177,7 +185,9 @@ pub fn div( _ => { let is_zero = is_zero_cell(compiler, rhs, span); let is_zero = is_zero[0].cvar().unwrap(); - compiler.backend.assert_eq_const(is_zero, B::Field::zero(), span); + compiler + .backend + .assert_eq_const(is_zero, B::Field::zero(), span); } }; @@ -194,7 +204,7 @@ pub fn div( _ => { let (_, q) = constrain_div_mod(compiler, lhs, rhs, span); Var::new_var(q, span) - }, + } } } @@ -470,10 +480,8 @@ pub fn less_than( assert!(bit_len <= (bitlen_upper_bound)); - let carry_bit_len = bit_len + 1; - // let pow2 = (1 << bit_len) as u32; // let pow2 = B::Field::from(pow2); let two = B::Field::from(2u32); @@ -489,20 +497,20 @@ pub fn less_than( (_, _) => { let pow2_lhs = match lhs { // todo: we really should refactor the backend to handle ConstOrCell - ConstOrCell::Const(lhs) => compiler.backend.add_constant( - Some("wrap a constant as var"), - *lhs + pow2, - span, - ), + ConstOrCell::Const(lhs) => { + compiler + .backend + .add_constant(Some("wrap a constant as var"), *lhs + pow2, span) + } ConstOrCell::Cell(lhs) => compiler.backend.add_const(lhs, &pow2, span), }; let rhs = match rhs { - ConstOrCell::Const(rhs) => compiler.backend.add_constant( - Some("wrap a constant as var"), - *rhs, - span, - ), + ConstOrCell::Const(rhs) => { + compiler + .backend + .add_constant(Some("wrap a constant as var"), *rhs, span) + } ConstOrCell::Cell(rhs) => rhs.clone(), }; @@ -522,15 +530,15 @@ pub fn less_than( let sum_var = Var::new_var(sum, span); let sum_var = VarInfo::new(sum_var, false, Some(TyKind::Field { constant: false })); - let sum_bits = to_bits(compiler, &gens, &[cbl_var, sum_var], span).unwrap().unwrap(); + let sum_bits = to_bits(compiler, &gens, &[cbl_var, sum_var], span) + .unwrap() + .unwrap(); // convert to cell vars let sum_bits: Vec<_> = sum_bits.cvars.into_iter().collect(); // if sum_bits[LEN] == 0, then lhs < rhs let res = &is_zero_cell(compiler, &sum_bits[bit_len], span)[0]; - let res = res - .cvar() - .unwrap(); + let res = res.cvar().unwrap(); Var::new_var(res.clone(), span) } } diff --git a/src/mast/mod.rs b/src/mast/mod.rs index 9f568ee02..4933215e1 100644 --- a/src/mast/mod.rs +++ b/src/mast/mod.rs @@ -549,16 +549,14 @@ fn monomorphize_expr( let rhs_mono = monomorphize_expr(ctx, rhs, mono_fn_env)?; let typ = match op { - Op2::Equality - | Op2::Inequality - | Op2::LessThan => Some(TyKind::Bool), + Op2::Equality | Op2::Inequality | Op2::LessThan => Some(TyKind::Bool), Op2::Addition | Op2::Subtraction | Op2::Multiplication | Op2::Division | Op2::Modulus | Op2::BoolAnd - | Op2::BoolOr + | Op2::BoolOr | Op2::LeftShift => lhs_mono.clone().typ, }; diff --git a/src/type_checker/checker.rs b/src/type_checker/checker.rs index ab355f4a3..281985ecd 100644 --- a/src/type_checker/checker.rs +++ b/src/type_checker/checker.rs @@ -302,16 +302,14 @@ impl TypeChecker { } let typ = match op { - Op2::Equality - | Op2::Inequality - | Op2::LessThan => TyKind::Bool, + Op2::Equality | Op2::Inequality | Op2::LessThan => TyKind::Bool, Op2::Addition | Op2::Subtraction | Op2::Multiplication | Op2::Division | Op2::Modulus | Op2::BoolAnd - | Op2::BoolOr + | Op2::BoolOr | Op2::LeftShift => lhs_node.typ, }; diff --git a/src/var.rs b/src/var.rs index e1ea00076..fcee95943 100644 --- a/src/var.rs +++ b/src/var.rs @@ -58,8 +58,8 @@ where NthBit(B::Var, usize), /// Divide - // todo: refactor to use a argument wrapper to encapsulate its own type, - // so that a variant can have an argument to be either B::Var or B::Field + // todo: refactor to use a argument wrapper to encapsulate its own type, + // so that a variant can have an argument to be either B::Var or B::Field CstDivVar(B::Field, B::Var), VarDivCst(B::Var, B::Field), VarDivVar(B::Var, B::Var),