diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py b/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py index 7f630074e756..5402c7243e00 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py @@ -27,6 +27,7 @@ def import_onnx(contents): # Import the ONNX model proto from the file contents: raw_model = onnx.load_from_string(contents) + # since it does not affect current e2e tests, data_prop is left false here model_proto = onnx.shape_inference.infer_shapes(raw_model) # Import the ONNX module into an MLIR module: diff --git a/python/torch_mlir/tools/import_onnx/__main__.py b/python/torch_mlir/tools/import_onnx/__main__.py index 6fbabb09a2ef..92ae3c7eb356 100644 --- a/python/torch_mlir/tools/import_onnx/__main__.py +++ b/python/torch_mlir/tools/import_onnx/__main__.py @@ -85,7 +85,9 @@ def load_onnx_model(args: argparse.Namespace) -> onnx.ModelProto: # in-memory shape inference. If not, go ahead and do the shape inference. try: onnx.checker.check_model(raw_model) - inferred_model = onnx.shape_inference.infer_shapes(raw_model) + inferred_model = onnx.shape_inference.infer_shapes( + raw_model, data_prop=args.data_prop + ) return inferred_model except ValueError: pass @@ -103,7 +105,9 @@ def load_onnx_model(args: argparse.Namespace) -> onnx.ModelProto: # Model is too big for in-memory inference: do file-based shape inference # to a temp file. temp_inferred_file = temp_dir / "inferred.onnx" - onnx.shape_inference.infer_shapes_path(args.input_file, temp_inferred_file) + onnx.shape_inference.infer_shapes_path( + args.input_file, temp_inferred_file, data_prop=args.data_prop + ) # Sanity check the shape-inferred model to be sure we have a good model # for the importer. This call uses the file-based method, as the @@ -138,6 +142,13 @@ def parse_arguments(argv=None) -> argparse.Namespace: action="store_true", help="Disable verification prior to printing", ) + parser.add_argument( + "--data-prop", + dest="data_prop", + default=True, + action=argparse.BooleanOptionalAction, + help="Toggle data propogation for onnx shape inference", + ) parser.add_argument( "--keep-temps", action="store_true", help="Keep intermediate files" )