Skip to content

Commit

Permalink
Merge pull request #281 from devreal/cuda_hip_guard
Browse files Browse the repository at this point in the history
Change guards for CUDA/HIP/L0 to TTG_ENABLE_*
  • Loading branch information
devreal authored Jun 4, 2024
2 parents b6c87e5 + 335b8ac commit b7190e9
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 28 deletions.
6 changes: 3 additions & 3 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ if (TARGET tiledarray)
if (TARGET CUDA::cublas)
add_ttg_executable(bspmm-cuda spmm/spmm_cuda.cc
LINK_LIBRARIES tiledarray TiledArray_Eigen BTAS CUDA::cublas
COMPILE_DEFINITIONS BLOCK_SPARSE_GEMM=1;BTAS_TARGET_MAX_INDEX_RANK=2
COMPILE_DEFINITIONS BLOCK_SPARSE_GEMM=1;BTAS_TARGET_MAX_INDEX_RANK=2;TTG_ENABLE_CUDA=1
RUNTIMES "parsec")

if (TARGET CUDA::cusolver)
Expand All @@ -34,7 +34,7 @@ if (TARGET tiledarray)
elseif (TARGET roc::hipblas)
add_ttg_executable(bspmm-hip spmm/spmm_cuda.cc
LINK_LIBRARIES tiledarray TiledArray_Eigen roc::hipblas
COMPILE_DEFINITIONS BLOCK_SPARSE_GEMM=1;BTAS_TARGET_MAX_INDEX_RANK=2
COMPILE_DEFINITIONS BLOCK_SPARSE_GEMM=1;BTAS_TARGET_MAX_INDEX_RANK=2;TTG_ENABLE_HIP=1
RUNTIMES "parsec")
if (TARGET roc::hipsolver)
add_ttg_executable(testing_dpotrf_hip potrf/testing_dpotrf.cc
Expand All @@ -45,7 +45,7 @@ if (TARGET tiledarray)
elseif (TARGET MKL::MKL_DPCPP)
add_ttg_executable(bspmm-lz spmm/spmm_cuda.cc
LINK_LIBRARIES tiledarray TiledArray_Eigen BTAS MKL::MKL_DPCPP level_zero::ze_loader m
COMPILE_DEFINITIONS BLOCK_SPARSE_GEMM=1;BTAS_TARGET_MAX_INDEX_RANK=2
COMPILE_DEFINITIONS BLOCK_SPARSE_GEMM=1;BTAS_TARGET_MAX_INDEX_RANK=2;TTG_ENABLE_LEVEL_ZERO=1
RUNTIMES "parsec")
endif()

Expand Down
8 changes: 4 additions & 4 deletions examples/devblas_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#include <map>
#include <mutex>

#ifdef TTG_HAVE_CUDART
#ifdef TTG_ENABLE_CUDA

#include <cuda_runtime.h>
#include <cublas_v2.h>
Expand Down Expand Up @@ -87,9 +87,9 @@ inline const cusolverDnHandle_t& cusolver_handle(T _ = 0) {

return it->second;
}
#endif // TTG_HAVE_CUDART
#endif // TTG_ENABLE_CUDA

#ifdef TTG_HAVE_HIPBLAS
#ifdef TTG_ENABLE_HIP

#include <hip/hip_runtime.h>
#include <hipblas/hipblas.h>
Expand Down Expand Up @@ -162,4 +162,4 @@ inline const hipsolverDnHandle_t& hipsolver_handle(T _ = 0) {
}
return it->second;
}
#endif // TTG_HAVE_HIPBLAS
#endif // TTG_ENABLE_HIP
26 changes: 13 additions & 13 deletions examples/potrf/potrf.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#define ES ttg::ExecutionSpace::CUDA
#define TASKRET -> ttg::device::Task
#include <cusolverDn.h>
#elif defined(TTG_HAVE_HIP)
#elif defined(TTG_ENABLE_HIP)
#define ES ttg::ExecutionSpace::HIP
#define TASKRET -> ttg::device::Task
#include <hipsolver/hipsolver.h>
Expand All @@ -35,13 +35,13 @@ namespace potrf {
#if defined(ENABLE_DEVICE_KERNEL)
static int device_potrf_workspace_size(MatrixTile<double> &A) {
int Lwork;
#if defined(TTG_HAVE_CUDA)
#if defined(TTG_ENABLE_CUDA)
cusolverDnDpotrf_bufferSize(cusolver_handle(),
CUBLAS_FILL_MODE_LOWER, A.cols(),
nullptr, A.lda(),
&Lwork);
return Lwork;
#elif defined(TTG_HAVE_HIPBLAS)
#elif defined(TTG_ENABLE_HIP)
hipsolverDnDpotrf_bufferSize(hipsolver_handle(),
HIPSOLVER_FILL_MODE_LOWER, A.cols(),
nullptr, A.lda(),
Expand All @@ -55,7 +55,7 @@ namespace potrf {
static void device_potrf(MatrixTile<double> &A, double *workspace, int Lwork, int *devInfo) {
int device = ttg::device::current_device();
assert(device >= 0);
#if defined(TTG_HAVE_CUDA)
#if defined(TTG_ENABLE_CUDA)
//std::cout << "POTRF A " << A.buffer().device_ptr_on(device) << " device " << device << " cols " << A.cols() << " lda " << A.lda() << " Lwork " << Lwork << " WS " << workspace << " devInfo " << devInfo << std::endl;
auto handle = cusolver_handle();
//std::cout << "POTRF handle " << handle << " device " << device << " stream " << ttg::device::current_stream() << std::endl;
Expand All @@ -64,7 +64,7 @@ namespace potrf {
A.buffer().current_device_ptr(), A.lda(),
workspace, Lwork,
devInfo);
#elif defined(TTG_HAVE_HIPBLAS)
#elif defined(TTG_ENABLE_HIP)
hipsolverDpotrf(hipsolver_handle(),
HIPSOLVER_FILL_MODE_LOWER, A.cols(),
A.buffer().current_device_ptr(), A.lda(),
Expand All @@ -77,11 +77,11 @@ namespace potrf {
auto size = A.size();
auto buffer = A.buffer().current_device_ptr();
//std::cout << "device_norm ptr " << buffer << " device " << ttg::device::current_device() << std::endl;
#if defined(TTG_HAVE_CUDA)
#if defined(TTG_ENABLE_CUDA)
auto handle = cublas_handle();
//double n = 1.0;
cublasDnrm2(handle, size, buffer, 1, norm);
#elif defined(TTG_HAVE_HIPBLAS)
#elif defined(TTG_ENABLE_HIP)
hipblasDnrm2(hipblas_handle(), size, buffer, 1, norm);
#endif
}
Expand Down Expand Up @@ -288,14 +288,14 @@ namespace potrf {

//std::cout << "TRSM [" << K << ", " << M << "] on " << device << std::endl;

#if defined(TTG_HAVE_CUDA)
#if defined(TTG_ENABLE_CUDA)
cublasDtrsm(cublas_handle(),
CUBLAS_SIDE_RIGHT, CUBLAS_FILL_MODE_LOWER,
CUBLAS_OP_T, CUBLAS_DIAG_NON_UNIT,
mb, nb, &alpha,
tile_kk.buffer().current_device_ptr(), tile_kk.lda(),
tile_mk.buffer().current_device_ptr(), tile_mk.lda());
#elif defined(TTG_HAVE_HIPBLAS)
#elif defined(TTG_ENABLE_HIP)
hipblasDtrsm(hipblas_handle(),
HIPBLAS_SIDE_RIGHT, HIPBLAS_FILL_MODE_LOWER,
HIPBLAS_OP_T, HIPBLAS_DIAG_NON_UNIT,
Expand Down Expand Up @@ -418,14 +418,14 @@ namespace potrf {

//std::cout << "SYRK [" << K << ", " << M << "] on " << device << std::endl;

#if defined(TTG_HAVE_CUDA)
#if defined(TTG_ENABLE_CUDA)
cublasDsyrk(cublas_handle(),
CUBLAS_FILL_MODE_LOWER,
CUBLAS_OP_N,
mb, nb, &alpha,
tile_mk.buffer().current_device_ptr(), tile_mk.lda(), &beta,
tile_kk.buffer().current_device_ptr(), tile_kk.lda());
#elif defined(TTG_HAVE_HIPBLAS)
#elif defined(TTG_ENABLE_HIP)
hipblasDsyrk(hipblas_handle(),
HIPBLAS_FILL_MODE_LOWER,
HIPBLAS_OP_N,
Expand Down Expand Up @@ -543,15 +543,15 @@ namespace potrf {
double alpha = -1.0;
double beta = 1.0;

#if defined(TTG_HAVE_CUDA)
#if defined(TTG_ENABLE_CUDA)
cublasDgemm(cublas_handle(),
CUBLAS_OP_N, CUBLAS_OP_T,
tile_mk.rows(), tile_nk.rows(),
tile_nk.cols(), &alpha,
tile_mk.buffer().current_device_ptr(), tile_mk.lda(),
tile_nk.buffer().current_device_ptr(), tile_nk.lda(), &beta,
tile_mn.buffer().current_device_ptr(), tile_mn.lda());
#elif defined(TTG_HAVE_HIPBLAS)
#elif defined(TTG_ENABLE_HIP)
hipblasDgemm(hipblas_handle(),
HIPBLAS_OP_N, HIPBLAS_OP_T,
tile_mk.rows(), tile_nk.rows(),
Expand Down
16 changes: 8 additions & 8 deletions examples/spmm/spmm_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ using namespace ttg;

#include "ttg/serialization/std/pair.h"

#if defined(TTG_HAVE_LEVEL_ZERO)
#if defined(TTG_ENABLE_LEVEL_ZERO)
#include <oneapi/mkl.hpp>
#include <sys/time.h>
#endif
Expand Down Expand Up @@ -254,7 +254,7 @@ struct DeviceTensor : public ttg::TTValue<DeviceTensor<_T, _Range, _Storage>>
};

using scalar_t = double;
#if defined(TTG_HAVE_CUDA) || defined(TTG_HAVE_HIPBLAS)
#if defined(TTG_ENABLE_CUDA) || defined(TTG_ENABLE_HIP)
using blk_t = DeviceTensor<scalar_t, btas::DEFAULT::range,
btas::mohndle<btas::varray<scalar_t, TiledArray::device_pinned_allocator<scalar_t>>,
btas::Handle::shared_ptr>>;
Expand Down Expand Up @@ -284,7 +284,7 @@ static void device_gemm(Blk &C, const Blk &A, const Blk &B) {
//assert(B.b.get_current_device() != 0);
auto device = ttg::device::current_device();
assert(device.is_device());
#if defined(TTG_HAVE_CUDA)
#if defined(TTG_ENABLE_CUDA)
if constexpr (std::is_same_v<T,double>) {
cublasDgemm(cublas_handle(), CUBLAS_OP_N, CUBLAS_OP_N, C.extent(0), C.extent(1), A.extent(1),
&alpha, A.b.current_device_ptr(), A.extent(0), B.b.current_device_ptr(), B.extent(0), &beta,
Expand All @@ -295,7 +295,7 @@ static void device_gemm(Blk &C, const Blk &A, const Blk &B) {
&alpha, A.b.current_device_ptr(), A.extent(0), B.b.current_device_ptr(), B.extent(0), &beta,
C.b.current_device_ptr(), C.extent(0));
}
#elif defined(TTG_HAVE_HIPBLAS)
#elif defined(TTG_ENABLE_HIP)
if constexpr (std::is_same_v<T,double>) {
hipblasDgemm(hipblas_handle(),
HIPBLAS_OP_N, HIPBLAS_OP_N,
Expand All @@ -311,7 +311,7 @@ static void device_gemm(Blk &C, const Blk &A, const Blk &B) {
B.b.current_device_ptr(), B.extent(0), &beta,
C.b.current_device_ptr(), C.extent(0));
}
#elif defined(TTG_HAVE_LEVEL_ZERO)
#elif defined(TTG_ENABLE_LEVEL_ZERO)

#if defined(DEBUG_SYNCHRONOUS)
try {
Expand Down Expand Up @@ -765,13 +765,13 @@ class SpMM25D {
public:
using baseT = typename MultiplyAdd::ttT;

#if defined(TTG_HAVE_CUDA)
#if defined(TTG_ENABLE_CUDA)
static constexpr bool have_cuda_op = true;
#warning SPMM using CUDA implementation
#elif defined(TTG_HAVE_HIPBLAS)
#elif defined(TTG_ENABLE_HIP)
static constexpr bool have_hip_op = true;
#warning SPMM using HIP implementation
#elif defined(TTG_HAVE_LEVEL_ZERO)
#elif defined(TTG_ENABLE_LEVEL_ZERO)
static constexpr bool have_level_zero_op = true;
#warning SPMM using LEVEL_ZERO implementation
#else
Expand Down

0 comments on commit b7190e9

Please sign in to comment.