From f3818c4b31cddee9687efb891578f93fc1e51993 Mon Sep 17 00:00:00 2001 From: Ouadie EL FAROUKI Date: Tue, 15 Aug 2023 11:25:29 +0100 Subject: [PATCH] Fixed minor issue in make_gemm & updated Gemm doc --- doc/Gemm.md | 60 ++++++++++++++++++++------------ include/operations/blas3_trees.h | 4 +-- src/interface/gemm_launcher.hpp | 3 +- 3 files changed, 41 insertions(+), 26 deletions(-) diff --git a/doc/Gemm.md b/doc/Gemm.md index 4424cd2cc..73c52c9b5 100644 --- a/doc/Gemm.md +++ b/doc/Gemm.md @@ -191,34 +191,41 @@ namespace blas { /*! * @brief Wrapper around Gemm. Creates the views, then makes and launches Gemm */ -template +template -template -typename SB_Handle::event_t Gemm_Launcher< - WgSize, DoubleBuffer, ConflictA, ConflictB, ClSize, TileT, TransA, TransB, - GemmMemoryType, GemmAlgorithm, GemmVectorization, is_beta_zero, VectorSize, - BatchType>::_select_gemm(SB_Handle& sb_handle, index_t _M, index_t _N, index_t _K, - element_t _alpha, container_t0 a_, index_t _lda, - container_t1 b_, index_t _ldb, element_t _beta, - container_t2 _C, index_t _ldc, - index_t batch_size) { +typename sb_handle_t::event_t +Gemm_Launcher::_select_gemm(sb_handle_t& sb_handle, index_t _M, + index_t _N, index_t _K, + element_t _alpha, container_t0 a_, + index_t _lda, index_t _stridea, + container_t1 b_, index_t _ldb, + index_t _strideb, element_t _beta, + container_t2 _C, index_t _ldc, + index_t _stridec, + index_t batch_size) { //Helper functions used to make matrix views - auto buffer_a = make_matrix_view(a_, _M, _K, _lda); - auto buffer_b = make_matrix_view(b_, _K, _N, _ldb); + auto buffer_a = make_matrix_view(a_, _M, _K, _lda); + auto buffer_b = make_matrix_view(b_, _K, _N, _ldb); auto buffer_c = make_matrix_view(_C, _M, _N, _ldc); //Helper function to construct the Gemm object - auto gemm = make_gemm( + auto gemm = make_gemm( buffer_a, buffer_b, buffer_c, element_t(_alpha), element_t(_beta), - batch_size); + batch_size, index_t(_stridea), index_t(_strideb), index_t(_stridec)); //Execute the gemm and return the associated event return sb_handle.execute(gemm); @@ -259,6 +266,14 @@ template typename SB_Handle::event_t _gemm_batched( ${INDEX_TYPE} _lda, ${INDEX_TYPE} _stridea, ${container_t1} b_, ${INDEX_TYPE} _ldb, ${INDEX_TYPE} _strideb, ${DATA_TYPE} _beta, ${container_t2} _C, ${INDEX_TYPE} _ldc, ${INDEX_TYPE} _stridec, ${INDEX_TYPE} batch_size, gemm_batch_type_t batch_type); +// strided batched gemm +template typename SB_Handle::event_t _gemm_strided_batched( + SB_Handle& sb_handle, char _TransA, char _TransB, ${INDEX_TYPE} _M, + ${INDEX_TYPE} _N, ${INDEX_TYPE} _K, ${DATA_TYPE} _alpha, ${container_t0} a_, + ${INDEX_TYPE} _lda, ${INDEX_TYPE} _stridea, ${container_t1} b_, + ${INDEX_TYPE} _ldb, ${INDEX_TYPE} _strideb, ${DATA_TYPE} _beta, + ${container_t2} _C, ${INDEX_TYPE} _ldc, ${INDEX_TYPE} _stridec, + ${INDEX_TYPE} batch_size); } // namespace internal } // namespace blas ``` @@ -300,9 +315,10 @@ template typename sb_handle_t::event_t _gemm( - sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, element_t _alpha, - container_0_t _a, index_t _lda, container_1_t _b, index_t _ldb, - element_t _beta, container_2_t _c, index_t _ldc, index_t batch_size, + sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, + element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, + container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta, + container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size, gemm_batch_type_t batch_type) { if (batch_type == gemm_batch_type_t::interleaved) { return blas::Gemm_Launcher< diff --git a/include/operations/blas3_trees.h b/include/operations/blas3_trees.h index 2d6113ffc..3f151f24e 100644 --- a/include/operations/blas3_trees.h +++ b/include/operations/blas3_trees.h @@ -267,8 +267,8 @@ inline Gemm make_gemm(input_t buffer_a, input_t buffer_b, output_t buffer_c, - element_t alpha, element_t beta, index_t batch_size, - element_t _stridea, element_t _strideb, element_t _stridec) { + element_t alpha, element_t beta, index_t batch_size, index_t _stridea, + index_t _strideb, index_t _stridec) { return Gemm( buffer_a, buffer_b, buffer_c, element_t(_alpha), element_t(_beta), - batch_size, element_t(_stridea), element_t(_strideb), - element_t(_stridec)); + batch_size, index_t(_stridea), index_t(_strideb), index_t(_stridec)); return sb_handle.execute(gemm); }