diff --git a/src/interface/blas2/backend/amd_gpu.hpp b/src/interface/blas2/backend/amd_gpu.hpp index 6818afa2a..156d1a4eb 100644 --- a/src/interface/blas2/backend/amd_gpu.hpp +++ b/src/interface/blas2/backend/amd_gpu.hpp @@ -37,15 +37,15 @@ typename SB_Handle::event_t _gemv(SB_Handle& sb_handle, index_t _M, index_t _N, index_t _lda, container_t1 _vx, increment_t _incx, element_t _beta, container_t2 _vy, increment_t _incy) { - static constexpr uint32_t cache_line_size = 256; + static constexpr uint32_t cache_line_size = 128; if (trn == transpose_type::Normal) { return blas::internal::_gemv_impl<256, cache_line_size, gemv_memory_t::local, trn>( sb_handle, _M, _N, _alpha, _mA, _lda, _vx, _incx, _beta, _vy, _incy); } else { - return blas::internal::_gemv_impl<64, cache_line_size, gemv_memory_t::local, - trn>(sb_handle, _M, _N, _alpha, _mA, _lda, - _vx, _incx, _beta, _vy, _incy); + return blas::internal::_gemv_impl<128, cache_line_size, + gemv_memory_t::local, trn>( + sb_handle, _M, _N, _alpha, _mA, _lda, _vx, _incx, _beta, _vy, _incy); } } } // namespace backend diff --git a/src/interface/blas2/backend/intel_gpu.hpp b/src/interface/blas2/backend/intel_gpu.hpp index ea4585ba3..3571e37cb 100644 --- a/src/interface/blas2/backend/intel_gpu.hpp +++ b/src/interface/blas2/backend/intel_gpu.hpp @@ -38,10 +38,18 @@ typename SB_Handle::event_t _gemv(SB_Handle& sb_handle, index_t _M, index_t _N, increment_t _incx, element_t _beta, container_t2 _vy, increment_t _incy) { if (trn == transpose_type::Normal) { - return blas::internal::_gemv_impl<256, 32, gemv_memory_t::local, trn>( - sb_handle, _M, _N, _alpha, _mA, _lda, _vx, _incx, _beta, _vy, _incy); + if (_N < 8192) { + return blas::internal::_gemv_impl<128, 64, gemv_memory_t::local, trn>( + sb_handle, _M, _N, _alpha, _mA, _lda, _vx, _incx, _beta, _vy, _incy); + } else if (_N < 16384) { + return blas::internal::_gemv_impl<256, 64, gemv_memory_t::local, trn>( + sb_handle, _M, _N, _alpha, _mA, _lda, _vx, _incx, _beta, _vy, _incy); + } else { + return blas::internal::_gemv_impl<512, 64, gemv_memory_t::local, trn>( + sb_handle, _M, _N, _alpha, _mA, _lda, _vx, _incx, _beta, _vy, _incy); + } } else { - return blas::internal::_gemv_impl<128, 32, gemv_memory_t::local, trn>( + return blas::internal::_gemv_impl<128, 64, gemv_memory_t::local, trn>( sb_handle, _M, _N, _alpha, _mA, _lda, _vx, _incx, _beta, _vy, _incy); } } diff --git a/src/interface/blas2/backend/nvidia_gpu.hpp b/src/interface/blas2/backend/nvidia_gpu.hpp index 2b1c84790..df968fedc 100644 --- a/src/interface/blas2/backend/nvidia_gpu.hpp +++ b/src/interface/blas2/backend/nvidia_gpu.hpp @@ -38,10 +38,10 @@ typename SB_Handle::event_t _gemv(SB_Handle& sb_handle, index_t _M, index_t _N, increment_t _incx, element_t _beta, container_t2 _vy, increment_t _incy) { if (trn == transpose_type::Normal) { - return blas::internal::_gemv_impl<256, 32, gemv_memory_t::local, trn>( + return blas::internal::_gemv_impl<256, 128, gemv_memory_t::local, trn>( sb_handle, _M, _N, _alpha, _mA, _lda, _vx, _incx, _beta, _vy, _incy); } else { - return blas::internal::_gemv_impl<128, 32, gemv_memory_t::local, trn>( + return blas::internal::_gemv_impl<128, 128, gemv_memory_t::local, trn>( sb_handle, _M, _N, _alpha, _mA, _lda, _vx, _incx, _beta, _vy, _incy); } } diff --git a/src/interface/blas2/gemv.cpp.in b/src/interface/blas2/gemv.cpp.in index 3e81984b5..d19fe0a60 100644 --- a/src/interface/blas2/gemv.cpp.in +++ b/src/interface/blas2/gemv.cpp.in @@ -22,49 +22,13 @@ * @filename gemv.cpp.in * **************************************************************************/ -#include "container/sycl_iterator.hpp" -#include "sb_handle/sycl_blas_handle.hpp" -#include "sb_handle/kernel_constructor.hpp" #include "interface/blas2_interface.hpp" -#include "operations/blas1_trees.hpp" -#include "operations/blas2_trees.hpp" -#include "operations/blas_constants.hpp" -#include "views/view_sycl.hpp" +#include "sb_handle/kernel_constructor.hpp" +#include "sb_handle/sycl_blas_handle.hpp" namespace blas { namespace internal { -/*! - @brief Generalised matrix vector product with rectangular non-symmetric - matrices. - - Generalised matrix vector product with rectangular non-symmetric matrices, i.e. - computing the mathematical operation: - - y = alpha*A*x + beta*y - - See the netlib blas interface documentation for more details of the high level - interface: http://www.netlib.org/lapack/explore-html/db/d58/sgemv_8f.html - SB_Handle& sb_handle, // SB_Handle (sycl, parallel, serial, etc) - char _trans, // The transposition of the matrix ('n', 't', 'c') - index_t _M, // The size of dimension M of the matrix (rows) - index_t _N, // The size of dimension N of the matrix (columns) - element_t _alpha, // Scalar parameter Alpha - container_t0 _mA, // An array (LDA,N), with the first m*n elements - index_t _lda, // Specifies the first dimension of a, max(1, m) - container_t1 _vx, // An array of dimension at least: - (1+(n-1)*abs(incx)) - // when trans = 'n' and (1+(m-1)*abs(incx) otherwise, - // containing the vector "x" - increment_t _incx, // The increment for elements in x (nonzero). - element_t _beta, // Scalar parameter Beta - container_t2 _vy, // An array of dimension at least: - (1+(m-1)*abs(incy)) - // when trans = "n" and (1+(n-1)*abs(incy) otherwise, - // containing the vector "y" (if beta is nonzero). When - // finished, y is overwritten with the updated vector. - increment_t _incy // The increment for elements in y (nonzero). - */ template typename SB_Handle::event_t _gemv( SB_Handle& sb_handle, char _trans, ${INDEX_TYPE} _M, ${INDEX_TYPE} _N, ${DATA_TYPE} _alpha, ${container_t0} _mA, ${INDEX_TYPE} _lda, diff --git a/src/interface/blas2/symv.cpp.in b/src/interface/blas2/symv.cpp.in index a15ca00a2..963632ece 100644 --- a/src/interface/blas2/symv.cpp.in +++ b/src/interface/blas2/symv.cpp.in @@ -22,23 +22,18 @@ * @filename symv.cpp.in * **************************************************************************/ -#include "container/sycl_iterator.hpp" -#include "sb_handle/sycl_blas_handle.hpp" -#include "sb_handle/kernel_constructor.hpp" #include "interface/blas2_interface.hpp" -#include "operations/blas1_trees.hpp" -#include "operations/blas2_trees.hpp" -#include "operations/blas_constants.hpp" -#include "views/view_sycl.hpp" +#include "sb_handle/kernel_constructor.hpp" +#include "sb_handle/sycl_blas_handle.hpp" namespace blas { namespace internal { template typename SB_Handle::event_t _symv( - SB_Handle& sb_handle, char _Uplo, ${INDEX_TYPE} _N, - ${DATA_TYPE} _alpha, ${container_t0} _mA, ${INDEX_TYPE} _lda, - ${container_t1} _vx, ${INCREMENT_TYPE} _incx, ${DATA_TYPE} _beta, - ${container_t2} _vy, ${INCREMENT_TYPE} _incy); + SB_Handle& sb_handle, char _Uplo, ${INDEX_TYPE} _N, ${DATA_TYPE} _alpha, + ${container_t0} _mA, ${INDEX_TYPE} _lda, ${container_t1} _vx, + ${INCREMENT_TYPE} _incx, ${DATA_TYPE} _beta, ${container_t2} _vy, + ${INCREMENT_TYPE} _incy); } // namespace internal } // namespace blas diff --git a/src/interface/blas2/trmv.cpp.in b/src/interface/blas2/trmv.cpp.in index 5dcc34f66..62cda27fc 100644 --- a/src/interface/blas2/trmv.cpp.in +++ b/src/interface/blas2/trmv.cpp.in @@ -22,21 +22,16 @@ * @filename trmv.cpp.in * **************************************************************************/ -#include "container/sycl_iterator.hpp" -#include "sb_handle/sycl_blas_handle.hpp" -#include "sb_handle/kernel_constructor.hpp" #include "interface/blas2_interface.hpp" -#include "operations/blas1_trees.hpp" -#include "operations/blas2_trees.hpp" -#include "operations/blas_constants.hpp" -#include "views/view_sycl.hpp" +#include "sb_handle/kernel_constructor.hpp" +#include "sb_handle/sycl_blas_handle.hpp" namespace blas { namespace internal { template typename SB_Handle::event_t _trmv( - SB_Handle& sb_handle, char _Uplo, char _trans, char _Diag, - ${INDEX_TYPE} _N, ${container_t0} _mA, ${INDEX_TYPE} _lda, - ${container_t1} _vx, ${INCREMENT_TYPE} _incx); + SB_Handle& sb_handle, char _Uplo, char _trans, char _Diag, ${INDEX_TYPE} _N, + ${container_t0} _mA, ${INDEX_TYPE} _lda, ${container_t1} _vx, + ${INCREMENT_TYPE} _incx); } // namespace internal } // end namespace blas