Skip to content

Commit

Permalink
Further simplify rocblas_trmm usage
Browse files Browse the repository at this point in the history
  • Loading branch information
msimberg committed Sep 11, 2023
1 parent 1e46a36 commit b12b618
Showing 1 changed file with 0 additions and 17 deletions.
17 changes: 0 additions & 17 deletions include/dlaf/blas/tile.h
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,7 @@ DLAF_MAKE_GPUBLAS_OP(Trmm, trmm);
#pragma GCC diagnostic ignored "-Wdeprecated-declarations"
#endif

// TODO: What will be the upper bound?
#if HIP_VERSION < 50000000 || 50700000 <= HIP_VERSION
DLAF_MAKE_GPUBLAS_OP(Trmm, trmm);
#else
DLAF_MAKE_GPUBLAS_OP(Trmm, trmm_outofplace);
#endif

#if defined(__clang__)
#pragma clang diagnostic pop
Expand Down Expand Up @@ -449,9 +444,7 @@ void trmm(cublasHandle_t handle, const blas::Side side, const blas::Uplo uplo, c
gpublas::internal::Trmm<T>::call(handle, blasToCublas(side), blasToCublas(uplo), blasToCublas(op),
blasToCublas(diag), to_int(s.m), to_int(s.n),
blasToCublasCast(&alpha), blasToCublasCast(a.ptr()), to_int(a.ld()),
#if defined(DLAF_WITH_CUDA) || (defined(DLAF_WITH_HIP) && HIP_VERSION >= 50000000)
blasToCublasCast(b.ptr()), to_int(b.ld()),
#endif
blasToCublasCast(b.ptr()), to_int(b.ld()));
}

Expand All @@ -464,19 +457,9 @@ void trmm3(cublasHandle_t handle, const blas::Side side, const blas::Uplo uplo,
auto s = tile::internal::getTrmm3Sizes(side, a, b, c);
DLAF_ASSERT(b.ptr() == nullptr || b.ptr() != c.ptr(), b.ptr(), c.ptr());

#if defined(DLAF_WITH_HIP) && HIP_VERSION < 50000000
whip::stream_t stream;
DLAF_GPUBLAS_CHECK_ERROR(cublasGetStream(handle, &stream));
matrix::internal::copy(b, c, stream);
#endif

gpublas::internal::Trmm<T>::call(handle, blasToCublas(side), blasToCublas(uplo), blasToCublas(op),
blasToCublas(diag), to_int(s.m), to_int(s.n),
blasToCublasCast(&alpha), blasToCublasCast(a.ptr()), to_int(a.ld()),
#if defined(DLAF_WITH_CUDA) || (defined(DLAF_WITH_HIP) && HIP_VERSION >= 50000000)
blasToCublasCast(b.ptr()), to_int(b.ld()),
#endif
blasToCublasCast(c.ptr()), to_int(c.ld()));
}

template <class T>
Expand Down

0 comments on commit b12b618

Please sign in to comment.