Skip to content

Commit

Permalink
POTRF: use TTG_ENABLE_CUDA/HIP, TTG_HAVE_* is internal
Browse files Browse the repository at this point in the history
Signed-off-by: Joseph Schuchart <[email protected]>
  • Loading branch information
devreal committed Mar 13, 2024
1 parent 1fd7cf3 commit d2563df
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions examples/potrf/potrf.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
#define ENABLE_DEVICE_KERNEL 1
#endif

#if defined(TTG_HAVE_CUDART)
#if defined(TTG_ENABLE_CUDA)
#define ES ttg::ExecutionSpace::CUDA
#define TASKRET -> ttg::device::Task
#include <cusolverDn.h>
#elif defined(TTG_HAVE_HIP)
#elif defined(TTG_ENABLE_HIP)
#define ES ttg::ExecutionSpace::HIP
#define TASKRET -> ttg::device::Task
#include <hipsolver/hipsolver.h>
Expand All @@ -35,13 +35,13 @@ namespace potrf {
#if defined(ENABLE_DEVICE_KERNEL)
static int device_potrf_workspace_size(MatrixTile<double> &A) {
int Lwork;
#if defined(TTG_HAVE_CUDA)
#if defined(TTG_ENABLE_CUDA)
cusolverDnDpotrf_bufferSize(cusolver_handle(),
CUBLAS_FILL_MODE_LOWER, A.cols(),
nullptr, A.lda(),
&Lwork);
return Lwork;
#elif defined(TTG_HAVE_HIPBLAS)
#elif defined(TTG_ENABLE_HIPBLAS)
hipsolverDnDpotrf_bufferSize(hipsolver_handle(),
HIPSOLVER_FILL_MODE_LOWER, A.cols(),
nullptr, A.lda(),
Expand All @@ -55,7 +55,7 @@ namespace potrf {
static void device_potrf(MatrixTile<double> &A, double *workspace, int Lwork, int *devInfo) {
int device = ttg::device::current_device();
assert(device >= 0);
#if defined(TTG_HAVE_CUDA)
#if defined(TTG_ENABLE_CUDA)
//std::cout << "POTRF A " << A.buffer().device_ptr_on(device) << " device " << device << " cols " << A.cols() << " lda " << A.lda() << " Lwork " << Lwork << " WS " << workspace << " devInfo " << devInfo << std::endl;
auto handle = cusolver_handle();
//std::cout << "POTRF handle " << handle << " device " << device << " stream " << ttg::device::current_stream() << std::endl;
Expand All @@ -64,7 +64,7 @@ namespace potrf {
A.buffer().current_device_ptr(), A.lda(),
workspace, Lwork,
devInfo);
#elif defined(TTG_HAVE_HIPBLAS)
#elif defined(TTG_ENABLE_HIPBLAS)
hipsolverDpotrf(hipsolver_handle(),
HIPSOLVER_FILL_MODE_LOWER, A.cols(),
A.buffer().current_device_ptr(), A.lda(),
Expand All @@ -77,11 +77,11 @@ namespace potrf {
auto size = A.size();
auto buffer = A.buffer().current_device_ptr();
//std::cout << "device_norm ptr " << buffer << " device " << ttg::device::current_device() << std::endl;
#if defined(TTG_HAVE_CUDA)
#if defined(TTG_ENABLE_CUDA)
auto handle = cublas_handle();
//double n = 1.0;
cublasDnrm2(handle, size, buffer, 1, norm);
#elif defined(TTG_HAVE_HIPBLAS)
#elif defined(TTG_ENABLE_HIPBLAS)
hipblasDnrm2(hipblas_handle(), size, buffer, 1, norm);
#endif
}
Expand Down Expand Up @@ -288,14 +288,14 @@ namespace potrf {

//std::cout << "TRSM [" << K << ", " << M << "] on " << device << std::endl;

#if defined(TTG_HAVE_CUDA)
#if defined(TTG_ENABLE_CUDA)
cublasDtrsm(cublas_handle(),
CUBLAS_SIDE_RIGHT, CUBLAS_FILL_MODE_LOWER,
CUBLAS_OP_T, CUBLAS_DIAG_NON_UNIT,
mb, nb, &alpha,
tile_kk.buffer().current_device_ptr(), tile_kk.lda(),
tile_mk.buffer().current_device_ptr(), tile_mk.lda());
#elif defined(TTG_HAVE_HIPBLAS)
#elif defined(TTG_ENABLE_HIPBLAS)
hipblasDtrsm(hipblas_handle(),
HIPBLAS_SIDE_RIGHT, HIPBLAS_FILL_MODE_LOWER,
HIPBLAS_OP_T, HIPBLAS_DIAG_NON_UNIT,
Expand Down Expand Up @@ -418,14 +418,14 @@ namespace potrf {

//std::cout << "SYRK [" << K << ", " << M << "] on " << device << std::endl;

#if defined(TTG_HAVE_CUDA)
#if defined(TTG_ENABLE_CUDA)
cublasDsyrk(cublas_handle(),
CUBLAS_FILL_MODE_LOWER,
CUBLAS_OP_N,
mb, nb, &alpha,
tile_mk.buffer().current_device_ptr(), tile_mk.lda(), &beta,
tile_kk.buffer().current_device_ptr(), tile_kk.lda());
#elif defined(TTG_HAVE_HIPBLAS)
#elif defined(TTG_ENABLE_HIPBLAS)
hipblasDsyrk(hipblas_handle(),
HIPBLAS_FILL_MODE_LOWER,
HIPBLAS_OP_N,
Expand Down Expand Up @@ -543,15 +543,15 @@ namespace potrf {
double alpha = -1.0;
double beta = 1.0;

#if defined(TTG_HAVE_CUDA)
#if defined(TTG_ENABLE_CUDA)
cublasDgemm(cublas_handle(),
CUBLAS_OP_N, CUBLAS_OP_T,
tile_mk.rows(), tile_nk.rows(),
tile_nk.cols(), &alpha,
tile_mk.buffer().current_device_ptr(), tile_mk.lda(),
tile_nk.buffer().current_device_ptr(), tile_nk.lda(), &beta,
tile_mn.buffer().current_device_ptr(), tile_mn.lda());
#elif defined(TTG_HAVE_HIPBLAS)
#elif defined(TTG_ENABLE_HIPBLAS)
hipblasDgemm(hipblas_handle(),
HIPBLAS_OP_N, HIPBLAS_OP_T,
tile_mk.rows(), tile_nk.rows(),
Expand Down

0 comments on commit d2563df

Please sign in to comment.