diff --git a/src/Compiler/CompilerUtils.cpp b/src/Compiler/CompilerUtils.cpp index 284e717b35..4b8b1c4de9 100644 --- a/src/Compiler/CompilerUtils.cpp +++ b/src/Compiler/CompilerUtils.cpp @@ -14,6 +14,7 @@ #include "CompilerUtils.hpp" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Parser/Parser.h" #include "mlir/Pass/PassManager.h" @@ -35,6 +36,7 @@ #include "src/Accelerators/Accelerator.hpp" #include "src/Builder/FrontendDialectTransformer.hpp" +#include "src/Builder/ModelInputShaper.hpp" #include "src/Compiler/CompilerDialects.hpp" #include "src/Compiler/CompilerOptions.hpp" #include "src/Compiler/CompilerPasses.hpp" @@ -183,6 +185,35 @@ static void loadMLIR(std::string inputFilename, mlir::MLIRContext &context, llvm::errs() << "Error can't load file " << inputFilename << "\n"; exit(1); } + + // Set shape information if required. + // Only set shape if the module has a single function. + uint64_t numOfFuncOp = 0; + func::FuncOp funcOp; + module->walk([&](func::FuncOp f) { + funcOp = f; + numOfFuncOp++; + }); + if ((numOfFuncOp == 1) && (!shapeInformation.empty())) { + ModelInputShaper modelInputShaper_; + modelInputShaper_.setShapeInformation(shapeInformation); + auto funcType = dyn_cast(funcOp.getFunctionType()); + ArrayRef argTypes = funcType.getInputs(); + SmallVector newArgTypes; + for (uint64_t i = 0; i < argTypes.size(); ++i) { + Type argTy = argTypes[i]; + // Get user's shape information. + argTy = modelInputShaper_.reshape(i, argTy); + // Update the arguments. + funcOp.getBody().back().getArgument(i).setType(argTy); + newArgTypes.emplace_back(argTy); + } + // Update the function type. + FunctionType newType = + FunctionType::get(&context, newArgTypes, funcType.getResults()); + ConversionPatternRewriter rewriter(&context); + rewriter.updateRootInPlace(funcOp, [&] { funcOp.setType(newType); }); + } } // Tailor LLVMIR to add features that cannot be done with MLIR LLVMIR. diff --git a/test/mlir/driver/shape_information.mlir b/test/mlir/driver/shape_information.mlir new file mode 100644 index 0000000000..9904bf9444 --- /dev/null +++ b/test/mlir/driver/shape_information.mlir @@ -0,0 +1,11 @@ +// RUN: onnx-mlir --EmitONNXIR --shapeInformation=0:3x-1 --printIR %s | FileCheck %s + +module { +func.func @main_graph(%arg0: tensor<3x2xi64>, %arg1: tensor<3x2xi64>) -> tensor<3x2xi64> { + %0 = "onnx.Add"(%arg0, %arg1) : (tensor<3x2xi64>, tensor<3x2xi64>) -> tensor<3x2xi64> + onnx.Return %0 : tensor<3x2xi64> + +// CHECK-LABEL main_graph +// CHECK: "onnx.Add"(%arg0, %arg1) : (tensor<3x?xi64>, tensor<3x2xi64>) -> tensor<3x2xi64 +} +}