Skip to content

Commit

Permalink
removed cublas v1 + SPMM tiles use pinned allocator ... does not compile
Browse files Browse the repository at this point in the history
yet
  • Loading branch information
evaleev committed Jul 14, 2023
1 parent f6d0afa commit e07140e
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 71 deletions.
2 changes: 1 addition & 1 deletion examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ if (TARGET eigen3)
if (NOT TARGET tiledarray)
message(FATAL_ERROR "TiledArray is not found; it is required for CUDA-based block-sparse SPMM")
endif()
add_ttg_executable(bspmm-cuda spmm/spmm_cuda.cc spmm/cuda_gemm.cc
add_ttg_executable(bspmm-cuda spmm/spmm_cuda.cc
LINK_LIBRARIES tiledarray eigen3 BTAS Boost::boost CUDA::cublas
COMPILE_DEFINITIONS BLOCK_SPARSE_GEMM=1;BTAS_TARGET_MAX_INDEX_RANK=2
RUNTIMES "parsec")
Expand Down
29 changes: 0 additions & 29 deletions examples/spmm/cuda_gemm.cc

This file was deleted.

24 changes: 0 additions & 24 deletions examples/spmm/cuda_gemm.h

This file was deleted.

5 changes: 3 additions & 2 deletions examples/spmm/spmm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
#include <btas/features.h>
#ifdef BTAS_IS_USABLE
#include <btas/btas.h>
#include <btas/optimize/contract.h>
#include <btas/util/mohndle.h>
#include <btas/optimize/contract.h>
#include <TiledArray/cuda/allocators.h>
#else
#warning "found btas/features.h but Boost.Iterators is missing, hence BTAS is unusable ... add -I/path/to/boost"
#endif
Expand All @@ -40,7 +41,7 @@ using namespace ttg;
#include "ttg/util/bug.h"

#if defined(BLOCK_SPARSE_GEMM) && defined(BTAS_IS_USABLE)
using blk_t = btas::Tensor<double, btas::DEFAULT::range, btas::mohndle<btas::varray<double>, btas::Handle::shared_ptr>>;
using blk_t = btas::Tensor<double, btas::DEFAULT::range, btas::mohndle<btas::varray<double, TiledArray::cuda_pinned_allocator<double>>, btas::Handle::shared_ptr>>;

#if defined(TTG_USE_PARSEC)
namespace ttg {
Expand Down
33 changes: 21 additions & 12 deletions ttg/ttg/device/cublas_helper.cpp
Original file line number Diff line number Diff line change
@@ -1,23 +1,32 @@
#ifdef TTG_HAVE_CUDART
#include <cublas.h>
#endif // TTG_HAVE_CUDART

#include <exception>
#include <stdexcept>

#include "ttg/config.h"
#include "ttg/device/cublas_helper.h"

#include <memory>
#include <stdexcept>
#include <optional>

namespace ttg::detail {

/* shim wrapper to work around the fact that cublas
* deliberately breaks its API depending on the order
* in which header are included */
#ifdef TTG_HAVE_CUDART
/// \brief Returns the cuBLAS handle to be used for launching cuBLAS kernels from the current thread
/// \return the cuBLAS handle for the current thread
inline const cublasHandle_t& cublas_get_handle() {
static thread_local std::optional<cublasHandle_t> handle;
if (!handle.has_value()) {
auto status = cublasCreate_v2(&handle.emplace());
if (CUBLAS_STATUS_SUCCESS != status) {
throw std::runtime_error("cublasCreate_v2 failed");
}
}
return *handle;
}
#endif // TTG_HAVE_CUDART

void cublas_set_kernel_stream(cudaStream_t stream) {
#ifdef TTG_HAVE_CUDART
cublasStatus_t status = cublasSetKernelStream(stream);
cublasStatus_t status = cublasSetStream_v2(cublas_get_handle(), stream);
if (CUBLAS_STATUS_SUCCESS != status) {
throw std::runtime_error("cublasSetKernelStream failed");
throw std::runtime_error("cublasSetStream_v2 failed");
}
#else
throw std::runtime_error("Support for cublas missing during installation!");
Expand Down
3 changes: 2 additions & 1 deletion ttg/ttg/device/cublas_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
#include "ttg/config.h"

#ifdef TTG_HAVE_CUDART
#include <cublas.h>
#include <cublas_v2.h>

namespace ttg::detail {

/// \brief Returns the current CUDA stream used by cuBLAS
void cublas_set_kernel_stream(cudaStream_t stream);

} // namespace ttg::detail
Expand Down
4 changes: 2 additions & 2 deletions ttg/ttg/parsec/ttg.h
Original file line number Diff line number Diff line change
Expand Up @@ -1362,8 +1362,8 @@ namespace ttg_parsec {
/* TODO: is this the right place to set the mask? */
task->parsec_task.chore_mask = PARSEC_DEV_ALL;
/* get a device and come back if we need another one */
int64_t task_load = 1;
dev_index = parsec_get_best_device(parsec_task, &task_load);
double task_load = 1.;
dev_index = parsec_get_best_device(parsec_task, task_load);
assert(dev_index >= 0);
if (dev_index < 2) {
return PARSEC_HOOK_RETURN_NEXT; /* Fall back */
Expand Down

0 comments on commit e07140e

Please sign in to comment.