diff --git a/docs/doc_example/main.cpp b/docs/doc_example/main.cpp index d0407bc81d..6102372b02 100644 --- a/docs/doc_example/main.cpp +++ b/docs/doc_example/main.cpp @@ -18,8 +18,8 @@ std::string readArgs(int argc, char *argv[]) { int main(int argc, char *argv[]) { // Read compiler options from command line and compile the doc example into a // model library. - const char *errorMessage = NULL; - const char *compiledFilename; + char *errorMessage = nullptr; + char *compiledFilename = nullptr; std::string flags = readArgs(argc, argv); flags += "-o add-cpp-interface"; std::cout << "Compile with options \"" << flags << "\"\n"; @@ -30,11 +30,15 @@ int main(int argc, char *argv[]) { if (errorMessage) std::cerr << " and message \"" << errorMessage << "\""; std::cerr << "." << std::endl; + free(compiledFilename); + free(errorMessage); return rc; } std::string libFilename(compiledFilename); std::cout << "Compiled succeeded with results in file: " << libFilename << std::endl; + free(compiledFilename); + free(errorMessage); // Prepare the execution session. onnx_mlir::ExecutionSession *session; diff --git a/docs/mnist_example/README.md b/docs/mnist_example/README.md index 3fa33fcbfe..03c7c9194d 100644 --- a/docs/mnist_example/README.md +++ b/docs/mnist_example/README.md @@ -279,8 +279,7 @@ The runtime use an `OMExecutionSession` object to hold a specific model and entr ```Python #Load the model mnist.so compiled with onnx-mlir. -model = 'mnist.so' -session = OMExecutionSession(model) +session = OMExecutionSession('mnist.so') #Print the models input / output signature, for display. #If there are problems with the signature functions, \ they can be simply commented out. @@ -295,10 +294,10 @@ outputs = session.run([input]) The outputs can then be analyzed by inspecting the values inside the `output` list of numpy arrays. -The full code is available [here](mnist.py). It finds that `0` is the most likely digit for the given input. The command is: +The full code is available [here](mnist-runPyRuntime.py). It finds that `0` is the most likely digit for the given input. The command is: ```shell -./mnist.py +./mnist-runPyRuntime.py ``` and produces an output similar to the following (you may see slightly different prediction numbers if you train the model yourself): @@ -321,6 +320,22 @@ prediction 9 = 8.650948e-15 The digit is 0 ``` +We provide two additional Python interfaces. +The second interface extends the above execution session by simply compiling a model before loading it for execution (see [here](mnist-runPyCompileAndRuntime.py)). +The user simply passes the `.onnx` model and the flags needed to compile the model. +Unless explicitly disabled by the `reuse_compiled_model=0`, the execution session will reuse a previously compiled model whose name matches the name the output file generated by the compiler. +Note that the execution session does not check if the cached version was compiled using identical compiler flags; it is the responsibility of the user to then clear the cached version, or disable the reuse using the provided optional flag. + +For example, the code below will compile and load the `mnist.onnx` model, compiling only when the `mnist2.so` binary file cannot be located. Model inference can then proceed using the `session.run(...)` command. + +```Python +# Load onnx model and create CompileExecutionSession object, +# by first compiling the mnist.onnx model with the "-O3" options. +session = OMCompileExecutionSession("./mnist.onnx" ,"-O3 -o=mnist2") +``` + +The third interface provides a simple interface to explicitly compile an onnx model (see [here](mnist-compile.py)). + ## Write a Java Driver Code Inference APIs and data structures for Java closely mirror those for C/C++. Documentation of the APIs are found [here](https://onnx.ai/onnx-mlir/doxygen_html/OMModel_java/classcom_1_1ibm_1_1onnxmlir_1_1_o_m_model.html), with the Java interface for Tensor [here](https://onnx.ai/onnx-mlir/doxygen_html/OMTensor_java/classcom_1_1ibm_1_1onnxmlir_1_1_o_m_tensor.html) and TensorList [here](https://onnx.ai/onnx-mlir/doxygen_html/OMTensorList_java/classcom_1_1ibm_1_1onnxmlir_1_1_o_m_tensor_list.html). diff --git a/docs/mnist_example/mnist-compile.py b/docs/mnist_example/mnist-compile.py index c854dd6069..5c408b64e5 100755 --- a/docs/mnist_example/mnist-compile.py +++ b/docs/mnist_example/mnist-compile.py @@ -7,7 +7,7 @@ file = './mnist.onnx' compiler = OMCompileSession(file) # Generate the library file. Success when rc == 0 while set the opt as "-O3" -rc = compiler.compile("-O3") +rc = compiler.compile("-O3 -o mnist") # Get the output file name model = compiler.get_compiled_file_name() if rc: diff --git a/docs/mnist_example/mnist-runPyCompileAndRuntime.py b/docs/mnist_example/mnist-runPyCompileAndRuntime.py index 0640750cfa..a8b5a4b077 100755 --- a/docs/mnist_example/mnist-runPyCompileAndRuntime.py +++ b/docs/mnist_example/mnist-runPyCompileAndRuntime.py @@ -3,13 +3,13 @@ import numpy as np from PyCompileAndRuntime import OMCompileExecutionSession -# Load onnx model and create CompileExecutionSession object. -inputFileName = './mnist.onnx' -# Set the full name of compiled model -sharedLibPath = './mnist.so' -# Set the compile option as "-O3" -session = OMCompileExecutionSession(inputFileName,sharedLibPath,"-O3") - +# Load onnx model and create CompileExecutionSession object, +# by first compiling the mnist.onnx model with the "-O3" options. +session = OMCompileExecutionSession("./mnist.onnx" ,"-O3 -o=mnist2", + reuse_compiled_model=1) +if session.get_compiled_result(): + print("error with :" + session.get_error_message()) + exit(1) # Print the models input/output signature, for display. # Signature functions for info only, commented out if they cause problems. print("input signature in json", session.input_signature()) diff --git a/docs/mnist_example/mnist-runPyRuntime.py b/docs/mnist_example/mnist-runPyRuntime.py index 5523519ef8..3ba90bc681 100755 --- a/docs/mnist_example/mnist-runPyRuntime.py +++ b/docs/mnist_example/mnist-runPyRuntime.py @@ -4,8 +4,7 @@ from PyRuntime import OMExecutionSession # Load the model mnist.so compiled with onnx-mlir. -model = './mnist.so' -session = OMExecutionSession(model) +session = OMExecutionSession('./mnist.so') # Print the models input/output signature, for display. # Signature functions for info only, commented out if they cause problems. print("input signature in json", session.input_signature()) diff --git a/include/OnnxMlirCompiler.h b/include/OnnxMlirCompiler.h index 98781d4863..f5e79ef6b4 100644 --- a/include/OnnxMlirCompiler.h +++ b/include/OnnxMlirCompiler.h @@ -60,18 +60,18 @@ namespace onnx_mlir { * Name may include a path, and must include the file name and its extention. * * @param outputFilename Output file name of the compiled output for the given - * emission target. User is responsible for freeing the string. + * emission target. User is responsible for freeing the string. * * @param flags A char * contains all the options provided to compile the - * model. + * model. * * @param errorMessage Output error message, if any. User is responsible for - * freeing the string. + * freeing the string. * * @return 0 on success or OnnxMlirCompilerErrorCodes on failure. */ ONNX_MLIR_EXPORT int64_t omCompileFromFile(const char *inputFilename, - const char *flags, const char **outputFilename, const char **errorMessage); + const char *flags, char **outputFilename, char **errorMessage); /*! * Compile an onnx model from an ONNX protobuf array. This method is not thread @@ -85,18 +85,33 @@ ONNX_MLIR_EXPORT int64_t omCompileFromFile(const char *inputFilename, * @param bufferSize Size of ONNX protobuf array. * @param outputBaseName File name without extension to write output. * Name may include a path, must include the file name, and should not include - * an extention. + * an extention. * @param emissionTarget Target format to compile to. * @param outputFilename Output file name of the compiled output for the given - * emission target. User is responsible for freeing the string. - * @param errorMessage Error message. + * emission target. User is responsible for freeing the string. + * @param errorMessage Error message, if any. User is responsible for freeing + * the string. * @return 0 on success or OnnxMlirCompilerErrorCodes failure. User is - * responsible for freeing the string. + * responsible for freeing the string. */ ONNX_MLIR_EXPORT int64_t omCompileFromArray(const void *inputBuffer, int64_t bufferSize, const char *outputBaseName, - EmissionTargetType emissionTarget, const char **outputFilename, - const char **errorMessage); + EmissionTargetType emissionTarget, char **outputFilename, + char **errorMessage); + +/*! + * Compute the file name of the compiled output for the given + * emission target. User is responsible for freeing the string. + * + * @param inputFilename File name pointing onnx model protobuf or MLIR. + * Name may include a path, and must include the file name and its extention. + * @param flags A char * contains all the options provided to compile the + * model. + * @return string containing the file name. User is responsible for freeing the + * string. + */ +ONNX_MLIR_EXPORT char *omCompileOutputFileName( + const char *inputFilename, const char *flags); #ifdef __cplusplus } // namespace onnx_mlir diff --git a/src/Compiler/OnnxMlirCompiler.cpp b/src/Compiler/OnnxMlirCompiler.cpp index 11a2a996a3..71176b5881 100644 --- a/src/Compiler/OnnxMlirCompiler.cpp +++ b/src/Compiler/OnnxMlirCompiler.cpp @@ -18,16 +18,27 @@ using namespace onnx_mlir; namespace onnx_mlir { +// Derive the name; base name is either given by a "-o" option, or is taken as +// the model name. The extention depends on the target; e.g. -EmitLib will +// generate a .so, other targets may generate a .mlir. static std::string deriveOutputFileName( std::vector &flagVect, std::string inputFilename) { // Get output file name. std::string outputBasename; int num = flagVect.size(); - for (int i = 0; i < num - 1; - ++i) { // Skip last as need 2 consecutive entries. - if (flagVect[i].find("-o") == 0) { - outputBasename = flagVect[i + 1]; - break; + for (int i = 0; i < num; ++i) { + if (flagVect[i].find("-o=", 0, 3) == 0) { + if (flagVect[i].length() > 3) { + outputBasename = flagVect[i].substr(3); + break; + } else + llvm::errs() << "Parsing `-o=` option, expected a name. Use default.\n"; + } else if (flagVect[i].find("-o") == 0) { + if (i < num - 1) { + outputBasename = flagVect[i + 1]; + break; + } else + llvm::errs() << "Parsing `-o` option, expected a name. Use default.\n"; } } // If no output file name, derive it from input file name @@ -56,12 +67,7 @@ static std::string deriveOutputFileName( return getTargetFilename(outputBasename, emissionTarget); } -extern "C" { - -ONNX_MLIR_EXPORT int64_t omCompileFromFile(const char *inputFilename, - const char *flags, const char **outputFilename, const char **errorMessage) { - // Process the flags, saving each space-separated text in a separate - // entry in the string vector flagVect. +static std::vector parseFlags(const char *flags) { std::vector flagVect; const char *str = flags; do { @@ -76,6 +82,23 @@ ONNX_MLIR_EXPORT int64_t omCompileFromFile(const char *inputFilename, if (begin != str) flagVect.push_back(std::string(begin, str)); } while (*str); + return flagVect; +} + +extern "C" { + +ONNX_MLIR_EXPORT int64_t omCompileFromFile(const char *inputFilename, + const char *flags, char **outputFilename, char **errorMessage) { + // Ensure known values in filename and error message if provided. + if (outputFilename) + *outputFilename = nullptr; + if (errorMessage) + *errorMessage = nullptr; + + // Process the flags, saving each space-separated text in a separate + // entry in the string vector flagVect. + std::vector flagVect = parseFlags(flags); + // Use 'onnx-mlir' command to compile the model. std::string onnxMlirPath; const auto &envDir = getEnvVar("ONNX_MLIR_BIN_PATH"); @@ -90,17 +113,33 @@ ONNX_MLIR_EXPORT int64_t omCompileFromFile(const char *inputFilename, onnxMlirCompile.appendStr(inputFilenameStr); // Run command. int rc = onnxMlirCompile.exec(); - if (rc == CompilerSuccess && outputFilename) { + if (rc != CompilerSuccess) { + // Failure to compile. + if (errorMessage) { + std::string errorStr = + "Compiler failed with error code " + std::to_string(rc); + *errorMessage = strdup(errorStr.c_str()); + } + return CompilerFailureInLLVMOpt; + } + // Success. + if (outputFilename) { std::string name = deriveOutputFileName(flagVect, inputFilenameStr); *outputFilename = strdup(name.c_str()); } - return rc != 0 ? CompilerFailureInLLVMOpt : CompilerSuccess; + return CompilerSuccess; } ONNX_MLIR_EXPORT int64_t omCompileFromArray(const void *inputBuffer, int64_t bufferSize, const char *outputBaseName, - EmissionTargetType emissionTarget, const char **outputFilename, - const char **errorMessage) { + EmissionTargetType emissionTarget, char **outputFilename, + char **errorMessage) { + // Ensure known values in filename and error message if provided. + if (outputFilename) + *outputFilename = nullptr; + if (errorMessage) + *errorMessage = nullptr; + mlir::OwningOpRef module; mlir::MLIRContext context; registerDialects(context); @@ -124,5 +163,13 @@ ONNX_MLIR_EXPORT int64_t omCompileFromArray(const void *inputBuffer, return rc; } +ONNX_MLIR_EXPORT char *omCompileOutputFileName( + const char *inputFilename, const char *flags) { + std::vector flagVect = parseFlags(flags); + std::string inputFilenameStr(inputFilename); + std::string name = deriveOutputFileName(flagVect, inputFilenameStr); + return strdup(name.c_str()); +} + } // extern C } // namespace onnx_mlir diff --git a/src/Compiler/PyOMCompileSession.cpp b/src/Compiler/PyOMCompileSession.cpp index 001a7f9a1c..1cb3faffeb 100644 --- a/src/Compiler/PyOMCompileSession.cpp +++ b/src/Compiler/PyOMCompileSession.cpp @@ -35,7 +35,8 @@ int64_t PyOMCompileSession::pyCompileFromFile(std::string flags) { "No OMCompileSession was created with the input file name specified."; return -1; } - const char *outputName, *errorMsg; + char *outputName = nullptr; + char *errorMsg = nullptr; int64_t rc; rc = omCompileFromFile( inputFileName.c_str(), flags.c_str(), &outputName, &errorMsg); @@ -50,6 +51,8 @@ int64_t PyOMCompileSession::pyCompileFromFile(std::string flags) { // Empty output file name. outputFileName = std::string(); } + free(outputName); + free(errorMsg); return rc; } @@ -60,7 +63,8 @@ int64_t PyOMCompileSession::pyCompileFromArray( "No OMCompileSession was created with the input buffer specified."; return -1; } - const char *outputName, *errorMsg; + char *outputName = nullptr; + char *errorMsg = nullptr; int64_t rc; rc = omCompileFromArray(inputBuffer, inputBufferSize, outputBaseName.c_str(), emissionTarget, &outputName, &errorMsg); @@ -75,6 +79,8 @@ int64_t PyOMCompileSession::pyCompileFromArray( // Empty output file name. outputFileName = std::string(); } + free(outputName); + free(errorMsg); return rc; } diff --git a/src/Runtime/ExecutionSession.cpp b/src/Runtime/ExecutionSession.cpp index 283202c656..b7e82e7e20 100644 --- a/src/Runtime/ExecutionSession.cpp +++ b/src/Runtime/ExecutionSession.cpp @@ -33,17 +33,24 @@ const std::string ExecutionSession::_queryEntryPointsName = const std::string ExecutionSession::_inputSignatureName = "omInputSignature"; const std::string ExecutionSession::_outputSignatureName = "omOutputSignature"; +// ============================================================================= +// Constructor, destructor, and init. + ExecutionSession::ExecutionSession( std::string sharedLibPath, bool defaultEntryPoint) { + Init(sharedLibPath, defaultEntryPoint); +} +void ExecutionSession::Init(std::string sharedLibPath, bool defaultEntryPoint) { + if (isInitialized) + throw std::runtime_error(reportInitError()); + + // Init symbols used by execution session. _sharedLibraryHandle = llvm::sys::DynamicLibrary::getLibrary(sharedLibPath.c_str()); if (!_sharedLibraryHandle.isValid()) throw std::runtime_error(reportLibraryOpeningError(sharedLibPath)); - if (defaultEntryPoint) - setEntryPoint("run_main_graph"); - _queryEntryPointsFunc = reinterpret_cast( _sharedLibraryHandle.getAddressOfSymbol(_queryEntryPointsName.c_str())); if (!_queryEntryPointsFunc) @@ -70,14 +77,32 @@ ExecutionSession::ExecutionSession( setenv("OM_CONSTANT_PATH", basePath.c_str(), /*overwrite=*/0); #endif } + // Successful completion of initialization. + isInitialized = true; + + // Set default entry point if requested. + if (defaultEntryPoint) + setEntryPoint("run_main_graph"); +} + +ExecutionSession::~ExecutionSession() { + if (_sharedLibraryHandle.isValid()) + llvm::sys::DynamicLibrary::closeLibrary(_sharedLibraryHandle); } +// ============================================================================= +// Setter and getter. + const std::string *ExecutionSession::queryEntryPoints( int64_t *numOfEntryPoints) const { + if (!isInitialized) + throw std::runtime_error(reportInitError()); return (const std::string *)_queryEntryPointsFunc(numOfEntryPoints); } void ExecutionSession::setEntryPoint(const std::string &entryPointName) { + if (!isInitialized) + throw std::runtime_error(reportInitError()); _entryPointFunc = reinterpret_cast( _sharedLibraryHandle.getAddressOfSymbol(entryPointName.c_str())); if (!_entryPointFunc) @@ -86,8 +111,31 @@ void ExecutionSession::setEntryPoint(const std::string &entryPointName) { errno = 0; // No errors. } +const std::string ExecutionSession::inputSignature() const { + if (!isInitialized) + throw std::runtime_error(reportInitError()); + if (!_entryPointFunc) + throw std::runtime_error(reportUndefinedEntryPointIn("signature")); + errno = 0; // No errors. + return _inputSignatureFunc(_entryPointName.c_str()); +} + +const std::string ExecutionSession::outputSignature() const { + if (!isInitialized) + throw std::runtime_error(reportInitError()); + if (!_entryPointFunc) + throw std::runtime_error(reportUndefinedEntryPointIn("signature")); + errno = 0; // No errors. + return _outputSignatureFunc(_entryPointName.c_str()); +} + +// ============================================================================= +// Run. + std::vector ExecutionSession::run( std::vector ins) { + if (!isInitialized) + throw std::runtime_error(reportInitError()); if (!_entryPointFunc) throw std::runtime_error(reportUndefinedEntryPointIn("run")); @@ -96,6 +144,7 @@ std::vector ExecutionSession::run( omts.emplace_back(inOmt.get()); auto *wrappedInput = omTensorListCreate(omts.data(), (int64_t)omts.size()); + // Run inference. auto *wrappedOutput = _entryPointFunc(wrappedInput); // We created a wrapper for the input list, but the input list does not really @@ -103,11 +152,10 @@ std::vector ExecutionSession::run( // need to simply deallocate the list structure without touching the // OMTensors. omTensorListDestroyShallow(wrappedInput); - if (!wrappedOutput) throw std::runtime_error(reportErrnoError()); - std::vector outs; + std::vector outs; for (int64_t i = 0; i < omTensorListGetSize(wrappedOutput); i++) { outs.emplace_back(OMTensorUniquePtr( omTensorListGetOmtByIndex(wrappedOutput, i), omTensorDestroy)); @@ -125,42 +173,27 @@ std::vector ExecutionSession::run( // Run using public interface. Explicit calls are needed to free tensor & tensor // lists. OMTensorList *ExecutionSession::run(OMTensorList *input) { - if (!_entryPointFunc) { - std::stringstream errStr; - errStr << "Must set the entry point before calling run function" - << std::endl; - errno = EINVAL; - throw std::runtime_error(errStr.str()); - } + if (!isInitialized) + throw std::runtime_error(reportInitError()); + if (!_entryPointFunc) + throw std::runtime_error(reportUndefinedEntryPointIn("run")); + + // Run inference. OMTensorList *output = _entryPointFunc(input); - if (!output) { - std::stringstream errStr; - std::string errMessageStr = std::string(strerror(errno)); - errStr << "Runtime error during inference returning with ERRNO code '" - << errMessageStr << "'" << std::endl; - throw std::runtime_error(errStr.str()); - } + if (!output) + throw std::runtime_error(reportErrnoError()); errno = 0; // No errors. return output; } -const std::string ExecutionSession::inputSignature() const { - if (!_entryPointFunc) - throw std::runtime_error(reportUndefinedEntryPointIn("signature")); - errno = 0; // No errors. - return _inputSignatureFunc(_entryPointName.c_str()); -} - -const std::string ExecutionSession::outputSignature() const { - if (!_entryPointFunc) - throw std::runtime_error(reportUndefinedEntryPointIn("signature")); - errno = 0; // No errors. - return _outputSignatureFunc(_entryPointName.c_str()); -} +// ============================================================================= +// Error reporting -ExecutionSession::~ExecutionSession() { - if (_sharedLibraryHandle.isValid()) - llvm::sys::DynamicLibrary::closeLibrary(_sharedLibraryHandle); +std::string ExecutionSession::reportInitError() const { + errno = EFAULT; // Bad Address. + std::stringstream errStr; + errStr << "Execution session must be initialized once." << std::endl; + return errStr.str(); } std::string ExecutionSession::reportLibraryOpeningError( @@ -196,4 +229,13 @@ std::string ExecutionSession::reportErrnoError() const { return errStr.str(); } +std::string ExecutionSession::reportCompilerError( + const std::string &errorMessage) const { + errno = EFAULT; // Bad Address. + std::stringstream errStr; + errStr << "Compiler failed with error message '" << errorMessage << "'." + << std::endl; + return errStr.str(); +} + } // namespace onnx_mlir diff --git a/src/Runtime/ExecutionSession.hpp b/src/Runtime/ExecutionSession.hpp index d7e2d2c310..75b5ca3f71 100644 --- a/src/Runtime/ExecutionSession.hpp +++ b/src/Runtime/ExecutionSession.hpp @@ -47,6 +47,7 @@ class ExecutionSession { // Create an execution session using the model given in sharedLibPath. // This path must point to the actual file, local directory is not searched. ExecutionSession(std::string sharedLibPath, bool defaultEntryPoint = true); + ~ExecutionSession(); // Get a NULL-terminated array of entry point names. // For example {"run_addition, "run_subtraction", NULL} @@ -75,18 +76,27 @@ class ExecutionSession { const std::string inputSignature() const; const std::string outputSignature() const; - ~ExecutionSession(); - protected: + // Constructor that build the object without initialization (for use by + // subclass only). + ExecutionSession() = default; + + // Initialization of library. Called by public constructor, or by subclasses. + void Init(std::string sharedLibPath, bool defaultEntryPoint); + // Error reporting processing when throwing runtime errors. Set errno as // appropriate. + std::string reportInitError() const; std::string reportLibraryOpeningError(const std::string &libraryName) const; std::string reportSymbolLoadingError(const std::string &symbolName) const; std::string reportUndefinedEntryPointIn( const std::string &functionName) const; std::string reportErrnoError() const; + std::string reportCompilerError(const std::string &errorMessage) const; + + // Track if Init was called or not. + bool isInitialized = false; -protected: // Handler to the shared library file being loaded. llvm::sys::DynamicLibrary _sharedLibraryHandle; diff --git a/src/Runtime/PyExecutionSession.hpp b/src/Runtime/PyExecutionSession.hpp index 363a44c007..23947296cd 100644 --- a/src/Runtime/PyExecutionSession.hpp +++ b/src/Runtime/PyExecutionSession.hpp @@ -27,9 +27,8 @@ class PyExecutionSession : public onnx_mlir::PyExecutionSessionBase { PYBIND11_MODULE(PyRuntime, m) { py::class_(m, "OMExecutionSession") - .def(py::init(), py::arg("shared_lib_path")) .def(py::init(), - py::arg("shared_lib_path"), py::arg("use_default_entry_point")) + py::arg("shared_lib_path"), py::arg("use_default_entry_point") = 1) .def("entry_points", &onnx_mlir::PyExecutionSession::pyQueryEntryPoints) .def("set_entry_point", &onnx_mlir::PyExecutionSession::pySetEntryPoint, py::arg("name")) diff --git a/src/Runtime/PyExecutionSessionBase.cpp b/src/Runtime/PyExecutionSessionBase.cpp index b456fe623d..7ac0e1dd41 100644 --- a/src/Runtime/PyExecutionSessionBase.cpp +++ b/src/Runtime/PyExecutionSessionBase.cpp @@ -58,14 +58,22 @@ PyExecutionSessionBase::PyExecutionSessionBase( std::string sharedLibPath, bool defaultEntryPoint) : onnx_mlir::ExecutionSession(sharedLibPath, defaultEntryPoint) {} +// ============================================================================= +// Run. + std::vector PyExecutionSessionBase::pyRun( const std::vector &inputsPyArray) { - assert(_entryPointFunc && "Entry point not loaded."); + if (!isInitialized) + throw std::runtime_error(reportInitError()); + if (!_entryPointFunc) + throw std::runtime_error(reportUndefinedEntryPointIn("run")); + // 1. Process inputs. std::vector omts; for (auto inputPyArray : inputsPyArray) { - assert(inputPyArray.flags() && py::array::c_style && - "Expect contiguous python array."); + if (!inputPyArray.flags() || !py::array::c_style) + throw std::runtime_error( + reportPythonError("Expect contiguous python array.")); void *dataPtr; int64_t ownData = 0; @@ -114,9 +122,10 @@ std::vector PyExecutionSessionBase::pyRun( dtype = ONNX_TYPE_COMPLEX128; // Missing bfloat16 support else { - std::cerr << "Numpy type not supported: " << inputPyArray.dtype() - << ".\n"; - exit(1); + std::stringstream errStr; + errStr << "Numpy type not supported: " << inputPyArray.dtype() + << std::endl; + throw std::runtime_error(reportPythonError(errStr.str())); } // Convert Py_ssize_t to int64_t if necessary @@ -139,10 +148,13 @@ std::vector PyExecutionSessionBase::pyRun( omts.emplace_back(inputOMTensor); } + // 2. Call entry point. auto *wrappedInput = omTensorListCreate(&omts[0], omts.size()); auto *wrappedOutput = _entryPointFunc(wrappedInput); if (!wrappedOutput) throw std::runtime_error(reportErrnoError()); + + // 3. Process outputs. std::vector outputPyArrays; for (int64_t i = 0; i < omTensorListGetSize(wrappedOutput); i++) { auto *omt = omTensorListGetOmtByIndex(wrappedOutput, i); @@ -197,10 +209,13 @@ std::vector PyExecutionSessionBase::pyRun( case (OM_DATA_TYPE)onnx::TensorProto::COMPLEX128: dtype = py::dtype("cdouble"); break; - default: - std::cerr << "Unsupported ONNX type in OMTensor: " - << omTensorGetDataType(omt) << ".\n"; - exit(1); + default: { + std::stringstream errStr; + errStr << "Unsupported ONNX type in OMTensor: " + << omTensorGetDataType(omt) << std::endl; + + throw std::runtime_error(reportPythonError(errStr.str())); + } } outputPyArrays.emplace_back( @@ -212,11 +227,16 @@ std::vector PyExecutionSessionBase::pyRun( return outputPyArrays; } +// ============================================================================= +// Setter and getter. + void PyExecutionSessionBase::pySetEntryPoint(std::string entryPointName) { setEntryPoint(entryPointName); } std::vector PyExecutionSessionBase::pyQueryEntryPoints() { + if (!isInitialized) + throw std::runtime_error(reportInitError()); assert(_queryEntryPointsFunc && "Query entry point not loaded."); const char **entryPointArr = _queryEntryPointsFunc(NULL); @@ -230,13 +250,23 @@ std::vector PyExecutionSessionBase::pyQueryEntryPoints() { } std::string PyExecutionSessionBase::pyInputSignature() { - assert(_inputSignatureFunc && "Input signature entry point not loaded."); return inputSignature(); } std::string PyExecutionSessionBase::pyOutputSignature() { - assert(_outputSignatureFunc && "Output signature entry point not loaded."); return outputSignature(); } +// ============================================================================= +// Error reporting + +std::string PyExecutionSessionBase::reportPythonError( + std::string errorStr) const { + errno = EFAULT; // Bad Address. + std::stringstream errStr; + errStr << "Execution session: encountered python error `" << errorStr << "'." + << std::endl; + return errStr.str(); +} + } // namespace onnx_mlir diff --git a/src/Runtime/PyExecutionSessionBase.hpp b/src/Runtime/PyExecutionSessionBase.hpp index eca4220279..67139c914f 100644 --- a/src/Runtime/PyExecutionSessionBase.hpp +++ b/src/Runtime/PyExecutionSessionBase.hpp @@ -39,5 +39,11 @@ class PyExecutionSessionBase std::vector pyRun(const std::vector &inputsPyArray); std::string pyInputSignature(); std::string pyOutputSignature(); + +protected: + // Constructor that build the object without initialization (for use by + // subclass only). + PyExecutionSessionBase() : onnx_mlir::ExecutionSession() {} + std::string reportPythonError(std::string errorStr) const; }; } // namespace onnx_mlir diff --git a/src/Runtime/PyOMCompileExecutionSession.cpp b/src/Runtime/PyOMCompileExecutionSession.cpp index f20966b549..d67931b8e4 100644 --- a/src/Runtime/PyOMCompileExecutionSession.cpp +++ b/src/Runtime/PyOMCompileExecutionSession.cpp @@ -23,40 +23,65 @@ SUPPRESS_WARNINGS_POP namespace onnx_mlir { +// ============================================================================= +// Constructor + PyOMCompileExecutionSession::PyOMCompileExecutionSession( - std::string inputFileName, std::string sharedLibPath, std::string flags, - bool defaultEntryPoint) - : onnx_mlir::PyExecutionSessionBase(sharedLibPath, defaultEntryPoint) { + std::string inputFileName, std::string flags, bool defaultEntryPoint, + bool reuseCompiledModel) + : onnx_mlir::PyExecutionSessionBase() /* constructor without Init */ { + // First compile the onnx file. this->inputFileName = inputFileName; - if (this->inputFileName.empty()) { - errorMessage = "No OMCompileExecuteSession was created with the input file " - "name specified."; + if (this->inputFileName.empty()) + throw std::runtime_error(reportLibraryOpeningError(inputFileName)); + + char *outputName = nullptr; + char *errorMsg = nullptr; + if (reuseCompiledModel) { + // see if there is a model to reuse. + outputName = omCompileOutputFileName(inputFileName.c_str(), flags.c_str()); + FILE *file = fopen(outputName, "r"); + if (file) + // File exists, we are ok. + fclose(file); + else + // File does not exist, cannot reuse compilation. + reuseCompiledModel = false; } - const char *outputName, *errorMsg; - int64_t rc; - rc = omCompileFromFile( - inputFileName.c_str(), flags.c_str(), &outputName, &errorMsg); - if (rc == 0) { - // Compilation success: save output file name. - this->sharedLibPath = std::string(outputName); - // Empty error. - errorMessage = std::string(); - } else { - // Compilation failure: save error message. - errorMessage = std::string(errorMsg); - // Empty output file name. - this->sharedLibPath = std::string(); + if (!reuseCompiledModel) { + int64_t rc; + rc = omCompileFromFile( + inputFileName.c_str(), flags.c_str(), &outputName, &errorMsg); + if (rc != 0) { + // Compilation failure: save error message. + errorMessage = std::string(errorMsg); + // Empty output file name. + this->outputFileName = std::string(); + free(outputName); + free(errorMsg); + throw std::runtime_error(reportCompilerError(errorMessage)); + } } + // Compilation success: save output file name. + this->outputFileName = std::string(outputName); + errorMessage = std::string(); + // Now that we have a .so, initialize execution session. + Init(this->outputFileName, defaultEntryPoint); + free(outputName); + free(errorMsg); } -int64_t PyOMCompileExecutionSession::pyGetCompiledResult() { return this->rc; } +// ============================================================================= +// Custom getters + +int64_t PyOMCompileExecutionSession::pyGetCompiledResult() { return rc; } std::string PyOMCompileExecutionSession::pyGetCompiledFileName() { - return this->sharedLibPath; + return outputFileName; } std::string PyOMCompileExecutionSession::pyGetErrorMessage() { - return this->errorMessage; + return errorMessage; } } // namespace onnx_mlir diff --git a/src/Runtime/PyOMCompileExecutionSession.hpp b/src/Runtime/PyOMCompileExecutionSession.hpp index 1629236a2d..879e0b1578 100644 --- a/src/Runtime/PyOMCompileExecutionSession.hpp +++ b/src/Runtime/PyOMCompileExecutionSession.hpp @@ -28,16 +28,15 @@ namespace onnx_mlir { class PyOMCompileExecutionSession : public onnx_mlir::PyExecutionSessionBase { public: - PyOMCompileExecutionSession(std::string inputFileName, - std::string sharedLibPath, std::string flags, - bool defaultEntryPoint = true); + PyOMCompileExecutionSession(std::string inputFileName, std::string flags, + bool defaultEntryPoint = true, bool reuseCompiledModel = true); std::string pyGetCompiledFileName(); std::string pyGetErrorMessage(); int64_t pyGetCompiledResult(); private: std::string inputFileName; - std::string sharedLibPath; + std::string outputFileName; std::string errorMessage; int64_t rc; }; @@ -46,14 +45,11 @@ class PyOMCompileExecutionSession : public onnx_mlir::PyExecutionSessionBase { PYBIND11_MODULE(PyCompileAndRuntime, m) { py::class_( m, "OMCompileExecutionSession") - .def(py::init(), - py::arg("input_model_path"), py::arg("compiled_file_path"), - py::arg("flags")) - .def(py::init(), - py::arg("input_model_path"), py::arg("compiled_file_path"), - py::arg("flags"), py::arg("use_default_entry_point")) + .def(py::init(), + py::arg("input_model_name"), py::arg("flags"), + py::arg("use_default_entry_point") = 1, + py::arg("reuse_compiled_model") = 1) .def("get_compiled_result", &onnx_mlir::PyOMCompileExecutionSession::pyGetCompiledResult) .def("get_compiled_file_name", diff --git a/test/compilerlib/CompilerLibTest.cpp b/test/compilerlib/CompilerLibTest.cpp index 2b08f6e8ff..e57bd0bcf8 100644 --- a/test/compilerlib/CompilerLibTest.cpp +++ b/test/compilerlib/CompilerLibTest.cpp @@ -71,8 +71,8 @@ void readArgsFromCommandLine(int argc, char *argv[]) { int main(int argc, char *argv[]) { int retVal = 0; - const char *errorMessage = NULL; - const char *compiledFilename; + char *errorMessage = nullptr; + char *compiledFilename = nullptr; readArgsFromCommandLine(argc, argv); @@ -86,7 +86,7 @@ int main(int argc, char *argv[]) { // Compile. retVal = onnx_mlir::omCompileFromFile( testFileName.c_str(), flags.c_str(), &compiledFilename, &errorMessage); - if (retVal != CompilerSuccess && errorMessage != NULL) + if (retVal != CompilerSuccess && errorMessage != nullptr) std::cerr << errorMessage; } else { std::ifstream inFile( @@ -96,12 +96,14 @@ int main(int argc, char *argv[]) { retVal = omCompileFromArray(test.data(), test.size(), outputBaseName.c_str(), onnx_mlir::EmitLib, &compiledFilename, &errorMessage); - if (retVal != CompilerSuccess && errorMessage != NULL) { + if (retVal != CompilerSuccess && errorMessage != nullptr) { std::cerr << errorMessage; } } if (retVal != 0) { std::cerr << "Compiling " << testFileName << "failed with code" << retVal; } + free(compiledFilename); + free(errorMessage); return retVal; } \ No newline at end of file