From 175361842b2d2d5c12ed1d712fc19094b1ac2b64 Mon Sep 17 00:00:00 2001 From: "romain.biessy" Date: Tue, 6 Feb 2024 13:59:23 +0000 Subject: [PATCH] WIP cuSPARSE backend --- CMakeLists.txt | 8 +- cmake/FindCompiler.cmake | 4 +- .../sparse_blas_gemv_usm_mklcpu.cpp | 2 +- .../run_time_dispatching/CMakeLists.txt | 2 +- include/oneapi/mkl/detail/backends.hpp | 25 +- include/oneapi/mkl/detail/backends_table.hpp | 6 + include/oneapi/mkl/sparse_blas.hpp | 3 + .../cusparse/onemkl_sparse_blas_cusparse.hpp | 33 +++ .../detail/cusparse/sparse_blas_ct.hpp | 40 +++ .../mkl/sparse_blas/detail/data_types.hpp | 40 +++ .../mkl/sparse_blas/detail/helper_types.hpp | 2 - .../mklcpu/onemkl_sparse_blas_mklcpu.hpp | 1 + .../detail/mklcpu/sparse_blas_ct.hpp | 1 - .../mklgpu/onemkl_sparse_blas_mklgpu.hpp | 1 + .../detail/mklgpu/sparse_blas_ct.hpp | 1 - .../detail/onemkl_sparse_blas_backends.hxx | 124 ++++++--- .../sparse_blas/detail/operation_types.hpp | 36 +++ .../mkl/sparse_blas/detail/sparse_blas_ct.hxx | 145 +++++++--- .../mkl/sparse_blas/detail/sparse_blas_rt.hpp | 100 +++++-- include/oneapi/mkl/sparse_blas/types.hpp | 24 +- src/config.hpp.in | 1 + src/sparse_blas/backends/CMakeLists.txt | 4 + src/sparse_blas/backends/backend_wrappers.cxx | 72 +++-- .../backends/cusparse/CMakeLists.txt | 85 ++++++ .../backends/cusparse/cusparse_data_types.cpp | 258 ++++++++++++++++++ .../backends/cusparse/cusparse_error.hpp | 94 +++++++ .../backends/cusparse/cusparse_handle.hpp | 63 +++++ .../backends/cusparse/cusparse_helper.hpp | 144 ++++++++++ .../cusparse/cusparse_internal_containers.hpp | 211 ++++++++++++++ .../cusparse/cusparse_scope_handle.cpp | 128 +++++++++ .../cusparse/cusparse_scope_handle.hpp | 82 ++++++ .../backends/cusparse/cusparse_task.hpp | 130 +++++++++ .../backends/cusparse/cusparse_wrappers.cpp | 32 +++ .../cusparse/operations/cusparse_gemm.cpp | 74 +++++ .../cusparse/operations/cusparse_gemv.cpp | 60 ++++ .../cusparse/operations/cusparse_trsv.cpp | 196 +++++++++++++ .../backends/mkl_common/mkl_helper.hpp | 5 + src/sparse_blas/function_table.hpp | 124 ++++++--- src/sparse_blas/sparse_blas_loader.cpp | 177 ++++++++---- tests/unit_tests/CMakeLists.txt | 5 + tests/unit_tests/include/test_helper.hpp | 8 + tests/unit_tests/main_test.cpp | 3 +- .../sparse_blas/include/test_common.hpp | 7 - .../sparse_blas/source/CMakeLists.txt | 11 +- .../sparse_blas/source/sparse_trsv_buffer.cpp | 106 +++---- 45 files changed, 2382 insertions(+), 296 deletions(-) create mode 100644 include/oneapi/mkl/sparse_blas/detail/cusparse/onemkl_sparse_blas_cusparse.hpp create mode 100644 include/oneapi/mkl/sparse_blas/detail/cusparse/sparse_blas_ct.hpp create mode 100644 include/oneapi/mkl/sparse_blas/detail/data_types.hpp create mode 100644 include/oneapi/mkl/sparse_blas/detail/operation_types.hpp create mode 100644 src/sparse_blas/backends/cusparse/CMakeLists.txt create mode 100644 src/sparse_blas/backends/cusparse/cusparse_data_types.cpp create mode 100644 src/sparse_blas/backends/cusparse/cusparse_error.hpp create mode 100644 src/sparse_blas/backends/cusparse/cusparse_handle.hpp create mode 100644 src/sparse_blas/backends/cusparse/cusparse_helper.hpp create mode 100644 src/sparse_blas/backends/cusparse/cusparse_internal_containers.hpp create mode 100644 src/sparse_blas/backends/cusparse/cusparse_scope_handle.cpp create mode 100644 src/sparse_blas/backends/cusparse/cusparse_scope_handle.hpp create mode 100644 src/sparse_blas/backends/cusparse/cusparse_task.hpp create mode 100644 src/sparse_blas/backends/cusparse/cusparse_wrappers.cpp create mode 100644 src/sparse_blas/backends/cusparse/operations/cusparse_gemm.cpp create mode 100644 src/sparse_blas/backends/cusparse/operations/cusparse_gemv.cpp create mode 100644 src/sparse_blas/backends/cusparse/operations/cusparse_trsv.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index ddd4eb3d3..251a5ffee 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -66,6 +66,9 @@ option(ENABLE_CUFFT_BACKEND "Enable the cuFFT backend for the DFT interface" OFF option(ENABLE_ROCFFT_BACKEND "Enable the rocFFT backend for the DFT interface" OFF) option(ENABLE_PORTFFT_BACKEND "Enable the portFFT DFT backend for the DFT interface. Cannot be used with other DFT backends." OFF) +# sparse +option(ENABLE_CUSPARSE_BACKEND "Enable the cuSPARSE backend for the SPARSE_BLAS interface" OFF) + set(ONEMKL_SYCL_IMPLEMENTATION "dpc++" CACHE STRING "Name of the SYCL compiler") set(HIP_TARGETS "" CACHE STRING "Target HIP architectures") @@ -108,7 +111,8 @@ if(ENABLE_MKLGPU_BACKEND list(APPEND DOMAINS_LIST "dft") endif() if(ENABLE_MKLCPU_BACKEND - OR ENABLE_MKLGPU_BACKEND) + OR ENABLE_MKLGPU_BACKEND + OR ENABLE_CUSPARSE_BACKEND) list(APPEND DOMAINS_LIST "sparse_blas") endif() @@ -135,7 +139,7 @@ if(CMAKE_CXX_COMPILER OR NOT ONEMKL_SYCL_IMPLEMENTATION STREQUAL "dpc++") string(REPLACE "\\" "/" CMAKE_CXX_COMPILER ${CMAKE_CXX_COMPILER}) endif() else() - if(ENABLE_CUBLAS_BACKEND OR ENABLE_CURAND_BACKEND OR ENABLE_CUSOLVER_BACKEND OR ENABLE_CUFFT_BACKEND + if(ENABLE_CUBLAS_BACKEND OR ENABLE_CURAND_BACKEND OR ENABLE_CUSOLVER_BACKEND OR ENABLE_CUFFT_BACKEND OR ENABLE_CUSPARSE_BACKEND OR ENABLE_ROCBLAS_BACKEND OR ENABLE_ROCRAND_BACKEND OR ENABLE_ROCSOLVER_BACKEND OR ENABLE_ROCFFT_BACKEND) set(CMAKE_CXX_COMPILER "clang++") elseif(ENABLE_MKLGPU_BACKEND) diff --git a/cmake/FindCompiler.cmake b/cmake/FindCompiler.cmake index 9b1e54a9f..905f26456 100644 --- a/cmake/FindCompiler.cmake +++ b/cmake/FindCompiler.cmake @@ -36,7 +36,7 @@ if(is_dpcpp) if(UNIX) set(UNIX_INTERFACE_COMPILE_OPTIONS -fsycl) set(UNIX_INTERFACE_LINK_OPTIONS -fsycl) - if(ENABLE_CURAND_BACKEND OR ENABLE_CUSOLVER_BACKEND) + if(ENABLE_CURAND_BACKEND OR ENABLE_CUSOLVER_BACKEND OR ENABLE_CUSPARSE_BACKEND) list(APPEND UNIX_INTERFACE_COMPILE_OPTIONS -fsycl-targets=nvptx64-nvidia-cuda -fsycl-unnamed-lambda) list(APPEND UNIX_INTERFACE_LINK_OPTIONS @@ -50,7 +50,7 @@ if(is_dpcpp) -fsycl-targets=amdgcn-amd-amdhsa -Xsycl-target-backend --offload-arch=${HIP_TARGETS}) endif() - if(ENABLE_CURAND_BACKEND OR ENABLE_CUSOLVER_BACKEND OR ENABLE_ROCBLAS_BACKEND + if(ENABLE_CURAND_BACKEND OR ENABLE_CUSOLVER_BACKEND OR ENABLE_CUSPARSE_BACKEND OR ENABLE_ROCBLAS_BACKEND OR ENABLE_ROCRAND_BACKEND OR ENABLE_ROCSOLVER_BACKEND) set_target_properties(ONEMKL::SYCL::SYCL PROPERTIES INTERFACE_COMPILE_OPTIONS "${UNIX_INTERFACE_COMPILE_OPTIONS}" diff --git a/examples/sparse_blas/compile_time_dispatching/sparse_blas_gemv_usm_mklcpu.cpp b/examples/sparse_blas/compile_time_dispatching/sparse_blas_gemv_usm_mklcpu.cpp index edb6d7e1f..79ea3673d 100644 --- a/examples/sparse_blas/compile_time_dispatching/sparse_blas_gemv_usm_mklcpu.cpp +++ b/examples/sparse_blas/compile_time_dispatching/sparse_blas_gemv_usm_mklcpu.cpp @@ -229,7 +229,7 @@ int main(int /*argc*/, char ** /*argv*/) { print_example_banner(); try { - // TODO: Add cuSPARSE compile-time dispatcher in this example once it is supported. + // TODO(Romain): Add cuSPARSE compile-time dispatcher in this example once it is supported. sycl::device cpu_dev(sycl::cpu_selector_v); std::cout << "Running Sparse BLAS GEMV USM example on CPU device." << std::endl; diff --git a/examples/sparse_blas/run_time_dispatching/CMakeLists.txt b/examples/sparse_blas/run_time_dispatching/CMakeLists.txt index fe16587a4..09ffe57b5 100644 --- a/examples/sparse_blas/run_time_dispatching/CMakeLists.txt +++ b/examples/sparse_blas/run_time_dispatching/CMakeLists.txt @@ -30,7 +30,7 @@ set(DEVICE_FILTERS "") if(ENABLE_MKLCPU_BACKEND) list(APPEND DEVICE_FILTERS "cpu") endif() -if(ENABLE_MKLGPU_BACKEND) +if(ENABLE_MKLGPU_BACKEND OR ENABLE_CUSPARSE_BACKEND) list(APPEND DEVICE_FILTERS "gpu") endif() diff --git a/include/oneapi/mkl/detail/backends.hpp b/include/oneapi/mkl/detail/backends.hpp index 32b7c2614..ded06c2e8 100644 --- a/include/oneapi/mkl/detail/backends.hpp +++ b/include/oneapi/mkl/detail/backends.hpp @@ -40,20 +40,27 @@ enum class backend { cufft, rocfft, portfft, + cusparse, unsupported }; typedef std::map backendmap; -static backendmap backend_map = { - { backend::mklcpu, "mklcpu" }, { backend::mklgpu, "mklgpu" }, - { backend::cublas, "cublas" }, { backend::cusolver, "cusolver" }, - { backend::curand, "curand" }, { backend::netlib, "netlib" }, - { backend::rocblas, "rocblas" }, { backend::rocrand, "rocrand" }, - { backend::rocsolver, "rocsolver" }, { backend::portblas, "portblas" }, - { backend::cufft, "cufft" }, { backend::rocfft, "rocfft" }, - { backend::portfft, "portfft" }, { backend::unsupported, "unsupported" } -}; +static backendmap backend_map = { { backend::mklcpu, "mklcpu" }, + { backend::mklgpu, "mklgpu" }, + { backend::cublas, "cublas" }, + { backend::cusolver, "cusolver" }, + { backend::curand, "curand" }, + { backend::netlib, "netlib" }, + { backend::rocblas, "rocblas" }, + { backend::rocrand, "rocrand" }, + { backend::rocsolver, "rocsolver" }, + { backend::portblas, "portblas" }, + { backend::cufft, "cufft" }, + { backend::rocfft, "rocfft" }, + { backend::portfft, "portfft" }, + { backend::cusparse, "cusparse" }, + { backend::unsupported, "unsupported" } }; } //namespace mkl } //namespace oneapi diff --git a/include/oneapi/mkl/detail/backends_table.hpp b/include/oneapi/mkl/detail/backends_table.hpp index 8e68674cc..8a79c5c06 100644 --- a/include/oneapi/mkl/detail/backends_table.hpp +++ b/include/oneapi/mkl/detail/backends_table.hpp @@ -186,6 +186,12 @@ static std::map>> libraries = { #ifdef ENABLE_MKLGPU_BACKEND LIB_NAME("sparse_blas_mklgpu") +#endif + } }, + { device::nvidiagpu, + { +#ifdef ENABLE_CUSPARSE_BACKEND + LIB_NAME("sparse_blas_cusparse") #endif } } } }, }; diff --git a/include/oneapi/mkl/sparse_blas.hpp b/include/oneapi/mkl/sparse_blas.hpp index 912a20eb8..73e6753c7 100644 --- a/include/oneapi/mkl/sparse_blas.hpp +++ b/include/oneapi/mkl/sparse_blas.hpp @@ -34,6 +34,9 @@ #ifdef ENABLE_MKLGPU_BACKEND #include "sparse_blas/detail/mklgpu/sparse_blas_ct.hpp" #endif +#ifdef ENABLE_CUSPARSE_BACKEND +#include "sparse_blas/detail/cusparse/sparse_blas_ct.hpp" +#endif #include "sparse_blas/detail/sparse_blas_rt.hpp" diff --git a/include/oneapi/mkl/sparse_blas/detail/cusparse/onemkl_sparse_blas_cusparse.hpp b/include/oneapi/mkl/sparse_blas/detail/cusparse/onemkl_sparse_blas_cusparse.hpp new file mode 100644 index 000000000..c8e816eeb --- /dev/null +++ b/include/oneapi/mkl/sparse_blas/detail/cusparse/onemkl_sparse_blas_cusparse.hpp @@ -0,0 +1,33 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ + +#ifndef _ONEMKL_SPARSE_BLAS_DETAIL_CUSPARSE_ONEMKL_SPARSE_BLAS_CUSPARSE_HPP_ +#define _ONEMKL_SPARSE_BLAS_DETAIL_CUSPARSE_ONEMKL_SPARSE_BLAS_CUSPARSE_HPP_ + +#include "oneapi/mkl/detail/export.hpp" +#include "oneapi/mkl/sparse_blas/detail/helper_types.hpp" +#include "oneapi/mkl/sparse_blas/types.hpp" + +namespace oneapi::mkl::sparse::cusparse { + +#include "oneapi/mkl/sparse_blas/detail/onemkl_sparse_blas_backends.hxx" + +} // namespace oneapi::mkl::sparse::cusparse + +#endif // _ONEMKL_SPARSE_BLAS_DETAIL_CUSPARSE_ONEMKL_SPARSE_BLAS_CUSPARSE_HPP_ diff --git a/include/oneapi/mkl/sparse_blas/detail/cusparse/sparse_blas_ct.hpp b/include/oneapi/mkl/sparse_blas/detail/cusparse/sparse_blas_ct.hpp new file mode 100644 index 000000000..11abb9a6f --- /dev/null +++ b/include/oneapi/mkl/sparse_blas/detail/cusparse/sparse_blas_ct.hpp @@ -0,0 +1,40 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ + +#ifndef _ONEMKL_SPARSE_BLAS_DETAIL_CUSPARSE_SPARSE_BLAS_CT_HPP_ +#define _ONEMKL_SPARSE_BLAS_DETAIL_CUSPARSE_SPARSE_BLAS_CT_HPP_ + +#include "oneapi/mkl/detail/backends.hpp" +#include "oneapi/mkl/detail/backend_selector.hpp" + +#include "onemkl_sparse_blas_cusparse.hpp" + +namespace oneapi { +namespace mkl { +namespace sparse { + +#define BACKEND cusparse +#include "oneapi/mkl/sparse_blas/detail/sparse_blas_ct.hxx" +#undef BACKEND + +} //namespace sparse +} //namespace mkl +} //namespace oneapi + +#endif // _ONEMKL_SPARSE_BLAS_DETAIL_CUSPARSE_SPARSE_BLAS_CT_HPP_ diff --git a/include/oneapi/mkl/sparse_blas/detail/data_types.hpp b/include/oneapi/mkl/sparse_blas/detail/data_types.hpp new file mode 100644 index 000000000..f351c3493 --- /dev/null +++ b/include/oneapi/mkl/sparse_blas/detail/data_types.hpp @@ -0,0 +1,40 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ + +#ifndef _ONEMKL_SPARSE_BLAS_DETAIL_DATA_TYPES_HPP_ +#define _ONEMKL_SPARSE_BLAS_DETAIL_DATA_TYPES_HPP_ + +namespace oneapi::mkl::sparse { + +namespace detail { + +// Each backend can create its own handle type or re-use the native handle types that will be reinterpret_cast'ed to the types below +struct dense_matrix_handle; +struct dense_vector_handle; +struct matrix_handle; + +} // namespace detail + +typedef struct detail::dense_matrix_handle *dense_matrix_handle_t; +typedef struct detail::dense_vector_handle *dense_vector_handle_t; +typedef struct detail::matrix_handle *matrix_handle_t; + +} // namespace oneapi::mkl::sparse + +#endif // _ONEMKL_SPARSE_BLAS_DETAIL_DATA_TYPES_HPP_ diff --git a/include/oneapi/mkl/sparse_blas/detail/helper_types.hpp b/include/oneapi/mkl/sparse_blas/detail/helper_types.hpp index 4964b1eff..62411cbe5 100644 --- a/include/oneapi/mkl/sparse_blas/detail/helper_types.hpp +++ b/include/oneapi/mkl/sparse_blas/detail/helper_types.hpp @@ -29,8 +29,6 @@ namespace mkl { namespace sparse { namespace detail { -struct matrix_handle; - template inline constexpr bool is_fp_supported_v = std::is_same_v || std::is_same_v || diff --git a/include/oneapi/mkl/sparse_blas/detail/mklcpu/onemkl_sparse_blas_mklcpu.hpp b/include/oneapi/mkl/sparse_blas/detail/mklcpu/onemkl_sparse_blas_mklcpu.hpp index 2535e61f6..8686d35bc 100644 --- a/include/oneapi/mkl/sparse_blas/detail/mklcpu/onemkl_sparse_blas_mklcpu.hpp +++ b/include/oneapi/mkl/sparse_blas/detail/mklcpu/onemkl_sparse_blas_mklcpu.hpp @@ -22,6 +22,7 @@ #include "oneapi/mkl/detail/export.hpp" #include "oneapi/mkl/sparse_blas/detail/helper_types.hpp" +#include "oneapi/mkl/sparse_blas/types.hpp" namespace oneapi::mkl::sparse::mklcpu { diff --git a/include/oneapi/mkl/sparse_blas/detail/mklcpu/sparse_blas_ct.hpp b/include/oneapi/mkl/sparse_blas/detail/mklcpu/sparse_blas_ct.hpp index bc0089c57..ee127c3f8 100644 --- a/include/oneapi/mkl/sparse_blas/detail/mklcpu/sparse_blas_ct.hpp +++ b/include/oneapi/mkl/sparse_blas/detail/mklcpu/sparse_blas_ct.hpp @@ -20,7 +20,6 @@ #ifndef _ONEMKL_SPARSE_BLAS_DETAIL_MKLCPU_SPARSE_BLAS_CT_HPP_ #define _ONEMKL_SPARSE_BLAS_DETAIL_MKLCPU_SPARSE_BLAS_CT_HPP_ -#include "oneapi/mkl/sparse_blas/types.hpp" #include "oneapi/mkl/detail/backends.hpp" #include "oneapi/mkl/detail/backend_selector.hpp" diff --git a/include/oneapi/mkl/sparse_blas/detail/mklgpu/onemkl_sparse_blas_mklgpu.hpp b/include/oneapi/mkl/sparse_blas/detail/mklgpu/onemkl_sparse_blas_mklgpu.hpp index 1ca336b9b..eb3aaa5ff 100644 --- a/include/oneapi/mkl/sparse_blas/detail/mklgpu/onemkl_sparse_blas_mklgpu.hpp +++ b/include/oneapi/mkl/sparse_blas/detail/mklgpu/onemkl_sparse_blas_mklgpu.hpp @@ -22,6 +22,7 @@ #include "oneapi/mkl/detail/export.hpp" #include "oneapi/mkl/sparse_blas/detail/helper_types.hpp" +#include "oneapi/mkl/sparse_blas/types.hpp" namespace oneapi::mkl::sparse::mklgpu { diff --git a/include/oneapi/mkl/sparse_blas/detail/mklgpu/sparse_blas_ct.hpp b/include/oneapi/mkl/sparse_blas/detail/mklgpu/sparse_blas_ct.hpp index 00c01346f..d3b0d365f 100644 --- a/include/oneapi/mkl/sparse_blas/detail/mklgpu/sparse_blas_ct.hpp +++ b/include/oneapi/mkl/sparse_blas/detail/mklgpu/sparse_blas_ct.hpp @@ -20,7 +20,6 @@ #ifndef _ONEMKL_SPARSE_BLAS_DETAIL_MKLGPU_SPARSE_BLAS_CT_HPP_ #define _ONEMKL_SPARSE_BLAS_DETAIL_MKLGPU_SPARSE_BLAS_CT_HPP_ -#include "oneapi/mkl/sparse_blas/types.hpp" #include "oneapi/mkl/detail/backends.hpp" #include "oneapi/mkl/detail/backend_selector.hpp" diff --git a/include/oneapi/mkl/sparse_blas/detail/onemkl_sparse_blas_backends.hxx b/include/oneapi/mkl/sparse_blas/detail/onemkl_sparse_blas_backends.hxx index 03beaa4b4..2954d6370 100644 --- a/include/oneapi/mkl/sparse_blas/detail/onemkl_sparse_blas_backends.hxx +++ b/include/oneapi/mkl/sparse_blas/detail/onemkl_sparse_blas_backends.hxx @@ -20,22 +20,66 @@ // This file is meant to be included in each backend onemkl_sparse_blas_BACKEND.hpp files. // It is used to exports each symbol to the onemkl_sparse_blas_BACKEND library. -ONEMKL_EXPORT void init_matrix_handle(sycl::queue &queue, matrix_handle_t *p_handle); +// Dense vector +template +ONEMKL_EXPORT void create_dense_vector(sycl::queue &queue, dense_vector_handle_t *p_dvhandle, + std::int64_t size, sycl::buffer &val); +template +ONEMKL_EXPORT sycl::event create_dense_vector(sycl::queue &queue, dense_vector_handle_t *p_dvhandle, + std::int64_t size, fpType *val, + const std::vector &dependencies = {}); -ONEMKL_EXPORT sycl::event release_matrix_handle(sycl::queue &queue, matrix_handle_t *p_handle, - const std::vector &dependencies = {}); +// Dense matrix +template +ONEMKL_EXPORT void create_dense_matrix(sycl::queue &queue, dense_matrix_handle_t *p_dmhandle, + std::int64_t num_rows, std::int64_t num_cols, + std::int64_t ld, layout dense_layout, + sycl::buffer &val); +template +ONEMKL_EXPORT sycl::event create_dense_matrix(sycl::queue &queue, dense_matrix_handle_t *p_dmhandle, + std::int64_t num_rows, std::int64_t num_cols, + std::int64_t ld, layout dense_layout, fpType *val, + const std::vector &dependencies = {}); +// CSR matrix template -ONEMKL_EXPORT std::enable_if_t> set_csr_data( - sycl::queue &queue, matrix_handle_t handle, intType num_rows, intType num_cols, intType nnz, - index_base index, sycl::buffer &row_ptr, sycl::buffer &col_ind, - sycl::buffer &val); - +ONEMKL_EXPORT void create_csr_matrix(sycl::queue &queue, matrix_handle_t *p_smhandle, + std::int64_t num_rows, std::int64_t num_cols, std::int64_t nnz, + index_base index, sycl::buffer &row_ptr, + sycl::buffer &col_ind, + sycl::buffer &val); template -ONEMKL_EXPORT std::enable_if_t, sycl::event> -set_csr_data(sycl::queue &queue, matrix_handle_t handle, intType num_rows, intType num_cols, - intType nnz, index_base index, intType *row_ptr, intType *col_ind, fpType *val, - const std::vector &dependencies = {}); +ONEMKL_EXPORT sycl::event create_csr_matrix(sycl::queue &queue, matrix_handle_t *p_smhandle, + std::int64_t num_rows, std::int64_t num_cols, + std::int64_t nnz, index_base index, intType *row_ptr, + intType *col_ind, fpType *val, + const std::vector &dependencies = {}); + +// Destroy data types +ONEMKL_EXPORT sycl::event destroy_dense_vector(sycl::queue &queue, dense_vector_handle_t dvhandle, + const std::vector &dependencies = {}); +ONEMKL_EXPORT sycl::event destroy_dense_matrix(sycl::queue &queue, dense_matrix_handle_t dmhandle, + const std::vector &dependencies = {}); +ONEMKL_EXPORT sycl::event destroy_csr_matrix(sycl::queue &queue, matrix_handle_t smhandle, + const std::vector &dependencies = {}); + +// Matrix property +ONEMKL_EXPORT void set_matrix_property(sycl::queue &queue, matrix_handle_t smhandle, + matrix_property property_value); + +// Operation descriptor +template +ONEMKL_EXPORT void init_trsv_descr(sycl::queue &queue, trsv_descr_t *p_trsv_descr); +ONEMKL_EXPORT sycl::event release_trsv_descr(sycl::queue &queue, trsv_descr_t trsv_descr, + const std::vector &dependencies = {}); + +// Temporary buffer size +ONEMKL_EXPORT sycl::event trsv_buffer_size(sycl::queue &queue, uplo uplo_val, + transpose transpose_val, diag diag_val, + matrix_handle_t A_handle, dense_vector_handle_t x, + dense_vector_handle_t y, trsv_alg alg, + trsv_descr_t trsv_descr, std::int64_t &temp_buffer_size, + const std::vector &dependencies = {}); ONEMKL_EXPORT sycl::event optimize_gemm(sycl::queue &queue, transpose transpose_A, matrix_handle_t handle, @@ -50,42 +94,42 @@ ONEMKL_EXPORT sycl::event optimize_gemv(sycl::queue &queue, transpose transpose_ matrix_handle_t handle, const std::vector &dependencies = {}); +ONEMKL_EXPORT void optimize_trsv(sycl::queue &queue, uplo uplo_val, transpose transpose_val, + diag diag_val, matrix_handle_t A_handle, trsv_alg alg, + trsv_descr_t trsv_descr, std::int64_t temp_buffer_size, + sycl::buffer temp_buffer); ONEMKL_EXPORT sycl::event optimize_trsv(sycl::queue &queue, uplo uplo_val, transpose transpose_val, - diag diag_val, matrix_handle_t handle, + diag diag_val, matrix_handle_t A_handle, trsv_alg alg, + trsv_descr_t trsv_descr, std::int64_t temp_buffer_size, + void *temp_buffer, const std::vector &dependencies = {}); template -ONEMKL_EXPORT std::enable_if_t> gemv( - sycl::queue &queue, transpose transpose_val, const fpType alpha, matrix_handle_t A_handle, - sycl::buffer &x, const fpType beta, sycl::buffer &y); +ONEMKL_EXPORT void gemv(sycl::queue &queue, transpose transpose_val, const fpType alpha, + matrix_handle_t A_handle, sycl::buffer &x, const fpType beta, + sycl::buffer &y); template -ONEMKL_EXPORT std::enable_if_t, sycl::event> gemv( - sycl::queue &queue, transpose transpose_val, const fpType alpha, matrix_handle_t A_handle, - const fpType *x, const fpType beta, fpType *y, - const std::vector &dependencies = {}); +ONEMKL_EXPORT sycl::event gemv(sycl::queue &queue, transpose transpose_val, const fpType alpha, + matrix_handle_t A_handle, const fpType *x, const fpType beta, + fpType *y, const std::vector &dependencies = {}); -template -ONEMKL_EXPORT std::enable_if_t> trsv( - sycl::queue &queue, uplo uplo_val, transpose transpose_val, diag diag_val, - matrix_handle_t A_handle, sycl::buffer &x, sycl::buffer &y); - -template -ONEMKL_EXPORT std::enable_if_t, sycl::event> trsv( - sycl::queue &queue, uplo uplo_val, transpose transpose_val, diag diag_val, - matrix_handle_t A_handle, const fpType *x, fpType *y, - const std::vector &dependencies = {}); +ONEMKL_EXPORT sycl::event trsv(sycl::queue &queue, uplo uplo_val, transpose transpose_val, + diag diag_val, matrix_handle_t A_handle, dense_vector_handle_t x, + dense_vector_handle_t y, trsv_alg alg, trsv_descr_t trsv_descr, + const std::vector &dependencies = {}); template -ONEMKL_EXPORT std::enable_if_t> gemm( - sycl::queue &queue, layout dense_matrix_layout, transpose transpose_A, transpose transpose_B, - const fpType alpha, matrix_handle_t A_handle, sycl::buffer &B, - const std::int64_t columns, const std::int64_t ldb, const fpType beta, - sycl::buffer &C, const std::int64_t ldc); +ONEMKL_EXPORT void gemm(sycl::queue &queue, layout dense_matrix_layout, transpose transpose_A, + transpose transpose_B, const fpType alpha, matrix_handle_t A_handle, + sycl::buffer &B, const std::int64_t columns, + const std::int64_t ldb, const fpType beta, sycl::buffer &C, + const std::int64_t ldc); template -ONEMKL_EXPORT std::enable_if_t, sycl::event> gemm( - sycl::queue &queue, layout dense_matrix_layout, transpose transpose_A, transpose transpose_B, - const fpType alpha, matrix_handle_t A_handle, const fpType *B, const std::int64_t columns, - const std::int64_t ldb, const fpType beta, fpType *C, const std::int64_t ldc, - const std::vector &dependencies = {}); +ONEMKL_EXPORT sycl::event gemm(sycl::queue &queue, layout dense_matrix_layout, + transpose transpose_A, transpose transpose_B, const fpType alpha, + matrix_handle_t A_handle, const fpType *B, + const std::int64_t columns, const std::int64_t ldb, + const fpType beta, fpType *C, const std::int64_t ldc, + const std::vector &dependencies = {}); diff --git a/include/oneapi/mkl/sparse_blas/detail/operation_types.hpp b/include/oneapi/mkl/sparse_blas/detail/operation_types.hpp new file mode 100644 index 000000000..1f52fbec8 --- /dev/null +++ b/include/oneapi/mkl/sparse_blas/detail/operation_types.hpp @@ -0,0 +1,36 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ + +#ifndef _ONEMKL_SPARSE_BLAS_DETAIL_OPERATION_TYPES_HPP_ +#define _ONEMKL_SPARSE_BLAS_DETAIL_OPERATION_TYPES_HPP_ + +namespace oneapi::mkl::sparse { + +namespace detail { + +// Each backend can create its own descriptor type or re-use the native descriptor types that will be reinterpret_cast'ed to the types below +struct trsv_descr; + +} // namespace detail + +typedef struct detail::trsv_descr *trsv_descr_t; + +} // namespace oneapi::mkl::sparse + +#endif // _ONEMKL_SPARSE_BLAS_DETAIL_OPERATION_TYPES_HPP_ diff --git a/include/oneapi/mkl/sparse_blas/detail/sparse_blas_ct.hxx b/include/oneapi/mkl/sparse_blas/detail/sparse_blas_ct.hxx index 41fe51c49..1380a0f44 100644 --- a/include/oneapi/mkl/sparse_blas/detail/sparse_blas_ct.hxx +++ b/include/oneapi/mkl/sparse_blas/detail/sparse_blas_ct.hxx @@ -24,33 +24,101 @@ #error "BACKEND is not defined" #endif -inline void init_matrix_handle(backend_selector selector, - matrix_handle_t *p_handle) { - BACKEND::init_matrix_handle(selector.get_queue(), p_handle); +// Dense vector +template +std::enable_if_t> create_dense_vector( + backend_selector selector, dense_vector_handle_t *p_dvhandle, + std::int64_t size, sycl::buffer &val) { + BACKEND::create_dense_vector(selector.get_queue(), p_dvhandle, size, val); +} +template +std::enable_if_t, sycl::event> create_dense_vector( + backend_selector selector, dense_vector_handle_t *p_dvhandle, + std::int64_t size, fpType *val, const std::vector &dependencies = {}) { + return BACKEND::create_dense_vector(selector.get_queue(), p_dvhandle, size, val, dependencies); } -inline sycl::event release_matrix_handle(backend_selector selector, - matrix_handle_t *p_handle, - const std::vector &dependencies = {}) { - return BACKEND::release_matrix_handle(selector.get_queue(), p_handle, dependencies); +// Dense matrix +template +std::enable_if_t> create_dense_matrix( + backend_selector selector, dense_matrix_handle_t *p_dmhandle, + std::int64_t num_rows, std::int64_t num_cols, std::int64_t ld, layout dense_layout, + sycl::buffer &val) { + BACKEND::create_dense_matrix(selector.get_queue(), p_dmhandle, num_rows, num_cols, ld, + dense_layout, val); +} +template +std::enable_if_t, sycl::event> create_dense_matrix( + backend_selector selector, dense_matrix_handle_t *p_dmhandle, + std::int64_t num_rows, std::int64_t num_cols, std::int64_t ld, layout dense_layout, fpType *val, + const std::vector &dependencies = {}) { + return BACKEND::create_dense_matrix(selector.get_queue(), p_dmhandle, num_rows, num_cols, ld, + dense_layout, val, dependencies); } +// CSR matrix template -std::enable_if_t> set_csr_data( - backend_selector selector, matrix_handle_t handle, intType num_rows, - intType num_cols, intType nnz, index_base index, sycl::buffer &row_ptr, +std::enable_if_t> create_csr_matrix( + backend_selector selector, matrix_handle_t *p_smhandle, std::int64_t num_rows, + std::int64_t num_cols, std::int64_t nnz, index_base index, sycl::buffer &row_ptr, sycl::buffer &col_ind, sycl::buffer &val) { - BACKEND::set_csr_data(selector.get_queue(), handle, num_rows, num_cols, nnz, index, row_ptr, - col_ind, val); + BACKEND::create_csr_matrix(selector.get_queue(), p_smhandle, num_rows, num_cols, nnz, index, + row_ptr, col_ind, val); } - template -std::enable_if_t, sycl::event> set_csr_data( - backend_selector selector, matrix_handle_t handle, intType num_rows, - intType num_cols, intType nnz, index_base index, intType *row_ptr, intType *col_ind, +std::enable_if_t, sycl::event> create_csr_matrix( + backend_selector selector, matrix_handle_t *p_smhandle, std::int64_t num_rows, + std::int64_t num_cols, std::int64_t nnz, index_base index, intType *row_ptr, intType *col_ind, fpType *val, const std::vector &dependencies = {}) { - return BACKEND::set_csr_data(selector.get_queue(), handle, num_rows, num_cols, nnz, index, - row_ptr, col_ind, val, dependencies); + return BACKEND::create_csr_matrix(selector.get_queue(), p_smhandle, num_rows, num_cols, nnz, + index, row_ptr, col_ind, val, dependencies); +} + +// Destroy data types +inline sycl::event destroy_dense_vector(backend_selector selector, + dense_vector_handle_t dvhandle, + const std::vector &dependencies = {}) { + return BACKEND::destroy_dense_vector(selector.get_queue(), dvhandle, dependencies); +} +inline sycl::event destroy_dense_matrix(backend_selector selector, + dense_matrix_handle_t dmhandle, + const std::vector &dependencies = {}) { + return BACKEND::destroy_dense_matrix(selector.get_queue(), dmhandle, dependencies); +} +inline sycl::event destroy_csr_matrix(backend_selector selector, + matrix_handle_t smhandle, + const std::vector &dependencies = {}) { + return BACKEND::destroy_csr_matrix(selector.get_queue(), smhandle, dependencies); +} + +// Matrix property +inline void set_matrix_property(backend_selector selector, + matrix_handle_t smhandle, matrix_property property_value) { + BACKEND::set_matrix_property(selector.get_queue(), smhandle, property_value); +} + +// Operation descriptor +template +std::enable_if_t> init_trsv_descr( + backend_selector selector, trsv_descr_t *p_trsv_descr) { + BACKEND::init_trsv_descr(selector.get_queue(), p_trsv_descr); +} +inline sycl::event release_trsv_descr(backend_selector selector, + trsv_descr_t trsv_descr, + const std::vector &dependencies = {}) { + return BACKEND::release_trsv_descr(selector.get_queue(), trsv_descr, dependencies); +} + +// Temporary buffer size +inline sycl::event trsv_buffer_size(backend_selector selector, uplo uplo_val, + transpose transpose_val, diag diag_val, + matrix_handle_t A_handle, dense_vector_handle_t x, + dense_vector_handle_t y, trsv_alg alg, trsv_descr_t trsv_descr, + std::int64_t &temp_buffer_size, + const std::vector &dependencies = {}) { + return BACKEND::trsv_buffer_size(selector.get_queue(), uplo_val, transpose_val, diag_val, + A_handle, x, y, alg, trsv_descr, temp_buffer_size, + dependencies); } inline sycl::event optimize_gemm(backend_selector selector, transpose transpose_A, @@ -73,11 +141,20 @@ inline sycl::event optimize_gemv(backend_selector selector, return BACKEND::optimize_gemv(selector.get_queue(), transpose_val, handle, dependencies); } +inline void optimize_trsv(backend_selector selector, uplo uplo_val, + transpose transpose_val, diag diag_val, matrix_handle_t A_handle, + trsv_alg alg, trsv_descr_t trsv_descr, std::int64_t temp_buffer_size, + sycl::buffer temp_buffer) { + BACKEND::optimize_trsv(selector.get_queue(), uplo_val, transpose_val, diag_val, A_handle, alg, + trsv_descr, temp_buffer_size, temp_buffer); +} inline sycl::event optimize_trsv(backend_selector selector, uplo uplo_val, - transpose transpose_val, diag diag_val, matrix_handle_t handle, + transpose transpose_val, diag diag_val, matrix_handle_t A_handle, + trsv_alg alg, trsv_descr_t trsv_descr, + std::int64_t temp_buffer_size, void *temp_buffer, const std::vector &dependencies = {}) { - return BACKEND::optimize_trsv(selector.get_queue(), uplo_val, transpose_val, diag_val, handle, - dependencies); + return BACKEND::optimize_trsv(selector.get_queue(), uplo_val, transpose_val, diag_val, A_handle, + alg, trsv_descr, temp_buffer_size, temp_buffer, dependencies); } template @@ -97,23 +174,6 @@ std::enable_if_t, sycl::event> gemv( dependencies); } -template -std::enable_if_t> trsv( - backend_selector selector, uplo uplo_val, transpose transpose_val, - diag diag_val, matrix_handle_t A_handle, sycl::buffer &x, - sycl::buffer &y) { - BACKEND::trsv(selector.get_queue(), uplo_val, transpose_val, diag_val, A_handle, x, y); -} - -template -std::enable_if_t, sycl::event> trsv( - backend_selector selector, uplo uplo_val, transpose transpose_val, - diag diag_val, matrix_handle_t A_handle, const fpType *x, fpType *y, - const std::vector &dependencies = {}) { - return BACKEND::trsv(selector.get_queue(), uplo_val, transpose_val, diag_val, A_handle, x, y, - dependencies); -} - template std::enable_if_t> gemm( backend_selector selector, layout dense_matrix_layout, transpose transpose_A, @@ -133,3 +193,12 @@ std::enable_if_t, sycl::event> gemm( return BACKEND::gemm(selector.get_queue(), dense_matrix_layout, transpose_A, transpose_B, alpha, A_handle, B, columns, ldb, beta, C, ldc, dependencies); } + +inline sycl::event trsv(backend_selector selector, uplo uplo_val, + transpose transpose_val, diag diag_val, matrix_handle_t A_handle, + dense_vector_handle_t x, dense_vector_handle_t y, trsv_alg alg, + trsv_descr_t trsv_descr, + const std::vector &dependencies = {}) { + return BACKEND::trsv(selector.get_queue(), uplo_val, transpose_val, diag_val, A_handle, x, y, + alg, trsv_descr, dependencies); +} diff --git a/include/oneapi/mkl/sparse_blas/detail/sparse_blas_rt.hpp b/include/oneapi/mkl/sparse_blas/detail/sparse_blas_rt.hpp index 131e0545a..2d99b15bd 100644 --- a/include/oneapi/mkl/sparse_blas/detail/sparse_blas_rt.hpp +++ b/include/oneapi/mkl/sparse_blas/detail/sparse_blas_rt.hpp @@ -20,29 +20,80 @@ #ifndef _ONEMKL_SPARSE_BLAS_DETAIL_SPARSE_BLAS_RT_HPP_ #define _ONEMKL_SPARSE_BLAS_DETAIL_SPARSE_BLAS_RT_HPP_ +#include "oneapi/mkl/sparse_blas/detail/helper_types.hpp" #include "oneapi/mkl/sparse_blas/types.hpp" namespace oneapi { namespace mkl { namespace sparse { -void init_matrix_handle(sycl::queue &queue, matrix_handle_t *p_handle); +// TODO(Romain): Make functions create_* and destroy_* synchronous +// TODO(Romain): Update API for gemm and gemv operations -sycl::event release_matrix_handle(sycl::queue &queue, matrix_handle_t *p_handle, - const std::vector &dependencies = {}); - -template -std::enable_if_t> set_csr_data( - sycl::queue &queue, matrix_handle_t handle, intType num_rows, intType num_cols, intType nnz, - index_base index, sycl::buffer &row_ptr, sycl::buffer &col_ind, +// Dense vector +template +std::enable_if_t> create_dense_vector( + sycl::queue &queue, dense_vector_handle_t *p_dvhandle, std::int64_t size, sycl::buffer &val); +template +std::enable_if_t, sycl::event> create_dense_vector( + sycl::queue &queue, dense_vector_handle_t *p_dvhandle, std::int64_t size, fpType *val, + const std::vector &dependencies = {}); +// Dense matrix +template +std::enable_if_t> create_dense_matrix( + sycl::queue &queue, dense_matrix_handle_t *p_dmhandle, std::int64_t num_rows, + std::int64_t num_cols, std::int64_t ld, layout dense_layout, sycl::buffer &val); +template +std::enable_if_t, sycl::event> create_dense_matrix( + sycl::queue &queue, dense_matrix_handle_t *p_dmhandle, std::int64_t num_rows, + std::int64_t num_cols, std::int64_t ld, layout dense_layout, fpType *val, + const std::vector &dependencies = {}); + +// CSR matrix template -std::enable_if_t, sycl::event> set_csr_data( - sycl::queue &queue, matrix_handle_t handle, intType num_rows, intType num_cols, intType nnz, - index_base index, intType *row_ptr, intType *col_ind, fpType *val, +std::enable_if_t> create_csr_matrix( + sycl::queue &queue, matrix_handle_t *p_smhandle, std::int64_t num_rows, std::int64_t num_cols, + std::int64_t nnz, index_base index, sycl::buffer &row_ptr, + sycl::buffer &col_ind, sycl::buffer &val); +template +std::enable_if_t, sycl::event> create_csr_matrix( + sycl::queue &queue, matrix_handle_t *p_smhandle, std::int64_t num_rows, std::int64_t num_cols, + std::int64_t nnz, index_base index, intType *row_ptr, intType *col_ind, fpType *val, const std::vector &dependencies = {}); +// Destroy data types +sycl::event destroy_dense_vector(sycl::queue &queue, dense_vector_handle_t dvhandle, + const std::vector &dependencies = {}); +sycl::event destroy_dense_matrix(sycl::queue &queue, dense_matrix_handle_t dmhandle, + const std::vector &dependencies = {}); +sycl::event destroy_csr_matrix(sycl::queue &queue, matrix_handle_t smhandle, + const std::vector &dependencies = {}); + +// Matrix property +void set_matrix_property(sycl::queue &queue, matrix_handle_t smhandle, + matrix_property property_value); + +// TODO(Romain): Add support for setting matrices and vector data + +// Operation descriptor +template +std::enable_if_t> init_trsv_descr(sycl::queue &queue, + trsv_descr_t *p_trsv_descr); +// TODO(Romain): Make synchronous +sycl::event release_trsv_descr(sycl::queue &queue, trsv_descr_t trsv_descr, + const std::vector &dependencies = {}); + +// Temporary buffer size +// TODO(Romain): Make synchronous +sycl::event trsv_buffer_size(sycl::queue &queue, uplo uplo_val, transpose transpose_val, + diag diag_val, matrix_handle_t A_handle, dense_vector_handle_t x, + dense_vector_handle_t y, trsv_alg alg, trsv_descr_t trsv_descr, + std::int64_t &temp_buffer_size, + const std::vector &dependencies = {}); + +// Optimize sycl::event optimize_gemm(sycl::queue &queue, transpose transpose_A, matrix_handle_t handle, const std::vector &dependencies = {}); @@ -54,10 +105,17 @@ sycl::event optimize_gemm(sycl::queue &queue, transpose transpose_A, transpose t sycl::event optimize_gemv(sycl::queue &queue, transpose transpose_val, matrix_handle_t handle, const std::vector &dependencies = {}); +// TODO(Romain): Make temp_buffer_size unsigned +void optimize_trsv(sycl::queue &queue, uplo uplo_val, transpose transpose_val, diag diag_val, + matrix_handle_t A_handle, trsv_alg alg, trsv_descr_t trsv_descr, + std::int64_t temp_buffer_size, sycl::buffer temp_buffer); +// TODO(Romain): Make synchronous? sycl::event optimize_trsv(sycl::queue &queue, uplo uplo_val, transpose transpose_val, diag diag_val, - matrix_handle_t handle, + matrix_handle_t A_handle, trsv_alg alg, trsv_descr_t trsv_descr, + std::int64_t temp_buffer_size, void *temp_buffer, const std::vector &dependencies = {}); +// Operations template std::enable_if_t> gemv( sycl::queue &queue, transpose transpose_val, const fpType alpha, matrix_handle_t A_handle, @@ -69,19 +127,6 @@ std::enable_if_t, sycl::event> gemv( const fpType *x, const fpType beta, fpType *y, const std::vector &dependencies = {}); -template -std::enable_if_t> trsv(sycl::queue &queue, uplo uplo_val, - transpose transpose_val, diag diag_val, - matrix_handle_t A_handle, - sycl::buffer &x, - sycl::buffer &y); - -template -std::enable_if_t, sycl::event> trsv( - sycl::queue &queue, uplo uplo_val, transpose transpose_val, diag diag_val, - matrix_handle_t A_handle, const fpType *x, fpType *y, - const std::vector &dependencies = {}); - template std::enable_if_t> gemm( sycl::queue &queue, layout dense_matrix_layout, transpose transpose_A, transpose transpose_B, @@ -96,6 +141,11 @@ std::enable_if_t, sycl::event> gemm( const std::int64_t ldb, const fpType beta, fpType *C, const std::int64_t ldc, const std::vector &dependencies = {}); +sycl::event trsv(sycl::queue &queue, uplo uplo_val, transpose transpose_val, diag diag_val, + matrix_handle_t A_handle, dense_vector_handle_t x, dense_vector_handle_t y, + trsv_alg alg, trsv_descr_t trsv_descr, + const std::vector &dependencies = {}); + } // namespace sparse } // namespace mkl } // namespace oneapi diff --git a/include/oneapi/mkl/sparse_blas/types.hpp b/include/oneapi/mkl/sparse_blas/types.hpp index 406c7dd1f..12dbfebed 100644 --- a/include/oneapi/mkl/sparse_blas/types.hpp +++ b/include/oneapi/mkl/sparse_blas/types.hpp @@ -20,22 +20,26 @@ #ifndef _ONEMKL_SPARSE_BLAS_TYPES_HPP_ #define _ONEMKL_SPARSE_BLAS_TYPES_HPP_ -#if __has_include() -#include -#else -#include -#endif - -#include - #include "oneapi/mkl/types.hpp" -#include "detail/helper_types.hpp" +#include "detail/data_types.hpp" +#include "detail/operation_types.hpp" + +/** + * @file Include and define the sparse types that are common between close-source MKL API and oneMKL API. +*/ namespace oneapi { namespace mkl { namespace sparse { -using matrix_handle_t = detail::matrix_handle*; +enum class matrix_property : char { + symmetric = 0x00, + sorted = 0x01, /* CSR, CSC, BSR only */ +}; + +enum class trsv_alg : char { + default_alg = 0x00, +}; } // namespace sparse } // namespace mkl diff --git a/src/config.hpp.in b/src/config.hpp.in index 5698abf9b..fd55006a6 100644 --- a/src/config.hpp.in +++ b/src/config.hpp.in @@ -24,6 +24,7 @@ #cmakedefine ENABLE_CUFFT_BACKEND #cmakedefine ENABLE_CURAND_BACKEND #cmakedefine ENABLE_CUSOLVER_BACKEND +#cmakedefine ENABLE_CUSPARSE_BACKEND #cmakedefine ENABLE_MKLCPU_BACKEND #cmakedefine ENABLE_MKLGPU_BACKEND #cmakedefine ENABLE_NETLIB_BACKEND diff --git a/src/sparse_blas/backends/CMakeLists.txt b/src/sparse_blas/backends/CMakeLists.txt index ef606c6e1..9a226d78d 100644 --- a/src/sparse_blas/backends/CMakeLists.txt +++ b/src/sparse_blas/backends/CMakeLists.txt @@ -27,3 +27,7 @@ endif() if(ENABLE_MKLGPU_BACKEND) add_subdirectory(mklgpu) endif() + +if(ENABLE_CUSPARSE_BACKEND) + add_subdirectory(cusparse) +endif() diff --git a/src/sparse_blas/backends/backend_wrappers.cxx b/src/sparse_blas/backends/backend_wrappers.cxx index 2c8161249..11099733a 100644 --- a/src/sparse_blas/backends/backend_wrappers.cxx +++ b/src/sparse_blas/backends/backend_wrappers.cxx @@ -33,31 +33,58 @@ extern "C" sparse_blas_function_table_t mkl_sparse_blas_table = { Changes to this file should be matched to changes in sparse_blas/function_table.hpp. The required function template instantiations must be added to backend_sparse_blas_instantiations.cxx. + +Functions that cannot have their template parameter automatically deduced must be explicitly specified here. */ // clang-format off -oneapi::mkl::sparse::BACKEND::init_matrix_handle, -oneapi::mkl::sparse::BACKEND::release_matrix_handle, -oneapi::mkl::sparse::BACKEND::set_csr_data, -oneapi::mkl::sparse::BACKEND::set_csr_data, -oneapi::mkl::sparse::BACKEND::set_csr_data, -oneapi::mkl::sparse::BACKEND::set_csr_data, -oneapi::mkl::sparse::BACKEND::set_csr_data, -oneapi::mkl::sparse::BACKEND::set_csr_data, -oneapi::mkl::sparse::BACKEND::set_csr_data, -oneapi::mkl::sparse::BACKEND::set_csr_data, -oneapi::mkl::sparse::BACKEND::set_csr_data, -oneapi::mkl::sparse::BACKEND::set_csr_data, -oneapi::mkl::sparse::BACKEND::set_csr_data, -oneapi::mkl::sparse::BACKEND::set_csr_data, -oneapi::mkl::sparse::BACKEND::set_csr_data, -oneapi::mkl::sparse::BACKEND::set_csr_data, -oneapi::mkl::sparse::BACKEND::set_csr_data, -oneapi::mkl::sparse::BACKEND::set_csr_data, +oneapi::mkl::sparse::BACKEND::create_dense_vector, +oneapi::mkl::sparse::BACKEND::create_dense_vector, +oneapi::mkl::sparse::BACKEND::create_dense_vector, +oneapi::mkl::sparse::BACKEND::create_dense_vector, +oneapi::mkl::sparse::BACKEND::create_dense_vector, +oneapi::mkl::sparse::BACKEND::create_dense_vector, +oneapi::mkl::sparse::BACKEND::create_dense_vector, +oneapi::mkl::sparse::BACKEND::create_dense_vector, +oneapi::mkl::sparse::BACKEND::create_dense_matrix, +oneapi::mkl::sparse::BACKEND::create_dense_matrix, +oneapi::mkl::sparse::BACKEND::create_dense_matrix, +oneapi::mkl::sparse::BACKEND::create_dense_matrix, +oneapi::mkl::sparse::BACKEND::create_dense_matrix, +oneapi::mkl::sparse::BACKEND::create_dense_matrix, +oneapi::mkl::sparse::BACKEND::create_dense_matrix, +oneapi::mkl::sparse::BACKEND::create_dense_matrix, +oneapi::mkl::sparse::BACKEND::create_csr_matrix, +oneapi::mkl::sparse::BACKEND::create_csr_matrix, +oneapi::mkl::sparse::BACKEND::create_csr_matrix, +oneapi::mkl::sparse::BACKEND::create_csr_matrix, +oneapi::mkl::sparse::BACKEND::create_csr_matrix, +oneapi::mkl::sparse::BACKEND::create_csr_matrix, +oneapi::mkl::sparse::BACKEND::create_csr_matrix, +oneapi::mkl::sparse::BACKEND::create_csr_matrix, +oneapi::mkl::sparse::BACKEND::create_csr_matrix, +oneapi::mkl::sparse::BACKEND::create_csr_matrix, +oneapi::mkl::sparse::BACKEND::create_csr_matrix, +oneapi::mkl::sparse::BACKEND::create_csr_matrix, +oneapi::mkl::sparse::BACKEND::create_csr_matrix, +oneapi::mkl::sparse::BACKEND::create_csr_matrix, +oneapi::mkl::sparse::BACKEND::create_csr_matrix, +oneapi::mkl::sparse::BACKEND::create_csr_matrix, +oneapi::mkl::sparse::BACKEND::destroy_dense_vector, +oneapi::mkl::sparse::BACKEND::destroy_dense_matrix, +oneapi::mkl::sparse::BACKEND::destroy_csr_matrix, +oneapi::mkl::sparse::BACKEND::set_matrix_property, +oneapi::mkl::sparse::BACKEND::init_trsv_descr, +oneapi::mkl::sparse::BACKEND::init_trsv_descr, +oneapi::mkl::sparse::BACKEND::init_trsv_descr>, +oneapi::mkl::sparse::BACKEND::init_trsv_descr>, +oneapi::mkl::sparse::BACKEND::release_trsv_descr, +oneapi::mkl::sparse::BACKEND::trsv_buffer_size, oneapi::mkl::sparse::BACKEND::optimize_gemm, oneapi::mkl::sparse::BACKEND::optimize_gemm, oneapi::mkl::sparse::BACKEND::optimize_gemv, oneapi::mkl::sparse::BACKEND::optimize_trsv, +oneapi::mkl::sparse::BACKEND::optimize_trsv, oneapi::mkl::sparse::BACKEND::gemv, oneapi::mkl::sparse::BACKEND::gemv, oneapi::mkl::sparse::BACKEND::gemv, @@ -66,14 +93,6 @@ oneapi::mkl::sparse::BACKEND::gemv, oneapi::mkl::sparse::BACKEND::gemv, oneapi::mkl::sparse::BACKEND::gemv, oneapi::mkl::sparse::BACKEND::gemv, -oneapi::mkl::sparse::BACKEND::trsv, -oneapi::mkl::sparse::BACKEND::trsv, -oneapi::mkl::sparse::BACKEND::trsv, -oneapi::mkl::sparse::BACKEND::trsv, -oneapi::mkl::sparse::BACKEND::trsv, -oneapi::mkl::sparse::BACKEND::trsv, -oneapi::mkl::sparse::BACKEND::trsv, -oneapi::mkl::sparse::BACKEND::trsv, oneapi::mkl::sparse::BACKEND::gemm, oneapi::mkl::sparse::BACKEND::gemm, oneapi::mkl::sparse::BACKEND::gemm, @@ -82,4 +101,5 @@ oneapi::mkl::sparse::BACKEND::gemm, oneapi::mkl::sparse::BACKEND::gemm, oneapi::mkl::sparse::BACKEND::gemm, oneapi::mkl::sparse::BACKEND::gemm, +oneapi::mkl::sparse::BACKEND::trsv, // clang-format on diff --git a/src/sparse_blas/backends/cusparse/CMakeLists.txt b/src/sparse_blas/backends/cusparse/CMakeLists.txt new file mode 100644 index 000000000..b528a67af --- /dev/null +++ b/src/sparse_blas/backends/cusparse/CMakeLists.txt @@ -0,0 +1,85 @@ +#=============================================================================== +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# +# +# SPDX-License-Identifier: Apache-2.0 +#=============================================================================== + +set(LIB_NAME onemkl_sparse_blas_cusparse) +set(LIB_OBJ ${LIB_NAME}_obj) + +include(WarningsUtils) + +add_library(${LIB_NAME}) +add_library(${LIB_OBJ} OBJECT + cusparse_data_types.cpp + cusparse_scope_handle.cpp + operations/cusparse_gemm.cpp + operations/cusparse_gemv.cpp + operations/cusparse_trsv.cpp + $<$: cusparse_wrappers.cpp> +) +add_dependencies(onemkl_backend_libs_sparse_blas ${LIB_NAME}) + +target_include_directories(${LIB_OBJ} + PRIVATE ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/src + ${CMAKE_BINARY_DIR}/bin + ${ONEMKL_GENERATED_INCLUDE_PATH} +) + +target_compile_options(${LIB_OBJ} PRIVATE ${ONEMKL_BUILD_COPT}) + +if (${CMAKE_VERSION} VERSION_LESS "3.17.0") + find_package(CUDA REQUIRED) + target_include_directories(${LIB_OBJ} PRIVATE ${CUDA_INCLUDE_DIRS}) + target_link_libraries(${LIB_OBJ} PRIVATE cuda ${CUDA_CUSPARSE_LIBRARIES}) +else() + find_package(CUDAToolkit REQUIRED) + target_link_libraries(${LIB_OBJ} PRIVATE CUDA::cusparse CUDA::cuda_driver) +endif() + +target_link_libraries(${LIB_OBJ} + PUBLIC ONEMKL::SYCL::SYCL + PRIVATE onemkl_warnings +) + +set_target_properties(${LIB_OBJ} PROPERTIES + POSITION_INDEPENDENT_CODE ON +) +target_link_libraries(${LIB_NAME} PUBLIC ${LIB_OBJ}) + +#Set oneMKL libraries as not transitive for dynamic +if(BUILD_SHARED_LIBS) + set_target_properties(${LIB_NAME} PROPERTIES + INTERFACE_LINK_LIBRARIES ONEMKL::SYCL::SYCL + ) +endif() + +# Add major version to the library +set_target_properties(${LIB_NAME} PROPERTIES + SOVERSION ${PROJECT_VERSION_MAJOR} +) + +# Add dependencies rpath to the library +list(APPEND CMAKE_BUILD_RPATH $) + +# Add the library to install package +install(TARGETS ${LIB_OBJ} EXPORT oneMKLTargets) +install(TARGETS ${LIB_NAME} EXPORT oneMKLTargets + RUNTIME DESTINATION bin + ARCHIVE DESTINATION lib + LIBRARY DESTINATION lib +) diff --git a/src/sparse_blas/backends/cusparse/cusparse_data_types.cpp b/src/sparse_blas/backends/cusparse/cusparse_data_types.cpp new file mode 100644 index 000000000..212654732 --- /dev/null +++ b/src/sparse_blas/backends/cusparse/cusparse_data_types.cpp @@ -0,0 +1,258 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ + +#include "oneapi/mkl/sparse_blas/detail/cusparse/onemkl_sparse_blas_cusparse.hpp" + +#include "cusparse_error.hpp" +#include "cusparse_helper.hpp" +#include "cusparse_internal_containers.hpp" +#include "cusparse_task.hpp" +#include "sparse_blas/macros.hpp" + +namespace oneapi::mkl::sparse::cusparse { + +using namespace oneapi::mkl::sparse::detail; + +/** + * In this file CusparseScopedContextHandler are used to ensure that a cusparseHandle_t is created before any other cuSPARSE call, as required by the specification. +*/ + +// Dense vector +template +void create_dense_vector(sycl::queue &queue, dense_vector_handle_t *p_dvhandle, std::int64_t size, + sycl::buffer &val) { + auto event = queue.submit([&](sycl::handler &cgh) { + auto acc = val.template get_access(cgh); + detail::submit_host_task(cgh, queue, [=](detail::CusparseScopedContextHandler &sc) { + // Ensure that a cusparse handle is created before any other cuSPARSE function is called. + sc.get_handle(queue); + auto cuda_type = detail::CudaEnumType::value; + cusparseDnVecDescr_t cu_dvhandle; + auto status = cusparseCreateDnVec(&cu_dvhandle, size, sc.get_mem(acc), cuda_type); + detail::check_status(status, "create_dense_vector"); + auto internal_dvhandle = new detail::dense_vector_handle(cu_dvhandle, val); + *p_dvhandle = reinterpret_cast(internal_dvhandle); + }); + }); + event.wait_and_throw(); +} + +template +sycl::event create_dense_vector(sycl::queue &queue, dense_vector_handle_t *p_dvhandle, + std::int64_t size, fpType *val, + const std::vector &dependencies) { + return queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependencies); + detail::submit_host_task(cgh, queue, [=](detail::CusparseScopedContextHandler &sc) { + // Ensure that a cusparse handle is created before any other cuSPARSE function is called. + sc.get_handle(queue); + auto cuda_type = detail::CudaEnumType::value; + cusparseDnVecDescr_t cu_dvhandle; + auto status = cusparseCreateDnVec(&cu_dvhandle, size, sc.get_mem(val), cuda_type); + detail::check_status(status, "create_dense_vector"); + auto internal_dvhandle = new detail::dense_vector_handle(cu_dvhandle, val); + *p_dvhandle = reinterpret_cast(internal_dvhandle); + }); + }); +} + +#define INSTANTIATE_CREATE_DENSE_VECTOR(FP_TYPE, FP_SUFFIX) \ + template std::enable_if_t> create_dense_vector( \ + sycl::queue & queue, dense_vector_handle_t * p_dvhandle, std::int64_t size, \ + sycl::buffer & val); \ + template std::enable_if_t, sycl::event> \ + create_dense_vector(sycl::queue & queue, dense_vector_handle_t * p_dvhandle, \ + std::int64_t size, FP_TYPE * val, \ + const std::vector &dependencies) +FOR_EACH_FP_TYPE(INSTANTIATE_CREATE_DENSE_VECTOR); +#undef INSTANTIATE_CREATE_DENSE_VECTOR + +// Dense matrix +template +void create_dense_matrix(sycl::queue &queue, dense_matrix_handle_t *p_dmhandle, + std::int64_t num_rows, std::int64_t num_cols, std::int64_t ld, + layout dense_layout, sycl::buffer &val) { + auto event = queue.submit([&](sycl::handler &cgh) { + auto acc = val.template get_access(cgh); + detail::submit_host_task(cgh, queue, [=](detail::CusparseScopedContextHandler &sc) { + // Ensure that a cusparse handle is created before any other cuSPARSE function is called. + sc.get_handle(queue); + auto cuda_type = detail::CudaEnumType::value; + auto cuda_order = detail::get_cuda_order(dense_layout); + cusparseDnMatDescr_t cu_dmhandle; + auto status = cusparseCreateDnMat(&cu_dmhandle, num_rows, num_cols, ld, sc.get_mem(acc), + cuda_type, cuda_order); + detail::check_status(status, "create_dense_matrix"); + auto internal_dmhandle = new detail::dense_matrix_handle(cu_dmhandle, val); + *p_dmhandle = reinterpret_cast(internal_dmhandle); + }); + }); + event.wait_and_throw(); +} +template +sycl::event create_dense_matrix(sycl::queue &queue, dense_matrix_handle_t *p_dmhandle, + std::int64_t num_rows, std::int64_t num_cols, std::int64_t ld, + layout dense_layout, fpType *val, + const std::vector &dependencies) { + return queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependencies); + detail::submit_host_task(cgh, queue, [=](detail::CusparseScopedContextHandler &sc) { + // Ensure that a cusparse handle is created before any other cuSPARSE function is called. + sc.get_handle(queue); + auto cuda_type = detail::CudaEnumType::value; + auto cuda_order = detail::get_cuda_order(dense_layout); + cusparseDnMatDescr_t cu_dmhandle; + auto status = cusparseCreateDnMat(&cu_dmhandle, num_rows, num_cols, ld, sc.get_mem(val), + cuda_type, cuda_order); + detail::check_status(status, "create_dense_matrix"); + auto internal_dmhandle = new detail::dense_matrix_handle(cu_dmhandle, val); + *p_dmhandle = reinterpret_cast(internal_dmhandle); + }); + }); +} + +#define INSTANTIATE_CREATE_DENSE_MATRIX(FP_TYPE, FP_SUFFIX) \ + template std::enable_if_t> create_dense_matrix( \ + sycl::queue & queue, dense_matrix_handle_t * p_dmhandle, std::int64_t num_rows, \ + std::int64_t num_cols, std::int64_t ld, layout dense_layout, \ + sycl::buffer & val); \ + template std::enable_if_t, sycl::event> \ + create_dense_matrix(sycl::queue & queue, dense_matrix_handle_t * p_dmhandle, \ + std::int64_t num_rows, std::int64_t num_cols, std::int64_t ld, \ + layout dense_layout, FP_TYPE * val, \ + const std::vector &dependencies) +FOR_EACH_FP_TYPE(INSTANTIATE_CREATE_DENSE_MATRIX); +#undef INSTANTIATE_CREATE_DENSE_MATRIX + +// CSR matrix +template +void create_csr_matrix(sycl::queue &queue, matrix_handle_t *p_smhandle, std::int64_t num_rows, + std::int64_t num_cols, std::int64_t nnz, index_base index, + sycl::buffer &row_ptr, sycl::buffer &col_ind, + sycl::buffer &val) { + auto event = queue.submit([&](sycl::handler &cgh) { + auto row_acc = row_ptr.template get_access(cgh); + auto col_acc = col_ind.template get_access(cgh); + auto val_acc = val.template get_access(cgh); + detail::submit_host_task(cgh, queue, [=](detail::CusparseScopedContextHandler &sc) { + // Ensure that a cusparse handle is created before any other cuSPARSE function is called. + sc.get_handle(queue); + auto cuda_type = detail::CudaEnumType::value; + auto cudaIndexType = detail::CudaIndexEnumType::value; + auto cudaIndexBase = detail::get_cuda_index_base(index); + cusparseSpMatDescr_t cu_smhandle; + auto status = cusparseCreateCsr( + &cu_smhandle, num_rows, num_cols, nnz, sc.get_mem(row_acc), sc.get_mem(col_acc), + sc.get_mem(val_acc), cudaIndexType, cudaIndexType, cudaIndexBase, cuda_type); + detail::check_status(status, "create_csr_matrix"); + auto internal_smhandle = + new detail::sparse_matrix_handle(cu_smhandle, val, row_ptr, col_ind); + *p_smhandle = reinterpret_cast(internal_smhandle); + }); + }); + event.wait_and_throw(); +} + +template +sycl::event create_csr_matrix(sycl::queue &queue, matrix_handle_t *p_smhandle, + std::int64_t num_rows, std::int64_t num_cols, std::int64_t nnz, + index_base index, intType *row_ptr, intType *col_ind, fpType *val, + const std::vector &dependencies) { + return queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependencies); + detail::submit_host_task(cgh, queue, [=](detail::CusparseScopedContextHandler &sc) { + // Ensure that a cusparse handle is created before any other cuSPARSE function is called. + sc.get_handle(queue); + auto cuda_type = detail::CudaEnumType::value; + auto cudaIndexType = detail::CudaIndexEnumType::value; + auto cudaIndexBase = detail::get_cuda_index_base(index); + cusparseSpMatDescr_t cu_smhandle; + auto status = cusparseCreateCsr( + &cu_smhandle, num_rows, num_cols, nnz, sc.get_mem(row_ptr), sc.get_mem(col_ind), + sc.get_mem(val), cudaIndexType, cudaIndexType, cudaIndexBase, cuda_type); + detail::check_status(status, "create_csr_matrix"); + auto internal_smhandle = + new detail::sparse_matrix_handle(cu_smhandle, val, row_ptr, col_ind); + *p_smhandle = reinterpret_cast(internal_smhandle); + }); + }); +} + +#define INSTANTIATE_CREATE_CSR_MATRIX(FP_TYPE, FP_SUFFIX, INT_TYPE, INT_SUFFIX) \ + template std::enable_if_t> \ + create_csr_matrix( \ + sycl::queue & queue, matrix_handle_t * p_smhandle, std::int64_t num_rows, \ + std::int64_t num_cols, std::int64_t nnz, index_base index, \ + sycl::buffer & row_ptr, sycl::buffer & col_ind, \ + sycl::buffer & val); \ + template std::enable_if_t, sycl::event> \ + create_csr_matrix( \ + sycl::queue & queue, matrix_handle_t * p_smhandle, std::int64_t num_rows, \ + std::int64_t num_cols, std::int64_t nnz, index_base index, INT_TYPE * row_ptr, \ + INT_TYPE * col_ind, FP_TYPE * val, const std::vector &dependencies) +FOR_EACH_FP_AND_INT_TYPE(INSTANTIATE_CREATE_CSR_MATRIX); +#undef INSTANTIATE_CREATE_CSR_MATRIX + +// Destroy data types +sycl::event destroy_dense_vector(sycl::queue &queue, dense_vector_handle_t dvhandle, + const std::vector &dependencies) { + return queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependencies); + cgh.host_task([=]() { + auto internal_dvhandle = reinterpret_cast(dvhandle); + auto cu_dvhandle = internal_dvhandle->cu_handle; + auto status = cusparseDestroyDnVec(cu_dvhandle); + detail::check_status(status, "destroy_dense_vector"); + delete internal_dvhandle; + }); + }); +} +sycl::event destroy_dense_matrix(sycl::queue &queue, dense_matrix_handle_t dmhandle, + const std::vector &dependencies) { + return queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependencies); + cgh.host_task([=]() { + auto internal_dmhandle = reinterpret_cast(dmhandle); + auto cu_dmhandle = internal_dmhandle->cu_handle; + auto status = cusparseDestroyDnMat(cu_dmhandle); + detail::check_status(status, "destroy_dense_matrix"); + delete internal_dmhandle; + }); + }); +} +sycl::event destroy_csr_matrix(sycl::queue &queue, matrix_handle_t smhandle, + const std::vector &dependencies) { + return queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependencies); + cgh.host_task([=]() { + auto internal_smhandle = reinterpret_cast(smhandle); + auto cu_smhandle = internal_smhandle->cu_handle; + auto status = cusparseDestroySpMat(cu_smhandle); + detail::check_status(status, "destroy_csr_matrix"); + delete internal_smhandle; + }); + }); +} + +// Matrix property +void set_matrix_property(sycl::queue &, matrix_handle_t, matrix_property) { + // No equivalent in cuSPARSE +} + +} // namespace oneapi::mkl::sparse::cusparse diff --git a/src/sparse_blas/backends/cusparse/cusparse_error.hpp b/src/sparse_blas/backends/cusparse/cusparse_error.hpp new file mode 100644 index 000000000..7964514fe --- /dev/null +++ b/src/sparse_blas/backends/cusparse/cusparse_error.hpp @@ -0,0 +1,94 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ + +#ifndef _ONEMKL_SPARSE_BLAS_BACKENDS_CUSPARSE_ERROR_HPP_ +#define _ONEMKL_SPARSE_BLAS_BACKENDS_CUSPARSE_ERROR_HPP_ + +#include + +#include +#include + +#include "oneapi/mkl/exceptions.hpp" + +namespace oneapi::mkl::sparse::cusparse::detail { + +inline std::string cuda_result_to_str(CUresult result) { + switch (result) { +#define ONEMKL_CUSPARSE_CASE(STATUS) \ + case STATUS: return #STATUS + ONEMKL_CUSPARSE_CASE(CUDA_SUCCESS); + ONEMKL_CUSPARSE_CASE(CUDA_ERROR_NOT_PERMITTED); + ONEMKL_CUSPARSE_CASE(CUDA_ERROR_INVALID_CONTEXT); + ONEMKL_CUSPARSE_CASE(CUDA_ERROR_INVALID_DEVICE); + ONEMKL_CUSPARSE_CASE(CUDA_ERROR_INVALID_VALUE); + ONEMKL_CUSPARSE_CASE(CUDA_ERROR_OUT_OF_MEMORY); + ONEMKL_CUSPARSE_CASE(CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES); + default: return ""; + } +} + +#define CUDA_ERROR_FUNC(func, ...) \ + do { \ + auto res = func(__VA_ARGS__); \ + if (res != CUDA_SUCCESS) { \ + throw oneapi::mkl::exception("sparse_blas", #func, \ + "cuda error: " + cuda_result_to_str(res)); \ + } \ + } while (0) + +inline std::string cusparse_status_to_str(cusparseStatus_t status) { + switch (status) { +#define ONEMKL_CUSPARSE_CASE(STATUS) \ + case STATUS: return #STATUS + ONEMKL_CUSPARSE_CASE(CUSPARSE_STATUS_SUCCESS); + ONEMKL_CUSPARSE_CASE(CUSPARSE_STATUS_NOT_INITIALIZED); + ONEMKL_CUSPARSE_CASE(CUSPARSE_STATUS_ALLOC_FAILED); + ONEMKL_CUSPARSE_CASE(CUSPARSE_STATUS_INVALID_VALUE); + ONEMKL_CUSPARSE_CASE(CUSPARSE_STATUS_ARCH_MISMATCH); + ONEMKL_CUSPARSE_CASE(CUSPARSE_STATUS_EXECUTION_FAILED); + ONEMKL_CUSPARSE_CASE(CUSPARSE_STATUS_INTERNAL_ERROR); + ONEMKL_CUSPARSE_CASE(CUSPARSE_STATUS_MATRIX_TYPE_NOT_SUPPORTED); + ONEMKL_CUSPARSE_CASE(CUSPARSE_STATUS_NOT_SUPPORTED); + ONEMKL_CUSPARSE_CASE(CUSPARSE_STATUS_INSUFFICIENT_RESOURCES); +#undef ONEMKL_CUSPARSE_CASE + default: return ""; + } +} + +inline void check_status(cusparseStatus_t status, const std::string& function, + std::string error_str = "") { + if (status != CUSPARSE_STATUS_SUCCESS) { + if (!error_str.empty()) { + error_str += "; "; + } + error_str += "cuSPARSE status: " + cusparse_status_to_str(status); + throw oneapi::mkl::exception("sparse_blas", function, error_str); + } +} + +#define CUSPARSE_ERR_FUNC(func, ...) \ + do { \ + auto status = func(__VA_ARGS__); \ + check_status(status, #func); \ + } while (0) + +} // namespace oneapi::mkl::sparse::cusparse::detail + +#endif // _ONEMKL_SPARSE_BLAS_BACKENDS_CUSPARSE_ERROR_HPP_ diff --git a/src/sparse_blas/backends/cusparse/cusparse_handle.hpp b/src/sparse_blas/backends/cusparse/cusparse_handle.hpp new file mode 100644 index 000000000..40da16a08 --- /dev/null +++ b/src/sparse_blas/backends/cusparse/cusparse_handle.hpp @@ -0,0 +1,63 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ + +#ifndef _ONEMKL_SPARSE_BLAS_BACKENDS_CUSPARSE_HANDLE_HPP_ +#define _ONEMKL_SPARSE_BLAS_BACKENDS_CUSPARSE_HANDLE_HPP_ + +/** + * @file Similar to blas_handle.hpp + * Provides a map from a pi_context (or equivalent) to a cusparseHandle_t. + * @see cusparse_scope_handle.hpp +*/ + +#include +#include + +namespace oneapi::mkl::sparse::cusparse::detail { + +template +struct cusparse_handle { + using handle_container_t = std::unordered_map *>; + handle_container_t cusparse_handle_mapper_{}; + + ~cusparse_handle() noexcept(false) { + for (auto &handle_pair : cusparse_handle_mapper_) { + if (handle_pair.second != nullptr) { + auto handle = handle_pair.second->exchange(nullptr); + if (handle != nullptr) { + CUSPARSE_ERR_FUNC(cusparseDestroy, handle); + handle = nullptr; + } + else { + // if the handle is nullptr it means the handle was already + // destroyed by the ContextCallback and we're free to delete the + // atomic object. + delete handle_pair.second; + } + + handle_pair.second = nullptr; + } + } + cusparse_handle_mapper_.clear(); + } +}; + +} // namespace oneapi::mkl::sparse::cusparse::detail + +#endif // _ONEMKL_SPARSE_BLAS_BACKENDS_CUSPARSE_HANDLE_HPP_ diff --git a/src/sparse_blas/backends/cusparse/cusparse_helper.hpp b/src/sparse_blas/backends/cusparse/cusparse_helper.hpp new file mode 100644 index 000000000..53c009f2e --- /dev/null +++ b/src/sparse_blas/backends/cusparse/cusparse_helper.hpp @@ -0,0 +1,144 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ +#ifndef _ONEMKL_SPARSE_BLAS_BACKENDS_CUSPARSE_HELPER_HPP_ +#define _ONEMKL_SPARSE_BLAS_BACKENDS_CUSPARSE_HELPER_HPP_ + +#include +#include +#include +#include + +#include + +#include "oneapi/mkl/sparse_blas/types.hpp" +#include "cusparse_error.hpp" + +namespace oneapi::mkl::sparse::cusparse::detail { + +template +struct CudaEnumType; +template <> +struct CudaEnumType { + static constexpr cudaDataType_t value = CUDA_R_32F; +}; +template <> +struct CudaEnumType { + static constexpr cudaDataType_t value = CUDA_R_64F; +}; +template <> +struct CudaEnumType> { + static constexpr cudaDataType_t value = CUDA_C_32F; +}; +template <> +struct CudaEnumType> { + static constexpr cudaDataType_t value = CUDA_C_64F; +}; + +template +struct CudaIndexEnumType; +template <> +struct CudaIndexEnumType { + static constexpr cusparseIndexType_t value = CUSPARSE_INDEX_32I; +}; +template <> +struct CudaIndexEnumType { + static constexpr cusparseIndexType_t value = CUSPARSE_INDEX_64I; +}; + +template +inline std::string enum_to_str(E e) { + return std::to_string(static_cast(e)); +} + +inline std::int64_t safe_cast(std::size_t x, const std::string& func_name) { + if (x >= std::numeric_limits::max()) { + throw oneapi::mkl::exception( + "sparse_blas", func_name, + "Integer overflow: " + std::to_string(x) + " does not fit in std::int64_t"); + } + return static_cast(x); +} + +inline cusparseOrder_t get_cuda_order(layout l) { + switch (l) { + case layout::row_major: return CUSPARSE_ORDER_ROW; + case layout::col_major: return CUSPARSE_ORDER_COL; + default: + throw oneapi::mkl::invalid_argument("sparse_blas", "get_cuda_order", + "Unknown layout: " + enum_to_str(l)); + } +} + +inline cusparseIndexBase_t get_cuda_index_base(index_base index) { + switch (index) { + case index_base::zero: return CUSPARSE_INDEX_BASE_ZERO; + case index_base::one: return CUSPARSE_INDEX_BASE_ONE; + default: + throw oneapi::mkl::invalid_argument("sparse_blas", "get_cuda_index_base", + "Unknown index_base: " + enum_to_str(index)); + } +} + +inline cusparseOperation_t get_cuda_operation(transpose op) { + switch (op) { + case transpose::nontrans: return CUSPARSE_OPERATION_NON_TRANSPOSE; + case transpose::trans: return CUSPARSE_OPERATION_TRANSPOSE; + case transpose::conjtrans: return CUSPARSE_OPERATION_CONJUGATE_TRANSPOSE; + default: + throw oneapi::mkl::invalid_argument("sparse_blas", "get_cuda_operation", + "Unknown transpose operation: " + enum_to_str(op)); + } +} + +inline auto get_cuda_uplo(uplo uplo_val) { + switch (uplo_val) { + case uplo::upper: return CUSPARSE_FILL_MODE_UPPER; + case uplo::lower: return CUSPARSE_FILL_MODE_LOWER; + default: + throw oneapi::mkl::invalid_argument("sparse_blas", "get_cuda_uplo", + "Unknown uplo: " + enum_to_str(uplo_val)); + } +} + +inline auto get_cuda_diag(diag diag_val) { + switch (diag_val) { + case diag::nonunit: return CUSPARSE_DIAG_TYPE_NON_UNIT; + case diag::unit: return CUSPARSE_DIAG_TYPE_UNIT; + default: + throw oneapi::mkl::invalid_argument("sparse_blas", "get_cuda_diag", + "Unknown diag: " + enum_to_str(diag_val)); + } +} + +inline void set_matrix_attributes(const std::string& func_name, cusparseSpMatDescr_t cu_a, + uplo uplo_val, diag diag_val) { + auto cu_fill_mode = get_cuda_uplo(uplo_val); + auto status = cusparseSpMatSetAttribute(cu_a, CUSPARSE_SPMAT_FILL_MODE, &cu_fill_mode, + sizeof(cu_fill_mode)); + check_status(status, func_name + "/set_uplo"); + + auto cu_diag_type = get_cuda_diag(diag_val); + status = cusparseSpMatSetAttribute(cu_a, CUSPARSE_SPMAT_DIAG_TYPE, &cu_diag_type, + sizeof(cu_diag_type)); + check_status(status, func_name + "/set_diag"); +} + +} // namespace oneapi::mkl::sparse::cusparse::detail + +#endif //_ONEMKL_SPARSE_BLAS_BACKENDS_CUSPARSE_HELPER_HPP_ diff --git a/src/sparse_blas/backends/cusparse/cusparse_internal_containers.hpp b/src/sparse_blas/backends/cusparse/cusparse_internal_containers.hpp new file mode 100644 index 000000000..a93d8a47b --- /dev/null +++ b/src/sparse_blas/backends/cusparse/cusparse_internal_containers.hpp @@ -0,0 +1,211 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ + +#ifndef _ONEMKL_SPARSE_BLAS_BACKENDS_CUSPARSE_INTERNAL_CONTAINERS_HPP_ +#define _ONEMKL_SPARSE_BLAS_BACKENDS_CUSPARSE_INTERNAL_CONTAINERS_HPP_ + +#include + +#if __has_include() +#include +#else +#include +#endif + +#include + +namespace oneapi::mkl::sparse::cusparse::detail { + +enum data_type { int32, int64, real_fp32, real_fp64, complex_fp32, complex_fp64 }; + +inline std::string data_type_to_str(data_type data_type) { + switch (data_type) { + case int32: return "int32"; + case int64: return "int64"; + case real_fp32: return "real_fp32"; + case real_fp64: return "real_fp64"; + case complex_fp32: return "complex_fp32"; + case complex_fp64: return "complex_fp64"; + default: return "unknown"; + } +} + +template +data_type get_data_type() { + if constexpr (std::is_same_v) { + return data_type::int32; + } + else if constexpr (std::is_same_v) { + return data_type::int64; + } + else if constexpr (std::is_same_v) { + return data_type::real_fp32; + } + else if constexpr (std::is_same_v) { + return data_type::real_fp64; + } + else if constexpr (std::is_same_v>) { + return data_type::complex_fp32; + } + else if constexpr (std::is_same_v>) { + return data_type::complex_fp64; + } + else { + static_assert(false, "Unsupported type"); + } +} + +/** + * Represent a non-templated container for USM or buffer. +*/ +struct generic_container { + // Store the buffer to properly handle the dependencies when the handle is used. This is not needed for USM pointers. + // Use a void* type for the buffer to avoid using template arguments in every function using data handles. + // Using reinterpret does not solve the issue as the returned buffer needs the type of the original buffer for the aligned_allocator. + std::shared_ptr buffer_ptr; + + // Underlying USM or buffer data type + data_type data_type; + + template + generic_container(T* /*ptr*/) : buffer_ptr(), + data_type(get_data_type()) {} + + template + generic_container(const sycl::buffer& buffer) + : buffer_ptr(std::make_shared>(buffer)), + data_type(get_data_type()) {} +}; + +template +struct dense_handle { + CuHandle cu_handle; + + generic_container value_container; + + template + dense_handle(CuHandle cu_handle, T* ptr) + : cu_handle(cu_handle), + value_container(generic_container(ptr)) {} + + template + dense_handle(CuHandle cu_handle, const sycl::buffer& value_buffer) + : cu_handle(cu_handle), + value_container(value_buffer) {} + + bool use_buffer() const { + return static_cast(value_container.buffer_ptr); + } + + data_type get_value_type() const { + return value_container.data_type; + } +}; + +template +struct sparse_handle { + CuHandle cu_handle; + + generic_container value_container; + generic_container row_container; + generic_container col_container; + + template + sparse_handle(CuHandle cu_handle, fpType* value_ptr, intType* row_ptr, intType* col_ptr) + : cu_handle(cu_handle), + value_container(generic_container(value_ptr)), + row_container(generic_container(row_ptr)), + col_container(generic_container(col_ptr)) {} + + template + sparse_handle(CuHandle cu_handle, const sycl::buffer& value_buffer, + const sycl::buffer& row_buffer, + const sycl::buffer& col_buffer) + : cu_handle(cu_handle), + value_container(value_buffer), + row_container(row_buffer), + col_container(col_buffer) {} + + bool use_buffer() const { + return static_cast(value_container.buffer_ptr); + } + + data_type get_value_type() const { + return value_container.data_type; + } + + data_type get_int_type() const { + return row_container.data_type; + } +}; + +using dense_vector_handle = dense_handle; +using dense_matrix_handle = dense_handle; +using sparse_matrix_handle = sparse_handle; + +/** + * Check that all internal containers use the same container. +*/ +template +void check_all_containers_use_buffers(const std::string& function, + sparse_matrix_handle first_internal_container, + Ts... internal_containers) { + bool first_use_buffer = first_internal_container.use_buffer(); + for (const auto internal_container : { internal_containers... }) { + if (internal_container.use_buffer() != first_use_buffer) { + throw oneapi::mkl::invalid_argument( + "sparse_blas", function, + "Incompatible container types. All inputs and outputs must use the same container: buffer or USM"); + } + } +} + +/** + * Check that all internal containers use the same container type, data type and integer type. +*/ +template +void check_all_containers_compatible(const std::string& function, + sparse_matrix_handle first_internal_container, + Ts... internal_containers) { + check_all_containers_use_buffers(function, first_internal_container, internal_containers...); + data_type first_value_type = first_internal_container.get_value_type(); + data_type first_int_type = first_internal_container.get_int_type(); + for (const auto internal_container : { internal_containers... }) { + const data_type other_value_type = internal_container.get_value_type(); + if (other_value_type != first_value_type) { + throw oneapi::mkl::invalid_argument( + "sparse_blas", function, + "Incompatible data types expected " + data_type_to_str(first_value_type) + + " but got " + data_type_to_str(other_value_type)); + } + if constexpr (std::is_same_v) { + const data_type other_int_type = internal_container.get_int_type(); + if (other_int_type != first_int_type) { + throw oneapi::mkl::invalid_argument( + "sparse_blas", function, + "Incompatible integer types expected " + data_type_to_str(first_int_type) + + " but got " + data_type_to_str(other_int_type)); + } + } + } +} + +} // namespace oneapi::mkl::sparse::cusparse::detail + +#endif // _ONEMKL_SPARSE_BLAS_BACKENDS_CUSPARSE_INTERNAL_CONTAINERS_HPP_ diff --git a/src/sparse_blas/backends/cusparse/cusparse_scope_handle.cpp b/src/sparse_blas/backends/cusparse/cusparse_scope_handle.cpp new file mode 100644 index 000000000..f38c7da51 --- /dev/null +++ b/src/sparse_blas/backends/cusparse/cusparse_scope_handle.cpp @@ -0,0 +1,128 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ + +/** + * @file Similar to cublas_scope_handle.cpp +*/ + +#include "cusparse_scope_handle.hpp" + +namespace oneapi::mkl::sparse::cusparse::detail { + +/** + * Inserts a new element in the map if its key is unique. This new element + * is constructed in place using args as the arguments for the construction + * of a value_type (which is an object of a pair type). The insertion only + * takes place if no other element in the container has a key equivalent to + * the one being emplaced (keys in a map container are unique). + */ +thread_local cusparse_handle CusparseScopedContextHandler::handle_helper = + cusparse_handle{}; + +CusparseScopedContextHandler::CusparseScopedContextHandler(sycl::queue queue) + : needToRecover_(false) { + placedContext_ = new sycl::context(queue.get_context()); + auto device = queue.get_device(); + auto desired = sycl::get_native(*placedContext_); + CUDA_ERROR_FUNC(cuCtxGetCurrent, &original_); + if (original_ != desired) { + // Sets the desired context as the active one for the thread + CUDA_ERROR_FUNC(cuCtxSetCurrent, desired); + // No context is installed and the suggested context is primary + // This is the most common case. We can activate the context in the + // thread and leave it there until all the PI context referring to the + // same underlying CUDA primary context are destroyed. This emulates + // the behaviour of the CUDA runtime api, and avoids costly context + // switches. No action is required on this side of the if. + needToRecover_ = !(original_ == nullptr); + } +} + +CusparseScopedContextHandler::~CusparseScopedContextHandler() noexcept(false) { + if (needToRecover_) { + CUDA_ERROR_FUNC(cuCtxSetCurrent, original_); + } + delete placedContext_; +} + +void ContextCallback(void *userData) { + auto *ptr = static_cast *>(userData); + if (!ptr) { + return; + } + auto handle = ptr->exchange(nullptr); + if (handle != nullptr) { + CUSPARSE_ERR_FUNC(cusparseDestroy, handle); + handle = nullptr; + } + else { + // if the handle is nullptr it means the handle was already destroyed by + // the cusparse_handle destructor and we're free to delete the atomic + // object. + delete ptr; + } +} + +cusparseHandle_t CusparseScopedContextHandler::get_handle(const sycl::queue &queue) { + auto piPlacedContext_ = reinterpret_cast( + sycl::get_native(*placedContext_)); + CUstream streamId = get_stream(queue); + auto it = handle_helper.cusparse_handle_mapper_.find(piPlacedContext_); + if (it != handle_helper.cusparse_handle_mapper_.end()) { + if (it->second == nullptr) { + handle_helper.cusparse_handle_mapper_.erase(it); + } + else { + auto handle = it->second->load(); + if (handle != nullptr) { + cudaStream_t currentStreamId; + CUSPARSE_ERR_FUNC(cusparseGetStream, handle, ¤tStreamId); + if (currentStreamId != streamId) { + CUSPARSE_ERR_FUNC(cusparseSetStream, handle, streamId); + } + return handle; + } + else { + handle_helper.cusparse_handle_mapper_.erase(it); + } + } + } + + cusparseHandle_t handle; + CUSPARSE_ERR_FUNC(cusparseCreate, &handle); + CUSPARSE_ERR_FUNC(cusparseSetStream, handle, streamId); + + auto insert_iter = handle_helper.cusparse_handle_mapper_.insert( + std::make_pair(piPlacedContext_, new std::atomic(handle))); + + sycl::detail::pi::contextSetExtendedDeleter(*placedContext_, ContextCallback, + insert_iter.first->second); + + return handle; +} + +CUstream CusparseScopedContextHandler::get_stream(const sycl::queue &queue) { + return sycl::get_native(queue); +} + +sycl::context CusparseScopedContextHandler::get_context(const sycl::queue &queue) { + return queue.get_context(); +} + +} // namespace oneapi::mkl::sparse::cusparse::detail diff --git a/src/sparse_blas/backends/cusparse/cusparse_scope_handle.hpp b/src/sparse_blas/backends/cusparse/cusparse_scope_handle.hpp new file mode 100644 index 000000000..b7df97bff --- /dev/null +++ b/src/sparse_blas/backends/cusparse/cusparse_scope_handle.hpp @@ -0,0 +1,82 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ +#ifndef _ONEMKL_SPARSE_BLAS_BACKENDS_CUSPARSE_SCOPE_HANDLE_HPP_ +#define _ONEMKL_SPARSE_BLAS_BACKENDS_CUSPARSE_SCOPE_HANDLE_HPP_ + +/** + * @file Similar to cusparse_scope_handle.hpp +*/ + +#if __has_include() +#include +#else +#include +#endif + +#include + +#include "cusparse_error.hpp" +#include "cusparse_handle.hpp" +#include "cusparse_helper.hpp" + +namespace oneapi::mkl::sparse::cusparse::detail { + +class CusparseScopedContextHandler { + CUcontext original_; + sycl::context *placedContext_; + bool needToRecover_; + static thread_local cusparse_handle handle_helper; + + CUstream get_stream(const sycl::queue &queue); + sycl::context get_context(const sycl::queue &queue); + +public: + CusparseScopedContextHandler(sycl::queue queue); + + ~CusparseScopedContextHandler() noexcept(false); + + /** + * @brief get_handle: creates the handle by implicitly impose the advice + * given by nvidia for creating a cusparse_handle. (e.g. one cuStream per device + * per thread). + * @param queue sycl queue. + * @return cusparseHandle_t a handle to construct cusparse routines + */ + cusparseHandle_t get_handle(const sycl::queue &queue); + + // This is a work-around function for reinterpret_casting the memory. This + // will be fixed when SYCL-2020 has been implemented for Pi backend. + template + inline void *get_mem(AccT acc) { + return reinterpret_cast(&acc[0]); + } + + template + inline void *get_mem(T *ptr) { + return reinterpret_cast(ptr); + } + + void wait_stream(const sycl::queue &queue) { + cuStreamSynchronize(get_stream(queue)); + } +}; + +} // namespace oneapi::mkl::sparse::cusparse::detail + +#endif //_ONEMKL_SPARSE_BLAS_BACKENDS_CUSPARSE_SCOPE_HANDLE_HPP_ diff --git a/src/sparse_blas/backends/cusparse/cusparse_task.hpp b/src/sparse_blas/backends/cusparse/cusparse_task.hpp new file mode 100644 index 000000000..3ef6db0a1 --- /dev/null +++ b/src/sparse_blas/backends/cusparse/cusparse_task.hpp @@ -0,0 +1,130 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Copyright (C) 2022 Heidelberg University, Engineering Mathematics and Computing Lab (EMCL) and Computing Centre (URZ) +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ + +#ifndef _ONEMKL_SPARSE_BLAS_BACKENDS_CUSPARSE_TASKS_HPP_ +#define _ONEMKL_SPARSE_BLAS_BACKENDS_CUSPARSE_TASKS_HPP_ + +#include "cusparse_internal_containers.hpp" +#include "cusparse_scope_handle.hpp" + +namespace oneapi::mkl::sparse::cusparse::detail { + +template +auto get_value_accessor(sycl::handler &cgh, Container container) { + auto buffer_ptr = + reinterpret_cast *>(container.value_container.buffer_ptr.get()); + return buffer_ptr->template get_access(cgh); +} + +template +auto get_fp_accessors(sycl::handler &cgh, Ts... containers) { + return std::array, sizeof...(containers)>{ get_value_accessor( + cgh, containers)... }; +} + +template +auto get_row_accessor(sycl::handler &cgh, sparse_matrix_handle smhandle) { + auto buffer_ptr = + reinterpret_cast *>(smhandle.row_container.buffer_ptr.get()); + return buffer_ptr->template get_access(cgh); +} + +template +auto get_col_accessor(sycl::handler &cgh, sparse_matrix_handle smhandle) { + auto buffer_ptr = + reinterpret_cast *>(smhandle.col_container.buffer_ptr.get()); + return buffer_ptr->template get_access(cgh); +} + +template +auto get_int_accessors(sycl::handler &cgh, sparse_matrix_handle smhandle) { + // TODO(Romain): Support possibly multiple sparse_matrix_handle + return std::array, 2>{ get_row_accessor(cgh, smhandle), + get_col_accessor(cgh, smhandle) }; +} + +template +void submit_host_task(sycl::handler &cgh, sycl::queue &queue, Functor functor, + CaptureAcc... accessors) { + // Only capture the accessors to ensure the dependencies are properly handled + // The accessors's pointer have already been set to the native container types in previous functions + cgh.host_task([functor, queue, accessors...]() { + auto unused = std::make_tuple(accessors...); + (void)unused; + auto sc = CusparseScopedContextHandler(queue); + functor(sc); + }); +} + +template +sycl::event dispatch_submit(sycl::queue queue, const std::vector &dependencies, + Functor functor, sparse_matrix_handle sm_handle, + Ts... other_containers) { + if (sm_handle.use_buffer()) { + data_type value_type = sm_handle.get_value_type(); + data_type int_type = sm_handle.get_int_type(); + +#define ONEMKL_CUSPARSE_SUBMIT(FP_TYPE, INT_TYPE) \ + return queue.submit([&](sycl::handler &cgh) { \ + cgh.depends_on(dependencies); \ + auto fp_accs = get_fp_accessors(cgh, sm_handle, other_containers...); \ + auto int_accs = get_int_accessors(cgh, sm_handle); \ + submit_host_task(cgh, queue, functor, fp_accs, int_accs); \ + }) +#define ONEMKL_CUSPARSE_SUBMIT_INT(FP_TYPE) \ + if (int_type == data_type::int32) { \ + ONEMKL_CUSPARSE_SUBMIT(FP_TYPE, std::int32_t); \ + } \ + else if (int_type == data_type::int64) { \ + ONEMKL_CUSPARSE_SUBMIT(FP_TYPE, std::int64_t); \ + } + + if (value_type == data_type::real_fp32) { + ONEMKL_CUSPARSE_SUBMIT_INT(float); + } + else if (value_type == data_type::real_fp64) { + ONEMKL_CUSPARSE_SUBMIT_INT(double); + } + else if (value_type == data_type::complex_fp32) { + ONEMKL_CUSPARSE_SUBMIT_INT(std::complex); + } + else if (value_type == data_type::complex_fp64) { + ONEMKL_CUSPARSE_SUBMIT_INT(std::complex); + } + +#undef ONEMKL_CUSPARSE_SUBMIT_INT +#undef ONEMKL_CUSPARSE_SUBMIT + + throw oneapi::mkl::exception("sparse_blas", "dispatch_submit", + "Could not dispatch buffer kernel to a supported type"); + } + else { + // USM submit does not need to capture accessors + return queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependencies); + submit_host_task(cgh, queue, functor); + }); + } +} + +} // namespace oneapi::mkl::sparse::cusparse::detail + +#endif // _ONEMKL_SPARSE_BLAS_BACKENDS_CUSPARSE_TASKS_HPP_ diff --git a/src/sparse_blas/backends/cusparse/cusparse_wrappers.cpp b/src/sparse_blas/backends/cusparse/cusparse_wrappers.cpp new file mode 100644 index 000000000..278aec296 --- /dev/null +++ b/src/sparse_blas/backends/cusparse/cusparse_wrappers.cpp @@ -0,0 +1,32 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ + +#include "oneapi/mkl/sparse_blas/types.hpp" + +#include "oneapi/mkl/sparse_blas/detail/cusparse/onemkl_sparse_blas_cusparse.hpp" + +#include "sparse_blas/function_table.hpp" + +#define WRAPPER_VERSION 1 +#define BACKEND cusparse + +extern "C" sparse_blas_function_table_t mkl_sparse_blas_table = { + WRAPPER_VERSION, +#include "sparse_blas/backends/backend_wrappers.cxx" +}; diff --git a/src/sparse_blas/backends/cusparse/operations/cusparse_gemm.cpp b/src/sparse_blas/backends/cusparse/operations/cusparse_gemm.cpp new file mode 100644 index 000000000..3a34cea99 --- /dev/null +++ b/src/sparse_blas/backends/cusparse/operations/cusparse_gemm.cpp @@ -0,0 +1,74 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ + +#include "oneapi/mkl/sparse_blas/detail/cusparse/onemkl_sparse_blas_cusparse.hpp" + +#include "sparse_blas/backends/cusparse/cusparse_error.hpp" +#include "sparse_blas/backends/cusparse/cusparse_helper.hpp" +#include "sparse_blas/backends/cusparse/cusparse_task.hpp" +#include "sparse_blas/backends/cusparse/cusparse_internal_containers.hpp" +#include "sparse_blas/macros.hpp" + +namespace oneapi::mkl::sparse::cusparse { + +// TODO(Romain): Update to new API + +sycl::event optimize_gemm(sycl::queue& queue, transpose transpose_A, matrix_handle_t handle, + const std::vector& dependencies) { + return {}; +} + +sycl::event optimize_gemm(sycl::queue& queue, transpose transpose_A, transpose transpose_B, + layout dense_matrix_layout, const std::int64_t columns, + matrix_handle_t handle, const std::vector& dependencies) { + return {}; +} + +template +void gemm(sycl::queue& queue, layout dense_matrix_layout, transpose transpose_A, + transpose transpose_B, const fpType alpha, matrix_handle_t A_handle, + sycl::buffer& B, const std::int64_t columns, const std::int64_t ldb, + const fpType beta, sycl::buffer& C, const std::int64_t ldc) {} + +template +sycl::event gemm(sycl::queue& queue, layout dense_matrix_layout, transpose transpose_A, + transpose transpose_B, const fpType alpha, matrix_handle_t A_handle, + const fpType* B, const std::int64_t columns, const std::int64_t ldb, + const fpType beta, fpType* C, const std::int64_t ldc, + const std::vector& dependencies) { + return {}; +} + +#define INSTANTIATE_GEMM(FP_TYPE, FP_SUFFIX) \ + template void gemm(sycl::queue& queue, layout dense_matrix_layout, transpose transpose_A, \ + transpose transpose_B, const FP_TYPE alpha, matrix_handle_t A_handle, \ + sycl::buffer& B, const std::int64_t columns, \ + const std::int64_t ldb, const FP_TYPE beta, sycl::buffer& C, \ + const std::int64_t ldc); \ + template sycl::event gemm( \ + sycl::queue& queue, layout dense_matrix_layout, transpose transpose_A, \ + transpose transpose_B, const FP_TYPE alpha, matrix_handle_t A_handle, const FP_TYPE* B, \ + const std::int64_t columns, const std::int64_t ldb, const FP_TYPE beta, FP_TYPE* C, \ + const std::int64_t ldc, const std::vector& dependencies) + +FOR_EACH_FP_TYPE(INSTANTIATE_GEMM); + +#undef INSTANTIATE_GEMM + +} // namespace oneapi::mkl::sparse::cusparse diff --git a/src/sparse_blas/backends/cusparse/operations/cusparse_gemv.cpp b/src/sparse_blas/backends/cusparse/operations/cusparse_gemv.cpp new file mode 100644 index 000000000..14808cddd --- /dev/null +++ b/src/sparse_blas/backends/cusparse/operations/cusparse_gemv.cpp @@ -0,0 +1,60 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ + +#include "oneapi/mkl/sparse_blas/detail/cusparse/onemkl_sparse_blas_cusparse.hpp" + +#include "sparse_blas/backends/cusparse/cusparse_error.hpp" +#include "sparse_blas/backends/cusparse/cusparse_helper.hpp" +#include "sparse_blas/backends/cusparse/cusparse_task.hpp" +#include "sparse_blas/backends/cusparse/cusparse_internal_containers.hpp" +#include "sparse_blas/macros.hpp" + +namespace oneapi::mkl::sparse::cusparse { + +// TODO(Romain): Update to new API + +sycl::event optimize_gemv(sycl::queue& queue, transpose transpose_val, matrix_handle_t handle, + const std::vector& dependencies) { + return {}; +} + +template +void gemv(sycl::queue& queue, transpose transpose_val, const fpType alpha, matrix_handle_t A_handle, + sycl::buffer& x, const fpType beta, sycl::buffer& y) {} + +template +sycl::event gemv(sycl::queue& queue, transpose transpose_val, const fpType alpha, + matrix_handle_t A_handle, const fpType* x, const fpType beta, fpType* y, + const std::vector& dependencies) { + return {}; +} + +#define INSTANTIATE_GEMV(FP_TYPE, FP_SUFFIX) \ + template void gemv(sycl::queue& queue, transpose transpose_val, const FP_TYPE alpha, \ + matrix_handle_t A_handle, sycl::buffer& x, const FP_TYPE beta, \ + sycl::buffer& y); \ + template sycl::event gemv(sycl::queue& queue, transpose transpose_val, const FP_TYPE alpha, \ + matrix_handle_t A_handle, const FP_TYPE* x, const FP_TYPE beta, \ + FP_TYPE* y, const std::vector& dependencies) + +FOR_EACH_FP_TYPE(INSTANTIATE_GEMV); + +#undef INSTANTIATE_GEMV + +} // namespace oneapi::mkl::sparse::cusparse diff --git a/src/sparse_blas/backends/cusparse/operations/cusparse_trsv.cpp b/src/sparse_blas/backends/cusparse/operations/cusparse_trsv.cpp new file mode 100644 index 000000000..af3c5faec --- /dev/null +++ b/src/sparse_blas/backends/cusparse/operations/cusparse_trsv.cpp @@ -0,0 +1,196 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ + +#include "oneapi/mkl/sparse_blas/detail/cusparse/onemkl_sparse_blas_cusparse.hpp" + +#include "sparse_blas/backends/cusparse/cusparse_error.hpp" +#include "sparse_blas/backends/cusparse/cusparse_helper.hpp" +#include "sparse_blas/backends/cusparse/cusparse_task.hpp" +#include "sparse_blas/backends/cusparse/cusparse_internal_containers.hpp" +#include "sparse_blas/macros.hpp" + +namespace oneapi::mkl::sparse::cusparse { + +namespace detail { + +struct trsv_descr { + cusparseSpSVDescr_t cu_descr; + cudaDataType compute_type; + + float alpha_one_fp32 = 1.f; + double alpha_one_fp64 = 1.0; + cuComplex alpha_one_cx32 = make_cuComplex(1.f, 0.f); + cuDoubleComplex alpha_one_cx64 = make_cuDoubleComplex(1.0, 0.0); + + void *get_alpha() { + switch (compute_type) { + case CUDA_R_32F: return &alpha_one_fp32; + case CUDA_R_64F: return &alpha_one_fp64; + case CUDA_C_32F: return &alpha_one_cx32; + case CUDA_C_64F: return &alpha_one_cx64; + default: return nullptr; + } + } +}; + +inline auto get_cuda_trsv_alg(trsv_alg /*alg*/) { + return CUSPARSE_SPSV_ALG_DEFAULT; +} + +void optimize_trsv_impl(cusparseHandle_t cu_handle, uplo uplo_val, transpose transpose_val, + diag diag_val, matrix_handle_t A_handle, trsv_alg alg, + trsv_descr_t trsv_descr, void *temp_buffer_ptr) { + auto cu_a = reinterpret_cast(A_handle)->cu_handle; + detail::set_matrix_attributes("optimize_trsv", cu_a, uplo_val, diag_val); + cusparseConstDnVecDescr_t cu_x = nullptr; + cusparseDnVecDescr_t cu_y = nullptr; + auto cu_op = detail::get_cuda_operation(transpose_val); + auto internal_trsv_descr = reinterpret_cast(trsv_descr); + auto cu_alpha = internal_trsv_descr->get_alpha(); + auto cu_type = internal_trsv_descr->compute_type; + auto cu_alg = detail::get_cuda_trsv_alg(alg); + auto cu_descr = internal_trsv_descr->cu_descr; + auto status = cusparseSpSV_analysis(cu_handle, cu_op, cu_alpha, cu_a, cu_x, cu_y, cu_type, + cu_alg, cu_descr, temp_buffer_ptr); + detail::check_status(status, "optimize_trsv"); +} + +} // namespace detail + +template +void init_trsv_descr(sycl::queue &queue, trsv_descr_t *p_trsv_descr) { + // Ensure that a cusparse handle is created before any other cuSPARSE function is called. + detail::CusparseScopedContextHandler sc(queue); + sc.get_handle(queue); + + auto internal_trsv_descr = new detail::trsv_descr(); + internal_trsv_descr->compute_type = detail::CudaEnumType::value; + auto status = cusparseSpSV_createDescr(&internal_trsv_descr->cu_descr); + detail::check_status(status, "init_trsv_descr"); + *p_trsv_descr = reinterpret_cast(internal_trsv_descr); +} + +#define INSTANTIATE_INIT_TRSV_DESCR(FP_TYPE, FP_SUFFIX) \ + template void init_trsv_descr(sycl::queue & queue, trsv_descr_t * p_trsv_descr) +FOR_EACH_FP_TYPE(INSTANTIATE_INIT_TRSV_DESCR); +#undef INSTANTIATE_INIT_TRSV_DESCR + +sycl::event release_trsv_descr(sycl::queue &queue, trsv_descr_t trsv_descr, + const std::vector &dependencies) { + return queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependencies); + cgh.host_task([=]() { + auto internal_trsv_descr = reinterpret_cast(trsv_descr); + auto status = cusparseSpSV_destroyDescr(internal_trsv_descr->cu_descr); + detail::check_status(status, "release_trsv_descr"); + delete internal_trsv_descr; + }); + }); +} + +sycl::event trsv_buffer_size(sycl::queue &queue, uplo uplo_val, transpose transpose_val, + diag diag_val, matrix_handle_t A_handle, dense_vector_handle_t x, + dense_vector_handle_t y, trsv_alg alg, trsv_descr_t trsv_descr, + std::int64_t &temp_buffer_size, + const std::vector &dependencies) { + auto internal_A = *reinterpret_cast(A_handle); + auto internal_x = *reinterpret_cast(x); + auto internal_y = *reinterpret_cast(y); + detail::check_all_containers_compatible(__FUNCTION__, internal_A, internal_x, internal_y); + auto functor = [=, &temp_buffer_size](detail::CusparseScopedContextHandler &sc) { + auto cu_handle = sc.get_handle(queue); + auto cu_a = internal_A.cu_handle; + auto cu_x = internal_x.cu_handle; + auto cu_y = internal_y.cu_handle; + detail::set_matrix_attributes(__FUNCTION__, cu_a, uplo_val, diag_val); + auto cu_op = detail::get_cuda_operation(transpose_val); + auto internal_trsv_descr = reinterpret_cast(trsv_descr); + auto cu_alpha = internal_trsv_descr->get_alpha(); + auto cu_type = internal_trsv_descr->compute_type; + auto cu_alg = detail::get_cuda_trsv_alg(alg); + auto cu_descr = internal_trsv_descr->cu_descr; + std::size_t cu_buffer_size; + auto status = cusparseSpSV_bufferSize(cu_handle, cu_op, cu_alpha, cu_a, cu_x, cu_y, cu_type, + cu_alg, cu_descr, &cu_buffer_size); + detail::check_status(status, __FUNCTION__); + temp_buffer_size = detail::safe_cast(cu_buffer_size, __FUNCTION__); + }; + return detail::dispatch_submit(queue, dependencies, functor, internal_A, internal_x, + internal_y); +} + +void optimize_trsv(sycl::queue &queue, uplo uplo_val, transpose transpose_val, diag diag_val, + matrix_handle_t A_handle, trsv_alg alg, trsv_descr_t trsv_descr, + std::int64_t /*temp_buffer_size*/, sycl::buffer temp_buffer) { + auto event = queue.submit([&](sycl::handler &cgh) { + // TODO(Romain): Get accessor from data handles + auto temp_buffer_acc = temp_buffer.template get_access(cgh); + detail::submit_host_task(cgh, queue, [=](detail::CusparseScopedContextHandler &sc) { + auto cu_handle = sc.get_handle(queue); + auto temp_buffer_ptr = sc.get_mem(temp_buffer_acc); + detail::optimize_trsv_impl(cu_handle, uplo_val, transpose_val, diag_val, A_handle, alg, + trsv_descr, temp_buffer_ptr); + }); + }); + event.wait_and_throw(); +} +sycl::event optimize_trsv(sycl::queue &queue, uplo uplo_val, transpose transpose_val, diag diag_val, + matrix_handle_t A_handle, trsv_alg alg, trsv_descr_t trsv_descr, + std::int64_t /*temp_buffer_size*/, void *temp_buffer, + const std::vector &dependencies) { + return queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependencies); + detail::submit_host_task(cgh, queue, [=](detail::CusparseScopedContextHandler &sc) { + auto cu_handle = sc.get_handle(queue); + detail::optimize_trsv_impl(cu_handle, uplo_val, transpose_val, diag_val, A_handle, alg, + trsv_descr, temp_buffer); + }); + }); +} + +sycl::event trsv(sycl::queue &queue, uplo uplo_val, transpose transpose_val, diag diag_val, + matrix_handle_t A_handle, dense_vector_handle_t x, dense_vector_handle_t y, + trsv_alg alg, trsv_descr_t trsv_descr, + const std::vector &dependencies) { + auto internal_A = *reinterpret_cast(A_handle); + auto internal_x = *reinterpret_cast(x); + auto internal_y = *reinterpret_cast(y); + detail::check_all_containers_compatible(__FUNCTION__, internal_A, internal_x, internal_y); + auto functor = [=](detail::CusparseScopedContextHandler &sc) { + auto cu_handle = sc.get_handle(queue); + auto cu_a = internal_A.cu_handle; + auto cu_x = internal_x.cu_handle; + auto cu_y = internal_y.cu_handle; + detail::set_matrix_attributes(__FUNCTION__, cu_a, uplo_val, diag_val); + auto cu_op = detail::get_cuda_operation(transpose_val); + auto internal_trsv_descr = reinterpret_cast(trsv_descr); + auto cu_alpha = internal_trsv_descr->get_alpha(); + auto cu_type = internal_trsv_descr->compute_type; + auto cu_alg = detail::get_cuda_trsv_alg(alg); + auto cu_descr = internal_trsv_descr->cu_descr; + //TODO(Romain): Investigate floating point exception + auto status = cusparseSpSV_solve(cu_handle, cu_op, cu_alpha, cu_a, cu_x, cu_y, cu_type, + cu_alg, cu_descr); + detail::check_status(status, __FUNCTION__); + }; + return detail::dispatch_submit(queue, dependencies, functor, internal_A, internal_x, + internal_y); +} + +} // namespace oneapi::mkl::sparse::cusparse diff --git a/src/sparse_blas/backends/mkl_common/mkl_helper.hpp b/src/sparse_blas/backends/mkl_common/mkl_helper.hpp index da5235ee0..af51d4cf0 100644 --- a/src/sparse_blas/backends/mkl_common/mkl_helper.hpp +++ b/src/sparse_blas/backends/mkl_common/mkl_helper.hpp @@ -17,6 +17,9 @@ * SPDX-License-Identifier: Apache-2.0 *******************************************************************************/ +#ifndef _ONEMKL_SPARSE_BLAS_BACKENDS_MKL_COMMON_MKL_HELPER_HPP_ +#define _ONEMKL_SPARSE_BLAS_BACKENDS_MKL_COMMON_MKL_HELPER_HPP_ + // MKLCPU and MKLGPU backends include // This include defines its own oneapi::mkl::sparse namespace with some of the types that are used here: matrix_handle_t, index_base, transpose, uolo, diag. #include @@ -54,3 +57,5 @@ inline auto get_handle(detail::matrix_handle *handle) { #define FOR_EACH_FP_AND_INT_TYPE(INSTANTIATE_MACRO) \ FOR_EACH_FP_AND_INT_TYPE_HELPER(INSTANTIATE_MACRO, std::int32_t); \ FOR_EACH_FP_AND_INT_TYPE_HELPER(INSTANTIATE_MACRO, std::int64_t) + +#endif // _ONEMKL_SPARSE_BLAS_BACKENDS_MKL_COMMON_MKL_HELPER_HPP_ diff --git a/src/sparse_blas/function_table.hpp b/src/sparse_blas/function_table.hpp index 57279fb3f..f427f72fe 100644 --- a/src/sparse_blas/function_table.hpp +++ b/src/sparse_blas/function_table.hpp @@ -23,16 +23,40 @@ #include "oneapi/mkl/sparse_blas/types.hpp" #include "sparse_blas/macros.hpp" -#define DEFINE_SET_CSR_DATA(FP_TYPE, FP_SUFFIX, INT_TYPE, INT_SUFFIX) \ - void (*set_csr_data_buffer##FP_SUFFIX##INT_SUFFIX)( \ - sycl::queue & queue, oneapi::mkl::sparse::matrix_handle_t handle, INT_TYPE num_rows, \ - INT_TYPE num_cols, INT_TYPE nnz, oneapi::mkl::index_base index, \ - sycl::buffer & row_ptr, sycl::buffer & col_ind, \ - sycl::buffer & val); \ - sycl::event (*set_csr_data_usm##FP_SUFFIX##INT_SUFFIX)( \ - sycl::queue & queue, oneapi::mkl::sparse::matrix_handle_t handle, INT_TYPE num_rows, \ - INT_TYPE num_cols, INT_TYPE nnz, oneapi::mkl::index_base index, INT_TYPE * row_ptr, \ - INT_TYPE * col_ind, FP_TYPE * val, const std::vector &dependencies) +#define DEFINE_CREATE_DENSE_VECTOR(FP_TYPE, FP_SUFFIX) \ + void (*create_dense_vector_buffer##FP_SUFFIX)( \ + sycl::queue & queue, oneapi::mkl::sparse::dense_vector_handle_t * p_dvhandle, \ + std::int64_t size, sycl::buffer & val); \ + sycl::event (*create_dense_vector_usm##FP_SUFFIX)( \ + sycl::queue & queue, oneapi::mkl::sparse::dense_vector_handle_t * p_dvhandle, \ + std::int64_t size, FP_TYPE * val, const std::vector &dependencies) + +#define DEFINE_CREATE_DENSE_MATRIX(FP_TYPE, FP_SUFFIX) \ + void (*create_dense_matrix_buffer##FP_SUFFIX)( \ + sycl::queue & queue, oneapi::mkl::sparse::dense_matrix_handle_t * p_dmhandle, \ + std::int64_t num_rows, std::int64_t num_cols, std::int64_t ld, \ + oneapi::mkl::layout dense_layout, sycl::buffer & val); \ + sycl::event (*create_dense_matrix_usm##FP_SUFFIX)( \ + sycl::queue & queue, oneapi::mkl::sparse::dense_matrix_handle_t * p_dmhandle, \ + std::int64_t num_rows, std::int64_t num_cols, std::int64_t ld, \ + oneapi::mkl::layout dense_layout, FP_TYPE * val, \ + const std::vector &dependencies) + +#define DEFINE_CREATE_CSR_MATRIX(FP_TYPE, FP_SUFFIX, INT_TYPE, INT_SUFFIX) \ + void (*create_csr_matrix_buffer##FP_SUFFIX##INT_SUFFIX)( \ + sycl::queue & queue, oneapi::mkl::sparse::matrix_handle_t * p_smhandle, \ + std::int64_t num_rows, std::int64_t num_cols, std::int64_t nnz, \ + oneapi::mkl::index_base index, sycl::buffer & row_ptr, \ + sycl::buffer & col_ind, sycl::buffer & val); \ + sycl::event (*create_csr_matrix_usm##FP_SUFFIX##INT_SUFFIX)( \ + sycl::queue & queue, oneapi::mkl::sparse::matrix_handle_t * p_smhandle, \ + std::int64_t num_rows, std::int64_t num_cols, std::int64_t nnz, \ + oneapi::mkl::index_base index, INT_TYPE * row_ptr, INT_TYPE * col_ind, FP_TYPE * val, \ + const std::vector &dependencies) + +#define DEFINE_INIT_TRSV_DESCR(FP_TYPE, FP_SUFFIX) \ + void (*init_trsv_descr##FP_SUFFIX)(sycl::queue & queue, \ + oneapi::mkl::sparse::trsv_descr_t * p_trsv_descr) #define DEFINE_GEMV(FP_TYPE, FP_SUFFIX) \ void (*gemv_buffer##FP_SUFFIX)( \ @@ -44,16 +68,6 @@ oneapi::mkl::sparse::matrix_handle_t A_handle, const FP_TYPE *x, const FP_TYPE beta, \ FP_TYPE *y, const std::vector &dependencies) -#define DEFINE_TRSV(FP_TYPE, FP_SUFFIX) \ - void (*trsv_buffer##FP_SUFFIX)( \ - sycl::queue & queue, oneapi::mkl::uplo uplo_val, oneapi::mkl::transpose transpose_val, \ - oneapi::mkl::diag diag_val, oneapi::mkl::sparse::matrix_handle_t A_handle, \ - sycl::buffer & x, sycl::buffer & y); \ - sycl::event (*trsv_usm##FP_SUFFIX)( \ - sycl::queue & queue, oneapi::mkl::uplo uplo_val, oneapi::mkl::transpose transpose_val, \ - oneapi::mkl::diag diag_val, oneapi::mkl::sparse::matrix_handle_t A_handle, \ - const FP_TYPE *x, FP_TYPE *y, const std::vector &dependencies) - #define DEFINE_GEMM(FP_TYPE, FP_SUFFIX) \ void (*gemm_buffer##FP_SUFFIX)( \ sycl::queue & queue, oneapi::mkl::layout dense_matrix_layout, \ @@ -70,13 +84,39 @@ typedef struct { int version; - void (*init_matrix_handle)(sycl::queue &queue, oneapi::mkl::sparse::matrix_handle_t *p_handle); - sycl::event (*release_matrix_handle)(sycl::queue &queue, - oneapi::mkl::sparse::matrix_handle_t *p_handle, - const std::vector &dependencies); + FOR_EACH_FP_TYPE(DEFINE_CREATE_DENSE_VECTOR); + FOR_EACH_FP_TYPE(DEFINE_CREATE_DENSE_MATRIX); + FOR_EACH_FP_AND_INT_TYPE(DEFINE_CREATE_CSR_MATRIX); + + // Destroy data types + sycl::event (*destroy_dense_vector)(sycl::queue &queue, + oneapi::mkl::sparse::dense_vector_handle_t dvhandle, + const std::vector &dependencies); + sycl::event (*destroy_dense_matrix)(sycl::queue &queue, + oneapi::mkl::sparse::dense_matrix_handle_t dmhandle, + const std::vector &dependencies); + sycl::event (*destroy_csr_matrix)(sycl::queue &queue, + oneapi::mkl::sparse::matrix_handle_t smhandle, + const std::vector &dependencies); + + // Matrix property + void (*set_matrix_property)(sycl::queue &queue, oneapi::mkl::sparse::matrix_handle_t smhandle, + oneapi::mkl::sparse::matrix_property property_value); - FOR_EACH_FP_AND_INT_TYPE(DEFINE_SET_CSR_DATA); + // Operation descriptor + FOR_EACH_FP_TYPE(DEFINE_INIT_TRSV_DESCR); + sycl::event (*release_trsv_descr)(sycl::queue &queue, + oneapi::mkl::sparse::trsv_descr_t trsv_descr, + const std::vector &dependencies); + + // Temporary buffer size + sycl::event (*trsv_buffer_size)( + sycl::queue &queue, oneapi::mkl::uplo uplo_val, oneapi::mkl::transpose transpose_val, + oneapi::mkl::diag diag_val, oneapi::mkl::sparse::matrix_handle_t A_handle, + oneapi::mkl::sparse::dense_vector_handle_t x, oneapi::mkl::sparse::dense_vector_handle_t y, + oneapi::mkl::sparse::trsv_alg alg, oneapi::mkl::sparse::trsv_descr_t trsv_descr, + std::int64_t &temp_buffer_size, const std::vector &dependencies); // optimize_* sycl::event (*optimize_gemm_v1)(sycl::queue &queue, oneapi::mkl::transpose transpose_A, @@ -91,19 +131,41 @@ typedef struct { sycl::event (*optimize_gemv)(sycl::queue &queue, oneapi::mkl::transpose transpose_val, oneapi::mkl::sparse::matrix_handle_t handle, const std::vector &dependencies); - sycl::event (*optimize_trsv)(sycl::queue &queue, oneapi::mkl::uplo uplo_val, + void (*optimize_trsv_buffer)(sycl::queue &queue, oneapi::mkl::uplo uplo_val, oneapi::mkl::transpose transpose_val, oneapi::mkl::diag diag_val, - oneapi::mkl::sparse::matrix_handle_t handle, - const std::vector &dependencies); + oneapi::mkl::sparse::matrix_handle_t A_handle, + oneapi::mkl::sparse::trsv_alg alg, + oneapi::mkl::sparse::trsv_descr_t trsv_descr, + std::int64_t temp_buffer_size, + sycl::buffer temp_buffer); + sycl::event (*optimize_trsv_usm)(sycl::queue &queue, oneapi::mkl::uplo uplo_val, + oneapi::mkl::transpose transpose_val, + oneapi::mkl::diag diag_val, + oneapi::mkl::sparse::matrix_handle_t A_handle, + oneapi::mkl::sparse::trsv_alg alg, + oneapi::mkl::sparse::trsv_descr_t trsv_descr, + std::int64_t temp_buffer_size, void *temp_buffer, + const std::vector &dependencies); FOR_EACH_FP_TYPE(DEFINE_GEMV); - FOR_EACH_FP_TYPE(DEFINE_TRSV); FOR_EACH_FP_TYPE(DEFINE_GEMM); + + sycl::event (*trsv)(sycl::queue &queue, oneapi::mkl::uplo uplo_val, + oneapi::mkl::transpose transpose_val, oneapi::mkl::diag diag_val, + oneapi::mkl::sparse::matrix_handle_t A_handle, + oneapi::mkl::sparse::dense_vector_handle_t x, + oneapi::mkl::sparse::dense_vector_handle_t y, + oneapi::mkl::sparse::trsv_alg alg, + oneapi::mkl::sparse::trsv_descr_t trsv_descr, + const std::vector &dependencies); + } sparse_blas_function_table_t; -#undef DEFINE_SET_CSR_DATA +#undef DEFINE_CREATE_DENSE_VECTOR +#undef DEFINE_CREATE_DENSE_MATRIX +#undef DEFINE_CREATE_CSR_MATRIX +#undef DEFINE_INIT_TRSV_DESCR #undef DEFINE_GEMV -#undef DEFINE_TRSV #undef DEFINE_GEMM #endif // _ONEMKL_SPARSE_BLAS_FUNCTION_TABLE_HPP_ diff --git a/src/sparse_blas/sparse_blas_loader.cpp b/src/sparse_blas/sparse_blas_loader.cpp index 95da6df9c..9b7036c62 100644 --- a/src/sparse_blas/sparse_blas_loader.cpp +++ b/src/sparse_blas/sparse_blas_loader.cpp @@ -30,39 +30,120 @@ static oneapi::mkl::detail::table_initializer function_tables; -void init_matrix_handle(sycl::queue &queue, matrix_handle_t *p_handle) { - auto libkey = get_device_id(queue); - function_tables[libkey].init_matrix_handle(queue, p_handle); -} +#define DEFINE_CREATE_DENSE_VECTOR(FP_TYPE, FP_SUFFIX) \ + template <> \ + void create_dense_vector(sycl::queue &queue, dense_vector_handle_t *p_dvhandle, \ + std::int64_t size, sycl::buffer &val) { \ + auto libkey = get_device_id(queue); \ + function_tables[libkey].create_dense_vector_buffer##FP_SUFFIX(queue, p_dvhandle, size, \ + val); \ + } \ + template <> \ + sycl::event create_dense_vector(sycl::queue &queue, dense_vector_handle_t *p_dvhandle, \ + std::int64_t size, FP_TYPE *val, \ + const std::vector &dependencies) { \ + auto libkey = get_device_id(queue); \ + return function_tables[libkey].create_dense_vector_usm##FP_SUFFIX(queue, p_dvhandle, size, \ + val, dependencies); \ + } +FOR_EACH_FP_TYPE(DEFINE_CREATE_DENSE_VECTOR) +#undef DEFINE_CREATE_DENSE_VECTOR -sycl::event release_matrix_handle(sycl::queue &queue, matrix_handle_t *p_handle, - const std::vector &dependencies) { - auto libkey = get_device_id(queue); - return function_tables[libkey].release_matrix_handle(queue, p_handle, dependencies); -} +#define DEFINE_CREATE_DENSE_MATRIX(FP_TYPE, FP_SUFFIX) \ + template <> \ + void create_dense_matrix(sycl::queue &queue, dense_matrix_handle_t *p_dmhandle, \ + std::int64_t num_rows, std::int64_t num_cols, std::int64_t ld, \ + layout dense_layout, sycl::buffer &val) { \ + auto libkey = get_device_id(queue); \ + function_tables[libkey].create_dense_matrix_buffer##FP_SUFFIX( \ + queue, p_dmhandle, num_rows, num_cols, ld, dense_layout, val); \ + } \ + template <> \ + sycl::event create_dense_matrix(sycl::queue &queue, dense_matrix_handle_t *p_dmhandle, \ + std::int64_t num_rows, std::int64_t num_cols, std::int64_t ld, \ + layout dense_layout, FP_TYPE *val, \ + const std::vector &dependencies) { \ + auto libkey = get_device_id(queue); \ + return function_tables[libkey].create_dense_matrix_usm##FP_SUFFIX( \ + queue, p_dmhandle, num_rows, num_cols, ld, dense_layout, val, dependencies); \ + } +FOR_EACH_FP_TYPE(DEFINE_CREATE_DENSE_MATRIX) +#undef DEFINE_CREATE_DENSE_MATRIX -#define DEFINE_SET_CSR_DATA(FP_TYPE, FP_SUFFIX, INT_TYPE, INT_SUFFIX) \ +#define DEFINE_CREATE_CSR_MATRIX(FP_TYPE, FP_SUFFIX, INT_TYPE, INT_SUFFIX) \ template <> \ - void set_csr_data(sycl::queue &queue, matrix_handle_t handle, INT_TYPE num_rows, \ - INT_TYPE num_cols, INT_TYPE nnz, index_base index, \ - sycl::buffer &row_ptr, sycl::buffer &col_ind, \ - sycl::buffer &val) { \ + void create_csr_matrix(sycl::queue &queue, matrix_handle_t *p_smhandle, std::int64_t num_rows, \ + std::int64_t num_cols, std::int64_t nnz, index_base index, \ + sycl::buffer &row_ptr, sycl::buffer &col_ind, \ + sycl::buffer &val) { \ auto libkey = get_device_id(queue); \ - function_tables[libkey].set_csr_data_buffer##FP_SUFFIX##INT_SUFFIX( \ - queue, handle, num_rows, num_cols, nnz, index, row_ptr, col_ind, val); \ + function_tables[libkey].create_csr_matrix_buffer##FP_SUFFIX##INT_SUFFIX( \ + queue, p_smhandle, num_rows, num_cols, nnz, index, row_ptr, col_ind, val); \ } \ template <> \ - sycl::event set_csr_data(sycl::queue &queue, matrix_handle_t handle, INT_TYPE num_rows, \ - INT_TYPE num_cols, INT_TYPE nnz, index_base index, INT_TYPE *row_ptr, \ - INT_TYPE *col_ind, FP_TYPE *val, \ - const std::vector &dependencies) { \ + sycl::event create_csr_matrix(sycl::queue &queue, matrix_handle_t *p_smhandle, \ + std::int64_t num_rows, std::int64_t num_cols, std::int64_t nnz, \ + index_base index, INT_TYPE *row_ptr, INT_TYPE *col_ind, \ + FP_TYPE *val, const std::vector &dependencies) { \ auto libkey = get_device_id(queue); \ - return function_tables[libkey].set_csr_data_usm##FP_SUFFIX##INT_SUFFIX( \ - queue, handle, num_rows, num_cols, nnz, index, row_ptr, col_ind, val, dependencies); \ + return function_tables[libkey].create_csr_matrix_usm##FP_SUFFIX##INT_SUFFIX( \ + queue, p_smhandle, num_rows, num_cols, nnz, index, row_ptr, col_ind, val, \ + dependencies); \ } -FOR_EACH_FP_AND_INT_TYPE(DEFINE_SET_CSR_DATA) -#undef DEFINE_SET_CSR_DATA +FOR_EACH_FP_AND_INT_TYPE(DEFINE_CREATE_CSR_MATRIX) +#undef DEFINE_CREATE_CSR_MATRIX + +sycl::event destroy_dense_vector(sycl::queue &queue, dense_vector_handle_t p_dvhandle, + const std::vector &dependencies) { + auto libkey = get_device_id(queue); + return function_tables[libkey].destroy_dense_vector(queue, p_dvhandle, dependencies); +} + +sycl::event destroy_dense_matrix(sycl::queue &queue, dense_matrix_handle_t p_dmhandle, + const std::vector &dependencies) { + auto libkey = get_device_id(queue); + return function_tables[libkey].destroy_dense_matrix(queue, p_dmhandle, dependencies); +} + +sycl::event destroy_csr_matrix(sycl::queue &queue, matrix_handle_t p_smhandle, + const std::vector &dependencies) { + auto libkey = get_device_id(queue); + return function_tables[libkey].destroy_csr_matrix(queue, p_smhandle, dependencies); +} + +void set_matrix_property(sycl::queue &queue, matrix_handle_t smhandle, + matrix_property property_value) { + auto libkey = get_device_id(queue); + return function_tables[libkey].set_matrix_property(queue, smhandle, property_value); +} + +#define DEFINE_INIT_TRSV_DESCR(FP_TYPE, FP_SUFFIX) \ + template <> \ + void init_trsv_descr(sycl::queue & queue, trsv_descr_t * p_trsv_descr) { \ + auto libkey = get_device_id(queue); \ + return function_tables[libkey].init_trsv_descr##FP_SUFFIX(queue, p_trsv_descr); \ + } +FOR_EACH_FP_TYPE(DEFINE_INIT_TRSV_DESCR) +#undef DEFINE_INIT_TRSV_DESCR + +sycl::event release_trsv_descr(sycl::queue &queue, trsv_descr_t trsv_descr, + const std::vector &dependencies) { + auto libkey = get_device_id(queue); + return function_tables[libkey].release_trsv_descr(queue, trsv_descr, dependencies); +} + +// Temporary buffer size +sycl::event trsv_buffer_size(sycl::queue &queue, uplo uplo_val, transpose transpose_val, + diag diag_val, matrix_handle_t A_handle, dense_vector_handle_t x, + dense_vector_handle_t y, trsv_alg alg, trsv_descr_t trsv_descr, + std::int64_t &temp_buffer_size, + const std::vector &dependencies) { + auto libkey = get_device_id(queue); + return function_tables[libkey].trsv_buffer_size(queue, uplo_val, transpose_val, diag_val, + A_handle, x, y, alg, trsv_descr, + temp_buffer_size, dependencies); +} sycl::event optimize_gemm(sycl::queue &queue, transpose transpose_A, matrix_handle_t handle, const std::vector &dependencies) { @@ -84,11 +165,23 @@ sycl::event optimize_gemv(sycl::queue &queue, transpose transpose_val, matrix_ha return function_tables[libkey].optimize_gemv(queue, transpose_val, handle, dependencies); } +void optimize_trsv(sycl::queue &queue, uplo uplo_val, transpose transpose_val, diag diag_val, + matrix_handle_t A_handle, trsv_alg alg, trsv_descr_t trsv_descr, + std::int64_t temp_buffer_size, sycl::buffer temp_buffer) { + auto libkey = get_device_id(queue); + return function_tables[libkey].optimize_trsv_buffer(queue, uplo_val, transpose_val, diag_val, + A_handle, alg, trsv_descr, temp_buffer_size, + temp_buffer); +} + sycl::event optimize_trsv(sycl::queue &queue, uplo uplo_val, transpose transpose_val, diag diag_val, - matrix_handle_t handle, const std::vector &dependencies) { + matrix_handle_t A_handle, trsv_alg alg, trsv_descr_t trsv_descr, + std::int64_t temp_buffer_size, void *temp_buffer, + const std::vector &dependencies) { auto libkey = get_device_id(queue); - return function_tables[libkey].optimize_trsv(queue, uplo_val, transpose_val, diag_val, handle, - dependencies); + return function_tables[libkey].optimize_trsv_usm(queue, uplo_val, transpose_val, diag_val, + A_handle, alg, trsv_descr, temp_buffer_size, + temp_buffer, dependencies); } #define DEFINE_GEMV(FP_TYPE, FP_SUFFIX) \ @@ -112,27 +205,6 @@ sycl::event optimize_trsv(sycl::queue &queue, uplo uplo_val, transpose transpose FOR_EACH_FP_TYPE(DEFINE_GEMV) #undef DEFINE_GEMV -#define DEFINE_TRSV(FP_TYPE, FP_SUFFIX) \ - template <> \ - void trsv(sycl::queue &queue, uplo uplo_val, transpose transpose_val, diag diag_val, \ - matrix_handle_t A_handle, sycl::buffer &x, \ - sycl::buffer &y) { \ - auto libkey = get_device_id(queue); \ - function_tables[libkey].trsv_buffer##FP_SUFFIX(queue, uplo_val, transpose_val, diag_val, \ - A_handle, x, y); \ - } \ - template <> \ - sycl::event trsv(sycl::queue &queue, uplo uplo_val, transpose transpose_val, diag diag_val, \ - matrix_handle_t A_handle, const FP_TYPE *x, FP_TYPE *y, \ - const std::vector &dependencies) { \ - auto libkey = get_device_id(queue); \ - return function_tables[libkey].trsv_usm##FP_SUFFIX( \ - queue, uplo_val, transpose_val, diag_val, A_handle, x, y, dependencies); \ - } - -FOR_EACH_FP_TYPE(DEFINE_TRSV) -#undef DEFINE_TRSV - #define DEFINE_GEMM(FP_TYPE, FP_SUFFIX) \ template <> \ void gemm(sycl::queue &queue, layout dense_matrix_layout, transpose transpose_A, \ @@ -159,4 +231,13 @@ FOR_EACH_FP_TYPE(DEFINE_TRSV) FOR_EACH_FP_TYPE(DEFINE_GEMM) #undef DEFINE_GEMM +sycl::event trsv(sycl::queue &queue, uplo uplo_val, transpose transpose_val, diag diag_val, + matrix_handle_t A_handle, dense_vector_handle_t x, dense_vector_handle_t y, + trsv_alg alg, trsv_descr_t trsv_descr, + const std::vector &dependencies) { + auto libkey = get_device_id(queue); + return function_tables[libkey].trsv(queue, uplo_val, transpose_val, diag_val, A_handle, x, y, + alg, trsv_descr, dependencies); +} + } // namespace oneapi::mkl::sparse diff --git a/tests/unit_tests/CMakeLists.txt b/tests/unit_tests/CMakeLists.txt index e7fe8e110..5fc56d04a 100644 --- a/tests/unit_tests/CMakeLists.txt +++ b/tests/unit_tests/CMakeLists.txt @@ -178,6 +178,11 @@ foreach(domain ${TARGET_DOMAINS}) list(APPEND ONEMKL_LIBRARIES_${domain} onemkl_dft_portfft) endif() + if(domain STREQUAL "sparse_blas" AND ENABLE_CUSPARSE_BACKEND) + add_dependencies(test_main_${domain}_ct onemkl_${domain}_cusparse) + list(APPEND ONEMKL_LIBRARIES_${domain} onemkl_${domain}_cusparse) + endif() + target_link_libraries(test_main_${domain}_ct PUBLIC gtest gtest_main diff --git a/tests/unit_tests/include/test_helper.hpp b/tests/unit_tests/include/test_helper.hpp index 7e0024195..cdeceec1e 100644 --- a/tests/unit_tests/include/test_helper.hpp +++ b/tests/unit_tests/include/test_helper.hpp @@ -176,6 +176,13 @@ #define TEST_RUN_PORTFFT_SELECT(q, func, ...) #endif +#ifdef ENABLE_CUSPARSE_BACKEND +#define TEST_RUN_NVIDIAGPU_CUSPARSE_SELECT(q, func, ...) \ + func(oneapi::mkl::backend_selector{ q }, __VA_ARGS__) +#else +#define TEST_RUN_NVIDIAGPU_CUSPARSE_SELECT(q, func, ...) +#endif + #ifndef __HIPSYCL__ #define CHECK_HOST_OR_CPU(q) q.get_device().is_cpu() #else @@ -216,6 +223,7 @@ TEST_RUN_NVIDIAGPU_CUBLAS_SELECT(q, func, __VA_ARGS__); \ TEST_RUN_NVIDIAGPU_CUSOLVER_SELECT(q, func, __VA_ARGS__); \ TEST_RUN_NVIDIAGPU_CURAND_SELECT(q, func, __VA_ARGS__); \ + TEST_RUN_NVIDIAGPU_CUSPARSE_SELECT(q, func, __VA_ARGS__); \ } \ else if (vendor_id == AMD_ID) { \ TEST_RUN_AMDGPU_ROCBLAS_SELECT(q, func, __VA_ARGS__); \ diff --git a/tests/unit_tests/main_test.cpp b/tests/unit_tests/main_test.cpp index bac3f8c83..fc208da09 100644 --- a/tests/unit_tests/main_test.cpp +++ b/tests/unit_tests/main_test.cpp @@ -122,7 +122,8 @@ int main(int argc, char** argv) { #endif #if !defined(ENABLE_CUBLAS_BACKEND) && !defined(ENABLE_CURAND_BACKEND) && \ !defined(ENABLE_CUSOLVER_BACKEND) && !defined(ENABLE_PORTBLAS_BACKEND_NVIDIA_GPU) && \ - !defined(ENABLE_CUFFT_BACKEND) && !defined(ENABLE_PORTFFT_BACKEND) + !defined(ENABLE_CUFFT_BACKEND) && !defined(ENABLE_PORTFFT_BACKEND) && \ + !defined(ENABLE_CUSPARSE_BACKEND) if (dev.is_gpu() && vendor_id == NVIDIA_ID) continue; #endif diff --git a/tests/unit_tests/sparse_blas/include/test_common.hpp b/tests/unit_tests/sparse_blas/include/test_common.hpp index fd1e91a47..1d0f9172d 100644 --- a/tests/unit_tests/sparse_blas/include/test_common.hpp +++ b/tests/unit_tests/sparse_blas/include/test_common.hpp @@ -209,13 +209,6 @@ void shuffle_data(const intType *ia, intType *ja, fpType *a, const std::size_t n } } -inline void wait_and_free(sycl::queue &main_queue, oneapi::mkl::sparse::matrix_handle_t *p_handle) { - main_queue.wait(); - sycl::event ev_release; - CALL_RT_OR_CT(ev_release = oneapi::mkl::sparse::release_matrix_handle, main_queue, p_handle); - ev_release.wait(); -} - template bool check_equal(fpType x, fpType x_ref, double abs_error_margin, double rel_error_margin, std::ostream &out) { diff --git a/tests/unit_tests/sparse_blas/source/CMakeLists.txt b/tests/unit_tests/sparse_blas/source/CMakeLists.txt index 3a1fcb288..542b1c063 100644 --- a/tests/unit_tests/sparse_blas/source/CMakeLists.txt +++ b/tests/unit_tests/sparse_blas/source/CMakeLists.txt @@ -18,12 +18,13 @@ #=============================================================================== set(SPBLAS_SOURCES - "sparse_gemm_buffer.cpp" - "sparse_gemm_usm.cpp" - "sparse_gemv_buffer.cpp" - "sparse_gemv_usm.cpp" + #TODO(Romain): Enable tests later + #"sparse_gemm_buffer.cpp" + #"sparse_gemm_usm.cpp" + #"sparse_gemv_buffer.cpp" + #"sparse_gemv_usm.cpp" "sparse_trsv_buffer.cpp" - "sparse_trsv_usm.cpp" + #"sparse_trsv_usm.cpp" ) include(WarningsUtils) diff --git a/tests/unit_tests/sparse_blas/source/sparse_trsv_buffer.cpp b/tests/unit_tests/sparse_blas/source/sparse_trsv_buffer.cpp index 00ec6e5ed..02f579cce 100644 --- a/tests/unit_tests/sparse_blas/source/sparse_trsv_buffer.cpp +++ b/tests/unit_tests/sparse_blas/source/sparse_trsv_buffer.cpp @@ -42,9 +42,10 @@ namespace { template int test(sycl::device *dev, intType m, double density_A_matrix, oneapi::mkl::index_base index, oneapi::mkl::uplo uplo_val, oneapi::mkl::transpose transpose_val, - oneapi::mkl::diag diag_val, bool use_optimize) { + oneapi::mkl::diag diag_val) { sycl::queue main_queue(*dev, exception_handler_t()); + oneapi::mkl::sparse::trsv_alg alg = oneapi::mkl::sparse::trsv_alg::default_alg; intType int_index = (index == oneapi::mkl::index_base::zero) ? 0 : 1; const std::size_t mu = static_cast(m); @@ -67,12 +68,8 @@ int test(sycl::device *dev, intType m, double density_A_matrix, oneapi::mkl::ind std::vector y_host(mu, -2.0f); std::vector y_ref_host(y_host); - // Intel oneMKL does not support unsorted data if - // `sparse::optimize_trsv()` is not called first. - if (use_optimize) { - // Shuffle ordering of column indices/values to test sortedness - shuffle_data(ia_host.data(), ja_host.data(), a_host.data(), mu); - } + // Shuffle ordering of column indices/values to test sortedness + shuffle_data(ia_host.data(), ja_host.data(), a_host.data(), mu); auto ia_buf = make_buffer(ia_host); auto ja_buf = make_buffer(ja_host); @@ -80,23 +77,40 @@ int test(sycl::device *dev, intType m, double density_A_matrix, oneapi::mkl::ind auto x_buf = make_buffer(x_host); auto y_buf = make_buffer(y_host); - sycl::event ev_release; - oneapi::mkl::sparse::matrix_handle_t handle = nullptr; + sycl::event ev_release, destroy_x, destroy_y, destroy_a; + oneapi::mkl::sparse::dense_vector_handle_t x_handle = nullptr, y_handle = nullptr; + oneapi::mkl::sparse::matrix_handle_t a_handle = nullptr; + oneapi::mkl::sparse::trsv_descr_t trsv_descr = nullptr; try { - CALL_RT_OR_CT(oneapi::mkl::sparse::init_matrix_handle, main_queue, &handle); - - CALL_RT_OR_CT(oneapi::mkl::sparse::set_csr_data, main_queue, handle, m, m, nnz, index, - ia_buf, ja_buf, a_buf); - - if (use_optimize) { - CALL_RT_OR_CT(oneapi::mkl::sparse::optimize_trsv, main_queue, uplo_val, transpose_val, - diag_val, handle); - } - - CALL_RT_OR_CT(oneapi::mkl::sparse::trsv, main_queue, uplo_val, transpose_val, diag_val, - handle, x_buf, y_buf); - - CALL_RT_OR_CT(ev_release = oneapi::mkl::sparse::release_matrix_handle, main_queue, &handle); + CALL_RT_OR_CT(oneapi::mkl::sparse::create_dense_vector, main_queue, &x_handle, m, x_buf); + CALL_RT_OR_CT(oneapi::mkl::sparse::create_dense_vector, main_queue, &y_handle, m, y_buf); + + CALL_RT_OR_CT(oneapi::mkl::sparse::create_csr_matrix, main_queue, &a_handle, m, m, nnz, + index, ia_buf, ja_buf, a_buf); + + CALL_RT_OR_CT(oneapi::mkl::sparse::init_trsv_descr, main_queue, &trsv_descr); + + sycl::event buffer_size_ev; + std::int64_t temp_buffer_size = 0; + CALL_RT_OR_CT(buffer_size_ev = oneapi::mkl::sparse::trsv_buffer_size, main_queue, uplo_val, + transpose_val, diag_val, a_handle, x_handle, y_handle, alg, trsv_descr, + temp_buffer_size); + buffer_size_ev.wait_and_throw(); + + sycl::buffer trsv_temp_buffer( + (sycl::range<1>(static_cast(temp_buffer_size)))); + CALL_RT_OR_CT(oneapi::mkl::sparse::optimize_trsv, main_queue, uplo_val, transpose_val, + diag_val, a_handle, alg, trsv_descr, temp_buffer_size, trsv_temp_buffer); + + sycl::event trsv_ev; + CALL_RT_OR_CT(trsv_ev = oneapi::mkl::sparse::trsv, main_queue, uplo_val, transpose_val, + diag_val, a_handle, x_handle, y_handle, alg, trsv_descr); + trsv_ev.wait_and_throw(); + + CALL_RT_OR_CT(ev_release = oneapi::mkl::sparse::release_trsv_descr, main_queue, trsv_descr); + CALL_RT_OR_CT(destroy_x = oneapi::mkl::sparse::destroy_dense_vector, main_queue, x_handle); + CALL_RT_OR_CT(destroy_y = oneapi::mkl::sparse::destroy_dense_vector, main_queue, y_handle); + CALL_RT_OR_CT(destroy_a = oneapi::mkl::sparse::destroy_csr_matrix, main_queue, a_handle); } catch (const sycl::exception &e) { std::cout << "Caught synchronous SYCL exception during sparse TRSV:\n" @@ -105,7 +119,15 @@ int test(sycl::device *dev, intType m, double density_A_matrix, oneapi::mkl::ind return 0; } catch (const oneapi::mkl::unimplemented &e) { - wait_and_free(main_queue, &handle); + main_queue.wait(); + CALL_RT_OR_CT(ev_release = oneapi::mkl::sparse::release_trsv_descr, main_queue, trsv_descr); + CALL_RT_OR_CT(destroy_x = oneapi::mkl::sparse::destroy_dense_vector, main_queue, x_handle); + CALL_RT_OR_CT(destroy_y = oneapi::mkl::sparse::destroy_dense_vector, main_queue, y_handle); + CALL_RT_OR_CT(destroy_a = oneapi::mkl::sparse::destroy_csr_matrix, main_queue, a_handle); + ev_release.wait(); + destroy_x.wait(); + destroy_y.wait(); + destroy_a.wait(); return test_skipped; } catch (const std::runtime_error &error) { @@ -123,6 +145,9 @@ int test(sycl::device *dev, intType m, double density_A_matrix, oneapi::mkl::ind bool valid = check_equal_vector(y_acc, y_ref_host); ev_release.wait_and_throw(); + destroy_x.wait_and_throw(); + destroy_y.wait_and_throw(); + destroy_a.wait_and_throw(); return static_cast(valid); } @@ -145,41 +170,30 @@ auto test_helper(sycl::device *dev, oneapi::mkl::transpose transpose_val, int &n oneapi::mkl::uplo lower = oneapi::mkl::uplo::lower; oneapi::mkl::diag nonunit = oneapi::mkl::diag::nonunit; int m = 277; - bool use_optimize = true; // Basic test - EXPECT_TRUE_OR_FUTURE_SKIP(test(dev, m, density_A_matrix, index_zero, lower, - transpose_val, nonunit, use_optimize), - num_passed, num_skipped); + EXPECT_TRUE_OR_FUTURE_SKIP( + test(dev, m, density_A_matrix, index_zero, lower, transpose_val, nonunit), + num_passed, num_skipped); // Test index_base 1 EXPECT_TRUE_OR_FUTURE_SKIP(test(dev, m, density_A_matrix, oneapi::mkl::index_base::one, - lower, transpose_val, nonunit, use_optimize), + lower, transpose_val, nonunit), num_passed, num_skipped); // Test upper triangular matrix - EXPECT_TRUE_OR_FUTURE_SKIP( - test(dev, m, density_A_matrix, index_zero, oneapi::mkl::uplo::upper, transpose_val, - nonunit, use_optimize), - num_passed, num_skipped); + EXPECT_TRUE_OR_FUTURE_SKIP(test(dev, m, density_A_matrix, index_zero, + oneapi::mkl::uplo::upper, transpose_val, nonunit), + num_passed, num_skipped); // Test unit diagonal matrix EXPECT_TRUE_OR_FUTURE_SKIP(test(dev, m, density_A_matrix, index_zero, lower, - transpose_val, oneapi::mkl::diag::unit, use_optimize), + transpose_val, oneapi::mkl::diag::unit), num_passed, num_skipped); // Temporarily disable trsv using long indices on GPU if (!dev->is_gpu()) { // Test int64 indices - EXPECT_TRUE_OR_FUTURE_SKIP(test(dev, 15L, density_A_matrix, index_zero, lower, - transpose_val, nonunit, use_optimize), - num_passed, num_skipped); + EXPECT_TRUE_OR_FUTURE_SKIP( + test(dev, 15L, density_A_matrix, index_zero, lower, transpose_val, nonunit), + num_passed, num_skipped); } - // Test lower without optimize_trsv - EXPECT_TRUE_OR_FUTURE_SKIP( - test(dev, m, density_A_matrix, index_zero, lower, transpose_val, nonunit, false), - num_passed, num_skipped); - // Test upper without optimize_trsv - EXPECT_TRUE_OR_FUTURE_SKIP( - test(dev, m, density_A_matrix, index_zero, oneapi::mkl::uplo::upper, transpose_val, - nonunit, false), - num_passed, num_skipped); } TEST_P(SparseTrsvBufferTests, RealSinglePrecision) {