Skip to content

Commit

Permalink
Merge pull request #116 from czgdp1807/pytorch
Browse files Browse the repository at this point in the history
Initial support for PyTorch C++ API
  • Loading branch information
czgdp1807 authored Mar 22, 2024
2 parents 6dcebb2 + 6664dbd commit 8b52363
Show file tree
Hide file tree
Showing 5 changed files with 260 additions and 18 deletions.
13 changes: 13 additions & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,16 @@ jobs:
export CPATH=$CONDA_PREFIX/include:$CPATH
./integration_tests/run_tests.py -b gcc llvm wasm c
./integration_tests/run_tests.py -b gcc llvm wasm c -f
- name: Test4 (Linux)
shell: bash -l -e {0}
if: contains(matrix.os, 'ubuntu')
run: |
export CPATH=$CONDA_PREFIX/include:$CPATH
conda install --yes pytorch::pytorch
cp -r $CONDA_PREFIX/lib/python3.12/site-packages/torch/include/* $CONDA_PREFIX/include/
cp -r $CONDA_PREFIX/lib/python3.12/site-packages/torch/lib/* $CONDA_PREFIX/lib/
cp -r $CONDA_PREFIX/lib/python3.12/site-packages/torch/share/* $CONDA_PREFIX/share/
./integration_tests/run_tests.py -b pytorch
export CPATH=$CPATH:$CONDA_PREFIX/include/torch/csrc/api/include
./integration_tests/run_tests.py -b llvmPytorch
16 changes: 14 additions & 2 deletions integration_tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ macro(RUN_UTIL RUN_FAIL RUN_NAME RUN_FILE_NAME RUN_LABELS RUN_EXTRAFILES RUN_EXT

if (ADD_TEST)
if ((LC_BACKEND STREQUAL "llvm") OR (LC_BACKEND STREQUAL "cpp") OR (LC_BACKEND STREQUAL "x86")
OR (LC_BACKEND STREQUAL "c") OR (LC_BACKEND STREQUAL "fortran"))
OR (LC_BACKEND STREQUAL "c") OR (LC_BACKEND STREQUAL "fortran") OR (LC_BACKEND STREQUAL "llvmPytorch"))
add_custom_command(
OUTPUT ${name}.o
COMMAND ${LC} -c ${CMAKE_CURRENT_SOURCE_DIR}/${file_name} -o ${name}.o ${extra_args}
Expand All @@ -110,6 +110,16 @@ macro(RUN_UTIL RUN_FAIL RUN_NAME RUN_FILE_NAME RUN_LABELS RUN_EXTRAFILES RUN_EXT
endif()
set(WASM_EXEC_FLAGS ${WASM_EXEC_FLAGS} "--experimental-wasi-unstable-preview1")
add_test(${name} ${WASM_EXEC_RUNTIME} ${WASM_EXEC_FLAGS} ${CURRENT_BINARY_DIR}/${name}.js)
elseif (LC_BACKEND STREQUAL "pytorch")
# PyTorch C++ API testing
find_package(Torch REQUIRED)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")

add_executable(${name} ${file_name} ${extra_files})
target_link_libraries(${name} "${TORCH_LIBRARIES}")
target_compile_options(${name} PUBLIC ${gcc_args})
set_property(TARGET ${name} PROPERTY CXX_STANDARD 17)
add_test(${name} ${CURRENT_BINARY_DIR}/${name})
else ()
add_executable(${name} ${file_name} ${extra_files})
target_compile_options(${name} PUBLIC ${gcc_args})
Expand Down Expand Up @@ -138,7 +148,7 @@ macro(RUN)
"${multiValueArgs}" ${ARGN} )

foreach(b ${RUN_LABELS})
if (NOT (b MATCHES "^(llvm|llvm2|llvm_rtlib|gcc|c|cpp|x86|wasm|gfortran|llvmImplicit|llvmStackArray|fortran|c_nopragma|llvm_nopragma)$"))
if (NOT (b MATCHES "^(llvm|llvm2|llvm_rtlib|gcc|c|cpp|x86|wasm|gfortran|llvmImplicit|llvmStackArray|fortran|c_nopragma|llvm_nopragma|pytorch|llvmPytorch)$"))
message(FATAL_ERROR "Unsupported backend: ${b}")
endif()
endforeach()
Expand Down Expand Up @@ -241,3 +251,5 @@ RUN(NAME vector_02.cpp LABELS gcc llvm)
RUN(NAME loop_01.cpp LABELS gcc llvm NOFAST)

RUN(NAME test_pkg_lnn_01.cpp LABELS gcc llvm NOFAST)

RUN(NAME pytorch_01.cpp LABELS pytorch llvmPytorch)
19 changes: 19 additions & 0 deletions integration_tests/pytorch_01.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#include <torch/torch.h>
#include <iostream>

void check(const torch::Tensor& tensor=torch::empty({1})) {
float array[5] = {4.0, 2.0, 2.0, 12.0, 2.0};
std::cout << tensor << "\n";
if( torch::any(torch::abs(tensor - torch::from_blob(array, {5})) > 1e-8).item<bool>() ) {
exit(2);
}
}

int main() {
torch::Tensor tensor = torch::ones(5);
tensor[0] = 2.0;
tensor[3] = 6.0;
tensor = 2 * tensor;
check(tensor);
std::cout << tensor << "\n";
}
2 changes: 1 addition & 1 deletion integration_tests/run_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
NO_OF_THREADS = 8 # default no of threads is 8
SUPPORTED_BACKENDS = ['llvm', 'llvm2', 'llvm_rtlib', 'c', 'cpp', 'x86', 'wasm',
'gcc', 'llvmImplicit', 'llvmStackArray', 'fortran',
'c_nopragma', 'llvm_nopragma']
'c_nopragma', 'llvm_nopragma', 'pytorch', 'llvmPytorch']
BASE_DIR = os.path.dirname(os.path.realpath(__file__))
LC_PATH = f"{BASE_DIR}/../src/bin:$PATH"

Expand Down
Loading

0 comments on commit 8b52363

Please sign in to comment.