Skip to content

Commit

Permalink
Fix the compile and run execution session in Python (#2373)
Browse files Browse the repository at this point in the history

Signed-off-by: Alexandre Eichenberger <[email protected]>
Co-authored-by: Tung D. Le <[email protected]>
  • Loading branch information
AlexandreEichenberger and tungld authored Jul 22, 2023
1 parent 6bc651e commit 4b8f25b
Show file tree
Hide file tree
Showing 16 changed files with 330 additions and 134 deletions.
8 changes: 6 additions & 2 deletions docs/doc_example/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ std::string readArgs(int argc, char *argv[]) {
int main(int argc, char *argv[]) {
// Read compiler options from command line and compile the doc example into a
// model library.
const char *errorMessage = NULL;
const char *compiledFilename;
char *errorMessage = nullptr;
char *compiledFilename = nullptr;
std::string flags = readArgs(argc, argv);
flags += "-o add-cpp-interface";
std::cout << "Compile with options \"" << flags << "\"\n";
Expand All @@ -30,11 +30,15 @@ int main(int argc, char *argv[]) {
if (errorMessage)
std::cerr << " and message \"" << errorMessage << "\"";
std::cerr << "." << std::endl;
free(compiledFilename);
free(errorMessage);
return rc;
}
std::string libFilename(compiledFilename);
std::cout << "Compiled succeeded with results in file: " << libFilename
<< std::endl;
free(compiledFilename);
free(errorMessage);

// Prepare the execution session.
onnx_mlir::ExecutionSession *session;
Expand Down
23 changes: 19 additions & 4 deletions docs/mnist_example/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,7 @@ The runtime use an `OMExecutionSession` object to hold a specific model and entr

```Python
#Load the model mnist.so compiled with onnx-mlir.
model = 'mnist.so'
session = OMExecutionSession(model)
session = OMExecutionSession('mnist.so')
#Print the models input / output signature, for display.
#If there are problems with the signature functions, \
they can be simply commented out.
Expand All @@ -295,10 +294,10 @@ outputs = session.run([input])

The outputs can then be analyzed by inspecting the values inside the `output` list of numpy arrays.

The full code is available [here](mnist.py). It finds that `0` is the most likely digit for the given input. The command is:
The full code is available [here](mnist-runPyRuntime.py). It finds that `0` is the most likely digit for the given input. The command is:

```shell
./mnist.py
./mnist-runPyRuntime.py
```

and produces an output similar to the following (you may see slightly different prediction numbers if you train the model yourself):
Expand All @@ -321,6 +320,22 @@ prediction 9 = 8.650948e-15
The digit is 0
```
We provide two additional Python interfaces.
The second interface extends the above execution session by simply compiling a model before loading it for execution (see [here](mnist-runPyCompileAndRuntime.py)).
The user simply passes the `.onnx` model and the flags needed to compile the model.
Unless explicitly disabled by the `reuse_compiled_model=0`, the execution session will reuse a previously compiled model whose name matches the name the output file generated by the compiler.
Note that the execution session does not check if the cached version was compiled using identical compiler flags; it is the responsibility of the user to then clear the cached version, or disable the reuse using the provided optional flag.
For example, the code below will compile and load the `mnist.onnx` model, compiling only when the `mnist2.so` binary file cannot be located. Model inference can then proceed using the `session.run(...)` command.
```Python
# Load onnx model and create CompileExecutionSession object,
# by first compiling the mnist.onnx model with the "-O3" options.
session = OMCompileExecutionSession("./mnist.onnx" ,"-O3 -o=mnist2")
```
The third interface provides a simple interface to explicitly compile an onnx model (see [here](mnist-compile.py)).
## Write a Java Driver Code
Inference APIs and data structures for Java closely mirror those for C/C++. Documentation of the APIs are found [here](https://onnx.ai/onnx-mlir/doxygen_html/OMModel_java/classcom_1_1ibm_1_1onnxmlir_1_1_o_m_model.html), with the Java interface for Tensor [here](https://onnx.ai/onnx-mlir/doxygen_html/OMTensor_java/classcom_1_1ibm_1_1onnxmlir_1_1_o_m_tensor.html) and TensorList [here](https://onnx.ai/onnx-mlir/doxygen_html/OMTensorList_java/classcom_1_1ibm_1_1onnxmlir_1_1_o_m_tensor_list.html).
Expand Down
2 changes: 1 addition & 1 deletion docs/mnist_example/mnist-compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
file = './mnist.onnx'
compiler = OMCompileSession(file)
# Generate the library file. Success when rc == 0 while set the opt as "-O3"
rc = compiler.compile("-O3")
rc = compiler.compile("-O3 -o mnist")
# Get the output file name
model = compiler.get_compiled_file_name()
if rc:
Expand Down
14 changes: 7 additions & 7 deletions docs/mnist_example/mnist-runPyCompileAndRuntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
import numpy as np
from PyCompileAndRuntime import OMCompileExecutionSession

# Load onnx model and create CompileExecutionSession object.
inputFileName = './mnist.onnx'
# Set the full name of compiled model
sharedLibPath = './mnist.so'
# Set the compile option as "-O3"
session = OMCompileExecutionSession(inputFileName,sharedLibPath,"-O3")

# Load onnx model and create CompileExecutionSession object,
# by first compiling the mnist.onnx model with the "-O3" options.
session = OMCompileExecutionSession("./mnist.onnx" ,"-O3 -o=mnist2",
reuse_compiled_model=1)
if session.get_compiled_result():
print("error with :" + session.get_error_message())
exit(1)
# Print the models input/output signature, for display.
# Signature functions for info only, commented out if they cause problems.
print("input signature in json", session.input_signature())
Expand Down
3 changes: 1 addition & 2 deletions docs/mnist_example/mnist-runPyRuntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
from PyRuntime import OMExecutionSession

# Load the model mnist.so compiled with onnx-mlir.
model = './mnist.so'
session = OMExecutionSession(model)
session = OMExecutionSession('./mnist.so')
# Print the models input/output signature, for display.
# Signature functions for info only, commented out if they cause problems.
print("input signature in json", session.input_signature())
Expand Down
35 changes: 25 additions & 10 deletions include/OnnxMlirCompiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,18 +60,18 @@ namespace onnx_mlir {
* Name may include a path, and must include the file name and its extention.
*
* @param outputFilename Output file name of the compiled output for the given
* emission target. User is responsible for freeing the string.
* emission target. User is responsible for freeing the string.
*
* @param flags A char * contains all the options provided to compile the
* model.
* model.
*
* @param errorMessage Output error message, if any. User is responsible for
* freeing the string.
* freeing the string.
*
* @return 0 on success or OnnxMlirCompilerErrorCodes on failure.
*/
ONNX_MLIR_EXPORT int64_t omCompileFromFile(const char *inputFilename,
const char *flags, const char **outputFilename, const char **errorMessage);
const char *flags, char **outputFilename, char **errorMessage);

/*!
* Compile an onnx model from an ONNX protobuf array. This method is not thread
Expand All @@ -85,18 +85,33 @@ ONNX_MLIR_EXPORT int64_t omCompileFromFile(const char *inputFilename,
* @param bufferSize Size of ONNX protobuf array.
* @param outputBaseName File name without extension to write output.
* Name may include a path, must include the file name, and should not include
* an extention.
* an extention.
* @param emissionTarget Target format to compile to.
* @param outputFilename Output file name of the compiled output for the given
* emission target. User is responsible for freeing the string.
* @param errorMessage Error message.
* emission target. User is responsible for freeing the string.
* @param errorMessage Error message, if any. User is responsible for freeing
* the string.
* @return 0 on success or OnnxMlirCompilerErrorCodes failure. User is
* responsible for freeing the string.
* responsible for freeing the string.
*/
ONNX_MLIR_EXPORT int64_t omCompileFromArray(const void *inputBuffer,
int64_t bufferSize, const char *outputBaseName,
EmissionTargetType emissionTarget, const char **outputFilename,
const char **errorMessage);
EmissionTargetType emissionTarget, char **outputFilename,
char **errorMessage);

/*!
* Compute the file name of the compiled output for the given
* emission target. User is responsible for freeing the string.
*
* @param inputFilename File name pointing onnx model protobuf or MLIR.
* Name may include a path, and must include the file name and its extention.
* @param flags A char * contains all the options provided to compile the
* model.
* @return string containing the file name. User is responsible for freeing the
* string.
*/
ONNX_MLIR_EXPORT char *omCompileOutputFileName(
const char *inputFilename, const char *flags);

#ifdef __cplusplus
} // namespace onnx_mlir
Expand Down
77 changes: 62 additions & 15 deletions src/Compiler/OnnxMlirCompiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,27 @@ using namespace onnx_mlir;

namespace onnx_mlir {

// Derive the name; base name is either given by a "-o" option, or is taken as
// the model name. The extention depends on the target; e.g. -EmitLib will
// generate a .so, other targets may generate a .mlir.
static std::string deriveOutputFileName(
std::vector<std::string> &flagVect, std::string inputFilename) {
// Get output file name.
std::string outputBasename;
int num = flagVect.size();
for (int i = 0; i < num - 1;
++i) { // Skip last as need 2 consecutive entries.
if (flagVect[i].find("-o") == 0) {
outputBasename = flagVect[i + 1];
break;
for (int i = 0; i < num; ++i) {
if (flagVect[i].find("-o=", 0, 3) == 0) {
if (flagVect[i].length() > 3) {
outputBasename = flagVect[i].substr(3);
break;
} else
llvm::errs() << "Parsing `-o=` option, expected a name. Use default.\n";
} else if (flagVect[i].find("-o") == 0) {
if (i < num - 1) {
outputBasename = flagVect[i + 1];
break;
} else
llvm::errs() << "Parsing `-o` option, expected a name. Use default.\n";
}
}
// If no output file name, derive it from input file name
Expand Down Expand Up @@ -56,12 +67,7 @@ static std::string deriveOutputFileName(
return getTargetFilename(outputBasename, emissionTarget);
}

extern "C" {

ONNX_MLIR_EXPORT int64_t omCompileFromFile(const char *inputFilename,
const char *flags, const char **outputFilename, const char **errorMessage) {
// Process the flags, saving each space-separated text in a separate
// entry in the string vector flagVect.
static std::vector<std::string> parseFlags(const char *flags) {
std::vector<std::string> flagVect;
const char *str = flags;
do {
Expand All @@ -76,6 +82,23 @@ ONNX_MLIR_EXPORT int64_t omCompileFromFile(const char *inputFilename,
if (begin != str)
flagVect.push_back(std::string(begin, str));
} while (*str);
return flagVect;
}

extern "C" {

ONNX_MLIR_EXPORT int64_t omCompileFromFile(const char *inputFilename,
const char *flags, char **outputFilename, char **errorMessage) {
// Ensure known values in filename and error message if provided.
if (outputFilename)
*outputFilename = nullptr;
if (errorMessage)
*errorMessage = nullptr;

// Process the flags, saving each space-separated text in a separate
// entry in the string vector flagVect.
std::vector<std::string> flagVect = parseFlags(flags);

// Use 'onnx-mlir' command to compile the model.
std::string onnxMlirPath;
const auto &envDir = getEnvVar("ONNX_MLIR_BIN_PATH");
Expand All @@ -90,17 +113,33 @@ ONNX_MLIR_EXPORT int64_t omCompileFromFile(const char *inputFilename,
onnxMlirCompile.appendStr(inputFilenameStr);
// Run command.
int rc = onnxMlirCompile.exec();
if (rc == CompilerSuccess && outputFilename) {
if (rc != CompilerSuccess) {
// Failure to compile.
if (errorMessage) {
std::string errorStr =
"Compiler failed with error code " + std::to_string(rc);
*errorMessage = strdup(errorStr.c_str());
}
return CompilerFailureInLLVMOpt;
}
// Success.
if (outputFilename) {
std::string name = deriveOutputFileName(flagVect, inputFilenameStr);
*outputFilename = strdup(name.c_str());
}
return rc != 0 ? CompilerFailureInLLVMOpt : CompilerSuccess;
return CompilerSuccess;
}

ONNX_MLIR_EXPORT int64_t omCompileFromArray(const void *inputBuffer,
int64_t bufferSize, const char *outputBaseName,
EmissionTargetType emissionTarget, const char **outputFilename,
const char **errorMessage) {
EmissionTargetType emissionTarget, char **outputFilename,
char **errorMessage) {
// Ensure known values in filename and error message if provided.
if (outputFilename)
*outputFilename = nullptr;
if (errorMessage)
*errorMessage = nullptr;

mlir::OwningOpRef<mlir::ModuleOp> module;
mlir::MLIRContext context;
registerDialects(context);
Expand All @@ -124,5 +163,13 @@ ONNX_MLIR_EXPORT int64_t omCompileFromArray(const void *inputBuffer,
return rc;
}

ONNX_MLIR_EXPORT char *omCompileOutputFileName(
const char *inputFilename, const char *flags) {
std::vector<std::string> flagVect = parseFlags(flags);
std::string inputFilenameStr(inputFilename);
std::string name = deriveOutputFileName(flagVect, inputFilenameStr);
return strdup(name.c_str());
}

} // extern C
} // namespace onnx_mlir
10 changes: 8 additions & 2 deletions src/Compiler/PyOMCompileSession.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ int64_t PyOMCompileSession::pyCompileFromFile(std::string flags) {
"No OMCompileSession was created with the input file name specified.";
return -1;
}
const char *outputName, *errorMsg;
char *outputName = nullptr;
char *errorMsg = nullptr;
int64_t rc;
rc = omCompileFromFile(
inputFileName.c_str(), flags.c_str(), &outputName, &errorMsg);
Expand All @@ -50,6 +51,8 @@ int64_t PyOMCompileSession::pyCompileFromFile(std::string flags) {
// Empty output file name.
outputFileName = std::string();
}
free(outputName);
free(errorMsg);
return rc;
}

Expand All @@ -60,7 +63,8 @@ int64_t PyOMCompileSession::pyCompileFromArray(
"No OMCompileSession was created with the input buffer specified.";
return -1;
}
const char *outputName, *errorMsg;
char *outputName = nullptr;
char *errorMsg = nullptr;
int64_t rc;
rc = omCompileFromArray(inputBuffer, inputBufferSize, outputBaseName.c_str(),
emissionTarget, &outputName, &errorMsg);
Expand All @@ -75,6 +79,8 @@ int64_t PyOMCompileSession::pyCompileFromArray(
// Empty output file name.
outputFileName = std::string();
}
free(outputName);
free(errorMsg);
return rc;
}

Expand Down
Loading

0 comments on commit 4b8f25b

Please sign in to comment.