Skip to content

Commit

Permalink
support optional outputs in function decomposition (#2493)
Browse files Browse the repository at this point in the history
Signed-off-by: Soren Lassen <[email protected]>
  • Loading branch information
sorenlassen authored Sep 12, 2023
1 parent c69721b commit 971f39a
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/Builder/FrontendDialectTransformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1235,7 +1235,10 @@ class FrontendGenImpl {
}

for (auto &name : functionProto.output()) {
outputs.push_back(LookupOnnxName(name));
// Skip missing optional outputs: they are not mapped.
if (const Value *valuePtr = frontend_symbols_.GetByOnnxName(name)) {
outputs.push_back(*valuePtr);
}
}

frontend_symbols_.popScope(scopeName);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// RUN: onnx-mlir --functions-to-decompose=LayerNormalization --EmitONNXBasic --printIR %s | FileCheck %s

// from onnx-mlir issue #2492
<
ir_version: 8,
opset_import: ["" : 17]
>
agraph (float[12,3,5] X, float[5] S) => (float[12,3,5] LN) {
LN = LayerNormalization (X, S)
}
// CHECK-LABEL: func.func @main_graph
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<12x3x5xf32>, [[PARAM_1_:%.+]]: tensor<5xf32>) -> tensor<12x3x5xf32> attributes {input_names = ["X", "S"], output_names = ["LN"]} {
// CHECK-DAG: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none
// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<9.99999974E-6> : tensor<f32>
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.Cast"([[VAR_1_]]) {saturate = 1 : si64, to = f32} : (tensor<f32>) -> tensor<f32>
// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Shape"([[PARAM_0_]]) {start = 0 : si64} : (tensor<12x3x5xf32>) -> tensor<3xi64>
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_4_:%.+]] = "onnx.Size"([[VAR_3_]]) : (tensor<3xi64>) -> tensor<i64>
// CHECK-DAG: [[VAR_5_:%.+]] = onnx.Constant dense<0> : tensor<1xi64>
// CHECK-DAG: [[VAR_6_:%.+]] = onnx.Constant dense<-1> : tensor<1xi64>
// CHECK-DAG: [[VAR_7_:%.+]] = "onnx.NoValue"() {value} : () -> none
// CHECK-DAG: [[VAR_8_:%.+]] = "onnx.NoValue"() {value} : () -> none
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_9_:%.+]] = "onnx.Slice"([[VAR_3_]], [[VAR_5_]], [[VAR_6_]], [[VAR_7_]], [[VAR_8_]]) : (tensor<3xi64>, tensor<1xi64>, tensor<1xi64>, none, none) -> tensor<2xi64>
// CHECK-DAG: [[VAR_10_:%.+]] = "onnx.Neg"([[VAR_6_]]) : (tensor<1xi64>) -> tensor<1xi64>
// CHECK: [[VAR_11_:%.+]] = onnx.ConstantOfShape([[VAR_10_]]) {value = dense<1> : tensor<1xi64>} : (tensor<1xi64>) -> tensor<?xi64>
// CHECK-DAG: [[VAR_12_:%.+]] = "onnx.Concat"([[VAR_9_]], [[VAR_11_]]) {axis = 0 : si64} : (tensor<2xi64>, tensor<?xi64>) -> tensor<?xi64>
// CHECK-DAG: [[VAR_13_:%.+]] = "onnx.Flatten"([[PARAM_0_]]) {axis = -1 : si64} : (tensor<12x3x5xf32>) -> tensor<36x5xf32>
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_14_:%.+]] = "onnx.Cast"([[VAR_13_]]) {saturate = 1 : si64, to = f32} : (tensor<36x5xf32>) -> tensor<36x5xf32>
// CHECK-DAG: [[VAR_15_:%.+]] = "onnx.NoValue"() {value} : () -> none
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_16_:%.+]] = "onnx.ReduceMean"([[VAR_14_]], [[VAR_15_]]) {axes = [1], keepdims = 1 : si64, noop_with_empty_axes = 0 : si64} : (tensor<36x5xf32>, none) -> tensor<36x1xf32>
// CHECK-DAG: [[VAR_17_:%.+]] = "onnx.Mul"([[VAR_14_]], [[VAR_14_]]) : (tensor<36x5xf32>, tensor<36x5xf32>) -> tensor<36x5xf32>
// CHECK-DAG: [[VAR_18_:%.+]] = "onnx.NoValue"() {value} : () -> none
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_19_:%.+]] = "onnx.ReduceMean"([[VAR_17_]], [[VAR_18_]]) {axes = [1], keepdims = 1 : si64, noop_with_empty_axes = 0 : si64} : (tensor<36x5xf32>, none) -> tensor<36x1xf32>
// CHECK-DAG: [[VAR_20_:%.+]] = "onnx.Mul"([[VAR_16_]], [[VAR_16_]]) : (tensor<36x1xf32>, tensor<36x1xf32>) -> tensor<36x1xf32>
// CHECK: [[VAR_21_:%.+]] = "onnx.Sub"([[VAR_19_]], [[VAR_20_]]) : (tensor<36x1xf32>, tensor<36x1xf32>) -> tensor<36x1xf32>
// CHECK: [[VAR_22_:%.+]] = "onnx.Add"([[VAR_21_]], [[VAR_2_]]) : (tensor<36x1xf32>, tensor<f32>) -> tensor<36x1xf32>
// CHECK-DAG: [[VAR_23_:%.+]] = "onnx.Sqrt"([[VAR_22_]]) : (tensor<36x1xf32>) -> tensor<36x1xf32>
// CHECK-DAG: [[VAR_24_:%.+]] = "onnx.Sub"([[VAR_14_]], [[VAR_16_]]) : (tensor<36x5xf32>, tensor<36x1xf32>) -> tensor<36x5xf32>
// CHECK: [[VAR_25_:%.+]] = "onnx.Div"([[VAR_24_]], [[VAR_23_]]) : (tensor<36x5xf32>, tensor<36x1xf32>) -> tensor<36x5xf32>
// CHECK-DAG: [[VAR_26_:%.+]] = "onnx.Cast"([[VAR_25_]]) {saturate = 1 : si64, to = f32} : (tensor<36x5xf32>) -> tensor<36x5xf32>
// CHECK-DAG: [[VAR_27_:%.+]] = "onnx.Flatten"([[PARAM_1_]]) {axis = 0 : si64} : (tensor<5xf32>) -> tensor<1x5xf32>
// CHECK: [[VAR_28_:%.+]] = "onnx.Mul"([[VAR_26_]], [[VAR_27_]]) : (tensor<36x5xf32>, tensor<1x5xf32>) -> tensor<36x5xf32>
// CHECK: [[VAR_29_:%.+]] = "onnx.Identity"([[VAR_28_]]) : (tensor<36x5xf32>) -> tensor<36x5xf32>
// CHECK-DAG: [[VAR_30_:%.+]] = "onnx.Reshape"([[VAR_29_]], [[VAR_3_]]) {allowzero = 0 : si64} : (tensor<36x5xf32>, tensor<3xi64>) -> tensor<12x3x5xf32>
// CHECK-DAG: [[VAR_31_:%.+]] = "onnx.Reciprocal"([[VAR_23_]]) : (tensor<36x1xf32>) -> tensor<36x1xf32>
// CHECK: onnx.Return [[VAR_30_]] : tensor<12x3x5xf32>
// CHECK: }

0 comments on commit 971f39a

Please sign in to comment.