Skip to content

Commit

Permalink
refactor: lookup-less sign relu abs
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-camuto committed Sep 13, 2024
1 parent c9f9d17 commit d89211a
Show file tree
Hide file tree
Showing 16 changed files with 712 additions and 56 deletions.
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,10 @@ harness = false
name = "relu"
harness = false

[[bench]]
name = "relu_lookupless"
harness = false

[[bench]]
name = "accum_matmul_relu"
harness = false
Expand Down
16 changes: 14 additions & 2 deletions benches/accum_matmul_relu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,15 @@ impl Circuit<Fr> for MyCircuit {

// sets up a new relu table
base_config
.configure_lookup(cs, &b, &output, &a, BITS, K, &LookupOp::ReLU)
.configure_lookup(
cs,
&b,
&output,
&a,
BITS,
K,
&&LookupOp::LeakyReLU { slope: 0.0.into() },
)
.unwrap();

MyConfig { base_config }
Expand All @@ -82,7 +90,11 @@ impl Circuit<Fr> for MyCircuit {
.unwrap();
let _output = config
.base_config
.layout(&mut region, &[output.unwrap()], Box::new(LookupOp::ReLU))
.layout(
&mut region,
&[output.unwrap()],
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
)
.unwrap();
Ok(())
},
Expand Down
16 changes: 14 additions & 2 deletions benches/accum_matmul_relu_overflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,15 @@ impl Circuit<Fr> for MyCircuit {

// sets up a new relu table
base_config
.configure_lookup(cs, &b, &output, &a, BITS, k, &LookupOp::ReLU)
.configure_lookup(
cs,
&b,
&output,
&a,
BITS,
k,
&LookupOp::LeakyReLU { slope: 0.0.into() },
)
.unwrap();

MyConfig { base_config }
Expand All @@ -83,7 +91,11 @@ impl Circuit<Fr> for MyCircuit {
.unwrap();
let _output = config
.base_config
.layout(&mut region, &[output.unwrap()], Box::new(LookupOp::ReLU))
.layout(
&mut region,
&[output.unwrap()],
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
)
.unwrap();
Ok(())
},
Expand Down
8 changes: 6 additions & 2 deletions benches/relu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ impl Circuit<Fr> for NLCircuit {
.map(|_| VarTensor::new_advice(cs, K, 1, LEN))
.collect::<Vec<_>>();

let nl = LookupOp::ReLU;
let nl = LookupOp::LeakyReLU { slope: 0.0.into() };

let mut config = Config::default();

Expand All @@ -65,7 +65,11 @@ impl Circuit<Fr> for NLCircuit {
|region| {
let mut region = RegionCtx::new(region, 0, 1);
config
.layout(&mut region, &[self.input.clone()], Box::new(LookupOp::ReLU))
.layout(
&mut region,
&[self.input.clone()],
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
)
.unwrap();
Ok(())
},
Expand Down
147 changes: 147 additions & 0 deletions benches/relu_lookupless.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use ezkl::circuit::poly::PolyOp;
use ezkl::circuit::region::RegionCtx;
use ezkl::circuit::{BaseConfig as Config, CheckMode};
use ezkl::fieldutils::IntegerRep;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::TranscriptType;
use ezkl::pfsys::{create_keys, srs::gen_srs};
use ezkl::tensor::*;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2_proofs::poly::kzg::multiopen::{ProverSHPLONK, VerifierSHPLONK};
use halo2_proofs::poly::kzg::strategy::SingleStrategy;
use halo2_proofs::{
circuit::{Layouter, SimpleFloorPlanner, Value},
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::{Bn256, Fr};
use rand::Rng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;

static mut LEN: usize = 4;
const K: usize = 16;

#[derive(Clone)]
struct NLCircuit {
pub input: ValTensor<Fr>,
}

impl Circuit<Fr> for NLCircuit {
type Config = Config<Fr>;
type FloorPlanner = SimpleFloorPlanner;
type Params = ();

fn without_witnesses(&self) -> Self {
self.clone()
}

fn configure(cs: &mut ConstraintSystem<Fr>) -> Self::Config {
unsafe {
let advices = (0..3)
.map(|_| VarTensor::new_advice(cs, K, 1, LEN))
.collect::<Vec<_>>();

let mut config = Config::default();

config
.configure_range_check(cs, &advices[0], &advices[1], (-1, 1), K)
.unwrap();

config
.configure_range_check(cs, &advices[0], &advices[1], (0, 1023), K)
.unwrap();

let _constant = VarTensor::constant_cols(cs, K, LEN, false);

config
}
}

fn synthesize(
&self,
mut config: Self::Config,
mut layouter: impl Layouter<Fr>, // layouter is our 'write buffer' for the circuit
) -> Result<(), Error> {
config.layout_range_checks(&mut layouter).unwrap();
layouter.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1);
config
.layout(
&mut region,
&[self.input.clone()],
Box::new(PolyOp::ReLU { base: 1024, n: 2 }),
)
.unwrap();
Ok(())
},
)?;
Ok(())
}
}

fn runrelu(c: &mut Criterion) {
let mut group = c.benchmark_group("relu");

let mut rng = rand::thread_rng();
let params = gen_srs::<KZGCommitmentScheme<_>>(17);
for &len in [4, 8].iter() {
unsafe {
LEN = len;
};

let input: Tensor<Value<Fr>> =
Tensor::<IntegerRep>::from((0..len).map(|_| rng.gen_range(0..10))).into();

let circuit = NLCircuit {
input: ValTensor::from(input.clone()),
};

group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("pk", len), &len, |b, &_| {
b.iter(|| {
create_keys::<KZGCommitmentScheme<Bn256>, NLCircuit>(&circuit, &params, true)
.unwrap();
});
});

let pk =
create_keys::<KZGCommitmentScheme<Bn256>, NLCircuit>(&circuit, &params, true).unwrap();

group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("prove", len), &len, |b, &_| {
b.iter(|| {
let prover = create_proof_circuit::<
KZGCommitmentScheme<_>,
NLCircuit,
ProverSHPLONK<_>,
VerifierSHPLONK<_>,
SingleStrategy<_>,
_,
EvmTranscript<_, _, _, _>,
EvmTranscript<_, _, _, _>,
>(
circuit.clone(),
vec![],
&params,
&pk,
CheckMode::UNSAFE,
ezkl::Commitments::KZG,
TranscriptType::EVM,
None,
None,
);
prover.unwrap();
});
});
}
group.finish();
}

criterion_group! {
name = benches;
config = Criterion::default().with_plots();
targets = runrelu
}
criterion_main!(benches);
8 changes: 6 additions & 2 deletions examples/conv2d_mnist/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ where
&params,
(LOOKUP_MIN, LOOKUP_MAX),
K,
&LookupOp::ReLU,
&&LookupOp::LeakyReLU { slope: 0.0.into() },
)
.unwrap();

Expand Down Expand Up @@ -221,7 +221,11 @@ where

let x = config
.layer_config
.layout(&mut region, &[x.unwrap()], Box::new(LookupOp::ReLU))
.layout(
&mut region,
&[x.unwrap()],
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
)
.unwrap();

let mut x = config
Expand Down
14 changes: 11 additions & 3 deletions examples/mlp_4d_einsum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
&params,
(LOOKUP_MIN, LOOKUP_MAX),
K,
&LookupOp::ReLU,
&&LookupOp::LeakyReLU { slope: 0.0.into() },
)
.unwrap();

Expand Down Expand Up @@ -141,7 +141,11 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
println!("x shape: {:?}", x.dims());
let mut x = config
.layer_config
.layout(&mut region, &[x], Box::new(LookupOp::ReLU))
.layout(
&mut region,
&[x],
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
)
.unwrap()
.unwrap();
println!("3");
Expand Down Expand Up @@ -177,7 +181,11 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
println!("x shape: {:?}", x.dims());
let x = config
.layer_config
.layout(&mut region, &[x], Box::new(LookupOp::ReLU))
.layout(
&mut region,
&[x],
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
)
.unwrap();
println!("6");
println!("offset: {}", region.row());
Expand Down
Loading

0 comments on commit d89211a

Please sign in to comment.