Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: unify leakyrelu and relu #858

Merged
merged 4 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -169,20 +169,20 @@ harness = false


[[bench]]
name = "relu"
name = "sigmoid"
harness = false

[[bench]]
name = "relu_lookupless"
harness = false

[[bench]]
name = "accum_matmul_relu"
name = "accum_matmul_sigmoid"
harness = false


[[bench]]
name = "accum_matmul_relu_overflow"
name = "accum_matmul_sigmoid_overflow"
harness = false

[[bin]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ impl Circuit<Fr> for MyCircuit {
&a,
BITS,
K,
&LookupOp::LeakyReLU { slope: 0.0.into() },
&LookupOp::Sigmoid { scale: 1.0.into() },
)
.unwrap();

Expand Down Expand Up @@ -93,7 +93,7 @@ impl Circuit<Fr> for MyCircuit {
.layout(
&mut region,
&[output.unwrap()],
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
)
.unwrap();
Ok(())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ impl Circuit<Fr> for MyCircuit {
&a,
BITS,
k,
&LookupOp::LeakyReLU { slope: 0.0.into() },
&LookupOp::Sigmoid { scale: 1.0.into() },
)
.unwrap();

Expand Down Expand Up @@ -94,7 +94,7 @@ impl Circuit<Fr> for MyCircuit {
.layout(
&mut region,
&[output.unwrap()],
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
)
.unwrap();
Ok(())
Expand Down
9 changes: 8 additions & 1 deletion benches/relu_lookupless.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,14 @@ impl Circuit<Fr> for NLCircuit {
|region| {
let mut region = RegionCtx::new(region, 0, 1, 1024, 2);
config
.layout(&mut region, &[self.input.clone()], Box::new(PolyOp::ReLU))
.layout(
&mut region,
&[self.input.clone()],
Box::new(PolyOp::LeakyReLU {
slope: 0.0.into(),
scale: 1,
}),
)
.unwrap();
Ok(())
},
Expand Down
4 changes: 2 additions & 2 deletions benches/relu.rs → benches/sigmoid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ impl Circuit<Fr> for NLCircuit {
.map(|_| VarTensor::new_advice(cs, K, 1, LEN))
.collect::<Vec<_>>();

let nl = LookupOp::LeakyReLU { slope: 0.0.into() };
let nl = LookupOp::Sigmoid { scale: 1.0.into() };

let mut config = Config::default();

Expand All @@ -68,7 +68,7 @@ impl Circuit<Fr> for NLCircuit {
.layout(
&mut region,
&[self.input.clone()],
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
)
.unwrap();
Ok(())
Expand Down
26 changes: 16 additions & 10 deletions examples/conv2d_mnist/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ where
let params = VarTensor::new_advice(cs, K, NUM_INNER_COLS, LEN);
let output = VarTensor::new_advice(cs, K, NUM_INNER_COLS, LEN);

let _constant = VarTensor::constant_cols(cs, K, LEN, false);

println!("INPUT COL {:#?}", input);

let mut layer_config = PolyConfig::configure(
Expand All @@ -156,15 +158,11 @@ where
);

layer_config
.configure_lookup(
cs,
&input,
&output,
&params,
(LOOKUP_MIN, LOOKUP_MAX),
K,
&LookupOp::LeakyReLU { slope: 0.0.into() },
)
.configure_range_check(cs, &input, &params, (-1, 1), K)
.unwrap();

layer_config
.configure_range_check(cs, &input, &params, (0, 1023), K)
.unwrap();

layer_config
Expand Down Expand Up @@ -195,6 +193,11 @@ where
) -> Result<(), Error> {
config.layer_config.layout_tables(&mut layouter).unwrap();

config
.layer_config
.layout_range_checks(&mut layouter)
.unwrap();

let x = layouter
.assign_region(
|| "mlp_4d",
Expand Down Expand Up @@ -224,7 +227,10 @@ where
.layout(
&mut region,
&[x.unwrap()],
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
Box::new(PolyOp::LeakyReLU {
slope: 0.0.into(),
scale: 1,
}),
)
.unwrap();

Expand Down
34 changes: 22 additions & 12 deletions examples/mlp_4d_einsum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,24 +53,23 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
let output = VarTensor::new_advice(cs, K, 1, LEN);
// tells the config layer to add an affine op to the circuit gate

let _constant = VarTensor::constant_cols(cs, K, LEN, false);

println!("INPUT COL {:#?}", input);

let mut layer_config = PolyConfig::<F>::configure(
cs,
&[input.clone(), params.clone()],
&output,
CheckMode::SAFE,
);

// sets up a new ReLU table and resuses it for l1 and l3 non linearities
layer_config
.configure_lookup(
cs,
&input,
&output,
&params,
(LOOKUP_MIN, LOOKUP_MAX),
K,
&LookupOp::LeakyReLU { slope: 0.0.into() },
)
.configure_range_check(cs, &input, &params, (-1, 1), K)
.unwrap();

layer_config
.configure_range_check(cs, &input, &params, (0, 1023), K)
.unwrap();

// sets up a new ReLU table and resuses it for l1 and l3 non linearities
Expand Down Expand Up @@ -104,6 +103,11 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
) -> Result<(), Error> {
config.layer_config.layout_tables(&mut layouter).unwrap();

config
.layer_config
.layout_range_checks(&mut layouter)
.unwrap();

let x = layouter
.assign_region(
|| "mlp_4d",
Expand Down Expand Up @@ -144,7 +148,10 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
.layout(
&mut region,
&[x],
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
Box::new(PolyOp::LeakyReLU {
scale: 1,
slope: 0.0.into(),
}),
)
.unwrap()
.unwrap();
Expand Down Expand Up @@ -184,7 +191,10 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
.layout(
&mut region,
&[x],
Box::new(LookupOp::LeakyReLU { slope: 0.0.into() }),
Box::new(PolyOp::LeakyReLU {
scale: 1,
slope: 0.0.into(),
}),
)
.unwrap();
println!("6");
Expand Down
3 changes: 3 additions & 0 deletions src/circuit/ops/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,4 +94,7 @@ pub enum CircuitError {
#[error("[io] {0}")]
/// IO error
IoError(#[from] std::io::Error),
/// Invalid scale
#[error("negative scale for an op that requires positive inputs {0}")]
NegativeScale(String),
}
44 changes: 39 additions & 5 deletions src/circuit/ops/layouts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4305,7 +4305,6 @@ pub(crate) fn sign<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
) -> Result<ValTensor<F>, CircuitError> {
let mut decomp = decompose(config, region, values, &region.base(), &region.legs())?;
// get every n elements now, which correspond to the sign bit

decomp.get_every_n(region.legs() + 1)?;
decomp.reshape(values[0].dims())?;

Expand All @@ -4322,10 +4321,12 @@ pub(crate) fn abs<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
pairwise(config, region, &[values[0].clone(), sign], BaseOp::Mult)
}

pub(crate) fn relu<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
pub(crate) fn leaky_relu<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
alpha: &utils::F32,
input_scale: &i32,
) -> Result<ValTensor<F>, CircuitError> {
let sign = sign(config, region, values)?;

Expand All @@ -4334,12 +4335,45 @@ pub(crate) fn relu<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(

let relu_mask = equals(config, region, &[sign, unit])?;

pairwise(
let positive = pairwise(
config,
region,
&[values[0].clone(), relu_mask],
&[values[0].clone(), relu_mask.clone()],
BaseOp::Mult,
)
)?;

if alpha.0 == 0. {
return Ok(positive);
}

if input_scale < &0 {
return Err(CircuitError::NegativeScale("leaky_relu".to_string()));
}

let scale_constant = create_constant_tensor(F::from(2_i32.pow(*input_scale as u32) as u64), 1);

let rescaled_positive = pairwise(config, region, &[positive, scale_constant], BaseOp::Mult)?;

let neg_mask = not(config, region, &[relu_mask])?;

let quantized_alpha = quantize_tensor(
Tensor::from([alpha.0; 1].into_iter()),
*input_scale,
&crate::graph::Visibility::Fixed,
)?;

let alpha_tensor = create_constant_tensor(quantized_alpha[0], 1);

let scaled_neg_mask = pairwise(config, region, &[neg_mask, alpha_tensor], BaseOp::Mult)?;

let neg_part = pairwise(
config,
region,
&[values[0].clone(), scaled_neg_mask],
BaseOp::Mult,
)?;

pairwise(config, region, &[rescaled_positive, neg_part], BaseOp::Add)
}

fn multi_dim_axes_op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
Expand Down
15 changes: 0 additions & 15 deletions src/circuit/ops/lookup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,6 @@ pub enum LookupOp {
input_scale: utils::F32,
output_scale: utils::F32,
},
LeakyReLU {
slope: utils::F32,
},
Sigmoid {
scale: utils::F32,
},
Expand Down Expand Up @@ -94,7 +91,6 @@ pub enum LookupOp {
Erf {
scale: utils::F32,
},
KroneckerDelta,
Pow {
scale: utils::F32,
a: utils::F32,
Expand All @@ -120,14 +116,12 @@ impl LookupOp {
LookupOp::Round { scale } => format!("round_{}", scale),
LookupOp::RoundHalfToEven { scale } => format!("round_half_to_even_{}", scale),
LookupOp::Pow { scale, a } => format!("pow_{}_{}", scale, a),
LookupOp::KroneckerDelta => "kronecker_delta".into(),
LookupOp::Div { denom } => format!("div_{}", denom),
LookupOp::Cast { scale } => format!("cast_{}", scale),
LookupOp::Recip {
input_scale,
output_scale,
} => format!("recip_{}_{}", input_scale, output_scale),
LookupOp::LeakyReLU { slope: a } => format!("leaky_relu_{}", a),
LookupOp::Sigmoid { scale } => format!("sigmoid_{}", scale),
LookupOp::Sqrt { scale } => format!("sqrt_{}", scale),
LookupOp::Rsqrt { scale } => format!("rsqrt_{}", scale),
Expand Down Expand Up @@ -173,9 +167,6 @@ impl LookupOp {
LookupOp::Pow { scale, a } => Ok::<_, TensorError>(
tensor::ops::nonlinearities::pow(&x, scale.0.into(), a.0.into()),
),
LookupOp::KroneckerDelta => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::kronecker_delta(&x))
}
LookupOp::Div { denom } => Ok::<_, TensorError>(
tensor::ops::nonlinearities::const_div(&x, f32::from(*denom).into()),
),
Expand All @@ -190,9 +181,6 @@ impl LookupOp {
input_scale.into(),
output_scale.into(),
)),
LookupOp::LeakyReLU { slope: a } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::leakyrelu(&x, a.0.into()))
}
LookupOp::Sigmoid { scale } => {
Ok::<_, TensorError>(tensor::ops::nonlinearities::sigmoid(&x, scale.into()))
}
Expand Down Expand Up @@ -272,7 +260,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Lookup
LookupOp::Round { scale } => format!("ROUND(scale={})", scale),
LookupOp::RoundHalfToEven { scale } => format!("ROUND_HALF_TO_EVEN(scale={})", scale),
LookupOp::Pow { a, scale } => format!("POW(scale={}, exponent={})", scale, a),
LookupOp::KroneckerDelta => "K_DELTA".into(),
LookupOp::Recip {
input_scale,
output_scale,
Expand All @@ -283,7 +270,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Lookup
LookupOp::Div { denom, .. } => format!("DIV(denom={})", denom),
LookupOp::Cast { scale } => format!("CAST(scale={})", scale),
LookupOp::Ln { scale } => format!("LN(scale={})", scale),
LookupOp::LeakyReLU { slope: a } => format!("L_RELU(slope={})", a),
LookupOp::Sigmoid { scale } => format!("SIGMOID(scale={})", scale),
LookupOp::Sqrt { scale } => format!("SQRT(scale={})", scale),
LookupOp::Erf { scale } => format!("ERF(scale={})", scale),
Expand Down Expand Up @@ -327,7 +313,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Lookup
in_scale + multiplier_to_scale(1. / scale.0 as f64)
}
LookupOp::Recip { output_scale, .. } => multiplier_to_scale(output_scale.into()),
LookupOp::KroneckerDelta => 0,
_ => inputs_scale[0],
};
Ok(scale)
Expand Down
Loading
Loading