From 0111ce4043851be5865ed88961f48e73aa924f43 Mon Sep 17 00:00:00 2001 From: dante <45801863+alexander-camuto@users.noreply.github.com> Date: Tue, 29 Oct 2024 11:56:50 -0400 Subject: [PATCH 1/2] chore: rm lookup recip --- examples/onnx/rsqrt/gen.py | 42 ++++++++++ examples/onnx/rsqrt/input.json | 1 + examples/onnx/rsqrt/network.onnx | 17 ++++ src/circuit/ops/hybrid.rs | 38 +++------ src/circuit/ops/lookup.rs | 129 +++++++------------------------ src/graph/utilities.rs | 16 ++-- tests/integration_tests.rs | 7 +- 7 files changed, 110 insertions(+), 140 deletions(-) create mode 100644 examples/onnx/rsqrt/gen.py create mode 100644 examples/onnx/rsqrt/input.json create mode 100644 examples/onnx/rsqrt/network.onnx diff --git a/examples/onnx/rsqrt/gen.py b/examples/onnx/rsqrt/gen.py new file mode 100644 index 000000000..c52776966 --- /dev/null +++ b/examples/onnx/rsqrt/gen.py @@ -0,0 +1,42 @@ +from torch import nn +import torch +import json +import numpy as np + + +class MyModel(nn.Module): + def __init__(self): + super(MyModel, self).__init__() + + def forward(self, x): + # reciprocal sqrt + m = 1 / torch.sqrt(x) + return m + + +circuit = MyModel() + +x = torch.empty(1, 8).uniform_(0, 1) + +out = circuit(x) + +print(out) + +torch.onnx.export(circuit, x, "network.onnx", + export_params=True, # store the trained parameter weights inside the model file + opset_version=17, # the ONNX version to export the model to + do_constant_folding=True, # whether to execute constant folding for optimization + input_names=['input'], # the model's input names + output_names=['output'], # the model's output names + dynamic_axes={'input': {0: 'batch_size'}, # variable length axes + 'output': {0: 'batch_size'}}) + + +d1 = ((x).detach().numpy()).reshape([-1]).tolist() + +data = dict( + input_data=[d1], +) + +# Serialize data into file: +json.dump(data, open("input.json", 'w')) diff --git a/examples/onnx/rsqrt/input.json b/examples/onnx/rsqrt/input.json new file mode 100644 index 000000000..da4b8fb68 --- /dev/null +++ b/examples/onnx/rsqrt/input.json @@ -0,0 +1 @@ +{"input_data": [[0.8590779900550842, 0.4029041528701782, 0.6507361531257629, 0.9782488942146301, 0.37392884492874146, 0.6867020726203918, 0.11407750844955444, 0.362740159034729]]} \ No newline at end of file diff --git a/examples/onnx/rsqrt/network.onnx b/examples/onnx/rsqrt/network.onnx new file mode 100644 index 000000000..b306e3c15 --- /dev/null +++ b/examples/onnx/rsqrt/network.onnx @@ -0,0 +1,17 @@ +pytorch2.2.2:¬ +$ +input/Sqrt_output_0/Sqrt"Sqrt +1 +/Sqrt_output_0output /Reciprocal" +Reciprocal +main_graphZ! +input + +  +batch_size +b" +output + +  +batch_size +B \ No newline at end of file diff --git a/src/circuit/ops/hybrid.rs b/src/circuit/ops/hybrid.rs index 4a081426d..a145e5997 100644 --- a/src/circuit/ops/hybrid.rs +++ b/src/circuit/ops/hybrid.rs @@ -16,7 +16,6 @@ pub enum HybridOp { Recip { input_scale: utils::F32, output_scale: utils::F32, - use_range_check_for_int: bool, }, Div { denom: utils::F32, @@ -102,10 +101,9 @@ impl Op for Hybrid HybridOp::Recip { input_scale, output_scale, - use_range_check_for_int, } => format!( - "RECIP (input_scale={}, output_scale={}, use_range_check_for_int={})", - input_scale, output_scale, use_range_check_for_int + "RECIP (input_scale={}, output_scale={})", + input_scale, output_scale ), HybridOp::Div { denom, @@ -187,31 +185,13 @@ impl Op for Hybrid HybridOp::Recip { input_scale, output_scale, - use_range_check_for_int, - } => { - if input_scale.0.fract() == 0.0 - && output_scale.0.fract() == 0.0 - && *use_range_check_for_int - { - layouts::recip( - config, - region, - values[..].try_into()?, - integer_rep_to_felt(input_scale.0 as i128), - integer_rep_to_felt(output_scale.0 as i128), - )? - } else { - layouts::nonlinearity( - config, - region, - values.try_into()?, - &LookupOp::Recip { - input_scale: *input_scale, - output_scale: *output_scale, - }, - )? - } - } + } => layouts::recip( + config, + region, + values[..].try_into()?, + integer_rep_to_felt(input_scale.0 as i128), + integer_rep_to_felt(output_scale.0 as i128), + )?, HybridOp::Div { denom, use_range_check_for_int, diff --git a/src/circuit/ops/lookup.rs b/src/circuit/ops/lookup.rs index 73caf81c3..825eb0806 100644 --- a/src/circuit/ops/lookup.rs +++ b/src/circuit/ops/lookup.rs @@ -15,89 +15,32 @@ use halo2curves::ff::PrimeField; /// An enum representing the operations that can be used to express more complex operations via accumulation #[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Deserialize, Serialize)] pub enum LookupOp { - Div { - denom: utils::F32, - }, - Cast { - scale: utils::F32, - }, - Ceil { - scale: utils::F32, - }, - Floor { - scale: utils::F32, - }, - Round { - scale: utils::F32, - }, - RoundHalfToEven { - scale: utils::F32, - }, - Sqrt { - scale: utils::F32, - }, - Rsqrt { - scale: utils::F32, - }, - Recip { - input_scale: utils::F32, - output_scale: utils::F32, - }, - Sigmoid { - scale: utils::F32, - }, - Ln { - scale: utils::F32, - }, - Exp { - scale: utils::F32, - }, - Cos { - scale: utils::F32, - }, - ACos { - scale: utils::F32, - }, - Cosh { - scale: utils::F32, - }, - ACosh { - scale: utils::F32, - }, - Sin { - scale: utils::F32, - }, - ASin { - scale: utils::F32, - }, - Sinh { - scale: utils::F32, - }, - ASinh { - scale: utils::F32, - }, - Tan { - scale: utils::F32, - }, - ATan { - scale: utils::F32, - }, - Tanh { - scale: utils::F32, - }, - ATanh { - scale: utils::F32, - }, - Erf { - scale: utils::F32, - }, - Pow { - scale: utils::F32, - a: utils::F32, - }, - HardSwish { - scale: utils::F32, - }, + Div { denom: utils::F32 }, + Cast { scale: utils::F32 }, + Ceil { scale: utils::F32 }, + Floor { scale: utils::F32 }, + Round { scale: utils::F32 }, + RoundHalfToEven { scale: utils::F32 }, + Sqrt { scale: utils::F32 }, + Rsqrt { scale: utils::F32 }, + Sigmoid { scale: utils::F32 }, + Ln { scale: utils::F32 }, + Exp { scale: utils::F32 }, + Cos { scale: utils::F32 }, + ACos { scale: utils::F32 }, + Cosh { scale: utils::F32 }, + ACosh { scale: utils::F32 }, + Sin { scale: utils::F32 }, + ASin { scale: utils::F32 }, + Sinh { scale: utils::F32 }, + ASinh { scale: utils::F32 }, + Tan { scale: utils::F32 }, + ATan { scale: utils::F32 }, + Tanh { scale: utils::F32 }, + ATanh { scale: utils::F32 }, + Erf { scale: utils::F32 }, + Pow { scale: utils::F32, a: utils::F32 }, + HardSwish { scale: utils::F32 }, } impl LookupOp { @@ -118,10 +61,6 @@ impl LookupOp { LookupOp::Pow { scale, a } => format!("pow_{}_{}", scale, a), LookupOp::Div { denom } => format!("div_{}", denom), LookupOp::Cast { scale } => format!("cast_{}", scale), - LookupOp::Recip { - input_scale, - output_scale, - } => format!("recip_{}_{}", input_scale, output_scale), LookupOp::Sigmoid { scale } => format!("sigmoid_{}", scale), LookupOp::Sqrt { scale } => format!("sqrt_{}", scale), LookupOp::Rsqrt { scale } => format!("rsqrt_{}", scale), @@ -173,14 +112,6 @@ impl LookupOp { LookupOp::Cast { scale } => Ok::<_, TensorError>( tensor::ops::nonlinearities::const_div(&x, f32::from(*scale).into()), ), - LookupOp::Recip { - input_scale, - output_scale, - } => Ok::<_, TensorError>(tensor::ops::nonlinearities::recip( - &x, - input_scale.into(), - output_scale.into(), - )), LookupOp::Sigmoid { scale } => { Ok::<_, TensorError>(tensor::ops::nonlinearities::sigmoid(&x, scale.into())) } @@ -260,13 +191,6 @@ impl Op for Lookup LookupOp::Round { scale } => format!("ROUND(scale={})", scale), LookupOp::RoundHalfToEven { scale } => format!("ROUND_HALF_TO_EVEN(scale={})", scale), LookupOp::Pow { a, scale } => format!("POW(scale={}, exponent={})", scale, a), - LookupOp::Recip { - input_scale, - output_scale, - } => format!( - "RECIP(input_scale={}, output_scale={})", - input_scale, output_scale - ), LookupOp::Div { denom, .. } => format!("DIV(denom={})", denom), LookupOp::Cast { scale } => format!("CAST(scale={})", scale), LookupOp::Ln { scale } => format!("LN(scale={})", scale), @@ -312,7 +236,6 @@ impl Op for Lookup let in_scale = inputs_scale[0]; in_scale + multiplier_to_scale(1. / scale.0 as f64) } - LookupOp::Recip { output_scale, .. } => multiplier_to_scale(output_scale.into()), _ => inputs_scale[0], }; Ok(scale) diff --git a/src/graph/utilities.rs b/src/graph/utilities.rs index 8aa9de6ab..aa6b73f8e 100644 --- a/src/graph/utilities.rs +++ b/src/graph/utilities.rs @@ -809,7 +809,6 @@ pub fn new_op_from_onnx( SupportedOp::Hybrid(HybridOp::Recip { input_scale: (scale_to_multiplier(in_scale) as f32).into(), output_scale: (scale_to_multiplier(max_scale) as f32).into(), - use_range_check_for_int: true, }) } @@ -1107,10 +1106,17 @@ pub fn new_op_from_onnx( if c.raw_values.len() > 1 { unimplemented!("only support scalar pow") } - SupportedOp::Nonlinear(LookupOp::Pow { - scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(), - a: crate::circuit::utils::F32(c.raw_values[0]), - }) + + let exponent = c.raw_values[0]; + + if exponent.fract() == 0.0 { + SupportedOp::Linear(PolyOp::Pow(exponent as u32)) + } else { + SupportedOp::Nonlinear(LookupOp::Pow { + scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(), + a: crate::circuit::utils::F32(exponent), + }) + } } else { unimplemented!("only support constant pow for now") } diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs index ff7752dbe..ac85672bc 100644 --- a/tests/integration_tests.rs +++ b/tests/integration_tests.rs @@ -205,7 +205,7 @@ mod native_tests { "1l_tiny_div", ]; - const TESTS: [&str; 94] = [ + const TESTS: [&str; 95] = [ "1l_mlp", //0 "1l_slice", "1l_concat", @@ -304,6 +304,7 @@ mod native_tests { "lstm_large", // 91 "lstm_medium", // 92 "lenet_5", // 93 + "rqsrt", // 94 ]; const WASM_TESTS: [&str; 46] = [ @@ -542,7 +543,7 @@ mod native_tests { } }); - seq!(N in 0..=93 { + seq!(N in 0..=94 { #(#[test_case(TESTS[N])])* #[ignore] @@ -1118,7 +1119,7 @@ mod native_tests { }); - seq!(N in 0..=93 { + seq!(N in 0..4 { #(#[test_case(TESTS[N])])* fn kzg_evm_prove_and_verify_reusable_verifier_(test: &str) { crate::native_tests::init_binary(); From 1dd9ef4d5bde342f96eac64a523a7a45e3116b14 Mon Sep 17 00:00:00 2001 From: dante <45801863+alexander-camuto@users.noreply.github.com> Date: Tue, 29 Oct 2024 12:41:34 -0400 Subject: [PATCH 2/2] Update integration_tests.rs --- tests/integration_tests.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration_tests.rs b/tests/integration_tests.rs index ac85672bc..ecfb9daab 100644 --- a/tests/integration_tests.rs +++ b/tests/integration_tests.rs @@ -304,7 +304,7 @@ mod native_tests { "lstm_large", // 91 "lstm_medium", // 92 "lenet_5", // 93 - "rqsrt", // 94 + "rsqrt", // 94 ]; const WASM_TESTS: [&str; 46] = [