From 971f39afa22f28dd062a5a010e31ccbf3350123b Mon Sep 17 00:00:00 2001 From: Soren Lassen Date: Mon, 11 Sep 2023 22:09:19 -0700 Subject: [PATCH] support optional outputs in function decomposition (#2493) Signed-off-by: Soren Lassen --- src/Builder/FrontendDialectTransformer.cpp | 5 +- ...malization_function_decomposition.onnxtext | 52 +++++++++++++++++++ 2 files changed, 56 insertions(+), 1 deletion(-) create mode 100644 test/mlir/onnx/parse/layer_normalization_function_decomposition.onnxtext diff --git a/src/Builder/FrontendDialectTransformer.cpp b/src/Builder/FrontendDialectTransformer.cpp index 004df5abc9..ab51de133e 100644 --- a/src/Builder/FrontendDialectTransformer.cpp +++ b/src/Builder/FrontendDialectTransformer.cpp @@ -1235,7 +1235,10 @@ class FrontendGenImpl { } for (auto &name : functionProto.output()) { - outputs.push_back(LookupOnnxName(name)); + // Skip missing optional outputs: they are not mapped. + if (const Value *valuePtr = frontend_symbols_.GetByOnnxName(name)) { + outputs.push_back(*valuePtr); + } } frontend_symbols_.popScope(scopeName); diff --git a/test/mlir/onnx/parse/layer_normalization_function_decomposition.onnxtext b/test/mlir/onnx/parse/layer_normalization_function_decomposition.onnxtext new file mode 100644 index 0000000000..29a9831c08 --- /dev/null +++ b/test/mlir/onnx/parse/layer_normalization_function_decomposition.onnxtext @@ -0,0 +1,52 @@ +// RUN: onnx-mlir --functions-to-decompose=LayerNormalization --EmitONNXBasic --printIR %s | FileCheck %s + +// from onnx-mlir issue #2492 +< + ir_version: 8, + opset_import: ["" : 17] +> +agraph (float[12,3,5] X, float[5] S) => (float[12,3,5] LN) { + LN = LayerNormalization (X, S) +} +// CHECK-LABEL: func.func @main_graph +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<12x3x5xf32>, [[PARAM_1_:%.+]]: tensor<5xf32>) -> tensor<12x3x5xf32> attributes {input_names = ["X", "S"], output_names = ["LN"]} { +// CHECK-DAG: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<9.99999974E-6> : tensor +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.Cast"([[VAR_1_]]) {saturate = 1 : si64, to = f32} : (tensor) -> tensor +// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Shape"([[PARAM_0_]]) {start = 0 : si64} : (tensor<12x3x5xf32>) -> tensor<3xi64> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_4_:%.+]] = "onnx.Size"([[VAR_3_]]) : (tensor<3xi64>) -> tensor +// CHECK-DAG: [[VAR_5_:%.+]] = onnx.Constant dense<0> : tensor<1xi64> +// CHECK-DAG: [[VAR_6_:%.+]] = onnx.Constant dense<-1> : tensor<1xi64> +// CHECK-DAG: [[VAR_7_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK-DAG: [[VAR_8_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_9_:%.+]] = "onnx.Slice"([[VAR_3_]], [[VAR_5_]], [[VAR_6_]], [[VAR_7_]], [[VAR_8_]]) : (tensor<3xi64>, tensor<1xi64>, tensor<1xi64>, none, none) -> tensor<2xi64> +// CHECK-DAG: [[VAR_10_:%.+]] = "onnx.Neg"([[VAR_6_]]) : (tensor<1xi64>) -> tensor<1xi64> +// CHECK: [[VAR_11_:%.+]] = onnx.ConstantOfShape([[VAR_10_]]) {value = dense<1> : tensor<1xi64>} : (tensor<1xi64>) -> tensor +// CHECK-DAG: [[VAR_12_:%.+]] = "onnx.Concat"([[VAR_9_]], [[VAR_11_]]) {axis = 0 : si64} : (tensor<2xi64>, tensor) -> tensor +// CHECK-DAG: [[VAR_13_:%.+]] = "onnx.Flatten"([[PARAM_0_]]) {axis = -1 : si64} : (tensor<12x3x5xf32>) -> tensor<36x5xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_14_:%.+]] = "onnx.Cast"([[VAR_13_]]) {saturate = 1 : si64, to = f32} : (tensor<36x5xf32>) -> tensor<36x5xf32> +// CHECK-DAG: [[VAR_15_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_16_:%.+]] = "onnx.ReduceMean"([[VAR_14_]], [[VAR_15_]]) {axes = [1], keepdims = 1 : si64, noop_with_empty_axes = 0 : si64} : (tensor<36x5xf32>, none) -> tensor<36x1xf32> +// CHECK-DAG: [[VAR_17_:%.+]] = "onnx.Mul"([[VAR_14_]], [[VAR_14_]]) : (tensor<36x5xf32>, tensor<36x5xf32>) -> tensor<36x5xf32> +// CHECK-DAG: [[VAR_18_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_19_:%.+]] = "onnx.ReduceMean"([[VAR_17_]], [[VAR_18_]]) {axes = [1], keepdims = 1 : si64, noop_with_empty_axes = 0 : si64} : (tensor<36x5xf32>, none) -> tensor<36x1xf32> +// CHECK-DAG: [[VAR_20_:%.+]] = "onnx.Mul"([[VAR_16_]], [[VAR_16_]]) : (tensor<36x1xf32>, tensor<36x1xf32>) -> tensor<36x1xf32> +// CHECK: [[VAR_21_:%.+]] = "onnx.Sub"([[VAR_19_]], [[VAR_20_]]) : (tensor<36x1xf32>, tensor<36x1xf32>) -> tensor<36x1xf32> +// CHECK: [[VAR_22_:%.+]] = "onnx.Add"([[VAR_21_]], [[VAR_2_]]) : (tensor<36x1xf32>, tensor) -> tensor<36x1xf32> +// CHECK-DAG: [[VAR_23_:%.+]] = "onnx.Sqrt"([[VAR_22_]]) : (tensor<36x1xf32>) -> tensor<36x1xf32> +// CHECK-DAG: [[VAR_24_:%.+]] = "onnx.Sub"([[VAR_14_]], [[VAR_16_]]) : (tensor<36x5xf32>, tensor<36x1xf32>) -> tensor<36x5xf32> +// CHECK: [[VAR_25_:%.+]] = "onnx.Div"([[VAR_24_]], [[VAR_23_]]) : (tensor<36x5xf32>, tensor<36x1xf32>) -> tensor<36x5xf32> +// CHECK-DAG: [[VAR_26_:%.+]] = "onnx.Cast"([[VAR_25_]]) {saturate = 1 : si64, to = f32} : (tensor<36x5xf32>) -> tensor<36x5xf32> +// CHECK-DAG: [[VAR_27_:%.+]] = "onnx.Flatten"([[PARAM_1_]]) {axis = 0 : si64} : (tensor<5xf32>) -> tensor<1x5xf32> +// CHECK: [[VAR_28_:%.+]] = "onnx.Mul"([[VAR_26_]], [[VAR_27_]]) : (tensor<36x5xf32>, tensor<1x5xf32>) -> tensor<36x5xf32> +// CHECK: [[VAR_29_:%.+]] = "onnx.Identity"([[VAR_28_]]) : (tensor<36x5xf32>) -> tensor<36x5xf32> +// CHECK-DAG: [[VAR_30_:%.+]] = "onnx.Reshape"([[VAR_29_]], [[VAR_3_]]) {allowzero = 0 : si64} : (tensor<36x5xf32>, tensor<3xi64>) -> tensor<12x3x5xf32> +// CHECK-DAG: [[VAR_31_:%.+]] = "onnx.Reciprocal"([[VAR_23_]]) : (tensor<36x1xf32>) -> tensor<36x1xf32> +// CHECK: onnx.Return [[VAR_30_]] : tensor<12x3x5xf32> +// CHECK: }