Skip to content

Commit

Permalink
Update gemv tuning target parameters (#449)
Browse files Browse the repository at this point in the history
This patch adjusts the gemv tuning target parameters and removes unnecessary headers for the generation of matrix-vector multiplication routines.
  • Loading branch information
pgorlani authored Jul 20, 2023
1 parent 0355a58 commit 2c6b203
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 68 deletions.
8 changes: 4 additions & 4 deletions src/interface/blas2/backend/amd_gpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,15 @@ typename SB_Handle::event_t _gemv(SB_Handle& sb_handle, index_t _M, index_t _N,
index_t _lda, container_t1 _vx,
increment_t _incx, element_t _beta,
container_t2 _vy, increment_t _incy) {
static constexpr uint32_t cache_line_size = 256;
static constexpr uint32_t cache_line_size = 128;
if (trn == transpose_type::Normal) {
return blas::internal::_gemv_impl<256, cache_line_size,
gemv_memory_t::local, trn>(
sb_handle, _M, _N, _alpha, _mA, _lda, _vx, _incx, _beta, _vy, _incy);
} else {
return blas::internal::_gemv_impl<64, cache_line_size, gemv_memory_t::local,
trn>(sb_handle, _M, _N, _alpha, _mA, _lda,
_vx, _incx, _beta, _vy, _incy);
return blas::internal::_gemv_impl<128, cache_line_size,
gemv_memory_t::local, trn>(
sb_handle, _M, _N, _alpha, _mA, _lda, _vx, _incx, _beta, _vy, _incy);
}
}
} // namespace backend
Expand Down
14 changes: 11 additions & 3 deletions src/interface/blas2/backend/intel_gpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,18 @@ typename SB_Handle::event_t _gemv(SB_Handle& sb_handle, index_t _M, index_t _N,
increment_t _incx, element_t _beta,
container_t2 _vy, increment_t _incy) {
if (trn == transpose_type::Normal) {
return blas::internal::_gemv_impl<256, 32, gemv_memory_t::local, trn>(
sb_handle, _M, _N, _alpha, _mA, _lda, _vx, _incx, _beta, _vy, _incy);
if (_N < 8192) {
return blas::internal::_gemv_impl<128, 64, gemv_memory_t::local, trn>(
sb_handle, _M, _N, _alpha, _mA, _lda, _vx, _incx, _beta, _vy, _incy);
} else if (_N < 16384) {
return blas::internal::_gemv_impl<256, 64, gemv_memory_t::local, trn>(
sb_handle, _M, _N, _alpha, _mA, _lda, _vx, _incx, _beta, _vy, _incy);
} else {
return blas::internal::_gemv_impl<512, 64, gemv_memory_t::local, trn>(
sb_handle, _M, _N, _alpha, _mA, _lda, _vx, _incx, _beta, _vy, _incy);
}
} else {
return blas::internal::_gemv_impl<128, 32, gemv_memory_t::local, trn>(
return blas::internal::_gemv_impl<128, 64, gemv_memory_t::local, trn>(
sb_handle, _M, _N, _alpha, _mA, _lda, _vx, _incx, _beta, _vy, _incy);
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/interface/blas2/backend/nvidia_gpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ typename SB_Handle::event_t _gemv(SB_Handle& sb_handle, index_t _M, index_t _N,
increment_t _incx, element_t _beta,
container_t2 _vy, increment_t _incy) {
if (trn == transpose_type::Normal) {
return blas::internal::_gemv_impl<256, 32, gemv_memory_t::local, trn>(
return blas::internal::_gemv_impl<256, 128, gemv_memory_t::local, trn>(
sb_handle, _M, _N, _alpha, _mA, _lda, _vx, _incx, _beta, _vy, _incy);
} else {
return blas::internal::_gemv_impl<128, 32, gemv_memory_t::local, trn>(
return blas::internal::_gemv_impl<128, 128, gemv_memory_t::local, trn>(
sb_handle, _M, _N, _alpha, _mA, _lda, _vx, _incx, _beta, _vy, _incy);
}
}
Expand Down
40 changes: 2 additions & 38 deletions src/interface/blas2/gemv.cpp.in
Original file line number Diff line number Diff line change
Expand Up @@ -22,49 +22,13 @@
* @filename gemv.cpp.in
*
**************************************************************************/
#include "container/sycl_iterator.hpp"
#include "sb_handle/sycl_blas_handle.hpp"
#include "sb_handle/kernel_constructor.hpp"
#include "interface/blas2_interface.hpp"
#include "operations/blas1_trees.hpp"
#include "operations/blas2_trees.hpp"
#include "operations/blas_constants.hpp"
#include "views/view_sycl.hpp"
#include "sb_handle/kernel_constructor.hpp"
#include "sb_handle/sycl_blas_handle.hpp"

namespace blas {
namespace internal {

/*!
@brief Generalised matrix vector product with rectangular non-symmetric
matrices.

Generalised matrix vector product with rectangular non-symmetric matrices, i.e.
computing the mathematical operation:

y = alpha*A*x + beta*y

See the netlib blas interface documentation for more details of the high level
interface: http://www.netlib.org/lapack/explore-html/db/d58/sgemv_8f.html
SB_Handle& sb_handle, // SB_Handle (sycl, parallel, serial, etc)
char _trans, // The transposition of the matrix ('n', 't', 'c')
index_t _M, // The size of dimension M of the matrix (rows)
index_t _N, // The size of dimension N of the matrix (columns)
element_t _alpha, // Scalar parameter Alpha
container_t0 _mA, // An array (LDA,N), with the first m*n elements
index_t _lda, // Specifies the first dimension of a, max(1, m)
container_t1 _vx, // An array of dimension at least:
(1+(n-1)*abs(incx))
// when trans = 'n' and (1+(m-1)*abs(incx) otherwise,
// containing the vector "x"
increment_t _incx, // The increment for elements in x (nonzero).
element_t _beta, // Scalar parameter Beta
container_t2 _vy, // An array of dimension at least:
(1+(m-1)*abs(incy))
// when trans = "n" and (1+(n-1)*abs(incy) otherwise,
// containing the vector "y" (if beta is nonzero). When
// finished, y is overwritten with the updated vector.
increment_t _incy // The increment for elements in y (nonzero).
*/
template typename SB_Handle::event_t _gemv(
SB_Handle& sb_handle, char _trans, ${INDEX_TYPE} _M, ${INDEX_TYPE} _N,
${DATA_TYPE} _alpha, ${container_t0} _mA, ${INDEX_TYPE} _lda,
Expand Down
17 changes: 6 additions & 11 deletions src/interface/blas2/symv.cpp.in
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,18 @@
* @filename symv.cpp.in
*
**************************************************************************/
#include "container/sycl_iterator.hpp"
#include "sb_handle/sycl_blas_handle.hpp"
#include "sb_handle/kernel_constructor.hpp"
#include "interface/blas2_interface.hpp"
#include "operations/blas1_trees.hpp"
#include "operations/blas2_trees.hpp"
#include "operations/blas_constants.hpp"
#include "views/view_sycl.hpp"
#include "sb_handle/kernel_constructor.hpp"
#include "sb_handle/sycl_blas_handle.hpp"

namespace blas {
namespace internal {

template typename SB_Handle::event_t _symv(
SB_Handle& sb_handle, char _Uplo, ${INDEX_TYPE} _N,
${DATA_TYPE} _alpha, ${container_t0} _mA, ${INDEX_TYPE} _lda,
${container_t1} _vx, ${INCREMENT_TYPE} _incx, ${DATA_TYPE} _beta,
${container_t2} _vy, ${INCREMENT_TYPE} _incy);
SB_Handle& sb_handle, char _Uplo, ${INDEX_TYPE} _N, ${DATA_TYPE} _alpha,
${container_t0} _mA, ${INDEX_TYPE} _lda, ${container_t1} _vx,
${INCREMENT_TYPE} _incx, ${DATA_TYPE} _beta, ${container_t2} _vy,
${INCREMENT_TYPE} _incy);

} // namespace internal
} // namespace blas
15 changes: 5 additions & 10 deletions src/interface/blas2/trmv.cpp.in
Original file line number Diff line number Diff line change
Expand Up @@ -22,21 +22,16 @@
* @filename trmv.cpp.in
*
**************************************************************************/
#include "container/sycl_iterator.hpp"
#include "sb_handle/sycl_blas_handle.hpp"
#include "sb_handle/kernel_constructor.hpp"
#include "interface/blas2_interface.hpp"
#include "operations/blas1_trees.hpp"
#include "operations/blas2_trees.hpp"
#include "operations/blas_constants.hpp"
#include "views/view_sycl.hpp"
#include "sb_handle/kernel_constructor.hpp"
#include "sb_handle/sycl_blas_handle.hpp"

namespace blas {
namespace internal {

template typename SB_Handle::event_t _trmv(
SB_Handle& sb_handle, char _Uplo, char _trans, char _Diag,
${INDEX_TYPE} _N, ${container_t0} _mA, ${INDEX_TYPE} _lda,
${container_t1} _vx, ${INCREMENT_TYPE} _incx);
SB_Handle& sb_handle, char _Uplo, char _trans, char _Diag, ${INDEX_TYPE} _N,
${container_t0} _mA, ${INDEX_TYPE} _lda, ${container_t1} _vx,
${INCREMENT_TYPE} _incx);
} // namespace internal
} // end namespace blas

0 comments on commit 2c6b203

Please sign in to comment.