Skip to content

Commit

Permalink
fix function decomposition opset and model functions (#2543)
Browse files Browse the repository at this point in the history
Signed-off-by: Soren Lassen <[email protected]>
  • Loading branch information
sorenlassen authored Oct 2, 2023
1 parent 1a15f58 commit ded4d47
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 26 deletions.
9 changes: 5 additions & 4 deletions src/Builder/FrontendDialectTransformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1192,9 +1192,6 @@ class FrontendGenImpl {
onnx::OpSchemaRegistry::Instance(),
/*options=*/{}, in_model_functions_);

std::string scopeName =
node.name() + ":" + node.op_type() + ":" + functionProto.name();

// Save caller context, while generating function body.
ModelLocalFunctionsMap callerModelFunctions;
if (schema) {
Expand All @@ -1209,6 +1206,10 @@ class FrontendGenImpl {

// TODO: Reuse importGraph() logic.
{
opset_map_ = std::move(function_opset_map);

std::string scopeName =
node.name() + ":" + node.op_type() + ":" + functionProto.name();
frontend_symbols_.pushScope(scopeName);
onnx_type_map.pushScope(scopeName);

Expand Down Expand Up @@ -1246,7 +1247,7 @@ class FrontendGenImpl {

// Restore caller context.
if (schema) {
callerModelFunctions = std::move(in_model_functions_);
in_model_functions_ = std::move(callerModelFunctions);
}
opset_map_ = std::move(callerOpsetMap);
frontend_symbols_ = std::move(callerScope);
Expand Down
2 changes: 1 addition & 1 deletion test/mlir/onnx/parse/fun_model_test.onnxtext
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ square (x) => (y) {
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<?x128xf32>, [[PARAM_1_:%.+]]: tensor<128x10xf32>, [[PARAM_2_:%.+]]: tensor<10xf32>) -> tensor<?x10xf32> attributes {input_names = ["X", "W", "B"], output_names = ["C"]} {
// CHECK: [[VAR_0_:%.+]] = "onnx.MatMul"([[PARAM_0_]], [[PARAM_1_]]) : (tensor<?x128xf32>, tensor<128x10xf32>) -> tensor<?x10xf32>
// CHECK: [[VAR_1_:%.+]] = "onnx.Add"([[VAR_0_]], [[PARAM_2_]]) : (tensor<?x10xf32>, tensor<10xf32>) -> tensor<?x10xf32>
// CHECK: [[VAR_2_:%.+]] = "onnx.Softmax"([[VAR_1_]]) {axis = -1 : si64} : (tensor<?x10xf32>) -> tensor<?x10xf32>
// CHECK: [[VAR_2_:%.+]] = "onnx.SoftmaxV11"([[VAR_1_]]) {axis = 1 : si64} : (tensor<?x10xf32>) -> tensor<?x10xf32>
// CHECK: [[VAR_3_:%.+]] = "onnx.Mul"([[VAR_2_]], [[VAR_2_]]) : (tensor<?x10xf32>, tensor<?x10xf32>) -> tensor<?x10xf32>
// CHECK: onnx.Return [[VAR_3_]] : tensor<?x10xf32>
// CHECK: }
Original file line number Diff line number Diff line change
Expand Up @@ -27,26 +27,22 @@ agraph (float[12,3,5] X, float[5] S) => (float[12,3,5] LN) {
// CHECK: [[VAR_11_:%.+]] = onnx.ConstantOfShape([[VAR_10_]]) {value = dense<1> : tensor<1xi64>} : (tensor<1xi64>) -> tensor<?xi64>
// CHECK-DAG: [[VAR_12_:%.+]] = "onnx.Concat"([[VAR_9_]], [[VAR_11_]]) {axis = 0 : si64} : (tensor<2xi64>, tensor<?xi64>) -> tensor<?xi64>
// CHECK-DAG: [[VAR_13_:%.+]] = "onnx.Flatten"([[PARAM_0_]]) {axis = -1 : si64} : (tensor<12x3x5xf32>) -> tensor<36x5xf32>
// CHECK: [[VAR_14_:%.+]] = "onnx.Cast"([[VAR_13_]]) {saturate = 1 : si64, to = f32} : (tensor<36x5xf32>) -> tensor<36x5xf32>
// CHECK-DAG: [[VAR_15_:%.+]] = "onnx.ReduceMeanV13"([[VAR_14_]]) {axes = [1], keepdims = 1 : si64} : (tensor<36x5xf32>) -> tensor<36x1xf32>
// CHECK-DAG: [[VAR_16_:%.+]] = "onnx.Mul"([[VAR_14_]], [[VAR_14_]]) : (tensor<36x5xf32>, tensor<36x5xf32>) -> tensor<36x5xf32>
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_14_:%.+]] = "onnx.Cast"([[VAR_13_]]) {saturate = 1 : si64, to = f32} : (tensor<36x5xf32>) -> tensor<36x5xf32>
// CHECK-DAG: [[VAR_15_:%.+]] = "onnx.NoValue"() {value} : () -> none
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_16_:%.+]] = "onnx.ReduceMean"([[VAR_14_]], [[VAR_15_]]) {axes = [1], keepdims = 1 : si64, noop_with_empty_axes = 0 : si64} : (tensor<36x5xf32>, none) -> tensor<36x1xf32>
// CHECK-DAG: [[VAR_17_:%.+]] = "onnx.Mul"([[VAR_14_]], [[VAR_14_]]) : (tensor<36x5xf32>, tensor<36x5xf32>) -> tensor<36x5xf32>
// CHECK-DAG: [[VAR_18_:%.+]] = "onnx.NoValue"() {value} : () -> none
// CHECK-NOT: separator of consecutive DAGs
// CHECK-DAG: [[VAR_19_:%.+]] = "onnx.ReduceMean"([[VAR_17_]], [[VAR_18_]]) {axes = [1], keepdims = 1 : si64, noop_with_empty_axes = 0 : si64} : (tensor<36x5xf32>, none) -> tensor<36x1xf32>
// CHECK-DAG: [[VAR_20_:%.+]] = "onnx.Mul"([[VAR_16_]], [[VAR_16_]]) : (tensor<36x1xf32>, tensor<36x1xf32>) -> tensor<36x1xf32>
// CHECK: [[VAR_21_:%.+]] = "onnx.Sub"([[VAR_19_]], [[VAR_20_]]) : (tensor<36x1xf32>, tensor<36x1xf32>) -> tensor<36x1xf32>
// CHECK: [[VAR_22_:%.+]] = "onnx.Add"([[VAR_21_]], [[VAR_2_]]) : (tensor<36x1xf32>, tensor<f32>) -> tensor<36x1xf32>
// CHECK-DAG: [[VAR_23_:%.+]] = "onnx.Sqrt"([[VAR_22_]]) : (tensor<36x1xf32>) -> tensor<36x1xf32>
// CHECK-DAG: [[VAR_24_:%.+]] = "onnx.Sub"([[VAR_14_]], [[VAR_16_]]) : (tensor<36x5xf32>, tensor<36x1xf32>) -> tensor<36x5xf32>
// CHECK: [[VAR_25_:%.+]] = "onnx.Div"([[VAR_24_]], [[VAR_23_]]) : (tensor<36x5xf32>, tensor<36x1xf32>) -> tensor<36x5xf32>
// CHECK-DAG: [[VAR_26_:%.+]] = "onnx.Cast"([[VAR_25_]]) {saturate = 1 : si64, to = f32} : (tensor<36x5xf32>) -> tensor<36x5xf32>
// CHECK-DAG: [[VAR_27_:%.+]] = "onnx.Flatten"([[PARAM_1_]]) {axis = 0 : si64} : (tensor<5xf32>) -> tensor<1x5xf32>
// CHECK: [[VAR_28_:%.+]] = "onnx.Mul"([[VAR_26_]], [[VAR_27_]]) : (tensor<36x5xf32>, tensor<1x5xf32>) -> tensor<36x5xf32>
// CHECK: [[VAR_29_:%.+]] = "onnx.Identity"([[VAR_28_]]) : (tensor<36x5xf32>) -> tensor<36x5xf32>
// CHECK-DAG: [[VAR_30_:%.+]] = "onnx.Reshape"([[VAR_29_]], [[VAR_3_]]) {allowzero = 0 : si64} : (tensor<36x5xf32>, tensor<3xi64>) -> tensor<12x3x5xf32>
// CHECK-DAG: [[VAR_31_:%.+]] = "onnx.Reciprocal"([[VAR_23_]]) : (tensor<36x1xf32>) -> tensor<36x1xf32>
// CHECK: onnx.Return [[VAR_30_]] : tensor<12x3x5xf32>
// CHECK-DAG: [[VAR_17_:%.+]] = "onnx.ReduceMeanV13"([[VAR_16_]]) {axes = [1], keepdims = 1 : si64} : (tensor<36x5xf32>) -> tensor<36x1xf32>
// CHECK-DAG: [[VAR_18_:%.+]] = "onnx.Mul"([[VAR_15_]], [[VAR_15_]]) : (tensor<36x1xf32>, tensor<36x1xf32>) -> tensor<36x1xf32>
// CHECK: [[VAR_19_:%.+]] = "onnx.Sub"([[VAR_17_]], [[VAR_18_]]) : (tensor<36x1xf32>, tensor<36x1xf32>) -> tensor<36x1xf32>
// CHECK: [[VAR_20_:%.+]] = "onnx.Add"([[VAR_19_]], [[VAR_2_]]) : (tensor<36x1xf32>, tensor<f32>) -> tensor<36x1xf32>
// CHECK-DAG: [[VAR_21_:%.+]] = "onnx.Sqrt"([[VAR_20_]]) : (tensor<36x1xf32>) -> tensor<36x1xf32>
// CHECK-DAG: [[VAR_22_:%.+]] = "onnx.Sub"([[VAR_14_]], [[VAR_15_]]) : (tensor<36x5xf32>, tensor<36x1xf32>) -> tensor<36x5xf32>
// CHECK: [[VAR_23_:%.+]] = "onnx.Div"([[VAR_22_]], [[VAR_21_]]) : (tensor<36x5xf32>, tensor<36x1xf32>) -> tensor<36x5xf32>
// CHECK-DAG: [[VAR_24_:%.+]] = "onnx.Cast"([[VAR_23_]]) {saturate = 1 : si64, to = f32} : (tensor<36x5xf32>) -> tensor<36x5xf32>
// CHECK-DAG: [[VAR_25_:%.+]] = "onnx.Flatten"([[PARAM_1_]]) {axis = 0 : si64} : (tensor<5xf32>) -> tensor<1x5xf32>
// CHECK: [[VAR_26_:%.+]] = "onnx.Mul"([[VAR_24_]], [[VAR_25_]]) : (tensor<36x5xf32>, tensor<1x5xf32>) -> tensor<36x5xf32>
// CHECK: [[VAR_27_:%.+]] = "onnx.Identity"([[VAR_26_]]) : (tensor<36x5xf32>) -> tensor<36x5xf32>
// CHECK-DAG: [[VAR_28_:%.+]] = "onnx.Reshape"([[VAR_27_]], [[VAR_3_]]) {allowzero = 0 : si64} : (tensor<36x5xf32>, tensor<3xi64>) -> tensor<12x3x5xf32>
// CHECK-DAG: [[VAR_29_:%.+]] = "onnx.Reciprocal"([[VAR_21_]]) : (tensor<36x1xf32>) -> tensor<36x1xf32>
// CHECK: onnx.Return [[VAR_28_]] : tensor<12x3x5xf32>
// CHECK: }

0 comments on commit ded4d47

Please sign in to comment.