diff --git a/include/dlaf/blas/tile.h b/include/dlaf/blas/tile.h index 47954f1f57..9c340d0b1d 100644 --- a/include/dlaf/blas/tile.h +++ b/include/dlaf/blas/tile.h @@ -137,7 +137,10 @@ DLAF_MAKE_GPUBLAS_SYHE_OP(Her2k, r2k); DLAF_MAKE_GPUBLAS_SYHE_OP(Herk, rk); -#if defined(DLAF_WITH_HIP) +#if defined(DLAF_WITH_CUDA) +DLAF_MAKE_GPUBLAS_OP(Trmm, trmm); +#elif defined(DLAF_WITH_HIP) + #if defined(__clang__) #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wdeprecated-declarations" @@ -145,14 +148,20 @@ DLAF_MAKE_GPUBLAS_SYHE_OP(Herk, rk); #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wdeprecated-declarations" #endif -#endif + +// TODO: What will be the upper bound? +#if HIP_VERSION < 50000000 || 50700000 <= HIP_VERSION DLAF_MAKE_GPUBLAS_OP(Trmm, trmm); -#if defined(DLAF_WITH_HIP) +#else +DLAF_MAKE_GPUBLAS_OP(Trmm, trmm_outofplace); +#endif + #if defined(__clang__) #pragma clang diagnostic pop #elif defined(__GNUC__) #pragma GCC diagnostic pop #endif + #endif DLAF_MAKE_GPUBLAS_OP(Trsm, trsm); @@ -406,12 +415,7 @@ void her2k(cublasHandle_t handle, const blas::Uplo uplo, blas::Op op, const T al using util::blasToCublas; using util::blasToCublasCast; auto s = getHer2kSizes(op, a, b, c); -#ifdef DLAF_WITH_HIP - // Note: - // Up to date the fix for this problem is on rocblas@develop, which should be included in - // the next 5.2.0 release. - // - // https://github.com/ROCmSoftwarePlatform/rocBLAS/commit/e714f1f29ab71dfcdfa4add4462548b34d1cd9e8 +#if defined(DLAF_WITH_HIP) && HIP_VERSION < 50200000 if (!isComplex_v && op == blas::Op::ConjTrans) op = blas::Op::Trans; #endif @@ -445,7 +449,7 @@ void trmm(cublasHandle_t handle, const blas::Side side, const blas::Uplo uplo, c gpublas::internal::Trmm::call(handle, blasToCublas(side), blasToCublas(uplo), blasToCublas(op), blasToCublas(diag), to_int(s.m), to_int(s.n), blasToCublasCast(&alpha), blasToCublasCast(a.ptr()), to_int(a.ld()), -#ifdef DLAF_WITH_CUDA +#if defined(DLAF_WITH_CUDA) || (defined(DLAF_WITH_HIP) && HIP_VERSION >= 50000000) blasToCublasCast(b.ptr()), to_int(b.ld()), #endif blasToCublasCast(b.ptr()), to_int(b.ld())); @@ -460,15 +464,16 @@ void trmm3(cublasHandle_t handle, const blas::Side side, const blas::Uplo uplo, auto s = tile::internal::getTrmm3Sizes(side, a, b, c); DLAF_ASSERT(b.ptr() == nullptr || b.ptr() != c.ptr(), b.ptr(), c.ptr()); -#ifdef DLAF_WITH_HIP +#if defined(DLAF_WITH_HIP) && HIP_VERSION < 50000000 whip::stream_t stream; DLAF_GPUBLAS_CHECK_ERROR(cublasGetStream(handle, &stream)); matrix::internal::copy(b, c, stream); #endif + gpublas::internal::Trmm::call(handle, blasToCublas(side), blasToCublas(uplo), blasToCublas(op), blasToCublas(diag), to_int(s.m), to_int(s.n), blasToCublasCast(&alpha), blasToCublasCast(a.ptr()), to_int(a.ld()), -#ifdef DLAF_WITH_CUDA +#if defined(DLAF_WITH_CUDA) || (defined(DLAF_WITH_HIP) && HIP_VERSION >= 50000000) blasToCublasCast(b.ptr()), to_int(b.ld()), #endif blasToCublasCast(c.ptr()), to_int(c.ld()));