From d89211a951d885f457add607c7c71ef1e91f2575 Mon Sep 17 00:00:00 2001 From: dante <45801863+alexander-camuto@users.noreply.github.com> Date: Fri, 13 Sep 2024 21:11:57 +0200 Subject: [PATCH] refactor: lookup-less `sign` `relu` `abs` --- Cargo.toml | 4 + benches/accum_matmul_relu.rs | 16 ++- benches/accum_matmul_relu_overflow.rs | 16 ++- benches/relu.rs | 8 +- benches/relu_lookupless.rs | 147 ++++++++++++++++++++++++++ examples/conv2d_mnist/main.rs | 8 +- examples/mlp_4d_einsum.rs | 14 ++- src/circuit/ops/layouts.rs | 144 +++++++++++++++++++++++++ src/circuit/ops/lookup.rs | 16 +-- src/circuit/ops/poly.rs | 27 ++++- src/circuit/tests.rs | 86 +++++++++++---- src/graph/utilities.rs | 17 ++- src/lib.rs | 13 +++ src/tensor/mod.rs | 73 +++++++++++++ src/tensor/ops.rs | 112 ++++++++++++++++++++ src/tensor/val.rs | 67 +++++++++++- 16 files changed, 712 insertions(+), 56 deletions(-) create mode 100644 benches/relu_lookupless.rs diff --git a/Cargo.toml b/Cargo.toml index 30dcfb0aa..a3287e822 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -162,6 +162,10 @@ harness = false name = "relu" harness = false +[[bench]] +name = "relu_lookupless" +harness = false + [[bench]] name = "accum_matmul_relu" harness = false diff --git a/benches/accum_matmul_relu.rs b/benches/accum_matmul_relu.rs index 0aa8901ef..7f223a0a2 100644 --- a/benches/accum_matmul_relu.rs +++ b/benches/accum_matmul_relu.rs @@ -57,7 +57,15 @@ impl Circuit for MyCircuit { // sets up a new relu table base_config - .configure_lookup(cs, &b, &output, &a, BITS, K, &LookupOp::ReLU) + .configure_lookup( + cs, + &b, + &output, + &a, + BITS, + K, + &&LookupOp::LeakyReLU { slope: 0.0.into() }, + ) .unwrap(); MyConfig { base_config } @@ -82,7 +90,11 @@ impl Circuit for MyCircuit { .unwrap(); let _output = config .base_config - .layout(&mut region, &[output.unwrap()], Box::new(LookupOp::ReLU)) + .layout( + &mut region, + &[output.unwrap()], + Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }), + ) .unwrap(); Ok(()) }, diff --git a/benches/accum_matmul_relu_overflow.rs b/benches/accum_matmul_relu_overflow.rs index b69ae7068..f59c2f69b 100644 --- a/benches/accum_matmul_relu_overflow.rs +++ b/benches/accum_matmul_relu_overflow.rs @@ -58,7 +58,15 @@ impl Circuit for MyCircuit { // sets up a new relu table base_config - .configure_lookup(cs, &b, &output, &a, BITS, k, &LookupOp::ReLU) + .configure_lookup( + cs, + &b, + &output, + &a, + BITS, + k, + &LookupOp::LeakyReLU { slope: 0.0.into() }, + ) .unwrap(); MyConfig { base_config } @@ -83,7 +91,11 @@ impl Circuit for MyCircuit { .unwrap(); let _output = config .base_config - .layout(&mut region, &[output.unwrap()], Box::new(LookupOp::ReLU)) + .layout( + &mut region, + &[output.unwrap()], + Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }), + ) .unwrap(); Ok(()) }, diff --git a/benches/relu.rs b/benches/relu.rs index ed39735e1..93e53bf84 100644 --- a/benches/relu.rs +++ b/benches/relu.rs @@ -42,7 +42,7 @@ impl Circuit for NLCircuit { .map(|_| VarTensor::new_advice(cs, K, 1, LEN)) .collect::>(); - let nl = LookupOp::ReLU; + let nl = LookupOp::LeakyReLU { slope: 0.0.into() }; let mut config = Config::default(); @@ -65,7 +65,11 @@ impl Circuit for NLCircuit { |region| { let mut region = RegionCtx::new(region, 0, 1); config - .layout(&mut region, &[self.input.clone()], Box::new(LookupOp::ReLU)) + .layout( + &mut region, + &[self.input.clone()], + Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }), + ) .unwrap(); Ok(()) }, diff --git a/benches/relu_lookupless.rs b/benches/relu_lookupless.rs new file mode 100644 index 000000000..c5ee1144d --- /dev/null +++ b/benches/relu_lookupless.rs @@ -0,0 +1,147 @@ +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; +use ezkl::circuit::poly::PolyOp; +use ezkl::circuit::region::RegionCtx; +use ezkl::circuit::{BaseConfig as Config, CheckMode}; +use ezkl::fieldutils::IntegerRep; +use ezkl::pfsys::create_proof_circuit; +use ezkl::pfsys::TranscriptType; +use ezkl::pfsys::{create_keys, srs::gen_srs}; +use ezkl::tensor::*; +use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme; +use halo2_proofs::poly::kzg::multiopen::{ProverSHPLONK, VerifierSHPLONK}; +use halo2_proofs::poly::kzg::strategy::SingleStrategy; +use halo2_proofs::{ + circuit::{Layouter, SimpleFloorPlanner, Value}, + plonk::{Circuit, ConstraintSystem, Error}, +}; +use halo2curves::bn256::{Bn256, Fr}; +use rand::Rng; +use snark_verifier::system::halo2::transcript::evm::EvmTranscript; + +static mut LEN: usize = 4; +const K: usize = 16; + +#[derive(Clone)] +struct NLCircuit { + pub input: ValTensor, +} + +impl Circuit for NLCircuit { + type Config = Config; + type FloorPlanner = SimpleFloorPlanner; + type Params = (); + + fn without_witnesses(&self) -> Self { + self.clone() + } + + fn configure(cs: &mut ConstraintSystem) -> Self::Config { + unsafe { + let advices = (0..3) + .map(|_| VarTensor::new_advice(cs, K, 1, LEN)) + .collect::>(); + + let mut config = Config::default(); + + config + .configure_range_check(cs, &advices[0], &advices[1], (-1, 1), K) + .unwrap(); + + config + .configure_range_check(cs, &advices[0], &advices[1], (0, 1023), K) + .unwrap(); + + let _constant = VarTensor::constant_cols(cs, K, LEN, false); + + config + } + } + + fn synthesize( + &self, + mut config: Self::Config, + mut layouter: impl Layouter, // layouter is our 'write buffer' for the circuit + ) -> Result<(), Error> { + config.layout_range_checks(&mut layouter).unwrap(); + layouter.assign_region( + || "", + |region| { + let mut region = RegionCtx::new(region, 0, 1); + config + .layout( + &mut region, + &[self.input.clone()], + Box::new(PolyOp::ReLU { base: 1024, n: 2 }), + ) + .unwrap(); + Ok(()) + }, + )?; + Ok(()) + } +} + +fn runrelu(c: &mut Criterion) { + let mut group = c.benchmark_group("relu"); + + let mut rng = rand::thread_rng(); + let params = gen_srs::>(17); + for &len in [4, 8].iter() { + unsafe { + LEN = len; + }; + + let input: Tensor> = + Tensor::::from((0..len).map(|_| rng.gen_range(0..10))).into(); + + let circuit = NLCircuit { + input: ValTensor::from(input.clone()), + }; + + group.throughput(Throughput::Elements(len as u64)); + group.bench_with_input(BenchmarkId::new("pk", len), &len, |b, &_| { + b.iter(|| { + create_keys::, NLCircuit>(&circuit, ¶ms, true) + .unwrap(); + }); + }); + + let pk = + create_keys::, NLCircuit>(&circuit, ¶ms, true).unwrap(); + + group.throughput(Throughput::Elements(len as u64)); + group.bench_with_input(BenchmarkId::new("prove", len), &len, |b, &_| { + b.iter(|| { + let prover = create_proof_circuit::< + KZGCommitmentScheme<_>, + NLCircuit, + ProverSHPLONK<_>, + VerifierSHPLONK<_>, + SingleStrategy<_>, + _, + EvmTranscript<_, _, _, _>, + EvmTranscript<_, _, _, _>, + >( + circuit.clone(), + vec![], + ¶ms, + &pk, + CheckMode::UNSAFE, + ezkl::Commitments::KZG, + TranscriptType::EVM, + None, + None, + ); + prover.unwrap(); + }); + }); + } + group.finish(); +} + +criterion_group! { + name = benches; + config = Criterion::default().with_plots(); + targets = runrelu +} +criterion_main!(benches); diff --git a/examples/conv2d_mnist/main.rs b/examples/conv2d_mnist/main.rs index 25aa64ca3..48452f9f4 100644 --- a/examples/conv2d_mnist/main.rs +++ b/examples/conv2d_mnist/main.rs @@ -163,7 +163,7 @@ where ¶ms, (LOOKUP_MIN, LOOKUP_MAX), K, - &LookupOp::ReLU, + &&LookupOp::LeakyReLU { slope: 0.0.into() }, ) .unwrap(); @@ -221,7 +221,11 @@ where let x = config .layer_config - .layout(&mut region, &[x.unwrap()], Box::new(LookupOp::ReLU)) + .layout( + &mut region, + &[x.unwrap()], + Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }), + ) .unwrap(); let mut x = config diff --git a/examples/mlp_4d_einsum.rs b/examples/mlp_4d_einsum.rs index 30a6f31dc..4052bb090 100644 --- a/examples/mlp_4d_einsum.rs +++ b/examples/mlp_4d_einsum.rs @@ -69,7 +69,7 @@ impl( .map_err(|e| e.into()) } +pub(crate) fn decompose( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 1], + base: &usize, + n: &usize, +) -> Result, CircuitError> { + let input = values[0].clone(); + + let is_assigned = !input.any_unknowns()?; + + let bases: ValTensor = Tensor::from( + (0..*n) + .rev() + .map(|x| ValType::Constant(integer_rep_to_felt(base.pow(x as u32) as IntegerRep))), + ) + .into(); + + let mut claimed_output: ValTensor = if is_assigned { + let input_evals = input.int_evals()?; + tensor::ops::decompose(&input_evals, *base, *n)? + .par_enum_map(|_, x| Ok::<_, TensorError>(Value::known(integer_rep_to_felt::(x))))? + .into() + } else { + let mut dims = input.dims().to_vec(); + dims.push(n + 1); + + Tensor::new( + Some(&vec![Value::::unknown(); input.len() * (n + 1)]), + &dims, + )? + .into() + }; + + claimed_output = region.assign(&config.custom_gates.inputs[0], &claimed_output)?; + region.increment(claimed_output.len()); + + println!("claimed output dims {:?}", claimed_output.dims()); + + let cartesian_coord = input + .dims() + .iter() + .map(|x| 0..*x) + .multi_cartesian_product() + .collect::>(); + + let mut dummy_iterator = Tensor::from(0..cartesian_coord.len()); + + let inner_loop_function = |i: usize, region: &mut RegionCtx| { + let coord = cartesian_coord[i].clone(); + let slice = coord.iter().map(|x| *x..*x + 1).collect::>(); + + let mut claimed_output_slice = claimed_output.get_slice(&slice)?; + claimed_output_slice.flatten(); + + println!("claimed_output_slice {:?}", claimed_output_slice.dims()); + + println!("claimed_output_slice len {}", claimed_output_slice.len()); + + let sliced_input = input.get_slice(&slice)?; + // get the sign bit and make sure it is valid + let sign = claimed_output.first()?; + let sign = range_check(config, region, &[sign], &(-1, 1))?; + + // get the rest of the thing and make sure it is in the correct range + let rest = claimed_output_slice.get_slice(&[1..claimed_output_slice.len()])?; + + println!("rest len {}, bases len {}", rest.len(), bases.len()); + + let rest = range_check(config, region, &[rest], &(0, (base - 1) as i128))?; + + let prod_decomp = dot(config, region, &[rest, bases.clone()])?; + + println!("prod_decomp {:?}", prod_decomp.dims()); + + println!("sign {:?}", sign.dims()); + + let signed_decomp = pairwise(config, region, &[prod_decomp, sign], BaseOp::Mult)?; + + println!("signed decomp {:?}", signed_decomp.dims()); + + println!("sliced_input {:?}", sliced_input.dims()); + + enforce_equality(config, region, &[sliced_input, signed_decomp])?; + + Ok(usize::default()) + }; + + region.apply_in_loop(&mut dummy_iterator, inner_loop_function)?; + + Ok(claimed_output) +} + +pub(crate) fn sign( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 1], + base: &usize, + n: &usize, +) -> Result, CircuitError> { + let mut decomp = decompose(config, region, values, base, n)?; + // get every n elements now, which correspond to the sign bit + println!("decomp dims {:?}", decomp.dims()); + + decomp.get_every_n(*n + 1)?; + decomp.reshape(values[0].dims())?; + + Ok(decomp) +} + +pub(crate) fn abs( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 1], + base: &usize, + n: &usize, +) -> Result, CircuitError> { + let sign = sign(config, region, values, base, n)?; + + pairwise(config, region, &[values[0].clone(), sign], BaseOp::Mult) +} + +pub(crate) fn relu( + config: &BaseConfig, + region: &mut RegionCtx, + values: &[ValTensor; 1], + base: &usize, + n: &usize, +) -> Result, CircuitError> { + let sign = sign(config, region, values, base, n)?; + + let mut unit = create_unit_tensor(sign.len()); + unit.reshape(sign.dims())?; + + let relu_mask = equals(config, region, &[sign, unit])?; + + pairwise( + config, + region, + &[values[0].clone(), relu_mask], + BaseOp::Mult, + ) +} + fn multi_dim_axes_op( config: &BaseConfig, region: &mut RegionCtx, diff --git a/src/circuit/ops/lookup.rs b/src/circuit/ops/lookup.rs index 2d164ea0e..1cf99d72a 100644 --- a/src/circuit/ops/lookup.rs +++ b/src/circuit/ops/lookup.rs @@ -15,14 +15,12 @@ use halo2curves::ff::PrimeField; /// An enum representing the operations that can be used to express more complex operations via accumulation #[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Deserialize, Serialize)] pub enum LookupOp { - Abs, Div { denom: utils::F32, }, Cast { scale: utils::F32, }, - ReLU, Max { scale: utils::F32, a: utils::F32, @@ -116,7 +114,6 @@ pub enum LookupOp { LessThanEqual { a: utils::F32, }, - Sign, KroneckerDelta, Pow { scale: utils::F32, @@ -138,7 +135,6 @@ impl LookupOp { /// as path pub fn as_path(&self) -> String { match self { - LookupOp::Abs => "abs".into(), LookupOp::Ceil { scale } => format!("ceil_{}", scale), LookupOp::Floor { scale } => format!("floor_{}", scale), LookupOp::Round { scale } => format!("round_{}", scale), @@ -147,7 +143,6 @@ impl LookupOp { LookupOp::KroneckerDelta => "kronecker_delta".into(), LookupOp::Max { scale, a } => format!("max_{}_{}", scale, a), LookupOp::Min { scale, a } => format!("min_{}_{}", scale, a), - LookupOp::Sign => "sign".into(), LookupOp::LessThan { a } => format!("less_than_{}", a), LookupOp::LessThanEqual { a } => format!("less_than_equal_{}", a), LookupOp::GreaterThan { a } => format!("greater_than_{}", a), @@ -158,7 +153,6 @@ impl LookupOp { input_scale, output_scale, } => format!("recip_{}_{}", input_scale, output_scale), - LookupOp::ReLU => "relu".to_string(), LookupOp::LeakyReLU { slope: a } => format!("leaky_relu_{}", a), LookupOp::Sigmoid { scale } => format!("sigmoid_{}", scale), LookupOp::Sqrt { scale } => format!("sqrt_{}", scale), @@ -189,7 +183,6 @@ impl LookupOp { ) -> Result, TensorError> { let x = x[0].clone().map(|x| felt_to_integer_rep(x)); let res = match &self { - LookupOp::Abs => Ok(tensor::ops::abs(&x)?), LookupOp::Ceil { scale } => Ok(tensor::ops::nonlinearities::ceil(&x, scale.into())), LookupOp::Floor { scale } => Ok(tensor::ops::nonlinearities::floor(&x, scale.into())), LookupOp::Round { scale } => Ok(tensor::ops::nonlinearities::round(&x, scale.into())), @@ -212,7 +205,6 @@ impl LookupOp { scale.0.into(), a.0.into(), )), - LookupOp::Sign => Ok(tensor::ops::nonlinearities::sign(&x)), LookupOp::LessThan { a } => Ok(tensor::ops::nonlinearities::less_than( &x, f32::from(*a).into(), @@ -244,8 +236,6 @@ impl LookupOp { input_scale.into(), output_scale.into(), )), - LookupOp::ReLU => Ok(tensor::ops::nonlinearities::leakyrelu(&x, 0_f64)), - LookupOp::LeakyReLU { slope: a } => { Ok(tensor::ops::nonlinearities::leakyrelu(&x, a.0.into())) } @@ -289,7 +279,6 @@ impl Op for Lookup /// Returns the name of the operation fn as_string(&self) -> String { match self { - LookupOp::Abs => "ABS".into(), LookupOp::Ceil { scale } => format!("CEIL(scale={})", scale), LookupOp::Floor { scale } => format!("FLOOR(scale={})", scale), LookupOp::Round { scale } => format!("ROUND(scale={})", scale), @@ -298,7 +287,6 @@ impl Op for Lookup LookupOp::KroneckerDelta => "K_DELTA".into(), LookupOp::Max { scale, a } => format!("MAX(scale={}, a={})", scale, a), LookupOp::Min { scale, a } => format!("MIN(scale={}, a={})", scale, a), - LookupOp::Sign => "SIGN".into(), LookupOp::GreaterThan { a } => format!("GREATER_THAN(a={})", a), LookupOp::GreaterThanEqual { a } => format!("GREATER_THAN_EQUAL(a={})", a), LookupOp::LessThan { a } => format!("LESS_THAN(a={})", a), @@ -313,7 +301,6 @@ impl Op for Lookup LookupOp::Div { denom, .. } => format!("DIV(denom={})", denom), LookupOp::Cast { scale } => format!("CAST(scale={})", scale), LookupOp::Ln { scale } => format!("LN(scale={})", scale), - LookupOp::ReLU => "RELU".to_string(), LookupOp::LeakyReLU { slope: a } => format!("L_RELU(slope={})", a), LookupOp::Sigmoid { scale } => format!("SIGMOID(scale={})", scale), LookupOp::Sqrt { scale } => format!("SQRT(scale={})", scale), @@ -358,8 +345,7 @@ impl Op for Lookup in_scale + multiplier_to_scale(1. / scale.0 as f64) } LookupOp::Recip { output_scale, .. } => multiplier_to_scale(output_scale.into()), - LookupOp::Sign - | LookupOp::GreaterThan { .. } + LookupOp::GreaterThan { .. } | LookupOp::LessThan { .. } | LookupOp::GreaterThanEqual { .. } | LookupOp::LessThanEqual { .. } diff --git a/src/circuit/ops/poly.rs b/src/circuit/ops/poly.rs index 4bd31883d..58c9aa56d 100644 --- a/src/circuit/ops/poly.rs +++ b/src/circuit/ops/poly.rs @@ -9,6 +9,18 @@ use super::{base::BaseOp, *}; /// An enum representing the operations that can be expressed as arithmetic (non lookup) operations. #[derive(Clone, Debug, Serialize, Deserialize)] pub enum PolyOp { + ReLU { + base: usize, + n: usize, + }, + Abs { + base: usize, + n: usize, + }, + Sign { + base: usize, + n: usize, + }, GatherElements { dim: usize, constant_idx: Option>, @@ -99,8 +111,7 @@ impl< + PartialOrd + std::hash::Hash + Serialize - + for<'de> Deserialize<'de> - , + + for<'de> Deserialize<'de>, > Op for PolyOp { /// Returns a reference to the Any trait. @@ -110,6 +121,9 @@ impl< fn as_string(&self) -> String { match &self { + PolyOp::Abs { base, n } => format!("ABS (base={}, n={})", base, n), + PolyOp::Sign { base, n } => format!("SIGN (base={}, n={})", base, n), + PolyOp::ReLU { base, n } => format!("RELU (base={}, n={})", base, n), PolyOp::GatherElements { dim, constant_idx } => format!( "GATHERELEMENTS (dim={}, constant_idx{})", dim, @@ -191,6 +205,15 @@ impl< values: &[ValTensor], ) -> Result>, CircuitError> { Ok(Some(match self { + PolyOp::Abs { base, n } => { + layouts::abs(config, region, values[..].try_into()?, base, n)? + } + PolyOp::Sign { base, n } => { + layouts::sign(config, region, values[..].try_into()?, base, n)? + } + PolyOp::ReLU { base, n } => { + layouts::relu(config, region, values[..].try_into()?, base, n)? + } PolyOp::MultiBroadcastTo { shape } => { layouts::expand(config, region, values[..].try_into()?, shape)? } diff --git a/src/circuit/tests.rs b/src/circuit/tests.rs index de4f80db2..1e470ca95 100644 --- a/src/circuit/tests.rs +++ b/src/circuit/tests.rs @@ -1297,7 +1297,7 @@ mod conv_relu_col_ultra_overflow { use super::*; - const K: usize = 4; + const K: usize = 8; const LEN: usize = 15; #[derive(Clone)] @@ -1317,15 +1317,23 @@ mod conv_relu_col_ultra_overflow { } fn configure(cs: &mut ConstraintSystem) -> Self::Config { - let a = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN); - let b = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN); - let output = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN); + let a = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN * 4); + let b = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN * 4); + let output = VarTensor::new_advice(cs, K, 1, LEN * LEN * LEN * 4); let mut base_config = Self::Config::configure(cs, &[a.clone(), b.clone()], &output, CheckMode::SAFE); // sets up a new relu table + + base_config + .configure_range_check(cs, &a, &b, (-1, 1), K) + .unwrap(); + base_config - .configure_lookup(cs, &b, &output, &a, (-3, 3), K, &LookupOp::ReLU) + .configure_range_check(cs, &a, &b, (0, 1), K) .unwrap(); + + let _constant = VarTensor::constant_cols(cs, K, 8, false); + base_config.clone() } @@ -1334,7 +1342,7 @@ mod conv_relu_col_ultra_overflow { mut config: Self::Config, mut layouter: impl Layouter, ) -> Result<(), Error> { - config.layout_tables(&mut layouter).unwrap(); + config.layout_range_checks(&mut layouter).unwrap(); layouter .assign_region( || "", @@ -1355,7 +1363,7 @@ mod conv_relu_col_ultra_overflow { .layout( &mut region, &[output.unwrap().unwrap()], - Box::new(LookupOp::ReLU), + Box::new(PolyOp::ReLU { base: 2, n: 2 }), ) .unwrap(); Ok(()) @@ -2258,7 +2266,6 @@ mod matmul_relu { const K: usize = 18; const LEN: usize = 32; - use crate::circuit::lookup::LookupOp; #[derive(Clone)] struct MyCircuit { @@ -2288,11 +2295,17 @@ mod matmul_relu { let mut base_config = BaseConfig::configure(cs, &[a.clone(), b.clone()], &output, CheckMode::SAFE); - // sets up a new relu table + base_config - .configure_lookup(cs, &b, &output, &a, (-32768, 32768), K, &LookupOp::ReLU) + .configure_range_check(cs, &a, &b, (-1, 1), K) .unwrap(); + base_config + .configure_range_check(cs, &a, &b, (0, 1023), K) + .unwrap(); + + let _constant = VarTensor::constant_cols(cs, K, 8, false); + MyConfig { base_config } } @@ -2301,7 +2314,10 @@ mod matmul_relu { mut config: Self::Config, mut layouter: impl Layouter, ) -> Result<(), Error> { - config.base_config.layout_tables(&mut layouter).unwrap(); + config + .base_config + .layout_range_checks(&mut layouter) + .unwrap(); layouter.assign_region( || "", |region| { @@ -2315,7 +2331,11 @@ mod matmul_relu { .unwrap(); let _output = config .base_config - .layout(&mut region, &[output.unwrap()], Box::new(LookupOp::ReLU)) + .layout( + &mut region, + &[output.unwrap()], + Box::new(PolyOp::ReLU { base: 1024, n: 2 }), + ) .unwrap(); Ok(()) }, @@ -2354,6 +2374,8 @@ mod relu { plonk::{Circuit, ConstraintSystem, Error}, }; + const K: u32 = 8; + #[derive(Clone)] struct ReLUCircuit { pub input: ValTensor, @@ -2370,16 +2392,26 @@ mod relu { fn configure(cs: &mut ConstraintSystem) -> Self::Config { let advices = (0..3) - .map(|_| VarTensor::new_advice(cs, 4, 1, 3)) + .map(|_| VarTensor::new_advice(cs, 8, 1, 3)) .collect::>(); - let nl = LookupOp::ReLU; + let mut config = BaseConfig::configure( + cs, + &[advices[0].clone(), advices[1].clone()], + &advices[2], + CheckMode::SAFE, + ); - let mut config = BaseConfig::default(); + config + .configure_range_check(cs, &advices[0], &advices[1], (-1, 1), K as usize) + .unwrap(); config - .configure_lookup(cs, &advices[0], &advices[1], &advices[2], (-6, 6), 4, &nl) + .configure_range_check(cs, &advices[0], &advices[1], (0, 1), K as usize) .unwrap(); + + let _constant = VarTensor::constant_cols(cs, K as usize, 8, false); + config } @@ -2388,15 +2420,19 @@ mod relu { mut config: Self::Config, mut layouter: impl Layouter, // layouter is our 'write buffer' for the circuit ) -> Result<(), Error> { - config.layout_tables(&mut layouter).unwrap(); + config.layout_range_checks(&mut layouter).unwrap(); layouter .assign_region( || "", |region| { let mut region = RegionCtx::new(region, 0, 1); - config - .layout(&mut region, &[self.input.clone()], Box::new(LookupOp::ReLU)) - .map_err(|_| Error::Synthesis) + Ok(config + .layout( + &mut region, + &[self.input.clone()], + Box::new(PolyOp::ReLU { base: 2, n: 2 }), + ) + .unwrap()) }, ) .unwrap(); @@ -2414,7 +2450,7 @@ mod relu { input: ValTensor::from(input), }; - let prover = MockProver::run(4_u32, &circuit, vec![]).unwrap(); + let prover = MockProver::run(K, &circuit, vec![]).unwrap(); prover.assert_satisfied(); } } @@ -2453,7 +2489,7 @@ mod lookup_ultra_overflow { .map(|_| VarTensor::new_advice(cs, 4, 1, 3)) .collect::>(); - let nl = LookupOp::ReLU; + let nl = LookupOp::LeakyReLU { slope: 0.0.into() }; let mut config = BaseConfig::default(); @@ -2483,7 +2519,11 @@ mod lookup_ultra_overflow { |region| { let mut region = RegionCtx::new(region, 0, 1); config - .layout(&mut region, &[self.input.clone()], Box::new(LookupOp::ReLU)) + .layout( + &mut region, + &[self.input.clone()], + Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }), + ) .map_err(|_| Error::Synthesis) }, ) diff --git a/src/graph/utilities.rs b/src/graph/utilities.rs index 0ec224440..7e663a704 100644 --- a/src/graph/utilities.rs +++ b/src/graph/utilities.rs @@ -281,7 +281,7 @@ pub fn new_op_from_onnx( ) -> Result<(SupportedOp, Vec), GraphError> { use tract_onnx::tract_core::ops::array::Trilu; - use crate::circuit::InputType; + use crate::{circuit::InputType, EZKL_DECOMP_BASE, EZKL_DECOMP_LEN}; let input_scales = inputs .iter() @@ -782,7 +782,10 @@ pub fn new_op_from_onnx( deleted_indices.push(const_idx); } if unit == 0. { - SupportedOp::Nonlinear(LookupOp::ReLU) + SupportedOp::Linear(PolyOp::ReLU { + base: *EZKL_DECOMP_BASE, + n: *EZKL_DECOMP_LEN, + }) } else { // get the non-constant index let non_const_idx = if const_idx == 0 { 1 } else { 0 }; @@ -871,7 +874,10 @@ pub fn new_op_from_onnx( "QuantizeLinearU8" | "DequantizeLinearF32" => { SupportedOp::Linear(PolyOp::Identity { out_scale: None }) } - "Abs" => SupportedOp::Nonlinear(LookupOp::Abs), + "Abs" => SupportedOp::Linear(PolyOp::Abs { + base: *EZKL_DECOMP_BASE, + n: *EZKL_DECOMP_LEN, + }), "Neg" => SupportedOp::Linear(PolyOp::Neg), "HardSwish" => SupportedOp::Nonlinear(LookupOp::HardSwish { scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(), @@ -1127,7 +1133,10 @@ pub fn new_op_from_onnx( "RoundHalfToEven" => SupportedOp::Nonlinear(LookupOp::RoundHalfToEven { scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(), }), - "Sign" => SupportedOp::Nonlinear(LookupOp::Sign), + "Sign" => SupportedOp::Linear(PolyOp::Sign { + base: *EZKL_DECOMP_BASE, + n: *EZKL_DECOMP_LEN, + }), "Pow" => { // Extract the slope layer hyperparams from a const diff --git a/src/lib.rs b/src/lib.rs index 3645afff7..4c5bc1a92 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -150,6 +150,19 @@ lazy_static! { /// The serialization format for the keys pub static ref EZKL_KEY_FORMAT: String = std::env::var("EZKL_KEY_FORMAT") .unwrap_or("raw-bytes".to_string()); + + /// The base used to decompose operations like abs, sign, relu + pub static ref EZKL_DECOMP_BASE: usize = std::env::var("EZKL_DECOMP_BASE") + // this is 2**14 + .unwrap_or("16384".to_string()) + .parse() + .unwrap(); + + /// The length of the decomposition for operations like abs, sign, relu + pub static ref EZKL_DECOMP_LEN: usize = std::env::var("EZKL_DECOMP_LEN") + .unwrap_or("2".to_string()) + .parse() + .unwrap(); } #[cfg(target_arch = "wasm32")] diff --git a/src/tensor/mod.rs b/src/tensor/mod.rs index 16842092e..7e31ca3df 100644 --- a/src/tensor/mod.rs +++ b/src/tensor/mod.rs @@ -764,6 +764,54 @@ impl Tensor { index } + /// Fetches every nth element + /// + /// ``` + /// use ezkl::tensor::Tensor; + /// use ezkl::fieldutils::IntegerRep; + /// let a = Tensor::::new(Some(&[1, 2, 3, 4, 5, 6]), &[6]).unwrap(); + /// let expected = Tensor::::new(Some(&[1, 3, 5]), &[3]).unwrap(); + /// assert_eq!(a.get_every_n(2).unwrap(), expected); + /// assert_eq!(a.get_every_n(1).unwrap(), a); + /// + /// let expected = Tensor::::new(Some(&[1, 6]), &[2]).unwrap(); + /// assert_eq!(a.duplicate_every_n(5).unwrap(), expected); + /// + /// ``` + pub fn get_every_n(&self, n: usize) -> Result, TensorError> { + let mut inner: Vec = vec![]; + for (i, elem) in self.inner.clone().into_iter().enumerate() { + if i % n == 0 { + inner.push(elem.clone()); + } + } + Tensor::new(Some(&inner), &[inner.len()]) + } + + /// Excludes every nth element + /// + /// ``` + /// use ezkl::tensor::Tensor; + /// use ezkl::fieldutils::IntegerRep; + /// let a = Tensor::::new(Some(&[1, 2, 3, 4, 5, 6]), &[6]).unwrap(); + /// let expected = Tensor::::new(Some(&[2, 4, 6]), &[3]).unwrap(); + /// assert_eq!(a.exclude_every_n(2).unwrap(), expected); + /// assert_eq!(a.exclude_every_n(7).unwrap(), a); + /// + /// let expected = Tensor::::new(Some(&[2, 3, 4, 5]), &[9]).unwrap(); + /// assert_eq!(a.duplicate_every_n(5).unwrap(), expected); + /// + /// ``` + pub fn exclude_every_n(&self, n: usize) -> Result, TensorError> { + let mut inner: Vec = vec![]; + for (i, elem) in self.inner.clone().into_iter().enumerate() { + if !(i % n == 0) { + inner.push(elem.clone()); + } + } + Tensor::new(Some(&inner), &[inner.len()]) + } + /// Duplicates every nth element /// /// ``` @@ -1217,6 +1265,31 @@ impl Tensor { Tensor::new(Some(&[res]), &[1]) } + /// Get first elem from Tensor + /// ``` + /// use ezkl::tensor::Tensor; + /// use ezkl::fieldutils::IntegerRep; + /// let mut a = Tensor::::new(Some(&[1, 2, 3]), &[3]).unwrap(); + /// let mut b = Tensor::::new(Some(&[1]), &[1]).unwrap(); + /// + /// assert_eq!(a.first().unwrap(), b); + /// ``` + pub fn first(&self) -> Result, TensorError> + where + T: Send + Sync, + { + let res = match self.inner.first() { + Some(e) => e.clone(), + None => { + return Err(TensorError::DimError( + "Cannot get first element of empty tensor".to_string(), + )) + } + }; + + Tensor::new(Some(&[res]), &[1]) + } + /// Maps a function to tensors and enumerates in parallel /// ``` /// use ezkl::tensor::{Tensor, TensorError}; diff --git a/src/tensor/ops.rs b/src/tensor/ops.rs index 8fc78a940..917fbd3a3 100644 --- a/src/tensor/ops.rs +++ b/src/tensor/ops.rs @@ -7,6 +7,118 @@ use itertools::Itertools; use maybe_rayon::{iter::ParallelIterator, prelude::IntoParallelRefIterator}; pub use std::ops::{Add, Mul, Neg, Sub}; +/// Helper function to get the base decomp of an integer +/// # Arguments +/// * `x` - IntegerRep +/// * `n` - usize +/// * `base` - usize +/// +fn get_rep(x: &IntegerRep, base: usize, n: usize) -> Vec { + let mut rep = vec![0; n + 1]; + // sign bit + rep[0] = if *x < 0 { + -1 + } else if *x > 0 { + 1 + } else { + 0 + }; + + let mut x = x.abs(); + // + for i in (1..rep.len()).rev() { + rep[i] = x % base as i128; + x /= base as i128; + } + + rep +} + +/// Decompose a tensor of integers into a larger tensor with added dimension of size `n + 1` with the binary (or OTHER base) representation of the integer. +/// # Arguments +/// * `x` - Tensor +/// * `n` - usize +/// * `base` - usize +/// # Examples +/// ``` +/// use ezkl::tensor::Tensor; +/// use ezkl::fieldutils::IntegerRep; +/// use ezkl::tensor::ops::decompose; +/// let x = Tensor::::new( +/// Some(&[0, 1, 2, -1]), +/// &[2, 2]).unwrap(); +/// +/// let result = decompose(&x, 2, 2).unwrap(); +/// // result will have dims [2, 2, 3] +/// let expected = Tensor::::new(Some(&[0, 0, 0, +/// 1, 0, 1, +/// 1, 1, 0, +/// -1, 0, 1]), &[2, 2, 3]).unwrap(); +/// assert_eq!(result, expected); +/// +/// let result = decompose(&x, 3, 1).unwrap(); +/// +/// +/// // result will have dims [2, 2, 2] +/// let expected = Tensor::::new(Some(&[0, 0, +/// 1, 1, +/// 1, 2, +/// -1, 1]), &[2, 2, 2]).unwrap(); +/// +/// assert_eq!(result, expected); +/// +/// let x = Tensor::::new( +/// Some(&[0, 11, 23, -1]), +/// &[2, 2]).unwrap(); +/// +/// let result = decompose(&x, 2, 5).unwrap(); +/// // result will have dims [2, 2, 6] +/// let expected = Tensor::::new(Some(&[0, 0, 0, 0, 0, 0, +/// 1, 0, 1, 0, 1, 1, +/// 1, 1, 0, 1, 1, 1, +/// -1, 0, 0, 0, 0, 1]), &[2, 2, 6]).unwrap(); +/// assert_eq!(result, expected); +/// +/// let result = decompose(&x, 16, 2).unwrap(); +/// // result will have dims [2, 2, 3] +/// let expected = Tensor::::new(Some(&[0, 0, 0, +/// 1, 0, 11, +/// 1, 1, 7, +/// -1, 0, 1]), &[2, 2, 3]).unwrap(); +/// assert_eq!(result, expected); +/// ``` +/// +pub fn decompose( + x: &Tensor, + base: usize, + n: usize, +) -> Result, TensorError> { + let mut dims = x.dims().to_vec(); + dims.push(n + 1); + + if n == 0 { + let mut x = x.clone(); + x.reshape(&dims)?; + return Ok(x); + } + + println!("{} {}", base, n); + + let resp = x + .iter() + .map(|val| get_rep(val, base, n)) + .flatten() + .collect::>(); + + println!("{} {} {:?} ", base, n, resp); + + let output = Tensor::::new(Some(&resp), &dims)?; + + println!("{} {} {:?} {:?}", base, n, resp, output); + + Ok(output) +} + /// Trilu operation. /// # Arguments /// * `a` - Tensor diff --git a/src/tensor/val.rs b/src/tensor/val.rs index f9bd6d29d..c7dbc2627 100644 --- a/src/tensor/val.rs +++ b/src/tensor/val.rs @@ -520,6 +520,18 @@ impl ValTensor { } } + /// Get the sign of the inner values + pub fn sign(&self) -> Result { + let evals = self.int_evals()?; + Ok(evals + .par_enum_map(|_, val| { + Ok(ValType::Value(Value::known(integer_rep_to_felt( + val.signum(), + )))) + })? + .into()) + } + /// Calls `int_evals` on the inner tensor. pub fn int_evals(&self) -> Result, TensorError> { // finally convert to vector of integers @@ -574,7 +586,7 @@ impl ValTensor { Ok(()) } - /// Calls `get_slice` on the inner tensor. + /// Calls `last` on the inner tensor. pub fn last(&self) -> Result, TensorError> { let slice = match self { ValTensor::Value { @@ -595,6 +607,27 @@ impl ValTensor { Ok(slice) } + /// Calls `first` + pub fn first(&self) -> Result, TensorError> { + let slice = match self { + ValTensor::Value { + inner: v, + dims: _, + scale, + } => { + let inner = v.first()?; + let dims = inner.dims().to_vec(); + ValTensor::Value { + inner, + dims, + scale: *scale, + } + } + _ => return Err(TensorError::WrongMethod), + }; + Ok(slice) + } + /// Calls `get_slice` on the inner tensor. pub fn get_slice(&self, indices: &[Range]) -> Result, TensorError> { if indices.iter().map(|x| x.end - x.start).collect::>() == self.dims() { @@ -775,6 +808,38 @@ impl ValTensor { Ok(()) } + /// Calls `get_every_n` on the inner [Tensor]. + pub fn get_every_n(&mut self, n: usize) -> Result<(), TensorError> { + match self { + ValTensor::Value { + inner: v, dims: d, .. + } => { + *v = v.get_every_n(n)?; + *d = v.dims().to_vec(); + } + ValTensor::Instance { .. } => { + return Err(TensorError::WrongMethod); + } + } + Ok(()) + } + + /// Calls `exclude_every_n` on the inner [Tensor]. + pub fn exclude_every_n(&mut self, n: usize) -> Result<(), TensorError> { + match self { + ValTensor::Value { + inner: v, dims: d, .. + } => { + *v = v.exclude_every_n(n)?; + *d = v.dims().to_vec(); + } + ValTensor::Instance { .. } => { + return Err(TensorError::WrongMethod); + } + } + Ok(()) + } + /// remove constant zero values constants pub fn remove_const_zero_values(&mut self) { match self {