diff --git a/cmake/CmakeFunctionHelper.cmake b/cmake/CmakeFunctionHelper.cmake index 8fd935691..ddd1e61f3 100644 --- a/cmake/CmakeFunctionHelper.cmake +++ b/cmake/CmakeFunctionHelper.cmake @@ -469,18 +469,28 @@ elseif(${TUNING_TARGET} STREQUAL "NVIDIA_GPU") add_gemm_configuration( "float" 256 "false" "true" "true" 128 8 8 16 16 16 2 1 1 1 1 16 16 16 cl::sycl::half float "local" "standard" "none" 1 "strided" "true") - endif() - # Non-Joint Matrix specific GEMM Configurations - add_gemm_configuration( - "${data}" 128 "false" "false" "true" - 128 2 2 8 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") - add_gemm_configuration( + add_gemm_configuration( "${data}" 64 "false" "false" "true" 64 8 8 8 8 1 1 2 2 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") + endif() + + # Non-Joint Matrix specific GEMM Configurations add_gemm_configuration( "${data}" 64 "false" "false" "false" 64 2 2 4 4 1 1 1 1 4 4 1 1 1 float float "no_local" "standard" "full" 4 "interleaved" "false") - endforeach() + add_gemm_configuration( + "${data}" 128 "false" "true" "true" + 128 2 2 16 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") + add_gemm_configuration( + "${data}" 128 "false" "true" "true" + 128 4 4 16 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") + add_gemm_configuration( + "${data}" 128 "false" "true" "true" + 128 8 8 16 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") + add_gemm_configuration( + "${data}" 256 "false" "true" "true" + 128 8 8 16 16 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") + endforeach() else() # default cpu backend set(supported_types "float" diff --git a/src/interface/blas3/backend/nvidia_gpu.hpp b/src/interface/blas3/backend/nvidia_gpu.hpp index d0d6df51a..15e502dcc 100644 --- a/src/interface/blas3/backend/nvidia_gpu.hpp +++ b/src/interface/blas3/backend/nvidia_gpu.hpp @@ -100,11 +100,14 @@ typename sb_handle_t::event_t _gemm( _ldc, _stridec, batch_size, _dependencies); } - } else { + } +#endif // SB_ENABLE_JOINT_MATRIX + + if (batch_size > 1) { return blas::Gemm_Launcher< - container_0_t, container_1_t, container_2_t, 64, false, false, true, 64, - Tile<8, 8, 8, 8, 1, 1, 2, 2, 1, 1, 1, 1, 1, float, float>, _t_a, _t_b, - s_a, s_b, static_cast(gemm_memory_t::local), + container_0_t, container_1_t, container_2_t, 256, false, true, true, + 128, Tile<8, 8, 16, 16, 1, 1, 1, 1, 1, 1, 1, 1, 1, float, float>, _t_a, + _t_b, s_a, s_b, static_cast(gemm_memory_t::local), static_cast(gemm_algorithm_t::standard), static_cast(gemm_vectorization_t::full), is_beta_zero, 1, static_cast(gemm_batch_type_t::strided), @@ -112,14 +115,35 @@ typename sb_handle_t::event_t _gemm( _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, batch_size, _dependencies); - } - -#else // SB_ENABLE_JOINT_MATRIX - else { + } else if (_M <= 256 && _N <= 256) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 128, false, true, true, + 128, Tile<2, 2, 16, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1, float, float>, _t_a, + _t_b, s_a, s_b, static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::full), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided), + false>::template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, + _ldc, _stridec, batch_size, + _dependencies); + } else if (_M <= 1024 && _N <= 1024) { return blas::Gemm_Launcher< - container_0_t, container_1_t, container_2_t, 64, false, false, true, 64, - Tile<8, 8, 8, 8, 1, 1, 2, 2, 1, 1, 1, 1, 1, float, float>, _t_a, _t_b, - s_a, s_b, static_cast(gemm_memory_t::local), + container_0_t, container_1_t, container_2_t, 128, false, true, true, + 128, Tile<4, 4, 16, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1, float, float>, _t_a, + _t_b, s_a, s_b, static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::full), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided), + false>::template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, + _ldc, _stridec, batch_size, + _dependencies); + } else if (_M <= 2048 && _N <= 2048) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 128, false, true, true, + 128, Tile<8, 8, 16, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1, float, float>, _t_a, + _t_b, s_a, s_b, static_cast(gemm_memory_t::local), static_cast(gemm_algorithm_t::standard), static_cast(gemm_vectorization_t::full), is_beta_zero, 1, static_cast(gemm_batch_type_t::strided), @@ -128,7 +152,17 @@ typename sb_handle_t::event_t _gemm( _ldc, _stridec, batch_size, _dependencies); } -#endif + + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 256, false, true, true, 128, + Tile<8, 8, 16, 16, 1, 1, 1, 1, 1, 1, 1, 1, 1, float, float>, _t_a, _t_b, + s_a, s_b, static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::full), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided), + false>::template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, + _ldc, _stridec, batch_size, _dependencies); } } // namespace backend } // namespace gemm diff --git a/src/operations/blas3/gemm_local.hpp b/src/operations/blas3/gemm_local.hpp index 8ef8a4452..9b1c1c98b 100644 --- a/src/operations/blas3/gemm_local.hpp +++ b/src/operations/blas3/gemm_local.hpp @@ -136,10 +136,6 @@ class Gemm 1) + ? ((4 * compute_units - 1) / get_workgroup_cluster() + 1) + : 1; } /*! @@ -289,7 +287,7 @@ class Gemm= m || wg_col >= n); const bool internal = m - wg_row >= block_rows && n - wg_col >= block_cols; const index_t vector_offset = internal ? packetize_t::packet_size : 1; @@ -373,15 +371,12 @@ class Gemm(j * wg_rows * offset < mc) && do_check(i < nc); if (in_range) { -#pragma unroll for (index_t l = 0; l < offset; ++l) { reg_res[i * item_rows + j * offset + l] = beta_ * *(C + j * (wg_rows * offset) + l); @@ -397,7 +392,6 @@ class Gemm::type scaling_c( element_t *reg_res, InputPointerType, const index_t &, const index_t &, const index_t &, const bool) { -#pragma unroll for (index_t i = 0; i < item_cols * item_rows; ++i) { reg_res[i] = 0; } @@ -560,9 +554,7 @@ class Gemm(j * wg_rows * offset < mc) && @@ -744,6 +736,9 @@ class Gemm