From e5aa48fbd6e0a33825bd27dfaedf86c9c5fe509c Mon Sep 17 00:00:00 2001 From: dante <45801863+alexander-camuto@users.noreply.github.com> Date: Sat, 5 Oct 2024 10:43:12 -0400 Subject: [PATCH] chore: support all padding types (#848) --- Cargo.lock | 18 +++++-------- examples/onnx/boolean/gen.py | 4 ++- examples/onnx/boolean/input.json | 2 +- examples/onnx/boolean/network.onnx | 26 ++++++++---------- src/graph/node.rs | 14 ++++++++++ src/graph/utilities.rs | 43 +++++++++++++++--------------- 6 files changed, 57 insertions(+), 50 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 1ac569874..0fdf7eb27 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2950,11 +2950,11 @@ dependencies = [ [[package]] name = "lazy_static" -version = "1.4.0" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" dependencies = [ - "spin 0.5.2", + "spin", ] [[package]] @@ -3532,9 +3532,9 @@ dependencies = [ [[package]] name = "parking_lot" -version = "0.12.1" +version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3742b2c103b9f06bc9fff0a37ff4912935851bee6d36f3c02bcc755bcfec228f" +checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" dependencies = [ "lock_api", "parking_lot_core", @@ -4426,7 +4426,7 @@ dependencies = [ "cfg-if", "getrandom", "libc", - "spin 0.9.8", + "spin", "untrusted", "windows-sys 0.52.0", ] @@ -4980,12 +4980,6 @@ dependencies = [ "unicode-xid", ] -[[package]] -name = "spin" -version = "0.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d" - [[package]] name = "spin" version = "0.9.8" diff --git a/examples/onnx/boolean/gen.py b/examples/onnx/boolean/gen.py index 0a434ccf4..c6564fd87 100644 --- a/examples/onnx/boolean/gen.py +++ b/examples/onnx/boolean/gen.py @@ -9,7 +9,9 @@ def __init__(self): super(MyModel, self).__init__() def forward(self, w, x, y, z): - return [((x & y)) == (x & (y | (z ^ w)))] + a = (x & y) + b = (y & (z ^ w)) + return [a & b] circuit = MyModel() diff --git a/examples/onnx/boolean/input.json b/examples/onnx/boolean/input.json index f4c5b9b85..18d044f3c 100644 --- a/examples/onnx/boolean/input.json +++ b/examples/onnx/boolean/input.json @@ -1 +1 @@ -{"input_data": [[false, true, false], [true, false, false], [true, false, false], [false, false, false]]} \ No newline at end of file +{"input_data": [[false, true, true], [false, true, true], [true, false, false], [false, true, true]]} \ No newline at end of file diff --git a/examples/onnx/boolean/network.onnx b/examples/onnx/boolean/network.onnx index b3b9a95e8..c16e16602 100644 --- a/examples/onnx/boolean/network.onnx +++ b/examples/onnx/boolean/network.onnx @@ -1,21 +1,17 @@ -pytorch1.12.1:« -+ +pytorch2.2.2:„ +* input1 -input2 onnx::Equal_4And_0"And -' +input2 /And_output_0/And"And +) input3 -input -onnx::Or_5Xor_1"Xor -+ +input /Xor_output_0/Xor"Xor +5 input2 - -onnx::Or_5 onnx::And_6Or_2"Or -0 -input1 - onnx::And_6 onnx::Equal_7And_3"And -6 - onnx::Equal_4 - onnx::Equal_7outputEqual_4"Equal torch_jitZ! + /Xor_output_0/And_1_output_0/And_1"And +5 + /And_output_0 +/And_1_output_0output/And_2"And +main_graphZ! input    diff --git a/src/graph/node.rs b/src/graph/node.rs index 539f10d79..a46654752 100644 --- a/src/graph/node.rs +++ b/src/graph/node.rs @@ -125,6 +125,7 @@ impl RebaseScale { if (op_out_scale > (global_scale * scale_rebase_multiplier as i32)) && !inner.is_constant() && !inner.is_input() + && !inner.is_identity() { let multiplier = scale_to_multiplier(op_out_scale - global_scale * scale_rebase_multiplier as i32); @@ -326,6 +327,19 @@ impl SupportedOp { SupportedOp::RebaseScale(op) => op, } } + + /// check if is the identity operation + /// # Returns + /// * `true` if the operation is the identity operation + /// * `false` otherwise + pub fn is_identity(&self) -> bool { + match self { + SupportedOp::Linear(op) => matches!(op, PolyOp::Identity { .. }), + SupportedOp::Rescaled(op) => op.inner.is_identity(), + SupportedOp::RebaseScale(op) => op.inner.is_identity(), + _ => false, + } + } } impl From>> for SupportedOp { diff --git a/src/graph/utilities.rs b/src/graph/utilities.rs index 75b8610c5..8bb2eb193 100644 --- a/src/graph/utilities.rs +++ b/src/graph/utilities.rs @@ -41,7 +41,7 @@ use tract_onnx::tract_hir::{ ops::konst::Const, ops::nn::DataFormat, tract_core::ops::cast::Cast, - tract_core::ops::cnn::{conv::KernelFormat, MaxPool, PaddingSpec, SumPool}, + tract_core::ops::cnn::{conv::KernelFormat, MaxPool, SumPool}, }; /// Quantizes an iterable of f32s to a [Tensor] of i32s using a fixed point representation. @@ -94,17 +94,18 @@ pub fn multiplier_to_scale(mult: f64) -> crate::Scale { /// extract padding from a onnx node. pub fn extract_padding( pool_spec: &PoolSpec, - num_dims: usize, + image_size: &[usize], ) -> Result, GraphError> { - let padding = match &pool_spec.padding { - PaddingSpec::Explicit(b, a) | PaddingSpec::ExplicitOnnxPool(b, a, _) => { - b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect() - } - PaddingSpec::Valid => vec![(0, 0); num_dims], - _ => { - return Err(GraphError::MissingParams("padding".to_string())); - } - }; + let num_relevant_dims = pool_spec.kernel_shape.len(); + + // get the last num_relevant_dims of the image size + let image_size = &image_size[image_size.len() - num_relevant_dims..]; + + let dims = pool_spec.computed_padding(image_size); + let mut padding = Vec::new(); + for dim in dims { + padding.push((dim.pad_before, dim.pad_after)); + } Ok(padding) } @@ -1016,8 +1017,13 @@ pub fn new_op_from_onnx( if raw_values.log2().fract() == 0.0 { inputs[const_idx].decrement_use(); deleted_indices.push(const_idx); + // get the non constant index + let non_const_idx = if const_idx == 0 { 1 } else { 0 }; + op = SupportedOp::Linear(PolyOp::Identity { - out_scale: Some(input_scales[0] + raw_values.log2() as i32), + out_scale: Some( + input_scales[non_const_idx] + raw_values.log2() as i32, + ), }); } } @@ -1108,7 +1114,7 @@ pub fn new_op_from_onnx( } let stride = extract_strides(pool_spec)?; - let padding = extract_padding(pool_spec, input_dims[0].len())?; + let padding = extract_padding(pool_spec, &input_dims[0])?; let kernel_shape = &pool_spec.kernel_shape; SupportedOp::Hybrid(HybridOp::MaxPool { @@ -1178,7 +1184,7 @@ pub fn new_op_from_onnx( let pool_spec = &conv_node.pool_spec; let stride = extract_strides(pool_spec)?; - let padding = extract_padding(pool_spec, input_dims[0].len())?; + let padding = extract_padding(pool_spec, &input_dims[0])?; // if bias exists then rescale it to the input + kernel scale if input_scales.len() == 3 { @@ -1236,7 +1242,7 @@ pub fn new_op_from_onnx( let pool_spec = &deconv_node.pool_spec; let stride = extract_strides(pool_spec)?; - let padding = extract_padding(pool_spec, input_dims[0].len())?; + let padding = extract_padding(pool_spec, &input_dims[0])?; // if bias exists then rescale it to the input + kernel scale if input_scales.len() == 3 { let bias_scale = input_scales[2]; @@ -1349,7 +1355,7 @@ pub fn new_op_from_onnx( } let stride = extract_strides(pool_spec)?; - let padding = extract_padding(pool_spec, input_dims[0].len())?; + let padding = extract_padding(pool_spec, &input_dims[0])?; SupportedOp::Hybrid(HybridOp::SumPool { padding, @@ -1358,11 +1364,6 @@ pub fn new_op_from_onnx( normalized: sumpool_node.normalize, }) } - // "GlobalAvgPool" => SupportedOp::Linear(PolyOp::SumPool { - // padding: [(0, 0); 2], - // stride: (1, 1), - // kernel_shape: (inputs[0].out_dims()[0][1], inputs[0].out_dims()[0][2]), - // }), "Pad" => { let pad_node: &Pad = match node.op().downcast_ref::() { Some(b) => b,