diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index bbcc0e9e0..c34e3c07e 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -22,7 +22,7 @@ if (TARGET tiledarray) if (TARGET CUDA::cublas) add_ttg_executable(bspmm-cuda spmm/spmm_cuda.cc LINK_LIBRARIES tiledarray TiledArray_Eigen BTAS CUDA::cublas - COMPILE_DEFINITIONS BLOCK_SPARSE_GEMM=1;BTAS_TARGET_MAX_INDEX_RANK=2 + COMPILE_DEFINITIONS BLOCK_SPARSE_GEMM=1;BTAS_TARGET_MAX_INDEX_RANK=2;TTG_ENABLE_CUDA=1 RUNTIMES "parsec") if (TARGET CUDA::cusolver) @@ -34,7 +34,7 @@ if (TARGET tiledarray) elseif (TARGET roc::hipblas) add_ttg_executable(bspmm-hip spmm/spmm_cuda.cc LINK_LIBRARIES tiledarray TiledArray_Eigen roc::hipblas - COMPILE_DEFINITIONS BLOCK_SPARSE_GEMM=1;BTAS_TARGET_MAX_INDEX_RANK=2 + COMPILE_DEFINITIONS BLOCK_SPARSE_GEMM=1;BTAS_TARGET_MAX_INDEX_RANK=2;TTG_ENABLE_HIP=1 RUNTIMES "parsec") if (TARGET roc::hipsolver) add_ttg_executable(testing_dpotrf_hip potrf/testing_dpotrf.cc @@ -45,7 +45,7 @@ if (TARGET tiledarray) elseif (TARGET MKL::MKL_DPCPP) add_ttg_executable(bspmm-lz spmm/spmm_cuda.cc LINK_LIBRARIES tiledarray TiledArray_Eigen BTAS MKL::MKL_DPCPP level_zero::ze_loader m - COMPILE_DEFINITIONS BLOCK_SPARSE_GEMM=1;BTAS_TARGET_MAX_INDEX_RANK=2 + COMPILE_DEFINITIONS BLOCK_SPARSE_GEMM=1;BTAS_TARGET_MAX_INDEX_RANK=2;TTG_ENABLE_LEVEL_ZERO=1 RUNTIMES "parsec") endif() diff --git a/examples/devblas_helper.h b/examples/devblas_helper.h index 5f2ad3e89..2d2139a54 100644 --- a/examples/devblas_helper.h +++ b/examples/devblas_helper.h @@ -10,7 +10,7 @@ #include #include -#ifdef TTG_HAVE_CUDART +#ifdef TTG_ENABLE_CUDA #include #include @@ -87,9 +87,9 @@ inline const cusolverDnHandle_t& cusolver_handle(T _ = 0) { return it->second; } -#endif // TTG_HAVE_CUDART +#endif // TTG_ENABLE_CUDA -#ifdef TTG_HAVE_HIPBLAS +#ifdef TTG_ENABLE_HIP #include #include @@ -162,4 +162,4 @@ inline const hipsolverDnHandle_t& hipsolver_handle(T _ = 0) { } return it->second; } -#endif // TTG_HAVE_HIPBLAS +#endif // TTG_ENABLE_HIP diff --git a/examples/potrf/potrf.h b/examples/potrf/potrf.h index f6ba6e147..78769fe30 100644 --- a/examples/potrf/potrf.h +++ b/examples/potrf/potrf.h @@ -15,7 +15,7 @@ #define ES ttg::ExecutionSpace::CUDA #define TASKRET -> ttg::device::Task #include -#elif defined(TTG_HAVE_HIP) +#elif defined(TTG_ENABLE_HIP) #define ES ttg::ExecutionSpace::HIP #define TASKRET -> ttg::device::Task #include @@ -35,13 +35,13 @@ namespace potrf { #if defined(ENABLE_DEVICE_KERNEL) static int device_potrf_workspace_size(MatrixTile &A) { int Lwork; - #if defined(TTG_HAVE_CUDA) + #if defined(TTG_ENABLE_CUDA) cusolverDnDpotrf_bufferSize(cusolver_handle(), CUBLAS_FILL_MODE_LOWER, A.cols(), nullptr, A.lda(), &Lwork); return Lwork; - #elif defined(TTG_HAVE_HIPBLAS) + #elif defined(TTG_ENABLE_HIP) hipsolverDnDpotrf_bufferSize(hipsolver_handle(), HIPSOLVER_FILL_MODE_LOWER, A.cols(), nullptr, A.lda(), @@ -55,7 +55,7 @@ namespace potrf { static void device_potrf(MatrixTile &A, double *workspace, int Lwork, int *devInfo) { int device = ttg::device::current_device(); assert(device >= 0); -#if defined(TTG_HAVE_CUDA) +#if defined(TTG_ENABLE_CUDA) //std::cout << "POTRF A " << A.buffer().device_ptr_on(device) << " device " << device << " cols " << A.cols() << " lda " << A.lda() << " Lwork " << Lwork << " WS " << workspace << " devInfo " << devInfo << std::endl; auto handle = cusolver_handle(); //std::cout << "POTRF handle " << handle << " device " << device << " stream " << ttg::device::current_stream() << std::endl; @@ -64,7 +64,7 @@ namespace potrf { A.buffer().current_device_ptr(), A.lda(), workspace, Lwork, devInfo); - #elif defined(TTG_HAVE_HIPBLAS) + #elif defined(TTG_ENABLE_HIP) hipsolverDpotrf(hipsolver_handle(), HIPSOLVER_FILL_MODE_LOWER, A.cols(), A.buffer().current_device_ptr(), A.lda(), @@ -77,11 +77,11 @@ namespace potrf { auto size = A.size(); auto buffer = A.buffer().current_device_ptr(); //std::cout << "device_norm ptr " << buffer << " device " << ttg::device::current_device() << std::endl; -#if defined(TTG_HAVE_CUDA) +#if defined(TTG_ENABLE_CUDA) auto handle = cublas_handle(); //double n = 1.0; cublasDnrm2(handle, size, buffer, 1, norm); - #elif defined(TTG_HAVE_HIPBLAS) + #elif defined(TTG_ENABLE_HIP) hipblasDnrm2(hipblas_handle(), size, buffer, 1, norm); #endif } @@ -288,14 +288,14 @@ namespace potrf { //std::cout << "TRSM [" << K << ", " << M << "] on " << device << std::endl; -#if defined(TTG_HAVE_CUDA) +#if defined(TTG_ENABLE_CUDA) cublasDtrsm(cublas_handle(), CUBLAS_SIDE_RIGHT, CUBLAS_FILL_MODE_LOWER, CUBLAS_OP_T, CUBLAS_DIAG_NON_UNIT, mb, nb, &alpha, tile_kk.buffer().current_device_ptr(), tile_kk.lda(), tile_mk.buffer().current_device_ptr(), tile_mk.lda()); -#elif defined(TTG_HAVE_HIPBLAS) +#elif defined(TTG_ENABLE_HIP) hipblasDtrsm(hipblas_handle(), HIPBLAS_SIDE_RIGHT, HIPBLAS_FILL_MODE_LOWER, HIPBLAS_OP_T, HIPBLAS_DIAG_NON_UNIT, @@ -418,14 +418,14 @@ namespace potrf { //std::cout << "SYRK [" << K << ", " << M << "] on " << device << std::endl; -#if defined(TTG_HAVE_CUDA) +#if defined(TTG_ENABLE_CUDA) cublasDsyrk(cublas_handle(), CUBLAS_FILL_MODE_LOWER, CUBLAS_OP_N, mb, nb, &alpha, tile_mk.buffer().current_device_ptr(), tile_mk.lda(), &beta, tile_kk.buffer().current_device_ptr(), tile_kk.lda()); -#elif defined(TTG_HAVE_HIPBLAS) +#elif defined(TTG_ENABLE_HIP) hipblasDsyrk(hipblas_handle(), HIPBLAS_FILL_MODE_LOWER, HIPBLAS_OP_N, @@ -543,7 +543,7 @@ namespace potrf { double alpha = -1.0; double beta = 1.0; -#if defined(TTG_HAVE_CUDA) +#if defined(TTG_ENABLE_CUDA) cublasDgemm(cublas_handle(), CUBLAS_OP_N, CUBLAS_OP_T, tile_mk.rows(), tile_nk.rows(), @@ -551,7 +551,7 @@ namespace potrf { tile_mk.buffer().current_device_ptr(), tile_mk.lda(), tile_nk.buffer().current_device_ptr(), tile_nk.lda(), &beta, tile_mn.buffer().current_device_ptr(), tile_mn.lda()); -#elif defined(TTG_HAVE_HIPBLAS) +#elif defined(TTG_ENABLE_HIP) hipblasDgemm(hipblas_handle(), HIPBLAS_OP_N, HIPBLAS_OP_T, tile_mk.rows(), tile_nk.rows(), diff --git a/examples/spmm/spmm_cuda.cc b/examples/spmm/spmm_cuda.cc index 90cca96f8..9dcf5928e 100644 --- a/examples/spmm/spmm_cuda.cc +++ b/examples/spmm/spmm_cuda.cc @@ -49,7 +49,7 @@ using namespace ttg; #include "ttg/serialization/std/pair.h" -#if defined(TTG_HAVE_LEVEL_ZERO) +#if defined(TTG_ENABLE_LEVEL_ZERO) #include #include #endif @@ -254,7 +254,7 @@ struct DeviceTensor : public ttg::TTValue> }; using scalar_t = double; -#if defined(TTG_HAVE_CUDA) || defined(TTG_HAVE_HIPBLAS) +#if defined(TTG_ENABLE_CUDA) || defined(TTG_ENABLE_HIP) using blk_t = DeviceTensor>, btas::Handle::shared_ptr>>; @@ -284,7 +284,7 @@ static void device_gemm(Blk &C, const Blk &A, const Blk &B) { //assert(B.b.get_current_device() != 0); auto device = ttg::device::current_device(); assert(device.is_device()); -#if defined(TTG_HAVE_CUDA) +#if defined(TTG_ENABLE_CUDA) if constexpr (std::is_same_v) { cublasDgemm(cublas_handle(), CUBLAS_OP_N, CUBLAS_OP_N, C.extent(0), C.extent(1), A.extent(1), &alpha, A.b.current_device_ptr(), A.extent(0), B.b.current_device_ptr(), B.extent(0), &beta, @@ -295,7 +295,7 @@ static void device_gemm(Blk &C, const Blk &A, const Blk &B) { &alpha, A.b.current_device_ptr(), A.extent(0), B.b.current_device_ptr(), B.extent(0), &beta, C.b.current_device_ptr(), C.extent(0)); } -#elif defined(TTG_HAVE_HIPBLAS) +#elif defined(TTG_ENABLE_HIP) if constexpr (std::is_same_v) { hipblasDgemm(hipblas_handle(), HIPBLAS_OP_N, HIPBLAS_OP_N, @@ -311,7 +311,7 @@ static void device_gemm(Blk &C, const Blk &A, const Blk &B) { B.b.current_device_ptr(), B.extent(0), &beta, C.b.current_device_ptr(), C.extent(0)); } -#elif defined(TTG_HAVE_LEVEL_ZERO) +#elif defined(TTG_ENABLE_LEVEL_ZERO) #if defined(DEBUG_SYNCHRONOUS) try { @@ -765,13 +765,13 @@ class SpMM25D { public: using baseT = typename MultiplyAdd::ttT; -#if defined(TTG_HAVE_CUDA) +#if defined(TTG_ENABLE_CUDA) static constexpr bool have_cuda_op = true; #warning SPMM using CUDA implementation -#elif defined(TTG_HAVE_HIPBLAS) +#elif defined(TTG_ENABLE_HIP) static constexpr bool have_hip_op = true; #warning SPMM using HIP implementation -#elif defined(TTG_HAVE_LEVEL_ZERO) +#elif defined(TTG_ENABLE_LEVEL_ZERO) static constexpr bool have_level_zero_op = true; #warning SPMM using LEVEL_ZERO implementation #else