diff --git a/CMakeLists.txt b/CMakeLists.txt index 80e7e1316fb..ed6f7ad86f5 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -53,6 +53,12 @@ include(CTest) find_package(ROCM REQUIRED) find_package(Threads REQUIRED) +if(WIN32) +option(MIGRAPHX_ENABLE_PYTHON "Enable python bindings" OFF) +else() +option(MIGRAPHX_ENABLE_PYTHON "Enable python bindings" ON) +endif() + find_path(HALF_INCLUDE_DIR half.hpp PATH_SUFFIXES half) if (NOT HALF_INCLUDE_DIR) message(FATAL_ERROR "Could not find half.hpp - Please check that the install path of half.hpp has been added to CMAKE_PREFIX_PATH") diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index a00242f4de8..be00add7e18 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -282,7 +282,9 @@ add_subdirectory(driver) add_subdirectory(onnx) add_subdirectory(tf) +if(MIGRAPHX_ENABLE_PYTHON) add_subdirectory(py) +endif() add_subdirectory(targets/ref) target_link_libraries(migraphx_all_targets INTERFACE migraphx_ref) if(MIGRAPHX_ENABLE_CPU) diff --git a/src/driver/CMakeLists.txt b/src/driver/CMakeLists.txt index 995daecb3f5..d1bb44d55f4 100755 --- a/src/driver/CMakeLists.txt +++ b/src/driver/CMakeLists.txt @@ -48,7 +48,12 @@ rocm_clang_tidy_check(driver) file(STRINGS "${CMAKE_SOURCE_DIR}/test/onnx/.onnxrt-commit" String_output) target_compile_definitions(driver PUBLIC MIGRAPHX_ORT_SHA1="${String_output}") -target_link_libraries(driver migraphx_all_targets migraphx_onnx migraphx_tf migraphx_py) +target_link_libraries(driver migraphx_all_targets migraphx_onnx migraphx_tf) + +if(MIGRAPHX_ENABLE_PYTHON) + target_link_libraries(driver migraphx_py) + target_compile_definitions(driver PRIVATE MIGRAPHX_ENABLE_PYTHON) +endif() rocm_install_targets( TARGETS driver diff --git a/src/driver/main.cpp b/src/driver/main.cpp index 4a920d66618..d4cf4acd3ba 100644 --- a/src/driver/main.cpp +++ b/src/driver/main.cpp @@ -32,7 +32,9 @@ #include #include +#ifdef MIGRAPHX_ENABLE_PYTHON #include +#endif #include #include #include @@ -281,10 +283,12 @@ struct loader options.format = "json"; p = migraphx::load(file, options); } +#ifdef MIGRAPHX_ENABLE_PYTHON else if(file_type == "py") { p = migraphx::load_py(file); } +#endif else if(file_type == "migraphx") { p = migraphx::load(file); diff --git a/src/py/CMakeLists.txt b/src/py/CMakeLists.txt index 2da0f7058aa..a939d642d62 100644 --- a/src/py/CMakeLists.txt +++ b/src/py/CMakeLists.txt @@ -22,27 +22,24 @@ # THE SOFTWARE. ##################################################################################### -option(MIGRAPHX_ENABLE_PYTHON "Enable python bindings" ON) add_library(migraphx_py py_loader.cpp) migraphx_generate_export_header(migraphx_py) target_include_directories(migraphx_py PRIVATE include) target_link_libraries(migraphx_py PUBLIC migraphx) rocm_install_targets(TARGETS migraphx_py INCLUDE include) -if(MIGRAPHX_ENABLE_PYTHON) - include(PythonModules) +include(PythonModules) - foreach(PYTHON_VERSION ${PYTHON_VERSIONS}) - py_add_module(migraphx_pybind_${PYTHON_VERSION} migraphx_py.cpp PYTHON_VERSION ${PYTHON_VERSION} PYTHON_MODULE migraphx) - target_link_libraries(migraphx_pybind_${PYTHON_VERSION} PRIVATE migraphx migraphx_tf migraphx_onnx migraphx_all_targets) - rocm_install_targets(TARGETS migraphx_pybind_${PYTHON_VERSION}) - add_dependencies(migraphx_py migraphx_pybind_${PYTHON_VERSION}) - - add_library(migraphx_py_${PYTHON_VERSION} py.cpp) - target_include_directories(migraphx_py_${PYTHON_VERSION} PRIVATE include) - target_link_libraries(migraphx_py_${PYTHON_VERSION} PUBLIC migraphx) - target_link_libraries(migraphx_py_${PYTHON_VERSION} PRIVATE pybind11::pybind11 python${PYTHON_VERSION}::runtime) - rocm_install_targets(TARGETS migraphx_py_${PYTHON_VERSION}) - add_dependencies(migraphx_py migraphx_py_${PYTHON_VERSION}) - endforeach() -endif() +foreach(PYTHON_VERSION ${PYTHON_VERSIONS}) + py_add_module(migraphx_pybind_${PYTHON_VERSION} migraphx_py.cpp PYTHON_VERSION ${PYTHON_VERSION} PYTHON_MODULE migraphx) + target_link_libraries(migraphx_pybind_${PYTHON_VERSION} PRIVATE migraphx migraphx_tf migraphx_onnx migraphx_all_targets) + rocm_install_targets(TARGETS migraphx_pybind_${PYTHON_VERSION}) + add_dependencies(migraphx_py migraphx_pybind_${PYTHON_VERSION}) + + add_library(migraphx_py_${PYTHON_VERSION} py.cpp) + target_include_directories(migraphx_py_${PYTHON_VERSION} PRIVATE include) + target_link_libraries(migraphx_py_${PYTHON_VERSION} PUBLIC migraphx) + target_link_libraries(migraphx_py_${PYTHON_VERSION} PRIVATE pybind11::pybind11 python${PYTHON_VERSION}::runtime) + rocm_install_targets(TARGETS migraphx_py_${PYTHON_VERSION}) + add_dependencies(migraphx_py migraphx_py_${PYTHON_VERSION}) +endforeach()